// FE-MX CUDA Kernels — pybind11 bindings // JIT-compiled via torch.utils.cpp_extension.load() #include // Forward declarations from femx_kernels.cu void femx_quantize_impl( torch::Tensor master, torch::Tensor tier, torch::Tensor packed, torch::Tensor scales, bool stochastic, int64_t seed ); torch::Tensor femx_dequantize_impl( torch::Tensor packed, torch::Tensor scales, torch::Tensor tier, int64_t block_size ); void femx_sync_impl( torch::Tensor master, torch::Tensor tier, torch::Tensor packed, torch::Tensor scales, torch::Tensor fast_weight, int64_t seed ); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FE-MX CUDA kernels: fused quantize/dequantize for Hebbian memory"; m.def("femx_quantize", &femx_quantize_impl, "Quantize FP32 master to packed uint8 + E8M0 scales (stochastic rounding)", py::arg("master"), py::arg("tier"), py::arg("packed"), py::arg("scales"), py::arg("stochastic"), py::arg("seed")); m.def("femx_dequantize", &femx_dequantize_impl, "Dequantize packed uint8 + E8M0 scales to FP32", py::arg("packed"), py::arg("scales"), py::arg("tier"), py::arg("block_size")); m.def("femx_sync", &femx_sync_impl, "Fused quantize + dequantize: master FP32 -> packed + BF16 fast_weight", py::arg("master"), py::arg("tier"), py::arg("packed"), py::arg("scales"), py::arg("fast_weight"), py::arg("seed")); }