Spaces:
Sleeping
Sleeping
| """HF Space probe: Gradio SDK + ZeroGPU. | |
| Exposes a /rewrite Gradio API endpoint that returns: | |
| rewritten: str | |
| examples: list[str] | |
| latency_s: float | |
| input_tokens: int | |
| output_tokens: int | |
| cold_load_s: float | |
| uptime_s: float | |
| Plus a /healthz endpoint via gr.routes for the probe to poll. | |
| """ | |
| from __future__ import annotations | |
| import threading | |
| import time | |
| from pathlib import Path | |
| import gradio as gr | |
| try: | |
| import spaces | |
| _HAS_SPACES = True | |
| except ImportError: | |
| _HAS_SPACES = False | |
| from rewriter import Rewriter | |
| DATA_DIR = Path(__file__).parent / "data" | |
| _BOOT_AT = time.time() | |
| _REWRITER: Rewriter | None = None | |
| _LOAD_LOCK = threading.Lock() | |
| _LOAD_LATENCY: float | None = None | |
| _FIRST_INFERENCE_AT: float | None = None | |
| def get_rewriter() -> Rewriter: | |
| global _REWRITER, _LOAD_LATENCY | |
| with _LOAD_LOCK: | |
| if _REWRITER is None: | |
| t0 = time.time() | |
| _REWRITER = Rewriter(DATA_DIR) | |
| _LOAD_LATENCY = time.time() - t0 | |
| print(f"[rewriter] singleton ready in {_LOAD_LATENCY:.1f}s on {_REWRITER.device}", flush=True) | |
| return _REWRITER | |
| def _do_rewrite(prompt: str) -> dict: | |
| return get_rewriter().rewrite(prompt) | |
| # On ZeroGPU, wrap inference so HF allocates a GPU burst. | |
| if _HAS_SPACES: | |
| _do_rewrite = spaces.GPU(duration=30)(_do_rewrite) | |
| def rewrite_api(prompt: str) -> dict: | |
| """Gradio API entry point exposed at /api/rewrite (and /rewrite for direct REST style).""" | |
| global _FIRST_INFERENCE_AT | |
| if not prompt or not prompt.strip(): | |
| return {"rewritten": "", "examples": [], "latency_s": 0.0, "input_tokens": 0, "output_tokens": 0, | |
| "cold_load_s": _LOAD_LATENCY or 0.0, "uptime_s": time.time() - _BOOT_AT, "error": "empty prompt"} | |
| t0 = time.time() | |
| out = _do_rewrite(prompt) | |
| if _FIRST_INFERENCE_AT is None: | |
| _FIRST_INFERENCE_AT = time.time() | |
| return { | |
| **out, | |
| "wall_latency_s": time.time() - t0, | |
| "cold_load_s": _LOAD_LATENCY or 0.0, | |
| "uptime_s": time.time() - _BOOT_AT, | |
| } | |
| def healthz_api() -> dict: | |
| return { | |
| "ready": _REWRITER is not None, | |
| "uptime_s": time.time() - _BOOT_AT, | |
| "load_latency_s": _LOAD_LATENCY, | |
| "first_inference_at_uptime_s": (_FIRST_INFERENCE_AT - _BOOT_AT) if _FIRST_INFERENCE_AT else None, | |
| "device": _REWRITER.device if _REWRITER else None, | |
| "model_id": _REWRITER.model_id if _REWRITER else None, | |
| } | |
| def gradio_rewrite(prompt: str): | |
| if not prompt.strip(): | |
| return "(empty)", "(empty)", 0.0 | |
| out = _do_rewrite(prompt) | |
| examples_str = "\n".join(f" · {e}" for e in out["examples"]) | |
| return out["rewritten"], examples_str, out["latency_s"] | |
| # Eager-load the rewriter at module import so the first Gradio call doesn't pay the load cost. | |
| # We do this AFTER the spaces decorator is wired up so ZeroGPU is initialised. | |
| get_rewriter() | |
| with gr.Blocks(title="AnimoFlow Rewriter Probe") as demo: | |
| gr.Markdown("# AnimoFlow Rewriter Probe — Qwen2.5-1.5B-Instruct + RAFSL on ZeroGPU") | |
| gr.Markdown( | |
| "Type a motion prompt in any language. The rewriter normalises it to a " | |
| "HumanML3D-style English caption. Powered by Qwen2.5-1.5B-Instruct + a multilingual " | |
| "MiniLM retriever over the 52K HumanML3D caption corpus." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Textbox(label="Your prompt (any language)", placeholder="一个人向前走", lines=2) | |
| btn = gr.Button("Rewrite", variant="primary") | |
| with gr.Column(): | |
| out_text = gr.Textbox(label="Rewritten (HumanML3D-style English)", lines=2) | |
| out_examples = gr.Textbox(label="Retrieved exemplars", lines=4) | |
| out_latency = gr.Number(label="Latency (s)", precision=3) | |
| btn.click(gradio_rewrite, inputs=[inp], outputs=[out_text, out_examples, out_latency], api_name="rewrite_ui") | |
| # Pure API endpoint with JSON dict response — what the probe + clients should hit. | |
| with gr.Row(visible=False): | |
| api_in = gr.Textbox() | |
| api_out = gr.JSON() | |
| api_btn = gr.Button() | |
| api_btn.click(rewrite_api, inputs=[api_in], outputs=[api_out], api_name="rewrite") | |
| with gr.Row(visible=False): | |
| hz_btn = gr.Button() | |
| hz_out = gr.JSON() | |
| hz_btn.click(healthz_api, inputs=[], outputs=[hz_out], api_name="healthz") | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |