permanence / tests /test_trl_integration.py
chane35's picture
PERMANENCE training: 4-stage SFT -> gate -> GRPO -> eval pipeline
21c24ae verified
"""Mock-TRL integration tests for the GRPO reward pipeline.
A TRL calling-convention bug crashed training with:
``reward_environmental() got multiple values for argument 'prompts'``
That bug was invisible to unit tests because no test ever invoked the reward
functions the way TRL's GRPOTrainer actually invokes them:
fn(prompts=[...], completions=[...], task_id=[...], seed=[...])
These tests simulate that calling convention. If any reward function in the
full pack (pure-text + env-wrapped) chokes on TRL-style kwargs, the test
fails before push β€” not after 40 minutes of GPU time.
This file runs on CPU only. No unsloth, no trl dependency.
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Any, Dict, List
# Ensure project root on sys.path
_ROOT = Path(__file__).resolve().parent.parent
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
from training.rewards import build_reward_pack, weighted_environmental_reward
from training.stages.stage_3_grpo import _build_prompt_records, _make_task_reward
class FakeGRPOTrainer:
"""Simulates the TRL GRPOTrainer's reward-function calling convention.
Real TRL calls:
for fn in reward_funcs:
fn(prompts=prompts, completions=completions, **extra_columns)
We mirror that exactly. Every reward function that survives a call from
this fake trainer is guaranteed to survive TRL.
"""
def __init__(self, reward_funcs: List, dataset_rows: List[Dict[str, Any]], num_generations: int = 2):
self.reward_funcs = reward_funcs
self.dataset_rows = dataset_rows
self.num_generations = num_generations
def simulate_one_step(self, completions: List[str]) -> List[List[float]]:
"""Invoke every reward function with realistic TRL-style kwargs."""
n = len(completions)
batch = self.dataset_rows[:n]
prompts = [r["prompt"] for r in batch]
task_ids = [r["task_id"] for r in batch]
seeds = [r["seed"] for r in batch]
all_rewards = []
for fn in self.reward_funcs:
rewards = fn(
prompts=prompts,
completions=completions,
task_id=task_ids,
seed=seeds,
)
assert isinstance(rewards, list), f"{fn.__name__} returned {type(rewards)}"
assert len(rewards) == n, f"{fn.__name__} returned {len(rewards)} scores for {n} completions"
all_rewards.append(rewards)
return all_rewards
# ─────────────────────────────────────────────────────────────────────────────
# The test that catches TRL keyword-collision bugs
# ─────────────────────────────────────────────────────────────────────────────
def test_full_reward_pack_survives_trl_calling_convention(tmp_path):
"""End-to-end regression: the EXACT reward list stage 3 hands to TRL
must survive a simulated TRL-style call. This is the test that would
have caught the duplicate-prompts bug locally."""
pack = build_reward_pack(total_episodes=50)
# Build the same env reward that stage 3 builds
task_reward, training_log = _make_task_reward(tmp_path / "grpo_artifacts")
all_reward_funcs = pack.funcs + [weighted_environmental_reward(task_reward, pack)]
# Generate a real prompt dataset (no GPU needed β€” uses PermanenceEnv)
dataset_rows = _build_prompt_records(total_episodes=8, domain="devtools")
# Realistic completions the model might produce
completions = [
'<thinking>list first</thinking><action id="fs_ls" path="/var/log"/><reversibility level="R1" confidence="0.99"/>',
'<thinking>snapshot</thinking><action id="fs_snapshot"/><reversibility level="R2" confidence="0.95"/>',
]
trainer = FakeGRPOTrainer(all_reward_funcs, dataset_rows, num_generations=2)
# If any reward function raises on the TRL calling convention, this
# fails. This is the regression test for TRL keyword-collision bugs.
all_rewards = trainer.simulate_one_step(completions)
# Every reward function returned the right number of scores
for scores in all_rewards:
assert len(scores) == len(completions)
def test_env_wrapper_does_not_double_pass_prompts(tmp_path):
"""Narrower regression test for the TRL keyword-collision bug."""
pack = build_reward_pack(total_episodes=10)
task_reward, _ = _make_task_reward(tmp_path / "grpo")
wrapped = weighted_environmental_reward(task_reward, pack)
# Invoke with the exact kwargs TRL passes
completions = ['<action id="fs_ls"/><reversibility level="R1"/>']
result = wrapped(
prompts=["some prompt"],
completions=completions,
task_id=["task_log_cleanup"],
seed=[0],
)
assert isinstance(result, list)
assert len(result) == 1
def test_text_reward_accepts_trl_kwargs_without_positional_completions():
"""Make sure make_weighted wrapper also survives keyword-only calls."""
pack = build_reward_pack(total_episodes=10)
for fn in pack.funcs:
# TRL doesn't always pass completions positionally β€” test the
# keyword path explicitly.
result = fn(
prompts=["p1", "p2"],
completions=["c1", "c2"],
task_id=["t1", "t2"],
seed=[0, 1],
)
assert len(result) == 2
def test_build_prompt_records_returns_usable_dataset_shape():
"""Stage 3 calls ``Dataset.from_list(_build_prompt_records(...))``.
The records must be a list of dicts with the required keys."""
rows = _build_prompt_records(total_episodes=5, domain="devtools")
assert len(rows) == 5
required_keys = {"prompt", "episode", "task_id", "seed"}
for r in rows:
assert required_keys.issubset(r.keys())
assert isinstance(r["prompt"], str)
assert r["prompt"] # non-empty
assert r["task_id"].startswith("task_")
def test_task_reward_writes_training_log_entries(tmp_path):
"""Stage 3's env reward appends to ``training_log``. Verify the log
accumulates entries in the right shape."""
pack = build_reward_pack(total_episodes=10)
task_reward, training_log = _make_task_reward(tmp_path / "grpo")
completions = ['<action id="fs_ls" path="/var/log"/><reversibility level="R1"/>']
task_reward(
prompts=["p"],
completions=completions,
task_id=["task_log_cleanup"],
seed=[0],
)
assert len(training_log) >= 1
# Each entry has the structured fields the dashboard and eval rely on
last = training_log[-1]
for k in ("task_id", "seed", "reward", "completion_length"):
assert k in last, f"missing key {k} in training_log entry"