| """ |
| Benchmark 2: Retrieval QA / Legal-Factual QA |
| |
| Compares: |
| A. direct answer |
| B. RAG baseline |
| C. RAG + verifier |
| D. RAG + abstention rule |
| E. OCC resource allocation |
| F. OCC + verifier + abstention reward |
| |
| Uses synthetic grounded QA with adversarial evidence. |
| """ |
|
|
| import json |
| import random |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import numpy as np |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| from oracle.oracle import ImpactOracle, OracleResult |
| from ledger.ledger import CreditLedger |
| from broker.broker import ResourceBroker, Decision |
| from rl.reward import RewardHook |
|
|
|
|
| @dataclass |
| class Question: |
| question: str |
| answer: Optional[str] |
| evidence: List[str] |
| adversarial: List[str] |
| is_unanswerable: bool = False |
|
|
|
|
| class SimulatedRetrievalAgent: |
| """ |
| Simulates a RAG agent with configurable accuracy, hallucination, and calibration. |
| """ |
|
|
| def __init__( |
| self, |
| agent_id: str, |
| accuracy: float = 0.6, |
| hallucination_rate: float = 0.15, |
| calibration_error: float = 0.2, |
| abstention_rate: float = 0.1, |
| cost_per_retrieval: float = 10.0, |
| cost_per_answer: float = 5.0, |
| gaming_mode: bool = False, |
| ): |
| self.agent_id = agent_id |
| self.accuracy = accuracy |
| self.hallucination_rate = hallucination_rate |
| self.calibration_error = calibration_error |
| self.abstention_rate = abstention_rate |
| self.cost_per_retrieval = cost_per_retrieval |
| self.cost_per_answer = cost_per_answer |
| self.gaming_mode = gaming_mode |
| self.retrieval_calls = 0 |
| self.answers_given = 0 |
|
|
| def answer( |
| self, |
| question: Question, |
| oracle: ImpactOracle, |
| max_retrievals: int = 3, |
| use_occ: bool = False, |
| broker: Optional[ResourceBroker] = None, |
| ledger: Optional[CreditLedger] = None, |
| ) -> Dict: |
| """Answer a question, optionally with OCC-managed retrievals.""" |
| retrieved = [] |
| compute_cost = 0.0 |
|
|
| |
| for i in range(max_retrievals): |
| if use_occ and broker and ledger: |
| balance = ledger.balance(self.agent_id, "retrieval", "global") |
| dec = broker.request( |
| "retrieval_call", |
| self.agent_id, |
| balance, |
| task_state={"progress": len(retrieved) / max_retrievals}, |
| ) |
| if dec.decision == Decision.DENY: |
| break |
|
|
| self.retrieval_calls += 1 |
| compute_cost += self.cost_per_retrieval |
|
|
| |
| if i == 0: |
| retrieved.extend(question.evidence) |
| else: |
| if random.random() < 0.3: |
| retrieved.extend(question.adversarial) |
| else: |
| retrieved.extend(question.evidence) |
|
|
| |
| if use_occ and i >= 1: |
| has_strong_evidence = any( |
| "legal text" in ev or "According to" in ev for ev in retrieved |
| ) |
| has_contradiction = any( |
| "unknown" in ev or "blog" in ev for ev in retrieved |
| ) |
| |
| if has_strong_evidence and not has_contradiction: |
| break |
| |
| if has_contradiction and i >= 1: |
| break |
| |
| if use_occ and broker and ledger: |
| balance = ledger.balance(self.agent_id, "retrieval", "global") |
| dec = broker.request( |
| "retrieval_call", |
| self.agent_id, |
| balance, |
| task_state={"progress": len(retrieved) / max_retrievals}, |
| ) |
| if dec.decision == Decision.DENY: |
| break |
|
|
| |
| abstained = False |
| if question.is_unanswerable: |
| abstained = random.random() < (self.abstention_rate + 0.3) |
| else: |
| abstained = random.random() < self.abstention_rate |
|
|
| if abstained: |
| self.answers_given += 1 |
| compute_cost += self.cost_per_answer |
| confidence = 0.5 + random.uniform(-self.calibration_error, self.calibration_error) |
| confidence = max(0.0, min(1.0, confidence)) |
|
|
| |
| evidence = { |
| "entailment_score": 0.0, |
| "contradiction_score": 0.0, |
| } |
|
|
| oracle_res = oracle.score( |
| mode="retrieval_qa", |
| action={"abstained": True}, |
| context={"gold_answer": question.answer}, |
| result={ |
| "answer": None, |
| "confidence": confidence, |
| "evidence": evidence, |
| "compute_cost": compute_cost, |
| }, |
| agent_id=self.agent_id, |
| ) |
| return { |
| "answer": None, |
| "abstained": True, |
| "correct": question.is_unanswerable, |
| "confidence": confidence, |
| "oracle_score": oracle_res.raw_score, |
| "reward": oracle_res.reward_value, |
| "compute_cost": compute_cost, |
| "retrieval_calls": len(retrieved), |
| } |
|
|
| |
| self.answers_given += 1 |
| compute_cost += self.cost_per_answer |
|
|
| if question.is_unanswerable: |
| |
| correct = False |
| answer_text = self._generate_fake_answer(question) |
| else: |
| |
| base_accuracy = self.accuracy |
| strong_evidence = any("legal text" in ev or "According to" in ev for ev in retrieved) |
| adversarial_evidence = any("unknown" in ev or "blog" in ev for ev in retrieved) |
| |
| if strong_evidence and not adversarial_evidence: |
| effective_accuracy = min(0.95, base_accuracy + 0.25) |
| elif adversarial_evidence: |
| effective_accuracy = max(0.3, base_accuracy - 0.15) |
| else: |
| effective_accuracy = base_accuracy |
|
|
| correct = random.random() < effective_accuracy |
| if not correct and random.random() < self.hallucination_rate: |
| answer_text = self._generate_hallucinated_answer(question) |
| correct = False |
| else: |
| answer_text = question.answer if correct else self._generate_wrong_answer(question) |
|
|
| confidence = self._calibrate_confidence(correct) |
|
|
| |
| if correct: |
| entailment = 0.8 + random.random() * 0.2 |
| contradiction = 0.0 |
| else: |
| if random.random() < self.hallucination_rate: |
| entailment = 0.2 |
| contradiction = 0.7 + random.random() * 0.3 |
| else: |
| entailment = 0.4 |
| contradiction = 0.1 |
|
|
| evidence = { |
| "entailment_score": entailment, |
| "contradiction_score": contradiction, |
| } |
|
|
| oracle_res = oracle.score( |
| mode="retrieval_qa", |
| action={"abstained": False}, |
| context={"gold_answer": question.answer}, |
| result={ |
| "answer": answer_text, |
| "confidence": confidence, |
| "evidence": evidence, |
| "compute_cost": compute_cost, |
| }, |
| agent_id=self.agent_id, |
| ) |
|
|
| return { |
| "answer": answer_text, |
| "abstained": False, |
| "correct": correct, |
| "confidence": confidence, |
| "oracle_score": oracle_res.raw_score, |
| "reward": oracle_res.reward_value, |
| "compute_cost": compute_cost, |
| "retrieval_calls": len(retrieved), |
| "hallucination": contradiction > 0.5, |
| } |
|
|
| def _calibrate_confidence(self, correct: bool) -> float: |
| """Generate confidence with controlled miscalibration.""" |
| if correct: |
| base = 0.8 + random.random() * 0.2 |
| else: |
| base = 0.3 + random.random() * 0.5 |
| |
| error = random.uniform(-self.calibration_error, self.calibration_error) |
| return max(0.0, min(1.0, base + error)) |
|
|
| def _generate_fake_answer(self, question: Question) -> str: |
| return f"I cannot answer based on the available evidence." |
|
|
| def _generate_hallucinated_answer(self, question: Question) -> str: |
| return f"The answer is {question.answer} according to source X." if question.answer else "Unknown." |
|
|
| def _generate_wrong_answer(self, question: Question) -> str: |
| return "42" |
|
|
|
|
| class RetrievalQABenchmark: |
| """ |
| Benchmark retrieval QA with abstention and calibration under budgets. |
| """ |
|
|
| def __init__( |
| self, |
| n_questions: int = 100, |
| unanswerable_ratio: float = 0.2, |
| adversarial_ratio: float = 0.3, |
| seed: int = 42, |
| ): |
| self.n_questions = n_questions |
| self.unanswerable_ratio = unanswerable_ratio |
| self.adversarial_ratio = adversarial_ratio |
| self.seed = seed |
| self.questions: List[Question] = [] |
| self.oracle = ImpactOracle(compute_budget=1e4) |
|
|
| def generate_questions(self): |
| random.seed(self.seed) |
| np.random.seed(self.seed) |
|
|
| topics = [ |
| ("What is the statute of limitations for contract disputes?", "6 years"), |
| ("Who authored the Copyright Act of 1976?", "United States Congress"), |
| ("What is the maximum penalty under GDPR Article 83?", "20 million EUR"), |
| ("Which amendment protects against unreasonable search and seizure?", "Fourth Amendment"), |
| ("What is the burden of proof in criminal cases?", "beyond reasonable doubt"), |
| ("What is the definition of negligence?", "breach of duty causing harm"), |
| ("When was the Paris Agreement signed?", "2015"), |
| ("What is the legal drinking age in the US?", "21"), |
| ("Which court handles patent appeals?", "Federal Circuit"), |
| ("What is the Dodd-Frank Act primarily about?", "financial regulation"), |
| ] |
|
|
| for i in range(self.n_questions): |
| if i < int(self.n_questions * self.unanswerable_ratio): |
| q = Question( |
| question=f"Unanswerable question {i}: What is the secret code of Atlantis?", |
| answer=None, |
| evidence=["No reliable source mentions Atlantis codes."], |
| adversarial=["Some blogs claim Atlantis code is 1234."], |
| is_unanswerable=True, |
| ) |
| else: |
| topic = topics[i % len(topics)] |
| has_adv = random.random() < self.adversarial_ratio |
| q = Question( |
| question=topic[0], |
| answer=topic[1], |
| evidence=[f"According to legal text X, {topic[1]}."], |
| adversarial=[f"Some sources claim the answer is 'unknown' for {topic[0]}."] if has_adv else [], |
| is_unanswerable=False, |
| ) |
| self.questions.append(q) |
|
|
| def run_direct_answer(self, agent: SimulatedRetrievalAgent) -> Dict: |
| """Baseline A: direct answer, no retrieval.""" |
| results = [] |
| for q in self.questions: |
| |
| agent.retrieval_calls = 0 |
| r = agent.answer(q, self.oracle, max_retrievals=0) |
| results.append(r) |
| return self._summarize(results, "direct_answer") |
|
|
| def run_rag_baseline(self, agent: SimulatedRetrievalAgent) -> Dict: |
| """Baseline B: RAG with fixed retrievals.""" |
| results = [] |
| for q in self.questions: |
| r = agent.answer(q, self.oracle, max_retrievals=2, use_occ=False) |
| results.append(r) |
| return self._summarize(results, "rag_baseline") |
|
|
| def run_rag_verifier(self, agent: SimulatedRetrievalAgent) -> Dict: |
| """Baseline C: RAG + verifier (extra check).""" |
| results = [] |
| for q in self.questions: |
| r = agent.answer(q, self.oracle, max_retrievals=2, use_occ=False) |
| |
| if r.get("hallucination", False): |
| r2 = agent.answer(q, self.oracle, max_retrievals=1, use_occ=False) |
| r2["compute_cost"] += r["compute_cost"] |
| r2["retrieval_calls"] += r["retrieval_calls"] |
| r = r2 |
| results.append(r) |
| return self._summarize(results, "rag_verifier") |
|
|
| def run_occ(self, agent: SimulatedRetrievalAgent) -> Dict: |
| """Baseline E/F: OCC resource allocation for retrievals.""" |
| ledger = CreditLedger(decay_lambda=0.05) |
| broker = ResourceBroker() |
| results = [] |
|
|
| |
| ledger.earn( |
| agent_id=agent.agent_id, |
| task_id="seed", |
| action_id="seed", |
| amount=10.0, |
| oracle_score=0.0, |
| compute_cost=0.0, |
| reason="initial_trial_credit", |
| capability_scope="retrieval", |
| ) |
|
|
| for q in self.questions: |
| r = agent.answer(q, self.oracle, max_retrievals=5, use_occ=True, broker=broker, ledger=ledger) |
|
|
| |
| earn_amount = max(0.0, r["reward"] * 3.0) |
| if earn_amount > 0: |
| ledger.earn( |
| agent_id=agent.agent_id, |
| task_id=f"q_{q.question[:30]}", |
| action_id="answer", |
| amount=earn_amount, |
| oracle_score=r["oracle_score"], |
| compute_cost=r["compute_cost"], |
| reason="correct_answer", |
| capability_scope="retrieval", |
| ) |
| else: |
| |
| bal = ledger.balance(agent.agent_id, "retrieval", "global") |
| penalty = min(bal, max(0.5, abs(r["reward"]))) |
| if penalty > 0: |
| ledger.spend( |
| agent_id=agent.agent_id, |
| task_id=f"q_{q.question[:30]}", |
| action_id="answer", |
| amount=penalty, |
| capability_scope="retrieval", |
| reason="wrong_answer_penalty", |
| ) |
|
|
| results.append(r) |
|
|
| return self._summarize(results, "occ_allocation") |
|
|
| def _summarize(self, results: List[Dict], label: str) -> Dict: |
| n = len(results) |
| correct = sum(1 for r in results if r["correct"]) |
| abstained = sum(1 for r in results if r.get("abstained", False)) |
| |
| unanswerable_qs = [i for i, r in enumerate(results) if self.questions[i].is_unanswerable] |
| correct_abstentions = sum( |
| 1 for i in unanswerable_qs if results[i].get("abstained", False) |
| ) |
| wrong_abstentions = sum( |
| 1 for i, r in enumerate(results) |
| if not self.questions[i].is_unanswerable and r.get("abstained", False) |
| ) |
| hallucinations = sum(1 for r in results if r.get("hallucination", False)) |
| confidences = [r["confidence"] for r in results] |
| correct_flags = [r["correct"] for r in results] |
|
|
| |
| ece = self.oracle.compute_ece(confidences, correct_flags, n_bins=5) |
|
|
| total_compute = sum(r["compute_cost"] for r in results) |
| total_retrievals = sum(r["retrieval_calls"] for r in results) |
|
|
| return { |
| "label": label, |
| "n": n, |
| "accuracy": correct / n if n else 0.0, |
| "abstention_rate": abstained / n if n else 0.0, |
| "correct_abstentions": correct_abstentions, |
| "wrong_abstentions": wrong_abstentions, |
| "hallucination_rate": hallucinations / n if n else 0.0, |
| "confident_wrong_rate": sum( |
| 1 for r in results if not r["correct"] and r["confidence"] > 0.8 |
| ) / n if n else 0.0, |
| "ece": float(ece), |
| "total_compute": float(total_compute), |
| "total_retrievals": total_retrievals, |
| "results": results, |
| } |
|
|
| def _make_agent(self, agent_id: str = "rag_agent") -> SimulatedRetrievalAgent: |
| """Create a fresh agent for fair comparison.""" |
| return SimulatedRetrievalAgent( |
| agent_id=agent_id, |
| accuracy=0.65, |
| hallucination_rate=0.12, |
| calibration_error=0.15, |
| abstention_rate=0.1, |
| ) |
|
|
| def run_all(self) -> Dict[str, Dict]: |
| if not self.questions: |
| self.generate_questions() |
|
|
| return { |
| "direct_answer": self.run_direct_answer(self._make_agent("direct_agent")), |
| "rag_baseline": self.run_rag_baseline(self._make_agent("rag_agent")), |
| "rag_verifier": self.run_rag_verifier(self._make_agent("verifier_agent")), |
| "occ_allocation": self.run_occ(self._make_agent("occ_agent")), |
| } |
|
|
|
|
| def main(): |
| bench = RetrievalQABenchmark(n_questions=100, seed=42) |
| bench.generate_questions() |
| results = bench.run_all() |
|
|
| print("=" * 60) |
| print("RETRIEVAL QA BENCHMARK") |
| print("=" * 60) |
| for label, res in results.items(): |
| print(f"\n{label}") |
| print(f" accuracy: {res['accuracy']:.3f}") |
| print(f" abstention_rate: {res['abstention_rate']:.3f}") |
| print(f" correct_abstentions: {res['correct_abstentions']}") |
| print(f" wrong_abstentions: {res['wrong_abstentions']}") |
| print(f" hallucination_rate: {res['hallucination_rate']:.3f}") |
| print(f" confident_wrong_rate: {res['confident_wrong_rate']:.3f}") |
| print(f" ECE: {res['ece']:.3f}") |
| print(f" total_compute: {res['total_compute']:.0f}") |
| print(f" total_retrievals: {res['total_retrievals']}") |
|
|
| Path("/app/occ/reports").mkdir(parents=True, exist_ok=True) |
| with open("/app/occ/reports/benchmark_retrieval_qa_results.json", "w") as f: |
| json.dump(results, f, indent=2, default=str) |
| print("\nSaved to reports/benchmark_retrieval_qa_results.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|