File size: 4,965 Bytes
cc49c15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14c441f
 
 
 
 
 
 
 
 
 
 
 
 
cc49c15
 
 
 
14c441f
cc49c15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14c441f
cc49c15
 
 
 
 
 
 
 
 
 
14c441f
cc49c15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14c441f
cc49c15
 
 
 
 
 
 
 
 
 
14c441f
 
 
cc49c15
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode
from urllib.request import Request, urlopen

DEFAULT_MAX_RESULTS = 20
DEFAULT_TIMEOUT_SEC = 30


def _load_token() -> str | None:
    # Check for request-scoped token first (when running as MCP server)
    # This allows clients to pass their own HF token via Authorization header
    try:
        from fast_agent.mcp.auth.context import request_bearer_token

        ctx_token = request_bearer_token.get()
        if ctx_token:
            return ctx_token
    except ImportError:
        # fast_agent.mcp.auth.context not available
        pass

    # Fall back to HF_TOKEN environment variable
    token = os.getenv("HF_TOKEN")
    if token:
        return token

    # Fall back to cached huggingface token file
    token_path = Path.home() / ".cache" / "huggingface" / "token"
    if token_path.exists():
        token_value = token_path.read_text(encoding="utf-8").strip()
        return token_value or None

    return None


def _max_results_from_env() -> int:
    raw = os.getenv("HF_MAX_RESULTS")
    if not raw:
        return DEFAULT_MAX_RESULTS
    try:
        value = int(raw)
    except ValueError:
        return DEFAULT_MAX_RESULTS
    return value if value > 0 else DEFAULT_MAX_RESULTS


def _normalize_endpoint(endpoint: str) -> str:
    if endpoint.startswith("http://") or endpoint.startswith("https://"):
        raise ValueError("Endpoint must be a path relative to /api, not a full URL.")
    endpoint = endpoint.strip()
    if not endpoint:
        raise ValueError("Endpoint must be a non-empty string.")
    if not endpoint.startswith("/"):
        endpoint = f"/{endpoint}"
    return endpoint


def _normalize_params(params: dict[str, Any] | None) -> dict[str, Any]:
    if not params:
        return {}
    normalized: dict[str, Any] = {}
    for key, value in params.items():
        if value is None:
            continue
        if isinstance(value, (list, tuple)):
            normalized[key] = [str(item) for item in value]
        else:
            normalized[key] = str(value)
    return normalized


def _build_url(endpoint: str, params: dict[str, Any] | None) -> str:
    base = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/")
    url = f"{base}/api{_normalize_endpoint(endpoint)}"
    normalized_params = _normalize_params(params)
    if normalized_params:
        url = f"{url}?{urlencode(normalized_params, doseq=True)}"
    return url


def hf_api_request(
    endpoint: str,
    method: str = "GET",
    params: dict[str, Any] | None = None,
    json_body: dict[str, Any] | None = None,
    max_results: int | None = None,
    offset: int | None = None,
) -> dict[str, Any]:
    """
    Call the Hugging Face Hub API (GET/POST only).

    Args:
        endpoint: API endpoint relative to /api (e.g. "/whoami-v2").
        method: HTTP method (GET or POST).
        params: Optional query parameters.
        json_body: Optional JSON payload for POST requests.
        max_results: Max results when response is a list (defaults to HF_MAX_RESULTS).
        offset: Client-side offset when response is a list (defaults to 0).

    Returns:
        A dict with the response data and request metadata.
    """
    method_upper = method.upper()
    if method_upper not in {"GET", "POST"}:
        raise ValueError("Only GET and POST are allowed for hf_api_request.")

    if method_upper == "GET" and json_body is not None:
        raise ValueError("GET requests do not accept json_body.")

    url = _build_url(endpoint, params)

    headers = {
        "Accept": "application/json",
    }
    token = _load_token()
    if token:
        headers["Authorization"] = f"Bearer {token}"

    data = None
    if method_upper == "POST":
        headers["Content-Type"] = "application/json"
        data = json.dumps(json_body or {}).encode("utf-8")

    request = Request(url, headers=headers, data=data, method=method_upper)

    try:
        with urlopen(request, timeout=DEFAULT_TIMEOUT_SEC) as response:
            raw = response.read()
            status_code = response.status
    except HTTPError as exc:
        error_body = exc.read().decode("utf-8", errors="replace")
        raise RuntimeError(f"HF API error {exc.code} for {url}: {error_body}") from exc
    except URLError as exc:
        raise RuntimeError(f"HF API request failed for {url}: {exc}") from exc

    try:
        payload = json.loads(raw)
    except json.JSONDecodeError:
        payload = raw.decode("utf-8", errors="replace")

    if isinstance(payload, list):
        limit = max_results if max_results is not None else _max_results_from_env()
        start = max(offset or 0, 0)
        end = start + max(limit, 0)
        payload = payload[start:end]

    return {
        "url": url,
        "status": status_code,
        "data": payload,
    }