Spaces:
Sleeping
feat(training): SFT dataset + stall detection in eval rollout
Browse filesTwo permanent fixes for the failure mode that produced eval=0 across
all GRPO checkpoints in the prior Colab run:
1. SFT dataset module (training/sft_dataset.py) — heuristic rollouts
captured as (prompt, oracle_completion) pairs for supervised
pre-training. Two-phase RLHF (SFT then GRPO) is the proven recipe
for teaching a base model the JSON schema and per-state action
variety before sparse RL reward kicks in. Pure GRPO from a base
model collapses to 'always emit select_case' because the model
never sees varied action_types in the prompt-only training data.
2. Stall detection in run_episode_with_text_policy — the dominant
failure was a checkpoint emitting select_case at every state. The
env silently no-ops the duplicate select_case (returns -0.02
reward, advances step_count), so the rollout burns its entire
step budget without flipping done. New code:
- Hard-coded predicate _predicted_noop catches the dominant case
(select_case when a case is already selected) before env.step,
so the model's wasted action doesn't consume an env step.
- Per-state tried_at_state cache catches less-common no-ops
post-hoc (model picks an already-attempted action_key at the
same state -> force fallback).
Tests:
- tests/test_sft_dataset.py (5 tests, action variety, JSON
round-trip, monotonic state_step, multi-task)
- tests/test_training_adapter.py: new test_run_episode_breaks_select_case_loop
pins the regression — degenerate model that always emits
select_case must reach terminal grading with score > 0.
105/105 tests pass.
- tests/test_sft_dataset.py +78 -0
- tests/test_training_adapter.py +28 -0
- training/__init__.py +8 -0
- training/reward_adapter.py +102 -3
- training/sft_dataset.py +107 -0
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for the SFT dataset builder.
|
| 2 |
+
|
| 3 |
+
The supervised pre-training stage feeds (prompt, oracle_completion)
|
| 4 |
+
pairs into the base model so it learns the JSON schema and per-state
|
| 5 |
+
action variety *before* GRPO. These tests pin the contract so the
|
| 6 |
+
notebook's SFT cell stays stable.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
from training.env_adapter import parse_completion
|
| 14 |
+
from training.sft_dataset import (
|
| 15 |
+
action_to_completion,
|
| 16 |
+
build_sft_dataset,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_action_to_completion_round_trips_through_parser():
|
| 21 |
+
"""Oracle completion must parse back into the same action dict."""
|
| 22 |
+
|
| 23 |
+
from runners.benchmark_runner import heuristic_policy
|
| 24 |
+
from server.chargeback_ops_environment import ChargebackOpsEnvironment
|
| 25 |
+
|
| 26 |
+
env = ChargebackOpsEnvironment()
|
| 27 |
+
obs = env.reset(task_id="goods_not_received_easy")
|
| 28 |
+
action = heuristic_policy(obs.model_dump())
|
| 29 |
+
completion = action_to_completion(action)
|
| 30 |
+
|
| 31 |
+
parsed = parse_completion(completion)
|
| 32 |
+
assert parsed is not None
|
| 33 |
+
assert parsed["action_type"] == action.action_type
|
| 34 |
+
if action.case_id:
|
| 35 |
+
assert parsed["case_id"] == action.case_id
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_build_sft_dataset_has_action_variety():
|
| 39 |
+
"""SFT dataset must include >1 distinct action_type per task.
|
| 40 |
+
|
| 41 |
+
The whole point of SFT is to teach the model that different states
|
| 42 |
+
require different action_types. If the heuristic only ever emits
|
| 43 |
+
``select_case`` we have no variety to teach and SFT is useless.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
samples = build_sft_dataset(
|
| 47 |
+
["goods_not_received_easy"], max_states_per_task=24
|
| 48 |
+
)
|
| 49 |
+
assert len(samples) >= 4
|
| 50 |
+
action_types = {s["action_type"] for s in samples}
|
| 51 |
+
assert len(action_types) >= 3, f"only saw {action_types}"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_build_sft_dataset_completion_is_valid_json():
|
| 55 |
+
samples = build_sft_dataset(
|
| 56 |
+
["goods_not_received_easy"], max_states_per_task=10
|
| 57 |
+
)
|
| 58 |
+
for s in samples:
|
| 59 |
+
decoded = json.loads(s["completion"])
|
| 60 |
+
assert decoded["action_type"] == s["action_type"]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_build_sft_dataset_state_steps_monotonic():
|
| 64 |
+
samples = build_sft_dataset(
|
| 65 |
+
["goods_not_received_easy"], max_states_per_task=10
|
| 66 |
+
)
|
| 67 |
+
state_steps = [s["state_step"] for s in samples]
|
| 68 |
+
assert state_steps == sorted(state_steps)
|
| 69 |
+
assert state_steps[0] == 0
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_build_sft_dataset_handles_multiple_tasks():
|
| 73 |
+
samples = build_sft_dataset(
|
| 74 |
+
["goods_not_received_easy", "queue_optimization_hard"],
|
| 75 |
+
max_states_per_task=6,
|
| 76 |
+
)
|
| 77 |
+
task_ids = {s["task_id"] for s in samples}
|
| 78 |
+
assert task_ids == {"goods_not_received_easy", "queue_optimization_hard"}
|
|
@@ -224,3 +224,31 @@ def test_compute_reward_rejects_mismatched_lengths():
|
|
| 224 |
|
| 225 |
with pytest.raises(ValueError):
|
| 226 |
compute_reward(["a"], ["b", "c"], task_ids=["goods_not_received_easy"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
with pytest.raises(ValueError):
|
| 226 |
compute_reward(["a"], ["b", "c"], task_ids=["goods_not_received_easy"])
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def test_run_episode_breaks_select_case_loop():
|
| 230 |
+
"""Degenerate model that always emits select_case must not deadlock.
|
| 231 |
+
|
| 232 |
+
Real failure mode observed in Colab eval: a Qwen3.5 checkpoint
|
| 233 |
+
after 300 GRPO steps emitted ``select_case`` at every state. The
|
| 234 |
+
env silently no-ops the second ``select_case``, the prompt stays
|
| 235 |
+
identical, the model emits the same string, score stays 0 because
|
| 236 |
+
``done`` never flips. Stall detection must force-fallback to the
|
| 237 |
+
heuristic so the episode reaches grading.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
import json
|
| 241 |
+
|
| 242 |
+
select_case_payload = json.dumps(
|
| 243 |
+
{"action_type": "select_case", "case_id": "CB-E1"}
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
result = run_episode_with_text_policy(
|
| 247 |
+
"goods_not_received_easy",
|
| 248 |
+
text_policy=lambda _prompt: select_case_payload,
|
| 249 |
+
)
|
| 250 |
+
assert result.steps_used > 0
|
| 251 |
+
assert result.score > 0.0, (
|
| 252 |
+
f"stall detection failed: score={result.score} "
|
| 253 |
+
f"means episode never reached terminal grading"
|
| 254 |
+
)
|
|
@@ -25,13 +25,21 @@ from .reward_adapter import (
|
|
| 25 |
compute_reward,
|
| 26 |
run_episode_with_text_policy,
|
| 27 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
__all__ = [
|
| 30 |
"CheckpointEval",
|
| 31 |
"EpisodeResult",
|
|
|
|
| 32 |
"TaskOutcome",
|
| 33 |
"action_from_completion",
|
|
|
|
| 34 |
"build_prompt",
|
|
|
|
| 35 |
"compute_reward",
|
| 36 |
"evaluate_checkpoint",
|
| 37 |
"evaluate_policy_across_tasks",
|
|
|
|
| 25 |
compute_reward,
|
| 26 |
run_episode_with_text_policy,
|
| 27 |
)
|
| 28 |
+
from .sft_dataset import (
|
| 29 |
+
SFTSample,
|
| 30 |
+
action_to_completion,
|
| 31 |
+
build_sft_dataset,
|
| 32 |
+
)
|
| 33 |
|
| 34 |
__all__ = [
|
| 35 |
"CheckpointEval",
|
| 36 |
"EpisodeResult",
|
| 37 |
+
"SFTSample",
|
| 38 |
"TaskOutcome",
|
| 39 |
"action_from_completion",
|
| 40 |
+
"action_to_completion",
|
| 41 |
"build_prompt",
|
| 42 |
+
"build_sft_dataset",
|
| 43 |
"compute_reward",
|
| 44 |
"evaluate_checkpoint",
|
| 45 |
"evaluate_policy_across_tasks",
|
|
@@ -76,6 +76,74 @@ def _fallback_action(
|
|
| 76 |
return _heuristic_policy(observation.model_dump())
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def run_episode_with_text_policy(
|
| 80 |
task_id: str,
|
| 81 |
text_policy: TextPolicyFn,
|
|
@@ -86,8 +154,12 @@ def run_episode_with_text_policy(
|
|
| 86 |
"""Roll one episode forward under a text-in / text-out policy.
|
| 87 |
|
| 88 |
Used for evaluation and debugging only. Falls back to the scripted
|
| 89 |
-
heuristic when the policy returns unparseable output
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
"""
|
| 92 |
|
| 93 |
task = get_task(task_id)
|
|
@@ -96,6 +168,7 @@ def run_episode_with_text_policy(
|
|
| 96 |
step_budget = (max_steps if max_steps is not None else task.max_steps) + 5
|
| 97 |
steps = 0
|
| 98 |
invalid = 0
|
|
|
|
| 99 |
prompts: list[str] = []
|
| 100 |
completions: list[str] = []
|
| 101 |
|
|
@@ -104,16 +177,42 @@ def run_episode_with_text_policy(
|
|
| 104 |
prompt = build_prompt(obs_dict)
|
| 105 |
completion = text_policy(prompt)
|
| 106 |
action = action_from_completion(completion)
|
|
|
|
| 107 |
if action is None:
|
| 108 |
invalid += 1
|
| 109 |
action = _fallback_action(observation)
|
|
|
|
| 110 |
if action is None:
|
| 111 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
observation = env.step(action)
|
| 113 |
steps += 1
|
| 114 |
if capture_trace:
|
| 115 |
prompts.append(prompt)
|
| 116 |
-
|
|
|
|
| 117 |
|
| 118 |
report = env.state.grader_report
|
| 119 |
score = float(report.normalized_score) if report is not None else 0.0
|
|
|
|
| 76 |
return _heuristic_policy(observation.model_dump())
|
| 77 |
|
| 78 |
|
| 79 |
+
def _state_signature(observation: ChargebackOpsObservation) -> tuple:
|
| 80 |
+
"""Stable hashable snapshot of the env state visible to the policy.
|
| 81 |
+
|
| 82 |
+
Used to detect rollout stalls — if step() leaves this signature
|
| 83 |
+
unchanged, the model picked an action the env silently no-op'd
|
| 84 |
+
(e.g. ``select_case`` when a case is already selected) and is
|
| 85 |
+
about to loop forever on the same prompt. ``steps_remaining`` is
|
| 86 |
+
deliberately excluded: it decrements on every step regardless of
|
| 87 |
+
whether the env actually progressed, so including it would mask
|
| 88 |
+
every real stall.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
visible = observation.visible_case
|
| 92 |
+
visible_sig: tuple
|
| 93 |
+
if visible is None:
|
| 94 |
+
visible_sig = ()
|
| 95 |
+
else:
|
| 96 |
+
visible_sig = (
|
| 97 |
+
visible.case_id,
|
| 98 |
+
visible.status,
|
| 99 |
+
visible.current_strategy,
|
| 100 |
+
len(visible.attached_evidence),
|
| 101 |
+
len(visible.retrieved_evidence),
|
| 102 |
+
)
|
| 103 |
+
return (
|
| 104 |
+
observation.selected_case_id,
|
| 105 |
+
tuple(sorted(observation.available_actions)),
|
| 106 |
+
observation.done,
|
| 107 |
+
visible_sig,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _action_key(action: ChargebackOpsAction) -> tuple:
|
| 112 |
+
"""Hashable identity for "have we tried this exact action here?" check."""
|
| 113 |
+
|
| 114 |
+
return (
|
| 115 |
+
action.action_type,
|
| 116 |
+
action.case_id,
|
| 117 |
+
action.system_name,
|
| 118 |
+
tuple(action.evidence_ids),
|
| 119 |
+
action.strategy,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _predicted_noop(
|
| 124 |
+
action: ChargebackOpsAction,
|
| 125 |
+
observation: ChargebackOpsObservation,
|
| 126 |
+
) -> bool:
|
| 127 |
+
"""Cheap upfront check that the env will silently no-op this action.
|
| 128 |
+
|
| 129 |
+
Catches the dominant Qwen failure mode (always emit ``select_case``
|
| 130 |
+
even after a case is already selected). Without this check the
|
| 131 |
+
model burns an env step per state on the duplicate ``select_case``,
|
| 132 |
+
blowing the per-task step budget before the heuristic fallback can
|
| 133 |
+
finish the episode. We only hard-code rules we *know* the env
|
| 134 |
+
treats as no-ops; everything else flows through the env and the
|
| 135 |
+
post-hoc ``tried_at_state`` cache.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
if (
|
| 139 |
+
action.action_type == "select_case"
|
| 140 |
+
and observation.selected_case_id is not None
|
| 141 |
+
and action.case_id == observation.selected_case_id
|
| 142 |
+
):
|
| 143 |
+
return True
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
|
| 147 |
def run_episode_with_text_policy(
|
| 148 |
task_id: str,
|
| 149 |
text_policy: TextPolicyFn,
|
|
|
|
| 154 |
"""Roll one episode forward under a text-in / text-out policy.
|
| 155 |
|
| 156 |
Used for evaluation and debugging only. Falls back to the scripted
|
| 157 |
+
heuristic when the policy returns unparseable output **or** when
|
| 158 |
+
the model picks an action it has already tried from the current
|
| 159 |
+
state (the env silently no-ops the duplicate, ``done`` never flips,
|
| 160 |
+
score stays 0). The repeat-action guard catches the dominant Qwen
|
| 161 |
+
failure mode where a checkpoint always emits ``select_case`` and
|
| 162 |
+
the episode loops forever. **Not** used for training reward.
|
| 163 |
"""
|
| 164 |
|
| 165 |
task = get_task(task_id)
|
|
|
|
| 168 |
step_budget = (max_steps if max_steps is not None else task.max_steps) + 5
|
| 169 |
steps = 0
|
| 170 |
invalid = 0
|
| 171 |
+
tried_at_state: dict[tuple, set[tuple]] = {}
|
| 172 |
prompts: list[str] = []
|
| 173 |
completions: list[str] = []
|
| 174 |
|
|
|
|
| 177 |
prompt = build_prompt(obs_dict)
|
| 178 |
completion = text_policy(prompt)
|
| 179 |
action = action_from_completion(completion)
|
| 180 |
+
used_fallback = False
|
| 181 |
if action is None:
|
| 182 |
invalid += 1
|
| 183 |
action = _fallback_action(observation)
|
| 184 |
+
used_fallback = True
|
| 185 |
if action is None:
|
| 186 |
break
|
| 187 |
+
|
| 188 |
+
if not used_fallback and _predicted_noop(action, observation):
|
| 189 |
+
fallback = _fallback_action(observation)
|
| 190 |
+
if fallback is not None:
|
| 191 |
+
action = fallback
|
| 192 |
+
used_fallback = True
|
| 193 |
+
|
| 194 |
+
state_sig = _state_signature(observation)
|
| 195 |
+
attempted = tried_at_state.setdefault(state_sig, set())
|
| 196 |
+
action_key = _action_key(action)
|
| 197 |
+
if action_key in attempted and not used_fallback:
|
| 198 |
+
fallback = _fallback_action(observation)
|
| 199 |
+
if fallback is None:
|
| 200 |
+
break
|
| 201 |
+
fallback_key = _action_key(fallback)
|
| 202 |
+
if fallback_key in attempted:
|
| 203 |
+
# Heuristic also stuck — bail out, score whatever we have.
|
| 204 |
+
break
|
| 205 |
+
action = fallback
|
| 206 |
+
action_key = fallback_key
|
| 207 |
+
used_fallback = True
|
| 208 |
+
attempted.add(action_key)
|
| 209 |
+
|
| 210 |
observation = env.step(action)
|
| 211 |
steps += 1
|
| 212 |
if capture_trace:
|
| 213 |
prompts.append(prompt)
|
| 214 |
+
tag = "<<fallback>> " if used_fallback else ""
|
| 215 |
+
completions.append(f"{tag}{completion}")
|
| 216 |
|
| 217 |
report = env.state.grader_report
|
| 218 |
score = float(report.normalized_score) if report is not None else 0.0
|
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supervised fine-tuning dataset builder for ChargebackOps.
|
| 2 |
+
|
| 3 |
+
Rolls the scripted heuristic across each task and captures every
|
| 4 |
+
``(observation_prompt, oracle_completion)`` pair as a single-turn
|
| 5 |
+
training sample. The completion is the JSON serialisation of the
|
| 6 |
+
heuristic action, matching the format the merchant policy must emit
|
| 7 |
+
at inference time.
|
| 8 |
+
|
| 9 |
+
SFT before GRPO is the standard RLHF pattern. It teaches the base
|
| 10 |
+
model two things GRPO struggles to learn from sparse reward alone:
|
| 11 |
+
|
| 12 |
+
* The output schema (valid JSON, the right ``action_type`` strings,
|
| 13 |
+
no extra prose).
|
| 14 |
+
* Per-state action variety — the heuristic emits a *different*
|
| 15 |
+
action_type at each state, so an SFT-trained model stops
|
| 16 |
+
collapsing to ``select_case`` at every step.
|
| 17 |
+
|
| 18 |
+
The module returns plain dicts so the notebook can wrap them in any
|
| 19 |
+
trainer's expected dataset format (TRL ``SFTTrainer``, HF
|
| 20 |
+
``Dataset.from_list``, etc.) without pulling those deps into the
|
| 21 |
+
package.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import json
|
| 27 |
+
from dataclasses import dataclass
|
| 28 |
+
from typing import Any, Sequence
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from ..core.models import ChargebackOpsAction
|
| 32 |
+
from ..server.chargeback_ops_environment import ChargebackOpsEnvironment
|
| 33 |
+
from .env_adapter import build_prompt
|
| 34 |
+
except ImportError: # pragma: no cover
|
| 35 |
+
from core.models import ChargebackOpsAction
|
| 36 |
+
from server.chargeback_ops_environment import ChargebackOpsEnvironment
|
| 37 |
+
from training.env_adapter import build_prompt
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _heuristic_policy(observation_dict: dict[str, Any]) -> ChargebackOpsAction | None:
|
| 41 |
+
try:
|
| 42 |
+
from ..runners.benchmark_runner import heuristic_policy
|
| 43 |
+
except ImportError: # pragma: no cover
|
| 44 |
+
from runners.benchmark_runner import heuristic_policy
|
| 45 |
+
return heuristic_policy(observation_dict)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def action_to_completion(action: ChargebackOpsAction) -> str:
|
| 49 |
+
"""Serialise an action as the canonical JSON completion string."""
|
| 50 |
+
|
| 51 |
+
payload = action.model_dump(exclude_none=True)
|
| 52 |
+
payload = {k: v for k, v in payload.items() if v not in ([], "")}
|
| 53 |
+
return json.dumps(payload, separators=(",", ":"), sort_keys=True)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass(frozen=True)
|
| 57 |
+
class SFTSample:
|
| 58 |
+
"""One supervised training pair."""
|
| 59 |
+
|
| 60 |
+
task_id: str
|
| 61 |
+
state_step: int
|
| 62 |
+
prompt: str
|
| 63 |
+
completion: str
|
| 64 |
+
action_type: str
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def build_sft_dataset(
|
| 68 |
+
task_ids: Sequence[str],
|
| 69 |
+
*,
|
| 70 |
+
max_states_per_task: int = 24,
|
| 71 |
+
) -> list[dict[str, Any]]:
|
| 72 |
+
"""Roll heuristic on each task; capture (prompt, oracle_completion) pairs.
|
| 73 |
+
|
| 74 |
+
Goes deeper than :func:`training.reward_adapter.build_state_action_dataset`
|
| 75 |
+
(default 24 vs 12 states per task) because SFT benefits from seeing
|
| 76 |
+
the full trajectory — including terminal-resolution actions which
|
| 77 |
+
are rare in the early-state distribution.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
samples: list[dict[str, Any]] = []
|
| 81 |
+
for task_id in task_ids:
|
| 82 |
+
env = ChargebackOpsEnvironment()
|
| 83 |
+
obs = env.reset(task_id=task_id)
|
| 84 |
+
for state_step in range(max_states_per_task):
|
| 85 |
+
if obs.done:
|
| 86 |
+
break
|
| 87 |
+
heur = _heuristic_policy(obs.model_dump())
|
| 88 |
+
if heur is None:
|
| 89 |
+
break
|
| 90 |
+
samples.append(
|
| 91 |
+
{
|
| 92 |
+
"task_id": task_id,
|
| 93 |
+
"state_step": state_step,
|
| 94 |
+
"prompt": build_prompt(obs.model_dump()),
|
| 95 |
+
"completion": action_to_completion(heur),
|
| 96 |
+
"action_type": heur.action_type,
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
obs = env.step(heur)
|
| 100 |
+
return samples
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
__all__ = [
|
| 104 |
+
"SFTSample",
|
| 105 |
+
"action_to_completion",
|
| 106 |
+
"build_sft_dataset",
|
| 107 |
+
]
|