| | #pragma once |
| |
|
| | #include <torch/torch.h> |
| |
|
| | void poly_norm(torch::Tensor &out, const torch::Tensor &input, |
| | const torch::Tensor &weights, const torch::Tensor &bias, |
| | double eps); |
| | void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, |
| | torch::Tensor &bias_grad, |
| | const torch::Tensor &output_grad, |
| | const torch::Tensor &input, const torch::Tensor &weight, |
| | double eps); |
| |
|
| | torch::Tensor rms_norm(const torch::Tensor &input, const torch::Tensor &weights, |
| | double eps); |
| | std::tuple<torch::Tensor, torch::Tensor> |
| | rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &input, |
| | const torch::Tensor &weight, double eps); |
| |
|
| | void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input, |
| | const torch::Tensor &mul, const torch::Tensor &weights, |
| | const torch::Tensor &bias, double eps); |
| | void fused_mul_poly_norm_backward( |
| | torch::Tensor &input_grad, torch::Tensor &mul_grad, |
| | torch::Tensor &weight_grad, torch::Tensor &bias_grad, |
| | const torch::Tensor &output_grad, const torch::Tensor &input, |
| | const torch::Tensor &mul, const torch::Tensor &weight, |
| | const torch::Tensor &bias, double eps); |
| |
|
| | std::tuple<torch::Tensor, torch::Tensor> |
| | fused_add_rms_norm(const torch::Tensor &input, const torch::Tensor &residual, |
| | const torch::Tensor &weight, double eps); |
| |
|
| | std::tuple<torch::Tensor, torch::Tensor> fused_add_rms_norm_backward( |
| | const torch::Tensor &output_grad, const torch::Tensor &add_output_grad, |
| | const torch::Tensor &input, const torch::Tensor &weight, double eps, |
| | bool need_input_grad); |
| |
|