Add torch library bindings (guard main, add torch wrapper functions)
Browse files
hausdorff_spectrum/hausdorff_spectrum.cu
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
/*
|
| 2 |
* Hausdorff Dimension Spectrum of Continued Fraction Cantor Sets
|
| 3 |
*
|
|
@@ -200,6 +204,8 @@ void format_subset(uint32_t mask, int max_d, char *buf, int buflen) {
|
|
| 200 |
* Host: main
|
| 201 |
* ============================================================ */
|
| 202 |
|
|
|
|
|
|
|
| 203 |
int main(int argc, char **argv) {
|
| 204 |
int max_d = argc > 1 ? atoi(argv[1]) : 10;
|
| 205 |
int N = argc > 2 ? atoi(argv[2]) : 40;
|
|
@@ -384,3 +390,29 @@ int main(int argc, char **argv) {
|
|
| 384 |
|
| 385 |
return 0;
|
| 386 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifdef TORCH_EXTENSION_NAME
|
| 2 |
+
#include <torch/torch.h>
|
| 3 |
+
#endif
|
| 4 |
+
|
| 5 |
/*
|
| 6 |
* Hausdorff Dimension Spectrum of Continued Fraction Cantor Sets
|
| 7 |
*
|
|
|
|
| 204 |
* Host: main
|
| 205 |
* ============================================================ */
|
| 206 |
|
| 207 |
+
#ifndef TORCH_EXTENSION_NAME
|
| 208 |
+
|
| 209 |
int main(int argc, char **argv) {
|
| 210 |
int max_d = argc > 1 ? atoi(argv[1]) : 10;
|
| 211 |
int N = argc > 2 ? atoi(argv[2]) : 40;
|
|
|
|
| 390 |
|
| 391 |
return 0;
|
| 392 |
}
|
| 393 |
+
#endif
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
#ifdef TORCH_EXTENSION_NAME
|
| 397 |
+
torch::Tensor hausdorff_all(int64_t max_d, int64_t N) {
|
| 398 |
+
TORCH_CHECK(max_d >= 1 && max_d <= 24, "max_d must be 1-24");
|
| 399 |
+
TORCH_CHECK(N >= 1 && N <= 48, "N must be 1-48");
|
| 400 |
+
|
| 401 |
+
int total = (1 << max_d) - 1;
|
| 402 |
+
auto results = torch::zeros({total}, torch::dtype(torch::kFloat64));
|
| 403 |
+
double *d_results;
|
| 404 |
+
cudaMalloc(&d_results, 1024 * sizeof(double));
|
| 405 |
+
|
| 406 |
+
for (int offset = 0; offset < total; offset += 1024) {
|
| 407 |
+
int batch = std::min(1024, total - offset);
|
| 408 |
+
uint32_t start_mask = offset + 1;
|
| 409 |
+
batch_hausdorff<<<batch, 1>>>(start_mask, batch, (int)max_d, (int)N, d_results);
|
| 410 |
+
cudaDeviceSynchronize();
|
| 411 |
+
cudaMemcpy(results.data_ptr<double>() + offset, d_results,
|
| 412 |
+
batch * sizeof(double), cudaMemcpyDeviceToHost);
|
| 413 |
+
}
|
| 414 |
+
cudaFree(d_results);
|
| 415 |
+
return results;
|
| 416 |
+
}
|
| 417 |
+
#endif
|
| 418 |
+
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -3,4 +3,7 @@
|
|
| 3 |
|
| 4 |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 5 |
m.doc() = "Hausdorff Dimension Spectrum CUDA kernel";
|
|
|
|
|
|
|
|
|
|
| 6 |
}
|
|
|
|
| 3 |
|
| 4 |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 5 |
m.doc() = "Hausdorff Dimension Spectrum CUDA kernel";
|
| 6 |
+
m.def("hausdorff_all", &hausdorff_all,
|
| 7 |
+
"Compute Hausdorff dimensions for all subsets of {1,...,max_d}",
|
| 8 |
+
py::arg("max_d"), py::arg("N"));
|
| 9 |
}
|
torch-ext/torch_binding.h
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
#pragma once
|
| 2 |
#include <torch/torch.h>
|
| 3 |
-
|
|
|
|
|
|
| 1 |
#pragma once
|
| 2 |
#include <torch/torch.h>
|
| 3 |
+
|
| 4 |
+
torch::Tensor hausdorff_all(int64_t max_d, int64_t N);
|