Spaces:
Running
Running
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Any | |
| from urllib.parse import urljoin, urlparse | |
| import httpx | |
| from gradio_client import Client as GradioClient | |
| from huggingface_hub.utils import build_hf_headers | |
| from trackio.utils import parse_trackio_server_url | |
| HTTP_API_VERSION = 1 | |
| FORCE_SYNC_TIMEOUT = 180.0 | |
| WRITE_TOKEN_HEADER = "x-trackio-write-token" | |
| def _normalize_src(src: str) -> str: | |
| return src if src.endswith("/") else src + "/" | |
| def _space_id_to_url(space_id: str) -> str: | |
| namespace, name = space_id.split("/", 1) | |
| subdomain = f"{namespace}-{name}".lower().replace("_", "-").replace(".", "-") | |
| return f"https://{subdomain}.hf.space/" | |
| def _host_is_hf_space(url: str) -> bool: | |
| p = urlparse(url) | |
| h = (p.hostname or "").lower() | |
| return h.endswith(".hf.space") | |
| def _resolve_src_url(src: str) -> str: | |
| if src.startswith(("http://", "https://")): | |
| base, _ = parse_trackio_server_url(src) | |
| return _normalize_src(base) | |
| if "/" in src: | |
| return _space_id_to_url(src) | |
| raise ValueError( | |
| f"Could not resolve Trackio remote source '{src}'. " | |
| "Pass a full Space id like 'user/space' or a URL." | |
| ) | |
| def _is_local_file_data(value: Any) -> bool: | |
| return ( | |
| isinstance(value, dict) | |
| and "path" in value | |
| and isinstance(value["path"], str) | |
| and value.get("meta", {}).get("_type") == "gradio.FileData" | |
| and Path(value["path"]).exists() | |
| ) | |
| def _merge_client_headers( | |
| hf_token: str | None, write_token: str | None | |
| ) -> dict[str, str]: | |
| headers: dict[str, str] = {} | |
| if hf_token: | |
| headers.update(build_hf_headers(token=hf_token)) | |
| if write_token: | |
| headers[WRITE_TOKEN_HEADER] = write_token | |
| return headers | |
| def _request_timeout_for_api( | |
| timeout: httpx.Timeout | float | int | None, api_name: str | |
| ) -> httpx.Timeout | float | int | None: | |
| if api_name != "force_sync": | |
| return timeout | |
| normalized = httpx.Timeout(timeout) | |
| read_timeout = normalized.read if normalized.read is not None else 0.0 | |
| if read_timeout >= FORCE_SYNC_TIMEOUT: | |
| return timeout | |
| return httpx.Timeout( | |
| connect=normalized.connect, | |
| read=FORCE_SYNC_TIMEOUT, | |
| write=normalized.write, | |
| pool=normalized.pool, | |
| ) | |
| class _TrackioHTTPClient: | |
| def __init__( | |
| self, | |
| src: str, | |
| hf_token: str | None = None, | |
| write_token: str | None = None, | |
| httpx_kwargs: dict[str, Any] | None = None, | |
| ) -> None: | |
| self.src = _resolve_src_url(src) | |
| self.httpx_kwargs = dict(httpx_kwargs or {}) | |
| self.httpx_kwargs.setdefault("timeout", 60) | |
| extra = self.httpx_kwargs.pop("headers", None) | |
| h = _merge_client_headers(hf_token, write_token) | |
| if isinstance(extra, dict): | |
| h.update({str(k): str(v) for k, v in extra.items()}) | |
| self.headers = h | |
| def _upload_file(self, file_data: dict[str, Any]) -> dict[str, Any]: | |
| path = Path(file_data["path"]) | |
| with path.open("rb") as f: | |
| resp = httpx.post( | |
| urljoin(self.src, "api/upload"), | |
| headers=self.headers, | |
| files={"files": (path.name, f)}, | |
| **self.httpx_kwargs, | |
| ) | |
| resp.raise_for_status() | |
| uploaded_path = resp.json()["paths"][0] | |
| return { | |
| **file_data, | |
| "path": uploaded_path, | |
| "orig_name": file_data.get("orig_name", path.name), | |
| } | |
| def _prepare_value(self, value: Any) -> Any: | |
| if _is_local_file_data(value): | |
| return self._upload_file(value) | |
| if isinstance(value, list): | |
| return [self._prepare_value(item) for item in value] | |
| if isinstance(value, tuple): | |
| return [self._prepare_value(item) for item in value] | |
| if isinstance(value, dict): | |
| return {k: self._prepare_value(v) for k, v in value.items()} | |
| return value | |
| def predict(self, *args, api_name: str, **kwargs) -> Any: | |
| api_name = api_name.lstrip("/") | |
| payload = { | |
| "args": self._prepare_value(list(args)), | |
| "kwargs": self._prepare_value(kwargs), | |
| } | |
| request_kwargs = dict(self.httpx_kwargs) | |
| request_kwargs["timeout"] = _request_timeout_for_api( | |
| request_kwargs.get("timeout"), api_name | |
| ) | |
| resp = httpx.post( | |
| urljoin(self.src, f"api/{api_name}"), | |
| headers=self.headers, | |
| json=payload, | |
| **request_kwargs, | |
| ) | |
| if resp.status_code == 404: | |
| raise RuntimeError( | |
| f"Space '{self.src}' does not support '/{api_name}'. Redeploy with `trackio sync`." | |
| ) | |
| resp.raise_for_status() | |
| body = resp.json() | |
| if body.get("error") is not None: | |
| raise RuntimeError(body["error"]) | |
| return body.get("data") | |
| class _TrackioGradioCompatClient: | |
| def __init__( | |
| self, | |
| src: str, | |
| hf_token: str | None = None, | |
| write_token: str | None = None, | |
| httpx_kwargs: dict[str, Any] | None = None, | |
| verbose: bool = False, | |
| ) -> None: | |
| kwargs: dict[str, Any] = {"verbose": verbose} | |
| if hf_token: | |
| kwargs["hf_token"] = hf_token | |
| merged = dict(httpx_kwargs or {}) | |
| h = _merge_client_headers( | |
| hf_token if hf_token else None, | |
| write_token, | |
| ) | |
| extra = merged.pop("headers", None) | |
| if isinstance(extra, dict): | |
| h.update({str(k): str(v) for k, v in extra.items()}) | |
| if h: | |
| merged["headers"] = h | |
| if merged: | |
| kwargs["httpx_kwargs"] = merged | |
| self._client = GradioClient(src, **kwargs) | |
| def predict(self, *args, api_name: str, **kwargs) -> Any: | |
| try: | |
| return self._client.predict(*args, api_name=api_name, **kwargs) | |
| except Exception as e: | |
| if "API Not Found" in str(e) or "api_name" in str(e): | |
| raise RuntimeError( | |
| f"Space '{self._client.src}' does not support '{api_name}'. " | |
| "Redeploy with `trackio sync`." | |
| ) from e | |
| raise | |
| def _supports_http_api( | |
| src: str, | |
| hf_token: str | None = None, | |
| write_token: str | None = None, | |
| httpx_kwargs: dict[str, Any] | None = None, | |
| ) -> bool: | |
| url = _resolve_src_url(src) | |
| headers = _merge_client_headers(hf_token, write_token) | |
| kwargs = dict(httpx_kwargs or {}) | |
| kwargs.setdefault("timeout", 10) | |
| try: | |
| resp = httpx.get(urljoin(url, "version"), headers=headers, **kwargs) | |
| if not resp.is_success: | |
| return False | |
| data = resp.json() | |
| return data.get("api_version") == HTTP_API_VERSION | |
| except Exception: | |
| return False | |
| class RemoteClient: | |
| def __init__( | |
| self, | |
| space: str, | |
| hf_token: str | None = None, | |
| write_token: str | None = None, | |
| httpx_kwargs: dict[str, Any] | None = None, | |
| verbose: bool = False, | |
| ) -> None: | |
| self._space = space | |
| src_for_resolve = space | |
| hf_effective = hf_token | |
| wt_effective = write_token | |
| if space.startswith(("http://", "https://")): | |
| base, url_tok = parse_trackio_server_url(space) | |
| src_for_resolve = base | |
| if wt_effective is None: | |
| wt_effective = url_tok | |
| if not _host_is_hf_space(_normalize_src(base)): | |
| hf_effective = None | |
| try: | |
| if _supports_http_api( | |
| src_for_resolve, | |
| hf_token=hf_effective, | |
| write_token=wt_effective, | |
| httpx_kwargs=httpx_kwargs, | |
| ): | |
| self._client = _TrackioHTTPClient( | |
| src_for_resolve, | |
| hf_token=hf_effective, | |
| write_token=wt_effective, | |
| httpx_kwargs=httpx_kwargs, | |
| ) | |
| else: | |
| self._client = _TrackioGradioCompatClient( | |
| src_for_resolve, | |
| hf_token=hf_effective, | |
| write_token=wt_effective, | |
| httpx_kwargs=httpx_kwargs, | |
| verbose=verbose, | |
| ) | |
| except ValueError: | |
| raise | |
| except Exception as e: | |
| raise ConnectionError( | |
| f"Could not connect to Space '{space}'. Is it running?\n{e}" | |
| ) from e | |
| def predict(self, *args, api_name: str, **kwargs) -> Any: | |
| return self._client.predict(*args, api_name=api_name, **kwargs) | |