feat: support sequence parallel with rms_norm
Browse files
tests/test_rms_norm_sequence_parallel.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import sys
|
| 3 |
+
from collections.abc import Sequence
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from packaging import version
|
| 9 |
+
from torch.distributed.tensor.placement_types import (Partial, Placement,
|
| 10 |
+
Replicate, Shard)
|
| 11 |
+
|
| 12 |
+
import activation
|
| 13 |
+
|
| 14 |
+
from .utils import assert_close, opcheck
|
| 15 |
+
|
| 16 |
+
DTYPES = [torch.float32]
|
| 17 |
+
NUM_TOKENS = [512] # Arbitrary values for testing
|
| 18 |
+
SEQUENCE_DIMS = [0, 1] # 0 is for [T, D] (packed), 1 is for [B, S, D]
|
| 19 |
+
D = [16] # Arbitrary values for testing
|
| 20 |
+
SEEDS = [0]
|
| 21 |
+
|
| 22 |
+
from torch.distributed._tensor import DTensor
|
| 23 |
+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
| 24 |
+
from torch.distributed.tensor.parallel import (SequenceParallel,
|
| 25 |
+
parallelize_module)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@pytest.fixture(scope="session", autouse=True)
|
| 29 |
+
def init_dist(request):
|
| 30 |
+
if version.parse(torch.__version__) < version.parse("2.8"):
|
| 31 |
+
pytest.skip("torch>=2.8.0 is required for sequence parallel")
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
dist.init_process_group(backend="nccl")
|
| 36 |
+
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"Failed to initialize torch.distributed: {e}")
|
| 39 |
+
pytest.skip("Failed to initialize torch.distributed")
|
| 40 |
+
|
| 41 |
+
if dist.get_world_size() < 2:
|
| 42 |
+
pytest.skip("Need at least 2 processes in dist group. "
|
| 43 |
+
"You can run with `torchrun --nproc-per-node=2 "
|
| 44 |
+
"--local-ranks-filter 0 -m pytest "
|
| 45 |
+
"test_rms_norm_sequence_parallel.py`")
|
| 46 |
+
|
| 47 |
+
yield
|
| 48 |
+
dist.destroy_process_group()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Model(torch.nn.Module):
|
| 52 |
+
|
| 53 |
+
def __init__(self, num_tokens, d) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.rms_norm = activation.layers.RMSNorm(d)
|
| 56 |
+
|
| 57 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
return self.rms_norm(x)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 62 |
+
@pytest.mark.parametrize("d", D)
|
| 63 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 64 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 65 |
+
@pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS)
|
| 66 |
+
def test_rms_norm(
|
| 67 |
+
num_tokens: int,
|
| 68 |
+
d: int,
|
| 69 |
+
dtype: torch.dtype,
|
| 70 |
+
seed: int,
|
| 71 |
+
sequence_dim: int,
|
| 72 |
+
) -> None:
|
| 73 |
+
if num_tokens % dist.get_world_size() != 0:
|
| 74 |
+
# It hangs at `y.full_tensor()` if not divisible
|
| 75 |
+
pytest.skip("num_tokens must be divisible by world_size for sharding")
|
| 76 |
+
|
| 77 |
+
random.seed(seed)
|
| 78 |
+
torch.manual_seed(seed)
|
| 79 |
+
|
| 80 |
+
num_ranks = dist.get_world_size()
|
| 81 |
+
rank = dist.get_rank()
|
| 82 |
+
mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", ))
|
| 83 |
+
|
| 84 |
+
match sequence_dim:
|
| 85 |
+
case 0:
|
| 86 |
+
x_shape = (num_tokens, d)
|
| 87 |
+
case 1:
|
| 88 |
+
BATCH_SIZE = 2
|
| 89 |
+
x_shape = (BATCH_SIZE, num_tokens, d)
|
| 90 |
+
case _:
|
| 91 |
+
raise ValueError(f"Invalid sequence_dim: {sequence_dim}")
|
| 92 |
+
|
| 93 |
+
x = torch.randn(x_shape, dtype=dtype, requires_grad=True).cuda()
|
| 94 |
+
weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda()
|
| 95 |
+
eps = 1e-05
|
| 96 |
+
|
| 97 |
+
x.retain_grad()
|
| 98 |
+
weight.retain_grad()
|
| 99 |
+
|
| 100 |
+
# Copy x, weight for reference
|
| 101 |
+
x_ref = x.detach().clone().requires_grad_(True)
|
| 102 |
+
weight_ref = weight.detach().clone().requires_grad_(True)
|
| 103 |
+
|
| 104 |
+
model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda()
|
| 105 |
+
model_sharded.rms_norm.weight = torch.nn.Parameter(weight)
|
| 106 |
+
parallelize_module(
|
| 107 |
+
model_sharded, mesh,
|
| 108 |
+
{"rms_norm": SequenceParallel(sequence_dim=sequence_dim)})
|
| 109 |
+
x_sharded = DTensor.from_local(
|
| 110 |
+
x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(),
|
| 111 |
+
placements=(Shard(sequence_dim), ),
|
| 112 |
+
device_mesh=mesh,
|
| 113 |
+
)
|
| 114 |
+
y = model_sharded(x_sharded)
|
| 115 |
+
y_from_sharded = y.full_tensor()
|
| 116 |
+
|
| 117 |
+
model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda()
|
| 118 |
+
model_unsharded.rms_norm.weight = torch.nn.Parameter(weight_ref)
|
| 119 |
+
|
| 120 |
+
y_from_unsharded = model_unsharded(x_ref)
|
| 121 |
+
|
| 122 |
+
assert_close(y_from_sharded, y_from_unsharded)
|
| 123 |
+
|
| 124 |
+
# Backward
|
| 125 |
+
y_grad = torch.randn_like(y_from_unsharded)
|
| 126 |
+
y_from_sharded.backward(y_grad)
|
| 127 |
+
y_from_unsharded.backward(y_grad)
|
| 128 |
+
|
| 129 |
+
weight_grad_from_sharded = model_sharded.rms_norm.weight.grad._local_tensor
|
| 130 |
+
weight_grad_from_unsharded = model_unsharded.rms_norm.weight.grad
|
| 131 |
+
|
| 132 |
+
torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM)
|
| 133 |
+
torch.distributed.all_reduce(weight_grad_from_sharded,
|
| 134 |
+
op=torch.distributed.ReduceOp.SUM)
|
| 135 |
+
|
| 136 |
+
assert_close(x.grad, x_ref.grad)
|
| 137 |
+
assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
|
torch-ext/activation/rms_norm.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
|
| 3 |
from ._ops import ops
|
| 4 |
|
|
@@ -70,3 +73,159 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
|
|
| 70 |
residual_grad = grad if need_res else None
|
| 71 |
|
| 72 |
return input_grad, residual_grad, weight_grad, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Sequence
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
+
from packaging import version
|
| 5 |
|
| 6 |
from ._ops import ops
|
| 7 |
|
|
|
|
| 73 |
residual_grad = grad if need_res else None
|
| 74 |
|
| 75 |
return input_grad, residual_grad, weight_grad, None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if version.parse(torch.__version__) >= version.parse("2.8"):
|
| 79 |
+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
| 80 |
+
from torch.distributed.tensor._op_schema import (OpSchema, OpSpec,
|
| 81 |
+
OpStrategy,
|
| 82 |
+
RuntimeSchemaInfo)
|
| 83 |
+
from torch.distributed.tensor._ops.utils import (
|
| 84 |
+
generate_redistribute_costs, register_op_strategy)
|
| 85 |
+
from torch.distributed.tensor.placement_types import (Placement, Replicate,
|
| 86 |
+
Shard)
|
| 87 |
+
|
| 88 |
+
@torch.library.register_fake(ops.rms_norm.default)
|
| 89 |
+
def rms_norm_abstract(x, weight, eps):
|
| 90 |
+
return torch.empty_like(x)
|
| 91 |
+
|
| 92 |
+
@torch.library.register_fake(ops.rms_norm_backward.default)
|
| 93 |
+
def rms_norm_backward_abstract(output_grad, x, weight, eps):
|
| 94 |
+
return torch.empty_like(x), torch.empty_like(weight)
|
| 95 |
+
|
| 96 |
+
def _replicate_dims_start_at(placements: Sequence[Placement],
|
| 97 |
+
start_dim: int = 0) -> tuple[Placement, ...]:
|
| 98 |
+
new_placements: list[Placement] = []
|
| 99 |
+
for p in placements:
|
| 100 |
+
if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
|
| 101 |
+
new_placements.append(Replicate()) # make it replicate
|
| 102 |
+
else:
|
| 103 |
+
new_placements.append(p) # keep the placement
|
| 104 |
+
return tuple(new_placements)
|
| 105 |
+
|
| 106 |
+
@register_op_strategy(ops.rms_norm.default,
|
| 107 |
+
schema_info=RuntimeSchemaInfo(1))
|
| 108 |
+
def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
| 109 |
+
mesh = op_schema.get_mesh_from_args()
|
| 110 |
+
|
| 111 |
+
assert len(op_schema.args_schema) == 3
|
| 112 |
+
(
|
| 113 |
+
input_strategy,
|
| 114 |
+
weight_strategy,
|
| 115 |
+
_, # eps
|
| 116 |
+
) = op_schema.args_schema
|
| 117 |
+
|
| 118 |
+
assert isinstance(input_strategy, OpStrategy)
|
| 119 |
+
assert isinstance(weight_strategy, OpStrategy)
|
| 120 |
+
|
| 121 |
+
assert len(input_strategy.strategies) == len(
|
| 122 |
+
weight_strategy.strategies)
|
| 123 |
+
|
| 124 |
+
last_dim = input_strategy.ndim - 1
|
| 125 |
+
strategy = OpStrategy([])
|
| 126 |
+
for idx in range(len(input_strategy.strategies)):
|
| 127 |
+
input_src = input_strategy.strategies[idx].output_spec
|
| 128 |
+
weight_src = weight_strategy.strategies[idx].output_spec
|
| 129 |
+
|
| 130 |
+
assert isinstance(input_src, DTensorSpec)
|
| 131 |
+
assert isinstance(weight_src, DTensorSpec)
|
| 132 |
+
|
| 133 |
+
redistribute_costs = []
|
| 134 |
+
|
| 135 |
+
# Input can be sharded in any dim except the last dim.
|
| 136 |
+
input_tgt = DTensorSpec(
|
| 137 |
+
mesh=mesh,
|
| 138 |
+
placements=_replicate_dims_start_at(input_src.placements,
|
| 139 |
+
last_dim),
|
| 140 |
+
tensor_meta=input_src.tensor_meta,
|
| 141 |
+
)
|
| 142 |
+
redistribute_costs.append(
|
| 143 |
+
generate_redistribute_costs(input_strategy, input_tgt))
|
| 144 |
+
|
| 145 |
+
# Weight cannot be sharded, so always replicate it.
|
| 146 |
+
weight_tgt = DTensorSpec(
|
| 147 |
+
mesh=mesh,
|
| 148 |
+
placements=(Replicate(), ),
|
| 149 |
+
tensor_meta=weight_src.tensor_meta,
|
| 150 |
+
)
|
| 151 |
+
redistribute_costs.append(
|
| 152 |
+
generate_redistribute_costs(weight_strategy, weight_tgt))
|
| 153 |
+
|
| 154 |
+
strategy.strategies.append(
|
| 155 |
+
OpSpec(
|
| 156 |
+
output_specs=input_tgt,
|
| 157 |
+
input_specs=[input_tgt, weight_tgt],
|
| 158 |
+
redistribute_cost=redistribute_costs,
|
| 159 |
+
))
|
| 160 |
+
return strategy
|
| 161 |
+
|
| 162 |
+
@register_op_strategy(ops.rms_norm_backward.default,
|
| 163 |
+
schema_info=RuntimeSchemaInfo(1))
|
| 164 |
+
def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
|
| 165 |
+
mesh = op_schema.get_mesh_from_args()
|
| 166 |
+
|
| 167 |
+
assert len(op_schema.args_schema) == 4
|
| 168 |
+
(
|
| 169 |
+
output_grad_strategy,
|
| 170 |
+
input_strategy,
|
| 171 |
+
weight_strategy,
|
| 172 |
+
_, # eps
|
| 173 |
+
) = op_schema.args_schema
|
| 174 |
+
|
| 175 |
+
assert isinstance(output_grad_strategy, OpStrategy)
|
| 176 |
+
assert isinstance(input_strategy, OpStrategy)
|
| 177 |
+
assert isinstance(weight_strategy, OpStrategy)
|
| 178 |
+
|
| 179 |
+
assert len(input_strategy.strategies) == len(
|
| 180 |
+
weight_strategy.strategies)
|
| 181 |
+
assert len(input_strategy.strategies) == len(
|
| 182 |
+
output_grad_strategy.strategies)
|
| 183 |
+
|
| 184 |
+
last_dim = input_strategy.ndim - 1
|
| 185 |
+
strategy = OpStrategy([])
|
| 186 |
+
for idx in range(len(input_strategy.strategies)):
|
| 187 |
+
output_grad_src = output_grad_strategy.strategies[idx].output_spec
|
| 188 |
+
input_src = input_strategy.strategies[idx].output_spec
|
| 189 |
+
weight_src = weight_strategy.strategies[idx].output_spec
|
| 190 |
+
|
| 191 |
+
assert isinstance(output_grad_src, DTensorSpec)
|
| 192 |
+
assert isinstance(input_src, DTensorSpec)
|
| 193 |
+
assert isinstance(weight_src, DTensorSpec)
|
| 194 |
+
|
| 195 |
+
redistribute_costs = []
|
| 196 |
+
|
| 197 |
+
# Output grad and input can be sharded in any dim except the last dim.
|
| 198 |
+
output_grad_tgt = DTensorSpec(
|
| 199 |
+
mesh=mesh,
|
| 200 |
+
placements=_replicate_dims_start_at(output_grad_src.placements,
|
| 201 |
+
last_dim),
|
| 202 |
+
tensor_meta=output_grad_src.tensor_meta,
|
| 203 |
+
)
|
| 204 |
+
redistribute_costs.append(
|
| 205 |
+
generate_redistribute_costs(output_grad_strategy,
|
| 206 |
+
output_grad_tgt))
|
| 207 |
+
input_tgt = DTensorSpec(
|
| 208 |
+
mesh=mesh,
|
| 209 |
+
placements=_replicate_dims_start_at(input_src.placements,
|
| 210 |
+
last_dim),
|
| 211 |
+
tensor_meta=input_src.tensor_meta,
|
| 212 |
+
)
|
| 213 |
+
redistribute_costs.append(
|
| 214 |
+
generate_redistribute_costs(input_strategy, input_tgt))
|
| 215 |
+
|
| 216 |
+
# Weight cannot be sharded, so always replicate it.
|
| 217 |
+
weight_tgt = DTensorSpec(
|
| 218 |
+
mesh=mesh,
|
| 219 |
+
placements=(Replicate(), ),
|
| 220 |
+
tensor_meta=weight_src.tensor_meta,
|
| 221 |
+
)
|
| 222 |
+
redistribute_costs.append(
|
| 223 |
+
generate_redistribute_costs(weight_strategy, weight_tgt))
|
| 224 |
+
|
| 225 |
+
strategy.strategies.append(
|
| 226 |
+
OpSpec(
|
| 227 |
+
output_specs=[input_tgt, weight_tgt],
|
| 228 |
+
input_specs=[output_grad_tgt, input_tgt, weight_tgt],
|
| 229 |
+
redistribute_cost=redistribute_costs,
|
| 230 |
+
))
|
| 231 |
+
return strategy
|