Hanrui / sglang /test /registered /debug_utils /comparator /test_visualizer.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
import sys
from pathlib import Path
import pytest
import torch
from sglang.srt.debug_utils.comparator.visualizer.preprocessing import (
_preprocess_tensor,
_reshape_to_balanced_aspect,
)
from sglang.test.ci.ci_register import register_cpu_ci
register_cpu_ci(est_time=30, suite="default", nightly=True)
class TestPreprocessTensor:
def test_1d_becomes_2d(self) -> None:
t: torch.Tensor = torch.randn(100)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
def test_3d_becomes_2d(self) -> None:
t: torch.Tensor = torch.randn(2, 3, 4)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
assert result.numel() == t.numel()
def test_high_dim_becomes_2d(self) -> None:
t: torch.Tensor = torch.randn(2, 3, 4, 5)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
assert result.numel() == t.numel()
def test_scalar_becomes_2d(self) -> None:
t: torch.Tensor = torch.tensor(3.14)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
assert result.numel() == 1
def test_already_2d_preserves_elements(self) -> None:
t: torch.Tensor = torch.randn(10, 20)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
assert result.numel() == 200
class TestReshapeToBalancedAspect:
def test_extreme_wide_gets_fixed(self) -> None:
t: torch.Tensor = torch.randn(1, 10000)
result: torch.Tensor = _reshape_to_balanced_aspect(t)
h, w = result.shape
ratio: float = max(h, w) / max(min(h, w), 1)
assert ratio <= 5.0
def test_extreme_tall_gets_fixed(self) -> None:
t: torch.Tensor = torch.randn(10000, 1)
result: torch.Tensor = _reshape_to_balanced_aspect(t)
h, w = result.shape
ratio: float = max(h, w) / max(min(h, w), 1)
assert ratio <= 5.0
def test_already_balanced_unchanged(self) -> None:
t: torch.Tensor = torch.randn(100, 100)
result: torch.Tensor = _reshape_to_balanced_aspect(t)
assert result.shape == (100, 100)
def test_preserves_numel(self) -> None:
t: torch.Tensor = torch.randn(1, 7919)
result: torch.Tensor = _reshape_to_balanced_aspect(t)
assert result.numel() == t.numel()
class TestGenerateComparisonFigure:
@pytest.fixture(autouse=True)
def _skip_if_no_matplotlib(self) -> None:
pytest.importorskip("matplotlib")
def test_nested_output_dir(self, tmp_path: Path) -> None:
from sglang.srt.debug_utils.comparator.visualizer import (
generate_comparison_figure,
)
output_path: Path = tmp_path / "a" / "b" / "c" / "nested.png"
generate_comparison_figure(
baseline=torch.randn(10, 10),
target=torch.randn(10, 10),
name="nested",
output_path=output_path,
)
assert output_path.exists()
if __name__ == "__main__":
sys.exit(pytest.main([__file__]))