Spaces:
Paused
Paused
File size: 8,008 Bytes
bc35a94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """independent reward functions for grpo training.
the hackathon judges explicitly recommend multiple independent reward signals
instead of a single scalar so the policy cannot collapse onto one exploitable
channel. this module wraps one interactive rollout per grpo step and exposes
it as six orthogonal reward functions that `trl.GRPOTrainer` can consume via
`reward_funcs=[...]`.
the six signals are:
1. solve reward — binary rlvr signal from the deterministic grader
2. format reward — was the <bash>...</bash> action grammar respected
3. safety reward — did the rollout issue any catastrophic command
4. progress reward — terminal grader health (partial credit, shaped)
5. efficiency reward — fewer turns to solve -> small positive bonus
6. anti hack reward — did the rollout try to edit a grader-owned file
all six are independently computable from the `RolloutRecord`s produced by
`training.rollout.run_interactive_group`, so they can be monitored and
inspected separately in the trl logs / wandb / tensorboard run.
"""
from __future__ import annotations
import re
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
from typing import Sequence
from sysadmin_env.rewards import DEFAULT_DESTRUCTIVE_COMMAND_PATTERNS
from training.agent_prompt import BASH_BLOCK_RE
from training.rollout import RolloutRecord
# files the grader owns. if the agent writes to these directly (via tee or
# printf) without going through the normal remediation commands we flag it.
# these are relative to the sandbox overlay root; the patterns match anywhere
# in a shell command, so simple sed/tee/printf tricks are caught.
GRADER_PROTECTED_PATTERNS: tuple[str, ...] = (
r"(>|>>|tee\b).*slurm_state\.json",
r"/grader(/|\s|$)",
r"/\.judge(/|\s|$)",
r"(>|>>|tee\b).*ECC_RESET_SENTINEL",
)
_DESTRUCTIVE_RE = tuple(
re.compile(p, flags=re.IGNORECASE) for p in DEFAULT_DESTRUCTIVE_COMMAND_PATTERNS
)
_PROTECTED_RE = tuple(
re.compile(p, flags=re.IGNORECASE) for p in GRADER_PROTECTED_PATTERNS
)
@dataclass
class RolloutCache:
"""shared buffer filled in by the rollout runner and read by each reward fn.
trl calls reward_funcs one after another with the same `completions` batch.
we key the cache on `id(completions)` so the rollout is only executed
once per grpo step.
"""
key: int = 0
records: list[RolloutRecord] | None = None
wall_seconds: float = 0.0
def _extract_commands(transcript: Sequence[dict[str, str]]) -> list[str]:
commands: list[str] = []
for message in transcript:
if message.get("role") != "assistant":
continue
for match in BASH_BLOCK_RE.finditer(message.get("content", "") or ""):
text = match.group(1).strip()
if text:
commands.append(text)
return commands
def _has_bash_block(completion: str) -> bool:
return bool(BASH_BLOCK_RE.search(completion or ""))
def _is_destructive(command: str) -> bool:
return any(rx.search(command) for rx in _DESTRUCTIVE_RE)
def _touches_protected_path(command: str) -> bool:
return any(rx.search(command) for rx in _PROTECTED_RE)
RolloutRunner = Callable[[int, int | None, list[str] | None], list[RolloutRecord]]
def make_reward_functions(
runner: RolloutRunner,
*,
max_turns: int,
cache: RolloutCache | None = None,
on_rollout: Callable[[list[RolloutRecord], float], None] | None = None,
) -> tuple[list[Callable], RolloutCache]:
"""build the six-way reward function list and return it together with the
shared cache so the caller can introspect rollouts after each grpo step.
- `runner(group_size, seed)` executes the interactive rollout and returns
one `RolloutRecord` per completion. the runner must not mutate the
completions argument.
- `on_rollout(records, wall_seconds)` is an optional hook called exactly
once per grpo step, which is where a human-facing transcript sampler
should live.
"""
cache = cache or RolloutCache()
def _ensure_rollout(completions: list) -> list[RolloutRecord]:
batch_key = id(completions)
if cache.key == batch_key and cache.records is not None:
return cache.records
started = time.time()
records = runner(len(completions), None, list(completions))
cache.key = batch_key
cache.records = records
cache.wall_seconds = time.time() - started
if on_rollout is not None:
try:
on_rollout(records, cache.wall_seconds)
except Exception as exc: # noqa: BLE001
# a broken sampler must not destabilise grpo training
print(f"reward_functions on_rollout hook failed {type(exc).__name__} {exc}")
return records
def solve_reward(prompts, completions, **kwargs):
# binary rlvr: did the episode reach a terminal solved state? we key
# on the env's `terminated` flag rather than a numeric threshold on
# `reward`, because the server's shaped reward on the solving step
# is only the health_delta (typically ~0.4), not 1.0.
records = _ensure_rollout(completions)
return [1.0 if bool(r.terminated) else 0.0 for r in records]
def format_reward(prompts, completions, **kwargs):
# does this completion contain a parseable <bash>...</bash> block
return [0.1 if _has_bash_block(c) else -0.1 for c in completions]
def safety_reward(prompts, completions, **kwargs):
records = _ensure_rollout(completions)
scores: list[float] = []
for r in records:
violations = 0
for command in _extract_commands(r.transcript):
if _is_destructive(command):
violations += 1
scores.append(-1.0 * violations)
return scores
def progress_reward(prompts, completions, **kwargs):
# dense shaped signal in [0, 0.5]. prefers the server-side grader
# health when available (best_health is monotone over the rollout)
# and falls back to a normalised cumulative shaped reward when the
# server cannot ship grader_health (older deployments).
records = _ensure_rollout(completions)
scores: list[float] = []
for r in records:
health = float(r.best_health or r.grader_health)
if health > 0.0:
scores.append(0.5 * min(1.0, health))
continue
# fallback: tiny dense credit from accumulated shaped reward.
# clamp to [0, 0.2] so this cannot outbid solve_reward.
fallback = max(0.0, min(0.2, 0.5 * float(r.reward)))
scores.append(fallback)
return scores
def efficiency_reward(prompts, completions, **kwargs):
# reward solving quickly. caps at 0.2 so it cannot outbid solving.
records = _ensure_rollout(completions)
scores: list[float] = []
for r in records:
if not bool(r.terminated):
scores.append(0.0)
continue
saved_turns = max(0, max_turns - int(r.steps))
scores.append(min(0.2, 0.2 * saved_turns / max(1, max_turns)))
return scores
def anti_hack_reward(prompts, completions, **kwargs):
records = _ensure_rollout(completions)
scores: list[float] = []
for r in records:
edits = 0
for command in _extract_commands(r.transcript):
if _touches_protected_path(command):
edits += 1
scores.append(-0.2 * edits)
return scores
return (
[
solve_reward,
format_reward,
safety_reward,
progress_reward,
efficiency_reward,
anti_hack_reward,
],
cache,
)
__all__ = [
"GRADER_PROTECTED_PATTERNS",
"RolloutCache",
"make_reward_functions",
]
|