ramanujan-machine-cuda / torch-ext /torch_binding.cpp
cahlen's picture
Add torch library bindings (guard main, add torch wrapper functions)
1839df4 verified
#include <torch/extension.h>
#include "torch_binding.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Ramanujan Machine v2 (Asymmetric-Degree CF Search) CUDA kernel";
m.def("search", &search, py::arg("deg_a"), py::arg("deg_b"),
py::arg("range_a"), py::arg("range_b"), py::arg("cf_depth") = 300);
}