#include "torch_binding.h" #include #include "registration.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // poly_norm ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, " "float eps) -> ()"); ops.impl("poly_norm", torch::kCUDA, &poly_norm); ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! " "bias_grad, Tensor output_grad, Tensor input, Tensor weight, float " "eps) -> ()"); ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward); // rms_norm ops.def("rms_norm(Tensor input, Tensor weight, float eps) -> Tensor"); ops.impl("rms_norm", torch::kCUDA, &rms_norm); ops.def("rms_norm_backward(Tensor output_grad, Tensor input, Tensor weight, " "float eps) -> (Tensor, Tensor)"); ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward); // fused_mul_poly_norm ops.def("fused_mul_poly_norm(Tensor! out, Tensor input, Tensor mul, Tensor " "weight, Tensor bias, " "float eps) -> ()"); ops.impl("fused_mul_poly_norm", torch::kCUDA, &fused_mul_poly_norm); ops.def("fused_mul_poly_norm_backward(Tensor! input_grad, Tensor! mul_grad, " "Tensor! weight_grad, Tensor! " "bias_grad, Tensor output_grad, Tensor input, Tensor mul, Tensor " "weight, Tensor " "bias, float eps) -> ()"); ops.impl("fused_mul_poly_norm_backward", torch::kCUDA, &fused_mul_poly_norm_backward); // fused_add_rms_norm ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor " "weight, float eps) -> (Tensor, Tensor)"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ops.def( "fused_add_rms_norm_backward(Tensor output_grad, Tensor add_output_grad," "Tensor input, Tensor weight, float eps, bool need_input_grad) -> " "(Tensor, Tensor)"); ops.impl("fused_add_rms_norm_backward", torch::kCUDA, &fused_add_rms_norm_backward); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)