sam3-cpu / examples /call_space_requests.py
Benji Peng
update files
2e2cccd
from __future__ import annotations
import argparse
import base64
import json
import os
import urllib.parse
from typing import Any, Dict, List, Tuple
import requests
def _image_data_from_http_url(url: str) -> Dict[str, Any]:
"""
Gradio's backend image preprocessing accepts `url` only for base64 data URLs.
So we fetch the bytes client-side and send a `data:<mime>;base64,...` URL.
"""
resp = requests.get(url, timeout=60)
resp.raise_for_status()
mime = (resp.headers.get("content-type") or "image/jpeg").split(";", 1)[0].strip()
b64 = base64.b64encode(resp.content).decode("utf-8")
return {"url": f"data:{mime};base64,{b64}", "meta": {"_type": "gradio.FileData"}}
def _call_gradio_call_api(
*,
base_url: str,
api_name: str,
data: List[Any],
headers: Dict[str, str] | None = None,
timeout_s: int = 600,
) -> Tuple[Any, str]:
"""
Calls a Gradio Space endpoint via:
POST {base}/gradio_api/call/{api_name} -> {"event_id": "..."}
GET {base}/gradio_api/call/{api_name}/{event_id} (SSE stream)
"""
headers = headers or {}
join = requests.post(
f"{base_url}/gradio_api/call/{api_name}",
json={"data": data},
headers=headers,
timeout=60,
)
join.raise_for_status()
event_id = join.json()["event_id"]
current_event: str | None = None
with requests.get(
f"{base_url}/gradio_api/call/{api_name}/{event_id}",
headers=headers,
stream=True,
timeout=timeout_s,
) as r:
r.raise_for_status()
for line in r.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("event:"):
current_event = line.split(":", 1)[1].strip()
continue
if not line.startswith("data:"):
continue
data_str = line.split(":", 1)[1].strip()
if data_str in ("", "null", "[DONE]"):
continue
if current_event == "error":
raise RuntimeError(data_str)
if current_event == "complete":
return json.loads(data_str), event_id
raise TimeoutError(f"Timed out waiting for Gradio result (event_id={event_id}).")
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--base", required=True, help="Space base URL, e.g. https://<name>.hf.space")
parser.add_argument("--image-url", required=True, help="Public image URL (jpg/png)")
parser.add_argument("--prompt", required=True, help="Text prompt, e.g. 'cat'")
parser.add_argument("--api-name", default="run_text_prompt", help="Gradio api_name (no leading slash)")
args = parser.parse_args()
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or ""
headers = {"Authorization": f"Bearer {token}"} if token else {}
result, event_id = _call_gradio_call_api(
base_url=args.base.rstrip("/"),
api_name=args.api_name.strip("/"),
data=[
_image_data_from_http_url(args.image_url),
args.prompt,
1008, # max_long_side
0.3, # score_threshold
0.5, # mask_threshold
8, # max_instances
8, # num_threads
0.5, # overlay_alpha
],
headers=headers,
)
overlay, detections = result
overlay_path = (overlay or {}).get("path") if isinstance(overlay, dict) else None
overlay_url = (overlay or {}).get("url") if isinstance(overlay, dict) else None
if not overlay_url and overlay_path:
overlay_url = f"{args.base.rstrip('/')}/gradio_api/file={urllib.parse.quote(overlay_path, safe='')}"
print("event_id:", event_id)
print("overlay_url:", overlay_url)
print("detections:", json.dumps(detections, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())