| | |
| | |
| |
|
| | #include <torch/extension.h> |
| |
|
| | |
| | void femx_quantize_impl( |
| | torch::Tensor master, |
| | torch::Tensor tier, |
| | torch::Tensor packed, |
| | torch::Tensor scales, |
| | bool stochastic, |
| | int64_t seed |
| | ); |
| |
|
| | torch::Tensor femx_dequantize_impl( |
| | torch::Tensor packed, |
| | torch::Tensor scales, |
| | torch::Tensor tier, |
| | int64_t block_size |
| | ); |
| |
|
| | void femx_sync_impl( |
| | torch::Tensor master, |
| | torch::Tensor tier, |
| | torch::Tensor packed, |
| | torch::Tensor scales, |
| | torch::Tensor fast_weight, |
| | int64_t seed |
| | ); |
| |
|
| | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| | m.doc() = "FE-MX CUDA kernels: fused quantize/dequantize for Hebbian memory"; |
| | m.def("femx_quantize", &femx_quantize_impl, |
| | "Quantize FP32 master to packed uint8 + E8M0 scales (stochastic rounding)", |
| | py::arg("master"), py::arg("tier"), |
| | py::arg("packed"), py::arg("scales"), |
| | py::arg("stochastic"), py::arg("seed")); |
| | m.def("femx_dequantize", &femx_dequantize_impl, |
| | "Dequantize packed uint8 + E8M0 scales to FP32", |
| | py::arg("packed"), py::arg("scales"), |
| | py::arg("tier"), py::arg("block_size")); |
| | m.def("femx_sync", &femx_sync_impl, |
| | "Fused quantize + dequantize: master FP32 -> packed + BF16 fast_weight", |
| | py::arg("master"), py::arg("tier"), |
| | py::arg("packed"), py::arg("scales"), |
| | py::arg("fast_weight"), py::arg("seed")); |
| | } |
| |
|