| import sys |
| from pathlib import Path |
|
|
| import pytest |
| import torch |
|
|
| from sglang.srt.debug_utils.comparator.output_types import SummaryRecord |
| from sglang.srt.debug_utils.comparator.utils import ( |
| Pair, |
| argmax_coord, |
| auto_descend_dir, |
| calc_per_token_rel_diff, |
| calc_rel_diff, |
| compute_exit_code, |
| compute_smaller_dtype, |
| try_unify_shape, |
| ) |
| from sglang.test.ci.ci_register import register_cpu_ci |
|
|
| register_cpu_ci(est_time=10, suite="default", nightly=True) |
|
|
|
|
| class TestCalcRelDiff: |
| def test_identical_tensors(self): |
| x = torch.randn(10, 10) |
| assert calc_rel_diff(x, x).item() == pytest.approx(0.0, abs=1e-5) |
|
|
| def test_orthogonal_tensors(self): |
| result = calc_rel_diff( |
| torch.tensor([1.0, 0.0]), torch.tensor([0.0, 1.0]) |
| ).item() |
| assert result == pytest.approx(1.0, abs=1e-5) |
|
|
| def test_similar_tensors(self): |
| x = torch.tensor([1.0, 2.0, 3.0]) |
| y = torch.tensor([1.01, 2.01, 3.01]) |
| result = calc_rel_diff(x, y).item() |
| assert 0.0 < result < 0.01 |
|
|
| def test_negated_tensors(self): |
| x = torch.tensor([1.0, 2.0]) |
| result = calc_rel_diff(x, -x).item() |
| assert result == pytest.approx(2.0, abs=1e-5) |
|
|
|
|
| class TestCalcPerTokenRelDiff: |
| def test_identical_tensors(self) -> None: |
| """Identical tensors β per-token diff all zero.""" |
| x: torch.Tensor = torch.randn(8, 16) |
| result: torch.Tensor = calc_per_token_rel_diff(x, x, seq_dim=0) |
|
|
| assert result.shape == (8,) |
| assert torch.allclose(result, torch.zeros(8), atol=1e-6) |
|
|
| def test_different_tensors(self) -> None: |
| """Single token position differs β that position has higher diff.""" |
| torch.manual_seed(42) |
| x: torch.Tensor = torch.randn(8, 16) |
| y: torch.Tensor = x.clone() |
| y[3, :] += 10.0 |
|
|
| result: torch.Tensor = calc_per_token_rel_diff(x, y, seq_dim=0) |
|
|
| assert result.shape == (8,) |
| assert result[3] > result[0] |
| assert result[3] > result[7] |
| for i in [0, 1, 2, 4, 5, 6, 7]: |
| assert result[i] < 1e-6 |
|
|
| def test_seq_dim_selection(self) -> None: |
| """Different seq_dim values produce correct output shapes.""" |
| x: torch.Tensor = torch.randn(4, 8, 16) |
| y: torch.Tensor = x + torch.randn_like(x) * 0.01 |
|
|
| assert calc_per_token_rel_diff(x, y, seq_dim=0).shape == (4,) |
| assert calc_per_token_rel_diff(x, y, seq_dim=1).shape == (8,) |
| assert calc_per_token_rel_diff(x, y, seq_dim=2).shape == (16,) |
|
|
| def test_1d_tensor(self) -> None: |
| """1D tensor with seq_dim=0 returns per-element diff.""" |
| x: torch.Tensor = torch.tensor([1.0, 2.0, 3.0]) |
| y: torch.Tensor = torch.tensor([1.0, 2.0, 4.0]) |
|
|
| result: torch.Tensor = calc_per_token_rel_diff(x, y, seq_dim=0) |
|
|
| assert result.shape == (3,) |
| assert result[0] < 1e-6 |
| assert result[1] < 1e-6 |
| assert result[2] > 0.01 |
|
|
|
|
| class TestArgmaxCoord: |
| def test_1d_tensor(self): |
| x = torch.tensor([0.0, 0.0, 5.0, 0.0]) |
| assert argmax_coord(x) == (2,) |
|
|
| def test_2d_tensor(self): |
| x = torch.zeros(3, 4) |
| x[1, 2] = 10.0 |
| assert argmax_coord(x) == (1, 2) |
|
|
| def test_3d_tensor(self): |
| x = torch.zeros(2, 3, 4) |
| x[1, 2, 3] = 10.0 |
| assert argmax_coord(x) == (1, 2, 3) |
|
|
|
|
| class TestTryUnifyShape: |
| def test_squeeze_leading_ones(self): |
| target = torch.Size([3, 4]) |
| assert try_unify_shape(torch.randn(1, 1, 3, 4), target).shape == target |
|
|
| def test_no_squeeze_when_leading_dim_not_one(self): |
| target = torch.Size([3, 4]) |
| assert try_unify_shape(torch.randn(2, 3, 4), target).shape == (2, 3, 4) |
|
|
| def test_same_shape_noop(self): |
| target = torch.Size([3, 4]) |
| x = torch.randn(3, 4) |
| result = try_unify_shape(x, target) |
| assert result.shape == target |
| assert result.data_ptr() == x.data_ptr() |
|
|
| def test_trailing_dims_mismatch(self): |
| target = torch.Size([5, 6]) |
| x = torch.randn(1, 3, 4) |
| result = try_unify_shape(x, target) |
| assert result.shape == (1, 3, 4) |
|
|
|
|
| class TestComputeSmallerDtype: |
| def test_float32_bfloat16(self): |
| assert ( |
| compute_smaller_dtype(Pair(x=torch.float32, y=torch.bfloat16)) |
| == torch.bfloat16 |
| ) |
|
|
| def test_reverse_order(self): |
| assert ( |
| compute_smaller_dtype(Pair(x=torch.bfloat16, y=torch.float32)) |
| == torch.bfloat16 |
| ) |
|
|
| def test_same_dtype_returns_none(self): |
| assert compute_smaller_dtype(Pair(x=torch.float32, y=torch.float32)) is None |
|
|
| def test_unknown_pair_returns_none(self): |
| assert compute_smaller_dtype(Pair(x=torch.int32, y=torch.int64)) is None |
|
|
|
|
| class TestPairMap: |
| def test_map_basic(self): |
| pair = Pair(x=[1, 2, 3], y=[4, 5, 6]) |
| result = pair.map(lambda lst: sum(lst)) |
| assert result.x == 6 |
| assert result.y == 15 |
|
|
| def test_map_type_change(self): |
| pair = Pair(x=[1, 2, 3], y=[10, 20]) |
| result = pair.map(len) |
| assert result.x == 3 |
| assert result.y == 2 |
|
|
| def test_map_returns_new_pair(self): |
| pair = Pair(x="hello", y="world") |
| result = pair.map(str.upper) |
| assert result.x == "HELLO" |
| assert result.y == "WORLD" |
| assert result is not pair |
|
|
|
|
| class TestComputeExitCode: |
| """Unit tests for compute_exit_code logic.""" |
|
|
| def test_all_passed(self): |
| """All passed β exit 0.""" |
| summary = SummaryRecord(total=3, passed=3, failed=0, skipped=0) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern=None, |
| failed_names=[], |
| ) |
| == 0 |
| ) |
|
|
| def test_has_failed_and_passed(self): |
| """Has failed and passed β exit 1.""" |
| summary = SummaryRecord(total=4, passed=2, failed=2, skipped=0) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern=None, |
| failed_names=["a", "b"], |
| ) |
| == 1 |
| ) |
|
|
| def test_all_failed(self): |
| """All failed (0 passed) β exit 1.""" |
| summary = SummaryRecord(total=3, passed=0, failed=3, skipped=0) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern=None, |
| failed_names=["a", "b", "c"], |
| ) |
| == 1 |
| ) |
|
|
| def test_all_skipped_allow_all(self): |
| """All skipped + allow_skipped_pattern='.*' β exit 1 (nothing passed).""" |
| summary = SummaryRecord(total=2, passed=0, failed=0, skipped=2) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=["a", "b"], |
| allow_failed_pattern=None, |
| failed_names=[], |
| ) |
| == 1 |
| ) |
|
|
| def test_all_skipped_forbid_all(self): |
| """All skipped + allow_skipped_pattern='^$' β exit 1.""" |
| summary = SummaryRecord(total=2, passed=0, failed=0, skipped=2) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern="^$", |
| skipped_names=["a", "b"], |
| allow_failed_pattern=None, |
| failed_names=[], |
| ) |
| == 1 |
| ) |
|
|
| def test_passed_and_skipped_allow_all(self): |
| """Passed + skipped, allow all β exit 0.""" |
| summary = SummaryRecord(total=3, passed=2, failed=0, skipped=1) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=["a"], |
| allow_failed_pattern=None, |
| failed_names=[], |
| ) |
| == 0 |
| ) |
|
|
| def test_passed_and_skipped_forbid_all(self): |
| """Passed + skipped + forbid all β exit 1.""" |
| summary = SummaryRecord(total=3, passed=2, failed=0, skipped=1) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern="^$", |
| skipped_names=["a"], |
| allow_failed_pattern=None, |
| failed_names=[], |
| ) |
| == 1 |
| ) |
|
|
| def test_skip_pattern_matches_specific_name(self): |
| """Pattern matching specific name allows that skip, forbids others.""" |
| summary = SummaryRecord(total=4, passed=2, failed=0, skipped=2) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern="positions|seq_lens", |
| skipped_names=["positions", "seq_lens"], |
| allow_failed_pattern=None, |
| failed_names=[], |
| ) |
| == 0 |
| ) |
|
|
| def test_skip_pattern_partial_match_forbidden(self): |
| """Pattern matches some skips but not all β exit 1.""" |
| summary = SummaryRecord(total=4, passed=1, failed=0, skipped=3) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern="positions|seq_lens", |
| skipped_names=["positions", "seq_lens", "hidden_states"], |
| allow_failed_pattern=None, |
| failed_names=[], |
| ) |
| == 1 |
| ) |
|
|
| def test_allow_failed_pattern_matches_all(self): |
| """allow_failed_pattern='.*' tolerates all failures β exit 0.""" |
| summary = SummaryRecord(total=3, passed=1, failed=2, skipped=0) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern=".*", |
| failed_names=["a", "b"], |
| ) |
| == 0 |
| ) |
|
|
| def test_allow_failed_pattern_matches_specific(self): |
| """Pattern matches all failed names β exit 0.""" |
| summary = SummaryRecord(total=3, passed=1, failed=2, skipped=0) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern="hidden_states|logits", |
| failed_names=["hidden_states", "logits"], |
| ) |
| == 0 |
| ) |
|
|
| def test_allow_failed_pattern_partial_match(self): |
| """Pattern matches some but not all failures β exit 1.""" |
| summary = SummaryRecord(total=3, passed=0, failed=3, skipped=0) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern="hidden_states", |
| failed_names=["hidden_states", "logits", "attn"], |
| ) |
| == 1 |
| ) |
|
|
| def test_allow_failed_pattern_no_failures(self): |
| """Pattern set but no failures β exit 0.""" |
| summary = SummaryRecord(total=2, passed=2, failed=0, skipped=0) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern=".*", |
| failed_names=[], |
| ) |
| == 0 |
| ) |
|
|
| def test_both_failed_and_skipped_patterns(self): |
| """Both patterns set, both satisfied β exit 0.""" |
| summary = SummaryRecord(total=4, passed=1, failed=1, skipped=2) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern="positions|seq_lens", |
| skipped_names=["positions", "seq_lens"], |
| allow_failed_pattern="logits", |
| failed_names=["logits"], |
| ) |
| == 0 |
| ) |
|
|
| def test_failed_pattern_satisfied_but_skipped_not(self): |
| """Failed pattern OK but skipped pattern fails β exit 1.""" |
| summary = SummaryRecord(total=3, passed=1, failed=1, skipped=1) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern="^$", |
| skipped_names=["a"], |
| allow_failed_pattern=".*", |
| failed_names=["b"], |
| ) |
| == 1 |
| ) |
|
|
| def test_zero_passed_exits_one(self): |
| """No tensors passed β exit 1, even when all failures are allowed.""" |
| summary = SummaryRecord(total=2, passed=0, failed=2, skipped=0) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern=".*", |
| failed_names=["a", "b"], |
| ) |
| == 1 |
| ) |
|
|
| def test_zero_passed_all_skipped_exits_one(self): |
| """All skipped, nothing passed β exit 1.""" |
| summary = SummaryRecord(total=3, passed=0, failed=0, skipped=3) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=["a", "b", "c"], |
| allow_failed_pattern=None, |
| failed_names=[], |
| ) |
| == 1 |
| ) |
|
|
| def test_errored_with_passed_exits_one(self): |
| """Has errored bundle even with passed β exit 1.""" |
| summary = SummaryRecord(total=3, passed=2, failed=0, skipped=0, errored=1) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern=None, |
| failed_names=[], |
| errored_names=["broken_tensor"], |
| ) |
| == 1 |
| ) |
|
|
| def test_errored_only_exits_one(self): |
| """All errored β exit 1 (passed==0 already exits 1, but errored also independently triggers).""" |
| summary = SummaryRecord(total=1, passed=0, failed=0, skipped=0, errored=1) |
| assert ( |
| compute_exit_code( |
| summary, |
| allow_skipped_pattern=".*", |
| skipped_names=[], |
| allow_failed_pattern=None, |
| failed_names=[], |
| errored_names=["broken_tensor"], |
| ) |
| == 1 |
| ) |
|
|
|
|
| def _make_pt(directory: Path) -> None: |
| directory.mkdir(parents=True, exist_ok=True) |
| torch.save(torch.tensor([1.0]), directory / "dummy.pt") |
|
|
|
|
| class TestAutoDescendDir: |
| def test_no_descend_when_pt_at_root(self, tmp_path: Path) -> None: |
| """Directory with .pt files directly is returned as-is.""" |
| _make_pt(tmp_path) |
| _make_pt(tmp_path / "child_a") |
| assert auto_descend_dir(tmp_path, label="test") == tmp_path |
|
|
| def test_descend_into_single_child(self, tmp_path: Path) -> None: |
| """Single child with .pt triggers descend.""" |
| child: Path = tmp_path / "engine_0" |
| _make_pt(child) |
| assert auto_descend_dir(tmp_path, label="test") == child |
|
|
| def test_descend_single_nonempty_child_among_empty(self, tmp_path: Path) -> None: |
| """Two subdirs but only one has .pt β descend into that one.""" |
| nonempty: Path = tmp_path / "engine_0" |
| _make_pt(nonempty) |
| (tmp_path / "empty_child").mkdir() |
| assert auto_descend_dir(tmp_path, label="test") == nonempty |
|
|
| def test_error_with_multiple_nonempty_children(self, tmp_path: Path) -> None: |
| """Two children with .pt files β ambiguous, raises ValueError.""" |
| _make_pt(tmp_path / "engine_0") |
| _make_pt(tmp_path / "engine_1") |
| with pytest.raises(ValueError, match="multiple subdirectories contain data"): |
| auto_descend_dir(tmp_path, label="test") |
|
|
| def test_error_when_no_data_found(self, tmp_path: Path) -> None: |
| """No .pt files anywhere β raises ValueError.""" |
| (tmp_path / "empty_child").mkdir() |
| with pytest.raises(ValueError, match="no .pt files found"): |
| auto_descend_dir(tmp_path, label="test") |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(pytest.main([__file__])) |
|
|