File size: 3,090 Bytes
a402b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import sys
from pathlib import Path

import pytest
import torch

from sglang.srt.debug_utils.comparator.visualizer.preprocessing import (
    _preprocess_tensor,
    _reshape_to_balanced_aspect,
)
from sglang.test.ci.ci_register import register_cpu_ci

register_cpu_ci(est_time=30, suite="default", nightly=True)


class TestPreprocessTensor:
    def test_1d_becomes_2d(self) -> None:
        t: torch.Tensor = torch.randn(100)
        result: torch.Tensor = _preprocess_tensor(t)
        assert result.ndim == 2

    def test_3d_becomes_2d(self) -> None:
        t: torch.Tensor = torch.randn(2, 3, 4)
        result: torch.Tensor = _preprocess_tensor(t)
        assert result.ndim == 2
        assert result.numel() == t.numel()

    def test_high_dim_becomes_2d(self) -> None:
        t: torch.Tensor = torch.randn(2, 3, 4, 5)
        result: torch.Tensor = _preprocess_tensor(t)
        assert result.ndim == 2
        assert result.numel() == t.numel()

    def test_scalar_becomes_2d(self) -> None:
        t: torch.Tensor = torch.tensor(3.14)
        result: torch.Tensor = _preprocess_tensor(t)
        assert result.ndim == 2
        assert result.numel() == 1

    def test_already_2d_preserves_elements(self) -> None:
        t: torch.Tensor = torch.randn(10, 20)
        result: torch.Tensor = _preprocess_tensor(t)
        assert result.ndim == 2
        assert result.numel() == 200


class TestReshapeToBalancedAspect:
    def test_extreme_wide_gets_fixed(self) -> None:
        t: torch.Tensor = torch.randn(1, 10000)
        result: torch.Tensor = _reshape_to_balanced_aspect(t)
        h, w = result.shape
        ratio: float = max(h, w) / max(min(h, w), 1)
        assert ratio <= 5.0

    def test_extreme_tall_gets_fixed(self) -> None:
        t: torch.Tensor = torch.randn(10000, 1)
        result: torch.Tensor = _reshape_to_balanced_aspect(t)
        h, w = result.shape
        ratio: float = max(h, w) / max(min(h, w), 1)
        assert ratio <= 5.0

    def test_already_balanced_unchanged(self) -> None:
        t: torch.Tensor = torch.randn(100, 100)
        result: torch.Tensor = _reshape_to_balanced_aspect(t)
        assert result.shape == (100, 100)

    def test_preserves_numel(self) -> None:
        t: torch.Tensor = torch.randn(1, 7919)
        result: torch.Tensor = _reshape_to_balanced_aspect(t)
        assert result.numel() == t.numel()


class TestGenerateComparisonFigure:
    @pytest.fixture(autouse=True)
    def _skip_if_no_matplotlib(self) -> None:
        pytest.importorskip("matplotlib")

    def test_nested_output_dir(self, tmp_path: Path) -> None:
        from sglang.srt.debug_utils.comparator.visualizer import (
            generate_comparison_figure,
        )

        output_path: Path = tmp_path / "a" / "b" / "c" / "nested.png"

        generate_comparison_figure(
            baseline=torch.randn(10, 10),
            target=torch.randn(10, 10),
            name="nested",
            output_path=output_path,
        )

        assert output_path.exists()


if __name__ == "__main__":
    sys.exit(pytest.main([__file__]))