MarisUK's picture
Maris AI model sync
f440f03 verified
"""Helpers for constructing Hugging Face inference clients in Space runtimes."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from maris_core.utils.env import get_env_any, get_env_any_or_default, get_hf_token
HF_INFERENCE_PROVIDER_DEFAULT = "hf-inference"
HF_INFERENCE_BASE_URL_DEFAULT = "https://api-inference.huggingface.co"
def build_hf_inference_client_kwargs(*, token: str | None = None) -> dict[str, str]:
"""Return kwargs that keep Spaces pinned to the public HF inference endpoint."""
resolved_token = token or get_env_any("HF_INFERENCE_API_KEY") or get_hf_token()
provider = get_env_any_or_default(
"MARIS_INFERENCE_PROVIDER",
"HF_INFERENCE_PROVIDER",
default=HF_INFERENCE_PROVIDER_DEFAULT,
)
base_url = get_env_any_or_default(
"MARIS_INFERENCE_BASE_URL",
"HF_INFERENCE_API_URL",
"HUGGINGFACE_INFERENCE_BASE_URL",
default=HF_INFERENCE_BASE_URL_DEFAULT,
)
kwargs: dict[str, str] = {
"provider": provider,
"base_url": base_url,
}
if resolved_token:
kwargs["token"] = resolved_token
return kwargs
def create_hf_inference_client(
client_factory: Callable[..., Any], *, token: str | None = None
) -> Any:
"""Create an inference client, falling back for older/simple test doubles."""
kwargs = build_hf_inference_client_kwargs(token=token)
try:
return client_factory(**kwargs)
except TypeError:
if "token" in kwargs:
return client_factory(token=kwargs["token"])
return client_factory()