| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import tempfile |
| from pathlib import Path |
|
|
| import requests |
| from gradio_client import Client |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--space", required=True, help="Space base URL or repo, e.g. https://<name>.hf.space or user/space") |
| parser.add_argument("--image-url", required=True) |
| parser.add_argument("--prompt", required=True) |
| args = parser.parse_args() |
|
|
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None |
| client = Client(args.space, hf_token=token) |
|
|
| tmp_path: str | None = None |
| try: |
| |
| |
| resp = requests.get(args.image_url, timeout=60) |
| resp.raise_for_status() |
| suffix = Path(args.image_url).suffix or ".jpg" |
| with tempfile.NamedTemporaryFile(prefix="gradio_", suffix=suffix, delete=False) as f: |
| f.write(resp.content) |
| tmp_path = f.name |
|
|
| result = client.predict( |
| tmp_path, |
| args.prompt, |
| 1008, |
| 0.3, |
| 0.5, |
| 8, |
| 8, |
| 0.5, |
| api_name="/run_text_prompt", |
| ) |
| finally: |
| if tmp_path: |
| try: |
| os.unlink(tmp_path) |
| except OSError: |
| pass |
|
|
| overlay, detections = result |
| print("overlay:", overlay) |
| print("detections:", detections) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|