kronecker-cuda / torch-ext /torch_binding.cpp
cahlen's picture
Add torch library bindings (guard main, add torch wrapper functions)
ad8dab0 verified
#include <torch/extension.h>
#include "torch_binding.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Kronecker Coefficients (Symmetric Group) CUDA kernel";
m.def("compute_kronecker", &compute_kronecker, py::arg("ct"), py::arg("z_inv"));
}