|
|
#include <torch/library.h> |
|
|
|
|
|
#include "registration.h" |
|
|
#include "torch_binding.h" |
|
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
|
ops.def("dropout_add_ln_fwd(Tensor input, Tensor gamma, Tensor beta, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float epsilon, float rowscale_const, int64_t z_numrows, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor"); |
|
|
ops.impl("dropout_add_ln_fwd", torch::kCUDA, &dropout_add_ln_fwd); |
|
|
ops.def("dropout_add_ln_bwd(Tensor dz, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float rowscale_const, int64_t x0_numrows, bool has_residual, bool is_rms_norm) -> Tensor"); |
|
|
ops.impl("dropout_add_ln_bwd", torch::kCUDA, &dropout_add_ln_bwd); |
|
|
ops.def("dropout_add_ln_parallel_residual_fwd(Tensor input, Tensor gamma0, Tensor beta0, Tensor gamma1, Tensor beta1, float dropout_p, float epsilon, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor"); |
|
|
ops.impl("dropout_add_ln_parallel_residual_fwd", torch::kCUDA, &dropout_add_ln_parallel_residual_fwd); |
|
|
ops.def("dropout_add_ln_parallel_residual_bwd(Tensor dz0, Tensor dz1, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma0, Tensor gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm) -> Tensor"); |
|
|
ops.impl("dropout_add_ln_parallel_residual_bwd", torch::kCUDA, &dropout_add_ln_parallel_residual_bwd); |
|
|
} |
|
|
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|
|