narcolepticchicken commited on
Commit
b8754a6
·
verified ·
1 Parent(s): df7e938

Upload oracle/oracle.py

Browse files
Files changed (1) hide show
  1. oracle/oracle.py +398 -0
oracle/oracle.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Impact Oracle: scores whether an agent action produced measurable marginal value.
3
+ Supports code tasks, retrieval QA tasks, and multi-agent debate tasks.
4
+ """
5
+
6
+ import json
7
+ import math
8
+ from dataclasses import dataclass, field
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+ import numpy as np
11
+
12
+
13
+ @dataclass
14
+ class OracleResult:
15
+ raw_score: float = 0.0
16
+ cost_adjusted_score: float = 0.0
17
+ confidence: float = 0.0
18
+ evidence: Dict[str, Any] = field(default_factory=dict)
19
+ reason: str = ""
20
+ failure_tags: List[str] = field(default_factory=list)
21
+ reward_value: float = 0.0
22
+
23
+
24
+ class ImpactOracle:
25
+ """
26
+ Multi-mode impact oracle with structured JSON output.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ compute_budget: float = 1e6, # tokens or FLOPs budget reference
32
+ decay_lambda: float = 0.1,
33
+ calibration_weight: float = 0.2,
34
+ hallucination_weight: float = 0.5,
35
+ confident_wrong_weight: float = 0.3,
36
+ compute_cost_weight: float = 0.2,
37
+ gaming_weight: float = 0.4,
38
+ ):
39
+ self.compute_budget = compute_budget
40
+ self.decay_lambda = decay_lambda
41
+ self.calibration_weight = calibration_weight
42
+ self.hallucination_weight = hallucination_weight
43
+ self.confident_wrong_weight = confident_wrong_weight
44
+ self.compute_cost_weight = compute_cost_weight
45
+ self.gaming_weight = gaming_weight
46
+
47
+ # Gaming detection state (lightweight per-agent history)
48
+ self._agent_history: Dict[str, List[Dict]] = {}
49
+
50
+ # ------------------------------------------------------------------
51
+ # Public API
52
+ # ------------------------------------------------------------------
53
+
54
+ def score(
55
+ self,
56
+ mode: str,
57
+ action: Dict[str, Any],
58
+ context: Dict[str, Any],
59
+ result: Dict[str, Any],
60
+ agent_id: Optional[str] = None,
61
+ ) -> OracleResult:
62
+ """Score an action based on mode and context."""
63
+ if mode == "code":
64
+ raw = self._score_code(action, context, result)
65
+ elif mode == "retrieval_qa":
66
+ raw = self._score_retrieval_qa(action, context, result)
67
+ elif mode == "debate":
68
+ raw = self._score_debate(action, context, result)
69
+ else:
70
+ raw = OracleResult(raw_score=0.0, reason=f"Unknown mode: {mode}")
71
+
72
+ # Apply cost adjustment
73
+ compute_cost = result.get("compute_cost", 0.0)
74
+ cost_adj = self.cost_adjusted_score(raw.raw_score, compute_cost)
75
+ raw.cost_adjusted_score = cost_adj
76
+
77
+ # Detect gaming patterns
78
+ gaming_penalty = 0.0
79
+ if agent_id is not None:
80
+ gaming_penalty = self._detect_gaming(agent_id, action, raw)
81
+ raw.reward_value = raw.cost_adjusted_score - gaming_penalty
82
+ if gaming_penalty > 0:
83
+ raw.failure_tags.append("gaming_detected")
84
+
85
+ return raw
86
+
87
+ def marginal_impact(self, before: OracleResult, after: OracleResult) -> float:
88
+ """Compute marginal impact between two oracle results."""
89
+ return after.reward_value - before.reward_value
90
+
91
+ def proper_score(self, prediction: float, outcome: int) -> float:
92
+ """
93
+ Brier score (proper scoring rule).
94
+ prediction: probability of outcome=1
95
+ outcome: 0 or 1
96
+ Returns 1 - Brier (higher is better).
97
+ """
98
+ brier = (prediction - outcome) ** 2
99
+ return 1.0 - brier
100
+
101
+ def abstention_score(
102
+ self,
103
+ answer: Optional[str],
104
+ confidence: float,
105
+ evidence: Dict[str, Any],
106
+ outcome: Optional[str],
107
+ ) -> float:
108
+ """
109
+ Reward correct abstention, penalize incorrect abstention.
110
+ answer=None means abstained.
111
+ outcome=None means question is unanswerable.
112
+ """
113
+ abstained = answer is None or answer.strip() == ""
114
+ unanswerable = outcome is None or outcome.strip() == ""
115
+
116
+ if abstained and unanswerable:
117
+ return +1.0 # correct abstention
118
+ elif abstained and not unanswerable:
119
+ return -1.0 # incorrect abstention (avoided answering a valid question)
120
+ elif not abstained and unanswerable:
121
+ return -0.5 # should have abstained
122
+ else:
123
+ # Both answered — use proper score on confidence
124
+ correct = self._fuzzy_match(answer, outcome)
125
+ return self.proper_score(confidence, int(correct))
126
+
127
+ def cost_adjusted_score(self, raw_score: float, compute_cost: float) -> float:
128
+ """Penalize excessive compute usage."""
129
+ penalty = min(1.0, (compute_cost / self.compute_budget)) * self.compute_cost_weight
130
+ return max(0.0, raw_score - penalty)
131
+
132
+ # ------------------------------------------------------------------
133
+ # Mode-specific scorers
134
+ # ------------------------------------------------------------------
135
+
136
+ def _score_code(
137
+ self, action: Dict, context: Dict, result: Dict
138
+ ) -> OracleResult:
139
+ """
140
+ Code scoring:
141
+ - unit test pass/fail
142
+ - pass@k estimate
143
+ - regression detection
144
+ - fixed-compute comparison
145
+ """
146
+ passed = result.get("passed", False)
147
+ hidden_passed = result.get("hidden_passed", passed)
148
+ k = result.get("k", 1)
149
+ passes = result.get("passes", 1 if passed else 0)
150
+ prev_passed = context.get("previous_passed", False)
151
+ compute_cost = result.get("compute_cost", 0.0)
152
+
153
+ # Regression: new pass but old also passed (no marginal value)
154
+ regression = prev_passed and passed
155
+
156
+ # Hidden-test gaming detection: passes public but fails hidden
157
+ hidden_gaming = passed and not hidden_passed
158
+
159
+ raw_score = 1.0 if hidden_passed else (0.3 if passed else 0.0)
160
+ if regression:
161
+ raw_score *= 0.5 # diminished marginal value
162
+ if hidden_gaming:
163
+ raw_score = -0.5 # strong penalty for gaming public tests
164
+
165
+ pass_at_k = passes / k if k > 0 else 0.0
166
+
167
+ failure_tags = []
168
+ if hidden_gaming:
169
+ failure_tags.append("hidden_test_gaming")
170
+ if regression:
171
+ failure_tags.append("regression")
172
+
173
+ reason = (
174
+ f"public={'pass' if passed else 'fail'}, "
175
+ f"hidden={'pass' if hidden_passed else 'fail'}, "
176
+ f"pass@{k}={pass_at_k:.2f}, "
177
+ f"regression={regression}"
178
+ )
179
+
180
+ return OracleResult(
181
+ raw_score=raw_score,
182
+ confidence=0.9 if hidden_passed else (0.5 if passed else 0.1),
183
+ evidence={"pass_at_k": pass_at_k, "regression": regression},
184
+ reason=reason,
185
+ failure_tags=failure_tags,
186
+ )
187
+
188
+ def _score_retrieval_qa(
189
+ self, action: Dict, context: Dict, result: Dict
190
+ ) -> OracleResult:
191
+ """
192
+ Retrieval QA scoring:
193
+ - answer correctness
194
+ - evidence support (NLI-style)
195
+ - hallucination detection
196
+ - abstention utility
197
+ - calibration / ECE
198
+ - proper scoring rule
199
+ """
200
+ answer = result.get("answer")
201
+ gold = context.get("gold_answer")
202
+ confidence = result.get("confidence", 0.5)
203
+ evidence = result.get("evidence", {})
204
+ compute_cost = result.get("compute_cost", 0.0)
205
+
206
+ # Correctness
207
+ correct = self._fuzzy_match(answer, gold) if answer else False
208
+ raw_score = 1.0 if correct else 0.0
209
+
210
+ # Evidence support: entailment score
211
+ entailment = evidence.get("entailment_score", 0.0)
212
+ contradiction = evidence.get("contradiction_score", 0.0)
213
+
214
+ # Hallucination penalty
215
+ hallucination_penalty = contradiction * self.hallucination_weight
216
+
217
+ # Abstention utility
218
+ abstention = self.abstention_score(answer, confidence, evidence, gold)
219
+
220
+ # Calibration bonus via Brier
221
+ if gold is not None and answer is not None:
222
+ brier = (confidence - int(correct)) ** 2
223
+ calibration_bonus = (1.0 - brier) * self.calibration_weight
224
+ else:
225
+ calibration_bonus = 0.0
226
+
227
+ # Confident-wrong penalty
228
+ confident_wrong_penalty = 0.0
229
+ if not correct and answer is not None:
230
+ confident_wrong_penalty = confidence * self.confident_wrong_weight
231
+
232
+ reward = (
233
+ raw_score
234
+ + abstention * 0.3
235
+ + calibration_bonus
236
+ - hallucination_penalty
237
+ - confident_wrong_penalty
238
+ )
239
+
240
+ failure_tags = []
241
+ if contradiction > 0.5:
242
+ failure_tags.append("hallucination")
243
+ if not correct and confidence > 0.8:
244
+ failure_tags.append("confident_wrong")
245
+
246
+ reason = (
247
+ f"correct={correct}, confidence={confidence:.2f}, "
248
+ f"entailment={entailment:.2f}, contradiction={contradiction:.2f}, "
249
+ f"abstention={abstention:.2f}, calib_bonus={calibration_bonus:.3f}"
250
+ )
251
+
252
+ return OracleResult(
253
+ raw_score=raw_score,
254
+ confidence=confidence,
255
+ evidence=evidence,
256
+ reason=reason,
257
+ failure_tags=failure_tags,
258
+ reward_value=max(-1.0, min(1.0, reward)),
259
+ )
260
+
261
+ def _score_debate(
262
+ self, action: Dict, context: Dict, result: Dict
263
+ ) -> OracleResult:
264
+ """
265
+ Multi-agent debate scoring:
266
+ - decision quality
267
+ - influence efficiency (marginal contribution per compute)
268
+ - throughput
269
+ """
270
+ final_correct = result.get("final_correct", False)
271
+ prev_correct = context.get("previous_correct", False)
272
+ agent_contribution = result.get("agent_contribution", 0.0)
273
+ compute_cost = result.get("compute_cost", 0.0)
274
+ tokens_used = result.get("tokens_used", 0)
275
+ total_turns = result.get("total_turns", 1)
276
+
277
+ # Decision quality
278
+ raw_score = 1.0 if final_correct else 0.0
279
+
280
+ # Marginal contribution: did this agent/action improve the decision?
281
+ marginal = 0.0
282
+ if final_correct and not prev_correct:
283
+ marginal = 1.0
284
+ elif not final_correct and prev_correct:
285
+ marginal = -1.0
286
+
287
+ # Influence efficiency: marginal contribution per token
288
+ efficiency = marginal / max(1, tokens_used)
289
+
290
+ # Throughput: decisions per unit compute
291
+ throughput = 1.0 / max(1, compute_cost)
292
+
293
+ reward = raw_score + efficiency * 10.0 # scale efficiency
294
+
295
+ reason = (
296
+ f"final_correct={final_correct}, marginal={marginal:.2f}, "
297
+ f"efficiency={efficiency:.4f}, throughput={throughput:.4f}"
298
+ )
299
+
300
+ return OracleResult(
301
+ raw_score=raw_score,
302
+ confidence=0.8 if final_correct else 0.2,
303
+ evidence={"marginal": marginal, "efficiency": efficiency, "throughput": throughput},
304
+ reason=reason,
305
+ reward_value=max(-1.0, min(1.0, reward)),
306
+ )
307
+
308
+ # ------------------------------------------------------------------
309
+ # Gaming detection
310
+ # ------------------------------------------------------------------
311
+
312
+ def _detect_gaming(
313
+ self, agent_id: str, action: Dict, result: OracleResult
314
+ ) -> float:
315
+ """Return penalty value for detected gaming patterns."""
316
+ history = self._agent_history.setdefault(agent_id, [])
317
+ now = len(history) # simplistic step index
318
+
319
+ penalty = 0.0
320
+ tags = []
321
+
322
+ # 1. Spam: repeated low-value actions within window
323
+ window = history[-10:]
324
+ low_value_count = sum(1 for h in window if h["score"] < 0.2)
325
+ if low_value_count >= 7:
326
+ penalty += 0.3
327
+ tags.append("spam")
328
+
329
+ # 2. Hoarding: credit balance above threshold for many steps (handled in ledger)
330
+ # We add a lightweight signal here if the agent keeps submitting without earning
331
+ recent_earnings = [h["earned"] for h in history[-20:]]
332
+ if len(recent_earnings) >= 20 and sum(recent_earnings) < 1.0:
333
+ penalty += 0.2
334
+ tags.append("low_earning_pattern")
335
+
336
+ # 3. Verbose padding: tokens per unit impact below threshold
337
+ tokens = action.get("tokens_used", 0)
338
+ if tokens > 500 and result.raw_score < 0.3:
339
+ penalty += 0.15
340
+ tags.append("verbose_padding")
341
+
342
+ # 4. Over-abstention
343
+ if action.get("abstained", False):
344
+ abstention_rate = sum(1 for h in history[-20:] if h.get("abstained", False)) / max(1, len(history[-20:]))
345
+ if abstention_rate > 0.7:
346
+ penalty += 0.25
347
+ tags.append("over_abstention")
348
+
349
+ # 5. Confidence manipulation: very high confidence on wrong answers
350
+ if result.confidence > 0.8 and result.raw_score < 0.3:
351
+ penalty += 0.2
352
+ tags.append("confidence_manipulation")
353
+
354
+ history.append({
355
+ "step": now,
356
+ "score": result.raw_score,
357
+ "earned": result.reward_value,
358
+ "abstained": action.get("abstained", False),
359
+ "tokens": tokens,
360
+ })
361
+
362
+ # Keep history bounded
363
+ if len(history) > 1000:
364
+ self._agent_history[agent_id] = history[-500:]
365
+
366
+ result.failure_tags.extend(tags)
367
+ return penalty * self.gaming_weight
368
+
369
+ # ------------------------------------------------------------------
370
+ # Utilities
371
+ # ------------------------------------------------------------------
372
+
373
+ @staticmethod
374
+ def _fuzzy_match(a: Optional[str], b: Optional[str]) -> bool:
375
+ if a is None or b is None:
376
+ return False
377
+ a_norm = a.strip().lower()
378
+ b_norm = b.strip().lower()
379
+ return a_norm == b_norm or a_norm in b_norm or b_norm in a_norm
380
+
381
+ def compute_ece(
382
+ self, confidences: List[float], accuracies: List[bool], n_bins: int = 10
383
+ ) -> float:
384
+ """Compute Expected Calibration Error."""
385
+ conf = np.array(confidences)
386
+ acc = np.array(accuracies, dtype=float)
387
+ bins = np.linspace(0.0, 1.0, n_bins + 1)
388
+ ece = 0.0
389
+ for i in range(n_bins):
390
+ mask = (conf >= bins[i]) & (conf < bins[i + 1])
391
+ if i == n_bins - 1:
392
+ mask = (conf >= bins[i]) & (conf <= bins[i + 1])
393
+ if mask.sum() == 0:
394
+ continue
395
+ avg_conf = conf[mask].mean()
396
+ avg_acc = acc[mask].mean()
397
+ ece += (mask.sum() / len(conf)) * abs(avg_conf - avg_acc)
398
+ return float(ece)