"""Mock-TRL integration tests for the GRPO reward pipeline. A TRL calling-convention bug crashed training with: ``reward_environmental() got multiple values for argument 'prompts'`` That bug was invisible to unit tests because no test ever invoked the reward functions the way TRL's GRPOTrainer actually invokes them: fn(prompts=[...], completions=[...], task_id=[...], seed=[...]) These tests simulate that calling convention. If any reward function in the full pack (pure-text + env-wrapped) chokes on TRL-style kwargs, the test fails before push — not after 40 minutes of GPU time. This file runs on CPU only. No unsloth, no trl dependency. """ from __future__ import annotations import sys from pathlib import Path from typing import Any, Dict, List # Ensure project root on sys.path _ROOT = Path(__file__).resolve().parent.parent if str(_ROOT) not in sys.path: sys.path.insert(0, str(_ROOT)) from training.rewards import build_reward_pack, weighted_environmental_reward from training.stages.stage_3_grpo import _build_prompt_records, _make_task_reward class FakeGRPOTrainer: """Simulates the TRL GRPOTrainer's reward-function calling convention. Real TRL calls: for fn in reward_funcs: fn(prompts=prompts, completions=completions, **extra_columns) We mirror that exactly. Every reward function that survives a call from this fake trainer is guaranteed to survive TRL. """ def __init__(self, reward_funcs: List, dataset_rows: List[Dict[str, Any]], num_generations: int = 2): self.reward_funcs = reward_funcs self.dataset_rows = dataset_rows self.num_generations = num_generations def simulate_one_step(self, completions: List[str]) -> List[List[float]]: """Invoke every reward function with realistic TRL-style kwargs.""" n = len(completions) batch = self.dataset_rows[:n] prompts = [r["prompt"] for r in batch] task_ids = [r["task_id"] for r in batch] seeds = [r["seed"] for r in batch] all_rewards = [] for fn in self.reward_funcs: rewards = fn( prompts=prompts, completions=completions, task_id=task_ids, seed=seeds, ) assert isinstance(rewards, list), f"{fn.__name__} returned {type(rewards)}" assert len(rewards) == n, f"{fn.__name__} returned {len(rewards)} scores for {n} completions" all_rewards.append(rewards) return all_rewards # ───────────────────────────────────────────────────────────────────────────── # The test that catches TRL keyword-collision bugs # ───────────────────────────────────────────────────────────────────────────── def test_full_reward_pack_survives_trl_calling_convention(tmp_path): """End-to-end regression: the EXACT reward list stage 3 hands to TRL must survive a simulated TRL-style call. This is the test that would have caught the duplicate-prompts bug locally.""" pack = build_reward_pack(total_episodes=50) # Build the same env reward that stage 3 builds task_reward, training_log = _make_task_reward(tmp_path / "grpo_artifacts") all_reward_funcs = pack.funcs + [weighted_environmental_reward(task_reward, pack)] # Generate a real prompt dataset (no GPU needed — uses PermanenceEnv) dataset_rows = _build_prompt_records(total_episodes=8, domain="devtools") # Realistic completions the model might produce completions = [ 'list first', 'snapshot', ] trainer = FakeGRPOTrainer(all_reward_funcs, dataset_rows, num_generations=2) # If any reward function raises on the TRL calling convention, this # fails. This is the regression test for TRL keyword-collision bugs. all_rewards = trainer.simulate_one_step(completions) # Every reward function returned the right number of scores for scores in all_rewards: assert len(scores) == len(completions) def test_env_wrapper_does_not_double_pass_prompts(tmp_path): """Narrower regression test for the TRL keyword-collision bug.""" pack = build_reward_pack(total_episodes=10) task_reward, _ = _make_task_reward(tmp_path / "grpo") wrapped = weighted_environmental_reward(task_reward, pack) # Invoke with the exact kwargs TRL passes completions = [''] result = wrapped( prompts=["some prompt"], completions=completions, task_id=["task_log_cleanup"], seed=[0], ) assert isinstance(result, list) assert len(result) == 1 def test_text_reward_accepts_trl_kwargs_without_positional_completions(): """Make sure make_weighted wrapper also survives keyword-only calls.""" pack = build_reward_pack(total_episodes=10) for fn in pack.funcs: # TRL doesn't always pass completions positionally — test the # keyword path explicitly. result = fn( prompts=["p1", "p2"], completions=["c1", "c2"], task_id=["t1", "t2"], seed=[0, 1], ) assert len(result) == 2 def test_build_prompt_records_returns_usable_dataset_shape(): """Stage 3 calls ``Dataset.from_list(_build_prompt_records(...))``. The records must be a list of dicts with the required keys.""" rows = _build_prompt_records(total_episodes=5, domain="devtools") assert len(rows) == 5 required_keys = {"prompt", "episode", "task_id", "seed"} for r in rows: assert required_keys.issubset(r.keys()) assert isinstance(r["prompt"], str) assert r["prompt"] # non-empty assert r["task_id"].startswith("task_") def test_task_reward_writes_training_log_entries(tmp_path): """Stage 3's env reward appends to ``training_log``. Verify the log accumulates entries in the right shape.""" pack = build_reward_pack(total_episodes=10) task_reward, training_log = _make_task_reward(tmp_path / "grpo") completions = [''] task_reward( prompts=["p"], completions=completions, task_id=["task_log_cleanup"], seed=[0], ) assert len(training_log) >= 1 # Each entry has the structured fields the dashboard and eval rely on last = training_log[-1] for k in ("task_id", "seed", "reward", "completion_length"): assert k in last, f"missing key {k} in training_log entry"