#!/usr/bin/env python3 """ Benchmark GGUF models via llama-cpp-python. Compares stock vs no-CoT fine-tuned Arch-Router-1.5B in Q8_0 GGUF format. """ import json import time from pathlib import Path from llama_cpp import Llama EVAL_PATH = Path(__file__).parent / "grpo_eval_data.json" STOCK_GGUF = Path(__file__).parent / "stock_arch_router.Q8_0.gguf" NOCOT_GGUF = Path(__file__).parent / "nocot_arch_router.Q8_0.gguf" ROUTE_POLICIES = [ {"name": "simple", "description": "Simple factual questions, greetings, basic lookups, yes/no answers, FAQ-style queries, single-step tasks, status checks, straightforward requests"}, {"name": "medium", "description": "Multi-step reasoning, summarization of moderate-length text, data extraction, moderate analysis, comparison tasks, troubleshooting, explanations requiring some depth"}, {"name": "complex", "description": "Complex multi-document reasoning, deep analysis, legal or financial interpretation, creative writing, code generation, multi-constraint problem solving, liability assessment, comprehensive evaluation"}, ] def find_gguf(directory: Path) -> Path: """Find the .gguf file in a directory.""" for f in directory.iterdir(): if f.suffix == ".gguf": return f raise FileNotFoundError(f"No .gguf file found in {directory}") def build_prompt(user_prompt: str) -> str: policies_json = json.dumps(ROUTE_POLICIES) conversation_json = json.dumps([{"role": "user", "content": user_prompt}]) return f"""You are a routing assistant. Given the route policies and user message, select the best matching route. {policies_json} {conversation_json} Select the best route for this user message. Respond with ONLY valid JSON: {{"route": "route_name"}}""" def extract_route(text: str) -> str | None: try: parsed = json.loads(text.strip()) route = parsed.get("route") if route in ("simple", "medium", "complex"): return route except (json.JSONDecodeError, TypeError): pass for tier in ("simple", "medium", "complex"): if tier in text.lower(): return tier return None def run_benchmark(model: Llama, data: list[dict], label: str) -> dict: results = {"correct": 0, "total": 0, "latencies_ms": [], "by_tier": {}, "misclassifications": []} # Warmup model.create_chat_completion( messages=[{"role": "user", "content": "Hello"}], max_tokens=10, temperature=0, ) for i, item in enumerate(data): prompt = build_prompt(item["prompt"]) start = time.perf_counter() output = model.create_chat_completion( messages=[{"role": "user", "content": prompt}], max_tokens=30, temperature=0, ) elapsed_ms = (time.perf_counter() - start) * 1000 response = output["choices"][0]["message"]["content"] predicted = extract_route(response) expected = item["expected_route"] correct = predicted == expected results["total"] += 1 results["latencies_ms"].append(elapsed_ms) if correct: results["correct"] += 1 if expected not in results["by_tier"]: results["by_tier"][expected] = {"correct": 0, "total": 0, "latencies": []} results["by_tier"][expected]["total"] += 1 results["by_tier"][expected]["latencies"].append(elapsed_ms) if correct: results["by_tier"][expected]["correct"] += 1 else: results["misclassifications"].append({ "prompt": item["prompt"][:80], "expected": expected, "predicted": predicted, }) status = "✓" if correct else "✗" print(f" [{i+1:2d}/{len(data)}] {status} [{expected:>7s}→{str(predicted):<7s}] {elapsed_ms:6.1f}ms | {item['prompt'][:55]}") return results def print_results(results, label): print(f"\n{'='*65}") print(f" {label}") print(f"{'='*65}") for tier in ("simple", "medium", "complex"): if tier in results["by_tier"]: t = results["by_tier"][tier] pct = t["correct"] / t["total"] * 100 if t["total"] else 0 avg_lat = sum(t["latencies"]) / len(t["latencies"]) bar = "█" * int(pct / 5) + "░" * (20 - int(pct / 5)) print(f" {tier:<10s} {t['correct']:>2d}/{t['total']:<2d} ({pct:5.1f}%) {bar} avg {avg_lat:6.1f}ms") total_pct = results["correct"] / results["total"] * 100 avg_latency = sum(results["latencies_ms"]) / len(results["latencies_ms"]) p50 = sorted(results["latencies_ms"])[len(results["latencies_ms"]) // 2] p95 = sorted(results["latencies_ms"])[int(len(results["latencies_ms"]) * 0.95)] print(f"\n OVERALL: {results['correct']}/{results['total']} ({total_pct:.1f}%)") print(f" Latency: avg {avg_latency:.1f}ms | p50 {p50:.1f}ms | p95 {p95:.1f}ms") if results["misclassifications"]: print(f"\n MISCLASSIFICATIONS ({len(results['misclassifications'])}):") for m in results["misclassifications"][:10]: print(f" {m['expected']}→{m['predicted']}: {m['prompt']}") print(f"{'='*65}\n") return {"accuracy": total_pct, "avg_ms": avg_latency, "p50_ms": p50, "p95_ms": p95} def main(): with open(EVAL_PATH) as f: data = json.load(f) print(f"Loaded {len(data)} eval prompts\n") # ── Stock GGUF ── print(f"Loading stock GGUF: {STOCK_GGUF.name}") stock_model = Llama( model_path=str(STOCK_GGUF), n_ctx=512, n_gpu_layers=-1, # All layers on GPU verbose=False, ) print("Stock GGUF loaded\n") print("Running stock GGUF benchmark...") stock_results = run_benchmark(stock_model, data, "Stock GGUF") stock_stats = print_results(stock_results, "STOCK Arch-Router-1.5B (GGUF Q8_0)") del stock_model # ── No-CoT GGUF ── print(f"Loading no-CoT GGUF: {NOCOT_GGUF.name}") nocot_model = Llama( model_path=str(NOCOT_GGUF), n_ctx=512, n_gpu_layers=-1, verbose=False, ) print("No-CoT GGUF loaded\n") print("Running no-CoT GGUF benchmark...") nocot_results = run_benchmark(nocot_model, data, "No-CoT GGUF") nocot_stats = print_results(nocot_results, "GRPO No-CoT Fine-Tuned (GGUF Q8_0)") # ── Comparison ── print(f"{'='*65}") print(f" GGUF Q8_0 COMPARISON: Stock vs No-CoT Fine-Tuned") print(f"{'='*65}") print(f" {'':20s} {'Stock':>12s} {'No-CoT FT':>12s} {'Delta':>10s}") print(f" {'Accuracy':20s} {stock_stats['accuracy']:>11.1f}% {nocot_stats['accuracy']:>11.1f}% {nocot_stats['accuracy']-stock_stats['accuracy']:>+9.1f}%") print(f" {'Avg latency':20s} {stock_stats['avg_ms']:>10.1f}ms {nocot_stats['avg_ms']:>10.1f}ms {nocot_stats['avg_ms']-stock_stats['avg_ms']:>+8.1f}ms") print(f" {'P50 latency':20s} {stock_stats['p50_ms']:>10.1f}ms {nocot_stats['p50_ms']:>10.1f}ms {nocot_stats['p50_ms']-stock_stats['p50_ms']:>+8.1f}ms") print(f" {'P95 latency':20s} {stock_stats['p95_ms']:>10.1f}ms {nocot_stats['p95_ms']:>10.1f}ms {nocot_stats['p95_ms']-stock_stats['p95_ms']:>+8.1f}ms") print(f"{'='*65}") if __name__ == "__main__": main()