Kernels
activation / torch-ext /torch_binding.h
wyldecat's picture
refactor(activation): change fused_add_rms_norm and fused_add_rms_norm_backward to out-place operations
7e4334d
#pragma once
#include <torch/torch.h>
void poly_norm(torch::Tensor &out, const torch::Tensor &input,
const torch::Tensor &weights, const torch::Tensor &bias,
double eps);
void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
torch::Tensor &bias_grad,
const torch::Tensor &output_grad,
const torch::Tensor &input, const torch::Tensor &weight,
double eps);
torch::Tensor rms_norm(const torch::Tensor &input, const torch::Tensor &weights,
double eps);
std::tuple<torch::Tensor, torch::Tensor>
rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &input,
const torch::Tensor &weight, double eps);
void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input,
const torch::Tensor &mul, const torch::Tensor &weights,
const torch::Tensor &bias, double eps);
void fused_mul_poly_norm_backward(
torch::Tensor &input_grad, torch::Tensor &mul_grad,
torch::Tensor &weight_grad, torch::Tensor &bias_grad,
const torch::Tensor &output_grad, const torch::Tensor &input,
const torch::Tensor &mul, const torch::Tensor &weight,
const torch::Tensor &bias, double eps);
std::tuple<torch::Tensor, torch::Tensor>
fused_add_rms_norm(const torch::Tensor &input, const torch::Tensor &residual,
const torch::Tensor &weight, double eps);
std::tuple<torch::Tensor, torch::Tensor> fused_add_rms_norm_backward(
const torch::Tensor &output_grad, const torch::Tensor &add_output_grad,
const torch::Tensor &input, const torch::Tensor &weight, double eps,
bool need_input_grad);