#include #include "registration.h" #include "torch_binding.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("dropout_add_ln_fwd(Tensor input, Tensor gamma, Tensor beta, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float epsilon, float rowscale_const, int64_t z_numrows, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor"); ops.impl("dropout_add_ln_fwd", torch::kCUDA, &dropout_add_ln_fwd); ops.def("dropout_add_ln_bwd(Tensor dz, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma, Tensor rowscale, Tensor colscale, Tensor x0_subset, Tensor z_subset, float dropout_p, float rowscale_const, int64_t x0_numrows, bool has_residual, bool is_rms_norm) -> Tensor"); ops.impl("dropout_add_ln_bwd", torch::kCUDA, &dropout_add_ln_bwd); ops.def("dropout_add_ln_parallel_residual_fwd(Tensor input, Tensor gamma0, Tensor beta0, Tensor gamma1, Tensor beta1, float dropout_p, float epsilon, Generator gen, bool residual_in_fp32, bool is_rms_norm) -> Tensor"); ops.impl("dropout_add_ln_parallel_residual_fwd", torch::kCUDA, &dropout_add_ln_parallel_residual_fwd); ops.def("dropout_add_ln_parallel_residual_bwd(Tensor dz0, Tensor dz1, Tensor dx, Tensor x, Tensor mu, Tensor rsigma, Tensor gamma0, Tensor gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm) -> Tensor"); ops.impl("dropout_add_ln_parallel_residual_bwd", torch::kCUDA, &dropout_add_ln_parallel_residual_bwd); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)