"""Helpers for constructing Hugging Face inference clients in Space runtimes.""" from __future__ import annotations from collections.abc import Callable from typing import Any from maris_core.utils.env import get_env_any, get_env_any_or_default, get_hf_token HF_INFERENCE_PROVIDER_DEFAULT = "hf-inference" HF_INFERENCE_BASE_URL_DEFAULT = "https://api-inference.huggingface.co" def build_hf_inference_client_kwargs(*, token: str | None = None) -> dict[str, str]: """Return kwargs that keep Spaces pinned to the public HF inference endpoint.""" resolved_token = token or get_env_any("HF_INFERENCE_API_KEY") or get_hf_token() provider = get_env_any_or_default( "MARIS_INFERENCE_PROVIDER", "HF_INFERENCE_PROVIDER", default=HF_INFERENCE_PROVIDER_DEFAULT, ) base_url = get_env_any_or_default( "MARIS_INFERENCE_BASE_URL", "HF_INFERENCE_API_URL", "HUGGINGFACE_INFERENCE_BASE_URL", default=HF_INFERENCE_BASE_URL_DEFAULT, ) kwargs: dict[str, str] = { "provider": provider, "base_url": base_url, } if resolved_token: kwargs["token"] = resolved_token return kwargs def create_hf_inference_client( client_factory: Callable[..., Any], *, token: str | None = None ) -> Any: """Create an inference client, falling back for older/simple test doubles.""" kwargs = build_hf_inference_client_kwargs(token=token) try: return client_factory(**kwargs) except TypeError: if "token" in kwargs: return client_factory(token=kwargs["token"]) return client_factory()