File size: 1,622 Bytes
f440f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Helpers for constructing Hugging Face inference clients in Space runtimes."""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

from maris_core.utils.env import get_env_any, get_env_any_or_default, get_hf_token

HF_INFERENCE_PROVIDER_DEFAULT = "hf-inference"
HF_INFERENCE_BASE_URL_DEFAULT = "https://api-inference.huggingface.co"


def build_hf_inference_client_kwargs(*, token: str | None = None) -> dict[str, str]:
    """Return kwargs that keep Spaces pinned to the public HF inference endpoint."""
    resolved_token = token or get_env_any("HF_INFERENCE_API_KEY") or get_hf_token()
    provider = get_env_any_or_default(
        "MARIS_INFERENCE_PROVIDER",
        "HF_INFERENCE_PROVIDER",
        default=HF_INFERENCE_PROVIDER_DEFAULT,
    )
    base_url = get_env_any_or_default(
        "MARIS_INFERENCE_BASE_URL",
        "HF_INFERENCE_API_URL",
        "HUGGINGFACE_INFERENCE_BASE_URL",
        default=HF_INFERENCE_BASE_URL_DEFAULT,
    )
    kwargs: dict[str, str] = {
        "provider": provider,
        "base_url": base_url,
    }
    if resolved_token:
        kwargs["token"] = resolved_token
    return kwargs


def create_hf_inference_client(
    client_factory: Callable[..., Any], *, token: str | None = None
) -> Any:
    """Create an inference client, falling back for older/simple test doubles."""
    kwargs = build_hf_inference_client_kwargs(token=token)
    try:
        return client_factory(**kwargs)
    except TypeError:
        if "token" in kwargs:
            return client_factory(token=kwargs["token"])
        return client_factory()