Spaces:
Sleeping
Sleeping
| """Mock-TRL integration tests for the GRPO reward pipeline. | |
| A TRL calling-convention bug crashed training with: | |
| ``reward_environmental() got multiple values for argument 'prompts'`` | |
| That bug was invisible to unit tests because no test ever invoked the reward | |
| functions the way TRL's GRPOTrainer actually invokes them: | |
| fn(prompts=[...], completions=[...], task_id=[...], seed=[...]) | |
| These tests simulate that calling convention. If any reward function in the | |
| full pack (pure-text + env-wrapped) chokes on TRL-style kwargs, the test | |
| fails before push β not after 40 minutes of GPU time. | |
| This file runs on CPU only. No unsloth, no trl dependency. | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| # Ensure project root on sys.path | |
| _ROOT = Path(__file__).resolve().parent.parent | |
| if str(_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_ROOT)) | |
| from training.rewards import build_reward_pack, weighted_environmental_reward | |
| from training.stages.stage_3_grpo import _build_prompt_records, _make_task_reward | |
| class FakeGRPOTrainer: | |
| """Simulates the TRL GRPOTrainer's reward-function calling convention. | |
| Real TRL calls: | |
| for fn in reward_funcs: | |
| fn(prompts=prompts, completions=completions, **extra_columns) | |
| We mirror that exactly. Every reward function that survives a call from | |
| this fake trainer is guaranteed to survive TRL. | |
| """ | |
| def __init__(self, reward_funcs: List, dataset_rows: List[Dict[str, Any]], num_generations: int = 2): | |
| self.reward_funcs = reward_funcs | |
| self.dataset_rows = dataset_rows | |
| self.num_generations = num_generations | |
| def simulate_one_step(self, completions: List[str]) -> List[List[float]]: | |
| """Invoke every reward function with realistic TRL-style kwargs.""" | |
| n = len(completions) | |
| batch = self.dataset_rows[:n] | |
| prompts = [r["prompt"] for r in batch] | |
| task_ids = [r["task_id"] for r in batch] | |
| seeds = [r["seed"] for r in batch] | |
| all_rewards = [] | |
| for fn in self.reward_funcs: | |
| rewards = fn( | |
| prompts=prompts, | |
| completions=completions, | |
| task_id=task_ids, | |
| seed=seeds, | |
| ) | |
| assert isinstance(rewards, list), f"{fn.__name__} returned {type(rewards)}" | |
| assert len(rewards) == n, f"{fn.__name__} returned {len(rewards)} scores for {n} completions" | |
| all_rewards.append(rewards) | |
| return all_rewards | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # The test that catches TRL keyword-collision bugs | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_full_reward_pack_survives_trl_calling_convention(tmp_path): | |
| """End-to-end regression: the EXACT reward list stage 3 hands to TRL | |
| must survive a simulated TRL-style call. This is the test that would | |
| have caught the duplicate-prompts bug locally.""" | |
| pack = build_reward_pack(total_episodes=50) | |
| # Build the same env reward that stage 3 builds | |
| task_reward, training_log = _make_task_reward(tmp_path / "grpo_artifacts") | |
| all_reward_funcs = pack.funcs + [weighted_environmental_reward(task_reward, pack)] | |
| # Generate a real prompt dataset (no GPU needed β uses PermanenceEnv) | |
| dataset_rows = _build_prompt_records(total_episodes=8, domain="devtools") | |
| # Realistic completions the model might produce | |
| completions = [ | |
| '<thinking>list first</thinking><action id="fs_ls" path="/var/log"/><reversibility level="R1" confidence="0.99"/>', | |
| '<thinking>snapshot</thinking><action id="fs_snapshot"/><reversibility level="R2" confidence="0.95"/>', | |
| ] | |
| trainer = FakeGRPOTrainer(all_reward_funcs, dataset_rows, num_generations=2) | |
| # If any reward function raises on the TRL calling convention, this | |
| # fails. This is the regression test for TRL keyword-collision bugs. | |
| all_rewards = trainer.simulate_one_step(completions) | |
| # Every reward function returned the right number of scores | |
| for scores in all_rewards: | |
| assert len(scores) == len(completions) | |
| def test_env_wrapper_does_not_double_pass_prompts(tmp_path): | |
| """Narrower regression test for the TRL keyword-collision bug.""" | |
| pack = build_reward_pack(total_episodes=10) | |
| task_reward, _ = _make_task_reward(tmp_path / "grpo") | |
| wrapped = weighted_environmental_reward(task_reward, pack) | |
| # Invoke with the exact kwargs TRL passes | |
| completions = ['<action id="fs_ls"/><reversibility level="R1"/>'] | |
| result = wrapped( | |
| prompts=["some prompt"], | |
| completions=completions, | |
| task_id=["task_log_cleanup"], | |
| seed=[0], | |
| ) | |
| assert isinstance(result, list) | |
| assert len(result) == 1 | |
| def test_text_reward_accepts_trl_kwargs_without_positional_completions(): | |
| """Make sure make_weighted wrapper also survives keyword-only calls.""" | |
| pack = build_reward_pack(total_episodes=10) | |
| for fn in pack.funcs: | |
| # TRL doesn't always pass completions positionally β test the | |
| # keyword path explicitly. | |
| result = fn( | |
| prompts=["p1", "p2"], | |
| completions=["c1", "c2"], | |
| task_id=["t1", "t2"], | |
| seed=[0, 1], | |
| ) | |
| assert len(result) == 2 | |
| def test_build_prompt_records_returns_usable_dataset_shape(): | |
| """Stage 3 calls ``Dataset.from_list(_build_prompt_records(...))``. | |
| The records must be a list of dicts with the required keys.""" | |
| rows = _build_prompt_records(total_episodes=5, domain="devtools") | |
| assert len(rows) == 5 | |
| required_keys = {"prompt", "episode", "task_id", "seed"} | |
| for r in rows: | |
| assert required_keys.issubset(r.keys()) | |
| assert isinstance(r["prompt"], str) | |
| assert r["prompt"] # non-empty | |
| assert r["task_id"].startswith("task_") | |
| def test_task_reward_writes_training_log_entries(tmp_path): | |
| """Stage 3's env reward appends to ``training_log``. Verify the log | |
| accumulates entries in the right shape.""" | |
| pack = build_reward_pack(total_episodes=10) | |
| task_reward, training_log = _make_task_reward(tmp_path / "grpo") | |
| completions = ['<action id="fs_ls" path="/var/log"/><reversibility level="R1"/>'] | |
| task_reward( | |
| prompts=["p"], | |
| completions=completions, | |
| task_id=["task_log_cleanup"], | |
| seed=[0], | |
| ) | |
| assert len(training_log) >= 1 | |
| # Each entry has the structured fields the dashboard and eval rely on | |
| last = training_log[-1] | |
| for k in ("task_id", "seed", "reward", "completion_length"): | |
| assert k in last, f"missing key {k} in training_log entry" | |