| | import io |
| | from typing import Any, Dict, List, Optional, Union |
| |
|
| | from . import constants |
| | from .hf_api import HfApi |
| | from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args |
| | from .utils._deprecation import _deprecate_method |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | ALL_TASKS = [ |
| | |
| | "text-classification", |
| | "token-classification", |
| | "table-question-answering", |
| | "question-answering", |
| | "zero-shot-classification", |
| | "translation", |
| | "summarization", |
| | "conversational", |
| | "feature-extraction", |
| | "text-generation", |
| | "text2text-generation", |
| | "fill-mask", |
| | "sentence-similarity", |
| | |
| | "text-to-speech", |
| | "automatic-speech-recognition", |
| | "audio-to-audio", |
| | "audio-classification", |
| | "voice-activity-detection", |
| | |
| | "image-classification", |
| | "object-detection", |
| | "image-segmentation", |
| | "text-to-image", |
| | "image-to-image", |
| | |
| | "tabular-classification", |
| | "tabular-regression", |
| | ] |
| |
|
| |
|
| | class InferenceApi: |
| | """Client to configure requests and make calls to the HuggingFace Inference API. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from huggingface_hub.inference_api import InferenceApi |
| | |
| | >>> # Mask-fill example |
| | >>> inference = InferenceApi("bert-base-uncased") |
| | >>> inference(inputs="The goal of life is [MASK].") |
| | [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] |
| | |
| | >>> # Question Answering example |
| | >>> inference = InferenceApi("deepset/roberta-base-squad2") |
| | >>> inputs = { |
| | ... "question": "What's my name?", |
| | ... "context": "My name is Clara and I live in Berkeley.", |
| | ... } |
| | >>> inference(inputs) |
| | {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} |
| | |
| | >>> # Zero-shot example |
| | >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli") |
| | >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" |
| | >>> params = {"candidate_labels": ["refund", "legal", "faq"]} |
| | >>> inference(inputs, params) |
| | {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} |
| | |
| | >>> # Overriding configured task |
| | >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction") |
| | |
| | >>> # Text-to-image |
| | >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1") |
| | >>> inference("cat") |
| | <PIL.PngImagePlugin.PngImageFile image (...)> |
| | |
| | >>> # Return as raw response to parse the output yourself |
| | >>> inference = InferenceApi("mio/amadeus") |
| | >>> response = inference("hello world", raw_response=True) |
| | >>> response.headers |
| | {"Content-Type": "audio/flac", ...} |
| | >>> response.content # raw bytes from server |
| | b'(...)' |
| | ``` |
| | """ |
| |
|
| | @validate_hf_hub_args |
| | @_deprecate_method( |
| | version="1.0", |
| | message=( |
| | "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out" |
| | " this guide to learn how to convert your script to use it:" |
| | " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client." |
| | ), |
| | ) |
| | def __init__( |
| | self, |
| | repo_id: str, |
| | task: Optional[str] = None, |
| | token: Optional[str] = None, |
| | gpu: bool = False, |
| | ): |
| | """Inits headers and API call information. |
| | |
| | Args: |
| | repo_id (``str``): |
| | Id of repository (e.g. `user/bert-base-uncased`). |
| | task (``str``, `optional`, defaults ``None``): |
| | Whether to force a task instead of using task specified in the |
| | repository. |
| | token (`str`, `optional`): |
| | The API token to use as HTTP bearer authorization. This is not |
| | the authentication token. You can find the token in |
| | https://huggingface.co/settings/token. Alternatively, you can |
| | find both your organizations and personal API tokens using |
| | `HfApi().whoami(token)`. |
| | gpu (`bool`, `optional`, defaults `False`): |
| | Whether to use GPU instead of CPU for inference(requires Startup |
| | plan at least). |
| | """ |
| | self.options = {"wait_for_model": True, "use_gpu": gpu} |
| | self.headers = build_hf_headers(token=token) |
| |
|
| | |
| | model_info = HfApi(token=token).model_info(repo_id=repo_id) |
| | if not model_info.pipeline_tag and not task: |
| | raise ValueError( |
| | "Task not specified in the repository. Please add it to the model card" |
| | " using pipeline_tag" |
| | " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)" |
| | ) |
| |
|
| | if task and task != model_info.pipeline_tag: |
| | if task not in ALL_TASKS: |
| | raise ValueError(f"Invalid task {task}. Make sure it's valid.") |
| |
|
| | logger.warning( |
| | "You're using a different task than the one specified in the" |
| | " repository. Be sure to know what you're doing :)" |
| | ) |
| | self.task = task |
| | else: |
| | assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None" |
| | self.task = model_info.pipeline_tag |
| |
|
| | self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}" |
| |
|
| | def __repr__(self): |
| | |
| | return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})" |
| |
|
| | def __call__( |
| | self, |
| | inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None, |
| | params: Optional[Dict] = None, |
| | data: Optional[bytes] = None, |
| | raw_response: bool = False, |
| | ) -> Any: |
| | """Make a call to the Inference API. |
| | |
| | Args: |
| | inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*): |
| | Inputs for the prediction. |
| | params (`Dict`, *optional*): |
| | Additional parameters for the models. Will be sent as `parameters` in the |
| | payload. |
| | data (`bytes`, *optional*): |
| | Bytes content of the request. In this case, leave `inputs` and `params` empty. |
| | raw_response (`bool`, defaults to `False`): |
| | If `True`, the raw `Response` object is returned. You can parse its content |
| | as preferred. By default, the content is parsed into a more practical format |
| | (json dictionary or PIL Image for example). |
| | """ |
| | |
| | payload: Dict[str, Any] = { |
| | "options": self.options, |
| | } |
| | if inputs: |
| | payload["inputs"] = inputs |
| | if params: |
| | payload["parameters"] = params |
| |
|
| | |
| | response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data) |
| |
|
| | |
| | if raw_response: |
| | return response |
| |
|
| | |
| | content_type = response.headers.get("Content-Type") or "" |
| | if content_type.startswith("image"): |
| | if not is_pillow_available(): |
| | raise ImportError( |
| | f"Task '{self.task}' returned as image but Pillow is not installed." |
| | " Please install it (`pip install Pillow`) or pass" |
| | " `raw_response=True` to get the raw `Response` object and parse" |
| | " the image by yourself." |
| | ) |
| |
|
| | from PIL import Image |
| |
|
| | return Image.open(io.BytesIO(response.content)) |
| | elif content_type == "application/json": |
| | return response.json() |
| | else: |
| | raise NotImplementedError( |
| | f"{content_type} output type is not implemented yet. You can pass" |
| | " `raw_response=True` to get the raw `Response` object and parse the" |
| | " output by yourself." |
| | ) |
| |
|