File size: 6,984 Bytes
a402b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Benchmarks SGLang kernels versus vLLM across
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse
import itertools
import os
import re
from typing import List, Tuple

import sgl_kernel
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul

# Optional vLLM import
try:
    from vllm import _custom_ops as vllm_ops

    VLLM_AVAILABLE = True
except ImportError:
    vllm_ops = None
    VLLM_AVAILABLE = False

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)

# gelu_quick is only available on HIP/ROCm platforms
try:
    from sgl_kernel import gelu_quick

    GELU_QUICK_AVAILABLE = True
except ImportError:
    GELU_QUICK_AVAILABLE = False
    gelu_quick = None

if VLLM_AVAILABLE and not hasattr(vllm_ops, "silu_and_mul"):
    vllm_ops = torch.ops._C


def str2int_list(arg: str) -> List[int]:
    if arg in ("", None):
        return []
    if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
        raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
    return [int(x) for x in arg.split(",")]


def calculate_diff(
    kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
) -> bool:
    """Compare vLLM with SGLang for one shape."""
    device = torch.device("cuda")

    if not VLLM_AVAILABLE:
        print(
            f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
            f"L={seq_len:3d} | D={dim:5d}] ⚠️  vLLM not available, skipping comparison"
        )
        return True

    # activation-only quick GELU
    if kernel == "gelu_quick":
        if not GELU_QUICK_AVAILABLE:
            print(
                f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
                f"L={seq_len:3d} | D={dim:5d}] ⚠️  not available on this platform"
            )
            return True
        x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
        ref_out = torch.zeros_like(x)
        getattr(vllm_ops, kernel)(ref_out, x)
        test_out = getattr(sgl_kernel, kernel)(x)
    # fused activation x mul kernels
    else:
        x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
        ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
        getattr(vllm_ops, kernel)(ref_out, x)
        test_out = getattr(sgl_kernel, kernel)(x)

    ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
    tag = "✅ match" if ok else "❌ mismatch"
    print(
        f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
        f"L={seq_len:3d} | D={dim:5d}] {tag}"
    )
    return ok


# CI environment uses simplified parameters for kernels and dtypes too
if IS_CI:
    kernels = ["silu_and_mul"]  # Only test one kernel in CI
    dtypes = [torch.float16]  # Only test one dtype in CI
else:
    kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
    if GELU_QUICK_AVAILABLE:
        kernels.append("gelu_quick")
    dtypes = [torch.float16, torch.bfloat16]


def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
    return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))


# CI environment uses simplified parameters
if IS_CI:
    default_batch_sizes = [1]  # Single batch size for CI
    default_seq_lens = [1]  # Single sequence length for CI
    default_dims = [1024]  # Single dimension for CI
else:
    default_batch_sizes = [2**i for i in range(0, 5, 2)]  # 1,4,16
    default_seq_lens = [2**i for i in range(0, 8, 2)]  # 1,4,16,64
    default_dims = [2**i for i in range(10, 15)]  # 1024...16384


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
        x_vals=[],
        line_arg="provider",
        line_vals=["vllm", "sglang", "speedup"],
        line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
        styles=[("blue", "-"), ("green", "-"), ("red", "--")],
        ylabel="µs (median)  or  × (speed-up)",
        plot_name="activation-performance",
        args={},
    )
)
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
    device = torch.device("cuda")
    in_mult = 1 if kernel == "gelu_quick" else 2
    x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
    y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)

    if not VLLM_AVAILABLE and provider in ["vllm", "speedup"]:
        # Skip vLLM-related benchmarks if vLLM is not available
        return (0, 0, 0)

    if VLLM_AVAILABLE:
        vllm_kernel = getattr(vllm_ops, kernel)
    if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
        # Skip benchmark for gelu_quick if not available
        return (0, 0, 0)
    sglang_kernel = getattr(sgl_kernel, kernel)

    def baseline():
        if VLLM_AVAILABLE:
            tmp = y0.clone()
            vllm_kernel(tmp, x)
            return tmp
        else:
            return torch.zeros_like(y0)

    def sglang():
        return sglang_kernel(x)

    # timing helper
    def timed(fn):
        for _ in range(5):
            fn()
        torch.cuda.synchronize()
        ms, qmin, qmax = triton.testing.do_bench_cudagraph(
            fn, quantiles=[0.5, 0.2, 0.8]
        )
        return 1000 * ms, 1000 * qmax, 1000 * qmin

    if provider == "vllm":
        return timed(baseline)
    if provider == "sglang":
        return timed(sglang)

    # provider == "speedup"
    t_ref, _, _ = timed(baseline)
    t_sgl, _, _ = timed(sglang)
    spd = t_ref / t_sgl if t_ref > 0 else 1.0
    return (spd, spd, spd)


if __name__ == "__main__":
    p = argparse.ArgumentParser("Activation kernel benchmark")
    p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
    p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
    p.add_argument("--dims", type=str2int_list, default=default_dims)
    p.add_argument("--verify_only", action="store_true")
    args = p.parse_args()

    # coerce lists
    if isinstance(args.batch_sizes, str):
        args.batch_sizes = str2int_list(args.batch_sizes)
    if isinstance(args.seq_lens, str):
        args.seq_lens = str2int_list(args.seq_lens)
    if isinstance(args.dims, str):
        args.dims = str2int_list(args.dims)

    # patch perf_report grid
    benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
    if hasattr(benchmark, "benchmarks"):
        benchmark.benchmarks.x_vals = benchmark_grid
    else:
        benchmark.benchmark.x_vals = benchmark_grid

    if args.verify_only:
        # Test with the first available kernel
        test_kernel = kernels[0]
        ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0])
        print("✅ sanity pass" if ok else "❌ mismatch")
    else:
        benchmark.run(print_data=True)