Kernels
wyldecat commited on
Commit
a2a2501
·
1 Parent(s): 9d0a235

feat: add assert is_contiguous

Browse files
activation/assert_utils.h CHANGED
@@ -8,6 +8,18 @@ inline void AssertTensorNotNull(const torch::Tensor &tensor,
8
  TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
9
  }
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a,
12
  const torch::Tensor &tensor_b,
13
  const std::string &name_a,
 
8
  TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
9
  }
10
 
11
+ inline void AssertTensorContiguous(const torch::Tensor &tensor,
12
+ const std::string &name,
13
+ bool nullable = false) {
14
+ if (nullable && !tensor.defined()) {
15
+ return;
16
+ }
17
+
18
+ AssertTensorNotNull(tensor, name);
19
+ TORCH_INTERNAL_ASSERT(tensor.is_contiguous(),
20
+ name + " tensor should be contiguous.");
21
+ }
22
+
23
  inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a,
24
  const torch::Tensor &tensor_b,
25
  const std::string &name_a,
activation/fused_add_rms_norm.cu CHANGED
@@ -307,6 +307,12 @@ void fused_add_rms_norm(torch::Tensor &out, // [..., d]
307
  AssertTensorNotNull(weight, "weight");
308
  // TODO shape check
309
 
 
 
 
 
 
 
310
  int d = input.size(-1);
311
  int64_t num_tokens = input.numel() / input.size(-1);
312
  dim3 grid(num_tokens);
@@ -346,6 +352,14 @@ void fused_add_rms_norm_backward(
346
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
347
  AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
348
  AssertTensorNotNull(weight, "weight");
 
 
 
 
 
 
 
 
349
  // TODO shape check
350
  // weight_grad, input_grad can be nullable
351
 
 
307
  AssertTensorNotNull(weight, "weight");
308
  // TODO shape check
309
 
310
+ AssertTensorContiguous(out, "out");
311
+ AssertTensorContiguous(add_out, "add_out");
312
+ AssertTensorContiguous(input, "input");
313
+ AssertTensorContiguous(residual, "residual");
314
+ AssertTensorContiguous(weight, "weight");
315
+
316
  int d = input.size(-1);
317
  int64_t num_tokens = input.numel() / input.size(-1);
318
  dim3 grid(num_tokens);
 
352
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
353
  AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
354
  AssertTensorNotNull(weight, "weight");
355
+
356
+ constexpr bool ALLOW_NULL = true;
357
+ AssertTensorContiguous(input_grad, "input_grad", ALLOW_NULL);
358
+ AssertTensorContiguous(weight_grad, "weight_grad", ALLOW_NULL);
359
+ AssertTensorContiguous(output_grad, "output_grad");
360
+ AssertTensorContiguous(add_output_grad, "add_output_grad");
361
+ AssertTensorContiguous(input, "input");
362
+ AssertTensorContiguous(weight, "weight");
363
  // TODO shape check
364
  // weight_grad, input_grad can be nullable
365
 
activation/fused_mul_poly_norm.cu CHANGED
@@ -556,6 +556,12 @@ void fused_mul_poly_norm(torch::Tensor &out, // [..., d]
556
  AssertTensorShapeEqual(input, mul, "input", "mul");
557
  AssertTensorNotNull(weight, "weight");
558
  AssertTensorNotNull(bias, "bias");
 
 
 
 
 
 
559
  // TODO shape check
560
 
561
  int d = input.size(-1);
@@ -602,6 +608,17 @@ void fused_mul_poly_norm_backward(torch::Tensor &input_grad, // [..., d]
602
  AssertTensorShapeEqual(input, mul_grad, "input", "mul_grad");
603
  AssertTensorShapeEqual(input, mul, "input", "mul");
604
  AssertTensorNotNull(weight, "weight");
 
 
 
 
 
 
 
 
 
 
 
605
  // TODO shape check
606
  // weight_grad, bias_grad, mul_grad and input_grad can be nullable
607
 
 
556
  AssertTensorShapeEqual(input, mul, "input", "mul");
557
  AssertTensorNotNull(weight, "weight");
558
  AssertTensorNotNull(bias, "bias");
559
+
560
+ AssertTensorContiguous(out, "out");
561
+ AssertTensorContiguous(input, "input");
562
+ AssertTensorContiguous(mul, "mul");
563
+ AssertTensorContiguous(weight, "weight");
564
+ AssertTensorContiguous(bias, "bias");
565
  // TODO shape check
566
 
567
  int d = input.size(-1);
 
608
  AssertTensorShapeEqual(input, mul_grad, "input", "mul_grad");
609
  AssertTensorShapeEqual(input, mul, "input", "mul");
610
  AssertTensorNotNull(weight, "weight");
611
+
612
+ constexpr bool ALLOW_NULL = true;
613
+ AssertTensorContiguous(input_grad, "input_grad", ALLOW_NULL);
614
+ AssertTensorContiguous(mul_grad, "mul_grad", ALLOW_NULL);
615
+ AssertTensorContiguous(weight_grad, "weight_grad", ALLOW_NULL);
616
+ AssertTensorContiguous(bias_grad, "bias_grad", ALLOW_NULL);
617
+ AssertTensorContiguous(output_grad, "output_grad");
618
+ AssertTensorContiguous(input, "input");
619
+ AssertTensorContiguous(mul, "mul");
620
+ AssertTensorContiguous(weight, "weight");
621
+ AssertTensorContiguous(bias, "bias");
622
  // TODO shape check
623
  // weight_grad, bias_grad, mul_grad and input_grad can be nullable
624
 
activation/poly_norm.cu CHANGED
@@ -508,6 +508,11 @@ void poly_norm(torch::Tensor &out, // [..., d]
508
  AssertTensorNotNull(bias, "bias");
509
  // TODO shape check
510
 
 
 
 
 
 
511
  int d = input.size(-1);
512
  int64_t num_tokens = input.numel() / d;
513
  dim3 grid(num_tokens);
@@ -548,6 +553,14 @@ void poly_norm_backward(torch::Tensor &input_grad, // [..., d]
548
  // TODO shape check
549
  // weight_grad, bias_grad and input_grad can be nullable
550
 
 
 
 
 
 
 
 
 
551
  int d = input.size(-1);
552
  int64_t num_tokens = input.numel() / d;
553
  dim3 grid(num_tokens);
 
508
  AssertTensorNotNull(bias, "bias");
509
  // TODO shape check
510
 
511
+ AssertTensorContiguous(out, "out");
512
+ AssertTensorContiguous(input, "input");
513
+ AssertTensorContiguous(weight, "weight");
514
+ AssertTensorContiguous(bias, "bias");
515
+
516
  int d = input.size(-1);
517
  int64_t num_tokens = input.numel() / d;
518
  dim3 grid(num_tokens);
 
553
  // TODO shape check
554
  // weight_grad, bias_grad and input_grad can be nullable
555
 
556
+ constexpr bool ALLOW_NULL = true;
557
+ AssertTensorContiguous(input_grad, "input_grad", ALLOW_NULL);
558
+ AssertTensorContiguous(weight_grad, "weight_grad", ALLOW_NULL);
559
+ AssertTensorContiguous(bias_grad, "bias_grad", ALLOW_NULL);
560
+ AssertTensorContiguous(output_grad, "output_grad");
561
+ AssertTensorContiguous(input, "input");
562
+ AssertTensorContiguous(weight, "weight");
563
+
564
  int d = input.size(-1);
565
  int64_t num_tokens = input.numel() / d;
566
  dim3 grid(num_tokens);
activation/rms_norm.cu CHANGED
@@ -276,6 +276,8 @@ torch::Tensor rms_norm(const torch::Tensor &input, // [..., d]
276
  double eps) {
277
  AssertTensorNotNull(weight, "weight");
278
  // TODO shape check
 
 
279
 
280
  torch::Tensor out = torch::empty_like(input);
281
  int d = input.size(-1);
@@ -314,6 +316,10 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
314
  torch::Tensor input_grad = torch::empty_like(input);
315
  torch::Tensor weight_grad = torch::empty_like(weight);
316
 
 
 
 
 
317
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
318
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
319
  AssertTensorNotNull(weight, "weight");
 
276
  double eps) {
277
  AssertTensorNotNull(weight, "weight");
278
  // TODO shape check
279
+ AssertTensorContiguous(input, "input");
280
+ AssertTensorContiguous(weight, "weight");
281
 
282
  torch::Tensor out = torch::empty_like(input);
283
  int d = input.size(-1);
 
316
  torch::Tensor input_grad = torch::empty_like(input);
317
  torch::Tensor weight_grad = torch::empty_like(weight);
318
 
319
+ AssertTensorContiguous(output_grad, "output_grad");
320
+ AssertTensorContiguous(input, "input");
321
+ AssertTensorContiguous(weight, "weight");
322
+
323
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
324
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
325
  AssertTensorNotNull(weight, "weight");