Sage / scripts /load_test.py
vxa8502's picture
Add .hfignore to exclude binaries from HuggingFace deployment
df69e5d
#!/usr/bin/env python3
"""
Load test script for Sage API.
Runs sequential requests and reports p50, p95, p99 latency.
Usage:
# Start the API first:
python -m sage.api.run
# Then run the load test:
python scripts/load_test.py --requests 100 --url http://localhost:8000
# Test without explanations (faster):
python scripts/load_test.py --no-explain
# Save results to JSON (for reproducibility):
python scripts/load_test.py --save
David's target: p99 < 500ms
"""
import argparse
import statistics
import sys
import time
from datetime import datetime
import httpx
from sage.config import RESULTS_DIR, save_results
# Test queries covering different scenarios
QUERIES = [
"wireless headphones for working out",
"laptop for video editing under $1500",
"best phone case for iPhone",
"comfortable running shoes",
"noise canceling earbuds",
"gaming keyboard mechanical",
"portable charger high capacity",
"bluetooth speaker waterproof",
"monitor for programming",
"ergonomic office chair",
]
def percentile(data: list[float], p: float) -> float:
"""Calculate the p-th percentile of data."""
if not data:
return 0.0
sorted_data = sorted(data)
k = (len(sorted_data) - 1) * (p / 100)
f = int(k)
c = f + 1
if c >= len(sorted_data):
return sorted_data[-1]
return sorted_data[f] + (sorted_data[c] - sorted_data[f]) * (k - f)
def run_load_test(
base_url: str,
num_requests: int,
explain: bool,
timeout: float,
) -> dict:
"""Run load test and return metrics."""
latencies: list[float] = []
errors = 0
cache_hits = 0
client = httpx.Client(timeout=timeout)
endpoint = f"{base_url}/recommend"
print(f"\nRunning {num_requests} requests to {endpoint}")
print(f" explain={explain}, timeout={timeout}s")
print("-" * 50)
for i in range(num_requests):
query = QUERIES[i % len(QUERIES)]
payload = {
"query": query,
"k": 3,
"explain": explain,
}
try:
start = time.perf_counter()
resp = client.post(endpoint, json=payload)
elapsed = time.perf_counter() - start
if resp.status_code == 200:
latencies.append(elapsed * 1000) # Convert to ms
# Check for cache hit (response time < 100ms typically indicates cache)
if elapsed < 0.1:
cache_hits += 1
else:
errors += 1
print(f" [{i + 1}] Error: {resp.status_code} - {resp.text[:100]}")
except Exception as e:
errors += 1
print(f" [{i + 1}] Exception: {e}")
# Progress indicator
if (i + 1) % 10 == 0:
print(f" Completed {i + 1}/{num_requests} requests...")
client.close()
# Calculate statistics
if latencies:
results = {
"timestamp": datetime.now().isoformat(),
"config": {
"url": base_url,
"num_requests": num_requests,
"explain": explain,
"timeout_s": timeout,
},
"total_requests": num_requests,
"successful": len(latencies),
"errors": errors,
"cache_hits": cache_hits,
"min_ms": round(min(latencies), 1),
"max_ms": round(max(latencies), 1),
"mean_ms": round(statistics.mean(latencies), 1),
"median_ms": round(statistics.median(latencies), 1),
"p50_ms": round(percentile(latencies, 50), 1),
"p95_ms": round(percentile(latencies, 95), 1),
"p99_ms": round(percentile(latencies, 99), 1),
"stdev_ms": round(statistics.stdev(latencies), 1)
if len(latencies) > 1
else 0,
}
else:
results = {
"timestamp": datetime.now().isoformat(),
"config": {
"url": base_url,
"num_requests": num_requests,
"explain": explain,
"timeout_s": timeout,
},
"total_requests": num_requests,
"successful": 0,
"errors": errors,
"cache_hits": 0,
}
return results
def print_results(results: dict, target_p99_ms: float = 500.0) -> None:
"""Print formatted results."""
print("\n" + "=" * 50)
print("LOAD TEST RESULTS")
print("=" * 50)
print(f"\nRequests: {results['successful']}/{results['total_requests']} successful")
print(f"Errors: {results['errors']}")
print(f"Cache hits: {results.get('cache_hits', 0)}")
if results["successful"] > 0:
print("\nLatency (ms):")
print(f" Min: {results['min_ms']:.1f}")
print(f" Max: {results['max_ms']:.1f}")
print(f" Mean: {results['mean_ms']:.1f}")
print(f" Median: {results['median_ms']:.1f}")
print(f" StdDev: {results['stdev_ms']:.1f}")
print("\nPercentiles (ms):")
print(f" p50: {results['p50_ms']:.1f}")
print(f" p95: {results['p95_ms']:.1f}")
print(f" p99: {results['p99_ms']:.1f}")
# Target check
p99 = results["p99_ms"]
if p99 <= target_p99_ms:
print(f"\n Target p99 < {target_p99_ms}ms: PASS ({p99:.1f}ms)")
else:
print(f"\n Target p99 < {target_p99_ms}ms: FAIL ({p99:.1f}ms)")
print(
" Bottleneck: Likely LLM generation (check sage_llm_duration_seconds)"
)
print("\n" + "=" * 50)
def main():
parser = argparse.ArgumentParser(description="Load test Sage API")
parser.add_argument(
"--url",
default="http://localhost:8000",
help="Base URL of the API (default: http://localhost:8000)",
)
parser.add_argument(
"--requests",
type=int,
default=100,
help="Number of requests to send (default: 100)",
)
parser.add_argument(
"--no-explain",
action="store_true",
help="Disable explanations (faster, tests retrieval only)",
)
parser.add_argument(
"--timeout",
type=float,
default=30.0,
help="Request timeout in seconds (default: 30)",
)
parser.add_argument(
"--target-p99",
type=float,
default=500.0,
help="Target p99 latency in ms (default: 500)",
)
parser.add_argument(
"--save",
action="store_true",
help="Save results to data/eval_results/load_test_*.json",
)
args = parser.parse_args()
# Quick health check
try:
resp = httpx.get(f"{args.url}/health", timeout=5.0)
if resp.status_code != 200:
print(f"API health check failed: {resp.status_code}")
sys.exit(1)
health = resp.json()
print(f"API Status: {health.get('status', 'unknown')}")
print(
f"Qdrant: {'connected' if health.get('qdrant_connected') else 'disconnected'}"
)
print(f"LLM: {'available' if health.get('llm_reachable') else 'unavailable'}")
except Exception as e:
print(f"Cannot connect to API at {args.url}: {e}")
sys.exit(1)
results = run_load_test(
base_url=args.url,
num_requests=args.requests,
explain=not args.no_explain,
timeout=args.timeout,
)
# Add pass/fail status
if results["successful"] > 0:
results["target_p99_ms"] = args.target_p99
results["pass"] = results["p99_ms"] <= args.target_p99
print_results(results, target_p99_ms=args.target_p99)
if args.save:
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
saved_path = save_results(results, "load_test")
print(f"\nResults saved: {saved_path}")
if __name__ == "__main__":
main()