layer_norm / torch-ext /torch_binding.h
MekkCyber
add kernel structure
8b60ef7
raw
history blame
1.27 kB
#pragma once
#include <torch/torch.h>
torch::Tensor dropout_add_ln_fwd(torch::Tensor &input, torch::Tensor &gamma, torch::Tensor &beta, torch::Tensor &rowscale, torch::Tensor &colscale, torch::Tensor &x0_subset, torch::Tensor &z_subset, float dropout_p, float epsilon, float rowscale_const, int64_t z_numrows, torch::Generator &gen, bool residual_in_fp32, bool is_rms_norm);
torch::Tensor dropout_add_ln_bwd(torch::Tensor &dz, torch::Tensor &dx, torch::Tensor &x, torch::Tensor &mu, torch::Tensor &rsigma, torch::Tensor &gamma, torch::Tensor &rowscale, torch::Tensor &colscale, torch::Tensor &x0_subset, torch::Tensor &z_subset, float dropout_p, float rowscale_const, int64_t x0_numrows, bool has_residual, bool is_rms_norm);
torch::Tensor dropout_add_ln_parallel_residual_fwd(torch::Tensor &input, torch::Tensor &gamma0, torch::Tensor &beta0, torch::Tensor &gamma1, torch::Tensor &beta1, float dropout_p, float epsilon, torch::Generator &gen, bool residual_in_fp32, bool is_rms_norm);
torch::Tensor dropout_add_ln_parallel_residual_bwd(torch::Tensor &dz0, torch::Tensor &dz1, torch::Tensor &dx, torch::Tensor &x, torch::Tensor &mu, torch::Tensor &rsigma, torch::Tensor &gamma0, torch::Tensor &gamma1, float dropout_p, bool has_x1, bool has_residual, bool is_rms_norm);