Kernels
wyldecat commited on
Commit
a35a092
·
1 Parent(s): 138159c

fix: fix fused add rms norm sharding strategy

Browse files
tests/test_fused_add_rms_norm_sequence_parallel.py CHANGED
@@ -55,7 +55,7 @@ class Model(torch.nn.Module):
55
  self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d)
56
 
57
  def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
58
- return self.fused_add_rms_norm(x, residual=residual)
59
 
60
 
61
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -122,18 +122,18 @@ def test_fused_add_rms_norm_sequence_parallel(
122
  ResidualSequenceParallel(sequence_dim=sequence_dim)
123
  })
124
 
125
- x_sharded = DTensor.from_local(
126
- x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(),
127
- placements=(Shard(sequence_dim), ),
128
  device_mesh=mesh,
129
  )
130
- residual_sharded = DTensor.from_local(
131
- residual.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(),
132
- placements=(Shard(sequence_dim), ),
133
  device_mesh=mesh,
134
  )
135
 
136
- y, add_output = model_sharded(x_sharded, residual_sharded)
137
 
138
  y_from_sharded = y.full_tensor()
139
  add_output_from_sharded = add_output.full_tensor()
@@ -156,21 +156,16 @@ def test_fused_add_rms_norm_sequence_parallel(
156
  (y_grad * y_from_unsharded +
157
  add_output_grad * add_output_from_unsharded).sum().backward()
158
 
159
- weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad._local_tensor
 
160
  weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad
161
 
162
  assert (x.grad is None) ^ x_requires_grad
163
  assert (residual.grad is None) ^ residual_requires_grad
164
 
165
- torch.distributed.all_reduce(weight_grad_from_sharded,
166
- op=torch.distributed.ReduceOp.SUM)
167
-
168
- if x.grad is not None:
169
- torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM)
170
  assert_close(x.grad, x_ref.grad)
171
- if residual.grad is not None:
172
- torch.distributed.all_reduce(residual.grad,
173
- op=torch.distributed.ReduceOp.SUM)
174
  assert_close(residual.grad, residual_ref.grad)
175
 
176
  assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
 
55
  self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d)
56
 
57
  def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
58
+ return self.fused_add_rms_norm(x, residual)
59
 
60
 
61
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 
122
  ResidualSequenceParallel(sequence_dim=sequence_dim)
123
  })
124
 
125
+ x_replicate = DTensor.from_local(
126
+ x,
127
+ placements=(Replicate(), ),
128
  device_mesh=mesh,
129
  )
130
+ residual_replicate = DTensor.from_local(
131
+ residual,
132
+ placements=(Replicate(), ),
133
  device_mesh=mesh,
134
  )
135
 
136
+ y, add_output = model_sharded(x_replicate, residual_replicate)
137
 
138
  y_from_sharded = y.full_tensor()
139
  add_output_from_sharded = add_output.full_tensor()
 
156
  (y_grad * y_from_unsharded +
157
  add_output_grad * add_output_from_unsharded).sum().backward()
158
 
159
+ weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad.full_tensor(
160
+ )
161
  weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad
162
 
163
  assert (x.grad is None) ^ x_requires_grad
164
  assert (residual.grad is None) ^ residual_requires_grad
165
 
166
+ if x_requires_grad:
 
 
 
 
167
  assert_close(x.grad, x_ref.grad)
168
+ if residual_requires_grad:
 
 
169
  assert_close(residual.grad, residual_ref.grad)
170
 
171
  assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
torch-ext/activation/fused_add_rms_norm_meta.py CHANGED
@@ -4,6 +4,9 @@ import torch
4
  from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
  from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
  RuntimeSchemaInfo)
 
 
 
7
  from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
  register_op_strategy)
9
  from torch.distributed.tensor.placement_types import (Placement, Replicate,
@@ -19,17 +22,6 @@ def register_fused_add_rms_norm_meta():
19
  pass
20
 
21
 
22
- def _replicate_dims_start_at(placements: Sequence[Placement],
23
- start_dim: int = 0) -> tuple[Placement, ...]:
24
- new_placements: list[Placement] = []
25
- for p in placements:
26
- if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
- new_placements.append(Replicate()) # make it replicate
28
- else:
29
- new_placements.append(p) # keep the placement
30
- return tuple(new_placements)
31
-
32
-
33
  @register_op_strategy(ops.fused_add_rms_norm.default,
34
  schema_info=RuntimeSchemaInfo(1))
35
  def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
@@ -89,7 +81,7 @@ def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
89
  # Weight cannot be sharded, so always replicate it.
90
  weight_tgt = DTensorSpec(
91
  mesh=mesh,
92
- placements=(Replicate(), ),
93
  tensor_meta=weight_src.tensor_meta,
94
  )
95
  redistribute_costs.append(
@@ -141,6 +133,8 @@ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
141
  )
142
 
143
  last_dim = output_grad_strategy.ndim - 1
 
 
144
  strategy = OpStrategy([])
145
  for output_grad, add_output_grad, add_output, weight in zipped:
146
  output_grad_src = output_grad.output_spec
@@ -179,16 +173,35 @@ def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
179
  # Weight cannot be sharded, so always replicate it.
180
  weight_tgt = DTensorSpec(
181
  mesh=mesh,
182
- placements=(Replicate(), ),
183
  tensor_meta=weight_src.tensor_meta,
184
  )
185
  redistribute_costs.append(
186
  generate_redistribute_costs(weight_strategy, weight_tgt))
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  strategy.strategies.append(
189
  OpSpec(
190
  output_specs=[
191
- output_grad_tgt if need_input_grad else None, weight_tgt
 
192
  ],
193
  input_specs=[
194
  output_grad_tgt, add_output_grad_tgt, add_output_tgt,
 
4
  from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
  from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
  RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops._math_ops import (
8
+ _infer_reduce_dims_map, _replicate_dims_start_at,
9
+ map_placements_after_reduction)
10
  from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
11
  register_op_strategy)
12
  from torch.distributed.tensor.placement_types import (Placement, Replicate,
 
22
  pass
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  @register_op_strategy(ops.fused_add_rms_norm.default,
26
  schema_info=RuntimeSchemaInfo(1))
27
  def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
 
81
  # Weight cannot be sharded, so always replicate it.
82
  weight_tgt = DTensorSpec(
83
  mesh=mesh,
84
+ placements=_replicate_dims_start_at(weight_src.placements),
85
  tensor_meta=weight_src.tensor_meta,
86
  )
87
  redistribute_costs.append(
 
133
  )
134
 
135
  last_dim = output_grad_strategy.ndim - 1
136
+ outer_dims = list(range(last_dim))
137
+
138
  strategy = OpStrategy([])
139
  for output_grad, add_output_grad, add_output, weight in zipped:
140
  output_grad_src = output_grad.output_spec
 
173
  # Weight cannot be sharded, so always replicate it.
174
  weight_tgt = DTensorSpec(
175
  mesh=mesh,
176
+ placements=_replicate_dims_start_at(weight_src.placements),
177
  tensor_meta=weight_src.tensor_meta,
178
  )
179
  redistribute_costs.append(
180
  generate_redistribute_costs(weight_strategy, weight_tgt))
181
 
182
+ # from torch/distributed/tensor/_ops/_math_ops.py::layer_norm_bwd_strategy()
183
+
184
+ # Weight cannot be sharded, so always replicate it.
185
+ # TODO: now d_weight spec follows input spec w/ a reduction.
186
+ # we may need to change to a pointwise rule over grad_out and
187
+ # input, then apply a reduction.
188
+ inp_placements = _replicate_dims_start_at(output_grad_src.placements,
189
+ last_dim)
190
+ reduce_dims_map = _infer_reduce_dims_map(outer_dims,
191
+ output_grad_src.ndim, False)
192
+ out_placements = map_placements_after_reduction(
193
+ inp_placements, outer_dims, reduce_dims_map, "sum")
194
+ weight_grad_tgt = DTensorSpec(
195
+ mesh=mesh,
196
+ placements=out_placements,
197
+ tensor_meta=weight_src.tensor_meta,
198
+ )
199
+
200
  strategy.strategies.append(
201
  OpSpec(
202
  output_specs=[
203
+ output_grad_tgt if need_input_grad else None,
204
+ weight_grad_tgt
205
  ],
206
  input_specs=[
207
  output_grad_tgt, add_output_grad_tgt, add_output_tgt,