rewriter-probe / app.py
AnimoFlow's picture
Initial probe deploy 2026-06-14 18:03:58
af1040f verified
Raw
History Blame Contribute Delete
4.46 kB
"""HF Space probe: Gradio SDK + ZeroGPU.
Exposes a /rewrite Gradio API endpoint that returns:
rewritten: str
examples: list[str]
latency_s: float
input_tokens: int
output_tokens: int
cold_load_s: float
uptime_s: float
Plus a /healthz endpoint via gr.routes for the probe to poll.
"""
from __future__ import annotations
import threading
import time
from pathlib import Path
import gradio as gr
try:
import spaces
_HAS_SPACES = True
except ImportError:
_HAS_SPACES = False
from rewriter import Rewriter
DATA_DIR = Path(__file__).parent / "data"
_BOOT_AT = time.time()
_REWRITER: Rewriter | None = None
_LOAD_LOCK = threading.Lock()
_LOAD_LATENCY: float | None = None
_FIRST_INFERENCE_AT: float | None = None
def get_rewriter() -> Rewriter:
global _REWRITER, _LOAD_LATENCY
with _LOAD_LOCK:
if _REWRITER is None:
t0 = time.time()
_REWRITER = Rewriter(DATA_DIR)
_LOAD_LATENCY = time.time() - t0
print(f"[rewriter] singleton ready in {_LOAD_LATENCY:.1f}s on {_REWRITER.device}", flush=True)
return _REWRITER
def _do_rewrite(prompt: str) -> dict:
return get_rewriter().rewrite(prompt)
# On ZeroGPU, wrap inference so HF allocates a GPU burst.
if _HAS_SPACES:
_do_rewrite = spaces.GPU(duration=30)(_do_rewrite)
def rewrite_api(prompt: str) -> dict:
"""Gradio API entry point exposed at /api/rewrite (and /rewrite for direct REST style)."""
global _FIRST_INFERENCE_AT
if not prompt or not prompt.strip():
return {"rewritten": "", "examples": [], "latency_s": 0.0, "input_tokens": 0, "output_tokens": 0,
"cold_load_s": _LOAD_LATENCY or 0.0, "uptime_s": time.time() - _BOOT_AT, "error": "empty prompt"}
t0 = time.time()
out = _do_rewrite(prompt)
if _FIRST_INFERENCE_AT is None:
_FIRST_INFERENCE_AT = time.time()
return {
**out,
"wall_latency_s": time.time() - t0,
"cold_load_s": _LOAD_LATENCY or 0.0,
"uptime_s": time.time() - _BOOT_AT,
}
def healthz_api() -> dict:
return {
"ready": _REWRITER is not None,
"uptime_s": time.time() - _BOOT_AT,
"load_latency_s": _LOAD_LATENCY,
"first_inference_at_uptime_s": (_FIRST_INFERENCE_AT - _BOOT_AT) if _FIRST_INFERENCE_AT else None,
"device": _REWRITER.device if _REWRITER else None,
"model_id": _REWRITER.model_id if _REWRITER else None,
}
def gradio_rewrite(prompt: str):
if not prompt.strip():
return "(empty)", "(empty)", 0.0
out = _do_rewrite(prompt)
examples_str = "\n".join(f" · {e}" for e in out["examples"])
return out["rewritten"], examples_str, out["latency_s"]
# Eager-load the rewriter at module import so the first Gradio call doesn't pay the load cost.
# We do this AFTER the spaces decorator is wired up so ZeroGPU is initialised.
get_rewriter()
with gr.Blocks(title="AnimoFlow Rewriter Probe") as demo:
gr.Markdown("# AnimoFlow Rewriter Probe — Qwen2.5-1.5B-Instruct + RAFSL on ZeroGPU")
gr.Markdown(
"Type a motion prompt in any language. The rewriter normalises it to a "
"HumanML3D-style English caption. Powered by Qwen2.5-1.5B-Instruct + a multilingual "
"MiniLM retriever over the 52K HumanML3D caption corpus."
)
with gr.Row():
with gr.Column():
inp = gr.Textbox(label="Your prompt (any language)", placeholder="一个人向前走", lines=2)
btn = gr.Button("Rewrite", variant="primary")
with gr.Column():
out_text = gr.Textbox(label="Rewritten (HumanML3D-style English)", lines=2)
out_examples = gr.Textbox(label="Retrieved exemplars", lines=4)
out_latency = gr.Number(label="Latency (s)", precision=3)
btn.click(gradio_rewrite, inputs=[inp], outputs=[out_text, out_examples, out_latency], api_name="rewrite_ui")
# Pure API endpoint with JSON dict response — what the probe + clients should hit.
with gr.Row(visible=False):
api_in = gr.Textbox()
api_out = gr.JSON()
api_btn = gr.Button()
api_btn.click(rewrite_api, inputs=[api_in], outputs=[api_out], api_name="rewrite")
with gr.Row(visible=False):
hz_btn = gr.Button()
hz_out = gr.JSON()
hz_btn.click(healthz_api, inputs=[], outputs=[hz_out], api_name="healthz")
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)