ChargeBackOps / training /sft_dataset.py
mitudrudutta's picture
feat(training): SFT dataset + stall detection in eval rollout
02a6a9f
"""Supervised fine-tuning dataset builder for ChargebackOps.
Rolls the scripted heuristic across each task and captures every
``(observation_prompt, oracle_completion)`` pair as a single-turn
training sample. The completion is the JSON serialisation of the
heuristic action, matching the format the merchant policy must emit
at inference time.
SFT before GRPO is the standard RLHF pattern. It teaches the base
model two things GRPO struggles to learn from sparse reward alone:
* The output schema (valid JSON, the right ``action_type`` strings,
no extra prose).
* Per-state action variety — the heuristic emits a *different*
action_type at each state, so an SFT-trained model stops
collapsing to ``select_case`` at every step.
The module returns plain dicts so the notebook can wrap them in any
trainer's expected dataset format (TRL ``SFTTrainer``, HF
``Dataset.from_list``, etc.) without pulling those deps into the
package.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Sequence
try:
from ..core.models import ChargebackOpsAction
from ..server.chargeback_ops_environment import ChargebackOpsEnvironment
from .env_adapter import build_prompt
except ImportError: # pragma: no cover
from core.models import ChargebackOpsAction
from server.chargeback_ops_environment import ChargebackOpsEnvironment
from training.env_adapter import build_prompt
def _heuristic_policy(observation_dict: dict[str, Any]) -> ChargebackOpsAction | None:
try:
from ..runners.benchmark_runner import heuristic_policy
except ImportError: # pragma: no cover
from runners.benchmark_runner import heuristic_policy
return heuristic_policy(observation_dict)
def action_to_completion(action: ChargebackOpsAction) -> str:
"""Serialise an action as the canonical JSON completion string."""
payload = action.model_dump(exclude_none=True)
payload = {k: v for k, v in payload.items() if v not in ([], "")}
return json.dumps(payload, separators=(",", ":"), sort_keys=True)
@dataclass(frozen=True)
class SFTSample:
"""One supervised training pair."""
task_id: str
state_step: int
prompt: str
completion: str
action_type: str
def build_sft_dataset(
task_ids: Sequence[str],
*,
max_states_per_task: int = 24,
) -> list[dict[str, Any]]:
"""Roll heuristic on each task; capture (prompt, oracle_completion) pairs.
Goes deeper than :func:`training.reward_adapter.build_state_action_dataset`
(default 24 vs 12 states per task) because SFT benefits from seeing
the full trajectory — including terminal-resolution actions which
are rare in the early-state distribution.
"""
samples: list[dict[str, Any]] = []
for task_id in task_ids:
env = ChargebackOpsEnvironment()
obs = env.reset(task_id=task_id)
for state_step in range(max_states_per_task):
if obs.done:
break
heur = _heuristic_policy(obs.model_dump())
if heur is None:
break
samples.append(
{
"task_id": task_id,
"state_step": state_step,
"prompt": build_prompt(obs.model_dump()),
"completion": action_to_completion(heur),
"action_type": heur.action_type,
}
)
obs = env.step(heur)
return samples
__all__ = [
"SFTSample",
"action_to_completion",
"build_sft_dataset",
]