test_521 / trackio /remote_client.py
abidlabs's picture
abidlabs HF Staff
Upload folder using huggingface_hub
32a953c verified
from __future__ import annotations
from pathlib import Path
from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
from gradio_client import Client as GradioClient
from huggingface_hub.utils import build_hf_headers
from trackio.utils import parse_trackio_server_url
HTTP_API_VERSION = 1
FORCE_SYNC_TIMEOUT = 180.0
WRITE_TOKEN_HEADER = "x-trackio-write-token"
def _normalize_src(src: str) -> str:
return src if src.endswith("/") else src + "/"
def _space_id_to_url(space_id: str) -> str:
namespace, name = space_id.split("/", 1)
subdomain = f"{namespace}-{name}".lower().replace("_", "-").replace(".", "-")
return f"https://{subdomain}.hf.space/"
def _host_is_hf_space(url: str) -> bool:
p = urlparse(url)
h = (p.hostname or "").lower()
return h.endswith(".hf.space")
def _resolve_src_url(src: str) -> str:
if src.startswith(("http://", "https://")):
base, _ = parse_trackio_server_url(src)
return _normalize_src(base)
if "/" in src:
return _space_id_to_url(src)
raise ValueError(
f"Could not resolve Trackio remote source '{src}'. "
"Pass a full Space id like 'user/space' or a URL."
)
def _is_local_file_data(value: Any) -> bool:
return (
isinstance(value, dict)
and "path" in value
and isinstance(value["path"], str)
and value.get("meta", {}).get("_type") == "gradio.FileData"
and Path(value["path"]).exists()
)
def _merge_client_headers(
hf_token: str | None, write_token: str | None
) -> dict[str, str]:
headers: dict[str, str] = {}
if hf_token:
headers.update(build_hf_headers(token=hf_token))
if write_token:
headers[WRITE_TOKEN_HEADER] = write_token
return headers
def _request_timeout_for_api(
timeout: httpx.Timeout | float | int | None, api_name: str
) -> httpx.Timeout | float | int | None:
if api_name != "force_sync":
return timeout
normalized = httpx.Timeout(timeout)
read_timeout = normalized.read if normalized.read is not None else 0.0
if read_timeout >= FORCE_SYNC_TIMEOUT:
return timeout
return httpx.Timeout(
connect=normalized.connect,
read=FORCE_SYNC_TIMEOUT,
write=normalized.write,
pool=normalized.pool,
)
class _TrackioHTTPClient:
def __init__(
self,
src: str,
hf_token: str | None = None,
write_token: str | None = None,
httpx_kwargs: dict[str, Any] | None = None,
) -> None:
self.src = _resolve_src_url(src)
self.httpx_kwargs = dict(httpx_kwargs or {})
self.httpx_kwargs.setdefault("timeout", 60)
extra = self.httpx_kwargs.pop("headers", None)
h = _merge_client_headers(hf_token, write_token)
if isinstance(extra, dict):
h.update({str(k): str(v) for k, v in extra.items()})
self.headers = h
def _upload_file(self, file_data: dict[str, Any]) -> dict[str, Any]:
path = Path(file_data["path"])
with path.open("rb") as f:
resp = httpx.post(
urljoin(self.src, "api/upload"),
headers=self.headers,
files={"files": (path.name, f)},
**self.httpx_kwargs,
)
resp.raise_for_status()
uploaded_path = resp.json()["paths"][0]
return {
**file_data,
"path": uploaded_path,
"orig_name": file_data.get("orig_name", path.name),
}
def _prepare_value(self, value: Any) -> Any:
if _is_local_file_data(value):
return self._upload_file(value)
if isinstance(value, list):
return [self._prepare_value(item) for item in value]
if isinstance(value, tuple):
return [self._prepare_value(item) for item in value]
if isinstance(value, dict):
return {k: self._prepare_value(v) for k, v in value.items()}
return value
def predict(self, *args, api_name: str, **kwargs) -> Any:
api_name = api_name.lstrip("/")
payload = {
"args": self._prepare_value(list(args)),
"kwargs": self._prepare_value(kwargs),
}
request_kwargs = dict(self.httpx_kwargs)
request_kwargs["timeout"] = _request_timeout_for_api(
request_kwargs.get("timeout"), api_name
)
resp = httpx.post(
urljoin(self.src, f"api/{api_name}"),
headers=self.headers,
json=payload,
**request_kwargs,
)
if resp.status_code == 404:
raise RuntimeError(
f"Space '{self.src}' does not support '/{api_name}'. Redeploy with `trackio sync`."
)
resp.raise_for_status()
body = resp.json()
if body.get("error") is not None:
raise RuntimeError(body["error"])
return body.get("data")
class _TrackioGradioCompatClient:
def __init__(
self,
src: str,
hf_token: str | None = None,
write_token: str | None = None,
httpx_kwargs: dict[str, Any] | None = None,
verbose: bool = False,
) -> None:
kwargs: dict[str, Any] = {"verbose": verbose}
if hf_token:
kwargs["hf_token"] = hf_token
merged = dict(httpx_kwargs or {})
h = _merge_client_headers(
hf_token if hf_token else None,
write_token,
)
extra = merged.pop("headers", None)
if isinstance(extra, dict):
h.update({str(k): str(v) for k, v in extra.items()})
if h:
merged["headers"] = h
if merged:
kwargs["httpx_kwargs"] = merged
self._client = GradioClient(src, **kwargs)
def predict(self, *args, api_name: str, **kwargs) -> Any:
try:
return self._client.predict(*args, api_name=api_name, **kwargs)
except Exception as e:
if "API Not Found" in str(e) or "api_name" in str(e):
raise RuntimeError(
f"Space '{self._client.src}' does not support '{api_name}'. "
"Redeploy with `trackio sync`."
) from e
raise
def _supports_http_api(
src: str,
hf_token: str | None = None,
write_token: str | None = None,
httpx_kwargs: dict[str, Any] | None = None,
) -> bool:
url = _resolve_src_url(src)
headers = _merge_client_headers(hf_token, write_token)
kwargs = dict(httpx_kwargs or {})
kwargs.setdefault("timeout", 10)
try:
resp = httpx.get(urljoin(url, "version"), headers=headers, **kwargs)
if not resp.is_success:
return False
data = resp.json()
return data.get("api_version") == HTTP_API_VERSION
except Exception:
return False
class RemoteClient:
def __init__(
self,
space: str,
hf_token: str | None = None,
write_token: str | None = None,
httpx_kwargs: dict[str, Any] | None = None,
verbose: bool = False,
) -> None:
self._space = space
src_for_resolve = space
hf_effective = hf_token
wt_effective = write_token
if space.startswith(("http://", "https://")):
base, url_tok = parse_trackio_server_url(space)
src_for_resolve = base
if wt_effective is None:
wt_effective = url_tok
if not _host_is_hf_space(_normalize_src(base)):
hf_effective = None
try:
if _supports_http_api(
src_for_resolve,
hf_token=hf_effective,
write_token=wt_effective,
httpx_kwargs=httpx_kwargs,
):
self._client = _TrackioHTTPClient(
src_for_resolve,
hf_token=hf_effective,
write_token=wt_effective,
httpx_kwargs=httpx_kwargs,
)
else:
self._client = _TrackioGradioCompatClient(
src_for_resolve,
hf_token=hf_effective,
write_token=wt_effective,
httpx_kwargs=httpx_kwargs,
verbose=verbose,
)
except ValueError:
raise
except Exception as e:
raise ConnectionError(
f"Could not connect to Space '{space}'. Is it running?\n{e}"
) from e
def predict(self, *args, api_name: str, **kwargs) -> Any:
return self._client.predict(*args, api_name=api_name, **kwargs)