| import torch | |
| def es_fp8_blockwise_scaled_grouped_mm( | |
| output, | |
| a, | |
| b, | |
| scales_a, | |
| scales_b, | |
| stride_a, | |
| stride_b, | |
| stride_d, | |
| problem_sizes, | |
| expert_offsets, | |
| workspace, | |
| ): | |
| torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default( | |
| output, | |
| a, | |
| b, | |
| scales_a, | |
| scales_b, | |
| stride_a, | |
| stride_b, | |
| stride_d, | |
| problem_sizes, | |
| expert_offsets, | |
| workspace, | |
| ) | |
| def es_sm100_mxfp8_blockscaled_grouped_mm( | |
| output, a, b, sfa, sfb, problem_sizes, expert_offsets, blockscale_offsets | |
| ): | |
| torch.ops.sgl_kernel.es_sm100_mxfp8_blockscaled_grouped_mm.default( | |
| a, b, sfa, sfb, output, problem_sizes, expert_offsets, blockscale_offsets | |
| ) | |
| def es_sm100_mxfp8_blockscaled_grouped_quant( | |
| input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor | |
| ): | |
| torch.ops.sgl_kernel.es_sm100_mxfp8_blockscaled_grouped_quant.default( | |
| input, | |
| problem_sizes, | |
| expert_offsets, | |
| blockscale_offsets, | |
| quant_output, | |
| scale_factor, | |
| ) | |