narcolepticchicken commited on
Commit
427ee84
·
verified ·
1 Parent(s): 3b67aea

Upload rl/reward.py

Browse files
Files changed (1) hide show
  1. rl/reward.py +219 -0
rl/reward.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRPO-compatible reward hook using Impact Oracle as reward source.
3
+ Includes an offline policy comparator for when training is infeasible.
4
+ """
5
+
6
+ import json
7
+ import math
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+
14
+
15
+ @dataclass
16
+ class Trajectory:
17
+ prompt: str = ""
18
+ completion: str = ""
19
+ oracle_result: Dict = field(default_factory=dict)
20
+ reward: float = 0.0
21
+ compute_cost: float = 0.0
22
+ mode: str = "code"
23
+ metadata: Dict = field(default_factory=dict)
24
+
25
+
26
+ class RewardHook:
27
+ """
28
+ Wraps Impact Oracle + Ledger + Broker into a reward function
29
+ compatible with TRL GRPOTrainer.
30
+
31
+ Usage with GRPOTrainer:
32
+ reward_fn = RewardHook(oracle, ledger, broker).compute_rewards
33
+ trainer = GRPOTrainer(..., reward_func=reward_fn)
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ oracle,
39
+ ledger,
40
+ broker,
41
+ mode: str = "code",
42
+ agent_id: str = "default_agent",
43
+ ):
44
+ self.oracle = oracle
45
+ self.ledger = ledger
46
+ self.broker = broker
47
+ self.mode = mode
48
+ self.agent_id = agent_id
49
+ self._trajectories: List[Trajectory] = []
50
+
51
+ def compute_rewards(
52
+ self,
53
+ prompts: List[str],
54
+ completions: List[str],
55
+ oracle_inputs: Optional[List[Dict]] = None,
56
+ **kwargs,
57
+ ) -> List[float]:
58
+ """
59
+ Compute rewards for a batch of completions.
60
+
61
+ Args:
62
+ prompts: list of prompt strings
63
+ completions: list of completion strings
64
+ oracle_inputs: optional list of dicts with keys:
65
+ {"action": ..., "context": ..., "result": ...}
66
+
67
+ Returns:
68
+ list of float rewards (same length as prompts/completions)
69
+ """
70
+ rewards = []
71
+ oracle_inputs = oracle_inputs or [{} for _ in prompts]
72
+
73
+ for prompt, completion, oin in zip(prompts, completions, oracle_inputs):
74
+ action = oin.get("action", {"text": completion})
75
+ context = oin.get("context", {})
76
+ result = oin.get("result", {})
77
+ result.setdefault("compute_cost", len(completion.split()))
78
+
79
+ oracle_res = self.oracle.score(
80
+ mode=self.mode,
81
+ action=action,
82
+ context=context,
83
+ result=result,
84
+ agent_id=self.agent_id,
85
+ )
86
+
87
+ reward = oracle_res.reward_value
88
+ rewards.append(reward)
89
+
90
+ # Ledger update
91
+ self.ledger.earn(
92
+ agent_id=self.agent_id,
93
+ task_id=oin.get("task_id", "default_task"),
94
+ action_id=oin.get("action_id", "default_action"),
95
+ amount=max(0.0, reward),
96
+ oracle_score=oracle_res.raw_score,
97
+ compute_cost=result["compute_cost"],
98
+ reason=oracle_res.reason,
99
+ capability_scope=oin.get("capability_scope", "general"),
100
+ task_scope=oin.get("task_scope", "global"),
101
+ )
102
+
103
+ self._trajectories.append(
104
+ Trajectory(
105
+ prompt=prompt,
106
+ completion=completion,
107
+ oracle_result={
108
+ "raw_score": oracle_res.raw_score,
109
+ "cost_adjusted_score": oracle_res.cost_adjusted_score,
110
+ "confidence": oracle_res.confidence,
111
+ "reason": oracle_res.reason,
112
+ "failure_tags": oracle_res.failure_tags,
113
+ },
114
+ reward=reward,
115
+ compute_cost=result["compute_cost"],
116
+ mode=self.mode,
117
+ metadata=oin,
118
+ )
119
+ )
120
+
121
+ return rewards
122
+
123
+ def get_trajectories(self) -> List[Trajectory]:
124
+ return self._trajectories
125
+
126
+ def save_trajectories(self, path: str):
127
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
128
+ with open(path, "w") as f:
129
+ for t in self._trajectories:
130
+ d = {
131
+ "prompt": t.prompt,
132
+ "completion": t.completion,
133
+ "reward": t.reward,
134
+ "compute_cost": t.compute_cost,
135
+ "mode": t.mode,
136
+ "metadata": t.metadata,
137
+ }
138
+ f.write(json.dumps(d, default=str) + "\n")
139
+
140
+
141
+ class OfflineComparator:
142
+ """
143
+ Compare two policies using saved trajectories when online GRPO training
144
+ is infeasible due to compute constraints.
145
+ """
146
+
147
+ def __init__(self, baseline_path: Optional[str] = None):
148
+ self.baseline_path = baseline_path
149
+ self.baseline: List[Trajectory] = []
150
+ if baseline_path and Path(baseline_path).exists():
151
+ self._load(baseline_path)
152
+
153
+ def _load(self, path: str):
154
+ with open(path, "r") as f:
155
+ for line in f:
156
+ d = json.loads(line)
157
+ self.baseline.append(Trajectory(**d))
158
+
159
+ def compare(self, candidate_trajectories: List[Trajectory]) -> Dict:
160
+ """
161
+ Return comparative metrics between candidate and baseline.
162
+ """
163
+ if not self.baseline:
164
+ return self._summarize(candidate_trajectories, label="candidate")
165
+
166
+ base = self._summarize(self.baseline, label="baseline")
167
+ cand = self._summarize(candidate_trajectories, label="candidate")
168
+
169
+ # Paired comparison on common prompts if available
170
+ base_by_prompt = {t.prompt: t for t in self.baseline}
171
+ cand_by_prompt = {t.prompt: t for t in candidate_trajectories}
172
+ common = set(base_by_prompt.keys()) & set(cand_by_prompt.keys())
173
+
174
+ reward_diffs = []
175
+ cost_diffs = []
176
+ for p in common:
177
+ reward_diffs.append(cand_by_prompt[p].reward - base_by_prompt[p].reward)
178
+ cost_diffs.append(
179
+ cand_by_prompt[p].compute_cost - base_by_prompt[p].compute_cost
180
+ )
181
+
182
+ return {
183
+ "baseline": base,
184
+ "candidate": cand,
185
+ "common_prompts": len(common),
186
+ "mean_reward_diff": float(np.mean(reward_diffs)) if reward_diffs else None,
187
+ "mean_cost_diff": float(np.mean(cost_diffs)) if cost_diffs else None,
188
+ "reward_p_value": None, # placeholder for t-test
189
+ "cost_p_value": None,
190
+ }
191
+
192
+ @staticmethod
193
+ def _summarize(trajectories: List[Trajectory], label: str) -> Dict:
194
+ rewards = [t.reward for t in trajectories]
195
+ costs = [t.compute_cost for t in trajectories]
196
+ return {
197
+ "label": label,
198
+ "n": len(trajectories),
199
+ "mean_reward": float(np.mean(rewards)) if rewards else 0.0,
200
+ "std_reward": float(np.std(rewards)) if rewards else 0.0,
201
+ "mean_cost": float(np.mean(costs)) if costs else 0.0,
202
+ "std_cost": float(np.std(costs)) if costs else 0.0,
203
+ "total_cost": float(np.sum(costs)) if costs else 0.0,
204
+ "success_rate": float(np.mean([r > 0.5 for r in rewards])) if rewards else 0.0,
205
+ }
206
+
207
+ def save_baseline(self, trajectories: List[Trajectory], path: str):
208
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
209
+ with open(path, "w") as f:
210
+ for t in trajectories:
211
+ d = {
212
+ "prompt": t.prompt,
213
+ "completion": t.completion,
214
+ "reward": t.reward,
215
+ "compute_cost": t.compute_cost,
216
+ "mode": t.mode,
217
+ "metadata": t.metadata,
218
+ }
219
+ f.write(json.dumps(d, default=str) + "\n")