BrejBala's picture
Deploy backend Docker app
e63c592
import argparse
import asyncio
import statistics
import time
from typing import Any, Dict, List, Tuple
import httpx
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Simple asyncio load tester for the /chat endpoint."
)
parser.add_argument(
"--backend-url",
type=str,
default="http://localhost:8000",
help="Base URL of the backend (default: http://localhost:8000)",
)
parser.add_argument(
"--namespace",
type=str,
default="dev",
help="Pinecone namespace to use for queries (default: dev)",
)
parser.add_argument(
"--concurrency",
type=int,
default=10,
help="Number of concurrent requests (default: 10)",
)
parser.add_argument(
"--requests",
type=int,
default=50,
help="Total number of /chat requests to issue (default: 50)",
)
parser.add_argument(
"--include-search",
action="store_true",
help="Also benchmark /search with the same concurrency and request count.",
)
parser.add_argument(
"--api-key",
type=str,
default=None,
help="Optional API key to send as X-API-Key header.",
)
return parser.parse_args()
async def _run_one_request(
client: httpx.AsyncClient,
method: str,
url: str,
json_body: Dict[str, Any],
headers: Dict[str, str],
semaphore: asyncio.Semaphore,
) -> Tuple[float, bool]:
start = time.perf_counter()
error = False
async with semaphore:
try:
resp = await client.request(method, url, json=json_body, headers=headers)
if resp.status_code >= 400:
error = True
except Exception:
error = True
finally:
elapsed = (time.perf_counter() - start) * 1000.0
return elapsed, error
async def _run_load_test(
base_url: str,
namespace: str,
concurrency: int,
total_requests: int,
api_key: str | None,
) -> Dict[str, Any]:
url = f"{base_url.rstrip('/')}/chat"
payload: Dict[str, Any] = {
"query": "Briefly explain retrieval-augmented generation.",
"namespace": namespace,
"top_k": 5,
"use_web_fallback": True,
}
headers: Dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["X-API-Key"] = api_key
semaphore = asyncio.Semaphore(concurrency)
latencies: List[float] = []
errors = 0
async with httpx.AsyncClient(timeout=30.0) as client:
tasks = [
_run_one_request(client, "POST", url, payload, headers, semaphore)
for _ in range(total_requests)
]
for coro in asyncio.as_completed(tasks):
elapsed_ms, is_error = await coro
latencies.append(elapsed_ms)
if is_error:
errors += 1
return {
"latencies_ms": latencies,
"errors": errors,
"total": total_requests,
}
async def _run_search_test(
base_url: str,
namespace: str,
concurrency: int,
total_requests: int,
api_key: str | None,
) -> Dict[str, Any]:
url = f"{base_url.rstrip('/')}/search"
payload: Dict[str, Any] = {
"query": "retrieval-augmented generation",
"top_k": 5,
"namespace": namespace,
}
headers: Dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["X-API-Key"] = api_key
semaphore = asyncio.Semaphore(concurrency)
latencies: List[float] = []
errors = 0
async with httpx.AsyncClient(timeout=30.0) as client:
tasks = [
_run_one_request(client, "POST", url, payload, headers, semaphore)
for _ in range(total_requests)
]
for coro in asyncio.as_completed(tasks):
elapsed_ms, is_error = await coro
latencies.append(elapsed_ms)
if is_error:
errors += 1
return {
"latencies_ms": latencies,
"errors": errors,
"total": total_requests,
}
def _summarise(result: Dict[str, Any], label: str) -> None:
latencies = result["latencies_ms"]
errors = result["errors"]
total = result["total"]
successes = total - errors
error_rate = (errors / total * 100.0) if total else 0.0
if latencies:
values = sorted(latencies)
avg = sum(values) / len(values)
p50 = statistics.median(values)
# Simple index-based p95 that works for small samples.
idx95 = max(0, int(round(0.95 * (len(values) - 1))))
p95 = values[idx95]
else:
avg = p50 = p95 = 0.0
print(f"=== {label} ===")
print(f"Total requests: {total}")
print(f"Successful: {successes}")
print(f"Errors: {errors} ({error_rate:.1f}%)")
print(f"Average latency: {avg:.2f} ms")
print(f"p50 latency: {p50:.2f} ms")
print(f"p95 latency: {p95:.2f} ms")
print()
async def main_async() -> None:
args = parse_args()
print(
f"Running /chat benchmark against {args.backend_url} "
f"namespace='{args.namespace}' concurrency={args.concurrency} "
f"requests={args.requests}"
)
chat_result = await _run_load_test(
base_url=args.backend_url,
namespace=args.namespace,
concurrency=args.concurrency,
total_requests=args.requests,
api_key=args.api_key,
)
_summarise(chat_result, "/chat")
if args.include_search:
print(
f"Running /search benchmark against {args.backend_url} "
f"namespace='{args.namespace}' concurrency={args.concurrency} "
f"requests={args.requests}"
)
search_result = await _run_search_test(
base_url=args.backend_url,
namespace=args.namespace,
concurrency=args.concurrency,
total_requests=args.requests,
api_key=args.api_key,
)
_summarise(search_result, "/search")
def main() -> None:
asyncio.run(main_async())
if __name__ == "__main__":
main()