feat: add assert is_contiguous
Browse files- activation/assert_utils.h +12 -0
- activation/fused_add_rms_norm.cu +14 -0
- activation/fused_mul_poly_norm.cu +17 -0
- activation/poly_norm.cu +13 -0
- activation/rms_norm.cu +6 -0
activation/assert_utils.h
CHANGED
|
@@ -8,6 +8,18 @@ inline void AssertTensorNotNull(const torch::Tensor &tensor,
|
|
| 8 |
TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
|
| 9 |
}
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a,
|
| 12 |
const torch::Tensor &tensor_b,
|
| 13 |
const std::string &name_a,
|
|
|
|
| 8 |
TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
|
| 9 |
}
|
| 10 |
|
| 11 |
+
inline void AssertTensorContiguous(const torch::Tensor &tensor,
|
| 12 |
+
const std::string &name,
|
| 13 |
+
bool nullable = false) {
|
| 14 |
+
if (nullable && !tensor.defined()) {
|
| 15 |
+
return;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
AssertTensorNotNull(tensor, name);
|
| 19 |
+
TORCH_INTERNAL_ASSERT(tensor.is_contiguous(),
|
| 20 |
+
name + " tensor should be contiguous.");
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a,
|
| 24 |
const torch::Tensor &tensor_b,
|
| 25 |
const std::string &name_a,
|
activation/fused_add_rms_norm.cu
CHANGED
|
@@ -307,6 +307,12 @@ void fused_add_rms_norm(torch::Tensor &out, // [..., d]
|
|
| 307 |
AssertTensorNotNull(weight, "weight");
|
| 308 |
// TODO shape check
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
int d = input.size(-1);
|
| 311 |
int64_t num_tokens = input.numel() / input.size(-1);
|
| 312 |
dim3 grid(num_tokens);
|
|
@@ -346,6 +352,14 @@ void fused_add_rms_norm_backward(
|
|
| 346 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 347 |
AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
|
| 348 |
AssertTensorNotNull(weight, "weight");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
// TODO shape check
|
| 350 |
// weight_grad, input_grad can be nullable
|
| 351 |
|
|
|
|
| 307 |
AssertTensorNotNull(weight, "weight");
|
| 308 |
// TODO shape check
|
| 309 |
|
| 310 |
+
AssertTensorContiguous(out, "out");
|
| 311 |
+
AssertTensorContiguous(add_out, "add_out");
|
| 312 |
+
AssertTensorContiguous(input, "input");
|
| 313 |
+
AssertTensorContiguous(residual, "residual");
|
| 314 |
+
AssertTensorContiguous(weight, "weight");
|
| 315 |
+
|
| 316 |
int d = input.size(-1);
|
| 317 |
int64_t num_tokens = input.numel() / input.size(-1);
|
| 318 |
dim3 grid(num_tokens);
|
|
|
|
| 352 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 353 |
AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
|
| 354 |
AssertTensorNotNull(weight, "weight");
|
| 355 |
+
|
| 356 |
+
constexpr bool ALLOW_NULL = true;
|
| 357 |
+
AssertTensorContiguous(input_grad, "input_grad", ALLOW_NULL);
|
| 358 |
+
AssertTensorContiguous(weight_grad, "weight_grad", ALLOW_NULL);
|
| 359 |
+
AssertTensorContiguous(output_grad, "output_grad");
|
| 360 |
+
AssertTensorContiguous(add_output_grad, "add_output_grad");
|
| 361 |
+
AssertTensorContiguous(input, "input");
|
| 362 |
+
AssertTensorContiguous(weight, "weight");
|
| 363 |
// TODO shape check
|
| 364 |
// weight_grad, input_grad can be nullable
|
| 365 |
|
activation/fused_mul_poly_norm.cu
CHANGED
|
@@ -556,6 +556,12 @@ void fused_mul_poly_norm(torch::Tensor &out, // [..., d]
|
|
| 556 |
AssertTensorShapeEqual(input, mul, "input", "mul");
|
| 557 |
AssertTensorNotNull(weight, "weight");
|
| 558 |
AssertTensorNotNull(bias, "bias");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
// TODO shape check
|
| 560 |
|
| 561 |
int d = input.size(-1);
|
|
@@ -602,6 +608,17 @@ void fused_mul_poly_norm_backward(torch::Tensor &input_grad, // [..., d]
|
|
| 602 |
AssertTensorShapeEqual(input, mul_grad, "input", "mul_grad");
|
| 603 |
AssertTensorShapeEqual(input, mul, "input", "mul");
|
| 604 |
AssertTensorNotNull(weight, "weight");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
// TODO shape check
|
| 606 |
// weight_grad, bias_grad, mul_grad and input_grad can be nullable
|
| 607 |
|
|
|
|
| 556 |
AssertTensorShapeEqual(input, mul, "input", "mul");
|
| 557 |
AssertTensorNotNull(weight, "weight");
|
| 558 |
AssertTensorNotNull(bias, "bias");
|
| 559 |
+
|
| 560 |
+
AssertTensorContiguous(out, "out");
|
| 561 |
+
AssertTensorContiguous(input, "input");
|
| 562 |
+
AssertTensorContiguous(mul, "mul");
|
| 563 |
+
AssertTensorContiguous(weight, "weight");
|
| 564 |
+
AssertTensorContiguous(bias, "bias");
|
| 565 |
// TODO shape check
|
| 566 |
|
| 567 |
int d = input.size(-1);
|
|
|
|
| 608 |
AssertTensorShapeEqual(input, mul_grad, "input", "mul_grad");
|
| 609 |
AssertTensorShapeEqual(input, mul, "input", "mul");
|
| 610 |
AssertTensorNotNull(weight, "weight");
|
| 611 |
+
|
| 612 |
+
constexpr bool ALLOW_NULL = true;
|
| 613 |
+
AssertTensorContiguous(input_grad, "input_grad", ALLOW_NULL);
|
| 614 |
+
AssertTensorContiguous(mul_grad, "mul_grad", ALLOW_NULL);
|
| 615 |
+
AssertTensorContiguous(weight_grad, "weight_grad", ALLOW_NULL);
|
| 616 |
+
AssertTensorContiguous(bias_grad, "bias_grad", ALLOW_NULL);
|
| 617 |
+
AssertTensorContiguous(output_grad, "output_grad");
|
| 618 |
+
AssertTensorContiguous(input, "input");
|
| 619 |
+
AssertTensorContiguous(mul, "mul");
|
| 620 |
+
AssertTensorContiguous(weight, "weight");
|
| 621 |
+
AssertTensorContiguous(bias, "bias");
|
| 622 |
// TODO shape check
|
| 623 |
// weight_grad, bias_grad, mul_grad and input_grad can be nullable
|
| 624 |
|
activation/poly_norm.cu
CHANGED
|
@@ -508,6 +508,11 @@ void poly_norm(torch::Tensor &out, // [..., d]
|
|
| 508 |
AssertTensorNotNull(bias, "bias");
|
| 509 |
// TODO shape check
|
| 510 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
int d = input.size(-1);
|
| 512 |
int64_t num_tokens = input.numel() / d;
|
| 513 |
dim3 grid(num_tokens);
|
|
@@ -548,6 +553,14 @@ void poly_norm_backward(torch::Tensor &input_grad, // [..., d]
|
|
| 548 |
// TODO shape check
|
| 549 |
// weight_grad, bias_grad and input_grad can be nullable
|
| 550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
int d = input.size(-1);
|
| 552 |
int64_t num_tokens = input.numel() / d;
|
| 553 |
dim3 grid(num_tokens);
|
|
|
|
| 508 |
AssertTensorNotNull(bias, "bias");
|
| 509 |
// TODO shape check
|
| 510 |
|
| 511 |
+
AssertTensorContiguous(out, "out");
|
| 512 |
+
AssertTensorContiguous(input, "input");
|
| 513 |
+
AssertTensorContiguous(weight, "weight");
|
| 514 |
+
AssertTensorContiguous(bias, "bias");
|
| 515 |
+
|
| 516 |
int d = input.size(-1);
|
| 517 |
int64_t num_tokens = input.numel() / d;
|
| 518 |
dim3 grid(num_tokens);
|
|
|
|
| 553 |
// TODO shape check
|
| 554 |
// weight_grad, bias_grad and input_grad can be nullable
|
| 555 |
|
| 556 |
+
constexpr bool ALLOW_NULL = true;
|
| 557 |
+
AssertTensorContiguous(input_grad, "input_grad", ALLOW_NULL);
|
| 558 |
+
AssertTensorContiguous(weight_grad, "weight_grad", ALLOW_NULL);
|
| 559 |
+
AssertTensorContiguous(bias_grad, "bias_grad", ALLOW_NULL);
|
| 560 |
+
AssertTensorContiguous(output_grad, "output_grad");
|
| 561 |
+
AssertTensorContiguous(input, "input");
|
| 562 |
+
AssertTensorContiguous(weight, "weight");
|
| 563 |
+
|
| 564 |
int d = input.size(-1);
|
| 565 |
int64_t num_tokens = input.numel() / d;
|
| 566 |
dim3 grid(num_tokens);
|
activation/rms_norm.cu
CHANGED
|
@@ -276,6 +276,8 @@ torch::Tensor rms_norm(const torch::Tensor &input, // [..., d]
|
|
| 276 |
double eps) {
|
| 277 |
AssertTensorNotNull(weight, "weight");
|
| 278 |
// TODO shape check
|
|
|
|
|
|
|
| 279 |
|
| 280 |
torch::Tensor out = torch::empty_like(input);
|
| 281 |
int d = input.size(-1);
|
|
@@ -314,6 +316,10 @@ rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
|
| 314 |
torch::Tensor input_grad = torch::empty_like(input);
|
| 315 |
torch::Tensor weight_grad = torch::empty_like(weight);
|
| 316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
| 318 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 319 |
AssertTensorNotNull(weight, "weight");
|
|
|
|
| 276 |
double eps) {
|
| 277 |
AssertTensorNotNull(weight, "weight");
|
| 278 |
// TODO shape check
|
| 279 |
+
AssertTensorContiguous(input, "input");
|
| 280 |
+
AssertTensorContiguous(weight, "weight");
|
| 281 |
|
| 282 |
torch::Tensor out = torch::empty_like(input);
|
| 283 |
int d = input.size(-1);
|
|
|
|
| 316 |
torch::Tensor input_grad = torch::empty_like(input);
|
| 317 |
torch::Tensor weight_grad = torch::empty_like(weight);
|
| 318 |
|
| 319 |
+
AssertTensorContiguous(output_grad, "output_grad");
|
| 320 |
+
AssertTensorContiguous(input, "input");
|
| 321 |
+
AssertTensorContiguous(weight, "weight");
|
| 322 |
+
|
| 323 |
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
| 324 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 325 |
AssertTensorNotNull(weight, "weight");
|