| | #include <torch/library.h> |
| |
|
| | #include "registration.h" |
| | #include "torch_binding.h" |
| |
|
| | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| | |
| | ops.def( |
| | "bnb_quantize_4bit(Tensor input, int blocksize, int quant_type) " |
| | "-> (Tensor, Tensor)"); |
| |
|
| | |
| | ops.def( |
| | "bnb_dequantize_4bit(Tensor packed, Tensor absmax, int blocksize, " |
| | "int quant_type, int numel, ScalarType output_dtype) -> Tensor"); |
| |
|
| | |
| | ops.def( |
| | "bnb_gemv_4bit(Tensor x, Tensor w, Tensor absmax, int blocksize, " |
| | "int quant_type, int output_features) -> Tensor"); |
| |
|
| | |
| | ops.def( |
| | "bnb_gemm_4bit(Tensor x, Tensor w, Tensor absmax, int blocksize, " |
| | "int quant_type, int output_features) -> Tensor"); |
| | } |
| |
|
| | TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, MPS, ops) { |
| | ops.impl("bnb_quantize_4bit", bnb_quantize_4bit); |
| | ops.impl("bnb_dequantize_4bit", bnb_dequantize_4bit); |
| | ops.impl("bnb_gemv_4bit", bnb_gemv_4bit); |
| | ops.impl("bnb_gemm_4bit", bnb_gemm_4bit); |
| | } |
| |
|
| | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
| |
|