File size: 3,882 Bytes
f824006 f4a0835 f824006 f4a0835 f824006 f4a0835 f824006 f4a0835 f824006 f4a0835 f824006 f4a0835 f824006 f4a0835 f824006 f4a0835 f824006 f4a0835 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | """
GRPO-compatible reward hook for TRL.
This module provides a reward function factory that wraps the OCC
ImpactOracle into a TRL GRPOTrainer-compatible callable.
Usage with TRL::
from grpo_hook import make_occ_reward_func
from trl import GRPOTrainer
reward_fn = make_occ_reward_func(mode="code", compute_budget=1e5)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=reward_fn,
train_dataset=ds, # must have a "prompt" column
)
The reward function signature follows TRL conventions:
def reward_fn(completions, **kwargs) -> list[float]
"""
import json
from pathlib import Path
from typing import Dict, List, Optional
from oracle.oracle import ImpactOracle
from ledger.ledger import CreditLedger
from broker.broker import ResourceBroker
from rl.reward import RewardHook, OfflinePolicyComparator
def make_occ_reward_func(
mode: str = "retrieval_qa",
compute_budget: float = 1e5,
qa_weights: Optional[Dict] = None,
code_weights: Optional[Dict] = None,
debate_weights: Optional[Dict] = None,
) -> callable:
"""
Factory for a TRL-compatible reward function.
Returns a function with signature (completions, **kwargs) -> list[float].
"""
oracle = ImpactOracle(
compute_budget=compute_budget,
qa_weights=qa_weights,
code_weights=code_weights,
debate_weights=debate_weights,
)
hook = RewardHook(oracle=oracle, mode=mode)
def _reward_fn(completions, **kwargs):
"""
TRL calls this with completions as list[str] (standard format)
or list[list[dict]] (conversational format).
We extract text and look for answer tags.
"""
texts = []
for comp in completions:
if isinstance(comp, list) and len(comp) > 0 and isinstance(comp[0], dict):
# Conversational format: [{"role":"assistant","content":"..."}]
texts.append(comp[0].get("content", ""))
elif isinstance(comp, str):
texts.append(comp)
else:
texts.append(str(comp))
answers = []
confidences = []
compute_costs = []
for txt in texts:
if "<answer>" in txt and "</answer>" in txt:
start = txt.find("<answer>") + len("<answer>")
end = txt.find("</answer>")
ans = txt[start:end].strip()
else:
# Fallback: last token or empty
parts = txt.strip().split()
ans = parts[-1] if parts else ""
answers.append(ans)
confidences.append(0.7 if len(ans) > 0 else 0.3)
compute_costs.append(len(txt.split()))
gold_answers = kwargs.get("answers", [""] * len(texts))
if not gold_answers:
gold_answers = [""] * len(texts)
rewards = hook.compute_rewards(
prompts=kwargs.get("prompts", [""] * len(texts)),
completions=texts,
answers=answers,
gold_answers=gold_answers,
confidences=confidences,
compute_costs=compute_costs,
agent_ids=kwargs.get("agent_ids", None),
)
return rewards
return _reward_fn
def demo_offline():
"""Offline comparison of two policies using the reward hook."""
hook = RewardHook(oracle=ImpactOracle(compute_budget=1e5), mode="retrieval_qa")
comparator = OfflinePolicyComparator(reward_hook=hook)
policy_a = [
{"reward": 0.5 + i * 0.02, "failure_tags": []}
for i in range(10)
]
policy_b = [
{"reward": 0.4 + i * 0.01, "failure_tags": []}
for i in range(10)
]
result = comparator.compare(policy_a, policy_b)
print(json.dumps(result, indent=2, default=str))
return result
if __name__ == "__main__":
demo_offline()
|