#include #include "registration.h" #include "torch_binding.h" extern "C" void higgs_dequantize_2_256_ptr_cuda_portable(uint64_t x_ptr, uint64_t grid_ptr, uint64_t out_ptr, int64_t out_dim); extern "C" void higgs_quantize_2_256_ptr_f16_cuda_portable( uint64_t x_ptr, uint64_t grid_ptr, uint64_t grid_norms_ptr, uint64_t out_ptr, int64_t out_dim); extern "C" void higgs_quantize_2_256_ptr_bf16_cuda_portable( uint64_t x_ptr, uint64_t grid_ptr, uint64_t grid_norms_ptr, uint64_t out_ptr, int64_t out_dim); void higgs_dequantize_2_256(torch::Tensor x, torch::Tensor grid, torch::Tensor out) { int64_t out_dim = x.size(0); higgs_dequantize_2_256_ptr_cuda_portable( reinterpret_cast(x.data_ptr()), reinterpret_cast(grid.data_ptr()), reinterpret_cast(out.data_ptr()), out_dim); } void higgs_quantize_2_256_f16(torch::Tensor x, torch::Tensor grid, torch::Tensor grid_norms, torch::Tensor out) { int64_t out_dim = x.size(0); higgs_quantize_2_256_ptr_f16_cuda_portable( reinterpret_cast(x.data_ptr()), reinterpret_cast(grid.data_ptr()), reinterpret_cast(grid_norms.data_ptr()), reinterpret_cast(out.data_ptr()), out_dim); } void higgs_quantize_2_256_bf16(torch::Tensor x, torch::Tensor grid, torch::Tensor grid_norms, torch::Tensor out) { int64_t out_dim = x.size(0); higgs_quantize_2_256_ptr_bf16_cuda_portable( reinterpret_cast(x.data_ptr()), reinterpret_cast(grid.data_ptr()), reinterpret_cast(grid_norms.data_ptr()), reinterpret_cast(out.data_ptr()), out_dim); } TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "higgs_dequantize_2_256(Tensor x, Tensor grid, Tensor! out) -> ()"); ops.impl("higgs_dequantize_2_256", torch::kCUDA, &higgs_dequantize_2_256); ops.def("higgs_quantize_2_256_f16(Tensor x, Tensor grid, Tensor " "grid_norms, Tensor! out) -> ()"); ops.impl("higgs_quantize_2_256_f16", torch::kCUDA, &higgs_quantize_2_256_f16); ops.def("higgs_quantize_2_256_bf16(Tensor x, Tensor grid, Tensor " "grid_norms, Tensor! out) -> ()"); ops.impl("higgs_quantize_2_256_bf16", torch::kCUDA, &higgs_quantize_2_256_bf16); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)