File size: 2,198 Bytes
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Fire concurrent prompts at a running tiny_vllm server.

Run the server first:
    python -m tiny_vllm.server --model Qwen/Qwen2.5-0.5B-Instruct

Then in another shell:
    python examples/smoke_client.py
"""
from __future__ import annotations

import argparse
import asyncio
import json
import time

import httpx


PROMPTS = [
    "Write a haiku about paged attention.",
    "Explain GQA in one paragraph.",
    "What is continuous batching, briefly?",
    "List three uses of prefix caching.",
]


async def one(client: httpx.AsyncClient, prompt: str, idx: int) -> tuple[str, float]:
    t0 = time.monotonic()
    print(f"[{idx}] >> {prompt!r}")
    text_parts: list[str] = []
    async with client.stream(
        "POST", "/generate",
        json={"prompt": prompt, "max_tokens": 48, "temperature": 0.7, "top_p": 0.9, "stream": True},
        timeout=None,
    ) as resp:
        resp.raise_for_status()
        async for raw in resp.aiter_lines():
            if not raw.startswith("data: "):
                continue
            data = raw[6:]
            if data == "[DONE]":
                break
            chunk = json.loads(data)
            if chunk.get("text"):
                text_parts.append(chunk["text"])
            if chunk.get("finished"):
                break
    dt = time.monotonic() - t0
    text = "".join(text_parts)
    print(f"[{idx}] << ({dt:.2f}s) {text}")
    return text, dt


async def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--base-url", default="http://127.0.0.1:8000")
    p.add_argument("--rounds", type=int, default=1)
    p.add_argument("--prefix-demo", action="store_true",
                   help="send same prompt 3x to show prefix cache speedup")
    args = p.parse_args()

    async with httpx.AsyncClient(base_url=args.base_url) as client:
        if args.prefix_demo:
            prompt = PROMPTS[0]
            for i in range(3):
                await one(client, prompt, i)
            return
        for r in range(args.rounds):
            tasks = [one(client, p, i + r * len(PROMPTS)) for i, p in enumerate(PROMPTS)]
            await asyncio.gather(*tasks)


if __name__ == "__main__":
    asyncio.run(main())