File size: 1,118 Bytes
d02d576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch


def es_fp8_blockwise_scaled_grouped_mm(
    output,
    a,
    b,
    scales_a,
    scales_b,
    stride_a,
    stride_b,
    stride_d,
    problem_sizes,
    expert_offsets,
    workspace,
):
    torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
        output,
        a,
        b,
        scales_a,
        scales_b,
        stride_a,
        stride_b,
        stride_d,
        problem_sizes,
        expert_offsets,
        workspace,
    )


def es_sm100_mxfp8_blockscaled_grouped_mm(
    output, a, b, sfa, sfb, problem_sizes, expert_offsets, blockscale_offsets
):
    torch.ops.sgl_kernel.es_sm100_mxfp8_blockscaled_grouped_mm.default(
        a, b, sfa, sfb, output, problem_sizes, expert_offsets, blockscale_offsets
    )


def es_sm100_mxfp8_blockscaled_grouped_quant(
    input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor
):
    torch.ops.sgl_kernel.es_sm100_mxfp8_blockscaled_grouped_quant.default(
        input,
        problem_sizes,
        expert_offsets,
        blockscale_offsets,
        quant_output,
        scale_factor,
    )