Hanrui / sglang /test /registered /debug_utils /comparator /test_model_validation.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
import json
import sys
import pytest
from pydantic import ValidationError
from sglang.srt.debug_utils.comparator.aligner.entrypoint.traced_types import (
TracedAlignerPlan,
TracedSidePlan,
TracedStepPlan,
TracedSubPlan,
)
from sglang.srt.debug_utils.comparator.aligner.entrypoint.types import (
AlignerPerStepPlan,
AlignerPlan,
)
from sglang.srt.debug_utils.comparator.aligner.token_aligner.smart.types import (
PositionalSeqId,
TokenAlignerPlan,
TokenAlignerSeqInfo,
TokenAlignerStepAux,
TokenLocator,
)
from sglang.srt.debug_utils.comparator.aligner.unsharder.types import (
AxisInfo,
ConcatParams,
UnsharderPlan,
)
from sglang.srt.debug_utils.comparator.dims_spec import ParallelAxis, TokenLayout
from sglang.srt.debug_utils.comparator.output_types import (
ComparisonErrorRecord,
ComparisonNonTensorRecord,
ComparisonSkipRecord,
ComparisonTensorRecord,
ErrorLog,
SummaryRecord,
parse_record_json,
)
from sglang.srt.debug_utils.comparator.tensor_comparator.types import (
DiffInfo,
TensorInfo,
TensorStats,
)
from sglang.srt.debug_utils.comparator.utils import Pair, _check_equal_lengths
from sglang.test.ci.ci_register import register_cpu_ci
register_cpu_ci(est_time=10, suite="default", nightly=True)
class TestCheckEqualLengths:
def test_all_equal(self):
_check_equal_lengths(a=[1, 2], b=[3, 4])
def test_empty_lists(self):
_check_equal_lengths(a=[], b=[])
def test_mismatch_raises(self):
with pytest.raises(ValueError, match="Length mismatch"):
_check_equal_lengths(a=[1, 2], b=[3])
class TestTokenAlignerStepAux:
def test_valid(self):
aux = TokenAlignerStepAux(
input_ids=[10, 20, 30],
positions=[0, 1, 2],
seq_lens=[2, 1],
seq_ids=[
PositionalSeqId(step=0, seq_index=0),
PositionalSeqId(step=0, seq_index=1),
],
)
assert len(aux.input_ids) == 3
def test_token_length_mismatch(self):
with pytest.raises(ValueError, match="Length mismatch"):
TokenAlignerStepAux(
input_ids=[10, 20, 30],
positions=[0, 1],
seq_lens=[2, 1],
seq_ids=[
PositionalSeqId(step=0, seq_index=0),
PositionalSeqId(step=0, seq_index=1),
],
)
def test_seq_length_mismatch(self):
with pytest.raises(ValueError, match="Length mismatch"):
TokenAlignerStepAux(
input_ids=[10, 20, 30],
positions=[0, 1, 2],
seq_lens=[2, 1],
seq_ids=[PositionalSeqId(step=0, seq_index=0)],
)
def test_sum_seq_lens_mismatch(self):
with pytest.raises(ValueError, match="sum\\(seq_lens\\)"):
TokenAlignerStepAux(
input_ids=[10, 20, 30],
positions=[0, 1, 2],
seq_lens=[1, 1],
seq_ids=[
PositionalSeqId(step=0, seq_index=0),
PositionalSeqId(step=0, seq_index=1),
],
)
class TestTokenAlignerSeqInfo:
def test_valid(self):
info = TokenAlignerSeqInfo(
input_ids=[10, 20, 30],
positions=[0, 1, 2],
locator=TokenLocator(steps=[0, 0, 1], token_index_in_step=[0, 1, 0]),
)
assert len(info.input_ids) == 3
def test_length_mismatch(self):
with pytest.raises(ValidationError):
TokenAlignerSeqInfo(
input_ids=[10, 20, 30],
positions=[0, 1, 2],
locator=TokenLocator(steps=[0, 0], token_index_in_step=[0, 1, 0]),
)
def test_positions_not_sequential(self):
with pytest.raises(ValidationError, match="positions must be"):
TokenAlignerSeqInfo(
input_ids=[10, 20, 30],
positions=[0, 2, 1],
locator=TokenLocator(steps=[0, 0, 1], token_index_in_step=[0, 1, 0]),
)
class TestTokenAlignerPlan:
def test_valid(self):
plan = TokenAlignerPlan(
locators=Pair(
x=TokenLocator(steps=[0, 0, 1], token_index_in_step=[0, 1, 0]),
y=TokenLocator(steps=[0, 1, 1], token_index_in_step=[0, 0, 1]),
),
layouts=Pair(x=TokenLayout.T, y=TokenLayout.T),
)
assert len(plan.locators.x.steps) == 3
def test_length_mismatch(self):
with pytest.raises(ValidationError, match="Length mismatch"):
TokenAlignerPlan(
locators=Pair(
x=TokenLocator(steps=[0, 0], token_index_in_step=[0, 1]),
y=TokenLocator(steps=[0, 1, 1], token_index_in_step=[0, 0, 1]),
),
layouts=Pair(x=TokenLayout.T, y=TokenLayout.T),
)
class TestSummaryRecord:
def test_valid(self):
record = SummaryRecord(total=10, passed=7, failed=2, skipped=1)
assert record.total == 10
def test_total_mismatch(self):
with pytest.raises(ValidationError, match="total=10"):
SummaryRecord(total=10, passed=5, failed=2, skipped=1)
def test_valid_with_errored(self):
record = SummaryRecord(total=10, passed=6, failed=2, skipped=1, errored=1)
assert record.errored == 1
def test_total_mismatch_with_errored(self):
with pytest.raises(ValidationError, match="total=10"):
SummaryRecord(total=10, passed=6, failed=2, skipped=1, errored=0)
class TestAxisInfo:
def test_valid(self):
info = AxisInfo(axis_rank=0, axis_size=4)
assert info.axis_rank == 0
def test_axis_size_zero(self):
with pytest.raises(ValidationError, match="axis_size must be > 0"):
AxisInfo(axis_rank=0, axis_size=0)
def test_axis_size_negative(self):
with pytest.raises(ValidationError, match="axis_size must be > 0"):
AxisInfo(axis_rank=0, axis_size=-1)
def test_axis_rank_negative(self):
with pytest.raises(ValidationError, match="axis_rank must be in"):
AxisInfo(axis_rank=-1, axis_size=4)
def test_axis_rank_too_large(self):
with pytest.raises(ValidationError, match="axis_rank must be in"):
AxisInfo(axis_rank=4, axis_size=4)
def test_axis_rank_equals_size_minus_one(self):
info = AxisInfo(axis_rank=3, axis_size=4)
assert info.axis_rank == 3
def _make_tensor_info() -> TensorInfo:
return TensorInfo(
shape=[4, 4],
dtype="float32",
stats=TensorStats(mean=0.0, abs_mean=0.8, std=1.0, min=-2.0, max=2.0),
)
def _make_diff_info(*, passed: bool) -> DiffInfo:
return DiffInfo(
rel_diff=0.001,
max_abs_diff=0.01,
mean_abs_diff=0.005,
max_diff_coord=[0, 0],
baseline_at_max=1.0,
target_at_max=1.01,
diff_threshold=1e-3,
passed=passed,
)
def _make_comparison_record(
*,
diff: DiffInfo | None,
errors: list | None = None,
) -> ComparisonTensorRecord:
ti: TensorInfo = _make_tensor_info()
return ComparisonTensorRecord(
name="t",
baseline=ti,
target=ti,
unified_shape=[4, 4],
shape_mismatch=False,
diff=diff,
errors=errors or [],
)
class TestOutputRecordCategories:
def test_skip_record_with_errors_is_failed(self) -> None:
record = ComparisonSkipRecord(
name="t",
reason="test",
errors=[ErrorLog(category="c", message="m")],
)
assert record.category == "failed"
def test_skip_record_no_warnings_is_skipped(self) -> None:
record = ComparisonSkipRecord(name="t", reason="test")
assert record.category == "skipped"
def test_comparison_record_diff_none_is_failed(self) -> None:
record: ComparisonTensorRecord = _make_comparison_record(diff=None)
assert record.category == "failed"
def test_comparison_record_passed_with_errors_is_failed(self) -> None:
record: ComparisonTensorRecord = _make_comparison_record(
diff=_make_diff_info(passed=True),
errors=[ErrorLog(category="c", message="m")],
)
assert record.category == "failed"
def test_comparison_record_passed_no_warnings_is_passed(self) -> None:
record: ComparisonTensorRecord = _make_comparison_record(
diff=_make_diff_info(passed=True),
)
assert record.category == "passed"
def test_non_tensor_record_equal_is_passed(self) -> None:
record = ComparisonNonTensorRecord(
name="sm_scale",
baseline_value="0.125",
target_value="0.125",
baseline_type="float",
target_type="float",
values_equal=True,
)
assert record.category == "passed"
def test_non_tensor_record_different_is_failed(self) -> None:
record = ComparisonNonTensorRecord(
name="sm_scale",
baseline_value="0.125",
target_value="0.25",
baseline_type="float",
target_type="float",
values_equal=False,
)
assert record.category == "failed"
def test_non_tensor_record_with_errors_is_failed(self) -> None:
record = ComparisonNonTensorRecord(
name="sm_scale",
baseline_value="0.125",
target_value="0.125",
baseline_type="float",
target_type="float",
values_equal=True,
errors=[ErrorLog(category="c", message="m")],
)
assert record.category == "failed"
def test_non_tensor_record_json_roundtrip(self) -> None:
record = ComparisonNonTensorRecord(
name="sm_scale",
baseline_value="0.125",
target_value="0.25",
baseline_type="float",
target_type="float",
values_equal=False,
)
json_str: str = record.model_dump_json()
roundtripped = parse_record_json(json_str)
assert isinstance(roundtripped, ComparisonNonTensorRecord)
assert roundtripped.name == "sm_scale"
assert roundtripped.values_equal is False
assert roundtripped.baseline_value == "0.125"
assert roundtripped.target_value == "0.25"
def test_non_tensor_record_text_format_equal(self) -> None:
record = ComparisonNonTensorRecord(
name="sm_scale",
baseline_value="0.125",
target_value="0.125",
baseline_type="float",
target_type="float",
values_equal=True,
)
text: str = record.to_text()
assert "sm_scale" in text
assert "[equal]" in text
def test_non_tensor_record_text_format_different(self) -> None:
record = ComparisonNonTensorRecord(
name="sm_scale",
baseline_value="0.125",
target_value="0.25",
baseline_type="float",
target_type="float",
values_equal=False,
)
text: str = record.to_text()
assert "baseline" in text
assert "target" in text
def test_error_record_category_is_errored(self) -> None:
record = ComparisonErrorRecord(
name="t", exception_type="ValueError", traceback_str="..."
)
assert record.category == "errored"
def test_error_record_json_roundtrip(self) -> None:
record = ComparisonErrorRecord(
name="t", exception_type="ValueError", traceback_str="traceback..."
)
json_str: str = record.model_dump_json()
roundtripped = parse_record_json(json_str)
assert isinstance(roundtripped, ComparisonErrorRecord)
assert roundtripped.name == "t"
assert roundtripped.exception_type == "ValueError"
def test_error_record_text_format(self) -> None:
record = ComparisonErrorRecord(
name="t", exception_type="RuntimeError", traceback_str="Traceback..."
)
text: str = record.to_text()
assert "RuntimeError" in text
assert "Traceback" in text
def _make_traced_aligner_plan() -> TracedAlignerPlan:
unsharder = UnsharderPlan(
axis=ParallelAxis.TP,
params=ConcatParams(dim_name="h"),
groups=[[0, 1]],
)
plan = AlignerPlan(
per_step_plans=Pair(
x=[
AlignerPerStepPlan(
step=0, input_object_indices=[0, 1], sub_plans=[unsharder]
)
],
y=[
AlignerPerStepPlan(
step=0, input_object_indices=[0, 1], sub_plans=[unsharder]
)
],
),
)
traced_sub = TracedSubPlan(plan=unsharder, snapshot=None)
traced_step = TracedStepPlan(
step=0, input_object_indices=[0, 1], sub_plans=[traced_sub]
)
return TracedAlignerPlan(
plan=plan,
per_side=Pair(
x=TracedSidePlan(step_plans=[traced_step]),
y=TracedSidePlan(step_plans=[traced_step]),
),
)
class TestAlignerPlanInComparisonTensorRecord:
def test_comparison_record_with_traced_plan(self) -> None:
traced_plan: TracedAlignerPlan = _make_traced_aligner_plan()
record: ComparisonTensorRecord = _make_comparison_record(
diff=_make_diff_info(passed=True),
)
record_with_plan = record.model_copy(update={"traced_plan": traced_plan})
assert record_with_plan.traced_plan is not None
assert record_with_plan.traced_plan.per_side.x.step_plans[0].step == 0
def test_traced_plan_json_roundtrip(self) -> None:
traced_plan: TracedAlignerPlan = _make_traced_aligner_plan()
record: ComparisonTensorRecord = _make_comparison_record(
diff=_make_diff_info(passed=True),
)
record_with_plan = record.model_copy(update={"traced_plan": traced_plan})
json_str: str = record_with_plan.model_dump_json()
parsed = json.loads(json_str)
assert "traced_plan" in parsed
assert (
parsed["traced_plan"]["per_side"]["x"]["step_plans"][0]["sub_plans"][0][
"plan"
]["type"]
== "unsharder"
)
roundtripped: ComparisonTensorRecord = parse_record_json(json_str)
assert roundtripped.traced_plan is not None
assert (
roundtripped.traced_plan.per_side.x.step_plans[0].sub_plans[0].plan.type
== "unsharder"
)
def test_comparison_record_without_traced_plan(self) -> None:
record: ComparisonTensorRecord = _make_comparison_record(
diff=_make_diff_info(passed=True),
)
json_str: str = record.model_dump_json()
roundtripped: ComparisonTensorRecord = parse_record_json(json_str)
assert roundtripped.traced_plan is None
def test_traced_plan_text_format(self) -> None:
traced_plan: TracedAlignerPlan = _make_traced_aligner_plan()
record: ComparisonTensorRecord = _make_comparison_record(
diff=_make_diff_info(passed=True),
)
record_with_plan = record.model_copy(update={"traced_plan": traced_plan})
text: str = record_with_plan.to_text()
assert "Aligner Plan:" in text
assert "unsharder" in text
if __name__ == "__main__":
sys.exit(pytest.main([__file__]))