import sys from io import StringIO from pathlib import Path from typing import Any, Optional import polars as pl import pytest import torch from rich.console import Console from sglang.srt.debug_utils.comparator.display import ( _collect_input_ids_and_positions, _collect_rank_info, _render_polars_as_text, extract_parallel_info, ) from sglang.srt.debug_utils.comparator.output_types import ( InputIdsRecord, RankInfoRecord, ) from sglang.test.ci.ci_register import register_cpu_ci register_cpu_ci(est_time=10, suite="default", nightly=True) def _render_rich(renderable: object) -> str: buf: StringIO = StringIO() Console(file=buf, force_terminal=False, width=120).print(renderable) return buf.getvalue().rstrip("\n") def _save_dump_file( directory: Path, *, name: str, step: int, rank: int, dump_index: int, value: torch.Tensor, meta: dict, ) -> str: filename = f"name={name}___step={step}___rank={rank}___dump_index={dump_index}.pt" torch.save({"value": value, "meta": meta}, directory / filename) return filename def _make_df(rows: list[dict]) -> pl.DataFrame: df = pl.DataFrame(rows) df = df.with_columns( pl.col("step").cast(int), pl.col("rank").cast(int), pl.col("dump_index").cast(int), ) return df class TestRenderPolarsAsText: def test_renders_table(self) -> None: df = pl.DataFrame({"col_a": [1, 2], "col_b": ["x", "y"]}) text: str = _render_polars_as_text(df, title="test table") assert "test table" in text assert "col_a" in text assert "col_b" in text def test_renders_empty_dataframe(self) -> None: df = pl.DataFrame({"a": [], "b": []}) text: str = _render_polars_as_text(df, title="empty") assert "empty" in text class TestCollectRankInfo: def test_collects_rank_info(self, tmp_path: Path) -> None: sglang_info = { "tp_rank": 0, "tp_size": 2, "pp_rank": 0, "pp_size": 1, } filename: str = _save_dump_file( tmp_path, name="input_ids", step=0, rank=0, dump_index=0, value=torch.tensor([1, 2, 3]), meta={"sglang_parallel_info": sglang_info}, ) df = _make_df( [ { "filename": filename, "name": "input_ids", "step": 0, "rank": 0, "dump_index": 0, } ] ) rows: Optional[list[dict[str, Any]]] = _collect_rank_info(df, dump_dir=tmp_path) assert rows is not None assert len(rows) == 1 assert rows[0]["rank"] == 0 assert rows[0]["tp"] == "0/2" assert rows[0]["pp"] == "0/1" def test_returns_none_when_no_input_ids(self, tmp_path: Path) -> None: df = _make_df( [ { "filename": "f.pt", "name": "some_other", "step": 0, "rank": 0, "dump_index": 0, } ] ) result = _collect_rank_info(df, dump_dir=tmp_path) assert result is None def test_deduplicates_ranks(self, tmp_path: Path) -> None: meta = {"sglang_parallel_info": {"tp_rank": 0, "tp_size": 1}} f1: str = _save_dump_file( tmp_path, name="input_ids", step=0, rank=0, dump_index=0, value=torch.tensor([1]), meta=meta, ) f2: str = _save_dump_file( tmp_path, name="input_ids", step=1, rank=0, dump_index=1, value=torch.tensor([2]), meta=meta, ) df = _make_df( [ { "filename": f1, "name": "input_ids", "step": 0, "rank": 0, "dump_index": 0, }, { "filename": f2, "name": "input_ids", "step": 1, "rank": 0, "dump_index": 1, }, ] ) rows = _collect_rank_info(df, dump_dir=tmp_path) assert rows is not None assert len(rows) == 1 class TestCollectInputIdsAndPositions: def test_collects_ids_and_positions(self, tmp_path: Path) -> None: f_ids: str = _save_dump_file( tmp_path, name="input_ids", step=0, rank=0, dump_index=0, value=torch.tensor([10, 20, 30]), meta={}, ) f_pos: str = _save_dump_file( tmp_path, name="positions", step=0, rank=0, dump_index=1, value=torch.tensor([0, 1, 2]), meta={}, ) df = _make_df( [ { "filename": f_ids, "name": "input_ids", "step": 0, "rank": 0, "dump_index": 0, }, { "filename": f_pos, "name": "positions", "step": 0, "rank": 0, "dump_index": 1, }, ] ) rows = _collect_input_ids_and_positions(df, dump_dir=tmp_path) assert rows is not None assert len(rows) == 1 assert rows[0]["step"] == 0 assert rows[0]["rank"] == 0 assert rows[0]["num_tokens"] == 3 assert "10" in rows[0]["input_ids"] assert "0" in rows[0]["positions"] def test_returns_none_when_empty(self, tmp_path: Path) -> None: df = _make_df( [ { "filename": "f.pt", "name": "weight", "step": 0, "rank": 0, "dump_index": 0, } ] ) result = _collect_input_ids_and_positions(df, dump_dir=tmp_path) assert result is None def test_with_mock_tokenizer(self, tmp_path: Path) -> None: f_ids: str = _save_dump_file( tmp_path, name="input_ids", step=0, rank=0, dump_index=0, value=torch.tensor([1, 2]), meta={}, ) df = _make_df( [ { "filename": f_ids, "name": "input_ids", "step": 0, "rank": 0, "dump_index": 0, } ] ) class _MockTokenizer: def decode(self, ids: list[int], skip_special_tokens: bool = False) -> str: return f"decoded:{ids}" rows = _collect_input_ids_and_positions( df, dump_dir=tmp_path, tokenizer=_MockTokenizer() ) assert rows is not None assert "decoded_text" in rows[0] assert "decoded:" in rows[0]["decoded_text"] class TestRankInfoRecordSnapshot: def test_to_text_snapshot(self) -> None: record = RankInfoRecord( label="baseline", rows=[ {"rank": 0, "tp": "0/2", "pp": "0/1"}, {"rank": 1, "tp": "1/2", "pp": "0/1"}, ], ) text: str = record.to_text() assert "baseline ranks" in text assert "rank" in text assert "tp" in text assert "pp" in text assert "0/2" in text assert "1/2" in text assert "0/1" in text def test_to_rich_snapshot(self) -> None: from rich.table import Table record = RankInfoRecord( label="baseline", rows=[ {"rank": 0, "tp": "0/2", "pp": "0/1"}, {"rank": 1, "tp": "1/2", "pp": "0/1"}, ], ) body = record._format_rich_body() assert isinstance(body, Table) rendered: str = _render_rich(body) assert "baseline ranks" in rendered assert "0/2" in rendered assert "1/2" in rendered def test_json_roundtrip(self) -> None: record = RankInfoRecord( label="target", rows=[{"rank": 0, "tp": "0/4"}], ) json_str: str = record.model_dump_json() assert '"type":"rank_info"' in json_str assert '"label":"target"' in json_str assert '"tp":"0/4"' in json_str class TestInputIdsRecordSnapshot: def test_to_text_snapshot(self) -> None: record = InputIdsRecord( label="target", rows=[ { "step": 0, "rank": 0, "num_tokens": 3, "input_ids": "[10, 20, 30]", "positions": "[0, 1, 2]", }, ], ) text: str = record.to_text() assert "target input_ids & positions" in text assert "step" in text assert "num_tokens" in text assert "10, 20, 30" in text assert "0, 1, 2" in text def test_to_rich_snapshot(self) -> None: from rich.table import Table record = InputIdsRecord( label="target", rows=[ { "step": 0, "rank": 0, "num_tokens": 3, "input_ids": "[10, 20, 30]", "positions": "[0, 1, 2]", }, ], ) body = record._format_rich_body() assert isinstance(body, Table) rendered: str = _render_rich(body) assert "target input_ids & positions" in rendered assert "10, 20, 30" in rendered assert "0, 1, 2" in rendered def test_json_roundtrip(self) -> None: record = InputIdsRecord( label="baseline", rows=[ { "step": 0, "rank": 0, "num_tokens": 2, "input_ids": "[1, 2]", "positions": "[0, 1]", "decoded_text": "'hello'", }, ], ) json_str: str = record.model_dump_json() assert '"type":"input_ids"' in json_str assert '"label":"baseline"' in json_str assert '"decoded_text"' in json_str def test_to_text_with_decoded(self) -> None: record = InputIdsRecord( label="test", rows=[ { "step": 0, "rank": 0, "num_tokens": 2, "input_ids": "[1, 2]", "positions": "[0, 1]", "decoded_text": "'hello world'", }, ], ) text: str = record.to_text() assert "decoded_text" in text assert "hello world" in text class TestExtractParallelInfo: def test_extracts_rank_size_pairs(self) -> None: info: dict = { "tp_rank": 1, "tp_size": 4, "pp_rank": 0, "pp_size": 2, } row_data: dict = {} extract_parallel_info(row_data=row_data, info=info) assert row_data["tp"] == "1/4" assert row_data["pp"] == "0/2" def test_skips_error_info(self) -> None: row_data: dict = {} extract_parallel_info( row_data=row_data, info={"error": True, "tp_rank": 0, "tp_size": 1} ) assert row_data == {} def test_skips_empty_info(self) -> None: row_data: dict = {} extract_parallel_info(row_data=row_data, info={}) assert row_data == {} def test_ignores_rank_without_size(self) -> None: row_data: dict = {} extract_parallel_info(row_data=row_data, info={"tp_rank": 0}) assert "tp" not in row_data if __name__ == "__main__": sys.exit(pytest.main([__file__]))