feat: make rms_norm as out-place
Browse files- activation/rms_norm.cu +16 -11
- tests/test_rms_norm.py +1 -2
- torch-ext/activation/rms_norm.py +3 -10
- torch-ext/torch_binding.cpp +3 -4
- torch-ext/torch_binding.h +4 -5
activation/rms_norm.cu
CHANGED
|
@@ -271,14 +271,13 @@ rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 271 |
weight.data_ptr<scalar_t>(), eps, d); \
|
| 272 |
});
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
double eps) {
|
| 278 |
-
AssertTensorShapeEqual(input, out, "input", "out");
|
| 279 |
AssertTensorNotNull(weight, "weight");
|
| 280 |
// TODO shape check
|
| 281 |
|
|
|
|
| 282 |
int d = input.size(-1);
|
| 283 |
int64_t num_tokens = input.numel() / input.size(-1);
|
| 284 |
dim3 grid(num_tokens);
|
|
@@ -292,6 +291,8 @@ void rms_norm(torch::Tensor &out, // [..., d]
|
|
| 292 |
} else {
|
| 293 |
LAUNCH_RMS_NORM(0);
|
| 294 |
}
|
|
|
|
|
|
|
| 295 |
}
|
| 296 |
|
| 297 |
#define LAUNCH_RMS_NORM_BWD(width) \
|
|
@@ -305,12 +306,14 @@ void rms_norm(torch::Tensor &out, // [..., d]
|
|
| 305 |
weight.data_ptr<scalar_t>(), eps, d); \
|
| 306 |
});
|
| 307 |
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
| 314 |
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
| 315 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 316 |
AssertTensorNotNull(weight, "weight");
|
|
@@ -340,4 +343,6 @@ void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
|
|
| 340 |
at::sum_out(acc, temp_weight_grad, {0});
|
| 341 |
weight_grad.copy_(acc);
|
| 342 |
}
|
|
|
|
|
|
|
| 343 |
}
|
|
|
|
| 271 |
weight.data_ptr<scalar_t>(), eps, d); \
|
| 272 |
});
|
| 273 |
|
| 274 |
+
torch::Tensor rms_norm(const torch::Tensor &input, // [..., d]
|
| 275 |
+
const torch::Tensor &weight, // [d]
|
| 276 |
+
double eps) {
|
|
|
|
|
|
|
| 277 |
AssertTensorNotNull(weight, "weight");
|
| 278 |
// TODO shape check
|
| 279 |
|
| 280 |
+
torch::Tensor out = torch::empty_like(input);
|
| 281 |
int d = input.size(-1);
|
| 282 |
int64_t num_tokens = input.numel() / input.size(-1);
|
| 283 |
dim3 grid(num_tokens);
|
|
|
|
| 291 |
} else {
|
| 292 |
LAUNCH_RMS_NORM(0);
|
| 293 |
}
|
| 294 |
+
|
| 295 |
+
return out;
|
| 296 |
}
|
| 297 |
|
| 298 |
#define LAUNCH_RMS_NORM_BWD(width) \
|
|
|
|
| 306 |
weight.data_ptr<scalar_t>(), eps, d); \
|
| 307 |
});
|
| 308 |
|
| 309 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 310 |
+
rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
| 311 |
+
const torch::Tensor &input, // [..., d]
|
| 312 |
+
const torch::Tensor &weight, // [d]
|
| 313 |
+
double eps) {
|
| 314 |
+
torch::Tensor input_grad = torch::empty_like(input);
|
| 315 |
+
torch::Tensor weight_grad = torch::empty_like(weight);
|
| 316 |
+
|
| 317 |
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
| 318 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 319 |
AssertTensorNotNull(weight, "weight");
|
|
|
|
| 343 |
at::sum_out(acc, temp_weight_grad, {0});
|
| 344 |
weight_grad.copy_(acc);
|
| 345 |
}
|
| 346 |
+
|
| 347 |
+
return {input_grad, weight_grad};
|
| 348 |
}
|
tests/test_rms_norm.py
CHANGED
|
@@ -51,8 +51,7 @@ def test_rms_norm(
|
|
| 51 |
layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype)
|
| 52 |
layer.weight = torch.nn.Parameter(weight)
|
| 53 |
|
| 54 |
-
|
| 55 |
-
opcheck(op, (out, x, weight, eps))
|
| 56 |
|
| 57 |
out = fn(x, weight, eps)
|
| 58 |
mod_out = layer(x)
|
|
|
|
| 51 |
layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype)
|
| 52 |
layer.weight = torch.nn.Parameter(weight)
|
| 53 |
|
| 54 |
+
opcheck(op, (x, weight, eps))
|
|
|
|
| 55 |
|
| 56 |
out = fn(x, weight, eps)
|
| 57 |
mod_out = layer(x)
|
torch-ext/activation/rms_norm.py
CHANGED
|
@@ -8,9 +8,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
| 8 |
# Note that forward, setup_context, and backward are @staticmethods
|
| 9 |
@staticmethod
|
| 10 |
def forward(input, weight, eps):
|
| 11 |
-
|
| 12 |
-
ops.rms_norm(output, input, weight, eps)
|
| 13 |
-
return output
|
| 14 |
|
| 15 |
@staticmethod
|
| 16 |
# inputs is a Tuple of all of the inputs passed to forward.
|
|
@@ -26,13 +24,8 @@ class RMSNormFunction(torch.autograd.Function):
|
|
| 26 |
input, weight = ctx.saved_tensors
|
| 27 |
eps = ctx.eps
|
| 28 |
|
| 29 |
-
input_grad =
|
| 30 |
-
input
|
| 31 |
-
weight_grad = torch.empty_like(
|
| 32 |
-
weight) if ctx.needs_input_grad[1] else None
|
| 33 |
-
|
| 34 |
-
ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
|
| 35 |
-
weight, eps)
|
| 36 |
|
| 37 |
return input_grad, weight_grad, None
|
| 38 |
|
|
|
|
| 8 |
# Note that forward, setup_context, and backward are @staticmethods
|
| 9 |
@staticmethod
|
| 10 |
def forward(input, weight, eps):
|
| 11 |
+
return ops.rms_norm(input, weight, eps)
|
|
|
|
|
|
|
| 12 |
|
| 13 |
@staticmethod
|
| 14 |
# inputs is a Tuple of all of the inputs passed to forward.
|
|
|
|
| 24 |
input, weight = ctx.saved_tensors
|
| 25 |
eps = ctx.eps
|
| 26 |
|
| 27 |
+
input_grad, weight_grad = ops.rms_norm_backward(
|
| 28 |
+
output_grad, input, weight, eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
return input_grad, weight_grad, None
|
| 31 |
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -16,12 +16,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 16 |
ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
|
| 17 |
|
| 18 |
// rms_norm
|
| 19 |
-
ops.def(
|
| 20 |
-
"rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
|
| 21 |
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
| 22 |
|
| 23 |
-
ops.def("rms_norm_backward(Tensor
|
| 24 |
-
"
|
| 25 |
ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
|
| 26 |
|
| 27 |
// fused_mul_poly_norm
|
|
|
|
| 16 |
ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
|
| 17 |
|
| 18 |
// rms_norm
|
| 19 |
+
ops.def("rms_norm(Tensor input, Tensor weight, float eps) -> Tensor");
|
|
|
|
| 20 |
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
| 21 |
|
| 22 |
+
ops.def("rms_norm_backward(Tensor output_grad, Tensor input, Tensor weight, "
|
| 23 |
+
"float eps) -> (Tensor, Tensor)");
|
| 24 |
ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
|
| 25 |
|
| 26 |
// fused_mul_poly_norm
|
torch-ext/torch_binding.h
CHANGED
|
@@ -11,12 +11,11 @@ void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
|
|
| 11 |
const torch::Tensor &input, const torch::Tensor &weight,
|
| 12 |
double eps);
|
| 13 |
|
| 14 |
-
|
| 15 |
-
const torch::Tensor &weights, double eps);
|
| 16 |
-
void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
|
| 17 |
-
const torch::Tensor &output_grad,
|
| 18 |
-
const torch::Tensor &input, const torch::Tensor &weight,
|
| 19 |
double eps);
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input,
|
| 22 |
const torch::Tensor &mul, const torch::Tensor &weights,
|
|
|
|
| 11 |
const torch::Tensor &input, const torch::Tensor &weight,
|
| 12 |
double eps);
|
| 13 |
|
| 14 |
+
torch::Tensor rms_norm(const torch::Tensor &input, const torch::Tensor &weights,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
double eps);
|
| 16 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 17 |
+
rms_norm_backward(const torch::Tensor &output_grad, const torch::Tensor &input,
|
| 18 |
+
const torch::Tensor &weight, double eps);
|
| 19 |
|
| 20 |
void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input,
|
| 21 |
const torch::Tensor &mul, const torch::Tensor &weights,
|