Codeseys's picture
feat(trainer): policy-optimization objective MENU (ADR-014)
aae66fa
Raw
History Blame Contribute Delete
3.08 kB
"""Tests for the policy-optimization objective menu (make_po_config, ADR-014).
These build real trl GRPOConfigs, so they require trl installed (the framework's
.venv has trl==1.5.0). Skips cleanly if trl is absent.
"""
from __future__ import annotations
import pytest
trl = pytest.importorskip("trl")
from composer_replication.trainer.composer_trainer import ( # noqa: E402
PO_OBJECTIVES,
make_po_config,
)
def test_menu_lists_expected_objectives():
assert set(PO_OBJECTIVES) == {"grpo", "dr_grpo", "bnpo", "dapo", "gspo", "cispo"}
def test_unknown_objective_raises_with_menu(tmp_path):
with pytest.raises(ValueError) as ei:
make_po_config("nope", output_dir=str(tmp_path))
msg = str(ei.value)
assert "Unknown PO objective" in msg and "dapo" in msg and "gspo" in msg
def test_grpo_preset(tmp_path):
cfg = make_po_config("grpo", output_dir=str(tmp_path))
assert str(cfg.loss_type) == "grpo"
assert str(cfg.importance_sampling_level) == "token"
# group scaling = std-normalized advantage (vanilla GRPO)
assert str(cfg.scale_rewards).lower() in ("group", "true")
def test_dr_grpo_preset_matches_legacy(tmp_path):
cfg = make_po_config("dr_grpo", output_dir=str(tmp_path))
assert str(cfg.loss_type) == "dr_grpo"
# no std-normalization (the Dr.GRPO fix)
assert str(cfg.scale_rewards).lower() in ("none", "false")
def test_dapo_preset_sets_decoupled_clip(tmp_path):
cfg = make_po_config("dapo", output_dir=str(tmp_path))
assert str(cfg.loss_type) == "dapo"
# clip-higher: epsilon_high strictly above epsilon
assert cfg.epsilon_high is not None
assert float(cfg.epsilon_high) > float(cfg.epsilon)
assert bool(cfg.mask_truncated_completions) is True
assert float(cfg.beta) == 0.0 # DAPO removes KL
def test_gspo_is_sequence_level(tmp_path):
cfg = make_po_config("gspo", output_dir=str(tmp_path))
# GSPO = grpo loss + SEQUENCE-level importance ratio
assert str(cfg.loss_type) == "grpo"
assert str(cfg.importance_sampling_level) == "sequence"
def test_gspo_guard_rejects_token_override(tmp_path):
# Overriding back to token-level would silently degrade GSPO to GRPO -> guard.
with pytest.raises(AssertionError):
make_po_config(
"gspo", output_dir=str(tmp_path), importance_sampling_level="token"
)
def test_cispo_preset(tmp_path):
cfg = make_po_config("cispo", output_dir=str(tmp_path))
assert str(cfg.loss_type) == "cispo"
# eps_max (ScaleRL recommended 5.0) carried via epsilon_high
assert cfg.epsilon_high is not None and float(cfg.epsilon_high) >= 5.0
def test_overrides_apply_on_top(tmp_path):
cfg = make_po_config(
"dr_grpo", output_dir=str(tmp_path), beta=0.05, num_generations=4
)
assert float(cfg.beta) == 0.05
assert int(cfg.num_generations) == 4
assert str(cfg.loss_type) == "dr_grpo" # preset preserved under overrides
def test_default_objective_is_dr_grpo(tmp_path):
cfg = make_po_config(output_dir=str(tmp_path))
assert str(cfg.loss_type) == "dr_grpo"