import subprocess import sys import textwrap from argparse import Namespace from pathlib import Path import pytest import torch import sglang.srt.debug_utils.comparator.entrypoint as _entrypoint_module import sglang.srt.debug_utils.dumper as _dumper_module from sglang.srt.debug_utils.comparator.entrypoint import ( parse_args, run, ) from sglang.srt.debug_utils.comparator.output_types import ( AnyRecord, ComparisonErrorRecord, ComparisonNonTensorRecord, ComparisonSkipRecord, ComparisonTensorRecord, ConfigRecord, InfoLog, LogRecord, ReplicatedCheckResult, SummaryRecord, _OutputRecord, parse_record_json, ) from sglang.srt.debug_utils.dumper import DumperConfig, _Dumper, _RecomputeStatus from sglang.test.ci.ci_register import register_cpu_ci register_cpu_ci(est_time=30, suite="default", nightly=True) _FIXED_EXP_NAME = "my_exp_name" # Each test has a one-line docstring describing the scenario it covers. class TestEntrypointGroupingRaw: """Test `--grouping-skip-keys` empty (raw) scenarios""" def test_run_basic(self, tmp_path, capsys): """Two matching tensors produce ConfigRecord, 2 ComparisonTensorRecords, and SummaryRecord.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a", "tensor_b"]) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) assert isinstance(records[0], ConfigRecord) assert len(_get_comparisons(records)) == 2 summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 2 assert summary.skipped == 0 def test_filter(self, tmp_path, capsys): """--filter selects only the matching tensor, producing 1 ComparisonTensorRecord.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a", "tensor_b"]) argv = _make_argv(baseline_path, target_path, filter="tensor_a", preset="raw") records, _ = _run_and_parse(argv, capsys) assert len(_get_comparisons(records)) == 1 def test_no_baseline_skip(self, tmp_path, capsys): """Target tensor missing from baseline emits a ComparisonSkipRecord with reason baseline_load_failed.""" baseline_path, target_path = _create_dumps( tmp_path, tensor_names=["tensor_a", "tensor_extra"], baseline_names=["tensor_a"], ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) skips = [r for r in records if isinstance(r, ComparisonSkipRecord)] assert len(skips) == 1 assert skips[0].reason == "baseline_load_failed" summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.skipped == 1 def test_step_range(self, tmp_path, capsys): """--start_step/--end_step restricts comparison to a single step out of three.""" baseline_path, target_path = _create_dumps(tmp_path, ["t"], num_steps=3) argv = _make_argv( baseline_path, target_path, start_step=1, end_step=1, preset="raw" ) records, _ = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 1 def test_all_valid_records(self, tmp_path, capsys): """Every emitted JSON record is a valid _OutputRecord subclass.""" baseline_path, target_path = _create_dumps(tmp_path, ["t"], num_steps=2) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) assert all(isinstance(r, _OutputRecord) for r in records) def test_comparison_failed(self, tmp_path, capsys): """Completely different tensors produce a failed ComparisonTensorRecord.""" torch.manual_seed(42) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="tensor_a", tensor=torch.randn(10, 10) ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="tensor_a", tensor=torch.randn(10, 10) * 100, ) argv = _make_argv(baseline_path, target_path, preset="raw", diff_threshold=1e-3) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].diff is not None assert not comparisons[0].diff.passed assert comparisons[0].category == "failed" summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.failed == 1 def test_shape_mismatch(self, tmp_path, capsys): """Different shapes produce shape_mismatch=True and category='failed'.""" torch.manual_seed(42) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="tensor_a", tensor=torch.randn(4, 8) ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="tensor_a", tensor=torch.randn(4, 10) ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].shape_mismatch is True assert comparisons[0].diff is None assert comparisons[0].category == "failed" summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.failed == 1 def test_unify_shape_leading_dims(self, tmp_path, capsys): """Leading singleton dims on baseline are squeezed to match target shape.""" torch.manual_seed(42) base_tensor = torch.randn(4, 8) baseline_tensor = base_tensor.unsqueeze(0) # (1, 4, 8) target_tensor = base_tensor + torch.randn(4, 8) * 0.0001 # (4, 8) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="tensor_a", tensor=baseline_tensor ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="tensor_a", tensor=target_tensor ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 comp = comparisons[0] assert comp.shape_mismatch is False assert comp.baseline.shape == [1, 4, 8] assert comp.target.shape == [4, 8] assert comp.unified_shape == [4, 8] assert comp.diff is not None assert comp.diff.passed def test_dtype_mismatch_downcast(self, tmp_path, capsys): """Baseline float32 vs target bfloat16 produces diff_downcast.""" torch.manual_seed(42) baseline_tensor = torch.randn(4, 8, dtype=torch.float32) target_tensor = (baseline_tensor + torch.randn(4, 8) * 0.0001).to( torch.bfloat16 ) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="tensor_a", tensor=baseline_tensor ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="tensor_a", tensor=target_tensor ) argv = _make_argv(baseline_path, target_path, preset="raw", diff_threshold=0.01) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].diff_downcast is not None assert comparisons[0].downcast_dtype is not None def test_mixed_summary(self, tmp_path, capsys): """One passed, one failed, one skipped tensor in a single run.""" torch.manual_seed(42) similar_tensor = torch.randn(4, 4) different_baseline = torch.randn(4, 4) different_target = torch.randn(4, 4) * 100 extra_tensor = torch.randn(4, 4) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" _create_rank_dump(baseline_dir, rank=0, name="similar", tensor=similar_tensor) _create_rank_dump( baseline_dir, rank=0, name="different", tensor=different_baseline ) _create_rank_dump( target_dir, rank=0, name="similar", tensor=similar_tensor + torch.randn(4, 4) * 0.0001, ) _create_rank_dump(target_dir, rank=0, name="different", tensor=different_target) _create_rank_dump(target_dir, rank=0, name="extra", tensor=extra_tensor) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, preset="raw", diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 1 assert summary.failed == 1 assert summary.skipped == 1 assert summary.total == 3 def test_filter_empty_result(self, tmp_path, capsys): """--filter matching nothing produces summary with total=0.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) argv = _make_argv( baseline_path, target_path, filter="nonexistent_pattern", preset="raw", ) records, _ = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 0 def test_raw_multi_rank(self, tmp_path, capsys): """Two ranks in raw grouping produce two ComparisonTensorRecords (one per rank).""" torch.manual_seed(42) tensor = torch.randn(4, 4) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for rank in range(2): _create_rank_dump(baseline_dir, rank=rank, name="hidden", tensor=tensor) _create_rank_dump( target_dir, rank=rank, name="hidden", tensor=tensor + torch.randn(4, 4) * 0.0001, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, preset="raw", diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 2 summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 2 assert summary.passed == 2 def test_text_output_smoke(self, tmp_path, capsys): """Text output format renders without errors and contains Config/Summary sections.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) argv = _make_argv( baseline_path, target_path, output_format="text", preset="raw" ) capsys.readouterr() run(parse_args(argv)) output = capsys.readouterr().out assert "Comparator Config" in output assert "SUMMARY" in output def test_text_output_with_failure(self, tmp_path, capsys): """Text output with a failed comparison renders failure info.""" torch.manual_seed(42) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="tensor_a", tensor=torch.randn(10, 10) ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="tensor_a", tensor=torch.randn(10, 10) * 100, ) argv = _make_argv( baseline_path, target_path, output_format="text", preset="raw" ) capsys.readouterr() run(parse_args(argv)) output = capsys.readouterr().out assert "SUMMARY" in output assert "failed" in output.lower() def test_duplicate_dump_pairing(self, tmp_path, capsys): """Same name dumped twice (different values) pairs by duplicate_index: 0th↔0th, 1st↔1st.""" torch.manual_seed(42) tensor_v0 = torch.randn(4, 4) tensor_v1 = torch.randn(4, 4) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir in [baseline_dir, target_dir]: with pytest.MonkeyPatch.context() as mp: mp.setattr(_dumper_module, "_get_rank", lambda: 0) dumper = _Dumper( config=DumperConfig( enable=True, dir=str(side_dir), exp_name=_FIXED_EXP_NAME, ) ) dumper.__dict__["_static_meta"] = {"world_rank": 0, "world_size": 1} dumper.dump("tensor_a", tensor_v0) dumper.dump("tensor_a", tensor_v1) dumper.step() argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, preset="raw", ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 2 assert all(c.diff is not None and c.diff.passed for c in comparisons) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 2 assert summary.passed == 2 class TestEntrypointGroupingLogical: """Test `--grouping-skip-keys rank` (logical) scenarios""" def test_no_dims_single_rank(self, tmp_path, capsys): """Single-rank dumps without dims fall back to raw loading.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a", "tensor_b"]) argv = _make_argv(baseline_path, target_path) records, _ = _run_and_parse(argv, capsys) assert len(_get_comparisons(records)) == 2 summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 2 assert summary.skipped == 0 def test_tp_unshard_same_size(self, tmp_path, capsys): """Both sides TP=2: shards are concatenated before comparison.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8) full_target = full_baseline + torch.randn(4, 8) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" baseline_path = _create_tp_sharded_dumps( baseline_dir, full_tensor=full_baseline, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", ) target_path = _create_tp_sharded_dumps( target_dir, full_tensor=full_target, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", ) argv = _make_argv(baseline_path, target_path, diff_threshold=0.01) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 1 assert summary.passed == 1 def test_tp_unshard_different_sizes(self, tmp_path, capsys): """Baseline TP=4 vs target TP=2: different shard counts are handled correctly.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8) full_target = full_baseline + torch.randn(4, 8) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" baseline_path = _create_tp_sharded_dumps( baseline_dir, full_tensor=full_baseline, name="hidden", tp_size=4, shard_dim=1, dims_str="b h[tp]", ) target_path = _create_tp_sharded_dumps( target_dir, full_tensor=full_target, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", ) argv = _make_argv(baseline_path, target_path, diff_threshold=0.01) records, _ = _run_and_parse(argv, capsys) _assert_single_comparison_passed(records) def test_one_side_dims_single_baseline(self, tmp_path, capsys): """Baseline has no dims (single rank), target has TP shards: unshard target only.""" torch.manual_seed(42) full_tensor = torch.randn(4, 8) target_full = full_tensor + torch.randn(4, 8) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" baseline_path = _create_rank_dump( baseline_dir, rank=0, name="hidden", tensor=full_tensor ) target_path = _create_tp_sharded_dumps( target_dir, full_tensor=target_full, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", ) argv = _make_argv(baseline_path, target_path, diff_threshold=0.01) records, _ = _run_and_parse(argv, capsys) _assert_single_comparison_passed(records) @pytest.mark.parametrize( "bad_side, expected_reason", [ ("baseline", "baseline_load_failed"), ("target", "target_load_failed"), ], ) def test_ambiguous_no_dims_skip(self, tmp_path, capsys, bad_side, expected_reason): """Multi-rank without dims on one side produces a ComparisonSkipRecord with the appropriate reason.""" torch.manual_seed(42) tensor = torch.randn(4, 8) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" good_dir = target_dir if bad_side == "baseline" else baseline_dir bad_dir = baseline_dir if bad_side == "baseline" else target_dir _create_rank_dump(good_dir, rank=0, name="hidden", tensor=tensor) for rank, shard in [(0, tensor[:, :4]), (1, tensor[:, 4:])]: _create_rank_dump(bad_dir, rank=rank, name="hidden", tensor=shard) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, ) records, _ = _run_and_parse(argv, capsys) skips = [r for r in records if isinstance(r, ComparisonSkipRecord)] assert len(skips) == 1 assert skips[0].reason == expected_reason summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.skipped == 1 def test_summary_counts_unshard(self, tmp_path, capsys): """Two TP-sharded tensors: summary counts total=2, passed=2, skipped=0.""" torch.manual_seed(42) full_a = torch.randn(4, 8) full_b = torch.randn(4, 8) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for tensor_name, tensor in [("t_a", full_a), ("t_b", full_b)]: baseline_path = _create_tp_sharded_dumps( baseline_dir, full_tensor=tensor, name=tensor_name, tp_size=2, shard_dim=1, dims_str="b h[tp]", ) target_tensor = tensor + torch.randn_like(tensor) * 0.0001 target_path = _create_tp_sharded_dumps( target_dir, full_tensor=target_tensor, name=tensor_name, tp_size=2, shard_dim=1, dims_str="b h[tp]", ) argv = _make_argv(baseline_path, target_path, diff_threshold=0.01) records, _ = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 2 assert summary.passed == 2 assert summary.failed == 0 assert summary.skipped == 0 def test_multi_step_tp(self, tmp_path, capsys): """Two steps with TP=2 shards: concat mode merges into one comparison.""" torch.manual_seed(42) full_tensor = torch.randn(4, 8) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" baseline_path = _create_tp_sharded_dumps( baseline_dir, full_tensor=full_tensor, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", num_steps=2, ) target_path = _create_tp_sharded_dumps( target_dir, full_tensor=full_tensor + torch.randn(4, 8) * 0.0001, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", num_steps=2, ) argv = _make_argv( baseline_path, target_path, diff_threshold=0.01, preset="sglang_megatron", ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 # concat along dim 0 (fallback, no token dim) → 2 steps × [4, 8] = [8, 8] assert comparisons[0].baseline.shape == [8, 8] summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 1 assert summary.passed == 1 def test_cp_axis_unshard(self, tmp_path, capsys): """CP-sharded tensors are correctly concatenated along the sequence dim.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 6) full_target = full_baseline + torch.randn(4, 8, 6) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir, full_tensor in [ (baseline_dir, full_baseline), (target_dir, full_target), ]: shards = list(full_tensor.chunk(2, dim=1)) for cp_rank in range(2): _create_rank_dump( side_dir, rank=cp_rank, name="attn_out", tensor=shards[cp_rank], dims="b s[cp] h", parallel_info={"cp_rank": cp_rank, "cp_size": 2}, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "attn_out" def test_filter_logical(self, tmp_path, capsys): """--filter in logical grouping selects only matching tensor bundles.""" torch.manual_seed(42) full_a = torch.randn(4, 8) full_b = torch.randn(4, 8) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for tensor_name, tensor in [("t_a", full_a), ("t_b", full_b)]: _create_tp_sharded_dumps( baseline_dir, full_tensor=tensor, name=tensor_name, tp_size=2, shard_dim=1, dims_str="b h[tp]", ) _create_tp_sharded_dumps( target_dir, full_tensor=tensor + torch.randn_like(tensor) * 0.0001, name=tensor_name, tp_size=2, shard_dim=1, dims_str="b h[tp]", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, filter="t_a", diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].name == "t_a" def test_mixed_dims_logical(self, tmp_path, capsys): """TP-sharded and single-rank tensors in the same logical run both compare successfully.""" torch.manual_seed(42) full_tp_tensor = torch.randn(4, 8) single_tensor = torch.randn(4, 4) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" _create_tp_sharded_dumps( baseline_dir, full_tensor=full_tp_tensor, name="tensor_a", tp_size=2, shard_dim=1, dims_str="b h[tp]", ) _create_tp_sharded_dumps( target_dir, full_tensor=full_tp_tensor + torch.randn(4, 8) * 0.0001, name="tensor_a", tp_size=2, shard_dim=1, dims_str="b h[tp]", ) _create_rank_dump(baseline_dir, rank=0, name="tensor_b", tensor=single_tensor) _create_rank_dump( target_dir, rank=0, name="tensor_b", tensor=single_tensor + torch.randn(4, 4) * 0.0001, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 2 assert all(c.diff is not None and c.diff.passed for c in comparisons) assert {c.name for c in comparisons} == {"tensor_a", "tensor_b"} summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 2 assert summary.passed == 2 def test_cp_tp_unshard(self, tmp_path, capsys): """CP=2 + TP=2: multi-axis shards are unsharded before comparison.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 16) full_target = full_baseline + torch.randn(4, 8, 16) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir, full_tensor in [ (baseline_dir, full_baseline), (target_dir, full_target), ]: _create_cp_tp_sharded_dumps( side_dir, full_tensor=full_tensor, name="hidden", cp_size=2, tp_size=2, seq_dim=1, head_dim=2, dims_str="b s[cp] h[tp]", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" def test_cp_tp_different_sizes(self, tmp_path, capsys): """Baseline CP=2+TP=2 vs target CP=1+TP=4: both sides independently unsharder.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 16) full_target = full_baseline + torch.randn(4, 8, 16) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" _create_cp_tp_sharded_dumps( baseline_dir, full_tensor=full_baseline, name="hidden", cp_size=2, tp_size=2, seq_dim=1, head_dim=2, dims_str="b s[cp] h[tp]", ) _create_tp_sharded_dumps( target_dir, full_tensor=full_target, name="hidden", tp_size=4, shard_dim=2, dims_str="b s h[tp]", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) _assert_single_comparison_passed(records) def test_ep_cp_tp_three_axis_unshard(self, tmp_path, capsys): """EP=2 + CP=2 + TP=2: three-axis shards are unsharded before comparison.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 16, 32) full_target = full_baseline + torch.randn(4, 8, 16, 32) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir, full_tensor in [ (baseline_dir, full_baseline), (target_dir, full_target), ]: _create_ep_cp_tp_sharded_dumps( side_dir, full_tensor=full_tensor, name="hidden", ep_size=2, cp_size=2, tp_size=2, expert_dim=1, seq_dim=2, head_dim=3, dims_str="b e[ep] s[cp] h[tp]", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" def test_cp_zigzag_unshard(self, tmp_path, capsys): """CP=2 zigzag reorder is correctly undone through the full pipeline.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 6) full_target = full_baseline + torch.randn(4, 8, 6) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir, full_tensor in [ (baseline_dir, full_baseline), (target_dir, full_target), ]: _create_cp_zigzag_tp_sharded_dumps( side_dir, full_tensor=full_tensor, name="attn_out", cp_size=2, tp_size=1, seq_dim=1, head_dim=2, dims_str="b s[cp:zigzag] h", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "attn_out" def test_cp_zigzag_tp_unshard(self, tmp_path, capsys): """CP=2 zigzag + TP=2: multi-axis unshard with reorder through full pipeline.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 16) full_target = full_baseline + torch.randn(4, 8, 16) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir, full_tensor in [ (baseline_dir, full_baseline), (target_dir, full_target), ]: _create_cp_zigzag_tp_sharded_dumps( side_dir, full_tensor=full_tensor, name="hidden", cp_size=2, tp_size=2, seq_dim=1, head_dim=2, dims_str="b s[cp:zigzag] h[tp]", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" def test_recompute_pseudo_replicated_verification(self, tmp_path, capsys): """Recompute pseudo-axis with identical original/recompute tensors → passed.""" torch.manual_seed(42) tensor = torch.randn(4, 8) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir in [baseline_dir, target_dir]: _create_recompute_rank_dump( side_dir, rank=0, name="hidden", original_tensor=tensor, recompute_tensor=tensor.clone(), ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, grouping_skip_keys=["rank", "recompute_status"], diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" def test_recompute_pseudo_mismatch_warning(self, tmp_path, capsys): """Recompute pseudo-axis with differing original/recompute → failed replicated_checks.""" torch.manual_seed(42) tensor = torch.randn(4, 8) mismatched_tensor = tensor + torch.randn(4, 8) * 10.0 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir in [baseline_dir, target_dir]: _create_recompute_rank_dump( side_dir, rank=0, name="hidden", original_tensor=tensor, recompute_tensor=mismatched_tensor, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, grouping_skip_keys=["rank", "recompute_status"], diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 recompute_checks: list[ReplicatedCheckResult] = [ c for c in comparisons[0].replicated_checks if c.axis == "recompute_pseudo" ] assert len(recompute_checks) > 0 assert any(not c.passed for c in recompute_checks) def test_tp_partial_reduction_unshard(self, tmp_path, capsys): """TP=2 with partial reduction: element-wise sum reconstructs full tensor.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8) full_target = full_baseline + torch.randn(4, 8) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" baseline_path = _create_tp_partial_dumps( baseline_dir, full_tensor=full_baseline, name="attn_out", tp_size=2, dims_str="b h[tp:partial]", ) target_path = _create_tp_partial_dumps( target_dir, full_tensor=full_target, name="attn_out", tp_size=2, dims_str="b h[tp:partial]", ) argv = _make_argv(baseline_path, target_path, diff_threshold=0.01) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "attn_out" summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 1 assert summary.passed == 1 def test_tp_partial_vs_single_rank(self, tmp_path, capsys): """Baseline single rank vs target TP=2 partial: unshard target then compare.""" torch.manual_seed(42) full_tensor = torch.randn(4, 8) target_full = full_tensor + torch.randn(4, 8) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" baseline_path = _create_rank_dump( baseline_dir, rank=0, name="attn_out", tensor=full_tensor ) target_path = _create_tp_partial_dumps( target_dir, full_tensor=target_full, name="attn_out", tp_size=2, dims_str="b h[tp:partial]", ) argv = _make_argv(baseline_path, target_path, diff_threshold=0.01) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "attn_out" def test_cp_concat_tp_partial_reduction(self, tmp_path, capsys): """CP=2 concat + TP=2 partial reduction: multi-axis unshard.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 16) full_target = full_baseline + torch.randn(4, 8, 16) * 0.001 for side_dir, full_tensor in [ (tmp_path / "baseline", full_baseline), (tmp_path / "target", full_target), ]: side_dir.mkdir() cp_chunks = list(full_tensor.chunk(2, dim=1)) rank = 0 for cp_rank in range(2): for tp_rank in range(2): _create_rank_dump( side_dir, rank=rank, name="hidden", tensor=cp_chunks[cp_rank] / 2, dims="b s[cp] h[tp:partial]", parallel_info={ "cp_rank": cp_rank, "cp_size": 2, "tp_rank": tp_rank, "tp_size": 2, }, ) rank += 1 argv = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" def test_cp_zigzag_sp_same_dim_unshard(self, tmp_path, capsys): """CP=2 zigzag + SP=2 on same seq dim: multi-axis unshard + reorder.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 6) full_target = full_baseline + torch.randn(4, 8, 6) * 0.001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir, full_tensor in [ (baseline_dir, full_baseline), (target_dir, full_target), ]: _create_cp_zigzag_sp_sharded_dumps( side_dir, full_tensor=full_tensor, name="hidden", cp_size=2, sp_size=2, dims_str="b s[cp:zigzag,sp] h", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" class TestEntrypointPerStepMode: """Test per-step comparison mode (sglang_dev preset behavior).""" def test_multi_step_per_step_comparison(self, tmp_path, capsys): """Multiple steps produce one ComparisonTensorRecord per step with step field set.""" torch.manual_seed(42) baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"], num_steps=3) argv = _make_argv(baseline_path, target_path, diff_threshold=0.1) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 3 steps: list[int] = sorted(c.location.step for c in comparisons) assert steps == [0, 1, 2] assert all(c.diff is not None and c.diff.passed for c in comparisons) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 3 assert summary.passed == 3 def test_per_step_with_tp_unshard(self, tmp_path, capsys): """Per-step mode with TP=2: each step independently unsharded and compared.""" torch.manual_seed(42) full_tensor = torch.randn(4, 8) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" baseline_path = _create_tp_sharded_dumps( baseline_dir, full_tensor=full_tensor, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", num_steps=2, ) target_path = _create_tp_sharded_dumps( target_dir, full_tensor=full_tensor + torch.randn(4, 8) * 0.0001, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", num_steps=2, ) argv = _make_argv(baseline_path, target_path, diff_threshold=0.01) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 2 steps: list[int] = sorted(c.location.step for c in comparisons) assert steps == [0, 1] assert all(c.diff is not None and c.diff.passed for c in comparisons) assert all(c.baseline.shape == [4, 8] for c in comparisons) def test_single_step_has_step_field(self, tmp_path, capsys): """Single step produces ComparisonTensorRecord with location.step=0.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"], num_steps=1) argv = _make_argv(baseline_path, target_path) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].location.step == 0 class TestEntrypointConcatMode: """Test concat token-aligner mode through the full entrypoint pipeline.""" @staticmethod def _make_dirs(tmp_path: Path) -> tuple[Path, Path]: baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() return baseline_dir, target_dir @staticmethod def _create_both_sides( tmp_path: Path, *, baseline_steps: list[torch.Tensor], target_steps: list[torch.Tensor], name: str = "hidden", dims: str | None = None, ) -> tuple[Path, Path]: """Create multi-step rank-0 dumps for both sides and return exp paths.""" baseline_dir, target_dir = TestEntrypointConcatMode._make_dirs(tmp_path) for side_dir, steps in [ (baseline_dir, baseline_steps), (target_dir, target_steps), ]: _create_multi_step_rank_dump( side_dir, rank=0, name=name, tensors_per_step=steps, dims=dims, ) return baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME @staticmethod def _run_concat( tmp_path: Path, capsys: pytest.CaptureFixture, *, baseline_steps: list[torch.Tensor], target_steps: list[torch.Tensor], name: str = "hidden", dims: str | None = None, diff_threshold: float = 0.01, ) -> list[AnyRecord]: """Create both-side dumps, run comparator, return parsed records.""" baseline_path, target_path = TestEntrypointConcatMode._create_both_sides( tmp_path, baseline_steps=baseline_steps, target_steps=target_steps, name=name, dims=dims, ) argv: list[str] = _make_argv( baseline_path, target_path, diff_threshold=diff_threshold, preset="sglang_megatron", ) records, _ = _run_and_parse(argv, capsys) return records def test_concat_multi_step_different_data(self, tmp_path, capsys): """Multi-step concat with different data per step + truncation.""" torch.manual_seed(42) # baseline: 2 steps [5,4] + [3,4] → concat → [8,4] baseline_step0 = torch.randn(5, 4) baseline_step1 = torch.randn(3, 4) baseline_concat = torch.cat([baseline_step0, baseline_step1], dim=0) # target: 1 step [6,4] — will be truncated to min(8,6)=6 target_step0 = baseline_concat[:6] + torch.randn(6, 4) * 0.0001 records = self._run_concat( tmp_path, capsys, baseline_steps=[baseline_step0, baseline_step1], target_steps=[target_step0], ) comparisons = _get_comparisons(records) assert len(comparisons) == 1 # truncated to min(8,6) = 6 along concat dim assert comparisons[0].baseline.shape == [6, 4] assert comparisons[0].target.shape == [6, 4] def test_concat_multi_step_tp_unshard(self, tmp_path, capsys): """Multi-step different data + TP=2 unshard + concat.""" torch.manual_seed(42) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" # 2 steps: [4,8] each → concat → [8,8] full_step0 = torch.randn(4, 8) full_step1 = torch.randn(4, 8) _create_multi_step_tp_sharded_dumps( baseline_dir, full_tensors_per_step=[full_step0, full_step1], name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", ) _create_multi_step_tp_sharded_dumps( target_dir, full_tensors_per_step=[ full_step0 + torch.randn(4, 8) * 0.0001, full_step1 + torch.randn(4, 8) * 0.0001, ], name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp]", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, preset="sglang_megatron", ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 # 2 steps × [4, 8] concat along dim 0 (fallback) → [8, 8] assert comparisons[0].baseline.shape == [8, 8] assert comparisons[0].diff is not None assert comparisons[0].diff.passed def test_concat_unequal_step_counts(self, tmp_path, capsys): """Baseline 3 steps vs target 2 steps with truncation.""" torch.manual_seed(42) # baseline: 3 steps [3]+[4]+[2] = 9 tokens along dim 0 b_step0 = torch.randn(3, 4) b_step1 = torch.randn(4, 4) b_step2 = torch.randn(2, 4) b_concat = torch.cat([b_step0, b_step1, b_step2], dim=0) # target: 2 steps [5]+[3] = 8 tokens along dim 0 t_step0 = b_concat[:5] + torch.randn(5, 4) * 0.0001 t_step1 = b_concat[5:8] + torch.randn(3, 4) * 0.0001 records = self._run_concat( tmp_path, capsys, baseline_steps=[b_step0, b_step1, b_step2], target_steps=[t_step0, t_step1], ) comparisons = _get_comparisons(records) assert len(comparisons) == 1 # truncated to min(9,8) = 8 assert comparisons[0].baseline.shape == [8, 4] assert comparisons[0].target.shape == [8, 4] assert comparisons[0].diff is not None assert comparisons[0].diff.passed def test_concat_token_dim_nonzero(self, tmp_path, capsys): """Token dim at dim=1 (dims='b t h') — concat along dim 1.""" torch.manual_seed(42) # 2 steps: [2,5,4] + [2,3,4] → concat along dim 1 → [2,8,4] b_step0 = torch.randn(2, 5, 4) b_step1 = torch.randn(2, 3, 4) b_concat = torch.cat([b_step0, b_step1], dim=1) t_step0 = b_concat[:, :5, :] + torch.randn(2, 5, 4) * 0.0001 t_step1 = b_concat[:, 5:, :] + torch.randn(2, 3, 4) * 0.0001 records = self._run_concat( tmp_path, capsys, baseline_steps=[b_step0, b_step1], target_steps=[t_step0, t_step1], dims="b t h", ) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].baseline.shape == [2, 8, 4] assert comparisons[0].diff is not None assert comparisons[0].diff.passed def test_concat_seq_dim_fallback(self, tmp_path, capsys): """No 't' dim but 's' dim present (dims='b s h') → concat along s.""" torch.manual_seed(42) # 2 steps: [2,5,4] + [2,3,4] → concat along dim 1 (s) → [2,8,4] b_step0 = torch.randn(2, 5, 4) b_step1 = torch.randn(2, 3, 4) b_concat = torch.cat([b_step0, b_step1], dim=1) t_step0 = b_concat[:, :5, :] + torch.randn(2, 5, 4) * 0.0001 t_step1 = b_concat[:, 5:, :] + torch.randn(2, 3, 4) * 0.0001 records = self._run_concat( tmp_path, capsys, baseline_steps=[b_step0, b_step1], target_steps=[t_step0, t_step1], dims="b s h", ) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].baseline.shape == [2, 8, 4] assert comparisons[0].diff is not None assert comparisons[0].diff.passed def test_concat_no_dims_fallback(self, tmp_path, capsys): """No dims annotation → fallback to concat along dim 0.""" torch.manual_seed(42) # 2 steps: [5,4] + [3,4] → concat along dim 0 → [8,4] b_step0 = torch.randn(5, 4) b_step1 = torch.randn(3, 4) b_concat = torch.cat([b_step0, b_step1], dim=0) t_step0 = b_concat[:5] + torch.randn(5, 4) * 0.0001 t_step1 = b_concat[5:] + torch.randn(3, 4) * 0.0001 records = self._run_concat( tmp_path, capsys, baseline_steps=[b_step0, b_step1], target_steps=[t_step0, t_step1], ) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].baseline.shape == [8, 4] assert comparisons[0].diff is not None assert comparisons[0].diff.passed def test_concat_preserves_step_order(self, tmp_path, capsys): """Verify step0 data precedes step1 data in the concatenated result.""" # deterministic integer data: step0=[1,2,3], step1=[4,5] b_step0 = torch.tensor([[1.0], [2.0], [3.0]]) b_step1 = torch.tensor([[4.0], [5.0]]) # target: same data, single step [1,2,3,4,5] t_full = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]]) records = self._run_concat( tmp_path, capsys, baseline_steps=[b_step0, b_step1], target_steps=[t_full], ) comp = _assert_single_comparison_passed(records) # if order were wrong, diff would not pass with exact integer data assert comp.baseline.shape == [5, 1] assert comp.diff is not None assert comp.diff.max_abs_diff == 0.0 def test_concat_aux_tensors_not_filtered(self, tmp_path, capsys): """Concat mode does not filter aux tensors — all participate in comparison.""" torch.manual_seed(42) baseline_dir, target_dir = self._make_dirs(tmp_path) hidden = torch.randn(4, 8) input_ids = torch.randint(0, 100, (4,)) positions = torch.arange(4) _create_rank_dump( baseline_dir, rank=0, name="hidden_states", tensor=hidden, extra_dumps=[("input_ids", input_ids), ("positions", positions)], ) _create_rank_dump( target_dir, rank=0, name="hidden_states", tensor=hidden + torch.randn(4, 8) * 0.0001, extra_dumps=[("input_ids", input_ids), ("positions", positions)], ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) # all 3 tensors should be compared (not filtered out) names = {c.name for c in comparisons} assert "hidden_states" in names assert "input_ids" in names assert "positions" in names assert len(comparisons) == 3 def test_concat_aligner_plan_fields(self, tmp_path, capsys): """ComparisonTensorRecord.traced_plan reports mode='concat' with plan=None.""" torch.manual_seed(42) records = self._run_concat( tmp_path, capsys, baseline_steps=[torch.randn(3, 4), torch.randn(2, 4)], target_steps=[torch.randn(3, 4), torch.randn(2, 4)], diff_threshold=100.0, ) comparisons = _get_comparisons(records) assert len(comparisons) == 1 traced_plan = comparisons[0].traced_plan assert traced_plan is not None plan = traced_plan.plan assert plan.token_aligner_mode == "concat_steps" assert plan.token_aligner_plan is None def test_concat_comparison_fails(self, tmp_path, capsys): """Completely different data → comparison fails.""" torch.manual_seed(42) b_step0 = torch.randn(4, 4) b_step1 = torch.randn(3, 4) # target: completely different random data torch.manual_seed(99) t_step0 = torch.randn(4, 4) * 100 t_step1 = torch.randn(3, 4) * 100 records = self._run_concat( tmp_path, capsys, baseline_steps=[b_step0, b_step1], target_steps=[t_step0, t_step1], diff_threshold=1e-6, ) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].diff is not None assert not comparisons[0].diff.passed summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.failed == 1 assert summary.passed == 0 def test_concat_multi_step_cp_unshard(self, tmp_path, capsys): """Multi-step different data + CP=2 unshard along seq dim + concat.""" torch.manual_seed(42) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" # 2 steps: [4,8,6] each → concat along seq dim (dim 1) → [4,16,6] full_step0 = torch.randn(4, 8, 6) full_step1 = torch.randn(4, 8, 6) for side_dir, steps in [ (baseline_dir, [full_step0, full_step1]), ( target_dir, [ full_step0 + torch.randn(4, 8, 6) * 0.0001, full_step1 + torch.randn(4, 8, 6) * 0.0001, ], ), ]: for cp_rank in range(2): per_step_shards: list[torch.Tensor] = [ t.chunk(2, dim=1)[cp_rank] for t in steps ] _create_multi_step_rank_dump( side_dir, rank=cp_rank, name="attn_out", tensors_per_step=per_step_shards, dims="b s[cp] h", parallel_info={"cp_rank": cp_rank, "cp_size": 2}, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, preset="sglang_megatron", ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 # CP unshard: [4,4,6] × 2 ranks → [4,8,6] per step # concat along seq dim (dim 1): 2 steps × [4,8,6] → [4,16,6] assert comparisons[0].baseline.shape == [4, 16, 6] assert comparisons[0].diff is not None assert comparisons[0].diff.passed def test_concat_thd_cp_zigzag(self, tmp_path: Path, capsys) -> None: """Concat mode with THD CP=2 zigzag (Megatron format) — unshard + reorder works.""" torch.manual_seed(42) cp_size: int = 2 seq_lens: list[int] = [100, 64] total_tokens: int = sum(seq_lens) total_per_rank: int = 128 num_steps: int = 2 full_tensor: torch.Tensor = torch.randn(total_tokens + 92) baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() baseline_path: Path = _create_thd_cp_zigzag_dumps( baseline_dir, full_tensor=full_tensor, name="hidden_states", seq_lens=seq_lens, cp_size=cp_size, total_per_rank=total_per_rank, num_steps=num_steps, ) target_tensor: torch.Tensor = full_tensor + torch.randn_like(full_tensor) * 1e-5 target_path: Path = _create_thd_cp_zigzag_dumps( target_dir, full_tensor=target_tensor, name="hidden_states", seq_lens=seq_lens, cp_size=cp_size, total_per_rank=total_per_rank, num_steps=num_steps, ) argv: list[str] = _make_argv( baseline_path, target_path, preset="sglang_megatron", diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comparisons: list[ComparisonTensorRecord] = _get_comparisons(records) hidden_comparisons: list[ComparisonTensorRecord] = [ c for c in comparisons if c.name == "hidden_states" ] assert len(hidden_comparisons) >= 1 assert all(c.diff is not None and c.diff.passed for c in hidden_comparisons) class TestEntrypointAxisAligner: """Test cross-framework dim reordering through the full entrypoint pipeline.""" def test_axis_swap_different_dim_order(self, tmp_path, capsys): """Baseline dims 'b h d' vs target dims 'b d h': axis swapper rearranges baseline to match.""" torch.manual_seed(42) full_tensor = torch.randn(4, 8, 16) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" _create_rank_dump( baseline_dir, rank=0, name="hidden", tensor=full_tensor, dims="b h d", ) _create_rank_dump( target_dir, rank=0, name="hidden", tensor=full_tensor.permute(0, 2, 1).contiguous(), dims="b d h", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" assert comp.baseline.shape == [4, 16, 8] assert comp.target.shape == [4, 16, 8] def test_axis_swap_with_tp_unshard(self, tmp_path, capsys): """Baseline TP=2 with dims 'b h[tp] d' vs target TP=2 with dims 'b d h[tp]': unshard + axis swap.""" torch.manual_seed(42) full_tensor = torch.randn(4, 8, 16) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" _create_tp_sharded_dumps( baseline_dir, full_tensor=full_tensor, name="hidden", tp_size=2, shard_dim=1, dims_str="b h[tp] d", ) _create_tp_sharded_dumps( target_dir, full_tensor=full_tensor.permute(0, 2, 1).contiguous(), name="hidden", tp_size=2, shard_dim=2, dims_str="b d h[tp]", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" def test_squeeze_dim_one_side(self, tmp_path, capsys): """SGLang dims 't h' vs Megatron dims 't 1 h': axis aligner squeezes the singleton dim.""" torch.manual_seed(42) full_tensor = torch.randn(4, 8) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" _create_rank_dump( baseline_dir, rank=0, name="hidden", tensor=full_tensor, dims="t h", ) _create_rank_dump( target_dir, rank=0, name="hidden", tensor=full_tensor.unsqueeze(1), dims="t 1 h", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.name == "hidden" assert comp.baseline.shape == [4, 8] assert comp.target.shape == [4, 8] class TestEntrypointReplicatedAxis: """Test replicated-axis scenarios through the full entrypoint pipeline.""" def test_replicated_axis_identical_replicas_passed(self, tmp_path, capsys): """CP2 TP2, TP replicated and identical → passed, replicated_checks all passed.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 6) full_target = full_baseline + torch.randn(4, 8, 6) * 0.0001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir, full_tensor in [ (baseline_dir, full_baseline), (target_dir, full_target), ]: _create_replicated_tp_sharded_cp_dumps( side_dir, full_tensor=full_tensor, name="attn_out", cp_size=2, tp_size=2, seq_dim=1, dims_str="b s[cp] d # tp:replicated", ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comp = _assert_single_comparison_passed(records) assert comp.errors == [] assert comp.infos == [] assert all(c.passed for c in comp.replicated_checks) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 1 def test_replicated_mismatch_fails(self, tmp_path, capsys): """CP2 TP2, TP replicas differ (> atol) → failed with replicated_checks.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 6) full_target = full_baseline + torch.randn(4, 8, 6) * 0.0001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir, full_tensor in [ (baseline_dir, full_baseline), (target_dir, full_target), ]: _create_replicated_tp_sharded_cp_dumps( side_dir, full_tensor=full_tensor, name="attn_out", cp_size=2, tp_size=2, seq_dim=1, dims_str="b s[cp] d # tp:replicated", tp_noise=0.5, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].category == "failed" assert any(not c.passed for c in comparisons[0].replicated_checks) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.failed == 1 def test_summary_counts_failed_from_replicated_checks_only(self, tmp_path, capsys): """Diff itself passes but TP replicas differ → summary.failed=1 from replicated_checks.""" torch.manual_seed(42) full_baseline = torch.randn(4, 8, 6) full_target = full_baseline + torch.randn(4, 8, 6) * 0.0001 baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" _create_replicated_tp_sharded_cp_dumps( baseline_dir, full_tensor=full_baseline, name="attn_out", cp_size=2, tp_size=2, seq_dim=1, dims_str="b s[cp] d # tp:replicated", tp_noise=0.5, ) _create_replicated_tp_sharded_cp_dumps( target_dir, full_tensor=full_target, name="attn_out", cp_size=2, tp_size=2, seq_dim=1, dims_str="b s[cp] d # tp:replicated", tp_noise=0.5, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.5, ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 comp = comparisons[0] assert comp.diff is not None assert comp.diff.passed assert any(not c.passed for c in comp.replicated_checks) assert comp.category == "failed" summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.failed == 1 assert summary.passed == 0 def test_replicated_shape_mismatch(self, tmp_path, capsys): """TP replicated tensors with different shapes → failed, replicated diff=None.""" torch.manual_seed(42) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir in [baseline_dir, target_dir]: # rank 0 (cp=0, tp=0): shape (4, 4, 6) _create_rank_dump( side_dir, rank=0, name="attn_out", tensor=torch.randn(4, 4, 6), dims="b s[cp] d # tp:replicated", parallel_info={ "cp_rank": 0, "cp_size": 2, "tp_rank": 0, "tp_size": 2, }, ) # rank 1 (cp=0, tp=1): shape (4, 4, 3) — different last dim _create_rank_dump( side_dir, rank=1, name="attn_out", tensor=torch.randn(4, 4, 3), dims="b s[cp] d # tp:replicated", parallel_info={ "cp_rank": 0, "cp_size": 2, "tp_rank": 1, "tp_size": 2, }, ) # rank 2 (cp=1, tp=0): shape (4, 4, 6) _create_rank_dump( side_dir, rank=2, name="attn_out", tensor=torch.randn(4, 4, 6), dims="b s[cp] d # tp:replicated", parallel_info={ "cp_rank": 1, "cp_size": 2, "tp_rank": 0, "tp_size": 2, }, ) # rank 3 (cp=1, tp=1): shape (4, 4, 3) — different last dim _create_rank_dump( side_dir, rank=3, name="attn_out", tensor=torch.randn(4, 4, 3), dims="b s[cp] d # tp:replicated", parallel_info={ "cp_rank": 1, "cp_size": 2, "tp_rank": 1, "tp_size": 2, }, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, diff_threshold=0.01, ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].category == "failed" failed_checks = [c for c in comparisons[0].replicated_checks if not c.passed] assert len(failed_checks) >= 1 assert all(c.diff is None for c in failed_checks) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.failed == 1 class TestEntrypointAlignment: """Test smart token alignment with aux tensors.""" def test_sglang_multi_step_alignment(self, tmp_path, capsys): """SGLang multi-step dumps with aux tensors auto-trigger alignment.""" torch.manual_seed(42) hidden_dim = 8 hidden_step0 = torch.randn(5, hidden_dim) hidden_step1 = torch.randn(2, hidden_dim) exp_paths: list[Path] = [] for side_dir in ["baseline", "target"]: d = tmp_path / side_dir d.mkdir() dumper = _Dumper( config=DumperConfig( enable=True, dir=str(d), exp_name=_FIXED_EXP_NAME, ) ) # Step 0: prefill with 2 sequences (3+2 tokens) dumper.dump("input_ids", torch.tensor([10, 20, 30, 40, 50])) dumper.dump("positions", torch.tensor([0, 1, 2, 0, 1])) dumper.dump("seq_lens", torch.tensor([3, 2])) dumper.dump("req_pool_indices", torch.tensor([7, 3])) dumper.dump("rids", ["A", "B"]) dumper.dump("hidden_states", hidden_step0) dumper.step() # Step 1: decode (1 token per sequence) dumper.dump("input_ids", torch.tensor([31, 51])) dumper.dump("positions", torch.tensor([3, 2])) dumper.dump("seq_lens", torch.tensor([1, 1])) dumper.dump("req_pool_indices", torch.tensor([7, 3])) dumper.dump("rids", ["A", "B"]) dumper.dump("hidden_states", hidden_step1) dumper.step() exp_paths.append(d / _FIXED_EXP_NAME) argv = _make_argv( exp_paths[0], exp_paths[1], grouping_skip_keys=["rank", "step"], token_aligner="smart", ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) # AUX_NAMES are filtered out after plan computation → only hidden_states remains assert len(comparisons) == 1 assert comparisons[0].name == "hidden_states" assert comparisons[0].diff is not None assert comparisons[0].diff.passed summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 1 assert summary.failed == 0 assert summary.skipped == 0 def test_sglang_vs_megatron_cross_framework(self, tmp_path, capsys): """SGLang 4-step thd baseline vs Megatron 1-step thd target align correctly.""" torch.manual_seed(42) hidden_dim: int = 8 all_hiddens: torch.Tensor = torch.randn(11, hidden_dim) seq_a_hiddens: torch.Tensor = all_hiddens[:6] seq_b_hiddens: torch.Tensor = all_hiddens[6:] # --- SGLang baseline: 1 prefill + 3 decode --- sglang_dir: Path = tmp_path / "baseline" sglang_dir.mkdir() sglang_dumper = _Dumper( config=DumperConfig( enable=True, dir=str(sglang_dir), exp_name=_FIXED_EXP_NAME, ) ) # Step 0: prefill — seq A (3 tokens) + seq B (2 tokens) sglang_dumper.dump("input_ids", torch.tensor([10, 20, 30, 40, 50])) sglang_dumper.dump("positions", torch.tensor([0, 1, 2, 0, 1])) sglang_dumper.dump("seq_lens", torch.tensor([3, 2])) sglang_dumper.dump("req_pool_indices", torch.tensor([7, 3])) sglang_dumper.dump("rids", ["A", "B"]) sglang_dumper.dump( "hidden_states", torch.stack( [ seq_a_hiddens[0], seq_a_hiddens[1], seq_a_hiddens[2], seq_b_hiddens[0], seq_b_hiddens[1], ] ), ) sglang_dumper.step() # Steps 1-3: decode — 1 token per sequence decode_data: list[dict[str, object]] = [ { "input_ids": torch.tensor([31, 51]), "positions": torch.tensor([3, 2]), "hidden": torch.stack([seq_a_hiddens[3], seq_b_hiddens[2]]), }, { "input_ids": torch.tensor([32, 52]), "positions": torch.tensor([4, 3]), "hidden": torch.stack([seq_a_hiddens[4], seq_b_hiddens[3]]), }, { "input_ids": torch.tensor([33, 53]), "positions": torch.tensor([5, 4]), "hidden": torch.stack([seq_a_hiddens[5], seq_b_hiddens[4]]), }, ] for step_data in decode_data: sglang_dumper.dump("input_ids", step_data["input_ids"]) sglang_dumper.dump("positions", step_data["positions"]) sglang_dumper.dump("seq_lens", torch.tensor([1, 1])) sglang_dumper.dump("req_pool_indices", torch.tensor([7, 3])) sglang_dumper.dump("rids", ["A", "B"]) sglang_dumper.dump("hidden_states", step_data["hidden"]) sglang_dumper.step() # --- Megatron target: 1 step, thd [T, H] --- megatron_dir: Path = tmp_path / "target" megatron_dir.mkdir() megatron_dumper = _Dumper( config=DumperConfig( enable=True, dir=str(megatron_dir), exp_name=_FIXED_EXP_NAME, ) ) # THD flat: seq A (6 tokens) + seq B (5 tokens) = 11 tokens total megatron_input_ids: torch.Tensor = torch.tensor( [10, 20, 30, 31, 32, 33, 40, 50, 51, 52, 53] ) megatron_cu_seqlens: torch.Tensor = torch.tensor([0, 6, 11]) megatron_hidden: torch.Tensor = torch.cat([seq_a_hiddens, seq_b_hiddens], dim=0) megatron_dumper.dump("input_ids", megatron_input_ids) megatron_dumper.dump("cu_seqlens_q", megatron_cu_seqlens) megatron_dumper.dump("hidden_states", megatron_hidden) megatron_dumper.step() # --- Run comparison --- argv = _make_argv( sglang_dir / _FIXED_EXP_NAME, megatron_dir / _FIXED_EXP_NAME, grouping_skip_keys=["rank", "step"], token_aligner="smart", ) records, _ = _run_and_parse(argv, capsys) log_records = [r for r in records if isinstance(r, LogRecord)] layout_infos = [ i for lr in log_records for i in lr.infos if isinstance(i, InfoLog) and i.category == "layout_detection_fallback" ] assert len(layout_infos) == 1 comparisons = _get_comparisons(records) # AUX_NAMES filtered out → only hidden_states remains assert len(comparisons) == 1 assert comparisons[0].name == "hidden_states" assert comparisons[0].diff is not None assert comparisons[0].diff.passed summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 1 assert summary.failed == 0 assert summary.skipped == 0 def test_alignment_fallback_when_no_aux(self, tmp_path, capsys): """Without aux tensors, smart alignment falls back to per-step comparison.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"], num_steps=2) argv = _make_argv( baseline_path, target_path, token_aligner="smart", diff_threshold=0.1, ) capsys.readouterr() run(parse_args(argv)) captured = capsys.readouterr() records = _parse_jsonl(captured.out) log_records = [r for r in records if isinstance(r, LogRecord)] aux_missing_infos = [ i for lr in log_records for i in lr.infos if isinstance(i, InfoLog) and i.category == "aux_tensors_missing" ] assert len(aux_missing_infos) == 1 comparisons = _get_comparisons(records) assert len(comparisons) == 2 summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.total == 2 assert summary.passed == 2 class TestEntrypointNonTensorValues: """Test non-tensor value comparison through the full entrypoint pipeline.""" def test_non_tensor_float_same_value(self, tmp_path: Path, capsys) -> None: """Two sides dump the same float → ComparisonNonTensorRecord with values_equal=True, category=passed.""" baseline_path, target_path = _create_non_tensor_dumps( tmp_path, name="sm_scale", baseline_value=0.125, target_value=0.125 ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) non_tensors = _get_non_tensors(records) assert len(non_tensors) == 1 assert non_tensors[0].name == "sm_scale" assert non_tensors[0].values_equal is True assert non_tensors[0].category == "passed" summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 1 assert summary.failed == 0 def test_non_tensor_float_different_value(self, tmp_path: Path, capsys) -> None: """Two sides dump different floats → ComparisonNonTensorRecord with values_equal=False, category=failed.""" baseline_path, target_path = _create_non_tensor_dumps( tmp_path, name="sm_scale", baseline_value=0.125, target_value=0.25 ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) non_tensors = _get_non_tensors(records) assert len(non_tensors) == 1 assert non_tensors[0].values_equal is False assert non_tensors[0].category == "failed" summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.failed == 1 def test_non_tensor_string_value(self, tmp_path: Path, capsys) -> None: """String non-tensor values are compared and displayed correctly.""" baseline_path, target_path = _create_non_tensor_dumps( tmp_path, name="attn_backend", baseline_value="flash_attn", target_value="flash_attn", ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) non_tensors = _get_non_tensors(records) assert len(non_tensors) == 1 assert non_tensors[0].values_equal is True assert non_tensors[0].baseline_type == "str" assert non_tensors[0].target_type == "str" def test_non_tensor_mixed_with_tensor(self, tmp_path: Path, capsys) -> None: """Tensors and non_tensors in the same dump are each handled correctly.""" torch.manual_seed(42) tensor = torch.randn(4, 4) baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" for side_dir in [baseline_dir, target_dir]: _create_non_tensor_rank_dump( side_dir, rank=0, name="sm_scale", value=0.125, extra_tensor_dumps=[("hidden", tensor)], ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, preset="raw", ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) non_tensors = _get_non_tensors(records) assert len(comparisons) == 1 assert comparisons[0].name == "hidden" assert len(non_tensors) == 1 assert non_tensors[0].name == "sm_scale" assert non_tensors[0].values_equal is True summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 2 def test_non_tensor_complex_object(self, tmp_path: Path, capsys) -> None: """Complex objects (e.g. dict containing a tensor) are displayed via repr, not skipped.""" value = {"a": 1, "b": "hello", "c": torch.tensor([1.0, 2.0])} baseline_path, target_path = _create_non_tensor_dumps( tmp_path, name="debug_info", baseline_value=value, target_value=value ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) non_tensors = _get_non_tensors(records) assert len(non_tensors) == 1 assert non_tensors[0].name == "debug_info" assert non_tensors[0].baseline_type == "dict" assert non_tensors[0].target_type == "dict" def test_non_tensor_none_value(self, tmp_path: Path, capsys) -> None: """Dumping None is displayed as ComparisonNonTensorRecord, not skipped as load failure.""" baseline_path, target_path = _create_non_tensor_dumps( tmp_path, name="optional_param", baseline_value=None, target_value=None ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) non_tensors = _get_non_tensors(records) assert len(non_tensors) == 1 assert non_tensors[0].name == "optional_param" assert non_tensors[0].values_equal is True assert non_tensors[0].baseline_value == "None" assert non_tensors[0].baseline_type == "NoneType" assert non_tensors[0].category == "passed" def test_non_tensor_json_roundtrip(self, tmp_path: Path, capsys) -> None: """ComparisonNonTensorRecord JSON output can be parsed back correctly.""" baseline_path, target_path = _create_non_tensor_dumps( tmp_path, name="sm_scale", baseline_value=0.125, target_value=0.125 ) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) non_tensors = _get_non_tensors(records) assert len(non_tensors) == 1 json_str: str = non_tensors[0].model_dump_json() roundtripped = parse_record_json(json_str) assert isinstance(roundtripped, ComparisonNonTensorRecord) assert roundtripped.name == "sm_scale" assert roundtripped.values_equal is True # ───────────────────── Visualization integration tests ───────────────────── class TestEntrypointVisualize: """Test --visualize-bundle-details integration.""" @pytest.fixture(autouse=True) def _skip_if_no_matplotlib(self) -> None: pytest.importorskip("matplotlib") def test_visualize_creates_pngs(self, tmp_path, capsys): """--visualize-bundle-details with --filter produces PNG files.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a", "tensor_b"]) viz_dir = tmp_path / "viz_out" argv = _make_argv( baseline_path, target_path, preset="raw", filter="tensor_a", viz_bundle_details=True, viz_output_dir=str(viz_dir), ) records, _ = _run_and_parse(argv, capsys) assert len(_get_comparisons(records)) == 1 png_files = list(viz_dir.glob("*.png")) assert len(png_files) == 1 assert png_files[0].stat().st_size > 0 def test_no_visualize_no_png(self, tmp_path, capsys): """Without --visualize-bundle-details, no PNGs are created.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) viz_dir = tmp_path / "viz_out" argv = _make_argv( baseline_path, target_path, preset="raw", viz_bundle_details=False, viz_output_dir=str(viz_dir), ) _run_and_parse(argv, capsys) assert not viz_dir.exists() or len(list(viz_dir.glob("*.png"))) == 0 # --------------------------- Assertion helpers ------------------- def _get_comparisons(records: list[AnyRecord]) -> list[ComparisonTensorRecord]: return [r for r in records if isinstance(r, ComparisonTensorRecord)] def _get_non_tensors(records: list[AnyRecord]) -> list[ComparisonNonTensorRecord]: return [r for r in records if isinstance(r, ComparisonNonTensorRecord)] def _assert_single_comparison_passed( records: list[AnyRecord], ) -> ComparisonTensorRecord: comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].diff is not None assert comparisons[0].diff.passed return comparisons[0] # --------------------------- Utils ------------------------------ def _make_dumper(directory: Path) -> _Dumper: return _Dumper(config=DumperConfig(enable=True, dir=str(directory))) def _create_dumps( tmp_path: Path, tensor_names: list[str], *, baseline_names: list[str] | None = None, num_steps: int = 1, ) -> tuple[Path, Path]: """Create baseline and target dump directories with given tensor names. If baseline_names is None, uses the same names as tensor_names. Each step dumps all names with the same tensor (different per baseline/target). """ if baseline_names is None: baseline_names = tensor_names d_baseline = tmp_path / "baseline" d_target = tmp_path / "target" d_baseline.mkdir() d_target.mkdir() torch.manual_seed(42) baseline_tensor = torch.randn(10, 10) target_tensor = baseline_tensor + torch.randn(10, 10) * 0.01 exp_paths: list[Path] = [] for d, names, tensor in [ (d_baseline, baseline_names, baseline_tensor), (d_target, tensor_names, target_tensor), ]: dumper = _make_dumper(d) for _ in range(num_steps): for name in names: dumper.dump(name, tensor) dumper.step() exp_paths.append(d / dumper._config.exp_name) return exp_paths[0], exp_paths[1] def _create_non_tensor_rank_dump( directory: Path, *, rank: int, name: str, value: object, extra_tensor_dumps: list[tuple[str, torch.Tensor]] | None = None, ) -> Path: with pytest.MonkeyPatch.context() as mp: mp.setattr(_dumper_module, "_get_rank", lambda: rank) dumper = _Dumper( config=DumperConfig( enable=True, dir=str(directory), exp_name=_FIXED_EXP_NAME, ) ) dumper.__dict__["_static_meta"] = {"world_rank": rank, "world_size": 1} dumper.dump(name, value) for extra_name, extra_tensor in extra_tensor_dumps or []: dumper.dump(extra_name, extra_tensor) dumper.step() return directory / _FIXED_EXP_NAME def _create_non_tensor_dumps( tmp_path: Path, *, name: str, baseline_value: object, target_value: object, ) -> tuple[Path, Path]: baseline_dir = tmp_path / "baseline" target_dir = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() baseline_path = _create_non_tensor_rank_dump( baseline_dir, rank=0, name=name, value=baseline_value ) target_path = _create_non_tensor_rank_dump( target_dir, rank=0, name=name, value=target_value ) return baseline_path, target_path def _make_argv( baseline_path: Path, target_path: Path, *, preset: str | None = None, grouping_skip_keys: list[str] | None = None, token_aligner: str | None = None, diff_threshold: float = 1e-3, output_format: str = "json", start_step: int | None = None, end_step: int | None = None, filter: str | None = None, override_dims: list[str] | None = None, override_baseline_dims: list[str] | None = None, override_target_dims: list[str] | None = None, override_config: str | None = None, allow_skipped_pattern: str | None = None, allow_failed_pattern: str | None = None, report_path: str | None = "", viz_bundle_details: bool = False, viz_output_dir: str | None = None, visualize_per_token: str | None = None, ) -> list[str]: argv: list[str] = [ "--baseline-path", str(baseline_path), "--target-path", str(target_path), "--diff-threshold", str(diff_threshold), "--output-format", output_format, ] if preset is not None: argv += ["--preset", preset] if grouping_skip_keys is not None: argv += ["--grouping-skip-keys"] + grouping_skip_keys if token_aligner is not None: argv += ["--token-aligner", token_aligner] if start_step is not None: argv += ["--start-step", str(start_step)] if end_step is not None: argv += ["--end-step", str(end_step)] if filter is not None: argv += ["--filter", filter] for dim in override_dims or []: argv += ["--override-dims", dim] for dim in override_baseline_dims or []: argv += ["--override-baseline-dims", dim] for dim in override_target_dims or []: argv += ["--override-target-dims", dim] if override_config is not None: argv += ["--override-config", override_config] if allow_skipped_pattern is not None: argv += ["--allow-skipped-pattern", allow_skipped_pattern] if allow_failed_pattern is not None: argv += ["--allow-failed-pattern", allow_failed_pattern] if report_path is not None: argv += ["--report-path", report_path] if viz_bundle_details: argv += ["--viz-bundle-details"] if viz_output_dir is not None: argv += ["--viz-output-dir", viz_output_dir] if visualize_per_token is not None: argv += ["--visualize-per-token", visualize_per_token] return argv def _run_and_parse( argv: list[str], capsys: pytest.CaptureFixture ) -> tuple[list[AnyRecord], int]: args: Namespace = parse_args(argv) capsys.readouterr() exit_code: int = run(args) return _parse_jsonl(capsys.readouterr().out), exit_code def _parse_jsonl(output: str) -> list[AnyRecord]: return [parse_record_json(line) for line in output.strip().splitlines()] def _create_rank_dump( directory: Path, *, rank: int, name: str, tensor: torch.Tensor, dims: str | None = None, parallel_info: dict | None = None, framework: str = "sglang", num_steps: int = 1, extra_dumps: list[tuple[str, object]] | None = None, ) -> Path: """Create a dump file via the real dumper, as if running on the given rank. extra_dumps: additional (name, value) pairs to dump alongside the main tensor each step. """ with pytest.MonkeyPatch.context() as mp: mp.setattr(_dumper_module, "_get_rank", lambda: rank) dumper = _Dumper( config=DumperConfig( enable=True, dir=str(directory), exp_name=_FIXED_EXP_NAME, ) ) static_meta: dict = {"world_rank": rank, "world_size": 1} if parallel_info is not None: static_meta[f"{framework}_parallel_info"] = parallel_info dumper.__dict__["_static_meta"] = static_meta for _ in range(num_steps): dumper.dump(name, tensor, dims=dims) for extra_name, extra_value in extra_dumps or []: dumper.dump(extra_name, extra_value) dumper.step() return directory / _FIXED_EXP_NAME def _create_multi_step_rank_dump( directory: Path, *, rank: int, name: str, tensors_per_step: list[torch.Tensor], dims: str | None = None, parallel_info: dict | None = None, framework: str = "sglang", ) -> Path: """Create a dump file with *different* tensors per step. Unlike ``_create_rank_dump`` (which repeats the same tensor), this helper accepts a list of tensors — one per step. """ with pytest.MonkeyPatch.context() as mp: mp.setattr(_dumper_module, "_get_rank", lambda: rank) dumper = _Dumper( config=DumperConfig( enable=True, dir=str(directory), exp_name=_FIXED_EXP_NAME, ) ) static_meta: dict = {"world_rank": rank, "world_size": 1} if parallel_info is not None: static_meta[f"{framework}_parallel_info"] = parallel_info dumper.__dict__["_static_meta"] = static_meta for tensor in tensors_per_step: dumper.dump(name, tensor, dims=dims) dumper.step() return directory / _FIXED_EXP_NAME def _create_cp_tp_sharded_dumps( directory: Path, *, full_tensor: torch.Tensor, name: str, cp_size: int, tp_size: int, seq_dim: int, head_dim: int, dims_str: str, num_steps: int = 1, ) -> Path: """Create CP+TP multi-axis sharded dump files from a full tensor.""" cp_chunks = list(full_tensor.chunk(cp_size, dim=seq_dim)) rank = 0 for cp_rank in range(cp_size): tp_chunks = list(cp_chunks[cp_rank].chunk(tp_size, dim=head_dim)) for tp_rank in range(tp_size): _create_rank_dump( directory, rank=rank, name=name, tensor=tp_chunks[tp_rank], dims=dims_str, parallel_info={ "cp_rank": cp_rank, "cp_size": cp_size, "tp_rank": tp_rank, "tp_size": tp_size, }, num_steps=num_steps, ) rank += 1 return directory / _FIXED_EXP_NAME def _create_ep_cp_tp_sharded_dumps( directory: Path, *, full_tensor: torch.Tensor, name: str, ep_size: int, cp_size: int, tp_size: int, expert_dim: int, seq_dim: int, head_dim: int, dims_str: str, num_steps: int = 1, ) -> Path: """Create EP+CP+TP three-axis sharded dump files from a full tensor.""" ep_chunks = list(full_tensor.chunk(ep_size, dim=expert_dim)) rank = 0 for ep_rank in range(ep_size): cp_chunks = list(ep_chunks[ep_rank].chunk(cp_size, dim=seq_dim)) for cp_rank in range(cp_size): tp_chunks = list(cp_chunks[cp_rank].chunk(tp_size, dim=head_dim)) for tp_rank in range(tp_size): _create_rank_dump( directory, rank=rank, name=name, tensor=tp_chunks[tp_rank], dims=dims_str, parallel_info={ "ep_rank": ep_rank, "ep_size": ep_size, "cp_rank": cp_rank, "cp_size": cp_size, "tp_rank": tp_rank, "tp_size": tp_size, }, num_steps=num_steps, ) rank += 1 return directory / _FIXED_EXP_NAME def _create_cp_zigzag_tp_sharded_dumps( directory: Path, *, full_tensor: torch.Tensor, name: str, cp_size: int, tp_size: int, seq_dim: int, head_dim: int, dims_str: str, num_steps: int = 1, ) -> Path: """Create CP-zigzag (+optional TP) sharded dump files from a full tensor.""" num_chunks: int = cp_size * 2 natural_chunks: list[torch.Tensor] = list( full_tensor.chunk(num_chunks, dim=seq_dim) ) zigzag_order: list[int] = [] for i in range(cp_size): zigzag_order.append(i) zigzag_order.append(num_chunks - 1 - i) zigzagged: torch.Tensor = torch.cat( [natural_chunks[idx] for idx in zigzag_order], dim=seq_dim ) cp_chunks: list[torch.Tensor] = list(zigzagged.chunk(cp_size, dim=seq_dim)) rank: int = 0 for cp_rank in range(cp_size): tp_chunks: list[torch.Tensor] = ( list(cp_chunks[cp_rank].chunk(tp_size, dim=head_dim)) if tp_size > 1 else [cp_chunks[cp_rank]] ) for tp_rank in range(tp_size): parallel_info: dict[str, int] = { "cp_rank": cp_rank, "cp_size": cp_size, } if tp_size > 1: parallel_info["tp_rank"] = tp_rank parallel_info["tp_size"] = tp_size _create_rank_dump( directory, rank=rank, name=name, tensor=tp_chunks[tp_rank], dims=dims_str, parallel_info=parallel_info, num_steps=num_steps, ) rank += 1 return directory / _FIXED_EXP_NAME def _create_cp_zigzag_sp_sharded_dumps( directory: Path, *, full_tensor: torch.Tensor, name: str, cp_size: int, sp_size: int, dims_str: str, seq_dim: int = 1, num_steps: int = 1, ) -> Path: """Create CP-zigzag + SP sharded dump files for a seq dim (b s h format). Shard order (outer to inner, matching left-to-right in dims annotation): 1. CP zigzag splits seq dim into cp_size chunks (zigzag order) 2. SP splits each CP chunk into sp_size chunks """ num_chunks: int = cp_size * 2 natural_chunks: list[torch.Tensor] = list( full_tensor.chunk(num_chunks, dim=seq_dim) ) zigzag_order: list[int] = [] for i in range(cp_size): zigzag_order.append(i) zigzag_order.append(num_chunks - 1 - i) zigzagged: torch.Tensor = torch.cat( [natural_chunks[idx] for idx in zigzag_order], dim=seq_dim ) cp_chunks: list[torch.Tensor] = list(zigzagged.chunk(cp_size, dim=seq_dim)) rank: int = 0 for cp_rank in range(cp_size): sp_chunks: list[torch.Tensor] = list( cp_chunks[cp_rank].chunk(sp_size, dim=seq_dim) ) for sp_rank in range(sp_size): _create_rank_dump( directory, rank=rank, name=name, tensor=sp_chunks[sp_rank], dims=dims_str, parallel_info={ "cp_rank": cp_rank, "cp_size": cp_size, "sp_rank": sp_rank, "sp_size": sp_size, }, num_steps=num_steps, ) rank += 1 return directory / _FIXED_EXP_NAME def _create_replicated_tp_sharded_cp_dumps( directory: Path, *, full_tensor: torch.Tensor, name: str, cp_size: int, tp_size: int, seq_dim: int, dims_str: str, tp_noise: float = 0.0, ) -> Path: """Create CP-sharded + TP-replicated dump files from a full tensor. CP direction: chunks along seq_dim (sharded). TP direction: clones (replicated), with optional noise to simulate mismatch. """ cp_chunks: list[torch.Tensor] = list(full_tensor.chunk(cp_size, dim=seq_dim)) rank: int = 0 for cp_rank in range(cp_size): for tp_rank in range(tp_size): shard = cp_chunks[cp_rank].clone() if tp_noise > 0 and tp_rank > 0: shard = shard + torch.randn_like(shard) * tp_noise _create_rank_dump( directory, rank=rank, name=name, tensor=shard, dims=dims_str, parallel_info={ "cp_rank": cp_rank, "cp_size": cp_size, "tp_rank": tp_rank, "tp_size": tp_size, }, ) rank += 1 return directory / _FIXED_EXP_NAME def _create_tp_sharded_dumps( directory: Path, *, full_tensor: torch.Tensor, name: str, tp_size: int, shard_dim: int, dims_str: str, num_steps: int = 1, ) -> Path: """Create TP-sharded dump files from a full tensor via the real dumper.""" shards = list(full_tensor.chunk(tp_size, dim=shard_dim)) for tp_rank in range(tp_size): _create_rank_dump( directory, rank=tp_rank, name=name, tensor=shards[tp_rank], dims=dims_str, parallel_info={"tp_rank": tp_rank, "tp_size": tp_size}, num_steps=num_steps, ) return directory / _FIXED_EXP_NAME def _create_multi_step_tp_sharded_dumps( directory: Path, *, full_tensors_per_step: list[torch.Tensor], name: str, tp_size: int, shard_dim: int, dims_str: str, ) -> Path: """Create TP-sharded dump files with *different* tensors per step. Each step's full tensor is chunked across TP ranks, then ``_create_multi_step_rank_dump`` writes one file per rank. """ shards_per_rank: list[list[torch.Tensor]] = [[] for _ in range(tp_size)] for full_tensor in full_tensors_per_step: shards = list(full_tensor.chunk(tp_size, dim=shard_dim)) for tp_rank in range(tp_size): shards_per_rank[tp_rank].append(shards[tp_rank]) for tp_rank in range(tp_size): _create_multi_step_rank_dump( directory, rank=tp_rank, name=name, tensors_per_step=shards_per_rank[tp_rank], dims=dims_str, parallel_info={"tp_rank": tp_rank, "tp_size": tp_size}, ) return directory / _FIXED_EXP_NAME def _create_tp_partial_dumps( directory: Path, *, full_tensor: torch.Tensor, name: str, tp_size: int, dims_str: str, num_steps: int = 1, ) -> Path: """Create TP-partial dump files where each rank holds full_tensor / tp_size. Each rank stores an equal fraction of the full tensor so that element-wise summation across ranks reconstructs the original. """ for tp_rank in range(tp_size): _create_rank_dump( directory, rank=tp_rank, name=name, tensor=full_tensor / tp_size, dims=dims_str, parallel_info={"tp_rank": tp_rank, "tp_size": tp_size}, num_steps=num_steps, ) return directory / _FIXED_EXP_NAME def _create_recompute_rank_dump( directory: Path, *, rank: int, name: str, original_tensor: torch.Tensor, recompute_tensor: torch.Tensor, dims: str = "h d", ) -> Path: """Create a dump with both original and recompute forward passes via monkeypatched dumper. The dumper naturally produces recompute_pseudo_rank=0 for original and =1 for recompute, plus recompute_pseudo_size=2. """ with pytest.MonkeyPatch.context() as mp: mp.setattr(_dumper_module, "_get_rank", lambda: rank) dumper = _Dumper( config=DumperConfig( enable=True, dir=str(directory), exp_name=_FIXED_EXP_NAME, ) ) dumper.__dict__["_static_meta"] = {"world_rank": rank, "world_size": 1} # dump original forward mp.setattr( _dumper_module, "_detect_recompute_status", lambda: _RecomputeStatus.ORIGINAL, ) dumper.dump(name, original_tensor, dims=dims) # dump recompute forward mp.setattr( _dumper_module, "_detect_recompute_status", lambda: _RecomputeStatus.RECOMPUTE, ) dumper.dump(name, recompute_tensor, dims=dims) dumper.step() return directory / _FIXED_EXP_NAME def _zigzag_split_seq(seq_natural: torch.Tensor, *, cp_size: int) -> list[torch.Tensor]: """Split a natural-order seq into per-rank zigzag segments.""" num_chunks: int = cp_size * 2 chunks: list[torch.Tensor] = list(seq_natural.chunk(num_chunks, dim=0)) order: list[int] = [] for i in range(cp_size): order.append(i) order.append(num_chunks - 1 - i) zigzagged: torch.Tensor = torch.cat([chunks[i] for i in order], dim=0) return list(zigzagged.chunk(cp_size, dim=0)) def _create_thd_cp_zigzag_dumps( directory: Path, *, full_tensor: torch.Tensor, name: str, seq_lens: list[int], cp_size: int, total_per_rank: int, dims_str: str = "t[cp:zigzag]", num_steps: int = 1, ) -> Path: """Create THD CP-zigzag sharded dump files simulating Megatron forward. Args: full_tensor: 1D tensor of shape [T] in natural order. seq_lens: per-seq token counts in natural order (e.g. [100, 64]). cp_size: context parallelism size. total_per_rank: total tokens per rank (including padding). dims_str: dims annotation for the main tensor. """ # Build per-rank tensors from natural-order full_tensor offset: int = 0 rank_segments: list[list[torch.Tensor]] = [[] for _ in range(cp_size)] for seq_len in seq_lens: seq_natural: torch.Tensor = full_tensor[offset : offset + seq_len] seq_ranks: list[torch.Tensor] = _zigzag_split_seq(seq_natural, cp_size=cp_size) for rank_idx in range(cp_size): rank_segments[rank_idx].append(seq_ranks[rank_idx]) offset += seq_len # Build cu_seqlens from seq_lens (global, replicated across ranks) cu_seqlens_values: list[int] = [0] for slen in seq_lens: cu_seqlens_values.append(cu_seqlens_values[-1] + slen) # Pad to total_per_rank per rank (global pad = last cu_seqlens entry to total_per_rank * cp_size) total_global: int = total_per_rank * cp_size if cu_seqlens_values[-1] < total_global: pad_global: int = total_global - cu_seqlens_values[-1] cu_seqlens_values.append(total_global) pad_per_rank: int = pad_global // cp_size for rank_idx in range(cp_size): rank_segments[rank_idx].append(torch.zeros(pad_per_rank)) cu_seqlens_q: torch.Tensor = torch.tensor(cu_seqlens_values, dtype=torch.int64) # Dump each rank for cp_rank in range(cp_size): rank_tensor: torch.Tensor = torch.cat(rank_segments[cp_rank], dim=0) assert ( rank_tensor.shape[0] == total_per_rank ), f"rank {cp_rank}: expected {total_per_rank} tokens, got {rank_tensor.shape[0]}" _create_rank_dump( directory, rank=cp_rank, name=name, tensor=rank_tensor, dims=dims_str, parallel_info={ "cp_rank": cp_rank, "cp_size": cp_size, }, framework="megatron", num_steps=num_steps, extra_dumps=[ ("cu_seqlens_q", cu_seqlens_q), ("input_ids", rank_tensor.to(torch.int64)), ], ) return directory / _FIXED_EXP_NAME class TestEntrypointPerTokenVisualization: """Test --visualize-per-token CLI flag integration.""" def test_visualize_per_token_creates_png(self, tmp_path: Path, capsys) -> None: """--visualize-per-token with dims metadata produces per-token data in records.""" pytest.importorskip("matplotlib") torch.manual_seed(42) baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() baseline_tensor: torch.Tensor = torch.randn(10, 10) target_tensor: torch.Tensor = baseline_tensor + torch.randn(10, 10) * 0.01 for name in ["tensor_a", "tensor_b"]: _create_rank_dump( baseline_dir, rank=0, name=name, tensor=baseline_tensor, dims="t h", ) _create_rank_dump( target_dir, rank=0, name=name, tensor=target_tensor, dims="t h", ) baseline_path: Path = baseline_dir / _FIXED_EXP_NAME target_path: Path = target_dir / _FIXED_EXP_NAME output_png: Path = tmp_path / "per_token.png" argv = _make_argv( baseline_path, target_path, preset="raw", visualize_per_token=str(output_png), ) records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 2 # per_token_rel_diff should be populated for comp in comparisons: assert comp.diff is not None assert comp.diff.per_token_rel_diff is not None assert isinstance(comp.diff.per_token_rel_diff, list) assert len(comp.diff.per_token_rel_diff) == 10 def test_no_visualize_no_per_token(self, tmp_path: Path, capsys) -> None: """Without --visualize-per-token, per_token_rel_diff is None.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) argv = _make_argv(baseline_path, target_path, preset="raw") records, _ = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].diff is not None assert comparisons[0].diff.per_token_rel_diff is None class TestEntrypointThdCpZigzag: """E2E entrypoint tests for THD CP zigzag format. Tests the full pipeline: dump creation → metadata loading → aligner plan → unshard + reorder → tensor comparison. """ def test_sglang_vs_megatron_zigzag_cp(self, tmp_path: Path, capsys) -> None: """SGLang single-rank THD baseline vs Megatron CP=2 zigzag target.""" torch.manual_seed(42) hidden_dim: int = 8 cp_size: int = 2 # Two sequences: 8 and 4 tokens (divisible by cp_size*2=4 for clean zigzag) seq_a_ids: list[int] = [10, 20, 30, 40, 50, 60, 70, 80] seq_b_ids: list[int] = [100, 200, 300, 400] all_ids: list[int] = seq_a_ids + seq_b_ids total_tokens: int = len(all_ids) seq_lens: list[int] = [len(seq_a_ids), len(seq_b_ids)] hidden_states: torch.Tensor = torch.randn(total_tokens, hidden_dim) # --- SGLang baseline: single rank, 1 step --- sglang_dir: Path = tmp_path / "baseline" sglang_dir.mkdir() sglang_dumper = _Dumper( config=DumperConfig( enable=True, dir=str(sglang_dir), exp_name=_FIXED_EXP_NAME, ) ) positions: list[int] = list(range(seq_lens[0])) + list(range(seq_lens[1])) sglang_dumper.dump("input_ids", torch.tensor(all_ids)) sglang_dumper.dump("positions", torch.tensor(positions)) sglang_dumper.dump("seq_lens", torch.tensor(seq_lens)) sglang_dumper.dump("rids", ["A", "B"]) sglang_dumper.dump("hidden_states", hidden_states) sglang_dumper.step() # --- Megatron target: CP=2, zigzag, 1 step --- megatron_dir: Path = tmp_path / "target" megatron_dir.mkdir() # Zigzag-split input_ids and hidden_states per sequence, then concat ids_tensor: torch.Tensor = torch.tensor(all_ids, dtype=torch.int64) offset: int = 0 rank_id_segments: list[list[torch.Tensor]] = [[] for _ in range(cp_size)] rank_hidden_segments: list[list[torch.Tensor]] = [[] for _ in range(cp_size)] for slen in seq_lens: seq_ids: torch.Tensor = ids_tensor[offset : offset + slen] seq_hidden: torch.Tensor = hidden_states[offset : offset + slen] zigzag_ids: list[torch.Tensor] = _zigzag_split_seq(seq_ids, cp_size=cp_size) zigzag_hidden: list[torch.Tensor] = _zigzag_split_seq( seq_hidden, cp_size=cp_size ) for rank_idx in range(cp_size): rank_id_segments[rank_idx].append(zigzag_ids[rank_idx]) rank_hidden_segments[rank_idx].append(zigzag_hidden[rank_idx]) offset += slen cu_seqlens_q: torch.Tensor = torch.tensor( [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], dtype=torch.int64, ) for cp_rank in range(cp_size): rank_ids: torch.Tensor = torch.cat(rank_id_segments[cp_rank]) rank_hidden: torch.Tensor = torch.cat(rank_hidden_segments[cp_rank]) _create_rank_dump( megatron_dir, rank=cp_rank, name="hidden_states", tensor=rank_hidden, dims="t[cp:zigzag] h", parallel_info={"cp_rank": cp_rank, "cp_size": cp_size}, framework="megatron", extra_dumps=[ ("cu_seqlens_q", cu_seqlens_q), ("input_ids", rank_ids), ], ) # --- Run comparison --- argv: list[str] = _make_argv( sglang_dir / _FIXED_EXP_NAME, megatron_dir / _FIXED_EXP_NAME, grouping_skip_keys=["rank", "step"], token_aligner="smart", diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comparisons: list[ComparisonTensorRecord] = _get_comparisons(records) hidden_comparisons: list[ComparisonTensorRecord] = [ c for c in comparisons if c.name == "hidden_states" ] assert len(hidden_comparisons) >= 1 assert all(c.diff is not None and c.diff.passed for c in hidden_comparisons) def test_thd_cp_zigzag_unshard(self, tmp_path: Path, capsys) -> None: """Both sides THD CP=2 zigzag, comparison should pass.""" torch.manual_seed(42) cp_size: int = 2 seq_lens: list[int] = [100, 64] total_tokens: int = sum(seq_lens) total_per_rank: int = 128 full_tensor: torch.Tensor = torch.randn(total_tokens + 92) baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() baseline_path: Path = _create_thd_cp_zigzag_dumps( baseline_dir, full_tensor=full_tensor, name="hidden_states", seq_lens=seq_lens, cp_size=cp_size, total_per_rank=total_per_rank, ) # Target: same data with small noise target_tensor: torch.Tensor = full_tensor + torch.randn_like(full_tensor) * 1e-5 target_path: Path = _create_thd_cp_zigzag_dumps( target_dir, full_tensor=target_tensor, name="hidden_states", seq_lens=seq_lens, cp_size=cp_size, total_per_rank=total_per_rank, ) argv: list[str] = _make_argv( baseline_path, target_path, grouping_skip_keys=["rank", "step"], token_aligner="smart", diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) # hidden_states should pass comparison (after unshard + reorder) comparisons: list[ComparisonTensorRecord] = _get_comparisons(records) hidden_comparisons: list[ComparisonTensorRecord] = [ c for c in comparisons if c.name == "hidden_states" ] assert len(hidden_comparisons) >= 1 assert all(c.diff is not None and c.diff.passed for c in hidden_comparisons) class TestEntrypointDpFilter: """E2E tests for DP (data parallel) filtering. When DP > 1, only one dp_rank has non-empty tensors; the others dump empty (numel=0) tensors. The comparator should filter out the empty dp_rank items and produce correct comparison results. """ def test_dp2_sglang_both_sides(self, tmp_path: Path, capsys) -> None: """DP=2 sglang: both baseline and target have 1 non-empty + 1 empty dp_rank.""" torch.manual_seed(42) tensor_data: torch.Tensor = torch.randn(10, 8) target_data: torch.Tensor = tensor_data + torch.randn(10, 8) * 0.001 for side, side_dir_name, data in [ ("baseline", "baseline", tensor_data), ("target", "target", target_data), ]: side_dir: Path = tmp_path / side_dir_name side_dir.mkdir() # dp_rank=0: non-empty tensor _create_rank_dump( side_dir, rank=0, name="hidden", tensor=data, dims="t h", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": 0, "dp_size": 2, }, framework="sglang", ) # dp_rank=1: empty tensor _create_rank_dump( side_dir, rank=1, name="hidden", tensor=torch.empty(0, 8), dims="t h", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": 1, "dp_size": 2, }, framework="sglang", ) argv: list[str] = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comparison: ComparisonTensorRecord = _assert_single_comparison_passed(records) assert comparison.name == "hidden" def test_dp2_megatron_both_sides(self, tmp_path: Path, capsys) -> None: """DP=2 megatron: both baseline and target have 1 non-empty + 1 empty dp_rank.""" torch.manual_seed(42) tensor_data: torch.Tensor = torch.randn(10, 8) target_data: torch.Tensor = tensor_data + torch.randn(10, 8) * 0.001 for side, side_dir_name, data in [ ("baseline", "baseline", tensor_data), ("target", "target", target_data), ]: side_dir: Path = tmp_path / side_dir_name side_dir.mkdir() # dp_rank=0: non-empty tensor _create_rank_dump( side_dir, rank=0, name="hidden", tensor=data, dims="t h", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": 0, "dp_size": 2, }, framework="megatron", ) # dp_rank=1: empty tensor _create_rank_dump( side_dir, rank=1, name="hidden", tensor=torch.empty(0, 8), dims="t h", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": 1, "dp_size": 2, }, framework="megatron", ) argv: list[str] = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comparison: ComparisonTensorRecord = _assert_single_comparison_passed(records) assert comparison.name == "hidden" def test_dp2_tp2_sglang(self, tmp_path: Path, capsys) -> None: """DP=2 x TP=2 sglang: 4 ranks, dp_rank=0 has data, dp_rank=1 empty.""" torch.manual_seed(42) full_tensor: torch.Tensor = torch.randn(10, 8) tp_chunks: list[torch.Tensor] = list(full_tensor.chunk(2, dim=1)) target_full: torch.Tensor = full_tensor + torch.randn(10, 8) * 0.001 target_tp_chunks: list[torch.Tensor] = list(target_full.chunk(2, dim=1)) for side, side_dir_name, chunks in [ ("baseline", "baseline", tp_chunks), ("target", "target", target_tp_chunks), ]: side_dir: Path = tmp_path / side_dir_name side_dir.mkdir() rank: int = 0 for dp_rank in range(2): for tp_rank in range(2): tensor: torch.Tensor = ( chunks[tp_rank] if dp_rank == 0 else torch.empty(0, 4) ) _create_rank_dump( side_dir, rank=rank, name="hidden", tensor=tensor, dims="t h[tp]", parallel_info={ "tp_rank": tp_rank, "tp_size": 2, "dp_rank": dp_rank, "dp_size": 2, }, framework="sglang", ) rank += 1 argv: list[str] = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comparison: ComparisonTensorRecord = _assert_single_comparison_passed(records) assert comparison.name == "hidden" def test_dp2_both_nonempty_raises(self, tmp_path: Path, capsys) -> None: """DP=2 sglang: both dp_rank=0 and dp_rank=1 have non-empty tensors => AssertionError.""" torch.manual_seed(42) tensor_data: torch.Tensor = torch.randn(10, 8) target_data: torch.Tensor = tensor_data + torch.randn(10, 8) * 0.001 for side, side_dir_name, data in [ ("baseline", "baseline", tensor_data), ("target", "target", target_data), ]: side_dir: Path = tmp_path / side_dir_name side_dir.mkdir() for dp_rank in range(2): _create_rank_dump( side_dir, rank=dp_rank, name="hidden", tensor=data, dims="t h", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": dp_rank, "dp_size": 2, }, framework="sglang", ) argv: list[str] = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, exit_code = _run_and_parse(argv, capsys) errors = [r for r in records if isinstance(r, ComparisonErrorRecord)] assert len(errors) == 1 assert errors[0].exception_type == "AssertionError" assert "Expected exactly 1 non-empty dp_rank" in errors[0].traceback_str assert exit_code == 1 class TestEntrypointDpGroupAlias: """E2E tests for the ``# dp:=`` dp group alias feature. In dp_attn mode, dp_size > 1 but MLP tensors after dp_gather have data on all ranks. With ``# dp:=moe_dp`` in dims, the dp filter uses ``moe_dp_rank/moe_dp_size`` instead of ``dp_rank/dp_size``. """ def test_dp_alias_absent_group_noop(self, tmp_path: Path, capsys) -> None: """Single rank with ``# dp:=moe_dp`` in dims → parse_dims strips ``#``, comparison OK.""" torch.manual_seed(42) tensor_data: torch.Tensor = torch.randn(10, 8) target_data: torch.Tensor = tensor_data + torch.randn(10, 8) * 0.001 for side_dir_name, data in [("baseline", tensor_data), ("target", target_data)]: side_dir: Path = tmp_path / side_dir_name side_dir.mkdir() _create_rank_dump( side_dir, rank=0, name="hidden", tensor=data, dims="t h # dp:=moe_dp", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": 0, "dp_size": 1, }, framework="sglang", ) argv: list[str] = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comparison: ComparisonTensorRecord = _assert_single_comparison_passed(records) assert comparison.name == "hidden" def test_dp_alias_via_override_dims(self, tmp_path: Path, capsys) -> None: """--override-dims adds ``# dp:=moe_dp`` → dp filter uses alias, filters correctly.""" torch.manual_seed(42) tensor_data: torch.Tensor = torch.randn(10, 8) target_data: torch.Tensor = tensor_data + torch.randn(10, 8) * 0.001 for side_dir_name, data in [("baseline", tensor_data), ("target", target_data)]: side_dir: Path = tmp_path / side_dir_name side_dir.mkdir() # moe_dp_rank=0: non-empty _create_rank_dump( side_dir, rank=0, name="hidden", tensor=data, dims="t h", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": 0, "dp_size": 1, "moe_dp_rank": 0, "moe_dp_size": 2, }, framework="sglang", ) # moe_dp_rank=1: empty _create_rank_dump( side_dir, rank=1, name="hidden", tensor=torch.empty(0, 8), dims="t h", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": 0, "dp_size": 1, "moe_dp_rank": 1, "moe_dp_size": 2, }, framework="sglang", ) argv: list[str] = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=1e-3, override_dims=["hidden:t h # dp:=moe_dp"], ) records, _ = _run_and_parse(argv, capsys) comparison: ComparisonTensorRecord = _assert_single_comparison_passed(records) assert comparison.name == "hidden" def test_dp_alias_with_real_alias_group_filters( self, tmp_path: Path, capsys ) -> None: """Alias group present with moe_dp_size=2, one empty rank → filters correctly.""" torch.manual_seed(42) tensor_data: torch.Tensor = torch.randn(10, 8) target_data: torch.Tensor = tensor_data + torch.randn(10, 8) * 0.001 for side_dir_name, data in [("baseline", tensor_data), ("target", target_data)]: side_dir: Path = tmp_path / side_dir_name side_dir.mkdir() for moe_dp_rank in range(2): tensor: torch.Tensor = data if moe_dp_rank == 0 else torch.empty(0, 8) _create_rank_dump( side_dir, rank=moe_dp_rank, name="hidden", tensor=tensor, dims="t h # dp:=moe_dp", parallel_info={ "tp_rank": 0, "tp_size": 1, "dp_rank": 0, "dp_size": 1, "moe_dp_rank": moe_dp_rank, "moe_dp_size": 2, }, framework="sglang", ) argv: list[str] = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, _ = _run_and_parse(argv, capsys) comparison: ComparisonTensorRecord = _assert_single_comparison_passed(records) assert comparison.name == "hidden" class TestEntrypointMetaOverride: """E2E: dump with wrong dims → --override-dims / --override-config corrects at comparison time.""" @staticmethod def _create_single_rank_pair( tmp_path: Path, *, name: str = "hidden", baseline_dims: str | None = "x y", target_dims: str | None = "x y", ) -> tuple[Path, Path]: """Create single-rank baseline+target dumps with a close tensor pair.""" torch.manual_seed(42) tensor: torch.Tensor = torch.randn(10, 8) target: torch.Tensor = tensor + torch.randn(10, 8) * 0.001 baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() _create_rank_dump( baseline_dir, rank=0, name=name, tensor=tensor, dims=baseline_dims ) _create_rank_dump( target_dir, rank=0, name=name, tensor=target, dims=target_dims ) return baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME @staticmethod def _assert_all_passed( records: list[AnyRecord], *, expected_count: int = 1 ) -> None: """Assert that exactly expected_count comparisons exist and all passed.""" comparisons: list[ComparisonTensorRecord] = _get_comparisons(records) assert len(comparisons) == expected_count assert all(c.diff is not None and c.diff.passed for c in comparisons) def test_override_dims_fixes_wrong_dims(self, tmp_path: Path, capsys) -> None: """Tensor dumped with wrong dims='h d' is fixed by --override-dims to 't h[tp]'.""" torch.manual_seed(42) full_tensor: torch.Tensor = torch.randn(10, 8) tp_chunks: list[torch.Tensor] = list(full_tensor.chunk(2, dim=1)) target_full: torch.Tensor = full_tensor + torch.randn(10, 8) * 0.001 target_tp_chunks: list[torch.Tensor] = list(target_full.chunk(2, dim=1)) baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() # Dump with WRONG dims "h d" instead of correct "t h[tp]" for tp_rank in range(2): _create_rank_dump( baseline_dir, rank=tp_rank, name="hidden", tensor=tp_chunks[tp_rank], dims="h d", parallel_info={"tp_rank": tp_rank, "tp_size": 2}, ) _create_rank_dump( target_dir, rank=tp_rank, name="hidden", tensor=target_tp_chunks[tp_rank], dims="h d", parallel_info={"tp_rank": tp_rank, "tp_size": 2}, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, override_dims=["hidden:t h[tp]"], ) self._assert_all_passed(_run_and_parse(argv, capsys)[0]) @pytest.mark.parametrize( "baseline_dims, target_dims, override_kwarg", [ ("x y", "t h", {"override_baseline_dims": ["hidden:t h"]}), ("t h", "x y", {"override_target_dims": ["hidden:t h"]}), ("x y", "x y", {"override_dims": ["hidden:t h"]}), ], ids=["baseline_only", "target_only", "both_via_override_dims"], ) def test_single_side_override( self, tmp_path: Path, capsys, baseline_dims: str, target_dims: str, override_kwarg: dict, ) -> None: """Per-side override fixes the wrong dims on one or both sides.""" baseline_path, target_path = self._create_single_rank_pair( tmp_path, baseline_dims=baseline_dims, target_dims=target_dims, ) argv = _make_argv(baseline_path, target_path, preset="raw", **override_kwarg) self._assert_all_passed(_run_and_parse(argv, capsys)[0]) def test_override_config_yaml(self, tmp_path: Path, capsys) -> None: """--override-config YAML overrides dims.""" baseline_path, target_path = self._create_single_rank_pair(tmp_path) yaml_path: Path = tmp_path / "override.yaml" yaml_path.write_text(textwrap.dedent("""\ overrides: - match: "hidden" dims: "t h" """)) argv = _make_argv( baseline_path, target_path, preset="raw", override_config=str(yaml_path), ) self._assert_all_passed(_run_and_parse(argv, capsys)[0]) def test_no_match_uses_original_dims(self, tmp_path: Path, capsys) -> None: """When override regex doesn't match, original dims from dump are used.""" baseline_path, target_path = self._create_single_rank_pair( tmp_path, baseline_dims="t h", target_dims="t h", ) argv = _make_argv( baseline_path, target_path, preset="raw", override_dims=["no_match_pattern:b s d"], ) self._assert_all_passed(_run_and_parse(argv, capsys)[0]) def test_selective_match_multi_tensor(self, tmp_path: Path, capsys) -> None: """Override matches only 'logits'; 'hidden' uses original dims.""" torch.manual_seed(42) baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() hidden_b: torch.Tensor = torch.randn(10, 8) hidden_t: torch.Tensor = hidden_b + torch.randn(10, 8) * 0.001 logits_b: torch.Tensor = torch.randn(10, 4) logits_t: torch.Tensor = logits_b + torch.randn(10, 4) * 0.001 for name, b_tensor, t_tensor, dims in [ ("hidden", hidden_b, hidden_t, "t h"), ("logits", logits_b, logits_t, "x y"), ]: _create_rank_dump( baseline_dir, rank=0, name=name, tensor=b_tensor, dims=dims ) _create_rank_dump(target_dir, rank=0, name=name, tensor=t_tensor, dims=dims) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, preset="raw", override_dims=["logits:t v"], ) self._assert_all_passed(_run_and_parse(argv, capsys)[0], expected_count=2) def test_multiple_cli_override_dims(self, tmp_path: Path, capsys) -> None: """Multiple --override-dims for different tensors.""" torch.manual_seed(42) baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() hidden_b: torch.Tensor = torch.randn(10, 8) hidden_t: torch.Tensor = hidden_b + torch.randn(10, 8) * 0.001 logits_b: torch.Tensor = torch.randn(10, 4) logits_t: torch.Tensor = logits_b + torch.randn(10, 4) * 0.001 for name, b_tensor, t_tensor in [ ("hidden", hidden_b, hidden_t), ("logits", logits_b, logits_t), ]: _create_rank_dump( baseline_dir, rank=0, name=name, tensor=b_tensor, dims="x y" ) _create_rank_dump( target_dir, rank=0, name=name, tensor=t_tensor, dims="x y" ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, preset="raw", override_dims=["hidden:t h", "logits:t v"], ) self._assert_all_passed(_run_and_parse(argv, capsys)[0], expected_count=2) def test_per_side_dims_different_parallelism(self, tmp_path: Path, capsys) -> None: """baseline TP-sharded, target EP-sharded — per-side override fixes both.""" torch.manual_seed(42) full_tensor: torch.Tensor = torch.randn(10, 8) target_full: torch.Tensor = full_tensor + torch.randn(10, 8) * 0.001 baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() b_chunks: list[torch.Tensor] = list(full_tensor.chunk(2, dim=1)) for tp_rank in range(2): _create_rank_dump( baseline_dir, rank=tp_rank, name="hidden", tensor=b_chunks[tp_rank], dims="x y", parallel_info={"tp_rank": tp_rank, "tp_size": 2}, ) t_chunks: list[torch.Tensor] = list(target_full.chunk(2, dim=1)) for ep_rank in range(2): _create_rank_dump( target_dir, rank=ep_rank, name="hidden", tensor=t_chunks[ep_rank], dims="x y", parallel_info={"ep_rank": ep_rank, "ep_size": 2}, ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, override_baseline_dims=["hidden:t h[tp]"], override_target_dims=["hidden:t h[ep]"], ) self._assert_all_passed(_run_and_parse(argv, capsys)[0]) def test_yaml_first_match_wins_e2e(self, tmp_path: Path, capsys) -> None: """YAML with two matching rules: first rule wins in real pipeline.""" baseline_path, target_path = self._create_single_rank_pair(tmp_path) yaml_path: Path = tmp_path / "override.yaml" yaml_path.write_text(textwrap.dedent("""\ overrides: - match: "hidden" dims: "t h" - match: "hidden" dims: "a b" """)) argv = _make_argv( baseline_path, target_path, preset="raw", override_config=str(yaml_path), ) self._assert_all_passed(_run_and_parse(argv, capsys)[0]) def test_cli_overrides_yaml_e2e(self, tmp_path: Path, capsys) -> None: """CLI --override-dims wins over YAML rule for the same tensor.""" baseline_path, target_path = self._create_single_rank_pair(tmp_path) yaml_path: Path = tmp_path / "override.yaml" yaml_path.write_text(textwrap.dedent("""\ overrides: - match: "hidden" dims: "a b" """)) argv = _make_argv( baseline_path, target_path, preset="raw", override_dims=["hidden:t h"], override_config=str(yaml_path), ) self._assert_all_passed(_run_and_parse(argv, capsys)[0]) def test_override_injects_dims_when_absent(self, tmp_path: Path, capsys) -> None: """Override injects dims into meta even when dump had no dims annotation.""" baseline_path, target_path = self._create_single_rank_pair( tmp_path, baseline_dims=None, target_dims=None, ) argv = _make_argv( baseline_path, target_path, preset="raw", override_dims=["hidden:t h"], ) self._assert_all_passed(_run_and_parse(argv, capsys)[0]) def test_non_tensor_unaffected_by_override(self, tmp_path: Path, capsys) -> None: """Non-tensor values pass through without error even with active override.""" torch.manual_seed(42) tensor: torch.Tensor = torch.randn(4, 4) baseline_dir: Path = tmp_path / "baseline" target_dir: Path = tmp_path / "target" baseline_dir.mkdir() target_dir.mkdir() for side_dir in [baseline_dir, target_dir]: _create_non_tensor_rank_dump( side_dir, rank=0, name="sm_scale", value=0.125, extra_tensor_dumps=[("hidden", tensor)], ) argv = _make_argv( baseline_dir / _FIXED_EXP_NAME, target_dir / _FIXED_EXP_NAME, preset="raw", override_dims=["hidden:x y"], ) records, _ = _run_and_parse(argv, capsys) non_tensors: list[ComparisonNonTensorRecord] = [ r for r in records if isinstance(r, ComparisonNonTensorRecord) ] assert len(non_tensors) == 1 assert non_tensors[0].name == "sm_scale" assert non_tensors[0].values_equal comparisons: list[ComparisonTensorRecord] = _get_comparisons(records) assert len(comparisons) == 1 assert comparisons[0].name == "hidden" summary: SummaryRecord = [r for r in records if isinstance(r, SummaryRecord)][0] assert summary.failed == 0 class TestExitCode: """E2E tests for exit code behavior based on comparison results.""" def test_e2e_all_passed_exit_zero(self, tmp_path, capsys): """Integration: all comparisons pass → run() returns 0.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a", "tensor_b"]) argv = _make_argv(baseline_path, target_path, preset="raw") records, exit_code = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 2 assert summary.failed == 0 assert exit_code == 0 def test_e2e_has_failed_exit_nonzero(self, tmp_path, capsys): """Integration: a failed comparison → run() returns 1.""" torch.manual_seed(42) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="tensor_a", tensor=torch.randn(10, 10) ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="tensor_a", tensor=torch.randn(10, 10) * 100, ) argv = _make_argv(baseline_path, target_path, preset="raw", diff_threshold=1e-3) records, exit_code = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.failed == 1 assert exit_code == 1 def test_e2e_allow_failed_pattern_exit_zero(self, tmp_path, capsys): """E2E: failed tensor matched by allow_failed_pattern + a passing tensor → exit 0.""" torch.manual_seed(42) shared_tensor = torch.randn(10, 10) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="tensor_bad", tensor=torch.randn(10, 10), extra_dumps=[("tensor_good", shared_tensor)], ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="tensor_bad", tensor=torch.randn(10, 10) * 100, extra_dumps=[("tensor_good", shared_tensor)], ) argv = _make_argv( baseline_path, target_path, preset="raw", diff_threshold=1e-3, allow_failed_pattern="tensor_bad", ) records, exit_code = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 1 assert summary.failed == 1 assert exit_code == 0 def test_e2e_allow_failed_pattern_no_match_exit_one(self, tmp_path, capsys): """E2E: failed tensor NOT matched by allow_failed_pattern → exit 1.""" torch.manual_seed(42) shared_tensor = torch.randn(10, 10) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="tensor_bad", tensor=torch.randn(10, 10), extra_dumps=[("tensor_good", shared_tensor)], ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="tensor_bad", tensor=torch.randn(10, 10) * 100, extra_dumps=[("tensor_good", shared_tensor)], ) argv = _make_argv( baseline_path, target_path, preset="raw", diff_threshold=1e-3, allow_failed_pattern="other_tensor", ) records, exit_code = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.passed == 1 assert summary.failed == 1 assert exit_code == 1 class TestExitCodeSubprocess: """E2E subprocess tests: invoke comparator as a child process and verify exit code.""" @staticmethod def _run_comparator( baseline_path: Path, target_path: Path, *, preset: str = "raw", allow_skipped_pattern: str = ".*", ) -> subprocess.CompletedProcess[str]: cmd: list[str] = [ sys.executable, "-m", "sglang.srt.debug_utils.comparator", "--baseline-path", str(baseline_path), "--target-path", str(target_path), "--preset", preset, "--output-format", "json", "--allow-skipped-pattern", allow_skipped_pattern, ] return subprocess.run(cmd, capture_output=True, text=True) def test_all_passed_exit_zero(self, tmp_path): """Subprocess: all comparisons pass → exit 0.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) result = self._run_comparator(baseline_path, target_path) assert result.returncode == 0 def test_failed_exit_nonzero(self, tmp_path): """Subprocess: failed comparison → exit 1.""" torch.manual_seed(42) baseline_path = _create_rank_dump( tmp_path / "baseline", rank=0, name="t", tensor=torch.randn(10, 10) ) target_path = _create_rank_dump( tmp_path / "target", rank=0, name="t", tensor=torch.randn(10, 10) * 100 ) result = self._run_comparator(baseline_path, target_path) assert result.returncode == 1 def test_skipped_allow_all_exit_zero(self, tmp_path): """Subprocess: skipped comparison with allow_skipped_pattern='.*' → exit 0.""" baseline_path, target_path = _create_dumps( tmp_path, tensor_names=["tensor_a", "tensor_extra"], baseline_names=["tensor_a"], ) result = self._run_comparator( baseline_path, target_path, allow_skipped_pattern=".*" ) assert result.returncode == 0 def test_skipped_forbid_all_exit_nonzero(self, tmp_path): """Subprocess: skipped comparison with allow_skipped_pattern='^$' → exit 1.""" baseline_path, target_path = _create_dumps( tmp_path, tensor_names=["tensor_a", "tensor_extra"], baseline_names=["tensor_a"], ) result = self._run_comparator( baseline_path, target_path, allow_skipped_pattern="^$" ) assert result.returncode == 1 class TestReportOutput: """Test JSONL report file output via ReportSink.""" def test_default_report_path(self, tmp_path, capsys): """Default writes to /comparator_report.jsonl with ConfigRecord + SummaryRecord.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) argv = _make_argv(baseline_path, target_path, preset="raw", report_path=None) exit_code: int = run(parse_args(argv)) report_file: Path = target_path / "comparator_report.jsonl" assert report_file.exists() report_records: list[AnyRecord] = _parse_jsonl(report_file.read_text()) assert isinstance(report_records[0], ConfigRecord) assert isinstance(report_records[-1], SummaryRecord) assert exit_code == 0 def test_custom_report_path(self, tmp_path, capsys): """--report-path writes to the specified location.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) custom_path: Path = tmp_path / "custom" / "report.jsonl" argv = _make_argv( baseline_path, target_path, preset="raw", report_path=str(custom_path), ) run(parse_args(argv)) assert custom_path.exists() report_records: list[AnyRecord] = _parse_jsonl(custom_path.read_text()) assert isinstance(report_records[0], ConfigRecord) assert isinstance(report_records[-1], SummaryRecord) def test_disabled_report(self, tmp_path, capsys): """--report-path '' disables file generation.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) argv = _make_argv(baseline_path, target_path, preset="raw", report_path="") run(parse_args(argv)) report_file: Path = target_path / "comparator_report.jsonl" assert not report_file.exists() def test_report_matches_stdout_json(self, tmp_path, capsys): """In json mode, report content matches stdout output.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) report_file: Path = tmp_path / "report.jsonl" argv = _make_argv( baseline_path, target_path, preset="raw", output_format="json", report_path=str(report_file), ) capsys.readouterr() run(parse_args(argv)) stdout_lines: list[str] = capsys.readouterr().out.strip().splitlines() report_lines: list[str] = report_file.read_text().strip().splitlines() assert stdout_lines == report_lines def test_text_mode_also_writes_report(self, tmp_path, capsys): """Text stdout mode still writes JSONL report.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) report_file: Path = tmp_path / "report.jsonl" argv = _make_argv( baseline_path, target_path, preset="raw", output_format="text", report_path=str(report_file), ) run(parse_args(argv)) assert report_file.exists() report_records: list[AnyRecord] = _parse_jsonl(report_file.read_text()) assert isinstance(report_records[0], ConfigRecord) assert isinstance(report_records[-1], SummaryRecord) def test_streaming_flush(self, tmp_path, capsys): """Report file is flushed after each record (readable before close).""" from sglang.srt.debug_utils.comparator.report_sink import report_sink report_file: Path = tmp_path / "stream_report.jsonl" report_sink.configure( output_format="json", report_path=report_file, ) report_sink.add(ConfigRecord(config={"test": True})) content: str = report_file.read_text() assert len(content.strip().splitlines()) == 1 parsed: AnyRecord = parse_record_json(content.strip()) assert isinstance(parsed, ConfigRecord) class TestEntrypointDpAttentionMissingAlias: """Regression: dp-attention without ``# dp:=attn_dp`` → shape mismatch failure. In dp-attention mode (tp_size=2, attn_dp_size=2), layer_input is dumped after prepare_attn which DP-distributes tokens. One rank gets 0 tokens (shape [0, H]), the other gets all tokens (shape [T, H]). Without ``# dp:=attn_dp`` in dims, the comparator has no dp_rank/dp_size to filter on, so it picks one rank via TP pick — potentially the empty one — causing a shape mismatch with the baseline. """ @staticmethod def _sglang_dp_attn_parallel_info(*, tp_rank: int) -> dict: return { "tp_rank": tp_rank, "tp_size": 2, "pp_rank": 0, "pp_size": 1, "moe_ep_rank": 0, "moe_ep_size": 1, "moe_tp_rank": tp_rank, "moe_tp_size": 2, "moe_dp_rank": 0, "moe_dp_size": 1, "enable_dp_attention": True, "attn_tp_rank": 0, "attn_tp_size": 1, "attn_dp_rank": tp_rank, "attn_dp_size": 2, "local_attn_dp_rank": tp_rank, "local_attn_dp_size": 2, "attn_cp_rank": 0, "attn_cp_size": 1, } def test_missing_dp_alias_causes_shape_mismatch( self, tmp_path: Path, capsys ) -> None: """dims='t h' (no dp:=attn_dp) → comparator picks empty rank → shape_mismatch failure.""" torch.manual_seed(42) tensor_data: torch.Tensor = torch.randn(5, 8) target_data: torch.Tensor = tensor_data + torch.randn(5, 8) * 0.001 for side_name, data in [("baseline", tensor_data), ("target", target_data)]: side_dir: Path = tmp_path / side_name side_dir.mkdir() # Baseline: single rank, no DP attention if side_name == "baseline": _create_rank_dump( side_dir, rank=0, name="layer_input", tensor=data, dims="t h", parallel_info={"tp_rank": 0, "tp_size": 1}, framework="sglang", ) else: # Target: dp-attention, tp_rank=0 gets 0 tokens, tp_rank=1 gets all _create_rank_dump( side_dir, rank=0, name="layer_input", tensor=torch.empty(0, 8), dims="t h", parallel_info=self._sglang_dp_attn_parallel_info(tp_rank=0), framework="sglang", ) _create_rank_dump( side_dir, rank=1, name="layer_input", tensor=data, dims="t h", parallel_info=self._sglang_dp_attn_parallel_info(tp_rank=1), framework="sglang", ) argv: list[str] = _make_argv( tmp_path / "baseline" / _FIXED_EXP_NAME, tmp_path / "target" / _FIXED_EXP_NAME, diff_threshold=1e-3, ) records, exit_code = _run_and_parse(argv, capsys) assert exit_code == 1 errors = [r for r in records if isinstance(r, ComparisonErrorRecord)] assert len(errors) == 1 assert errors[0].category == "errored" class TestEntrypointAutoDescend: """Test auto-descend: --baseline-path / --target-path pointing to a parent directory that contains a single subdirectory with .pt files.""" def test_auto_descend_single_engine(self, tmp_path: Path, capsys) -> None: """Parent dir wrapping a single engine subdir is auto-descended and comparison succeeds.""" baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) baseline_wrapper: Path = tmp_path / "baseline_wrap" target_wrapper: Path = tmp_path / "target_wrap" baseline_wrapper.mkdir() target_wrapper.mkdir() baseline_exp.rename(baseline_wrapper / "engine_0") target_exp.rename(target_wrapper / "engine_0") argv = _make_argv(baseline_wrapper, target_wrapper, preset="raw") records, exit_code = _run_and_parse(argv, capsys) assert exit_code == 0 _assert_single_comparison_passed(records) def test_no_descend_when_pt_at_root(self, tmp_path: Path, capsys) -> None: """Direct .pt files — no descend needed, comparison still works.""" baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) argv = _make_argv(baseline_exp, target_exp, preset="raw") records, exit_code = _run_and_parse(argv, capsys) assert exit_code == 0 _assert_single_comparison_passed(records) def test_auto_descend_emits_log_record(self, tmp_path: Path, capsys) -> None: """Auto-descend emits a LogRecord with the info message.""" baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) wrapper: Path = tmp_path / "target_wrap" wrapper.mkdir() target_exp.rename(wrapper / "engine_0") argv = _make_argv(baseline_exp, wrapper, preset="raw") records, _ = _run_and_parse(argv, capsys) log_records: list[LogRecord] = [r for r in records if isinstance(r, LogRecord)] auto_descend_msgs: list[str] = [ info.message for lr in log_records for info in lr.infos if "auto-descend" in info.message ] assert any("target_path" in m for m in auto_descend_msgs) def test_auto_descend_single_nonempty_among_empty( self, tmp_path: Path, capsys ) -> None: """Two subdirs but only one has .pt — auto-descend picks the non-empty one.""" baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) wrapper: Path = tmp_path / "target_wrap" wrapper.mkdir() target_exp.rename(wrapper / "engine_0") (wrapper / "empty_subdir").mkdir() argv = _make_argv(baseline_exp, wrapper, preset="raw") records, exit_code = _run_and_parse(argv, capsys) assert exit_code == 0 _assert_single_comparison_passed(records) def test_error_multiple_nonempty_subdirs(self, tmp_path: Path) -> None: """Two subdirs both with .pt — raises ValueError with clear message.""" baseline_exp, target_exp = _create_dumps(tmp_path, ["tensor_a"]) wrapper: Path = tmp_path / "target_wrap" wrapper.mkdir() target_exp.rename(wrapper / "engine_0") engine_1: Path = wrapper / "engine_1" engine_1.mkdir() torch.save(torch.tensor([1.0]), engine_1 / "dummy.pt") argv: list[str] = _make_argv(baseline_exp, wrapper, preset="raw") with pytest.raises(ValueError, match="multiple subdirectories contain data"): run(parse_args(argv)) def test_error_no_data_found(self, tmp_path: Path) -> None: """No .pt files anywhere — raises ValueError.""" baseline_exp, _ = _create_dumps(tmp_path, ["tensor_a"]) empty_dir: Path = tmp_path / "empty_target" empty_dir.mkdir() (empty_dir / "subdir").mkdir() argv: list[str] = _make_argv(baseline_exp, empty_dir, preset="raw") with pytest.raises(ValueError, match="no .pt files found"): run(parse_args(argv)) class TestErrorResilience: """Bundle comparison exception → continue with remaining bundles.""" def test_one_bundle_errors_others_continue(self, tmp_path, capsys, monkeypatch): """One bundle raises exception → other bundles still compared, summary correct.""" baseline_path, target_path = _create_dumps( tmp_path, ["tensor_a", "tensor_b", "tensor_c"] ) argv = _make_argv(baseline_path, target_path, preset="raw") original = _entrypoint_module.compare_bundle_pair def _patched(**kwargs): if kwargs["name"] == "tensor_b": raise RuntimeError("intentional test error") return original(**kwargs) monkeypatch.setattr(_entrypoint_module, "compare_bundle_pair", _patched) records, exit_code = _run_and_parse(argv, capsys) comparisons = _get_comparisons(records) assert len(comparisons) == 2 errors = [r for r in records if isinstance(r, ComparisonErrorRecord)] assert len(errors) == 1 assert errors[0].name == "tensor_b" assert errors[0].exception_type == "RuntimeError" assert "intentional test error" in errors[0].traceback_str summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.errored == 1 assert summary.passed == 2 assert summary.total == 3 assert exit_code == 1 def test_all_bundles_error_exits_one(self, tmp_path, capsys, monkeypatch): """All bundles error → exit 1, summary all errored.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) argv = _make_argv(baseline_path, target_path, preset="raw") def _always_raise(**kwargs): raise ValueError("always fail") monkeypatch.setattr(_entrypoint_module, "compare_bundle_pair", _always_raise) records, exit_code = _run_and_parse(argv, capsys) summary = records[-1] assert isinstance(summary, SummaryRecord) assert summary.errored == 1 assert summary.passed == 0 assert exit_code == 1 def test_error_record_json_roundtrip_in_output(self, tmp_path, capsys, monkeypatch): """ComparisonErrorRecord correctly serializes and deserializes in output.""" baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a"]) argv = _make_argv(baseline_path, target_path, preset="raw") def _raise(**kwargs): raise TypeError("bad type") monkeypatch.setattr(_entrypoint_module, "compare_bundle_pair", _raise) records, _ = _run_and_parse(argv, capsys) errors = [r for r in records if isinstance(r, ComparisonErrorRecord)] assert len(errors) == 1 assert errors[0].exception_type == "TypeError" if __name__ == "__main__": sys.exit(pytest.main([__file__]))