| import sys |
| from pathlib import Path |
|
|
| import pytest |
| import torch |
|
|
| from sglang.srt.debug_utils.comparator.visualizer.preprocessing import ( |
| _preprocess_tensor, |
| _reshape_to_balanced_aspect, |
| ) |
| from sglang.test.ci.ci_register import register_cpu_ci |
|
|
| register_cpu_ci(est_time=30, suite="default", nightly=True) |
|
|
|
|
| class TestPreprocessTensor: |
| def test_1d_becomes_2d(self) -> None: |
| t: torch.Tensor = torch.randn(100) |
| result: torch.Tensor = _preprocess_tensor(t) |
| assert result.ndim == 2 |
|
|
| def test_3d_becomes_2d(self) -> None: |
| t: torch.Tensor = torch.randn(2, 3, 4) |
| result: torch.Tensor = _preprocess_tensor(t) |
| assert result.ndim == 2 |
| assert result.numel() == t.numel() |
|
|
| def test_high_dim_becomes_2d(self) -> None: |
| t: torch.Tensor = torch.randn(2, 3, 4, 5) |
| result: torch.Tensor = _preprocess_tensor(t) |
| assert result.ndim == 2 |
| assert result.numel() == t.numel() |
|
|
| def test_scalar_becomes_2d(self) -> None: |
| t: torch.Tensor = torch.tensor(3.14) |
| result: torch.Tensor = _preprocess_tensor(t) |
| assert result.ndim == 2 |
| assert result.numel() == 1 |
|
|
| def test_already_2d_preserves_elements(self) -> None: |
| t: torch.Tensor = torch.randn(10, 20) |
| result: torch.Tensor = _preprocess_tensor(t) |
| assert result.ndim == 2 |
| assert result.numel() == 200 |
|
|
|
|
| class TestReshapeToBalancedAspect: |
| def test_extreme_wide_gets_fixed(self) -> None: |
| t: torch.Tensor = torch.randn(1, 10000) |
| result: torch.Tensor = _reshape_to_balanced_aspect(t) |
| h, w = result.shape |
| ratio: float = max(h, w) / max(min(h, w), 1) |
| assert ratio <= 5.0 |
|
|
| def test_extreme_tall_gets_fixed(self) -> None: |
| t: torch.Tensor = torch.randn(10000, 1) |
| result: torch.Tensor = _reshape_to_balanced_aspect(t) |
| h, w = result.shape |
| ratio: float = max(h, w) / max(min(h, w), 1) |
| assert ratio <= 5.0 |
|
|
| def test_already_balanced_unchanged(self) -> None: |
| t: torch.Tensor = torch.randn(100, 100) |
| result: torch.Tensor = _reshape_to_balanced_aspect(t) |
| assert result.shape == (100, 100) |
|
|
| def test_preserves_numel(self) -> None: |
| t: torch.Tensor = torch.randn(1, 7919) |
| result: torch.Tensor = _reshape_to_balanced_aspect(t) |
| assert result.numel() == t.numel() |
|
|
|
|
| class TestGenerateComparisonFigure: |
| @pytest.fixture(autouse=True) |
| def _skip_if_no_matplotlib(self) -> None: |
| pytest.importorskip("matplotlib") |
|
|
| def test_nested_output_dir(self, tmp_path: Path) -> None: |
| from sglang.srt.debug_utils.comparator.visualizer import ( |
| generate_comparison_figure, |
| ) |
|
|
| output_path: Path = tmp_path / "a" / "b" / "c" / "nested.png" |
|
|
| generate_comparison_figure( |
| baseline=torch.randn(10, 10), |
| target=torch.randn(10, 10), |
| name="nested", |
| output_path=output_path, |
| ) |
|
|
| assert output_path.exists() |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(pytest.main([__file__])) |
|
|