| from __future__ import annotations |
|
|
| import argparse |
| import base64 |
| import json |
| import os |
| import urllib.parse |
| from typing import Any, Dict, List, Tuple |
|
|
| import requests |
|
|
|
|
| def _image_data_from_http_url(url: str) -> Dict[str, Any]: |
| """ |
| Gradio's backend image preprocessing accepts `url` only for base64 data URLs. |
| So we fetch the bytes client-side and send a `data:<mime>;base64,...` URL. |
| """ |
| resp = requests.get(url, timeout=60) |
| resp.raise_for_status() |
| mime = (resp.headers.get("content-type") or "image/jpeg").split(";", 1)[0].strip() |
| b64 = base64.b64encode(resp.content).decode("utf-8") |
| return {"url": f"data:{mime};base64,{b64}", "meta": {"_type": "gradio.FileData"}} |
|
|
|
|
| def _call_gradio_call_api( |
| *, |
| base_url: str, |
| api_name: str, |
| data: List[Any], |
| headers: Dict[str, str] | None = None, |
| timeout_s: int = 600, |
| ) -> Tuple[Any, str]: |
| """ |
| Calls a Gradio Space endpoint via: |
| POST {base}/gradio_api/call/{api_name} -> {"event_id": "..."} |
| GET {base}/gradio_api/call/{api_name}/{event_id} (SSE stream) |
| """ |
| headers = headers or {} |
|
|
| join = requests.post( |
| f"{base_url}/gradio_api/call/{api_name}", |
| json={"data": data}, |
| headers=headers, |
| timeout=60, |
| ) |
| join.raise_for_status() |
| event_id = join.json()["event_id"] |
|
|
| current_event: str | None = None |
| with requests.get( |
| f"{base_url}/gradio_api/call/{api_name}/{event_id}", |
| headers=headers, |
| stream=True, |
| timeout=timeout_s, |
| ) as r: |
| r.raise_for_status() |
| for line in r.iter_lines(decode_unicode=True): |
| if not line: |
| continue |
|
|
| if line.startswith("event:"): |
| current_event = line.split(":", 1)[1].strip() |
| continue |
|
|
| if not line.startswith("data:"): |
| continue |
|
|
| data_str = line.split(":", 1)[1].strip() |
| if data_str in ("", "null", "[DONE]"): |
| continue |
|
|
| if current_event == "error": |
| raise RuntimeError(data_str) |
| if current_event == "complete": |
| return json.loads(data_str), event_id |
|
|
| raise TimeoutError(f"Timed out waiting for Gradio result (event_id={event_id}).") |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--base", required=True, help="Space base URL, e.g. https://<name>.hf.space") |
| parser.add_argument("--image-url", required=True, help="Public image URL (jpg/png)") |
| parser.add_argument("--prompt", required=True, help="Text prompt, e.g. 'cat'") |
| parser.add_argument("--api-name", default="run_text_prompt", help="Gradio api_name (no leading slash)") |
| args = parser.parse_args() |
|
|
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or "" |
| headers = {"Authorization": f"Bearer {token}"} if token else {} |
|
|
| result, event_id = _call_gradio_call_api( |
| base_url=args.base.rstrip("/"), |
| api_name=args.api_name.strip("/"), |
| data=[ |
| _image_data_from_http_url(args.image_url), |
| args.prompt, |
| 1008, |
| 0.3, |
| 0.5, |
| 8, |
| 8, |
| 0.5, |
| ], |
| headers=headers, |
| ) |
|
|
| overlay, detections = result |
| overlay_path = (overlay or {}).get("path") if isinstance(overlay, dict) else None |
| overlay_url = (overlay or {}).get("url") if isinstance(overlay, dict) else None |
| if not overlay_url and overlay_path: |
| overlay_url = f"{args.base.rstrip('/')}/gradio_api/file={urllib.parse.quote(overlay_path, safe='')}" |
|
|
| print("event_id:", event_id) |
| print("overlay_url:", overlay_url) |
| print("detections:", json.dumps(detections, indent=2)) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|