Spaces:
Sleeping
Sleeping
Commit ·
bd00c06
1
Parent(s): 06abe10
Add training notebook and benchmark runner for ChargebackOps
Browse files- Introduced `train_merchant_agent.ipynb` for end-to-end GRPO training of the merchant-side chargeback agent, including environment setup, model loading, training prompt dataset creation, and evaluation.
- Created `benchmark_runner.py` to implement scripted policies for benchmarking against the trained agent, including heuristic, escalate_all, concede_all, and naive policies.
- Updated `issuer_model.py` to refine rationale messages for issuer decisions.
- Added unit tests for the benchmark runner and training adapter to ensure policy validity and reward computation accuracy.
- Implemented training helpers in `training/__init__.py`, `env_adapter.py`, and `reward_adapter.py` to facilitate interaction with the GRPO trainer.
- docs/RESULTS.md +113 -157
- notebooks/train_merchant_agent.ipynb +236 -0
- runners/benchmark_runner.py +351 -0
- scenarios/issuer_model.py +9 -9
- tests/test_benchmark_runner.py +81 -0
- tests/test_training_adapter.py +105 -0
- training/__init__.py +29 -0
- training/env_adapter.py +142 -0
- training/reward_adapter.py +156 -0
docs/RESULTS.md
CHANGED
|
@@ -1,125 +1,71 @@
|
|
| 1 |
-
# ChargebackOps —
|
| 2 |
|
| 3 |
-
Reference numbers for the 10-task headline
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
bottom; headline scores match to within ±1e-3 (float rounding).
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
| --- | --- | --- | --- | --- |
|
| 13 |
-
| **Bad policy** (concede-everything) | **0.199** | `generated_medium_s99` (0.442) | `generated_nightmare_s77` (0.053) | 0 |
|
| 14 |
-
| **Heuristic** (no LLM, rule-based) | **0.724** | `goods_not_received_easy` / `fraud_signal_ambiguity` (0.968) | `generated_hard_s53` (0.440) | 0 |
|
| 15 |
-
| **Heuristic + LLM tiebreak** (openrouter gpt-oss-120b) | **0.729** | `goods_not_received_easy` / `fraud_signal_ambiguity` / `generated_easy_s42` (0.958) | `generated_hard_s53` (0.440) | 7 (7 ✓ / 0 ✗) |
|
| 16 |
-
|
| 17 |
-
**Key signal:** the bad policy vs. heuristic delta is **0.525** (72.4 → 19.9 = 264% spread).
|
| 18 |
-
The `Gate(CaseAbandonedRubric)` wrapper around the per-case `WeightedSum` means a case left
|
| 19 |
-
unresolved past its deadline hard-zeros — a lazy concede-everything agent cannot game the score,
|
| 20 |
-
and a correct agent cannot trivially saturate it on hard tasks. The LLM-assisted run now edges
|
| 21 |
-
ahead of the pure heuristic (+0.005) after the v1.1 prompt and `_obvious_next_action` upgrades;
|
| 22 |
-
the LLM is invoked only **7 times** across the 10-task run (down from 19 in v1) because
|
| 23 |
-
deterministic workflow states are now dispatched without a model call.
|
| 24 |
-
|
| 25 |
-
## Score Curve by Difficulty
|
| 26 |
-
|
| 27 |
-
| Difficulty | Task count | Heuristic avg | LLM avg | Bad avg | Target band | Status |
|
| 28 |
-
| --- | --- | --- | --- | --- | --- | --- |
|
| 29 |
-
| easy | 3 | 0.964 | 0.964 | 0.323 | ≥ 0.90 | ✓ |
|
| 30 |
-
| medium | 2 | 0.755 | 0.755 | 0.278 | 0.50 – 0.85 | ✓ |
|
| 31 |
-
| hard | 3 | 0.635 | 0.651 | 0.113 | 0.50 – 0.75 | ✓ |
|
| 32 |
-
| nightmare | 2 | 0.466 | 0.466 | 0.065 | ≤ 0.55 | ✓ |
|
| 33 |
|
| 34 |
-
|
| 35 |
-
- The LLM-assisted run now **matches or narrowly beats** the heuristic on every difficulty band
|
| 36 |
-
(overall +0.005). The old v1 regression — where the LLM dropped 0.56 on `fraud_signal_ambiguity`
|
| 37 |
-
and 0.29 on `generated_medium_s99` — was caused by the model picking a concede strategy over
|
| 38 |
-
contest at `set_strategy` time. `_obvious_next_action` now short-circuits all strategy picks
|
| 39 |
-
so the heuristic-derived strategy is used directly, and the prompt explicitly lists the
|
| 40 |
-
reason-code → optimal-strategy mapping for the remaining decision points. Provider call count
|
| 41 |
-
fell from 19 to 7 because deterministic housekeeping (add_evidence, remove_evidence,
|
| 42 |
-
submit_representment, set_strategy, resolve_case) is now bypassed entirely.
|
| 43 |
-
- The LLM's remaining upside is on `queue_optimization_hard` (+0.049 over heuristic), where the
|
| 44 |
-
queue-triage branching is genuine and the heuristic's fixed priority order leaves marginal
|
| 45 |
-
value on the table.
|
| 46 |
-
- Nightmare tasks cluster around **0.47** for the heuristic because the 15-step budget collides
|
| 47 |
-
with 5-case portfolios that have deadline_step=3–5 per case. Missed deadlines that were
|
| 48 |
-
*attempted* still land in the weighted sum (with 0 on the deadline dimension and ~0.55 from
|
| 49 |
-
the other 85%); truly abandoned cases are zeroed by the `Gate(CaseAbandonedRubric)` wrapper.
|
| 50 |
-
Not a scoring artifact: the bad-policy run shows the same tasks at ~0.065.
|
| 51 |
-
- The deadline `Gate` is the v1 upgrade over a flat weighted sum: a case never even attempted by
|
| 52 |
-
the deadline collapses completely, while a case resolved late still earns dimensional credit
|
| 53 |
-
for evidence, strategy, and packet quality. This matches real chargeback operations — a missed
|
| 54 |
-
representment is "case forfeit," while a late one takes a penalty but is still scored on what
|
| 55 |
-
the merchant tried to do.
|
| 56 |
-
|
| 57 |
-
## Full Per-Task Table
|
| 58 |
-
|
| 59 |
-
| Task ID | Difficulty | Cases | Heuristic | H steps | LLM | LLM steps | Bad | Bad steps |
|
| 60 |
-
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 61 |
-
| goods_not_received_easy | easy | 1 | 0.968 | 6 | 0.968 | 6 | 0.280 | 3 |
|
| 62 |
-
| fraud_signal_ambiguity | easy | 1 | 0.968 | 7 | 0.968 | 7 | 0.280 | 3 |
|
| 63 |
-
| generated_easy_s42 | easy | 1 | 0.958 | 7 | 0.958 | 7 | 0.408 | 3 |
|
| 64 |
-
| generated_medium_s17 | medium | 2 | 0.809 | 10 | 0.809 | 10 | 0.114 | 12 |
|
| 65 |
-
| generated_medium_s99 | medium | 2 | 0.701 | 9 | 0.701 | 9 | 0.442 | 12 |
|
| 66 |
-
| queue_optimization_hard | hard | 3 | 0.802 | 12 | 0.850 | 11 | 0.129 | 15 |
|
| 67 |
-
| generated_hard_s7 | hard | 2 | 0.663 | 5 | 0.663 | 5 | 0.120 | 12 |
|
| 68 |
-
| generated_hard_s53 | hard | 3 | 0.440 | 6 | 0.440 | 6 | 0.089 | 15 |
|
| 69 |
-
| generated_nightmare_s31 | nightmare | 5 | 0.486 | 15 | 0.486 | 15 | 0.077 | 15 |
|
| 70 |
-
| generated_nightmare_s77 | nightmare | 5 | 0.445 | 15 | 0.445 | 15 | 0.053 | 15 |
|
| 71 |
-
| **Average** | | | **0.724** | 9.2 | **0.729** | 9.0 | **0.199** | 10.5 |
|
| 72 |
-
|
| 73 |
-
## Multi-seed Stress Grid (7 seeds × 4 difficulties)
|
| 74 |
-
|
| 75 |
-
Running the heuristic and bad-policy agents across seven generator seeds per difficulty (seeds
|
| 76 |
-
7, 17, 31, 42, 53, 77, 99) gives the statistically defensible version of the headline numbers.
|
| 77 |
-
All runs are fully offline — no provider calls involved.
|
| 78 |
-
|
| 79 |
-
| Difficulty | n | Heuristic mean ± std | Bad mean ± std |
|
| 80 |
| --- | --- | --- | --- |
|
| 81 |
-
|
|
| 82 |
-
|
|
| 83 |
-
|
|
| 84 |
-
|
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
Observations:
|
| 88 |
-
- Heuristic score decreases
|
| 89 |
-
0.
|
| 90 |
-
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
| 108 |
-
|
|
| 109 |
-
|
|
| 110 |
-
|
|
| 111 |
-
|
|
| 112 |
-
|
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
| note_quality | 0.05 | 0.85 | 0.0425 |
|
| 116 |
-
| **Total** | **1.00** | — | **0.9675** |
|
| 117 |
-
|
| 118 |
-
Per-dimension scores captured by reading `rubric.last_score` on every child in the
|
| 119 |
-
`ChargebackOpsEpisodeRubric.case_rubric.aggregator` tree after one forward pass — exactly the
|
| 120 |
-
introspection path an RL trainer would use for credit assignment. The small gaps
|
| 121 |
-
(`evidence_quality=0.90`, `efficiency=0.95`, `note_quality=0.85`) are the real headroom an
|
| 122 |
-
LLM-fine-tuned agent is expected to close.
|
| 123 |
|
| 124 |
## Rubric Composition (what's wired)
|
| 125 |
|
|
@@ -129,65 +75,75 @@ ChargebackOpsEpisodeRubric
|
|
| 129 |
├── deadline_gate: Gate(threshold=1.0) # hard-zero if case abandoned past deadline
|
| 130 |
│ └── CaseAbandonedRubric
|
| 131 |
└── aggregator: WeightedSum # weights sum to 1.0
|
| 132 |
-
├── rubric_0: StrategyCorrectnessRubric #
|
| 133 |
-
├── rubric_1: EvidenceQualityRubric #
|
| 134 |
-
├── rubric_2: PacketValidityRubric #
|
| 135 |
-
├── rubric_3: DeadlineComplianceRubric #
|
| 136 |
-
├── rubric_4: EfficiencyRubric #
|
| 137 |
-
├── rubric_5: OutcomeQualityRubric #
|
| 138 |
-
|
|
|
|
| 139 |
```
|
| 140 |
|
| 141 |
-
Every node is an OpenEnv `Rubric` subclass and every node exposes
|
| 142 |
-
`env.rubric.named_rubrics()` walks the tree
|
| 143 |
-
surface for a judge or trainer to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
## Reproducing These Numbers
|
| 146 |
|
| 147 |
```bash
|
| 148 |
-
# Activate the project's venv
|
| 149 |
source ~/python/bin/activate
|
| 150 |
|
| 151 |
-
# 1. Headline 10-task run (heuristic + bad policy, no network)
|
| 152 |
python - <<'PY'
|
| 153 |
-
from
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
print(f"{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
PY
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
-
python - <<'PY'
|
| 163 |
-
from statistics import mean, stdev
|
| 164 |
-
from evaluation.agent_brutal_audit import run_episode
|
| 165 |
-
for d in ("easy","medium","hard","nightmare"):
|
| 166 |
-
hs, bs = [], []
|
| 167 |
-
for s in (7, 17, 31, 42, 53, 77, 99):
|
| 168 |
-
hs.append(run_episode(f"generated_{d}_s{s}", policy='heuristic')['score'])
|
| 169 |
-
bs.append(run_episode(f"generated_{d}_s{s}", policy='bad')['score'])
|
| 170 |
-
print(f"{d:10s} heur={mean(hs):.4f}±{stdev(hs):.4f} bad={mean(bs):.4f}±{stdev(bs):.4f}")
|
| 171 |
-
PY
|
| 172 |
|
| 173 |
-
|
| 174 |
python -m runners.baseline_runner | tee /tmp/baseline_run.json
|
| 175 |
```
|
| 176 |
|
| 177 |
## Hardware / Environment
|
| 178 |
|
| 179 |
-
- Python 3.12
|
| 180 |
-
- `openenv-core
|
| 181 |
-
-
|
| 182 |
-
-
|
| 183 |
-
~2.5s in v1 because `_obvious_next_action` bypasses most model calls)
|
| 184 |
-
- Full test suite: 22/22 passing, `openenv validate .` clean, Docker build clean
|
| 185 |
|
| 186 |
## What This Table Does Not Show
|
| 187 |
|
| 188 |
-
- **Per-dimension score dispersion across the full catalog** — the
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ChargebackOps — Benchmark Results
|
| 2 |
|
| 3 |
+
Reference numbers for the 10-task headline catalog and the 28-task
|
| 4 |
+
multi-seed stress grid against the current multi-round adversarial
|
| 5 |
+
environment. Reproduce with the commands at the bottom; scores match to
|
| 6 |
+
within ±1e-3 (float rounding).
|
|
|
|
| 7 |
|
| 8 |
+
Captured on **2026-04-19** on `main` with the 8-dimension case rubric
|
| 9 |
+
(weights `(0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.05, 0.20)`,
|
| 10 |
+
`escalation_roi` dimension added) and the deterministic Issuer agent
|
| 11 |
+
(LLM softening disabled — benchmarks stay fully offline).
|
| 12 |
|
| 13 |
+
## TL;DR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
| Policy | Headline avg (10 tasks) | Multi-seed avg (28 tasks) | Provider calls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
| --- | --- | --- | --- |
|
| 17 |
+
| **naive** (empty packet → submit) | **0.0000** | **0.0000** | 0 |
|
| 18 |
+
| **concede_all** (always `accept_chargeback`) | **0.5666** | **0.5634** | 0 |
|
| 19 |
+
| **escalate_all** (contest, then always escalate) | **0.7731** | **0.7647** | 0 |
|
| 20 |
+
| **heuristic** (first-candidate rule-based pick) | **0.7731** | **0.7647** | 0 |
|
| 21 |
+
|
| 22 |
+
**Discrimination delta** (heuristic − naive) is **0.7731** on the headline
|
| 23 |
+
catalog and **0.7647** on the multi-seed grid — well above the 0.40 target.
|
| 24 |
+
|
| 25 |
+
`escalate_all` ties with `heuristic` because the heuristic wins the
|
| 26 |
+
representment on most tasks in the first review; the environment never
|
| 27 |
+
enters the pre-arbitration branch and the escalation override never
|
| 28 |
+
fires. That match is a signal, not a bug: when the scripted merchant
|
| 29 |
+
packet is strong, escalation is never rational in the current
|
| 30 |
+
deterministic Issuer, so the two policies produce identical trajectories.
|
| 31 |
+
|
| 32 |
+
## Score Curve by Difficulty (multi-seed grid, 7 seeds / difficulty)
|
| 33 |
+
|
| 34 |
+
| Difficulty | n | heuristic | escalate_all | concede_all | naive |
|
| 35 |
+
| --- | --- | --- | --- | --- | --- |
|
| 36 |
+
| easy | 7 | 0.974 | 0.974 | 0.470 | 0.000 |
|
| 37 |
+
| medium | 7 | 0.876 | 0.876 | 0.699 | 0.000 |
|
| 38 |
+
| hard | 7 | 0.701 | 0.701 | 0.584 | 0.000 |
|
| 39 |
+
| nightmare | 7 | 0.508 | 0.508 | 0.501 | 0.000 |
|
| 40 |
|
| 41 |
Observations:
|
| 42 |
+
- Heuristic score decreases monotonically with difficulty
|
| 43 |
+
(0.97 → 0.88 → 0.70 → 0.51). The difficulty gradient is real.
|
| 44 |
+
- `concede_all` narrows the gap at nightmare (0.508 vs 0.501) because
|
| 45 |
+
the 15-step budget vs. 5-case portfolio forces the heuristic to
|
| 46 |
+
forfeit cases deadline-wise, while conceding is cheap per case.
|
| 47 |
+
This is the expected `Gate(CaseAbandonedRubric)` behavior.
|
| 48 |
+
- `naive` sits flat at 0.000 because an empty packet fails the
|
| 49 |
+
packet-validity gate and every case is scored as unresolved /
|
| 50 |
+
abandoned.
|
| 51 |
+
|
| 52 |
+
## Headline Per-Task Table (10 tasks, offline)
|
| 53 |
+
|
| 54 |
+
| Task ID | Difficulty | heuristic | escalate_all | concede_all | naive |
|
| 55 |
+
| --- | --- | --- | --- | --- | --- |
|
| 56 |
+
| goods_not_received_easy | easy | 0.968 | 0.968 | 0.580 | 0.000 |
|
| 57 |
+
| fraud_signal_ambiguity | easy | 0.968 | 0.968 | 0.580 | 0.000 |
|
| 58 |
+
| queue_optimization_hard | hard | 0.802 | 0.802 | 0.576 | 0.000 |
|
| 59 |
+
| generated_easy_s42 | easy | 0.958 | 0.958 | 0.533 | 0.000 |
|
| 60 |
+
| generated_medium_s17 | medium | 0.861 | 0.861 | 0.623 | 0.000 |
|
| 61 |
+
| generated_medium_s99 | medium | 0.770 | 0.770 | 0.727 | 0.000 |
|
| 62 |
+
| generated_hard_s7 | hard | 0.724 | 0.724 | 0.615 | 0.000 |
|
| 63 |
+
| generated_hard_s53 | hard | 0.544 | 0.544 | 0.612 | 0.000 |
|
| 64 |
+
| generated_nightmare_s31 | nightmare | 0.602 | 0.602 | 0.529 | 0.000 |
|
| 65 |
+
| generated_nightmare_s77 | nightmare | 0.474 | 0.474 | 0.537 | 0.000 |
|
| 66 |
+
| **Average** | | **0.7731** | **0.7731** | **0.5666** | **0.0000** |
|
| 67 |
+
|
| 68 |
+
(Per-task numbers from `runners.benchmark_runner.run_policy_sweep()`.)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
## Rubric Composition (what's wired)
|
| 71 |
|
|
|
|
| 75 |
├── deadline_gate: Gate(threshold=1.0) # hard-zero if case abandoned past deadline
|
| 76 |
│ └── CaseAbandonedRubric
|
| 77 |
└── aggregator: WeightedSum # weights sum to 1.0
|
| 78 |
+
├── rubric_0: StrategyCorrectnessRubric # 0.20
|
| 79 |
+
├── rubric_1: EvidenceQualityRubric # 0.15
|
| 80 |
+
├── rubric_2: PacketValidityRubric # 0.10
|
| 81 |
+
├── rubric_3: DeadlineComplianceRubric # 0.10
|
| 82 |
+
├── rubric_4: EfficiencyRubric # 0.10
|
| 83 |
+
├── rubric_5: OutcomeQualityRubric # 0.10
|
| 84 |
+
├── rubric_6: NoteQualityRubric # 0.05
|
| 85 |
+
└── rubric_7: EscalationROIRubric # 0.20
|
| 86 |
```
|
| 87 |
|
| 88 |
+
Every node is an OpenEnv `Rubric` subclass and every node exposes
|
| 89 |
+
`last_score` after forward. `env.rubric.named_rubrics()` walks the tree
|
| 90 |
+
and returns the hook-compatible surface for a judge or trainer to
|
| 91 |
+
introspect per-dimension scores.
|
| 92 |
+
|
| 93 |
+
`EscalationROIRubric` encodes the economic rule that escalating to
|
| 94 |
+
network arbitration is rational only when
|
| 95 |
+
`P(win) × dispute_amount > arb_fee` (fee = $250/side). Scripted policies
|
| 96 |
+
that escalate negative-EV cases (or concede positive-EV cases) are
|
| 97 |
+
penalised on this axis.
|
| 98 |
|
| 99 |
## Reproducing These Numbers
|
| 100 |
|
| 101 |
```bash
|
|
|
|
| 102 |
source ~/python/bin/activate
|
| 103 |
|
|
|
|
| 104 |
python - <<'PY'
|
| 105 |
+
from runners.benchmark_runner import run_policy_sweep, run_multi_seed
|
| 106 |
+
|
| 107 |
+
headline = run_policy_sweep()
|
| 108 |
+
print("HEADLINE (10 tasks)")
|
| 109 |
+
for s in headline.policies:
|
| 110 |
+
print(f" {s.policy:14s} mean={s.mean_score:.4f} stdev={s.stdev:.4f}")
|
| 111 |
+
print(f" delta (heuristic - naive): {headline.discrimination_delta}")
|
| 112 |
+
|
| 113 |
+
grid = run_multi_seed(
|
| 114 |
+
seeds=[7, 17, 31, 42, 53, 77, 99],
|
| 115 |
+
difficulties=["easy", "medium", "hard", "nightmare"],
|
| 116 |
+
)
|
| 117 |
+
print("MULTI-SEED (28 tasks)")
|
| 118 |
+
for s in grid.policies:
|
| 119 |
+
print(f" {s.policy:14s} mean={s.mean_score:.4f} stdev={s.stdev:.4f}")
|
| 120 |
+
print(f" delta (heuristic - naive): {grid.discrimination_delta}")
|
| 121 |
PY
|
| 122 |
+
```
|
| 123 |
|
| 124 |
+
Optional LLM-assisted baseline (requires `OPENROUTER_API_KEY`):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
```bash
|
| 127 |
python -m runners.baseline_runner | tee /tmp/baseline_run.json
|
| 128 |
```
|
| 129 |
|
| 130 |
## Hardware / Environment
|
| 131 |
|
| 132 |
+
- Python 3.12, pytest 8.x
|
| 133 |
+
- `openenv-core`, `pydantic`, `openai` per `pyproject.toml`
|
| 134 |
+
- No provider calls for the four scripted policies — all results fully offline
|
| 135 |
+
- Full test suite: **65/65 passing**
|
|
|
|
|
|
|
| 136 |
|
| 137 |
## What This Table Does Not Show
|
| 138 |
|
| 139 |
+
- **Per-dimension score dispersion across the full catalog** — the
|
| 140 |
+
headline table aggregates to one scalar per task. Walk
|
| 141 |
+
`env.rubric.named_rubrics()` on any run for the per-dimension
|
| 142 |
+
introspection path.
|
| 143 |
+
- **LLM-trained merchant curves** — this environment is the substrate;
|
| 144 |
+
training curves are produced separately by the TRL notebook.
|
| 145 |
+
- **Adversarial Issuer with LLM softening enabled** — softening is
|
| 146 |
+
gated on API keys. With keys set, the Issuer can override the
|
| 147 |
+
deterministic midpoint in the ambiguity band; that configuration is
|
| 148 |
+
tested in `tests/test_llm_softening.py` but is not part of the
|
| 149 |
+
offline benchmark numbers above.
|
notebooks/train_merchant_agent.ipynb
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Train Merchant Agent on ChargebackOps\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"End-to-end GRPO training skeleton for the merchant-side chargeback agent.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"- Environment: `ChargebackOpsEnvironment` (multi-round adversarial Issuer, arbitration ROI).\n",
|
| 12 |
+
"- Text interface: `training.env_adapter` (prompt build, completion parse).\n",
|
| 13 |
+
"- Reward: `training.reward_adapter.compute_reward` — returns the normalised episode score in `[0, 1]`.\n",
|
| 14 |
+
"- Trainer: `trl.GRPOTrainer` on a small base model so this fits a free Colab T4.\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"The first run is intentionally tiny (1 step, micro-batch of 2). Once the wiring is green, bump `max_steps` to 200 for the real curve."
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"source": [
|
| 23 |
+
"## 1. Colab setup\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"Installs TRL, transformers, and the ChargebackOps package itself. Skip if the environment already has them."
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"%%capture\n",
|
| 35 |
+
"import sys\n",
|
| 36 |
+
"if 'google.colab' in sys.modules:\n",
|
| 37 |
+
" !pip install --quiet trl==0.11.4 transformers==4.44.2 accelerate==0.33.0 peft==0.12.0 bitsandbytes==0.43.3\n",
|
| 38 |
+
" !git clone https://github.com/example/chargebackops.git /content/chargebackops\n",
|
| 39 |
+
" %cd /content/chargebackops\n",
|
| 40 |
+
" !pip install --quiet -e ."
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "markdown",
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"source": [
|
| 47 |
+
"## 2. Sanity-check the env adapter\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"Run one scripted episode via the text adapter to confirm prompts render and rewards land inside `[0, 1]`."
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": null,
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"from training.env_adapter import build_prompt\n",
|
| 59 |
+
"from training.reward_adapter import run_episode_with_text_policy\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"def heuristic_text_policy(prompt: str) -> str:\n",
|
| 62 |
+
" # Force the fallback path so the scripted heuristic drives the episode.\n",
|
| 63 |
+
" return ''\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"result = run_episode_with_text_policy('goods_not_received_easy', heuristic_text_policy)\n",
|
| 66 |
+
"print('score', result.score, 'steps', result.steps_used, 'invalid', result.invalid_actions)"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "markdown",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"source": [
|
| 73 |
+
"## 3. Load a small base model\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"`Qwen/Qwen2.5-0.5B-Instruct` fits on a free T4 with LoRA adapters. Swap in a bigger instruct model if you have the memory budget."
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": null,
|
| 81 |
+
"metadata": {},
|
| 82 |
+
"outputs": [],
|
| 83 |
+
"source": [
|
| 84 |
+
"import torch\n",
|
| 85 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"MODEL_ID = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
|
| 88 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
|
| 89 |
+
"if tokenizer.pad_token is None:\n",
|
| 90 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 91 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 92 |
+
" MODEL_ID,\n",
|
| 93 |
+
" torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n",
|
| 94 |
+
" device_map='auto',\n",
|
| 95 |
+
")"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "markdown",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"source": [
|
| 102 |
+
"## 4. Build the training prompt dataset\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"GRPO expects a list of prompts; it generates K completions per prompt internally and scores each with `compute_reward`. We sample prompts from fresh environment resets across the headline catalog."
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": null,
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"from datasets import Dataset\n",
|
| 114 |
+
"from scenarios.simulation import list_tasks\n",
|
| 115 |
+
"from server.chargeback_ops_environment import ChargebackOpsEnvironment\n",
|
| 116 |
+
"from training.env_adapter import build_prompt\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"def sample_prompts(n: int = 32):\n",
|
| 119 |
+
" tasks = list_tasks()\n",
|
| 120 |
+
" prompts, task_ids = [], []\n",
|
| 121 |
+
" for i in range(n):\n",
|
| 122 |
+
" task = tasks[i % len(tasks)]\n",
|
| 123 |
+
" env = ChargebackOpsEnvironment()\n",
|
| 124 |
+
" obs = env.reset(task_id=task.task_id).model_dump()\n",
|
| 125 |
+
" prompts.append(build_prompt(obs))\n",
|
| 126 |
+
" task_ids.append(task.task_id)\n",
|
| 127 |
+
" return Dataset.from_dict({'prompt': prompts, 'task_id': task_ids})\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"train_dataset = sample_prompts(32)\n",
|
| 130 |
+
"len(train_dataset)"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "markdown",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"source": [
|
| 137 |
+
"## 5. GRPO training step\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"Starts with `max_steps=1` — just verify the gradient path closes. Increase to 200 for the real curve."
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"cell_type": "code",
|
| 144 |
+
"execution_count": null,
|
| 145 |
+
"metadata": {},
|
| 146 |
+
"outputs": [],
|
| 147 |
+
"source": [
|
| 148 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 149 |
+
"from training.reward_adapter import compute_reward\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"def reward_fn(prompts, completions, **kwargs):\n",
|
| 152 |
+
" task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n",
|
| 153 |
+
" return compute_reward(prompts, completions, task_ids=task_ids)\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"config = GRPOConfig(\n",
|
| 156 |
+
" output_dir='./grpo-merchant-agent',\n",
|
| 157 |
+
" per_device_train_batch_size=2,\n",
|
| 158 |
+
" num_generations=4,\n",
|
| 159 |
+
" max_prompt_length=1024,\n",
|
| 160 |
+
" max_completion_length=128,\n",
|
| 161 |
+
" learning_rate=5e-6,\n",
|
| 162 |
+
" max_steps=1,\n",
|
| 163 |
+
" logging_steps=1,\n",
|
| 164 |
+
" save_steps=50,\n",
|
| 165 |
+
" gradient_accumulation_steps=1,\n",
|
| 166 |
+
" bf16=torch.cuda.is_available(),\n",
|
| 167 |
+
" report_to='none',\n",
|
| 168 |
+
")\n",
|
| 169 |
+
"trainer = GRPOTrainer(\n",
|
| 170 |
+
" model=model,\n",
|
| 171 |
+
" processing_class=tokenizer,\n",
|
| 172 |
+
" reward_funcs=[reward_fn],\n",
|
| 173 |
+
" args=config,\n",
|
| 174 |
+
" train_dataset=train_dataset,\n",
|
| 175 |
+
")\n",
|
| 176 |
+
"trainer.train()"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "markdown",
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"source": [
|
| 183 |
+
"## 6. Evaluate the trained policy\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"Runs one rollout per headline task with the trained model as the text policy and reports the per-task scores plus the overall mean."
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": [
|
| 194 |
+
"import torch\n",
|
| 195 |
+
"from statistics import mean\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"def model_text_policy(prompt: str) -> str:\n",
|
| 198 |
+
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(model.device)\n",
|
| 199 |
+
" with torch.no_grad():\n",
|
| 200 |
+
" out = model.generate(**inputs, max_new_tokens=128, do_sample=False, temperature=0.0, pad_token_id=tokenizer.pad_token_id)\n",
|
| 201 |
+
" return tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"scores = []\n",
|
| 204 |
+
"for task in list_tasks():\n",
|
| 205 |
+
" result = run_episode_with_text_policy(task.task_id, model_text_policy)\n",
|
| 206 |
+
" scores.append(result.score)\n",
|
| 207 |
+
" print(f'{task.task_id:32s} score={result.score:.4f} steps={result.steps_used} invalid={result.invalid_actions}')\n",
|
| 208 |
+
"print('mean score', mean(scores))"
|
| 209 |
+
]
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"cell_type": "markdown",
|
| 213 |
+
"metadata": {},
|
| 214 |
+
"source": [
|
| 215 |
+
"## Next steps\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"1. Bump `max_steps` in step 5 to 200 (save checkpoints at 0/50/100/150/200).\n",
|
| 218 |
+
"2. Record per-checkpoint mean score and plot the curve.\n",
|
| 219 |
+
"3. Compare against the fixed-policy baselines in `runners/benchmark_runner.py`.\n"
|
| 220 |
+
]
|
| 221 |
+
}
|
| 222 |
+
],
|
| 223 |
+
"metadata": {
|
| 224 |
+
"kernelspec": {
|
| 225 |
+
"display_name": "Python 3",
|
| 226 |
+
"language": "python",
|
| 227 |
+
"name": "python3"
|
| 228 |
+
},
|
| 229 |
+
"language_info": {
|
| 230 |
+
"name": "python",
|
| 231 |
+
"version": "3.12"
|
| 232 |
+
}
|
| 233 |
+
},
|
| 234 |
+
"nbformat": 4,
|
| 235 |
+
"nbformat_minor": 5
|
| 236 |
+
}
|
runners/benchmark_runner.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scripted-policy benchmark runner for ChargebackOps.
|
| 2 |
+
|
| 3 |
+
Drives a fixed set of non-learning policies through the full environment so
|
| 4 |
+
the trained-merchant vs. baseline discrimination delta can be measured
|
| 5 |
+
without calling an LLM provider. Every policy returned here is deterministic
|
| 6 |
+
and offline.
|
| 7 |
+
|
| 8 |
+
Policies
|
| 9 |
+
--------
|
| 10 |
+
* ``heuristic`` — the Round 1 first-candidate pick (best scripted baseline).
|
| 11 |
+
* ``concede_all`` — always set strategy to ``accept_chargeback`` and resolve.
|
| 12 |
+
* ``escalate_all`` — contest like the heuristic, then escalate in the
|
| 13 |
+
pre-arb and arbitration steps regardless of evidence strength.
|
| 14 |
+
* ``naive`` — submit an empty packet / take a minimal path to terminal.
|
| 15 |
+
|
| 16 |
+
The runner also exposes :func:`run_multi_seed` which sweeps each policy
|
| 17 |
+
over the headline catalog plus extra generator seeds so the benchmark
|
| 18 |
+
table in ``docs/RESULTS_V2.md`` is reproducible from one command.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from statistics import mean, pstdev
|
| 25 |
+
from typing import Any, Callable, Iterable, Sequence
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from ..core.models import ChargebackOpsAction
|
| 29 |
+
from ..scenarios.simulation import TaskScenario, get_task, list_tasks
|
| 30 |
+
from ..server.chargeback_ops_environment import ChargebackOpsEnvironment
|
| 31 |
+
from .baseline_runner import candidate_actions
|
| 32 |
+
except ImportError: # pragma: no cover
|
| 33 |
+
from core.models import ChargebackOpsAction
|
| 34 |
+
from scenarios.simulation import TaskScenario, get_task, list_tasks
|
| 35 |
+
from server.chargeback_ops_environment import ChargebackOpsEnvironment
|
| 36 |
+
from runners.baseline_runner import candidate_actions
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
PolicyFn = Callable[[dict[str, Any]], ChargebackOpsAction | None]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
POLICY_NAMES: tuple[str, ...] = (
|
| 43 |
+
"heuristic",
|
| 44 |
+
"escalate_all",
|
| 45 |
+
"concede_all",
|
| 46 |
+
"naive",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# Scripted policies
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def heuristic_policy(observation: dict[str, Any]) -> ChargebackOpsAction | None:
|
| 56 |
+
"""First-candidate pick from the existing candidate generator."""
|
| 57 |
+
|
| 58 |
+
candidates = candidate_actions(observation)
|
| 59 |
+
if not candidates:
|
| 60 |
+
return None
|
| 61 |
+
return candidates[0].action
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def escalate_all_policy(observation: dict[str, Any]) -> ChargebackOpsAction | None:
|
| 65 |
+
"""Play like the heuristic, but always push terminal disputes into arbitration."""
|
| 66 |
+
|
| 67 |
+
available = set(observation.get("available_actions", []))
|
| 68 |
+
visible_case = observation.get("visible_case")
|
| 69 |
+
if visible_case is not None and "escalate_to_arbitration" in available:
|
| 70 |
+
return ChargebackOpsAction(
|
| 71 |
+
action_type="escalate_to_arbitration",
|
| 72 |
+
case_id=visible_case["case_id"],
|
| 73 |
+
)
|
| 74 |
+
return heuristic_policy(observation)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def concede_all_policy(observation: dict[str, Any]) -> ChargebackOpsAction | None:
|
| 78 |
+
"""Always accept the chargeback. Never contests, never escalates."""
|
| 79 |
+
|
| 80 |
+
available = set(observation.get("available_actions", []))
|
| 81 |
+
visible_case = observation.get("visible_case")
|
| 82 |
+
queue = observation.get("queue", [])
|
| 83 |
+
|
| 84 |
+
if visible_case is None:
|
| 85 |
+
open_cases = [item for item in queue if item["status"] == "open"]
|
| 86 |
+
if not open_cases:
|
| 87 |
+
return None
|
| 88 |
+
target = sorted(
|
| 89 |
+
open_cases,
|
| 90 |
+
key=lambda item: (item["steps_until_deadline"], -item["amount"]),
|
| 91 |
+
)[0]
|
| 92 |
+
return ChargebackOpsAction(
|
| 93 |
+
action_type="select_case", case_id=target["case_id"]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
case_id = visible_case["case_id"]
|
| 97 |
+
if visible_case["status"] != "open":
|
| 98 |
+
open_cases = [
|
| 99 |
+
item
|
| 100 |
+
for item in queue
|
| 101 |
+
if item["status"] == "open" and item["case_id"] != case_id
|
| 102 |
+
]
|
| 103 |
+
if not open_cases:
|
| 104 |
+
return None
|
| 105 |
+
target = sorted(
|
| 106 |
+
open_cases,
|
| 107 |
+
key=lambda item: (item["steps_until_deadline"], -item["amount"]),
|
| 108 |
+
)[0]
|
| 109 |
+
return ChargebackOpsAction(
|
| 110 |
+
action_type="select_case", case_id=target["case_id"]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if "accept_arbitration_loss" in available:
|
| 114 |
+
return ChargebackOpsAction(
|
| 115 |
+
action_type="accept_arbitration_loss", case_id=case_id
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if visible_case.get("current_strategy") != "accept_chargeback" and (
|
| 119 |
+
"set_strategy" in available
|
| 120 |
+
):
|
| 121 |
+
return ChargebackOpsAction(
|
| 122 |
+
action_type="set_strategy",
|
| 123 |
+
case_id=case_id,
|
| 124 |
+
strategy="accept_chargeback",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if "resolve_case" in available:
|
| 128 |
+
return ChargebackOpsAction(
|
| 129 |
+
action_type="resolve_case",
|
| 130 |
+
case_id=case_id,
|
| 131 |
+
strategy="accept_chargeback",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return heuristic_policy(observation)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def naive_policy(observation: dict[str, Any]) -> ChargebackOpsAction | None:
|
| 138 |
+
"""Minimum-effort agent: select a case, submit without evidence or policy work."""
|
| 139 |
+
|
| 140 |
+
available = set(observation.get("available_actions", []))
|
| 141 |
+
visible_case = observation.get("visible_case")
|
| 142 |
+
queue = observation.get("queue", [])
|
| 143 |
+
|
| 144 |
+
if visible_case is None:
|
| 145 |
+
open_cases = [item for item in queue if item["status"] == "open"]
|
| 146 |
+
if not open_cases:
|
| 147 |
+
return None
|
| 148 |
+
return ChargebackOpsAction(
|
| 149 |
+
action_type="select_case", case_id=open_cases[0]["case_id"]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
case_id = visible_case["case_id"]
|
| 153 |
+
if visible_case["status"] != "open":
|
| 154 |
+
open_cases = [
|
| 155 |
+
item
|
| 156 |
+
for item in queue
|
| 157 |
+
if item["status"] == "open" and item["case_id"] != case_id
|
| 158 |
+
]
|
| 159 |
+
if not open_cases:
|
| 160 |
+
return None
|
| 161 |
+
return ChargebackOpsAction(
|
| 162 |
+
action_type="select_case", case_id=open_cases[0]["case_id"]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if "accept_arbitration_loss" in available:
|
| 166 |
+
return ChargebackOpsAction(
|
| 167 |
+
action_type="accept_arbitration_loss", case_id=case_id
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if "submit_representment" in available:
|
| 171 |
+
return ChargebackOpsAction(
|
| 172 |
+
action_type="submit_representment", case_id=case_id
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if "respond_to_pre_arb" in available:
|
| 176 |
+
return ChargebackOpsAction(
|
| 177 |
+
action_type="respond_to_pre_arb", case_id=case_id
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if "resolve_case" in available:
|
| 181 |
+
return ChargebackOpsAction(
|
| 182 |
+
action_type="resolve_case",
|
| 183 |
+
case_id=case_id,
|
| 184 |
+
strategy="accept_chargeback",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return heuristic_policy(observation)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
POLICY_REGISTRY: dict[str, PolicyFn] = {
|
| 191 |
+
"heuristic": heuristic_policy,
|
| 192 |
+
"escalate_all": escalate_all_policy,
|
| 193 |
+
"concede_all": concede_all_policy,
|
| 194 |
+
"naive": naive_policy,
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ---------------------------------------------------------------------------
|
| 199 |
+
# Episode / sweep driver
|
| 200 |
+
# ---------------------------------------------------------------------------
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@dataclass(frozen=True)
|
| 204 |
+
class TaskScore:
|
| 205 |
+
"""One policy × task result."""
|
| 206 |
+
|
| 207 |
+
policy: str
|
| 208 |
+
task_id: str
|
| 209 |
+
score: float
|
| 210 |
+
steps_used: int
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@dataclass(frozen=True)
|
| 214 |
+
class PolicySummary:
|
| 215 |
+
"""Aggregate of one policy across a task list."""
|
| 216 |
+
|
| 217 |
+
policy: str
|
| 218 |
+
mean_score: float
|
| 219 |
+
stdev: float
|
| 220 |
+
tasks: tuple[TaskScore, ...]
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@dataclass(frozen=True)
|
| 224 |
+
class BenchmarkResult:
|
| 225 |
+
"""Output of a full policy sweep."""
|
| 226 |
+
|
| 227 |
+
policies: tuple[PolicySummary, ...]
|
| 228 |
+
discrimination_delta: float # heuristic minus naive
|
| 229 |
+
|
| 230 |
+
def to_dict(self) -> dict[str, Any]:
|
| 231 |
+
return {
|
| 232 |
+
"discrimination_delta": self.discrimination_delta,
|
| 233 |
+
"policies": [
|
| 234 |
+
{
|
| 235 |
+
"policy": summary.policy,
|
| 236 |
+
"mean_score": summary.mean_score,
|
| 237 |
+
"stdev": summary.stdev,
|
| 238 |
+
"tasks": [
|
| 239 |
+
{
|
| 240 |
+
"task_id": task.task_id,
|
| 241 |
+
"score": task.score,
|
| 242 |
+
"steps_used": task.steps_used,
|
| 243 |
+
}
|
| 244 |
+
for task in summary.tasks
|
| 245 |
+
],
|
| 246 |
+
}
|
| 247 |
+
for summary in self.policies
|
| 248 |
+
],
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def run_policy_on_task(policy: PolicyFn, task: TaskScenario) -> TaskScore:
|
| 253 |
+
"""Drive one policy through one task. Fully offline, no LLM calls."""
|
| 254 |
+
|
| 255 |
+
env = ChargebackOpsEnvironment()
|
| 256 |
+
observation = env.reset(task_id=task.task_id)
|
| 257 |
+
max_steps = task.max_steps + 5 # small safety margin
|
| 258 |
+
steps = 0
|
| 259 |
+
while not observation.done and steps < max_steps:
|
| 260 |
+
action = policy(observation.model_dump())
|
| 261 |
+
if action is None:
|
| 262 |
+
break
|
| 263 |
+
observation = env.step(action)
|
| 264 |
+
steps += 1
|
| 265 |
+
|
| 266 |
+
report = env.state.grader_report
|
| 267 |
+
score = float(report.normalized_score) if report is not None else 0.0
|
| 268 |
+
return TaskScore(
|
| 269 |
+
policy=policy.__name__,
|
| 270 |
+
task_id=task.task_id,
|
| 271 |
+
score=score,
|
| 272 |
+
steps_used=env.state.step_count,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def run_policy_sweep(
|
| 277 |
+
policy_names: Sequence[str] = POLICY_NAMES,
|
| 278 |
+
tasks: Iterable[TaskScenario] | None = None,
|
| 279 |
+
) -> BenchmarkResult:
|
| 280 |
+
"""Run each named policy across the headline catalog (or provided tasks)."""
|
| 281 |
+
|
| 282 |
+
task_list = list(tasks) if tasks is not None else list_tasks()
|
| 283 |
+
|
| 284 |
+
summaries: list[PolicySummary] = []
|
| 285 |
+
for name in policy_names:
|
| 286 |
+
if name not in POLICY_REGISTRY:
|
| 287 |
+
raise KeyError(f"Unknown policy '{name}'. Known: {sorted(POLICY_REGISTRY)}")
|
| 288 |
+
policy = POLICY_REGISTRY[name]
|
| 289 |
+
task_scores: list[TaskScore] = []
|
| 290 |
+
for task in task_list:
|
| 291 |
+
score = run_policy_on_task(policy, task)
|
| 292 |
+
task_scores.append(
|
| 293 |
+
TaskScore(
|
| 294 |
+
policy=name,
|
| 295 |
+
task_id=score.task_id,
|
| 296 |
+
score=score.score,
|
| 297 |
+
steps_used=score.steps_used,
|
| 298 |
+
)
|
| 299 |
+
)
|
| 300 |
+
scores = [item.score for item in task_scores]
|
| 301 |
+
summaries.append(
|
| 302 |
+
PolicySummary(
|
| 303 |
+
policy=name,
|
| 304 |
+
mean_score=round(mean(scores), 4) if scores else 0.0,
|
| 305 |
+
stdev=round(pstdev(scores), 4) if len(scores) > 1 else 0.0,
|
| 306 |
+
tasks=tuple(task_scores),
|
| 307 |
+
)
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
by_name = {summary.policy: summary for summary in summaries}
|
| 311 |
+
delta = 0.0
|
| 312 |
+
if "heuristic" in by_name and "naive" in by_name:
|
| 313 |
+
delta = round(
|
| 314 |
+
by_name["heuristic"].mean_score - by_name["naive"].mean_score, 4
|
| 315 |
+
)
|
| 316 |
+
return BenchmarkResult(policies=tuple(summaries), discrimination_delta=delta)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def run_multi_seed(
|
| 320 |
+
seeds: Sequence[int],
|
| 321 |
+
difficulties: Sequence[str] = ("easy", "medium", "hard", "nightmare"),
|
| 322 |
+
policy_names: Sequence[str] = POLICY_NAMES,
|
| 323 |
+
) -> BenchmarkResult:
|
| 324 |
+
"""Sweep each policy over ``seeds × difficulties`` generated tasks.
|
| 325 |
+
|
| 326 |
+
Used for the multi-seed grid cited in the PRD's Day-5 exit criteria.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
tasks: list[TaskScenario] = []
|
| 330 |
+
for difficulty in difficulties:
|
| 331 |
+
for seed in seeds:
|
| 332 |
+
task_id = f"generated_{difficulty}_s{seed}"
|
| 333 |
+
tasks.append(get_task(task_id))
|
| 334 |
+
return run_policy_sweep(policy_names, tasks=tasks)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
__all__ = [
|
| 338 |
+
"POLICY_NAMES",
|
| 339 |
+
"POLICY_REGISTRY",
|
| 340 |
+
"PolicyFn",
|
| 341 |
+
"BenchmarkResult",
|
| 342 |
+
"PolicySummary",
|
| 343 |
+
"TaskScore",
|
| 344 |
+
"heuristic_policy",
|
| 345 |
+
"escalate_all_policy",
|
| 346 |
+
"concede_all_policy",
|
| 347 |
+
"naive_policy",
|
| 348 |
+
"run_policy_on_task",
|
| 349 |
+
"run_policy_sweep",
|
| 350 |
+
"run_multi_seed",
|
| 351 |
+
]
|
scenarios/issuer_model.py
CHANGED
|
@@ -147,15 +147,15 @@ class IssuerAgent:
|
|
| 147 |
decision=IssuerDecision.ACCEPT,
|
| 148 |
evidence_strength_score=score,
|
| 149 |
rationale=(
|
| 150 |
-
f"
|
| 151 |
-
f"
|
| 152 |
),
|
| 153 |
)
|
| 154 |
return IssuerReview(
|
| 155 |
decision=IssuerDecision.ESCALATE_TO_ARBITRATION,
|
| 156 |
evidence_strength_score=score,
|
| 157 |
rationale=(
|
| 158 |
-
f"
|
| 159 |
f"escalating to network arbitration."
|
| 160 |
),
|
| 161 |
)
|
|
@@ -166,7 +166,7 @@ class IssuerAgent:
|
|
| 166 |
decision=IssuerDecision.ACCEPT,
|
| 167 |
evidence_strength_score=score,
|
| 168 |
rationale=(
|
| 169 |
-
f"
|
| 170 |
),
|
| 171 |
)
|
| 172 |
if score <= ROUND1_REJECT_THRESHOLD:
|
|
@@ -174,7 +174,7 @@ class IssuerAgent:
|
|
| 174 |
decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
|
| 175 |
evidence_strength_score=score,
|
| 176 |
rationale=(
|
| 177 |
-
f"
|
| 178 |
f"requesting compelling evidence."
|
| 179 |
),
|
| 180 |
)
|
|
@@ -191,7 +191,7 @@ class IssuerAgent:
|
|
| 191 |
decision=IssuerDecision.ACCEPT,
|
| 192 |
evidence_strength_score=score,
|
| 193 |
rationale=(
|
| 194 |
-
f"
|
| 195 |
f"LLM softening accepted."
|
| 196 |
),
|
| 197 |
used_llm_softening=True,
|
|
@@ -201,7 +201,7 @@ class IssuerAgent:
|
|
| 201 |
decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
|
| 202 |
evidence_strength_score=score,
|
| 203 |
rationale=(
|
| 204 |
-
f"
|
| 205 |
f"LLM softening requested compelling evidence."
|
| 206 |
),
|
| 207 |
used_llm_softening=True,
|
|
@@ -212,7 +212,7 @@ class IssuerAgent:
|
|
| 212 |
decision=IssuerDecision.ACCEPT,
|
| 213 |
evidence_strength_score=score,
|
| 214 |
rationale=(
|
| 215 |
-
f"
|
| 216 |
f"(>= {ROUND1_MIDPOINT_FALLBACK:.2f} midpoint) — accepting."
|
| 217 |
),
|
| 218 |
)
|
|
@@ -220,7 +220,7 @@ class IssuerAgent:
|
|
| 220 |
decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
|
| 221 |
evidence_strength_score=score,
|
| 222 |
rationale=(
|
| 223 |
-
f"
|
| 224 |
f"(< {ROUND1_MIDPOINT_FALLBACK:.2f} midpoint) — requesting more evidence."
|
| 225 |
),
|
| 226 |
)
|
|
|
|
| 147 |
decision=IssuerDecision.ACCEPT,
|
| 148 |
evidence_strength_score=score,
|
| 149 |
rationale=(
|
| 150 |
+
f"Pre-arb evidence brings the packet to {score:.2f}, "
|
| 151 |
+
f"above the 0.60 acceptance bar."
|
| 152 |
),
|
| 153 |
)
|
| 154 |
return IssuerReview(
|
| 155 |
decision=IssuerDecision.ESCALATE_TO_ARBITRATION,
|
| 156 |
evidence_strength_score=score,
|
| 157 |
rationale=(
|
| 158 |
+
f"Packet still scores {score:.2f}; "
|
| 159 |
f"escalating to network arbitration."
|
| 160 |
),
|
| 161 |
)
|
|
|
|
| 166 |
decision=IssuerDecision.ACCEPT,
|
| 167 |
evidence_strength_score=score,
|
| 168 |
rationale=(
|
| 169 |
+
f"Packet scores {score:.2f}, clearing the 0.70 acceptance bar."
|
| 170 |
),
|
| 171 |
)
|
| 172 |
if score <= ROUND1_REJECT_THRESHOLD:
|
|
|
|
| 174 |
decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
|
| 175 |
evidence_strength_score=score,
|
| 176 |
rationale=(
|
| 177 |
+
f"Packet scores {score:.2f}, below the 0.40 floor; "
|
| 178 |
f"requesting compelling evidence."
|
| 179 |
),
|
| 180 |
)
|
|
|
|
| 191 |
decision=IssuerDecision.ACCEPT,
|
| 192 |
evidence_strength_score=score,
|
| 193 |
rationale=(
|
| 194 |
+
f"Ambiguity band: packet scores {score:.2f} — "
|
| 195 |
f"LLM softening accepted."
|
| 196 |
),
|
| 197 |
used_llm_softening=True,
|
|
|
|
| 201 |
decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
|
| 202 |
evidence_strength_score=score,
|
| 203 |
rationale=(
|
| 204 |
+
f"Ambiguity band: packet scores {score:.2f} — "
|
| 205 |
f"LLM softening requested compelling evidence."
|
| 206 |
),
|
| 207 |
used_llm_softening=True,
|
|
|
|
| 212 |
decision=IssuerDecision.ACCEPT,
|
| 213 |
evidence_strength_score=score,
|
| 214 |
rationale=(
|
| 215 |
+
f"Ambiguity band: packet scores {score:.2f} "
|
| 216 |
f"(>= {ROUND1_MIDPOINT_FALLBACK:.2f} midpoint) — accepting."
|
| 217 |
),
|
| 218 |
)
|
|
|
|
| 220 |
decision=IssuerDecision.REQUEST_MORE_EVIDENCE,
|
| 221 |
evidence_strength_score=score,
|
| 222 |
rationale=(
|
| 223 |
+
f"Ambiguity band: packet scores {score:.2f} "
|
| 224 |
f"(< {ROUND1_MIDPOINT_FALLBACK:.2f} midpoint) — requesting more evidence."
|
| 225 |
),
|
| 226 |
)
|
tests/test_benchmark_runner.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for the scripted-policy benchmark runner.
|
| 2 |
+
|
| 3 |
+
The runner drives a fixed set of non-learning policies through the full
|
| 4 |
+
environment without LLM calls. These tests pin:
|
| 5 |
+
|
| 6 |
+
1. Each policy returns valid action or None offline.
|
| 7 |
+
2. The aggregator produces per-policy means and a discrimination delta.
|
| 8 |
+
3. The headline policy sweep keeps the heuristic ≥ 0.40 above the naive floor.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from runners.benchmark_runner import (
|
| 14 |
+
POLICY_NAMES,
|
| 15 |
+
POLICY_REGISTRY,
|
| 16 |
+
concede_all_policy,
|
| 17 |
+
escalate_all_policy,
|
| 18 |
+
heuristic_policy,
|
| 19 |
+
naive_policy,
|
| 20 |
+
run_multi_seed,
|
| 21 |
+
run_policy_on_task,
|
| 22 |
+
run_policy_sweep,
|
| 23 |
+
)
|
| 24 |
+
from scenarios.simulation import get_task
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_EASY_TASK = get_task("goods_not_received_easy")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_policy_registry_matches_public_names():
|
| 31 |
+
assert set(POLICY_NAMES) == set(POLICY_REGISTRY)
|
| 32 |
+
assert set(POLICY_NAMES) == {"heuristic", "escalate_all", "concede_all", "naive"}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_heuristic_scores_above_naive_on_easy():
|
| 36 |
+
heur = run_policy_on_task(heuristic_policy, _EASY_TASK)
|
| 37 |
+
nv = run_policy_on_task(naive_policy, _EASY_TASK)
|
| 38 |
+
assert heur.score > nv.score
|
| 39 |
+
assert heur.task_id == _EASY_TASK.task_id
|
| 40 |
+
assert heur.steps_used > 0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_concede_all_lands_final_resolution():
|
| 44 |
+
"""concede_all must always terminate the episode with a concede path."""
|
| 45 |
+
result = run_policy_on_task(concede_all_policy, _EASY_TASK)
|
| 46 |
+
assert result.steps_used > 0
|
| 47 |
+
# concede_all scores strictly below heuristic but must stay in [0, 1].
|
| 48 |
+
assert 0.0 <= result.score <= 1.0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_escalate_all_runs_to_completion():
|
| 52 |
+
result = run_policy_on_task(escalate_all_policy, _EASY_TASK)
|
| 53 |
+
assert 0.0 <= result.score <= 1.0
|
| 54 |
+
assert result.steps_used > 0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_sweep_aggregates_and_produces_delta():
|
| 58 |
+
result = run_policy_sweep()
|
| 59 |
+
policies = {summary.policy: summary for summary in result.policies}
|
| 60 |
+
assert set(policies) == set(POLICY_NAMES)
|
| 61 |
+
# mean scores sit in the valid range
|
| 62 |
+
for summary in result.policies:
|
| 63 |
+
assert 0.0 <= summary.mean_score <= 1.0
|
| 64 |
+
# discrimination delta is heuristic - naive and must clear the PRD bar.
|
| 65 |
+
assert result.discrimination_delta >= 0.40
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_sweep_is_deterministic():
|
| 69 |
+
"""Two runs on the same catalog must produce identical numbers."""
|
| 70 |
+
first = run_policy_sweep().to_dict()
|
| 71 |
+
second = run_policy_sweep().to_dict()
|
| 72 |
+
assert first == second
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_multi_seed_sweep_runs_subset():
|
| 76 |
+
"""Tiny grid (2 seeds × 1 difficulty) stays under a second and returns data."""
|
| 77 |
+
result = run_multi_seed(seeds=[42, 17], difficulties=["easy"])
|
| 78 |
+
for summary in result.policies:
|
| 79 |
+
assert len(summary.tasks) == 2
|
| 80 |
+
for task_score in summary.tasks:
|
| 81 |
+
assert task_score.task_id.startswith("generated_easy_s")
|
tests/test_training_adapter.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for the training adapter.
|
| 2 |
+
|
| 3 |
+
Pin the prompt/completion serialization and the episode-replay reward
|
| 4 |
+
signal so the training notebook has a stable offline contract.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
from core.models import ChargebackOpsAction
|
| 12 |
+
from scenarios.simulation import get_task
|
| 13 |
+
from server.chargeback_ops_environment import ChargebackOpsEnvironment
|
| 14 |
+
from training.env_adapter import (
|
| 15 |
+
action_from_completion,
|
| 16 |
+
build_prompt,
|
| 17 |
+
parse_completion,
|
| 18 |
+
)
|
| 19 |
+
from training.reward_adapter import (
|
| 20 |
+
compute_reward,
|
| 21 |
+
run_episode_with_text_policy,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _fresh_observation(task_id: str = "goods_not_received_easy"):
|
| 26 |
+
env = ChargebackOpsEnvironment()
|
| 27 |
+
return env.reset(task_id=task_id).model_dump()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_build_prompt_is_deterministic_and_includes_available_actions():
|
| 31 |
+
obs = _fresh_observation()
|
| 32 |
+
a = build_prompt(obs)
|
| 33 |
+
b = build_prompt(obs)
|
| 34 |
+
assert a == b
|
| 35 |
+
assert "available_actions" in a
|
| 36 |
+
assert "OBSERVATION:" in a
|
| 37 |
+
assert "ACTION:" in a
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_parse_completion_accepts_plain_json():
|
| 41 |
+
payload = '{"action_type": "select_case", "case_id": "CB-X"}'
|
| 42 |
+
parsed = parse_completion(payload)
|
| 43 |
+
assert parsed == {"action_type": "select_case", "case_id": "CB-X"}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_parse_completion_strips_code_fence():
|
| 47 |
+
payload = '```json\n{"action_type": "select_case", "case_id": "CB-X"}\n```'
|
| 48 |
+
parsed = parse_completion(payload)
|
| 49 |
+
assert parsed == {"action_type": "select_case", "case_id": "CB-X"}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_parse_completion_returns_none_on_garbage():
|
| 53 |
+
assert parse_completion("") is None
|
| 54 |
+
assert parse_completion("not json at all") is None
|
| 55 |
+
assert parse_completion("{not-valid-json}") is None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_parse_completion_drops_unknown_fields():
|
| 59 |
+
payload = json.dumps({"action_type": "select_case", "hack_field": 42})
|
| 60 |
+
parsed = parse_completion(payload)
|
| 61 |
+
assert parsed == {"action_type": "select_case"}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_action_from_completion_returns_valid_action():
|
| 65 |
+
payload = '{"action_type": "select_case", "case_id": "CB-X"}'
|
| 66 |
+
action = action_from_completion(payload)
|
| 67 |
+
assert isinstance(action, ChargebackOpsAction)
|
| 68 |
+
assert action.action_type == "select_case"
|
| 69 |
+
assert action.case_id == "CB-X"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_action_from_completion_returns_none_on_bad_type():
|
| 73 |
+
payload = '{"action_type": "not_a_real_action"}'
|
| 74 |
+
assert action_from_completion(payload) is None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_run_episode_falls_back_to_heuristic_on_empty_completion():
|
| 78 |
+
"""Unparseable completions must not deadlock the episode."""
|
| 79 |
+
result = run_episode_with_text_policy(
|
| 80 |
+
"goods_not_received_easy",
|
| 81 |
+
text_policy=lambda _prompt: "",
|
| 82 |
+
)
|
| 83 |
+
assert result.steps_used > 0
|
| 84 |
+
assert result.invalid_actions > 0
|
| 85 |
+
assert result.score > 0.0 # heuristic fallback still scores
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def test_compute_reward_matches_episode_score():
|
| 89 |
+
"""Single completion + heuristic tail reproduces the heuristic score."""
|
| 90 |
+
task = get_task("goods_not_received_easy")
|
| 91 |
+
prompts = ["unused"]
|
| 92 |
+
completions = [""] # triggers heuristic fallback on the first action
|
| 93 |
+
rewards = compute_reward(
|
| 94 |
+
prompts, completions, task_ids=[task.task_id]
|
| 95 |
+
)
|
| 96 |
+
assert len(rewards) == 1
|
| 97 |
+
assert 0.0 <= rewards[0] <= 1.0
|
| 98 |
+
assert rewards[0] > 0.5 # heuristic scores ~0.97 on this task
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def test_compute_reward_rejects_mismatched_lengths():
|
| 102 |
+
import pytest
|
| 103 |
+
|
| 104 |
+
with pytest.raises(ValueError):
|
| 105 |
+
compute_reward(["a"], ["b", "c"], task_ids=["goods_not_received_easy"])
|
training/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training helpers for ChargebackOps.
|
| 2 |
+
|
| 3 |
+
Lightweight pure-Python wrappers that convert the environment into a
|
| 4 |
+
prompt/completion/reward interface compatible with TRL's GRPO trainer.
|
| 5 |
+
The module is import-safe without ``trl`` / ``torch`` installed so unit
|
| 6 |
+
tests stay fast and offline.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from .env_adapter import (
|
| 12 |
+
action_from_completion,
|
| 13 |
+
build_prompt,
|
| 14 |
+
parse_completion,
|
| 15 |
+
)
|
| 16 |
+
from .reward_adapter import (
|
| 17 |
+
EpisodeResult,
|
| 18 |
+
compute_reward,
|
| 19 |
+
run_episode_with_text_policy,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"EpisodeResult",
|
| 24 |
+
"action_from_completion",
|
| 25 |
+
"build_prompt",
|
| 26 |
+
"compute_reward",
|
| 27 |
+
"parse_completion",
|
| 28 |
+
"run_episode_with_text_policy",
|
| 29 |
+
]
|
training/env_adapter.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Text prompt / completion adapter for the merchant policy.
|
| 2 |
+
|
| 3 |
+
Serialize an observation into a compact prompt the model can condition
|
| 4 |
+
on, and parse a JSON completion back into a typed
|
| 5 |
+
``ChargebackOpsAction``. Both helpers are pure — no provider calls, no
|
| 6 |
+
side effects — so they are cheap to unit-test.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from ..core.models import ChargebackOpsAction
|
| 16 |
+
except ImportError: # pragma: no cover
|
| 17 |
+
from core.models import ChargebackOpsAction
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_SYSTEM_INSTRUCTION = (
|
| 21 |
+
"You play the merchant-side agent in a chargeback dispute. "
|
| 22 |
+
"Look at the observation and choose the single best next action. "
|
| 23 |
+
"Return JSON only: "
|
| 24 |
+
'{"action_type": "...", "case_id": "...", "strategy": "...", '
|
| 25 |
+
'"evidence_ids": [...], "note": "..."} '
|
| 26 |
+
"Use only action_types listed in available_actions. Omit fields you "
|
| 27 |
+
"do not need."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
_ALLOWED_ACTION_FIELDS: frozenset[str] = frozenset(
|
| 32 |
+
{
|
| 33 |
+
"action_type",
|
| 34 |
+
"case_id",
|
| 35 |
+
"system_name",
|
| 36 |
+
"evidence_ids",
|
| 37 |
+
"compelling_evidence_ids",
|
| 38 |
+
"strategy",
|
| 39 |
+
"note",
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _compact_observation(observation: dict[str, Any]) -> dict[str, Any]:
|
| 45 |
+
"""Drop fields that add tokens without signal for the merchant policy."""
|
| 46 |
+
|
| 47 |
+
visible_case = observation.get("visible_case")
|
| 48 |
+
compact_case: dict[str, Any] | None = None
|
| 49 |
+
if visible_case is not None:
|
| 50 |
+
compact_case = {
|
| 51 |
+
"case_id": visible_case["case_id"],
|
| 52 |
+
"status": visible_case["status"],
|
| 53 |
+
"reason_code": visible_case["reason_code"],
|
| 54 |
+
"amount": visible_case["amount"],
|
| 55 |
+
"currency": visible_case["currency"],
|
| 56 |
+
"current_strategy": visible_case.get("current_strategy"),
|
| 57 |
+
"systems_revealed": visible_case.get("systems_revealed", []),
|
| 58 |
+
"retrieved_evidence": [
|
| 59 |
+
{
|
| 60 |
+
"evidence_id": item["evidence_id"],
|
| 61 |
+
"source_system": item["source_system"],
|
| 62 |
+
"title": item["title"],
|
| 63 |
+
}
|
| 64 |
+
for item in visible_case.get("retrieved_evidence", [])
|
| 65 |
+
],
|
| 66 |
+
"attached_evidence": [
|
| 67 |
+
item["evidence_id"]
|
| 68 |
+
for item in visible_case.get("attached_evidence", [])
|
| 69 |
+
],
|
| 70 |
+
"policy": visible_case.get("policy"),
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"objective": observation.get("objective", ""),
|
| 75 |
+
"selected_case_id": observation.get("selected_case_id"),
|
| 76 |
+
"available_actions": observation.get("available_actions", []),
|
| 77 |
+
"steps_remaining": observation.get("steps_remaining", 0),
|
| 78 |
+
"queue": [
|
| 79 |
+
{
|
| 80 |
+
"case_id": item["case_id"],
|
| 81 |
+
"status": item["status"],
|
| 82 |
+
"reason_code": item["reason_code"],
|
| 83 |
+
"amount": item["amount"],
|
| 84 |
+
"steps_until_deadline": item["steps_until_deadline"],
|
| 85 |
+
}
|
| 86 |
+
for item in observation.get("queue", [])
|
| 87 |
+
],
|
| 88 |
+
"visible_case": compact_case,
|
| 89 |
+
"last_action_result": observation.get("last_action_result", ""),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def build_prompt(observation: dict[str, Any]) -> str:
|
| 94 |
+
"""Return a deterministic prompt for the merchant policy."""
|
| 95 |
+
|
| 96 |
+
compact = _compact_observation(observation)
|
| 97 |
+
body = json.dumps(compact, separators=(",", ":"), sort_keys=True)
|
| 98 |
+
return f"{_SYSTEM_INSTRUCTION}\nOBSERVATION:\n{body}\nACTION:"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def parse_completion(text: str) -> dict[str, Any] | None:
|
| 102 |
+
"""Parse a model completion into a raw action dict, or return None."""
|
| 103 |
+
|
| 104 |
+
if not text:
|
| 105 |
+
return None
|
| 106 |
+
cleaned = text.strip()
|
| 107 |
+
# Strip common code-fence patterns.
|
| 108 |
+
if cleaned.startswith("```"):
|
| 109 |
+
cleaned = cleaned.strip("`").strip()
|
| 110 |
+
if cleaned.lower().startswith("json"):
|
| 111 |
+
cleaned = cleaned[4:].lstrip()
|
| 112 |
+
# Find the first {...} block so prose before JSON is tolerated.
|
| 113 |
+
start = cleaned.find("{")
|
| 114 |
+
end = cleaned.rfind("}")
|
| 115 |
+
if start == -1 or end == -1 or end <= start:
|
| 116 |
+
return None
|
| 117 |
+
try:
|
| 118 |
+
data = json.loads(cleaned[start : end + 1])
|
| 119 |
+
except json.JSONDecodeError:
|
| 120 |
+
return None
|
| 121 |
+
if not isinstance(data, dict):
|
| 122 |
+
return None
|
| 123 |
+
return {k: v for k, v in data.items() if k in _ALLOWED_ACTION_FIELDS}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def action_from_completion(text: str) -> ChargebackOpsAction | None:
|
| 127 |
+
"""Parse a completion and build a validated :class:`ChargebackOpsAction`."""
|
| 128 |
+
|
| 129 |
+
parsed = parse_completion(text)
|
| 130 |
+
if parsed is None or "action_type" not in parsed:
|
| 131 |
+
return None
|
| 132 |
+
try:
|
| 133 |
+
return ChargebackOpsAction(**parsed)
|
| 134 |
+
except Exception:
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
__all__ = [
|
| 139 |
+
"action_from_completion",
|
| 140 |
+
"build_prompt",
|
| 141 |
+
"parse_completion",
|
| 142 |
+
]
|
training/reward_adapter.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward adapter for GRPO / RL training on ChargebackOps.
|
| 2 |
+
|
| 3 |
+
Exposes a callable shape compatible with TRL's GRPO reward function:
|
| 4 |
+
|
| 5 |
+
``reward_fn(prompts, completions, **kwargs) -> list[float]``
|
| 6 |
+
|
| 7 |
+
Each completion is parsed into an action sequence (one action per line
|
| 8 |
+
is the simplest case; the helper also accepts a single-action
|
| 9 |
+
completion and runs the remainder of the episode under the scripted
|
| 10 |
+
heuristic so training always produces a terminal score). The resulting
|
| 11 |
+
reward is the episode's deterministic normalized grade in ``[0, 1]``.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Any, Callable, Sequence
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from ..core.models import ChargebackOpsAction, ChargebackOpsObservation
|
| 21 |
+
from ..scenarios.simulation import get_task, list_tasks
|
| 22 |
+
from ..server.chargeback_ops_environment import ChargebackOpsEnvironment
|
| 23 |
+
from .env_adapter import action_from_completion, build_prompt
|
| 24 |
+
except ImportError: # pragma: no cover
|
| 25 |
+
from core.models import ChargebackOpsAction, ChargebackOpsObservation
|
| 26 |
+
from scenarios.simulation import get_task, list_tasks
|
| 27 |
+
from server.chargeback_ops_environment import ChargebackOpsEnvironment
|
| 28 |
+
from training.env_adapter import action_from_completion, build_prompt
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
TextPolicyFn = Callable[[str], str]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass(frozen=True)
|
| 35 |
+
class EpisodeResult:
|
| 36 |
+
"""Outcome of a single rollout."""
|
| 37 |
+
|
| 38 |
+
task_id: str
|
| 39 |
+
score: float
|
| 40 |
+
steps_used: int
|
| 41 |
+
invalid_actions: int
|
| 42 |
+
prompts: tuple[str, ...] = field(default_factory=tuple)
|
| 43 |
+
completions: tuple[str, ...] = field(default_factory=tuple)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _fallback_action(
|
| 47 |
+
observation: ChargebackOpsObservation,
|
| 48 |
+
) -> ChargebackOpsAction | None:
|
| 49 |
+
"""Scripted fallback when the model output is unparseable."""
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from ..runners.benchmark_runner import heuristic_policy
|
| 53 |
+
except ImportError: # pragma: no cover
|
| 54 |
+
from runners.benchmark_runner import heuristic_policy
|
| 55 |
+
return heuristic_policy(observation.model_dump())
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_episode_with_text_policy(
|
| 59 |
+
task_id: str,
|
| 60 |
+
text_policy: TextPolicyFn,
|
| 61 |
+
*,
|
| 62 |
+
max_steps: int | None = None,
|
| 63 |
+
capture_trace: bool = False,
|
| 64 |
+
) -> EpisodeResult:
|
| 65 |
+
"""Roll one episode forward under a text-in / text-out policy.
|
| 66 |
+
|
| 67 |
+
The policy is invoked at every step. If the completion fails to
|
| 68 |
+
parse into a valid action the scripted heuristic is used instead;
|
| 69 |
+
this keeps early-training trajectories from deadlocking.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
task = get_task(task_id)
|
| 73 |
+
env = ChargebackOpsEnvironment()
|
| 74 |
+
observation = env.reset(task_id=task_id)
|
| 75 |
+
step_budget = (max_steps if max_steps is not None else task.max_steps) + 5
|
| 76 |
+
steps = 0
|
| 77 |
+
invalid = 0
|
| 78 |
+
prompts: list[str] = []
|
| 79 |
+
completions: list[str] = []
|
| 80 |
+
|
| 81 |
+
while not observation.done and steps < step_budget:
|
| 82 |
+
obs_dict = observation.model_dump()
|
| 83 |
+
prompt = build_prompt(obs_dict)
|
| 84 |
+
completion = text_policy(prompt)
|
| 85 |
+
action = action_from_completion(completion)
|
| 86 |
+
if action is None:
|
| 87 |
+
invalid += 1
|
| 88 |
+
action = _fallback_action(observation)
|
| 89 |
+
if action is None:
|
| 90 |
+
break
|
| 91 |
+
observation = env.step(action)
|
| 92 |
+
steps += 1
|
| 93 |
+
if capture_trace:
|
| 94 |
+
prompts.append(prompt)
|
| 95 |
+
completions.append(completion)
|
| 96 |
+
|
| 97 |
+
report = env.state.grader_report
|
| 98 |
+
score = float(report.normalized_score) if report is not None else 0.0
|
| 99 |
+
return EpisodeResult(
|
| 100 |
+
task_id=task_id,
|
| 101 |
+
score=score,
|
| 102 |
+
steps_used=env.state.step_count,
|
| 103 |
+
invalid_actions=invalid,
|
| 104 |
+
prompts=tuple(prompts),
|
| 105 |
+
completions=tuple(completions),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_reward(
|
| 110 |
+
prompts: Sequence[str],
|
| 111 |
+
completions: Sequence[str],
|
| 112 |
+
*,
|
| 113 |
+
task_ids: Sequence[str] | None = None,
|
| 114 |
+
**_: Any,
|
| 115 |
+
) -> list[float]:
|
| 116 |
+
"""GRPO-style reward function.
|
| 117 |
+
|
| 118 |
+
Each ``completion`` is replayed as a *single* action. The remainder
|
| 119 |
+
of the episode is driven by the scripted heuristic, so the reward
|
| 120 |
+
signal rewards the model for picking a good first move from a
|
| 121 |
+
given observation. This matches the behaviour TRL expects: one
|
| 122 |
+
``(prompt, completion)`` pair → one scalar reward.
|
| 123 |
+
|
| 124 |
+
``task_ids`` optionally binds each prompt to a task id for env
|
| 125 |
+
replay. When omitted, the headline catalog is cycled.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
if task_ids is None:
|
| 129 |
+
headline = [task.task_id for task in list_tasks()]
|
| 130 |
+
task_ids = [headline[i % len(headline)] for i in range(len(prompts))]
|
| 131 |
+
if len(task_ids) != len(prompts) or len(prompts) != len(completions):
|
| 132 |
+
raise ValueError(
|
| 133 |
+
"prompts, completions, and task_ids must all have the same length"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
rewards: list[float] = []
|
| 137 |
+
for task_id, completion in zip(task_ids, completions):
|
| 138 |
+
first_action = action_from_completion(completion)
|
| 139 |
+
|
| 140 |
+
def _once(_prompt: str, _used=[False], _action=first_action) -> str:
|
| 141 |
+
if _used[0] or _action is None:
|
| 142 |
+
return ""
|
| 143 |
+
_used[0] = True
|
| 144 |
+
return completion
|
| 145 |
+
|
| 146 |
+
result = run_episode_with_text_policy(task_id, _once)
|
| 147 |
+
rewards.append(result.score)
|
| 148 |
+
return rewards
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
__all__ = [
|
| 152 |
+
"EpisodeResult",
|
| 153 |
+
"TextPolicyFn",
|
| 154 |
+
"compute_reward",
|
| 155 |
+
"run_episode_with_text_policy",
|
| 156 |
+
]
|