Kernels
wyldecat commited on
Commit
66b3c5e
·
1 Parent(s): 06d6367

refactor(rms_norm): move RMS normalization logic to a new module for better organization and maintainability

Browse files
torch-ext/activation/rms_norm.py CHANGED
@@ -76,156 +76,5 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
76
 
77
 
78
  if version.parse(torch.__version__) >= version.parse("2.8"):
79
- from torch.distributed.tensor._dtensor_spec import DTensorSpec
80
- from torch.distributed.tensor._op_schema import (OpSchema, OpSpec,
81
- OpStrategy,
82
- RuntimeSchemaInfo)
83
- from torch.distributed.tensor._ops.utils import (
84
- generate_redistribute_costs, register_op_strategy)
85
- from torch.distributed.tensor.placement_types import (Placement, Replicate,
86
- Shard)
87
-
88
- @torch.library.register_fake(ops.rms_norm.default)
89
- def rms_norm_abstract(x, weight, eps):
90
- return torch.empty_like(x)
91
-
92
- @torch.library.register_fake(ops.rms_norm_backward.default)
93
- def rms_norm_backward_abstract(output_grad, x, weight, eps):
94
- return torch.empty_like(x), torch.empty_like(weight)
95
-
96
- def _replicate_dims_start_at(placements: Sequence[Placement],
97
- start_dim: int = 0) -> tuple[Placement, ...]:
98
- new_placements: list[Placement] = []
99
- for p in placements:
100
- if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
101
- new_placements.append(Replicate()) # make it replicate
102
- else:
103
- new_placements.append(p) # keep the placement
104
- return tuple(new_placements)
105
-
106
- @register_op_strategy(ops.rms_norm.default,
107
- schema_info=RuntimeSchemaInfo(1))
108
- def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
109
- mesh = op_schema.get_mesh_from_args()
110
-
111
- assert len(op_schema.args_schema) == 3
112
- (
113
- input_strategy,
114
- weight_strategy,
115
- _, # eps
116
- ) = op_schema.args_schema
117
-
118
- assert isinstance(input_strategy, OpStrategy)
119
- assert isinstance(weight_strategy, OpStrategy)
120
-
121
- assert len(input_strategy.strategies) == len(
122
- weight_strategy.strategies)
123
-
124
- last_dim = input_strategy.ndim - 1
125
- strategy = OpStrategy([])
126
- for idx in range(len(input_strategy.strategies)):
127
- input_src = input_strategy.strategies[idx].output_spec
128
- weight_src = weight_strategy.strategies[idx].output_spec
129
-
130
- assert isinstance(input_src, DTensorSpec)
131
- assert isinstance(weight_src, DTensorSpec)
132
-
133
- redistribute_costs = []
134
-
135
- # Input can be sharded in any dim except the last dim.
136
- input_tgt = DTensorSpec(
137
- mesh=mesh,
138
- placements=_replicate_dims_start_at(input_src.placements,
139
- last_dim),
140
- tensor_meta=input_src.tensor_meta,
141
- )
142
- redistribute_costs.append(
143
- generate_redistribute_costs(input_strategy, input_tgt))
144
-
145
- # Weight cannot be sharded, so always replicate it.
146
- weight_tgt = DTensorSpec(
147
- mesh=mesh,
148
- placements=(Replicate(), ),
149
- tensor_meta=weight_src.tensor_meta,
150
- )
151
- redistribute_costs.append(
152
- generate_redistribute_costs(weight_strategy, weight_tgt))
153
-
154
- strategy.strategies.append(
155
- OpSpec(
156
- output_specs=input_tgt,
157
- input_specs=[input_tgt, weight_tgt],
158
- redistribute_cost=redistribute_costs,
159
- ))
160
- return strategy
161
-
162
- @register_op_strategy(ops.rms_norm_backward.default,
163
- schema_info=RuntimeSchemaInfo(1))
164
- def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
165
- mesh = op_schema.get_mesh_from_args()
166
-
167
- assert len(op_schema.args_schema) == 4
168
- (
169
- output_grad_strategy,
170
- input_strategy,
171
- weight_strategy,
172
- _, # eps
173
- ) = op_schema.args_schema
174
-
175
- assert isinstance(output_grad_strategy, OpStrategy)
176
- assert isinstance(input_strategy, OpStrategy)
177
- assert isinstance(weight_strategy, OpStrategy)
178
-
179
- assert len(input_strategy.strategies) == len(
180
- weight_strategy.strategies)
181
- assert len(input_strategy.strategies) == len(
182
- output_grad_strategy.strategies)
183
-
184
- last_dim = input_strategy.ndim - 1
185
- strategy = OpStrategy([])
186
- for idx in range(len(input_strategy.strategies)):
187
- output_grad_src = output_grad_strategy.strategies[idx].output_spec
188
- input_src = input_strategy.strategies[idx].output_spec
189
- weight_src = weight_strategy.strategies[idx].output_spec
190
-
191
- assert isinstance(output_grad_src, DTensorSpec)
192
- assert isinstance(input_src, DTensorSpec)
193
- assert isinstance(weight_src, DTensorSpec)
194
-
195
- redistribute_costs = []
196
-
197
- # Output grad and input can be sharded in any dim except the last dim.
198
- output_grad_tgt = DTensorSpec(
199
- mesh=mesh,
200
- placements=_replicate_dims_start_at(output_grad_src.placements,
201
- last_dim),
202
- tensor_meta=output_grad_src.tensor_meta,
203
- )
204
- redistribute_costs.append(
205
- generate_redistribute_costs(output_grad_strategy,
206
- output_grad_tgt))
207
- input_tgt = DTensorSpec(
208
- mesh=mesh,
209
- placements=_replicate_dims_start_at(input_src.placements,
210
- last_dim),
211
- tensor_meta=input_src.tensor_meta,
212
- )
213
- redistribute_costs.append(
214
- generate_redistribute_costs(input_strategy, input_tgt))
215
-
216
- # Weight cannot be sharded, so always replicate it.
217
- weight_tgt = DTensorSpec(
218
- mesh=mesh,
219
- placements=(Replicate(), ),
220
- tensor_meta=weight_src.tensor_meta,
221
- )
222
- redistribute_costs.append(
223
- generate_redistribute_costs(weight_strategy, weight_tgt))
224
-
225
- strategy.strategies.append(
226
- OpSpec(
227
- output_specs=[input_tgt, weight_tgt],
228
- input_specs=[output_grad_tgt, input_tgt, weight_tgt],
229
- redistribute_cost=redistribute_costs,
230
- ))
231
- return strategy
 
76
 
77
 
78
  if version.parse(torch.__version__) >= version.parse("2.8"):
79
+ from .rms_norm_meta import register_rms_norm_meta
80
+ register_rms_norm_meta()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/activation/rms_norm_meta.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
3
+ from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
4
+ RuntimeSchemaInfo)
5
+ from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
6
+ register_op_strategy)
7
+ from torch.distributed.tensor.placement_types import (Placement, Replicate,
8
+ Shard)
9
+
10
+ from ._ops import ops
11
+
12
+
13
+ def register_rms_norm_meta():
14
+ """Dummy function to register the meta functions.
15
+ Registration happens at import time by the decorators below.
16
+ """
17
+ pass
18
+
19
+
20
+ @torch.library.register_fake(ops.rms_norm.default)
21
+ def rms_norm_abstract(x, weight, eps):
22
+ return torch.empty_like(x)
23
+
24
+
25
+ @torch.library.register_fake(ops.rms_norm_backward.default)
26
+ def rms_norm_backward_abstract(output_grad, x, weight, eps):
27
+ return torch.empty_like(x), torch.empty_like(weight)
28
+
29
+
30
+ def _replicate_dims_start_at(placements: Sequence[Placement],
31
+ start_dim: int = 0) -> tuple[Placement, ...]:
32
+ new_placements: list[Placement] = []
33
+ for p in placements:
34
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
35
+ new_placements.append(Replicate()) # make it replicate
36
+ else:
37
+ new_placements.append(p) # keep the placement
38
+ return tuple(new_placements)
39
+
40
+
41
+ @register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
42
+ def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
43
+ mesh = op_schema.get_mesh_from_args()
44
+
45
+ assert len(op_schema.args_schema) == 3
46
+ (
47
+ input_strategy,
48
+ weight_strategy,
49
+ _, # eps
50
+ ) = op_schema.args_schema
51
+
52
+ assert isinstance(input_strategy, OpStrategy)
53
+ assert isinstance(weight_strategy, OpStrategy)
54
+
55
+ assert len(input_strategy.strategies) == len(weight_strategy.strategies)
56
+
57
+ last_dim = input_strategy.ndim - 1
58
+ strategy = OpStrategy([])
59
+ for idx in range(len(input_strategy.strategies)):
60
+ input_src = input_strategy.strategies[idx].output_spec
61
+ weight_src = weight_strategy.strategies[idx].output_spec
62
+
63
+ assert isinstance(input_src, DTensorSpec)
64
+ assert isinstance(weight_src, DTensorSpec)
65
+
66
+ redistribute_costs = []
67
+
68
+ # Input can be sharded in any dim except the last dim.
69
+ input_tgt = DTensorSpec(
70
+ mesh=mesh,
71
+ placements=_replicate_dims_start_at(input_src.placements,
72
+ last_dim),
73
+ tensor_meta=input_src.tensor_meta,
74
+ )
75
+ redistribute_costs.append(
76
+ generate_redistribute_costs(input_strategy, input_tgt))
77
+
78
+ # Weight cannot be sharded, so always replicate it.
79
+ weight_tgt = DTensorSpec(
80
+ mesh=mesh,
81
+ placements=(Replicate(), ),
82
+ tensor_meta=weight_src.tensor_meta,
83
+ )
84
+ redistribute_costs.append(
85
+ generate_redistribute_costs(weight_strategy, weight_tgt))
86
+
87
+ strategy.strategies.append(
88
+ OpSpec(
89
+ output_specs=input_tgt,
90
+ input_specs=[input_tgt, weight_tgt],
91
+ redistribute_cost=redistribute_costs,
92
+ ))
93
+ return strategy
94
+
95
+
96
+ @register_op_strategy(ops.rms_norm_backward.default,
97
+ schema_info=RuntimeSchemaInfo(1))
98
+ def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
99
+ mesh = op_schema.get_mesh_from_args()
100
+
101
+ assert len(op_schema.args_schema) == 4
102
+ (
103
+ output_grad_strategy,
104
+ input_strategy,
105
+ weight_strategy,
106
+ _, # eps
107
+ ) = op_schema.args_schema
108
+
109
+ assert isinstance(output_grad_strategy, OpStrategy)
110
+ assert isinstance(input_strategy, OpStrategy)
111
+ assert isinstance(weight_strategy, OpStrategy)
112
+
113
+ assert len(input_strategy.strategies) == len(weight_strategy.strategies)
114
+ assert len(input_strategy.strategies) == len(
115
+ output_grad_strategy.strategies)
116
+
117
+ last_dim = input_strategy.ndim - 1
118
+ strategy = OpStrategy([])
119
+ for idx in range(len(input_strategy.strategies)):
120
+ output_grad_src = output_grad_strategy.strategies[idx].output_spec
121
+ input_src = input_strategy.strategies[idx].output_spec
122
+ weight_src = weight_strategy.strategies[idx].output_spec
123
+
124
+ assert isinstance(output_grad_src, DTensorSpec)
125
+ assert isinstance(input_src, DTensorSpec)
126
+ assert isinstance(weight_src, DTensorSpec)
127
+
128
+ redistribute_costs = []
129
+
130
+ # Output grad and input can be sharded in any dim except the last dim.
131
+ output_grad_tgt = DTensorSpec(
132
+ mesh=mesh,
133
+ placements=_replicate_dims_start_at(output_grad_src.placements,
134
+ last_dim),
135
+ tensor_meta=output_grad_src.tensor_meta,
136
+ )
137
+ redistribute_costs.append(
138
+ generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
139
+ input_tgt = DTensorSpec(
140
+ mesh=mesh,
141
+ placements=_replicate_dims_start_at(input_src.placements,
142
+ last_dim),
143
+ tensor_meta=input_src.tensor_meta,
144
+ )
145
+ redistribute_costs.append(
146
+ generate_redistribute_costs(input_strategy, input_tgt))
147
+
148
+ # Weight cannot be sharded, so always replicate it.
149
+ weight_tgt = DTensorSpec(
150
+ mesh=mesh,
151
+ placements=(Replicate(), ),
152
+ tensor_meta=weight_src.tensor_meta,
153
+ )
154
+ redistribute_costs.append(
155
+ generate_redistribute_costs(weight_strategy, weight_tgt))
156
+
157
+ strategy.strategies.append(
158
+ OpSpec(
159
+ output_specs=[input_tgt, weight_tgt],
160
+ input_specs=[output_grad_tgt, input_tgt, weight_tgt],
161
+ redistribute_cost=redistribute_costs,
162
+ ))
163
+ return strategy