layer_norm / torch-ext /torch_binding.cpp
MekkCyber
add kernel structure
8b60ef7
raw
history blame
1.54 kB
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("dropout_add_ln_fwd(Tensor input, Tensor gamma, Tensor beta, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float epsilon, float rowscale_const, int64_t z_numrows, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor");
ops.impl("dropout_add_ln_fwd", torch::kCUDA, &dropout_add_ln_fwd);
ops.def("dropout_add_ln_bwd(Tensor dz, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float rowscale_const, int64_t x0_numrows, bool has_residual, bool is_rms_norm) -> Tensor");
ops.impl("dropout_add_ln_bwd", torch::kCUDA, &dropout_add_ln_bwd);
ops.def("dropout_add_ln_parallel_residual_fwd(Tensor input, Tensor gamma0, Tensor beta0, Tensor gamma1, Tensor beta1, float dropout_p, float epsilon, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor");
ops.impl("dropout_add_ln_parallel_residual_fwd", torch::kCUDA, &dropout_add_ln_parallel_residual_fwd);
ops.def("dropout_add_ln_parallel_residual_bwd(Tensor dz0, Tensor dz1, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma0, Tensor gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm) -> Tensor");
ops.impl("dropout_add_ln_parallel_residual_bwd", torch::kCUDA, &dropout_add_ln_parallel_residual_bwd);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)