blackboxai-api / scripts /load_test_chat.py
teletubbies's picture
Add multi-user mode to load test script
2b14a4b
Raw
History Blame Contribute Delete
5.76 kB
import asyncio
import os
import statistics
import time
from urllib.parse import urlparse
import httpx
BASE_URL = os.getenv("LOADTEST_BASE_URL", "http://localhost:8000").rstrip("/")
TOKEN = os.getenv("LOADTEST_ACCESS_TOKEN", os.getenv("LOADTEST_BEARER_TOKEN", ""))
LOGIN_EMAIL = os.getenv("LOADTEST_LOGIN_EMAIL", "")
LOGIN_PASSWORD = os.getenv("LOADTEST_LOGIN_PASSWORD", "")
SESSION_ID = int(os.getenv("LOADTEST_SESSION_ID", "1"))
CONCURRENCY = int(os.getenv("LOADTEST_CONCURRENCY", "20"))
TOTAL_REQUESTS = int(os.getenv("LOADTEST_TOTAL_REQUESTS", "100"))
USER_CREDENTIALS_RAW = os.getenv("LOADTEST_USER_CREDENTIALS", "")
def _parse_user_credentials() -> list[tuple[str, str]]:
creds: list[tuple[str, str]] = []
raw = USER_CREDENTIALS_RAW.strip()
if raw:
for item in raw.split(","):
chunk = item.strip()
if not chunk or ":" not in chunk:
continue
email, password = chunk.split(":", 1)
email = email.strip()
password = password.strip()
if email and password:
creds.append((email, password))
elif LOGIN_EMAIL and LOGIN_PASSWORD:
creds.append((LOGIN_EMAIL, LOGIN_PASSWORD))
return creds
async def _one_request(client: httpx.AsyncClient, index: int) -> float:
started = time.perf_counter()
resp = await client.post(
f"{BASE_URL}/chat/send",
json={"prompt": f"Load test prompt #{index}", "session_id": SESSION_ID},
timeout=65.0,
)
resp.raise_for_status()
return (time.perf_counter() - started) * 1000.0
async def _worker(
client: httpx.AsyncClient,
queue: asyncio.Queue[int],
out: list[float],
status_counts: dict[int, int],
) -> None:
while True:
idx = await queue.get()
if idx < 0:
queue.task_done()
return
try:
latency_ms = await _one_request(client, idx)
out.append(latency_ms)
except httpx.HTTPStatusError as e:
code = e.response.status_code
status_counts[code] = status_counts.get(code, 0) + 1
except Exception:
# Keep workers alive so one auth/network failure does not stall the run.
status_counts[-1] = status_counts.get(-1, 0) + 1
finally:
queue.task_done()
async def main() -> None:
if TOTAL_REQUESTS <= 0 or CONCURRENCY <= 0:
raise RuntimeError("LOADTEST_TOTAL_REQUESTS and LOADTEST_CONCURRENCY must be > 0")
queue: asyncio.Queue[int] = asyncio.Queue()
for i in range(TOTAL_REQUESTS):
queue.put_nowait(i + 1)
for _ in range(CONCURRENCY):
queue.put_nowait(-1)
latencies: list[float] = []
status_counts: dict[int, int] = {}
started = time.perf_counter()
limits = httpx.Limits(max_connections=CONCURRENCY * 2, max_keepalive_connections=CONCURRENCY)
host = urlparse(BASE_URL).hostname
if not host:
raise RuntimeError("Invalid LOADTEST_BASE_URL host")
credentials = _parse_user_credentials()
clients: list[httpx.AsyncClient] = []
workers: list[asyncio.Task] = []
try:
if not TOKEN and not credentials:
raise RuntimeError(
"Set LOADTEST_ACCESS_TOKEN (or LOADTEST_BEARER_TOKEN), "
"or LOADTEST_LOGIN_EMAIL + LOADTEST_LOGIN_PASSWORD, "
"or LOADTEST_USER_CREDENTIALS=email1:pass1,email2:pass2"
)
for worker_index in range(CONCURRENCY):
client = httpx.AsyncClient(limits=limits, http2=True)
if TOKEN:
client.cookies.set("access_token", TOKEN, domain=host)
else:
email, password = credentials[worker_index % len(credentials)]
login = await client.post(
f"{BASE_URL}/auth/login",
data={"username": email, "password": password},
timeout=30.0,
)
login.raise_for_status()
login_json = login.json()
access_token = login_json.get("access_token")
if not access_token:
raise RuntimeError(f"Login succeeded but no access_token found for {email}")
client.cookies.set("access_token", access_token, domain=host)
clients.append(client)
workers = [
asyncio.create_task(_worker(clients[i], queue, latencies, status_counts))
for i in range(CONCURRENCY)
]
await queue.join()
for worker in workers:
await worker
finally:
for client in clients:
await client.aclose()
elapsed = time.perf_counter() - started
latencies.sort()
p50 = statistics.median(latencies) if latencies else 0.0
p95 = latencies[int(len(latencies) * 0.95) - 1] if len(latencies) >= 20 else max(latencies, default=0.0)
p99 = latencies[int(len(latencies) * 0.99) - 1] if len(latencies) >= 100 else max(latencies, default=0.0)
print(f"Base URL: {BASE_URL}")
succeeded = len(latencies)
failed = TOTAL_REQUESTS - succeeded
print(f"Total requests: {TOTAL_REQUESTS}")
print(f"Succeeded: {succeeded}")
print(f"Failed: {failed}")
print(f"Concurrency: {CONCURRENCY}")
print(f"Elapsed: {elapsed:.2f}s")
print(f"Throughput (success): {succeeded/elapsed:.2f} req/s")
print(f"Latency p50: {p50:.1f}ms")
print(f"Latency p95: {p95:.1f}ms")
print(f"Latency p99: {p99:.1f}ms")
if status_counts:
print("Failure codes:")
for code in sorted(status_counts):
label = "network_or_other" if code == -1 else str(code)
print(f" {label}: {status_counts[code]}")
if __name__ == "__main__":
asyncio.run(main())