narcolepticchicken commited on
Commit
f4a0835
·
verified ·
1 Parent(s): 69dc3e0

Upload grpo_hook.py

Browse files
Files changed (1) hide show
  1. grpo_hook.py +101 -90
grpo_hook.py CHANGED
@@ -1,113 +1,124 @@
1
  """
2
- Minimal GRPO-compatible reward hook demonstration.
3
 
4
- If full GRPO training is feasible, use this with TRL GRPOTrainer.
5
- If not, use OfflineComparator for policy evaluation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  import json
9
  from pathlib import Path
10
  from typing import Dict, List, Optional
11
 
12
- import numpy as np
13
-
14
  from oracle.oracle import ImpactOracle
15
  from ledger.ledger import CreditLedger
16
  from broker.broker import ResourceBroker
17
- from rl.reward import RewardHook, OfflineComparator
18
 
19
 
20
- def demo_grpo_hook():
 
 
 
 
 
 
21
  """
22
- Demonstrate the reward hook with synthetic completions.
23
- This is a toy loop showing how GRPO reward computation would work.
 
24
  """
25
- oracle = ImpactOracle(compute_budget=1e5)
26
- ledger = CreditLedger(decay_lambda=0.05)
27
- broker = ResourceBroker()
28
- hook = RewardHook(oracle, ledger, broker, mode="code", agent_id="demo_agent")
29
-
30
- # Simulate a group of completions (as in GRPO)
31
- prompts = [
32
- "def add(a, b):\n return",
33
- "def add(a, b):\n return",
34
- "def add(a, b):\n return",
35
- ]
36
- completions = [
37
- "a + b",
38
- "a * b",
39
- "a + b + 0",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ]
41
- oracle_inputs = [
42
- {
43
- "action": {"text": c},
44
- "context": {"previous_passed": False},
45
- "result": {"passed": True, "hidden_passed": True, "compute_cost": 5.0},
46
- "task_id": "task_1",
47
- "action_id": f"comp_{i}",
48
- }
49
- for i, c in enumerate(completions)
50
  ]
51
- # Fix the wrong one
52
- oracle_inputs[1]["result"]["passed"] = False
53
- oracle_inputs[1]["result"]["hidden_passed"] = False
54
 
55
- rewards = hook.compute_rewards(prompts, completions, oracle_inputs)
56
- print("GRPO Hook Demo")
57
- print("Prompts:", prompts)
58
- print("Completions:", completions)
59
- print("Rewards:", rewards)
60
-
61
- # Save trajectories for offline comparison
62
- hook.save_trajectories("/app/occ/reports/demo_trajectories.jsonl")
63
- print("Saved trajectories to reports/demo_trajectories.jsonl")
64
-
65
- return hook
66
-
67
-
68
- def demo_offline_comparison():
69
- """
70
- Compare two policies using offline trajectory comparison.
71
- """
72
- # Create baseline policy trajectories
73
- baseline_trajs = []
74
- for i in range(10):
75
- t = type("T", (), {
76
- "prompt": f"prompt_{i}",
77
- "completion": f"baseline_completion_{i}",
78
- "reward": 0.5 + np.random.rand() * 0.3,
79
- "compute_cost": 100.0,
80
- "mode": "code",
81
- "metadata": {},
82
- })()
83
- baseline_trajs.append(t)
84
-
85
- # Create candidate policy trajectories
86
- candidate_trajs = []
87
- for i in range(10):
88
- t = type("T", (), {
89
- "prompt": f"prompt_{i}",
90
- "completion": f"candidate_completion_{i}",
91
- "reward": 0.6 + np.random.rand() * 0.3,
92
- "compute_cost": 70.0,
93
- "mode": "code",
94
- "metadata": {},
95
- })()
96
- candidate_trajs.append(t)
97
-
98
- comparator = OfflineComparator()
99
- comparator.save_baseline(baseline_trajs, "/app/occ/reports/baseline_trajectories.jsonl")
100
-
101
- comparator2 = OfflineComparator("/app/occ/reports/baseline_trajectories.jsonl")
102
- result = comparator2.compare(candidate_trajs)
103
-
104
- print("\nOffline Comparison Demo")
105
  print(json.dumps(result, indent=2, default=str))
106
  return result
107
 
108
 
109
  if __name__ == "__main__":
110
- Path("/app/occ/reports").mkdir(parents=True, exist_ok=True)
111
- demo_grpo_hook()
112
- print()
113
- demo_offline_comparison()
 
1
  """
2
+ GRPO-compatible reward hook for TRL.
3
 
4
+ This module provides a reward function factory that wraps the OCC
5
+ ImpactOracle into a TRL GRPOTrainer-compatible callable.
6
+
7
+ Usage with TRL::
8
+
9
+ from grpo_hook import make_occ_reward_func
10
+ from trl import GRPOTrainer
11
+
12
+ reward_fn = make_occ_reward_func(mode="code", compute_budget=1e5)
13
+ trainer = GRPOTrainer(
14
+ model="Qwen/Qwen2.5-0.5B-Instruct",
15
+ reward_funcs=reward_fn,
16
+ train_dataset=ds, # must have a "prompt" column
17
+ )
18
+
19
+ The reward function signature follows TRL conventions:
20
+ def reward_fn(completions, **kwargs) -> list[float]
21
  """
22
 
23
  import json
24
  from pathlib import Path
25
  from typing import Dict, List, Optional
26
 
 
 
27
  from oracle.oracle import ImpactOracle
28
  from ledger.ledger import CreditLedger
29
  from broker.broker import ResourceBroker
30
+ from rl.reward import RewardHook, OfflinePolicyComparator
31
 
32
 
33
+ def make_occ_reward_func(
34
+ mode: str = "retrieval_qa",
35
+ compute_budget: float = 1e5,
36
+ qa_weights: Optional[Dict] = None,
37
+ code_weights: Optional[Dict] = None,
38
+ debate_weights: Optional[Dict] = None,
39
+ ) -> callable:
40
  """
41
+ Factory for a TRL-compatible reward function.
42
+
43
+ Returns a function with signature (completions, **kwargs) -> list[float].
44
  """
45
+ oracle = ImpactOracle(
46
+ compute_budget=compute_budget,
47
+ qa_weights=qa_weights,
48
+ code_weights=code_weights,
49
+ debate_weights=debate_weights,
50
+ )
51
+ hook = RewardHook(oracle=oracle, mode=mode)
52
+
53
+ def _reward_fn(completions, **kwargs):
54
+ """
55
+ TRL calls this with completions as list[str] (standard format)
56
+ or list[list[dict]] (conversational format).
57
+ We extract text and look for answer tags.
58
+ """
59
+ texts = []
60
+ for comp in completions:
61
+ if isinstance(comp, list) and len(comp) > 0 and isinstance(comp[0], dict):
62
+ # Conversational format: [{"role":"assistant","content":"..."}]
63
+ texts.append(comp[0].get("content", ""))
64
+ elif isinstance(comp, str):
65
+ texts.append(comp)
66
+ else:
67
+ texts.append(str(comp))
68
+
69
+ answers = []
70
+ confidences = []
71
+ compute_costs = []
72
+
73
+ for txt in texts:
74
+ if "<answer>" in txt and "</answer>" in txt:
75
+ start = txt.find("<answer>") + len("<answer>")
76
+ end = txt.find("</answer>")
77
+ ans = txt[start:end].strip()
78
+ else:
79
+ # Fallback: last token or empty
80
+ parts = txt.strip().split()
81
+ ans = parts[-1] if parts else ""
82
+ answers.append(ans)
83
+ confidences.append(0.7 if len(ans) > 0 else 0.3)
84
+ compute_costs.append(len(txt.split()))
85
+
86
+ gold_answers = kwargs.get("answers", [""] * len(texts))
87
+ if not gold_answers:
88
+ gold_answers = [""] * len(texts)
89
+
90
+ rewards = hook.compute_rewards(
91
+ prompts=kwargs.get("prompts", [""] * len(texts)),
92
+ completions=texts,
93
+ answers=answers,
94
+ gold_answers=gold_answers,
95
+ confidences=confidences,
96
+ compute_costs=compute_costs,
97
+ agent_ids=kwargs.get("agent_ids", None),
98
+ )
99
+ return rewards
100
+
101
+ return _reward_fn
102
+
103
+
104
+ def demo_offline():
105
+ """Offline comparison of two policies using the reward hook."""
106
+ hook = RewardHook(oracle=ImpactOracle(compute_budget=1e5), mode="retrieval_qa")
107
+ comparator = OfflinePolicyComparator(reward_hook=hook)
108
+
109
+ policy_a = [
110
+ {"reward": 0.5 + i * 0.02, "failure_tags": []}
111
+ for i in range(10)
112
  ]
113
+ policy_b = [
114
+ {"reward": 0.4 + i * 0.01, "failure_tags": []}
115
+ for i in range(10)
 
 
 
 
 
 
116
  ]
 
 
 
117
 
118
+ result = comparator.compare(policy_a, policy_b)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  print(json.dumps(result, indent=2, default=str))
120
  return result
121
 
122
 
123
  if __name__ == "__main__":
124
+ demo_offline()