modelcourt / scripts /benchmark.py
existcode's picture
reupload
f983375 verified
Raw
History Blame Contribute Delete
2.46 kB
"""Model Court benchmark: sequential vs. parallel first-wave Role Agent execution."""
from __future__ import annotations
import argparse
import statistics
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from core.court import run_model_court_benchmark
from core.court_client import MockCourtClient, VLLMCourtClient
CASE = """\
Claim title: Escalation case with disputed slip-and-fall evidence
Claim amount: $50,000
The claimant slipped on a wet floor. The policy was active. No witnesses.
Store manager says no camera. Prior claim two years ago. Medical bills submitted.
"""
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--mock", action="store_true")
parser.add_argument("--endpoint", default="http://localhost:8000/v1")
parser.add_argument("--model", default="Qwen/Qwen2.5-14B-Instruct")
parser.add_argument("--runs", type=int, default=1)
args = parser.parse_args()
mode = "mock" if args.mock else "vllm"
first_wave_speedups: list[float] = []
full_speedups: list[float] = []
print("\nModel Court Benchmark")
print("=" * 60)
print(f"Model: {args.model}")
print(f"Endpoint mode: {mode}")
for index in range(args.runs):
client = MockCourtClient() if args.mock else VLLMCourtClient(args.endpoint, args.model)
result = run_model_court_benchmark(
case_text=CASE,
client=client,
model_name=args.model,
endpoint_mode=mode,
case_title="Benchmark case",
)
benchmark = result.benchmark
if benchmark is None:
raise RuntimeError("Benchmark artifact was not produced.")
first_wave_speedups.append(benchmark.first_wave_speedup)
full_speedups.append(benchmark.full_tribunal_speedup)
print(
f"Run {index + 1}: sequential={benchmark.sequential.total_seconds:.3f}s "
f"parallel={benchmark.parallel.total_seconds:.3f}s "
f"first-wave={benchmark.first_wave_speedup:.2f}x "
f"full={benchmark.full_tribunal_speedup:.2f}x"
)
print("-" * 60)
print(f"Mean first-wave speedup: {statistics.mean(first_wave_speedups):.2f}x")
print(f"Mean full Tribunal speedup: {statistics.mean(full_speedups):.2f}x")
if args.mock:
print("Mock mode measures orchestration overhead only, not GPU batching.")
if __name__ == "__main__":
main()