narcolepticchicken commited on
Commit
f824006
·
verified ·
1 Parent(s): bc02d39

Upload grpo_hook.py

Browse files
Files changed (1) hide show
  1. grpo_hook.py +113 -0
grpo_hook.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal GRPO-compatible reward hook demonstration.
3
+
4
+ If full GRPO training is feasible, use this with TRL GRPOTrainer.
5
+ If not, use OfflineComparator for policy evaluation.
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional
11
+
12
+ import numpy as np
13
+
14
+ from oracle.oracle import ImpactOracle
15
+ from ledger.ledger import CreditLedger
16
+ from broker.broker import ResourceBroker
17
+ from rl.reward import RewardHook, OfflineComparator
18
+
19
+
20
+ def demo_grpo_hook():
21
+ """
22
+ Demonstrate the reward hook with synthetic completions.
23
+ This is a toy loop showing how GRPO reward computation would work.
24
+ """
25
+ oracle = ImpactOracle(compute_budget=1e5)
26
+ ledger = CreditLedger(decay_lambda=0.05)
27
+ broker = ResourceBroker()
28
+ hook = RewardHook(oracle, ledger, broker, mode="code", agent_id="demo_agent")
29
+
30
+ # Simulate a group of completions (as in GRPO)
31
+ prompts = [
32
+ "def add(a, b):\n return",
33
+ "def add(a, b):\n return",
34
+ "def add(a, b):\n return",
35
+ ]
36
+ completions = [
37
+ "a + b",
38
+ "a * b",
39
+ "a + b + 0",
40
+ ]
41
+ oracle_inputs = [
42
+ {
43
+ "action": {"text": c},
44
+ "context": {"previous_passed": False},
45
+ "result": {"passed": True, "hidden_passed": True, "compute_cost": 5.0},
46
+ "task_id": "task_1",
47
+ "action_id": f"comp_{i}",
48
+ }
49
+ for i, c in enumerate(completions)
50
+ ]
51
+ # Fix the wrong one
52
+ oracle_inputs[1]["result"]["passed"] = False
53
+ oracle_inputs[1]["result"]["hidden_passed"] = False
54
+
55
+ rewards = hook.compute_rewards(prompts, completions, oracle_inputs)
56
+ print("GRPO Hook Demo")
57
+ print("Prompts:", prompts)
58
+ print("Completions:", completions)
59
+ print("Rewards:", rewards)
60
+
61
+ # Save trajectories for offline comparison
62
+ hook.save_trajectories("/app/occ/reports/demo_trajectories.jsonl")
63
+ print("Saved trajectories to reports/demo_trajectories.jsonl")
64
+
65
+ return hook
66
+
67
+
68
+ def demo_offline_comparison():
69
+ """
70
+ Compare two policies using offline trajectory comparison.
71
+ """
72
+ # Create baseline policy trajectories
73
+ baseline_trajs = []
74
+ for i in range(10):
75
+ t = type("T", (), {
76
+ "prompt": f"prompt_{i}",
77
+ "completion": f"baseline_completion_{i}",
78
+ "reward": 0.5 + np.random.rand() * 0.3,
79
+ "compute_cost": 100.0,
80
+ "mode": "code",
81
+ "metadata": {},
82
+ })()
83
+ baseline_trajs.append(t)
84
+
85
+ # Create candidate policy trajectories
86
+ candidate_trajs = []
87
+ for i in range(10):
88
+ t = type("T", (), {
89
+ "prompt": f"prompt_{i}",
90
+ "completion": f"candidate_completion_{i}",
91
+ "reward": 0.6 + np.random.rand() * 0.3,
92
+ "compute_cost": 70.0,
93
+ "mode": "code",
94
+ "metadata": {},
95
+ })()
96
+ candidate_trajs.append(t)
97
+
98
+ comparator = OfflineComparator()
99
+ comparator.save_baseline(baseline_trajs, "/app/occ/reports/baseline_trajectories.jsonl")
100
+
101
+ comparator2 = OfflineComparator("/app/occ/reports/baseline_trajectories.jsonl")
102
+ result = comparator2.compare(candidate_trajs)
103
+
104
+ print("\nOffline Comparison Demo")
105
+ print(json.dumps(result, indent=2, default=str))
106
+ return result
107
+
108
+
109
+ if __name__ == "__main__":
110
+ Path("/app/occ/reports").mkdir(parents=True, exist_ok=True)
111
+ demo_grpo_hook()
112
+ print()
113
+ demo_offline_comparison()