File size: 4,790 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Reward callables for TRL GRPO training.

These helpers consume rollout metadata and return one float reward per
completion, matching TRL reward function expectations.
"""

from typing import Any


def _coerce_bool(value: Any) -> bool:
    """Convert common truthy/falsey values to bool."""

    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return value != 0
    if isinstance(value, str):
        normalized = value.strip().lower()
        if normalized in {"true", "1", "yes", "y"}:
            return True
        if normalized in {"false", "0", "no", "n", ""}:
            return False
    return bool(value)


def _coerce_float(value: Any, default: float = 0.0) -> float:
    """Convert numeric-like values to float with fallback."""

    try:
        return float(value)
    except (TypeError, ValueError):
        return default


def _clamp(value: float, low: float, high: float) -> float:
    """Clamp value to the closed interval [low, high]."""

    return max(low, min(high, value))


def _extract_metadata_rows(
    completions: list[list[dict[str, str]]],
    **kwargs: Any,
) -> list[dict[str, Any]]:
    """Resolve one metadata dict per completion.

    TRL can pass rollout metadata in different shapes depending on wrapper code.
    We support the common variants:
    - ``kwargs['metadata']`` as list[dict]
    - ``kwargs['metadata']`` as dict containing list-valued keys
    - flattened keys like ``correct``, ``progress``, ``operational``
    - fallback to empty dict when metadata is unavailable
    """

    batch_size = len(completions)

    metadata_kw = kwargs.get("metadata")
    if isinstance(metadata_kw, list):
        rows: list[dict[str, Any]] = []
        for idx in range(batch_size):
            entry = metadata_kw[idx] if idx < len(metadata_kw) else {}
            rows.append(entry if isinstance(entry, dict) else {})
        return rows

    if isinstance(metadata_kw, dict):
        rows = []
        for idx in range(batch_size):
            row: dict[str, Any] = {}
            for key, value in metadata_kw.items():
                if isinstance(value, list):
                    row[key] = value[idx] if idx < len(value) else None
                else:
                    row[key] = value
            rows.append(row)
        return rows

    rows = []
    for idx in range(batch_size):
        row = {}
        for key in (
            "answer_correct",
            "correct",
            "cumulative_progress",
            "progress",
            "operational_signals",
            "operational",
        ):
            value = kwargs.get(key)
            if isinstance(value, list):
                row[key] = value[idx] if idx < len(value) else None
            elif value is not None:
                row[key] = value
        rows.append(row)
    return rows


def reward_correctness(
    completions: list[list[dict[str, str]]],
    **kwargs: Any,
) -> list[float]:
    """Binary reward: 1.0 for correct terminal answer, else 0.0."""

    metadata_rows = _extract_metadata_rows(completions, **kwargs)
    rewards: list[float] = []
    for row in metadata_rows:
        is_correct = _coerce_bool(row.get("answer_correct", row.get("correct", False)))
        rewards.append(1.0 if is_correct else 0.0)
    return rewards


def reward_progress(
    completions: list[list[dict[str, str]]],
    **kwargs: Any,
) -> list[float]:
    """Progress reward normalized to [0.0, 1.0]."""

    metadata_rows = _extract_metadata_rows(completions, **kwargs)
    rewards: list[float] = []
    for row in metadata_rows:
        raw = row.get("cumulative_progress", row.get("progress", 0.0))
        rewards.append(_clamp(_coerce_float(raw, default=0.0), 0.0, 1.0))
    return rewards


def reward_operational(
    completions: list[list[dict[str, str]]],
    **kwargs: Any,
) -> list[float]:
    """Operational reward from per-step L1-style rollout signals."""

    metadata_rows = _extract_metadata_rows(completions, **kwargs)
    rewards: list[float] = []
    for row in metadata_rows:
        signals = row.get("operational_signals")
        if isinstance(signals, list) and signals:
            score = 0.0
            for signal in signals:
                if not isinstance(signal, dict):
                    continue
                if _coerce_bool(signal.get("exec_ok", False)):
                    score += 1.0
                if _coerce_bool(signal.get("new_info", False)):
                    score += 1.0
                if _coerce_bool(signal.get("repeat", False)):
                    score -= 1.0
            rewards.append(float(score))
            continue

        fallback = row.get("operational", 0.0)
        rewards.append(_coerce_float(fallback, default=0.0))
    return rewards