Spaces:
Sleeping
Sleeping
| """Supervised fine-tuning dataset builder for ChargebackOps. | |
| Rolls the scripted heuristic across each task and captures every | |
| ``(observation_prompt, oracle_completion)`` pair as a single-turn | |
| training sample. The completion is the JSON serialisation of the | |
| heuristic action, matching the format the merchant policy must emit | |
| at inference time. | |
| SFT before GRPO is the standard RLHF pattern. It teaches the base | |
| model two things GRPO struggles to learn from sparse reward alone: | |
| * The output schema (valid JSON, the right ``action_type`` strings, | |
| no extra prose). | |
| * Per-state action variety — the heuristic emits a *different* | |
| action_type at each state, so an SFT-trained model stops | |
| collapsing to ``select_case`` at every step. | |
| The module returns plain dicts so the notebook can wrap them in any | |
| trainer's expected dataset format (TRL ``SFTTrainer``, HF | |
| ``Dataset.from_list``, etc.) without pulling those deps into the | |
| package. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from typing import Any, Sequence | |
| try: | |
| from ..core.models import ChargebackOpsAction | |
| from ..server.chargeback_ops_environment import ChargebackOpsEnvironment | |
| from .env_adapter import build_prompt | |
| except ImportError: # pragma: no cover | |
| from core.models import ChargebackOpsAction | |
| from server.chargeback_ops_environment import ChargebackOpsEnvironment | |
| from training.env_adapter import build_prompt | |
| def _heuristic_policy(observation_dict: dict[str, Any]) -> ChargebackOpsAction | None: | |
| try: | |
| from ..runners.benchmark_runner import heuristic_policy | |
| except ImportError: # pragma: no cover | |
| from runners.benchmark_runner import heuristic_policy | |
| return heuristic_policy(observation_dict) | |
| def action_to_completion(action: ChargebackOpsAction) -> str: | |
| """Serialise an action as the canonical JSON completion string.""" | |
| payload = action.model_dump(exclude_none=True) | |
| payload = {k: v for k, v in payload.items() if v not in ([], "")} | |
| return json.dumps(payload, separators=(",", ":"), sort_keys=True) | |
| class SFTSample: | |
| """One supervised training pair.""" | |
| task_id: str | |
| state_step: int | |
| prompt: str | |
| completion: str | |
| action_type: str | |
| def build_sft_dataset( | |
| task_ids: Sequence[str], | |
| *, | |
| max_states_per_task: int = 24, | |
| ) -> list[dict[str, Any]]: | |
| """Roll heuristic on each task; capture (prompt, oracle_completion) pairs. | |
| Goes deeper than :func:`training.reward_adapter.build_state_action_dataset` | |
| (default 24 vs 12 states per task) because SFT benefits from seeing | |
| the full trajectory — including terminal-resolution actions which | |
| are rare in the early-state distribution. | |
| """ | |
| samples: list[dict[str, Any]] = [] | |
| for task_id in task_ids: | |
| env = ChargebackOpsEnvironment() | |
| obs = env.reset(task_id=task_id) | |
| for state_step in range(max_states_per_task): | |
| if obs.done: | |
| break | |
| heur = _heuristic_policy(obs.model_dump()) | |
| if heur is None: | |
| break | |
| samples.append( | |
| { | |
| "task_id": task_id, | |
| "state_step": state_step, | |
| "prompt": build_prompt(obs.model_dump()), | |
| "completion": action_to_completion(heur), | |
| "action_type": heur.action_type, | |
| } | |
| ) | |
| obs = env.step(heur) | |
| return samples | |
| __all__ = [ | |
| "SFTSample", | |
| "action_to_completion", | |
| "build_sft_dataset", | |
| ] | |