MekkCyber commited on
Commit
d180d2d
·
1 Parent(s): 63184de

rm bindings

Browse files
Files changed (1) hide show
  1. layer-norm/ln_api.cpp +1 -23
layer-norm/ln_api.cpp CHANGED
@@ -825,26 +825,4 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
825
  return result;
826
  }
827
 
828
- ////////////////////////////////////////////////////////////////////////////////////////////////////
829
-
830
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
831
- m.doc() = "CUDA DropoutAddLayerNorm";
832
- m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
833
- py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta_"),
834
- py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
835
- py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
836
- py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
837
- m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
838
- py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
839
- py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
840
- py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
841
- py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
842
- m.def("dropout_add_ln_parallel_residual_fwd", &dropout_add_ln_parallel_residual_fwd, "Run Dropout + Add + LayerNorm parallel residual forward kernel",
843
- py::arg("x0"), py::arg("x1_"), py::arg("residual"), py::arg("gamma0"), py::arg("beta0_"),
844
- py::arg("gamma1_"), py::arg("beta1_"), py::arg("dropout_p"), py::arg("epsilon"),
845
- py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
846
- m.def("dropout_add_ln_parallel_residual_bwd", &dropout_add_ln_parallel_residual_bwd, "Run Dropout + Add + LayerNorm parallel residual backward kernel",
847
- py::arg("dz0"), py::arg("dz1_"), py::arg("dx_"), py::arg("x"), py::arg("dmask0_"),
848
- py::arg("dmask1_"), py::arg("mu"), py::arg("rsigma"), py::arg("gamma0"), py::arg("gamma1_"),
849
- py::arg("dropout_p"), py::arg("has_x1"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
850
- }
 
825
  return result;
826
  }
827
 
828
+ ////////////////////////////////////////////////////////////////////////////////////////////////////