cahlen commited on
Commit
21969c8
·
verified ·
1 Parent(s): c96ca24

Add torch library bindings (guard main, add torch wrapper functions)

Browse files
hausdorff_spectrum/hausdorff_spectrum.cu CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  /*
2
  * Hausdorff Dimension Spectrum of Continued Fraction Cantor Sets
3
  *
@@ -200,6 +204,8 @@ void format_subset(uint32_t mask, int max_d, char *buf, int buflen) {
200
  * Host: main
201
  * ============================================================ */
202
 
 
 
203
  int main(int argc, char **argv) {
204
  int max_d = argc > 1 ? atoi(argv[1]) : 10;
205
  int N = argc > 2 ? atoi(argv[2]) : 40;
@@ -384,3 +390,29 @@ int main(int argc, char **argv) {
384
 
385
  return 0;
386
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef TORCH_EXTENSION_NAME
2
+ #include <torch/torch.h>
3
+ #endif
4
+
5
  /*
6
  * Hausdorff Dimension Spectrum of Continued Fraction Cantor Sets
7
  *
 
204
  * Host: main
205
  * ============================================================ */
206
 
207
+ #ifndef TORCH_EXTENSION_NAME
208
+
209
  int main(int argc, char **argv) {
210
  int max_d = argc > 1 ? atoi(argv[1]) : 10;
211
  int N = argc > 2 ? atoi(argv[2]) : 40;
 
390
 
391
  return 0;
392
  }
393
+ #endif
394
+
395
+
396
+ #ifdef TORCH_EXTENSION_NAME
397
+ torch::Tensor hausdorff_all(int64_t max_d, int64_t N) {
398
+ TORCH_CHECK(max_d >= 1 && max_d <= 24, "max_d must be 1-24");
399
+ TORCH_CHECK(N >= 1 && N <= 48, "N must be 1-48");
400
+
401
+ int total = (1 << max_d) - 1;
402
+ auto results = torch::zeros({total}, torch::dtype(torch::kFloat64));
403
+ double *d_results;
404
+ cudaMalloc(&d_results, 1024 * sizeof(double));
405
+
406
+ for (int offset = 0; offset < total; offset += 1024) {
407
+ int batch = std::min(1024, total - offset);
408
+ uint32_t start_mask = offset + 1;
409
+ batch_hausdorff<<<batch, 1>>>(start_mask, batch, (int)max_d, (int)N, d_results);
410
+ cudaDeviceSynchronize();
411
+ cudaMemcpy(results.data_ptr<double>() + offset, d_results,
412
+ batch * sizeof(double), cudaMemcpyDeviceToHost);
413
+ }
414
+ cudaFree(d_results);
415
+ return results;
416
+ }
417
+ #endif
418
+
torch-ext/torch_binding.cpp CHANGED
@@ -3,4 +3,7 @@
3
 
4
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5
  m.doc() = "Hausdorff Dimension Spectrum CUDA kernel";
 
 
 
6
  }
 
3
 
4
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5
  m.doc() = "Hausdorff Dimension Spectrum CUDA kernel";
6
+ m.def("hausdorff_all", &hausdorff_all,
7
+ "Compute Hausdorff dimensions for all subsets of {1,...,max_d}",
8
+ py::arg("max_d"), py::arg("N"));
9
  }
torch-ext/torch_binding.h CHANGED
@@ -1,3 +1,4 @@
1
  #pragma once
2
  #include <torch/torch.h>
3
- // See hausdorff_spectrum/hausdorff_spectrum.cu for kernel API
 
 
1
  #pragma once
2
  #include <torch/torch.h>
3
+
4
+ torch::Tensor hausdorff_all(int64_t max_d, int64_t N);