diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..2813b21a07ee150f9d176a08e4b960f07a612a99 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu118-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a1c6980a905e7227d4460dd86ead0425b8eba9f2 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +MRA kernels for transformers \ No newline at end of file diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..8f377c9a82e3b5d560661cf52e019d45793a83e2 --- /dev/null +++ b/build.toml @@ -0,0 +1,20 @@ +[general] +name = "mra" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/cuda_launch.h", +] + + +[kernel.mra] +backend = "cuda" +depends = ["torch"] +src = [ + "mra/cuda_kernel.cu", + "mra/cuda_kernel.h", + "mra/cuda_launch.cu", + "mra/cuda_launch.h", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/mra/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8812e2e4685709c071b2e36defd21fecbfcbf84f Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16e775a2c4d8ced4d3614aa11aa5fa773b737e88 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ef71e524d352f13941ccf8ab823f7ad8c1f14641 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e07be154aab143447264cbd25ba8987760af84f50304cc0940419cae754d8fc2 +size 2289096 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/mra/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/mra/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20b524cc06016e5025a13a73fd1e70da25a1204 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b6703a820302d3d1173abb45d402b032ac9c2e4 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..44ff4b6b559f48ecc51a578054632f215ff50251 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:579d1b3e91773c7802fc4c5b58b5fac62235b4555c7a836af0306e34f7bb0719 +size 2334496 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/mra/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/mra/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6775174f47011764e4b2ae30bf9e0462761b53b Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..466db5067e88082bb8f77736eafc40e1a026352b Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..b18acf089b226ee82ca0a3167a3a4ba10023c3f2 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40e6e3cc4433f6333afc56fdce4dd0e5aaaf007d701b7a0582d46234d93d57ec +size 2602656 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/mra/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/mra/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3c416b7108689c3d7f6de23c94ddc07616f46b9 Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c569dc1f555d607de9c50530a6b3df9c7f2c46d Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..0bd1c5a8e019b15f3629be8ac476fae5e1acc7c0 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4b972c43e1a8b2a6941a3ab44b99a638d253f9ef1e67cb973fff0abd2664926 +size 2334520 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/mra/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/mra/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d6f1aaff835a26255fbd02bba44189113bdf121 Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2040f0669f73cea3f323540adc4de5b045048287 Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8e710a19d3b1be6b3177d34c900cd3772097511b --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1ab899e2e33ce5edf4b75e5b4138d4ca30e1f62a91ecef4111b4121b408dd5a +size 2602880 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/mra/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/mra/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b8bd021e61bb7318a628a6c9df25f94351519b2 Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a86a5a29d21755ab1864d786afe4d0112ff44d4 Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..113b995c436455b85d328889759c6f02dc18069c --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c10f2bd0af7f0564de136062d345541ed2bb493e21de5fc5cfc30942342abf22 +size 2632568 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/mra/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/mra/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..847e98e0d36de14edfd5494ba497ca1b5292185a Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28ea9c79a0bf8ad48baf34598de1ba68e4a9c77e Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..20d9f09fbc2daa930b6a9a0760ec2c7caad43c48 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac28155160d68c77778beccc3e0fa7041e5e7d2b96822322a75e0d09eaf452f5 +size 2334496 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/mra/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/mra/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7b51a4851e913d4e6bd4d892316af56dde2b7b3 Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab43cb75bc690d1adec720cad48ced52bb16ca98 Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..b101e516c1578158829015767fe3fd3124d85b55 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f668c4d60ff23112e2b6c62a2271bb2dd8812a8ed5c51a845ce1248a9e13cbf +size 2606944 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/mra/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/mra/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7c6e7dea72c3a87da5c15e8518a522ad054565a Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/mra/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b8965f5f261c901fb63a87401df7ee28d77b5f7 Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/mra/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..cc47f4b68099046bc660f61c2e6473d1bd8a48cb --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/mra/_mra_e8307c7_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a37b346ec5d1df0c97fdf3c9cfdf9eaead44ba9a8e162cd5f00d2a73ecf3e4b +size 2569704 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/mra/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/mra/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6761818e9f178bb9db57d95925c0cb21cdd8abc2 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/mra/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _mra_e8307c7_dirty +ops = torch.ops._mra_e8307c7_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_mra_e8307c7_dirty::{op_name}" \ No newline at end of file diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..85b5d60a855bf4c19555cc9b8de8ca88d6fd3ae9 --- /dev/null +++ b/flake.lock @@ -0,0 +1,168 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1759851564, + "narHash": "sha256-Xybkhm0FM/VzlZ5WndTYq/X/9MAeddd4EQ2Vz8GdkOA=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "351655d9f124805ed7c1193aa61550ce245f4570", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760035358, + "narHash": "sha256-N5vmCrgwcIluPclf/hmnofLK77EJJYh5PR8SRvw++es=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "a48cbd19ae7e425dfc1865188ef06dac43ab9244", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1755963616, + "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..9d5ee16d31cb9da4c4174b81c095f11905027def --- /dev/null +++ b/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for mra kernels"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/mra/cuda_kernel.cu b/mra/cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..87ed89052873813153786bd416a981d3e5279af9 --- /dev/null +++ b/mra/cuda_kernel.cu @@ -0,0 +1,383 @@ +#include "cuda_kernel.h" + +////////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void index_max_cuda_kernel( + float *index_vals, // [batch_size, 32, num_block] + int *indices, // [batch_size, num_block] + float *max_vals, // [batch_size, A_num_block * 32] + float *max_vals_scatter, // [batch_size, 32, num_block] + long batch_size, + long A_num_block, + long B_num_block, + long num_block +) { + + long batch_idx = blockIdx.x; + + long thread_idx = threadIdx.x; + long num_thread = blockDim.x; + + extern __shared__ float buffer[]; + int *max_buffer = (int*)buffer; + + for (int i = 0; i < A_num_block * 32; i = i + num_thread) { + int idx = i + thread_idx; + if (idx < A_num_block * 32) { + max_buffer[idx] = -1e8; + } + } + __syncthreads(); + + int *indices_pt = &indices[batch_idx * num_block]; + float *index_vals_pt = &index_vals[batch_idx * num_block * 32]; + + for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) { + int idx = idx_start + thread_idx; + int A_block_idx = indices_pt[idx % num_block] / B_num_block; + atomicMax(&max_buffer[A_block_idx * 32 + idx / num_block], (int)(index_vals_pt[idx] * 1000)); + } + __syncthreads(); + + float *max_vals_pt = &max_vals[batch_idx * A_num_block * 32]; + for (int i = 0; i < A_num_block * 32; i = i + num_thread) { + int idx = i + thread_idx; + if (idx < A_num_block * 32) { + max_vals_pt[idx] = (float)max_buffer[idx] / 1000.; + } + } + + float *max_vals_scatter_pt = &max_vals_scatter[batch_idx * num_block * 32]; + for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) { + int idx = idx_start + thread_idx; + int A_block_idx = indices_pt[idx % num_block] / B_num_block; + max_vals_scatter_pt[idx] = (float)max_buffer[A_block_idx * 32 + idx / num_block] / 1000.; + } + +} + +__global__ void mm_to_sparse_cuda_kernel( + float *dense_A, // [batch_size, A_num_block, dim, 32] + float *dense_B, // [batch_size, B_num_block, dim, 32] + int *indices, // [batch_size, num_block] + float *sparse_C, // [batch_size, num_block, 32, 32] + long batch_size, + long A_num_block, + long B_num_block, + long dim, + long num_block +) { + + long batch_idx = blockIdx.y; + long block_idx = blockIdx.x * blockDim.y + threadIdx.y; + + long thread_idx = threadIdx.x; + + __shared__ float buffer[4096]; + float *A_buffer = &buffer[threadIdx.y * 1024]; // [2, 8, 32] + float *B_buffer = &buffer[threadIdx.y * 1024 + 512]; // [2, 8, 32] + + long batch_idx__block_idx = batch_idx * num_block + block_idx; + + long AB_block_idx = indices[batch_idx__block_idx]; + float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * dim * 32]; + float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * dim * 32]; + + int reg_1_idx = thread_idx / 8; // [0000000011111111222222223333333344444444555555556666666677777777] + int reg_2_idx = thread_idx % 8; // [0123456701234567012345670123456701234567012345670123456701234567] + + float reg_1[8]; + float reg_2[8]; + + float reg_array[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + #pragma unroll + for (int i = 0; i < 4; i++) { + A_buffer[i * 64 + thread_idx] = dense_A_pt[i * 64 + thread_idx]; + B_buffer[i * 64 + thread_idx] = dense_B_pt[i * 64 + thread_idx]; + } + + __syncthreads(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + reg_1[i] = A_buffer[reg_1_idx * 4 + i]; + reg_2[i] = B_buffer[reg_2_idx * 4 + i]; + } + + for (int dim_stride = 1; dim_stride < (dim / 8); dim_stride++) { + + #pragma unroll + for (int i = 0; i < 4; i++) { + A_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_A_pt[dim_stride * 256 + i * 64 + thread_idx]; + B_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_B_pt[dim_stride * 256 + i * 64 + thread_idx]; + } + + #pragma unroll + for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_1_idx * 4 + i]; + reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_2_idx * 4 + i]; + } + #pragma unroll + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j]; + } + } + } + + __syncthreads(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + reg_1[i] = A_buffer[(dim_stride % 2) * 256 + reg_1_idx * 4 + i]; + reg_2[i] = B_buffer[(dim_stride % 2) * 256 + reg_2_idx * 4 + i]; + } + + #pragma unroll + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j]; + } + } + + } + + #pragma unroll + for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[256 + mini_dim_idx * 32 + reg_1_idx * 4 + i]; + reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[256 + mini_dim_idx * 32 + reg_2_idx * 4 + i]; + } + #pragma unroll + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j]; + } + } + } + #pragma unroll + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j]; + } + } + __syncthreads(); + + float *C_buffer = &buffer[threadIdx.y * 1024]; // [32, 32] + + #pragma unroll + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + C_buffer[(reg_2_idx * 4 + j) * 32 + reg_1_idx * 4 + i] = reg_array[i * 4 + j]; + } + } + __syncthreads(); + + float *sparse_C_pt = &sparse_C[batch_idx__block_idx * 1024]; + + #pragma unroll + for (int i = 0; i < 16; i++) { + sparse_C_pt[i * 64 + thread_idx] = C_buffer[i * 64 + thread_idx]; + } + +} + +__global__ void sparse_dense_mm_cuda_kernel( + float *sparse_A, // [batch_size, num_block, 32, 32] + int *indices, // [batch_size, num_block] + float *dense_B, // [batch_size, B_num_block, dim, 32] + float *dense_C, // [batch_size, A_num_block, dim, 32] + long batch_size, + long A_num_block, + long B_num_block, + long dim, + long num_block +) { + + long batch_idx = blockIdx.y; + long block_idx = blockIdx.x * blockDim.y + threadIdx.y; + + long thread_idx = threadIdx.x; + + __shared__ float buffer[6144]; + float *A_buffer = &buffer[threadIdx.y * 3072]; // [32, 32] + float *B_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [32, 64] + + long batch_idx__block_idx = batch_idx * num_block + block_idx; + + float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024]; + #pragma unroll + for (int i = 0; i < 8; i++) { + A_buffer[i * 128 + thread_idx] = sparse_A_pt[i * 128 + thread_idx]; + } + + long AB_block_idx = indices[batch_idx__block_idx]; + float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * 32 * dim]; + float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32 * dim]; + + // [0000000011111111222222223333333344444444555555556666666677777777] + // [0123456701234567012345670123456701234567012345670123456701234567] + int reg_1_idx = thread_idx / 8; + int reg_2_idx = thread_idx % 8; + + float reg_1[8]; + float reg_2[8]; + + float reg_array[16]; + + for (int dim_stride = 0; dim_stride < dim; dim_stride = dim_stride + 64) { + + #pragma unroll + for (int i = 0; i < 16; i++) { + B_buffer[i * 128 + thread_idx] = dense_B_pt[dim_stride * 32 + i * 128 + thread_idx]; + } + + #pragma unroll + for (int i = 0; i < 16; i++) { + reg_array[i] = 0; + } + + __syncthreads(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + reg_1[i] = B_buffer[(reg_1_idx * 4 + i) * 32]; + reg_2[i] = A_buffer[reg_2_idx * 4 + i]; + } + + #pragma unroll + for (int mini_dim_idx = 1; mini_dim_idx < 32; mini_dim_idx++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + reg_1[(mini_dim_idx % 2) * 4 + i] = B_buffer[(reg_1_idx * 4 + i) * 32 + mini_dim_idx]; + reg_2[(mini_dim_idx % 2) * 4 + i] = A_buffer[mini_dim_idx * 32 + reg_2_idx * 4 + i]; + } + #pragma unroll + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j]; + } + } + } + + #pragma unroll + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j]; + } + } + + __syncthreads(); + + float *C_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [64, 32] + + #pragma unroll + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + C_buffer[(reg_1_idx * 4 + i) * 32 + reg_2_idx * 4 + j] = reg_array[i * 4 + j]; + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < 16; i++) { + atomicAdd(&dense_C_pt[dim_stride * 32 + i * 128 + thread_idx], C_buffer[i * 128 + thread_idx]); + } + __syncthreads(); + + } + +} + + +__global__ void reduce_sum_cuda_kernel( + float *sparse_A, // [batch_size, num_block, 32, 32] + int *indices, // [batch_size, num_block] + float *dense_C, // [batch_size, A_num_block, 32] + long batch_size, + long A_num_block, + long B_num_block, + long num_block +) { + + long batch_idx = blockIdx.y; + long block_idx = blockIdx.x * blockDim.y + threadIdx.y; + + long thread_idx = threadIdx.x; + + long batch_idx__block_idx = batch_idx * num_block + block_idx; + + long AB_block_idx = indices[batch_idx__block_idx]; + float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024]; + + float reg_array[16]; + float value = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) { + reg_array[i] = sparse_A_pt[i * 32 + thread_idx]; + } + #pragma unroll + for (int stride = 8; stride < 32; stride = stride + 8) { + #pragma unroll + for (int i = 0; i < 8; i++) { + reg_array[(stride + i) % 16] = sparse_A_pt[(stride + i) * 32 + thread_idx]; + } + #pragma unroll + for (int i = 0; i < 8; i++) { + value = value + reg_array[(stride - 8 + i) % 16]; + } + } + #pragma unroll + for (int i = 0; i < 8; i++) { + value = value + reg_array[8 + i]; + } + + float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32]; + + atomicAdd(&dense_C_pt[thread_idx], value); + +} + +__global__ void scatter_cuda_kernel( + float *dense_A, // [batch_size, A_num_block, 32] + int *indices, // [batch_size, num_block] + float *sparse_C, // [batch_size, num_block, 32, 32] + long batch_size, + long A_num_block, + long B_num_block, + long num_block +) { + + long batch_idx = blockIdx.y; + long block_idx = blockIdx.x * blockDim.y + threadIdx.y; + + long thread_idx = threadIdx.x; + + long batch_idx__block_idx = batch_idx * num_block + block_idx; + + long AB_block_idx = indices[batch_idx__block_idx]; + float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32]; + float *sparse_C_pt = &sparse_C[(batch_idx * num_block + block_idx) * 1024]; + + float value = dense_A_pt[thread_idx]; + + #pragma unroll + for (int i = 0; i < 32; i++) { + sparse_C_pt[i * 32 + thread_idx] = value; + } + +} diff --git a/mra/cuda_kernel.h b/mra/cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a95b46f7d159b11851143710034cf80c20aa6bf8 --- /dev/null +++ b/mra/cuda_kernel.h @@ -0,0 +1,59 @@ + +#define WARP_SIZE 32 +#define FULL_MASK 0xffffffff +#define OPTIMAL_THREADS 256 + +__global__ void index_max_cuda_kernel( + float *index_vals, // [batch_size, 32, num_block] + int *indices, // [batch_size, num_block] + float *max_vals, // [batch_size, A_num_block * 32] + float *max_vals_scatter, // [batch_size, 32, num_block] + long batch_size, + long A_num_block, + long B_num_block, + long num_block +); + +__global__ void mm_to_sparse_cuda_kernel( + float *dense_A, // [batch_size, A_num_block, dim, 32] + float *dense_B, // [batch_size, B_num_block, dim, 32] + int *indices, // [batch_size, num_block] + float *sparse_C, // [batch_size, num_block, 32, 32] + long batch_size, + long A_num_block, + long B_num_block, + long dim, + long num_block +); + +__global__ void sparse_dense_mm_cuda_kernel( + float *sparse_A, // [batch_size, num_block, 32, 32] + int *indices, // [batch_size, num_block] + float *dense_B, // [batch_size, B_num_block, dim, 32] + float *dense_C, // [batch_size, A_num_block, dim, 32] + long batch_size, + long A_num_block, + long B_num_block, + long dim, + long num_block +); + +__global__ void reduce_sum_cuda_kernel( + float *sparse_A, // [batch_size, num_block, 32, 32] + int *indices, // [batch_size, num_block] + float *dense_C, // [batch_size, A_num_block, 32] + long batch_size, + long A_num_block, + long B_num_block, + long num_block +); + +__global__ void scatter_cuda_kernel( + float *dense_A, // [batch_size, A_num_block, 32] + int *indices, // [batch_size, num_block] + float *sparse_C, // [batch_size, num_block, 32, 32] + long batch_size, + long A_num_block, + long B_num_block, + long num_block +); diff --git a/mra/cuda_launch.cu b/mra/cuda_launch.cu new file mode 100644 index 0000000000000000000000000000000000000000..fd9565875380d3267191e102ca9f4ff5dc381a0e --- /dev/null +++ b/mra/cuda_launch.cu @@ -0,0 +1,154 @@ +#include +#include +#include "cuda_launch.h" +#include "cuda_kernel.h" +#include + +////////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector index_max_kernel( + at::Tensor index_vals, // [batch_size, 32, num_block] + at::Tensor indices, // [batch_size, num_block], + int A_num_block, + int B_num_block +) { + int batch_size = indices.size(0); + int num_block = indices.size(1); + + at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options()); + at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options()); + + dim3 threads(256); + dim3 blocks(batch_size); + int shared_mem = A_num_block * 32 * sizeof(float); + + index_max_cuda_kernel<<>>( + index_vals.data_ptr(), + indices.data_ptr(), + max_vals.data_ptr(), + max_vals_scatter.data_ptr(), + batch_size, + A_num_block, + B_num_block, + num_block + ); + + return {max_vals, max_vals_scatter}; +} + +at::Tensor mm_to_sparse_kernel( + at::Tensor dense_A, // [batch_size, A_num_block, dim, 32] + at::Tensor dense_B, // [batch_size, B_num_block, dim, 32] + at::Tensor indices // [batch_size, num_block] +) { + int batch_size = dense_A.size(0); + int A_num_block = dense_A.size(1); + int B_num_block = dense_B.size(1); + int dim = dense_A.size(2); + int num_block = indices.size(1); + + at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options()); + + dim3 threads(64, 4); + dim3 blocks(num_block / 4, batch_size); + + mm_to_sparse_cuda_kernel<<>>( + dense_A.data_ptr(), + dense_B.data_ptr(), + indices.data_ptr(), + sparse_C.data_ptr(), + batch_size, + A_num_block, + B_num_block, + dim, + num_block + ); + + return sparse_C; +} + +at::Tensor sparse_dense_mm_kernel( + at::Tensor sparse_A, // [batch_size, num_block, 32, 32] + at::Tensor indices, // [batch_size, num_block] + at::Tensor dense_B, // [batch_size, B_num_block, dim, 32] + int A_num_block +) { + int batch_size = sparse_A.size(0); + int num_block = sparse_A.size(1); + int B_num_block = dense_B.size(1); + int dim = dense_B.size(2); + + at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options()); + + dim3 threads(128, 2); + dim3 blocks(num_block / 2, batch_size); + + sparse_dense_mm_cuda_kernel<<>>( + sparse_A.data_ptr(), + indices.data_ptr(), + dense_B.data_ptr(), + dense_C.data_ptr(), + batch_size, + A_num_block, + B_num_block, + dim, + num_block + ); + + return dense_C; +} + +at::Tensor reduce_sum_kernel( + at::Tensor sparse_A, // [batch_size, num_block, 32, 32] + at::Tensor indices, // [batch_size, num_block] + int A_num_block, + int B_num_block +) { + int batch_size = sparse_A.size(0); + int num_block = sparse_A.size(1); + + at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options()); + + dim3 threads(32, 4); + dim3 blocks(num_block / 4, batch_size); + + reduce_sum_cuda_kernel<<>>( + sparse_A.data_ptr(), + indices.data_ptr(), + dense_C.data_ptr(), + batch_size, + A_num_block, + B_num_block, + num_block + ); + + return dense_C; +} + +at::Tensor scatter_kernel( + at::Tensor dense_A, // [batch_size, A_num_block, 32] + at::Tensor indices, // [batch_size, num_block] + int B_num_block +) { + int batch_size = dense_A.size(0); + int A_num_block = dense_A.size(1); + int num_block = indices.size(1); + + at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options()); + + dim3 threads(32, 4); + dim3 blocks(num_block / 4, batch_size); + + scatter_cuda_kernel<<>>( + dense_A.data_ptr(), + indices.data_ptr(), + sparse_C.data_ptr(), + batch_size, + A_num_block, + B_num_block, + num_block + ); + + return sparse_C; +} diff --git a/mra/cuda_launch.h b/mra/cuda_launch.h new file mode 100644 index 0000000000000000000000000000000000000000..9a8950a657c50ff70351eb43e4862f30e49f36e4 --- /dev/null +++ b/mra/cuda_launch.h @@ -0,0 +1,39 @@ +#include +#include +#include + +#define min(a, b) ((a)<(b)?(a):(b)) +#define max(a, b) ((a)>(b)?(a):(b)) + +std::vector index_max_kernel( + at::Tensor index_vals, + at::Tensor indices, + int A_num_block, + int B_num_block +); + +at::Tensor mm_to_sparse_kernel( + at::Tensor dense_A, + at::Tensor dense_B, + at::Tensor indices +); + +at::Tensor sparse_dense_mm_kernel( + at::Tensor sparse_A, + at::Tensor indices, + at::Tensor dense_B, + int A_num_block +); + +at::Tensor reduce_sum_kernel( + at::Tensor sparse_A, + at::Tensor indices, + int A_num_block, + int B_num_block +); + +at::Tensor scatter_kernel( + at::Tensor dense_A, + at::Tensor indices, + int B_num_block +); diff --git a/torch-ext/cuda_launch.h b/torch-ext/cuda_launch.h new file mode 100644 index 0000000000000000000000000000000000000000..9a8950a657c50ff70351eb43e4862f30e49f36e4 --- /dev/null +++ b/torch-ext/cuda_launch.h @@ -0,0 +1,39 @@ +#include +#include +#include + +#define min(a, b) ((a)<(b)?(a):(b)) +#define max(a, b) ((a)>(b)?(a):(b)) + +std::vector index_max_kernel( + at::Tensor index_vals, + at::Tensor indices, + int A_num_block, + int B_num_block +); + +at::Tensor mm_to_sparse_kernel( + at::Tensor dense_A, + at::Tensor dense_B, + at::Tensor indices +); + +at::Tensor sparse_dense_mm_kernel( + at::Tensor sparse_A, + at::Tensor indices, + at::Tensor dense_B, + int A_num_block +); + +at::Tensor reduce_sum_kernel( + at::Tensor sparse_A, + at::Tensor indices, + int A_num_block, + int B_num_block +); + +at::Tensor scatter_kernel( + at::Tensor dense_A, + at::Tensor indices, + int B_num_block +); diff --git a/torch-ext/mra/__init__.py b/torch-ext/mra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86daaf874c808b8b828b43b4ee8a6b292323d336 --- /dev/null +++ b/torch-ext/mra/__init__.py @@ -0,0 +1,25 @@ +from ._ops import ops +import torch + +def index_max(index_vals: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.index_max(index_vals, indices, A_num_block, B_num_block) + +def mm_to_sparse(dense_A: torch.Tensor, dense_B: torch.Tensor, indices: torch.Tensor): + return ops.mm_to_sparse(dense_A, dense_B, indices) + +def sparse_dense_mm(sparse_A: torch.Tensor, indices: torch.Tensor, dense_B: torch.Tensor, A_num_block: int): + return ops.sparse_dense_mm(sparse_A, indices, dense_B, A_num_block) + +def reduce_sum(sparse_A: torch.Tensor, indices: torch.Tensor, A_num_block: int, B_num_block: int): + return ops.reduce_sum(sparse_A, indices, A_num_block, B_num_block) + +def scatter(dense_A: torch.Tensor, indices: torch.Tensor, B_num_block: int): + return ops.scatter(dense_A, indices, B_num_block) + +__all__ = [ + "index_max", + "mm_to_sparse", + "sparse_dense_mm", + "reduce_sum", + "scatter", +] \ No newline at end of file diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc885ced71edcbe2f3bbc900f4d56b11ee423841 --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,92 @@ +#include +#include +#include +#include + +#include "registration.h" +#include "cuda_launch.h" + +std::vector index_max( + at::Tensor index_vals, + at::Tensor indices, + int64_t A_num_block, + int64_t B_num_block +) { + return index_max_kernel( + index_vals, + indices, + static_cast(A_num_block), + static_cast(B_num_block) + ); +} + +at::Tensor mm_to_sparse( + at::Tensor dense_A, + at::Tensor dense_B, + at::Tensor indices +) { + return mm_to_sparse_kernel( + dense_A, + dense_B, + indices + ); +} + +at::Tensor sparse_dense_mm( + at::Tensor sparse_A, + at::Tensor indices, + at::Tensor dense_B, + int64_t A_num_block +) { + return sparse_dense_mm_kernel( + sparse_A, + indices, + dense_B, + static_cast(A_num_block) + ); +} + +at::Tensor reduce_sum( + at::Tensor sparse_A, + at::Tensor indices, + int64_t A_num_block, + int64_t B_num_block +) { + return reduce_sum_kernel( + sparse_A, + indices, + static_cast(A_num_block), + static_cast(B_num_block) + ); +} + +at::Tensor scatter( + at::Tensor dense_A, + at::Tensor indices, + int64_t B_num_block +) { + return scatter_kernel( + dense_A, + indices, + static_cast(B_num_block) + ); +} + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("index_max(Tensor index_vals, Tensor indices, int A_num_block, int B_num_block) -> Tensor[]"); + ops.impl("index_max", torch::kCUDA, &index_max); + + ops.def("mm_to_sparse(Tensor dense_A, Tensor dense_B, Tensor indices) -> Tensor"); + ops.impl("mm_to_sparse", torch::kCUDA, &mm_to_sparse); + + ops.def("sparse_dense_mm(Tensor sparse_A, Tensor indices, Tensor dense_B, int A_num_block) -> Tensor"); + ops.impl("sparse_dense_mm", torch::kCUDA, &sparse_dense_mm); + + ops.def("reduce_sum(Tensor sparse_A, Tensor indices, int A_num_block, int B_num_block) -> Tensor"); + ops.impl("reduce_sum", torch::kCUDA, &reduce_sum); + + ops.def("scatter(Tensor dense_A, Tensor indices, int B_num_block) -> Tensor"); + ops.impl("scatter", torch::kCUDA, &scatter); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME); \ No newline at end of file