diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index 1b3674d54c044dddf2d037c1d3bac522bc19440c..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d21a85bf21aa74f1281541e658acfd4f4326d902efe3578b059eccf054443284 -size 8089696 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3a932d27ddea2e33fe525e6b4967c495185ed4e6 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80267a0391fa4cb22aa3eb04b05d8214c2bfaed968b714185bc20214596072e3 +size 8618232 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index 5a1e5a3587679a157ba7b067d28d762c6577fb8f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ec9ea7edc8b27f7983e20d615ab470cef6b82975afc214becfddfd05a867a839 -size 8600336 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index f3a874e78aac8a38f35e3d3aa4d26c892c9a0d66..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bd84c828d4c15e96d65d6c8f0eb7a945ee8167d92e978b2ebce03eeaf41e7fce -size 4405112 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/parallel_style.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm_meta.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index df3c3ae7785a3c30c36d900923c1dd7a349448db..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:74d4955271509451b946495da75f69a0f978e7258b8303fe3c077e585c0d3e6a -size 8272456 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..78a31eaa048eb33b56dd4c5c506be5821382d51e --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef6e4eb51daac20f0d7ed9825052ecca9d8451825784c87d58fa69092c145f35 +size 8793008 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index 30ab86df7c79038bc40bcd1292a2fa606b44ebc1..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5d3511410cdc288d2fafc500223ed2e625e360f50fa341809cf892fb2c822924 -size 8779000 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index 689760116de97c954865cd824732f04d2f746728..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:caffcadbb99fbaa27e8a81d5ef508f2e1a798e7626d618c3cf5b0d387d2c8686 -size 4618624 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/parallel_style.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index 0de3488964fc7207148b7b9b62cc4db838e64c7b..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0bf0d2ab5ff5520704e0b0c959b61d0043d360cfd4335950e69677873a87e436 -size 12792112 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..652a75128a964627956e4ddad2c408a156e7ad3d --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0699647f4c0bfc57711e8488dfa3864e7cfdf9119fb743fdaafcb2cbd2cea2c +size 13836872 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index b57174622d44e91556d4646cc225ce02ae186236..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:25efc9c32e4bd6609a8326025aad861cbf79b544893755fe44519c9df7224c40 -size 13818872 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index 45881f2bf18843120634173e5a0974ebdcbe07c6..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3b7c6ece8e8d316c4cc5fe46b1cec4422b2f61e9bb7240af71a2b4a35975d8e6 -size 6676528 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/parallel_style.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index 57361102c13046a6a1aab2f7125193ece35b21da..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:640322a8fac8fd9d8e9f195a3034c4ee0f81ee1acf897fd7c482a84ce47a1bec -size 4160688 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..53cb15c8d46b535435939e427efa04eab7480e38 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d973bad96565705f9e27514a9dbfb37343d0220da4a3ae7156b1cf6a27813643 +size 2773952 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index c0069ea9e4f962208b869f671b23aa15f728cb92..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c80d05690547f2842d416ebb85c9f830370373bc7e6c54ba08eec61b3690280f -size 4386744 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index 6e05f5b3045576c970e67481e0182f9aaf5a88d2..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4be173820e2a4bf4b6b8de6b63faf6544b599d9b0583f650a940adaef4a048b3 -size 2899184 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/parallel_style.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index c703b3b19594e8b20ee5b4dc7692fbdad8079365..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1768d8d5072ac06d937cb5332988c6b3bfaa191f72d1369a22d2c577e9a3bca2 -size 8215280 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..9dd321cefe0de18549ef04943e3a0114bf147423 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c301db3d37625ebf0cecf016948ec18fbeddb497acca8c870d2d8eff0a1d1203 +size 8735952 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index a50764fa05ea1e21294f84d922050f5d70f7db93..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:440f5c17a7ddaf73c506bbc84fd1405e2e188b8ceaf4977910608be6b91e89bf -size 8730200 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index 6c12e8b587a01fe10f4e73cca22a5a27fd2e794a..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cb222449350310f90f7271f34fcf9052c9eec28021fee0348130a8f239a97bf4 -size 4571976 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/parallel_style.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index ecdc467a674247fe3898453418ce88a9983d08c5..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:37a572bd877980ab8c0331ca5682191cb5a2b1f05bc69ea493a9e24f7728ba3f -size 12730840 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1708e32f91b7337cd1151eacf86c8f12a534e889 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f7879c74d91f2412bbf5524cd107dea64edeeeabf1dd496eeefa627d2e7143c +size 13775752 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index d3e4416a52e04ff527f48c721c6c4f1fa16059ed..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1dfb6d468f9cef0239d4ea47f0a247fa721befc5b8db86e1cddfc25f1814b67a -size 13770064 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index ff5ceef3b840a9957dab36434074fa21417f6711..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:79be6527f579de1133e50a66310d7d0690649dcac63009a54b5e68809408f12a -size 6634208 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/parallel_style.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index d6c8a74ea050b78cf9dcd4c43ac618094b0ca303..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3f15919c4cac697cde550af16256e338472400e50df751e93622350c7f626bc8 -size 12726208 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7c6bd75e264546f1b0995013e4c660dc5975641b --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88f3763ca4b2daa7bf72027ca1a190c63ff63e78d9d1a52b6b274de304a757db +size 13762936 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index ebdc9108aad1a1dfd16dc0d8baebf827bc0476f4..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0815a50e61497b357b2b90fc28602b3f53a25da1161edd2cb0b0fbebc7c62bf6 -size 13757248 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index f7ab393218a3d825e10b9e1e838440d8a543ce19..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8d95e4491d35cb022a6eaa2febbc555f203893f989a4fb1cc483b2632f141869 -size 6687456 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/parallel_style.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm_meta.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index 670a8291fdc208c690447600ee77449e1fac9929..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e72d4bb4459a5da96ca5eda1d305237a361140f0e25360e3d20326a22f1b6d47 -size 4165584 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..24bccb66acb500ca181fdb99103935662014d339 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a8daaa2a9b0e307b0740b2744cf759b6d1b4c229c6030458fe9600b3f70a28f +size 2774744 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index a7e8ec3a1957ec7fa888600e141e2d6acdb1d4be..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4d404c88b72f1b6da551a64b3373395e80403a52ccff14fc401be3e8ee184d83 -size 4387536 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index 1843d54d5917206c0947de8effc1cf347ea9e853..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:58116124bb2b5d11de2753dd0c30a1e4c84759f18599da7016c791bad37528e9 -size 2899984 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/parallel_style.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py index 938feeff791794d011fec65cf86df957e2c4da2f..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py @@ -1,6 +1,6 @@ import torch -from . import layers +from . import layers, parallel_style from ._ops import ops from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -48,5 +48,6 @@ __all__ = [ "rms_norm", "fused_add_rms_norm", "layers", + "parallel_style", "ops", ] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so deleted file mode 100644 index c8f702b9ecfdc1c01dcdd2880d088458c4f11c2d..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c3325c2748cf7a070383068995078f93f440cc95fbed491d00bd414cdd851376 -size 4171472 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..74dac254474746ee831611a77b8abf6015c4c63e --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9086fc0bb8cf675c3cd54d976117342203535d17d5fb29ab34ca72661cc6cc6 +size 2780440 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so deleted file mode 100644 index dafb119147ed94f04203dd8c8a366ef9a6ed7680..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b8d52dee20ba3c4619f7c614984f656f34f32dd74ba6cf866cf80f32245117cf -size 4393240 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so deleted file mode 100644 index 86ae5f11c05134ad7347aca293b13aeff2caf4c1..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:65319d3d93ac3bf0f2939fa4e53ddfc8cd633b9e396cde3a97d63b9041ba03a7 -size 2885344 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py index fa68616c13166de47619ed052ed1eba664998b82..3725c2b21e803832098265d4704e789c837084ef 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_e5e2eeb_dirty -ops = torch.ops._activation_e5e2eeb_dirty +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/parallel_style.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py index 0e2c29e955b87025e63f4795d58a14104318f736..2b3ab7e1476aba5d7799ff888449470e23665676 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import torch +from packaging import version from ._ops import ops @@ -8,9 +11,7 @@ class RMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, weight, eps): - output = torch.empty_like(input) - ops.rms_norm(output, input, weight, eps) - return output + return ops.rms_norm(input, weight, eps) @staticmethod # inputs is a Tuple of all of the inputs passed to forward. @@ -26,13 +27,8 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like( - input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[1] else None - - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, - weight, eps) + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) return input_grad, weight_grad, None @@ -42,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function): # Note that forward, setup_context, and backward are @staticmethods @staticmethod def forward(input, residual, weight, eps): - output = torch.empty_like(input) - add_output = torch.empty_like(input) - ops.fused_add_rms_norm(output, add_output, input, residual, weight, - eps) + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) return output, add_output @staticmethod @@ -65,14 +59,47 @@ class FusedAddRMSNormFunction(torch.autograd.Function): need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] - grad = torch.empty_like(output_grad) if need_in or need_res else None + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" - weight_grad = torch.empty_like( - weight) if ctx.needs_input_grad[2] else None - - ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, - weight, eps) + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) input_grad = grad if need_in else None residual_grad = grad if need_res else None return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm_meta.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/activation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/__init__.py @@ -0,0 +1,53 @@ +import torch + +from . import layers, parallel_style +from ._ops import ops +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +def poly_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return PolyNormFunction.apply(x, weight, bias, eps) + + +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return RMSNormFunction.apply(x, weight, eps) + + +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) + + +__all__ = [ + "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", + "layers", + "parallel_style", + "ops", +] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..2fa0da1af20f226e2b409389eb673e44b5b39682 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8d4e3841c349a5e51bd9be3fca4e2763b26e7cedff51869898fab4bf7da1f0a +size 8740024 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/activation/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3725c2b21e803832098265d4704e789c837084ef --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch29-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch29-cxx11-cu126-x86_64-linux/activation/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/layers.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from torch.nn import init + +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +class PolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return PolyNormFunction.apply(x, self.weight, self.bias, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return RMSNormFunction.apply(x, self.weight, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/parallel_style.py b/build/torch29-cxx11-cu126-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/poly_norm.py b/build/torch29-cxx11-cu126-x86_64-linux/activation/poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0fd85f1835e02a36eb9184874d77dcad8221f9 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/poly_norm.py @@ -0,0 +1,76 @@ +import torch + +from ._ops import ops + + +# Inherit from Function +class PolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, bias, eps): + output = torch.empty_like(input) + ops.poly_norm(output, input, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, bias, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) + + return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch29-cxx11-cu126-x86_64-linux/activation/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3ab7e1476aba5d7799ff888449470e23665676 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -0,0 +1,105 @@ +from collections.abc import Sequence + +import torch +from packaging import version + +from ._ops import ops + + +# Inherit from Function +class RMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, eps): + return ops.rms_norm(input, weight, eps) + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) + + return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + @staticmethod + def backward(ctx, output_grad, add_output_grad): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" + + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch29-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py b/build/torch29-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/activation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/__init__.py @@ -0,0 +1,53 @@ +import torch + +from . import layers, parallel_style +from ._ops import ops +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +def poly_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return PolyNormFunction.apply(x, weight, bias, eps) + + +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return RMSNormFunction.apply(x, weight, eps) + + +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) + + +__all__ = [ + "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", + "layers", + "parallel_style", + "ops", +] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1b4b266909da3429a6eccd61eaeec4f37ebb7aca --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5945c030dbc49cdedb421b3ba12c63b0798ab6e7ed4f494996071748800ae3f6 +size 13783920 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/activation/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3725c2b21e803832098265d4704e789c837084ef --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch29-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch29-cxx11-cu128-x86_64-linux/activation/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/layers.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from torch.nn import init + +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +class PolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return PolyNormFunction.apply(x, self.weight, self.bias, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return RMSNormFunction.apply(x, self.weight, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/parallel_style.py b/build/torch29-cxx11-cu128-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/poly_norm.py b/build/torch29-cxx11-cu128-x86_64-linux/activation/poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0fd85f1835e02a36eb9184874d77dcad8221f9 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/poly_norm.py @@ -0,0 +1,76 @@ +import torch + +from ._ops import ops + + +# Inherit from Function +class PolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, bias, eps): + output = torch.empty_like(input) + ops.poly_norm(output, input, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, bias, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) + + return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch29-cxx11-cu128-x86_64-linux/activation/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3ab7e1476aba5d7799ff888449470e23665676 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -0,0 +1,105 @@ +from collections.abc import Sequence + +import torch +from packaging import version + +from ._ops import ops + + +# Inherit from Function +class RMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, eps): + return ops.rms_norm(input, weight, eps) + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) + + return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + @staticmethod + def backward(ctx, output_grad, add_output_grad): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" + + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch29-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py b/build/torch29-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/activation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/__init__.py @@ -0,0 +1,53 @@ +import torch + +from . import layers, parallel_style +from ._ops import ops +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +def poly_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return PolyNormFunction.apply(x, weight, bias, eps) + + +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return RMSNormFunction.apply(x, weight, eps) + + +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) + + +__all__ = [ + "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", + "layers", + "parallel_style", + "ops", +] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..55fc72322664a4eac9e1b744578f3398e3602745 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74ca4a2bd28731463bd92faccb0c9077324f02cfbfe5e7c815bceb215a9fd68e +size 12493376 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/activation/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3725c2b21e803832098265d4704e789c837084ef --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch29-cxx11-cu130-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/layers.py b/build/torch29-cxx11-cu130-x86_64-linux/activation/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/layers.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from torch.nn import init + +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +class PolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return PolyNormFunction.apply(x, self.weight, self.bias, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return RMSNormFunction.apply(x, self.weight, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/parallel_style.py b/build/torch29-cxx11-cu130-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/poly_norm.py b/build/torch29-cxx11-cu130-x86_64-linux/activation/poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0fd85f1835e02a36eb9184874d77dcad8221f9 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/poly_norm.py @@ -0,0 +1,76 @@ +import torch + +from ._ops import ops + + +# Inherit from Function +class PolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, bias, eps): + output = torch.empty_like(input) + ops.poly_norm(output, input, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, bias, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) + + return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/rms_norm.py b/build/torch29-cxx11-cu130-x86_64-linux/activation/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3ab7e1476aba5d7799ff888449470e23665676 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/rms_norm.py @@ -0,0 +1,105 @@ +from collections.abc import Sequence + +import torch +from packaging import version + +from ._ops import ops + + +# Inherit from Function +class RMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, eps): + return ops.rms_norm(input, weight, eps) + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) + + return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + @staticmethod + def backward(ctx, output_grad, add_output_grad): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" + + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/activation/rms_norm_meta.py b/build/torch29-cxx11-cu130-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/activation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/__init__.py @@ -0,0 +1,53 @@ +import torch + +from . import layers, parallel_style +from ._ops import ops +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +def poly_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return PolyNormFunction.apply(x, weight, bias, eps) + + +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return RMSNormFunction.apply(x, weight, eps) + + +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) + + +__all__ = [ + "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", + "layers", + "parallel_style", + "ops", +] diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..715147c9d7a888621b5cd472a2875eee4c4b9896 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fc29c93c326e89e5f1c17fdf5c0df7992e56d5bfde50fe7f1f45bad491d9173 +size 2774936 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/activation/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3725c2b21e803832098265d4704e789c837084ef --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch29-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch29-cxx11-rocm63-x86_64-linux/activation/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/layers.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from torch.nn import init + +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +class PolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return PolyNormFunction.apply(x, self.weight, self.bias, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return RMSNormFunction.apply(x, self.weight, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/parallel_style.py b/build/torch29-cxx11-rocm63-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/poly_norm.py b/build/torch29-cxx11-rocm63-x86_64-linux/activation/poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0fd85f1835e02a36eb9184874d77dcad8221f9 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/poly_norm.py @@ -0,0 +1,76 @@ +import torch + +from ._ops import ops + + +# Inherit from Function +class PolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, bias, eps): + output = torch.empty_like(input) + ops.poly_norm(output, input, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, bias, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) + + return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch29-cxx11-rocm63-x86_64-linux/activation/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3ab7e1476aba5d7799ff888449470e23665676 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -0,0 +1,105 @@ +from collections.abc import Sequence + +import torch +from packaging import version + +from ._ops import ops + + +# Inherit from Function +class RMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, eps): + return ops.rms_norm(input, weight, eps) + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) + + return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + @staticmethod + def backward(ctx, output_grad, add_output_grad): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" + + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py b/build/torch29-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/activation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6f29ac2c688bd09afa41c5d1abd9942c4456d8 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/__init__.py @@ -0,0 +1,53 @@ +import torch + +from . import layers, parallel_style +from ._ops import ops +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +def poly_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return PolyNormFunction.apply(x, weight, bias, eps) + + +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return RMSNormFunction.apply(x, weight, eps) + + +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) + + +__all__ = [ + "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", + "layers", + "parallel_style", + "ops", +] diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d19ede12d734026716e7bfe4fe93591834cf5e7f --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/_activation_53ed492_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33be6e5daa4a8369a3931ffce7492c506eb7a6536ce786590f0f0359670502f4 +size 2784728 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/activation/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3725c2b21e803832098265d4704e789c837084ef --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _activation_53ed492_dirty +ops = torch.ops._activation_53ed492_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_activation_53ed492_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/fused_add_rms_norm_meta.py b/build/torch29-cxx11-rocm64-x86_64-linux/activation/fused_add_rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..a472844644bb93a27ae962cbc0fdc50c27ec780a --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/fused_add_rms_norm_meta.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_fused_add_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.fused_add_rms_norm.default, + schema_info=RuntimeSchemaInfo(1)) +def fused_add_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + input_strategy, + residual_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(residual_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "input": len(input_strategy.strategies), + "residual": len(residual_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, residual, weight in zip(input_strategy.strategies, + residual_strategy.strategies, + weight_strategy.strategies): + + input_src = input.output_spec + residual_src = residual.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(residual_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Residual add must have the same sharding as input. + residual_tgt = input_tgt + redistribute_costs.append( + generate_redistribute_costs(residual_strategy, residual_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, input_tgt], + input_specs=[input_tgt, residual_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.fused_add_rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(2)) +def fused_add_rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 6 + ( + output_grad_strategy, + add_output_grad_strategy, + add_output_strategy, + weight_strategy, + _, # eps + need_input_grad, # need_input_grad + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(add_output_grad_strategy, OpStrategy) + assert isinstance(add_output_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "add_output_grad": len(add_output_grad_strategy.strategies), + "add_output": len(add_output_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + assert len(set( + lengths.values())) == 1, f"Strategy length mismatch: {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + add_output_grad_strategy.strategies, + add_output_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = output_grad_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, add_output_grad, add_output, weight in zipped: + output_grad_src = output_grad.output_spec + add_output_grad_src = add_output_grad.output_spec + add_output_src = add_output.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(add_output_grad_src, DTensorSpec) + assert isinstance(add_output_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # add_output_grad must have the same sharding as output_grad. + add_output_grad_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_grad_strategy, + add_output_grad_tgt)) + + # add_output must have the same sharding as output_grad. + add_output_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(add_output_strategy, add_output_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[ + output_grad_tgt if need_input_grad else None, weight_tgt + ], + input_specs=[ + output_grad_tgt, add_output_grad_tgt, add_output_tgt, + weight_tgt + ], + redistribute_cost=redistribute_costs, + )) + return strategy diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/layers.py b/build/torch29-cxx11-rocm64-x86_64-linux/activation/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/layers.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from torch.nn import init + +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction + + +class PolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return PolyNormFunction.apply(x, self.weight, self.bias, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + ): + return RMSNormFunction.apply(x, self.weight, self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/parallel_style.py b/build/torch29-cxx11-rocm64-x86_64-linux/activation/parallel_style.py new file mode 100644 index 0000000000000000000000000000000000000000..470ab69d9889284f0be5cb075d5211eab30eb755 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/parallel_style.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import (DeviceMesh, DTensor, Replicate, Shard, + distribute_module, distribute_tensor) +from torch.distributed.tensor.parallel import SequenceParallel +from torch.distributed.tensor.placement_types import Placement + + +class ResidualSequenceParallel(SequenceParallel): + """ Consider the case where we have a residual connection across a sequence parallel layer.""" + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + residual_tensor = inputs[1] + + assert isinstance(input_tensor, + DTensor) == isinstance(residual_tensor, DTensor) + assert isinstance(input_tensor, + torch.Tensor) == isinstance(residual_tensor, + torch.Tensor) + + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True) + if residual_tensor.placements != sequence_sharding: + residual_tensor = residual_tensor.redistribute( + placements=sequence_sharding, async_op=True) + return input_tensor, residual_tensor + + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local(input_tensor, + device_mesh, + sequence_sharding, + run_check=False), DTensor.from_local( + residual_tensor, + device_mesh, + sequence_sharding, + run_check=False) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/poly_norm.py b/build/torch29-cxx11-rocm64-x86_64-linux/activation/poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0fd85f1835e02a36eb9184874d77dcad8221f9 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/poly_norm.py @@ -0,0 +1,76 @@ +import torch + +from ._ops import ops + + +# Inherit from Function +class PolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, bias, eps): + output = torch.empty_like(input) + ops.poly_norm(output, input, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, bias, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) + + return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/rms_norm.py b/build/torch29-cxx11-rocm64-x86_64-linux/activation/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3ab7e1476aba5d7799ff888449470e23665676 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/rms_norm.py @@ -0,0 +1,105 @@ +from collections.abc import Sequence + +import torch +from packaging import version + +from ._ops import ops + + +# Inherit from Function +class RMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, weight, eps): + return ops.rms_norm(input, weight, eps) + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, weight, eps = inputs + ctx.save_for_backward(input, weight) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, weight = ctx.saved_tensors + eps = ctx.eps + + input_grad, weight_grad = ops.rms_norm_backward( + output_grad, input, weight, eps) + + return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output, add_output = ops.fused_add_rms_norm(input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + @staticmethod + def backward(ctx, output_grad, add_output_grad): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + # TODO(ai-system): kernels currently do not support no input gradients + assert need_in or need_res, "Not implemented for no input gradients yet" + + grad, weight_grad = ops.fused_add_rms_norm_backward( + output_grad, + add_output_grad, + add_output, + weight, + eps, + need_input_grad=need_in or need_res) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None + + +@torch.library.register_fake(ops.rms_norm.default) +def rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@torch.library.register_fake(ops.rms_norm_backward.default) +def rms_norm_backward_abstract(output_grad, x, weight, eps): + return torch.empty_like(x), torch.empty_like(weight) + + +@torch.library.register_fake(ops.fused_add_rms_norm.default) +def fused_add_rms_norm_abstract(x, residual, weight, eps): + return torch.empty_like(x), torch.empty_like(x) + + +@torch.library.register_fake(ops.fused_add_rms_norm_backward.default) +def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad, + add_output, weight, eps, + need_input_grad: bool): + return torch.empty_like( + output_grad) if need_input_grad else None, torch.empty_like(weight) + + +if version.parse(torch.__version__) >= version.parse("2.8"): + from .fused_add_rms_norm_meta import register_fused_add_rms_norm_meta + from .rms_norm_meta import register_rms_norm_meta + register_fused_add_rms_norm_meta() + register_rms_norm_meta() diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/activation/rms_norm_meta.py b/build/torch29-cxx11-rocm64-x86_64-linux/activation/rms_norm_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..12527aef0e055c0836752a9dda814c4ce6f24832 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/activation/rms_norm_meta.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy, + RuntimeSchemaInfo) +from torch.distributed.tensor._ops.utils import (generate_redistribute_costs, + register_op_strategy) +from torch.distributed.tensor.placement_types import (Placement, Replicate, + Shard) + +from ._ops import ops + + +def register_rms_norm_meta(): + """Dummy function to register the meta functions. + Registration happens at import time by the decorators below. + """ + pass + + +def _replicate_dims_start_at(placements: Sequence[Placement], + start_dim: int = 0) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +@register_op_strategy(ops.rms_norm.default, schema_info=RuntimeSchemaInfo(1)) +def rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 3 + ( + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + assert len(input_strategy.strategies) == len(weight_strategy.strategies) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for input, weight in zip(input_strategy.strategies, + weight_strategy.strategies): + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Input can be sharded in any dim except the last dim. + input_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src.placements, + last_dim), + tensor_meta=input_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=input_tgt, + input_specs=[input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy + + +@register_op_strategy(ops.rms_norm_backward.default, + schema_info=RuntimeSchemaInfo(1)) +def rms_norm_backward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 4 + ( + output_grad_strategy, + input_strategy, + weight_strategy, + _, # eps + ) = op_schema.args_schema + + assert isinstance(output_grad_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(weight_strategy, OpStrategy) + + lengths = { + "output_grad": len(output_grad_strategy.strategies), + "input": len(input_strategy.strategies), + "weight": len(weight_strategy.strategies), + } + + assert len(set( + lengths.values())) == 1, f"Strategies length mismatch {lengths}" + + zipped = zip( + output_grad_strategy.strategies, + input_strategy.strategies, + weight_strategy.strategies, + ) + + last_dim = input_strategy.ndim - 1 + strategy = OpStrategy([]) + for output_grad, input, weight in zipped: + output_grad_src = output_grad.output_spec + input_src = input.output_spec + weight_src = weight.output_spec + + assert isinstance(output_grad_src, DTensorSpec) + assert isinstance(input_src, DTensorSpec) + assert isinstance(weight_src, DTensorSpec) + + redistribute_costs = [] + + # Output grad can be sharded in any dim except the last dim. + output_grad_tgt = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(output_grad_src.placements, + last_dim), + tensor_meta=output_grad_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(output_grad_strategy, output_grad_tgt)) + + # Input must have the same sharding as output grad. + input_tgt = output_grad_tgt + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_tgt)) + + # Weight cannot be sharded, so always replicate it. + weight_tgt = DTensorSpec( + mesh=mesh, + placements=(Replicate(), ), + tensor_meta=weight_src.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_tgt)) + + strategy.strategies.append( + OpSpec( + output_specs=[input_tgt, weight_tgt], + input_specs=[output_grad_tgt, input_tgt, weight_tgt], + redistribute_cost=redistribute_costs, + )) + return strategy