Hanrui / sglang /test /registered /debug_utils /test_dump_comparator.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
from argparse import Namespace
from pathlib import Path
import pytest
import torch
from sglang.srt.debug_utils.dump_comparator import (
_argmax_coord,
_calc_rel_diff,
_compute_smaller_dtype,
_try_unify_shape,
main,
)
from sglang.srt.debug_utils.dumper import DumperConfig, _Dumper
from sglang.test.ci.ci_register import register_cpu_ci
register_cpu_ci(est_time=30, suite="default", nightly=True)
# ----------------------------- Unit tests -----------------------------
class TestCalcRelDiff:
def test_identical_vectors(self) -> None:
x: torch.Tensor = torch.randn(10, 10)
assert _calc_rel_diff(x, x).item() == pytest.approx(0.0, abs=1e-5)
def test_zero_vectors(self) -> None:
z: torch.Tensor = torch.zeros(5)
result = _calc_rel_diff(z, z)
assert not torch.isnan(result) or True # should not crash
class TestArgmaxCoord:
def test_known_position(self) -> None:
x: torch.Tensor = torch.zeros(2, 3, 4)
x[1, 2, 3] = 10.0
assert _argmax_coord(x) == (1, 2, 3)
class TestTryUnifyShape:
def test_squeeze_leading_ones(self) -> None:
target_shape: torch.Size = torch.Size([3, 4])
result: torch.Tensor = _try_unify_shape(torch.randn(1, 1, 3, 4), target_shape)
assert result.shape == target_shape
def test_no_op_when_no_leading_ones(self) -> None:
target_shape: torch.Size = torch.Size([3, 4])
result: torch.Tensor = _try_unify_shape(torch.randn(2, 3, 4), target_shape)
assert result.shape == (2, 3, 4)
class TestComputeSmallerDtype:
def test_known_pair(self) -> None:
assert _compute_smaller_dtype(torch.float32, torch.bfloat16) == torch.bfloat16
assert _compute_smaller_dtype(torch.bfloat16, torch.float32) == torch.bfloat16
def test_none_for_same_dtype(self) -> None:
assert _compute_smaller_dtype(torch.float32, torch.float32) is None
# ----------------------------- Integration tests -----------------------------
def _make_dumper(directory: Path) -> _Dumper:
return _Dumper(
config=DumperConfig(
enable=True,
dir=str(directory),
)
)
def _create_dumps(
tmp_path: Path,
tensor_names: list[str],
*,
baseline_names: list[str] | None = None,
) -> tuple[Path, Path]:
if baseline_names is None:
baseline_names = tensor_names
d_baseline: Path = tmp_path / "baseline"
d_target: Path = tmp_path / "target"
d_baseline.mkdir()
d_target.mkdir()
torch.manual_seed(42)
baseline_tensor: torch.Tensor = torch.randn(10, 10)
target_tensor: torch.Tensor = baseline_tensor + torch.randn(10, 10) * 0.01
exp_paths: list[Path] = []
for d, names, tensor in [
(d_baseline, baseline_names, baseline_tensor),
(d_target, tensor_names, target_tensor),
]:
dumper: _Dumper = _make_dumper(d)
for name in names:
dumper.dump(name, tensor)
dumper.step()
exp_paths.append(d / dumper._config.exp_name)
return exp_paths[0], exp_paths[1]
def _make_args(
baseline_path: Path,
target_path: Path,
*,
filter_pattern: str | None = None,
) -> Namespace:
return Namespace(
baseline_path=str(baseline_path),
target_path=str(target_path),
start_step=0,
end_step=1000000,
diff_threshold=1e-3,
filter=filter_pattern,
)
class TestMainBasic:
def test_matching_tensors(
self, tmp_path: Path, capsys: pytest.CaptureFixture
) -> None:
baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a", "tensor_b"])
args: Namespace = _make_args(baseline_path, target_path)
main(args)
captured: str = capsys.readouterr().out
assert "✅" in captured
def test_with_filter(self, tmp_path: Path, capsys: pytest.CaptureFixture) -> None:
baseline_path, target_path = _create_dumps(tmp_path, ["tensor_a", "tensor_b"])
args: Namespace = _make_args(
baseline_path, target_path, filter_pattern="tensor_a"
)
main(args)
captured: str = capsys.readouterr().out
assert "tensor_a" in captured
assert "Check:" in captured
def test_no_match_skips(
self, tmp_path: Path, capsys: pytest.CaptureFixture
) -> None:
baseline_path, target_path = _create_dumps(
tmp_path,
["only_in_target"],
baseline_names=["only_in_baseline"],
)
args: Namespace = _make_args(baseline_path, target_path)
main(args)
captured: str = capsys.readouterr().out
assert "Skip" in captured