Kernels
wyldecat commited on
Commit
9d0a235
·
1 Parent(s): ab05e35

feat: make rms_norm as out-place

Browse files
activation/rms_norm.cu CHANGED
@@ -271,14 +271,13 @@ rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
271
  weight.data_ptr<scalar_t>(), eps, d); \
272
  });
273
 
274
- void rms_norm(torch::Tensor &out, // [..., d]
275
- const torch::Tensor &input, // [..., d]
276
- const torch::Tensor &weight, // [d]
277
- double eps) {
278
- AssertTensorShapeEqual(input, out, "input", "out");
279
  AssertTensorNotNull(weight, "weight");
280
  // TODO shape check
281
 
 
282
  int d = input.size(-1);
283
  int64_t num_tokens = input.numel() / input.size(-1);
284
  dim3 grid(num_tokens);
@@ -292,6 +291,8 @@ void rms_norm(torch::Tensor &out, // [..., d]
292
  } else {
293
  LAUNCH_RMS_NORM(0);
294
  }
 
 
295
  }
296
 
297
  #define LAUNCH_RMS_NORM_BWD(width) \
@@ -305,12 +306,14 @@ void rms_norm(torch::Tensor &out, // [..., d]
305
  weight.data_ptr<scalar_t>(), eps, d); \
306
  });
307
 
308
- void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
309
- torch::Tensor &weight_grad, // [d]
310
- const torch::Tensor &output_grad, // [..., d]
311
- const torch::Tensor &input, // [..., d]
312
- const torch::Tensor &weight, // [d]
313
- double eps) {
 
 
314
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
315
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
316
  AssertTensorNotNull(weight, "weight");
@@ -340,4 +343,6 @@ void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
340
  at::sum_out(acc, temp_weight_grad, {0});
341
  weight_grad.copy_(acc);
342
  }
 
 
343
  }
 
271
  weight.data_ptr<scalar_t>(), eps, d); \
272
  });
273
 
274
+ torch::Tensor rms_norm(const torch::Tensor &input, // [..., d]
275
+ const torch::Tensor &weight, // [d]
276
+ double eps) {
 
 
277
  AssertTensorNotNull(weight, "weight");
278
  // TODO shape check
279
 
280
+ torch::Tensor out = torch::empty_like(input);
281
  int d = input.size(-1);
282
  int64_t num_tokens = input.numel() / input.size(-1);
283
  dim3 grid(num_tokens);
 
291
  } else {
292
  LAUNCH_RMS_NORM(0);
293
  }
294
+
295
+ return out;
296
  }
297
 
298
  #define LAUNCH_RMS_NORM_BWD(width) \
 
306
  weight.data_ptr<scalar_t>(), eps, d); \
307
  });
308
 
309
+ std::tuple<torch::Tensor, torch::Tensor>
310
+ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
311
+ const torch::Tensor &input, // [..., d]
312
+ const torch::Tensor &weight, // [d]
313
+ double eps) {
314
+ torch::Tensor input_grad = torch::empty_like(input);
315
+ torch::Tensor weight_grad = torch::empty_like(weight);
316
+
317
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
318
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
319
  AssertTensorNotNull(weight, "weight");
 
343
  at::sum_out(acc, temp_weight_grad, {0});
344
  weight_grad.copy_(acc);
345
  }
346
+
347
+ return {input_grad, weight_grad};
348
  }
tests/test_rms_norm.py CHANGED
@@ -51,8 +51,7 @@ def test_rms_norm(
51
  layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype)
52
  layer.weight = torch.nn.Parameter(weight)
53
 
54
- out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
55
- opcheck(op, (out, x, weight, eps))
56
 
57
  out = fn(x, weight, eps)
58
  mod_out = layer(x)
 
51
  layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype)
52
  layer.weight = torch.nn.Parameter(weight)
53
 
54
+ opcheck(op, (x, weight, eps))
 
55
 
56
  out = fn(x, weight, eps)
57
  mod_out = layer(x)
torch-ext/activation/rms_norm.py CHANGED
@@ -8,9 +8,7 @@ class RMSNormFunction(torch.autograd.Function):
8
  # Note that forward, setup_context, and backward are @staticmethods
9
  @staticmethod
10
  def forward(input, weight, eps):
11
- output = torch.empty_like(input)
12
- ops.rms_norm(output, input, weight, eps)
13
- return output
14
 
15
  @staticmethod
16
  # inputs is a Tuple of all of the inputs passed to forward.
@@ -26,13 +24,8 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(
30
- input) if ctx.needs_input_grad[0] else None
31
- weight_grad = torch.empty_like(
32
- weight) if ctx.needs_input_grad[1] else None
33
-
34
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
- weight, eps)
36
 
37
  return input_grad, weight_grad, None
38
 
 
8
  # Note that forward, setup_context, and backward are @staticmethods
9
  @staticmethod
10
  def forward(input, weight, eps):
11
+ return ops.rms_norm(input, weight, eps)
 
 
12
 
13
  @staticmethod
14
  # inputs is a Tuple of all of the inputs passed to forward.
 
24
  input, weight = ctx.saved_tensors
25
  eps = ctx.eps
26
 
27
+ input_grad, weight_grad = ops.rms_norm_backward(
28
+ output_grad, input, weight, eps)
 
 
 
 
 
29
 
30
  return input_grad, weight_grad, None
31
 
torch-ext/torch_binding.cpp CHANGED
@@ -16,12 +16,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
16
  ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
17
 
18
  // rms_norm
19
- ops.def(
20
- "rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
21
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);
22
 
23
- ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor "
24
- "output_grad, Tensor input, Tensor weight, float eps) -> ()");
25
  ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
26
 
27
  // fused_mul_poly_norm
 
16
  ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
17
 
18
  // rms_norm
19
+ ops.def("rms_norm(Tensor input, Tensor weight, float eps) -> Tensor");
 
20
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);
21
 
22
+ ops.def("rms_norm_backward(Tensor output_grad, Tensor input, Tensor weight, "
23
+ "float eps) -> (Tensor, Tensor)");
24
  ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
25
 
26
  // fused_mul_poly_norm
torch-ext/torch_binding.h CHANGED
@@ -11,12 +11,11 @@ void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
11
  const torch::Tensor &input, const torch::Tensor &weight,
12
  double eps);
13
 
14
- void rms_norm(torch::Tensor &out, const torch::Tensor &input,
15
- const torch::Tensor &weights, double eps);
16
- void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
17
- const torch::Tensor &output_grad,
18
- const torch::Tensor &input, const torch::Tensor &weight,
19
  double eps);
 
 
 
20
 
21
  void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input,
22
  const torch::Tensor &mul, const torch::Tensor &weights,
 
11
  const torch::Tensor &input, const torch::Tensor &weight,
12
  double eps);
13
 
14
+ torch::Tensor rms_norm(const torch::Tensor &input, const torch::Tensor &weights,
 
 
 
 
15
  double eps);
16
+ std::tuple<torch::Tensor, torch::Tensor>
17
+ rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &input,
18
+ const torch::Tensor &weight, double eps);
19
 
20
  void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input,
21
  const torch::Tensor &mul, const torch::Tensor &weights,