File size: 1,121 Bytes
20347e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | #include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// 4-bit quantization
ops.def(
"bnb_quantize_4bit(Tensor input, int blocksize, int quant_type) "
"-> (Tensor, Tensor)");
// 4-bit dequantization
ops.def(
"bnb_dequantize_4bit(Tensor packed, Tensor absmax, int blocksize, "
"int quant_type, int numel, ScalarType output_dtype) -> Tensor");
// Fused GEMV with 4-bit weights
ops.def(
"bnb_gemv_4bit(Tensor x, Tensor w, Tensor absmax, int blocksize, "
"int quant_type, int output_features) -> Tensor");
// Fused GEMM with 4-bit transposed weights
ops.def(
"bnb_gemm_4bit(Tensor x, Tensor w, Tensor absmax, int blocksize, "
"int quant_type, int output_features) -> Tensor");
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, MPS, ops) {
ops.impl("bnb_quantize_4bit", bnb_quantize_4bit);
ops.impl("bnb_dequantize_4bit", bnb_dequantize_4bit);
ops.impl("bnb_gemv_4bit", bnb_gemv_4bit);
ops.impl("bnb_gemm_4bit", bnb_gemm_4bit);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|