Spaces:
Sleeping
Sleeping
File size: 4,461 Bytes
af1040f d674874 160b869 af1040f 160b869 d674874 af1040f d674874 af1040f d674874 160b869 af1040f 160b869 af1040f d674874 af1040f d674874 af1040f d674874 160b869 d674874 af1040f d674874 af1040f d674874 af1040f d674874 af1040f d674874 af1040f d674874 af1040f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """HF Space probe: Gradio SDK + ZeroGPU.
Exposes a /rewrite Gradio API endpoint that returns:
rewritten: str
examples: list[str]
latency_s: float
input_tokens: int
output_tokens: int
cold_load_s: float
uptime_s: float
Plus a /healthz endpoint via gr.routes for the probe to poll.
"""
from __future__ import annotations
import threading
import time
from pathlib import Path
import gradio as gr
try:
import spaces
_HAS_SPACES = True
except ImportError:
_HAS_SPACES = False
from rewriter import Rewriter
DATA_DIR = Path(__file__).parent / "data"
_BOOT_AT = time.time()
_REWRITER: Rewriter | None = None
_LOAD_LOCK = threading.Lock()
_LOAD_LATENCY: float | None = None
_FIRST_INFERENCE_AT: float | None = None
def get_rewriter() -> Rewriter:
global _REWRITER, _LOAD_LATENCY
with _LOAD_LOCK:
if _REWRITER is None:
t0 = time.time()
_REWRITER = Rewriter(DATA_DIR)
_LOAD_LATENCY = time.time() - t0
print(f"[rewriter] singleton ready in {_LOAD_LATENCY:.1f}s on {_REWRITER.device}", flush=True)
return _REWRITER
def _do_rewrite(prompt: str) -> dict:
return get_rewriter().rewrite(prompt)
# On ZeroGPU, wrap inference so HF allocates a GPU burst.
if _HAS_SPACES:
_do_rewrite = spaces.GPU(duration=30)(_do_rewrite)
def rewrite_api(prompt: str) -> dict:
"""Gradio API entry point exposed at /api/rewrite (and /rewrite for direct REST style)."""
global _FIRST_INFERENCE_AT
if not prompt or not prompt.strip():
return {"rewritten": "", "examples": [], "latency_s": 0.0, "input_tokens": 0, "output_tokens": 0,
"cold_load_s": _LOAD_LATENCY or 0.0, "uptime_s": time.time() - _BOOT_AT, "error": "empty prompt"}
t0 = time.time()
out = _do_rewrite(prompt)
if _FIRST_INFERENCE_AT is None:
_FIRST_INFERENCE_AT = time.time()
return {
**out,
"wall_latency_s": time.time() - t0,
"cold_load_s": _LOAD_LATENCY or 0.0,
"uptime_s": time.time() - _BOOT_AT,
}
def healthz_api() -> dict:
return {
"ready": _REWRITER is not None,
"uptime_s": time.time() - _BOOT_AT,
"load_latency_s": _LOAD_LATENCY,
"first_inference_at_uptime_s": (_FIRST_INFERENCE_AT - _BOOT_AT) if _FIRST_INFERENCE_AT else None,
"device": _REWRITER.device if _REWRITER else None,
"model_id": _REWRITER.model_id if _REWRITER else None,
}
def gradio_rewrite(prompt: str):
if not prompt.strip():
return "(empty)", "(empty)", 0.0
out = _do_rewrite(prompt)
examples_str = "\n".join(f" · {e}" for e in out["examples"])
return out["rewritten"], examples_str, out["latency_s"]
# Eager-load the rewriter at module import so the first Gradio call doesn't pay the load cost.
# We do this AFTER the spaces decorator is wired up so ZeroGPU is initialised.
get_rewriter()
with gr.Blocks(title="AnimoFlow Rewriter Probe") as demo:
gr.Markdown("# AnimoFlow Rewriter Probe — Qwen2.5-1.5B-Instruct + RAFSL on ZeroGPU")
gr.Markdown(
"Type a motion prompt in any language. The rewriter normalises it to a "
"HumanML3D-style English caption. Powered by Qwen2.5-1.5B-Instruct + a multilingual "
"MiniLM retriever over the 52K HumanML3D caption corpus."
)
with gr.Row():
with gr.Column():
inp = gr.Textbox(label="Your prompt (any language)", placeholder="一个人向前走", lines=2)
btn = gr.Button("Rewrite", variant="primary")
with gr.Column():
out_text = gr.Textbox(label="Rewritten (HumanML3D-style English)", lines=2)
out_examples = gr.Textbox(label="Retrieved exemplars", lines=4)
out_latency = gr.Number(label="Latency (s)", precision=3)
btn.click(gradio_rewrite, inputs=[inp], outputs=[out_text, out_examples, out_latency], api_name="rewrite_ui")
# Pure API endpoint with JSON dict response — what the probe + clients should hit.
with gr.Row(visible=False):
api_in = gr.Textbox()
api_out = gr.JSON()
api_btn = gr.Button()
api_btn.click(rewrite_api, inputs=[api_in], outputs=[api_out], api_name="rewrite")
with gr.Row(visible=False):
hz_btn = gr.Button()
hz_out = gr.JSON()
hz_btn.click(healthz_api, inputs=[], outputs=[hz_out], api_name="healthz")
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|