ChargeBackOps / tests /test_training_adapter.py
mitudrudutta's picture
feat(training): SFT dataset + stall detection in eval rollout
02a6a9f
"""Unit tests for the training adapter.
Pin the prompt/completion serialization and the episode-replay reward
signal so the training notebook has a stable offline contract.
"""
from __future__ import annotations
import json
from core.models import ChargebackOpsAction
from scenarios.simulation import get_task
from server.chargeback_ops_environment import ChargebackOpsEnvironment
from training.env_adapter import (
action_from_completion,
build_prompt,
parse_completion,
)
from training.reward_adapter import (
build_state_action_dataset,
compute_reward,
run_episode_with_text_policy,
)
def _fresh_observation(task_id: str = "goods_not_received_easy"):
env = ChargebackOpsEnvironment()
return env.reset(task_id=task_id).model_dump()
def test_build_prompt_is_deterministic_and_includes_available_actions():
obs = _fresh_observation()
a = build_prompt(obs)
b = build_prompt(obs)
assert a == b
assert "available_actions" in a
assert "OBSERVATION:" in a
assert "ACTION:" in a
def test_parse_completion_accepts_plain_json():
payload = '{"action_type": "select_case", "case_id": "CB-X"}'
parsed = parse_completion(payload)
assert parsed == {"action_type": "select_case", "case_id": "CB-X"}
def test_parse_completion_strips_code_fence():
payload = '```json\n{"action_type": "select_case", "case_id": "CB-X"}\n```'
parsed = parse_completion(payload)
assert parsed == {"action_type": "select_case", "case_id": "CB-X"}
def test_parse_completion_returns_none_on_garbage():
assert parse_completion("") is None
assert parse_completion("not json at all") is None
assert parse_completion("{not-valid-json}") is None
def test_parse_completion_drops_unknown_fields():
payload = json.dumps({"action_type": "select_case", "hack_field": 42})
parsed = parse_completion(payload)
assert parsed == {"action_type": "select_case"}
def test_action_from_completion_returns_valid_action():
payload = '{"action_type": "select_case", "case_id": "CB-X"}'
action = action_from_completion(payload)
assert isinstance(action, ChargebackOpsAction)
assert action.action_type == "select_case"
assert action.case_id == "CB-X"
def test_action_from_completion_returns_none_on_bad_type():
payload = '{"action_type": "not_a_real_action"}'
assert action_from_completion(payload) is None
def test_parse_completion_handles_truncated_json():
"""Mid-string truncation: tolerate by closing at last balanced field."""
payload = (
'```json\n{"action_type": "select_case", "case_id": "CB-E1", '
'"strategy": "Select the case ID to procee'
)
parsed = parse_completion(payload)
assert parsed is not None
assert parsed["action_type"] == "select_case"
assert parsed["case_id"] == "CB-E1"
def test_parse_completion_strips_think_block():
payload = (
'<think>\nlet me think about this\n</think>\n'
'{"action_type": "select_case", "case_id": "CB-1"}'
)
parsed = parse_completion(payload)
assert parsed == {"action_type": "select_case", "case_id": "CB-1"}
def test_parse_completion_infers_action_type_from_prefix():
"""Model emits action name as prose then JSON without action_type field."""
payload = ' select_case\n{"case_id": "CB-X", "strategy": "go"}'
parsed = parse_completion(payload)
assert parsed is not None
assert parsed["action_type"] == "select_case"
assert parsed["case_id"] == "CB-X"
def test_run_episode_falls_back_to_heuristic_on_empty_completion():
"""Unparseable completions must not deadlock the episode."""
result = run_episode_with_text_policy(
"goods_not_received_easy",
text_policy=lambda _prompt: "",
)
assert result.steps_used > 0
assert result.invalid_actions > 0
assert result.score > 0.0 # heuristic fallback still scores
def test_compute_reward_unparseable_returns_zero():
"""Per-action scorer must NOT fall back to heuristic on parse-fail.
The previous fallback design poisoned the GRPO signal: garbage
completions earned ~0.96 reward (heuristic played the episode), so
the model learned that emitting garbage was optimal and group
reward variance collapsed to ~0.005, killing the gradient.
"""
rewards = compute_reward(
["unused"], [""], task_ids=["goods_not_received_easy"]
)
assert rewards == [0.0]
def test_compute_reward_exact_match_scores_one():
"""Completion that matches the heuristic action exactly gets 1.0."""
import json
from runners.benchmark_runner import heuristic_policy
from server.chargeback_ops_environment import ChargebackOpsEnvironment
env = ChargebackOpsEnvironment()
obs = env.reset(task_id="goods_not_received_easy")
oracle = heuristic_policy(obs.model_dump())
completion = json.dumps(oracle.model_dump(exclude_none=True))
rewards = compute_reward(
["unused"], [completion], task_ids=["goods_not_received_easy"]
)
assert rewards == [1.0]
def test_compute_reward_unavailable_action_scores_low():
"""Valid JSON but action_type not allowed at this state β†’ 0.1."""
# First state on goods_not_received_easy only allows ``select_case``.
completion = '{"action_type": "submit_representment", "case_id": "CB-E1"}'
rewards = compute_reward(
["unused"], [completion], task_ids=["goods_not_received_easy"]
)
assert rewards == [0.1]
def test_compute_reward_has_real_variance_across_diverse_completions():
"""Diverse completions must produce distinct rewards (the whole point).
The prior design produced std β‰ˆ 0.005 across 6 wildly different
completions because the heuristic dominated the episode. New design
should give β‰₯ 3 distinct reward values across the same set.
"""
import json
from runners.benchmark_runner import heuristic_policy
from server.chargeback_ops_environment import ChargebackOpsEnvironment
env = ChargebackOpsEnvironment()
obs = env.reset(task_id="goods_not_received_easy")
oracle = heuristic_policy(obs.model_dump())
completions = [
"", # parse-fail β†’ 0.0
"garbage no json", # parse-fail β†’ 0.0
'{"action_type": "submit_representment", "case_id": "CB-E1"}', # unavailable β†’ 0.1
json.dumps(oracle.model_dump(exclude_none=True)), # exact β†’ 1.0
]
rewards = compute_reward(
["x"] * 4, completions, task_ids=["goods_not_received_easy"] * 4
)
assert len(set(rewards)) >= 3
assert max(rewards) - min(rewards) >= 0.5
def test_compute_reward_state_steps_advance_env():
"""state_steps replays heuristic to reach mid-episode states."""
rewards = compute_reward(
["x", "x"],
["", ""],
task_ids=["goods_not_received_easy", "goods_not_received_easy"],
state_steps=[0, 2],
)
# Both unparseable β†’ both 0.0 regardless of state.
assert rewards == [0.0, 0.0]
def test_build_state_action_dataset_covers_multiple_states():
"""Heuristic rollout must yield several (state, oracle) pairs per task."""
samples = build_state_action_dataset(
["goods_not_received_easy"], max_states_per_task=8
)
assert len(samples) >= 2
state_steps = [s["state_step"] for s in samples]
assert state_steps == sorted(state_steps)
assert state_steps[0] == 0
for s in samples:
assert s["task_id"] == "goods_not_received_easy"
assert "OBSERVATION:" in s["prompt"]
def test_compute_reward_rejects_mismatched_lengths():
import pytest
with pytest.raises(ValueError):
compute_reward(["a"], ["b", "c"], task_ids=["goods_not_received_easy"])
def test_run_episode_breaks_select_case_loop():
"""Degenerate model that always emits select_case must not deadlock.
Real failure mode observed in Colab eval: a Qwen3.5 checkpoint
after 300 GRPO steps emitted ``select_case`` at every state. The
env silently no-ops the second ``select_case``, the prompt stays
identical, the model emits the same string, score stays 0 because
``done`` never flips. Stall detection must force-fallback to the
heuristic so the episode reaches grading.
"""
import json
select_case_payload = json.dumps(
{"action_type": "select_case", "case_id": "CB-E1"}
)
result = run_episode_with_text_policy(
"goods_not_received_easy",
text_policy=lambda _prompt: select_case_payload,
)
assert result.steps_used > 0
assert result.score > 0.0, (
f"stall detection failed: score={result.score} "
f"means episode never reached terminal grading"
)