occ-stack / oracle /oracle.py
narcolepticchicken's picture
Upload oracle/oracle.py
8f9dd28 verified
"""
Impact Oracle - scores whether an agent action produced measurable marginal value.
Rule-based to prevent reward hacking from neural reward models.
"""
import math
import random
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class OracleResult:
raw_score: float
cost_adjusted_score: float
confidence: float
evidence: Dict[str, Any]
reason: str
failure_tags: List[str] = field(default_factory=list)
reward_value: float = 0.0
class ImpactOracle:
"""
The Impact Oracle scores agent actions on verified impact.
It supports: code tasks, retrieval QA, multi-agent debate.
"""
def __init__(
self,
code_weights: Optional[Dict[str, float]] = None,
qa_weights: Optional[Dict[str, float]] = None,
debate_weights: Optional[Dict[str, float]] = None,
compute_penalty_rate: float = 0.0001,
calibration_weight: float = 0.2,
abstention_bonus: float = 1.0,
hallucination_penalty: float = 2.0,
confident_wrong_penalty: float = 3.0,
gaming_penalty: float = 2.0,
):
self.code_weights = code_weights or {
"correctness": 1.0,
"pass_at_k": 0.3,
"regression": -0.5,
"compute_penalty": 0.001,
}
self.qa_weights = qa_weights or {
"correctness": 1.0,
"evidence_support": 0.5,
"calibration": 0.2,
"abstention_utility": 1.0,
"hallucination_penalty": 2.0,
"confident_wrong_penalty": 3.0,
}
self.debate_weights = debate_weights or {
"decision_quality": 1.0,
"influence_efficiency": 0.5,
"throughput": 0.3,
"marginal_contribution": 0.5,
}
self.compute_penalty_rate = compute_penalty_rate
self.calibration_weight = calibration_weight
self.abstention_bonus = abstention_bonus
self.hallucination_penalty = hallucination_penalty
self.confident_wrong_penalty = confident_wrong_penalty
self.gaming_penalty = gaming_penalty
def score(
self,
mode: str,
action: Dict[str, Any],
context: Dict[str, Any],
result: Dict[str, Any],
agent_id: str = "",
) -> OracleResult:
if mode == "code":
return self._score_code(action, context, result, agent_id)
elif mode == "retrieval_qa":
return self._score_retrieval_qa(action, context, result, agent_id)
elif mode == "debate":
return self._score_debate(action, context, result, agent_id)
else:
return OracleResult(
raw_score=0.0,
cost_adjusted_score=0.0,
confidence=0.0,
evidence={},
reason=f"Unknown mode: {mode}",
failure_tags=["unknown_mode"],
reward_value=0.0,
)
def _score_code(
self,
action: Dict[str, Any],
context: Dict[str, Any],
result: Dict[str, Any],
agent_id: str,
) -> OracleResult:
correctness = result.get("correctness", 0.0)
pass_at_k = result.get("pass_at_k", 0.0)
regression = result.get("regression", False)
compute_cost = result.get("compute_cost", 0.0)
hidden_tests_pass = result.get("hidden_tests_pass", correctness)
public_pass = result.get("public_pass", correctness)
failure_tags = []
if public_pass and not hidden_tests_pass:
failure_tags.append("gaming_hidden_tests")
raw = (
correctness * self.code_weights["correctness"]
+ pass_at_k * self.code_weights["pass_at_k"]
+ (self.code_weights["regression"] if regression else 0.0)
- compute_cost * self.code_weights.get("compute_penalty", self.compute_penalty_rate)
)
if "gaming_hidden_tests" in failure_tags:
raw -= self.gaming_penalty
cost_adj = raw - compute_cost * self.compute_penalty_rate
confidence = result.get("confidence", correctness)
reason = f"correctness={correctness:.2f}, pass@k={pass_at_k:.2f}, cost={compute_cost}"
if failure_tags:
reason += f", failures={failure_tags}"
return OracleResult(
raw_score=raw,
cost_adjusted_score=cost_adj,
confidence=confidence,
evidence={"correctness": correctness, "pass_at_k": pass_at_k, "regression": regression},
reason=reason,
failure_tags=failure_tags,
reward_value=cost_adj,
)
def _score_retrieval_qa(
self,
action: Dict[str, Any],
context: Dict[str, Any],
result: Dict[str, Any],
agent_id: str,
) -> OracleResult:
gold_answer = context.get("gold_answer", "")
answer = result.get("answer", "")
confidence = result.get("confidence", 0.5)
evidence = result.get("evidence", {})
compute_cost = result.get("compute_cost", 0.0)
abstained = action.get("abstained", False)
failure_tags = []
if abstained:
is_unanswerable = context.get("is_unanswerable", False)
correct_abstention = is_unanswerable
raw = self.abstention_bonus if correct_abstention else -self.abstention_bonus
if not correct_abstention:
failure_tags.append("wrong_abstention")
else:
failure_tags.append("correct_abstention")
reason = f"abstained, correct={correct_abstention}"
else:
correctness = self._answer_correctness(answer, gold_answer)
entailment = evidence.get("entailment_score", 0.0)
contradiction = evidence.get("contradiction_score", 0.0)
hallucination = evidence.get("hallucination", False) or contradiction > 0.5
confident_wrong = (confidence > 0.8) and (correctness < 0.5)
compute_waste = compute_cost > 500 and correctness < 0.5
if hallucination:
failure_tags.append("hallucination")
if confident_wrong:
failure_tags.append("confident_wrong")
if compute_waste:
failure_tags.append("compute_waste")
if compute_cost > 2000:
failure_tags.append("excessive_compute")
raw = (
correctness * self.qa_weights["correctness"]
+ entailment * self.qa_weights.get("evidence_support", 0.5)
- (self.hallucination_penalty if hallucination else 0.0)
- (self.confident_wrong_penalty if confident_wrong else 0.0)
- compute_cost * self.compute_penalty_rate
)
brier = (confidence - correctness) ** 2
calibration_bonus = (1 - brier) * self.calibration_weight
raw += calibration_bonus
reason = f"correctness={correctness:.2f}, entailment={entailment:.2f}, conf={confidence:.2f}"
cost_adj = raw - compute_cost * self.compute_penalty_rate
if compute_cost > 100 and raw < 0.5:
cost_adj -= self.gaming_penalty * 0.5
return OracleResult(
raw_score=raw,
cost_adjusted_score=cost_adj,
confidence=confidence,
evidence=evidence,
reason=reason,
failure_tags=failure_tags,
reward_value=cost_adj,
)
def _score_debate(
self,
action: Dict[str, Any],
context: Dict[str, Any],
result: Dict[str, Any],
agent_id: str,
) -> OracleResult:
decision_quality = result.get("decision_quality", 0.0)
marginal = result.get("marginal_contribution", 0.0)
tokens = result.get("tokens", 0)
n_agents = context.get("n_agents", 1)
compute_cost = result.get("compute_cost", tokens)
spam = result.get("spam", False)
collusion = result.get("collusion", False)
failure_tags = []
if spam:
failure_tags.append("spam")
if collusion:
failure_tags.append("collusion")
if tokens > 5000:
failure_tags.append("verbose_waste")
raw = (
decision_quality * self.debate_weights["decision_quality"]
+ marginal * self.debate_weights["marginal_contribution"]
+ (1.0 / max(tokens, 1)) * self.debate_weights["influence_efficiency"]
- compute_cost * self.compute_penalty_rate
)
if spam:
raw -= self.gaming_penalty
if collusion:
raw -= self.gaming_penalty * 2
cost_adj = raw - compute_cost * self.compute_penalty_rate
return OracleResult(
raw_score=raw,
cost_adjusted_score=cost_adj,
confidence=result.get("confidence", 0.5),
evidence={"marginal": marginal, "tokens": tokens, "n_agents": n_agents},
reason=f"decision_quality={decision_quality:.2f}, marginal={marginal:.2f}, tokens={tokens}",
failure_tags=failure_tags,
reward_value=cost_adj,
)
def _answer_correctness(self, answer: str, gold: str) -> float:
if not answer or not gold:
return 0.0
ans = answer.strip().lower()
gld = gold.strip().lower()
if ans == gld:
return 1.0
if gld in ans or ans in gld:
return 0.5
return 0.0
def proper_score(self, prediction: float, outcome: float) -> float:
return -((prediction - outcome) ** 2)
def abstention_score(
self,
answer: Optional[str],
confidence: float,
evidence: Dict[str, Any],
outcome: float,
) -> float:
if answer is None:
return self.abstention_bonus if outcome < 0.5 else -self.abstention_bonus
return 0.0
def marginal_impact(self, before: OracleResult, after: OracleResult) -> float:
return after.cost_adjusted_score - before.cost_adjusted_score
def cost_adjusted_score(self, raw_score: float, compute_cost: float) -> float:
return raw_score - compute_cost * self.compute_penalty_rate