Kernels
wyldecat commited on
Commit
06d6367
·
1 Parent(s): a2a2501

feat: support sequence parallel with rms_norm

Browse files
tests/test_rms_norm_sequence_parallel.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sys
3
+ from collections.abc import Sequence
4
+
5
+ import pytest
6
+ import torch
7
+ import torch.distributed as dist
8
+ from packaging import version
9
+ from torch.distributed.tensor.placement_types import (Partial, Placement,
10
+ Replicate, Shard)
11
+
12
+ import activation
13
+
14
+ from .utils import assert_close, opcheck
15
+
16
+ DTYPES = [torch.float32]
17
+ NUM_TOKENS = [512] # Arbitrary values for testing
18
+ SEQUENCE_DIMS = [0, 1] # 0 is for [T, D] (packed), 1 is for [B, S, D]
19
+ D = [16] # Arbitrary values for testing
20
+ SEEDS = [0]
21
+
22
+ from torch.distributed._tensor import DTensor
23
+ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
24
+ from torch.distributed.tensor.parallel import (SequenceParallel,
25
+ parallelize_module)
26
+
27
+
28
+ @pytest.fixture(scope="session", autouse=True)
29
+ def init_dist(request):
30
+ if version.parse(torch.__version__) < version.parse("2.8"):
31
+ pytest.skip("torch>=2.8.0 is required for sequence parallel")
32
+ return
33
+
34
+ try:
35
+ dist.init_process_group(backend="nccl")
36
+ torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
37
+ except Exception as e:
38
+ print(f"Failed to initialize torch.distributed: {e}")
39
+ pytest.skip("Failed to initialize torch.distributed")
40
+
41
+ if dist.get_world_size() < 2:
42
+ pytest.skip("Need at least 2 processes in dist group. "
43
+ "You can run with `torchrun --nproc-per-node=2 "
44
+ "--local-ranks-filter 0 -m pytest "
45
+ "test_rms_norm_sequence_parallel.py`")
46
+
47
+ yield
48
+ dist.destroy_process_group()
49
+
50
+
51
+ class Model(torch.nn.Module):
52
+
53
+ def __init__(self, num_tokens, d) -> None:
54
+ super().__init__()
55
+ self.rms_norm = activation.layers.RMSNorm(d)
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ return self.rms_norm(x)
59
+
60
+
61
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
62
+ @pytest.mark.parametrize("d", D)
63
+ @pytest.mark.parametrize("dtype", DTYPES)
64
+ @pytest.mark.parametrize("seed", SEEDS)
65
+ @pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS)
66
+ def test_rms_norm(
67
+ num_tokens: int,
68
+ d: int,
69
+ dtype: torch.dtype,
70
+ seed: int,
71
+ sequence_dim: int,
72
+ ) -> None:
73
+ if num_tokens % dist.get_world_size() != 0:
74
+ # It hangs at `y.full_tensor()` if not divisible
75
+ pytest.skip("num_tokens must be divisible by world_size for sharding")
76
+
77
+ random.seed(seed)
78
+ torch.manual_seed(seed)
79
+
80
+ num_ranks = dist.get_world_size()
81
+ rank = dist.get_rank()
82
+ mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", ))
83
+
84
+ match sequence_dim:
85
+ case 0:
86
+ x_shape = (num_tokens, d)
87
+ case 1:
88
+ BATCH_SIZE = 2
89
+ x_shape = (BATCH_SIZE, num_tokens, d)
90
+ case _:
91
+ raise ValueError(f"Invalid sequence_dim: {sequence_dim}")
92
+
93
+ x = torch.randn(x_shape, dtype=dtype, requires_grad=True).cuda()
94
+ weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda()
95
+ eps = 1e-05
96
+
97
+ x.retain_grad()
98
+ weight.retain_grad()
99
+
100
+ # Copy x, weight for reference
101
+ x_ref = x.detach().clone().requires_grad_(True)
102
+ weight_ref = weight.detach().clone().requires_grad_(True)
103
+
104
+ model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda()
105
+ model_sharded.rms_norm.weight = torch.nn.Parameter(weight)
106
+ parallelize_module(
107
+ model_sharded, mesh,
108
+ {"rms_norm": SequenceParallel(sequence_dim=sequence_dim)})
109
+ x_sharded = DTensor.from_local(
110
+ x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(),
111
+ placements=(Shard(sequence_dim), ),
112
+ device_mesh=mesh,
113
+ )
114
+ y = model_sharded(x_sharded)
115
+ y_from_sharded = y.full_tensor()
116
+
117
+ model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda()
118
+ model_unsharded.rms_norm.weight = torch.nn.Parameter(weight_ref)
119
+
120
+ y_from_unsharded = model_unsharded(x_ref)
121
+
122
+ assert_close(y_from_sharded, y_from_unsharded)
123
+
124
+ # Backward
125
+ y_grad = torch.randn_like(y_from_unsharded)
126
+ y_from_sharded.backward(y_grad)
127
+ y_from_unsharded.backward(y_grad)
128
+
129
+ weight_grad_from_sharded = model_sharded.rms_norm.weight.grad._local_tensor
130
+ weight_grad_from_unsharded = model_unsharded.rms_norm.weight.grad
131
+
132
+ torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM)
133
+ torch.distributed.all_reduce(weight_grad_from_sharded,
134
+ op=torch.distributed.ReduceOp.SUM)
135
+
136
+ assert_close(x.grad, x_ref.grad)
137
+ assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
torch-ext/activation/rms_norm.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import torch
 
2
 
3
  from ._ops import ops
4
 
@@ -70,3 +73,159 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
70
  residual_grad = grad if need_res else None
71
 
72
  return input_grad, residual_grad, weight_grad, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
  import torch
4
+ from packaging import version
5
 
6
  from ._ops import ops
7
 
 
73
  residual_grad = grad if need_res else None
74
 
75
  return input_grad, residual_grad, weight_grad, None
76
+
77
+
78
+ if version.parse(torch.__version__) >= version.parse("2.8"):
79
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
80
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec,
81
+ OpStrategy,
82
+ RuntimeSchemaInfo)
83
+ from torch.distributed.tensor._ops.utils import (
84
+ generate_redistribute_costs, register_op_strategy)
85
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
86
+ Shard)
87
+
88
+ @torch.library.register_fake(ops.rms_norm.default)
89
+ def rms_norm_abstract(x, weight, eps):
90
+ return torch.empty_like(x)
91
+
92
+ @torch.library.register_fake(ops.rms_norm_backward.default)
93
+ def rms_norm_backward_abstract(output_grad, x, weight, eps):
94
+ return torch.empty_like(x), torch.empty_like(weight)
95
+
96
+ def _replicate_dims_start_at(placements: Sequence[Placement],
97
+ start_dim: int = 0) -> tuple[Placement, ...]:
98
+ new_placements: list[Placement] = []
99
+ for p in placements:
100
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
101
+ new_placements.append(Replicate()) # make it replicate
102
+ else:
103
+ new_placements.append(p) # keep the placement
104
+ return tuple(new_placements)
105
+
106
+ @register_op_strategy(ops.rms_norm.default,
107
+ schema_info=RuntimeSchemaInfo(1))
108
+ def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
109
+ mesh = op_schema.get_mesh_from_args()
110
+
111
+ assert len(op_schema.args_schema) == 3
112
+ (
113
+ input_strategy,
114
+ weight_strategy,
115
+ _, # eps
116
+ ) = op_schema.args_schema
117
+
118
+ assert isinstance(input_strategy, OpStrategy)
119
+ assert isinstance(weight_strategy, OpStrategy)
120
+
121
+ assert len(input_strategy.strategies) == len(
122
+ weight_strategy.strategies)
123
+
124
+ last_dim = input_strategy.ndim - 1
125
+ strategy = OpStrategy([])
126
+ for idx in range(len(input_strategy.strategies)):
127
+ input_src = input_strategy.strategies[idx].output_spec
128
+ weight_src = weight_strategy.strategies[idx].output_spec
129
+
130
+ assert isinstance(input_src, DTensorSpec)
131
+ assert isinstance(weight_src, DTensorSpec)
132
+
133
+ redistribute_costs = []
134
+
135
+ # Input can be sharded in any dim except the last dim.
136
+ input_tgt = DTensorSpec(
137
+ mesh=mesh,
138
+ placements=_replicate_dims_start_at(input_src.placements,
139
+ last_dim),
140
+ tensor_meta=input_src.tensor_meta,
141
+ )
142
+ redistribute_costs.append(
143
+ generate_redistribute_costs(input_strategy, input_tgt))
144
+
145
+ # Weight cannot be sharded, so always replicate it.
146
+ weight_tgt = DTensorSpec(
147
+ mesh=mesh,
148
+ placements=(Replicate(), ),
149
+ tensor_meta=weight_src.tensor_meta,
150
+ )
151
+ redistribute_costs.append(
152
+ generate_redistribute_costs(weight_strategy, weight_tgt))
153
+
154
+ strategy.strategies.append(
155
+ OpSpec(
156
+ output_specs=input_tgt,
157
+ input_specs=[input_tgt, weight_tgt],
158
+ redistribute_cost=redistribute_costs,
159
+ ))
160
+ return strategy
161
+
162
+ @register_op_strategy(ops.rms_norm_backward.default,
163
+ schema_info=RuntimeSchemaInfo(1))
164
+ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
165
+ mesh = op_schema.get_mesh_from_args()
166
+
167
+ assert len(op_schema.args_schema) == 4
168
+ (
169
+ output_grad_strategy,
170
+ input_strategy,
171
+ weight_strategy,
172
+ _, # eps
173
+ ) = op_schema.args_schema
174
+
175
+ assert isinstance(output_grad_strategy, OpStrategy)
176
+ assert isinstance(input_strategy, OpStrategy)
177
+ assert isinstance(weight_strategy, OpStrategy)
178
+
179
+ assert len(input_strategy.strategies) == len(
180
+ weight_strategy.strategies)
181
+ assert len(input_strategy.strategies) == len(
182
+ output_grad_strategy.strategies)
183
+
184
+ last_dim = input_strategy.ndim - 1
185
+ strategy = OpStrategy([])
186
+ for idx in range(len(input_strategy.strategies)):
187
+ output_grad_src = output_grad_strategy.strategies[idx].output_spec
188
+ input_src = input_strategy.strategies[idx].output_spec
189
+ weight_src = weight_strategy.strategies[idx].output_spec
190
+
191
+ assert isinstance(output_grad_src, DTensorSpec)
192
+ assert isinstance(input_src, DTensorSpec)
193
+ assert isinstance(weight_src, DTensorSpec)
194
+
195
+ redistribute_costs = []
196
+
197
+ # Output grad and input can be sharded in any dim except the last dim.
198
+ output_grad_tgt = DTensorSpec(
199
+ mesh=mesh,
200
+ placements=_replicate_dims_start_at(output_grad_src.placements,
201
+ last_dim),
202
+ tensor_meta=output_grad_src.tensor_meta,
203
+ )
204
+ redistribute_costs.append(
205
+ generate_redistribute_costs(output_grad_strategy,
206
+ output_grad_tgt))
207
+ input_tgt = DTensorSpec(
208
+ mesh=mesh,
209
+ placements=_replicate_dims_start_at(input_src.placements,
210
+ last_dim),
211
+ tensor_meta=input_src.tensor_meta,
212
+ )
213
+ redistribute_costs.append(
214
+ generate_redistribute_costs(input_strategy, input_tgt))
215
+
216
+ # Weight cannot be sharded, so always replicate it.
217
+ weight_tgt = DTensorSpec(
218
+ mesh=mesh,
219
+ placements=(Replicate(), ),
220
+ tensor_meta=weight_src.tensor_meta,
221
+ )
222
+ redistribute_costs.append(
223
+ generate_redistribute_costs(weight_strategy, weight_tgt))
224
+
225
+ strategy.strategies.append(
226
+ OpSpec(
227
+ output_specs=[input_tgt, weight_tgt],
228
+ input_specs=[output_grad_tgt, input_tgt, weight_tgt],
229
+ redistribute_cost=redistribute_costs,
230
+ ))
231
+ return strategy