mitudrudutta commited on
Commit
243aa68
·
1 Parent(s): 71f1fe0

fix(training): per-action reward scoring vs heuristic oracle

Browse files

Prior compute_reward fell back to the scripted heuristic on parse-fail,
which trained the model that emitting garbage was optimal: the heuristic
played the whole episode and earned ~0.96 reward regardless of model
output. Group reward variance collapsed to std~=0.005 across 6 wildly
different completions, GRPO advantage went to 0, loss collapsed to 0.

This rewrites compute_reward as a pure per-action scorer against the
heuristic oracle at the dataset's recorded env state:
- parse-fail -> 0.0
- action_type not in available_actions at this state -> 0.1
- valid action_type, different than oracle -> 0.4
- right action_type, wrong target (case_id/system) -> 0.7
- exact match on action_type + targets -> 1.0

Variance across the same 6 diverse completions: std=0.39 (70x lift),
distinct values {0.0, 0.1, 0.7, 1.0}. GRPO now has real gradient.

Adds build_state_action_dataset() to roll the heuristic and capture
(state, oracle_action) pairs, so training prompts cover mid-episode
states (otherwise the model only learns first-action policy).

run_episode_with_text_policy keeps its heuristic fallback because it is
used for evaluation/debug rollouts, not training reward.

tests/test_training_adapter.py CHANGED
@@ -17,6 +17,7 @@ from training.env_adapter import (
17
  parse_completion,
18
  )
19
  from training.reward_adapter import (
 
20
  compute_reward,
21
  run_episode_with_text_policy,
22
  )
@@ -115,17 +116,107 @@ def test_run_episode_falls_back_to_heuristic_on_empty_completion():
115
  assert result.score > 0.0 # heuristic fallback still scores
116
 
117
 
118
- def test_compute_reward_matches_episode_score():
119
- """Single completion + heuristic tail reproduces the heuristic score."""
120
- task = get_task("goods_not_received_easy")
121
- prompts = ["unused"]
122
- completions = [""] # triggers heuristic fallback on the first action
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  rewards = compute_reward(
124
- prompts, completions, task_ids=[task.task_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
- assert len(rewards) == 1
127
- assert 0.0 <= rewards[0] <= 1.0
128
- assert rewards[0] > 0.5 # heuristic scores ~0.97 on this task
 
 
 
 
129
 
130
 
131
  def test_compute_reward_rejects_mismatched_lengths():
 
17
  parse_completion,
18
  )
19
  from training.reward_adapter import (
20
+ build_state_action_dataset,
21
  compute_reward,
22
  run_episode_with_text_policy,
23
  )
 
116
  assert result.score > 0.0 # heuristic fallback still scores
117
 
118
 
119
+ def test_compute_reward_unparseable_returns_zero():
120
+ """Per-action scorer must NOT fall back to heuristic on parse-fail.
121
+
122
+ The previous fallback design poisoned the GRPO signal: garbage
123
+ completions earned ~0.96 reward (heuristic played the episode), so
124
+ the model learned that emitting garbage was optimal and group
125
+ reward variance collapsed to ~0.005, killing the gradient.
126
+ """
127
+
128
+ rewards = compute_reward(
129
+ ["unused"], [""], task_ids=["goods_not_received_easy"]
130
+ )
131
+ assert rewards == [0.0]
132
+
133
+
134
+ def test_compute_reward_exact_match_scores_one():
135
+ """Completion that matches the heuristic action exactly gets 1.0."""
136
+
137
+ import json
138
+
139
+ from runners.benchmark_runner import heuristic_policy
140
+ from server.chargeback_ops_environment import ChargebackOpsEnvironment
141
+
142
+ env = ChargebackOpsEnvironment()
143
+ obs = env.reset(task_id="goods_not_received_easy")
144
+ oracle = heuristic_policy(obs.model_dump())
145
+ completion = json.dumps(oracle.model_dump(exclude_none=True))
146
+
147
+ rewards = compute_reward(
148
+ ["unused"], [completion], task_ids=["goods_not_received_easy"]
149
+ )
150
+ assert rewards == [1.0]
151
+
152
+
153
+ def test_compute_reward_unavailable_action_scores_low():
154
+ """Valid JSON but action_type not allowed at this state → 0.1."""
155
+
156
+ # First state on goods_not_received_easy only allows ``select_case``.
157
+ completion = '{"action_type": "submit_representment", "case_id": "CB-E1"}'
158
+ rewards = compute_reward(
159
+ ["unused"], [completion], task_ids=["goods_not_received_easy"]
160
+ )
161
+ assert rewards == [0.1]
162
+
163
+
164
+ def test_compute_reward_has_real_variance_across_diverse_completions():
165
+ """Diverse completions must produce distinct rewards (the whole point).
166
+
167
+ The prior design produced std ≈ 0.005 across 6 wildly different
168
+ completions because the heuristic dominated the episode. New design
169
+ should give ≥ 3 distinct reward values across the same set.
170
+ """
171
+
172
+ import json
173
+
174
+ from runners.benchmark_runner import heuristic_policy
175
+ from server.chargeback_ops_environment import ChargebackOpsEnvironment
176
+
177
+ env = ChargebackOpsEnvironment()
178
+ obs = env.reset(task_id="goods_not_received_easy")
179
+ oracle = heuristic_policy(obs.model_dump())
180
+
181
+ completions = [
182
+ "", # parse-fail → 0.0
183
+ "garbage no json", # parse-fail → 0.0
184
+ '{"action_type": "submit_representment", "case_id": "CB-E1"}', # unavailable → 0.1
185
+ json.dumps(oracle.model_dump(exclude_none=True)), # exact → 1.0
186
+ ]
187
  rewards = compute_reward(
188
+ ["x"] * 4, completions, task_ids=["goods_not_received_easy"] * 4
189
+ )
190
+ assert len(set(rewards)) >= 3
191
+ assert max(rewards) - min(rewards) >= 0.5
192
+
193
+
194
+ def test_compute_reward_state_steps_advance_env():
195
+ """state_steps replays heuristic to reach mid-episode states."""
196
+
197
+ rewards = compute_reward(
198
+ ["x", "x"],
199
+ ["", ""],
200
+ task_ids=["goods_not_received_easy", "goods_not_received_easy"],
201
+ state_steps=[0, 2],
202
+ )
203
+ # Both unparseable → both 0.0 regardless of state.
204
+ assert rewards == [0.0, 0.0]
205
+
206
+
207
+ def test_build_state_action_dataset_covers_multiple_states():
208
+ """Heuristic rollout must yield several (state, oracle) pairs per task."""
209
+
210
+ samples = build_state_action_dataset(
211
+ ["goods_not_received_easy"], max_states_per_task=8
212
  )
213
+ assert len(samples) >= 2
214
+ state_steps = [s["state_step"] for s in samples]
215
+ assert state_steps == sorted(state_steps)
216
+ assert state_steps[0] == 0
217
+ for s in samples:
218
+ assert s["task_id"] == "goods_not_received_easy"
219
+ assert "OBSERVATION:" in s["prompt"]
220
 
221
 
222
  def test_compute_reward_rejects_mismatched_lengths():
training/reward_adapter.py CHANGED
@@ -4,11 +4,14 @@ Exposes a callable shape compatible with TRL's GRPO reward function:
4
 
5
  ``reward_fn(prompts, completions, **kwargs) -> list[float]``
6
 
7
- Each completion is parsed into an action sequence (one action per line
8
- is the simplest case; the helper also accepts a single-action
9
- completion and runs the remainder of the episode under the scripted
10
- heuristic so training always produces a terminal score). The resulting
11
- reward is the episode's deterministic normalized grade in ``[0, 1]``.
 
 
 
12
  """
13
 
14
  from __future__ import annotations
@@ -43,16 +46,34 @@ class EpisodeResult:
43
  completions: tuple[str, ...] = field(default_factory=tuple)
44
 
45
 
46
- def _fallback_action(
47
- observation: ChargebackOpsObservation,
48
- ) -> ChargebackOpsAction | None:
49
- """Scripted fallback when the model output is unparseable."""
 
 
 
 
50
 
 
 
51
  try:
52
  from ..runners.benchmark_runner import heuristic_policy
53
  except ImportError: # pragma: no cover
54
  from runners.benchmark_runner import heuristic_policy
55
- return heuristic_policy(observation.model_dump())
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  def run_episode_with_text_policy(
@@ -64,9 +85,9 @@ def run_episode_with_text_policy(
64
  ) -> EpisodeResult:
65
  """Roll one episode forward under a text-in / text-out policy.
66
 
67
- The policy is invoked at every step. If the completion fails to
68
- parse into a valid action the scripted heuristic is used instead;
69
- this keeps early-training trajectories from deadlocking.
70
  """
71
 
72
  task = get_task(task_id)
@@ -106,23 +127,87 @@ def run_episode_with_text_policy(
106
  )
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def compute_reward(
110
  prompts: Sequence[str],
111
  completions: Sequence[str],
112
  *,
113
  task_ids: Sequence[str] | None = None,
 
114
  **_: Any,
115
  ) -> list[float]:
116
- """GRPO-style reward function.
 
 
117
 
118
- Each ``completion`` is replayed as a *single* action. The remainder
119
- of the episode is driven by the scripted heuristic, so the reward
120
- signal rewards the model for picking a good first move from a
121
- given observation. This matches the behaviour TRL expects: one
122
- ``(prompt, completion)`` pair → one scalar reward.
123
 
124
- ``task_ids`` optionally binds each prompt to a task id for env
125
- replay. When omitted, the headline catalog is cycled.
 
 
 
 
126
  """
127
 
128
  if task_ids is None:
@@ -132,25 +217,74 @@ def compute_reward(
132
  raise ValueError(
133
  "prompts, completions, and task_ids must all have the same length"
134
  )
 
 
 
 
135
 
136
  rewards: list[float] = []
137
- for task_id, completion in zip(task_ids, completions):
138
- first_action = action_from_completion(completion)
 
 
 
 
 
 
 
 
 
139
 
140
- def _once(_prompt: str, _used=[False], _action=first_action) -> str:
141
- if _used[0] or _action is None:
142
- return ""
143
- _used[0] = True
144
- return completion
145
 
146
- result = run_episode_with_text_policy(task_id, _once)
147
- rewards.append(result.score)
 
148
  return rewards
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  __all__ = [
152
  "EpisodeResult",
 
153
  "TextPolicyFn",
 
154
  "compute_reward",
155
  "run_episode_with_text_policy",
156
  ]
 
4
 
5
  ``reward_fn(prompts, completions, **kwargs) -> list[float]``
6
 
7
+ The reward is a *per-action* match score against the scripted heuristic
8
+ oracle at the dataset's recorded environment state. Episode replay was
9
+ removed deliberately: previously every parse-failure fell back to the
10
+ heuristic and earned ~0.96 reward, which trained the model that emitting
11
+ garbage was optimal (group reward variance 0 GRPO advantage = 0
12
+ loss collapsed). Per-action scoring against the oracle gives high
13
+ variance even for an untrained model: parse-fails earn 0.0, valid-but-
14
+ wrong actions earn 0.1-0.7, exact matches earn 1.0.
15
  """
16
 
17
  from __future__ import annotations
 
46
  completions: tuple[str, ...] = field(default_factory=tuple)
47
 
48
 
49
+ @dataclass(frozen=True)
50
+ class StateActionSample:
51
+ """One (env_state, oracle_action) pair captured from a heuristic rollout."""
52
+
53
+ task_id: str
54
+ state_step: int
55
+ prompt: str
56
+ oracle_action_type: str
57
 
58
+
59
+ def _heuristic_policy(observation_dict: dict[str, Any]) -> ChargebackOpsAction | None:
60
  try:
61
  from ..runners.benchmark_runner import heuristic_policy
62
  except ImportError: # pragma: no cover
63
  from runners.benchmark_runner import heuristic_policy
64
+ return heuristic_policy(observation_dict)
65
+
66
+
67
+ def _fallback_action(
68
+ observation: ChargebackOpsObservation,
69
+ ) -> ChargebackOpsAction | None:
70
+ """Scripted fallback used by the debug/eval rollout helper only.
71
+
72
+ NOTE: deliberately *not* used by :func:`compute_reward` — falling back
73
+ to the heuristic on parse-fail trains the model that garbage = good.
74
+ """
75
+
76
+ return _heuristic_policy(observation.model_dump())
77
 
78
 
79
  def run_episode_with_text_policy(
 
85
  ) -> EpisodeResult:
86
  """Roll one episode forward under a text-in / text-out policy.
87
 
88
+ Used for evaluation and debugging only. Falls back to the scripted
89
+ heuristic when the policy returns unparseable output, so the episode
90
+ always reaches a terminal state. **Not** used for training reward.
91
  """
92
 
93
  task = get_task(task_id)
 
127
  )
128
 
129
 
130
+ def _advance_to_state(
131
+ task_id: str, state_step: int
132
+ ) -> tuple[ChargebackOpsEnvironment, ChargebackOpsObservation] | None:
133
+ """Reset env and replay heuristic for ``state_step`` steps.
134
+
135
+ Returns ``None`` if the heuristic terminates the episode before
136
+ reaching ``state_step`` (e.g. dataset went stale).
137
+ """
138
+
139
+ env = ChargebackOpsEnvironment()
140
+ obs = env.reset(task_id=task_id)
141
+ for _ in range(state_step):
142
+ if obs.done:
143
+ return None
144
+ heur = _heuristic_policy(obs.model_dump())
145
+ if heur is None:
146
+ return None
147
+ obs = env.step(heur)
148
+ if obs.done:
149
+ return None
150
+ return env, obs
151
+
152
+
153
+ def _score_action_match(
154
+ action: ChargebackOpsAction,
155
+ heuristic: ChargebackOpsAction,
156
+ available_actions: list[str],
157
+ ) -> float:
158
+ """Score a model action against the oracle (heuristic) at this state.
159
+
160
+ Tiers (chosen for non-degenerate reward variance under GRPO sampling):
161
+
162
+ * 0.0 — parse-fail (handled by caller before calling this).
163
+ * 0.1 — parses, but action_type not in the env's allowed set at this
164
+ state. The model emitted valid JSON but picked an impossible move.
165
+ * 0.4 — same valid action_type as heuristic neighbourhood but a
166
+ different action_type than the oracle. Valid exploration, low credit.
167
+ * 0.7 — right action_type, wrong target (e.g. picked a different case
168
+ or system than the oracle). Right idea, wrong object.
169
+ * 1.0 — exact match on action_type + targeted fields.
170
+ """
171
+
172
+ if action.action_type not in available_actions:
173
+ return 0.1
174
+
175
+ if action.action_type != heuristic.action_type:
176
+ return 0.4
177
+
178
+ if heuristic.case_id and action.case_id != heuristic.case_id:
179
+ return 0.7
180
+
181
+ if heuristic.system_name and action.system_name != heuristic.system_name:
182
+ return 0.7
183
+
184
+ return 1.0
185
+
186
+
187
  def compute_reward(
188
  prompts: Sequence[str],
189
  completions: Sequence[str],
190
  *,
191
  task_ids: Sequence[str] | None = None,
192
+ state_steps: Sequence[int] | None = None,
193
  **_: Any,
194
  ) -> list[float]:
195
+ """GRPO-style per-action reward.
196
+
197
+ For each ``(task_id, state_step, completion)`` triple:
198
 
199
+ 1. Reset env to ``task_id`` and replay the heuristic for
200
+ ``state_step`` steps to land on the dataset state.
201
+ 2. Parse the completion into an action.
202
+ 3. Score the action against the heuristic oracle at that state via
203
+ :func:`_score_action_match`.
204
 
205
+ No fallback to the heuristic on parse-fail (the prior design did
206
+ this; it created a reward floor that flattened group variance and
207
+ starved GRPO of gradient signal).
208
+
209
+ ``state_steps`` defaults to all-zero (initial state) when omitted, so
210
+ legacy callers that only pass ``task_ids`` still work.
211
  """
212
 
213
  if task_ids is None:
 
217
  raise ValueError(
218
  "prompts, completions, and task_ids must all have the same length"
219
  )
220
+ if state_steps is None:
221
+ state_steps = [0] * len(prompts)
222
+ if len(state_steps) != len(prompts):
223
+ raise ValueError("state_steps must have the same length as prompts")
224
 
225
  rewards: list[float] = []
226
+ for task_id, state_step, completion in zip(task_ids, state_steps, completions):
227
+ action = action_from_completion(completion)
228
+ if action is None:
229
+ rewards.append(0.0)
230
+ continue
231
+
232
+ advanced = _advance_to_state(task_id, int(state_step))
233
+ if advanced is None:
234
+ rewards.append(0.0)
235
+ continue
236
+ _env, obs = advanced
237
 
238
+ heur = _heuristic_policy(obs.model_dump())
239
+ if heur is None:
240
+ rewards.append(0.0)
241
+ continue
 
242
 
243
+ rewards.append(
244
+ _score_action_match(action, heur, list(obs.available_actions))
245
+ )
246
  return rewards
247
 
248
 
249
+ def build_state_action_dataset(
250
+ task_ids: Sequence[str],
251
+ *,
252
+ max_states_per_task: int = 12,
253
+ ) -> list[dict[str, Any]]:
254
+ """Roll the heuristic on each task and capture (state, oracle) pairs.
255
+
256
+ For each task we reset, then step the heuristic forward and record
257
+ the prompt string + state_step at every state until termination or
258
+ ``max_states_per_task``. The resulting list is suitable as a TRL
259
+ dataset (each row carries ``prompt``, ``task_id``, ``state_step``).
260
+ """
261
+
262
+ samples: list[dict[str, Any]] = []
263
+ for task_id in task_ids:
264
+ env = ChargebackOpsEnvironment()
265
+ obs = env.reset(task_id=task_id)
266
+ for state_step in range(max_states_per_task):
267
+ if obs.done:
268
+ break
269
+ samples.append(
270
+ {
271
+ "task_id": task_id,
272
+ "state_step": state_step,
273
+ "prompt": build_prompt(obs.model_dump()),
274
+ }
275
+ )
276
+ heur = _heuristic_policy(obs.model_dump())
277
+ if heur is None:
278
+ break
279
+ obs = env.step(heur)
280
+ return samples
281
+
282
+
283
  __all__ = [
284
  "EpisodeResult",
285
+ "StateActionSample",
286
  "TextPolicyFn",
287
+ "build_state_action_dataset",
288
  "compute_reward",
289
  "run_episode_with_text_policy",
290
  ]