File size: 3,994 Bytes
61ba51e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | import sys
import polars as pl
import pytest
import torch
from sglang.srt.debug_utils.dump_loader import (
LOAD_FAILED,
ValueWithMeta,
_add_duplicate_index,
_cast_to_polars_dtype,
find_row,
parse_meta_from_filename,
read_meta,
)
from sglang.test.ci.ci_register import register_cpu_ci
register_cpu_ci(est_time=30, suite="default", nightly=True)
class TestReadMeta:
def test_basic(self, tmp_path):
for fn in [
"step=1___rank=0___dump_index=1___name=a.pt",
"step=2___rank=0___dump_index=2___name=b.pt",
]:
torch.save(torch.randn(5), tmp_path / fn)
df = read_meta(str(tmp_path))
assert len(df) == 2
assert all(c in df.columns for c in ["step", "rank", "name"])
class TestFindRow:
def test_single_match(self):
df = pl.DataFrame({"id": [1, 2], "name": ["a", "b"], "file": ["f1", "f2"]})
assert find_row(df, {"id": 2})["file"] == "f2"
def test_no_match(self):
df = pl.DataFrame({"id": [1, 2], "name": ["a", "b"], "file": ["f1", "f2"]})
assert find_row(df, {"id": 999}) is None
def test_ambiguous(self):
df = pl.DataFrame({"id": [1, 1], "file": ["f1", "f2"]})
assert find_row(df, {"id": 1}) is None
class TestCastToPolars:
def test_int(self):
assert _cast_to_polars_dtype("42", pl.Int64) == 42
def test_float(self):
assert _cast_to_polars_dtype("3.14", pl.Float64) == pytest.approx(3.14)
class TestAddDuplicateIndex:
def test_basic(self):
df = pl.DataFrame(
{
"name": ["a", "a", "b"],
"dump_index": [1, 2, 3],
"filename": ["f1", "f2", "f3"],
}
)
result = _add_duplicate_index(df)
assert result.filter(pl.col("name") == "a").sort("dump_index")[
"duplicate_index"
].to_list() == [0, 1]
class TestValueWithMeta:
def test_load_dict_format(self, tmp_path) -> None:
path = tmp_path / "step=0___rank=0___dump_index=1___name=hidden.pt"
tensor = torch.randn(4, 8)
torch.save({"value": tensor, "meta": {"custom": "field"}}, path)
loaded = ValueWithMeta.load(path)
assert torch.allclose(loaded.value, tensor)
assert loaded.meta["custom"] == "field"
assert loaded.meta["name"] == "hidden"
assert loaded.meta["rank"] == 0
def test_load_bare_tensor(self, tmp_path) -> None:
path = tmp_path / "step=0___rank=0___dump_index=1___name=bare.pt"
tensor = torch.randn(3, 3)
torch.save(tensor, path)
loaded = ValueWithMeta.load(path)
assert torch.allclose(loaded.value, tensor)
assert loaded.meta["name"] == "bare"
def test_load_corrupted_file(self, tmp_path) -> None:
path = tmp_path / "step=0___rank=0___dump_index=1___name=bad.pt"
path.write_text("not a valid pt file")
loaded = ValueWithMeta.load(path)
assert loaded.value is LOAD_FAILED
assert loaded.meta["name"] == "bad"
class TestRecomputeStatusParsing:
def test_parse_recompute_status_from_filename(self) -> None:
from pathlib import Path
meta_disabled = parse_meta_from_filename(
Path(
"step=0___rank=0___dump_index=1___name=x___recompute_status=disabled.pt"
)
)
assert meta_disabled["recompute_status"] == "disabled"
meta_recompute = parse_meta_from_filename(
Path(
"step=0___rank=0___dump_index=1___name=x___recompute_status=recompute.pt"
)
)
assert meta_recompute["recompute_status"] == "recompute"
meta_original = parse_meta_from_filename(
Path(
"step=0___rank=0___dump_index=1___name=x___recompute_status=original.pt"
)
)
assert meta_original["recompute_status"] == "original"
if __name__ == "__main__":
sys.exit(pytest.main([__file__]))
|