|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Contains utilities used by both the sync and async inference clients.""" |
|
|
|
|
|
import base64 |
|
|
import io |
|
|
import json |
|
|
import logging |
|
|
import mimetypes |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import ( |
|
|
TYPE_CHECKING, |
|
|
Any, |
|
|
AsyncIterable, |
|
|
BinaryIO, |
|
|
Dict, |
|
|
Iterable, |
|
|
List, |
|
|
Literal, |
|
|
NoReturn, |
|
|
Optional, |
|
|
Union, |
|
|
overload, |
|
|
) |
|
|
|
|
|
from requests import HTTPError |
|
|
|
|
|
from huggingface_hub.errors import ( |
|
|
GenerationError, |
|
|
IncompleteGenerationError, |
|
|
OverloadedError, |
|
|
TextGenerationError, |
|
|
UnknownError, |
|
|
ValidationError, |
|
|
) |
|
|
|
|
|
from ..utils import get_session, is_aiohttp_available, is_numpy_available, is_pillow_available |
|
|
from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from aiohttp import ClientResponse, ClientSession |
|
|
from PIL.Image import Image |
|
|
|
|
|
|
|
|
UrlT = str |
|
|
PathT = Union[str, Path] |
|
|
ContentT = Union[bytes, BinaryIO, PathT, UrlT, "Image", bytearray, memoryview] |
|
|
|
|
|
|
|
|
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"} |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RequestParameters: |
|
|
url: str |
|
|
task: str |
|
|
model: Optional[str] |
|
|
json: Optional[Union[str, Dict, List]] |
|
|
data: Optional[bytes] |
|
|
headers: Dict[str, Any] |
|
|
|
|
|
|
|
|
class MimeBytes(bytes): |
|
|
""" |
|
|
A bytes object with a mime type. |
|
|
To be returned by `_prepare_payload_open_as_mime_bytes` in subclasses. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
>>> b = MimeBytes(b"hello", "text/plain") |
|
|
>>> isinstance(b, bytes) |
|
|
True |
|
|
>>> b.mime_type |
|
|
'text/plain' |
|
|
``` |
|
|
""" |
|
|
|
|
|
mime_type: Optional[str] |
|
|
|
|
|
def __new__(cls, data: bytes, mime_type: Optional[str] = None): |
|
|
obj = super().__new__(cls, data) |
|
|
obj.mime_type = mime_type |
|
|
if isinstance(data, MimeBytes) and mime_type is None: |
|
|
obj.mime_type = data.mime_type |
|
|
return obj |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _import_aiohttp(): |
|
|
|
|
|
if not is_aiohttp_available(): |
|
|
raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).") |
|
|
import aiohttp |
|
|
|
|
|
return aiohttp |
|
|
|
|
|
|
|
|
def _import_numpy(): |
|
|
"""Make sure `numpy` is installed on the machine.""" |
|
|
if not is_numpy_available(): |
|
|
raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).") |
|
|
import numpy |
|
|
|
|
|
return numpy |
|
|
|
|
|
|
|
|
def _import_pil_image(): |
|
|
"""Make sure `PIL` is installed on the machine.""" |
|
|
if not is_pillow_available(): |
|
|
raise ImportError( |
|
|
"Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be" |
|
|
" post-processed, use `client.post(...)` and get the raw response from the server." |
|
|
) |
|
|
from PIL import Image |
|
|
|
|
|
return Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@overload |
|
|
def _open_as_mime_bytes(content: ContentT) -> MimeBytes: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def _open_as_mime_bytes(content: Literal[None]) -> Literal[None]: ... |
|
|
|
|
|
|
|
|
def _open_as_mime_bytes(content: Optional[ContentT]) -> Optional[MimeBytes]: |
|
|
"""Open `content` as a binary file, either from a URL, a local path, raw bytes, or a PIL Image. |
|
|
|
|
|
Do nothing if `content` is None. |
|
|
""" |
|
|
|
|
|
if content is None: |
|
|
return None |
|
|
|
|
|
|
|
|
if isinstance(content, bytes): |
|
|
return MimeBytes(content) |
|
|
|
|
|
|
|
|
if isinstance(content, (bytearray, memoryview)): |
|
|
return MimeBytes(bytes(content)) |
|
|
|
|
|
|
|
|
if hasattr(content, "read"): |
|
|
logger.debug("Reading content from BinaryIO") |
|
|
data = content.read() |
|
|
mime_type = mimetypes.guess_type(content.name)[0] if hasattr(content, "name") else None |
|
|
if isinstance(data, str): |
|
|
raise TypeError("Expected binary stream (bytes), but got text stream") |
|
|
return MimeBytes(data, mime_type=mime_type) |
|
|
|
|
|
|
|
|
if isinstance(content, str): |
|
|
if content.startswith("https://") or content.startswith("http://"): |
|
|
logger.debug(f"Downloading content from {content}") |
|
|
response = get_session().get(content) |
|
|
mime_type = response.headers.get("Content-Type") |
|
|
if mime_type is None: |
|
|
mime_type = mimetypes.guess_type(content)[0] |
|
|
return MimeBytes(response.content, mime_type=mime_type) |
|
|
|
|
|
content = Path(content) |
|
|
if not content.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local" |
|
|
" file. To pass raw content, please encode it as bytes first." |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(content, Path): |
|
|
logger.debug(f"Opening content from {content}") |
|
|
return MimeBytes(content.read_bytes(), mime_type=mimetypes.guess_type(content)[0]) |
|
|
|
|
|
|
|
|
if is_pillow_available(): |
|
|
from PIL import Image |
|
|
|
|
|
if isinstance(content, Image.Image): |
|
|
logger.debug("Converting PIL Image to bytes") |
|
|
buffer = io.BytesIO() |
|
|
format = content.format or "PNG" |
|
|
content.save(buffer, format=format) |
|
|
return MimeBytes(buffer.getvalue(), mime_type=f"image/{format.lower()}") |
|
|
|
|
|
|
|
|
raise TypeError( |
|
|
f"Unsupported content type: {type(content)}. " |
|
|
"Expected one of: bytes, bytearray, BinaryIO, memoryview, Path, str (URL or file path), or PIL.Image.Image." |
|
|
) |
|
|
|
|
|
|
|
|
def _b64_encode(content: ContentT) -> str: |
|
|
"""Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL.""" |
|
|
raw_bytes = _open_as_mime_bytes(content) |
|
|
return base64.b64encode(raw_bytes).decode() |
|
|
|
|
|
|
|
|
def _as_url(content: ContentT, default_mime_type: str) -> str: |
|
|
if isinstance(content, str) and content.startswith(("http://", "https://", "data:")): |
|
|
return content |
|
|
|
|
|
|
|
|
raw_bytes = _open_as_mime_bytes(content) |
|
|
|
|
|
|
|
|
mime_type = raw_bytes.mime_type or default_mime_type |
|
|
|
|
|
|
|
|
encoded_data = base64.b64encode(raw_bytes).decode() |
|
|
|
|
|
|
|
|
return f"data:{mime_type};base64,{encoded_data}" |
|
|
|
|
|
|
|
|
def _b64_to_image(encoded_image: str) -> "Image": |
|
|
"""Parse a base64-encoded string into a PIL Image.""" |
|
|
Image = _import_pil_image() |
|
|
return Image.open(io.BytesIO(base64.b64decode(encoded_image))) |
|
|
|
|
|
|
|
|
def _bytes_to_list(content: bytes) -> List: |
|
|
"""Parse bytes from a Response object into a Python list. |
|
|
|
|
|
Expects the response body to be JSON-encoded data. |
|
|
|
|
|
NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a |
|
|
dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. |
|
|
""" |
|
|
return json.loads(content.decode()) |
|
|
|
|
|
|
|
|
def _bytes_to_dict(content: bytes) -> Dict: |
|
|
"""Parse bytes from a Response object into a Python dictionary. |
|
|
|
|
|
Expects the response body to be JSON-encoded data. |
|
|
|
|
|
NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a |
|
|
list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. |
|
|
""" |
|
|
return json.loads(content.decode()) |
|
|
|
|
|
|
|
|
def _bytes_to_image(content: bytes) -> "Image": |
|
|
"""Parse bytes from a Response object into a PIL Image. |
|
|
|
|
|
Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead. |
|
|
""" |
|
|
Image = _import_pil_image() |
|
|
return Image.open(io.BytesIO(content)) |
|
|
|
|
|
|
|
|
def _as_dict(response: Union[bytes, Dict]) -> Dict: |
|
|
return json.loads(response) if isinstance(response, bytes) else response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _stream_text_generation_response( |
|
|
bytes_output_as_lines: Iterable[bytes], details: bool |
|
|
) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]: |
|
|
"""Used in `InferenceClient.text_generation`.""" |
|
|
|
|
|
for byte_payload in bytes_output_as_lines: |
|
|
try: |
|
|
output = _format_text_generation_stream_output(byte_payload, details) |
|
|
except StopIteration: |
|
|
break |
|
|
if output is not None: |
|
|
yield output |
|
|
|
|
|
|
|
|
async def _async_stream_text_generation_response( |
|
|
bytes_output_as_lines: AsyncIterable[bytes], details: bool |
|
|
) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: |
|
|
"""Used in `AsyncInferenceClient.text_generation`.""" |
|
|
|
|
|
async for byte_payload in bytes_output_as_lines: |
|
|
try: |
|
|
output = _format_text_generation_stream_output(byte_payload, details) |
|
|
except StopIteration: |
|
|
break |
|
|
if output is not None: |
|
|
yield output |
|
|
|
|
|
|
|
|
def _format_text_generation_stream_output( |
|
|
byte_payload: bytes, details: bool |
|
|
) -> Optional[Union[str, TextGenerationStreamOutput]]: |
|
|
if not byte_payload.startswith(b"data:"): |
|
|
return None |
|
|
|
|
|
if byte_payload.strip() == b"data: [DONE]": |
|
|
raise StopIteration("[DONE] signal received.") |
|
|
|
|
|
|
|
|
payload = byte_payload.decode("utf-8") |
|
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) |
|
|
|
|
|
|
|
|
if json_payload.get("error") is not None: |
|
|
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) |
|
|
|
|
|
|
|
|
output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload) |
|
|
return output.token.text if not details else output |
|
|
|
|
|
|
|
|
def _stream_chat_completion_response( |
|
|
bytes_lines: Iterable[bytes], |
|
|
) -> Iterable[ChatCompletionStreamOutput]: |
|
|
"""Used in `InferenceClient.chat_completion` if model is served with TGI.""" |
|
|
for item in bytes_lines: |
|
|
try: |
|
|
output = _format_chat_completion_stream_output(item) |
|
|
except StopIteration: |
|
|
break |
|
|
if output is not None: |
|
|
yield output |
|
|
|
|
|
|
|
|
async def _async_stream_chat_completion_response( |
|
|
bytes_lines: AsyncIterable[bytes], |
|
|
) -> AsyncIterable[ChatCompletionStreamOutput]: |
|
|
"""Used in `AsyncInferenceClient.chat_completion`.""" |
|
|
async for item in bytes_lines: |
|
|
try: |
|
|
output = _format_chat_completion_stream_output(item) |
|
|
except StopIteration: |
|
|
break |
|
|
if output is not None: |
|
|
yield output |
|
|
|
|
|
|
|
|
def _format_chat_completion_stream_output( |
|
|
byte_payload: bytes, |
|
|
) -> Optional[ChatCompletionStreamOutput]: |
|
|
if not byte_payload.startswith(b"data:"): |
|
|
return None |
|
|
|
|
|
if byte_payload.strip() == b"data: [DONE]": |
|
|
raise StopIteration("[DONE] signal received.") |
|
|
|
|
|
|
|
|
payload = byte_payload.decode("utf-8") |
|
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) |
|
|
|
|
|
|
|
|
if json_payload.get("error") is not None: |
|
|
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) |
|
|
|
|
|
|
|
|
return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) |
|
|
|
|
|
|
|
|
async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]: |
|
|
try: |
|
|
async for byte_payload in response.content: |
|
|
yield byte_payload.strip() |
|
|
finally: |
|
|
|
|
|
await client.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {} |
|
|
|
|
|
|
|
|
def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None: |
|
|
_UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs) |
|
|
|
|
|
|
|
|
def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]: |
|
|
return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def raise_text_generation_error(http_error: HTTPError) -> NoReturn: |
|
|
""" |
|
|
Try to parse text-generation-inference error message and raise HTTPError in any case. |
|
|
|
|
|
Args: |
|
|
error (`HTTPError`): |
|
|
The HTTPError that have been raised. |
|
|
""" |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() |
|
|
error = payload.get("error") |
|
|
error_type = payload.get("error_type") |
|
|
except Exception: |
|
|
raise http_error |
|
|
|
|
|
|
|
|
if error_type is not None: |
|
|
exception = _parse_text_generation_error(error, error_type) |
|
|
raise exception from http_error |
|
|
|
|
|
|
|
|
raise http_error |
|
|
|
|
|
|
|
|
def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: |
|
|
if error_type == "generation": |
|
|
return GenerationError(error) |
|
|
if error_type == "incomplete_generation": |
|
|
return IncompleteGenerationError(error) |
|
|
if error_type == "overloaded": |
|
|
return OverloadedError(error) |
|
|
if error_type == "validation": |
|
|
return ValidationError(error) |
|
|
return UnknownError(error) |
|
|
|