File size: 7,323 Bytes
e892559 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | #!/usr/bin/env python3
"""
Benchmark GGUF models via llama-cpp-python.
Compares stock vs no-CoT fine-tuned Arch-Router-1.5B in Q8_0 GGUF format.
"""
import json
import time
from pathlib import Path
from llama_cpp import Llama
EVAL_PATH = Path(__file__).parent / "grpo_eval_data.json"
STOCK_GGUF = Path(__file__).parent / "stock_arch_router.Q8_0.gguf"
NOCOT_GGUF = Path(__file__).parent / "nocot_arch_router.Q8_0.gguf"
ROUTE_POLICIES = [
{"name": "simple", "description": "Simple factual questions, greetings, basic lookups, yes/no answers, FAQ-style queries, single-step tasks, status checks, straightforward requests"},
{"name": "medium", "description": "Multi-step reasoning, summarization of moderate-length text, data extraction, moderate analysis, comparison tasks, troubleshooting, explanations requiring some depth"},
{"name": "complex", "description": "Complex multi-document reasoning, deep analysis, legal or financial interpretation, creative writing, code generation, multi-constraint problem solving, liability assessment, comprehensive evaluation"},
]
def find_gguf(directory: Path) -> Path:
"""Find the .gguf file in a directory."""
for f in directory.iterdir():
if f.suffix == ".gguf":
return f
raise FileNotFoundError(f"No .gguf file found in {directory}")
def build_prompt(user_prompt: str) -> str:
policies_json = json.dumps(ROUTE_POLICIES)
conversation_json = json.dumps([{"role": "user", "content": user_prompt}])
return f"""You are a routing assistant. Given the route policies and user message, select the best matching route.
<route_policies>
{policies_json}
</route_policies>
<conversation>
{conversation_json}
</conversation>
Select the best route for this user message. Respond with ONLY valid JSON: {{"route": "route_name"}}"""
def extract_route(text: str) -> str | None:
try:
parsed = json.loads(text.strip())
route = parsed.get("route")
if route in ("simple", "medium", "complex"):
return route
except (json.JSONDecodeError, TypeError):
pass
for tier in ("simple", "medium", "complex"):
if tier in text.lower():
return tier
return None
def run_benchmark(model: Llama, data: list[dict], label: str) -> dict:
results = {"correct": 0, "total": 0, "latencies_ms": [], "by_tier": {}, "misclassifications": []}
# Warmup
model.create_chat_completion(
messages=[{"role": "user", "content": "Hello"}],
max_tokens=10, temperature=0,
)
for i, item in enumerate(data):
prompt = build_prompt(item["prompt"])
start = time.perf_counter()
output = model.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=30,
temperature=0,
)
elapsed_ms = (time.perf_counter() - start) * 1000
response = output["choices"][0]["message"]["content"]
predicted = extract_route(response)
expected = item["expected_route"]
correct = predicted == expected
results["total"] += 1
results["latencies_ms"].append(elapsed_ms)
if correct:
results["correct"] += 1
if expected not in results["by_tier"]:
results["by_tier"][expected] = {"correct": 0, "total": 0, "latencies": []}
results["by_tier"][expected]["total"] += 1
results["by_tier"][expected]["latencies"].append(elapsed_ms)
if correct:
results["by_tier"][expected]["correct"] += 1
else:
results["misclassifications"].append({
"prompt": item["prompt"][:80],
"expected": expected,
"predicted": predicted,
})
status = "β" if correct else "β"
print(f" [{i+1:2d}/{len(data)}] {status} [{expected:>7s}β{str(predicted):<7s}] {elapsed_ms:6.1f}ms | {item['prompt'][:55]}")
return results
def print_results(results, label):
print(f"\n{'='*65}")
print(f" {label}")
print(f"{'='*65}")
for tier in ("simple", "medium", "complex"):
if tier in results["by_tier"]:
t = results["by_tier"][tier]
pct = t["correct"] / t["total"] * 100 if t["total"] else 0
avg_lat = sum(t["latencies"]) / len(t["latencies"])
bar = "β" * int(pct / 5) + "β" * (20 - int(pct / 5))
print(f" {tier:<10s} {t['correct']:>2d}/{t['total']:<2d} ({pct:5.1f}%) {bar} avg {avg_lat:6.1f}ms")
total_pct = results["correct"] / results["total"] * 100
avg_latency = sum(results["latencies_ms"]) / len(results["latencies_ms"])
p50 = sorted(results["latencies_ms"])[len(results["latencies_ms"]) // 2]
p95 = sorted(results["latencies_ms"])[int(len(results["latencies_ms"]) * 0.95)]
print(f"\n OVERALL: {results['correct']}/{results['total']} ({total_pct:.1f}%)")
print(f" Latency: avg {avg_latency:.1f}ms | p50 {p50:.1f}ms | p95 {p95:.1f}ms")
if results["misclassifications"]:
print(f"\n MISCLASSIFICATIONS ({len(results['misclassifications'])}):")
for m in results["misclassifications"][:10]:
print(f" {m['expected']}β{m['predicted']}: {m['prompt']}")
print(f"{'='*65}\n")
return {"accuracy": total_pct, "avg_ms": avg_latency, "p50_ms": p50, "p95_ms": p95}
def main():
with open(EVAL_PATH) as f:
data = json.load(f)
print(f"Loaded {len(data)} eval prompts\n")
# ββ Stock GGUF ββ
print(f"Loading stock GGUF: {STOCK_GGUF.name}")
stock_model = Llama(
model_path=str(STOCK_GGUF),
n_ctx=512,
n_gpu_layers=-1, # All layers on GPU
verbose=False,
)
print("Stock GGUF loaded\n")
print("Running stock GGUF benchmark...")
stock_results = run_benchmark(stock_model, data, "Stock GGUF")
stock_stats = print_results(stock_results, "STOCK Arch-Router-1.5B (GGUF Q8_0)")
del stock_model
# ββ No-CoT GGUF ββ
print(f"Loading no-CoT GGUF: {NOCOT_GGUF.name}")
nocot_model = Llama(
model_path=str(NOCOT_GGUF),
n_ctx=512,
n_gpu_layers=-1,
verbose=False,
)
print("No-CoT GGUF loaded\n")
print("Running no-CoT GGUF benchmark...")
nocot_results = run_benchmark(nocot_model, data, "No-CoT GGUF")
nocot_stats = print_results(nocot_results, "GRPO No-CoT Fine-Tuned (GGUF Q8_0)")
# ββ Comparison ββ
print(f"{'='*65}")
print(f" GGUF Q8_0 COMPARISON: Stock vs No-CoT Fine-Tuned")
print(f"{'='*65}")
print(f" {'':20s} {'Stock':>12s} {'No-CoT FT':>12s} {'Delta':>10s}")
print(f" {'Accuracy':20s} {stock_stats['accuracy']:>11.1f}% {nocot_stats['accuracy']:>11.1f}% {nocot_stats['accuracy']-stock_stats['accuracy']:>+9.1f}%")
print(f" {'Avg latency':20s} {stock_stats['avg_ms']:>10.1f}ms {nocot_stats['avg_ms']:>10.1f}ms {nocot_stats['avg_ms']-stock_stats['avg_ms']:>+8.1f}ms")
print(f" {'P50 latency':20s} {stock_stats['p50_ms']:>10.1f}ms {nocot_stats['p50_ms']:>10.1f}ms {nocot_stats['p50_ms']-stock_stats['p50_ms']:>+8.1f}ms")
print(f" {'P95 latency':20s} {stock_stats['p95_ms']:>10.1f}ms {nocot_stats['p95_ms']:>10.1f}ms {nocot_stats['p95_ms']-stock_stats['p95_ms']:>+8.1f}ms")
print(f"{'='*65}")
if __name__ == "__main__":
main()
|