pauldebanshu19 commited on
Commit
bd00c06
·
1 Parent(s): 06abe10

Add training notebook and benchmark runner for ChargebackOps

Browse files

- Introduced `train_merchant_agent.ipynb` for end-to-end GRPO training of the merchant-side chargeback agent, including environment setup, model loading, training prompt dataset creation, and evaluation.
- Created `benchmark_runner.py` to implement scripted policies for benchmarking against the trained agent, including heuristic, escalate_all, concede_all, and naive policies.
- Updated `issuer_model.py` to refine rationale messages for issuer decisions.
- Added unit tests for the benchmark runner and training adapter to ensure policy validity and reward computation accuracy.
- Implemented training helpers in `training/__init__.py`, `env_adapter.py`, and `reward_adapter.py` to facilitate interaction with the GRPO trainer.

docs/RESULTS.md CHANGED
@@ -1,125 +1,71 @@
1
- # ChargebackOps — Baseline Results
2
 
3
- Reference numbers for the 10-task headline benchmark and the 28-task multi-seed stress grid.
4
- Captured on **2026-04-15** against `main` (Rubric system + `Gate(CaseAbandonedRubric)`
5
- composition, tightened `acceptable_strategies` on contest-optimal templates, expanded
6
- `_obvious_next_action` coverage, improved LLM prompt). Reproduce with the commands at the
7
- bottom; headline scores match to within ±1e-3 (float rounding).
8
 
9
- ## TL;DR
 
 
 
10
 
11
- | Agent | Avg score | Best task | Worst task | Provider calls |
12
- | --- | --- | --- | --- | --- |
13
- | **Bad policy** (concede-everything) | **0.199** | `generated_medium_s99` (0.442) | `generated_nightmare_s77` (0.053) | 0 |
14
- | **Heuristic** (no LLM, rule-based) | **0.724** | `goods_not_received_easy` / `fraud_signal_ambiguity` (0.968) | `generated_hard_s53` (0.440) | 0 |
15
- | **Heuristic + LLM tiebreak** (openrouter gpt-oss-120b) | **0.729** | `goods_not_received_easy` / `fraud_signal_ambiguity` / `generated_easy_s42` (0.958) | `generated_hard_s53` (0.440) | 7 (7 ✓ / 0 ✗) |
16
-
17
- **Key signal:** the bad policy vs. heuristic delta is **0.525** (72.4 → 19.9 = 264% spread).
18
- The `Gate(CaseAbandonedRubric)` wrapper around the per-case `WeightedSum` means a case left
19
- unresolved past its deadline hard-zeros — a lazy concede-everything agent cannot game the score,
20
- and a correct agent cannot trivially saturate it on hard tasks. The LLM-assisted run now edges
21
- ahead of the pure heuristic (+0.005) after the v1.1 prompt and `_obvious_next_action` upgrades;
22
- the LLM is invoked only **7 times** across the 10-task run (down from 19 in v1) because
23
- deterministic workflow states are now dispatched without a model call.
24
-
25
- ## Score Curve by Difficulty
26
-
27
- | Difficulty | Task count | Heuristic avg | LLM avg | Bad avg | Target band | Status |
28
- | --- | --- | --- | --- | --- | --- | --- |
29
- | easy | 3 | 0.964 | 0.964 | 0.323 | ≥ 0.90 | ✓ |
30
- | medium | 2 | 0.755 | 0.755 | 0.278 | 0.50 – 0.85 | ✓ |
31
- | hard | 3 | 0.635 | 0.651 | 0.113 | 0.50 – 0.75 | ✓ |
32
- | nightmare | 2 | 0.466 | 0.466 | 0.065 | ≤ 0.55 | ✓ |
33
 
34
- Observations:
35
- - The LLM-assisted run now **matches or narrowly beats** the heuristic on every difficulty band
36
- (overall +0.005). The old v1 regression — where the LLM dropped 0.56 on `fraud_signal_ambiguity`
37
- and 0.29 on `generated_medium_s99` — was caused by the model picking a concede strategy over
38
- contest at `set_strategy` time. `_obvious_next_action` now short-circuits all strategy picks
39
- so the heuristic-derived strategy is used directly, and the prompt explicitly lists the
40
- reason-code → optimal-strategy mapping for the remaining decision points. Provider call count
41
- fell from 19 to 7 because deterministic housekeeping (add_evidence, remove_evidence,
42
- submit_representment, set_strategy, resolve_case) is now bypassed entirely.
43
- - The LLM's remaining upside is on `queue_optimization_hard` (+0.049 over heuristic), where the
44
- queue-triage branching is genuine and the heuristic's fixed priority order leaves marginal
45
- value on the table.
46
- - Nightmare tasks cluster around **0.47** for the heuristic because the 15-step budget collides
47
- with 5-case portfolios that have deadline_step=3–5 per case. Missed deadlines that were
48
- *attempted* still land in the weighted sum (with 0 on the deadline dimension and ~0.55 from
49
- the other 85%); truly abandoned cases are zeroed by the `Gate(CaseAbandonedRubric)` wrapper.
50
- Not a scoring artifact: the bad-policy run shows the same tasks at ~0.065.
51
- - The deadline `Gate` is the v1 upgrade over a flat weighted sum: a case never even attempted by
52
- the deadline collapses completely, while a case resolved late still earns dimensional credit
53
- for evidence, strategy, and packet quality. This matches real chargeback operations — a missed
54
- representment is "case forfeit," while a late one takes a penalty but is still scored on what
55
- the merchant tried to do.
56
-
57
- ## Full Per-Task Table
58
-
59
- | Task ID | Difficulty | Cases | Heuristic | H steps | LLM | LLM steps | Bad | Bad steps |
60
- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
61
- | goods_not_received_easy | easy | 1 | 0.968 | 6 | 0.968 | 6 | 0.280 | 3 |
62
- | fraud_signal_ambiguity | easy | 1 | 0.968 | 7 | 0.968 | 7 | 0.280 | 3 |
63
- | generated_easy_s42 | easy | 1 | 0.958 | 7 | 0.958 | 7 | 0.408 | 3 |
64
- | generated_medium_s17 | medium | 2 | 0.809 | 10 | 0.809 | 10 | 0.114 | 12 |
65
- | generated_medium_s99 | medium | 2 | 0.701 | 9 | 0.701 | 9 | 0.442 | 12 |
66
- | queue_optimization_hard | hard | 3 | 0.802 | 12 | 0.850 | 11 | 0.129 | 15 |
67
- | generated_hard_s7 | hard | 2 | 0.663 | 5 | 0.663 | 5 | 0.120 | 12 |
68
- | generated_hard_s53 | hard | 3 | 0.440 | 6 | 0.440 | 6 | 0.089 | 15 |
69
- | generated_nightmare_s31 | nightmare | 5 | 0.486 | 15 | 0.486 | 15 | 0.077 | 15 |
70
- | generated_nightmare_s77 | nightmare | 5 | 0.445 | 15 | 0.445 | 15 | 0.053 | 15 |
71
- | **Average** | | | **0.724** | 9.2 | **0.729** | 9.0 | **0.199** | 10.5 |
72
-
73
- ## Multi-seed Stress Grid (7 seeds × 4 difficulties)
74
-
75
- Running the heuristic and bad-policy agents across seven generator seeds per difficulty (seeds
76
- 7, 17, 31, 42, 53, 77, 99) gives the statistically defensible version of the headline numbers.
77
- All runs are fully offline — no provider calls involved.
78
-
79
- | Difficulty | n | Heuristic mean ± std | Bad mean ± std |
80
  | --- | --- | --- | --- |
81
- | easy | 7 | 0.9696 ± 0.014 | 0.3346 ± 0.068 |
82
- | medium | 7 | 0.8411 ± 0.089 | 0.4369 ± 0.238 |
83
- | hard | 7 | 0.6245 ± 0.151 | 0.1299 ± 0.047 |
84
- | nightmare | 7 | 0.4121 ± 0.079 | 0.0635 ± 0.010 |
85
- | **OVERALL** | **28** | **0.7118 ± 0.235** | **0.2412 ± 0.194** |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  Observations:
88
- - Heuristic score decreases cleanly and monotonically with difficulty: 0.97 → 0.84 → 0.62 →
89
- 0.41. The difficulty gradient is real — not a labeling artifact.
90
- - Nightmare std is the tightest (0.079) because every nightmare task is constrained by the
91
- same step budget vs. case count collision. Hard is the widest (0.151) because case counts
92
- vary from 2 to 3 across seeds.
93
- - Bad policy shows wide variance on medium (±0.238) because some medium seeds generate
94
- concede-optimal templates (credit_not_processed, duplicate_processing) where
95
- concede-everything is trivially correct exactly the expected behavior of a discriminating
96
- rubric on a mixed task distribution.
97
- - Overall delta (heuristic − bad) across 28 runs: **0.4706**. The headline 10-task catalog
98
- delta (0.525) is within of the multi-seed delta, so the fixed-seed headline is not a
99
- cherry-picked result.
100
-
101
- ## Rubric Breakdown (single-case sanity check)
102
-
103
- For `goods_not_received_easy` under the heuristic, the 7-dimension breakdown from
104
- `ChargebackOpsEpisodeRubric` (weights sum to 1.0, Gate passes because the case was resolved
105
- before step 8):
106
-
107
- | Dimension | Weight | Score | Weighted contribution |
108
- | --- | --- | --- | --- |
109
- | strategy_correctness | 0.25 | 1.00 | 0.2500 |
110
- | evidence_quality | 0.20 | 0.90 | 0.1800 |
111
- | packet_validity | 0.15 | 1.00 | 0.1500 |
112
- | deadline_compliance | 0.15 | 1.00 | 0.1500 |
113
- | efficiency | 0.10 | 0.95 | 0.0950 |
114
- | outcome_quality | 0.10 | 1.00 | 0.1000 |
115
- | note_quality | 0.05 | 0.85 | 0.0425 |
116
- | **Total** | **1.00** | — | **0.9675** |
117
-
118
- Per-dimension scores captured by reading `rubric.last_score` on every child in the
119
- `ChargebackOpsEpisodeRubric.case_rubric.aggregator` tree after one forward pass — exactly the
120
- introspection path an RL trainer would use for credit assignment. The small gaps
121
- (`evidence_quality=0.90`, `efficiency=0.95`, `note_quality=0.85`) are the real headroom an
122
- LLM-fine-tuned agent is expected to close.
123
 
124
  ## Rubric Composition (what's wired)
125
 
@@ -129,65 +75,75 @@ ChargebackOpsEpisodeRubric
129
  ├── deadline_gate: Gate(threshold=1.0) # hard-zero if case abandoned past deadline
130
  │ └── CaseAbandonedRubric
131
  └── aggregator: WeightedSum # weights sum to 1.0
132
- ├── rubric_0: StrategyCorrectnessRubric # weight 0.25
133
- ├── rubric_1: EvidenceQualityRubric # weight 0.20
134
- ├── rubric_2: PacketValidityRubric # weight 0.15
135
- ├── rubric_3: DeadlineComplianceRubric # weight 0.15 (dimension-level, not gate)
136
- ├── rubric_4: EfficiencyRubric # weight 0.10
137
- ├── rubric_5: OutcomeQualityRubric # weight 0.10
138
- ── rubric_6: NoteQualityRubric # weight 0.05
 
139
  ```
140
 
141
- Every node is an OpenEnv `Rubric` subclass and every node exposes `last_score` after forward.
142
- `env.rubric.named_rubrics()` walks the tree and returns 11 named children — the hook-compatible
143
- surface for a judge or trainer to introspect per-dimension scores.
 
 
 
 
 
 
 
144
 
145
  ## Reproducing These Numbers
146
 
147
  ```bash
148
- # Activate the project's venv
149
  source ~/python/bin/activate
150
 
151
- # 1. Headline 10-task run (heuristic + bad policy, no network)
152
  python - <<'PY'
153
- from evaluation.agent_brutal_audit import run_episode
154
- from scenarios.simulation import list_tasks
155
- for t in list_tasks():
156
- h = run_episode(t.task_id, policy='heuristic')
157
- b = run_episode(t.task_id, policy='bad')
158
- print(f"{t.task_id:32s} heur={h['score']:.4f} bad={b['score']:.4f}")
 
 
 
 
 
 
 
 
 
 
159
  PY
 
160
 
161
- # 2. Multi-seed stress grid (28 runs across 7 seeds × 4 difficulties, no network)
162
- python - <<'PY'
163
- from statistics import mean, stdev
164
- from evaluation.agent_brutal_audit import run_episode
165
- for d in ("easy","medium","hard","nightmare"):
166
- hs, bs = [], []
167
- for s in (7, 17, 31, 42, 53, 77, 99):
168
- hs.append(run_episode(f"generated_{d}_s{s}", policy='heuristic')['score'])
169
- bs.append(run_episode(f"generated_{d}_s{s}", policy='bad')['score'])
170
- print(f"{d:10s} heur={mean(hs):.4f}±{stdev(hs):.4f} bad={mean(bs):.4f}±{stdev(bs):.4f}")
171
- PY
172
 
173
- # 3. LLM tiebreak run (requires OPENROUTER_API_KEY in .env)
174
  python -m runners.baseline_runner | tee /tmp/baseline_run.json
175
  ```
176
 
177
  ## Hardware / Environment
178
 
179
- - Python 3.12.13, pytest 7.4.3
180
- - `openenv-core==0.2.3`, `pydantic==2.12.5`, `openai==2.31.0`
181
- - Provider: OpenRouter (model `openai/gpt-oss-120b`), all 7 decision calls succeeded, zero retries
182
- - Average end-to-end episode wall-clock: ~0.8s (heuristic), ~1.8s (with LLM tiebreak — down from
183
- ~2.5s in v1 because `_obvious_next_action` bypasses most model calls)
184
- - Full test suite: 22/22 passing, `openenv validate .` clean, Docker build clean
185
 
186
  ## What This Table Does Not Show
187
 
188
- - **Per-dimension score dispersion across the full catalog** — the table above shows one task's
189
- breakdown. An introspection demo command exists for walking `env.rubric.named_rubrics()` on
190
- any run: see `README.md` "Rubric introspection".
191
- - **RL training curves** — ChargebackOps is a ready environment, not a trained agent. Anyone
192
- wiring this into Gym/SB3/CleanRL is expected to produce training curves separately; the
193
- rubric tree is the machinery they would hook into for credit assignment.
 
 
 
 
 
 
1
+ # ChargebackOps — Benchmark Results
2
 
3
+ Reference numbers for the 10-task headline catalog and the 28-task
4
+ multi-seed stress grid against the current multi-round adversarial
5
+ environment. Reproduce with the commands at the bottom; scores match to
6
+ within ±1e-3 (float rounding).
 
7
 
8
+ Captured on **2026-04-19** on `main` with the 8-dimension case rubric
9
+ (weights `(0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.05, 0.20)`,
10
+ `escalation_roi` dimension added) and the deterministic Issuer agent
11
+ (LLM softening disabled — benchmarks stay fully offline).
12
 
13
+ ## TL;DR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ | Policy | Headline avg (10 tasks) | Multi-seed avg (28 tasks) | Provider calls |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  | --- | --- | --- | --- |
17
+ | **naive** (empty packet submit) | **0.0000** | **0.0000** | 0 |
18
+ | **concede_all** (always `accept_chargeback`) | **0.5666** | **0.5634** | 0 |
19
+ | **escalate_all** (contest, then always escalate) | **0.7731** | **0.7647** | 0 |
20
+ | **heuristic** (first-candidate rule-based pick) | **0.7731** | **0.7647** | 0 |
21
+
22
+ **Discrimination delta** (heuristic − naive) is **0.7731** on the headline
23
+ catalog and **0.7647** on the multi-seed grid — well above the 0.40 target.
24
+
25
+ `escalate_all` ties with `heuristic` because the heuristic wins the
26
+ representment on most tasks in the first review; the environment never
27
+ enters the pre-arbitration branch and the escalation override never
28
+ fires. That match is a signal, not a bug: when the scripted merchant
29
+ packet is strong, escalation is never rational in the current
30
+ deterministic Issuer, so the two policies produce identical trajectories.
31
+
32
+ ## Score Curve by Difficulty (multi-seed grid, 7 seeds / difficulty)
33
+
34
+ | Difficulty | n | heuristic | escalate_all | concede_all | naive |
35
+ | --- | --- | --- | --- | --- | --- |
36
+ | easy | 7 | 0.974 | 0.974 | 0.470 | 0.000 |
37
+ | medium | 7 | 0.876 | 0.876 | 0.699 | 0.000 |
38
+ | hard | 7 | 0.701 | 0.701 | 0.584 | 0.000 |
39
+ | nightmare | 7 | 0.508 | 0.508 | 0.501 | 0.000 |
40
 
41
  Observations:
42
+ - Heuristic score decreases monotonically with difficulty
43
+ (0.97 → 0.88 → 0.70 → 0.51). The difficulty gradient is real.
44
+ - `concede_all` narrows the gap at nightmare (0.508 vs 0.501) because
45
+ the 15-step budget vs. 5-case portfolio forces the heuristic to
46
+ forfeit cases deadline-wise, while conceding is cheap per case.
47
+ This is the expected `Gate(CaseAbandonedRubric)` behavior.
48
+ - `naive` sits flat at 0.000 because an empty packet fails the
49
+ packet-validity gate and every case is scored as unresolved /
50
+ abandoned.
51
+
52
+ ## Headline Per-Task Table (10 tasks, offline)
53
+
54
+ | Task ID | Difficulty | heuristic | escalate_all | concede_all | naive |
55
+ | --- | --- | --- | --- | --- | --- |
56
+ | goods_not_received_easy | easy | 0.968 | 0.968 | 0.580 | 0.000 |
57
+ | fraud_signal_ambiguity | easy | 0.968 | 0.968 | 0.580 | 0.000 |
58
+ | queue_optimization_hard | hard | 0.802 | 0.802 | 0.576 | 0.000 |
59
+ | generated_easy_s42 | easy | 0.958 | 0.958 | 0.533 | 0.000 |
60
+ | generated_medium_s17 | medium | 0.861 | 0.861 | 0.623 | 0.000 |
61
+ | generated_medium_s99 | medium | 0.770 | 0.770 | 0.727 | 0.000 |
62
+ | generated_hard_s7 | hard | 0.724 | 0.724 | 0.615 | 0.000 |
63
+ | generated_hard_s53 | hard | 0.544 | 0.544 | 0.612 | 0.000 |
64
+ | generated_nightmare_s31 | nightmare | 0.602 | 0.602 | 0.529 | 0.000 |
65
+ | generated_nightmare_s77 | nightmare | 0.474 | 0.474 | 0.537 | 0.000 |
66
+ | **Average** | | **0.7731** | **0.7731** | **0.5666** | **0.0000** |
67
+
68
+ (Per-task numbers from `runners.benchmark_runner.run_policy_sweep()`.)
 
 
 
 
 
 
 
 
69
 
70
  ## Rubric Composition (what's wired)
71
 
 
75
  ├── deadline_gate: Gate(threshold=1.0) # hard-zero if case abandoned past deadline
76
  │ └── CaseAbandonedRubric
77
  └── aggregator: WeightedSum # weights sum to 1.0
78
+ ├── rubric_0: StrategyCorrectnessRubric # 0.20
79
+ ├── rubric_1: EvidenceQualityRubric # 0.15
80
+ ├── rubric_2: PacketValidityRubric # 0.10
81
+ ├── rubric_3: DeadlineComplianceRubric # 0.10
82
+ ├── rubric_4: EfficiencyRubric # 0.10
83
+ ├── rubric_5: OutcomeQualityRubric # 0.10
84
+ ── rubric_6: NoteQualityRubric # 0.05
85
+ └── rubric_7: EscalationROIRubric # 0.20
86
  ```
87
 
88
+ Every node is an OpenEnv `Rubric` subclass and every node exposes
89
+ `last_score` after forward. `env.rubric.named_rubrics()` walks the tree
90
+ and returns the hook-compatible surface for a judge or trainer to
91
+ introspect per-dimension scores.
92
+
93
+ `EscalationROIRubric` encodes the economic rule that escalating to
94
+ network arbitration is rational only when
95
+ `P(win) × dispute_amount > arb_fee` (fee = $250/side). Scripted policies
96
+ that escalate negative-EV cases (or concede positive-EV cases) are
97
+ penalised on this axis.
98
 
99
  ## Reproducing These Numbers
100
 
101
  ```bash
 
102
  source ~/python/bin/activate
103
 
 
104
  python - <<'PY'
105
+ from runners.benchmark_runner import run_policy_sweep, run_multi_seed
106
+
107
+ headline = run_policy_sweep()
108
+ print("HEADLINE (10 tasks)")
109
+ for s in headline.policies:
110
+ print(f" {s.policy:14s} mean={s.mean_score:.4f} stdev={s.stdev:.4f}")
111
+ print(f" delta (heuristic - naive): {headline.discrimination_delta}")
112
+
113
+ grid = run_multi_seed(
114
+ seeds=[7, 17, 31, 42, 53, 77, 99],
115
+ difficulties=["easy", "medium", "hard", "nightmare"],
116
+ )
117
+ print("MULTI-SEED (28 tasks)")
118
+ for s in grid.policies:
119
+ print(f" {s.policy:14s} mean={s.mean_score:.4f} stdev={s.stdev:.4f}")
120
+ print(f" delta (heuristic - naive): {grid.discrimination_delta}")
121
  PY
122
+ ```
123
 
124
+ Optional LLM-assisted baseline (requires `OPENROUTER_API_KEY`):
 
 
 
 
 
 
 
 
 
 
125
 
126
+ ```bash
127
  python -m runners.baseline_runner | tee /tmp/baseline_run.json
128
  ```
129
 
130
  ## Hardware / Environment
131
 
132
+ - Python 3.12, pytest 8.x
133
+ - `openenv-core`, `pydantic`, `openai` per `pyproject.toml`
134
+ - No provider calls for the four scripted policies all results fully offline
135
+ - Full test suite: **65/65 passing**
 
 
136
 
137
  ## What This Table Does Not Show
138
 
139
+ - **Per-dimension score dispersion across the full catalog** — the
140
+ headline table aggregates to one scalar per task. Walk
141
+ `env.rubric.named_rubrics()` on any run for the per-dimension
142
+ introspection path.
143
+ - **LLM-trained merchant curves** this environment is the substrate;
144
+ training curves are produced separately by the TRL notebook.
145
+ - **Adversarial Issuer with LLM softening enabled** — softening is
146
+ gated on API keys. With keys set, the Issuer can override the
147
+ deterministic midpoint in the ambiguity band; that configuration is
148
+ tested in `tests/test_llm_softening.py` but is not part of the
149
+ offline benchmark numbers above.
notebooks/train_merchant_agent.ipynb ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Train Merchant Agent on ChargebackOps\n",
8
+ "\n",
9
+ "End-to-end GRPO training skeleton for the merchant-side chargeback agent.\n",
10
+ "\n",
11
+ "- Environment: `ChargebackOpsEnvironment` (multi-round adversarial Issuer, arbitration ROI).\n",
12
+ "- Text interface: `training.env_adapter` (prompt build, completion parse).\n",
13
+ "- Reward: `training.reward_adapter.compute_reward` — returns the normalised episode score in `[0, 1]`.\n",
14
+ "- Trainer: `trl.GRPOTrainer` on a small base model so this fits a free Colab T4.\n",
15
+ "\n",
16
+ "The first run is intentionally tiny (1 step, micro-batch of 2). Once the wiring is green, bump `max_steps` to 200 for the real curve."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {},
22
+ "source": [
23
+ "## 1. Colab setup\n",
24
+ "\n",
25
+ "Installs TRL, transformers, and the ChargebackOps package itself. Skip if the environment already has them."
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "%%capture\n",
35
+ "import sys\n",
36
+ "if 'google.colab' in sys.modules:\n",
37
+ " !pip install --quiet trl==0.11.4 transformers==4.44.2 accelerate==0.33.0 peft==0.12.0 bitsandbytes==0.43.3\n",
38
+ " !git clone https://github.com/example/chargebackops.git /content/chargebackops\n",
39
+ " %cd /content/chargebackops\n",
40
+ " !pip install --quiet -e ."
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {},
46
+ "source": [
47
+ "## 2. Sanity-check the env adapter\n",
48
+ "\n",
49
+ "Run one scripted episode via the text adapter to confirm prompts render and rewards land inside `[0, 1]`."
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "from training.env_adapter import build_prompt\n",
59
+ "from training.reward_adapter import run_episode_with_text_policy\n",
60
+ "\n",
61
+ "def heuristic_text_policy(prompt: str) -> str:\n",
62
+ " # Force the fallback path so the scripted heuristic drives the episode.\n",
63
+ " return ''\n",
64
+ "\n",
65
+ "result = run_episode_with_text_policy('goods_not_received_easy', heuristic_text_policy)\n",
66
+ "print('score', result.score, 'steps', result.steps_used, 'invalid', result.invalid_actions)"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "markdown",
71
+ "metadata": {},
72
+ "source": [
73
+ "## 3. Load a small base model\n",
74
+ "\n",
75
+ "`Qwen/Qwen2.5-0.5B-Instruct` fits on a free T4 with LoRA adapters. Swap in a bigger instruct model if you have the memory budget."
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "import torch\n",
85
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
86
+ "\n",
87
+ "MODEL_ID = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
88
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
89
+ "if tokenizer.pad_token is None:\n",
90
+ " tokenizer.pad_token = tokenizer.eos_token\n",
91
+ "model = AutoModelForCausalLM.from_pretrained(\n",
92
+ " MODEL_ID,\n",
93
+ " torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n",
94
+ " device_map='auto',\n",
95
+ ")"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "metadata": {},
101
+ "source": [
102
+ "## 4. Build the training prompt dataset\n",
103
+ "\n",
104
+ "GRPO expects a list of prompts; it generates K completions per prompt internally and scores each with `compute_reward`. We sample prompts from fresh environment resets across the headline catalog."
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "from datasets import Dataset\n",
114
+ "from scenarios.simulation import list_tasks\n",
115
+ "from server.chargeback_ops_environment import ChargebackOpsEnvironment\n",
116
+ "from training.env_adapter import build_prompt\n",
117
+ "\n",
118
+ "def sample_prompts(n: int = 32):\n",
119
+ " tasks = list_tasks()\n",
120
+ " prompts, task_ids = [], []\n",
121
+ " for i in range(n):\n",
122
+ " task = tasks[i % len(tasks)]\n",
123
+ " env = ChargebackOpsEnvironment()\n",
124
+ " obs = env.reset(task_id=task.task_id).model_dump()\n",
125
+ " prompts.append(build_prompt(obs))\n",
126
+ " task_ids.append(task.task_id)\n",
127
+ " return Dataset.from_dict({'prompt': prompts, 'task_id': task_ids})\n",
128
+ "\n",
129
+ "train_dataset = sample_prompts(32)\n",
130
+ "len(train_dataset)"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {},
136
+ "source": [
137
+ "## 5. GRPO training step\n",
138
+ "\n",
139
+ "Starts with `max_steps=1` — just verify the gradient path closes. Increase to 200 for the real curve."
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "from trl import GRPOConfig, GRPOTrainer\n",
149
+ "from training.reward_adapter import compute_reward\n",
150
+ "\n",
151
+ "def reward_fn(prompts, completions, **kwargs):\n",
152
+ " task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n",
153
+ " return compute_reward(prompts, completions, task_ids=task_ids)\n",
154
+ "\n",
155
+ "config = GRPOConfig(\n",
156
+ " output_dir='./grpo-merchant-agent',\n",
157
+ " per_device_train_batch_size=2,\n",
158
+ " num_generations=4,\n",
159
+ " max_prompt_length=1024,\n",
160
+ " max_completion_length=128,\n",
161
+ " learning_rate=5e-6,\n",
162
+ " max_steps=1,\n",
163
+ " logging_steps=1,\n",
164
+ " save_steps=50,\n",
165
+ " gradient_accumulation_steps=1,\n",
166
+ " bf16=torch.cuda.is_available(),\n",
167
+ " report_to='none',\n",
168
+ ")\n",
169
+ "trainer = GRPOTrainer(\n",
170
+ " model=model,\n",
171
+ " processing_class=tokenizer,\n",
172
+ " reward_funcs=[reward_fn],\n",
173
+ " args=config,\n",
174
+ " train_dataset=train_dataset,\n",
175
+ ")\n",
176
+ "trainer.train()"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {},
182
+ "source": [
183
+ "## 6. Evaluate the trained policy\n",
184
+ "\n",
185
+ "Runs one rollout per headline task with the trained model as the text policy and reports the per-task scores plus the overall mean."
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "import torch\n",
195
+ "from statistics import mean\n",
196
+ "\n",
197
+ "def model_text_policy(prompt: str) -> str:\n",
198
+ " inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(model.device)\n",
199
+ " with torch.no_grad():\n",
200
+ " out = model.generate(**inputs, max_new_tokens=128, do_sample=False, temperature=0.0, pad_token_id=tokenizer.pad_token_id)\n",
201
+ " return tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
202
+ "\n",
203
+ "scores = []\n",
204
+ "for task in list_tasks():\n",
205
+ " result = run_episode_with_text_policy(task.task_id, model_text_policy)\n",
206
+ " scores.append(result.score)\n",
207
+ " print(f'{task.task_id:32s} score={result.score:.4f} steps={result.steps_used} invalid={result.invalid_actions}')\n",
208
+ "print('mean score', mean(scores))"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "markdown",
213
+ "metadata": {},
214
+ "source": [
215
+ "## Next steps\n",
216
+ "\n",
217
+ "1. Bump `max_steps` in step 5 to 200 (save checkpoints at 0/50/100/150/200).\n",
218
+ "2. Record per-checkpoint mean score and plot the curve.\n",
219
+ "3. Compare against the fixed-policy baselines in `runners/benchmark_runner.py`.\n"
220
+ ]
221
+ }
222
+ ],
223
+ "metadata": {
224
+ "kernelspec": {
225
+ "display_name": "Python 3",
226
+ "language": "python",
227
+ "name": "python3"
228
+ },
229
+ "language_info": {
230
+ "name": "python",
231
+ "version": "3.12"
232
+ }
233
+ },
234
+ "nbformat": 4,
235
+ "nbformat_minor": 5
236
+ }
runners/benchmark_runner.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scripted-policy benchmark runner for ChargebackOps.
2
+
3
+ Drives a fixed set of non-learning policies through the full environment so
4
+ the trained-merchant vs. baseline discrimination delta can be measured
5
+ without calling an LLM provider. Every policy returned here is deterministic
6
+ and offline.
7
+
8
+ Policies
9
+ --------
10
+ * ``heuristic`` — the Round 1 first-candidate pick (best scripted baseline).
11
+ * ``concede_all`` — always set strategy to ``accept_chargeback`` and resolve.
12
+ * ``escalate_all`` — contest like the heuristic, then escalate in the
13
+ pre-arb and arbitration steps regardless of evidence strength.
14
+ * ``naive`` — submit an empty packet / take a minimal path to terminal.
15
+
16
+ The runner also exposes :func:`run_multi_seed` which sweeps each policy
17
+ over the headline catalog plus extra generator seeds so the benchmark
18
+ table in ``docs/RESULTS_V2.md`` is reproducible from one command.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass
24
+ from statistics import mean, pstdev
25
+ from typing import Any, Callable, Iterable, Sequence
26
+
27
+ try:
28
+ from ..core.models import ChargebackOpsAction
29
+ from ..scenarios.simulation import TaskScenario, get_task, list_tasks
30
+ from ..server.chargeback_ops_environment import ChargebackOpsEnvironment
31
+ from .baseline_runner import candidate_actions
32
+ except ImportError: # pragma: no cover
33
+ from core.models import ChargebackOpsAction
34
+ from scenarios.simulation import TaskScenario, get_task, list_tasks
35
+ from server.chargeback_ops_environment import ChargebackOpsEnvironment
36
+ from runners.baseline_runner import candidate_actions
37
+
38
+
39
+ PolicyFn = Callable[[dict[str, Any]], ChargebackOpsAction | None]
40
+
41
+
42
+ POLICY_NAMES: tuple[str, ...] = (
43
+ "heuristic",
44
+ "escalate_all",
45
+ "concede_all",
46
+ "naive",
47
+ )
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Scripted policies
52
+ # ---------------------------------------------------------------------------
53
+
54
+
55
+ def heuristic_policy(observation: dict[str, Any]) -> ChargebackOpsAction | None:
56
+ """First-candidate pick from the existing candidate generator."""
57
+
58
+ candidates = candidate_actions(observation)
59
+ if not candidates:
60
+ return None
61
+ return candidates[0].action
62
+
63
+
64
+ def escalate_all_policy(observation: dict[str, Any]) -> ChargebackOpsAction | None:
65
+ """Play like the heuristic, but always push terminal disputes into arbitration."""
66
+
67
+ available = set(observation.get("available_actions", []))
68
+ visible_case = observation.get("visible_case")
69
+ if visible_case is not None and "escalate_to_arbitration" in available:
70
+ return ChargebackOpsAction(
71
+ action_type="escalate_to_arbitration",
72
+ case_id=visible_case["case_id"],
73
+ )
74
+ return heuristic_policy(observation)
75
+
76
+
77
+ def concede_all_policy(observation: dict[str, Any]) -> ChargebackOpsAction | None:
78
+ """Always accept the chargeback. Never contests, never escalates."""
79
+
80
+ available = set(observation.get("available_actions", []))
81
+ visible_case = observation.get("visible_case")
82
+ queue = observation.get("queue", [])
83
+
84
+ if visible_case is None:
85
+ open_cases = [item for item in queue if item["status"] == "open"]
86
+ if not open_cases:
87
+ return None
88
+ target = sorted(
89
+ open_cases,
90
+ key=lambda item: (item["steps_until_deadline"], -item["amount"]),
91
+ )[0]
92
+ return ChargebackOpsAction(
93
+ action_type="select_case", case_id=target["case_id"]
94
+ )
95
+
96
+ case_id = visible_case["case_id"]
97
+ if visible_case["status"] != "open":
98
+ open_cases = [
99
+ item
100
+ for item in queue
101
+ if item["status"] == "open" and item["case_id"] != case_id
102
+ ]
103
+ if not open_cases:
104
+ return None
105
+ target = sorted(
106
+ open_cases,
107
+ key=lambda item: (item["steps_until_deadline"], -item["amount"]),
108
+ )[0]
109
+ return ChargebackOpsAction(
110
+ action_type="select_case", case_id=target["case_id"]
111
+ )
112
+
113
+ if "accept_arbitration_loss" in available:
114
+ return ChargebackOpsAction(
115
+ action_type="accept_arbitration_loss", case_id=case_id
116
+ )
117
+
118
+ if visible_case.get("current_strategy") != "accept_chargeback" and (
119
+ "set_strategy" in available
120
+ ):
121
+ return ChargebackOpsAction(
122
+ action_type="set_strategy",
123
+ case_id=case_id,
124
+ strategy="accept_chargeback",
125
+ )
126
+
127
+ if "resolve_case" in available:
128
+ return ChargebackOpsAction(
129
+ action_type="resolve_case",
130
+ case_id=case_id,
131
+ strategy="accept_chargeback",
132
+ )
133
+
134
+ return heuristic_policy(observation)
135
+
136
+
137
+ def naive_policy(observation: dict[str, Any]) -> ChargebackOpsAction | None:
138
+ """Minimum-effort agent: select a case, submit without evidence or policy work."""
139
+
140
+ available = set(observation.get("available_actions", []))
141
+ visible_case = observation.get("visible_case")
142
+ queue = observation.get("queue", [])
143
+
144
+ if visible_case is None:
145
+ open_cases = [item for item in queue if item["status"] == "open"]
146
+ if not open_cases:
147
+ return None
148
+ return ChargebackOpsAction(
149
+ action_type="select_case", case_id=open_cases[0]["case_id"]
150
+ )
151
+
152
+ case_id = visible_case["case_id"]
153
+ if visible_case["status"] != "open":
154
+ open_cases = [
155
+ item
156
+ for item in queue
157
+ if item["status"] == "open" and item["case_id"] != case_id
158
+ ]
159
+ if not open_cases:
160
+ return None
161
+ return ChargebackOpsAction(
162
+ action_type="select_case", case_id=open_cases[0]["case_id"]
163
+ )
164
+
165
+ if "accept_arbitration_loss" in available:
166
+ return ChargebackOpsAction(
167
+ action_type="accept_arbitration_loss", case_id=case_id
168
+ )
169
+
170
+ if "submit_representment" in available:
171
+ return ChargebackOpsAction(
172
+ action_type="submit_representment", case_id=case_id
173
+ )
174
+
175
+ if "respond_to_pre_arb" in available:
176
+ return ChargebackOpsAction(
177
+ action_type="respond_to_pre_arb", case_id=case_id
178
+ )
179
+
180
+ if "resolve_case" in available:
181
+ return ChargebackOpsAction(
182
+ action_type="resolve_case",
183
+ case_id=case_id,
184
+ strategy="accept_chargeback",
185
+ )
186
+
187
+ return heuristic_policy(observation)
188
+
189
+
190
+ POLICY_REGISTRY: dict[str, PolicyFn] = {
191
+ "heuristic": heuristic_policy,
192
+ "escalate_all": escalate_all_policy,
193
+ "concede_all": concede_all_policy,
194
+ "naive": naive_policy,
195
+ }
196
+
197
+
198
+ # ---------------------------------------------------------------------------
199
+ # Episode / sweep driver
200
+ # ---------------------------------------------------------------------------
201
+
202
+
203
+ @dataclass(frozen=True)
204
+ class TaskScore:
205
+ """One policy × task result."""
206
+
207
+ policy: str
208
+ task_id: str
209
+ score: float
210
+ steps_used: int
211
+
212
+
213
+ @dataclass(frozen=True)
214
+ class PolicySummary:
215
+ """Aggregate of one policy across a task list."""
216
+
217
+ policy: str
218
+ mean_score: float
219
+ stdev: float
220
+ tasks: tuple[TaskScore, ...]
221
+
222
+
223
+ @dataclass(frozen=True)
224
+ class BenchmarkResult:
225
+ """Output of a full policy sweep."""
226
+
227
+ policies: tuple[PolicySummary, ...]
228
+ discrimination_delta: float # heuristic minus naive
229
+
230
+ def to_dict(self) -> dict[str, Any]:
231
+ return {
232
+ "discrimination_delta": self.discrimination_delta,
233
+ "policies": [
234
+ {
235
+ "policy": summary.policy,
236
+ "mean_score": summary.mean_score,
237
+ "stdev": summary.stdev,
238
+ "tasks": [
239
+ {
240
+ "task_id": task.task_id,
241
+ "score": task.score,
242
+ "steps_used": task.steps_used,
243
+ }
244
+ for task in summary.tasks
245
+ ],
246
+ }
247
+ for summary in self.policies
248
+ ],
249
+ }
250
+
251
+
252
+ def run_policy_on_task(policy: PolicyFn, task: TaskScenario) -> TaskScore:
253
+ """Drive one policy through one task. Fully offline, no LLM calls."""
254
+
255
+ env = ChargebackOpsEnvironment()
256
+ observation = env.reset(task_id=task.task_id)
257
+ max_steps = task.max_steps + 5 # small safety margin
258
+ steps = 0
259
+ while not observation.done and steps < max_steps:
260
+ action = policy(observation.model_dump())
261
+ if action is None:
262
+ break
263
+ observation = env.step(action)
264
+ steps += 1
265
+
266
+ report = env.state.grader_report
267
+ score = float(report.normalized_score) if report is not None else 0.0
268
+ return TaskScore(
269
+ policy=policy.__name__,
270
+ task_id=task.task_id,
271
+ score=score,
272
+ steps_used=env.state.step_count,
273
+ )
274
+
275
+
276
+ def run_policy_sweep(
277
+ policy_names: Sequence[str] = POLICY_NAMES,
278
+ tasks: Iterable[TaskScenario] | None = None,
279
+ ) -> BenchmarkResult:
280
+ """Run each named policy across the headline catalog (or provided tasks)."""
281
+
282
+ task_list = list(tasks) if tasks is not None else list_tasks()
283
+
284
+ summaries: list[PolicySummary] = []
285
+ for name in policy_names:
286
+ if name not in POLICY_REGISTRY:
287
+ raise KeyError(f"Unknown policy '{name}'. Known: {sorted(POLICY_REGISTRY)}")
288
+ policy = POLICY_REGISTRY[name]
289
+ task_scores: list[TaskScore] = []
290
+ for task in task_list:
291
+ score = run_policy_on_task(policy, task)
292
+ task_scores.append(
293
+ TaskScore(
294
+ policy=name,
295
+ task_id=score.task_id,
296
+ score=score.score,
297
+ steps_used=score.steps_used,
298
+ )
299
+ )
300
+ scores = [item.score for item in task_scores]
301
+ summaries.append(
302
+ PolicySummary(
303
+ policy=name,
304
+ mean_score=round(mean(scores), 4) if scores else 0.0,
305
+ stdev=round(pstdev(scores), 4) if len(scores) > 1 else 0.0,
306
+ tasks=tuple(task_scores),
307
+ )
308
+ )
309
+
310
+ by_name = {summary.policy: summary for summary in summaries}
311
+ delta = 0.0
312
+ if "heuristic" in by_name and "naive" in by_name:
313
+ delta = round(
314
+ by_name["heuristic"].mean_score - by_name["naive"].mean_score, 4
315
+ )
316
+ return BenchmarkResult(policies=tuple(summaries), discrimination_delta=delta)
317
+
318
+
319
+ def run_multi_seed(
320
+ seeds: Sequence[int],
321
+ difficulties: Sequence[str] = ("easy", "medium", "hard", "nightmare"),
322
+ policy_names: Sequence[str] = POLICY_NAMES,
323
+ ) -> BenchmarkResult:
324
+ """Sweep each policy over ``seeds × difficulties`` generated tasks.
325
+
326
+ Used for the multi-seed grid cited in the PRD's Day-5 exit criteria.
327
+ """
328
+
329
+ tasks: list[TaskScenario] = []
330
+ for difficulty in difficulties:
331
+ for seed in seeds:
332
+ task_id = f"generated_{difficulty}_s{seed}"
333
+ tasks.append(get_task(task_id))
334
+ return run_policy_sweep(policy_names, tasks=tasks)
335
+
336
+
337
+ __all__ = [
338
+ "POLICY_NAMES",
339
+ "POLICY_REGISTRY",
340
+ "PolicyFn",
341
+ "BenchmarkResult",
342
+ "PolicySummary",
343
+ "TaskScore",
344
+ "heuristic_policy",
345
+ "escalate_all_policy",
346
+ "concede_all_policy",
347
+ "naive_policy",
348
+ "run_policy_on_task",
349
+ "run_policy_sweep",
350
+ "run_multi_seed",
351
+ ]
scenarios/issuer_model.py CHANGED
@@ -147,15 +147,15 @@ class IssuerAgent:
147
  decision=IssuerDecision.ACCEPT,
148
  evidence_strength_score=score,
149
  rationale=(
150
- f"Round {round_number}: pre-arb evidence brings the packet "
151
- f"to {score:.2f}, above the 0.60 acceptance bar."
152
  ),
153
  )
154
  return IssuerReview(
155
  decision=IssuerDecision.ESCALATE_TO_ARBITRATION,
156
  evidence_strength_score=score,
157
  rationale=(
158
- f"Round {round_number}: packet still scores {score:.2f}; "
159
  f"escalating to network arbitration."
160
  ),
161
  )
@@ -166,7 +166,7 @@ class IssuerAgent:
166
  decision=IssuerDecision.ACCEPT,
167
  evidence_strength_score=score,
168
  rationale=(
169
- f"Round 1: packet scores {score:.2f}, clearing the 0.70 acceptance bar."
170
  ),
171
  )
172
  if score <= ROUND1_REJECT_THRESHOLD:
@@ -174,7 +174,7 @@ class IssuerAgent:
174
  decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
175
  evidence_strength_score=score,
176
  rationale=(
177
- f"Round 1: packet scores {score:.2f}, below the 0.40 floor; "
178
  f"requesting compelling evidence."
179
  ),
180
  )
@@ -191,7 +191,7 @@ class IssuerAgent:
191
  decision=IssuerDecision.ACCEPT,
192
  evidence_strength_score=score,
193
  rationale=(
194
- f"Round 1 ambiguity band: packet scores {score:.2f} — "
195
  f"LLM softening accepted."
196
  ),
197
  used_llm_softening=True,
@@ -201,7 +201,7 @@ class IssuerAgent:
201
  decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
202
  evidence_strength_score=score,
203
  rationale=(
204
- f"Round 1 ambiguity band: packet scores {score:.2f} — "
205
  f"LLM softening requested compelling evidence."
206
  ),
207
  used_llm_softening=True,
@@ -212,7 +212,7 @@ class IssuerAgent:
212
  decision=IssuerDecision.ACCEPT,
213
  evidence_strength_score=score,
214
  rationale=(
215
- f"Round 1 ambiguity band: packet scores {score:.2f} "
216
  f"(>= {ROUND1_MIDPOINT_FALLBACK:.2f} midpoint) — accepting."
217
  ),
218
  )
@@ -220,7 +220,7 @@ class IssuerAgent:
220
  decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
221
  evidence_strength_score=score,
222
  rationale=(
223
- f"Round 1 ambiguity band: packet scores {score:.2f} "
224
  f"(< {ROUND1_MIDPOINT_FALLBACK:.2f} midpoint) — requesting more evidence."
225
  ),
226
  )
 
147
  decision=IssuerDecision.ACCEPT,
148
  evidence_strength_score=score,
149
  rationale=(
150
+ f"Pre-arb evidence brings the packet to {score:.2f}, "
151
+ f"above the 0.60 acceptance bar."
152
  ),
153
  )
154
  return IssuerReview(
155
  decision=IssuerDecision.ESCALATE_TO_ARBITRATION,
156
  evidence_strength_score=score,
157
  rationale=(
158
+ f"Packet still scores {score:.2f}; "
159
  f"escalating to network arbitration."
160
  ),
161
  )
 
166
  decision=IssuerDecision.ACCEPT,
167
  evidence_strength_score=score,
168
  rationale=(
169
+ f"Packet scores {score:.2f}, clearing the 0.70 acceptance bar."
170
  ),
171
  )
172
  if score <= ROUND1_REJECT_THRESHOLD:
 
174
  decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
175
  evidence_strength_score=score,
176
  rationale=(
177
+ f"Packet scores {score:.2f}, below the 0.40 floor; "
178
  f"requesting compelling evidence."
179
  ),
180
  )
 
191
  decision=IssuerDecision.ACCEPT,
192
  evidence_strength_score=score,
193
  rationale=(
194
+ f"Ambiguity band: packet scores {score:.2f} — "
195
  f"LLM softening accepted."
196
  ),
197
  used_llm_softening=True,
 
201
  decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
202
  evidence_strength_score=score,
203
  rationale=(
204
+ f"Ambiguity band: packet scores {score:.2f} — "
205
  f"LLM softening requested compelling evidence."
206
  ),
207
  used_llm_softening=True,
 
212
  decision=IssuerDecision.ACCEPT,
213
  evidence_strength_score=score,
214
  rationale=(
215
+ f"Ambiguity band: packet scores {score:.2f} "
216
  f"(>= {ROUND1_MIDPOINT_FALLBACK:.2f} midpoint) — accepting."
217
  ),
218
  )
 
220
  decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
221
  evidence_strength_score=score,
222
  rationale=(
223
+ f"Ambiguity band: packet scores {score:.2f} "
224
  f"(< {ROUND1_MIDPOINT_FALLBACK:.2f} midpoint) — requesting more evidence."
225
  ),
226
  )
tests/test_benchmark_runner.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for the scripted-policy benchmark runner.
2
+
3
+ The runner drives a fixed set of non-learning policies through the full
4
+ environment without LLM calls. These tests pin:
5
+
6
+ 1. Each policy returns valid action or None offline.
7
+ 2. The aggregator produces per-policy means and a discrimination delta.
8
+ 3. The headline policy sweep keeps the heuristic ≥ 0.40 above the naive floor.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from runners.benchmark_runner import (
14
+ POLICY_NAMES,
15
+ POLICY_REGISTRY,
16
+ concede_all_policy,
17
+ escalate_all_policy,
18
+ heuristic_policy,
19
+ naive_policy,
20
+ run_multi_seed,
21
+ run_policy_on_task,
22
+ run_policy_sweep,
23
+ )
24
+ from scenarios.simulation import get_task
25
+
26
+
27
+ _EASY_TASK = get_task("goods_not_received_easy")
28
+
29
+
30
+ def test_policy_registry_matches_public_names():
31
+ assert set(POLICY_NAMES) == set(POLICY_REGISTRY)
32
+ assert set(POLICY_NAMES) == {"heuristic", "escalate_all", "concede_all", "naive"}
33
+
34
+
35
+ def test_heuristic_scores_above_naive_on_easy():
36
+ heur = run_policy_on_task(heuristic_policy, _EASY_TASK)
37
+ nv = run_policy_on_task(naive_policy, _EASY_TASK)
38
+ assert heur.score > nv.score
39
+ assert heur.task_id == _EASY_TASK.task_id
40
+ assert heur.steps_used > 0
41
+
42
+
43
+ def test_concede_all_lands_final_resolution():
44
+ """concede_all must always terminate the episode with a concede path."""
45
+ result = run_policy_on_task(concede_all_policy, _EASY_TASK)
46
+ assert result.steps_used > 0
47
+ # concede_all scores strictly below heuristic but must stay in [0, 1].
48
+ assert 0.0 <= result.score <= 1.0
49
+
50
+
51
+ def test_escalate_all_runs_to_completion():
52
+ result = run_policy_on_task(escalate_all_policy, _EASY_TASK)
53
+ assert 0.0 <= result.score <= 1.0
54
+ assert result.steps_used > 0
55
+
56
+
57
+ def test_sweep_aggregates_and_produces_delta():
58
+ result = run_policy_sweep()
59
+ policies = {summary.policy: summary for summary in result.policies}
60
+ assert set(policies) == set(POLICY_NAMES)
61
+ # mean scores sit in the valid range
62
+ for summary in result.policies:
63
+ assert 0.0 <= summary.mean_score <= 1.0
64
+ # discrimination delta is heuristic - naive and must clear the PRD bar.
65
+ assert result.discrimination_delta >= 0.40
66
+
67
+
68
+ def test_sweep_is_deterministic():
69
+ """Two runs on the same catalog must produce identical numbers."""
70
+ first = run_policy_sweep().to_dict()
71
+ second = run_policy_sweep().to_dict()
72
+ assert first == second
73
+
74
+
75
+ def test_multi_seed_sweep_runs_subset():
76
+ """Tiny grid (2 seeds × 1 difficulty) stays under a second and returns data."""
77
+ result = run_multi_seed(seeds=[42, 17], difficulties=["easy"])
78
+ for summary in result.policies:
79
+ assert len(summary.tasks) == 2
80
+ for task_score in summary.tasks:
81
+ assert task_score.task_id.startswith("generated_easy_s")
tests/test_training_adapter.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for the training adapter.
2
+
3
+ Pin the prompt/completion serialization and the episode-replay reward
4
+ signal so the training notebook has a stable offline contract.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+
11
+ from core.models import ChargebackOpsAction
12
+ from scenarios.simulation import get_task
13
+ from server.chargeback_ops_environment import ChargebackOpsEnvironment
14
+ from training.env_adapter import (
15
+ action_from_completion,
16
+ build_prompt,
17
+ parse_completion,
18
+ )
19
+ from training.reward_adapter import (
20
+ compute_reward,
21
+ run_episode_with_text_policy,
22
+ )
23
+
24
+
25
+ def _fresh_observation(task_id: str = "goods_not_received_easy"):
26
+ env = ChargebackOpsEnvironment()
27
+ return env.reset(task_id=task_id).model_dump()
28
+
29
+
30
+ def test_build_prompt_is_deterministic_and_includes_available_actions():
31
+ obs = _fresh_observation()
32
+ a = build_prompt(obs)
33
+ b = build_prompt(obs)
34
+ assert a == b
35
+ assert "available_actions" in a
36
+ assert "OBSERVATION:" in a
37
+ assert "ACTION:" in a
38
+
39
+
40
+ def test_parse_completion_accepts_plain_json():
41
+ payload = '{"action_type": "select_case", "case_id": "CB-X"}'
42
+ parsed = parse_completion(payload)
43
+ assert parsed == {"action_type": "select_case", "case_id": "CB-X"}
44
+
45
+
46
+ def test_parse_completion_strips_code_fence():
47
+ payload = '```json\n{"action_type": "select_case", "case_id": "CB-X"}\n```'
48
+ parsed = parse_completion(payload)
49
+ assert parsed == {"action_type": "select_case", "case_id": "CB-X"}
50
+
51
+
52
+ def test_parse_completion_returns_none_on_garbage():
53
+ assert parse_completion("") is None
54
+ assert parse_completion("not json at all") is None
55
+ assert parse_completion("{not-valid-json}") is None
56
+
57
+
58
+ def test_parse_completion_drops_unknown_fields():
59
+ payload = json.dumps({"action_type": "select_case", "hack_field": 42})
60
+ parsed = parse_completion(payload)
61
+ assert parsed == {"action_type": "select_case"}
62
+
63
+
64
+ def test_action_from_completion_returns_valid_action():
65
+ payload = '{"action_type": "select_case", "case_id": "CB-X"}'
66
+ action = action_from_completion(payload)
67
+ assert isinstance(action, ChargebackOpsAction)
68
+ assert action.action_type == "select_case"
69
+ assert action.case_id == "CB-X"
70
+
71
+
72
+ def test_action_from_completion_returns_none_on_bad_type():
73
+ payload = '{"action_type": "not_a_real_action"}'
74
+ assert action_from_completion(payload) is None
75
+
76
+
77
+ def test_run_episode_falls_back_to_heuristic_on_empty_completion():
78
+ """Unparseable completions must not deadlock the episode."""
79
+ result = run_episode_with_text_policy(
80
+ "goods_not_received_easy",
81
+ text_policy=lambda _prompt: "",
82
+ )
83
+ assert result.steps_used > 0
84
+ assert result.invalid_actions > 0
85
+ assert result.score > 0.0 # heuristic fallback still scores
86
+
87
+
88
+ def test_compute_reward_matches_episode_score():
89
+ """Single completion + heuristic tail reproduces the heuristic score."""
90
+ task = get_task("goods_not_received_easy")
91
+ prompts = ["unused"]
92
+ completions = [""] # triggers heuristic fallback on the first action
93
+ rewards = compute_reward(
94
+ prompts, completions, task_ids=[task.task_id]
95
+ )
96
+ assert len(rewards) == 1
97
+ assert 0.0 <= rewards[0] <= 1.0
98
+ assert rewards[0] > 0.5 # heuristic scores ~0.97 on this task
99
+
100
+
101
+ def test_compute_reward_rejects_mismatched_lengths():
102
+ import pytest
103
+
104
+ with pytest.raises(ValueError):
105
+ compute_reward(["a"], ["b", "c"], task_ids=["goods_not_received_easy"])
training/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training helpers for ChargebackOps.
2
+
3
+ Lightweight pure-Python wrappers that convert the environment into a
4
+ prompt/completion/reward interface compatible with TRL's GRPO trainer.
5
+ The module is import-safe without ``trl`` / ``torch`` installed so unit
6
+ tests stay fast and offline.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from .env_adapter import (
12
+ action_from_completion,
13
+ build_prompt,
14
+ parse_completion,
15
+ )
16
+ from .reward_adapter import (
17
+ EpisodeResult,
18
+ compute_reward,
19
+ run_episode_with_text_policy,
20
+ )
21
+
22
+ __all__ = [
23
+ "EpisodeResult",
24
+ "action_from_completion",
25
+ "build_prompt",
26
+ "compute_reward",
27
+ "parse_completion",
28
+ "run_episode_with_text_policy",
29
+ ]
training/env_adapter.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text prompt / completion adapter for the merchant policy.
2
+
3
+ Serialize an observation into a compact prompt the model can condition
4
+ on, and parse a JSON completion back into a typed
5
+ ``ChargebackOpsAction``. Both helpers are pure — no provider calls, no
6
+ side effects — so they are cheap to unit-test.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ from typing import Any
13
+
14
+ try:
15
+ from ..core.models import ChargebackOpsAction
16
+ except ImportError: # pragma: no cover
17
+ from core.models import ChargebackOpsAction
18
+
19
+
20
+ _SYSTEM_INSTRUCTION = (
21
+ "You play the merchant-side agent in a chargeback dispute. "
22
+ "Look at the observation and choose the single best next action. "
23
+ "Return JSON only: "
24
+ '{"action_type": "...", "case_id": "...", "strategy": "...", '
25
+ '"evidence_ids": [...], "note": "..."} '
26
+ "Use only action_types listed in available_actions. Omit fields you "
27
+ "do not need."
28
+ )
29
+
30
+
31
+ _ALLOWED_ACTION_FIELDS: frozenset[str] = frozenset(
32
+ {
33
+ "action_type",
34
+ "case_id",
35
+ "system_name",
36
+ "evidence_ids",
37
+ "compelling_evidence_ids",
38
+ "strategy",
39
+ "note",
40
+ }
41
+ )
42
+
43
+
44
+ def _compact_observation(observation: dict[str, Any]) -> dict[str, Any]:
45
+ """Drop fields that add tokens without signal for the merchant policy."""
46
+
47
+ visible_case = observation.get("visible_case")
48
+ compact_case: dict[str, Any] | None = None
49
+ if visible_case is not None:
50
+ compact_case = {
51
+ "case_id": visible_case["case_id"],
52
+ "status": visible_case["status"],
53
+ "reason_code": visible_case["reason_code"],
54
+ "amount": visible_case["amount"],
55
+ "currency": visible_case["currency"],
56
+ "current_strategy": visible_case.get("current_strategy"),
57
+ "systems_revealed": visible_case.get("systems_revealed", []),
58
+ "retrieved_evidence": [
59
+ {
60
+ "evidence_id": item["evidence_id"],
61
+ "source_system": item["source_system"],
62
+ "title": item["title"],
63
+ }
64
+ for item in visible_case.get("retrieved_evidence", [])
65
+ ],
66
+ "attached_evidence": [
67
+ item["evidence_id"]
68
+ for item in visible_case.get("attached_evidence", [])
69
+ ],
70
+ "policy": visible_case.get("policy"),
71
+ }
72
+
73
+ return {
74
+ "objective": observation.get("objective", ""),
75
+ "selected_case_id": observation.get("selected_case_id"),
76
+ "available_actions": observation.get("available_actions", []),
77
+ "steps_remaining": observation.get("steps_remaining", 0),
78
+ "queue": [
79
+ {
80
+ "case_id": item["case_id"],
81
+ "status": item["status"],
82
+ "reason_code": item["reason_code"],
83
+ "amount": item["amount"],
84
+ "steps_until_deadline": item["steps_until_deadline"],
85
+ }
86
+ for item in observation.get("queue", [])
87
+ ],
88
+ "visible_case": compact_case,
89
+ "last_action_result": observation.get("last_action_result", ""),
90
+ }
91
+
92
+
93
+ def build_prompt(observation: dict[str, Any]) -> str:
94
+ """Return a deterministic prompt for the merchant policy."""
95
+
96
+ compact = _compact_observation(observation)
97
+ body = json.dumps(compact, separators=(",", ":"), sort_keys=True)
98
+ return f"{_SYSTEM_INSTRUCTION}\nOBSERVATION:\n{body}\nACTION:"
99
+
100
+
101
+ def parse_completion(text: str) -> dict[str, Any] | None:
102
+ """Parse a model completion into a raw action dict, or return None."""
103
+
104
+ if not text:
105
+ return None
106
+ cleaned = text.strip()
107
+ # Strip common code-fence patterns.
108
+ if cleaned.startswith("```"):
109
+ cleaned = cleaned.strip("`").strip()
110
+ if cleaned.lower().startswith("json"):
111
+ cleaned = cleaned[4:].lstrip()
112
+ # Find the first {...} block so prose before JSON is tolerated.
113
+ start = cleaned.find("{")
114
+ end = cleaned.rfind("}")
115
+ if start == -1 or end == -1 or end <= start:
116
+ return None
117
+ try:
118
+ data = json.loads(cleaned[start : end + 1])
119
+ except json.JSONDecodeError:
120
+ return None
121
+ if not isinstance(data, dict):
122
+ return None
123
+ return {k: v for k, v in data.items() if k in _ALLOWED_ACTION_FIELDS}
124
+
125
+
126
+ def action_from_completion(text: str) -> ChargebackOpsAction | None:
127
+ """Parse a completion and build a validated :class:`ChargebackOpsAction`."""
128
+
129
+ parsed = parse_completion(text)
130
+ if parsed is None or "action_type" not in parsed:
131
+ return None
132
+ try:
133
+ return ChargebackOpsAction(**parsed)
134
+ except Exception:
135
+ return None
136
+
137
+
138
+ __all__ = [
139
+ "action_from_completion",
140
+ "build_prompt",
141
+ "parse_completion",
142
+ ]
training/reward_adapter.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward adapter for GRPO / RL training on ChargebackOps.
2
+
3
+ Exposes a callable shape compatible with TRL's GRPO reward function:
4
+
5
+ ``reward_fn(prompts, completions, **kwargs) -> list[float]``
6
+
7
+ Each completion is parsed into an action sequence (one action per line
8
+ is the simplest case; the helper also accepts a single-action
9
+ completion and runs the remainder of the episode under the scripted
10
+ heuristic so training always produces a terminal score). The resulting
11
+ reward is the episode's deterministic normalized grade in ``[0, 1]``.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from dataclasses import dataclass, field
17
+ from typing import Any, Callable, Sequence
18
+
19
+ try:
20
+ from ..core.models import ChargebackOpsAction, ChargebackOpsObservation
21
+ from ..scenarios.simulation import get_task, list_tasks
22
+ from ..server.chargeback_ops_environment import ChargebackOpsEnvironment
23
+ from .env_adapter import action_from_completion, build_prompt
24
+ except ImportError: # pragma: no cover
25
+ from core.models import ChargebackOpsAction, ChargebackOpsObservation
26
+ from scenarios.simulation import get_task, list_tasks
27
+ from server.chargeback_ops_environment import ChargebackOpsEnvironment
28
+ from training.env_adapter import action_from_completion, build_prompt
29
+
30
+
31
+ TextPolicyFn = Callable[[str], str]
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class EpisodeResult:
36
+ """Outcome of a single rollout."""
37
+
38
+ task_id: str
39
+ score: float
40
+ steps_used: int
41
+ invalid_actions: int
42
+ prompts: tuple[str, ...] = field(default_factory=tuple)
43
+ completions: tuple[str, ...] = field(default_factory=tuple)
44
+
45
+
46
+ def _fallback_action(
47
+ observation: ChargebackOpsObservation,
48
+ ) -> ChargebackOpsAction | None:
49
+ """Scripted fallback when the model output is unparseable."""
50
+
51
+ try:
52
+ from ..runners.benchmark_runner import heuristic_policy
53
+ except ImportError: # pragma: no cover
54
+ from runners.benchmark_runner import heuristic_policy
55
+ return heuristic_policy(observation.model_dump())
56
+
57
+
58
+ def run_episode_with_text_policy(
59
+ task_id: str,
60
+ text_policy: TextPolicyFn,
61
+ *,
62
+ max_steps: int | None = None,
63
+ capture_trace: bool = False,
64
+ ) -> EpisodeResult:
65
+ """Roll one episode forward under a text-in / text-out policy.
66
+
67
+ The policy is invoked at every step. If the completion fails to
68
+ parse into a valid action the scripted heuristic is used instead;
69
+ this keeps early-training trajectories from deadlocking.
70
+ """
71
+
72
+ task = get_task(task_id)
73
+ env = ChargebackOpsEnvironment()
74
+ observation = env.reset(task_id=task_id)
75
+ step_budget = (max_steps if max_steps is not None else task.max_steps) + 5
76
+ steps = 0
77
+ invalid = 0
78
+ prompts: list[str] = []
79
+ completions: list[str] = []
80
+
81
+ while not observation.done and steps < step_budget:
82
+ obs_dict = observation.model_dump()
83
+ prompt = build_prompt(obs_dict)
84
+ completion = text_policy(prompt)
85
+ action = action_from_completion(completion)
86
+ if action is None:
87
+ invalid += 1
88
+ action = _fallback_action(observation)
89
+ if action is None:
90
+ break
91
+ observation = env.step(action)
92
+ steps += 1
93
+ if capture_trace:
94
+ prompts.append(prompt)
95
+ completions.append(completion)
96
+
97
+ report = env.state.grader_report
98
+ score = float(report.normalized_score) if report is not None else 0.0
99
+ return EpisodeResult(
100
+ task_id=task_id,
101
+ score=score,
102
+ steps_used=env.state.step_count,
103
+ invalid_actions=invalid,
104
+ prompts=tuple(prompts),
105
+ completions=tuple(completions),
106
+ )
107
+
108
+
109
+ def compute_reward(
110
+ prompts: Sequence[str],
111
+ completions: Sequence[str],
112
+ *,
113
+ task_ids: Sequence[str] | None = None,
114
+ **_: Any,
115
+ ) -> list[float]:
116
+ """GRPO-style reward function.
117
+
118
+ Each ``completion`` is replayed as a *single* action. The remainder
119
+ of the episode is driven by the scripted heuristic, so the reward
120
+ signal rewards the model for picking a good first move from a
121
+ given observation. This matches the behaviour TRL expects: one
122
+ ``(prompt, completion)`` pair → one scalar reward.
123
+
124
+ ``task_ids`` optionally binds each prompt to a task id for env
125
+ replay. When omitted, the headline catalog is cycled.
126
+ """
127
+
128
+ if task_ids is None:
129
+ headline = [task.task_id for task in list_tasks()]
130
+ task_ids = [headline[i % len(headline)] for i in range(len(prompts))]
131
+ if len(task_ids) != len(prompts) or len(prompts) != len(completions):
132
+ raise ValueError(
133
+ "prompts, completions, and task_ids must all have the same length"
134
+ )
135
+
136
+ rewards: list[float] = []
137
+ for task_id, completion in zip(task_ids, completions):
138
+ first_action = action_from_completion(completion)
139
+
140
+ def _once(_prompt: str, _used=[False], _action=first_action) -> str:
141
+ if _used[0] or _action is None:
142
+ return ""
143
+ _used[0] = True
144
+ return completion
145
+
146
+ result = run_episode_with_text_policy(task_id, _once)
147
+ rewards.append(result.score)
148
+ return rewards
149
+
150
+
151
+ __all__ = [
152
+ "EpisodeResult",
153
+ "TextPolicyFn",
154
+ "compute_reward",
155
+ "run_episode_with_text_policy",
156
+ ]