File size: 4,462 Bytes
427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 b81e833 427ee84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """
Reward module - GRPO-compatible reward hook using Impact Oracle.
"""
import math
from typing import Any, Dict, List, Optional
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from oracle.oracle import ImpactOracle
class RewardHook:
"""
Converts Impact Oracle scores into RL rewards.
Compatible with TRL GRPOTrainer via reward_funcs parameter.
"""
def __init__(
self,
oracle: Optional[ImpactOracle] = None,
mode: str = "retrieval_qa",
compute_budget: float = 10000.0,
target_accuracy: float = 0.8,
):
self.oracle = oracle or ImpactOracle()
self.mode = mode
self.compute_budget = compute_budget
self.target_accuracy = target_accuracy
self.trajectory_history: List[Dict[str, Any]] = []
def compute_rewards(
self,
prompts: List[str],
completions: List[str],
answers: List[Optional[str]],
gold_answers: List[str],
confidences: List[float],
compute_costs: List[float],
agent_ids: Optional[List[str]] = None,
**kwargs,
) -> List[float]:
"""
Compute rewards for a batch of completions.
Returns list of float rewards (one per completion).
"""
rewards = []
agent_ids = agent_ids or ["agent_default"] * len(prompts)
for i in range(len(prompts)):
oracle_res = self.oracle.score(
mode=self.mode,
action={"abstained": answers[i] is None},
context={"gold_answer": gold_answers[i]},
result={
"answer": answers[i],
"confidence": confidences[i],
"evidence": kwargs.get("evidences", [{}] * len(prompts))[i],
"compute_cost": compute_costs[i],
},
agent_id=agent_ids[i],
)
rewards.append(oracle_res.reward_value)
self.trajectory_history.append({
"prompt": prompts[i][:100],
"reward": oracle_res.reward_value,
"raw_score": oracle_res.raw_score,
"failure_tags": oracle_res.failure_tags,
})
return rewards
def compute_reward_single(
self,
prompt: str,
completion: str,
answer: Optional[str],
gold_answer: str,
confidence: float,
compute_cost: float,
agent_id: str = "agent_default",
evidence: Optional[Dict[str, Any]] = None,
) -> float:
"""Compute reward for a single completion."""
oracle_res = self.oracle.score(
mode=self.mode,
action={"abstained": answer is None},
context={"gold_answer": gold_answer},
result={
"answer": answer,
"confidence": confidence,
"evidence": evidence or {},
"compute_cost": compute_cost,
},
agent_id=agent_id,
)
self.trajectory_history.append({
"prompt": prompt[:100],
"reward": oracle_res.reward_value,
"raw_score": oracle_res.raw_score,
"failure_tags": oracle_res.failure_tags,
})
return oracle_res.reward_value
class OfflinePolicyComparator:
"""
Compare two policies using offline trajectory data.
Useful when full GRPO training is not feasible.
"""
def __init__(self, reward_hook: RewardHook):
self.reward_hook = reward_hook
def compare(
self,
policy_a_trajectories: List[Dict[str, Any]],
policy_b_trajectories: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Compare two policies on same test set."""
rewards_a = [t["reward"] for t in policy_a_trajectories]
rewards_b = [t["reward"] for t in policy_b_trajectories]
return {
"mean_reward_a": sum(rewards_a) / len(rewards_a),
"mean_reward_b": sum(rewards_b) / len(rewards_b),
"win_rate": sum(1 for a, b in zip(rewards_a, rewards_b) if a > b) / len(rewards_a),
"improvement": (sum(rewards_a) - sum(rewards_b)) / max(abs(sum(rewards_b)), 1e-6),
"policy_a_failures": sum(1 for t in policy_a_trajectories if t.get("failure_tags")),
"policy_b_failures": sum(1 for t in policy_b_trajectories if t.get("failure_tags")),
}
|