File size: 3,082 Bytes
aae66fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Tests for the policy-optimization objective menu (make_po_config, ADR-014).

These build real trl GRPOConfigs, so they require trl installed (the framework's
.venv has trl==1.5.0). Skips cleanly if trl is absent.
"""
from __future__ import annotations

import pytest

trl = pytest.importorskip("trl")

from composer_replication.trainer.composer_trainer import (  # noqa: E402
    PO_OBJECTIVES,
    make_po_config,
)


def test_menu_lists_expected_objectives():
    assert set(PO_OBJECTIVES) == {"grpo", "dr_grpo", "bnpo", "dapo", "gspo", "cispo"}


def test_unknown_objective_raises_with_menu(tmp_path):
    with pytest.raises(ValueError) as ei:
        make_po_config("nope", output_dir=str(tmp_path))
    msg = str(ei.value)
    assert "Unknown PO objective" in msg and "dapo" in msg and "gspo" in msg


def test_grpo_preset(tmp_path):
    cfg = make_po_config("grpo", output_dir=str(tmp_path))
    assert str(cfg.loss_type) == "grpo"
    assert str(cfg.importance_sampling_level) == "token"
    # group scaling = std-normalized advantage (vanilla GRPO)
    assert str(cfg.scale_rewards).lower() in ("group", "true")


def test_dr_grpo_preset_matches_legacy(tmp_path):
    cfg = make_po_config("dr_grpo", output_dir=str(tmp_path))
    assert str(cfg.loss_type) == "dr_grpo"
    # no std-normalization (the Dr.GRPO fix)
    assert str(cfg.scale_rewards).lower() in ("none", "false")


def test_dapo_preset_sets_decoupled_clip(tmp_path):
    cfg = make_po_config("dapo", output_dir=str(tmp_path))
    assert str(cfg.loss_type) == "dapo"
    # clip-higher: epsilon_high strictly above epsilon
    assert cfg.epsilon_high is not None
    assert float(cfg.epsilon_high) > float(cfg.epsilon)
    assert bool(cfg.mask_truncated_completions) is True
    assert float(cfg.beta) == 0.0  # DAPO removes KL


def test_gspo_is_sequence_level(tmp_path):
    cfg = make_po_config("gspo", output_dir=str(tmp_path))
    # GSPO = grpo loss + SEQUENCE-level importance ratio
    assert str(cfg.loss_type) == "grpo"
    assert str(cfg.importance_sampling_level) == "sequence"


def test_gspo_guard_rejects_token_override(tmp_path):
    # Overriding back to token-level would silently degrade GSPO to GRPO -> guard.
    with pytest.raises(AssertionError):
        make_po_config(
            "gspo", output_dir=str(tmp_path), importance_sampling_level="token"
        )


def test_cispo_preset(tmp_path):
    cfg = make_po_config("cispo", output_dir=str(tmp_path))
    assert str(cfg.loss_type) == "cispo"
    # eps_max (ScaleRL recommended 5.0) carried via epsilon_high
    assert cfg.epsilon_high is not None and float(cfg.epsilon_high) >= 5.0


def test_overrides_apply_on_top(tmp_path):
    cfg = make_po_config(
        "dr_grpo", output_dir=str(tmp_path), beta=0.05, num_generations=4
    )
    assert float(cfg.beta) == 0.05
    assert int(cfg.num_generations) == 4
    assert str(cfg.loss_type) == "dr_grpo"  # preset preserved under overrides


def test_default_objective_is_dr_grpo(tmp_path):
    cfg = make_po_config(output_dir=str(tmp_path))
    assert str(cfg.loss_type) == "dr_grpo"