Spaces:
Sleeping
Sleeping
| import argparse | |
| import asyncio | |
| import statistics | |
| import time | |
| from typing import Any, Dict, List, Tuple | |
| import httpx | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Simple asyncio load tester for the /chat endpoint." | |
| ) | |
| parser.add_argument( | |
| "--backend-url", | |
| type=str, | |
| default="http://localhost:8000", | |
| help="Base URL of the backend (default: http://localhost:8000)", | |
| ) | |
| parser.add_argument( | |
| "--namespace", | |
| type=str, | |
| default="dev", | |
| help="Pinecone namespace to use for queries (default: dev)", | |
| ) | |
| parser.add_argument( | |
| "--concurrency", | |
| type=int, | |
| default=10, | |
| help="Number of concurrent requests (default: 10)", | |
| ) | |
| parser.add_argument( | |
| "--requests", | |
| type=int, | |
| default=50, | |
| help="Total number of /chat requests to issue (default: 50)", | |
| ) | |
| parser.add_argument( | |
| "--include-search", | |
| action="store_true", | |
| help="Also benchmark /search with the same concurrency and request count.", | |
| ) | |
| parser.add_argument( | |
| "--api-key", | |
| type=str, | |
| default=None, | |
| help="Optional API key to send as X-API-Key header.", | |
| ) | |
| return parser.parse_args() | |
| async def _run_one_request( | |
| client: httpx.AsyncClient, | |
| method: str, | |
| url: str, | |
| json_body: Dict[str, Any], | |
| headers: Dict[str, str], | |
| semaphore: asyncio.Semaphore, | |
| ) -> Tuple[float, bool]: | |
| start = time.perf_counter() | |
| error = False | |
| async with semaphore: | |
| try: | |
| resp = await client.request(method, url, json=json_body, headers=headers) | |
| if resp.status_code >= 400: | |
| error = True | |
| except Exception: | |
| error = True | |
| finally: | |
| elapsed = (time.perf_counter() - start) * 1000.0 | |
| return elapsed, error | |
| async def _run_load_test( | |
| base_url: str, | |
| namespace: str, | |
| concurrency: int, | |
| total_requests: int, | |
| api_key: str | None, | |
| ) -> Dict[str, Any]: | |
| url = f"{base_url.rstrip('/')}/chat" | |
| payload: Dict[str, Any] = { | |
| "query": "Briefly explain retrieval-augmented generation.", | |
| "namespace": namespace, | |
| "top_k": 5, | |
| "use_web_fallback": True, | |
| } | |
| headers: Dict[str, str] = {"Content-Type": "application/json"} | |
| if api_key: | |
| headers["X-API-Key"] = api_key | |
| semaphore = asyncio.Semaphore(concurrency) | |
| latencies: List[float] = [] | |
| errors = 0 | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| tasks = [ | |
| _run_one_request(client, "POST", url, payload, headers, semaphore) | |
| for _ in range(total_requests) | |
| ] | |
| for coro in asyncio.as_completed(tasks): | |
| elapsed_ms, is_error = await coro | |
| latencies.append(elapsed_ms) | |
| if is_error: | |
| errors += 1 | |
| return { | |
| "latencies_ms": latencies, | |
| "errors": errors, | |
| "total": total_requests, | |
| } | |
| async def _run_search_test( | |
| base_url: str, | |
| namespace: str, | |
| concurrency: int, | |
| total_requests: int, | |
| api_key: str | None, | |
| ) -> Dict[str, Any]: | |
| url = f"{base_url.rstrip('/')}/search" | |
| payload: Dict[str, Any] = { | |
| "query": "retrieval-augmented generation", | |
| "top_k": 5, | |
| "namespace": namespace, | |
| } | |
| headers: Dict[str, str] = {"Content-Type": "application/json"} | |
| if api_key: | |
| headers["X-API-Key"] = api_key | |
| semaphore = asyncio.Semaphore(concurrency) | |
| latencies: List[float] = [] | |
| errors = 0 | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| tasks = [ | |
| _run_one_request(client, "POST", url, payload, headers, semaphore) | |
| for _ in range(total_requests) | |
| ] | |
| for coro in asyncio.as_completed(tasks): | |
| elapsed_ms, is_error = await coro | |
| latencies.append(elapsed_ms) | |
| if is_error: | |
| errors += 1 | |
| return { | |
| "latencies_ms": latencies, | |
| "errors": errors, | |
| "total": total_requests, | |
| } | |
| def _summarise(result: Dict[str, Any], label: str) -> None: | |
| latencies = result["latencies_ms"] | |
| errors = result["errors"] | |
| total = result["total"] | |
| successes = total - errors | |
| error_rate = (errors / total * 100.0) if total else 0.0 | |
| if latencies: | |
| values = sorted(latencies) | |
| avg = sum(values) / len(values) | |
| p50 = statistics.median(values) | |
| # Simple index-based p95 that works for small samples. | |
| idx95 = max(0, int(round(0.95 * (len(values) - 1)))) | |
| p95 = values[idx95] | |
| else: | |
| avg = p50 = p95 = 0.0 | |
| print(f"=== {label} ===") | |
| print(f"Total requests: {total}") | |
| print(f"Successful: {successes}") | |
| print(f"Errors: {errors} ({error_rate:.1f}%)") | |
| print(f"Average latency: {avg:.2f} ms") | |
| print(f"p50 latency: {p50:.2f} ms") | |
| print(f"p95 latency: {p95:.2f} ms") | |
| print() | |
| async def main_async() -> None: | |
| args = parse_args() | |
| print( | |
| f"Running /chat benchmark against {args.backend_url} " | |
| f"namespace='{args.namespace}' concurrency={args.concurrency} " | |
| f"requests={args.requests}" | |
| ) | |
| chat_result = await _run_load_test( | |
| base_url=args.backend_url, | |
| namespace=args.namespace, | |
| concurrency=args.concurrency, | |
| total_requests=args.requests, | |
| api_key=args.api_key, | |
| ) | |
| _summarise(chat_result, "/chat") | |
| if args.include_search: | |
| print( | |
| f"Running /search benchmark against {args.backend_url} " | |
| f"namespace='{args.namespace}' concurrency={args.concurrency} " | |
| f"requests={args.requests}" | |
| ) | |
| search_result = await _run_search_test( | |
| base_url=args.backend_url, | |
| namespace=args.namespace, | |
| concurrency=args.concurrency, | |
| total_requests=args.requests, | |
| api_key=args.api_key, | |
| ) | |
| _summarise(search_result, "/search") | |
| def main() -> None: | |
| asyncio.run(main_async()) | |
| if __name__ == "__main__": | |
| main() |