Instructions to use cahlen/kronecker-cuda with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use cahlen/kronecker-cuda with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("cahlen/kronecker-cuda") - Notebooks
- Google Colab
- Kaggle
Add torch library bindings (guard main, add torch wrapper functions)
Browse files- kronecker/kronecker_gpu.cu +39 -0
- torch-ext/torch_binding.cpp +1 -0
- torch-ext/torch_binding.h +2 -1
kronecker/kronecker_gpu.cu
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#include <stdio.h>
|
| 2 |
#include <stdlib.h>
|
| 3 |
#include <stdint.h>
|
|
@@ -36,6 +40,8 @@ __global__ void reduce_stats(const int64_t *slab, int P, int j,
|
|
| 36 |
}
|
| 37 |
}
|
| 38 |
|
|
|
|
|
|
|
| 39 |
int main(int argc, char **argv) {
|
| 40 |
int n = atoi(argv[1]);
|
| 41 |
int gpu = argc > 2 ? atoi(argv[2]) : 0;
|
|
@@ -115,3 +121,36 @@ int main(int argc, char **argv) {
|
|
| 115 |
free(h_ct); free(h_z);
|
| 116 |
cudaFree(d_ct); cudaFree(d_z); cudaFree(d_out); cudaFree(d_nz); cudaFree(d_mx);
|
| 117 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifdef TORCH_EXTENSION_NAME
|
| 2 |
+
#include <torch/torch.h>
|
| 3 |
+
#endif
|
| 4 |
+
|
| 5 |
#include <stdio.h>
|
| 6 |
#include <stdlib.h>
|
| 7 |
#include <stdint.h>
|
|
|
|
| 40 |
}
|
| 41 |
}
|
| 42 |
|
| 43 |
+
#ifndef TORCH_EXTENSION_NAME
|
| 44 |
+
|
| 45 |
int main(int argc, char **argv) {
|
| 46 |
int n = atoi(argv[1]);
|
| 47 |
int gpu = argc > 2 ? atoi(argv[2]) : 0;
|
|
|
|
| 121 |
free(h_ct); free(h_z);
|
| 122 |
cudaFree(d_ct); cudaFree(d_z); cudaFree(d_out); cudaFree(d_nz); cudaFree(d_mx);
|
| 123 |
}
|
| 124 |
+
#endif
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
#ifdef TORCH_EXTENSION_NAME
|
| 128 |
+
std::vector<torch::Tensor> compute_kronecker(torch::Tensor ct, torch::Tensor z_inv) {
|
| 129 |
+
TORCH_CHECK(ct.is_cuda() && ct.dtype() == torch::kInt64, "ct must be int64 CUDA tensor");
|
| 130 |
+
TORCH_CHECK(ct.dim() == 2, "ct must be 2-D (P x C)");
|
| 131 |
+
TORCH_CHECK(z_inv.is_cuda() && z_inv.dtype() == torch::kFloat64, "z_inv must be float64 CUDA");
|
| 132 |
+
|
| 133 |
+
int P = ct.size(0);
|
| 134 |
+
int C = ct.size(1);
|
| 135 |
+
|
| 136 |
+
auto out = torch::zeros({P, P}, torch::dtype(torch::kInt64).device(ct.device()));
|
| 137 |
+
auto nz_dev = torch::zeros({1}, torch::dtype(torch::kInt64).device(ct.device()));
|
| 138 |
+
auto mx_dev = torch::zeros({1}, torch::dtype(torch::kInt64).device(ct.device()));
|
| 139 |
+
|
| 140 |
+
int64_t total_nz = 0, global_max = 0;
|
| 141 |
+
int nblocks = (P * P + 255) / 256;
|
| 142 |
+
|
| 143 |
+
for (int j = 0; j < P; j++) {
|
| 144 |
+
out.zero_(); nz_dev.zero_(); mx_dev.zero_();
|
| 145 |
+
kronecker_slab<<<nblocks, 256>>>(ct.data_ptr<int64_t>(), z_inv.data_ptr<double>(), P, C, j, out.data_ptr<int64_t>());
|
| 146 |
+
cudaDeviceSynchronize();
|
| 147 |
+
reduce_stats<<<nblocks, 256>>>(out.data_ptr<int64_t>(), P, j, (unsigned long long*)nz_dev.data_ptr<int64_t>(), (unsigned long long*)mx_dev.data_ptr<int64_t>());
|
| 148 |
+
cudaDeviceSynchronize();
|
| 149 |
+
total_nz += nz_dev.cpu().item<int64_t>();
|
| 150 |
+
int64_t sm = mx_dev.cpu().item<int64_t>();
|
| 151 |
+
if (sm > global_max) global_max = sm;
|
| 152 |
+
}
|
| 153 |
+
return {torch::tensor({total_nz}, torch::kInt64), torch::tensor({global_max}, torch::kInt64)};
|
| 154 |
+
}
|
| 155 |
+
#endif
|
| 156 |
+
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -3,4 +3,5 @@
|
|
| 3 |
|
| 4 |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 5 |
m.doc() = "Kronecker Coefficients (Symmetric Group) CUDA kernel";
|
|
|
|
| 6 |
}
|
|
|
|
| 3 |
|
| 4 |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 5 |
m.doc() = "Kronecker Coefficients (Symmetric Group) CUDA kernel";
|
| 6 |
+
m.def("compute_kronecker", &compute_kronecker, py::arg("ct"), py::arg("z_inv"));
|
| 7 |
}
|
torch-ext/torch_binding.h
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
#pragma once
|
| 2 |
#include <torch/torch.h>
|
| 3 |
-
|
|
|
|
|
|
| 1 |
#pragma once
|
| 2 |
#include <torch/torch.h>
|
| 3 |
+
|
| 4 |
+
std::vector<torch::Tensor> compute_kronecker(torch::Tensor ct, torch::Tensor z_inv);
|