File size: 3,882 Bytes
f824006
f4a0835
f824006
f4a0835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f824006
 
 
 
 
 
 
 
 
f4a0835
f824006
 
f4a0835
 
 
 
 
 
 
f824006
f4a0835
 
 
f824006
f4a0835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f824006
f4a0835
 
 
f824006
 
f4a0835
f824006
 
 
 
 
f4a0835
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
GRPO-compatible reward hook for TRL.

This module provides a reward function factory that wraps the OCC
ImpactOracle into a TRL GRPOTrainer-compatible callable.

Usage with TRL::

    from grpo_hook import make_occ_reward_func
    from trl import GRPOTrainer

    reward_fn = make_occ_reward_func(mode="code", compute_budget=1e5)
    trainer = GRPOTrainer(
        model="Qwen/Qwen2.5-0.5B-Instruct",
        reward_funcs=reward_fn,
        train_dataset=ds,   # must have a "prompt" column
    )

The reward function signature follows TRL conventions:
    def reward_fn(completions, **kwargs) -> list[float]
"""

import json
from pathlib import Path
from typing import Dict, List, Optional

from oracle.oracle import ImpactOracle
from ledger.ledger import CreditLedger
from broker.broker import ResourceBroker
from rl.reward import RewardHook, OfflinePolicyComparator


def make_occ_reward_func(
    mode: str = "retrieval_qa",
    compute_budget: float = 1e5,
    qa_weights: Optional[Dict] = None,
    code_weights: Optional[Dict] = None,
    debate_weights: Optional[Dict] = None,
) -> callable:
    """
    Factory for a TRL-compatible reward function.

    Returns a function with signature (completions, **kwargs) -> list[float].
    """
    oracle = ImpactOracle(
        compute_budget=compute_budget,
        qa_weights=qa_weights,
        code_weights=code_weights,
        debate_weights=debate_weights,
    )
    hook = RewardHook(oracle=oracle, mode=mode)

    def _reward_fn(completions, **kwargs):
        """
        TRL calls this with completions as list[str] (standard format)
        or list[list[dict]] (conversational format).
        We extract text and look for answer tags.
        """
        texts = []
        for comp in completions:
            if isinstance(comp, list) and len(comp) > 0 and isinstance(comp[0], dict):
                # Conversational format: [{"role":"assistant","content":"..."}]
                texts.append(comp[0].get("content", ""))
            elif isinstance(comp, str):
                texts.append(comp)
            else:
                texts.append(str(comp))

        answers = []
        confidences = []
        compute_costs = []

        for txt in texts:
            if "<answer>" in txt and "</answer>" in txt:
                start = txt.find("<answer>") + len("<answer>")
                end = txt.find("</answer>")
                ans = txt[start:end].strip()
            else:
                # Fallback: last token or empty
                parts = txt.strip().split()
                ans = parts[-1] if parts else ""
            answers.append(ans)
            confidences.append(0.7 if len(ans) > 0 else 0.3)
            compute_costs.append(len(txt.split()))

        gold_answers = kwargs.get("answers", [""] * len(texts))
        if not gold_answers:
            gold_answers = [""] * len(texts)

        rewards = hook.compute_rewards(
            prompts=kwargs.get("prompts", [""] * len(texts)),
            completions=texts,
            answers=answers,
            gold_answers=gold_answers,
            confidences=confidences,
            compute_costs=compute_costs,
            agent_ids=kwargs.get("agent_ids", None),
        )
        return rewards

    return _reward_fn


def demo_offline():
    """Offline comparison of two policies using the reward hook."""
    hook = RewardHook(oracle=ImpactOracle(compute_budget=1e5), mode="retrieval_qa")
    comparator = OfflinePolicyComparator(reward_hook=hook)

    policy_a = [
        {"reward": 0.5 + i * 0.02, "failure_tags": []}
        for i in range(10)
    ]
    policy_b = [
        {"reward": 0.4 + i * 0.01, "failure_tags": []}
        for i in range(10)
    ]

    result = comparator.compare(policy_a, policy_b)
    print(json.dumps(result, indent=2, default=str))
    return result


if __name__ == "__main__":
    demo_offline()