AaryanK commited on
Commit
e892559
Β·
verified Β·
1 Parent(s): e6aab9e

Upload bench_gguf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. bench_gguf.py +188 -0
bench_gguf.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Benchmark GGUF models via llama-cpp-python.
4
+ Compares stock vs no-CoT fine-tuned Arch-Router-1.5B in Q8_0 GGUF format.
5
+ """
6
+
7
+ import json
8
+ import time
9
+ from pathlib import Path
10
+ from llama_cpp import Llama
11
+
12
+ EVAL_PATH = Path(__file__).parent / "grpo_eval_data.json"
13
+ STOCK_GGUF = Path(__file__).parent / "stock_arch_router.Q8_0.gguf"
14
+ NOCOT_GGUF = Path(__file__).parent / "nocot_arch_router.Q8_0.gguf"
15
+
16
+ ROUTE_POLICIES = [
17
+ {"name": "simple", "description": "Simple factual questions, greetings, basic lookups, yes/no answers, FAQ-style queries, single-step tasks, status checks, straightforward requests"},
18
+ {"name": "medium", "description": "Multi-step reasoning, summarization of moderate-length text, data extraction, moderate analysis, comparison tasks, troubleshooting, explanations requiring some depth"},
19
+ {"name": "complex", "description": "Complex multi-document reasoning, deep analysis, legal or financial interpretation, creative writing, code generation, multi-constraint problem solving, liability assessment, comprehensive evaluation"},
20
+ ]
21
+
22
+
23
+ def find_gguf(directory: Path) -> Path:
24
+ """Find the .gguf file in a directory."""
25
+ for f in directory.iterdir():
26
+ if f.suffix == ".gguf":
27
+ return f
28
+ raise FileNotFoundError(f"No .gguf file found in {directory}")
29
+
30
+
31
+ def build_prompt(user_prompt: str) -> str:
32
+ policies_json = json.dumps(ROUTE_POLICIES)
33
+ conversation_json = json.dumps([{"role": "user", "content": user_prompt}])
34
+ return f"""You are a routing assistant. Given the route policies and user message, select the best matching route.
35
+
36
+ <route_policies>
37
+ {policies_json}
38
+ </route_policies>
39
+
40
+ <conversation>
41
+ {conversation_json}
42
+ </conversation>
43
+
44
+ Select the best route for this user message. Respond with ONLY valid JSON: {{"route": "route_name"}}"""
45
+
46
+
47
+ def extract_route(text: str) -> str | None:
48
+ try:
49
+ parsed = json.loads(text.strip())
50
+ route = parsed.get("route")
51
+ if route in ("simple", "medium", "complex"):
52
+ return route
53
+ except (json.JSONDecodeError, TypeError):
54
+ pass
55
+ for tier in ("simple", "medium", "complex"):
56
+ if tier in text.lower():
57
+ return tier
58
+ return None
59
+
60
+
61
+ def run_benchmark(model: Llama, data: list[dict], label: str) -> dict:
62
+ results = {"correct": 0, "total": 0, "latencies_ms": [], "by_tier": {}, "misclassifications": []}
63
+
64
+ # Warmup
65
+ model.create_chat_completion(
66
+ messages=[{"role": "user", "content": "Hello"}],
67
+ max_tokens=10, temperature=0,
68
+ )
69
+
70
+ for i, item in enumerate(data):
71
+ prompt = build_prompt(item["prompt"])
72
+
73
+ start = time.perf_counter()
74
+ output = model.create_chat_completion(
75
+ messages=[{"role": "user", "content": prompt}],
76
+ max_tokens=30,
77
+ temperature=0,
78
+ )
79
+ elapsed_ms = (time.perf_counter() - start) * 1000
80
+
81
+ response = output["choices"][0]["message"]["content"]
82
+ predicted = extract_route(response)
83
+ expected = item["expected_route"]
84
+ correct = predicted == expected
85
+
86
+ results["total"] += 1
87
+ results["latencies_ms"].append(elapsed_ms)
88
+ if correct:
89
+ results["correct"] += 1
90
+
91
+ if expected not in results["by_tier"]:
92
+ results["by_tier"][expected] = {"correct": 0, "total": 0, "latencies": []}
93
+ results["by_tier"][expected]["total"] += 1
94
+ results["by_tier"][expected]["latencies"].append(elapsed_ms)
95
+ if correct:
96
+ results["by_tier"][expected]["correct"] += 1
97
+ else:
98
+ results["misclassifications"].append({
99
+ "prompt": item["prompt"][:80],
100
+ "expected": expected,
101
+ "predicted": predicted,
102
+ })
103
+
104
+ status = "βœ“" if correct else "βœ—"
105
+ print(f" [{i+1:2d}/{len(data)}] {status} [{expected:>7s}β†’{str(predicted):<7s}] {elapsed_ms:6.1f}ms | {item['prompt'][:55]}")
106
+
107
+ return results
108
+
109
+
110
+ def print_results(results, label):
111
+ print(f"\n{'='*65}")
112
+ print(f" {label}")
113
+ print(f"{'='*65}")
114
+
115
+ for tier in ("simple", "medium", "complex"):
116
+ if tier in results["by_tier"]:
117
+ t = results["by_tier"][tier]
118
+ pct = t["correct"] / t["total"] * 100 if t["total"] else 0
119
+ avg_lat = sum(t["latencies"]) / len(t["latencies"])
120
+ bar = "β–ˆ" * int(pct / 5) + "β–‘" * (20 - int(pct / 5))
121
+ print(f" {tier:<10s} {t['correct']:>2d}/{t['total']:<2d} ({pct:5.1f}%) {bar} avg {avg_lat:6.1f}ms")
122
+
123
+ total_pct = results["correct"] / results["total"] * 100
124
+ avg_latency = sum(results["latencies_ms"]) / len(results["latencies_ms"])
125
+ p50 = sorted(results["latencies_ms"])[len(results["latencies_ms"]) // 2]
126
+ p95 = sorted(results["latencies_ms"])[int(len(results["latencies_ms"]) * 0.95)]
127
+
128
+ print(f"\n OVERALL: {results['correct']}/{results['total']} ({total_pct:.1f}%)")
129
+ print(f" Latency: avg {avg_latency:.1f}ms | p50 {p50:.1f}ms | p95 {p95:.1f}ms")
130
+
131
+ if results["misclassifications"]:
132
+ print(f"\n MISCLASSIFICATIONS ({len(results['misclassifications'])}):")
133
+ for m in results["misclassifications"][:10]:
134
+ print(f" {m['expected']}β†’{m['predicted']}: {m['prompt']}")
135
+ print(f"{'='*65}\n")
136
+
137
+ return {"accuracy": total_pct, "avg_ms": avg_latency, "p50_ms": p50, "p95_ms": p95}
138
+
139
+
140
+ def main():
141
+ with open(EVAL_PATH) as f:
142
+ data = json.load(f)
143
+ print(f"Loaded {len(data)} eval prompts\n")
144
+
145
+ # ── Stock GGUF ──
146
+ print(f"Loading stock GGUF: {STOCK_GGUF.name}")
147
+ stock_model = Llama(
148
+ model_path=str(STOCK_GGUF),
149
+ n_ctx=512,
150
+ n_gpu_layers=-1, # All layers on GPU
151
+ verbose=False,
152
+ )
153
+ print("Stock GGUF loaded\n")
154
+
155
+ print("Running stock GGUF benchmark...")
156
+ stock_results = run_benchmark(stock_model, data, "Stock GGUF")
157
+ stock_stats = print_results(stock_results, "STOCK Arch-Router-1.5B (GGUF Q8_0)")
158
+
159
+ del stock_model
160
+
161
+ # ── No-CoT GGUF ──
162
+ print(f"Loading no-CoT GGUF: {NOCOT_GGUF.name}")
163
+ nocot_model = Llama(
164
+ model_path=str(NOCOT_GGUF),
165
+ n_ctx=512,
166
+ n_gpu_layers=-1,
167
+ verbose=False,
168
+ )
169
+ print("No-CoT GGUF loaded\n")
170
+
171
+ print("Running no-CoT GGUF benchmark...")
172
+ nocot_results = run_benchmark(nocot_model, data, "No-CoT GGUF")
173
+ nocot_stats = print_results(nocot_results, "GRPO No-CoT Fine-Tuned (GGUF Q8_0)")
174
+
175
+ # ── Comparison ──
176
+ print(f"{'='*65}")
177
+ print(f" GGUF Q8_0 COMPARISON: Stock vs No-CoT Fine-Tuned")
178
+ print(f"{'='*65}")
179
+ print(f" {'':20s} {'Stock':>12s} {'No-CoT FT':>12s} {'Delta':>10s}")
180
+ print(f" {'Accuracy':20s} {stock_stats['accuracy']:>11.1f}% {nocot_stats['accuracy']:>11.1f}% {nocot_stats['accuracy']-stock_stats['accuracy']:>+9.1f}%")
181
+ print(f" {'Avg latency':20s} {stock_stats['avg_ms']:>10.1f}ms {nocot_stats['avg_ms']:>10.1f}ms {nocot_stats['avg_ms']-stock_stats['avg_ms']:>+8.1f}ms")
182
+ print(f" {'P50 latency':20s} {stock_stats['p50_ms']:>10.1f}ms {nocot_stats['p50_ms']:>10.1f}ms {nocot_stats['p50_ms']-stock_stats['p50_ms']:>+8.1f}ms")
183
+ print(f" {'P95 latency':20s} {stock_stats['p95_ms']:>10.1f}ms {nocot_stats['p95_ms']:>10.1f}ms {nocot_stats['p95_ms']-stock_stats['p95_ms']:>+8.1f}ms")
184
+ print(f"{'='*65}")
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()