Kernels
wyldecat commited on
Commit
138159c
·
1 Parent(s): 1da8432

fix: fix rms norm sharding strategy

Browse files
tests/test_rms_norm_sequence_parallel.py CHANGED
@@ -6,6 +6,10 @@ import pytest
6
  import torch
7
  import torch.distributed as dist
8
  from packaging import version
 
 
 
 
9
  from torch.distributed.tensor.placement_types import (Partial, Placement,
10
  Replicate, Shard)
11
 
@@ -13,17 +17,6 @@ import activation
13
 
14
  from .utils import assert_close, opcheck
15
 
16
- DTYPES = [torch.float32]
17
- NUM_TOKENS = [512] # Arbitrary values for testing
18
- SEQUENCE_DIMS = [0, 1] # 0 is for [T, D] (packed), 1 is for [B, S, D]
19
- D = [16] # Arbitrary values for testing
20
- SEEDS = [0]
21
-
22
- from torch.distributed._tensor import DTensor
23
- from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
24
- from torch.distributed.tensor.parallel import (SequenceParallel,
25
- parallelize_module)
26
-
27
 
28
  @pytest.fixture(scope="session", autouse=True)
29
  def init_dist(request):
@@ -58,6 +51,13 @@ class Model(torch.nn.Module):
58
  return self.rms_norm(x)
59
 
60
 
 
 
 
 
 
 
 
61
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
62
  @pytest.mark.parametrize("d", D)
63
  @pytest.mark.parametrize("dtype", DTYPES)
@@ -106,12 +106,16 @@ def test_rms_norm_sequence_parallel(
106
  parallelize_module(
107
  model_sharded, mesh,
108
  {"rms_norm": SequenceParallel(sequence_dim=sequence_dim)})
109
- x_sharded = DTensor.from_local(
110
- x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(),
111
- placements=(Shard(sequence_dim), ),
 
112
  device_mesh=mesh,
113
  )
114
- y = model_sharded(x_sharded)
 
 
 
115
  y_from_sharded = y.full_tensor()
116
 
117
  model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda()
@@ -123,15 +127,11 @@ def test_rms_norm_sequence_parallel(
123
 
124
  # Backward
125
  y_grad = torch.randn_like(y_from_unsharded)
126
- y_from_sharded.backward(y_grad)
127
  y_from_unsharded.backward(y_grad)
 
128
 
129
- weight_grad_from_sharded = model_sharded.rms_norm.weight.grad._local_tensor
130
  weight_grad_from_unsharded = model_unsharded.rms_norm.weight.grad
131
 
132
- torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM)
133
- torch.distributed.all_reduce(weight_grad_from_sharded,
134
- op=torch.distributed.ReduceOp.SUM)
135
-
136
  assert_close(x.grad, x_ref.grad)
137
  assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
 
6
  import torch
7
  import torch.distributed as dist
8
  from packaging import version
9
+ from torch.distributed._tensor import DTensor
10
+ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
11
+ from torch.distributed.tensor.parallel import (SequenceParallel,
12
+ parallelize_module)
13
  from torch.distributed.tensor.placement_types import (Partial, Placement,
14
  Replicate, Shard)
15
 
 
17
 
18
  from .utils import assert_close, opcheck
19
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @pytest.fixture(scope="session", autouse=True)
22
  def init_dist(request):
 
51
  return self.rms_norm(x)
52
 
53
 
54
+ DTYPES = [torch.float32]
55
+ NUM_TOKENS = [512] # Arbitrary values for testing
56
+ SEQUENCE_DIMS = [0, 1] # 0 is for [T, D] (packed), 1 is for [B, S, D]
57
+ D = [16] # Arbitrary values for testing
58
+ SEEDS = [0]
59
+
60
+
61
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
62
  @pytest.mark.parametrize("d", D)
63
  @pytest.mark.parametrize("dtype", DTYPES)
 
106
  parallelize_module(
107
  model_sharded, mesh,
108
  {"rms_norm": SequenceParallel(sequence_dim=sequence_dim)})
109
+
110
+ x_replicate = DTensor.from_local(
111
+ x,
112
+ placements=(Replicate(), ),
113
  device_mesh=mesh,
114
  )
115
+
116
+ # Input will redistributed in SequenceParallel
117
+ y = model_sharded(x_replicate)
118
+
119
  y_from_sharded = y.full_tensor()
120
 
121
  model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda()
 
127
 
128
  # Backward
129
  y_grad = torch.randn_like(y_from_unsharded)
 
130
  y_from_unsharded.backward(y_grad)
131
+ y_from_sharded.backward(y_grad)
132
 
133
+ weight_grad_from_sharded = model_sharded.rms_norm.weight.grad.full_tensor()
134
  weight_grad_from_unsharded = model_unsharded.rms_norm.weight.grad
135
 
 
 
 
 
136
  assert_close(x.grad, x_ref.grad)
137
  assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
torch-ext/activation/rms_norm.py CHANGED
@@ -29,7 +29,6 @@ class RMSNormFunction(torch.autograd.Function):
29
 
30
  input_grad, weight_grad = ops.rms_norm_backward(
31
  output_grad, input, weight, eps)
32
-
33
  return input_grad, weight_grad, None
34
 
35
 
 
29
 
30
  input_grad, weight_grad = ops.rms_norm_backward(
31
  output_grad, input, weight, eps)
 
32
  return input_grad, weight_grad, None
33
 
34
 
torch-ext/activation/rms_norm_meta.py CHANGED
@@ -4,6 +4,9 @@ import torch
4
  from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
  from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
  RuntimeSchemaInfo)
 
 
 
7
  from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
8
  register_op_strategy)
9
  from torch.distributed.tensor.placement_types import (Placement, Replicate,
@@ -19,17 +22,6 @@ def register_rms_norm_meta():
19
  pass
20
 
21
 
22
- def _replicate_dims_start_at(placements: Sequence[Placement],
23
- start_dim: int = 0) -> tuple[Placement, ...]:
24
- new_placements: list[Placement] = []
25
- for p in placements:
26
- if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
27
- new_placements.append(Replicate()) # make it replicate
28
- else:
29
- new_placements.append(p) # keep the placement
30
- return tuple(new_placements)
31
-
32
-
33
  @register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
34
  def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
35
  mesh = op_schema.get_mesh_from_args()
@@ -71,7 +63,7 @@ def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
71
  # Weight cannot be sharded, so always replicate it.
72
  weight_tgt = DTensorSpec(
73
  mesh=mesh,
74
- placements=(Replicate(), ),
75
  tensor_meta=weight_src.tensor_meta,
76
  )
77
  redistribute_costs.append(
@@ -119,6 +111,8 @@ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
119
  )
120
 
121
  last_dim = input_strategy.ndim - 1
 
 
122
  strategy = OpStrategy([])
123
  for output_grad, input, weight in zipped:
124
  output_grad_src = output_grad.output_spec
@@ -134,7 +128,7 @@ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
134
  # Output grad can be sharded in any dim except the last dim.
135
  output_grad_tgt = DTensorSpec(
136
  mesh=mesh,
137
- placements=_replicate_dims_start_at(output_grad_src.placements,
138
  last_dim),
139
  tensor_meta=output_grad_src.tensor_meta,
140
  )
@@ -142,22 +136,48 @@ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
142
  generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
143
 
144
  # Input must have the same sharding as output grad.
145
- input_tgt = output_grad_tgt
 
 
 
 
 
 
146
  redistribute_costs.append(
147
  generate_redistribute_costs(input_strategy, input_tgt))
148
 
149
  # Weight cannot be sharded, so always replicate it.
150
  weight_tgt = DTensorSpec(
151
  mesh=mesh,
152
- placements=(Replicate(), ),
153
  tensor_meta=weight_src.tensor_meta,
154
  )
155
  redistribute_costs.append(
156
  generate_redistribute_costs(weight_strategy, weight_tgt))
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  strategy.strategies.append(
159
  OpSpec(
160
- output_specs=[input_tgt, weight_tgt],
161
  input_specs=[output_grad_tgt, input_tgt, weight_tgt],
162
  redistribute_cost=redistribute_costs,
163
  ))
 
4
  from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
  from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
6
  RuntimeSchemaInfo)
7
+ from torch.distributed.tensor._ops._math_ops import (
8
+ _infer_reduce_dims_map, _replicate_dims_start_at,
9
+ map_placements_after_reduction)
10
  from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
11
  register_op_strategy)
12
  from torch.distributed.tensor.placement_types import (Placement, Replicate,
 
22
  pass
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  @register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
26
  def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
27
  mesh = op_schema.get_mesh_from_args()
 
63
  # Weight cannot be sharded, so always replicate it.
64
  weight_tgt = DTensorSpec(
65
  mesh=mesh,
66
+ placements=_replicate_dims_start_at(weight_src.placements),
67
  tensor_meta=weight_src.tensor_meta,
68
  )
69
  redistribute_costs.append(
 
111
  )
112
 
113
  last_dim = input_strategy.ndim - 1
114
+ outer_dims = list(range(last_dim))
115
+
116
  strategy = OpStrategy([])
117
  for output_grad, input, weight in zipped:
118
  output_grad_src = output_grad.output_spec
 
128
  # Output grad can be sharded in any dim except the last dim.
129
  output_grad_tgt = DTensorSpec(
130
  mesh=mesh,
131
+ placements=_replicate_dims_start_at(input_src.placements,
132
  last_dim),
133
  tensor_meta=output_grad_src.tensor_meta,
134
  )
 
136
  generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
137
 
138
  # Input must have the same sharding as output grad.
139
+ input_tgt = DTensorSpec(
140
+ mesh=mesh,
141
+ placements=_replicate_dims_start_at(input_src.placements,
142
+ last_dim),
143
+ tensor_meta=input_src.tensor_meta,
144
+ )
145
+
146
  redistribute_costs.append(
147
  generate_redistribute_costs(input_strategy, input_tgt))
148
 
149
  # Weight cannot be sharded, so always replicate it.
150
  weight_tgt = DTensorSpec(
151
  mesh=mesh,
152
+ placements=_replicate_dims_start_at(weight_src.placements),
153
  tensor_meta=weight_src.tensor_meta,
154
  )
155
  redistribute_costs.append(
156
  generate_redistribute_costs(weight_strategy, weight_tgt))
157
 
158
+ # from torch/distributed/tensor/_ops/_math_ops.py::layer_norm_bwd_strategy()
159
+
160
+ # Weight cannot be sharded, so always replicate it.
161
+ # TODO: now d_weight spec follows input spec w/ a reduction.
162
+ # we may need to change to a pointwise rule over grad_out and
163
+ # input, then apply a reduction.
164
+ inp_placements = _replicate_dims_start_at(input_src.placements,
165
+ last_dim)
166
+ reduce_dims_map = _infer_reduce_dims_map(outer_dims, input_src.ndim,
167
+ False)
168
+ out_placements = map_placements_after_reduction(
169
+ inp_placements, outer_dims, reduce_dims_map, "sum")
170
+ weight_grad_tgt = DTensorSpec(
171
+ mesh=mesh,
172
+ placements=out_placements,
173
+ tensor_meta=weight_src.tensor_meta,
174
+ )
175
+
176
+ input_grad_tgt = output_grad_tgt
177
+
178
  strategy.strategies.append(
179
  OpSpec(
180
+ output_specs=[input_grad_tgt, weight_grad_tgt],
181
  input_specs=[output_grad_tgt, input_tgt, weight_tgt],
182
  redistribute_cost=redistribute_costs,
183
  ))