| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.doc() = "Class Numbers of Real Quadratic Fields CUDA kernel"; | |
| m.def("compute_class_numbers", &compute_class_numbers_torch, | |
| "Compute class numbers h(d) for fundamental discriminants", | |
| py::arg("discriminants")); | |
| } | |