File size: 4,461 Bytes
af1040f
 
 
 
 
 
 
 
 
 
 
 
 
d674874
 
 
 
 
 
 
 
160b869
af1040f
160b869
 
 
 
d674874
 
 
 
af1040f
d674874
 
 
af1040f
d674874
 
 
 
 
 
 
 
 
 
 
 
 
160b869
 
 
 
af1040f
160b869
 
 
 
af1040f
 
d674874
af1040f
 
 
 
 
d674874
 
af1040f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d674874
 
 
 
 
160b869
d674874
 
 
 
af1040f
 
 
 
 
d674874
af1040f
d674874
 
af1040f
 
d674874
 
 
 
 
 
 
 
 
af1040f
 
 
 
 
 
 
 
d674874
af1040f
 
 
 
d674874
 
 
af1040f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""HF Space probe: Gradio SDK + ZeroGPU.

Exposes a /rewrite Gradio API endpoint that returns:
  rewritten: str
  examples: list[str]
  latency_s: float
  input_tokens: int
  output_tokens: int
  cold_load_s: float
  uptime_s: float

Plus a /healthz endpoint via gr.routes for the probe to poll.
"""
from __future__ import annotations

import threading
import time
from pathlib import Path

import gradio as gr

try:
    import spaces
    _HAS_SPACES = True
except ImportError:
    _HAS_SPACES = False

from rewriter import Rewriter

DATA_DIR = Path(__file__).parent / "data"

_BOOT_AT = time.time()
_REWRITER: Rewriter | None = None
_LOAD_LOCK = threading.Lock()
_LOAD_LATENCY: float | None = None
_FIRST_INFERENCE_AT: float | None = None


def get_rewriter() -> Rewriter:
    global _REWRITER, _LOAD_LATENCY
    with _LOAD_LOCK:
        if _REWRITER is None:
            t0 = time.time()
            _REWRITER = Rewriter(DATA_DIR)
            _LOAD_LATENCY = time.time() - t0
            print(f"[rewriter] singleton ready in {_LOAD_LATENCY:.1f}s on {_REWRITER.device}", flush=True)
    return _REWRITER


def _do_rewrite(prompt: str) -> dict:
    return get_rewriter().rewrite(prompt)


# On ZeroGPU, wrap inference so HF allocates a GPU burst.
if _HAS_SPACES:
    _do_rewrite = spaces.GPU(duration=30)(_do_rewrite)


def rewrite_api(prompt: str) -> dict:
    """Gradio API entry point exposed at /api/rewrite (and /rewrite for direct REST style)."""
    global _FIRST_INFERENCE_AT
    if not prompt or not prompt.strip():
        return {"rewritten": "", "examples": [], "latency_s": 0.0, "input_tokens": 0, "output_tokens": 0,
                "cold_load_s": _LOAD_LATENCY or 0.0, "uptime_s": time.time() - _BOOT_AT, "error": "empty prompt"}
    t0 = time.time()
    out = _do_rewrite(prompt)
    if _FIRST_INFERENCE_AT is None:
        _FIRST_INFERENCE_AT = time.time()
    return {
        **out,
        "wall_latency_s": time.time() - t0,
        "cold_load_s": _LOAD_LATENCY or 0.0,
        "uptime_s": time.time() - _BOOT_AT,
    }


def healthz_api() -> dict:
    return {
        "ready": _REWRITER is not None,
        "uptime_s": time.time() - _BOOT_AT,
        "load_latency_s": _LOAD_LATENCY,
        "first_inference_at_uptime_s": (_FIRST_INFERENCE_AT - _BOOT_AT) if _FIRST_INFERENCE_AT else None,
        "device": _REWRITER.device if _REWRITER else None,
        "model_id": _REWRITER.model_id if _REWRITER else None,
    }


def gradio_rewrite(prompt: str):
    if not prompt.strip():
        return "(empty)", "(empty)", 0.0
    out = _do_rewrite(prompt)
    examples_str = "\n".join(f"  · {e}" for e in out["examples"])
    return out["rewritten"], examples_str, out["latency_s"]


# Eager-load the rewriter at module import so the first Gradio call doesn't pay the load cost.
# We do this AFTER the spaces decorator is wired up so ZeroGPU is initialised.
get_rewriter()


with gr.Blocks(title="AnimoFlow Rewriter Probe") as demo:
    gr.Markdown("# AnimoFlow Rewriter Probe — Qwen2.5-1.5B-Instruct + RAFSL on ZeroGPU")
    gr.Markdown(
        "Type a motion prompt in any language. The rewriter normalises it to a "
        "HumanML3D-style English caption. Powered by Qwen2.5-1.5B-Instruct + a multilingual "
        "MiniLM retriever over the 52K HumanML3D caption corpus."
    )
    with gr.Row():
        with gr.Column():
            inp = gr.Textbox(label="Your prompt (any language)", placeholder="一个人向前走", lines=2)
            btn = gr.Button("Rewrite", variant="primary")
        with gr.Column():
            out_text = gr.Textbox(label="Rewritten (HumanML3D-style English)", lines=2)
            out_examples = gr.Textbox(label="Retrieved exemplars", lines=4)
            out_latency = gr.Number(label="Latency (s)", precision=3)
    btn.click(gradio_rewrite, inputs=[inp], outputs=[out_text, out_examples, out_latency], api_name="rewrite_ui")

    # Pure API endpoint with JSON dict response — what the probe + clients should hit.
    with gr.Row(visible=False):
        api_in = gr.Textbox()
        api_out = gr.JSON()
        api_btn = gr.Button()
        api_btn.click(rewrite_api, inputs=[api_in], outputs=[api_out], api_name="rewrite")

    with gr.Row(visible=False):
        hz_btn = gr.Button()
        hz_out = gr.JSON()
        hz_btn.click(healthz_api, inputs=[], outputs=[hz_out], api_name="healthz")


if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)