File size: 1,118 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import torch
def es_fp8_blockwise_scaled_grouped_mm(
output,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_d,
problem_sizes,
expert_offsets,
workspace,
):
torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
output,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_d,
problem_sizes,
expert_offsets,
workspace,
)
def es_sm100_mxfp8_blockscaled_grouped_mm(
output, a, b, sfa, sfb, problem_sizes, expert_offsets, blockscale_offsets
):
torch.ops.sgl_kernel.es_sm100_mxfp8_blockscaled_grouped_mm.default(
a, b, sfa, sfb, output, problem_sizes, expert_offsets, blockscale_offsets
)
def es_sm100_mxfp8_blockscaled_grouped_quant(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor
):
torch.ops.sgl_kernel.es_sm100_mxfp8_blockscaled_grouped_quant.default(
input,
problem_sizes,
expert_offsets,
blockscale_offsets,
quant_output,
scale_factor,
)
|