Spaces:
Runtime error
Runtime error
| """Custom trainer + dataset adapter (docs/modules/training.md §2.2, §3.2.3). | |
| Two public types: | |
| - :class:`EpisodeDatasetAdapter` — stateless iterable feeding | |
| ``GRPOTrainer.train_dataset``. Each ``__iter__`` tick yields | |
| ``{"prompt": str, "_meta": {...}}`` where ``_meta`` carries the | |
| ``GoalSpec``, the monotonically-derived ``episode_seed``, the curriculum | |
| ``stage``, and the ``language_weights``. One call to | |
| ``task_generator.generate`` per step; one call to | |
| ``tokenizer.apply_chat_template(messages, tokenize=False, | |
| add_generation_prompt=True)`` to render the prompt. | |
| - :class:`DriftCallGRPOTrainer` — ``GRPOTrainer`` subclass whose | |
| ``_generate_and_score_completions`` override runs G multi-turn episodes | |
| via a caller-provided ``RolloutGroupFn`` and plumbs the resulting | |
| frozen ``Episode`` tuple into ``reward_fn`` (step_13) before handing the | |
| G reward scalars + padded completions back to the inherited GRPO | |
| advantage / KL / optimizer step path. **The inherited code path is | |
| untouched** (training.md §3.2.3). | |
| ``trl`` and ``torch`` are imported lazily. Pure-Python fallbacks for | |
| ``_generate_and_score_completions`` are provided so the class shape | |
| can be verified on CPU-only CI. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Any, Literal, Protocol, cast | |
| if TYPE_CHECKING: # pragma: no cover - typing only | |
| from collections.abc import Callable, Iterator | |
| from cells.step_13_grpo_config import BETA_KL | |
| PINNED_SYSTEM_PROMPT: str = ( | |
| "You are a concierge assistant. Use the provided tools. " | |
| "Respond in the caller's language. Submit with calibrated confidence." | |
| ) | |
| LanguageCode = Literal["hi", "ta", "kn", "en", "hinglish"] | |
| class EpisodeSampler(Protocol): | |
| """Draws a ``GoalSpec`` for one prompt slot (training.md §2.2).""" | |
| def __call__(self, step: int) -> Any: ... | |
| class EnvFactory(Protocol): | |
| """Returns a fresh ``DriftCallEnv`` per rollout (training.md §3.2).""" | |
| def __call__(self) -> Any: ... | |
| class RolloutGroupFn(Protocol): | |
| """Runs G multi-turn rollouts sharing one goal. | |
| Returns a tuple ``(episodes, completions)`` of length G each. | |
| """ | |
| def __call__( | |
| self, | |
| *, | |
| model: Any, | |
| tokenizer: Any, | |
| goal: Any, | |
| episode_seed: int, | |
| num_generations: int, | |
| env_factory: EnvFactory, | |
| ) -> tuple[tuple[Any, ...], tuple[str, ...]]: ... | |
| class AdapterRecord: | |
| """Frozen view of one :class:`EpisodeDatasetAdapter` yield. | |
| Tests consume this view rather than dict-typing ``_meta`` inline. | |
| """ | |
| prompt: str | |
| goal: Any | |
| episode_seed: int | |
| stage: Literal[1, 2, 3] | |
| language_weights: dict[LanguageCode, float] | |
| def render_initial_prompt(tokenizer: Any, goal: Any) -> str: | |
| """Render the turn-0 chat template (training.md §3.2.1). | |
| Messages: pinned system prompt + ``goal.seed_utterance`` as the user | |
| turn. ``add_generation_prompt=True`` tells Gemma to emit an assistant | |
| turn. Tool schemas live in later turns so only these two messages | |
| appear at ``step == 0``. | |
| """ | |
| seed_utterance = getattr(goal, "seed_utterance", "") | |
| messages: list[dict[str, str]] = [ | |
| {"role": "system", "content": PINNED_SYSTEM_PROMPT}, | |
| {"role": "user", "content": seed_utterance}, | |
| ] | |
| result = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| return str(result) | |
| class EpisodeDatasetAdapter: | |
| """Stateless streaming dataset (training.md §2.2). | |
| Constructor signature matches training.md §2.2: a ``task_gen`` callable | |
| accepting ``(seed, stage, language_weights)``, an ``env_factory`` | |
| producing fresh envs, the curriculum ``stage``, a ``stage_base_seed`` | |
| used to derive per-step ``episode_seed``, the per-language sampling | |
| ``language_weights``, and the ``tokenizer`` used to render prompts. | |
| Iteration is infinite — exactly one record per GRPO training step. | |
| Step counter is local to ``__iter__`` so resume simply restarts from | |
| whatever step TRL's ``resume_from_checkpoint`` restores. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| task_gen: Callable[..., Any], | |
| env_factory: EnvFactory, | |
| stage: Literal[1, 2, 3], | |
| stage_base_seed: int, | |
| language_weights: dict[LanguageCode, float], | |
| tokenizer: Any, | |
| ) -> None: | |
| self.task_gen = task_gen | |
| self.env_factory = env_factory | |
| self.stage: Literal[1, 2, 3] = stage | |
| self.stage_base_seed = stage_base_seed | |
| self.language_weights = dict(language_weights) | |
| self.tokenizer = tokenizer | |
| def _build_record(self, step: int) -> dict[str, Any]: | |
| episode_seed = self.stage_base_seed + step | |
| goal = self.task_gen( | |
| seed=episode_seed, | |
| stage=self.stage, | |
| language_weights=self.language_weights, | |
| ) | |
| prompt = render_initial_prompt(self.tokenizer, goal) | |
| return { | |
| "prompt": prompt, | |
| "_meta": { | |
| "goal": goal, | |
| "episode_seed": episode_seed, | |
| "stage": self.stage, | |
| "language_weights": dict(self.language_weights), | |
| }, | |
| } | |
| def __iter__(self) -> Iterator[dict[str, Any]]: | |
| step = 0 | |
| while True: | |
| yield self._build_record(step) | |
| step += 1 | |
| def __len__(self) -> int: | |
| """Length sentinel for TRL 0.24+ ``RepeatSampler``. | |
| The dataset is logically infinite (one record per GRPO step), but | |
| TRL 0.24's ``RepeatSampler`` calls ``len(data_source)`` to size the | |
| sampler. Returning a large finite number lets training proceed; the | |
| actual step count is bounded by ``GRPOConfig.max_steps``. | |
| """ | |
| return 1_000_000 | |
| def __getitem__(self, idx: int) -> dict[str, Any]: | |
| """Map-style indexing for TRL 0.24+ DataLoader. | |
| TRL 0.24 treats the train_dataset as a Map-style dataset and looks | |
| records up by integer index. We honour the contract by deriving the | |
| record purely from ``idx`` — the adapter is stateless so any index | |
| produces a deterministic ``(prompt, _meta)`` pair for that step. | |
| """ | |
| return self._build_record(int(idx)) | |
| def peek(self, step: int) -> AdapterRecord: | |
| """Materialize the record at ``step`` without advancing iteration. | |
| Used by tests (§1.2 U14–U18) to assert record shape at arbitrary | |
| steps without consuming a generator. | |
| """ | |
| rec = self._build_record(step) | |
| meta = rec["_meta"] | |
| return AdapterRecord( | |
| prompt=rec["prompt"], | |
| goal=meta["goal"], | |
| episode_seed=meta["episode_seed"], | |
| stage=meta["stage"], | |
| language_weights=meta["language_weights"], | |
| ) | |
| def _import_grpo_trainer() -> type[Any]: | |
| """Lazy import of ``trl.GRPOTrainer``; isolated for mocking in tests.""" | |
| from trl import GRPOTrainer | |
| return cast("type[Any]", GRPOTrainer) | |
| def _make_driftcall_init( | |
| base_cls: type[Any], | |
| ) -> Callable[..., None]: | |
| """Build an ``__init__`` bound to ``base_cls``; avoids super() recursion | |
| when the returned class is itself further subclassed. | |
| DriftCall-specific kwargs added on top of ``GRPOTrainer.__init__``: | |
| - ``rollout_group_fn``, ``env_factory``, ``reward_fn_driftcall`` — the | |
| multi-turn rollout override surface (see class docstring). | |
| - ``enable_adaptive_kl`` (default ``True``) — auto-attach an | |
| :class:`AdaptiveKLCallback` so β retargets to the measured KL each | |
| logging tick (training.md §3.3.1). Set ``False`` to disable. | |
| - ``adaptive_kl_target`` — override the default ``target_kl=BETA_KL``. | |
| - ``adaptive_kl_kp`` — override the proportional gain. | |
| - ``adaptive_kl_beta_min`` / ``adaptive_kl_beta_max`` — override clamp | |
| bounds. | |
| """ | |
| def _init( | |
| self: Any, | |
| *args: Any, | |
| rollout_group_fn: RolloutGroupFn, | |
| env_factory: EnvFactory, | |
| reward_fn_driftcall: Callable[..., list[float]], | |
| enable_adaptive_kl: bool = True, | |
| adaptive_kl_target: float | None = None, | |
| adaptive_kl_kp: float = DEFAULT_KP, | |
| adaptive_kl_beta_min: float = DEFAULT_BETA_MIN, | |
| adaptive_kl_beta_max: float = DEFAULT_BETA_MAX, | |
| **kwargs: Any, | |
| ) -> None: | |
| # TRL 0.24 made ``reward_funcs`` a required arg on GRPOTrainer. | |
| # Our custom ``_generate_and_score_completions`` short-circuits the | |
| # base reward path entirely (calls ``reward_fn_driftcall`` directly), | |
| # so the parent's ``reward_funcs`` value is never invoked. Pass a | |
| # placeholder identity reward to satisfy the signature on TRL>=0.24. | |
| if "reward_funcs" not in kwargs: | |
| def _placeholder_reward( | |
| completions: Any = None, | |
| **_unused: Any, | |
| ) -> list[float]: | |
| n = len(completions) if completions is not None else 0 | |
| return [0.0] * n | |
| kwargs["reward_funcs"] = [_placeholder_reward] | |
| base_cls.__init__(self, *args, **kwargs) | |
| self.rollout_group_fn = rollout_group_fn | |
| self.env_factory = env_factory | |
| self.reward_fn_driftcall = reward_fn_driftcall | |
| if enable_adaptive_kl: | |
| target = ( | |
| adaptive_kl_target if adaptive_kl_target is not None else BETA_KL | |
| ) | |
| callback = AdaptiveKLCallback( | |
| target_kl=target, | |
| kp=adaptive_kl_kp, | |
| beta_min=adaptive_kl_beta_min, | |
| beta_max=adaptive_kl_beta_max, | |
| ) | |
| self.adaptive_kl_callback = callback | |
| add_callback = getattr(base_cls, "add_callback", None) | |
| if callable(add_callback): | |
| # Production path (TRL ≥ 0.23): register through the TRL | |
| # callback handler so ``on_log`` fires alongside default | |
| # loggers with the correct ``args``/``state``/``control``. | |
| self.add_callback(callback) | |
| else: | |
| # Fallback: minimal bases in tests lack ``add_callback``. | |
| # Keep a private list so callers can still invoke the hook. | |
| if not hasattr(self, "_driftcall_callbacks"): | |
| self._driftcall_callbacks = [] | |
| self._driftcall_callbacks.append(callback) | |
| else: | |
| self.adaptive_kl_callback = None | |
| return _init | |
| def _driftcall_generate_and_score_completions( | |
| self: Any, inputs: list[dict[str, Any]] | |
| ) -> dict[str, Any]: | |
| """Run the multi-turn rollout, then call ``reward_fn``. | |
| Expects ``inputs`` to carry one row per prompt slot with the | |
| ``_meta`` dict produced by :class:`EpisodeDatasetAdapter`. | |
| Returns a dict with keys ``episodes``, ``completions``, ``rewards``, | |
| ``prompts`` — each length G (num_generations). | |
| """ | |
| if not inputs: | |
| raise ValueError("inputs must be a non-empty list") | |
| row = inputs[0] | |
| meta = row["_meta"] | |
| prompt = row["prompt"] | |
| goal = meta["goal"] | |
| episode_seed = meta["episode_seed"] | |
| num_generations = int(getattr(self.args, "num_generations", 8)) | |
| episodes, completions = self.rollout_group_fn( | |
| model=self.model, | |
| tokenizer=self.processing_class, | |
| goal=goal, | |
| episode_seed=episode_seed, | |
| num_generations=num_generations, | |
| env_factory=self.env_factory, | |
| ) | |
| if len(episodes) != num_generations or len(completions) != num_generations: | |
| raise ValueError( | |
| f"rollout_group_fn produced {len(episodes)} episodes and " | |
| f"{len(completions)} completions; expected {num_generations} each" | |
| ) | |
| prompts = [prompt] * num_generations | |
| metas = [dict(meta) for _ in range(num_generations)] | |
| rewards = self.reward_fn_driftcall( | |
| prompts=prompts, | |
| completions=list(completions), | |
| _meta=metas, | |
| episodes=list(episodes), | |
| ) | |
| return { | |
| "episodes": episodes, | |
| "completions": completions, | |
| "rewards": rewards, | |
| "prompts": prompts, | |
| } | |
| def make_driftcall_grpo_trainer_cls(base_cls: type[Any] | None = None) -> type[Any]: | |
| """Build the :class:`DriftCallGRPOTrainer` class bound to ``base_cls``. | |
| Default ``base_cls`` is ``trl.GRPOTrainer`` (imported lazily). Tests | |
| pass a stub base class so they can exercise the override path without | |
| TRL installed. | |
| GRPOTrainer subclass with multi-turn rollout override | |
| (training.md §3.2.3). Construction adds three DriftCall-specific | |
| kwargs over the standard ``GRPOTrainer.__init__``: | |
| - ``rollout_group_fn``: :class:`RolloutGroupFn` running G multi-turn | |
| episodes and returning ``(episodes, completions)``. | |
| - ``env_factory``: :class:`EnvFactory` producing a fresh | |
| ``DriftCallEnv`` per rollout. | |
| - ``reward_fn_driftcall``: the step_13 ``reward_fn`` — called | |
| directly with the frozen ``Episode`` tuple after rollout. | |
| ``_generate_and_score_completions`` replaces the TRL default. | |
| Advantage + KL + optimizer step paths are inherited unchanged. | |
| """ | |
| resolved_base: type[Any] = ( | |
| base_cls if base_cls is not None else _import_grpo_trainer() | |
| ) | |
| return type( | |
| "DriftCallGRPOTrainer", | |
| (resolved_base,), | |
| { | |
| "__init__": _make_driftcall_init(resolved_base), | |
| "_generate_and_score_completions": _driftcall_generate_and_score_completions, | |
| "__doc__": "GRPOTrainer subclass with multi-turn rollout override.", | |
| }, | |
| ) | |
| def driftcall_grpo_trainer_methods() -> tuple[str, ...]: | |
| """Return the method names the subclass overrides (introspection helper). | |
| Used by the shape test (U in §1.x) to verify the override surface. | |
| """ | |
| return ("__init__", "_generate_and_score_completions") | |
| # --------------------------------------------------------------------------- | |
| # Adaptive KL controller (training.md §3.3 — retarget β from measured KL) | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_BETA_MIN: float = 0.001 | |
| DEFAULT_BETA_MAX: float = 1.0 | |
| DEFAULT_KP: float = 2.0 | |
| def _trainer_callback_base() -> type: | |
| """Return ``transformers.TrainerCallback`` if importable, else ``object``. | |
| Importing transformers lazily keeps step_14 importable on CPU-only CI | |
| runners that don't have transformers installed. | |
| """ | |
| try: | |
| from transformers.trainer_callback import TrainerCallback | |
| return TrainerCallback | |
| except Exception: | |
| return object | |
| class AdaptiveKLCallback(_trainer_callback_base()): # type: ignore[misc] | |
| """Retarget β each step based on the ratio of measured KL to ``target_kl``. | |
| Proportional controller with symmetric log-space update: | |
| err = (kl - target_kl) / target_kl | |
| new_beta = beta * exp(kp * err) | |
| new_beta = clamp(new_beta, beta_min, beta_max) | |
| When ``kl`` matches ``target_kl``, ``err == 0`` and β is left unchanged. | |
| Safe on missing / NaN / non-numeric KL signals (no-op, no exception). | |
| Inherits from :class:`transformers.trainer_callback.TrainerCallback` when | |
| available (production path) so all the no-op callback events | |
| (``on_train_begin``, ``on_step_begin``, etc.) come for free; falls back | |
| to ``object`` on CPU-only CI when transformers is not installed. | |
| """ | |
| def __init__( | |
| self, | |
| target_kl: float = BETA_KL, | |
| *, | |
| kp: float = DEFAULT_KP, | |
| beta_min: float = DEFAULT_BETA_MIN, | |
| beta_max: float = DEFAULT_BETA_MAX, | |
| ) -> None: | |
| if target_kl <= 0.0: | |
| raise ValueError(f"target_kl must be > 0; got {target_kl}") | |
| if beta_min <= 0.0 or beta_max <= 0.0: | |
| raise ValueError( | |
| f"beta bounds must be > 0; got min={beta_min}, max={beta_max}" | |
| ) | |
| if beta_min > beta_max: | |
| raise ValueError( | |
| f"beta_min ({beta_min}) must be <= beta_max ({beta_max})" | |
| ) | |
| self.target_kl = float(target_kl) | |
| self.kp = float(kp) | |
| self.beta_min = float(beta_min) | |
| self.beta_max = float(beta_max) | |
| def _coerce_kl(self, raw: Any) -> float | None: | |
| """Return a finite float or ``None`` — propagates no-op on bad input.""" | |
| try: | |
| value = float(raw) | |
| except (TypeError, ValueError): | |
| return None | |
| if math.isnan(value) or math.isinf(value): | |
| return None | |
| return value | |
| def _next_beta(self, beta: float, kl: float) -> tuple[float, bool, bool]: | |
| """Return ``(new_beta, clamped_to_min, clamped_to_max)``.""" | |
| err = (kl - self.target_kl) / self.target_kl | |
| # Clamp the exponent so extreme KL spikes don't overflow math.exp; | |
| # the result is clamped anyway and exp(±50) easily saturates either bound. | |
| exponent = max(-50.0, min(50.0, self.kp * err)) | |
| scaled = beta * math.exp(exponent) | |
| if scaled <= self.beta_min: | |
| return self.beta_min, True, False | |
| if scaled >= self.beta_max: | |
| return self.beta_max, False, True | |
| return scaled, False, False | |
| def on_log( | |
| self, | |
| args: Any, | |
| state: Any, | |
| control: Any, | |
| *, | |
| logs: dict[str, Any] | None = None, | |
| **_kwargs: Any, | |
| ) -> Any: | |
| """TRL hook — called with every ``trainer.log(...)`` dict. | |
| On a well-formed KL signal: mutates ``args.beta`` with the new | |
| coefficient and writes five diagnostic fields back into ``logs`` | |
| so TRL's default reporter forwards them to wandb / CSV / etc.: | |
| - ``train/beta_adaptive`` current KL coefficient | |
| - ``train/kl_measured`` sanitised KL input | |
| - ``train/kl_target`` constant — aids chart-by-reference | |
| - ``train/beta_clamped_to_min`` 0/1 — fires on collapse | |
| - ``train/beta_clamped_to_max`` 0/1 — fires on runaway divergence | |
| """ | |
| if logs is None: | |
| return control | |
| if "kl" not in logs: | |
| return control | |
| kl = self._coerce_kl(logs["kl"]) | |
| if kl is None: | |
| return control | |
| beta = float(getattr(args, "beta", BETA_KL)) | |
| new_beta, clamped_lo, clamped_hi = self._next_beta(beta, kl) | |
| args.beta = new_beta | |
| logs["train/beta_adaptive"] = new_beta | |
| logs["train/kl_measured"] = kl | |
| logs["train/kl_target"] = self.target_kl | |
| logs["train/beta_clamped_to_min"] = 1 if clamped_lo else 0 | |
| logs["train/beta_clamped_to_max"] = 1 if clamped_hi else 0 | |
| return control | |
| __all__ = [ | |
| "AdapterRecord", | |
| "AdaptiveKLCallback", | |
| "DEFAULT_BETA_MAX", | |
| "DEFAULT_BETA_MIN", | |
| "DEFAULT_KP", | |
| "EnvFactory", | |
| "EpisodeDatasetAdapter", | |
| "EpisodeSampler", | |
| "LanguageCode", | |
| "PINNED_SYSTEM_PROMPT", | |
| "RolloutGroupFn", | |
| "driftcall_grpo_trainer_methods", | |
| "make_driftcall_grpo_trainer_cls", | |
| "render_initial_prompt", | |
| ] | |