narcolepticchicken commited on
Commit
b81e833
·
verified ·
1 Parent(s): 0da095b

Upload rl/reward.py

Browse files
Files changed (1) hide show
  1. rl/reward.py +92 -180
rl/reward.py CHANGED
@@ -1,219 +1,131 @@
1
  """
2
- GRPO-compatible reward hook using Impact Oracle as reward source.
3
- Includes an offline policy comparator for when training is infeasible.
4
  """
5
-
6
- import json
7
  import math
8
- from dataclasses import dataclass, field
9
- from pathlib import Path
10
- from typing import Any, Dict, List, Optional, Tuple
11
-
12
- import numpy as np
13
-
14
 
15
- @dataclass
16
- class Trajectory:
17
- prompt: str = ""
18
- completion: str = ""
19
- oracle_result: Dict = field(default_factory=dict)
20
- reward: float = 0.0
21
- compute_cost: float = 0.0
22
- mode: str = "code"
23
- metadata: Dict = field(default_factory=dict)
24
 
25
 
26
  class RewardHook:
27
  """
28
- Wraps Impact Oracle + Ledger + Broker into a reward function
29
- compatible with TRL GRPOTrainer.
30
-
31
- Usage with GRPOTrainer:
32
- reward_fn = RewardHook(oracle, ledger, broker).compute_rewards
33
- trainer = GRPOTrainer(..., reward_func=reward_fn)
34
  """
35
 
36
  def __init__(
37
  self,
38
- oracle,
39
- ledger,
40
- broker,
41
- mode: str = "code",
42
- agent_id: str = "default_agent",
43
  ):
44
- self.oracle = oracle
45
- self.ledger = ledger
46
- self.broker = broker
47
  self.mode = mode
48
- self.agent_id = agent_id
49
- self._trajectories: List[Trajectory] = []
 
50
 
51
  def compute_rewards(
52
  self,
53
  prompts: List[str],
54
  completions: List[str],
55
- oracle_inputs: Optional[List[Dict]] = None,
 
 
 
 
56
  **kwargs,
57
  ) -> List[float]:
58
  """
59
  Compute rewards for a batch of completions.
60
-
61
- Args:
62
- prompts: list of prompt strings
63
- completions: list of completion strings
64
- oracle_inputs: optional list of dicts with keys:
65
- {"action": ..., "context": ..., "result": ...}
66
-
67
- Returns:
68
- list of float rewards (same length as prompts/completions)
69
  """
70
  rewards = []
71
- oracle_inputs = oracle_inputs or [{} for _ in prompts]
72
-
73
- for prompt, completion, oin in zip(prompts, completions, oracle_inputs):
74
- action = oin.get("action", {"text": completion})
75
- context = oin.get("context", {})
76
- result = oin.get("result", {})
77
- result.setdefault("compute_cost", len(completion.split()))
78
 
 
79
  oracle_res = self.oracle.score(
80
  mode=self.mode,
81
- action=action,
82
- context=context,
83
- result=result,
84
- agent_id=self.agent_id,
85
- )
86
-
87
- reward = oracle_res.reward_value
88
- rewards.append(reward)
89
-
90
- # Ledger update
91
- self.ledger.earn(
92
- agent_id=self.agent_id,
93
- task_id=oin.get("task_id", "default_task"),
94
- action_id=oin.get("action_id", "default_action"),
95
- amount=max(0.0, reward),
96
- oracle_score=oracle_res.raw_score,
97
- compute_cost=result["compute_cost"],
98
- reason=oracle_res.reason,
99
- capability_scope=oin.get("capability_scope", "general"),
100
- task_scope=oin.get("task_scope", "global"),
101
- )
102
-
103
- self._trajectories.append(
104
- Trajectory(
105
- prompt=prompt,
106
- completion=completion,
107
- oracle_result={
108
- "raw_score": oracle_res.raw_score,
109
- "cost_adjusted_score": oracle_res.cost_adjusted_score,
110
- "confidence": oracle_res.confidence,
111
- "reason": oracle_res.reason,
112
- "failure_tags": oracle_res.failure_tags,
113
- },
114
- reward=reward,
115
- compute_cost=result["compute_cost"],
116
- mode=self.mode,
117
- metadata=oin,
118
- )
119
  )
 
 
 
 
 
 
 
120
 
121
  return rewards
122
 
123
- def get_trajectories(self) -> List[Trajectory]:
124
- return self._trajectories
125
-
126
- def save_trajectories(self, path: str):
127
- Path(path).parent.mkdir(parents=True, exist_ok=True)
128
- with open(path, "w") as f:
129
- for t in self._trajectories:
130
- d = {
131
- "prompt": t.prompt,
132
- "completion": t.completion,
133
- "reward": t.reward,
134
- "compute_cost": t.compute_cost,
135
- "mode": t.mode,
136
- "metadata": t.metadata,
137
- }
138
- f.write(json.dumps(d, default=str) + "\n")
139
-
140
-
141
- class OfflineComparator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  """
143
- Compare two policies using saved trajectories when online GRPO training
144
- is infeasible due to compute constraints.
145
  """
146
 
147
- def __init__(self, baseline_path: Optional[str] = None):
148
- self.baseline_path = baseline_path
149
- self.baseline: List[Trajectory] = []
150
- if baseline_path and Path(baseline_path).exists():
151
- self._load(baseline_path)
152
-
153
- def _load(self, path: str):
154
- with open(path, "r") as f:
155
- for line in f:
156
- d = json.loads(line)
157
- self.baseline.append(Trajectory(**d))
158
-
159
- def compare(self, candidate_trajectories: List[Trajectory]) -> Dict:
160
- """
161
- Return comparative metrics between candidate and baseline.
162
- """
163
- if not self.baseline:
164
- return self._summarize(candidate_trajectories, label="candidate")
165
-
166
- base = self._summarize(self.baseline, label="baseline")
167
- cand = self._summarize(candidate_trajectories, label="candidate")
168
 
169
- # Paired comparison on common prompts if available
170
- base_by_prompt = {t.prompt: t for t in self.baseline}
171
- cand_by_prompt = {t.prompt: t for t in candidate_trajectories}
172
- common = set(base_by_prompt.keys()) & set(cand_by_prompt.keys())
173
-
174
- reward_diffs = []
175
- cost_diffs = []
176
- for p in common:
177
- reward_diffs.append(cand_by_prompt[p].reward - base_by_prompt[p].reward)
178
- cost_diffs.append(
179
- cand_by_prompt[p].compute_cost - base_by_prompt[p].compute_cost
180
- )
181
-
182
- return {
183
- "baseline": base,
184
- "candidate": cand,
185
- "common_prompts": len(common),
186
- "mean_reward_diff": float(np.mean(reward_diffs)) if reward_diffs else None,
187
- "mean_cost_diff": float(np.mean(cost_diffs)) if cost_diffs else None,
188
- "reward_p_value": None, # placeholder for t-test
189
- "cost_p_value": None,
190
- }
191
 
192
- @staticmethod
193
- def _summarize(trajectories: List[Trajectory], label: str) -> Dict:
194
- rewards = [t.reward for t in trajectories]
195
- costs = [t.compute_cost for t in trajectories]
196
  return {
197
- "label": label,
198
- "n": len(trajectories),
199
- "mean_reward": float(np.mean(rewards)) if rewards else 0.0,
200
- "std_reward": float(np.std(rewards)) if rewards else 0.0,
201
- "mean_cost": float(np.mean(costs)) if costs else 0.0,
202
- "std_cost": float(np.std(costs)) if costs else 0.0,
203
- "total_cost": float(np.sum(costs)) if costs else 0.0,
204
- "success_rate": float(np.mean([r > 0.5 for r in rewards])) if rewards else 0.0,
205
  }
206
-
207
- def save_baseline(self, trajectories: List[Trajectory], path: str):
208
- Path(path).parent.mkdir(parents=True, exist_ok=True)
209
- with open(path, "w") as f:
210
- for t in trajectories:
211
- d = {
212
- "prompt": t.prompt,
213
- "completion": t.completion,
214
- "reward": t.reward,
215
- "compute_cost": t.compute_cost,
216
- "mode": t.mode,
217
- "metadata": t.metadata,
218
- }
219
- f.write(json.dumps(d, default=str) + "\n")
 
1
  """
2
+ Reward module - GRPO-compatible reward hook using Impact Oracle.
 
3
  """
 
 
4
  import math
5
+ from typing import Any, Dict, List, Optional
 
 
 
 
 
6
 
7
+ import sys
8
+ from pathlib import Path
9
+ sys.path.insert(0, str(Path(__file__).parent.parent))
10
+ from oracle.oracle import ImpactOracle
 
 
 
 
 
11
 
12
 
13
  class RewardHook:
14
  """
15
+ Converts Impact Oracle scores into RL rewards.
16
+ Compatible with TRL GRPOTrainer via reward_funcs parameter.
 
 
 
 
17
  """
18
 
19
  def __init__(
20
  self,
21
+ oracle: Optional[ImpactOracle] = None,
22
+ mode: str = "retrieval_qa",
23
+ compute_budget: float = 10000.0,
24
+ target_accuracy: float = 0.8,
 
25
  ):
26
+ self.oracle = oracle or ImpactOracle()
 
 
27
  self.mode = mode
28
+ self.compute_budget = compute_budget
29
+ self.target_accuracy = target_accuracy
30
+ self.trajectory_history: List[Dict[str, Any]] = []
31
 
32
  def compute_rewards(
33
  self,
34
  prompts: List[str],
35
  completions: List[str],
36
+ answers: List[Optional[str]],
37
+ gold_answers: List[str],
38
+ confidences: List[float],
39
+ compute_costs: List[float],
40
+ agent_ids: Optional[List[str]] = None,
41
  **kwargs,
42
  ) -> List[float]:
43
  """
44
  Compute rewards for a batch of completions.
45
+ Returns list of float rewards (one per completion).
 
 
 
 
 
 
 
 
46
  """
47
  rewards = []
48
+ agent_ids = agent_ids or ["agent_default"] * len(prompts)
 
 
 
 
 
 
49
 
50
+ for i in range(len(prompts)):
51
  oracle_res = self.oracle.score(
52
  mode=self.mode,
53
+ action={"abstained": answers[i] is None},
54
+ context={"gold_answer": gold_answers[i]},
55
+ result={
56
+ "answer": answers[i],
57
+ "confidence": confidences[i],
58
+ "evidence": kwargs.get("evidences", [{}] * len(prompts))[i],
59
+ "compute_cost": compute_costs[i],
60
+ },
61
+ agent_id=agent_ids[i],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
+ rewards.append(oracle_res.reward_value)
64
+ self.trajectory_history.append({
65
+ "prompt": prompts[i][:100],
66
+ "reward": oracle_res.reward_value,
67
+ "raw_score": oracle_res.raw_score,
68
+ "failure_tags": oracle_res.failure_tags,
69
+ })
70
 
71
  return rewards
72
 
73
+ def compute_reward_single(
74
+ self,
75
+ prompt: str,
76
+ completion: str,
77
+ answer: Optional[str],
78
+ gold_answer: str,
79
+ confidence: float,
80
+ compute_cost: float,
81
+ agent_id: str = "agent_default",
82
+ evidence: Optional[Dict[str, Any]] = None,
83
+ ) -> float:
84
+ """Compute reward for a single completion."""
85
+ oracle_res = self.oracle.score(
86
+ mode=self.mode,
87
+ action={"abstained": answer is None},
88
+ context={"gold_answer": gold_answer},
89
+ result={
90
+ "answer": answer,
91
+ "confidence": confidence,
92
+ "evidence": evidence or {},
93
+ "compute_cost": compute_cost,
94
+ },
95
+ agent_id=agent_id,
96
+ )
97
+ self.trajectory_history.append({
98
+ "prompt": prompt[:100],
99
+ "reward": oracle_res.reward_value,
100
+ "raw_score": oracle_res.raw_score,
101
+ "failure_tags": oracle_res.failure_tags,
102
+ })
103
+ return oracle_res.reward_value
104
+
105
+
106
+ class OfflinePolicyComparator:
107
  """
108
+ Compare two policies using offline trajectory data.
109
+ Useful when full GRPO training is not feasible.
110
  """
111
 
112
+ def __init__(self, reward_hook: RewardHook):
113
+ self.reward_hook = reward_hook
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ def compare(
116
+ self,
117
+ policy_a_trajectories: List[Dict[str, Any]],
118
+ policy_b_trajectories: List[Dict[str, Any]],
119
+ ) -> Dict[str, Any]:
120
+ """Compare two policies on same test set."""
121
+ rewards_a = [t["reward"] for t in policy_a_trajectories]
122
+ rewards_b = [t["reward"] for t in policy_b_trajectories]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
 
 
 
124
  return {
125
+ "mean_reward_a": sum(rewards_a) / len(rewards_a),
126
+ "mean_reward_b": sum(rewards_b) / len(rewards_b),
127
+ "win_rate": sum(1 for a, b in zip(rewards_a, rewards_b) if a > b) / len(rewards_a),
128
+ "improvement": (sum(rewards_a) - sum(rewards_b)) / max(abs(sum(rewards_b)), 1e-6),
129
+ "policy_a_failures": sum(1 for t in policy_a_trajectories if t.get("failure_tags")),
130
+ "policy_b_failures": sum(1 for t in policy_b_trajectories if t.get("failure_tags")),
 
 
131
  }