File size: 16,312 Bytes
a402b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
import sys
from pathlib import Path

import pytest
import torch

from sglang.srt.debug_utils.comparator.output_types import SummaryRecord
from sglang.srt.debug_utils.comparator.utils import (
    Pair,
    argmax_coord,
    auto_descend_dir,
    calc_per_token_rel_diff,
    calc_rel_diff,
    compute_exit_code,
    compute_smaller_dtype,
    try_unify_shape,
)
from sglang.test.ci.ci_register import register_cpu_ci

register_cpu_ci(est_time=10, suite="default", nightly=True)


class TestCalcRelDiff:
    def test_identical_tensors(self):
        x = torch.randn(10, 10)
        assert calc_rel_diff(x, x).item() == pytest.approx(0.0, abs=1e-5)

    def test_orthogonal_tensors(self):
        result = calc_rel_diff(
            torch.tensor([1.0, 0.0]), torch.tensor([0.0, 1.0])
        ).item()
        assert result == pytest.approx(1.0, abs=1e-5)

    def test_similar_tensors(self):
        x = torch.tensor([1.0, 2.0, 3.0])
        y = torch.tensor([1.01, 2.01, 3.01])
        result = calc_rel_diff(x, y).item()
        assert 0.0 < result < 0.01

    def test_negated_tensors(self):
        x = torch.tensor([1.0, 2.0])
        result = calc_rel_diff(x, -x).item()
        assert result == pytest.approx(2.0, abs=1e-5)


class TestCalcPerTokenRelDiff:
    def test_identical_tensors(self) -> None:
        """Identical tensors β†’ per-token diff all zero."""
        x: torch.Tensor = torch.randn(8, 16)
        result: torch.Tensor = calc_per_token_rel_diff(x, x, seq_dim=0)

        assert result.shape == (8,)
        assert torch.allclose(result, torch.zeros(8), atol=1e-6)

    def test_different_tensors(self) -> None:
        """Single token position differs β†’ that position has higher diff."""
        torch.manual_seed(42)
        x: torch.Tensor = torch.randn(8, 16)
        y: torch.Tensor = x.clone()
        y[3, :] += 10.0

        result: torch.Tensor = calc_per_token_rel_diff(x, y, seq_dim=0)

        assert result.shape == (8,)
        assert result[3] > result[0]
        assert result[3] > result[7]
        for i in [0, 1, 2, 4, 5, 6, 7]:
            assert result[i] < 1e-6

    def test_seq_dim_selection(self) -> None:
        """Different seq_dim values produce correct output shapes."""
        x: torch.Tensor = torch.randn(4, 8, 16)
        y: torch.Tensor = x + torch.randn_like(x) * 0.01

        assert calc_per_token_rel_diff(x, y, seq_dim=0).shape == (4,)
        assert calc_per_token_rel_diff(x, y, seq_dim=1).shape == (8,)
        assert calc_per_token_rel_diff(x, y, seq_dim=2).shape == (16,)

    def test_1d_tensor(self) -> None:
        """1D tensor with seq_dim=0 returns per-element diff."""
        x: torch.Tensor = torch.tensor([1.0, 2.0, 3.0])
        y: torch.Tensor = torch.tensor([1.0, 2.0, 4.0])

        result: torch.Tensor = calc_per_token_rel_diff(x, y, seq_dim=0)

        assert result.shape == (3,)
        assert result[0] < 1e-6
        assert result[1] < 1e-6
        assert result[2] > 0.01


class TestArgmaxCoord:
    def test_1d_tensor(self):
        x = torch.tensor([0.0, 0.0, 5.0, 0.0])
        assert argmax_coord(x) == (2,)

    def test_2d_tensor(self):
        x = torch.zeros(3, 4)
        x[1, 2] = 10.0
        assert argmax_coord(x) == (1, 2)

    def test_3d_tensor(self):
        x = torch.zeros(2, 3, 4)
        x[1, 2, 3] = 10.0
        assert argmax_coord(x) == (1, 2, 3)


class TestTryUnifyShape:
    def test_squeeze_leading_ones(self):
        target = torch.Size([3, 4])
        assert try_unify_shape(torch.randn(1, 1, 3, 4), target).shape == target

    def test_no_squeeze_when_leading_dim_not_one(self):
        target = torch.Size([3, 4])
        assert try_unify_shape(torch.randn(2, 3, 4), target).shape == (2, 3, 4)

    def test_same_shape_noop(self):
        target = torch.Size([3, 4])
        x = torch.randn(3, 4)
        result = try_unify_shape(x, target)
        assert result.shape == target
        assert result.data_ptr() == x.data_ptr()

    def test_trailing_dims_mismatch(self):
        target = torch.Size([5, 6])
        x = torch.randn(1, 3, 4)
        result = try_unify_shape(x, target)
        assert result.shape == (1, 3, 4)


class TestComputeSmallerDtype:
    def test_float32_bfloat16(self):
        assert (
            compute_smaller_dtype(Pair(x=torch.float32, y=torch.bfloat16))
            == torch.bfloat16
        )

    def test_reverse_order(self):
        assert (
            compute_smaller_dtype(Pair(x=torch.bfloat16, y=torch.float32))
            == torch.bfloat16
        )

    def test_same_dtype_returns_none(self):
        assert compute_smaller_dtype(Pair(x=torch.float32, y=torch.float32)) is None

    def test_unknown_pair_returns_none(self):
        assert compute_smaller_dtype(Pair(x=torch.int32, y=torch.int64)) is None


class TestPairMap:
    def test_map_basic(self):
        pair = Pair(x=[1, 2, 3], y=[4, 5, 6])
        result = pair.map(lambda lst: sum(lst))
        assert result.x == 6
        assert result.y == 15

    def test_map_type_change(self):
        pair = Pair(x=[1, 2, 3], y=[10, 20])
        result = pair.map(len)
        assert result.x == 3
        assert result.y == 2

    def test_map_returns_new_pair(self):
        pair = Pair(x="hello", y="world")
        result = pair.map(str.upper)
        assert result.x == "HELLO"
        assert result.y == "WORLD"
        assert result is not pair


class TestComputeExitCode:
    """Unit tests for compute_exit_code logic."""

    def test_all_passed(self):
        """All passed β†’ exit 0."""
        summary = SummaryRecord(total=3, passed=3, failed=0, skipped=0)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern=None,
                failed_names=[],
            )
            == 0
        )

    def test_has_failed_and_passed(self):
        """Has failed and passed β†’ exit 1."""
        summary = SummaryRecord(total=4, passed=2, failed=2, skipped=0)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern=None,
                failed_names=["a", "b"],
            )
            == 1
        )

    def test_all_failed(self):
        """All failed (0 passed) β†’ exit 1."""
        summary = SummaryRecord(total=3, passed=0, failed=3, skipped=0)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern=None,
                failed_names=["a", "b", "c"],
            )
            == 1
        )

    def test_all_skipped_allow_all(self):
        """All skipped + allow_skipped_pattern='.*' β†’ exit 1 (nothing passed)."""
        summary = SummaryRecord(total=2, passed=0, failed=0, skipped=2)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=["a", "b"],
                allow_failed_pattern=None,
                failed_names=[],
            )
            == 1
        )

    def test_all_skipped_forbid_all(self):
        """All skipped + allow_skipped_pattern='^$' β†’ exit 1."""
        summary = SummaryRecord(total=2, passed=0, failed=0, skipped=2)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern="^$",
                skipped_names=["a", "b"],
                allow_failed_pattern=None,
                failed_names=[],
            )
            == 1
        )

    def test_passed_and_skipped_allow_all(self):
        """Passed + skipped, allow all β†’ exit 0."""
        summary = SummaryRecord(total=3, passed=2, failed=0, skipped=1)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=["a"],
                allow_failed_pattern=None,
                failed_names=[],
            )
            == 0
        )

    def test_passed_and_skipped_forbid_all(self):
        """Passed + skipped + forbid all β†’ exit 1."""
        summary = SummaryRecord(total=3, passed=2, failed=0, skipped=1)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern="^$",
                skipped_names=["a"],
                allow_failed_pattern=None,
                failed_names=[],
            )
            == 1
        )

    def test_skip_pattern_matches_specific_name(self):
        """Pattern matching specific name allows that skip, forbids others."""
        summary = SummaryRecord(total=4, passed=2, failed=0, skipped=2)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern="positions|seq_lens",
                skipped_names=["positions", "seq_lens"],
                allow_failed_pattern=None,
                failed_names=[],
            )
            == 0
        )

    def test_skip_pattern_partial_match_forbidden(self):
        """Pattern matches some skips but not all β†’ exit 1."""
        summary = SummaryRecord(total=4, passed=1, failed=0, skipped=3)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern="positions|seq_lens",
                skipped_names=["positions", "seq_lens", "hidden_states"],
                allow_failed_pattern=None,
                failed_names=[],
            )
            == 1
        )

    def test_allow_failed_pattern_matches_all(self):
        """allow_failed_pattern='.*' tolerates all failures β†’ exit 0."""
        summary = SummaryRecord(total=3, passed=1, failed=2, skipped=0)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern=".*",
                failed_names=["a", "b"],
            )
            == 0
        )

    def test_allow_failed_pattern_matches_specific(self):
        """Pattern matches all failed names β†’ exit 0."""
        summary = SummaryRecord(total=3, passed=1, failed=2, skipped=0)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern="hidden_states|logits",
                failed_names=["hidden_states", "logits"],
            )
            == 0
        )

    def test_allow_failed_pattern_partial_match(self):
        """Pattern matches some but not all failures β†’ exit 1."""
        summary = SummaryRecord(total=3, passed=0, failed=3, skipped=0)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern="hidden_states",
                failed_names=["hidden_states", "logits", "attn"],
            )
            == 1
        )

    def test_allow_failed_pattern_no_failures(self):
        """Pattern set but no failures β†’ exit 0."""
        summary = SummaryRecord(total=2, passed=2, failed=0, skipped=0)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern=".*",
                failed_names=[],
            )
            == 0
        )

    def test_both_failed_and_skipped_patterns(self):
        """Both patterns set, both satisfied β†’ exit 0."""
        summary = SummaryRecord(total=4, passed=1, failed=1, skipped=2)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern="positions|seq_lens",
                skipped_names=["positions", "seq_lens"],
                allow_failed_pattern="logits",
                failed_names=["logits"],
            )
            == 0
        )

    def test_failed_pattern_satisfied_but_skipped_not(self):
        """Failed pattern OK but skipped pattern fails β†’ exit 1."""
        summary = SummaryRecord(total=3, passed=1, failed=1, skipped=1)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern="^$",
                skipped_names=["a"],
                allow_failed_pattern=".*",
                failed_names=["b"],
            )
            == 1
        )

    def test_zero_passed_exits_one(self):
        """No tensors passed β†’ exit 1, even when all failures are allowed."""
        summary = SummaryRecord(total=2, passed=0, failed=2, skipped=0)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern=".*",
                failed_names=["a", "b"],
            )
            == 1
        )

    def test_zero_passed_all_skipped_exits_one(self):
        """All skipped, nothing passed β†’ exit 1."""
        summary = SummaryRecord(total=3, passed=0, failed=0, skipped=3)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=["a", "b", "c"],
                allow_failed_pattern=None,
                failed_names=[],
            )
            == 1
        )

    def test_errored_with_passed_exits_one(self):
        """Has errored bundle even with passed β†’ exit 1."""
        summary = SummaryRecord(total=3, passed=2, failed=0, skipped=0, errored=1)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern=None,
                failed_names=[],
                errored_names=["broken_tensor"],
            )
            == 1
        )

    def test_errored_only_exits_one(self):
        """All errored β†’ exit 1 (passed==0 already exits 1, but errored also independently triggers)."""
        summary = SummaryRecord(total=1, passed=0, failed=0, skipped=0, errored=1)
        assert (
            compute_exit_code(
                summary,
                allow_skipped_pattern=".*",
                skipped_names=[],
                allow_failed_pattern=None,
                failed_names=[],
                errored_names=["broken_tensor"],
            )
            == 1
        )


def _make_pt(directory: Path) -> None:
    directory.mkdir(parents=True, exist_ok=True)
    torch.save(torch.tensor([1.0]), directory / "dummy.pt")


class TestAutoDescendDir:
    def test_no_descend_when_pt_at_root(self, tmp_path: Path) -> None:
        """Directory with .pt files directly is returned as-is."""
        _make_pt(tmp_path)
        _make_pt(tmp_path / "child_a")
        assert auto_descend_dir(tmp_path, label="test") == tmp_path

    def test_descend_into_single_child(self, tmp_path: Path) -> None:
        """Single child with .pt triggers descend."""
        child: Path = tmp_path / "engine_0"
        _make_pt(child)
        assert auto_descend_dir(tmp_path, label="test") == child

    def test_descend_single_nonempty_child_among_empty(self, tmp_path: Path) -> None:
        """Two subdirs but only one has .pt β€” descend into that one."""
        nonempty: Path = tmp_path / "engine_0"
        _make_pt(nonempty)
        (tmp_path / "empty_child").mkdir()
        assert auto_descend_dir(tmp_path, label="test") == nonempty

    def test_error_with_multiple_nonempty_children(self, tmp_path: Path) -> None:
        """Two children with .pt files β€” ambiguous, raises ValueError."""
        _make_pt(tmp_path / "engine_0")
        _make_pt(tmp_path / "engine_1")
        with pytest.raises(ValueError, match="multiple subdirectories contain data"):
            auto_descend_dir(tmp_path, label="test")

    def test_error_when_no_data_found(self, tmp_path: Path) -> None:
        """No .pt files anywhere β€” raises ValueError."""
        (tmp_path / "empty_child").mkdir()
        with pytest.raises(ValueError, match="no .pt files found"):
            auto_descend_dir(tmp_path, label="test")


if __name__ == "__main__":
    sys.exit(pytest.main([__file__]))