FireEcho / FireEcho Engine /csrc /femx_bindings.cpp
Joysulem's picture
Upload 3258 files
b5bff9c verified
// FE-MX CUDA Kernels — pybind11 bindings
// JIT-compiled via torch.utils.cpp_extension.load()
#include <torch/extension.h>
// Forward declarations from femx_kernels.cu
void femx_quantize_impl(
torch::Tensor master,
torch::Tensor tier,
torch::Tensor packed,
torch::Tensor scales,
bool stochastic,
int64_t seed
);
torch::Tensor femx_dequantize_impl(
torch::Tensor packed,
torch::Tensor scales,
torch::Tensor tier,
int64_t block_size
);
void femx_sync_impl(
torch::Tensor master,
torch::Tensor tier,
torch::Tensor packed,
torch::Tensor scales,
torch::Tensor fast_weight,
int64_t seed
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FE-MX CUDA kernels: fused quantize/dequantize for Hebbian memory";
m.def("femx_quantize", &femx_quantize_impl,
"Quantize FP32 master to packed uint8 + E8M0 scales (stochastic rounding)",
py::arg("master"), py::arg("tier"),
py::arg("packed"), py::arg("scales"),
py::arg("stochastic"), py::arg("seed"));
m.def("femx_dequantize", &femx_dequantize_impl,
"Dequantize packed uint8 + E8M0 scales to FP32",
py::arg("packed"), py::arg("scales"),
py::arg("tier"), py::arg("block_size"));
m.def("femx_sync", &femx_sync_impl,
"Fused quantize + dequantize: master FP32 -> packed + BF16 fast_weight",
py::arg("master"), py::arg("tier"),
py::arg("packed"), py::arg("scales"),
py::arg("fast_weight"), py::arg("seed"));
}