fix: fix fused add rms norm sharding strategy
Browse files
tests/test_fused_add_rms_norm_sequence_parallel.py
CHANGED
|
@@ -55,7 +55,7 @@ class Model(torch.nn.Module):
|
|
| 55 |
self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d)
|
| 56 |
|
| 57 |
def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
| 58 |
-
return self.fused_add_rms_norm(x, residual
|
| 59 |
|
| 60 |
|
| 61 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
|
@@ -122,18 +122,18 @@ def test_fused_add_rms_norm_sequence_parallel(
|
|
| 122 |
ResidualSequenceParallel(sequence_dim=sequence_dim)
|
| 123 |
})
|
| 124 |
|
| 125 |
-
|
| 126 |
-
x
|
| 127 |
-
placements=(
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
-
|
| 131 |
-
residual
|
| 132 |
-
placements=(
|
| 133 |
device_mesh=mesh,
|
| 134 |
)
|
| 135 |
|
| 136 |
-
y, add_output = model_sharded(
|
| 137 |
|
| 138 |
y_from_sharded = y.full_tensor()
|
| 139 |
add_output_from_sharded = add_output.full_tensor()
|
|
@@ -156,21 +156,16 @@ def test_fused_add_rms_norm_sequence_parallel(
|
|
| 156 |
(y_grad * y_from_unsharded +
|
| 157 |
add_output_grad * add_output_from_unsharded).sum().backward()
|
| 158 |
|
| 159 |
-
weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad.
|
|
|
|
| 160 |
weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad
|
| 161 |
|
| 162 |
assert (x.grad is None) ^ x_requires_grad
|
| 163 |
assert (residual.grad is None) ^ residual_requires_grad
|
| 164 |
|
| 165 |
-
|
| 166 |
-
op=torch.distributed.ReduceOp.SUM)
|
| 167 |
-
|
| 168 |
-
if x.grad is not None:
|
| 169 |
-
torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM)
|
| 170 |
assert_close(x.grad, x_ref.grad)
|
| 171 |
-
if
|
| 172 |
-
torch.distributed.all_reduce(residual.grad,
|
| 173 |
-
op=torch.distributed.ReduceOp.SUM)
|
| 174 |
assert_close(residual.grad, residual_ref.grad)
|
| 175 |
|
| 176 |
assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
|
|
|
|
| 55 |
self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d)
|
| 56 |
|
| 57 |
def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
return self.fused_add_rms_norm(x, residual)
|
| 59 |
|
| 60 |
|
| 61 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
|
|
|
| 122 |
ResidualSequenceParallel(sequence_dim=sequence_dim)
|
| 123 |
})
|
| 124 |
|
| 125 |
+
x_replicate = DTensor.from_local(
|
| 126 |
+
x,
|
| 127 |
+
placements=(Replicate(), ),
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
residual_replicate = DTensor.from_local(
|
| 131 |
+
residual,
|
| 132 |
+
placements=(Replicate(), ),
|
| 133 |
device_mesh=mesh,
|
| 134 |
)
|
| 135 |
|
| 136 |
+
y, add_output = model_sharded(x_replicate, residual_replicate)
|
| 137 |
|
| 138 |
y_from_sharded = y.full_tensor()
|
| 139 |
add_output_from_sharded = add_output.full_tensor()
|
|
|
|
| 156 |
(y_grad * y_from_unsharded +
|
| 157 |
add_output_grad * add_output_from_unsharded).sum().backward()
|
| 158 |
|
| 159 |
+
weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad.full_tensor(
|
| 160 |
+
)
|
| 161 |
weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad
|
| 162 |
|
| 163 |
assert (x.grad is None) ^ x_requires_grad
|
| 164 |
assert (residual.grad is None) ^ residual_requires_grad
|
| 165 |
|
| 166 |
+
if x_requires_grad:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
assert_close(x.grad, x_ref.grad)
|
| 168 |
+
if residual_requires_grad:
|
|
|
|
|
|
|
| 169 |
assert_close(residual.grad, residual_ref.grad)
|
| 170 |
|
| 171 |
assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
|
torch-ext/activation/fused_add_rms_norm_meta.py
CHANGED
|
@@ -4,6 +4,9 @@ import torch
|
|
| 4 |
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
| 5 |
from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
|
| 6 |
RuntimeSchemaInfo)
|
|
|
|
|
|
|
|
|
|
| 7 |
from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
|
| 8 |
register_op_strategy)
|
| 9 |
from torch.distributed.tensor.placement_types import (Placement, Replicate,
|
|
@@ -19,17 +22,6 @@ def register_fused_add_rms_norm_meta():
|
|
| 19 |
pass
|
| 20 |
|
| 21 |
|
| 22 |
-
def _replicate_dims_start_at(placements: Sequence[Placement],
|
| 23 |
-
start_dim: int = 0) -> tuple[Placement, ...]:
|
| 24 |
-
new_placements: list[Placement] = []
|
| 25 |
-
for p in placements:
|
| 26 |
-
if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
|
| 27 |
-
new_placements.append(Replicate()) # make it replicate
|
| 28 |
-
else:
|
| 29 |
-
new_placements.append(p) # keep the placement
|
| 30 |
-
return tuple(new_placements)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
@register_op_strategy(ops.fused_add_rms_norm.default,
|
| 34 |
schema_info=RuntimeSchemaInfo(1))
|
| 35 |
def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
|
@@ -89,7 +81,7 @@ def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
|
| 89 |
# Weight cannot be sharded, so always replicate it.
|
| 90 |
weight_tgt = DTensorSpec(
|
| 91 |
mesh=mesh,
|
| 92 |
-
placements=(
|
| 93 |
tensor_meta=weight_src.tensor_meta,
|
| 94 |
)
|
| 95 |
redistribute_costs.append(
|
|
@@ -141,6 +133,8 @@ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
|
|
| 141 |
)
|
| 142 |
|
| 143 |
last_dim = output_grad_strategy.ndim - 1
|
|
|
|
|
|
|
| 144 |
strategy = OpStrategy([])
|
| 145 |
for output_grad, add_output_grad, add_output, weight in zipped:
|
| 146 |
output_grad_src = output_grad.output_spec
|
|
@@ -179,16 +173,35 @@ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
|
|
| 179 |
# Weight cannot be sharded, so always replicate it.
|
| 180 |
weight_tgt = DTensorSpec(
|
| 181 |
mesh=mesh,
|
| 182 |
-
placements=(
|
| 183 |
tensor_meta=weight_src.tensor_meta,
|
| 184 |
)
|
| 185 |
redistribute_costs.append(
|
| 186 |
generate_redistribute_costs(weight_strategy, weight_tgt))
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
strategy.strategies.append(
|
| 189 |
OpSpec(
|
| 190 |
output_specs=[
|
| 191 |
-
output_grad_tgt if need_input_grad else None,
|
|
|
|
| 192 |
],
|
| 193 |
input_specs=[
|
| 194 |
output_grad_tgt, add_output_grad_tgt, add_output_tgt,
|
|
|
|
| 4 |
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
| 5 |
from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
|
| 6 |
RuntimeSchemaInfo)
|
| 7 |
+
from torch.distributed.tensor._ops._math_ops import (
|
| 8 |
+
_infer_reduce_dims_map, _replicate_dims_start_at,
|
| 9 |
+
map_placements_after_reduction)
|
| 10 |
from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
|
| 11 |
register_op_strategy)
|
| 12 |
from torch.distributed.tensor.placement_types import (Placement, Replicate,
|
|
|
|
| 22 |
pass
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
@register_op_strategy(ops.fused_add_rms_norm.default,
|
| 26 |
schema_info=RuntimeSchemaInfo(1))
|
| 27 |
def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
|
|
|
| 81 |
# Weight cannot be sharded, so always replicate it.
|
| 82 |
weight_tgt = DTensorSpec(
|
| 83 |
mesh=mesh,
|
| 84 |
+
placements=_replicate_dims_start_at(weight_src.placements),
|
| 85 |
tensor_meta=weight_src.tensor_meta,
|
| 86 |
)
|
| 87 |
redistribute_costs.append(
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
last_dim = output_grad_strategy.ndim - 1
|
| 136 |
+
outer_dims = list(range(last_dim))
|
| 137 |
+
|
| 138 |
strategy = OpStrategy([])
|
| 139 |
for output_grad, add_output_grad, add_output, weight in zipped:
|
| 140 |
output_grad_src = output_grad.output_spec
|
|
|
|
| 173 |
# Weight cannot be sharded, so always replicate it.
|
| 174 |
weight_tgt = DTensorSpec(
|
| 175 |
mesh=mesh,
|
| 176 |
+
placements=_replicate_dims_start_at(weight_src.placements),
|
| 177 |
tensor_meta=weight_src.tensor_meta,
|
| 178 |
)
|
| 179 |
redistribute_costs.append(
|
| 180 |
generate_redistribute_costs(weight_strategy, weight_tgt))
|
| 181 |
|
| 182 |
+
# from torch/distributed/tensor/_ops/_math_ops.py::layer_norm_bwd_strategy()
|
| 183 |
+
|
| 184 |
+
# Weight cannot be sharded, so always replicate it.
|
| 185 |
+
# TODO: now d_weight spec follows input spec w/ a reduction.
|
| 186 |
+
# we may need to change to a pointwise rule over grad_out and
|
| 187 |
+
# input, then apply a reduction.
|
| 188 |
+
inp_placements = _replicate_dims_start_at(output_grad_src.placements,
|
| 189 |
+
last_dim)
|
| 190 |
+
reduce_dims_map = _infer_reduce_dims_map(outer_dims,
|
| 191 |
+
output_grad_src.ndim, False)
|
| 192 |
+
out_placements = map_placements_after_reduction(
|
| 193 |
+
inp_placements, outer_dims, reduce_dims_map, "sum")
|
| 194 |
+
weight_grad_tgt = DTensorSpec(
|
| 195 |
+
mesh=mesh,
|
| 196 |
+
placements=out_placements,
|
| 197 |
+
tensor_meta=weight_src.tensor_meta,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
strategy.strategies.append(
|
| 201 |
OpSpec(
|
| 202 |
output_specs=[
|
| 203 |
+
output_grad_tgt if need_input_grad else None,
|
| 204 |
+
weight_grad_tgt
|
| 205 |
],
|
| 206 |
input_specs=[
|
| 207 |
output_grad_tgt, add_output_grad_tgt, add_output_tgt,
|