Add torch library bindings (guard main, add torch wrapper functions)
Browse files
minkowski_spectrum/minkowski_spectrum.cu
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
/*
|
| 2 |
* Multifractal Singularity Spectrum of the Minkowski Question Mark Function
|
| 3 |
*
|
|
@@ -166,6 +170,8 @@ __global__ void compute_tau(int num_q, double q_min, double q_step,
|
|
| 166 |
|
| 167 |
/* ---- Host ---- */
|
| 168 |
|
|
|
|
|
|
|
| 169 |
int main(int argc, char **argv) {
|
| 170 |
int A_max = argc > 1 ? atoi(argv[1]) : 50;
|
| 171 |
int N = argc > 2 ? atoi(argv[2]) : 40;
|
|
@@ -318,3 +324,47 @@ int main(int argc, char **argv) {
|
|
| 318 |
free(h_tau); free(h_q); free(h_alpha); free(h_f);
|
| 319 |
return 0;
|
| 320 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifdef TORCH_EXTENSION_NAME
|
| 2 |
+
#include <torch/torch.h>
|
| 3 |
+
#endif
|
| 4 |
+
|
| 5 |
/*
|
| 6 |
* Multifractal Singularity Spectrum of the Minkowski Question Mark Function
|
| 7 |
*
|
|
|
|
| 170 |
|
| 171 |
/* ---- Host ---- */
|
| 172 |
|
| 173 |
+
#ifndef TORCH_EXTENSION_NAME
|
| 174 |
+
|
| 175 |
int main(int argc, char **argv) {
|
| 176 |
int A_max = argc > 1 ? atoi(argv[1]) : 50;
|
| 177 |
int N = argc > 2 ? atoi(argv[2]) : 40;
|
|
|
|
| 324 |
free(h_tau); free(h_q); free(h_alpha); free(h_f);
|
| 325 |
return 0;
|
| 326 |
}
|
| 327 |
+
#endif
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
#ifdef TORCH_EXTENSION_NAME
|
| 331 |
+
std::vector<torch::Tensor> compute_spectrum(int64_t A_max, int64_t N) {
|
| 332 |
+
TORCH_CHECK(A_max >= 1 && A_max <= 100, "A_max must be 1-100");
|
| 333 |
+
TORCH_CHECK(N >= 1 && N <= 48, "N must be 1-48");
|
| 334 |
+
|
| 335 |
+
// Q_MIN, Q_MAX, Q_STEP, Q_COUNT are already defined as macros above
|
| 336 |
+
double *d_tau;
|
| 337 |
+
cudaMalloc(&d_tau, Q_COUNT * sizeof(double));
|
| 338 |
+
|
| 339 |
+
int threads = 32;
|
| 340 |
+
int blocks = (Q_COUNT + threads - 1) / threads;
|
| 341 |
+
compute_tau<<<blocks, threads>>>(Q_COUNT, Q_MIN, Q_STEP, (int)A_max, (int)N, d_tau);
|
| 342 |
+
cudaDeviceSynchronize();
|
| 343 |
+
|
| 344 |
+
auto tau = torch::zeros({Q_COUNT}, torch::dtype(torch::kFloat64));
|
| 345 |
+
cudaMemcpy(tau.data_ptr<double>(), d_tau, Q_COUNT * sizeof(double), cudaMemcpyDeviceToHost);
|
| 346 |
+
cudaFree(d_tau);
|
| 347 |
+
|
| 348 |
+
auto q = torch::zeros({Q_COUNT}, torch::dtype(torch::kFloat64));
|
| 349 |
+
auto alpha = torch::zeros({Q_COUNT}, torch::dtype(torch::kFloat64));
|
| 350 |
+
auto f_alpha = torch::zeros({Q_COUNT}, torch::dtype(torch::kFloat64));
|
| 351 |
+
|
| 352 |
+
double *pq = q.data_ptr<double>();
|
| 353 |
+
double *pt = tau.data_ptr<double>();
|
| 354 |
+
double *pa = alpha.data_ptr<double>();
|
| 355 |
+
double *pf = f_alpha.data_ptr<double>();
|
| 356 |
+
|
| 357 |
+
for (int i = 0; i < Q_COUNT; i++) pq[i] = Q_MIN + i * Q_STEP;
|
| 358 |
+
|
| 359 |
+
pa[0] = -(pt[1] - pt[0]) / Q_STEP;
|
| 360 |
+
for (int i = 1; i < Q_COUNT - 1; i++)
|
| 361 |
+
pa[i] = -(pt[i+1] - pt[i-1]) / (2.0 * Q_STEP);
|
| 362 |
+
pa[Q_COUNT-1] = -(pt[Q_COUNT-1] - pt[Q_COUNT-2]) / Q_STEP;
|
| 363 |
+
|
| 364 |
+
for (int i = 0; i < Q_COUNT; i++)
|
| 365 |
+
pf[i] = pq[i] * pa[i] + pt[i];
|
| 366 |
+
|
| 367 |
+
return {q, tau, alpha, f_alpha};
|
| 368 |
+
}
|
| 369 |
+
#endif
|
| 370 |
+
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -3,4 +3,7 @@
|
|
| 3 |
|
| 4 |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 5 |
m.doc() = "Minkowski Question-Mark Singularity Spectrum CUDA kernel";
|
|
|
|
|
|
|
|
|
|
| 6 |
}
|
|
|
|
| 3 |
|
| 4 |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 5 |
m.doc() = "Minkowski Question-Mark Singularity Spectrum CUDA kernel";
|
| 6 |
+
m.def("compute_spectrum", &compute_spectrum,
|
| 7 |
+
"Compute multifractal spectrum tau(q), alpha(q), f(alpha)",
|
| 8 |
+
py::arg("A_max"), py::arg("N"));
|
| 9 |
}
|
torch-ext/torch_binding.h
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
#pragma once
|
| 2 |
#include <torch/torch.h>
|
| 3 |
-
|
|
|
|
|
|
| 1 |
#pragma once
|
| 2 |
#include <torch/torch.h>
|
| 3 |
+
|
| 4 |
+
std::vector<torch::Tensor> compute_spectrum(int64_t A_max, int64_t N);
|