mitudrudutta commited on
Commit
02a6a9f
·
1 Parent(s): 243aa68

feat(training): SFT dataset + stall detection in eval rollout

Browse files

Two permanent fixes for the failure mode that produced eval=0 across
all GRPO checkpoints in the prior Colab run:

1. SFT dataset module (training/sft_dataset.py) — heuristic rollouts
captured as (prompt, oracle_completion) pairs for supervised
pre-training. Two-phase RLHF (SFT then GRPO) is the proven recipe
for teaching a base model the JSON schema and per-state action
variety before sparse RL reward kicks in. Pure GRPO from a base
model collapses to 'always emit select_case' because the model
never sees varied action_types in the prompt-only training data.

2. Stall detection in run_episode_with_text_policy — the dominant
failure was a checkpoint emitting select_case at every state. The
env silently no-ops the duplicate select_case (returns -0.02
reward, advances step_count), so the rollout burns its entire
step budget without flipping done. New code:
- Hard-coded predicate _predicted_noop catches the dominant case
(select_case when a case is already selected) before env.step,
so the model's wasted action doesn't consume an env step.
- Per-state tried_at_state cache catches less-common no-ops
post-hoc (model picks an already-attempted action_key at the
same state -> force fallback).

Tests:
- tests/test_sft_dataset.py (5 tests, action variety, JSON
round-trip, monotonic state_step, multi-task)
- tests/test_training_adapter.py: new test_run_episode_breaks_select_case_loop
pins the regression — degenerate model that always emits
select_case must reach terminal grading with score > 0.

105/105 tests pass.

tests/test_sft_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for the SFT dataset builder.
2
+
3
+ The supervised pre-training stage feeds (prompt, oracle_completion)
4
+ pairs into the base model so it learns the JSON schema and per-state
5
+ action variety *before* GRPO. These tests pin the contract so the
6
+ notebook's SFT cell stays stable.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+
13
+ from training.env_adapter import parse_completion
14
+ from training.sft_dataset import (
15
+ action_to_completion,
16
+ build_sft_dataset,
17
+ )
18
+
19
+
20
+ def test_action_to_completion_round_trips_through_parser():
21
+ """Oracle completion must parse back into the same action dict."""
22
+
23
+ from runners.benchmark_runner import heuristic_policy
24
+ from server.chargeback_ops_environment import ChargebackOpsEnvironment
25
+
26
+ env = ChargebackOpsEnvironment()
27
+ obs = env.reset(task_id="goods_not_received_easy")
28
+ action = heuristic_policy(obs.model_dump())
29
+ completion = action_to_completion(action)
30
+
31
+ parsed = parse_completion(completion)
32
+ assert parsed is not None
33
+ assert parsed["action_type"] == action.action_type
34
+ if action.case_id:
35
+ assert parsed["case_id"] == action.case_id
36
+
37
+
38
+ def test_build_sft_dataset_has_action_variety():
39
+ """SFT dataset must include >1 distinct action_type per task.
40
+
41
+ The whole point of SFT is to teach the model that different states
42
+ require different action_types. If the heuristic only ever emits
43
+ ``select_case`` we have no variety to teach and SFT is useless.
44
+ """
45
+
46
+ samples = build_sft_dataset(
47
+ ["goods_not_received_easy"], max_states_per_task=24
48
+ )
49
+ assert len(samples) >= 4
50
+ action_types = {s["action_type"] for s in samples}
51
+ assert len(action_types) >= 3, f"only saw {action_types}"
52
+
53
+
54
+ def test_build_sft_dataset_completion_is_valid_json():
55
+ samples = build_sft_dataset(
56
+ ["goods_not_received_easy"], max_states_per_task=10
57
+ )
58
+ for s in samples:
59
+ decoded = json.loads(s["completion"])
60
+ assert decoded["action_type"] == s["action_type"]
61
+
62
+
63
+ def test_build_sft_dataset_state_steps_monotonic():
64
+ samples = build_sft_dataset(
65
+ ["goods_not_received_easy"], max_states_per_task=10
66
+ )
67
+ state_steps = [s["state_step"] for s in samples]
68
+ assert state_steps == sorted(state_steps)
69
+ assert state_steps[0] == 0
70
+
71
+
72
+ def test_build_sft_dataset_handles_multiple_tasks():
73
+ samples = build_sft_dataset(
74
+ ["goods_not_received_easy", "queue_optimization_hard"],
75
+ max_states_per_task=6,
76
+ )
77
+ task_ids = {s["task_id"] for s in samples}
78
+ assert task_ids == {"goods_not_received_easy", "queue_optimization_hard"}
tests/test_training_adapter.py CHANGED
@@ -224,3 +224,31 @@ def test_compute_reward_rejects_mismatched_lengths():
224
 
225
  with pytest.raises(ValueError):
226
  compute_reward(["a"], ["b", "c"], task_ids=["goods_not_received_easy"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  with pytest.raises(ValueError):
226
  compute_reward(["a"], ["b", "c"], task_ids=["goods_not_received_easy"])
227
+
228
+
229
+ def test_run_episode_breaks_select_case_loop():
230
+ """Degenerate model that always emits select_case must not deadlock.
231
+
232
+ Real failure mode observed in Colab eval: a Qwen3.5 checkpoint
233
+ after 300 GRPO steps emitted ``select_case`` at every state. The
234
+ env silently no-ops the second ``select_case``, the prompt stays
235
+ identical, the model emits the same string, score stays 0 because
236
+ ``done`` never flips. Stall detection must force-fallback to the
237
+ heuristic so the episode reaches grading.
238
+ """
239
+
240
+ import json
241
+
242
+ select_case_payload = json.dumps(
243
+ {"action_type": "select_case", "case_id": "CB-E1"}
244
+ )
245
+
246
+ result = run_episode_with_text_policy(
247
+ "goods_not_received_easy",
248
+ text_policy=lambda _prompt: select_case_payload,
249
+ )
250
+ assert result.steps_used > 0
251
+ assert result.score > 0.0, (
252
+ f"stall detection failed: score={result.score} "
253
+ f"means episode never reached terminal grading"
254
+ )
training/__init__.py CHANGED
@@ -25,13 +25,21 @@ from .reward_adapter import (
25
  compute_reward,
26
  run_episode_with_text_policy,
27
  )
 
 
 
 
 
28
 
29
  __all__ = [
30
  "CheckpointEval",
31
  "EpisodeResult",
 
32
  "TaskOutcome",
33
  "action_from_completion",
 
34
  "build_prompt",
 
35
  "compute_reward",
36
  "evaluate_checkpoint",
37
  "evaluate_policy_across_tasks",
 
25
  compute_reward,
26
  run_episode_with_text_policy,
27
  )
28
+ from .sft_dataset import (
29
+ SFTSample,
30
+ action_to_completion,
31
+ build_sft_dataset,
32
+ )
33
 
34
  __all__ = [
35
  "CheckpointEval",
36
  "EpisodeResult",
37
+ "SFTSample",
38
  "TaskOutcome",
39
  "action_from_completion",
40
+ "action_to_completion",
41
  "build_prompt",
42
+ "build_sft_dataset",
43
  "compute_reward",
44
  "evaluate_checkpoint",
45
  "evaluate_policy_across_tasks",
training/reward_adapter.py CHANGED
@@ -76,6 +76,74 @@ def _fallback_action(
76
  return _heuristic_policy(observation.model_dump())
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def run_episode_with_text_policy(
80
  task_id: str,
81
  text_policy: TextPolicyFn,
@@ -86,8 +154,12 @@ def run_episode_with_text_policy(
86
  """Roll one episode forward under a text-in / text-out policy.
87
 
88
  Used for evaluation and debugging only. Falls back to the scripted
89
- heuristic when the policy returns unparseable output, so the episode
90
- always reaches a terminal state. **Not** used for training reward.
 
 
 
 
91
  """
92
 
93
  task = get_task(task_id)
@@ -96,6 +168,7 @@ def run_episode_with_text_policy(
96
  step_budget = (max_steps if max_steps is not None else task.max_steps) + 5
97
  steps = 0
98
  invalid = 0
 
99
  prompts: list[str] = []
100
  completions: list[str] = []
101
 
@@ -104,16 +177,42 @@ def run_episode_with_text_policy(
104
  prompt = build_prompt(obs_dict)
105
  completion = text_policy(prompt)
106
  action = action_from_completion(completion)
 
107
  if action is None:
108
  invalid += 1
109
  action = _fallback_action(observation)
 
110
  if action is None:
111
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  observation = env.step(action)
113
  steps += 1
114
  if capture_trace:
115
  prompts.append(prompt)
116
- completions.append(completion)
 
117
 
118
  report = env.state.grader_report
119
  score = float(report.normalized_score) if report is not None else 0.0
 
76
  return _heuristic_policy(observation.model_dump())
77
 
78
 
79
+ def _state_signature(observation: ChargebackOpsObservation) -> tuple:
80
+ """Stable hashable snapshot of the env state visible to the policy.
81
+
82
+ Used to detect rollout stalls — if step() leaves this signature
83
+ unchanged, the model picked an action the env silently no-op'd
84
+ (e.g. ``select_case`` when a case is already selected) and is
85
+ about to loop forever on the same prompt. ``steps_remaining`` is
86
+ deliberately excluded: it decrements on every step regardless of
87
+ whether the env actually progressed, so including it would mask
88
+ every real stall.
89
+ """
90
+
91
+ visible = observation.visible_case
92
+ visible_sig: tuple
93
+ if visible is None:
94
+ visible_sig = ()
95
+ else:
96
+ visible_sig = (
97
+ visible.case_id,
98
+ visible.status,
99
+ visible.current_strategy,
100
+ len(visible.attached_evidence),
101
+ len(visible.retrieved_evidence),
102
+ )
103
+ return (
104
+ observation.selected_case_id,
105
+ tuple(sorted(observation.available_actions)),
106
+ observation.done,
107
+ visible_sig,
108
+ )
109
+
110
+
111
+ def _action_key(action: ChargebackOpsAction) -> tuple:
112
+ """Hashable identity for "have we tried this exact action here?" check."""
113
+
114
+ return (
115
+ action.action_type,
116
+ action.case_id,
117
+ action.system_name,
118
+ tuple(action.evidence_ids),
119
+ action.strategy,
120
+ )
121
+
122
+
123
+ def _predicted_noop(
124
+ action: ChargebackOpsAction,
125
+ observation: ChargebackOpsObservation,
126
+ ) -> bool:
127
+ """Cheap upfront check that the env will silently no-op this action.
128
+
129
+ Catches the dominant Qwen failure mode (always emit ``select_case``
130
+ even after a case is already selected). Without this check the
131
+ model burns an env step per state on the duplicate ``select_case``,
132
+ blowing the per-task step budget before the heuristic fallback can
133
+ finish the episode. We only hard-code rules we *know* the env
134
+ treats as no-ops; everything else flows through the env and the
135
+ post-hoc ``tried_at_state`` cache.
136
+ """
137
+
138
+ if (
139
+ action.action_type == "select_case"
140
+ and observation.selected_case_id is not None
141
+ and action.case_id == observation.selected_case_id
142
+ ):
143
+ return True
144
+ return False
145
+
146
+
147
  def run_episode_with_text_policy(
148
  task_id: str,
149
  text_policy: TextPolicyFn,
 
154
  """Roll one episode forward under a text-in / text-out policy.
155
 
156
  Used for evaluation and debugging only. Falls back to the scripted
157
+ heuristic when the policy returns unparseable output **or** when
158
+ the model picks an action it has already tried from the current
159
+ state (the env silently no-ops the duplicate, ``done`` never flips,
160
+ score stays 0). The repeat-action guard catches the dominant Qwen
161
+ failure mode where a checkpoint always emits ``select_case`` and
162
+ the episode loops forever. **Not** used for training reward.
163
  """
164
 
165
  task = get_task(task_id)
 
168
  step_budget = (max_steps if max_steps is not None else task.max_steps) + 5
169
  steps = 0
170
  invalid = 0
171
+ tried_at_state: dict[tuple, set[tuple]] = {}
172
  prompts: list[str] = []
173
  completions: list[str] = []
174
 
 
177
  prompt = build_prompt(obs_dict)
178
  completion = text_policy(prompt)
179
  action = action_from_completion(completion)
180
+ used_fallback = False
181
  if action is None:
182
  invalid += 1
183
  action = _fallback_action(observation)
184
+ used_fallback = True
185
  if action is None:
186
  break
187
+
188
+ if not used_fallback and _predicted_noop(action, observation):
189
+ fallback = _fallback_action(observation)
190
+ if fallback is not None:
191
+ action = fallback
192
+ used_fallback = True
193
+
194
+ state_sig = _state_signature(observation)
195
+ attempted = tried_at_state.setdefault(state_sig, set())
196
+ action_key = _action_key(action)
197
+ if action_key in attempted and not used_fallback:
198
+ fallback = _fallback_action(observation)
199
+ if fallback is None:
200
+ break
201
+ fallback_key = _action_key(fallback)
202
+ if fallback_key in attempted:
203
+ # Heuristic also stuck — bail out, score whatever we have.
204
+ break
205
+ action = fallback
206
+ action_key = fallback_key
207
+ used_fallback = True
208
+ attempted.add(action_key)
209
+
210
  observation = env.step(action)
211
  steps += 1
212
  if capture_trace:
213
  prompts.append(prompt)
214
+ tag = "<<fallback>> " if used_fallback else ""
215
+ completions.append(f"{tag}{completion}")
216
 
217
  report = env.state.grader_report
218
  score = float(report.normalized_score) if report is not None else 0.0
training/sft_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Supervised fine-tuning dataset builder for ChargebackOps.
2
+
3
+ Rolls the scripted heuristic across each task and captures every
4
+ ``(observation_prompt, oracle_completion)`` pair as a single-turn
5
+ training sample. The completion is the JSON serialisation of the
6
+ heuristic action, matching the format the merchant policy must emit
7
+ at inference time.
8
+
9
+ SFT before GRPO is the standard RLHF pattern. It teaches the base
10
+ model two things GRPO struggles to learn from sparse reward alone:
11
+
12
+ * The output schema (valid JSON, the right ``action_type`` strings,
13
+ no extra prose).
14
+ * Per-state action variety — the heuristic emits a *different*
15
+ action_type at each state, so an SFT-trained model stops
16
+ collapsing to ``select_case`` at every step.
17
+
18
+ The module returns plain dicts so the notebook can wrap them in any
19
+ trainer's expected dataset format (TRL ``SFTTrainer``, HF
20
+ ``Dataset.from_list``, etc.) without pulling those deps into the
21
+ package.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import json
27
+ from dataclasses import dataclass
28
+ from typing import Any, Sequence
29
+
30
+ try:
31
+ from ..core.models import ChargebackOpsAction
32
+ from ..server.chargeback_ops_environment import ChargebackOpsEnvironment
33
+ from .env_adapter import build_prompt
34
+ except ImportError: # pragma: no cover
35
+ from core.models import ChargebackOpsAction
36
+ from server.chargeback_ops_environment import ChargebackOpsEnvironment
37
+ from training.env_adapter import build_prompt
38
+
39
+
40
+ def _heuristic_policy(observation_dict: dict[str, Any]) -> ChargebackOpsAction | None:
41
+ try:
42
+ from ..runners.benchmark_runner import heuristic_policy
43
+ except ImportError: # pragma: no cover
44
+ from runners.benchmark_runner import heuristic_policy
45
+ return heuristic_policy(observation_dict)
46
+
47
+
48
+ def action_to_completion(action: ChargebackOpsAction) -> str:
49
+ """Serialise an action as the canonical JSON completion string."""
50
+
51
+ payload = action.model_dump(exclude_none=True)
52
+ payload = {k: v for k, v in payload.items() if v not in ([], "")}
53
+ return json.dumps(payload, separators=(",", ":"), sort_keys=True)
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class SFTSample:
58
+ """One supervised training pair."""
59
+
60
+ task_id: str
61
+ state_step: int
62
+ prompt: str
63
+ completion: str
64
+ action_type: str
65
+
66
+
67
+ def build_sft_dataset(
68
+ task_ids: Sequence[str],
69
+ *,
70
+ max_states_per_task: int = 24,
71
+ ) -> list[dict[str, Any]]:
72
+ """Roll heuristic on each task; capture (prompt, oracle_completion) pairs.
73
+
74
+ Goes deeper than :func:`training.reward_adapter.build_state_action_dataset`
75
+ (default 24 vs 12 states per task) because SFT benefits from seeing
76
+ the full trajectory — including terminal-resolution actions which
77
+ are rare in the early-state distribution.
78
+ """
79
+
80
+ samples: list[dict[str, Any]] = []
81
+ for task_id in task_ids:
82
+ env = ChargebackOpsEnvironment()
83
+ obs = env.reset(task_id=task_id)
84
+ for state_step in range(max_states_per_task):
85
+ if obs.done:
86
+ break
87
+ heur = _heuristic_policy(obs.model_dump())
88
+ if heur is None:
89
+ break
90
+ samples.append(
91
+ {
92
+ "task_id": task_id,
93
+ "state_step": state_step,
94
+ "prompt": build_prompt(obs.model_dump()),
95
+ "completion": action_to_completion(heur),
96
+ "action_type": heur.action_type,
97
+ }
98
+ )
99
+ obs = env.step(heur)
100
+ return samples
101
+
102
+
103
+ __all__ = [
104
+ "SFTSample",
105
+ "action_to_completion",
106
+ "build_sft_dataset",
107
+ ]