cahlen commited on
Commit
3df8753
·
verified ·
1 Parent(s): aee284e

Add torch library bindings (guard main, add torch wrapper functions)

Browse files
minkowski_spectrum/minkowski_spectrum.cu CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  /*
2
  * Multifractal Singularity Spectrum of the Minkowski Question Mark Function
3
  *
@@ -166,6 +170,8 @@ __global__ void compute_tau(int num_q, double q_min, double q_step,
166
 
167
  /* ---- Host ---- */
168
 
 
 
169
  int main(int argc, char **argv) {
170
  int A_max = argc > 1 ? atoi(argv[1]) : 50;
171
  int N = argc > 2 ? atoi(argv[2]) : 40;
@@ -318,3 +324,47 @@ int main(int argc, char **argv) {
318
  free(h_tau); free(h_q); free(h_alpha); free(h_f);
319
  return 0;
320
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef TORCH_EXTENSION_NAME
2
+ #include <torch/torch.h>
3
+ #endif
4
+
5
  /*
6
  * Multifractal Singularity Spectrum of the Minkowski Question Mark Function
7
  *
 
170
 
171
  /* ---- Host ---- */
172
 
173
+ #ifndef TORCH_EXTENSION_NAME
174
+
175
  int main(int argc, char **argv) {
176
  int A_max = argc > 1 ? atoi(argv[1]) : 50;
177
  int N = argc > 2 ? atoi(argv[2]) : 40;
 
324
  free(h_tau); free(h_q); free(h_alpha); free(h_f);
325
  return 0;
326
  }
327
+ #endif
328
+
329
+
330
+ #ifdef TORCH_EXTENSION_NAME
331
+ std::vector<torch::Tensor> compute_spectrum(int64_t A_max, int64_t N) {
332
+ TORCH_CHECK(A_max >= 1 && A_max <= 100, "A_max must be 1-100");
333
+ TORCH_CHECK(N >= 1 && N <= 48, "N must be 1-48");
334
+
335
+ // Q_MIN, Q_MAX, Q_STEP, Q_COUNT are already defined as macros above
336
+ double *d_tau;
337
+ cudaMalloc(&d_tau, Q_COUNT * sizeof(double));
338
+
339
+ int threads = 32;
340
+ int blocks = (Q_COUNT + threads - 1) / threads;
341
+ compute_tau<<<blocks, threads>>>(Q_COUNT, Q_MIN, Q_STEP, (int)A_max, (int)N, d_tau);
342
+ cudaDeviceSynchronize();
343
+
344
+ auto tau = torch::zeros({Q_COUNT}, torch::dtype(torch::kFloat64));
345
+ cudaMemcpy(tau.data_ptr<double>(), d_tau, Q_COUNT * sizeof(double), cudaMemcpyDeviceToHost);
346
+ cudaFree(d_tau);
347
+
348
+ auto q = torch::zeros({Q_COUNT}, torch::dtype(torch::kFloat64));
349
+ auto alpha = torch::zeros({Q_COUNT}, torch::dtype(torch::kFloat64));
350
+ auto f_alpha = torch::zeros({Q_COUNT}, torch::dtype(torch::kFloat64));
351
+
352
+ double *pq = q.data_ptr<double>();
353
+ double *pt = tau.data_ptr<double>();
354
+ double *pa = alpha.data_ptr<double>();
355
+ double *pf = f_alpha.data_ptr<double>();
356
+
357
+ for (int i = 0; i < Q_COUNT; i++) pq[i] = Q_MIN + i * Q_STEP;
358
+
359
+ pa[0] = -(pt[1] - pt[0]) / Q_STEP;
360
+ for (int i = 1; i < Q_COUNT - 1; i++)
361
+ pa[i] = -(pt[i+1] - pt[i-1]) / (2.0 * Q_STEP);
362
+ pa[Q_COUNT-1] = -(pt[Q_COUNT-1] - pt[Q_COUNT-2]) / Q_STEP;
363
+
364
+ for (int i = 0; i < Q_COUNT; i++)
365
+ pf[i] = pq[i] * pa[i] + pt[i];
366
+
367
+ return {q, tau, alpha, f_alpha};
368
+ }
369
+ #endif
370
+
torch-ext/torch_binding.cpp CHANGED
@@ -3,4 +3,7 @@
3
 
4
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5
  m.doc() = "Minkowski Question-Mark Singularity Spectrum CUDA kernel";
 
 
 
6
  }
 
3
 
4
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5
  m.doc() = "Minkowski Question-Mark Singularity Spectrum CUDA kernel";
6
+ m.def("compute_spectrum", &compute_spectrum,
7
+ "Compute multifractal spectrum tau(q), alpha(q), f(alpha)",
8
+ py::arg("A_max"), py::arg("N"));
9
  }
torch-ext/torch_binding.h CHANGED
@@ -1,3 +1,4 @@
1
  #pragma once
2
  #include <torch/torch.h>
3
- // See minkowski_spectrum/minkowski_spectrum.cu for kernel API
 
 
1
  #pragma once
2
  #include <torch/torch.h>
3
+
4
+ std::vector<torch::Tensor> compute_spectrum(int64_t A_max, int64_t N);