File size: 6,672 Bytes
dbc69f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import annotations

from .registry import register_objective


@register_objective("token_grpo")
class TokenGRPOObjective:
    name = "token_grpo"

    def __init__(
        self,
        enable_length_penalty: bool = True,
        enable_token_utilisation_reward: bool = False,
        reward_mode: str = "additive",
        correctness_weight: float = 1.0,
        length_penalty_lambda: float = 1.0,
        stage2_length_penalty_lambda: float | None = None,
        stage2_start_epoch: float = 1.0,
        length_penalty_interaction: str = "correctness_length",
        strict_format_gate: bool = True,
        non_strict_penalty: float = -1.0,
        **kwargs,
    ):
        del kwargs
        self.enable_length_penalty = bool(enable_length_penalty)
        self.enable_token_utilisation_reward = bool(enable_token_utilisation_reward)
        self.reward_mode = reward_mode
        self.correctness_weight = float(correctness_weight)
        self.length_penalty_lambda = float(length_penalty_lambda)
        self.stage2_length_penalty_lambda = (
            None
            if stage2_length_penalty_lambda is None
            else float(stage2_length_penalty_lambda)
        )
        self.stage2_start_epoch = float(stage2_start_epoch)
        self.length_penalty_interaction = length_penalty_interaction
        self.strict_format_gate = bool(strict_format_gate)
        self.non_strict_penalty = float(non_strict_penalty)
        allowed_modes = {"additive", "weighted_length_penalty"}
        if self.reward_mode not in allowed_modes:
            allowed = ", ".join(sorted(allowed_modes))
            raise ValueError(
                f"Unknown reward_mode '{self.reward_mode}'. Allowed: {allowed}"
            )
        allowed_interactions = {"correctness_length", "correctness_length_format"}
        if self.length_penalty_interaction not in allowed_interactions:
            allowed = ", ".join(sorted(allowed_interactions))
            raise ValueError(
                "Unknown length_penalty_interaction "
                f"'{self.length_penalty_interaction}'. Allowed: {allowed}"
            )

    def reward_names(self) -> list[str]:
        names = [
            "format_tag_reward",
            "gsm8k_correctness_reward",
        ]
        if self.enable_length_penalty:
            names.append("length_penalty_reward")
        if self.enable_token_utilisation_reward:
            names.append("token_utilisation_reward")
        return names

    def combine_rewards(
        self, reward_outputs: dict[str, list[float]], objective_extra: list[float]
    ) -> list[float]:
        format_scores = reward_outputs.get("format_tag_reward", [])
        del objective_extra

        if self.reward_mode == "additive":
            sources = list(reward_outputs.values())
            if not sources:
                return []
            width = len(sources[0])
            totals = [0.0] * width
            for scores in sources:
                if len(scores) != width:
                    raise ValueError("Reward functions returned inconsistent lengths.")
                totals = [a + b for a, b in zip(totals, scores, strict=True)]
            if self.strict_format_gate:
                gated: list[float] = []
                for idx, total in enumerate(totals):
                    if float(format_scores[idx]) > 0.5:
                        gated.append(float(total))
                    else:
                        gated.append(self.non_strict_penalty)
                return gated
            return totals

        # weighted_length_penalty:
        # default: R_total = R_correctness + lambda * (R_correctness * R_length) + R_format
        # alt:     R_total = R_correctness + lambda * (R_correctness * R_length * R_format) + R_format
        correctness_scores = reward_outputs.get("gsm8k_correctness_reward", [])
        if len(format_scores) != len(correctness_scores):
            raise ValueError("format_tag_reward and gsm8k_correctness_reward must align.")

        if self.enable_length_penalty:
            length_scores = reward_outputs.get(
                "length_penalty_reward", [0.0] * len(correctness_scores)
            )
            if len(length_scores) != len(correctness_scores):
                raise ValueError("length_penalty_reward must align with core rewards.")
        else:
            length_scores = [0.0] * len(correctness_scores)
        token_utilisation_scores = reward_outputs.get(
            "token_utilisation_reward", [0.0] * len(correctness_scores)
        )
        if len(token_utilisation_scores) != len(correctness_scores):
            raise ValueError("token_utilisation_reward must align with core rewards.")

        totals: list[float] = []
        for idx in range(len(correctness_scores)):
            r_correctness = float(correctness_scores[idx])
            r_format = float(format_scores[idx])
            r_length = float(length_scores[idx])
            if self.strict_format_gate and r_format <= 0.5:
                totals.append(self.non_strict_penalty)
                continue
            interaction = r_correctness * r_length
            if self.length_penalty_interaction == "correctness_length_format":
                interaction *= r_format
            total = (
                self.correctness_weight * r_correctness
                + self.length_penalty_lambda * interaction
                + r_format
                + float(token_utilisation_scores[idx])
            )
            totals.append(total)
        return totals

    def extra_reward(
        self,
        prompts: list[str],
        completions: list[str],
        references: list[str],
        metadata: list[dict],
    ) -> list[float]:
        del prompts, references, metadata
        return [0.0 for _ in completions]


@register_objective("latent_neuralese")
class LatentNeuraleseObjective:
    name = "latent_neuralese"

    def __init__(self, representation_key: str = "hidden_state", **kwargs):
        self.representation_key = representation_key
        self.kwargs = kwargs

    def reward_names(self) -> list[str]:
        # Keep token-level rewards while you bootstrap latent objectives.
        return ["format_tag_reward"]

    def extra_reward(
        self,
        prompts: list[str],
        completions: list[str],
        references: list[str],
        metadata: list[dict],
    ) -> list[float]:
        del prompts, references, metadata
        # Baseline no-op: replace with latent-space scoring from activations.
        # This is intentionally isolated so you can iterate without touching trainer internals.
        return [0.0 for _ in completions]