diff --git a/.gitignore b/.gitignore index 64510a07ab488778f7867d2fd0e8e47830a8fcf3..7a60b85e148f80966a550e5ab6a762a907c69ca6 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -/build/temp* +__pycache__/ +*.pyc diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..52b9c248d90541699bf87ca7d204bdccdf70913a --- /dev/null +++ b/build.toml @@ -0,0 +1,94 @@ +[general] +name = "vllm_flash_attn3" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] + +[kernel.layer-norm] +depends = ["torch"] +backend = "cuda" +include = ["."] +src = [ + "layer-norm/ln.h" + "layer-norm/ln_api.cpp", + "layer-norm/ln_bwd_1024.cu", + "layer-norm/ln_bwd_1280.cu", + "layer-norm/ln_bwd_1536.cu", + "layer-norm/ln_bwd_2048.cu", + "layer-norm/ln_bwd_256.cu", + "layer-norm/ln_bwd_2560.cu", + "layer-norm/ln_bwd_3072.cu", + "layer-norm/ln_bwd_4096.cu", + "layer-norm/ln_bwd_512.cu", + "layer-norm/ln_bwd_5120.cu", + "layer-norm/ln_bwd_6144.cu", + "layer-norm/ln_bwd_7168.cu", + "layer-norm/ln_bwd_768.cu", + "layer-norm/ln_bwd_8192.cu", + "layer-norm/ln_bwd_kernels.cuh", + "layer-norm/ln_fwd_1024.cu", + "layer-norm/ln_fwd_1280.cu", + "layer-norm/ln_fwd_1536.cu", + "layer-norm/ln_fwd_2048.cu", + "layer-norm/ln_fwd_256.cu", + "layer-norm/ln_fwd_2560.cu", + "layer-norm/ln_fwd_3072.cu", + "layer-norm/ln_fwd_4096.cu", + "layer-norm/ln_fwd_512.cu", + "layer-norm/ln_fwd_5120.cu", + "layer-norm/ln_fwd_6144.cu", + "layer-norm/ln_fwd_7168.cu", + "layer-norm/ln_fwd_768.cu", + "layer-norm/ln_fwd_8192.cu", + "layer-norm/ln_fwd_kernels.cuh", + "layer-norm/ln_kernel_traits.h", + "layer-norm/ln_parallel_bwd_1024.cu", + "layer-norm/ln_parallel_bwd_1280.cu", + "layer-norm/ln_parallel_bwd_1536.cu", + "layer-norm/ln_parallel_bwd_2048.cu", + "layer-norm/ln_parallel_bwd_256.cu", + "layer-norm/ln_parallel_bwd_2560.cu", + "layer-norm/ln_parallel_bwd_3072.cu", + "layer-norm/ln_parallel_bwd_4096.cu", + "layer-norm/ln_parallel_bwd_512.cu", + "layer-norm/ln_parallel_bwd_5120.cu", + "layer-norm/ln_parallel_bwd_6144.cu", + "layer-norm/ln_parallel_bwd_7168.cu", + "layer-norm/ln_parallel_bwd_768.cu", + "layer-norm/ln_parallel_bwd_8192.cu", + "layer-norm/ln_parallel_fwd_1024.cu", + "layer-norm/ln_parallel_fwd_1280.cu", + "layer-norm/ln_parallel_fwd_1536.cu", + "layer-norm/ln_parallel_fwd_2048.cu", + "layer-norm/ln_parallel_fwd_256.cu", + "layer-norm/ln_parallel_fwd_2560.cu", + "layer-norm/ln_parallel_fwd_3072.cu", + "layer-norm/ln_parallel_fwd_4096.cu", + "layer-norm/ln_parallel_fwd_512.cu", + "layer-norm/ln_parallel_fwd_5120.cu", + "layer-norm/ln_parallel_fwd_6144.cu", + "layer-norm/ln_parallel_fwd_7168.cu", + "layer-norm/ln_parallel_fwd_768.cu", + "layer-norm/ln_parallel_fwd_8192.cu", + "layer-norm/ln_parallel_residual_bwd_kernels.cuh", + "layer-norm/ln_parallel_residual_fwd_kernels.cuh", + "layer-norm/ln_utils.cuh", + "layer-norm/static_switch.h" +] +cuda-flags = [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", +] + diff --git a/build/2.5.1+cu124/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so b/build/2.5.1+cu124/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so deleted file mode 100755 index 63357fa5fbd0e7831cfc25afc8d11f34c1727e7c..0000000000000000000000000000000000000000 --- a/build/2.5.1+cu124/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5a2a2612ebb3261909ec7ca8b85213e19f5ea128aeee93eca5640222d299969c -size 734115376 diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..ee6d68ec2b4e0ae4c7a9139538f9706cb31b1d2e --- /dev/null +++ b/flake.nix @@ -0,0 +1,13 @@ +{ + description = "Flake for Torch kernel extension"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = { self, kernel-builder, }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/ln.h b/layer-norm/ln.h similarity index 100% rename from ln.h rename to layer-norm/ln.h diff --git a/ln_api.cpp b/layer-norm/ln_api.cpp similarity index 100% rename from ln_api.cpp rename to layer-norm/ln_api.cpp diff --git a/ln_bwd_1024.cu b/layer-norm/ln_bwd_1024.cu similarity index 100% rename from ln_bwd_1024.cu rename to layer-norm/ln_bwd_1024.cu diff --git a/ln_bwd_1280.cu b/layer-norm/ln_bwd_1280.cu similarity index 100% rename from ln_bwd_1280.cu rename to layer-norm/ln_bwd_1280.cu diff --git a/ln_bwd_1536.cu b/layer-norm/ln_bwd_1536.cu similarity index 100% rename from ln_bwd_1536.cu rename to layer-norm/ln_bwd_1536.cu diff --git a/ln_bwd_2048.cu b/layer-norm/ln_bwd_2048.cu similarity index 100% rename from ln_bwd_2048.cu rename to layer-norm/ln_bwd_2048.cu diff --git a/ln_bwd_256.cu b/layer-norm/ln_bwd_256.cu similarity index 100% rename from ln_bwd_256.cu rename to layer-norm/ln_bwd_256.cu diff --git a/ln_bwd_2560.cu b/layer-norm/ln_bwd_2560.cu similarity index 100% rename from ln_bwd_2560.cu rename to layer-norm/ln_bwd_2560.cu diff --git a/ln_bwd_3072.cu b/layer-norm/ln_bwd_3072.cu similarity index 100% rename from ln_bwd_3072.cu rename to layer-norm/ln_bwd_3072.cu diff --git a/ln_bwd_4096.cu b/layer-norm/ln_bwd_4096.cu similarity index 100% rename from ln_bwd_4096.cu rename to layer-norm/ln_bwd_4096.cu diff --git a/ln_bwd_512.cu b/layer-norm/ln_bwd_512.cu similarity index 100% rename from ln_bwd_512.cu rename to layer-norm/ln_bwd_512.cu diff --git a/ln_bwd_5120.cu b/layer-norm/ln_bwd_5120.cu similarity index 100% rename from ln_bwd_5120.cu rename to layer-norm/ln_bwd_5120.cu diff --git a/ln_bwd_6144.cu b/layer-norm/ln_bwd_6144.cu similarity index 100% rename from ln_bwd_6144.cu rename to layer-norm/ln_bwd_6144.cu diff --git a/ln_bwd_7168.cu b/layer-norm/ln_bwd_7168.cu similarity index 100% rename from ln_bwd_7168.cu rename to layer-norm/ln_bwd_7168.cu diff --git a/ln_bwd_768.cu b/layer-norm/ln_bwd_768.cu similarity index 100% rename from ln_bwd_768.cu rename to layer-norm/ln_bwd_768.cu diff --git a/ln_bwd_8192.cu b/layer-norm/ln_bwd_8192.cu similarity index 100% rename from ln_bwd_8192.cu rename to layer-norm/ln_bwd_8192.cu diff --git a/ln_bwd_kernels.cuh b/layer-norm/ln_bwd_kernels.cuh similarity index 100% rename from ln_bwd_kernels.cuh rename to layer-norm/ln_bwd_kernels.cuh diff --git a/ln_fwd_1024.cu b/layer-norm/ln_fwd_1024.cu similarity index 100% rename from ln_fwd_1024.cu rename to layer-norm/ln_fwd_1024.cu diff --git a/ln_fwd_1280.cu b/layer-norm/ln_fwd_1280.cu similarity index 100% rename from ln_fwd_1280.cu rename to layer-norm/ln_fwd_1280.cu diff --git a/ln_fwd_1536.cu b/layer-norm/ln_fwd_1536.cu similarity index 100% rename from ln_fwd_1536.cu rename to layer-norm/ln_fwd_1536.cu diff --git a/ln_fwd_2048.cu b/layer-norm/ln_fwd_2048.cu similarity index 100% rename from ln_fwd_2048.cu rename to layer-norm/ln_fwd_2048.cu diff --git a/ln_fwd_256.cu b/layer-norm/ln_fwd_256.cu similarity index 100% rename from ln_fwd_256.cu rename to layer-norm/ln_fwd_256.cu diff --git a/ln_fwd_2560.cu b/layer-norm/ln_fwd_2560.cu similarity index 100% rename from ln_fwd_2560.cu rename to layer-norm/ln_fwd_2560.cu diff --git a/ln_fwd_3072.cu b/layer-norm/ln_fwd_3072.cu similarity index 100% rename from ln_fwd_3072.cu rename to layer-norm/ln_fwd_3072.cu diff --git a/ln_fwd_4096.cu b/layer-norm/ln_fwd_4096.cu similarity index 100% rename from ln_fwd_4096.cu rename to layer-norm/ln_fwd_4096.cu diff --git a/ln_fwd_512.cu b/layer-norm/ln_fwd_512.cu similarity index 100% rename from ln_fwd_512.cu rename to layer-norm/ln_fwd_512.cu diff --git a/ln_fwd_5120.cu b/layer-norm/ln_fwd_5120.cu similarity index 100% rename from ln_fwd_5120.cu rename to layer-norm/ln_fwd_5120.cu diff --git a/ln_fwd_6144.cu b/layer-norm/ln_fwd_6144.cu similarity index 100% rename from ln_fwd_6144.cu rename to layer-norm/ln_fwd_6144.cu diff --git a/ln_fwd_7168.cu b/layer-norm/ln_fwd_7168.cu similarity index 100% rename from ln_fwd_7168.cu rename to layer-norm/ln_fwd_7168.cu diff --git a/ln_fwd_768.cu b/layer-norm/ln_fwd_768.cu similarity index 100% rename from ln_fwd_768.cu rename to layer-norm/ln_fwd_768.cu diff --git a/ln_fwd_8192.cu b/layer-norm/ln_fwd_8192.cu similarity index 100% rename from ln_fwd_8192.cu rename to layer-norm/ln_fwd_8192.cu diff --git a/ln_fwd_kernels.cuh b/layer-norm/ln_fwd_kernels.cuh similarity index 100% rename from ln_fwd_kernels.cuh rename to layer-norm/ln_fwd_kernels.cuh diff --git a/ln_kernel_traits.h b/layer-norm/ln_kernel_traits.h similarity index 100% rename from ln_kernel_traits.h rename to layer-norm/ln_kernel_traits.h diff --git a/ln_parallel_bwd_1024.cu b/layer-norm/ln_parallel_bwd_1024.cu similarity index 100% rename from ln_parallel_bwd_1024.cu rename to layer-norm/ln_parallel_bwd_1024.cu diff --git a/ln_parallel_bwd_1280.cu b/layer-norm/ln_parallel_bwd_1280.cu similarity index 100% rename from ln_parallel_bwd_1280.cu rename to layer-norm/ln_parallel_bwd_1280.cu diff --git a/ln_parallel_bwd_1536.cu b/layer-norm/ln_parallel_bwd_1536.cu similarity index 100% rename from ln_parallel_bwd_1536.cu rename to layer-norm/ln_parallel_bwd_1536.cu diff --git a/ln_parallel_bwd_2048.cu b/layer-norm/ln_parallel_bwd_2048.cu similarity index 100% rename from ln_parallel_bwd_2048.cu rename to layer-norm/ln_parallel_bwd_2048.cu diff --git a/ln_parallel_bwd_256.cu b/layer-norm/ln_parallel_bwd_256.cu similarity index 100% rename from ln_parallel_bwd_256.cu rename to layer-norm/ln_parallel_bwd_256.cu diff --git a/ln_parallel_bwd_2560.cu b/layer-norm/ln_parallel_bwd_2560.cu similarity index 100% rename from ln_parallel_bwd_2560.cu rename to layer-norm/ln_parallel_bwd_2560.cu diff --git a/ln_parallel_bwd_3072.cu b/layer-norm/ln_parallel_bwd_3072.cu similarity index 100% rename from ln_parallel_bwd_3072.cu rename to layer-norm/ln_parallel_bwd_3072.cu diff --git a/ln_parallel_bwd_4096.cu b/layer-norm/ln_parallel_bwd_4096.cu similarity index 100% rename from ln_parallel_bwd_4096.cu rename to layer-norm/ln_parallel_bwd_4096.cu diff --git a/ln_parallel_bwd_512.cu b/layer-norm/ln_parallel_bwd_512.cu similarity index 100% rename from ln_parallel_bwd_512.cu rename to layer-norm/ln_parallel_bwd_512.cu diff --git a/ln_parallel_bwd_5120.cu b/layer-norm/ln_parallel_bwd_5120.cu similarity index 100% rename from ln_parallel_bwd_5120.cu rename to layer-norm/ln_parallel_bwd_5120.cu diff --git a/ln_parallel_bwd_6144.cu b/layer-norm/ln_parallel_bwd_6144.cu similarity index 100% rename from ln_parallel_bwd_6144.cu rename to layer-norm/ln_parallel_bwd_6144.cu diff --git a/ln_parallel_bwd_7168.cu b/layer-norm/ln_parallel_bwd_7168.cu similarity index 100% rename from ln_parallel_bwd_7168.cu rename to layer-norm/ln_parallel_bwd_7168.cu diff --git a/ln_parallel_bwd_768.cu b/layer-norm/ln_parallel_bwd_768.cu similarity index 100% rename from ln_parallel_bwd_768.cu rename to layer-norm/ln_parallel_bwd_768.cu diff --git a/ln_parallel_bwd_8192.cu b/layer-norm/ln_parallel_bwd_8192.cu similarity index 100% rename from ln_parallel_bwd_8192.cu rename to layer-norm/ln_parallel_bwd_8192.cu diff --git a/ln_parallel_fwd_1024.cu b/layer-norm/ln_parallel_fwd_1024.cu similarity index 100% rename from ln_parallel_fwd_1024.cu rename to layer-norm/ln_parallel_fwd_1024.cu diff --git a/ln_parallel_fwd_1280.cu b/layer-norm/ln_parallel_fwd_1280.cu similarity index 100% rename from ln_parallel_fwd_1280.cu rename to layer-norm/ln_parallel_fwd_1280.cu diff --git a/ln_parallel_fwd_1536.cu b/layer-norm/ln_parallel_fwd_1536.cu similarity index 100% rename from ln_parallel_fwd_1536.cu rename to layer-norm/ln_parallel_fwd_1536.cu diff --git a/ln_parallel_fwd_2048.cu b/layer-norm/ln_parallel_fwd_2048.cu similarity index 100% rename from ln_parallel_fwd_2048.cu rename to layer-norm/ln_parallel_fwd_2048.cu diff --git a/ln_parallel_fwd_256.cu b/layer-norm/ln_parallel_fwd_256.cu similarity index 100% rename from ln_parallel_fwd_256.cu rename to layer-norm/ln_parallel_fwd_256.cu diff --git a/ln_parallel_fwd_2560.cu b/layer-norm/ln_parallel_fwd_2560.cu similarity index 100% rename from ln_parallel_fwd_2560.cu rename to layer-norm/ln_parallel_fwd_2560.cu diff --git a/ln_parallel_fwd_3072.cu b/layer-norm/ln_parallel_fwd_3072.cu similarity index 100% rename from ln_parallel_fwd_3072.cu rename to layer-norm/ln_parallel_fwd_3072.cu diff --git a/ln_parallel_fwd_4096.cu b/layer-norm/ln_parallel_fwd_4096.cu similarity index 100% rename from ln_parallel_fwd_4096.cu rename to layer-norm/ln_parallel_fwd_4096.cu diff --git a/ln_parallel_fwd_512.cu b/layer-norm/ln_parallel_fwd_512.cu similarity index 100% rename from ln_parallel_fwd_512.cu rename to layer-norm/ln_parallel_fwd_512.cu diff --git a/ln_parallel_fwd_5120.cu b/layer-norm/ln_parallel_fwd_5120.cu similarity index 100% rename from ln_parallel_fwd_5120.cu rename to layer-norm/ln_parallel_fwd_5120.cu diff --git a/ln_parallel_fwd_6144.cu b/layer-norm/ln_parallel_fwd_6144.cu similarity index 100% rename from ln_parallel_fwd_6144.cu rename to layer-norm/ln_parallel_fwd_6144.cu diff --git a/ln_parallel_fwd_7168.cu b/layer-norm/ln_parallel_fwd_7168.cu similarity index 100% rename from ln_parallel_fwd_7168.cu rename to layer-norm/ln_parallel_fwd_7168.cu diff --git a/ln_parallel_fwd_768.cu b/layer-norm/ln_parallel_fwd_768.cu similarity index 100% rename from ln_parallel_fwd_768.cu rename to layer-norm/ln_parallel_fwd_768.cu diff --git a/ln_parallel_fwd_8192.cu b/layer-norm/ln_parallel_fwd_8192.cu similarity index 100% rename from ln_parallel_fwd_8192.cu rename to layer-norm/ln_parallel_fwd_8192.cu diff --git a/ln_parallel_residual_bwd_kernels.cuh b/layer-norm/ln_parallel_residual_bwd_kernels.cuh similarity index 100% rename from ln_parallel_residual_bwd_kernels.cuh rename to layer-norm/ln_parallel_residual_bwd_kernels.cuh diff --git a/ln_parallel_residual_fwd_kernels.cuh b/layer-norm/ln_parallel_residual_fwd_kernels.cuh similarity index 100% rename from ln_parallel_residual_fwd_kernels.cuh rename to layer-norm/ln_parallel_residual_fwd_kernels.cuh diff --git a/ln_utils.cuh b/layer-norm/ln_utils.cuh similarity index 100% rename from ln_utils.cuh rename to layer-norm/ln_utils.cuh diff --git a/static_switch.h b/layer-norm/static_switch.h similarity index 100% rename from static_switch.h rename to layer-norm/static_switch.h diff --git a/torch-ext/layer-norm/__init__.py b/torch-ext/layer-norm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..946160f16b9dc91fefeea037fb7ac84fd6afd802 --- /dev/null +++ b/torch-ext/layer-norm/__init__.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +from ._ops import ops + +from . import layers + +def dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm): + return ops.dropout_add_ln_fwd(input, gamma, beta, rowscale, colscale, x0_subset, z_subset, dropout_p, epsilon, rowscale_const, z_numrows, gen, residual_in_fp32, is_rms_norm) + +def dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm): + return ops.dropout_add_ln_bwd(dz, dx, x, mu, rsigma, gamma, rowscale, colscale, x0_subset, z_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm) + +def dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm): + return ops.dropout_add_ln_parallel_residual_fwd(input, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, gen, residual_in_fp32, is_rms_norm) + +def dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm): + return ops.dropout_add_ln_parallel_residual_bwd(dz0, dz1, dx, x, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm) + +__all__ = [ + "layers", + "dropout_add_ln_fwd", + "dropout_add_ln_bwd", + "dropout_add_ln_parallel_residual_fwd", + "dropout_add_ln_parallel_residual_bwd", +] \ No newline at end of file diff --git a/torch-ext/layer-norm/layers.py b/torch-ext/layer-norm/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f313316614a9b35a8337b40345e2f6ecef7ae20 --- /dev/null +++ b/torch-ext/layer-norm/layers.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn + +from ._ops import ops + + +class LayerNorm(nn.Module): + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return ops.dropout_add_ln_fwd( + hidden_states, + gamma = self.weight, + beta = None, + rowscale = None, + colscale = None, + x0_subset = None, + z_subset = None, + dropout_p = 0, + epsilon = self.variance_epsilon, + rowscale_const = 1.0, + z_numrows = hidden_states.shape[1], + gen = None, + residual_in_fp32 = False, + is_rms_norm = False, + ) + +class LlamaRMSNorm(nn.Module): + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return ops.dropout_add_ln_fwd( + hidden_states, + gamma = self.weight, + beta = None, + rowscale = None, + colscale = None, + x0_subset = None, + z_subset = None, + dropout_p = 0, + epsilon = self.variance_epsilon, + rowscale_const = 1.0, + z_numrows = hidden_states.shape[1], + gen = None, + residual_in_fp32 = False, + is_rms_norm = True, + ) \ No newline at end of file diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba23184ed6f2f7358cb7fb6b80b8d54d43cbcf94 --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,17 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("dropout_add_ln_fwd(Tensor input, Tensor gamma, Tensor beta, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float epsilon, float rowscale_const, int64_t z_numrows, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor"); + ops.impl("dropout_add_ln_fwd", torch::kCUDA, &dropout_add_ln_fwd); + ops.def("dropout_add_ln_bwd(Tensor dz, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float rowscale_const, int64_t x0_numrows, bool has_residual, bool is_rms_norm) -> Tensor"); + ops.impl("dropout_add_ln_bwd", torch::kCUDA, &dropout_add_ln_bwd); + ops.def("dropout_add_ln_parallel_residual_fwd(Tensor input, Tensor gamma0, Tensor beta0, Tensor gamma1, Tensor beta1, float dropout_p, float epsilon, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor"); + ops.impl("dropout_add_ln_parallel_residual_fwd", torch::kCUDA, &dropout_add_ln_parallel_residual_fwd); + ops.def("dropout_add_ln_parallel_residual_bwd(Tensor dz0, Tensor dz1, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma0, Tensor gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm) -> Tensor"); + ops.impl("dropout_add_ln_parallel_residual_bwd", torch::kCUDA, &dropout_add_ln_parallel_residual_bwd); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..9d3a385ded4e53c1414532154fe00b3f2e35defb --- /dev/null +++ b/torch-ext/torch_binding.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +torch::Tensor dropout_add_ln_fwd(torch::Tensor &input, torch::Tensor &gamma, torch::Tensor &beta, torch::Tensor &rowscale, torch::Tensor &colscale, torch::Tensor &x0_subset, torch::Tensor &z_subset, float dropout_p, float epsilon, float rowscale_const, int64_t z_numrows, torch::Generator &gen, bool residual_in_fp32, bool is_rms_norm); +torch::Tensor dropout_add_ln_bwd(torch::Tensor &dz, torch::Tensor &dx, torch::Tensor &x, torch::Tensor &mu, torch::Tensor &rsigma, torch::Tensor &gamma, torch::Tensor &rowscale, torch::Tensor &colscale, torch::Tensor &x0_subset, torch::Tensor &z_subset, float dropout_p, float rowscale_const, int64_t x0_numrows, bool has_residual, bool is_rms_norm); +torch::Tensor dropout_add_ln_parallel_residual_fwd(torch::Tensor &input, torch::Tensor &gamma0, torch::Tensor &beta0, torch::Tensor &gamma1, torch::Tensor &beta1, float dropout_p, float epsilon, torch::Generator &gen, bool residual_in_fp32, bool is_rms_norm); +torch::Tensor dropout_add_ln_parallel_residual_bwd(torch::Tensor &dz0, torch::Tensor &dz1, torch::Tensor &dx, torch::Tensor &x, torch::Tensor &mu, torch::Tensor &rsigma, torch::Tensor &gamma0, torch::Tensor &gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm);