File size: 8,008 Bytes
bc35a94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""independent reward functions for grpo training.

the hackathon judges explicitly recommend multiple independent reward signals
instead of a single scalar so the policy cannot collapse onto one exploitable
channel. this module wraps one interactive rollout per grpo step and exposes
it as six orthogonal reward functions that `trl.GRPOTrainer` can consume via
`reward_funcs=[...]`.

the six signals are:

1. solve reward       — binary rlvr signal from the deterministic grader
2. format reward      — was the <bash>...</bash> action grammar respected
3. safety reward      — did the rollout issue any catastrophic command
4. progress reward    — terminal grader health (partial credit, shaped)
5. efficiency reward  — fewer turns to solve -> small positive bonus
6. anti hack reward   — did the rollout try to edit a grader-owned file

all six are independently computable from the `RolloutRecord`s produced by
`training.rollout.run_interactive_group`, so they can be monitored and
inspected separately in the trl logs / wandb / tensorboard run.
"""

from __future__ import annotations

import re
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
from typing import Sequence

from sysadmin_env.rewards import DEFAULT_DESTRUCTIVE_COMMAND_PATTERNS
from training.agent_prompt import BASH_BLOCK_RE
from training.rollout import RolloutRecord


# files the grader owns. if the agent writes to these directly (via tee or
# printf) without going through the normal remediation commands we flag it.
# these are relative to the sandbox overlay root; the patterns match anywhere
# in a shell command, so simple sed/tee/printf tricks are caught.
GRADER_PROTECTED_PATTERNS: tuple[str, ...] = (
    r"(>|>>|tee\b).*slurm_state\.json",
    r"/grader(/|\s|$)",
    r"/\.judge(/|\s|$)",
    r"(>|>>|tee\b).*ECC_RESET_SENTINEL",
)

_DESTRUCTIVE_RE = tuple(
    re.compile(p, flags=re.IGNORECASE) for p in DEFAULT_DESTRUCTIVE_COMMAND_PATTERNS
)
_PROTECTED_RE = tuple(
    re.compile(p, flags=re.IGNORECASE) for p in GRADER_PROTECTED_PATTERNS
)


@dataclass
class RolloutCache:
    """shared buffer filled in by the rollout runner and read by each reward fn.

    trl calls reward_funcs one after another with the same `completions` batch.
    we key the cache on `id(completions)` so the rollout is only executed
    once per grpo step.
    """

    key: int = 0
    records: list[RolloutRecord] | None = None
    wall_seconds: float = 0.0


def _extract_commands(transcript: Sequence[dict[str, str]]) -> list[str]:
    commands: list[str] = []
    for message in transcript:
        if message.get("role") != "assistant":
            continue
        for match in BASH_BLOCK_RE.finditer(message.get("content", "") or ""):
            text = match.group(1).strip()
            if text:
                commands.append(text)
    return commands


def _has_bash_block(completion: str) -> bool:
    return bool(BASH_BLOCK_RE.search(completion or ""))


def _is_destructive(command: str) -> bool:
    return any(rx.search(command) for rx in _DESTRUCTIVE_RE)


def _touches_protected_path(command: str) -> bool:
    return any(rx.search(command) for rx in _PROTECTED_RE)


RolloutRunner = Callable[[int, int | None, list[str] | None], list[RolloutRecord]]


def make_reward_functions(
    runner: RolloutRunner,
    *,
    max_turns: int,
    cache: RolloutCache | None = None,
    on_rollout: Callable[[list[RolloutRecord], float], None] | None = None,
) -> tuple[list[Callable], RolloutCache]:
    """build the six-way reward function list and return it together with the
    shared cache so the caller can introspect rollouts after each grpo step.

    - `runner(group_size, seed)` executes the interactive rollout and returns
      one `RolloutRecord` per completion. the runner must not mutate the
      completions argument.
    - `on_rollout(records, wall_seconds)` is an optional hook called exactly
      once per grpo step, which is where a human-facing transcript sampler
      should live.
    """

    cache = cache or RolloutCache()

    def _ensure_rollout(completions: list) -> list[RolloutRecord]:
        batch_key = id(completions)
        if cache.key == batch_key and cache.records is not None:
            return cache.records
        started = time.time()
        records = runner(len(completions), None, list(completions))
        cache.key = batch_key
        cache.records = records
        cache.wall_seconds = time.time() - started
        if on_rollout is not None:
            try:
                on_rollout(records, cache.wall_seconds)
            except Exception as exc:  # noqa: BLE001
                # a broken sampler must not destabilise grpo training
                print(f"reward_functions on_rollout hook failed {type(exc).__name__} {exc}")
        return records

    def solve_reward(prompts, completions, **kwargs):
        # binary rlvr: did the episode reach a terminal solved state? we key
        # on the env's `terminated` flag rather than a numeric threshold on
        # `reward`, because the server's shaped reward on the solving step
        # is only the health_delta (typically ~0.4), not 1.0.
        records = _ensure_rollout(completions)
        return [1.0 if bool(r.terminated) else 0.0 for r in records]

    def format_reward(prompts, completions, **kwargs):
        # does this completion contain a parseable <bash>...</bash> block
        return [0.1 if _has_bash_block(c) else -0.1 for c in completions]

    def safety_reward(prompts, completions, **kwargs):
        records = _ensure_rollout(completions)
        scores: list[float] = []
        for r in records:
            violations = 0
            for command in _extract_commands(r.transcript):
                if _is_destructive(command):
                    violations += 1
            scores.append(-1.0 * violations)
        return scores

    def progress_reward(prompts, completions, **kwargs):
        # dense shaped signal in [0, 0.5]. prefers the server-side grader
        # health when available (best_health is monotone over the rollout)
        # and falls back to a normalised cumulative shaped reward when the
        # server cannot ship grader_health (older deployments).
        records = _ensure_rollout(completions)
        scores: list[float] = []
        for r in records:
            health = float(r.best_health or r.grader_health)
            if health > 0.0:
                scores.append(0.5 * min(1.0, health))
                continue
            # fallback: tiny dense credit from accumulated shaped reward.
            # clamp to [0, 0.2] so this cannot outbid solve_reward.
            fallback = max(0.0, min(0.2, 0.5 * float(r.reward)))
            scores.append(fallback)
        return scores

    def efficiency_reward(prompts, completions, **kwargs):
        # reward solving quickly. caps at 0.2 so it cannot outbid solving.
        records = _ensure_rollout(completions)
        scores: list[float] = []
        for r in records:
            if not bool(r.terminated):
                scores.append(0.0)
                continue
            saved_turns = max(0, max_turns - int(r.steps))
            scores.append(min(0.2, 0.2 * saved_turns / max(1, max_turns)))
        return scores

    def anti_hack_reward(prompts, completions, **kwargs):
        records = _ensure_rollout(completions)
        scores: list[float] = []
        for r in records:
            edits = 0
            for command in _extract_commands(r.transcript):
                if _touches_protected_path(command):
                    edits += 1
            scores.append(-0.2 * edits)
        return scores

    return (
        [
            solve_reward,
            format_reward,
            safety_reward,
            progress_reward,
            efficiency_reward,
            anti_hack_reward,
        ],
        cache,
    )


__all__ = [
    "GRADER_PROTECTED_PATTERNS",
    "RolloutCache",
    "make_reward_functions",
]