File size: 2,126 Bytes
a402b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Shared test helpers for comparator tests."""

from __future__ import annotations

from typing import Optional

from sglang.test.ci.ci_register import register_cpu_ci

register_cpu_ci(
    est_time=0, suite="default", nightly=True, disabled="helper module, no tests"
)

from sglang.srt.debug_utils.comparator.tensor_comparator.types import (
    DiffInfo,
    TensorInfo,
    TensorStats,
)

DEFAULT_PERCENTILES: dict[int, float] = {
    1: -1.8,
    5: -1.5,
    50: 0.0,
    95: 1.5,
    99: 1.8,
}

DEFAULT_ABS_DIFF_PERCENTILES: dict[int, float] = {
    1: 0.0001,
    5: 0.0001,
    50: 0.0002,
    95: 0.0004,
    99: 0.0005,
}


def make_stats(
    mean: float = 0.0,
    abs_mean: float = 0.8,
    std: float = 1.0,
    min: float = -2.0,
    max: float = 2.0,
    percentiles: Optional[dict[int, float]] = None,
) -> TensorStats:
    return TensorStats(
        mean=mean,
        abs_mean=abs_mean,
        std=std,
        min=min,
        max=max,
        percentiles=percentiles if percentiles is not None else DEFAULT_PERCENTILES,
    )


def make_diff(
    rel_diff: float = 0.0001,
    max_abs_diff: float = 0.0005,
    mean_abs_diff: float = 0.0002,
    abs_diff_percentiles: Optional[dict[int, float]] = None,
    diff_threshold: float = 1e-3,
    passed: bool = True,
) -> DiffInfo:
    return DiffInfo(
        rel_diff=rel_diff,
        max_abs_diff=max_abs_diff,
        mean_abs_diff=mean_abs_diff,
        abs_diff_percentiles=(
            abs_diff_percentiles
            if abs_diff_percentiles is not None
            else DEFAULT_ABS_DIFF_PERCENTILES
        ),
        max_diff_coord=[2, 3],
        baseline_at_max=1.0,
        target_at_max=1.0005,
        diff_threshold=diff_threshold,
        passed=passed,
    )


def make_tensor_info(
    shape: Optional[list[int]] = None,
    dtype: str = "torch.float32",
    stats: Optional[TensorStats] = None,
    sample: Optional[str] = None,
) -> TensorInfo:
    return TensorInfo(
        shape=shape if shape is not None else [4, 8],
        dtype=dtype,
        stats=stats if stats is not None else make_stats(),
        sample=sample,
    )