File size: 334 Bytes
7ec6a18
 
 
 
 
a51a509
 
 
7ec6a18
1
2
3
4
5
6
7
8
9
10
#include <torch/extension.h>
#include "torch_binding.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.doc() = "Class Numbers of Real Quadratic Fields CUDA kernel";
  m.def("compute_class_numbers", &compute_class_numbers_torch,
        "Compute class numbers h(d) for fundamental discriminants",
        py::arg("discriminants"));
}