| import sys |
|
|
| import polars as pl |
| import pytest |
| import torch |
|
|
| from sglang.srt.debug_utils.dump_loader import ( |
| LOAD_FAILED, |
| ValueWithMeta, |
| _add_duplicate_index, |
| _cast_to_polars_dtype, |
| find_row, |
| parse_meta_from_filename, |
| read_meta, |
| ) |
| from sglang.test.ci.ci_register import register_cpu_ci |
|
|
| register_cpu_ci(est_time=30, suite="default", nightly=True) |
|
|
|
|
| class TestReadMeta: |
| def test_basic(self, tmp_path): |
| for fn in [ |
| "step=1___rank=0___dump_index=1___name=a.pt", |
| "step=2___rank=0___dump_index=2___name=b.pt", |
| ]: |
| torch.save(torch.randn(5), tmp_path / fn) |
|
|
| df = read_meta(str(tmp_path)) |
| assert len(df) == 2 |
| assert all(c in df.columns for c in ["step", "rank", "name"]) |
|
|
|
|
| class TestFindRow: |
| def test_single_match(self): |
| df = pl.DataFrame({"id": [1, 2], "name": ["a", "b"], "file": ["f1", "f2"]}) |
| assert find_row(df, {"id": 2})["file"] == "f2" |
|
|
| def test_no_match(self): |
| df = pl.DataFrame({"id": [1, 2], "name": ["a", "b"], "file": ["f1", "f2"]}) |
| assert find_row(df, {"id": 999}) is None |
|
|
| def test_ambiguous(self): |
| df = pl.DataFrame({"id": [1, 1], "file": ["f1", "f2"]}) |
| assert find_row(df, {"id": 1}) is None |
|
|
|
|
| class TestCastToPolars: |
| def test_int(self): |
| assert _cast_to_polars_dtype("42", pl.Int64) == 42 |
|
|
| def test_float(self): |
| assert _cast_to_polars_dtype("3.14", pl.Float64) == pytest.approx(3.14) |
|
|
|
|
| class TestAddDuplicateIndex: |
| def test_basic(self): |
| df = pl.DataFrame( |
| { |
| "name": ["a", "a", "b"], |
| "dump_index": [1, 2, 3], |
| "filename": ["f1", "f2", "f3"], |
| } |
| ) |
| result = _add_duplicate_index(df) |
| assert result.filter(pl.col("name") == "a").sort("dump_index")[ |
| "duplicate_index" |
| ].to_list() == [0, 1] |
|
|
|
|
| class TestValueWithMeta: |
| def test_load_dict_format(self, tmp_path) -> None: |
| path = tmp_path / "step=0___rank=0___dump_index=1___name=hidden.pt" |
| tensor = torch.randn(4, 8) |
| torch.save({"value": tensor, "meta": {"custom": "field"}}, path) |
|
|
| loaded = ValueWithMeta.load(path) |
| assert torch.allclose(loaded.value, tensor) |
| assert loaded.meta["custom"] == "field" |
| assert loaded.meta["name"] == "hidden" |
| assert loaded.meta["rank"] == 0 |
|
|
| def test_load_bare_tensor(self, tmp_path) -> None: |
| path = tmp_path / "step=0___rank=0___dump_index=1___name=bare.pt" |
| tensor = torch.randn(3, 3) |
| torch.save(tensor, path) |
|
|
| loaded = ValueWithMeta.load(path) |
| assert torch.allclose(loaded.value, tensor) |
| assert loaded.meta["name"] == "bare" |
|
|
| def test_load_corrupted_file(self, tmp_path) -> None: |
| path = tmp_path / "step=0___rank=0___dump_index=1___name=bad.pt" |
| path.write_text("not a valid pt file") |
|
|
| loaded = ValueWithMeta.load(path) |
| assert loaded.value is LOAD_FAILED |
| assert loaded.meta["name"] == "bad" |
|
|
|
|
| class TestRecomputeStatusParsing: |
| def test_parse_recompute_status_from_filename(self) -> None: |
| from pathlib import Path |
|
|
| meta_disabled = parse_meta_from_filename( |
| Path( |
| "step=0___rank=0___dump_index=1___name=x___recompute_status=disabled.pt" |
| ) |
| ) |
| assert meta_disabled["recompute_status"] == "disabled" |
|
|
| meta_recompute = parse_meta_from_filename( |
| Path( |
| "step=0___rank=0___dump_index=1___name=x___recompute_status=recompute.pt" |
| ) |
| ) |
| assert meta_recompute["recompute_status"] == "recompute" |
|
|
| meta_original = parse_meta_from_filename( |
| Path( |
| "step=0___rank=0___dump_index=1___name=x___recompute_status=original.pt" |
| ) |
| ) |
| assert meta_original["recompute_status"] == "original" |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(pytest.main([__file__])) |
|
|