refactor(rms_norm): move RMS normalization logic to a new module for better organization and maintainability
Browse files
torch-ext/activation/rms_norm.py
CHANGED
|
@@ -76,156 +76,5 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
|
|
| 76 |
|
| 77 |
|
| 78 |
if version.parse(torch.__version__) >= version.parse("2.8"):
|
| 79 |
-
from
|
| 80 |
-
|
| 81 |
-
OpStrategy,
|
| 82 |
-
RuntimeSchemaInfo)
|
| 83 |
-
from torch.distributed.tensor._ops.utils import (
|
| 84 |
-
generate_redistribute_costs, register_op_strategy)
|
| 85 |
-
from torch.distributed.tensor.placement_types import (Placement, Replicate,
|
| 86 |
-
Shard)
|
| 87 |
-
|
| 88 |
-
@torch.library.register_fake(ops.rms_norm.default)
|
| 89 |
-
def rms_norm_abstract(x, weight, eps):
|
| 90 |
-
return torch.empty_like(x)
|
| 91 |
-
|
| 92 |
-
@torch.library.register_fake(ops.rms_norm_backward.default)
|
| 93 |
-
def rms_norm_backward_abstract(output_grad, x, weight, eps):
|
| 94 |
-
return torch.empty_like(x), torch.empty_like(weight)
|
| 95 |
-
|
| 96 |
-
def _replicate_dims_start_at(placements: Sequence[Placement],
|
| 97 |
-
start_dim: int = 0) -> tuple[Placement, ...]:
|
| 98 |
-
new_placements: list[Placement] = []
|
| 99 |
-
for p in placements:
|
| 100 |
-
if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
|
| 101 |
-
new_placements.append(Replicate()) # make it replicate
|
| 102 |
-
else:
|
| 103 |
-
new_placements.append(p) # keep the placement
|
| 104 |
-
return tuple(new_placements)
|
| 105 |
-
|
| 106 |
-
@register_op_strategy(ops.rms_norm.default,
|
| 107 |
-
schema_info=RuntimeSchemaInfo(1))
|
| 108 |
-
def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
| 109 |
-
mesh = op_schema.get_mesh_from_args()
|
| 110 |
-
|
| 111 |
-
assert len(op_schema.args_schema) == 3
|
| 112 |
-
(
|
| 113 |
-
input_strategy,
|
| 114 |
-
weight_strategy,
|
| 115 |
-
_, # eps
|
| 116 |
-
) = op_schema.args_schema
|
| 117 |
-
|
| 118 |
-
assert isinstance(input_strategy, OpStrategy)
|
| 119 |
-
assert isinstance(weight_strategy, OpStrategy)
|
| 120 |
-
|
| 121 |
-
assert len(input_strategy.strategies) == len(
|
| 122 |
-
weight_strategy.strategies)
|
| 123 |
-
|
| 124 |
-
last_dim = input_strategy.ndim - 1
|
| 125 |
-
strategy = OpStrategy([])
|
| 126 |
-
for idx in range(len(input_strategy.strategies)):
|
| 127 |
-
input_src = input_strategy.strategies[idx].output_spec
|
| 128 |
-
weight_src = weight_strategy.strategies[idx].output_spec
|
| 129 |
-
|
| 130 |
-
assert isinstance(input_src, DTensorSpec)
|
| 131 |
-
assert isinstance(weight_src, DTensorSpec)
|
| 132 |
-
|
| 133 |
-
redistribute_costs = []
|
| 134 |
-
|
| 135 |
-
# Input can be sharded in any dim except the last dim.
|
| 136 |
-
input_tgt = DTensorSpec(
|
| 137 |
-
mesh=mesh,
|
| 138 |
-
placements=_replicate_dims_start_at(input_src.placements,
|
| 139 |
-
last_dim),
|
| 140 |
-
tensor_meta=input_src.tensor_meta,
|
| 141 |
-
)
|
| 142 |
-
redistribute_costs.append(
|
| 143 |
-
generate_redistribute_costs(input_strategy, input_tgt))
|
| 144 |
-
|
| 145 |
-
# Weight cannot be sharded, so always replicate it.
|
| 146 |
-
weight_tgt = DTensorSpec(
|
| 147 |
-
mesh=mesh,
|
| 148 |
-
placements=(Replicate(), ),
|
| 149 |
-
tensor_meta=weight_src.tensor_meta,
|
| 150 |
-
)
|
| 151 |
-
redistribute_costs.append(
|
| 152 |
-
generate_redistribute_costs(weight_strategy, weight_tgt))
|
| 153 |
-
|
| 154 |
-
strategy.strategies.append(
|
| 155 |
-
OpSpec(
|
| 156 |
-
output_specs=input_tgt,
|
| 157 |
-
input_specs=[input_tgt, weight_tgt],
|
| 158 |
-
redistribute_cost=redistribute_costs,
|
| 159 |
-
))
|
| 160 |
-
return strategy
|
| 161 |
-
|
| 162 |
-
@register_op_strategy(ops.rms_norm_backward.default,
|
| 163 |
-
schema_info=RuntimeSchemaInfo(1))
|
| 164 |
-
def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
|
| 165 |
-
mesh = op_schema.get_mesh_from_args()
|
| 166 |
-
|
| 167 |
-
assert len(op_schema.args_schema) == 4
|
| 168 |
-
(
|
| 169 |
-
output_grad_strategy,
|
| 170 |
-
input_strategy,
|
| 171 |
-
weight_strategy,
|
| 172 |
-
_, # eps
|
| 173 |
-
) = op_schema.args_schema
|
| 174 |
-
|
| 175 |
-
assert isinstance(output_grad_strategy, OpStrategy)
|
| 176 |
-
assert isinstance(input_strategy, OpStrategy)
|
| 177 |
-
assert isinstance(weight_strategy, OpStrategy)
|
| 178 |
-
|
| 179 |
-
assert len(input_strategy.strategies) == len(
|
| 180 |
-
weight_strategy.strategies)
|
| 181 |
-
assert len(input_strategy.strategies) == len(
|
| 182 |
-
output_grad_strategy.strategies)
|
| 183 |
-
|
| 184 |
-
last_dim = input_strategy.ndim - 1
|
| 185 |
-
strategy = OpStrategy([])
|
| 186 |
-
for idx in range(len(input_strategy.strategies)):
|
| 187 |
-
output_grad_src = output_grad_strategy.strategies[idx].output_spec
|
| 188 |
-
input_src = input_strategy.strategies[idx].output_spec
|
| 189 |
-
weight_src = weight_strategy.strategies[idx].output_spec
|
| 190 |
-
|
| 191 |
-
assert isinstance(output_grad_src, DTensorSpec)
|
| 192 |
-
assert isinstance(input_src, DTensorSpec)
|
| 193 |
-
assert isinstance(weight_src, DTensorSpec)
|
| 194 |
-
|
| 195 |
-
redistribute_costs = []
|
| 196 |
-
|
| 197 |
-
# Output grad and input can be sharded in any dim except the last dim.
|
| 198 |
-
output_grad_tgt = DTensorSpec(
|
| 199 |
-
mesh=mesh,
|
| 200 |
-
placements=_replicate_dims_start_at(output_grad_src.placements,
|
| 201 |
-
last_dim),
|
| 202 |
-
tensor_meta=output_grad_src.tensor_meta,
|
| 203 |
-
)
|
| 204 |
-
redistribute_costs.append(
|
| 205 |
-
generate_redistribute_costs(output_grad_strategy,
|
| 206 |
-
output_grad_tgt))
|
| 207 |
-
input_tgt = DTensorSpec(
|
| 208 |
-
mesh=mesh,
|
| 209 |
-
placements=_replicate_dims_start_at(input_src.placements,
|
| 210 |
-
last_dim),
|
| 211 |
-
tensor_meta=input_src.tensor_meta,
|
| 212 |
-
)
|
| 213 |
-
redistribute_costs.append(
|
| 214 |
-
generate_redistribute_costs(input_strategy, input_tgt))
|
| 215 |
-
|
| 216 |
-
# Weight cannot be sharded, so always replicate it.
|
| 217 |
-
weight_tgt = DTensorSpec(
|
| 218 |
-
mesh=mesh,
|
| 219 |
-
placements=(Replicate(), ),
|
| 220 |
-
tensor_meta=weight_src.tensor_meta,
|
| 221 |
-
)
|
| 222 |
-
redistribute_costs.append(
|
| 223 |
-
generate_redistribute_costs(weight_strategy, weight_tgt))
|
| 224 |
-
|
| 225 |
-
strategy.strategies.append(
|
| 226 |
-
OpSpec(
|
| 227 |
-
output_specs=[input_tgt, weight_tgt],
|
| 228 |
-
input_specs=[output_grad_tgt, input_tgt, weight_tgt],
|
| 229 |
-
redistribute_cost=redistribute_costs,
|
| 230 |
-
))
|
| 231 |
-
return strategy
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
if version.parse(torch.__version__) >= version.parse("2.8"):
|
| 79 |
+
from .rms_norm_meta import register_rms_norm_meta
|
| 80 |
+
register_rms_norm_meta()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/activation/rms_norm_meta.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
| 3 |
+
from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
|
| 4 |
+
RuntimeSchemaInfo)
|
| 5 |
+
from torch.distributed.tensor._ops.utils import (generate_redistribute_costs,
|
| 6 |
+
register_op_strategy)
|
| 7 |
+
from torch.distributed.tensor.placement_types import (Placement, Replicate,
|
| 8 |
+
Shard)
|
| 9 |
+
|
| 10 |
+
from ._ops import ops
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def register_rms_norm_meta():
|
| 14 |
+
"""Dummy function to register the meta functions.
|
| 15 |
+
Registration happens at import time by the decorators below.
|
| 16 |
+
"""
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@torch.library.register_fake(ops.rms_norm.default)
|
| 21 |
+
def rms_norm_abstract(x, weight, eps):
|
| 22 |
+
return torch.empty_like(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@torch.library.register_fake(ops.rms_norm_backward.default)
|
| 26 |
+
def rms_norm_backward_abstract(output_grad, x, weight, eps):
|
| 27 |
+
return torch.empty_like(x), torch.empty_like(weight)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _replicate_dims_start_at(placements: Sequence[Placement],
|
| 31 |
+
start_dim: int = 0) -> tuple[Placement, ...]:
|
| 32 |
+
new_placements: list[Placement] = []
|
| 33 |
+
for p in placements:
|
| 34 |
+
if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
|
| 35 |
+
new_placements.append(Replicate()) # make it replicate
|
| 36 |
+
else:
|
| 37 |
+
new_placements.append(p) # keep the placement
|
| 38 |
+
return tuple(new_placements)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1))
|
| 42 |
+
def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
| 43 |
+
mesh = op_schema.get_mesh_from_args()
|
| 44 |
+
|
| 45 |
+
assert len(op_schema.args_schema) == 3
|
| 46 |
+
(
|
| 47 |
+
input_strategy,
|
| 48 |
+
weight_strategy,
|
| 49 |
+
_, # eps
|
| 50 |
+
) = op_schema.args_schema
|
| 51 |
+
|
| 52 |
+
assert isinstance(input_strategy, OpStrategy)
|
| 53 |
+
assert isinstance(weight_strategy, OpStrategy)
|
| 54 |
+
|
| 55 |
+
assert len(input_strategy.strategies) == len(weight_strategy.strategies)
|
| 56 |
+
|
| 57 |
+
last_dim = input_strategy.ndim - 1
|
| 58 |
+
strategy = OpStrategy([])
|
| 59 |
+
for idx in range(len(input_strategy.strategies)):
|
| 60 |
+
input_src = input_strategy.strategies[idx].output_spec
|
| 61 |
+
weight_src = weight_strategy.strategies[idx].output_spec
|
| 62 |
+
|
| 63 |
+
assert isinstance(input_src, DTensorSpec)
|
| 64 |
+
assert isinstance(weight_src, DTensorSpec)
|
| 65 |
+
|
| 66 |
+
redistribute_costs = []
|
| 67 |
+
|
| 68 |
+
# Input can be sharded in any dim except the last dim.
|
| 69 |
+
input_tgt = DTensorSpec(
|
| 70 |
+
mesh=mesh,
|
| 71 |
+
placements=_replicate_dims_start_at(input_src.placements,
|
| 72 |
+
last_dim),
|
| 73 |
+
tensor_meta=input_src.tensor_meta,
|
| 74 |
+
)
|
| 75 |
+
redistribute_costs.append(
|
| 76 |
+
generate_redistribute_costs(input_strategy, input_tgt))
|
| 77 |
+
|
| 78 |
+
# Weight cannot be sharded, so always replicate it.
|
| 79 |
+
weight_tgt = DTensorSpec(
|
| 80 |
+
mesh=mesh,
|
| 81 |
+
placements=(Replicate(), ),
|
| 82 |
+
tensor_meta=weight_src.tensor_meta,
|
| 83 |
+
)
|
| 84 |
+
redistribute_costs.append(
|
| 85 |
+
generate_redistribute_costs(weight_strategy, weight_tgt))
|
| 86 |
+
|
| 87 |
+
strategy.strategies.append(
|
| 88 |
+
OpSpec(
|
| 89 |
+
output_specs=input_tgt,
|
| 90 |
+
input_specs=[input_tgt, weight_tgt],
|
| 91 |
+
redistribute_cost=redistribute_costs,
|
| 92 |
+
))
|
| 93 |
+
return strategy
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@register_op_strategy(ops.rms_norm_backward.default,
|
| 97 |
+
schema_info=RuntimeSchemaInfo(1))
|
| 98 |
+
def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy:
|
| 99 |
+
mesh = op_schema.get_mesh_from_args()
|
| 100 |
+
|
| 101 |
+
assert len(op_schema.args_schema) == 4
|
| 102 |
+
(
|
| 103 |
+
output_grad_strategy,
|
| 104 |
+
input_strategy,
|
| 105 |
+
weight_strategy,
|
| 106 |
+
_, # eps
|
| 107 |
+
) = op_schema.args_schema
|
| 108 |
+
|
| 109 |
+
assert isinstance(output_grad_strategy, OpStrategy)
|
| 110 |
+
assert isinstance(input_strategy, OpStrategy)
|
| 111 |
+
assert isinstance(weight_strategy, OpStrategy)
|
| 112 |
+
|
| 113 |
+
assert len(input_strategy.strategies) == len(weight_strategy.strategies)
|
| 114 |
+
assert len(input_strategy.strategies) == len(
|
| 115 |
+
output_grad_strategy.strategies)
|
| 116 |
+
|
| 117 |
+
last_dim = input_strategy.ndim - 1
|
| 118 |
+
strategy = OpStrategy([])
|
| 119 |
+
for idx in range(len(input_strategy.strategies)):
|
| 120 |
+
output_grad_src = output_grad_strategy.strategies[idx].output_spec
|
| 121 |
+
input_src = input_strategy.strategies[idx].output_spec
|
| 122 |
+
weight_src = weight_strategy.strategies[idx].output_spec
|
| 123 |
+
|
| 124 |
+
assert isinstance(output_grad_src, DTensorSpec)
|
| 125 |
+
assert isinstance(input_src, DTensorSpec)
|
| 126 |
+
assert isinstance(weight_src, DTensorSpec)
|
| 127 |
+
|
| 128 |
+
redistribute_costs = []
|
| 129 |
+
|
| 130 |
+
# Output grad and input can be sharded in any dim except the last dim.
|
| 131 |
+
output_grad_tgt = DTensorSpec(
|
| 132 |
+
mesh=mesh,
|
| 133 |
+
placements=_replicate_dims_start_at(output_grad_src.placements,
|
| 134 |
+
last_dim),
|
| 135 |
+
tensor_meta=output_grad_src.tensor_meta,
|
| 136 |
+
)
|
| 137 |
+
redistribute_costs.append(
|
| 138 |
+
generate_redistribute_costs(output_grad_strategy, output_grad_tgt))
|
| 139 |
+
input_tgt = DTensorSpec(
|
| 140 |
+
mesh=mesh,
|
| 141 |
+
placements=_replicate_dims_start_at(input_src.placements,
|
| 142 |
+
last_dim),
|
| 143 |
+
tensor_meta=input_src.tensor_meta,
|
| 144 |
+
)
|
| 145 |
+
redistribute_costs.append(
|
| 146 |
+
generate_redistribute_costs(input_strategy, input_tgt))
|
| 147 |
+
|
| 148 |
+
# Weight cannot be sharded, so always replicate it.
|
| 149 |
+
weight_tgt = DTensorSpec(
|
| 150 |
+
mesh=mesh,
|
| 151 |
+
placements=(Replicate(), ),
|
| 152 |
+
tensor_meta=weight_src.tensor_meta,
|
| 153 |
+
)
|
| 154 |
+
redistribute_costs.append(
|
| 155 |
+
generate_redistribute_costs(weight_strategy, weight_tgt))
|
| 156 |
+
|
| 157 |
+
strategy.strategies.append(
|
| 158 |
+
OpSpec(
|
| 159 |
+
output_specs=[input_tgt, weight_tgt],
|
| 160 |
+
input_specs=[output_grad_tgt, input_tgt, weight_tgt],
|
| 161 |
+
redistribute_cost=redistribute_costs,
|
| 162 |
+
))
|
| 163 |
+
return strategy
|