Kernels
File size: 2,051 Bytes
e5e2eeb
 
44e9845
 
 
 
 
e5e2eeb
97825b8
 
e5e2eeb
 
97825b8
 
 
44e9845
f3b99fb
e5e2eeb
9d0a235
e5e2eeb
 
9d0a235
 
f3b99fb
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e4334d
 
e5e2eeb
a1e5ca8
 
7e4334d
 
 
a1e5ca8
 
44e9845
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include "torch_binding.h"

#include <torch/library.h>

#include "registration.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // poly_norm
  ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, "
          "float eps) -> ()");
  ops.impl("poly_norm", torch::kCUDA, &poly_norm);

  ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! "
          "bias_grad, Tensor output_grad, Tensor input, Tensor weight, float "
          "eps) -> ()");
  ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);

  // rms_norm
  ops.def("rms_norm(Tensor input, Tensor weight, float eps) -> Tensor");
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);

  ops.def("rms_norm_backward(Tensor output_grad, Tensor input, Tensor weight, "
          "float eps) -> (Tensor, Tensor)");
  ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);

  // fused_mul_poly_norm
  ops.def("fused_mul_poly_norm(Tensor! out, Tensor input, Tensor mul, Tensor "
          "weight, Tensor bias, "
          "float eps) -> ()");
  ops.impl("fused_mul_poly_norm", torch::kCUDA, &fused_mul_poly_norm);

  ops.def("fused_mul_poly_norm_backward(Tensor! input_grad, Tensor! mul_grad, "
          "Tensor! weight_grad, Tensor! "
          "bias_grad, Tensor output_grad, Tensor input, Tensor mul, Tensor "
          "weight, Tensor "
          "bias, float eps) -> ()");
  ops.impl("fused_mul_poly_norm_backward", torch::kCUDA,
           &fused_mul_poly_norm_backward);

  // fused_add_rms_norm
  ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor "
          "weight, float eps) -> (Tensor, Tensor)");
  ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

  ops.def(
      "fused_add_rms_norm_backward(Tensor output_grad, Tensor add_output_grad,"
      "Tensor input, Tensor weight, float eps, bool need_input_grad) -> "
      "(Tensor, Tensor)");
  ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
           &fused_add_rms_norm_backward);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)