File size: 2,096 Bytes
55ce07b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <torch/torch.h>
#include <ATen/ATen.h>
#include <vector>
#include <torch/library.h>

#include "registration.h"
#include "cuda_launch.h"

std::vector<at::Tensor> index_max(
  at::Tensor index_vals,
  at::Tensor indices,
  int64_t A_num_block,
  int64_t B_num_block
) {
  return index_max_kernel(
    index_vals,
    indices,
    static_cast<int>(A_num_block),
    static_cast<int>(B_num_block)
  );
}

at::Tensor mm_to_sparse(
  at::Tensor dense_A,
  at::Tensor dense_B,
  at::Tensor indices
) {
  return mm_to_sparse_kernel(
    dense_A,
    dense_B,
    indices
  );
}

at::Tensor sparse_dense_mm(
  at::Tensor sparse_A,
  at::Tensor indices,
  at::Tensor dense_B,
  int64_t A_num_block
) {
  return sparse_dense_mm_kernel(
    sparse_A,
    indices,
    dense_B,
    static_cast<int>(A_num_block)
  );
}

at::Tensor reduce_sum(
  at::Tensor sparse_A,
  at::Tensor indices,
  int64_t A_num_block,
  int64_t B_num_block
) {
  return reduce_sum_kernel(
    sparse_A,
    indices,
    static_cast<int>(A_num_block),
    static_cast<int>(B_num_block)
  );
}

at::Tensor scatter(
  at::Tensor dense_A,
  at::Tensor indices,
  int64_t B_num_block
) {
  return scatter_kernel(
    dense_A,
    indices,
    static_cast<int>(B_num_block)
  );
}

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  ops.def("index_max(Tensor index_vals, Tensor indices, int A_num_block, int B_num_block) -> Tensor[]");
  ops.impl("index_max", torch::kCUDA, &index_max);

  ops.def("mm_to_sparse(Tensor dense_A, Tensor dense_B, Tensor indices) -> Tensor");
  ops.impl("mm_to_sparse", torch::kCUDA, &mm_to_sparse);

  ops.def("sparse_dense_mm(Tensor sparse_A, Tensor indices, Tensor dense_B, int A_num_block) -> Tensor");
  ops.impl("sparse_dense_mm", torch::kCUDA, &sparse_dense_mm);

  ops.def("reduce_sum(Tensor sparse_A, Tensor indices, int A_num_block, int B_num_block) -> Tensor");
  ops.impl("reduce_sum", torch::kCUDA, &reduce_sum);

  ops.def("scatter(Tensor dense_A, Tensor indices, int B_num_block) -> Tensor");
  ops.impl("scatter", torch::kCUDA, &scatter);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME);