"""Tests for the policy-optimization objective menu (make_po_config, ADR-014). These build real trl GRPOConfigs, so they require trl installed (the framework's .venv has trl==1.5.0). Skips cleanly if trl is absent. """ from __future__ import annotations import pytest trl = pytest.importorskip("trl") from composer_replication.trainer.composer_trainer import ( # noqa: E402 PO_OBJECTIVES, make_po_config, ) def test_menu_lists_expected_objectives(): assert set(PO_OBJECTIVES) == {"grpo", "dr_grpo", "bnpo", "dapo", "gspo", "cispo"} def test_unknown_objective_raises_with_menu(tmp_path): with pytest.raises(ValueError) as ei: make_po_config("nope", output_dir=str(tmp_path)) msg = str(ei.value) assert "Unknown PO objective" in msg and "dapo" in msg and "gspo" in msg def test_grpo_preset(tmp_path): cfg = make_po_config("grpo", output_dir=str(tmp_path)) assert str(cfg.loss_type) == "grpo" assert str(cfg.importance_sampling_level) == "token" # group scaling = std-normalized advantage (vanilla GRPO) assert str(cfg.scale_rewards).lower() in ("group", "true") def test_dr_grpo_preset_matches_legacy(tmp_path): cfg = make_po_config("dr_grpo", output_dir=str(tmp_path)) assert str(cfg.loss_type) == "dr_grpo" # no std-normalization (the Dr.GRPO fix) assert str(cfg.scale_rewards).lower() in ("none", "false") def test_dapo_preset_sets_decoupled_clip(tmp_path): cfg = make_po_config("dapo", output_dir=str(tmp_path)) assert str(cfg.loss_type) == "dapo" # clip-higher: epsilon_high strictly above epsilon assert cfg.epsilon_high is not None assert float(cfg.epsilon_high) > float(cfg.epsilon) assert bool(cfg.mask_truncated_completions) is True assert float(cfg.beta) == 0.0 # DAPO removes KL def test_gspo_is_sequence_level(tmp_path): cfg = make_po_config("gspo", output_dir=str(tmp_path)) # GSPO = grpo loss + SEQUENCE-level importance ratio assert str(cfg.loss_type) == "grpo" assert str(cfg.importance_sampling_level) == "sequence" def test_gspo_guard_rejects_token_override(tmp_path): # Overriding back to token-level would silently degrade GSPO to GRPO -> guard. with pytest.raises(AssertionError): make_po_config( "gspo", output_dir=str(tmp_path), importance_sampling_level="token" ) def test_cispo_preset(tmp_path): cfg = make_po_config("cispo", output_dir=str(tmp_path)) assert str(cfg.loss_type) == "cispo" # eps_max (ScaleRL recommended 5.0) carried via epsilon_high assert cfg.epsilon_high is not None and float(cfg.epsilon_high) >= 5.0 def test_overrides_apply_on_top(tmp_path): cfg = make_po_config( "dr_grpo", output_dir=str(tmp_path), beta=0.05, num_generations=4 ) assert float(cfg.beta) == 0.05 assert int(cfg.num_generations) == 4 assert str(cfg.loss_type) == "dr_grpo" # preset preserved under overrides def test_default_objective_is_dr_grpo(tmp_path): cfg = make_po_config(output_dir=str(tmp_path)) assert str(cfg.loss_type) == "dr_grpo"