cahlen commited on
Commit
ad8dab0
·
verified ·
1 Parent(s): bee95b7

Add torch library bindings (guard main, add torch wrapper functions)

Browse files
kronecker/kronecker_gpu.cu CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  #include <stdio.h>
2
  #include <stdlib.h>
3
  #include <stdint.h>
@@ -36,6 +40,8 @@ __global__ void reduce_stats(const int64_t *slab, int P, int j,
36
  }
37
  }
38
 
 
 
39
  int main(int argc, char **argv) {
40
  int n = atoi(argv[1]);
41
  int gpu = argc > 2 ? atoi(argv[2]) : 0;
@@ -115,3 +121,36 @@ int main(int argc, char **argv) {
115
  free(h_ct); free(h_z);
116
  cudaFree(d_ct); cudaFree(d_z); cudaFree(d_out); cudaFree(d_nz); cudaFree(d_mx);
117
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef TORCH_EXTENSION_NAME
2
+ #include <torch/torch.h>
3
+ #endif
4
+
5
  #include <stdio.h>
6
  #include <stdlib.h>
7
  #include <stdint.h>
 
40
  }
41
  }
42
 
43
+ #ifndef TORCH_EXTENSION_NAME
44
+
45
  int main(int argc, char **argv) {
46
  int n = atoi(argv[1]);
47
  int gpu = argc > 2 ? atoi(argv[2]) : 0;
 
121
  free(h_ct); free(h_z);
122
  cudaFree(d_ct); cudaFree(d_z); cudaFree(d_out); cudaFree(d_nz); cudaFree(d_mx);
123
  }
124
+ #endif
125
+
126
+
127
+ #ifdef TORCH_EXTENSION_NAME
128
+ std::vector<torch::Tensor> compute_kronecker(torch::Tensor ct, torch::Tensor z_inv) {
129
+ TORCH_CHECK(ct.is_cuda() && ct.dtype() == torch::kInt64, "ct must be int64 CUDA tensor");
130
+ TORCH_CHECK(ct.dim() == 2, "ct must be 2-D (P x C)");
131
+ TORCH_CHECK(z_inv.is_cuda() && z_inv.dtype() == torch::kFloat64, "z_inv must be float64 CUDA");
132
+
133
+ int P = ct.size(0);
134
+ int C = ct.size(1);
135
+
136
+ auto out = torch::zeros({P, P}, torch::dtype(torch::kInt64).device(ct.device()));
137
+ auto nz_dev = torch::zeros({1}, torch::dtype(torch::kInt64).device(ct.device()));
138
+ auto mx_dev = torch::zeros({1}, torch::dtype(torch::kInt64).device(ct.device()));
139
+
140
+ int64_t total_nz = 0, global_max = 0;
141
+ int nblocks = (P * P + 255) / 256;
142
+
143
+ for (int j = 0; j < P; j++) {
144
+ out.zero_(); nz_dev.zero_(); mx_dev.zero_();
145
+ kronecker_slab<<<nblocks, 256>>>(ct.data_ptr<int64_t>(), z_inv.data_ptr<double>(), P, C, j, out.data_ptr<int64_t>());
146
+ cudaDeviceSynchronize();
147
+ reduce_stats<<<nblocks, 256>>>(out.data_ptr<int64_t>(), P, j, (unsigned long long*)nz_dev.data_ptr<int64_t>(), (unsigned long long*)mx_dev.data_ptr<int64_t>());
148
+ cudaDeviceSynchronize();
149
+ total_nz += nz_dev.cpu().item<int64_t>();
150
+ int64_t sm = mx_dev.cpu().item<int64_t>();
151
+ if (sm > global_max) global_max = sm;
152
+ }
153
+ return {torch::tensor({total_nz}, torch::kInt64), torch::tensor({global_max}, torch::kInt64)};
154
+ }
155
+ #endif
156
+
torch-ext/torch_binding.cpp CHANGED
@@ -3,4 +3,5 @@
3
 
4
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5
  m.doc() = "Kronecker Coefficients (Symmetric Group) CUDA kernel";
 
6
  }
 
3
 
4
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5
  m.doc() = "Kronecker Coefficients (Symmetric Group) CUDA kernel";
6
+ m.def("compute_kronecker", &compute_kronecker, py::arg("ct"), py::arg("z_inv"));
7
  }
torch-ext/torch_binding.h CHANGED
@@ -1,3 +1,4 @@
1
  #pragma once
2
  #include <torch/torch.h>
3
- // See kronecker/kronecker_gpu.cu for kernel API
 
 
1
  #pragma once
2
  #include <torch/torch.h>
3
+
4
+ std::vector<torch::Tensor> compute_kronecker(torch::Tensor ct, torch::Tensor z_inv);