File size: 9,508 Bytes
783a8bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import json
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional, Union
from urllib.parse import urlparse, urlunparse
from huggingface_hub import constants
from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _b64_encode, _bytes_to_dict, _open_as_binary
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status
class HFInferenceTask(TaskProviderHelper):
"""Base class for HF Inference API tasks."""
def __init__(self, task: str):
super().__init__(
provider="hf-inference",
base_url=constants.INFERENCE_PROXY_TEMPLATE.format(provider="hf-inference"),
task=task,
)
def _prepare_api_key(self, api_key: Optional[str]) -> str:
# special case: for HF Inference we allow not providing an API key
return api_key or get_token() # type: ignore[return-value]
def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
if model is not None and model.startswith(("http://", "https://")):
return InferenceProviderMapping(
provider="hf-inference", providerId=model, hf_model_id=model, task=self.task, status="live"
)
model_id = model if model is not None else _fetch_recommended_models().get(self.task)
if model_id is None:
raise ValueError(
f"Task {self.task} has no recommended model for HF Inference. Please specify a model"
" explicitly. Visit https://huggingface.co/tasks for more info."
)
_check_supported_task(model_id, self.task)
return InferenceProviderMapping(
provider="hf-inference", providerId=model_id, hf_model_id=model_id, task=self.task, status="live"
)
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)
if mapped_model.startswith(("http://", "https://")):
return mapped_model
return (
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
f"{self.base_url}/models/{mapped_model}/pipeline/{self.task}"
if self.task in ("feature-extraction", "sentence-similarity")
# Otherwise, we use the default endpoint
else f"{self.base_url}/models/{mapped_model}"
)
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
if isinstance(inputs, bytes):
raise ValueError(f"Unexpected binary input for task {self.task}.")
if isinstance(inputs, Path):
raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})")
return filter_none({"inputs": inputs, "parameters": parameters})
class HFInferenceBinaryInputTask(HFInferenceTask):
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
return None
def _prepare_payload_as_bytes(
self,
inputs: Any,
parameters: Dict,
provider_mapping_info: InferenceProviderMapping,
extra_payload: Optional[Dict],
) -> Optional[bytes]:
parameters = filter_none(parameters)
extra_payload = extra_payload or {}
has_parameters = len(parameters) > 0 or len(extra_payload) > 0
# Raise if not a binary object or a local path or a URL.
if not isinstance(inputs, (bytes, Path)) and not isinstance(inputs, str):
raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}")
# Send inputs as raw content when no parameters are provided
if not has_parameters:
with _open_as_binary(inputs) as data:
data_as_bytes = data if isinstance(data, bytes) else data.read()
return data_as_bytes
# Otherwise encode as b64
return json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8")
class HFInferenceConversational(HFInferenceTask):
def __init__(self):
super().__init__("conversational")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload = filter_none(parameters)
mapped_model = provider_mapping_info.provider_id
payload_model = parameters.get("model") or mapped_model
if payload_model is None or payload_model.startswith(("http://", "https://")):
payload_model = "dummy"
response_format = parameters.get("response_format")
if isinstance(response_format, dict) and response_format.get("type") == "json_schema":
payload["response_format"] = {
"type": "json_object",
"value": response_format["json_schema"]["schema"],
}
return {**payload, "model": payload_model, "messages": inputs}
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
base_url = (
mapped_model
if mapped_model.startswith(("http://", "https://"))
else f"{constants.INFERENCE_PROXY_TEMPLATE.format(provider='hf-inference')}/models/{mapped_model}"
)
return _build_chat_completion_url(base_url)
def _build_chat_completion_url(model_url: str) -> str:
parsed = urlparse(model_url)
path = parsed.path.rstrip("/")
# If the path already ends with /chat/completions, we're done!
if path.endswith("/chat/completions"):
return model_url
# Append /chat/completions if not already present
if path.endswith("/v1"):
new_path = path + "/chat/completions"
# If path was empty or just "/", set the full path
elif not path:
new_path = "/v1/chat/completions"
# Append /v1/chat/completions if not already present
else:
new_path = path + "/v1/chat/completions"
# Reconstruct the URL with the new path and original query parameters.
return urlunparse(parsed._replace(path=new_path))
@lru_cache(maxsize=1)
def _fetch_recommended_models() -> Dict[str, Optional[str]]:
response = get_session().get(f"{constants.ENDPOINT}/api/tasks", headers=build_hf_headers())
hf_raise_for_status(response)
return {task: next(iter(details["widgetModels"]), None) for task, details in response.json().items()}
@lru_cache(maxsize=None)
def _check_supported_task(model: str, task: str) -> None:
from huggingface_hub.hf_api import HfApi
model_info = HfApi().model_info(model)
pipeline_tag = model_info.pipeline_tag
tags = model_info.tags or []
is_conversational = "conversational" in tags
if task in ("text-generation", "conversational"):
if pipeline_tag == "text-generation":
# text-generation + conversational tag -> both tasks allowed
if is_conversational:
return
# text-generation without conversational tag -> only text-generation allowed
if task == "text-generation":
return
raise ValueError(f"Model '{model}' doesn't support task '{task}'.")
if pipeline_tag == "text2text-generation":
if task == "text-generation":
return
raise ValueError(f"Model '{model}' doesn't support task '{task}'.")
if pipeline_tag == "image-text-to-text":
if is_conversational and task == "conversational":
return # Only conversational allowed if tagged as conversational
raise ValueError("Non-conversational image-text-to-text task is not supported.")
if (
task in ("feature-extraction", "sentence-similarity")
and pipeline_tag in ("feature-extraction", "sentence-similarity")
and task in tags
):
# feature-extraction and sentence-similarity are interchangeable for HF Inference
return
# For all other tasks, just check pipeline tag
if pipeline_tag != task:
raise ValueError(
f"Model '{model}' doesn't support task '{task}'. Supported tasks: '{pipeline_tag}', got: '{task}'"
)
return
class HFInferenceFeatureExtractionTask(HFInferenceTask):
def __init__(self):
super().__init__("feature-extraction")
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
if isinstance(inputs, bytes):
raise ValueError(f"Unexpected binary input for task {self.task}.")
if isinstance(inputs, Path):
raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})")
# Parameters are sent at root-level for feature-extraction task
# See specs: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/feature-extraction/spec/input.json
return {"inputs": inputs, **filter_none(parameters)}
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
if isinstance(response, bytes):
return _bytes_to_dict(response)
return response
|