#!/usr/bin/env python3 """ Load test script for Sage API. Runs sequential requests and reports p50, p95, p99 latency. Usage: # Start the API first: python -m sage.api.run # Then run the load test: python scripts/load_test.py --requests 100 --url http://localhost:8000 # Test without explanations (faster): python scripts/load_test.py --no-explain # Save results to JSON (for reproducibility): python scripts/load_test.py --save David's target: p99 < 500ms """ import argparse import statistics import sys import time from datetime import datetime import httpx from sage.config import RESULTS_DIR, save_results # Test queries covering different scenarios QUERIES = [ "wireless headphones for working out", "laptop for video editing under $1500", "best phone case for iPhone", "comfortable running shoes", "noise canceling earbuds", "gaming keyboard mechanical", "portable charger high capacity", "bluetooth speaker waterproof", "monitor for programming", "ergonomic office chair", ] def percentile(data: list[float], p: float) -> float: """Calculate the p-th percentile of data.""" if not data: return 0.0 sorted_data = sorted(data) k = (len(sorted_data) - 1) * (p / 100) f = int(k) c = f + 1 if c >= len(sorted_data): return sorted_data[-1] return sorted_data[f] + (sorted_data[c] - sorted_data[f]) * (k - f) def run_load_test( base_url: str, num_requests: int, explain: bool, timeout: float, ) -> dict: """Run load test and return metrics.""" latencies: list[float] = [] errors = 0 cache_hits = 0 client = httpx.Client(timeout=timeout) endpoint = f"{base_url}/recommend" print(f"\nRunning {num_requests} requests to {endpoint}") print(f" explain={explain}, timeout={timeout}s") print("-" * 50) for i in range(num_requests): query = QUERIES[i % len(QUERIES)] payload = { "query": query, "k": 3, "explain": explain, } try: start = time.perf_counter() resp = client.post(endpoint, json=payload) elapsed = time.perf_counter() - start if resp.status_code == 200: latencies.append(elapsed * 1000) # Convert to ms # Check for cache hit (response time < 100ms typically indicates cache) if elapsed < 0.1: cache_hits += 1 else: errors += 1 print(f" [{i + 1}] Error: {resp.status_code} - {resp.text[:100]}") except Exception as e: errors += 1 print(f" [{i + 1}] Exception: {e}") # Progress indicator if (i + 1) % 10 == 0: print(f" Completed {i + 1}/{num_requests} requests...") client.close() # Calculate statistics if latencies: results = { "timestamp": datetime.now().isoformat(), "config": { "url": base_url, "num_requests": num_requests, "explain": explain, "timeout_s": timeout, }, "total_requests": num_requests, "successful": len(latencies), "errors": errors, "cache_hits": cache_hits, "min_ms": round(min(latencies), 1), "max_ms": round(max(latencies), 1), "mean_ms": round(statistics.mean(latencies), 1), "median_ms": round(statistics.median(latencies), 1), "p50_ms": round(percentile(latencies, 50), 1), "p95_ms": round(percentile(latencies, 95), 1), "p99_ms": round(percentile(latencies, 99), 1), "stdev_ms": round(statistics.stdev(latencies), 1) if len(latencies) > 1 else 0, } else: results = { "timestamp": datetime.now().isoformat(), "config": { "url": base_url, "num_requests": num_requests, "explain": explain, "timeout_s": timeout, }, "total_requests": num_requests, "successful": 0, "errors": errors, "cache_hits": 0, } return results def print_results(results: dict, target_p99_ms: float = 500.0) -> None: """Print formatted results.""" print("\n" + "=" * 50) print("LOAD TEST RESULTS") print("=" * 50) print(f"\nRequests: {results['successful']}/{results['total_requests']} successful") print(f"Errors: {results['errors']}") print(f"Cache hits: {results.get('cache_hits', 0)}") if results["successful"] > 0: print("\nLatency (ms):") print(f" Min: {results['min_ms']:.1f}") print(f" Max: {results['max_ms']:.1f}") print(f" Mean: {results['mean_ms']:.1f}") print(f" Median: {results['median_ms']:.1f}") print(f" StdDev: {results['stdev_ms']:.1f}") print("\nPercentiles (ms):") print(f" p50: {results['p50_ms']:.1f}") print(f" p95: {results['p95_ms']:.1f}") print(f" p99: {results['p99_ms']:.1f}") # Target check p99 = results["p99_ms"] if p99 <= target_p99_ms: print(f"\n Target p99 < {target_p99_ms}ms: PASS ({p99:.1f}ms)") else: print(f"\n Target p99 < {target_p99_ms}ms: FAIL ({p99:.1f}ms)") print( " Bottleneck: Likely LLM generation (check sage_llm_duration_seconds)" ) print("\n" + "=" * 50) def main(): parser = argparse.ArgumentParser(description="Load test Sage API") parser.add_argument( "--url", default="http://localhost:8000", help="Base URL of the API (default: http://localhost:8000)", ) parser.add_argument( "--requests", type=int, default=100, help="Number of requests to send (default: 100)", ) parser.add_argument( "--no-explain", action="store_true", help="Disable explanations (faster, tests retrieval only)", ) parser.add_argument( "--timeout", type=float, default=30.0, help="Request timeout in seconds (default: 30)", ) parser.add_argument( "--target-p99", type=float, default=500.0, help="Target p99 latency in ms (default: 500)", ) parser.add_argument( "--save", action="store_true", help="Save results to data/eval_results/load_test_*.json", ) args = parser.parse_args() # Quick health check try: resp = httpx.get(f"{args.url}/health", timeout=5.0) if resp.status_code != 200: print(f"API health check failed: {resp.status_code}") sys.exit(1) health = resp.json() print(f"API Status: {health.get('status', 'unknown')}") print( f"Qdrant: {'connected' if health.get('qdrant_connected') else 'disconnected'}" ) print(f"LLM: {'available' if health.get('llm_reachable') else 'unavailable'}") except Exception as e: print(f"Cannot connect to API at {args.url}: {e}") sys.exit(1) results = run_load_test( base_url=args.url, num_requests=args.requests, explain=not args.no_explain, timeout=args.timeout, ) # Add pass/fail status if results["successful"] > 0: results["target_p99_ms"] = args.target_p99 results["pass"] = results["p99_ms"] <= args.target_p99 print_results(results, target_p99_ms=args.target_p99) if args.save: RESULTS_DIR.mkdir(parents=True, exist_ok=True) saved_path = save_results(results, "load_test") print(f"\nResults saved: {saved_path}") if __name__ == "__main__": main()