| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import logging |
| import time |
| import warnings |
| from dataclasses import asdict |
| from typing import ( |
| TYPE_CHECKING, |
| Any, |
| Dict, |
| Iterable, |
| List, |
| Literal, |
| Optional, |
| Union, |
| overload, |
| ) |
|
|
| from requests import HTTPError |
| from requests.structures import CaseInsensitiveDict |
|
|
| from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS |
| from huggingface_hub.inference._common import ( |
| TASKS_EXPECTING_IMAGES, |
| ContentT, |
| InferenceTimeoutError, |
| ModelStatus, |
| _b64_encode, |
| _b64_to_image, |
| _bytes_to_dict, |
| _bytes_to_image, |
| _bytes_to_list, |
| _fetch_recommended_models, |
| _import_numpy, |
| _is_tgi_server, |
| _open_as_binary, |
| _set_as_non_tgi, |
| _stream_text_generation_response, |
| ) |
| from huggingface_hub.inference._text_generation import ( |
| TextGenerationParameters, |
| TextGenerationRequest, |
| TextGenerationResponse, |
| TextGenerationStreamResponse, |
| raise_text_generation_error, |
| ) |
| from huggingface_hub.inference._types import ( |
| ClassificationOutput, |
| ConversationalOutput, |
| FillMaskOutput, |
| ImageSegmentationOutput, |
| ObjectDetectionOutput, |
| QuestionAnsweringOutput, |
| TableQuestionAnsweringOutput, |
| TokenClassificationOutput, |
| ) |
| from huggingface_hub.utils import ( |
| BadRequestError, |
| build_hf_headers, |
| get_session, |
| hf_raise_for_status, |
| ) |
|
|
|
|
| if TYPE_CHECKING: |
| import numpy as np |
| from PIL import Image |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class InferenceClient: |
| """ |
| Initialize a new Inference Client. |
| |
| [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used |
| seamlessly with either the (free) Inference API or self-hosted Inference Endpoints. |
| |
| Args: |
| model (`str`, `optional`): |
| The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder` |
| or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is |
| automatically selected for the task. |
| token (`str`, *optional*): |
| Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send |
| your token to the server. |
| timeout (`float`, `optional`): |
| The maximum number of seconds to wait for a response from the server. Loading a new model in Inference |
| API can take up to several minutes. Defaults to None, meaning it will loop until the server is available. |
| headers (`Dict[str, str]`, `optional`): |
| Additional headers to send to the server. By default only the authorization and user-agent headers are sent. |
| Values in this dictionary will override the default values. |
| cookies (`Dict[str, str]`, `optional`): |
| Additional cookies to send to the server. |
| """ |
|
|
| def __init__( |
| self, |
| model: Optional[str] = None, |
| token: Union[str, bool, None] = None, |
| timeout: Optional[float] = None, |
| headers: Optional[Dict[str, str]] = None, |
| cookies: Optional[Dict[str, str]] = None, |
| ) -> None: |
| self.model: Optional[str] = model |
| self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) |
| if headers is not None: |
| self.headers.update(headers) |
| self.cookies = cookies |
| self.timeout = timeout |
|
|
| def __repr__(self): |
| return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>" |
|
|
| @overload |
| def post( |
| self, |
| *, |
| json: Optional[Union[str, Dict, List]] = None, |
| data: Optional[ContentT] = None, |
| model: Optional[str] = None, |
| task: Optional[str] = None, |
| stream: Literal[False] = ..., |
| ) -> bytes: |
| pass |
|
|
| @overload |
| def post( |
| self, |
| *, |
| json: Optional[Union[str, Dict, List]] = None, |
| data: Optional[ContentT] = None, |
| model: Optional[str] = None, |
| task: Optional[str] = None, |
| stream: Literal[True] = ..., |
| ) -> Iterable[bytes]: |
| pass |
|
|
| def post( |
| self, |
| *, |
| json: Optional[Union[str, Dict, List]] = None, |
| data: Optional[ContentT] = None, |
| model: Optional[str] = None, |
| task: Optional[str] = None, |
| stream: bool = False, |
| ) -> Union[bytes, Iterable[bytes]]: |
| """ |
| Make a POST request to the inference server. |
| |
| Args: |
| json (`Union[str, Dict, List]`, *optional*): |
| The JSON data to send in the request body, specific to each task. Defaults to None. |
| data (`Union[str, Path, bytes, BinaryIO]`, *optional*): |
| The content to send in the request body, specific to each task. |
| It can be raw bytes, a pointer to an opened file, a local file path, |
| or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed, |
| `data` will take precedence. At least `json` or `data` must be provided. Defaults to None. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. Will override the model defined at the instance level. Defaults to None. |
| task (`str`, *optional*): |
| The task to perform on the inference. All available tasks can be found |
| [here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not |
| provided. At least `model` or `task` must be provided. Defaults to None. |
| stream (`bool`, *optional*): |
| Whether to iterate over streaming APIs. |
| |
| Returns: |
| bytes: The raw bytes returned by the server. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| """ |
| url = self._resolve_url(model, task) |
|
|
| if data is not None and json is not None: |
| warnings.warn("Ignoring `json` as `data` is passed as binary.") |
|
|
| |
| headers = self.headers.copy() |
| if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers: |
| headers["Accept"] = "image/png" |
|
|
| t0 = time.time() |
| timeout = self.timeout |
| while True: |
| with _open_as_binary(data) as data_as_binary: |
| try: |
| response = get_session().post( |
| url, |
| json=json, |
| data=data_as_binary, |
| headers=headers, |
| cookies=self.cookies, |
| timeout=self.timeout, |
| stream=stream, |
| ) |
| except TimeoutError as error: |
| |
| raise InferenceTimeoutError(f"Inference call timed out: {url}") from error |
|
|
| try: |
| hf_raise_for_status(response) |
| return response.iter_lines() if stream else response.content |
| except HTTPError as error: |
| if error.response.status_code == 422 and task is not None: |
| error.args = ( |
| f"{error.args[0]}\nMake sure '{task}' task is supported by the model.", |
| ) + error.args[1:] |
| if error.response.status_code == 503: |
| |
| if timeout is not None and time.time() - t0 > timeout: |
| raise InferenceTimeoutError( |
| f"Model not loaded on the server: {url}. Please retry with a higher timeout (current:" |
| f" {self.timeout}).", |
| request=error.request, |
| response=error.response, |
| ) from error |
| |
| logger.info(f"Waiting for model to be loaded on the server: {error}") |
| time.sleep(1) |
| if timeout is not None: |
| timeout = max(self.timeout - (time.time() - t0), 1) |
| continue |
| raise |
|
|
| def audio_classification( |
| self, |
| audio: ContentT, |
| *, |
| model: Optional[str] = None, |
| ) -> List[ClassificationOutput]: |
| """ |
| Perform audio classification on the provided audio content. |
| |
| Args: |
| audio (Union[str, Path, bytes, BinaryIO]): |
| The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an |
| audio file. |
| model (`str`, *optional*): |
| The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub |
| or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for |
| audio classification will be used. |
| |
| Returns: |
| `List[Dict]`: The classification output containing the predicted label and its confidence. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.audio_classification("audio.flac") |
| [{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...] |
| ``` |
| """ |
| response = self.post(data=audio, model=model, task="audio-classification") |
| return _bytes_to_list(response) |
|
|
| def automatic_speech_recognition( |
| self, |
| audio: ContentT, |
| *, |
| model: Optional[str] = None, |
| ) -> str: |
| """ |
| Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. |
| |
| Args: |
| audio (Union[str, Path, bytes, BinaryIO]): |
| The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file. |
| model (`str`, *optional*): |
| The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. If not provided, the default recommended model for ASR will be used. |
| |
| Returns: |
| str: The transcribed text. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.automatic_speech_recognition("hello_world.flac") |
| "hello world" |
| ``` |
| """ |
| response = self.post(data=audio, model=model, task="automatic-speech-recognition") |
| return _bytes_to_dict(response)["text"] |
|
|
| def conversational( |
| self, |
| text: str, |
| generated_responses: Optional[List[str]] = None, |
| past_user_inputs: Optional[List[str]] = None, |
| *, |
| parameters: Optional[Dict[str, Any]] = None, |
| model: Optional[str] = None, |
| ) -> ConversationalOutput: |
| """ |
| Generate conversational responses based on the given input text (i.e. chat with the API). |
| |
| Args: |
| text (`str`): |
| The last input from the user in the conversation. |
| generated_responses (`List[str]`, *optional*): |
| A list of strings corresponding to the earlier replies from the model. Defaults to None. |
| past_user_inputs (`List[str]`, *optional*): |
| A list of strings corresponding to the earlier replies from the user. Should be the same length as |
| `generated_responses`. Defaults to None. |
| parameters (`Dict[str, Any]`, *optional*): |
| Additional parameters for the conversational task. Defaults to None. For more details about the available |
| parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#conversational-task) |
| model (`str`, *optional*): |
| The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. |
| Defaults to None. |
| |
| Returns: |
| `Dict`: The generated conversational output. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> output = client.conversational("Hi, who are you?") |
| >>> output |
| {'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']} |
| >>> client.conversational( |
| ... "Wow, that's scary!", |
| ... generated_responses=output["conversation"]["generated_responses"], |
| ... past_user_inputs=output["conversation"]["past_user_inputs"], |
| ... ) |
| ``` |
| """ |
| payload: Dict[str, Any] = {"inputs": {"text": text}} |
| if generated_responses is not None: |
| payload["inputs"]["generated_responses"] = generated_responses |
| if past_user_inputs is not None: |
| payload["inputs"]["past_user_inputs"] = past_user_inputs |
| if parameters is not None: |
| payload["parameters"] = parameters |
| response = self.post(json=payload, model=model, task="conversational") |
| return _bytes_to_dict(response) |
|
|
| def visual_question_answering( |
| self, |
| image: ContentT, |
| question: str, |
| *, |
| model: Optional[str] = None, |
| ) -> List[str]: |
| """ |
| Answering open-ended questions based on an image. |
| |
| Args: |
| image (`Union[str, Path, bytes, BinaryIO]`): |
| The input image for the context. It can be raw bytes, an image file, or a URL to an online image. |
| question (`str`): |
| Question to be answered. |
| model (`str`, *optional*): |
| The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. |
| Defaults to None. |
| |
| Returns: |
| `List[Dict]`: a list of dictionaries containing the predicted label and associated probability. |
| |
| Raises: |
| `InferenceTimeoutError`: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.visual_question_answering( |
| ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg", |
| ... question="What is the animal doing?" |
| ... ) |
| [{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...] |
| ``` |
| """ |
| payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} |
| response = self.post(json=payload, model=model, task="visual-question-answering") |
| return _bytes_to_list(response) |
|
|
| def document_question_answering( |
| self, |
| image: ContentT, |
| question: str, |
| *, |
| model: Optional[str] = None, |
| ) -> List[QuestionAnsweringOutput]: |
| """ |
| Answer questions on document images. |
| |
| Args: |
| image (`Union[str, Path, bytes, BinaryIO]`): |
| The input image for the context. It can be raw bytes, an image file, or a URL to an online image. |
| question (`str`): |
| Question to be answered. |
| model (`str`, *optional*): |
| The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. |
| Defaults to None. |
| |
| Returns: |
| `List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") |
| [{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}] |
| ``` |
| """ |
| payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} |
| response = self.post(json=payload, model=model, task="document-question-answering") |
| return _bytes_to_list(response) |
|
|
| def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray": |
| """ |
| Generate embeddings for a given text. |
| |
| Args: |
| text (`str`): |
| The text to embed. |
| model (`str`, *optional*): |
| The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. |
| Defaults to None. |
| |
| Returns: |
| `np.ndarray`: The embedding representing the input text as a float32 numpy array. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.feature_extraction("Hi, who are you?") |
| array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ], |
| [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ], |
| ..., |
| [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) |
| ``` |
| """ |
| response = self.post(json={"inputs": text}, model=model, task="feature-extraction") |
| np = _import_numpy() |
| return np.array(_bytes_to_dict(response), dtype="float32") |
|
|
| def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutput]: |
| """ |
| Fill in a hole with a missing word (token to be precise). |
| |
| Args: |
| text (`str`): |
| a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask). |
| model (`str`, *optional*): |
| The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. |
| Defaults to None. |
| |
| Returns: |
| `List[Dict]`: a list of fill mask output dictionaries containing the predicted label, associated |
| probability, token reference, and completed text. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.fill_mask("The goal of life is <mask>.") |
| [{'score': 0.06897063553333282, |
| 'token': 11098, |
| 'token_str': ' happiness', |
| 'sequence': 'The goal of life is happiness.'}, |
| {'score': 0.06554922461509705, |
| 'token': 45075, |
| 'token_str': ' immortality', |
| 'sequence': 'The goal of life is immortality.'}] |
| ``` |
| """ |
| response = self.post(json={"inputs": text}, model=model, task="fill-mask") |
| return _bytes_to_list(response) |
|
|
| def image_classification( |
| self, |
| image: ContentT, |
| *, |
| model: Optional[str] = None, |
| ) -> List[ClassificationOutput]: |
| """ |
| Perform image classification on the given image using the specified model. |
| |
| Args: |
| image (`Union[str, Path, bytes, BinaryIO]`): |
| The image to classify. It can be raw bytes, an image file, or a URL to an online image. |
| model (`str`, *optional*): |
| The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a |
| deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. |
| |
| Returns: |
| `List[Dict]`: a list of dictionaries containing the predicted label and associated probability. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") |
| [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...] |
| ``` |
| """ |
| response = self.post(data=image, model=model, task="image-classification") |
| return _bytes_to_list(response) |
|
|
| def image_segmentation( |
| self, |
| image: ContentT, |
| *, |
| model: Optional[str] = None, |
| ) -> List[ImageSegmentationOutput]: |
| """ |
| Perform image segmentation on the given image using the specified model. |
| |
| <Tip warning={true}> |
| |
| You must have `PIL` installed if you want to work with images (`pip install Pillow`). |
| |
| </Tip> |
| |
| Args: |
| image (`Union[str, Path, bytes, BinaryIO]`): |
| The image to segment. It can be raw bytes, an image file, or a URL to an online image. |
| model (`str`, *optional*): |
| The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a |
| deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. |
| |
| Returns: |
| `List[Dict]`: A list of dictionaries containing the segmented masks and associated attributes. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.image_segmentation("cat.jpg"): |
| [{'score': 0.989008, 'label': 'LABEL_184', 'mask': <PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>}, ...] |
| ``` |
| """ |
|
|
| |
| response = self.post(data=image, model=model, task="image-segmentation") |
| output = _bytes_to_dict(response) |
|
|
| |
| if not isinstance(output, list): |
| raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...") |
| for item in output: |
| item["mask"] = _b64_to_image(item["mask"]) |
| return output |
|
|
| def image_to_image( |
| self, |
| image: ContentT, |
| prompt: Optional[str] = None, |
| *, |
| negative_prompt: Optional[str] = None, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| num_inference_steps: Optional[int] = None, |
| guidance_scale: Optional[float] = None, |
| model: Optional[str] = None, |
| **kwargs, |
| ) -> "Image": |
| """ |
| Perform image-to-image translation using a specified model. |
| |
| <Tip warning={true}> |
| |
| You must have `PIL` installed if you want to work with images (`pip install Pillow`). |
| |
| </Tip> |
| |
| Args: |
| image (`Union[str, Path, bytes, BinaryIO]`): |
| The input image for translation. It can be raw bytes, an image file, or a URL to an online image. |
| prompt (`str`, *optional*): |
| The text prompt to guide the image generation. |
| negative_prompt (`str`, *optional*): |
| A negative prompt to guide the translation process. |
| height (`int`, *optional*): |
| The height in pixels of the generated image. |
| width (`int`, *optional*): |
| The width in pixels of the generated image. |
| num_inference_steps (`int`, *optional*): |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| expense of slower inference. |
| guidance_scale (`float`, *optional*): |
| Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
| usually at the expense of lower image quality. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. |
| |
| Returns: |
| `Image`: The translated image. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger") |
| >>> image.save("tiger.jpg") |
| ``` |
| """ |
| parameters = { |
| "prompt": prompt, |
| "negative_prompt": negative_prompt, |
| "height": height, |
| "width": width, |
| "num_inference_steps": num_inference_steps, |
| "guidance_scale": guidance_scale, |
| **kwargs, |
| } |
| if all(parameter is None for parameter in parameters.values()): |
| |
| data = image |
| payload: Optional[Dict[str, Any]] = None |
| else: |
| |
| data = None |
| payload = {"inputs": _b64_encode(image)} |
| for key, value in parameters.items(): |
| if value is not None: |
| payload.setdefault("parameters", {})[key] = value |
|
|
| response = self.post(json=payload, data=data, model=model, task="image-to-image") |
| return _bytes_to_image(response) |
|
|
| def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> str: |
| """ |
| Takes an input image and return text. |
| |
| Models can have very different outputs depending on your use case (image captioning, optical character recognition |
| (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities. |
| |
| Args: |
| image (`Union[str, Path, bytes, BinaryIO]`): |
| The input image to caption. It can be raw bytes, an image file, or a URL to an online image.. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. |
| |
| Returns: |
| `str`: The generated text. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.image_to_text("cat.jpg") |
| 'a cat standing in a grassy field ' |
| >>> client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") |
| 'a dog laying on the grass next to a flower pot ' |
| ``` |
| """ |
| response = self.post(data=image, model=model, task="image-to-text") |
| return _bytes_to_dict(response)[0]["generated_text"] |
|
|
| def list_deployed_models( |
| self, frameworks: Union[None, str, Literal["all"], List[str]] = None |
| ) -> Dict[str, List[str]]: |
| """ |
| List models currently deployed on the Inference API service. |
| |
| This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that |
| are supported and account for 95% of the hosted models. However, if you want a complete list of models you can |
| specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested |
| in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more |
| frameworks are checked, the more time it will take. |
| |
| <Tip> |
| |
| This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to |
| check its availability, you can directly use [`~InferenceClient.get_model_status`]. |
| |
| </Tip> |
| |
| Args: |
| frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*): |
| The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to |
| "all", all available frameworks will be tested. It is also possible to provide a single framework or a |
| custom set of frameworks to check. |
| |
| Returns: |
| `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs. |
| |
| Example: |
| ```python |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| |
| # Discover zero-shot-classification models currently deployed |
| >>> models = client.list_deployed_models() |
| >>> models["zero-shot-classification"] |
| ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...] |
| |
| # List from only 1 framework |
| >>> client.list_deployed_models("text-generation-inference") |
| {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...} |
| ``` |
| """ |
| |
| if frameworks is None: |
| frameworks = MAIN_INFERENCE_API_FRAMEWORKS |
| elif frameworks == "all": |
| frameworks = ALL_INFERENCE_API_FRAMEWORKS |
| elif isinstance(frameworks, str): |
| frameworks = [frameworks] |
| frameworks = list(set(frameworks)) |
|
|
| |
| models_by_task: Dict[str, List[str]] = {} |
|
|
| def _unpack_response(framework: str, items: List[Dict]) -> None: |
| for model in items: |
| if framework == "sentence-transformers": |
| |
| |
| models_by_task.setdefault("feature-extraction", []).append(model["model_id"]) |
| models_by_task.setdefault("sentence-similarity", []).append(model["model_id"]) |
| else: |
| models_by_task.setdefault(model["task"], []).append(model["model_id"]) |
|
|
| for framework in frameworks: |
| response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=self.headers) |
| hf_raise_for_status(response) |
| _unpack_response(framework, response.json()) |
|
|
| |
| for task, models in models_by_task.items(): |
| models_by_task[task] = sorted(set(models), key=lambda x: x.lower()) |
| return models_by_task |
|
|
| def object_detection( |
| self, |
| image: ContentT, |
| *, |
| model: Optional[str] = None, |
| ) -> List[ObjectDetectionOutput]: |
| """ |
| Perform object detection on the given image using the specified model. |
| |
| <Tip warning={true}> |
| |
| You must have `PIL` installed if you want to work with images (`pip install Pillow`). |
| |
| </Tip> |
| |
| Args: |
| image (`Union[str, Path, bytes, BinaryIO]`): |
| The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image. |
| model (`str`, *optional*): |
| The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a |
| deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. |
| |
| Returns: |
| `List[ObjectDetectionOutput]`: A list of dictionaries containing the bounding boxes and associated attributes. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| `ValueError`: |
| If the request output is not a List. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.object_detection("people.jpg"): |
| [{"score":0.9486683011054993,"label":"person","box":{"xmin":59,"ymin":39,"xmax":420,"ymax":510}}, ... ] |
| ``` |
| """ |
| |
| response = self.post(data=image, model=model, task="object-detection") |
| output = _bytes_to_dict(response) |
| if not isinstance(output, list): |
| raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...") |
| return output |
|
|
| def question_answering( |
| self, question: str, context: str, *, model: Optional[str] = None |
| ) -> QuestionAnsweringOutput: |
| """ |
| Retrieve the answer to a question from a given text. |
| |
| Args: |
| question (`str`): |
| Question to be answered. |
| context (`str`): |
| The context of the question. |
| model (`str`): |
| The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. |
| |
| Returns: |
| `Dict`: a dictionary of question answering output containing the score, start index, end index, and answer. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") |
| {'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'} |
| ``` |
| """ |
|
|
| payload: Dict[str, Any] = {"question": question, "context": context} |
| response = self.post( |
| json=payload, |
| model=model, |
| task="question-answering", |
| ) |
| return _bytes_to_dict(response) |
|
|
| def sentence_similarity( |
| self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None |
| ) -> List[float]: |
| """ |
| Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. |
| |
| Args: |
| sentence (`str`): |
| The main sentence to compare to others. |
| other_sentences (`List[str]`): |
| The list of sentences to compare to. |
| model (`str`, *optional*): |
| The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. |
| Defaults to None. |
| |
| Returns: |
| `List[float]`: The embedding representing the input text. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.sentence_similarity( |
| ... "Machine learning is so easy.", |
| ... other_sentences=[ |
| ... "Deep learning is so straightforward.", |
| ... "This is so difficult, like rocket science.", |
| ... "I can't believe how much I struggled with this.", |
| ... ], |
| ... ) |
| [0.7785726189613342, 0.45876261591911316, 0.2906220555305481] |
| ``` |
| """ |
| response = self.post( |
| json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}}, |
| model=model, |
| task="sentence-similarity", |
| ) |
| return _bytes_to_list(response) |
|
|
| def summarization( |
| self, |
| text: str, |
| *, |
| parameters: Optional[Dict[str, Any]] = None, |
| model: Optional[str] = None, |
| ) -> str: |
| """ |
| Generate a summary of a given text using a specified model. |
| |
| Args: |
| text (`str`): |
| The input text to summarize. |
| parameters (`Dict[str, Any]`, *optional*): |
| Additional parameters for summarization. Check out this [page](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task) |
| for more details. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. |
| |
| Returns: |
| `str`: The generated summary text. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.summarization("The Eiffel tower...") |
| 'The Eiffel tower is one of the most famous landmarks in the world....' |
| ``` |
| """ |
| payload: Dict[str, Any] = {"inputs": text} |
| if parameters is not None: |
| payload["parameters"] = parameters |
| response = self.post(json=payload, model=model, task="summarization") |
| return _bytes_to_dict(response)[0]["summary_text"] |
|
|
| def table_question_answering( |
| self, table: Dict[str, Any], query: str, *, model: Optional[str] = None |
| ) -> TableQuestionAnsweringOutput: |
| """ |
| Retrieve the answer to a question from information given in a table. |
| |
| Args: |
| table (`str`): |
| A table of data represented as a dict of lists where entries are headers and the lists are all the |
| values, all lists must have the same size. |
| query (`str`): |
| The query in plain text that you want to ask the table. |
| model (`str`): |
| The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face |
| Hub or a URL to a deployed Inference Endpoint. |
| |
| Returns: |
| `Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> query = "How many stars does the transformers repository have?" |
| >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]} |
| >>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq") |
| {'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'} |
| ``` |
| """ |
| response = self.post( |
| json={ |
| "query": query, |
| "table": table, |
| }, |
| model=model, |
| task="table-question-answering", |
| ) |
| return _bytes_to_dict(response) |
|
|
| def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]: |
| """ |
| Classifying a target category (a group) based on a set of attributes. |
| |
| Args: |
| table (`Dict[str, Any]`): |
| Set of attributes to classify. |
| model (`str`): |
| The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. |
| |
| Returns: |
| `List`: a list of labels, one per row in the initial table. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> table = { |
| ... "fixed_acidity": ["7.4", "7.8", "10.3"], |
| ... "volatile_acidity": ["0.7", "0.88", "0.32"], |
| ... "citric_acid": ["0", "0", "0.45"], |
| ... "residual_sugar": ["1.9", "2.6", "6.4"], |
| ... "chlorides": ["0.076", "0.098", "0.073"], |
| ... "free_sulfur_dioxide": ["11", "25", "5"], |
| ... "total_sulfur_dioxide": ["34", "67", "13"], |
| ... "density": ["0.9978", "0.9968", "0.9976"], |
| ... "pH": ["3.51", "3.2", "3.23"], |
| ... "sulphates": ["0.56", "0.68", "0.82"], |
| ... "alcohol": ["9.4", "9.8", "12.6"], |
| ... } |
| >>> client.tabular_classification(table=table, model="julien-c/wine-quality") |
| ["5", "5", "5"] |
| ``` |
| """ |
| response = self.post(json={"table": table}, model=model, task="tabular-classification") |
| return _bytes_to_list(response) |
|
|
| def tabular_regression(self, table: Dict[str, Any], *, model: str) -> List[float]: |
| """ |
| Predicting a numerical target value given a set of attributes/features in a table. |
| |
| Args: |
| table (`Dict[str, Any]`): |
| Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. |
| model (`str`): |
| The model to use for the tabular-regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. |
| |
| Returns: |
| `List`: a list of predicted numerical target values. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> table = { |
| ... "Height": ["11.52", "12.48", "12.3778"], |
| ... "Length1": ["23.2", "24", "23.9"], |
| ... "Length2": ["25.4", "26.3", "26.5"], |
| ... "Length3": ["30", "31.2", "31.1"], |
| ... "Species": ["Bream", "Bream", "Bream"], |
| ... "Width": ["4.02", "4.3056", "4.6961"], |
| ... } |
| >>> client.tabular_regression(table, model="scikit-learn/Fish-Weight") |
| [110, 120, 130] |
| ``` |
| """ |
| response = self.post(json={"table": table}, model=model, task="tabular-regression") |
| return _bytes_to_list(response) |
|
|
| def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]: |
| """ |
| Perform text classification (e.g. sentiment-analysis) on the given text. |
| |
| Args: |
| text (`str`): |
| A string to be classified. |
| model (`str`, *optional*): |
| The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. |
| Defaults to None. |
| |
| Returns: |
| `List[Dict]`: a list of dictionaries containing the predicted label and associated probability. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.text_classification("I like you") |
| [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}] |
| ``` |
| """ |
| response = self.post(json={"inputs": text}, model=model, task="text-classification") |
| return _bytes_to_list(response)[0] |
|
|
| @overload |
| def text_generation( |
| self, |
| prompt: str, |
| *, |
| details: Literal[False] = ..., |
| stream: Literal[False] = ..., |
| model: Optional[str] = None, |
| do_sample: bool = False, |
| max_new_tokens: int = 20, |
| best_of: Optional[int] = None, |
| repetition_penalty: Optional[float] = None, |
| return_full_text: bool = False, |
| seed: Optional[int] = None, |
| stop_sequences: Optional[List[str]] = None, |
| temperature: Optional[float] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| truncate: Optional[int] = None, |
| typical_p: Optional[float] = None, |
| watermark: bool = False, |
| ) -> str: |
| ... |
|
|
| @overload |
| def text_generation( |
| self, |
| prompt: str, |
| *, |
| details: Literal[True] = ..., |
| stream: Literal[False] = ..., |
| model: Optional[str] = None, |
| do_sample: bool = False, |
| max_new_tokens: int = 20, |
| best_of: Optional[int] = None, |
| repetition_penalty: Optional[float] = None, |
| return_full_text: bool = False, |
| seed: Optional[int] = None, |
| stop_sequences: Optional[List[str]] = None, |
| temperature: Optional[float] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| truncate: Optional[int] = None, |
| typical_p: Optional[float] = None, |
| watermark: bool = False, |
| ) -> TextGenerationResponse: |
| ... |
|
|
| @overload |
| def text_generation( |
| self, |
| prompt: str, |
| *, |
| details: Literal[False] = ..., |
| stream: Literal[True] = ..., |
| model: Optional[str] = None, |
| do_sample: bool = False, |
| max_new_tokens: int = 20, |
| best_of: Optional[int] = None, |
| repetition_penalty: Optional[float] = None, |
| return_full_text: bool = False, |
| seed: Optional[int] = None, |
| stop_sequences: Optional[List[str]] = None, |
| temperature: Optional[float] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| truncate: Optional[int] = None, |
| typical_p: Optional[float] = None, |
| watermark: bool = False, |
| ) -> Iterable[str]: |
| ... |
|
|
| @overload |
| def text_generation( |
| self, |
| prompt: str, |
| *, |
| details: Literal[True] = ..., |
| stream: Literal[True] = ..., |
| model: Optional[str] = None, |
| do_sample: bool = False, |
| max_new_tokens: int = 20, |
| best_of: Optional[int] = None, |
| repetition_penalty: Optional[float] = None, |
| return_full_text: bool = False, |
| seed: Optional[int] = None, |
| stop_sequences: Optional[List[str]] = None, |
| temperature: Optional[float] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| truncate: Optional[int] = None, |
| typical_p: Optional[float] = None, |
| watermark: bool = False, |
| ) -> Iterable[TextGenerationStreamResponse]: |
| ... |
|
|
| def text_generation( |
| self, |
| prompt: str, |
| *, |
| details: bool = False, |
| stream: bool = False, |
| model: Optional[str] = None, |
| do_sample: bool = False, |
| max_new_tokens: int = 20, |
| best_of: Optional[int] = None, |
| repetition_penalty: Optional[float] = None, |
| return_full_text: bool = False, |
| seed: Optional[int] = None, |
| stop_sequences: Optional[List[str]] = None, |
| temperature: Optional[float] = None, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| truncate: Optional[int] = None, |
| typical_p: Optional[float] = None, |
| watermark: bool = False, |
| decoder_input_details: bool = False, |
| ) -> Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]: |
| """ |
| Given a prompt, generate the following text. |
| |
| It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow |
| early failures. |
| |
| API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the |
| go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the |
| default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but |
| not exactly the same. This method is compatible with both approaches but some parameters are only available for |
| `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process |
| continues correctly. |
| |
| To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference. |
| |
| Args: |
| prompt (`str`): |
| Input text. |
| details (`bool`, *optional*): |
| By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens, |
| probabilities, seed, finish reason, etc.). Only available for models running on with the |
| `text-generation-inference` backend. |
| stream (`bool`, *optional*): |
| By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of |
| tokens to be returned. Only available for models running on with the `text-generation-inference` |
| backend. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. |
| do_sample (`bool`): |
| Activate logits sampling |
| max_new_tokens (`int`): |
| Maximum number of generated tokens |
| best_of (`int`): |
| Generate best_of sequences and return the one if the highest token logprobs |
| repetition_penalty (`float`): |
| The parameter for repetition penalty. 1.0 means no penalty. See [this |
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. |
| return_full_text (`bool`): |
| Whether to prepend the prompt to the generated text |
| seed (`int`): |
| Random sampling seed |
| stop_sequences (`List[str]`): |
| Stop generating tokens if a member of `stop_sequences` is generated |
| temperature (`float`): |
| The value used to module the logits distribution. |
| top_k (`int`): |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. |
| top_p (`float`): |
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or |
| higher are kept for generation. |
| truncate (`int`): |
| Truncate inputs tokens to the given size |
| typical_p (`float`): |
| Typical Decoding mass |
| See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information |
| watermark (`bool`): |
| Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) |
| decoder_input_details (`bool`): |
| Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken |
| into account. Defaults to `False`. |
| |
| Returns: |
| `Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`: |
| Generated text returned from the server: |
| - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) |
| - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` |
| - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`] |
| - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`] |
| |
| Raises: |
| `ValidationError`: |
| If input values are not valid. No HTTP call is made to the server. |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| |
| # Case 1: generate text |
| >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12) |
| '100% open source and built to be easy to use.' |
| |
| # Case 2: iterate over the generated tokens. Useful for large generation. |
| >>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True): |
| ... print(token) |
| 100 |
| % |
| open |
| source |
| and |
| built |
| to |
| be |
| easy |
| to |
| use |
| . |
| |
| # Case 3: get more details about the generation process. |
| >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True) |
| TextGenerationResponse( |
| generated_text='100% open source and built to be easy to use.', |
| details=Details( |
| finish_reason=<FinishReason.Length: 'length'>, |
| generated_tokens=12, |
| seed=None, |
| prefill=[ |
| InputToken(id=487, text='The', logprob=None), |
| InputToken(id=53789, text=' hugging', logprob=-13.171875), |
| (...) |
| InputToken(id=204, text=' ', logprob=-7.0390625) |
| ], |
| tokens=[ |
| Token(id=1425, text='100', logprob=-1.0175781, special=False), |
| Token(id=16, text='%', logprob=-0.0463562, special=False), |
| (...) |
| Token(id=25, text='.', logprob=-0.5703125, special=False) |
| ], |
| best_of_sequences=None |
| ) |
| ) |
| |
| # Case 4: iterate over the generated tokens with more details. |
| # Last object is more complete, containing the full generated text and the finish reason. |
| >>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True): |
| ... print(details) |
| ... |
| TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None) |
| TextGenerationStreamResponse(token=Token( |
| id=25, |
| text='.', |
| logprob=-0.5703125, |
| special=False), |
| generated_text='100% open source and built to be easy to use.', |
| details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None) |
| ) |
| ``` |
| """ |
| |
| |
|
|
| if decoder_input_details and not details: |
| warnings.warn( |
| "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that" |
| " the output from the server will be truncated." |
| ) |
| decoder_input_details = False |
|
|
| |
| parameters = TextGenerationParameters( |
| best_of=best_of, |
| details=details, |
| do_sample=do_sample, |
| max_new_tokens=max_new_tokens, |
| repetition_penalty=repetition_penalty, |
| return_full_text=return_full_text, |
| seed=seed, |
| stop=stop_sequences if stop_sequences is not None else [], |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| truncate=truncate, |
| typical_p=typical_p, |
| watermark=watermark, |
| decoder_input_details=decoder_input_details, |
| ) |
| request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters) |
| payload = asdict(request) |
|
|
| |
| if not _is_tgi_server(model): |
| ignored_parameters = [] |
| for key in "watermark", "stop", "details", "decoder_input_details": |
| if payload["parameters"][key] is not None: |
| ignored_parameters.append(key) |
| del payload["parameters"][key] |
| if len(ignored_parameters) > 0: |
| warnings.warn( |
| "API endpoint/model for text-generation is not served via TGI. Ignoring parameters" |
| f" {ignored_parameters}.", |
| UserWarning, |
| ) |
| if details: |
| warnings.warn( |
| "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will" |
| " be ignored meaning only the generated text will be returned.", |
| UserWarning, |
| ) |
| details = False |
| if stream: |
| raise ValueError( |
| "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream." |
| " Please pass `stream=False` as input." |
| ) |
|
|
| |
| try: |
| bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) |
| except HTTPError as e: |
| if isinstance(e, BadRequestError) and "The following `model_kwargs` are not used by the model" in str(e): |
| _set_as_non_tgi(model) |
| return self.text_generation( |
| prompt=prompt, |
| details=details, |
| stream=stream, |
| model=model, |
| do_sample=do_sample, |
| max_new_tokens=max_new_tokens, |
| best_of=best_of, |
| repetition_penalty=repetition_penalty, |
| return_full_text=return_full_text, |
| seed=seed, |
| stop_sequences=stop_sequences, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| truncate=truncate, |
| typical_p=typical_p, |
| watermark=watermark, |
| decoder_input_details=decoder_input_details, |
| ) |
| raise_text_generation_error(e) |
|
|
| |
| if stream: |
| return _stream_text_generation_response(bytes_output, details) |
|
|
| data = _bytes_to_dict(bytes_output)[0] |
| return TextGenerationResponse(**data) if details else data["generated_text"] |
|
|
| def text_to_image( |
| self, |
| prompt: str, |
| *, |
| negative_prompt: Optional[str] = None, |
| height: Optional[float] = None, |
| width: Optional[float] = None, |
| num_inference_steps: Optional[float] = None, |
| guidance_scale: Optional[float] = None, |
| model: Optional[str] = None, |
| **kwargs, |
| ) -> "Image": |
| """ |
| Generate an image based on a given text using a specified model. |
| |
| <Tip warning={true}> |
| |
| You must have `PIL` installed if you want to work with images (`pip install Pillow`). |
| |
| </Tip> |
| |
| Args: |
| prompt (`str`): |
| The prompt to generate an image from. |
| negative_prompt (`str`, *optional*): |
| An optional negative prompt for the image generation. |
| height (`float`, *optional*): |
| The height in pixels of the image to generate. |
| width (`float`, *optional*): |
| The width in pixels of the image to generate. |
| num_inference_steps (`int`, *optional*): |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| expense of slower inference. |
| guidance_scale (`float`, *optional*): |
| Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
| usually at the expense of lower image quality. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. |
| |
| Returns: |
| `Image`: The generated image. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| |
| >>> image = client.text_to_image("An astronaut riding a horse on the moon.") |
| >>> image.save("astronaut.png") |
| |
| >>> image = client.text_to_image( |
| ... "An astronaut riding a horse on the moon.", |
| ... negative_prompt="low resolution, blurry", |
| ... model="stabilityai/stable-diffusion-2-1", |
| ... ) |
| >>> image.save("better_astronaut.png") |
| ``` |
| """ |
| payload = {"inputs": prompt} |
| parameters = { |
| "negative_prompt": negative_prompt, |
| "height": height, |
| "width": width, |
| "num_inference_steps": num_inference_steps, |
| "guidance_scale": guidance_scale, |
| **kwargs, |
| } |
| for key, value in parameters.items(): |
| if value is not None: |
| payload.setdefault("parameters", {})[key] = value |
| response = self.post(json=payload, model=model, task="text-to-image") |
| return _bytes_to_image(response) |
|
|
| def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes: |
| """ |
| Synthesize an audio of a voice pronouncing a given text. |
| |
| Args: |
| text (`str`): |
| The text to synthesize. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. |
| |
| Returns: |
| `bytes`: The generated audio. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from pathlib import Path |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| |
| >>> audio = client.text_to_speech("Hello world") |
| >>> Path("hello_world.flac").write_bytes(audio) |
| ``` |
| """ |
| return self.post(json={"inputs": text}, model=model, task="text-to-speech") |
|
|
| def token_classification(self, text: str, *, model: Optional[str] = None) -> List[TokenClassificationOutput]: |
| """ |
| Perform token classification on the given text. |
| Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. |
| |
| Args: |
| text (`str`): |
| A string to be classified. |
| model (`str`, *optional*): |
| The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. |
| Defaults to None. |
| |
| Returns: |
| `List[Dict]`: List of token classification outputs containing the entity group, confidence score, word, start and end index. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica") |
| [{'entity_group': 'PER', |
| 'score': 0.9971321225166321, |
| 'word': 'Sarah Jessica Parker', |
| 'start': 11, |
| 'end': 31}, |
| {'entity_group': 'PER', |
| 'score': 0.9773476123809814, |
| 'word': 'Jessica', |
| 'start': 52, |
| 'end': 59}] |
| ``` |
| """ |
| payload: Dict[str, Any] = {"inputs": text} |
| response = self.post( |
| json=payload, |
| model=model, |
| task="token-classification", |
| ) |
| return _bytes_to_list(response) |
|
|
| def translation( |
| self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None |
| ) -> str: |
| """ |
| Convert text from one language to another. |
| |
| Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for |
| your specific use case. Source and target languages usually depend on the model. |
| However, it is possible to specify source and target languages for certain models. If you are working with one of these models, |
| you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. |
| You can find this information in the model card. |
| |
| Args: |
| text (`str`): |
| A string to be translated. |
| model (`str`, *optional*): |
| The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to |
| a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. |
| Defaults to None. |
| src_lang (`str`, *optional*): |
| Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`. |
| tgt_lang (`str`, *optional*): |
| Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`. |
| |
| Returns: |
| `str`: The generated translated text. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| `ValueError`: |
| If only one of the `src_lang` and `tgt_lang` arguments are provided. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.translation("My name is Wolfgang and I live in Berlin") |
| 'Mein Name ist Wolfgang und ich lebe in Berlin.' |
| >>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") |
| "Je m'appelle Wolfgang et je vis à Berlin." |
| ``` |
| |
| Specifying languages: |
| ```py |
| >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") |
| "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica" |
| ``` |
| """ |
| |
| if src_lang is not None and tgt_lang is None: |
| raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") |
|
|
| if src_lang is None and tgt_lang is not None: |
| raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") |
|
|
| |
| payload: Dict = {"inputs": text} |
| if src_lang and tgt_lang: |
| payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang} |
| response = self.post(json=payload, model=model, task="translation") |
| return _bytes_to_dict(response)[0]["translation_text"] |
|
|
| def zero_shot_classification( |
| self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None |
| ) -> List[ClassificationOutput]: |
| """ |
| Provide as input a text and a set of candidate labels to classify the input text. |
| |
| Args: |
| text (`str`): |
| The input text to classify. |
| labels (`List[str]`): |
| List of string possible labels. There must be at least 2 labels. |
| multi_label (`bool`): |
| Boolean that is set to True if classes can overlap. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. |
| |
| Returns: |
| `List[Dict]`: List of classification outputs containing the predicted labels and their confidence. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> text = ( |
| ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's" |
| ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" |
| ... " mysteries when he went for a run up a hill in Nice, France." |
| ... ) |
| >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"] |
| >>> client.zero_shot_classification(text, labels) |
| [ |
| {"label": "scientific discovery", "score": 0.7961668968200684}, |
| {"label": "space & cosmos", "score": 0.18570658564567566}, |
| {"label": "microbiology", "score": 0.00730885099619627}, |
| {"label": "archeology", "score": 0.006258360575884581}, |
| {"label": "robots", "score": 0.004559356719255447}, |
| ] |
| >>> client.zero_shot_classification(text, labels, multi_label=True) |
| [ |
| {"label": "scientific discovery", "score": 0.9829297661781311}, |
| {"label": "space & cosmos", "score": 0.755190908908844}, |
| {"label": "microbiology", "score": 0.0005462635890580714}, |
| {"label": "archeology", "score": 0.00047131875180639327}, |
| {"label": "robots", "score": 0.00030448526376858354}, |
| ] |
| ``` |
| """ |
| |
| if len(labels) < 2: |
| raise ValueError("You must specify at least 2 classes to compare.") |
|
|
| response = self.post( |
| json={ |
| "inputs": text, |
| "parameters": { |
| "candidate_labels": ",".join(labels), |
| "multi_label": multi_label, |
| }, |
| }, |
| model=model, |
| task="zero-shot-classification", |
| ) |
| output = _bytes_to_dict(response) |
| return [{"label": label, "score": score} for label, score in zip(output["labels"], output["scores"])] |
|
|
| def zero_shot_image_classification( |
| self, image: ContentT, labels: List[str], *, model: Optional[str] = None |
| ) -> List[ClassificationOutput]: |
| """ |
| Provide input image and text labels to predict text labels for the image. |
| |
| Args: |
| image (`Union[str, Path, bytes, BinaryIO]`): |
| The input image to caption. It can be raw bytes, an image file, or a URL to an online image. |
| labels (`List[str]`): |
| List of string possible labels. There must be at least 2 labels. |
| model (`str`, *optional*): |
| The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed |
| Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. |
| |
| Returns: |
| `List[Dict]`: List of classification outputs containing the predicted labels and their confidence. |
| |
| Raises: |
| [`InferenceTimeoutError`]: |
| If the model is unavailable or the request times out. |
| `HTTPError`: |
| If the request fails with an HTTP error status code other than HTTP 503. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| |
| >>> client.zero_shot_image_classification( |
| ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg", |
| ... labels=["dog", "cat", "horse"], |
| ... ) |
| [{"label": "dog", "score": 0.956}, ...] |
| ``` |
| """ |
| |
| if len(labels) < 2: |
| raise ValueError("You must specify at least 2 classes to compare.") |
|
|
| response = self.post( |
| json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}}, |
| model=model, |
| task="zero-shot-image-classification", |
| ) |
| return _bytes_to_list(response) |
|
|
| def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str: |
| model = model or self.model |
|
|
| |
| if model is not None and (model.startswith("http://") or model.startswith("https://")): |
| return model |
|
|
| |
| if model is None: |
| if task is None: |
| raise ValueError( |
| "You must specify at least a model (repo_id or URL) or a task, either when instantiating" |
| " `InferenceClient` or when making a request." |
| ) |
| model = self.get_recommended_model(task) |
| logger.info( |
| f"Using recommended model {model} for task {task}. Note that it is" |
| f" encouraged to explicitly set `model='{model}'` as the recommended" |
| " models list might get updated without prior notice." |
| ) |
|
|
| |
| return ( |
| |
| f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}" |
| if task in ("feature-extraction", "sentence-similarity") |
| |
| else f"{INFERENCE_ENDPOINT}/models/{model}" |
| ) |
|
|
| @staticmethod |
| def get_recommended_model(task: str) -> str: |
| """ |
| Get the model Hugging Face recommends for the input task. |
| |
| Args: |
| task (`str`): |
| The Hugging Face task to get which model Hugging Face recommends. |
| All available tasks can be found [here](https://huggingface.co/tasks). |
| |
| Returns: |
| `str`: Name of the model recommended for the input task. |
| |
| Raises: |
| `ValueError`: If Hugging Face has no recommendation for the input task. |
| """ |
| model = _fetch_recommended_models().get(task) |
| if model is None: |
| raise ValueError( |
| f"Task {task} has no recommended model. Please specify a model" |
| " explicitly. Visit https://huggingface.co/tasks for more info." |
| ) |
| return model |
|
|
| def get_model_status(self, model: Optional[str] = None) -> ModelStatus: |
| """ |
| Get the status of a model hosted on the Inference API. |
| |
| <Tip> |
| |
| This endpoint is mostly useful when you already know which model you want to use and want to check its |
| availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`]. |
| |
| </Tip> |
| |
| Args: |
| model (`str`, *optional*): |
| Identifier of the model for witch the status gonna be checked. If model is not provided, |
| the model associated with this instance of [`InferenceClient`] will be used. Only InferenceAPI service can be checked so the |
| identifier cannot be a URL. |
| |
| |
| Returns: |
| [`ModelStatus`]: An instance of ModelStatus dataclass, containing information, |
| about the state of the model: load, state, compute type and framework. |
| |
| Example: |
| ```py |
| >>> from huggingface_hub import InferenceClient |
| >>> client = InferenceClient() |
| >>> client.get_model_status("bigcode/starcoder") |
| ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference') |
| ``` |
| """ |
| model = model or self.model |
| if model is None: |
| raise ValueError("Model id not provided.") |
| if model.startswith("https://"): |
| raise NotImplementedError("Model status is only available for Inference API endpoints.") |
| url = f"{INFERENCE_ENDPOINT}/status/{model}" |
|
|
| response = get_session().get(url, headers=self.headers) |
| hf_raise_for_status(response) |
| response_data = response.json() |
|
|
| if "error" in response_data: |
| raise ValueError(response_data["error"]) |
|
|
| return ModelStatus( |
| loaded=response_data["loaded"], |
| state=response_data["state"], |
| compute_type=response_data["compute_type"], |
| framework=response_data["framework"], |
| ) |
|
|