File size: 3,994 Bytes
61ba51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys

import polars as pl
import pytest
import torch

from sglang.srt.debug_utils.dump_loader import (
    LOAD_FAILED,
    ValueWithMeta,
    _add_duplicate_index,
    _cast_to_polars_dtype,
    find_row,
    parse_meta_from_filename,
    read_meta,
)
from sglang.test.ci.ci_register import register_cpu_ci

register_cpu_ci(est_time=30, suite="default", nightly=True)


class TestReadMeta:
    def test_basic(self, tmp_path):
        for fn in [
            "step=1___rank=0___dump_index=1___name=a.pt",
            "step=2___rank=0___dump_index=2___name=b.pt",
        ]:
            torch.save(torch.randn(5), tmp_path / fn)

        df = read_meta(str(tmp_path))
        assert len(df) == 2
        assert all(c in df.columns for c in ["step", "rank", "name"])


class TestFindRow:
    def test_single_match(self):
        df = pl.DataFrame({"id": [1, 2], "name": ["a", "b"], "file": ["f1", "f2"]})
        assert find_row(df, {"id": 2})["file"] == "f2"

    def test_no_match(self):
        df = pl.DataFrame({"id": [1, 2], "name": ["a", "b"], "file": ["f1", "f2"]})
        assert find_row(df, {"id": 999}) is None

    def test_ambiguous(self):
        df = pl.DataFrame({"id": [1, 1], "file": ["f1", "f2"]})
        assert find_row(df, {"id": 1}) is None


class TestCastToPolars:
    def test_int(self):
        assert _cast_to_polars_dtype("42", pl.Int64) == 42

    def test_float(self):
        assert _cast_to_polars_dtype("3.14", pl.Float64) == pytest.approx(3.14)


class TestAddDuplicateIndex:
    def test_basic(self):
        df = pl.DataFrame(
            {
                "name": ["a", "a", "b"],
                "dump_index": [1, 2, 3],
                "filename": ["f1", "f2", "f3"],
            }
        )
        result = _add_duplicate_index(df)
        assert result.filter(pl.col("name") == "a").sort("dump_index")[
            "duplicate_index"
        ].to_list() == [0, 1]


class TestValueWithMeta:
    def test_load_dict_format(self, tmp_path) -> None:
        path = tmp_path / "step=0___rank=0___dump_index=1___name=hidden.pt"
        tensor = torch.randn(4, 8)
        torch.save({"value": tensor, "meta": {"custom": "field"}}, path)

        loaded = ValueWithMeta.load(path)
        assert torch.allclose(loaded.value, tensor)
        assert loaded.meta["custom"] == "field"
        assert loaded.meta["name"] == "hidden"
        assert loaded.meta["rank"] == 0

    def test_load_bare_tensor(self, tmp_path) -> None:
        path = tmp_path / "step=0___rank=0___dump_index=1___name=bare.pt"
        tensor = torch.randn(3, 3)
        torch.save(tensor, path)

        loaded = ValueWithMeta.load(path)
        assert torch.allclose(loaded.value, tensor)
        assert loaded.meta["name"] == "bare"

    def test_load_corrupted_file(self, tmp_path) -> None:
        path = tmp_path / "step=0___rank=0___dump_index=1___name=bad.pt"
        path.write_text("not a valid pt file")

        loaded = ValueWithMeta.load(path)
        assert loaded.value is LOAD_FAILED
        assert loaded.meta["name"] == "bad"


class TestRecomputeStatusParsing:
    def test_parse_recompute_status_from_filename(self) -> None:
        from pathlib import Path

        meta_disabled = parse_meta_from_filename(
            Path(
                "step=0___rank=0___dump_index=1___name=x___recompute_status=disabled.pt"
            )
        )
        assert meta_disabled["recompute_status"] == "disabled"

        meta_recompute = parse_meta_from_filename(
            Path(
                "step=0___rank=0___dump_index=1___name=x___recompute_status=recompute.pt"
            )
        )
        assert meta_recompute["recompute_status"] == "recompute"

        meta_original = parse_meta_from_filename(
            Path(
                "step=0___rank=0___dump_index=1___name=x___recompute_status=original.pt"
            )
        )
        assert meta_original["recompute_status"] == "original"


if __name__ == "__main__":
    sys.exit(pytest.main([__file__]))