driftcall / cells /step_14_custom_trainer.py
saumilyajj's picture
Upload folder using huggingface_hub
b43d8da verified
"""Custom trainer + dataset adapter (docs/modules/training.md §2.2, §3.2.3).
Two public types:
- :class:`EpisodeDatasetAdapter` — stateless iterable feeding
``GRPOTrainer.train_dataset``. Each ``__iter__`` tick yields
``{"prompt": str, "_meta": {...}}`` where ``_meta`` carries the
``GoalSpec``, the monotonically-derived ``episode_seed``, the curriculum
``stage``, and the ``language_weights``. One call to
``task_generator.generate`` per step; one call to
``tokenizer.apply_chat_template(messages, tokenize=False,
add_generation_prompt=True)`` to render the prompt.
- :class:`DriftCallGRPOTrainer` — ``GRPOTrainer`` subclass whose
``_generate_and_score_completions`` override runs G multi-turn episodes
via a caller-provided ``RolloutGroupFn`` and plumbs the resulting
frozen ``Episode`` tuple into ``reward_fn`` (step_13) before handing the
G reward scalars + padded completions back to the inherited GRPO
advantage / KL / optimizer step path. **The inherited code path is
untouched** (training.md §3.2.3).
``trl`` and ``torch`` are imported lazily. Pure-Python fallbacks for
``_generate_and_score_completions`` are provided so the class shape
can be verified on CPU-only CI.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
if TYPE_CHECKING: # pragma: no cover - typing only
from collections.abc import Callable, Iterator
from cells.step_13_grpo_config import BETA_KL
PINNED_SYSTEM_PROMPT: str = (
"You are a concierge assistant. Use the provided tools. "
"Respond in the caller's language. Submit with calibrated confidence."
)
LanguageCode = Literal["hi", "ta", "kn", "en", "hinglish"]
class EpisodeSampler(Protocol):
"""Draws a ``GoalSpec`` for one prompt slot (training.md §2.2)."""
def __call__(self, step: int) -> Any: ...
class EnvFactory(Protocol):
"""Returns a fresh ``DriftCallEnv`` per rollout (training.md §3.2)."""
def __call__(self) -> Any: ...
class RolloutGroupFn(Protocol):
"""Runs G multi-turn rollouts sharing one goal.
Returns a tuple ``(episodes, completions)`` of length G each.
"""
def __call__(
self,
*,
model: Any,
tokenizer: Any,
goal: Any,
episode_seed: int,
num_generations: int,
env_factory: EnvFactory,
) -> tuple[tuple[Any, ...], tuple[str, ...]]: ...
@dataclass(frozen=True)
class AdapterRecord:
"""Frozen view of one :class:`EpisodeDatasetAdapter` yield.
Tests consume this view rather than dict-typing ``_meta`` inline.
"""
prompt: str
goal: Any
episode_seed: int
stage: Literal[1, 2, 3]
language_weights: dict[LanguageCode, float]
def render_initial_prompt(tokenizer: Any, goal: Any) -> str:
"""Render the turn-0 chat template (training.md §3.2.1).
Messages: pinned system prompt + ``goal.seed_utterance`` as the user
turn. ``add_generation_prompt=True`` tells Gemma to emit an assistant
turn. Tool schemas live in later turns so only these two messages
appear at ``step == 0``.
"""
seed_utterance = getattr(goal, "seed_utterance", "")
messages: list[dict[str, str]] = [
{"role": "system", "content": PINNED_SYSTEM_PROMPT},
{"role": "user", "content": seed_utterance},
]
result = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return str(result)
class EpisodeDatasetAdapter:
"""Stateless streaming dataset (training.md §2.2).
Constructor signature matches training.md §2.2: a ``task_gen`` callable
accepting ``(seed, stage, language_weights)``, an ``env_factory``
producing fresh envs, the curriculum ``stage``, a ``stage_base_seed``
used to derive per-step ``episode_seed``, the per-language sampling
``language_weights``, and the ``tokenizer`` used to render prompts.
Iteration is infinite — exactly one record per GRPO training step.
Step counter is local to ``__iter__`` so resume simply restarts from
whatever step TRL's ``resume_from_checkpoint`` restores.
"""
def __init__(
self,
*,
task_gen: Callable[..., Any],
env_factory: EnvFactory,
stage: Literal[1, 2, 3],
stage_base_seed: int,
language_weights: dict[LanguageCode, float],
tokenizer: Any,
) -> None:
self.task_gen = task_gen
self.env_factory = env_factory
self.stage: Literal[1, 2, 3] = stage
self.stage_base_seed = stage_base_seed
self.language_weights = dict(language_weights)
self.tokenizer = tokenizer
def _build_record(self, step: int) -> dict[str, Any]:
episode_seed = self.stage_base_seed + step
goal = self.task_gen(
seed=episode_seed,
stage=self.stage,
language_weights=self.language_weights,
)
prompt = render_initial_prompt(self.tokenizer, goal)
return {
"prompt": prompt,
"_meta": {
"goal": goal,
"episode_seed": episode_seed,
"stage": self.stage,
"language_weights": dict(self.language_weights),
},
}
def __iter__(self) -> Iterator[dict[str, Any]]:
step = 0
while True:
yield self._build_record(step)
step += 1
def __len__(self) -> int:
"""Length sentinel for TRL 0.24+ ``RepeatSampler``.
The dataset is logically infinite (one record per GRPO step), but
TRL 0.24's ``RepeatSampler`` calls ``len(data_source)`` to size the
sampler. Returning a large finite number lets training proceed; the
actual step count is bounded by ``GRPOConfig.max_steps``.
"""
return 1_000_000
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Map-style indexing for TRL 0.24+ DataLoader.
TRL 0.24 treats the train_dataset as a Map-style dataset and looks
records up by integer index. We honour the contract by deriving the
record purely from ``idx`` — the adapter is stateless so any index
produces a deterministic ``(prompt, _meta)`` pair for that step.
"""
return self._build_record(int(idx))
def peek(self, step: int) -> AdapterRecord:
"""Materialize the record at ``step`` without advancing iteration.
Used by tests (§1.2 U14–U18) to assert record shape at arbitrary
steps without consuming a generator.
"""
rec = self._build_record(step)
meta = rec["_meta"]
return AdapterRecord(
prompt=rec["prompt"],
goal=meta["goal"],
episode_seed=meta["episode_seed"],
stage=meta["stage"],
language_weights=meta["language_weights"],
)
def _import_grpo_trainer() -> type[Any]:
"""Lazy import of ``trl.GRPOTrainer``; isolated for mocking in tests."""
from trl import GRPOTrainer
return cast("type[Any]", GRPOTrainer)
def _make_driftcall_init(
base_cls: type[Any],
) -> Callable[..., None]:
"""Build an ``__init__`` bound to ``base_cls``; avoids super() recursion
when the returned class is itself further subclassed.
DriftCall-specific kwargs added on top of ``GRPOTrainer.__init__``:
- ``rollout_group_fn``, ``env_factory``, ``reward_fn_driftcall`` — the
multi-turn rollout override surface (see class docstring).
- ``enable_adaptive_kl`` (default ``True``) — auto-attach an
:class:`AdaptiveKLCallback` so β retargets to the measured KL each
logging tick (training.md §3.3.1). Set ``False`` to disable.
- ``adaptive_kl_target`` — override the default ``target_kl=BETA_KL``.
- ``adaptive_kl_kp`` — override the proportional gain.
- ``adaptive_kl_beta_min`` / ``adaptive_kl_beta_max`` — override clamp
bounds.
"""
def _init(
self: Any,
*args: Any,
rollout_group_fn: RolloutGroupFn,
env_factory: EnvFactory,
reward_fn_driftcall: Callable[..., list[float]],
enable_adaptive_kl: bool = True,
adaptive_kl_target: float | None = None,
adaptive_kl_kp: float = DEFAULT_KP,
adaptive_kl_beta_min: float = DEFAULT_BETA_MIN,
adaptive_kl_beta_max: float = DEFAULT_BETA_MAX,
**kwargs: Any,
) -> None:
# TRL 0.24 made ``reward_funcs`` a required arg on GRPOTrainer.
# Our custom ``_generate_and_score_completions`` short-circuits the
# base reward path entirely (calls ``reward_fn_driftcall`` directly),
# so the parent's ``reward_funcs`` value is never invoked. Pass a
# placeholder identity reward to satisfy the signature on TRL>=0.24.
if "reward_funcs" not in kwargs:
def _placeholder_reward(
completions: Any = None,
**_unused: Any,
) -> list[float]:
n = len(completions) if completions is not None else 0
return [0.0] * n
kwargs["reward_funcs"] = [_placeholder_reward]
base_cls.__init__(self, *args, **kwargs)
self.rollout_group_fn = rollout_group_fn
self.env_factory = env_factory
self.reward_fn_driftcall = reward_fn_driftcall
if enable_adaptive_kl:
target = (
adaptive_kl_target if adaptive_kl_target is not None else BETA_KL
)
callback = AdaptiveKLCallback(
target_kl=target,
kp=adaptive_kl_kp,
beta_min=adaptive_kl_beta_min,
beta_max=adaptive_kl_beta_max,
)
self.adaptive_kl_callback = callback
add_callback = getattr(base_cls, "add_callback", None)
if callable(add_callback):
# Production path (TRL ≥ 0.23): register through the TRL
# callback handler so ``on_log`` fires alongside default
# loggers with the correct ``args``/``state``/``control``.
self.add_callback(callback)
else:
# Fallback: minimal bases in tests lack ``add_callback``.
# Keep a private list so callers can still invoke the hook.
if not hasattr(self, "_driftcall_callbacks"):
self._driftcall_callbacks = []
self._driftcall_callbacks.append(callback)
else:
self.adaptive_kl_callback = None
return _init
def _driftcall_generate_and_score_completions(
self: Any, inputs: list[dict[str, Any]]
) -> dict[str, Any]:
"""Run the multi-turn rollout, then call ``reward_fn``.
Expects ``inputs`` to carry one row per prompt slot with the
``_meta`` dict produced by :class:`EpisodeDatasetAdapter`.
Returns a dict with keys ``episodes``, ``completions``, ``rewards``,
``prompts`` — each length G (num_generations).
"""
if not inputs:
raise ValueError("inputs must be a non-empty list")
row = inputs[0]
meta = row["_meta"]
prompt = row["prompt"]
goal = meta["goal"]
episode_seed = meta["episode_seed"]
num_generations = int(getattr(self.args, "num_generations", 8))
episodes, completions = self.rollout_group_fn(
model=self.model,
tokenizer=self.processing_class,
goal=goal,
episode_seed=episode_seed,
num_generations=num_generations,
env_factory=self.env_factory,
)
if len(episodes) != num_generations or len(completions) != num_generations:
raise ValueError(
f"rollout_group_fn produced {len(episodes)} episodes and "
f"{len(completions)} completions; expected {num_generations} each"
)
prompts = [prompt] * num_generations
metas = [dict(meta) for _ in range(num_generations)]
rewards = self.reward_fn_driftcall(
prompts=prompts,
completions=list(completions),
_meta=metas,
episodes=list(episodes),
)
return {
"episodes": episodes,
"completions": completions,
"rewards": rewards,
"prompts": prompts,
}
def make_driftcall_grpo_trainer_cls(base_cls: type[Any] | None = None) -> type[Any]:
"""Build the :class:`DriftCallGRPOTrainer` class bound to ``base_cls``.
Default ``base_cls`` is ``trl.GRPOTrainer`` (imported lazily). Tests
pass a stub base class so they can exercise the override path without
TRL installed.
GRPOTrainer subclass with multi-turn rollout override
(training.md §3.2.3). Construction adds three DriftCall-specific
kwargs over the standard ``GRPOTrainer.__init__``:
- ``rollout_group_fn``: :class:`RolloutGroupFn` running G multi-turn
episodes and returning ``(episodes, completions)``.
- ``env_factory``: :class:`EnvFactory` producing a fresh
``DriftCallEnv`` per rollout.
- ``reward_fn_driftcall``: the step_13 ``reward_fn`` — called
directly with the frozen ``Episode`` tuple after rollout.
``_generate_and_score_completions`` replaces the TRL default.
Advantage + KL + optimizer step paths are inherited unchanged.
"""
resolved_base: type[Any] = (
base_cls if base_cls is not None else _import_grpo_trainer()
)
return type(
"DriftCallGRPOTrainer",
(resolved_base,),
{
"__init__": _make_driftcall_init(resolved_base),
"_generate_and_score_completions": _driftcall_generate_and_score_completions,
"__doc__": "GRPOTrainer subclass with multi-turn rollout override.",
},
)
def driftcall_grpo_trainer_methods() -> tuple[str, ...]:
"""Return the method names the subclass overrides (introspection helper).
Used by the shape test (U in §1.x) to verify the override surface.
"""
return ("__init__", "_generate_and_score_completions")
# ---------------------------------------------------------------------------
# Adaptive KL controller (training.md §3.3 — retarget β from measured KL)
# ---------------------------------------------------------------------------
DEFAULT_BETA_MIN: float = 0.001
DEFAULT_BETA_MAX: float = 1.0
DEFAULT_KP: float = 2.0
def _trainer_callback_base() -> type:
"""Return ``transformers.TrainerCallback`` if importable, else ``object``.
Importing transformers lazily keeps step_14 importable on CPU-only CI
runners that don't have transformers installed.
"""
try:
from transformers.trainer_callback import TrainerCallback
return TrainerCallback
except Exception:
return object
class AdaptiveKLCallback(_trainer_callback_base()): # type: ignore[misc]
"""Retarget β each step based on the ratio of measured KL to ``target_kl``.
Proportional controller with symmetric log-space update:
err = (kl - target_kl) / target_kl
new_beta = beta * exp(kp * err)
new_beta = clamp(new_beta, beta_min, beta_max)
When ``kl`` matches ``target_kl``, ``err == 0`` and β is left unchanged.
Safe on missing / NaN / non-numeric KL signals (no-op, no exception).
Inherits from :class:`transformers.trainer_callback.TrainerCallback` when
available (production path) so all the no-op callback events
(``on_train_begin``, ``on_step_begin``, etc.) come for free; falls back
to ``object`` on CPU-only CI when transformers is not installed.
"""
def __init__(
self,
target_kl: float = BETA_KL,
*,
kp: float = DEFAULT_KP,
beta_min: float = DEFAULT_BETA_MIN,
beta_max: float = DEFAULT_BETA_MAX,
) -> None:
if target_kl <= 0.0:
raise ValueError(f"target_kl must be > 0; got {target_kl}")
if beta_min <= 0.0 or beta_max <= 0.0:
raise ValueError(
f"beta bounds must be > 0; got min={beta_min}, max={beta_max}"
)
if beta_min > beta_max:
raise ValueError(
f"beta_min ({beta_min}) must be <= beta_max ({beta_max})"
)
self.target_kl = float(target_kl)
self.kp = float(kp)
self.beta_min = float(beta_min)
self.beta_max = float(beta_max)
def _coerce_kl(self, raw: Any) -> float | None:
"""Return a finite float or ``None`` — propagates no-op on bad input."""
try:
value = float(raw)
except (TypeError, ValueError):
return None
if math.isnan(value) or math.isinf(value):
return None
return value
def _next_beta(self, beta: float, kl: float) -> tuple[float, bool, bool]:
"""Return ``(new_beta, clamped_to_min, clamped_to_max)``."""
err = (kl - self.target_kl) / self.target_kl
# Clamp the exponent so extreme KL spikes don't overflow math.exp;
# the result is clamped anyway and exp(±50) easily saturates either bound.
exponent = max(-50.0, min(50.0, self.kp * err))
scaled = beta * math.exp(exponent)
if scaled <= self.beta_min:
return self.beta_min, True, False
if scaled >= self.beta_max:
return self.beta_max, False, True
return scaled, False, False
def on_log(
self,
args: Any,
state: Any,
control: Any,
*,
logs: dict[str, Any] | None = None,
**_kwargs: Any,
) -> Any:
"""TRL hook — called with every ``trainer.log(...)`` dict.
On a well-formed KL signal: mutates ``args.beta`` with the new
coefficient and writes five diagnostic fields back into ``logs``
so TRL's default reporter forwards them to wandb / CSV / etc.:
- ``train/beta_adaptive`` current KL coefficient
- ``train/kl_measured`` sanitised KL input
- ``train/kl_target`` constant — aids chart-by-reference
- ``train/beta_clamped_to_min`` 0/1 — fires on collapse
- ``train/beta_clamped_to_max`` 0/1 — fires on runaway divergence
"""
if logs is None:
return control
if "kl" not in logs:
return control
kl = self._coerce_kl(logs["kl"])
if kl is None:
return control
beta = float(getattr(args, "beta", BETA_KL))
new_beta, clamped_lo, clamped_hi = self._next_beta(beta, kl)
args.beta = new_beta
logs["train/beta_adaptive"] = new_beta
logs["train/kl_measured"] = kl
logs["train/kl_target"] = self.target_kl
logs["train/beta_clamped_to_min"] = 1 if clamped_lo else 0
logs["train/beta_clamped_to_max"] = 1 if clamped_hi else 0
return control
__all__ = [
"AdapterRecord",
"AdaptiveKLCallback",
"DEFAULT_BETA_MAX",
"DEFAULT_BETA_MIN",
"DEFAULT_KP",
"EnvFactory",
"EpisodeDatasetAdapter",
"EpisodeSampler",
"LanguageCode",
"PINNED_SYSTEM_PROMPT",
"RolloutGroupFn",
"driftcall_grpo_trainer_methods",
"make_driftcall_grpo_trainer_cls",
"render_initial_prompt",
]