| """Shared test helpers for comparator tests.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| from sglang.test.ci.ci_register import register_cpu_ci |
|
|
| register_cpu_ci( |
| est_time=0, suite="default", nightly=True, disabled="helper module, no tests" |
| ) |
|
|
| from sglang.srt.debug_utils.comparator.tensor_comparator.types import ( |
| DiffInfo, |
| TensorInfo, |
| TensorStats, |
| ) |
|
|
| DEFAULT_PERCENTILES: dict[int, float] = { |
| 1: -1.8, |
| 5: -1.5, |
| 50: 0.0, |
| 95: 1.5, |
| 99: 1.8, |
| } |
|
|
| DEFAULT_ABS_DIFF_PERCENTILES: dict[int, float] = { |
| 1: 0.0001, |
| 5: 0.0001, |
| 50: 0.0002, |
| 95: 0.0004, |
| 99: 0.0005, |
| } |
|
|
|
|
| def make_stats( |
| mean: float = 0.0, |
| abs_mean: float = 0.8, |
| std: float = 1.0, |
| min: float = -2.0, |
| max: float = 2.0, |
| percentiles: Optional[dict[int, float]] = None, |
| ) -> TensorStats: |
| return TensorStats( |
| mean=mean, |
| abs_mean=abs_mean, |
| std=std, |
| min=min, |
| max=max, |
| percentiles=percentiles if percentiles is not None else DEFAULT_PERCENTILES, |
| ) |
|
|
|
|
| def make_diff( |
| rel_diff: float = 0.0001, |
| max_abs_diff: float = 0.0005, |
| mean_abs_diff: float = 0.0002, |
| abs_diff_percentiles: Optional[dict[int, float]] = None, |
| diff_threshold: float = 1e-3, |
| passed: bool = True, |
| ) -> DiffInfo: |
| return DiffInfo( |
| rel_diff=rel_diff, |
| max_abs_diff=max_abs_diff, |
| mean_abs_diff=mean_abs_diff, |
| abs_diff_percentiles=( |
| abs_diff_percentiles |
| if abs_diff_percentiles is not None |
| else DEFAULT_ABS_DIFF_PERCENTILES |
| ), |
| max_diff_coord=[2, 3], |
| baseline_at_max=1.0, |
| target_at_max=1.0005, |
| diff_threshold=diff_threshold, |
| passed=passed, |
| ) |
|
|
|
|
| def make_tensor_info( |
| shape: Optional[list[int]] = None, |
| dtype: str = "torch.float32", |
| stats: Optional[TensorStats] = None, |
| sample: Optional[str] = None, |
| ) -> TensorInfo: |
| return TensorInfo( |
| shape=shape if shape is not None else [4, 8], |
| dtype=dtype, |
| stats=stats if stats is not None else make_stats(), |
| sample=sample, |
| ) |
|
|