File size: 4,462 Bytes
427ee84
b81e833
427ee84
 
b81e833
427ee84
b81e833
 
 
 
427ee84
 
 
 
b81e833
 
427ee84
 
 
 
b81e833
 
 
 
427ee84
b81e833
427ee84
b81e833
 
 
427ee84
 
 
 
 
b81e833
 
 
 
 
427ee84
 
 
 
b81e833
427ee84
 
b81e833
427ee84
b81e833
427ee84
 
b81e833
 
 
 
 
 
 
 
 
427ee84
b81e833
 
 
 
 
 
 
427ee84
 
 
b81e833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427ee84
b81e833
 
427ee84
 
b81e833
 
427ee84
b81e833
 
 
 
 
 
 
 
427ee84
 
b81e833
 
 
 
 
 
427ee84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Reward module - GRPO-compatible reward hook using Impact Oracle.
"""
import math
from typing import Any, Dict, List, Optional

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from oracle.oracle import ImpactOracle


class RewardHook:
    """
    Converts Impact Oracle scores into RL rewards.
    Compatible with TRL GRPOTrainer via reward_funcs parameter.
    """

    def __init__(
        self,
        oracle: Optional[ImpactOracle] = None,
        mode: str = "retrieval_qa",
        compute_budget: float = 10000.0,
        target_accuracy: float = 0.8,
    ):
        self.oracle = oracle or ImpactOracle()
        self.mode = mode
        self.compute_budget = compute_budget
        self.target_accuracy = target_accuracy
        self.trajectory_history: List[Dict[str, Any]] = []

    def compute_rewards(
        self,
        prompts: List[str],
        completions: List[str],
        answers: List[Optional[str]],
        gold_answers: List[str],
        confidences: List[float],
        compute_costs: List[float],
        agent_ids: Optional[List[str]] = None,
        **kwargs,
    ) -> List[float]:
        """
        Compute rewards for a batch of completions.
        Returns list of float rewards (one per completion).
        """
        rewards = []
        agent_ids = agent_ids or ["agent_default"] * len(prompts)

        for i in range(len(prompts)):
            oracle_res = self.oracle.score(
                mode=self.mode,
                action={"abstained": answers[i] is None},
                context={"gold_answer": gold_answers[i]},
                result={
                    "answer": answers[i],
                    "confidence": confidences[i],
                    "evidence": kwargs.get("evidences", [{}] * len(prompts))[i],
                    "compute_cost": compute_costs[i],
                },
                agent_id=agent_ids[i],
            )
            rewards.append(oracle_res.reward_value)
            self.trajectory_history.append({
                "prompt": prompts[i][:100],
                "reward": oracle_res.reward_value,
                "raw_score": oracle_res.raw_score,
                "failure_tags": oracle_res.failure_tags,
            })

        return rewards

    def compute_reward_single(
        self,
        prompt: str,
        completion: str,
        answer: Optional[str],
        gold_answer: str,
        confidence: float,
        compute_cost: float,
        agent_id: str = "agent_default",
        evidence: Optional[Dict[str, Any]] = None,
    ) -> float:
        """Compute reward for a single completion."""
        oracle_res = self.oracle.score(
            mode=self.mode,
            action={"abstained": answer is None},
            context={"gold_answer": gold_answer},
            result={
                "answer": answer,
                "confidence": confidence,
                "evidence": evidence or {},
                "compute_cost": compute_cost,
            },
            agent_id=agent_id,
        )
        self.trajectory_history.append({
            "prompt": prompt[:100],
            "reward": oracle_res.reward_value,
            "raw_score": oracle_res.raw_score,
            "failure_tags": oracle_res.failure_tags,
        })
        return oracle_res.reward_value


class OfflinePolicyComparator:
    """
    Compare two policies using offline trajectory data.
    Useful when full GRPO training is not feasible.
    """

    def __init__(self, reward_hook: RewardHook):
        self.reward_hook = reward_hook

    def compare(
        self,
        policy_a_trajectories: List[Dict[str, Any]],
        policy_b_trajectories: List[Dict[str, Any]],
    ) -> Dict[str, Any]:
        """Compare two policies on same test set."""
        rewards_a = [t["reward"] for t in policy_a_trajectories]
        rewards_b = [t["reward"] for t in policy_b_trajectories]

        return {
            "mean_reward_a": sum(rewards_a) / len(rewards_a),
            "mean_reward_b": sum(rewards_b) / len(rewards_b),
            "win_rate": sum(1 for a, b in zip(rewards_a, rewards_b) if a > b) / len(rewards_a),
            "improvement": (sum(rewards_a) - sum(rewards_b)) / max(abs(sum(rewards_b)), 1e-6),
            "policy_a_failures": sum(1 for t in policy_a_trajectories if t.get("failure_tags")),
            "policy_b_failures": sum(1 for t in policy_b_trajectories if t.get("failure_tags")),
        }