occ-stack / rl /reward.py
narcolepticchicken's picture
Upload rl/reward.py
b81e833 verified
"""
Reward module - GRPO-compatible reward hook using Impact Oracle.
"""
import math
from typing import Any, Dict, List, Optional
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from oracle.oracle import ImpactOracle
class RewardHook:
"""
Converts Impact Oracle scores into RL rewards.
Compatible with TRL GRPOTrainer via reward_funcs parameter.
"""
def __init__(
self,
oracle: Optional[ImpactOracle] = None,
mode: str = "retrieval_qa",
compute_budget: float = 10000.0,
target_accuracy: float = 0.8,
):
self.oracle = oracle or ImpactOracle()
self.mode = mode
self.compute_budget = compute_budget
self.target_accuracy = target_accuracy
self.trajectory_history: List[Dict[str, Any]] = []
def compute_rewards(
self,
prompts: List[str],
completions: List[str],
answers: List[Optional[str]],
gold_answers: List[str],
confidences: List[float],
compute_costs: List[float],
agent_ids: Optional[List[str]] = None,
**kwargs,
) -> List[float]:
"""
Compute rewards for a batch of completions.
Returns list of float rewards (one per completion).
"""
rewards = []
agent_ids = agent_ids or ["agent_default"] * len(prompts)
for i in range(len(prompts)):
oracle_res = self.oracle.score(
mode=self.mode,
action={"abstained": answers[i] is None},
context={"gold_answer": gold_answers[i]},
result={
"answer": answers[i],
"confidence": confidences[i],
"evidence": kwargs.get("evidences", [{}] * len(prompts))[i],
"compute_cost": compute_costs[i],
},
agent_id=agent_ids[i],
)
rewards.append(oracle_res.reward_value)
self.trajectory_history.append({
"prompt": prompts[i][:100],
"reward": oracle_res.reward_value,
"raw_score": oracle_res.raw_score,
"failure_tags": oracle_res.failure_tags,
})
return rewards
def compute_reward_single(
self,
prompt: str,
completion: str,
answer: Optional[str],
gold_answer: str,
confidence: float,
compute_cost: float,
agent_id: str = "agent_default",
evidence: Optional[Dict[str, Any]] = None,
) -> float:
"""Compute reward for a single completion."""
oracle_res = self.oracle.score(
mode=self.mode,
action={"abstained": answer is None},
context={"gold_answer": gold_answer},
result={
"answer": answer,
"confidence": confidence,
"evidence": evidence or {},
"compute_cost": compute_cost,
},
agent_id=agent_id,
)
self.trajectory_history.append({
"prompt": prompt[:100],
"reward": oracle_res.reward_value,
"raw_score": oracle_res.raw_score,
"failure_tags": oracle_res.failure_tags,
})
return oracle_res.reward_value
class OfflinePolicyComparator:
"""
Compare two policies using offline trajectory data.
Useful when full GRPO training is not feasible.
"""
def __init__(self, reward_hook: RewardHook):
self.reward_hook = reward_hook
def compare(
self,
policy_a_trajectories: List[Dict[str, Any]],
policy_b_trajectories: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Compare two policies on same test set."""
rewards_a = [t["reward"] for t in policy_a_trajectories]
rewards_b = [t["reward"] for t in policy_b_trajectories]
return {
"mean_reward_a": sum(rewards_a) / len(rewards_a),
"mean_reward_b": sum(rewards_b) / len(rewards_b),
"win_rate": sum(1 for a, b in zip(rewards_a, rewards_b) if a > b) / len(rewards_a),
"improvement": (sum(rewards_a) - sum(rewards_b)) / max(abs(sum(rewards_b)), 1e-6),
"policy_a_failures": sum(1 for t in policy_a_trajectories if t.get("failure_tags")),
"policy_b_failures": sum(1 for t in policy_b_trajectories if t.get("failure_tags")),
}