File size: 3,090 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import sys
from pathlib import Path
import pytest
import torch
from sglang.srt.debug_utils.comparator.visualizer.preprocessing import (
_preprocess_tensor,
_reshape_to_balanced_aspect,
)
from sglang.test.ci.ci_register import register_cpu_ci
register_cpu_ci(est_time=30, suite="default", nightly=True)
class TestPreprocessTensor:
def test_1d_becomes_2d(self) -> None:
t: torch.Tensor = torch.randn(100)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
def test_3d_becomes_2d(self) -> None:
t: torch.Tensor = torch.randn(2, 3, 4)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
assert result.numel() == t.numel()
def test_high_dim_becomes_2d(self) -> None:
t: torch.Tensor = torch.randn(2, 3, 4, 5)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
assert result.numel() == t.numel()
def test_scalar_becomes_2d(self) -> None:
t: torch.Tensor = torch.tensor(3.14)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
assert result.numel() == 1
def test_already_2d_preserves_elements(self) -> None:
t: torch.Tensor = torch.randn(10, 20)
result: torch.Tensor = _preprocess_tensor(t)
assert result.ndim == 2
assert result.numel() == 200
class TestReshapeToBalancedAspect:
def test_extreme_wide_gets_fixed(self) -> None:
t: torch.Tensor = torch.randn(1, 10000)
result: torch.Tensor = _reshape_to_balanced_aspect(t)
h, w = result.shape
ratio: float = max(h, w) / max(min(h, w), 1)
assert ratio <= 5.0
def test_extreme_tall_gets_fixed(self) -> None:
t: torch.Tensor = torch.randn(10000, 1)
result: torch.Tensor = _reshape_to_balanced_aspect(t)
h, w = result.shape
ratio: float = max(h, w) / max(min(h, w), 1)
assert ratio <= 5.0
def test_already_balanced_unchanged(self) -> None:
t: torch.Tensor = torch.randn(100, 100)
result: torch.Tensor = _reshape_to_balanced_aspect(t)
assert result.shape == (100, 100)
def test_preserves_numel(self) -> None:
t: torch.Tensor = torch.randn(1, 7919)
result: torch.Tensor = _reshape_to_balanced_aspect(t)
assert result.numel() == t.numel()
class TestGenerateComparisonFigure:
@pytest.fixture(autouse=True)
def _skip_if_no_matplotlib(self) -> None:
pytest.importorskip("matplotlib")
def test_nested_output_dir(self, tmp_path: Path) -> None:
from sglang.srt.debug_utils.comparator.visualizer import (
generate_comparison_figure,
)
output_path: Path = tmp_path / "a" / "b" / "c" / "nested.png"
generate_comparison_figure(
baseline=torch.randn(10, 10),
target=torch.randn(10, 10),
name="nested",
output_path=output_path,
)
assert output_path.exists()
if __name__ == "__main__":
sys.exit(pytest.main([__file__]))
|