narcolepticchicken commited on
Commit
30c4069
·
verified ·
1 Parent(s): 944b77c

Upload benchmarks/benchmark_retrieval_qa.py

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark_retrieval_qa.py +493 -0
benchmarks/benchmark_retrieval_qa.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark 2: Retrieval QA / Legal-Factual QA
3
+
4
+ Compares:
5
+ A. direct answer
6
+ B. RAG baseline
7
+ C. RAG + verifier
8
+ D. RAG + abstention rule
9
+ E. OCC resource allocation
10
+ F. OCC + verifier + abstention reward
11
+
12
+ Uses synthetic grounded QA with adversarial evidence.
13
+ """
14
+
15
+ import json
16
+ import random
17
+ from dataclasses import dataclass
18
+ from pathlib import Path
19
+ from typing import Any, Dict, List, Optional, Tuple
20
+
21
+ import numpy as np
22
+
23
+ import sys
24
+ sys.path.insert(0, str(Path(__file__).parent.parent))
25
+ from oracle.oracle import ImpactOracle, OracleResult
26
+ from ledger.ledger import CreditLedger
27
+ from broker.broker import ResourceBroker, Decision
28
+ from rl.reward import RewardHook
29
+
30
+
31
+ @dataclass
32
+ class Question:
33
+ question: str
34
+ answer: Optional[str] # None = unanswerable
35
+ evidence: List[str]
36
+ adversarial: List[str] # misleading evidence
37
+ is_unanswerable: bool = False
38
+
39
+
40
+ class SimulatedRetrievalAgent:
41
+ """
42
+ Simulates a RAG agent with configurable accuracy, hallucination, and calibration.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ agent_id: str,
48
+ accuracy: float = 0.6,
49
+ hallucination_rate: float = 0.15,
50
+ calibration_error: float = 0.2, # ECE-like
51
+ abstention_rate: float = 0.1,
52
+ cost_per_retrieval: float = 10.0,
53
+ cost_per_answer: float = 5.0,
54
+ gaming_mode: bool = False,
55
+ ):
56
+ self.agent_id = agent_id
57
+ self.accuracy = accuracy
58
+ self.hallucination_rate = hallucination_rate
59
+ self.calibration_error = calibration_error
60
+ self.abstention_rate = abstention_rate
61
+ self.cost_per_retrieval = cost_per_retrieval
62
+ self.cost_per_answer = cost_per_answer
63
+ self.gaming_mode = gaming_mode
64
+ self.retrieval_calls = 0
65
+ self.answers_given = 0
66
+
67
+ def answer(
68
+ self,
69
+ question: Question,
70
+ oracle: ImpactOracle,
71
+ max_retrievals: int = 3,
72
+ use_occ: bool = False,
73
+ broker: Optional[ResourceBroker] = None,
74
+ ledger: Optional[CreditLedger] = None,
75
+ ) -> Dict:
76
+ """Answer a question, optionally with OCC-managed retrievals."""
77
+ retrieved = []
78
+ compute_cost = 0.0
79
+
80
+ # Retrieve evidence
81
+ for i in range(max_retrievals):
82
+ if use_occ and broker and ledger:
83
+ balance = ledger.balance(self.agent_id, "retrieval", "global")
84
+ dec = broker.request(
85
+ "retrieval_call",
86
+ self.agent_id,
87
+ balance,
88
+ task_state={"progress": len(retrieved) / max_retrievals},
89
+ )
90
+ if dec.decision == Decision.DENY:
91
+ break
92
+
93
+ self.retrieval_calls += 1
94
+ compute_cost += self.cost_per_retrieval
95
+
96
+ # Mix genuine and adversarial evidence
97
+ if i == 0:
98
+ retrieved.extend(question.evidence)
99
+ else:
100
+ if random.random() < 0.3:
101
+ retrieved.extend(question.adversarial)
102
+ else:
103
+ retrieved.extend(question.evidence)
104
+
105
+ # OCC: smart stopping — if we already have good evidence, stop retrieving
106
+ if use_occ and i >= 1:
107
+ has_strong_evidence = any(
108
+ "legal text" in ev or "According to" in ev for ev in retrieved
109
+ )
110
+ has_contradiction = any(
111
+ "unknown" in ev or "blog" in ev for ev in retrieved
112
+ )
113
+ # If strong evidence and no contradiction, stop early (save compute)
114
+ if has_strong_evidence and not has_contradiction:
115
+ break
116
+ # If too much adversarial evidence, stop to avoid confusion
117
+ if has_contradiction and i >= 1:
118
+ break
119
+ # If broker denied after first retrieval, stop
120
+ if use_occ and broker and ledger:
121
+ balance = ledger.balance(self.agent_id, "retrieval", "global")
122
+ dec = broker.request(
123
+ "retrieval_call",
124
+ self.agent_id,
125
+ balance,
126
+ task_state={"progress": len(retrieved) / max_retrievals},
127
+ )
128
+ if dec.decision == Decision.DENY:
129
+ break
130
+
131
+ # Decide whether to abstain
132
+ abstained = False
133
+ if question.is_unanswerable:
134
+ abstained = random.random() < (self.abstention_rate + 0.3)
135
+ else:
136
+ abstained = random.random() < self.abstention_rate
137
+
138
+ if abstained:
139
+ self.answers_given += 1
140
+ compute_cost += self.cost_per_answer
141
+ confidence = 0.5 + random.uniform(-self.calibration_error, self.calibration_error)
142
+ confidence = max(0.0, min(1.0, confidence))
143
+
144
+ # Evidence NLI simulation
145
+ evidence = {
146
+ "entailment_score": 0.0,
147
+ "contradiction_score": 0.0,
148
+ }
149
+
150
+ oracle_res = oracle.score(
151
+ mode="retrieval_qa",
152
+ action={"abstained": True},
153
+ context={"gold_answer": question.answer},
154
+ result={
155
+ "answer": None,
156
+ "confidence": confidence,
157
+ "evidence": evidence,
158
+ "compute_cost": compute_cost,
159
+ },
160
+ agent_id=self.agent_id,
161
+ )
162
+ return {
163
+ "answer": None,
164
+ "abstained": True,
165
+ "correct": question.is_unanswerable,
166
+ "confidence": confidence,
167
+ "oracle_score": oracle_res.raw_score,
168
+ "reward": oracle_res.reward_value,
169
+ "compute_cost": compute_cost,
170
+ "retrieval_calls": len(retrieved),
171
+ }
172
+
173
+ # Generate answer
174
+ self.answers_given += 1
175
+ compute_cost += self.cost_per_answer
176
+
177
+ if question.is_unanswerable:
178
+ # Should have abstained
179
+ correct = False
180
+ answer_text = self._generate_fake_answer(question)
181
+ else:
182
+ # Evidence-quality-aware accuracy
183
+ base_accuracy = self.accuracy
184
+ strong_evidence = any("legal text" in ev or "According to" in ev for ev in retrieved)
185
+ adversarial_evidence = any("unknown" in ev or "blog" in ev for ev in retrieved)
186
+
187
+ if strong_evidence and not adversarial_evidence:
188
+ effective_accuracy = min(0.95, base_accuracy + 0.25)
189
+ elif adversarial_evidence:
190
+ effective_accuracy = max(0.3, base_accuracy - 0.15)
191
+ else:
192
+ effective_accuracy = base_accuracy
193
+
194
+ correct = random.random() < effective_accuracy
195
+ if not correct and random.random() < self.hallucination_rate:
196
+ answer_text = self._generate_hallucinated_answer(question)
197
+ correct = False
198
+ else:
199
+ answer_text = question.answer if correct else self._generate_wrong_answer(question)
200
+
201
+ confidence = self._calibrate_confidence(correct)
202
+
203
+ # Evidence NLI simulation
204
+ if correct:
205
+ entailment = 0.8 + random.random() * 0.2
206
+ contradiction = 0.0
207
+ else:
208
+ if random.random() < self.hallucination_rate:
209
+ entailment = 0.2
210
+ contradiction = 0.7 + random.random() * 0.3
211
+ else:
212
+ entailment = 0.4
213
+ contradiction = 0.1
214
+
215
+ evidence = {
216
+ "entailment_score": entailment,
217
+ "contradiction_score": contradiction,
218
+ }
219
+
220
+ oracle_res = oracle.score(
221
+ mode="retrieval_qa",
222
+ action={"abstained": False},
223
+ context={"gold_answer": question.answer},
224
+ result={
225
+ "answer": answer_text,
226
+ "confidence": confidence,
227
+ "evidence": evidence,
228
+ "compute_cost": compute_cost,
229
+ },
230
+ agent_id=self.agent_id,
231
+ )
232
+
233
+ return {
234
+ "answer": answer_text,
235
+ "abstained": False,
236
+ "correct": correct,
237
+ "confidence": confidence,
238
+ "oracle_score": oracle_res.raw_score,
239
+ "reward": oracle_res.reward_value,
240
+ "compute_cost": compute_cost,
241
+ "retrieval_calls": len(retrieved),
242
+ "hallucination": contradiction > 0.5,
243
+ }
244
+
245
+ def _calibrate_confidence(self, correct: bool) -> float:
246
+ """Generate confidence with controlled miscalibration."""
247
+ if correct:
248
+ base = 0.8 + random.random() * 0.2
249
+ else:
250
+ base = 0.3 + random.random() * 0.5
251
+ # Inject calibration error
252
+ error = random.uniform(-self.calibration_error, self.calibration_error)
253
+ return max(0.0, min(1.0, base + error))
254
+
255
+ def _generate_fake_answer(self, question: Question) -> str:
256
+ return f"I cannot answer based on the available evidence."
257
+
258
+ def _generate_hallucinated_answer(self, question: Question) -> str:
259
+ return f"The answer is {question.answer} according to source X." if question.answer else "Unknown."
260
+
261
+ def _generate_wrong_answer(self, question: Question) -> str:
262
+ return "42" # generic wrong
263
+
264
+
265
+ class RetrievalQABenchmark:
266
+ """
267
+ Benchmark retrieval QA with abstention and calibration under budgets.
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ n_questions: int = 100,
273
+ unanswerable_ratio: float = 0.2,
274
+ adversarial_ratio: float = 0.3,
275
+ seed: int = 42,
276
+ ):
277
+ self.n_questions = n_questions
278
+ self.unanswerable_ratio = unanswerable_ratio
279
+ self.adversarial_ratio = adversarial_ratio
280
+ self.seed = seed
281
+ self.questions: List[Question] = []
282
+ self.oracle = ImpactOracle(compute_budget=1e4)
283
+
284
+ def generate_questions(self):
285
+ random.seed(self.seed)
286
+ np.random.seed(self.seed)
287
+
288
+ topics = [
289
+ ("What is the statute of limitations for contract disputes?", "6 years"),
290
+ ("Who authored the Copyright Act of 1976?", "United States Congress"),
291
+ ("What is the maximum penalty under GDPR Article 83?", "20 million EUR"),
292
+ ("Which amendment protects against unreasonable search and seizure?", "Fourth Amendment"),
293
+ ("What is the burden of proof in criminal cases?", "beyond reasonable doubt"),
294
+ ("What is the definition of negligence?", "breach of duty causing harm"),
295
+ ("When was the Paris Agreement signed?", "2015"),
296
+ ("What is the legal drinking age in the US?", "21"),
297
+ ("Which court handles patent appeals?", "Federal Circuit"),
298
+ ("What is the Dodd-Frank Act primarily about?", "financial regulation"),
299
+ ]
300
+
301
+ for i in range(self.n_questions):
302
+ if i < int(self.n_questions * self.unanswerable_ratio):
303
+ q = Question(
304
+ question=f"Unanswerable question {i}: What is the secret code of Atlantis?",
305
+ answer=None,
306
+ evidence=["No reliable source mentions Atlantis codes."],
307
+ adversarial=["Some blogs claim Atlantis code is 1234."],
308
+ is_unanswerable=True,
309
+ )
310
+ else:
311
+ topic = topics[i % len(topics)]
312
+ has_adv = random.random() < self.adversarial_ratio
313
+ q = Question(
314
+ question=topic[0],
315
+ answer=topic[1],
316
+ evidence=[f"According to legal text X, {topic[1]}."],
317
+ adversarial=[f"Some sources claim the answer is 'unknown' for {topic[0]}."] if has_adv else [],
318
+ is_unanswerable=False,
319
+ )
320
+ self.questions.append(q)
321
+
322
+ def run_direct_answer(self, agent: SimulatedRetrievalAgent) -> Dict:
323
+ """Baseline A: direct answer, no retrieval."""
324
+ results = []
325
+ for q in self.questions:
326
+ # Force 0 retrievals
327
+ agent.retrieval_calls = 0
328
+ r = agent.answer(q, self.oracle, max_retrievals=0)
329
+ results.append(r)
330
+ return self._summarize(results, "direct_answer")
331
+
332
+ def run_rag_baseline(self, agent: SimulatedRetrievalAgent) -> Dict:
333
+ """Baseline B: RAG with fixed retrievals."""
334
+ results = []
335
+ for q in self.questions:
336
+ r = agent.answer(q, self.oracle, max_retrievals=2, use_occ=False)
337
+ results.append(r)
338
+ return self._summarize(results, "rag_baseline")
339
+
340
+ def run_rag_verifier(self, agent: SimulatedRetrievalAgent) -> Dict:
341
+ """Baseline C: RAG + verifier (extra check)."""
342
+ results = []
343
+ for q in self.questions:
344
+ r = agent.answer(q, self.oracle, max_retrievals=2, use_occ=False)
345
+ # Simulate verifier: if hallucination detected, retry once
346
+ if r.get("hallucination", False):
347
+ r2 = agent.answer(q, self.oracle, max_retrievals=1, use_occ=False)
348
+ r2["compute_cost"] += r["compute_cost"]
349
+ r2["retrieval_calls"] += r["retrieval_calls"]
350
+ r = r2
351
+ results.append(r)
352
+ return self._summarize(results, "rag_verifier")
353
+
354
+ def run_occ(self, agent: SimulatedRetrievalAgent) -> Dict:
355
+ """Baseline E/F: OCC resource allocation for retrievals."""
356
+ ledger = CreditLedger(decay_lambda=0.05)
357
+ broker = ResourceBroker()
358
+ results = []
359
+
360
+ # Seed initial trial credits for the agent
361
+ ledger.earn(
362
+ agent_id=agent.agent_id,
363
+ task_id="seed",
364
+ action_id="seed",
365
+ amount=10.0,
366
+ oracle_score=0.0,
367
+ compute_cost=0.0,
368
+ reason="initial_trial_credit",
369
+ capability_scope="retrieval",
370
+ )
371
+
372
+ for q in self.questions:
373
+ r = agent.answer(q, self.oracle, max_retrievals=5, use_occ=True, broker=broker, ledger=ledger)
374
+
375
+ # Update ledger based on outcome
376
+ earn_amount = max(0.0, r["reward"] * 3.0)
377
+ if earn_amount > 0:
378
+ ledger.earn(
379
+ agent_id=agent.agent_id,
380
+ task_id=f"q_{q.question[:30]}",
381
+ action_id="answer",
382
+ amount=earn_amount,
383
+ oracle_score=r["oracle_score"],
384
+ compute_cost=r["compute_cost"],
385
+ reason="correct_answer",
386
+ capability_scope="retrieval",
387
+ )
388
+ else:
389
+ # Penalty for wrong / low-reward answers (capped so we don't over-spend)
390
+ bal = ledger.balance(agent.agent_id, "retrieval", "global")
391
+ penalty = min(bal, max(0.5, abs(r["reward"])))
392
+ if penalty > 0:
393
+ ledger.spend(
394
+ agent_id=agent.agent_id,
395
+ task_id=f"q_{q.question[:30]}",
396
+ action_id="answer",
397
+ amount=penalty,
398
+ capability_scope="retrieval",
399
+ reason="wrong_answer_penalty",
400
+ )
401
+
402
+ results.append(r)
403
+
404
+ return self._summarize(results, "occ_allocation")
405
+
406
+ def _summarize(self, results: List[Dict], label: str) -> Dict:
407
+ n = len(results)
408
+ correct = sum(1 for r in results if r["correct"])
409
+ abstained = sum(1 for r in results if r.get("abstained", False))
410
+ correct_abstentions = sum(
411
+ 1 for i in unanswerable_qs if results[i].get("abstained", False)
412
+ )
413
+ wrong_abstentions = sum(
414
+ 1 for i, r in enumerate(results)
415
+ if not self.questions[i].is_unanswerable and r.get("abstained", False)
416
+ )
417
+ hallucinations = sum(1 for r in results if r.get("hallucination", False))
418
+ confidences = [r["confidence"] for r in results]
419
+ correct_flags = [r["correct"] for r in results]
420
+
421
+ # ECE approximation
422
+ ece = self.oracle.compute_ece(confidences, correct_flags, n_bins=5)
423
+
424
+ total_compute = sum(r["compute_cost"] for r in results)
425
+ total_retrievals = sum(r["retrieval_calls"] for r in results)
426
+
427
+ return {
428
+ "label": label,
429
+ "n": n,
430
+ "accuracy": correct / n if n else 0.0,
431
+ "abstention_rate": abstained / n if n else 0.0,
432
+ "correct_abstentions": correct_abstentions,
433
+ "wrong_abstentions": wrong_abstentions,
434
+ "hallucination_rate": hallucinations / n if n else 0.0,
435
+ "confident_wrong_rate": sum(
436
+ 1 for r in results if not r["correct"] and r["confidence"] > 0.8
437
+ ) / n if n else 0.0,
438
+ "ece": float(ece),
439
+ "total_compute": float(total_compute),
440
+ "total_retrievals": total_retrievals,
441
+ "results": results,
442
+ }
443
+
444
+ def _make_agent(self, agent_id: str = "rag_agent") -> SimulatedRetrievalAgent:
445
+ """Create a fresh agent for fair comparison."""
446
+ return SimulatedRetrievalAgent(
447
+ agent_id=agent_id,
448
+ accuracy=0.65,
449
+ hallucination_rate=0.12,
450
+ calibration_error=0.15,
451
+ abstention_rate=0.1,
452
+ )
453
+
454
+ def run_all(self) -> Dict[str, Dict]:
455
+ if not self.questions:
456
+ self.generate_questions()
457
+
458
+ return {
459
+ "direct_answer": self.run_direct_answer(self._make_agent("direct_agent")),
460
+ "rag_baseline": self.run_rag_baseline(self._make_agent("rag_agent")),
461
+ "rag_verifier": self.run_rag_verifier(self._make_agent("verifier_agent")),
462
+ "occ_allocation": self.run_occ(self._make_agent("occ_agent")),
463
+ }
464
+
465
+
466
+ def main():
467
+ bench = RetrievalQABenchmark(n_questions=100, seed=42)
468
+ bench.generate_questions()
469
+ results = bench.run_all()
470
+
471
+ print("=" * 60)
472
+ print("RETRIEVAL QA BENCHMARK")
473
+ print("=" * 60)
474
+ for label, res in results.items():
475
+ print(f"\n{label}")
476
+ print(f" accuracy: {res['accuracy']:.3f}")
477
+ print(f" abstention_rate: {res['abstention_rate']:.3f}")
478
+ print(f" correct_abstentions: {res['correct_abstentions']}")
479
+ print(f" wrong_abstentions: {res['wrong_abstentions']}")
480
+ print(f" hallucination_rate: {res['hallucination_rate']:.3f}")
481
+ print(f" confident_wrong_rate: {res['confident_wrong_rate']:.3f}")
482
+ print(f" ECE: {res['ece']:.3f}")
483
+ print(f" total_compute: {res['total_compute']:.0f}")
484
+ print(f" total_retrievals: {res['total_retrievals']}")
485
+
486
+ Path("/app/occ/reports").mkdir(parents=True, exist_ok=True)
487
+ with open("/app/occ/reports/benchmark_retrieval_qa_results.json", "w") as f:
488
+ json.dump(results, f, indent=2, default=str)
489
+ print("\nSaved to reports/benchmark_retrieval_qa_results.json")
490
+
491
+
492
+ if __name__ == "__main__":
493
+ main()