| """Helpers for constructing Hugging Face inference clients in Space runtimes.""" |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
| from typing import Any |
|
|
| from maris_core.utils.env import get_env_any, get_env_any_or_default, get_hf_token |
|
|
| HF_INFERENCE_PROVIDER_DEFAULT = "hf-inference" |
| HF_INFERENCE_BASE_URL_DEFAULT = "https://api-inference.huggingface.co" |
|
|
|
|
| def build_hf_inference_client_kwargs(*, token: str | None = None) -> dict[str, str]: |
| """Return kwargs that keep Spaces pinned to the public HF inference endpoint.""" |
| resolved_token = token or get_env_any("HF_INFERENCE_API_KEY") or get_hf_token() |
| provider = get_env_any_or_default( |
| "MARIS_INFERENCE_PROVIDER", |
| "HF_INFERENCE_PROVIDER", |
| default=HF_INFERENCE_PROVIDER_DEFAULT, |
| ) |
| base_url = get_env_any_or_default( |
| "MARIS_INFERENCE_BASE_URL", |
| "HF_INFERENCE_API_URL", |
| "HUGGINGFACE_INFERENCE_BASE_URL", |
| default=HF_INFERENCE_BASE_URL_DEFAULT, |
| ) |
| kwargs: dict[str, str] = { |
| "provider": provider, |
| "base_url": base_url, |
| } |
| if resolved_token: |
| kwargs["token"] = resolved_token |
| return kwargs |
|
|
|
|
| def create_hf_inference_client( |
| client_factory: Callable[..., Any], *, token: str | None = None |
| ) -> Any: |
| """Create an inference client, falling back for older/simple test doubles.""" |
| kwargs = build_hf_inference_client_kwargs(token=token) |
| try: |
| return client_factory(**kwargs) |
| except TypeError: |
| if "token" in kwargs: |
| return client_factory(token=kwargs["token"]) |
| return client_factory() |
|
|