| import json |
| import sys |
|
|
| import pytest |
| from pydantic import ValidationError |
|
|
| from sglang.srt.debug_utils.comparator.aligner.entrypoint.traced_types import ( |
| TracedAlignerPlan, |
| TracedSidePlan, |
| TracedStepPlan, |
| TracedSubPlan, |
| ) |
| from sglang.srt.debug_utils.comparator.aligner.entrypoint.types import ( |
| AlignerPerStepPlan, |
| AlignerPlan, |
| ) |
| from sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import ( |
| PositionalSeqId, |
| TokenAlignerPlan, |
| TokenAlignerSeqInfo, |
| TokenAlignerStepAux, |
| TokenLocator, |
| ) |
| from sglang.srt.debug_utils.comparator.aligner.unsharder.types import ( |
| AxisInfo, |
| ConcatParams, |
| UnsharderPlan, |
| ) |
| from sglang.srt.debug_utils.comparator.dims_spec import ParallelAxis, TokenLayout |
| from sglang.srt.debug_utils.comparator.output_types import ( |
| ComparisonErrorRecord, |
| ComparisonNonTensorRecord, |
| ComparisonSkipRecord, |
| ComparisonTensorRecord, |
| ErrorLog, |
| SummaryRecord, |
| parse_record_json, |
| ) |
| from sglang.srt.debug_utils.comparator.tensor_comparator.types import ( |
| DiffInfo, |
| TensorInfo, |
| TensorStats, |
| ) |
| from sglang.srt.debug_utils.comparator.utils import Pair, _check_equal_lengths |
| from sglang.test.ci.ci_register import register_cpu_ci |
|
|
| register_cpu_ci(est_time=10, suite="default", nightly=True) |
|
|
|
|
| class TestCheckEqualLengths: |
| def test_all_equal(self): |
| _check_equal_lengths(a=[1, 2], b=[3, 4]) |
|
|
| def test_empty_lists(self): |
| _check_equal_lengths(a=[], b=[]) |
|
|
| def test_mismatch_raises(self): |
| with pytest.raises(ValueError, match="Length mismatch"): |
| _check_equal_lengths(a=[1, 2], b=[3]) |
|
|
|
|
| class TestTokenAlignerStepAux: |
| def test_valid(self): |
| aux = TokenAlignerStepAux( |
| input_ids=[10, 20, 30], |
| positions=[0, 1, 2], |
| seq_lens=[2, 1], |
| seq_ids=[ |
| PositionalSeqId(step=0, seq_index=0), |
| PositionalSeqId(step=0, seq_index=1), |
| ], |
| ) |
| assert len(aux.input_ids) == 3 |
|
|
| def test_token_length_mismatch(self): |
| with pytest.raises(ValueError, match="Length mismatch"): |
| TokenAlignerStepAux( |
| input_ids=[10, 20, 30], |
| positions=[0, 1], |
| seq_lens=[2, 1], |
| seq_ids=[ |
| PositionalSeqId(step=0, seq_index=0), |
| PositionalSeqId(step=0, seq_index=1), |
| ], |
| ) |
|
|
| def test_seq_length_mismatch(self): |
| with pytest.raises(ValueError, match="Length mismatch"): |
| TokenAlignerStepAux( |
| input_ids=[10, 20, 30], |
| positions=[0, 1, 2], |
| seq_lens=[2, 1], |
| seq_ids=[PositionalSeqId(step=0, seq_index=0)], |
| ) |
|
|
| def test_sum_seq_lens_mismatch(self): |
| with pytest.raises(ValueError, match="sum\\(seq_lens\\)"): |
| TokenAlignerStepAux( |
| input_ids=[10, 20, 30], |
| positions=[0, 1, 2], |
| seq_lens=[1, 1], |
| seq_ids=[ |
| PositionalSeqId(step=0, seq_index=0), |
| PositionalSeqId(step=0, seq_index=1), |
| ], |
| ) |
|
|
|
|
| class TestTokenAlignerSeqInfo: |
| def test_valid(self): |
| info = TokenAlignerSeqInfo( |
| input_ids=[10, 20, 30], |
| positions=[0, 1, 2], |
| locator=TokenLocator(steps=[0, 0, 1], token_index_in_step=[0, 1, 0]), |
| ) |
| assert len(info.input_ids) == 3 |
|
|
| def test_length_mismatch(self): |
| with pytest.raises(ValidationError): |
| TokenAlignerSeqInfo( |
| input_ids=[10, 20, 30], |
| positions=[0, 1, 2], |
| locator=TokenLocator(steps=[0, 0], token_index_in_step=[0, 1, 0]), |
| ) |
|
|
| def test_positions_not_sequential(self): |
| with pytest.raises(ValidationError, match="positions must be"): |
| TokenAlignerSeqInfo( |
| input_ids=[10, 20, 30], |
| positions=[0, 2, 1], |
| locator=TokenLocator(steps=[0, 0, 1], token_index_in_step=[0, 1, 0]), |
| ) |
|
|
|
|
| class TestTokenAlignerPlan: |
| def test_valid(self): |
| plan = TokenAlignerPlan( |
| locators=Pair( |
| x=TokenLocator(steps=[0, 0, 1], token_index_in_step=[0, 1, 0]), |
| y=TokenLocator(steps=[0, 1, 1], token_index_in_step=[0, 0, 1]), |
| ), |
| layouts=Pair(x=TokenLayout.T, y=TokenLayout.T), |
| ) |
| assert len(plan.locators.x.steps) == 3 |
|
|
| def test_length_mismatch(self): |
| with pytest.raises(ValidationError, match="Length mismatch"): |
| TokenAlignerPlan( |
| locators=Pair( |
| x=TokenLocator(steps=[0, 0], token_index_in_step=[0, 1]), |
| y=TokenLocator(steps=[0, 1, 1], token_index_in_step=[0, 0, 1]), |
| ), |
| layouts=Pair(x=TokenLayout.T, y=TokenLayout.T), |
| ) |
|
|
|
|
| class TestSummaryRecord: |
| def test_valid(self): |
| record = SummaryRecord(total=10, passed=7, failed=2, skipped=1) |
| assert record.total == 10 |
|
|
| def test_total_mismatch(self): |
| with pytest.raises(ValidationError, match="total=10"): |
| SummaryRecord(total=10, passed=5, failed=2, skipped=1) |
|
|
| def test_valid_with_errored(self): |
| record = SummaryRecord(total=10, passed=6, failed=2, skipped=1, errored=1) |
| assert record.errored == 1 |
|
|
| def test_total_mismatch_with_errored(self): |
| with pytest.raises(ValidationError, match="total=10"): |
| SummaryRecord(total=10, passed=6, failed=2, skipped=1, errored=0) |
|
|
|
|
| class TestAxisInfo: |
| def test_valid(self): |
| info = AxisInfo(axis_rank=0, axis_size=4) |
| assert info.axis_rank == 0 |
|
|
| def test_axis_size_zero(self): |
| with pytest.raises(ValidationError, match="axis_size must be > 0"): |
| AxisInfo(axis_rank=0, axis_size=0) |
|
|
| def test_axis_size_negative(self): |
| with pytest.raises(ValidationError, match="axis_size must be > 0"): |
| AxisInfo(axis_rank=0, axis_size=-1) |
|
|
| def test_axis_rank_negative(self): |
| with pytest.raises(ValidationError, match="axis_rank must be in"): |
| AxisInfo(axis_rank=-1, axis_size=4) |
|
|
| def test_axis_rank_too_large(self): |
| with pytest.raises(ValidationError, match="axis_rank must be in"): |
| AxisInfo(axis_rank=4, axis_size=4) |
|
|
| def test_axis_rank_equals_size_minus_one(self): |
| info = AxisInfo(axis_rank=3, axis_size=4) |
| assert info.axis_rank == 3 |
|
|
|
|
| def _make_tensor_info() -> TensorInfo: |
| return TensorInfo( |
| shape=[4, 4], |
| dtype="float32", |
| stats=TensorStats(mean=0.0, abs_mean=0.8, std=1.0, min=-2.0, max=2.0), |
| ) |
|
|
|
|
| def _make_diff_info(*, passed: bool) -> DiffInfo: |
| return DiffInfo( |
| rel_diff=0.001, |
| max_abs_diff=0.01, |
| mean_abs_diff=0.005, |
| max_diff_coord=[0, 0], |
| baseline_at_max=1.0, |
| target_at_max=1.01, |
| diff_threshold=1e-3, |
| passed=passed, |
| ) |
|
|
|
|
| def _make_comparison_record( |
| *, |
| diff: DiffInfo | None, |
| errors: list | None = None, |
| ) -> ComparisonTensorRecord: |
| ti: TensorInfo = _make_tensor_info() |
| return ComparisonTensorRecord( |
| name="t", |
| baseline=ti, |
| target=ti, |
| unified_shape=[4, 4], |
| shape_mismatch=False, |
| diff=diff, |
| errors=errors or [], |
| ) |
|
|
|
|
| class TestOutputRecordCategories: |
| def test_skip_record_with_errors_is_failed(self) -> None: |
| record = ComparisonSkipRecord( |
| name="t", |
| reason="test", |
| errors=[ErrorLog(category="c", message="m")], |
| ) |
| assert record.category == "failed" |
|
|
| def test_skip_record_no_warnings_is_skipped(self) -> None: |
| record = ComparisonSkipRecord(name="t", reason="test") |
| assert record.category == "skipped" |
|
|
| def test_comparison_record_diff_none_is_failed(self) -> None: |
| record: ComparisonTensorRecord = _make_comparison_record(diff=None) |
| assert record.category == "failed" |
|
|
| def test_comparison_record_passed_with_errors_is_failed(self) -> None: |
| record: ComparisonTensorRecord = _make_comparison_record( |
| diff=_make_diff_info(passed=True), |
| errors=[ErrorLog(category="c", message="m")], |
| ) |
| assert record.category == "failed" |
|
|
| def test_comparison_record_passed_no_warnings_is_passed(self) -> None: |
| record: ComparisonTensorRecord = _make_comparison_record( |
| diff=_make_diff_info(passed=True), |
| ) |
| assert record.category == "passed" |
|
|
| def test_non_tensor_record_equal_is_passed(self) -> None: |
| record = ComparisonNonTensorRecord( |
| name="sm_scale", |
| baseline_value="0.125", |
| target_value="0.125", |
| baseline_type="float", |
| target_type="float", |
| values_equal=True, |
| ) |
| assert record.category == "passed" |
|
|
| def test_non_tensor_record_different_is_failed(self) -> None: |
| record = ComparisonNonTensorRecord( |
| name="sm_scale", |
| baseline_value="0.125", |
| target_value="0.25", |
| baseline_type="float", |
| target_type="float", |
| values_equal=False, |
| ) |
| assert record.category == "failed" |
|
|
| def test_non_tensor_record_with_errors_is_failed(self) -> None: |
| record = ComparisonNonTensorRecord( |
| name="sm_scale", |
| baseline_value="0.125", |
| target_value="0.125", |
| baseline_type="float", |
| target_type="float", |
| values_equal=True, |
| errors=[ErrorLog(category="c", message="m")], |
| ) |
| assert record.category == "failed" |
|
|
| def test_non_tensor_record_json_roundtrip(self) -> None: |
| record = ComparisonNonTensorRecord( |
| name="sm_scale", |
| baseline_value="0.125", |
| target_value="0.25", |
| baseline_type="float", |
| target_type="float", |
| values_equal=False, |
| ) |
| json_str: str = record.model_dump_json() |
| roundtripped = parse_record_json(json_str) |
| assert isinstance(roundtripped, ComparisonNonTensorRecord) |
| assert roundtripped.name == "sm_scale" |
| assert roundtripped.values_equal is False |
| assert roundtripped.baseline_value == "0.125" |
| assert roundtripped.target_value == "0.25" |
|
|
| def test_non_tensor_record_text_format_equal(self) -> None: |
| record = ComparisonNonTensorRecord( |
| name="sm_scale", |
| baseline_value="0.125", |
| target_value="0.125", |
| baseline_type="float", |
| target_type="float", |
| values_equal=True, |
| ) |
| text: str = record.to_text() |
| assert "sm_scale" in text |
| assert "[equal]" in text |
|
|
| def test_non_tensor_record_text_format_different(self) -> None: |
| record = ComparisonNonTensorRecord( |
| name="sm_scale", |
| baseline_value="0.125", |
| target_value="0.25", |
| baseline_type="float", |
| target_type="float", |
| values_equal=False, |
| ) |
| text: str = record.to_text() |
| assert "baseline" in text |
| assert "target" in text |
|
|
| def test_error_record_category_is_errored(self) -> None: |
| record = ComparisonErrorRecord( |
| name="t", exception_type="ValueError", traceback_str="..." |
| ) |
| assert record.category == "errored" |
|
|
| def test_error_record_json_roundtrip(self) -> None: |
| record = ComparisonErrorRecord( |
| name="t", exception_type="ValueError", traceback_str="traceback..." |
| ) |
| json_str: str = record.model_dump_json() |
| roundtripped = parse_record_json(json_str) |
| assert isinstance(roundtripped, ComparisonErrorRecord) |
| assert roundtripped.name == "t" |
| assert roundtripped.exception_type == "ValueError" |
|
|
| def test_error_record_text_format(self) -> None: |
| record = ComparisonErrorRecord( |
| name="t", exception_type="RuntimeError", traceback_str="Traceback..." |
| ) |
| text: str = record.to_text() |
| assert "RuntimeError" in text |
| assert "Traceback" in text |
|
|
|
|
| def _make_traced_aligner_plan() -> TracedAlignerPlan: |
| unsharder = UnsharderPlan( |
| axis=ParallelAxis.TP, |
| params=ConcatParams(dim_name="h"), |
| groups=[[0, 1]], |
| ) |
| plan = AlignerPlan( |
| per_step_plans=Pair( |
| x=[ |
| AlignerPerStepPlan( |
| step=0, input_object_indices=[0, 1], sub_plans=[unsharder] |
| ) |
| ], |
| y=[ |
| AlignerPerStepPlan( |
| step=0, input_object_indices=[0, 1], sub_plans=[unsharder] |
| ) |
| ], |
| ), |
| ) |
| traced_sub = TracedSubPlan(plan=unsharder, snapshot=None) |
| traced_step = TracedStepPlan( |
| step=0, input_object_indices=[0, 1], sub_plans=[traced_sub] |
| ) |
| return TracedAlignerPlan( |
| plan=plan, |
| per_side=Pair( |
| x=TracedSidePlan(step_plans=[traced_step]), |
| y=TracedSidePlan(step_plans=[traced_step]), |
| ), |
| ) |
|
|
|
|
| class TestAlignerPlanInComparisonTensorRecord: |
| def test_comparison_record_with_traced_plan(self) -> None: |
| traced_plan: TracedAlignerPlan = _make_traced_aligner_plan() |
| record: ComparisonTensorRecord = _make_comparison_record( |
| diff=_make_diff_info(passed=True), |
| ) |
| record_with_plan = record.model_copy(update={"traced_plan": traced_plan}) |
| assert record_with_plan.traced_plan is not None |
| assert record_with_plan.traced_plan.per_side.x.step_plans[0].step == 0 |
|
|
| def test_traced_plan_json_roundtrip(self) -> None: |
| traced_plan: TracedAlignerPlan = _make_traced_aligner_plan() |
| record: ComparisonTensorRecord = _make_comparison_record( |
| diff=_make_diff_info(passed=True), |
| ) |
| record_with_plan = record.model_copy(update={"traced_plan": traced_plan}) |
|
|
| json_str: str = record_with_plan.model_dump_json() |
| parsed = json.loads(json_str) |
| assert "traced_plan" in parsed |
| assert ( |
| parsed["traced_plan"]["per_side"]["x"]["step_plans"][0]["sub_plans"][0][ |
| "plan" |
| ]["type"] |
| == "unsharder" |
| ) |
|
|
| roundtripped: ComparisonTensorRecord = parse_record_json(json_str) |
| assert roundtripped.traced_plan is not None |
| assert ( |
| roundtripped.traced_plan.per_side.x.step_plans[0].sub_plans[0].plan.type |
| == "unsharder" |
| ) |
|
|
| def test_comparison_record_without_traced_plan(self) -> None: |
| record: ComparisonTensorRecord = _make_comparison_record( |
| diff=_make_diff_info(passed=True), |
| ) |
| json_str: str = record.model_dump_json() |
| roundtripped: ComparisonTensorRecord = parse_record_json(json_str) |
| assert roundtripped.traced_plan is None |
|
|
| def test_traced_plan_text_format(self) -> None: |
| traced_plan: TracedAlignerPlan = _make_traced_aligner_plan() |
| record: ComparisonTensorRecord = _make_comparison_record( |
| diff=_make_diff_info(passed=True), |
| ) |
| record_with_plan = record.model_copy(update={"traced_plan": traced_plan}) |
|
|
| text: str = record_with_plan.to_text() |
| assert "Aligner Plan:" in text |
| assert "unsharder" in text |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(pytest.main([__file__])) |
|
|