diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79dd0590cd73127cfa31670a242a725388423f7c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..123b6e0ea8ba3263777c0b6e71301d37ee778caa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4044dee12d2e879637d1a19956bbd18e98dbe8f7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ee4e5608610fdf535140fff8a53cedc3aecb370 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffbc245861ec1c43f71054a96b5d68988522c276 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..392d352119e7906c86bc3a9531353da76b404980 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c14189ec94168be01c8c378a5ea97c6e4f8c709f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..615fb64d0cfbc599e2c07e8fda56610e8d67a5e7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2db653edd2cf9569a2a41723d34da672b93683f1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eed6e898832936e58a24c475e2b2b1ec1c864bff Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a84fe41e137a6c4632cad41bcc8d28d019dc0c8a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..907935ee21c230dfec26922a2ea06d822ad751f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d225a0b628fc096bff1944ea134e98147d2c5c96 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a79b4a03bb0946064195f6d80648344ccebac65 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d196957465b08bd7113812a4d5e4de48e2fc01b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_IR.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_IR.py new file mode 100644 index 0000000000000000000000000000000000000000..036ad6cf8621eb4e0473f5ba6494f347d90d97d4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_IR.py @@ -0,0 +1,1243 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import copy +import logging +import operator +from collections import defaultdict +from enum import Enum +from inspect import Parameter, Signature, signature +from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.fx as fx +from torch.distributed import ProcessGroup +from torch.export import ExportedProgram +from torch.export.unflatten import ( + _assign_attr, + _AttrKind, + _sink_params, + InterpreterModule, +) +from torch.fx.node import map_aggregate +from torch.fx.passes.split_module import split_module + +from ._backward import _null_coalesce_accumulate, stage_backward +from ._unflatten import _outline_submodules +from ._utils import PipeInfo +from .stage import _PipelineStage + + +logger = logging.getLogger(__name__) + +# TODO: +# 1. investigate gradient sync for shared parameters. how does DDP do it? +# 2. Add parameter movement to split_module + + +def _find_loss_from_output_and_spec(output_val, spec_val): + if spec_val is False: + return None + if spec_val is True: + if not isinstance(output_val, fx.Node): + raise RuntimeError( + f"Loss spec must specify a dynamic value but got {output_val}" + ) + return output_val + + if isinstance(spec_val, (tuple, list)): + if not isinstance(output_val, (tuple, list)): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if len(output_val) != len(spec_val): + raise RuntimeError( + f"Output value {output_val} must match length of loss specification " + f"{spec_val}" + ) + for out, spec in zip(output_val, spec_val): + loss_val = _find_loss_from_output_and_spec(out, spec) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + if isinstance(spec_val, dict): + if not isinstance(output_val, dict): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if set(output_val.keys()) != set(spec_val.keys()): + raise RuntimeError( + f"Output value {output_val} must match keys of loss specification " + f"{spec_val}" + ) + for k in spec_val: + loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k]) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification") + + +def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec): + output_nodes = [n for n in g.nodes if n.op == "output"] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + output_val = output_node.args[0] + generated_spec: Any = None + + if isinstance(mod, TrivialLossWrapper): + # TrivialLossWrapper is pre-defined by PiPPy. + # It has loss as the only output so we can safely assume the first output arg is the loss. + assert len(output_node.args) == 1 + loss_node = output_val + generated_spec = TrivialLossWrapper.loss_spec + elif output_loss_value_spec is None: + # Use default spec, i.e. search for "loss" in output values + if isinstance(output_val, dict) and "loss" in output_val.keys(): + loss_node = output_val["loss"] + generated_spec = {k: k == "loss" for k in output_val} + else: + loss_node = None + generated_spec = None + else: + loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec) + generated_spec = output_loss_value_spec + + return loss_node, output_node, generated_spec + + +def _insert_stage_symbolic_backward( + g: fx.Graph, + loss_node: fx.Node, + output_node: fx.Node, +): + # Collect metadata about tuple output values. TODO: move this to split_module or FX IR + tuples: Dict[fx.Node, Tuple] = {} + for node in reversed(g.nodes): + if node.op == "call_function": + # In the forward pass, only emit placeholder, module calls, and + # getitem calls. If we have a target other than getitem in this + # (forward-only) code, there is a bug. + assert node.target == operator.getitem, ( + "Found non-getitem call in forward pass. " + "Please report a bug to PiPPy" + ) + assert ( + len(node.args) == 2 + ), "Found malformed getitem call. Please report a bug to PiPPy" + indexed_value, node_idx = tuple(node.args) + + # indexed_value is a collection that we are indexing into. It could + # exist in the tuples map if we've processed another `getitem` + # already. + existing_list_size = ( + len(tuples[indexed_value]) if indexed_value in tuples else -1 + ) + new_list_size = max(node_idx + 1, existing_list_size) + + reconstructed_list = [None for _ in range(new_list_size)] + + # Copy over existing elements if present + if indexed_value in tuples: + for i, val in enumerate(tuples[indexed_value]): + reconstructed_list[i] = val + + # Populate value represented by this node + reconstructed_list[node_idx] = node + + tuples[indexed_value] = tuple(reconstructed_list) + + # Keep track of nodes that dominate the loss node. + # We will only emit backward operations for nodes that can contribute + # to the specified loss value. + live_nodes = {loss_node: None} + val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None} + + def assign_or_accumulate_grad(forward_node, grad_value): + if forward_node in val_to_grad and forward_node.op != "placeholder": + grad_value = g.call_function( + _null_coalesce_accumulate, + (val_to_grad[forward_node], grad_value), + ) + val_to_grad[forward_node] = grad_value + + with g.inserting_before(output_node): + for node in reversed(g.nodes): + if node not in live_nodes: + continue + + def add_to_live_nodes(n): + live_nodes.setdefault(n, None) + + fx.node.map_arg(node.args, add_to_live_nodes) + fx.node.map_arg(node.kwargs, add_to_live_nodes) + if node.op == "call_module": + output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]] + if node in tuples: + stage_output = tuples[node] + output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node]) + outputs_with_grads_idxs = [ + i for i, n in enumerate(tuples[node]) if n in live_nodes + ] + else: + stage_output = (node,) + output_grads = val_to_grad[node] + outputs_with_grads_idxs = [0] + + output_grads = ( + (output_grads,) + if not isinstance(output_grads, tuple) + else output_grads + ) + + grad_call = g.call_function( + stage_backward, + kwargs={ + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": list(node.all_input_nodes), + "outputs_with_grads_idxs": outputs_with_grads_idxs, + }, + ) + # Insert backward stage debug info + kwargs_copy = dict(grad_call.kwargs) + grad_call.kwargs = kwargs_copy + + grad_call_proxy = fx.Proxy(grad_call) + grads = grad_call_proxy.node + + input_nodes = list(node.all_input_nodes) + grads_proxy = fx.Proxy(grads) + for i, input_node in enumerate(input_nodes): + assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index] + + return g + + +class PipeSequential(torch.nn.Sequential): + @staticmethod + def from_sequential(sequential_instance: torch.nn.Sequential): + return PipeSequential(*[copy.copy(m) for m in sequential_instance]) + + def forward(self, input): + for i, module in enumerate(self): + input = module(input) + if i != len(self) - 1: + pipe_split() + return input + + +class LossWrapper(torch.nn.Module): + """ + LossWrapper is a convenient abstract class that allows you to wrap up both + your model as well as its loss function and specify the connectivity between + the inputs, model, loss function, and output value. Example:: + + class MyModelWrapper(LossWrapper): + def forward(self, x, targets): + model_out = self.module(x) + loss_value = self.loss_fn(model_out, targets) + return loss_value + + The above example defines a connectivity where we expect the forward/loss/backward + training procedure to take two arguments (x and targets), pass x into the module + to get the output of the feedforward computation, pass the model output and the + targets value into the loss function, and get and return the loss value, which will + be backpropagated by PiPPy. The above class would then be instantiated like:: + + model = ... # instantiate the model + loss_fn = torch.nn.MSELoss() # for the sake of demonstration + + wrapper = MyModelWrapper(model, loss_fn) + pipe = Pipe.from_tracing(wrapper, ...) + + """ + + def __init__(self, module, loss_fn): + super().__init__() + self.module = module + self.loss_fn = loss_fn + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This instance of LossWrapper does not have an overridden" + "forward(). Please implement forward() to specify the arguments, " + "connection between the module and loss, and loss output " + "value." + ) + + +class TrivialLossWrapper(LossWrapper): + def forward(self, x, targets): + model_out = self.module(x) + return self.loss_fn(model_out, targets) + + loss_spec = True + + +# Pipe model representation +# +# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies +# a single topological ordering of pipeline "stages" that, when run in series, +# constitutes all of the operations of the program. However, unlike `nn.Sequential`, +# Pipe allows non-local usages of values, so long as those uses still respect +# topological ordering. In particular: +# +# 1. Non-local activations. This type of usage can appear in, for example, skip +# connections. These values will be directly transmitted from the "def" stage +# to all stages that use them skipping intermediate stages. During autograd, +# gradients will be propagated back through this skip connection reverse +# to how activations propagated in the forward pass. +# 2. Non-local parameter/module invocations. This occurs when a parameter is used +# in a stage downstream of where it is resident. These values can be carried +# forward similarly to (1), but in addition one might want to replicate the +# value on multiple stages. Gradients for these shared parameters will be +# accumulated separately on each stage, but there will be an additional +# gradient accumulation before the optimizer step. + + +# Register `_pipe_split()` as an ATen operator. This is required for Export to +# preserve this marker in the graph. +torch.library.define("pippy::_pipe_split", "() -> ()") + + +@torch.library.impl("pippy::_pipe_split", "BackendSelect") +def _pipe_split(): + return None + + +@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] +def _pipe_split(): # noqa: F811 + return None + + +# Add an alias for convenience +aten_pipe_split_alias = torch.ops.pippy._pipe_split.default + +# Ask Export to preserve the `_pipe_split` op. +# See examples in pytorch/torch/fx/node.py +fx.node._side_effectful_functions.add(aten_pipe_split_alias) + + +# User facing API +def pipe_split(): + """ + pipe_split is a special operator that is used to mark the boundary between + stages in a module. It is used to split the module into stages. It is a + no-op if your annotated module is run eagerly. + + Example: + >>> # xdoctest: +SKIP + >>> def forward(self, x): + >>> x = torch.mm(x, self.mm_param) + >>> x = torch.relu(x) + >>> pipe_split() + >>> x = self.lin(x) + >>> return x + + The above example will be split into two stages. + """ + return torch.ops.pippy._pipe_split() + + +class MultiUseParameterConfig(Enum): + TRANSMIT = 1 + REPLICATE = 2 + + +MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]] + + +class DetachExecutor(fx.Interpreter): + """ + Special interpreter to run the split_gm in testing that detaches all inputs to + a module invocation. This is needed so that the values at the boundary are + leaf modules in autograd execution. + """ + + def __init__(self, module, garbage_collect_values=True): + garbage_collect_values = False + super().__init__(module, garbage_collect_values) + self.value_remap = {} + + def run(self, *args, initial_env=None): + self.value_remap = {} + return super().run(*args, initial_env=initial_env) + + def call_module(self, target, args, kwargs): + def detach_tensors(a): + if isinstance(a, torch.Tensor) and a.requires_grad: + if a not in self.value_remap: + new_val = a.detach().requires_grad_(True) + self.value_remap[a] = new_val + return self.value_remap[a] + else: + return a + + """ + def dont_traverse_size(a): + return type(a) != torch.Size + """ + + args = map_aggregate( + args, + detach_tensors, # dont_traverse_size + ) + kwargs = map_aggregate( + kwargs, + detach_tensors, # dont_traverse_size + ) + + return super().call_module(target, args, kwargs) + + def call_function(self, target, args, kwargs): + # HACK to reroute saved input tensors to point to the detach()ed version + if target == stage_backward: + kwargs = dict(kwargs) + kwargs["input_values"] = [ + self.value_remap.get(v, v) for v in kwargs["input_values"] + ] + return super().call_function(target, args, kwargs) + + +class _NodeReference: + def __init__(self, name): + self.name = name + + name: str + + +class _LinearNodeList: + def __init__(self, node_list): + self.serialize_node_list = [] + for node in node_list: + node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + serialize_node = fx.Node( + graph=None, # type: ignore[arg-type] + name=node.name, + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + return_type=node.type, + ) + serialize_node.meta = copy.copy(node.meta) + self.serialize_node_list.append(serialize_node) + + def to_graph(self): + graph = fx.Graph() + + ref_str_to_node: Dict[str, fx.Node] = {} + + def ref_to_node(arg): + if isinstance(arg, _NodeReference): + return ref_str_to_node[arg.name] + else: + return arg + + for node in self.serialize_node_list: + node_args = map_aggregate(node.args, ref_to_node) + node_kwargs = map_aggregate(node.kwargs, ref_to_node) + deser_node = graph.create_node( + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + name=node.name, + type_expr=node.type, + ) + ref_str_to_node[node.name] = deser_node + + return graph + + +def _direct_serialization_deserialize(body, nodes): + """ + Custom `__reduce__` method for serialization. + DO AS I SAY -- NOT AS I DO. This violates the principle that + GraphModules serialize via code export & re-tracing. We allow + for this here because **PIPE STAGES SHOULD NOT BE PERSISTED + TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting + these instances to disk will expose internal implementation + details of `fx.Graph` and related data structures and is + NOT advised. + """ + + class DummyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__.update(body) + + dummy = DummyModule(body) + + return fx.GraphModule(dummy, nodes.to_graph()) + + +def _direct_serialization_reduce(self): + serialization_dict = dict(self.__dict__) + serialization_dict.pop("_graph") + return ( + _direct_serialization_deserialize, + (serialization_dict, _LinearNodeList(self.graph.nodes)), + ) + + +def _modify_graph_op_device( + gm: torch.fx.GraphModule, + new_device: torch.device, +): + """ + Modify the device argument of all "call_function" nodes in the graph. This + is useful for moving the graph to a different device. In particular for + generator ops, like torch.ones. + """ + modified = False + for node in gm.graph.nodes: + if node.op == "call_function": + if "device" in node.kwargs and node.kwargs["device"] != new_device: + logger.debug( + f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 + ) + node.update_kwarg("device", new_device) + modified = True + elif node.op == "call_module": + # Recursively modify "device" in submodules + submod = gm.get_submodule(node.target) + if isinstance(submod, torch.fx.GraphModule): + _modify_graph_op_device(submod, new_device) + elif isinstance(submod, InterpreterModule): + # If unflattening has been performed, we need to access its graph module by `.graph_module` + _modify_graph_op_device(submod.graph_module, new_device) + else: + logger.warning( + f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 + ) + + if modified: + gm.recompile() + + +class Pipe(torch.nn.Module): + def __init__( + self, + split_gm: fx.GraphModule, + num_stages: int, + has_loss_and_backward: bool, + loss_spec, + ): + # TODO: is there a way not to hard wire init? + torch.nn.Module.__init__(self) + self.split_gm: fx.GraphModule = split_gm + self.executor: DetachExecutor = DetachExecutor(self.split_gm) + self.num_stages: int = num_stages + self.has_loss_and_backward = has_loss_and_backward + self.loss_spec = loss_spec + + for node in split_gm.graph.nodes: + assert ( + node.op in {"call_module", "placeholder", "output"} + or (node.op, node.target) == ("call_function", operator.getitem) + or (node.op, node.target) == ("call_method", "backward") + or (node.op, node.target) == ("call_function", stage_backward) + or (node.op, node.target) + == ("call_function", _null_coalesce_accumulate) + ), node + + # Detect replicated parameters so we know that we have to do an additional allreduce + # before applying the optimizer + # + # Note that this also handles the case where there were multiple calls to a single + # module from different stages, regardless of whether that module invocation + # was handled by the logic above. + + # Map parameter value to a dictionary that maps the user pipeline module + # to the local qualname within that module + params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {} + + for m_qualname, mod in self.split_gm.named_children(): + for p_qualname, param in mod.named_parameters(): + params_to_users.setdefault(param, {}) + params_to_users[param][m_qualname] = p_qualname + + self.replicated_params: List[Dict[str, str]] = [ + use_mapping + for _, use_mapping in params_to_users.items() + if len(use_mapping) > 1 + ] + + # We must break the aliasing relationship between the replicated parameters for correct + # numerics in reference runs. If we do not do this, the autograd tape in separate stages + # will have a reference to the same tensor value and will erroneously apply gradient + # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the + # values so that we have separate instances. + for param_mapping in self.replicated_params: + for submod_name, param_qualname in param_mapping.items(): + submod = getattr(self.split_gm, submod_name) + atoms = param_qualname.split(".") + for atom in atoms[:-1]: + submod = getattr(submod, atom) + setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1]))) + + def throw(self, *args, **kwargs): + raise RuntimeError( + "To run pipeline locally, invoke the Pipe object directly, not `split_gm`" + ) + + self.split_gm.forward = throw + + # Make submodules use custom direct-serialized GraphModule + i = 0 + while True: + try: + name = f"submod_{i}" + submod = getattr(self.split_gm, name) + submod.__class__.__reduce__ = _direct_serialization_reduce + i += 1 + except AttributeError: + break + + def forward(self, *args, **kwargs): + executor_args = args + if len(kwargs) > 0: + parameters = [] + for node in self.split_gm.graph.nodes: + if node.op == "placeholder": + if node.args and len(node.args) > 0: + parameters.append( + Parameter( + node.target, + Parameter.POSITIONAL_OR_KEYWORD, + default=node.args[0], + ) + ) + else: + parameter_kind = Parameter.POSITIONAL_OR_KEYWORD + param_name = node.target + if node.target.startswith("**"): + parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment] + param_name = param_name[2:] + elif node.target.startswith("*"): + parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment] + param_name = param_name[1:] + parameters.append(Parameter(param_name, parameter_kind)) + signature = Signature(parameters) + ba = signature.bind(*args, **kwargs) + ba.apply_defaults() + executor_args = ba.arguments.values() # type: ignore[assignment] + + res = self.executor.run(*executor_args) + + return res + + def get_stage_module(self, stage_idx: int) -> torch.nn.Module: + """ + Return a stage module corresponding to `stage_idx` of the `pipe`. + """ + if stage_idx < 0 or stage_idx >= self.num_stages: + raise ValueError(f"Invalid stage index {stage_idx}!") + return getattr(self.split_gm, f"submod_{stage_idx}") + + @staticmethod + def _number_and_count_forward_stages(gm: fx.GraphModule): + num_stages = 0 + found_idxs: Dict[int, None] = {} + for node in gm.graph.nodes: + if node.op == "call_module" and node.target.startswith("submod_"): + node.meta["stage_idx"] = int(node.target[len("submod_") :]) + found_idxs.setdefault(node.meta["stage_idx"]) + num_stages += 1 + + # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule + # Update: the following assert may fail against some torch versions >= + # 2.2.0, as: + # submod_0, submod_1, submod_2, ... + # may be named as + # submod_0, submod_2, submod_4, ... + # TODO: investigate + # assert all(i in found_idxs for i in range(num_stages)) + + return num_stages + + @staticmethod + def _from_traced( + mod: torch.nn.Module, + exported_program: ExportedProgram, + multi_use_param_spec: Optional[MultiUseParamSpec] = None, + output_loss_value_spec=None, + split_policy: Optional[ + Callable[[torch.fx.GraphModule], torch.fx.GraphModule] + ] = None, + ): + """ + Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate + which value in the output of `forward` is the loss value on which PiPPy should apply + backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``, + you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns + a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify + ``output_loss_value_spec={'loss': True, 'model_out': False}`` + """ + + traced = exported_program.module() + + if split_policy is not None: + logger.info("Auto-splitting model") + traced = split_policy(traced) # type: ignore[arg-type] + + logger.debug(traced.print_readable(print_output=False)) + + # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving + # parameters relies on the invariant that parameter accesses happen once. This is not necessarily + # the case (especially with custom tracers), so fix that up here. + get_attr_nodes: Dict[str, fx.Node] = {} + for node in traced.graph.nodes: + if node.op == "get_attr": + get_attr_nodes.setdefault(node.target, node) + + if get_attr_nodes[node.target] != node: + node.replace_all_uses_with(get_attr_nodes[node.target]) + traced.graph.erase_node(node) + + # avoid looking at next node by keeping track of previous pipe_split + prev_pipe_split_idx = -1 + pipe_split_nodes_to_erase = set() + for i, node in enumerate(traced.graph.nodes): + if (node.op, node.target) == ("call_function", pipe_split): + if prev_pipe_split_idx == i - 1: + pipe_split_nodes_to_erase.add(node) + prev_pipe_split_idx = i + + for node in pipe_split_nodes_to_erase: + traced.graph.erase_node(node) + + traced.recompile() + + part_idx = 0 + + def split_callback(n: fx.Node): + nonlocal part_idx + if (n.op, n.target) == ( + "call_function", + aten_pipe_split_alias, + ): + logger.debug(f"Found pipe_split {part_idx}") # noqa: G004 + part_idx += 1 + return part_idx + + # TODO: what does split do with module invocations? does it move the modules + # into the submodules? + split = split_module(traced, mod, split_callback) # type: ignore[arg-type] + # a (custom) tracer can produce dead code like orphan get_attr nodes + split.graph.eliminate_dead_code() + + # peephole to remove pipe_split + for submodule in split.modules(): + if isinstance(submodule, fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ( + "call_function", + aten_pipe_split_alias, + ): + submodule.graph.erase_node(node) + submodule.recompile() + + for name, submodule in split.named_children(): + if isinstance(submodule, fx.GraphModule): + new_submod = _outline_submodules(submodule.graph) + # Replace old submod + split.register_module(name, new_submod) + + # TODO: backport this into split_module + def delete_user_reference(node, user): + """ + Delete reference of `node` from `user`'s arg list. + Args: + - node: a `get_attr` node at root. + - user: a submodule node that uses `node`. + """ + assert len(user.kwargs) == 0 + use_idxs = [i for i, arg in enumerate(user.args) if arg == node] + assert len(use_idxs) == 1 + args_copy = list(user.args) + args_copy.pop(use_idxs[0]) + user.args = tuple(args_copy) + logger.debug( + f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004 + ) + + # A list of param referrals for deferred deletion. + # To be accumulated in `move_param_to_callee`. + to_delete = [] + + def _recursive_getattr_with_parent(mod, fqn): + # Returns getattr call given a nested FQN, and the last parent + atoms = fqn.split(".") + for atom in atoms[:-1]: + if not hasattr(mod, atom): + return None, None + mod = getattr(mod, atom) + if not hasattr(mod, atoms[-1]): + return mod, None + attr = getattr(mod, atoms[-1]) + return mod, attr + + def move_param_to_callee( + root, + callee_name, + param_fqn, + ): + """ + Move a parameter from the root module to a submodule. + Args: + root: The root module. + callee_name: The name of the submodule to move the parameter to. + param_fqn: The fully qualified name of the parameter to move. + """ + # `atoms` is a list of strings representing the path to the + # parameter in the original model + atoms = param_fqn.split(".") + mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) + # Check whether the parameter is a buffer or a parameter + is_buffer = atoms[-1] in mod_itr._buffers + + # Check whether the parameter is a tensor + assert isinstance(param_val, torch.Tensor), ( + f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}." + + ( + f" It might happen if module '{param_fqn}' was passed to some 'leaf function'" + f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect " + f"usages of '{param_fqn}' in the traced graph." + if isinstance(param_val, torch.nn.Module) + else "" + ) + ) + + # Get submodule + callee = root.get_submodule(callee_name) + assert not hasattr( + callee, param_fqn + ), f"Module {callee_name} already has a parameter named {param_fqn}" + + # Assign the parameter to the submodule + if is_buffer: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.BUFFER, + persistent=True, # TODO: handle non-persistent buffer + ) + else: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.PARAMETER, + ) + logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004 + + # Next step is to replace placeholder of submodule with a get_attr. + # Those placeholders are created by `split_module` inside each + # submodule. + # Update: this step is now moved to `_sink_params` because + # `_sink_params` can do it recursively (i.e. for modules inside + # submodule) + + to_delete.append((mod_itr, atoms[-1])) + + # Get the list of all parameters in the root module + attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes)) + for node in attr_nodes: + # Check whether the parameter is used in only one submodule + if len(node.users) > 1: + logger.info( + f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004 + ) + for user in node.users: + assert user.op == "call_module" + # Move parameter into submodule + move_param_to_callee( + split, + user.target, + node.target, + ) + + # [aliasing] store tensor id -> list of FQNs, built from state dict + # Also assign non-persistent buffers + id_to_fqns: Dict[int, Set[str]] = defaultdict(set) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + id_to_fqns[id(tensor)].add(fqn) + for fqn, tensor in mod.named_buffers(): + id_to_fqns[id(tensor)].add(fqn) + + # After moving the params to their corresponding hierarchies, we also + # need to move the `get_attr` nodes from the root of the graph to those + # hierarchies. + # [aliasing] use id -> fqn mapping to list out all valid FQNs + inputs_to_state: Dict[str, List[str]] = {} + for attr in attr_nodes: + _, tensor = _recursive_getattr_with_parent(mod, attr.target) + fqns = list(id_to_fqns[id(tensor)]) + if fqns: + inputs_to_state[attr.name] = fqns + elif attr.target in exported_program.constants: # lifted constants + inputs_to_state[attr.name] = [attr.target] + + # [aliasing] for each submodule split, assign attributes on FQNs that may be used. + # We determine this based on whether or not the FQN attribute parent exists. + # i.e. if the last submodule exists, assign the attribute. + added_attributes: Dict[str, List[str]] = defaultdict(list) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + parent, child = _recursive_getattr_with_parent(submod, fqn) + if ( + parent and child is None + ): # parent exists, attribute doesn't -> assign + added_attributes[name].append(fqn) + setattr(parent, fqn.split(".")[-1], tensor) + + # Deferral deletion: Remove the original attributes (to params) from the + # root GraphModule + for mod_itr, last_atom in to_delete: + try: + delattr(mod_itr, last_atom) + except AttributeError: + # This is expected if the parameter is used in multiple stages + pass + + # This is done by (1) `_sink_params` at each submodule; + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + _sink_params(submod, inputs_to_state, []) + submod.graph.lint() + submod.recompile() + + # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory. + # After _sink_params() routine has run, clean up unused attributes that we previously added. + # Determine this based on the get_attr nodes - if not used, remove it. + for name, attributes in added_attributes.items(): + submod = getattr(split, name) + unused_attributes = set(attributes) + # track used attributes in the submodule, running DFS on subgraph hierarchy + stack = [("", submod)] # (scope, submodule) + while stack: + scope, _mod = stack.pop() + if isinstance(_mod, (fx.GraphModule, InterpreterModule)): + for node in _mod.graph.nodes: + if node.op == "get_attr": + # get_attr might get access deeper level attribute + fqn = scope + "." + node.target if scope else node.target + if fqn in unused_attributes: # used, remove it + unused_attributes.remove(fqn) + for _name, _submod in _mod.named_children(): + stack.append((scope + "." + _name if scope else _name, _submod)) + # delete unused attributes + for attr in unused_attributes: + mod_itr, atoms = submod, attr.split(".") + for atom in atoms[:-1]: + mod_itr = getattr(mod_itr, atom) + delattr(mod_itr, atoms[-1]) + + for node in attr_nodes: + # And (2): remove `get_attr` node from submod's arg list + for user in copy.copy(node.users): + assert user.op == "call_module" + delete_user_reference(node, user) + # And (3): remove the `get_attr` node from the root graph. + split.graph.erase_node(node) + + split.delete_all_unused_submodules() + split.graph.lint() + split.recompile() + + num_stages = Pipe._number_and_count_forward_stages(split) + + has_loss_and_backward = False + generated_loss_spec = output_loss_value_spec + + if output_loss_value_spec is not None: + loss_node, output_node, generated_loss_spec = _find_loss_output( + mod, split.graph, output_loss_value_spec + ) + if loss_node is not None: + _insert_stage_symbolic_backward( + split.graph, + loss_node, + output_node, + ) + split.recompile() + has_loss_and_backward = True + logger.debug("Pipeline is in training mode, backward pass generated") + else: + raise RuntimeError( + f"Did not find any loss value according to {output_loss_value_spec=}" + ) + else: + logger.debug("Pipeline is in inference mode, backward pass not generated") + + logger.debug("Full pipe model:\n" f"{split}") # noqa: G004 + + return Pipe( + split, + num_stages, + has_loss_and_backward, + generated_loss_spec, + ) + + def print_readable(self): + """ + Print the pipe in a human-readable format. + This will print both the root pipe and each stage module. + """ + self.split_gm.print_readable() + + @staticmethod + def _trace_with_export( + mod: torch.nn.Module, + example_args: Tuple[Any, ...], + example_kwargs: Optional[Dict[str, Any]] = None, + ) -> ExportedProgram: + logger.info("Tracing model ...") + try: + ep = torch.export.export( + mod, + example_args, + example_kwargs, + ) + except Exception as e: + raise RuntimeError( + "It seems that we cannot capture your model as a full graph. " + "Typical reasons include graph breaks, data/shape-dependent " + "control flow, or missing meta kernels for custom operators. " + "You can use our manual pipeline interfaces, or try to fix the " + "graph breaks, see https://pytorch.org/docs/stable/export.html" + ) from e + + return ep + + @staticmethod + def from_tracing( + mod: torch.nn.Module, + example_args: Tuple[Any, ...], + example_kwargs: Optional[Dict[str, Any]] = None, + split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, + ): + # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across + # stages instead of TRANSMIT'ting it + multi_use_param_spec = MultiUseParameterConfig.REPLICATE + + # Figure out which output is loss from output_chunk_spec + output_loss_value_spec: Any = None + # Deprecated + """ + if output_chunk_spec is not None: + output_loss_value_spec = map_aggregate( + output_chunk_spec, lambda v: isinstance(v, _LossReducer) + ) + """ + + # Trace with export + exported_program = Pipe._trace_with_export( + mod, + example_args, + example_kwargs, + ) + + pipe = Pipe._from_traced( + mod, + exported_program, + multi_use_param_spec, + output_loss_value_spec=output_loss_value_spec, + split_policy=split_policy, + ) + + # Users want the first pipeline stage to accept kwargs if the original + # program does. This is controlled by the `_codegen` field of the graph, + # so we make a copy here. Note: we only want the input spec and not the + # output spec, because the output spec is for the last stage. Maybe a + # TODO? Not sure yet. + split = pipe.split_gm + traced = exported_program.module() + submod0 = next(iter(split.children())) + submod0_sign = signature(submod0.forward) + model_sign = signature(traced.forward) + if len(model_sign.parameters) != len(submod0_sign.parameters): + # We don't change the signature of the first stage if it takes + # different number of args than original model + logger.info( + f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004 + f"first pipeline stage takes {len(submod0_sign.parameters)}. " + "Please provide args to respective pipeline stages." + ) + else: + # Support kwargs for the first stage + submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) + # `_replace` is actually not "private" or internal. based on this doc: + # To prevent conflicts with field names, the method and attribute names + # start with an underscore + submod0.graph._codegen.pytree_info = ( + submod0.graph._codegen.pytree_info._replace(out_spec=None) + ) + submod0.recompile() + + return pipe + + def __str__(self): + return self.split_gm.__str__() + + def __repr__(self): + return self.split_gm.__repr__() + + def info(self) -> PipeInfo: + """ + Get information about the pipe. + + Returns + ------- + PipeInfo + A dataclass containing information about the pipe. + """ + return PipeInfo( + graph=self.split_gm.graph, + num_stages=self.num_stages, + has_loss_and_backward=self.has_loss_and_backward, + ) + + def build_stage( + self, + stage_index: int, + device: torch.device, + group: Optional[ProcessGroup] = None, + ) -> _PipelineStage: + """ + Create a `PipelineStage` given a stage index and distributed group. + The `PipelineStage` can run with `PipelineSchedule`s. + """ + # Find stage module + stage_module = self.get_stage_module(stage_index) + + # Move ops argument to device + # Today PT2 tracer does not treat `x.device` as a symbolic device; + # instead, the device of tracing time got burned into the generated + # code. Here we provide a workaround for users to manually modify the + # "device" kwarg of operations. Such operation may include: + # `torch.ones`, `torch.zeros`, `torch.rand`, etc. + if isinstance(stage_module, torch.fx.GraphModule): + _modify_graph_op_device(stage_module, device) + else: + logger.warning( + f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004 + ) + + # Detach pipe info + # Note: be careful what's included in `pipe_info`. We don't want to keep + # a reference to `Pipe` or `Pipe.split_gm` which stops python from + # recycling them. When python recycles them, other stage modules (which + # are irrelevant to current rank) can be automatically freed. + pipe_info = self.info() + return _PipelineStage(stage_module, stage_index, pipe_info, device, group) + + +class SplitPoint(Enum): + BEGINNING = 1 + END = 2 + + +# For backward compatibility, we kept the PipeSplitWrapper class because `class +# SplitPoint` used to be defined in this class. +class PipeSplitWrapper: + # Create a class alias for BC + SplitPoint = SplitPoint + + +def _split_before_forward(self, *args, **kwargs): + pipe_split() + return self._orig_forward(*args, **kwargs) + + +def _split_after_forward(self, *args, **kwargs): + try: + return self._orig_forward(*args, **kwargs) + finally: + pipe_split() + + +def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): + # TODO: make this implementation out-of-place? + for qualname, split_type in spec.items(): + atoms = qualname.split(".") + predecessor_module = mod + for i, atom in enumerate(atoms[:-1]): + try: + predecessor_module = getattr(predecessor_module, atom) + except AttributeError as e: + raise AttributeError( + f"Specified target {qualname} referenced " + f'nonexistent module {".".join(atoms[: i + 1])}' + ) from e + + mod_to_wrap = getattr(predecessor_module, atoms[-1]) + mod_to_wrap._orig_forward = mod_to_wrap.forward + if split_type == SplitPoint.BEGINNING: + mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap) + elif split_type == SplitPoint.END: + mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap) + else: + raise ValueError("Unknown split point type.") + + +def pipeline( + module: torch.nn.Module, + mb_args: Tuple[Any, ...], + mb_kwargs: Optional[Dict[str, Any]] = None, + split_spec: Optional[Dict[str, SplitPoint]] = None, + split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, +) -> Pipe: + """ + Split a module based on a specification. + + See `Pipe` for more details. + + Arguments + --------- + module: + The module to be splitted. + mb_args: + Example positional inputs, in micro-batch form. + mb_kwargs: + Example keyword inputs, in micro-batch form. (default: `None`) + split_spec: + A dictionary using submodule names as split marker. (default: `None`) + split_policy: + The policy to use for splitting the module. (default: `None`) + + Returns + ------- + A pipeline representation of class `Pipe`. + """ + if split_spec is not None and split_policy is not None: + raise ValueError( + "Cannot specify both `split_spec` and `split_policy`. Please use only one of them." + ) + + if split_spec is not None: + # Annotate split points in the module based on user spec + annotate_split_points(module, split_spec) + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + ) + else: + # Use split policy + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + split_policy=split_policy, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__init__.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..476bf6a18a08713aa2d01c2cc8b0e9087269a7ef --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from ._IR import Pipe, pipe_split, pipeline, SplitPoint +from .schedules import ( + _ScheduleForwardOnly, + Schedule1F1B, + ScheduleFlexibleInterleaved1F1B, + ScheduleGPipe, + ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, + ScheduleLoopedBFS, +) +from .stage import build_stage, PipelineStage + + +__all__ = [ + "Pipe", + "pipe_split", + "SplitPoint", + "pipeline", + "PipelineStage", + "build_stage", + "Schedule1F1B", + "ScheduleFlexibleInterleaved1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", +] diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..643040ad66556f0ef3beac2d07ffea6d5fd4c995 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..968b0dbb977f90bb80e806ef6e3ff8583b81ed80 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..477c980875c4c58e36eb256c9784a997d903f594 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d3c0997f01638d9fbf911e491fb8e3769b3280a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8527529b16a8a8a7811fc0b935293378c3caf65 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..692c8bbc6a00d974ee5a6b3691f380e63b9a9ff9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb14e2bd841f6f478953f52b84e689911809f96f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03883fc2bc81e9352bcdb33a127460903207b4f0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93c47f771d51a555158eb75a89f80dcb221c6982 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_backward.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c5516e83da0175f9fb7cc1c836b562ac10c1fd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_backward.py @@ -0,0 +1,370 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import collections +import logging +import weakref +from typing import Any, cast, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union + +import torch +from torch.autograd.graph import GradientEdge, Node +from torch.nn import Parameter + +from ._debug import map_debug_info + + +logger = logging.getLogger(__name__) + + +def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]: + """ + Get the grad function or grad accumulator for a tensor. + + Accumulate grad nodes are lazily created, so we need to a + dummy view in order to trigger its creation. + """ + if t.requires_grad and t.grad_fn is None: + # if no grad function (leaf tensors) we use view + viewed_t = t.view_as(t) + grad_fn = viewed_t.grad_fn + if grad_fn is not None: + return grad_fn.next_functions[0][0] + else: + raise RuntimeError( + "Attempted to get grad_fn, but got None." + "Is this being created in a no-grad context?" + ) + else: + return t.grad_fn + + +def reverse_closure( + roots: List[Node], target_nodes: Set[Node] +) -> Tuple[Set[Node], Set[Node]]: + """ + This function returns the reverse closure of the given roots, + i.e. the set of nodes that can be reached from the roots by following the + reverse edges of the graph. The target_nodes are the nodes that we want to + include in the closure. + """ + # Recurse until we reach a target node + closure: Set[Node] = set() + visited_target_nodes = set() + q: Deque[Node] = collections.deque() + for node in roots: + if node is not None and node not in closure: + closure.add(node) + q.append(node) + while q: + node = q.popleft() + metadata = cast(Dict[str, List], node.metadata) + reverse_edges = metadata.get("reverse_edges", []) + for holder_ref, idx in reverse_edges: + ref = holder_ref() + if ref is None: + # this reverse graph is no longer alive + # raise RuntimeError("Reverse graph is no longer alive") + continue + fn = ref.node + if fn in closure or fn is None: + continue + if fn in target_nodes: + visited_target_nodes.add(fn) + continue + closure.add(fn) + q.append(fn) + return closure, visited_target_nodes + + +# Enable weak pointer +class Holder: + def __init__(self, node: Node): + self.node = node + + +def construct_reverse_graph(roots: List[Node]) -> List[Holder]: + q: Deque[Node] = collections.deque() + root_seen: Set[Node] = set() + reverse_graph_refs: List[Holder] = [] + for node in roots: + if node is not None and node not in root_seen: + q.append(node) + root_seen.add(node) + while q: + node = q.popleft() + for fn, idx in node.next_functions: + if fn is not None: + # Don't necessarily need to store on the graph + metadata = cast(Dict[str, List], fn.metadata) + reverse_edges = metadata.get("reverse_edges", []) + if len(reverse_edges) == 0: + q.append(fn) + holder = Holder(node) + holder_ref = weakref.ref(holder) + reverse_graph_refs.append(holder) + reverse_edges.append((holder_ref, idx)) + metadata["reverse_edges"] = reverse_edges + return reverse_graph_refs + + +def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, Any]]: + """ + Given a list of inputs and a list of parameters, return a list of parameter + groups, where each group contains the parameters and the intermediates that + are connected to the parameters. + + The returned list of parameter groups is a list of dictionaries, where each + dictionary contains the following keys: + - "params": a set of parameters + - "intermediates": a set of intermediates + + The returned list of parameter groups is a list of dictionaries, + """ + # reverse graph that starts with inputs, and goes up to the dOutput or the loss, + # but omits weights and any subgraphs connecting weights to this closure + inputs_closure, _ = reverse_closure(inputs, set()) + param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates + for i, param in enumerate(params): + closure, intersected = reverse_closure([param], inputs_closure) + param_group: Dict[str, Set] = { + "params": {param}, + "intermediates": intersected, + } + for input_node in intersected: + existing = param_groups.get(input_node, None) + if existing is not None: + existing["params"] = existing["params"].union(param_group["params"]) + existing["intermediates"] = existing["intermediates"].union( + param_group["intermediates"] + ) + param_group = existing + else: + param_groups[input_node] = param_group + + # Sanity check: union of all param_groups params should be equal to all params + union_params: Set[Node] = set() + seen_ids: Set[int] = set() + unique_param_groups = [] + for param_group in param_groups.values(): + if id(param_group) not in seen_ids: + seen_ids.add(id(param_group)) + unique_param_groups.append(param_group) + union_params = union_params.union(param_group["params"]) + + # The assert will only be true if the input tensor requires gradients, + # otherwise the autograd graph will miss the first layer of inputs + # assert union_params == set(params) + return unique_param_groups + + +def stage_backward_input( + stage_outputs: List[torch.Tensor], + output_grads: Optional[List[torch.Tensor]], + input_values: List[torch.Tensor], + weights: Iterator[Parameter], +): + """ + compute the gradients for only the stage inputs with respect to the stage outputs + """ + stage_output_grad_fns: List[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs)) + ) + stage_input_grad_fns: List[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, input_values)) + ) + weight_grad_fns: List[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, weights)) + ) + + reverse_graph_refs = construct_reverse_graph(stage_output_grad_fns) + param_groups = get_param_groups(stage_input_grad_fns, weight_grad_fns) + del reverse_graph_refs + + for param_group in param_groups: + for i, intermediate in enumerate(param_group["intermediates"]): + + def get_hook(param_group, i): + def hook(grad_inputs): + if param_group.get("grads", None) is None: + param_group["grads"] = [None] * len( + param_group["intermediates"] + ) + param_group["grads"][i] = grad_inputs + + return hook + + # These are always "split" nodes that we need to recompute, so + # save their inputs. + intermediate.register_prehook(get_hook(param_group, i)) + + # Stage 0 inputs do not require grads? Should we skip in that case? + if all(tensor.requires_grad for tensor in input_values): + if output_grads is None: + # In case this is the loss and there are no output_grads, then we just use 1s + output_grads = [ + torch.ones_like(stage_output) for stage_output in stage_outputs + ] + + dinputs = torch.autograd.grad( + stage_outputs, + inputs=input_values, + grad_outputs=output_grads, + retain_graph=True, + ) + + # update the gradients for inputs + for i, inp in enumerate(input_values): + if inp.grad is None: + inp.grad = dinputs[i] + else: + inp.grad += dinputs[i] + else: + dinputs = None + return dinputs, param_groups + + +def stage_backward_weight( + weights: Iterator[Parameter], param_groups: List[Dict[str, Any]] +): + # map weights to param_group_weights + grad_acc_to_weight = {} + weight_grads = [] + for index, weight in enumerate(weights): + grad_acc = _get_grad_fn_or_grad_acc(weight) + grad_acc_to_weight[grad_acc] = weight, index + weight_grads.append(weight.grad) + + for param_group in param_groups: + # TODO: Handle case where intermediate can have multiple outputs + intermediate_edges = tuple( + GradientEdge(i, 0) for i in param_group["intermediates"] + ) + weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) + + assert all(len(g) == 1 for g in param_group["grads"]) + # [NEW!] Able to pass a GradientEdge to autograd.grad as output + # We do not need to retain_graph because... guarantee no overlap? + # print("trying to execute: ", intermediate_edges, weights_edges) + dweights = torch.autograd.grad( + intermediate_edges, + weights_edges, + grad_outputs=sum(param_group["grads"], tuple()), + ) + for grad_acc, dw in zip(param_group["params"], dweights): + weight, index = grad_acc_to_weight[grad_acc] + if weight.grad is None: + weight.grad = dw + else: + weight.grad += dw + # return grads in the original order weights were provided in + return weight_grads + + +def stage_backward( + stage_output, + output_grads, + input_values, + outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used +): + """ + This is a helper function to: + 1. compute the gradients for the stage inputs, and + 2. accumulate gradients for the stage module's parameters. + + Given the input value(s) and the corresponding gradient for the output + value(s), compute and accumulate gradients for all parameter values (leaves + in the autograd trace) as well as return a list of the gradients for the + input values + """ + if outputs_with_grads_idxs is not None: + # Deprecated, not used in runtime calls, only exists in compiler + stage_output = [stage_output[i] for i in outputs_with_grads_idxs] + output_grads = [output_grads[i] for i in outputs_with_grads_idxs] + + try: + # stage_output may be a composite datatype like dict. Extract all individual + # tensor values here + stage_output_tensors = [] + output_grad_tensors = [] + + def extract_tensors_with_grads(output_val, grad_val): + if isinstance(output_val, torch.Tensor): + if not output_val.requires_grad and output_val.grad_fn is None: + return + assert isinstance( + grad_val, (torch.Tensor, type(None)) + ), f"Expected Tensor or None gradient but got {type(grad_val)}" + stage_output_tensors.append(output_val) + output_grad_tensors.append(grad_val) + elif isinstance(output_val, (tuple, list)): + if grad_val is None: + return + assert isinstance( + grad_val, (tuple, list) + ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + assert len(output_val) == len(grad_val) + for ov, gv in zip(output_val, grad_val): + extract_tensors_with_grads(ov, gv) + elif isinstance(output_val, dict): + if grad_val is None: + return + assert isinstance(grad_val, dict) + assert set(output_val.keys()) == set(grad_val.keys()) + for k in output_val.keys(): + extract_tensors_with_grads(output_val[k], grad_val[k]) + else: + # Output is a non-tensor type; just ignore it + pass + + extract_tensors_with_grads(stage_output, output_grads) + + torch.autograd.backward( + stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] + ) + + # Extract gradients wrt the input values + grad_inputs = [] + for val in input_values: + if isinstance(val, torch.Tensor): + grad_inputs.append(val.grad) + else: + grad_inputs.append(None) + + # Alternative impl: `torch.autograd.grad`. + # Note that `torch.autograd.grad` will not accumulate gradients into the + # model's parameters. + """ + inputs_with_grad = [] + for val in input_values: + if isinstance(val, torch.Tensor) and val.requires_grad: + inputs_with_grad.append(val) + + grad_inputs = torch.autograd.grad( + stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] + ) + """ + + except Exception as e: + exc_msg = f""" + Failed to run stage backward: + Stage output: {map_debug_info(stage_output)} + Output gradient: {map_debug_info(output_grads)} + Input: {map_debug_info(input_values)} + """ + raise RuntimeError(exc_msg) from e + + return grad_inputs + + +# TODO: handling requires_grad=False dynamically. Can we analyze this during initial +# IR emission? +def _null_coalesce_accumulate(lhs, rhs): + """ + Coalesce two values, even if one of them is null, returning the non-null + value. + """ + if lhs is None: + return rhs + elif rhs is None: + return lhs + else: + return torch.add(lhs, rhs) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_debug.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..6b153ec78d8902bb2187ffc25eb902593a30db52 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_debug.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch + + +def friendly_debug_info(v): + """ + Helper function to print out debug info in a friendly way. + """ + if isinstance(v, torch.Tensor): + return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})" + else: + return str(v) + + +def map_debug_info(a): + """ + Helper function to apply `friendly_debug_info` to items in `a`. + `a` may be a list, tuple, or dict. + """ + return torch.fx.node.map_aggregate(a, friendly_debug_info) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_unflatten.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_unflatten.py new file mode 100644 index 0000000000000000000000000000000000000000..659c9804a96691516d2458e1d6fac3cee2d5ccb4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_unflatten.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Dict + +import torch +from torch.export.unflatten import _ModuleFrame + + +def _outline_submodules(orig_graph: torch.fx.Graph): + # Create an empty GraphModule to hold the outlined modules + new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + seen_nodes: Dict[str, torch.fx.Node] = {} + seen_modules: Dict[int, torch.nn.Module] = {} + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + None, + [""], + "", + {}, + module=new_module, + ).run_outer() + new_module.graph.lint() + new_module.recompile() + return new_module diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_utils.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf7097795868684c5e51b55e345ddbc66fe76325 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_utils.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +from dataclasses import dataclass +from typing import List, Tuple, Union + +import torch +from torch import fx + + +logger = logging.getLogger(__name__) + + +def flatten_args_detach(args): + """ + Flatten the args into a list form and detach the tensors from computational graph. + """ + flat_detached_args = [] + + def extract_tensor_args(a): + nonlocal flat_detached_args + if isinstance(a, torch.Tensor): + val = a.detach().requires_grad_(a.requires_grad) + flat_detached_args.append(val) + return val + else: + flat_detached_args.append(a) + return a + + new_args = fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return new_args, flat_detached_args + + +def flatten_args(args): + """ + Flatten the args into a list form. + """ + flat_args = [] + + def extract_tensor_args(a): + nonlocal flat_args + flat_args.append(a) + return a + + fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return flat_args + + +class PipeliningShapeError(RuntimeError): + """Shape mismatch between configured and runtime values.""" + + +def validate_tensor_metadata(desc, expected, given): + if not expected.shape == given.shape: + raise PipeliningShapeError( + f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" + ) + if not expected.dtype == given.dtype: + raise PipeliningShapeError( + f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" + ) + if not expected.stride() == given.stride(): + raise PipeliningShapeError( + f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}" + ) + + +def validate_tensors_metadata( + desc, + expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]], + actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]], +): + if len(expected_tensors) != len(actual_tensors): + raise PipeliningShapeError( + f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" + ) + for i in range(len(expected_tensors)): + validate_tensor_metadata( + f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] + ) + + +@dataclass +class PipeInfo: + """ + Captures information for a pipeline (`Pipe` object). + """ + + graph: fx.Graph + num_stages: int + has_loss_and_backward: bool diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/microbatch.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/microbatch.py new file mode 100644 index 0000000000000000000000000000000000000000..2c276b7d6a557cf671fc0a3fc253383a121cbe17 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/microbatch.py @@ -0,0 +1,469 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch.fx.node import map_aggregate +from torch.utils._pytree import tree_flatten, tree_unflatten + + +__all__ = [ + "TensorChunkSpec", + "split_args_kwargs_into_chunks", + "merge_chunks", +] + +logger = logging.getLogger(__name__) + +""" +_debug_mask_minibatches specifies to send masked versions of the mini-batch +through instead of micro-batch slices--this can be used for more stable +numerical testing (see [A Note About Correctness Testing]) +""" +_debug_mask_minibatches = False + + +class _CustomReducer: + """ + Custom reducer class that can be used to specify a custom operation that + reduces losses of multiple microbatches into one value. + + Example: + >>> # xdoctest: +SKIP + >>> sum_reducer = _CustomReducer( + >>> torch.tensor(0.0), + >>> lambda a, b: a + b + >>> ) + """ + + def __init__(self, init_value, reduce_fn): + self.init_value = init_value + self.reduce_fn = reduce_fn + + +class _LossReducer(_CustomReducer): + pass + + +sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b) + +# Default chunking dimension is 0. This is used for the case where the user did +# not specify a chunking dimension. +DEFAULT_CHUNK_DIM = 0 + + +class TensorChunkSpec: + """ + Class used to specify chunking of inputs + """ + + def __init__(self, split_dim): + self.split_dim = split_dim + + split_dim: int + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" + ) + + def __str__(self): + return f"TensorChunkSpec({self.split_dim})" + + @staticmethod + def from_tuple( + chunk_dims: Tuple[int, ...], + ): + """ + A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk + dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # There are three positional arguments to the model, and + >>> # we are chunking them along dimension 0, 0 and 1, respectively + >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) + """ + args_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return args_chunk_spec + + @staticmethod + def from_dict( + chunk_dims: Dict[str, int], + ): + """ + A helper for creating a dictionary of `TensorChunkSpec` from a + dictionary of chunk dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument + >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) + """ + kwargs_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return kwargs_chunk_spec + + +# Class used to specify replication of inputs +class _Replicate: + pass + + +def _shard_dict_of_args( + args_dict, + args_chunk_spec, + num_chunks, +): + """ + Given a dictionary of args, and a dictionary of chunking specs, shard the + args according to the chunking specs. + + Args: + args_dict: Dictionary of args + args_chunk_spec: Dictionary of chunking specs + num_chunks: Number of chunks to shard the args into + + Returns: + args_split: List of sharded args + """ + # Stage 1+2: flatten and shard/replicate + + # args_sharded_replicated : [num args, num flat values, num chunks] + args_sharded_replicated = {} + arg_specs = [] + + real_num_chunks = num_chunks + first_tensor = True + + assert len(args_dict) == len( + args_chunk_spec + ), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + + for arg_key, arg in args_dict.items(): + flat, spec = tree_flatten(arg) + arg_specs.append(spec) + + chunk_spec = args_chunk_spec[arg_key] + assert chunk_spec is not None # Should have been set by caller + chunk_spec_flat, _ = tree_flatten(chunk_spec) + if len(flat) != len(chunk_spec_flat): + raise ValueError( + f"Argument value {arg} did not have the same number of " + f"values as as chunk spec {chunk_spec}" + ) + + sharded_arg_flat = [] + + for v, chunk_v in zip(flat, chunk_spec_flat): + if chunk_v is _Replicate or not isinstance(v, torch.Tensor): + sharded_arg_flat.append([v] * real_num_chunks) + elif isinstance(chunk_v, TensorChunkSpec): + # TODO: check type of v. If it's a tensor, use chunk (or debug mask). + # If it's a collection type, split it as you would expect. Otherwise, + # Throw an error + assert isinstance(v, torch.Tensor), f"{v} is not a tensor" + + v_split_dim_size = v.size(chunk_v.split_dim) + if v_split_dim_size < real_num_chunks: + if first_tensor: + # We can only adjust number of chunks when we hit this + # issue at the first tensor encountered + logger.warning( + f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004 + f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}." + ) + real_num_chunks = v_split_dim_size + else: + raise RuntimeError( + f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, " + f"smaller than the number of chunks {num_chunks}. " + "PiPPy cannot reduce the number of chunks because " + "other arguments have bigger chunk-dimension sizes. " + "Please adjust your num_chunks setting." + ) + + chunk_tensors = torch.tensor_split( + v, real_num_chunks, chunk_v.split_dim + ) + + if _debug_mask_minibatches: + expanded_chunks = [] + + split_dim_idx = 0 + for chunk_tensor in chunk_tensors: + new_val = torch.zeros_like(v) + upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim) + + slice_indices = [slice(None, None, None)] * new_val.ndim + slice_indices[chunk_v.split_dim] = slice( + split_dim_idx, upper_idx + ) + new_val[slice_indices] = chunk_tensor + + expanded_chunks.append(new_val) + + split_dim_idx += chunk_tensor.size(chunk_v.split_dim) + + sharded_arg_flat.append(expanded_chunks) + else: + sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type] + + first_tensor = False + else: + raise TypeError(f"Unrecognized chunk spec: {chunk_v}") + + args_sharded_replicated[arg_key] = sharded_arg_flat + + # chunks_flat : [num chunks, num args, num flat values] + chunks_flat = [] + for chunk_idx in range(real_num_chunks): + chunk_args = {} + for key, arg in args_sharded_replicated.items(): + arg_single_chunk = [] + for v_flat in arg: + arg_single_chunk.append(v_flat[chunk_idx]) + chunk_args[key] = arg_single_chunk + chunks_flat.append(chunk_args) + + # args_split : [num chunks, num args] + args_split = [] + + for chunk in chunks_flat: + per_chunk_args = {} + assert len(arg_specs) == len(chunk) + for (key, arg), arg_spec in zip(chunk.items(), arg_specs): + per_chunk_args[key] = tree_unflatten(arg, arg_spec) + args_split.append(per_chunk_args) + + return args_split + + +def split_args_kwargs_into_chunks( + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]], + chunks: int, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, +) -> Tuple[List[Tuple], List[Dict]]: + """ + Given a sequence of args and kwargs, split them into a number of chunks + according to their respective chunking specs. + + Args: + args: Tuple of args + kwargs: Dict of kwargs + chunks: Number of chunks to split the args and kwargs into + args_chunk_spec: chunking specs for args, in same shape as args + kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs + + Returns: + args_split: List of sharded args + kwargs_split: List of sharded kwargs + """ + # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that + # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` + # and `kwargs_chunk_spec` specifications. The steps are as follows: + # + # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values. + # To use a running example: suppose our inputs look like + # + # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None) + # (kwargs not shown but it's a similar process) + # + # Then for this step we would end up with + # + # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None) + # + # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2 + # + # args = ([[A, A], [B, B], [C_1, C_2]], [D, D]) + # + # 3. Rotate the nesting order such that chunks are the outer dimension + # + # args_chunks = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 4. Unflatten each chunk according to the spec + # + # args_chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + + # TODO: _debug_mask_minibatches + # Handle the case where kwargs is None + if kwargs is None: + kwargs = {} + + # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend + # their format and use default chunking along dim 0 + if args_chunk_spec is None: + args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args) + + if kwargs_chunk_spec is None: + kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM)) + + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + chunks, + ) + real_num_chunks = len(args_split_dict) + + kwargs_split = _shard_dict_of_args( + kwargs, + kwargs_chunk_spec, + real_num_chunks, + ) + + if len(kwargs_split) < real_num_chunks: + # In case kwargs are sharded into less chunks + # e.g. when `args` has no tensor, just values + real_num_chunks = len(kwargs_split) + # Re-shard args + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + real_num_chunks, + ) + + if len(args_split_dict) != len(kwargs_split): + raise RuntimeError( + "args and kwargs are split into different number of chunks: " + f"{len(args_split_dict)}, {len(kwargs_split)}" + ) + + args_split = [] + for chunk_args in args_split_dict: + args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args)))) + + return args_split, kwargs_split + + +def merge_chunks( + chunks: List[Any], + chunk_spec, +): + """ + Given a list of chunks, merge them into a single value according to + the chunk spec. + + Args: + chunks: list of chunks + chunk_spec: Chunking spec for the chunks + + Returns: + value: Merged value + """ + # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the + # steps are similar to the steps in that function but in reverse. Given the + # input values: + # + # chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + # args_spec = ([None, [None, TensorChunkSpec]], None) + # + # 1. Flatten the chunks according to the chunk_spec + # + # chunks_flat = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 2. Rotate the nesting order such that chunks are the inner dimension + # + # value_inner = ([A, B, [C_1, C_2]], D) + # + # 3. Concatenate sharded arguments + # + # value_combined = ([A, B, C], D) + # + # 4. Unflatten the combined args given the spec + # + # value = ([A, [B, C]], D) + + # Preliminary: flatten the chunk spec + if chunk_spec is not None: + spec_flattened, flatten_spec = tree_flatten(chunk_spec) + else: + # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields + # We obtain the output structure by flattening chunk 0 and generate the chunk_spec + chunk0_flat, flatten_spec = tree_flatten(chunks[0]) + spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) + + # Stage 1: flatten chunks + # chunks_flattened : [num chunks, num args] + chunks_flattened = [] + + for chunk in chunks: + chunk_flattened, _ = tree_flatten(chunk) + if len(chunk_flattened) != len(spec_flattened): + raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") + + chunks_flattened.append(chunk_flattened) + + # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and + # concatenate sharded operands + # args_flattened : [num args] + args_flattened = [] + for arg_idx, arg in enumerate(spec_flattened): + if isinstance(arg, TensorChunkSpec): + partial_values = [ + chunks_flattened[chunk_idx][arg_idx] + for chunk_idx in range(len(chunks_flattened)) + ] + + if _debug_mask_minibatches: + # Infer size of individual chunks by running `tensor_split` again + overall_shape = partial_values[0].shape + for val in partial_values[1:]: + assert val.shape == overall_shape + meta_chunks = torch.tensor_split( + torch.empty(*overall_shape, device="meta"), + sections=len(partial_values), + dim=arg.split_dim, + ) + + values_to_cat = [] + chunk_start_idx = 0 + assert len(partial_values) == len(meta_chunks) + for partial_value, meta_chunk in zip(partial_values, meta_chunks): + chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) + + slice_indices = [slice(None, None, None)] * partial_value.ndim + slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) + sliced = partial_value[slice_indices] + values_to_cat.append(sliced) + + chunk_start_idx = chunk_end_idx + + else: + values_to_cat = partial_values + + args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) + elif isinstance(arg, _CustomReducer): + reduced_val = arg.init_value + + for chunk_idx in range(len(chunks_flattened)): + reduced_val = arg.reduce_fn( + reduced_val, chunks_flattened[chunk_idx][arg_idx] + ) + + args_flattened.append(reduced_val) + else: + value = chunks_flattened[0][arg_idx] + for chunk_idx in range(1, len(chunks_flattened)): + assert chunks_flattened[chunk_idx][arg_idx] == value + args_flattened.append(value) + + # Stage 4: Unflatten combined args + return tree_unflatten(args_flattened, flatten_spec) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/schedules.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..cd02e0e9042ce4fce06b853f485dbb3e0726c921 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/schedules.py @@ -0,0 +1,2162 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import csv +import itertools +import logging +import re +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import Enum +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import torch +import torch.distributed as dist +from torch.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle +from torch.profiler import record_function + +from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec +from .stage import _PipelineStageBase + + +if TYPE_CHECKING: + from torch.distributed import Work + +__all__ = [ + "get_schedule_class", + "PipelineScheduleSingle", + "PipelineScheduleMulti", + "Schedule1F1B", + "ScheduleFlexibleInterleaved1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", +] + +logger = logging.getLogger(__name__) + + +class _ComputationType(Enum): + # TODO(whc) rename to _ActType? + FORWARD = 1 + BACKWARD = 2 + WEIGHT = 3 + UNSHARD = 4 + RESHARD = 5 + SEND_F = 6 + RECV_F = 7 + SEND_B = 8 + RECV_B = 9 + + def __str__(self): + str_map = { + _ComputationType.FORWARD: "F", + _ComputationType.BACKWARD: "B", + _ComputationType.WEIGHT: "W", + _ComputationType.UNSHARD: "UNSHARD", + _ComputationType.RESHARD: "RESHARD", + _ComputationType.SEND_F: "SEND_F", + _ComputationType.RECV_F: "RECV_F", + _ComputationType.SEND_B: "SEND_B", + _ComputationType.RECV_B: "RECV_B", + } + return str_map[self] + + @staticmethod + def from_str(action): + if action == "F": + return _ComputationType.FORWARD + elif action == "B": + return _ComputationType.BACKWARD + elif action == "W": + return _ComputationType.WEIGHT + elif action == "UNSHARD": + return _ComputationType.UNSHARD + elif action == "RESHARD": + return _ComputationType.RESHARD + elif action == "SEND_F": + return _ComputationType.SEND_F + elif action == "RECV_F": + return _ComputationType.RECV_F + elif action == "SEND_B": + return _ComputationType.SEND_B + elif action == "RECV_B": + return _ComputationType.RECV_B + else: + raise RuntimeError(f"Invalid computation type {action}") + + +FORWARD = _ComputationType.FORWARD +BACKWARD = _ComputationType.BACKWARD +WEIGHT = _ComputationType.WEIGHT +UNSHARD = _ComputationType.UNSHARD +RESHARD = _ComputationType.RESHARD +SEND_F = _ComputationType.SEND_F +RECV_F = _ComputationType.RECV_F +SEND_B = _ComputationType.SEND_B +RECV_B = _ComputationType.RECV_B + +# Convenience shorthand for compute actions only since they are used in 'simple schedule format' +F = FORWARD +B = BACKWARD +W = WEIGHT + +# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) +_action_regex = re.compile( + r"(\d+)([F,B,W]|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B{0,1})(\d*)" +) + + +class _Action(NamedTuple): + stage_index: int + computation_type: _ComputationType + microbatch_index: Optional[int] = None + + def __repr__(self): + repr = str(self.stage_index) + repr += str(self.computation_type) + if self.microbatch_index is not None: + repr += str(self.microbatch_index) + return repr + + @staticmethod + def from_str(str): + """ + Reverse of __repr__ + + String should be formatted as [stage][action type][(microbatch)] + e.g. `2F0`, `1UNSHARD`, `3SEND_F1` + """ + if match := _action_regex.match(str): + stage_index, computation_type, microbatch_index = match.groups() + return _Action( + int(stage_index), + _ComputationType.from_str(computation_type), + int(microbatch_index) if len(microbatch_index) else None, + ) + elif str == "" or str.isspace(): + return None + raise RuntimeError( + f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0" + ) + + +def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str: + """ + Formats the pipeline order in a timestep (row) x rank (column) grid of actions + and returns the formatted string + """ + # Calculate the maximum number of steps across all ranks + num_steps = max(len(actions) for actions in pipeline_order.values()) + step_labels = [ + "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps) + ] + # Sorting the dictionary by keys and retrieving values in that order + rank_actions = [ + pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) + ] + # Transpose the list of lists (rows to columns) + transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) + # Generate column labels for ranks + num_ranks = len(pipeline_order) + rank_labels = ["Rank " + str(i) for i in range(num_ranks)] + # Calculate the maximum length of each column, considering labels + max_lengths = [ + max(len(str(item)) if item is not None else 0 for item in col) + for col in zip(step_labels, *transposed_actions) + ] + # Format the header row with rank labels + header_row = " " * (len(step_labels[0]) + 2) + " ".join( + f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) + ) + # Format each row with its corresponding label + formatted_rows = [ + f"{label}: " + + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row)) + for label, row in zip(step_labels, transposed_actions) + ] + # Join the rows into a single string + formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n" + return formatted_table + + +def _validate_pipeline_order( + pipeline_order: Dict[int, List[Optional[_Action]]], + num_microbatches: int, + num_stages: int, + enable_zero_bubble: bool = False, +): + """ + pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...] + Validating that the pipeline order follows the rules: + 1. Forward action for a microbatch must be before the Backward action for that microbatch + 2. Recv for a microbatch must be before the send for that microbatch + 3. Microbatch index is handled in sequential order for each stage + 4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it + 5. Same microbatch cannot be handled in the same time step across ranks + """ + # microbatch_index: (current computation type, current stage) + microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {} + max_timestep = max(len(rank_list) for rank_list in pipeline_order.values()) + for timestep in range(max_timestep): + error_msg: List[str] = [] + current_timestep_actions = [] + for rank in range(len(pipeline_order)): + action = ( + pipeline_order[rank][timestep] + if timestep < len(pipeline_order[rank]) + else None + ) + + if action is not None: + computation_type = action.computation_type + if computation_type != _ComputationType.WEIGHT: + current_timestep_actions.append(action) + + # TODO: enable this + # if len(current_timestep_actions) == 0: + # error_msg.append( + # "All actions were None, there is an unnecessary gap in the schedule" + # ) + + # Ensure that no microbatch is operated on twice in current_timestep_actions + unique_microbatch_indices = { + action.microbatch_index for action in current_timestep_actions + } + if len(unique_microbatch_indices) != len(current_timestep_actions): + error_msg.append( + "Duplicate microbatch index found in current_timestep_actions" + ) + + for action in current_timestep_actions: + stage_index = action.stage_index + computation_type = action.computation_type + mb_index = action.microbatch_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + if mb_index >= num_microbatches: + error_msg.append(f"Microbatch index {mb_index} out of range") + + # first microbatch + if mb_index not in microbatch_process_info: + if computation_type != _ComputationType.FORWARD or stage_index != 0: + error_msg.append(f"Incorrect start for microbatch {mb_index}") + microbatch_process_info[mb_index] = (computation_type, stage_index) + else: + # if the microbatch is included, check that the current stage is right after prev + prev_computation, prev_stage = microbatch_process_info[mb_index] + + if prev_computation == _ComputationType.FORWARD: + if prev_stage == num_stages - 1: + expected_stage = num_stages - 1 + expected_computation = _ComputationType.BACKWARD + else: + expected_stage = prev_stage + 1 + expected_computation = _ComputationType.FORWARD + elif prev_computation == _ComputationType.BACKWARD: + if prev_stage == 0: + error_msg.append( + f"[{mb_index=}] already finished backward computation" + ) + break + else: + expected_stage = prev_stage - 1 + expected_computation = _ComputationType.BACKWARD + else: + raise ValueError( + f"Computation type {prev_computation} not supported" + ) + + if expected_computation is not None: + if expected_computation != computation_type: + error_msg.append( + f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}" + ) + + if expected_stage != stage_index: + error_msg.append( + f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}" + ) + + microbatch_process_info[mb_index] = ( + expected_computation, + expected_stage, + ) + + if not enable_zero_bubble: + if len(error_msg) != 0: + raise RuntimeError( + f"Error at timestep {timestep}: " + ",".join(error_msg) + ) + return + + for rank in range(len(pipeline_order)): + backward_steps: Set[Tuple[int, int]] = set() + weight_steps: Set[Tuple[int, int]] = set() + + for action in pipeline_order[rank]: + if action is None: + continue + + stage_index = action.stage_index + computation_type = action.computation_type + mb_index = action.microbatch_index + if computation_type == _ComputationType.BACKWARD: + if mb_index is not None: + backward_steps.add((mb_index, stage_index)) + elif computation_type == _ComputationType.WEIGHT: + if (mb_index, stage_index) not in backward_steps: + error_msg.append( + f"{mb_index=}, {stage_index=} Weight happened before bwd" + ) + if (mb_index, stage_index) in weight_steps: + error_msg.append( + f"{mb_index=}, {stage_index=} Duplicated weight step" + ) + if mb_index is not None: + weight_steps.add((mb_index, stage_index)) + + if len(backward_steps) != len(weight_steps): + error_msg.append("Length weight steps != Length bwd steps") + + if len(error_msg) != 0: + raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg)) + + +class _PipelineSchedule(ABC): + def __init__( + self, + n_microbatches: int, + loss_fn: Optional[Callable[..., torch.Tensor]] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + # From arguments + self._n_microbatches = n_microbatches + self._loss_fn = loss_fn + # Chunking specification for positional inputs. (default: `None`) + self._args_chunk_spec = args_chunk_spec + # Chunking specification for keyword inputs. (default: `None`) + self._kwargs_chunk_spec = kwargs_chunk_spec + self._output_merge_spec = output_merge_spec + """ + # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. + # They are used to convert batch to microbatches in `step(x)`. See + # `TensorChunkSpec` for helper methods for creating them. + """ + + # Derived + self._has_backward = self._loss_fn is not None + + # Holds the losses for each microbatch. + self._internal_losses: List[torch.Tensor] = [] + logger.info("Using %s", self.__class__.__name__) + + def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): + if stage.is_last and self._has_backward: + loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] + self._internal_losses.append(loss) + + def _maybe_get_loss(self, stage, mb_index): + valid_index = 0 <= mb_index < len(self._internal_losses) + if stage.is_last and self._has_backward and valid_index: + return self._internal_losses[mb_index] + elif len(self._internal_losses) != 0 and not valid_index: + raise RuntimeError( + f"Loss for microbatch {mb_index} is not available. " + f"Available losses for microbatches: {self._internal_losses}" + ) + else: + return None + + def _update_losses(self, stages, losses): + """ + Update the losses to those in the internal state + """ + # if stages not a list turn into a list + if not isinstance(stages, list): + stages = [stages] + contains_last_stage = any(stage.is_last for stage in stages) + + # Return losses if there is a container passed in + if contains_last_stage and losses is not None: + if len(self._internal_losses) != self._n_microbatches: + raise RuntimeError( + f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" + ) + + # Clean external container first + losses.clear() + # Copy internal losses to external container + losses.extend(self._internal_losses) + + self._internal_losses.clear() + + @abstractmethod + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the schedule + implementation. + + Args: + microbatches: list of microbatch args. + """ + raise NotImplementedError + + @abstractmethod + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + raise NotImplementedError + + def _check_inputs( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Pre-process/check inputs + """ + + def check_type_and_len(mbs, name: str): + if not isinstance(mbs, list): + raise TypeError(f"{name} must be a list but got a {type(mbs)}") + if len(mbs) != self._n_microbatches: + raise ValueError( + f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" + ) + + if arg_mbs is not None: + check_type_and_len(arg_mbs, "arg_mbs") + else: + arg_mbs = [()] * self._n_microbatches + + if kwarg_mbs is not None: + check_type_and_len(kwarg_mbs, "kwarg_mbs") + else: + kwarg_mbs = [{}] * self._n_microbatches + + if target_mbs is not None: + check_type_and_len(target_mbs, "target_mbs") + + if losses is not None: + if not isinstance(losses, list): + raise TypeError(f"losses must be a list but got a {type(losses)}") + + return arg_mbs, kwarg_mbs + + def _compute_loss(self, output, target): + return self._loss_fn(output, target) # type: ignore[misc] + + def _split_inputs( + self, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Splits a full-batch input into chunks (i.e. microbatches) and returns + the chunks + """ + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self._n_microbatches, + self._args_chunk_spec, + self._kwargs_chunk_spec, + ) + return args_split, kwargs_split + else: + # Empty inputs (e.g. when called on middle stages) + # Return a list of empty tuples/dicts with matching length as chunks + return [()] * self._n_microbatches, [{}] * self._n_microbatches + + def _merge_outputs(self, output_chunks: List[Any]) -> Any: + """ + Merge output chunks back to a batch state. + If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). + """ + return merge_chunks( + output_chunks, + self._output_merge_spec, + ) + + +def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): + """ + Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. + """ + if len(p2p_ops) == 0: + return None + desc_str = f"{desc}, " if desc else "" + logger.debug("batch_p2p %s%s", desc_str, p2p_ops) + return dist.batch_isend_irecv(p2p_ops).pop() + + +def _sorted_batch_p2p( + p2p_ops: List[dist.P2POp], desc: Optional[str] = None +) -> Dict[int, dist.Work]: + """ + Sorts the list of P2P ops by the peer rank, and then calls + batch_isend_irecv. Return a dictionary of works by peer rank. This function + helps us avoid hangs in case of skip connections. + """ + # Arrange p2p_ops by peer rank: + # int is the peer rank; + # List is the list of ops towards the peer + ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list) + work_by_peer: Dict[int, dist.Work] = {} + if len(p2p_ops) == 0: + return work_by_peer + + # Classify the ops by peer rank + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = _batch_p2p(ops, desc=desc) + + return work_by_peer + + +class PipelineScheduleSingle(_PipelineSchedule): + """ + Base class for single-stage schedules. + Implements the `step` method. + Derived classes should implement `_step_microbatches`. + """ + + def __init__( + self, + stage: _PipelineStageBase, + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + # Self attributes + self._stage = stage + self._num_stages = stage.num_stages + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward + + # TODO: later replace this with lazy shape inference during forward + # Prepare forward send/recv infrastructure for stage + stage._prepare_forward_infra(n_microbatches) + if self._has_backward: + stage._prepare_backward_infra(n_microbatches) + + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + + # Clean per iteration + self._stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + if self._stage.is_last: + return self._merge_outputs(self._stage.output_chunks) + else: + return None + + +class _ScheduleForwardOnly(PipelineScheduleSingle): + """ + The forward-only schedule. + Will go through all the microbatches and perform only the forward pass + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule + """ + if target_mbs is not None or losses is not None: + raise RuntimeError( + "Forward-only schedule does not support loss computation" + ) + + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Delay send waits + fwd_sends_to_wait: List[dist.Work] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + work.wait() + + self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + work.wait() + + +class ScheduleGPipe(PipelineScheduleSingle): + """ + The GPipe schedule. + Will go through all the microbatches in a fill-drain manner. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the GPipe schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Delay send waits + fwd_sends_to_wait: List[dist.Work] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + work.wait() + + output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + self._maybe_compute_loss(self._stage, output, target_mbs, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + work.wait() + + # No loss function, no need to run backward + if not self._has_backward: + return + + # Run backward + # Delay send waits + bwd_sends_to_wait: List[dist.Work] = [] + for i in range(self._n_microbatches): + with record_function(f"Backward {i}"): + ops = self._stage.get_bwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_recv") + for work in works.values(): + work.wait() + + loss = self._maybe_get_loss(self._stage, i) + self._stage.backward_one_chunk(i, loss=loss) + + ops = self._stage.get_bwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i) + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + # Wait for all backward sends to finish + for work in bwd_sends_to_wait: + work.wait() + + +class Schedule1F1B(PipelineScheduleSingle): + """ + The 1F1B schedule. + Will perform one forward and one backward on the microbatches in steady state. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the 1F1B schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Last stage has 1 warmup, second-to-last 2 warmups, ... + # first stage `num_stages` warmups + warmup_chunks = min( + self._n_microbatches, + self._num_stages - self._stage.stage_index, + ) + + # Chunk counters + fwd_mb_index = 0 + bwd_mb_index = 0 + weight_stage_mb_index = 0 + + # Warmup phase + send_work = None + fwd_sends = [] + for _ in range(warmup_chunks): + # Receive activations + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"): + recv_work.wait() + + # Compute + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Clear previous chunk's forward sends (hopefully they have well + # finished, otherwise, we are heavily communication bound, in which + # case it doesn't create a lot of benefit to compute next chunk + # eagerly either) + if send_work: + send_work.wait() + + # Send activations + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + if fwd_mb_index != warmup_chunks - 1: + # Safe to fire + send_work = _batch_p2p(fwd_sends, desc="fwd_send") + # otherwise: + # The last foward send is left for fuse with first 1B in 1B1F below + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + fwd_mb_index += 1 + + # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. + + # 1B1F phase + while True: # Don't worry, we have a break inside + # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + + # Now, we need to fire the fwd_sends and bwd_recvs together + if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"): + fuse_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk(bwd_mb_index, loss=loss) + + # Get the bwd send ops, but don't fire, to be fused with the 1F below + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + bwd_mb_index += 1 + + if fwd_mb_index == self._n_microbatches: + # We are done with 1B1F, so break with some left-over bwd_sends + break + + # We prepare 1F of the `1B1F` + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + + # Fuse it with bwd_sends above + if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"): + fuse_work.wait() + + # Now do the fwd + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + + # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + fwd_mb_index += 1 + + # Remember we still have some bwd_sends left over after the break? Now it is time to fire it + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + + # Cooldown + while bwd_mb_index < self._n_microbatches: + # prepare bwd recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"): + recv_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk(bwd_mb_index, loss=loss) + + # Clear previous chunk's backward sends (hopefully they have well finished) + if send_work: + send_work.wait() + + # Get the bwd send ops, fire it + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + bwd_mb_index += 1 + + # Wait for the last backward send to finish + if send_work: + send_work.wait() + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + +def _add_unshard_reshard( + compute_actions: List[Optional[_Action]], + max_active_stages: int = 3, +) -> List[_Action]: + """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. + + UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. + RESHARD does the opposite, releasing memory (but doing no commmunication) + + We abandon the "timestep lock" during lowering + + max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice + 3 stages is probably the thing we want? + (to account for having one f and one b active, and something else prefetching?) + """ + + def next_stage_indices( + count: int, next_actions: List[Optional[_Action]] + ) -> List[int]: + """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.""" + seen: Set[int] = set() + ret: List[int] = [] + + for a in next_actions: + if a is not None and a.stage_index not in seen: + seen.add(a.stage_index) + ret.append(a.stage_index) + if len(ret) == count: + break + return ret + + active_stages: Set[int] = set() + fsdp_aware_actions: List[_Action] = [] + + def _unshard(stage_index: int): + active_stages.add(stage_index) + fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None)) + + def _reshard(stage_index: int): + active_stages.remove(stage_index) + fsdp_aware_actions.append(_Action(stage_index, RESHARD, None)) + + for i, action in enumerate(compute_actions): + if action is None: + continue + + # We prefetch the next N stages we'll see, dropping existing stages to make room + next_n = next_stage_indices(max_active_stages, compute_actions[i:]) + # Fetch needs to be ordered correctly, so don't use a set + fetch = list(filter(lambda s: s not in active_stages, next_n)) + # Unclear what the best policy is for eviction, but we can maintain order so we do + evict = list(filter(lambda s: s not in next_n, active_stages)) + + # logger.debug( + # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s", + # i, + # active_stages, + # fetch, + # evict, + # ) + + for stage in evict: + _reshard(stage) + for stage in fetch: + _unshard(stage) + fsdp_aware_actions.append(action) + + return fsdp_aware_actions + + +def _add_send_recv( + compute_actions: Dict[int, List[_Action]], + stage_to_rank: Callable[[int], int], + num_stages: int, +) -> Dict[int, List[_Action]]: + comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions} + + def _has_comms(action: _Action) -> bool: + if action.computation_type == F: + return action.stage_index != num_stages - 1 + elif action.computation_type == B: + return action.stage_index != 0 + return False + + def _get_comms(action: _Action) -> Tuple[_Action, _Action]: + assert _has_comms(action), f"{action} is not a valid comm action" + stage_idx = action.stage_index + ctype = action.computation_type + mb_idx = action.microbatch_index + send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx) + recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1 + recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx) + return send, recv + + def _ready_to_schedule( + action: Optional[_Action], prev_actions: List[_Action] + ) -> bool: + """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. + This helps ensure a sane (non-hanging) ordering of sends and recvs. + But it also means we might not be able to schedule our next compute action yet. + """ + if action is None: + return True + elif action.computation_type == F and not action.stage_index == 0: + expected_recv = _Action( + action.stage_index, + RECV_F if action.computation_type == F else RECV_B, + action.microbatch_index, + ) + return expected_recv in prev_actions + elif action.computation_type == B and not action.stage_index == num_stages - 1: + expected_recv = _Action( + action.stage_index, + RECV_F if action.computation_type == F else RECV_B, + action.microbatch_index, + ) + return expected_recv in prev_actions + else: + return True + + while compute_actions: + progress = False + # go in order of ranks even if dict keys aren't ordered + for rank in range(len(compute_actions)): + assert len(compute_actions[rank]) > 0 + action = compute_actions[rank][0] + + if not _ready_to_schedule(action, comm_actions[rank]): + continue + + if action is not None: + comm_actions[rank].append(action) + if _has_comms(action): + send, recv = _get_comms(action) + # TODO we can avoid send/recv if the 2 stages are on the same rank. + # should we avoid that in the runtime or here? + comm_actions[rank].append(send) + comm_actions[stage_to_rank(recv.stage_index)].append(recv) + + compute_actions[rank].pop(0) + if len(compute_actions[rank]) == 0: + del compute_actions[rank] + progress = True + assert progress, "Malformed compute schedule, can't schedule sends/recvs" + return comm_actions + + +class PipelineScheduleMulti(_PipelineSchedule): + """ + Base class for multi-stage schedules. + Implements the `step` method. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + stage_index_to_group_rank: Optional[Dict[int, int]] = None, + use_full_backward: bool = True, + ): + if len(stages) <= 1: + raise ValueError( + f"Multi-stage schedule expects at least two stages but got {len(stages)}" + ) + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + # Self attributes + self._stages = stages + self._num_stages = stages[0].num_stages + self.pp_group_size = stages[0].group_size + self.rank = stages[0].group_rank + # Set the pipeline stage states + if stage_index_to_group_rank is not None: + for stage in self._stages: + stage.stage_index_to_group_rank = stage_index_to_group_rank + self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank + + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + + self._should_compute_loss = ( + lambda stage: stage.is_last and self._loss_fn is not None + ) + + # This will be set during init of derived schedules + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + self.use_full_backward = use_full_backward + + # TODO: later replace this with lazy shape inference during forward + # Prepare forward send/recv infrastructure for stage + for stage in self._stages: + stage._prepare_forward_infra(n_microbatches) + if self._has_backward: + stage._prepare_backward_infra(n_microbatches) + + def _dump_csv(self, filename): + """Dump a CSV representation of the schedule into a file with the provided filename.""" + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order: + writer.writerow(self.pipeline_order[rank]) + + def _validate_schedule(self): + # TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554 + def _validate_rank_actions( + actions: Dict[int, List[_Action | None]], + num_stages: int, + num_microbatches: int, + ): + # We will count all the actions per stage and ensure they happen in a valid order + # (e.g. F before B before W for a given microbatch) + stage_actions: Dict[int, Dict[_ComputationType, Set]] = { + stage_id: { + F: set(), + B: set(), + W: set(), + } + for stage_id in range(num_stages) + } + for rank in actions: + for action in actions[rank]: + if action is None: + continue + assert isinstance( + action, _Action + ), f"Got an invalid action: {action}, expected instance of _Action" + s_id = action.stage_index + ctype = action.computation_type + mb_id = action.microbatch_index + if ctype == F: + stage_actions[s_id][F].add(mb_id) + elif ctype == B: + assert ( + mb_id in stage_actions[s_id][F] + ), f"Running Backward for stage {s_id}, microbatch {mb_id} without first running Forward" + stage_actions[s_id][B].add(mb_id) + elif ctype == W: + assert ( + not self.use_full_backward + ), "Schedule contains 'W' actions, but is configured to use full backward" + assert ( + mb_id in stage_actions[s_id][B] + ), f"Running Weight for stage {s_id}, microbatch {mb_id} without first running Backward" + stage_actions[s_id][W].add(mb_id) + + for s_id in stage_actions: + for ctype in (F, B, W): + stage_mb = len(stage_actions[s_id][ctype]) + assert ( + stage_mb == num_microbatches + ), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}" + + assert ( + len(self.pipeline_order) == self.pp_group_size + ), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}" + for rank in range(self.pp_group_size): + assert ( + rank in self.pipeline_order + ), f"Schedule is missing actions for rank {rank}" + _validate_rank_actions( + self.pipeline_order, + self._num_stages, + self._n_microbatches, + ) + + def _load_csv(self, filename, format="compute_only"): + """Load a CSV representation of the schedule from a file with the provided filename. + This API will most likely get renamed/refactored so is marked as internal for now. + + format must be "compute_only" for PipelineScheduleMulti + """ + assert format == "compute_only" + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + self.pipeline_order[rank] = [_Action.from_str(s) for s in row] + self._validate_schedule() + + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + + # Clean per iteration + for stage in self._stages: + stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + for stage in self._stages: + if stage.is_last: + return self._merge_outputs(stage.output_chunks) + # Does not contain the last stage + return None + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + # determine prev_rank and next_rank based on which ranks are next to + # the stages in the pipeline_order + all_prev_ranks: Set[int] = set() + all_next_ranks: Set[int] = set() + for stage_index in stage_index_to_stage.keys(): + # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) + if stage_index > 0: + all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) + if stage_index < self._num_stages - 1: + all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1]) + + for time_step, action in enumerate(self.pipeline_order[self.rank]): + try: + ops: List[dist.P2POp] = [] + if action is not None: + computation_type = action.computation_type + mb_index = action.microbatch_index + stage_index = action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + if computation_type == _ComputationType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + ops.extend(stage.get_fwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, loss=loss, full_backward=self.use_full_backward + ) + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.WEIGHT: + # perform weight update + if self.use_full_backward: + raise ValueError( + f"We detected a weight update in the pipeline schedule, but \ + {self.use_full_backward=}" + ) + stage = stage_index_to_stage[stage_index] + stage.backward_weight_one_chunk(mb_index) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + for prev_rank in all_prev_ranks: + prev_rank_ops = self.pipeline_order[prev_rank] + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type = prev_rank_action.computation_type + mb_index = prev_rank_action.microbatch_index + stage_index = prev_rank_action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + # Only handle sends for the forward from a previous rank + if computation_type == _ComputationType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index + 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops(mb_index)) + elif ( + computation_type == _ComputationType.BACKWARD + or computation_type == _ComputationType.WEIGHT + ): + # Previous rank doing backward or weight update has no influence for the current rank forward recv + pass + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + for next_rank in all_next_ranks: + next_rank_ops = self.pipeline_order[next_rank] + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type = next_rank_action.computation_type + mb_index = next_rank_action.microbatch_index + stage_index = next_rank_action.stage_index + assert ( + mb_index is not None + ), "All currently supported action types require valid microbatch_index" + # Only handle receives for the backwards from a next rank + if ( + computation_type == _ComputationType.FORWARD + or computation_type == _ComputationType.WEIGHT + ): + # Next rank doing forward or weight update has no influence for the current rank backward recv + pass + elif computation_type == _ComputationType.BACKWARD: + # If not the first stage, then receive bwd gradients + if stage_index - 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops(mb_index)) + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + + # do the communication + if ops: + _batch_p2p(ops).wait() + except Exception as e: + logger.error( + "[Rank %s] pipeline schedule %s caught the following exception \ + at time_step %s when running action %s", + self.rank, + self.__class__.__name__, + time_step, + action, + ) + logger.error("%s", _format_pipeline_order(self.pipeline_order)) + raise e + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class _PipelineScheduleRuntime(PipelineScheduleMulti): + """ + Provides a simple runtime that requires a 'schedule IR' including specified communication operations. + + Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be + subclassed and the subclass can be responsible for creating a schedule IR. + """ + + def _load_actions( + self, + actions: Dict[int, List[Optional[_Action]]], + format: str = "compute_only", + ): + """ + Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including + communication actions. Stores the schedule in self, and must be called before running step_mo() + """ + assert ( + self.stage_index_to_group_rank is not None + ), "stage_index_to_group_rank is required for PipelineScheduleRuntime" + self.pipeline_order_with_comms: Dict[int, List[_Action]] = {} + if format == "compute_comms": + for rank in actions: + self.pipeline_order_with_comms[rank] = [] + for action in actions[rank]: + assert action is not None + self.pipeline_order_with_comms[rank].append(action) + # TODO what level of validation should we offer for compute+comms schedule? + elif format == "compute_only": + # Perform schedule lowering + for rank in actions: + self.pipeline_order_with_comms[rank] = _add_unshard_reshard( + actions[rank] + ) + + self.pipeline_order_with_comms = _add_send_recv( + self.pipeline_order_with_comms, + stage_to_rank=lambda s: self.stage_index_to_group_rank[s], + num_stages=self._num_stages, + ) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _load_csv(self, filename: str, format: str = "compute_only"): + """Loads a csv in simple format and then lowers it to include comunication actions + + format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes + will automatically be run to generate a compute_comms schedule. + """ + if format == "compute_only": + # this will populate self.pipeline_order + super()._load_csv(filename) + # this will populate self.pipeline_order_with_comms + self._load_actions(self.pipeline_order) + elif format == "compute_comms": + actions = {} + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + actions[rank] = [_Action.from_str(s) for s in row] + self._load_actions(actions, format=format) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _dump_csv(self, filename: str): + """Dump a CSV representation of the compute + comms schedule into a file with the provided filename.""" + # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible + # that it does not exist if it was created from a compute_comms schedule. + assert ( + self.pipeline_order_with_comms is not None + ), "Must initialize compute_comms schedule before dump_csv" + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order_with_comms: + writer.writerow(self.pipeline_order_with_comms[rank]) + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + assert ( + self.pipeline_order_with_comms is not None + ), "Must call _load_actions() before calling _step_microbatches()" + + # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use + bwd_recv_ops: Dict[Tuple[int, int], Work] = {} + fwd_recv_ops: Dict[Tuple[int, int], Work] = {} + + # send ops should be waited on before step() exists, mainly for hygeine + send_ops: List[Work] = [] + + # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages + unshard_ops: Dict[int, UnshardHandle] = {} + unsharded_stages = set() + + def _assert_unsharded(stage_idx: int): + """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared.""" + if stage_idx in unshard_ops: + unshard_ops[stage_idx].wait() + del unshard_ops[stage_idx] + unsharded_stages.add(stage_idx) + assert ( + stage_idx in unsharded_stages + ), f"Attempted to compute on sharded {stage_idx=}" + + for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): + try: + comp_type = action.computation_type + mb_index: int = ( + action.microbatch_index + if action.microbatch_index is not None + else -1 + ) + assert mb_index >= 0 or comp_type in ( + UNSHARD, + RESHARD, + ), f"{action=} missing mb_index" + stage_idx = action.stage_index + stage = stage_index_to_stage[stage_idx] + stage_uses_fsdp = isinstance(stage.submod, FSDPModule) + + logger.debug( + "_PipelineScheduleRuntime running time_step %d, action %s", + time_step, + action, + ) + + # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, + # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be + # safe to use instead. + # However, I was wondering if I should avoid calling batched operators at all in the case that there is + # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them. + if comp_type == SEND_F: + send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index))) + elif comp_type == SEND_B: + send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index))) + elif comp_type == RECV_F: + assert ( + stage_idx, + mb_index, + ) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward" + fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_fwd_recv_ops(mb_index) + ) + elif comp_type == RECV_B: + assert ( + stage_idx, + mb_index, + ) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward" + bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_bwd_recv_ops(mb_index) + ) + elif comp_type == UNSHARD: + if stage_uses_fsdp: + assert ( + stage_idx not in unsharded_stages + and stage_idx not in unshard_ops + ), f"Unsharding the same {stage_idx=} twice" + unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) + elif comp_type == RESHARD: + if stage_uses_fsdp: + assert ( + stage_idx in unsharded_stages + ), f"Resharding {stage_idx=} without unsharding" + assert ( + stage_idx not in unshard_ops + ), f"Resharding {stage_idx=} before finishing unshard" + stage.submod.reshard() + elif comp_type == FORWARD: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if not stage.is_first: + assert ( + stage_idx, + mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + fwd_recv_ops.pop((stage_idx, mb_index)).wait() + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + elif comp_type == BACKWARD: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if not stage.is_last: + assert ( + stage_idx, + mb_index, + ) in bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + bwd_recv_ops.pop((stage_idx, mb_index)).wait() + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, loss=loss, full_backward=self.use_full_backward + ) + elif comp_type == WEIGHT: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if self.use_full_backward: + raise ValueError( + f"We detected a weight update in the pipeline schedule, but \ + {self.use_full_backward=}" + ) + stage.backward_weight_one_chunk(mb_index) + else: + raise ValueError(f"{action=} is unknown or unsupported") + except Exception as e: + logger.error( + "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:", + time_step, + action, + ) + # TODO(whc) what is the best practice for printing a multiline log? + # logger will split it into multiple log lines, but this makes it hard to read (too wide) + print(_format_pipeline_order(self.pipeline_order_with_comms)) # type: ignore[arg-type] + raise e + + # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them + while len(send_ops): + send_ops.pop().wait() + + assert len(unshard_ops) == 0, "Unused unshard operations" + + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class ScheduleLoopedBFS(PipelineScheduleMulti): + """ + Breadth-First Pipeline Parallelism. + See https://arxiv.org/abs/2211.05953 for details. + Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. + What is different is that when microbatches are ready for multiple local + stages, Loops BFS will prioritizes the earlier stage, running all available + microbatches at once. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + ) + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + # ======================================================================== + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank): + n_local_stages = len(self._stages) + stage_indices = range( + rank, self.pp_group_size * n_local_stages, self.pp_group_size + ) + + # Store the list of operations used for that rank + rank_ops: List[Optional[_Action]] = [] + # Pre-padding, rank starts with no-ops based on the warmup. + for _ in range(rank): + rank_ops.append(None) + + for stage_index in stage_indices: + for mb_index in range(self._n_microbatches): + rank_ops.append( + _Action(stage_index, _ComputationType.FORWARD, mb_index) + ) + + # wait for the first backward to trickle up + # which is 2 for every hop away + post_warmup_ops = 2 * (self.pp_group_size - 1 - rank) + rank_ops.extend([None] * post_warmup_ops) + + for stage_index in reversed(stage_indices): + for mb_index in reversed(range(self._n_microbatches)): + rank_ops.append( + _Action(stage_index, _ComputationType.BACKWARD, mb_index) + ) + return rank_ops + + +def _get_1f1b_rank_ops( + n_local_stages, + pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches=0, + enable_zero_bubble=False, +): + # All stages start with handling microbatch 0 + fwd_stage_mb_index: Dict[int, int] = defaultdict(int) + bwd_stage_mb_index: Dict[int, int] = defaultdict(int) + weight_stage_mb_index: Dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + rank_ops: List[Optional[_Action]] = [] + # Pre-padding, rank starts with no-ops based on the warmup. + for _ in range(rank): + rank_ops.append(None) + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + if enable_zero_bubble: + post_warmup_ops = pp_group_size - rank - 1 + + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + + backward_op_ids = [] + weight_op_count = 0 + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index) + ) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, _ComputationType.WEIGHT, weight_mb_index + ) + ) + weight_op_count += 1 + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + if not enable_zero_bubble: + rank_ops.append(None) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, _ComputationType.WEIGHT, weight_mb_index + ) + ) + weight_op_count += 1 + + while enable_zero_bubble and weight_op_count < len(backward_op_ids): + weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count]) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index) + ) + weight_op_count += 1 + + return rank_ops + + +class ScheduleInterleaved1F1B(PipelineScheduleMulti): + """ + The Interleaved 1F1B schedule. + See https://arxiv.org/pdf/2104.04473 for details. + Will perform one forward and one backward on the microbatches in steady + state and supports multiple stages per rank. When microbatches are ready for + multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch + (also called "depth first"). + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + self.pp_group_size = stages[0].group_size + # TODO: is this limitation a must? + if n_microbatches % self.pp_group_size != 0: + raise ValueError( + f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \ + to be a multiple of the number of pipeline ranks ({self.pp_group_size})." + ) + + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.group = stages[0].group + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size + # Increment warmup operations by 2 for each hop away from the last stage + warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank) + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.pp_group_size) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + ) + + +class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti): + """ + The Flexible Interleaved 1F1B schedule. + + This schedule is mostly similar to the interleaved 1F1B schedule. + It differs by being relaxing the requirement of num_microbatch % pp_size == 0. + Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and + it works as long as n_microbatches % num_rounds is 0. As a few examples, support + + 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. + 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. + + When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5 + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + enable_zero_bubble: bool = False, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + use_full_backward=not enable_zero_bubble, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + self.enable_zero_bubble = enable_zero_bubble + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Flexible Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # This function add bubbles to the generated schedule based on dependencies of actions + # Note that the ZB1P schedule will not require bubbles to be manually added and it is + # only useful when n_microbatches <= microbatches_per_round + self.pipeline_order = self._add_bubbles_to_actions( + self.n_local_stages * self.pp_group_size, + ) + + def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 1 if self.enable_zero_bubble else 2 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + if self.enable_zero_bubble: + num_1f1b_microbatches = rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches, + enable_zero_bubble=True, + ) + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + ) + + def _add_bubbles_to_actions(self, num_stages_global): + actions = self.pipeline_order + if not self.enable_zero_bubble: + return actions + + def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): + if op == _ComputationType.FORWARD: + if stage != 0 and (stage - 1, op, microbatch) not in seen_ops: + return True + elif op == _ComputationType.BACKWARD: + if stage == num_stages_global - 1: + return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops + return (stage + 1, op, microbatch) not in seen_ops + return False + + seen_ops: Set[Tuple[int, _ComputationType, int]] = set() + result: Dict[int, List[Optional[_Action]]] = {} + next_pointer: Dict[int, int] = {} + bubbles_added: Dict[int, int] = {} + total_bubbles_added = 0 + + for rank in range(self.pp_group_size): + result[rank] = [] + next_pointer[rank] = 0 + bubbles_added[rank] = 0 + + while True: + should_stop = True + + temp_seen_ops: Set[Tuple[int, _ComputationType, int]] = set() + + for rank in range(self.pp_group_size): + timestamp = next_pointer[rank] + if timestamp >= len(actions[rank]): + continue + + should_stop = False + + if actions[rank][timestamp] is not None: + temp_action = actions[rank][timestamp] + assert temp_action is not None + stage_index, op, microbatch = temp_action + if not need_bubble( + stage_index, op, microbatch, num_stages_global, seen_ops + ): + result[rank].append(actions[rank][timestamp]) + if microbatch is not None: + temp_seen_ops.add((stage_index, op, microbatch)) + next_pointer[rank] += 1 + else: + result[rank].append(None) + bubbles_added[rank] += 1 + else: + next_pointer[rank] += 1 + result[rank].append(None) + + seen_ops.update(temp_seen_ops) + if should_stop: + break + + if total_bubbles_added > 0: + logger.warning( + "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s", + total_bubbles_added, + bubbles_added, + ) + return result + + +class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B): + """ + The Interleaved Zero Bubble schedule. + See https://arxiv.org/pdf/2401.10241 for details. + Will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses the backward for weights to fill in + the pipeline bubble. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + enable_zero_bubble=True, + ) + + +def get_schedule_class(schedule_name: str): + """ + Maps a schedule name to its corresponding class object. + + Args: + schedule_name (str): The name of the schedule. + """ + schedule_map = { + "1F1B": Schedule1F1B, + "Interleaved1F1B": ScheduleInterleaved1F1B, + "GPipe": ScheduleGPipe, + "FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B, + "LoopedBFS": ScheduleLoopedBFS, + "InterleavedZeroBubble": ScheduleInterleavedZeroBubble, + "PipelineScheduleSingle": PipelineScheduleSingle, + "PipelineScheduleMulti": PipelineScheduleMulti, + } + if schedule_name not in schedule_map: + raise ValueError(f"Unknown schedule name: {schedule_name}") + return schedule_map[schedule_name] diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/stage.py b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/stage.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4abfb7b0b35271e36703aca6fdff46db0cf4d7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/pipelining/stage.py @@ -0,0 +1,1468 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.fx as fx +import torch.nn as nn +from torch._subclasses.fake_tensor import FakeTensor +from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard +from torch.fx.node import map_aggregate +from torch.nn.parallel import DistributedDataParallel + +from ._backward import stage_backward, stage_backward_input, stage_backward_weight +from ._debug import map_debug_info +from ._utils import flatten_args, PipeInfo, validate_tensors_metadata + + +__all__ = [ + "PipelineStage", + "build_stage", +] + +logger = logging.getLogger(__name__) + + +class _RootArgPlaceholder: + """ + Placeholder for model-level inputs. + """ + + def __init__(self, tensor): + self.meta = tensor.to("meta") + + +class _RecvInfo: + """ + Represents a stage input. + """ + + def __init__( + self, + input_name: str, + source: int, + buffer: torch.Tensor, + ): + # Name of this input + self.input_name = input_name + # Stage index of the source of this input + self.source = source + # Buffer to receive the input into. + self.buffer = buffer + + def __repr__(self): + return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" + + +# An input can be either a received activation or a model input +InputInfo = Union[_RecvInfo, _RootArgPlaceholder] + + +def _make_tensor_from_meta( + example: Union[torch.Tensor, FakeTensor], + device: torch.device, +) -> torch.Tensor: + """ + Create a real tensor from a tensor. + """ + return torch.empty( + example.size(), + dtype=example.dtype, + layout=example.layout, + device=device, + ) + + +class _PipelineStageBase(ABC): + """ + Base class for pipeline stages. + Defines or implements common methods used by the `_PipelineStage` used by + the tracing frontend and `PipelineStage` used by manual frontend. + """ + + def __init__( + self, + submodule: torch.nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + group: Optional[dist.ProcessGroup] = None, + dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + ): + """ + Args: + submodule (torch.nn.Module): The module to be executed in this stage. + stage_index (int): The index of this stage. + num_stages (int): The total number of stages in this pipeline. + device (torch.device): The device to run this stage on. + group (Optional[dist.ProcessGroup]): The process group to use for communication. + If `None`, the default process group will be used. + Default: `None`. + dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_runner is a builder function + that will build a new dw_runner function that will run parts of module backward that were intentionally + skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs + model backwards, and stage should save the latest dw_runner to run during weight pass. + If not provided, a dw_runner will be generated automatically by traversing the autograd graph. + When used with schedules that only have F and B steps, the fresh dw_runner function will be called as + part of B. + When used with F,B,W schedules, the dw_runner function implements 'W'. + """ + super().__init__() + if stage_index >= num_stages: + raise ValueError( + f"Stage index {stage_index} is out of range of {num_stages}" + ) + + self.submod = submodule + self.stage_index = stage_index + self.num_stages = num_stages + self.device = device + self.group = group + + self.dw_builder = dw_builder + + # backward state + self.backward_state: Dict[int, Tuple[Any, ...]] = {} + + # store dw_runner per microbatch_id + self.dw_runner: Dict[int, Callable[..., None]] = {} + + # `group_rank` is rank in process group `group`. + self.group_rank = dist.get_rank(self.group) + self.group_size = dist.get_world_size(self.group) + if self.group_size > self.num_stages: + raise RuntimeError( + f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}" + ) + + # Run time states + self._outputs_meta: Optional[Tuple[torch.Tensor, ...]] = None + # map microbatch ID to list of forward tensor args + self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} + # Caching chunk outputs for final output merge or reduction + self.output_chunks: List[Any] = [] + + # Initialize has_backward to false; this will be set to true if loss + # function is passed to pipeline schedule + self.has_backward = False + # Log prefix + self.log_prefix = f"[Stage {self.stage_index}]" + + # Forward infra + self.args_recv_info: Dict[int, Tuple[InputInfo, ...]] = {} + self.set_requires_grad: Dict[int, bool] = {} + self.act_send_info: Dict[int, List] = {} + + # Backward infra will created lazily + self.grad_recv_info: Dict = {} + self.grad_send_info: Optional[List] = None + + # Number of backward chunks seen. This is used to determine when to do + # grad reduction in DDP or FSDP. + self._seen_bwd_chunks = 0 + + # To be populated later by the Schedule + self.chunks: Optional[int] = None + self.stage_index_to_group_rank: Dict[int, int] = { + i: i % self.group_size for i in range(self.num_stages) + } + + @property + def has_backward(self) -> bool: + """ + Returns true if this stage has a backward pass. + """ + return self._has_backward + + @has_backward.setter + def has_backward(self, has_backward: bool): + self._has_backward = has_backward + + @property + def is_first(self): + """ + Returns true if this stage is the first stage in the pipeline. + """ + return self.stage_index == 0 + + @property + def is_last(self): + """ + Returns true if this stage is the last stage in the pipeline. + """ + return self.stage_index == self.num_stages - 1 + + def _check_chunk_id(self, chunk_id: int): + if self.chunks is None: + raise RuntimeError( + "Attempted to access chunk_id before chunks have been configured." + ) + if chunk_id >= self.chunks: + raise RuntimeError( + f"Chunk id {chunk_id} is out of range [0, {self.chunks})" + ) + + def _configure_outputs_meta(self, outputs_meta: Tuple[torch.Tensor, ...]): + """ + Track the output shapes/dtype of this stage since they determine the send operation(s) which must match + recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial + configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches + which could show up as hangs, silent corruption, or other errors. + """ + assert ( + self._outputs_meta is None + ), "Attempting to reconfigure output_meta, which is not supported" + self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] + + def get_outputs_meta(self) -> Tuple[torch.Tensor, ...]: + """Get the output metadata (meta tensors) reprensenting the outputs of this stage""" + assert ( + self._outputs_meta is not None + ), "Attempted to get_outputs_meta() without configuring output meta" + return self._outputs_meta + + def _create_grad_send_info( + self, + args_recv_info: Tuple, + ) -> List[Optional[int]]: + """ + Create a list of stage indices to send gradients to. + """ + grad_send_info: List[Optional[int]] = [] + + def map_recv_to_send(a): + # Note: we send gradients back to previous stage as long as in + # forward it is a received input, regardless of whether it requires + # grad. It is up to the previous stage to disgard this gradient. + if isinstance(a, _RecvInfo): + grad_send_info.append(a.source) + return a.source + else: + grad_send_info.append(None) + return None + + map_aggregate(args_recv_info, map_recv_to_send) + + logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info) + return grad_send_info + + @abstractmethod + def _prepare_forward_infra(self, num_microbatches: int): + raise NotImplementedError + + def _prepare_backward_infra(self, num_microbatches: int): + # TODO: this is needed for backward_maybe_with_nosync + self.chunks = num_microbatches + + for mb_index in range(num_microbatches): + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info[mb_index] = self._create_grad_recv_info( + self.act_send_info + ) + + @abstractmethod + def _create_grad_recv_info( + self, + act_send_info: Dict, + ) -> Tuple[_RecvInfo, ...]: + raise NotImplementedError + + def _get_recv_ops( + self, + recv_infos: Tuple[InputInfo, ...], + ) -> List[dist.P2POp]: + """ + Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. + Returns a list of ops that correspond to the recv infos. + """ + ops: List[dist.P2POp] = [] + for info in recv_infos: + if not isinstance(info, _RecvInfo): + continue + + peer_rank = self.stage_index_to_group_rank[info.source] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) # TODO + ops.append( + dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group) + ) + + return ops + + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the input arguments + for this stage. + """ + recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] + + # In case there is backward pass, set requires_grad for receive buffers + # before first forward + if self.has_backward and not self.set_requires_grad[fwd_chunk_id]: + for a in recv_infos: + if isinstance(a, _RecvInfo): + a.buffer.requires_grad_(True) + + return self._get_recv_ops(recv_infos) + + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the gradients + for this stage. + """ + if not self.has_backward or self.is_last: + return [] + + recv_infos = self.grad_recv_info[bwd_chunk_id] + return self._get_recv_ops(recv_infos) + + def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: + """ + Get the activation send ops for current stage's forward. + """ + output = self.output_chunks[fwd_chunk_id] + # Unify output form to tuple for easy correspondance with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + + ops: List[dist.P2POp] = [] + + for idx, out in enumerate(output_tuple): + dst_stages = self.act_send_info[idx] + for dst in dst_stages: + if dst is None: + continue + logger.debug( + "%s Sending tensor to Stage %s: %s", + self.log_prefix, + dst, + out.size(), + ) + peer_rank = self.stage_index_to_group_rank[dst] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) # TODO + ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group)) + + return ops + + def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: + """ + Get the gradient send ops for current stage's backward. + """ + self._check_chunk_id(bwd_chunk_id) + + if not self.has_backward or self.is_first: + return [] + + # Create bwd send infra lazily + if self.grad_send_info is None: + # Send info for input grads during backward: + # List of destinations corresponding to input grads + # Can be None if an input has no grad + # `grad_send_info` is a mirror of `args_recv_info` + self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0]) + + ops: List[dist.P2POp] = [] + for grad, grad_recv_stage in zip(self.grads_input, self.grad_send_info): + if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: + logger.debug( + "%s Sending gradient to Stage %s: %s", + self.log_prefix, + grad_recv_stage, + grad.size(), + ) + peer_rank = self.stage_index_to_group_rank[grad_recv_stage] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) # TODO + ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group)) + else: + if not (grad is None and grad_recv_stage is None): + raise RuntimeError( + f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " + f"and is expecting to send gradients to stage {grad_recv_stage}" + ) + return ops + + def clear_runtime_states(self) -> None: + """ + Clear runtime states of the stage. + """ + # map microbatch ID to list of forward tensor args + self.fwd_cache.clear() + # Caching chunk outputs for final output merge or reduction + self.output_chunks.clear() + # Reset bwd chunk counter + self._seen_bwd_chunks = 0 + + # Clear grad of input buffers in between schedule steps. This is because + # `torch.autograd.backward()` will accumulate gradients into leaf + # tensors by default. For gradients to pass back to previous stages, we + # don't want such accumulation. + for recv_tuple in self.args_recv_info.values(): # iterate over all chunks + for a in recv_tuple: # iterate over all input args + if isinstance(a, _RecvInfo): + # Set to None is the newer and recommended way to clear grads, compared to `zero_()`. + # See https://github.com/pytorch/pytorch/pull/92731 + a.buffer.grad = None + + def _map_tensor_from_recv_info( + self, + recv_infos: Tuple[InputInfo, ...], + ): + """ + Map tensors from recv infos to a list. + """ + + def get_recv_tensor(info): + if isinstance(info, _RecvInfo): + return info.buffer + else: + raise AssertionError(f"Expected _RecvInfo but got {type(info)}") + + tensors = map_aggregate( + recv_infos, + get_recv_tensor, + ) + + return tensors + + def _retrieve_recv_activations(self, fwd_chunk_id: int): + """ + Retrieve the activations received for the current stage during forward. + """ + recv_infos = self.args_recv_info[fwd_chunk_id] + activations = self._map_tensor_from_recv_info(recv_infos) + return activations + + def _retrieve_recv_grads( + self, + bwd_chunk_id: int, + ): + """ + Retrieve the gradients received for the current stage during backward. + """ + recv_infos = self.grad_recv_info[bwd_chunk_id] + grads = self._map_tensor_from_recv_info(recv_infos) + return grads + + def forward_maybe_with_nosync(self, *args, **kwargs): + # If submod is wrapped with DDP, we use the `no_sync` context manager to + # avoid gradient all-reduce per microbatch + if isinstance(self.submod, DistributedDataParallel): + with self.submod.no_sync(): # type: ignore[operator] + out_val = self.submod(*args, **kwargs) + else: + out_val = self.submod(*args, **kwargs) + return out_val + + def backward_maybe_with_nosync(self, backward_type, bwd_kwargs: Dict): + """ + Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the + other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but + there are additional state-variables and performance considerations depending on the data parallelism used. + This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. + """ + full_backward = bwd_kwargs["full_backward"] + if full_backward: + last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] + else: + # For backwards are split into weight and input, we will see twice as many bwd_chunks + last_backward = self._seen_bwd_chunks == 2 * self.chunks - 1 # type: ignore[operator] + + def perform_backward(backward_type): + if backward_type == "full": + return lambda: stage_backward( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + ) + elif backward_type == "input": + return lambda: stage_backward_input( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + self.submod.parameters(), + ) + elif backward_type == "weight": + return lambda: stage_backward_weight( + self.submod.parameters(), bwd_kwargs["param_groups"] + ) + else: + raise RuntimeError(f"Unknown backward type: {backward_type}") + + # If submod is wrapped by DDP + if isinstance(self.submod, DistributedDataParallel): + if last_backward: + # Last chunk, prepare for gradient reduction + # HACK: reaching into DDP implementation details here. Is there a better way? + self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] + list( + torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] + bwd_kwargs["stage_output"] + ) + ) + ) + result = perform_backward(backward_type)() + else: + with self.submod.no_sync(): # type: ignore[operator] + result = perform_backward(backward_type)() + # If submod is a FSDP module + elif isinstance(self.submod, FSDPModule): + self.submod.set_is_last_backward(False) + self.submod.set_reshard_after_backward(False) + self.submod.set_requires_gradient_sync(False) + result = perform_backward(backward_type)() + if last_backward: + # Manually call post backward for FSDP + def run_post_backward(fsdp_module: FSDPModule) -> None: + fsdp_module.set_is_last_backward(True) + fsdp_module.set_reshard_after_backward(True) + fsdp_module.set_requires_gradient_sync(True) + fsdp_state = fully_shard.state(fsdp_module) + for state in fsdp_state._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.post_backward() + + run_post_backward(self.submod) + else: + # Non-DP submodule, regular backward + result = perform_backward(backward_type)() + + self._seen_bwd_chunks += 1 + + if isinstance(result, tuple) and len(result) == 2: + # for stage_backward_input() + grads, param_groups = result + else: + grads, param_groups = result, None + + return grads, param_groups + + def forward_one_chunk( + self, + fwd_chunk_id: int, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Perform forward pass on the stage with one microbatch. + `args` and `kwargs` are the inputs from *external* to this stage. They + applies only to the first stage in most cases. + """ + + if self.is_first: + # First stage doesn't need to receive anything + composite_args = args + composite_kwargs = kwargs or {} + else: + # Receive activations for this chunk + # Activations only come in args form + composite_args = self._retrieve_recv_activations(fwd_chunk_id) + composite_kwargs = {} + + self._validate_fwd_input(args, kwargs) + + # Compute forward + try: + output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs) + + except Exception as e: + exc_msg = f""" + {self.log_prefix} failed to run forward: + args: {map_debug_info(composite_args)} + kwargs: {map_debug_info(composite_kwargs)} + """ + raise RuntimeError(exc_msg) from e + + if type(output) is list: + # HACK: this is a hacky workaround for the fact that export creates + # output in list format + output = tuple(output) + + # Unify output form to tuple for easy correspondance with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + # Prepare for final output merge or reduction + self.output_chunks.append(output) + + # Save activations and inputs for backward + flat_args = flatten_args(composite_args) + flat_kwargs = flatten_args(composite_kwargs) + flatten_input_tensors = flat_args + flat_kwargs + self.fwd_cache[fwd_chunk_id] = ( + output_tuple, # stage_output + flatten_input_tensors, # input_values + ) + + logger.debug( + "%s Forwarded chunk %s, outputs: %s", + self.log_prefix, + fwd_chunk_id, + map_debug_info(output), + ) + self._validate_fwd_outputs(output_tuple) + return output + + def backward_one_chunk( + self, bwd_chunk_id: int, loss=None, full_backward: bool = True + ): + """ + Perform backward pass on the module. + This should only be called once per microbatch. + + If full_backward is True (the default), the full backward pass including weight and input gradients will be run, + and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id. + + If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time, + and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward. + """ + self._check_chunk_id(bwd_chunk_id) + + ( + stage_output, + input_values, + ) = self.fwd_cache.pop(bwd_chunk_id) + + # Compute backward + if self.is_last: + # Last stage computes gradients from loss and has no gradients from + # next stage + bwd_kwargs = { + "stage_output": loss, + "output_grads": None, + "input_values": input_values, + } + else: + # Otherwise, receive gradients from next stage + grads_output = self._retrieve_recv_grads(bwd_chunk_id) + # If an input to the pipeline requires gradient, + # `torch.autograd.backward` will accumulate the gradient into the + # `.grad` field of such input + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": grads_output, + "input_values": input_values, + } + + # Save full_backward + bwd_kwargs["full_backward"] = full_backward + + # Custom backward function + if self.dw_builder: + # TODO: We may want to change our semantics so we are allowed to ignore + # the 'dw_builder' and call full_backward directly when it is a full_backward op. + self.grads_input, _ = self.backward_maybe_with_nosync("full", bwd_kwargs) + if full_backward: + self.dw_builder()() + else: + self.dw_runner[bwd_chunk_id] = self.dw_builder() + else: + if full_backward: + self.grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs + ) + else: + # perform the partial backwards for the inputs with a custom backward function + # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors + if isinstance(bwd_kwargs["stage_output"], torch.Tensor): + bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],) + + grads_input, param_groups = self.backward_maybe_with_nosync( + "input", bwd_kwargs + ) + + # TODO: we dont need to save this, add to dw_runner? + self.backward_state[bwd_chunk_id] = ( + input_values, + param_groups, + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + ) + self.grads_input = grads_input + # Save a placeholder for the dw_runner + self.dw_runner[bwd_chunk_id] = lambda: None + logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) + + def backward_weight_one_chunk(self, bwd_chunk_id: int): + assert bwd_chunk_id in self.dw_runner, ( + f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}" + " without first calling `backward_one_chunk(full_backward=False)`" + ) + + if self.dw_builder is not None: + self.dw_runner.pop(bwd_chunk_id)() + else: + ( + input_values, + param_groups, + stage_output, + output_grads, + ) = self.backward_state.pop(bwd_chunk_id) + + if self.stage_index != 0: + bwd_kwargs = { + "stage_output": stage_output, + "param_groups": param_groups, + "full_backward": False, + } + weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs) + else: + # TODO: figure out a better way to do this: + # if inputs does not require gradient, + # then the parameter group will not be fully captured during stage_backward_input + # in this case, we need call grad directly on the parameters + # To solve: make input fn do the intersect compute and then finish it off during W + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": input_values, + "full_backward": False, + } + self.backward_maybe_with_nosync("full", bwd_kwargs) + + def _validate_fwd_input(self, args, kwargs): + """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" + + if self.is_first: + # TODO why is there a separate recv_info for each pipeline chunk? + # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we + # check all chunks against args_recv_info[0] + expected_args = self.args_recv_info[0] + else: + # We don't check inputs for non-0 stages assuming they don't accept + # user inputs in canonical pipeline scenarios + return + + if len(kwargs): + # TODO- need a mapping of kwarg to position in self.args_recv_info + # without it, we just validate shapes for args and ignore kwargs + expected_args = expected_args[: len(expected_args) - len(kwargs)] + + # TODO- need a mapping of kwarg to position in self.args_recv_info + # maybe it's impossible to tell whether the len mismatches because + # (a) the user passed an extra arg or missed an arg + # (b) the user did not pass a kwarg, which has a default value baked into expected_args + expected_tensors_meta = [ + e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer + for e in expected_args + ] + validate_tensors_metadata( + f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args + ) + + def _validate_fwd_outputs(self, outputs: Tuple[torch.Tensor, ...]): + """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. + Most likely, this could be cause either by incorrect user specification of output shapes, or becuase + shape inference was done on the original model but then at runtime the model is wrapped with something like + mixed precision which changes output dtype. + """ + expected_tensors_meta = self.get_outputs_meta() + validate_tensors_metadata( + f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs + ) + + +class _PipelineStage(_PipelineStageBase): + def __init__( + self, + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: Optional[dist.ProcessGroup] = None, + ): + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and a `pipe_info` describing the stage relationship of the pipeline. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + """ + _PipelineStageBase.__init__( + self, + stage_module, + stage_index, + pipe_info.num_stages, + device, + group, + ) + self.pipe_info = pipe_info + + # Find stage nodes in graph + submod_nodes = [ + node for node in pipe_info.graph.nodes if node.op == "call_module" + ] + if len(submod_nodes) != self.num_stages: + raise AssertionError( + f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}" + ) + + # Find my stage node in graph + self.node = submod_nodes[self.stage_index] + self.name = self.node.name + logger.info( + "[%s] Creating PipelineStage %s for %s", + self.group_rank, + stage_index, + self.name, + ) + + # Create mapping from stage name to stage index + self.submod_to_stage_index: Dict[str, int] = {} + for i, node in enumerate(submod_nodes): + self.submod_to_stage_index.setdefault(node.name, i) + + # Cast submodule to device + self._move_submod_to_device() + + def _move_submod_to_device(self): + # Move submodule to indicated device if possible + # Note: we cannot move meta module to real devices because meta tensors + # do not support to() method. One needs to do an in-place tensor swap in + # that case. + has_meta_param = any( + isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters() + ) + if has_meta_param: + logger.debug("%s Found meta parameters!", self.log_prefix) + else: + self.submod.to(self.device) + + def _prepare_forward_infra(self, num_microbatches: int): + """ + Create send/recv infrastructures for activations (during forward) + """ + # Flag per chunk to keep track of whether we have set `requires_grad` + # for receive buffers. Format: {chunk : Boolean} + for chunk in range(num_microbatches): + self.args_recv_info[chunk] = self._create_act_recv_info() + self.set_requires_grad[chunk] = False + + # Send info during forward for each activation + self.act_send_info = self._create_act_send_info() + + def get_stage_index_of_submod( + self, + submod_name: str, + ): + """ + Given a submodule name, return the stage index of the submodule. + """ + if submod_name not in self.submod_to_stage_index: + raise AssertionError(f"Stage id of {submod_name} not found") + + return self.submod_to_stage_index[submod_name] + + def _create_act_recv_info( + self, + ): + """ + Create a tuple of `_RecvInfo` for inputs to the stage. + """ + + def create_recv_tensor(placeholder, arg_node): + """ + Create a receive buffer for a placeholder. + """ + example_value = placeholder.meta["val"] + if arg_node.op == "placeholder": + # This is a root level placeholder, thus an input argument to the entire model. + # We are likely at stage 0, hence no need to create a receive buffer. + return _RootArgPlaceholder(example_value) + + # Figure out the source stage of this input + while arg_node.target is operator.getitem: + # If the input is a getitem, we need to go deeper + arg_node = arg_node.args[0] + + assert ( + arg_node.op == "call_module" + ), f"Expecting call_module, got {arg_node.op}" + src_stage = self.get_stage_index_of_submod(arg_node.name) + + # Create a receive buffer for this placeholder + logger.debug( + "%s Creating recv buffer for input '%s' : %s, %s", + self.log_prefix, + placeholder.name, + example_value.shape, + example_value.dtype, + ) + buffer = _make_tensor_from_meta(example_value, self.device) + + return _RecvInfo( + arg_node.name, + src_stage, + buffer, + ) + + args_recv_info: List[InputInfo] = [] + # Filter out placeholder nodes from `self.submod` (a GraphModule) + placeholders = filter( + lambda node: node.op == "placeholder", self.submod.graph.nodes + ) + # `placeholders` are nodes internal to submod. + # `self.node.args` are dependency nodes in the outer graph. + # The two are 1:1. + for placeholder, arg_node in zip(placeholders, self.node.args): + # Create a receive buffer for this placeholder + recv_info = create_recv_tensor(placeholder, arg_node) + args_recv_info.append(recv_info) + + logger.debug( + "%s Activation recv / args info: %s", self.log_prefix, args_recv_info + ) + # `args` is a Tuple, hence we will return a Tuple[InputInfo] + return tuple(args_recv_info) + + def find_dst_rank( + self, + user: fx.Node, + ) -> Optional[int]: + """ + Find the destination rank of a `user` node. + If the `user` is not a submod, `None` may be returned. + """ + if user.op == "call_module": + # User is a stage (`call_module`) + return self.get_stage_index_of_submod(user.name) + else: + # - If user.op == "output": + # No need to send back to rank 0 + # - If user.target is stage_backward: + # No need to send assuming submod output is stored locally or + # should be re-calucated in case of activation checkpointing + return None + + def _create_act_send_info(self): + """ + Create a dict of send info for activations. + The dict is of the form: + { + output_index: [dst_rank_0, dst_rank_1, ...], + ... + } + where the list of `dst_rank`s covers the case where an output value may + be consumed by multiple stages. + """ + # Output index: List of receiver ranks + act_send_info: Dict[int, List] = {} + out_idx = 0 + + for user in self.node.users: + if user.target is operator.getitem: + # Recursively find the real destination + gi_dsts = act_send_info.setdefault(out_idx, []) + for gi_user in user.users: + dst_rank = self.find_dst_rank(gi_user) + if dst_rank is not None: + gi_dsts.append(dst_rank) + # Next `getitem` will point to the next output index + out_idx += 1 + else: + # In case of single output value, `out_idx` will not increase + dsts = act_send_info.setdefault(out_idx, []) + dst_rank = self.find_dst_rank(user) + if dst_rank is not None: + dsts.append(dst_rank) + + output_node = self._get_output_node() + output_vals: Tuple[torch.Tensor] = tuple( + v.meta["val"] for v in flatten_args(output_node.args) + ) + self._configure_outputs_meta(output_vals) + + logger.debug("%s Send info: %s", self.log_prefix, act_send_info) + return act_send_info + + def _get_output_node(self): + output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + return output_node + + def _create_grad_recv_info( + self, + act_send_info: Dict, + ) -> Tuple[_RecvInfo, ...]: + """ + Create a tuple of `_RecvInfo` for gradients. + """ + # Dict[output_index, _RecvInfo] + grad_recv_info: Dict[int, _RecvInfo] = {} + output_node = self._get_output_node() + + # The output node may take multiple args, meaning the submod having multiple output values. + output_vals = flatten_args(output_node.args) + + for out_idx, dst_list in act_send_info.items(): + if not dst_list: + # No actual receiver for activation so no grad coming back + continue + + output = output_vals[out_idx] + example_value = output.meta["val"] + logger.debug( + f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004 + f": {example_value.shape}, {example_value.dtype}" + ) + + # TODO: otherwise needs grad accumulation + assert len(dst_list) == 1, "Backward of skip connections not supported yet" + grad_src = dst_list[0] + grad_recv_info[out_idx] = _RecvInfo( + f"{grad_src}", # noqa: G004 + grad_src, + _make_tensor_from_meta(example_value, self.device), + ) + + # Convert to tuple for convenience in get_ops and retrieve tensor + grad_recv_info_tuple = tuple(grad_recv_info.values()) + logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple) + return grad_recv_info_tuple + + +# A helper function to create a pipeline stage based on traced pipeline information +def build_stage( + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: Optional[dist.ProcessGroup] = None, +) -> _PipelineStage: + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and pipeline information. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + + Returns: + _PipelineStage: a pipeline stage that can run with `PipelineSchedules`. + """ + return _PipelineStage( + stage_module, + stage_index, + pipe_info, + device, + group, + ) + + +# Manual PipelineStage functions and definition + +METADATA_TENSOR_LEN = 100 +PLACEHOLDER_VAL = -1 + + +def _create_empty_tensors( + tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device +) -> List[torch.Tensor]: + """ + Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s), + and places them on the specified device. + Args: + tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s). + device (torch.device): The device where the new tensors will be placed. + Returns: + List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s). + """ + if isinstance(tensor, torch.Tensor): + return [torch.empty_like(tensor, device=device)] + elif isinstance(tensor, (list, tuple)): + return [torch.empty_like(t, device=device) for t in tensor] + raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors") + + +def _create_metadata_tensor( + tensors: Optional[List[torch.Tensor]] = None, + device: Optional[torch.device] = torch.device("cpu"), +) -> torch.Tensor: + """ + Create a metadata tensor that can be sent over the wire. + This tensor contains the number of dimensions and the shape of each tensor being sent. + + The data is of format [num_dims, dim1, dim2, ...]. + If the tensor is None, a tensor of only placeholder values will be returned. + + Inputs: + tensors: A list of tensors, the tensors will converted into its shape dimensions and + these dimensions will be concatenated. + device: The device where the metadata tensor will be created. + If the tensor is None, then this tensor will contain PLACEHOLDER_VALs. + + """ + metadata_tensor = torch.full( + (METADATA_TENSOR_LEN,), + PLACEHOLDER_VAL, + dtype=torch.int32, + device=device, + ) + if tensors: + # Create a list of tensors containing the number of dimensions and the shape of each tensor + data = [ + # data is of format [num_dims, dim1, dim2, ...] + torch.tensor( + [len(tensor.shape)] + list(tensor.shape), + dtype=torch.int32, + device=device, + ) + for tensor in tensors + ] + # Concatenate the data into a single tensor + data_tensor = torch.cat(data) + dt_shape = data_tensor.shape[0] + if dt_shape > METADATA_TENSOR_LEN: + raise ValueError( + f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})." + ) + metadata_tensor[:dt_shape] = data_tensor + return metadata_tensor + + +def _extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]: + """ + Extract the number of dimensions and the shape of each tensor from a metadata tensor. + """ + metadata: List[torch.Size] = [] + i = 0 + while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL: + num_dims = int(tensor[i].item()) + shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist()) + metadata.append(shape) + i += num_dims + 1 + return metadata + + +def _get_stage_shapes( + stage_modules: List[nn.Module], + stage_ids: List[int], + num_stages: int, + rank: int, + world_size: int, + device: torch.device, + microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, +): + """ + Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of + virtual pipelining) and returns the shape of the inputs and outputs of the module. + Only the first stage must pass in a microbatch. + + Each rank must call _get_stage_shapes or the program will hang. + + Args: + stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any + non-interleaved schedules and >1 for any interleaved schedules. + stage_ids: The id of the stages assigned to this rank. + num_stages: Total number of stages. + rank: Rank of the current process. + world_size: Number of processes participating in the pipeline. + device: Device where the tensors are allocated. + + Returns a dictionary containing the following keys: + "inputs": Shape of the inputs to the module + "outputs": Shape of the outputs of the module + """ + + stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {} + for stage_id, model in zip(stage_ids, stage_modules): + input_shape_metadata_tensor = _create_metadata_tensor(device=device) + # TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1 + prev_rank = (rank - 1) % world_size + next_rank = (rank + 1) % world_size + shapes = {} + + # first stage doesn't receive anything and uses a microbatch + if stage_id == 0: + if microbatch is None: + raise RuntimeError("Microbatch is required for first stage") + example_fwd_inputs = microbatch + if isinstance(example_fwd_inputs, torch.Tensor): + example_fwd_inputs = [example_fwd_inputs] + else: + # other stages must receive shape information + # TODO: send/recv should take a group, rather than use the default group + dist.recv(input_shape_metadata_tensor, prev_rank) + metadata = _extract_metadata_from_tensor(input_shape_metadata_tensor) + example_fwd_inputs = [ + torch.empty(shape_list, device=device) for shape_list in metadata + ] + shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs] + + # perform forward + # TODO: if forward fails raise a more descriptive error explaining which stage failed + fwd_outputs = model(*example_fwd_inputs) + fwd_outputs = _create_empty_tensors(fwd_outputs, device) + shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs] + + # send shape dims + if stage_id != num_stages - 1: + output_shape_metadata_tensor = _create_metadata_tensor( + fwd_outputs, device=device + ) + dist.send(output_shape_metadata_tensor, next_rank) + stage_id_to_shapes[stage_id] = shapes + logger.info(stage_id_to_shapes) + return stage_id_to_shapes + + +class PipelineStage(_PipelineStageBase): + """ + A class representing a pipeline stage in a pipeline parallelism setup. + This class is created manually by providing a example input (and optionally output) + as opposed to the PipelineStage class that is outputed from pipeline(). + This class extends the `_PipelineStageBase` class and can similarly be used + in `PipelineScheule`. + + Args: + submodule (nn.Module): The PyTorch module wrapped by this stage. + stage_index (int): The ID of this stage. + num_stages (int): The total number of stages. + device (torch.device): The device where this stage is located. + input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. + output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. + group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. + dw_builder: TODO clean up comments + """ + + def __init__( + self, + submodule: nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None, + group: Optional[dist.ProcessGroup] = None, + dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + ): + super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) + self.submod.to(self.device) + # When we materialize the model partition on cuda, we call reset_parameters() if it is available + self.inputs: List[torch.Tensor] = [] + self.outputs: List[torch.Tensor] = [] + + self.inputs = _create_empty_tensors(input_args, device) + + if output_args is None: + logger.info("output_args not provided, performing forward using input_args") + self.outputs = self.submod(*self.inputs) + # create buffers for the output so that the data is in the correct + # shape in order to use in p2p op (send) + self.outputs = _create_empty_tensors(self.outputs, device) + else: + self.outputs = _create_empty_tensors(output_args, device) + + self._configure_outputs_meta(tuple(self.outputs)) + + # these are the buffers used in backwards send/recv, they are allocated later + self.outputs_grad: List[torch.Tensor] = [] + + def stage_global_rank(peer_rank): + return ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + + self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size) + self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size) + + logger.debug( + f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + f"{self.is_last=}, {self.num_stages=}, " + f"inputs: {[inp.shape for inp in self.inputs]}, " + f"output: {[output.shape for output in self.outputs]}" + ) + + def _prepare_forward_infra(self, num_microbatches: int) -> None: + # Receive info during forward + # TODO: create args_recv_info lazily? (same needed for PipelineStage) + for chunk_id in range(num_microbatches): + self.set_requires_grad[chunk_id] = False + if not self.is_first: + # We assume that we always receive from stage - 1 + recv_infos = tuple( + [ + _RecvInfo( + f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", + self.stage_index - 1, + _make_tensor_from_meta(inp, self.device), + ) + for inp in self.inputs + ] + ) + + self.args_recv_info[chunk_id] = recv_infos + else: + self.args_recv_info[chunk_id] = tuple( + [_RootArgPlaceholder(i) for i in self.inputs] + ) + + # Send info during forward for each activation + # only need the rank that is being sent to + self.act_send_info: Dict[int, List] = {} + for idx in range(len(self.outputs)): + # We assume we always send to stage + 1 + if not self.is_last: + self.act_send_info[idx] = [self.stage_index + 1] + else: + self.act_send_info[idx] = [] + + def _create_grad_recv_info( + self, + act_send_info: Dict, + ) -> Tuple[_RecvInfo, ...]: + grad_recv_info: Tuple[_RecvInfo, ...] = () + if not self.is_last: + # Receiving gradients from multiple sources is not supported + # hence we only take the first destination + grad_recv_info = tuple( + [ + _RecvInfo( + f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", + dst_list[0], + _make_tensor_from_meta(self.outputs[idx], self.device), + ) + for idx, dst_list in act_send_info.items() + ] + ) + return grad_recv_info + + def _init_p2p_neighbors(self): + """ + Set up p2p communitors between previous and next stages + by sending a dummy tensor. + + If this is used, must be called for all pipeline stages. + """ + ops = [] + recv_tensor = torch.zeros(1, device="cuda") + send_tensor = torch.ones(1, device="cuda") + # forward + if not self.is_first: + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_stage, self.group)) + if not self.is_last: + ops.append(dist.P2POp(dist.isend, send_tensor, self.next_stage, self.group)) + + # backward + if not self.is_first: + ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_stage, self.group)) + if not self.is_last: + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_stage, self.group)) + + return True + + +def _validate_stage_shapes(pipeline_stages: List[PipelineStage]): + """ + Check that the buffer shapes match between stages was expected by performing an all_gather between + all stages. + """ + if len(pipeline_stages) == 0: + raise ValueError("No pipeline stages provided.") + + virtual_pipeline_size = len(pipeline_stages) + all_inputs = [] + all_outputs = [] + world_size = pipeline_stages[0].group_size + num_stages = pipeline_stages[0].num_stages + + # perform all gathers between all stages + for virtual_id, stage in enumerate(pipeline_stages): + world_size = stage.group_size + stage_id: int = stage.stage_index + rank = stage.group_rank + # check that world_size and num_stages are consistent across all stages + if stage.group_size != world_size: + raise ValueError( + f"Stage id {stage_id} has world size ({stage.group_size}) \ + which does not match world size ({world_size}) of other stages." + ) + if stage.num_stages != num_stages: + raise ValueError( + f"Stage id {stage_id} has num stages ({stage.num_stages}) \ + which does not match num stages ({num_stages}) of other stages." + ) + + pg_rank = dist.get_rank(stage.group) + if rank != pg_rank: + raise ValueError( + f"Rank {rank} is not equal to process group rank {pg_rank}" + ) + + if (num_stages := stage.num_stages) % world_size != 0: + raise ValueError( + f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})" + ) + + # all gather each ranks inputs + tensor_list = [ + _create_metadata_tensor(device=stage.device) + for _ in range(stage.group_size) + ] + expected_inputs = stage.inputs + stage_input = _create_metadata_tensor(expected_inputs, device=stage.device) + dist.all_gather(tensor_list, stage_input) + stage_input_shapes = [ + _extract_metadata_from_tensor(tensor) for tensor in tensor_list + ] + + # all gather each ranks outputs + tensor_list = [ + _create_metadata_tensor(device=stage.device) + for _ in range(stage.group_size) + ] + expected_outputs = stage.outputs + stage_output = _create_metadata_tensor(expected_outputs, device=stage.device) + dist.all_gather(tensor_list, stage_output) + stage_output_shapes = [ + _extract_metadata_from_tensor(tensor) for tensor in tensor_list + ] + + logger.debug( + f"Rank: {pg_rank}" # noqa: G004 + f"Stage id: {stage_id}" + f"Stage num stages: {stage.num_stages}" + f"Stage rank: {rank}" + f"Stage world size: {world_size}" + f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}" # noqa: G003 + f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}" # noqa: G003 + ) + + all_inputs.extend(stage_input_shapes) + all_outputs.extend(stage_output_shapes) + + # log only rank 0's view, they will all be equivalent + if pg_rank == 0: + logger.info( + "all stage inputs: %s \n all stage outputs: %s", all_inputs, all_outputs + ) + + # Check if the output for stage 0 matches the input at stage 1, and so forth + for i in range(virtual_pipeline_size * world_size - 1): + if (out := all_outputs[i]) != (inp := all_inputs[i + 1]): + raise ValueError( + f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}." + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__init__.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2746f60ba8720fa05fb06fa6cd60d7a57558417 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__init__.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +import torch.distributed.tensor._ops # force import all built-in dtensor ops +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 +from torch.distributed.tensor._api import ( + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + ones, + rand, + randn, + zeros, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.optim.optimizer import ( + _foreach_supported_types as _optim_foreach_supported_types, +) +from torch.utils._foreach_utils import ( + _foreach_supported_types as _util_foreach_supported_types, +) + + +# All public APIs from dtensor package +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "Shard", + "Replicate", + "Partial", + "Placement", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + + +# Append DTensor to the list of supported types for foreach implementation for optimizer +# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. +if DTensor not in _optim_foreach_supported_types: + _optim_foreach_supported_types.append(DTensor) + +if DTensor not in _util_foreach_supported_types: + _util_foreach_supported_types.append(DTensor) + + +# Set namespace for exposed private names +DTensor.__module__ = "torch.distributed.tensor" +distribute_tensor.__module__ = "torch.distributed.tensor" +distribute_module.__module__ = "torch.distributed.tensor" +ones.__module__ = "torch.distributed.tensor" +empty.__module__ = "torch.distributed.tensor" +full.__module__ = "torch.distributed.tensor" +rand.__module__ = "torch.distributed.tensor" +randn.__module__ = "torch.distributed.tensor" +zeros.__module__ = "torch.distributed.tensor" diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_api.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_api.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2993690e74e20c6cbd024a760baeb090541ef3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_api.py @@ -0,0 +1,1231 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import inspect +import warnings +from typing import Any, Callable, cast, Optional, Sequence, Tuple + +import torch +import torch.distributed.tensor._dispatch as op_dispatch +import torch.distributed.tensor._random as random +import torch.nn as nn +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._random import ( + is_rng_supported_mesh, + OffsetBasedRNGTracker, +) +from torch.distributed.tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, +) +from torch.distributed.tensor._utils import ( + compute_global_tensor_info, + compute_local_shape, + normalize_to_torch_size, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +aten = torch.ops.aten + + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure DTensor to work +# together with torch.Tensor within the autograd engine. This +# allows DTensor to only exist on part of the module hierarchy. +# +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DTensor params, the following forward/backward should work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input (from_local) -> Sharded Module B -> DTensor output +# -> torch.Tensor output (to_local) -> Module C +# +# So from_local/to_local must be Autograd functions. +# +class _ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + input: "DTensor", + grad_placements: Optional[Sequence[Placement]], + ): + ctx.dtensor_spec = input._spec + ctx.grad_placements = grad_placements + local_tensor = input._local_tensor + + # We need to return a fresh Tensor object there as autograd metadata + # will be inplaced into it. So we don't want to pollute the Tensor + # object stored in the _local_tensor of this DTensor. + return local_tensor.view_as(local_tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + + return ( + DTensor( + grad_output, + grad_spec, + requires_grad=grad_output.requires_grad, + ), + None, + ) + + +class _FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: Tuple[Placement, ...], + run_check: bool, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + + if shape and stride: + tensor_shape, tensor_stride = shape, stride + elif not shape and not stride: + # if it's not by default run_check, we assume user is certain that each + # rank has the same tensor shape, and we just use that to calculate the + # global shape + global_shape, global_stride = compute_global_tensor_info( + input, device_mesh, placements + ) + tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) + else: + raise RuntimeError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time.", + ) + + if device_mesh.get_coordinate() is None: + # if the global rank is not participating in the device mesh, we + # simply set the local tensor to an empty tensor + input = input.new_empty(0, requires_grad=input.requires_grad) + elif run_check: + # TODO: support uneven sharding when global shape/stride not passed, by + # building the global TensorMeta during check_tensor_meta + check_shape_stride = not shape and not stride + check_tensor_meta(input, check_shape_stride=check_shape_stride) + # TODO: See if we need to make this run_check logic + # have a corresponding backward. + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + # only broadcast if run_check is True + input = input.contiguous() + mesh_broadcast(input, device_mesh, mesh_dim=idx) + + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + + # We want a fresh Tensor object that shares memory with the input tensor + dist_tensor = DTensor( + input.view_as(input), + dist_spec, + # requires_grad of the dist tensor depends on if input + # requires_grad or not + requires_grad=input.requires_grad, + ) + return dist_tensor + + @staticmethod + def backward(ctx, grad_output: "DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + current_spec = grad_output._spec + target_spec = DTensorSpec( + previous_device_mesh, + previous_placement, + tensor_meta=grad_output._spec.tensor_meta, + ) + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, is_backward=True + ) + # TODO: return the redistributed local tensor directly without + # differentiable backward. see if this make sense for all cases. + return output, None, None, None, None, None + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return grad_output.to_local(), None, None, None, None, None + + +class DTensor(torch.Tensor): + """ + ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like + abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding + layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: + + * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension + * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension + * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension + + When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue + communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the + placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. + + To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` + requires every Tensor argument of the operator be DTensor. + + """ + + _local_tensor: torch.Tensor + _spec: DTensorSpec + __slots__ = ["_local_tensor", "_spec"] + + # _op_dispatcher instance as a class attribute to handle runtime dispatching logic + _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() + + @staticmethod + @torch._disable_dynamo + def __new__( + cls, + local_tensor: torch.Tensor, + spec: DTensorSpec, + *, + requires_grad: bool, + ) -> "DTensor": + """ + Construct a DTensor from a local tensor, device mesh, and placement and + other tensor properties (i.e. shape, requires_grad, strides, etc). + + .. note:: This is not a public API and it's only supposed to be used by the + operator implementations and internals. If you want to construct a + DTensor from a local tensor, consider using ``DTensor.from_local``, if + you want to construct a DTensor from a "global" tensor (where you + already have tensor initialized and want to shard this tensor), + consider using ``distribute_tensor``. + """ + if local_tensor.requires_grad and not requires_grad: + warnings.warn( + "To construct DTensor from torch.Tensor, it's recommended to " + "use local_tensor.detach() and make requires_grad consistent." + ) + + # new method instruct wrapper tensor from local_tensor and add + # placement spec, it does not do actual distribution + assert spec.tensor_meta is not None, "TensorMeta should not be None!" + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + spec.tensor_meta.shape, + strides=spec.tensor_meta.stride, + dtype=local_tensor.dtype, + device=local_tensor.device, + layout=local_tensor.layout, + requires_grad=requires_grad, + ) + + r._spec = spec + r._local_tensor = local_tensor + return r + + # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): + # TODO: consider all_gather the local tensors for better debugging + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + + def __tensor_flatten__(self): + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + return ["_local_tensor"], (self._spec, self.requires_grad) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + assert ( + flatten_spec is not None + ), "Expecting spec to be not None from `__tensor_flatten__` return value!" + local_tensor = inner_tensors["_local_tensor"] + spec, requires_grad = flatten_spec + unflatten_tensor_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + return DTensor( + local_tensor, + unflatten_spec, + requires_grad=requires_grad, + ) + + def __coerce_tangent_metadata__(self): + if not any(isinstance(p, Partial) for p in self.placements): + return self + placements = [ + Replicate() if isinstance(p, Partial) else p for p in self.placements + ] + return self.redistribute(device_mesh=self.device_mesh, placements=placements) + + def __coerce_same_metadata_as_tangent__(self, flatten_spec): + (spec, _) = flatten_spec # Result of tensor_flatten() + return self.redistribute( + device_mesh=self.device_mesh, + placements=spec.placements, + ) + + @classmethod + @torch._disable_dynamo + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + return DTensor._op_dispatcher.dispatch( + func, + args, + kwargs or {}, + ) + + @staticmethod + def from_local( + local_tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + run_check: bool = False, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + ) -> "DTensor": + """ + Create a :class:`DTensor` from a local torch.Tensor on each rank + according to the ``device_mesh`` and ``placements`` specified. + + Args: + local_tensor (torch.Tensor): local torch.Tensor on each rank. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + tensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the local torch.Tensor on DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + + Keyword args: + run_check (bool, optional): at a cost of extra communications, perform + sanity check across ranks to check each local tensor's meta information + to ensure correctness. If have :class:`Replicate` in ``placements``, the + data on first rank of the device mesh dimension will be broadcasted + to other ranks. default: False + shape (torch.Size, optional): A List of int which specifies the size of + DTensor which build on top of `local_tensor`. Note this needs to be + provided if the shape of ``local_tensor`` are different across the ranks. + If not provided, ``shape`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + stride (tuple, optional): A List of int which specifies the stride of DTensor. + If not provided, ``stride`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + + Returns: + A :class:`DTensor` object + + .. note:: When ``run_check=False``, it is the user's responsibility to ensure the + local tensor passed in is correct across ranks (i.e. the tensor is sharded for + the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). + If not, the behavior of the created DTensor is undefined. + + .. note:: ``from_local`` is differentiable, the `requires_grad` of the created + `DTensor` object will depend on if `local_tensor` requires_grad or not. + """ + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + else: + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + placements[idx] = Shard(placement.dim + local_tensor.ndim) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, + device_mesh, + tuple(placements), + run_check, + shape, + stride, + ) + + def to_local( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Get the local tensor of this DTensor on its current rank. For sharding it returns + a local shard of the logical tensor view, for replication it returns the replica on + its current rank. + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the Tensor returned from this + function. + `to_local` converts DTensor to local tensor and the returned local tensor + might not be used as the original DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original DTensor layout. + If not specified, we will assume the gradient layout remains the same + as the original DTensor and use that for gradient computation. + + Returns: + A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the + local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, + it means the local tensor is not ready yet (i.e. communication is not finished). In this + case, user needs to call ``wait`` to wait the local tensor to be ready. + + .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned + will depend on if the `DTensor` requires_grad or not. + """ + if not torch.is_grad_enabled(): + return self._local_tensor + + if grad_placements is not None and not isinstance(grad_placements, tuple): + grad_placements = tuple(grad_placements) + return _ToTorchTensor.apply( + self, grad_placements + ) # pyre-ignore[16]: autograd func + + def redistribute( + self, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + async_op: bool = False, + ) -> "DTensor": + """ + ``redistribute`` performs necessary collective operations that redistribute the current + DTensor from its current placements to a new placements, or from is current DeviceMesh + to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by + specifying a Replicate placement for each dimension of the DeviceMesh. + + When redistributing from current to the new placements on one device mesh dimension, we + will perform the following operations including communication collective or local operation: + + 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` + 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` + 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) + 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` + 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` + + + ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors + that are created either on 1-D or N-D DeviceMesh. + + Args: + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. If not specified, it would use the current DTensor's DeviceMesh. + default: None + placements (List[:class:`Placement`], optional): the new placements that + describes how to place the DTensor into the DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + default: replicate on all mesh dimensions + + Keyword args: + async_op (bool, optional): whether to perform the DTensor redistribute operation + asynchronously or not. Default: False + + Returns: + A :class:`DTensor` object + + .. note:: ``redistribute`` is differentiable, which means user do not need to worry about + the backward formula of the redistribute operation. + + .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, + Please file an issue if you need to redistribute DTensor to different DeviceMesh. + """ + # NOTE: This redistribute API currently only supports out + # of place redistribution, i.e. it always create a new + # DTensor object and leave the original one unchanged. + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self.device_mesh + # raise error if new placements not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + + placements = list(placements) + for i, placement in enumerate(placements): + if placement.is_partial(): + raise RuntimeError( + "Can not redistribute to Partial, redistributing to Partial is for internal use only!" + ) + elif isinstance(placement, Shard) and placement.dim < 0: + # normalize shard dim to be positive + placements[i] = Shard(placement.dim + self.ndim) + placements = tuple(placements) + + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + return Redistribute.apply(self, device_mesh, placements, async_op) + + def full_tensor( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Return the full tensor of this DTensor. It will perform necessary collectives + to gather the local tensors from other ranks in its DeviceMesh and concatenate + them together. It's a syntatic sugar of the following code: + + ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the full Tensor returned from this + function. + `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor + might not be used as the original replicated DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original replicated DTensor layout. + If not specified, we will assume the gradient layout of the full tensor be replicated. + + Returns: + A :class:`torch.Tensor` object that represents the full tensor of this DTensor. + + .. note:: ``full_tensor`` is differentiable. + """ + + redist_res = self.redistribute( + placements=[Replicate()] * self.device_mesh.ndim, async_op=False + ) + return _ToTorchTensor.apply(redist_res, grad_placements) + + @property + def device_mesh(self) -> DeviceMesh: + """ + The :class:`DeviceMesh` attribute that associates with this DTensor object. + + .. note:: ``device_mesh`` is a read-only property, it can not be set. + """ + return self._spec.mesh + + @property + def placements(self) -> Tuple[Placement, ...]: + """ + The placements attribute of this DTensor that describes the layout of this + DTensor on the its DeviceMesh. + + .. note:: ``placements`` is a read-only property, it can not be set. + """ + return self._spec.placements + + def __create_write_items__(self, fqn: str, object: Any): + from torch.distributed.checkpoint.planner_helpers import ( + _create_write_items_for_dtensor, + ) + + if hasattr(self._local_tensor, "__create_write_items__"): + return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_write_items_for_dtensor(fqn, object)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __create_chunk_list__(self): + from torch.distributed.checkpoint.planner_helpers import ( + _create_chunk_from_dtensor, + ) + + if hasattr(self._local_tensor, "__create_chunk_list__"): + return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_chunk_from_dtensor(self)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __get_tensor_shard__(self, index): + if hasattr(self._local_tensor, "__get_tensor_shard__"): + return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return self.to_local() + else: + raise RuntimeError("Unsupported tensor type!") + + +def distribute_tensor( + tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according + to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the + same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use + the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve + the single-device semantic. If you want to construct a DTensor in the middle of the Autograd + computation, please use :meth:`DTensor.from_local` instead. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use ``torch.chunk`` + semantic to shard the tensor and scatter the shards. The uneven sharding + behavior is experimental and subject to change. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as ``device_mesh.ndim``. If not specified, we will + by default replicate the tensor across the ``device_mesh`` from the + first rank of each dimension of the `device_mesh`. + + Returns: + A :class:`DTensor` or ``XLAShardedTensor`` object. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` + return `XLAShardedTensor` instead. see `this issue `__ + for more details. The XLA integration is experimental and subject to change. + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") + + # get default device mesh if there's nothing specified + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # call PyTorch/XLA SPMD for `xla` backend type device mesh. + # This returns XLAShardedTensor + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_tensor, + ) + + return xla_distribute_tensor( + tensor, device_mesh, placements + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + # instantiate a RNG tracker if haven't. By default DTensor uses an + # OffsetBasedRNGTracker to perform random operators. + # TODO: the value assignment to global variable is not the ideal solution + # we can replace it in future. + if not random._rng_tracker and is_rng_supported_mesh(device_mesh): + random._rng_tracker = OffsetBasedRNGTracker(device_type) + + if not tensor.is_leaf: + raise RuntimeError( + "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" + ) + + # convert tensor to the corresponding device type if it's not in that device type + if device_type != tensor.device.type and not tensor.is_meta: + tensor = tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + if len(placements) != device_mesh.ndim: + raise ValueError( + f"`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." + ) + if isinstance(tensor, DTensor): + # if the tensor is already a DTensor, we need to check: + # 1. if the we can further shard this DTensor if the two device mesh belong to + # the same parenet mesh and further sharding is possible. + # 2. check if device mesh and placements are the same + if tensor.device_mesh != device_mesh: + raise ValueError( + f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " + f"to a different device mesh {device_mesh}." + ) + if tensor.placements != tuple(placements): + raise ValueError( + f"Cannot distribute a DTensor with placements {tensor.placements} " + f"to a different placements {placements}. do you want to call " + f"`redistribute` instead?" + ) + return tensor + + local_tensor = tensor.detach() + + # TODO(xilun): address sharding order + # distribute the tensor according to the placements. + placements = list(placements) + for idx, placement in enumerate(placements): + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + # normalize shard placement dim + placement = Shard(placement.dim + tensor.ndim) + placements[idx] = placement + local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx) + elif placement.is_replicate(): + placement = cast(Replicate, placement) + local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) + placements = tuple(placements) + + assert local_tensor is not None, "distributing a tensor should not be None" + # detach the local tensor passed to DTensor since after the construction + # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + return DTensor( + local_tensor.requires_grad_(tensor.requires_grad), + spec, + requires_grad=tensor.requires_grad, + ) + + +def distribute_module( + module: nn.Module, + device_mesh: Optional[DeviceMesh] = None, + partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, + input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, + output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, +) -> nn.Module: + """ + This function expose three functions to control the parameters/inputs/outputs of the module: + + 1. To perform sharding on the module before runtime execution by specifying the + ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` + parameters according to the `partition_fn` specified). + 2. To control the inputs or outputs of the module during runtime execution by + specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to + :class:`DTensor`, convert the output back to ``torch.Tensor``) + + Args: + module (:class:`nn.Module`): user module to be partitioned. + device_mesh (:class:`DeviceMesh`): the device mesh to place the module. + partition_fn (Callable): the function to partition parameters (i.e. shard certain + parameters across the ``device_mesh``). If ``partition_fn`` is not specified, + by default we replicate all module parameters of ``module`` across the mesh. + input_fn (Callable): specify the input distribution, i.e. could control how the + input of the module is sharded. ``input_fn`` will be installed as a module + ``forward_pre_hook`` (pre forward hook). + output_fn (Callable): specify the output distribution, i.e. could control how the + output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be + installed as a module ``forward_hook`` (post forward hook). + + Returns: + A module that contains parameters/buffers that are all ``DTensor`` s. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` + return nn.Module with PyTorch/XLA SPMD annotated parameters. See + `this issue `__ + for more details. The XLA integration is experimental and subject to change. + + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_module") + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # This function annotates all module parameters for auto-partitioning with + # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters + # according to the `partition_fn` specified. + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_module, + ) + + return xla_distribute_module( + module, device_mesh, partition_fn, input_fn, output_fn + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: + # This function loop over the immediate module parameters and + # buffers, replicate all non DTensor params/buffers to DTensor + # parameters/buffers, if they have not been partitioned in the + # partition_fn, we can't easily use `module._apply` here + # because we don't know what happened inside partition_fn as + # user could do anything, i.e. install hooks, and we want to + # preserve those. + full_replicate = [Replicate()] * mesh.ndim + for key, param in m._parameters.items(): + if param is not None and not isinstance(param, DTensor): + m.register_parameter( + key, + nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), + ) + for key, buffer in m._buffers.items(): + if buffer is not None and not isinstance(buffer, DTensor): + m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) + + if partition_fn is None: + # if partition_fn not specified, we by default replicate + # all module params/buffers + for name, submod in module.named_modules(): + replicate_module_params_buffers(submod, device_mesh) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) + replicate_module_params_buffers(submod, device_mesh) + + # register input_fn as module forward pre hook + if input_fn is not None: + # check the input_fn signature + num_args = len(inspect.signature(input_fn).parameters) + if num_args == 2: + # input_fn only takes in inputs and device mesh + warnings.warn( + "Deprecating input_fn that takes two arguments (inputs, device_mesh), " + "please use input_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] + elif num_args == 3: + # input_fn takes in module, inputs, device mesh + module.register_forward_pre_hook( + lambda mod, inputs: input_fn(mod, inputs, device_mesh) + ) + else: + raise ValueError( + f"input_fn should take in 3 arguments, but got {num_args} arguments!" + ) + # register output_fn as module forward hook + if output_fn is not None: + num_args = len(inspect.signature(output_fn).parameters) + if num_args == 2: + # output_fn only takes in outputs and device mesh + warnings.warn( + "Deprecating output_fn that takes two arguments (inputs, device_mesh), " + "please use output_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) + ) + else: + raise ValueError( + f"output_fn should take in 3 arguments, but got {num_args} arguments!" + ) + + return module + + +# Below are tensor factory function APIs, which are used to create a DTensor directly. We need +# to make separate factory function APIs because tensor subclass could not override the tensor +# factory methods, and we need user to call the factory functions with user intended device_mesh +# and placements to create a proper DTensor. + + +def _dtensor_init_helper( # type: ignore[no-untyped-def] + init_op, + size: torch.Size, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + **kwargs, +) -> DTensor: + # from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + + # if device_mesh is None, use the one from mesh resources + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + kwargs["device"] = device_mesh.device_type + + # set default placements to replicated if not specified + placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) + + # check device_mesh againts placements + assert device_mesh.ndim == len( + placements + ), "mesh dimension does not match the length of placements" + + assert kwargs["layout"] == torch.strided, "layout value not supported!" + torch_stride = torch._prims_common.make_contiguous_strides_for(size) + + # get local tensor shape + local_shape = compute_local_shape(size, device_mesh, placements) + # initialize the local tensor + if init_op == torch.full: + fill_value = kwargs.pop("fill_value", 0) + local_tensor = init_op(local_shape, fill_value, **kwargs) + elif init_op == torch.rand or init_op == torch.randn: + # this tensor meta is not used except `shape` + dtype = kwargs.get("dtype", torch.get_default_dtype()) + + tensor_meta = TensorMeta(size, (0,), dtype) + spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) + + if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: + random._rng_tracker = random.OffsetBasedRNGTracker() + + assert random._rng_tracker is not None + with random._rng_tracker._distribute_region(spec): + local_tensor = init_op(local_shape, **kwargs) + else: + local_tensor = init_op(local_shape, **kwargs) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + + return DTensor( + local_tensor, + spec, + requires_grad=kwargs["requires_grad"], + ) + + +def ones( # type: ignore[no-untyped-def] + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined + by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.ones, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def empty( # type: ignore[no-untyped-def] + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` + is defined by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.empty, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def full( # type: ignore[no-untyped-def] + size, + fill_value, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and + ``placements``, with the shape defined by the argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + fill_value(Scalar): the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.full, + torch_size, + fill_value=fill_value, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def rand( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a uniform distribution + on the interval ``[0, 1)``. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.rand, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def randn( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a normal distribution + with mean 0 and variance 1. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.randn, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def zeros( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 0. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) + Keyword args: + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.zeros, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_collective_utils.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_collective_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34795858356ab88e973b228998aa8b51d7e7868b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_collective_utils.py @@ -0,0 +1,373 @@ +# mypy: allow-untyped-defs +import logging +import math +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._dtensor_spec as dtensor_spec +from torch._C._distributed_c10d import _resolve_process_group +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + broadcast, + get_global_rank, + get_group_rank, + get_rank, + GroupMember, + ProcessGroup, + scatter, + Work, +) + + +logger = logging.getLogger(__name__) + + +if not torch._running_with_deploy(): + + @torch.library.register_fake("_dtensor::shard_dim_alltoall") + def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): + group_size = _get_group_size_by_name(group_name) + stacked_list = [torch.empty_like(input) for _ in range(group_size)] + group = _resolve_process_group(group_name) + group_rank = get_group_rank(group, get_rank()) + + return torch.cat(stacked_list, dim=gather_dim).chunk(group_size, dim=shard_dim)[ + group_rank + ] + +else: + import warnings + + warnings.warn( + "PyTorch Distributed functional collectives do not work with torch::deploy." + ) + + +def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): + if mesh.device_type == "cpu": + # Gloo does not support alltoall, so falling back to allgather + chunk + + # TODO: This logs way too much + logger.warning( + "CPU process group does not support alltoall yet, falling back with allgather + chunk!" + ) + out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim)) + if isinstance(out, funcol.AsyncCollectiveTensor): + # stick to the same behavior for the alltoall case, remove this once we enable alltoall async + out = out.wait() + out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[ + mesh.get_local_rank(mesh_dim) + ] + return out.contiguous() if not out.is_contiguous() else out + + group_name = funcol._resolve_group_name((mesh, mesh_dim)) + # TODO: enable async op for shard_dim_alltoall + return torch.ops._dtensor.shard_dim_alltoall( + input, gather_dim, shard_dim, group_name + ) + + +def mesh_scatter( + output: torch.Tensor, + scatter_list: List[torch.Tensor], + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, +) -> Optional[Work]: + """ + scatter a list of tensors to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will + scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank + 2 to rank 2/3. + + Args: + output (torch.Tensor): the tensor to receive the scattered list. + scatter_list (List[torch.Tensor]): the tensor list to be scattered. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Returns: + A :class:`Work` object + """ + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if output.is_meta: + return None + dim_group = mesh.get_group(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + # src need to be global rank + src_for_dim = 0 + + if dim_group is not GroupMember.WORLD: + src_for_dim = get_global_rank(dim_group, 0) + + if src_for_dim == get_rank(): + fut = scatter( + output, + scatter_list=scatter_list, + src=src_for_dim, + group=dim_group, + async_op=async_op, + ) + else: + fut = scatter( + output, + scatter_list=None, + src=src_for_dim, + group=dim_group, + async_op=async_op, + ) + + return fut + + +def mesh_broadcast( + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, +) -> Optional[Work]: + """ + broadcast the tensor to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will + broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2 + to rank 2/3. + + Args: + tensor (torch.Tensor): tensor to broadcast. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Returns: + A :class:`Work` object + """ + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if tensor.is_meta: + return None + dim_group = mesh.get_group(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + # src need to be global rank + src_for_dim = 0 + if dim_group is not GroupMember.WORLD: + src_for_dim = get_global_rank(dim_group, 0) + + return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op) + + +def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: + if pad_size == 0: + return tensor + pad = [0, 0] * (tensor.ndim - pad_dim) + pad[-1] = pad_size + return torch.nn.functional.pad(tensor, pad) + + +def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: + if pad_size == 0: + return tensor + return tensor.narrow( + pad_dim, + start=0, + length=tensor.size(pad_dim) - pad_size, + ) + + +def fill_empty_tensor_to_shards( + shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int +) -> List[torch.Tensor]: + if num_empty_tensors == 0: + return shards + tensor_size = list(shards[0].size()) + tensor_size = [ + size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size) + ] + tensor = shards[0].new_zeros(tensor_size) + for _ in range(num_empty_tensors): + shards.append(tensor) + return shards + + +def check_tensor_meta( + local_tensor, check_shape_stride=False +) -> Optional["dtensor_spec.TensorMeta"]: + local_metadata = { + "dtype": local_tensor.dtype, + "requires_grad": local_tensor.requires_grad, + } + + if check_shape_stride: + local_metadata.update( + {"shape": local_tensor.shape, "stride": local_tensor.stride()} + ) + + gathered_metadata = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_metadata, local_metadata) + + # Check if metadata is consistent across ranks + if not all(meta == local_metadata for meta in gathered_metadata): + raise ValueError( + "Inconsistent tensor metadata (including shape and stride) across ranks." + ) + return None + + +def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: + assert spec.tensor_meta is not None, "spec should have tensor meta defined!" + return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) + + +@dataclass +class MeshTopoInfo: + """ + Mesh information for collective cost estimation + """ + + mesh: DeviceMesh + mesh_dim_devices: List[int] + mesh_dim_bandwidth: List[float] + mesh_dim_latency: List[float] + + @staticmethod + @lru_cache(None) + def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo": + # Generate mesh topology info for intra-host/inter-host communication pattern + # Note that we made bunch of assumptions for simplicity: + # 1. we assume the mesh is homogeneous, and it's gpu/nccl model + # 2. we assume gpu arch is Ampere or Hopper + # 3. we assume collectives are all ring base algo for now + num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type) + # the base bw number (intra-node), GB/s + base_bw = 87.7 + mesh_dim_bandwidth = [base_bw] * mesh.ndim + # the latency in terms of us (intra-node, nv-link) + mesh_dim_latency = [0.6] * mesh.ndim + mesh_dim_devices = [1] * mesh.ndim + + total_num_devices = 1 + for mesh_dim in reversed(range(mesh.ndim)): + num_devices = mesh.size(mesh_dim) + mesh_dim_devices[mesh_dim] = num_devices + total_num_devices *= num_devices + if total_num_devices > num_devices_per_host: + # magic number for inter-host communication bandwidth/latency factor + # This number assumes latest GPU arch, i.e. Ampere or Hopper + # TODO: see if we need to tweak this or offer a way for user + # to specify the bandwidths/latency + mesh_dim_bandwidth[mesh_dim] *= 0.22 + # set to ethernet latency for inter-host + mesh_dim_latency[mesh_dim] = 2.7 + + return MeshTopoInfo( + mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency + ) + + +def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + num_hops = num_devices_on_mesh_dim - 1 + # base latency + comm latency + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s + return latency + bw * 1e6 # rescale to us + + +def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter + num_hops = 2 * num_devices_on_mesh_dim - 1 + + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def reduce_scatter_cost( + bytes_gb: float, + mesh_topo: MeshTopoInfo, + mesh_dim: int, +) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + num_hops = num_devices_on_mesh_dim - 1 + # base latency + comm latency + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def redistribute_cost( + current_spec: "dtensor_spec.DTensorSpec", + target_spec: "dtensor_spec.DTensorSpec", +) -> float: + """ + This function returns the cost of redistribute from current to target DTensorSpec. + + NOTE: + 1. Only consider communication cost here, since computation costs for redistribute + are quite trival (i.e. we only need to narrow or simple division) + 2. Only consider redistribute cost on same mesh, cross mesh communication cost is + not quite needed for operator strategy estimation/selection. + """ + if current_spec.mesh != target_spec.mesh: + # make infinite cost if meshes are not same + # TODO: see if we want to support this once there's cross mesh communication + return float("inf") + + if current_spec.is_replicated(): + # short-cut: + # comm cost is 0 if current spec is already full replication + return 0.0 + + mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) + cost = 0.0 + comm_bytes_gb = ( + spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 + ) + # Transformation that considered for redistribute cost: + # 1. allgather 2. alltoall + # 3. allreduce 4. reduce_scatter + for i, (current, target) in enumerate( + zip(current_spec.placements, target_spec.placements) + ): + if current == target: + continue + + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] + if current.is_shard() and target.is_replicate(): + # allgather gives larger comm bytes + comm_bytes_gb *= num_devices_on_mesh_dim + # add up allgather comm cost + cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + elif current.is_shard() and target.is_shard(): + # should be alltoall comm, since we haven't implement it yet, add penalty + # to favor allgather instead + cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0 + elif current.is_partial() and target.is_replicate(): + # add up allreduce comm cost + cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) + elif current.is_partial() and target.is_shard(): + # add up reduce_scatter comm cost + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) + # after reduce_scatter the comm bytes for further collectives halved. + comm_bytes_gb /= num_devices_on_mesh_dim + elif current.is_shard() and target.is_partial(): + # ban shard -> partial as it does not make sense to perform + # this redistribute + return float("inf") + + return cost diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..4579a16826d0fbfadb8bc4ccc75c352cd01581fa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py @@ -0,0 +1,510 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import functools +import logging +import operator +import warnings +from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING + +import torch +import torch.distributed as dist +import torch.distributed.tensor._api as dtensor +import torch.distributed.tensor._random as random +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + _is_inplace_op, + _is_out_variant_op, + OpInfo, + OpSchema, + OutputSpecType, +) +from torch.distributed.tensor._random import is_rng_supported_mesh +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor._sharding_prop import ShardingPropagator +from torch.distributed.tensor._tp_conv import ( + convolution_backward_handler, + convolution_handler, +) +from torch.distributed.tensor._utils import try_find_mesh_from_args +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate + + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + +try: + from torch.utils import _cxx_pytree as pytree +except ImportError: + from torch.utils import _pytree as pytree # type: ignore[no-redef] + +aten = torch.ops.aten +logger = logging.getLogger(__name__) + + +def decompose_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + """ + Decomposes a op to core ATen op, this handler is mostly here + for inference mode usage where the ops are not core aten ops. + """ + r = op_call.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + else: + raise RuntimeError("Decomposition failed") + + +def is_same_size_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> bool: + lhs = cast(torch.Tensor, args[0]) + rhs = cast(torch.Tensor, args[1]) + return lhs.shape == rhs.shape + + +def found_inf_reduce_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> None: + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + local_tensor_args = pytree.tree_unflatten( + cast(List[object], op_info.local_args), op_info.args_tree_spec + ) + local_tensor_args = cast(Tuple[object, ...], local_tensor_args) + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + + grad_dtensor = cast(list[dtensor.DTensor], args[0])[0] + grad_placements = grad_dtensor.placements + mesh = grad_dtensor.device_mesh + + found_inf_placements: list[Placement] = [] + for placement in grad_placements: + if isinstance(placement, Replicate): + found_inf_placements.append(placement) + else: + found_inf_placements.append(Partial("max")) + + target_tensor = cast(torch.Tensor, args[1]) + spec = DTensorSpec( + mesh=mesh, + placements=tuple(found_inf_placements), + tensor_meta=TensorMeta( + shape=target_tensor.size(), + stride=target_tensor.stride(), + dtype=target_tensor.dtype, + ), + ) + found_inf_dtensor = dtensor.DTensor( + local_tensor=target_tensor, spec=spec, requires_grad=False + ) + found_inf = found_inf_dtensor.full_tensor() + target_tensor.copy_(found_inf) + + +class OpDispatcher: + """ + Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding + propagation, redistribute local args, local compute, and post-processing (re-wrapping). It + also handles any op specific logic if necessary. + + NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher + is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster + pytree if needed, and leveraging various caching mechanisms implemented in the sharding + propagation and redistribute modules. The CPU overhead is critical to eager mode performance, + one need to carefully measure the CPU overhead when making significant changes to the + OpDispatcher and ShardingPropagator. + """ + + def __init__(self) -> None: + self.sharding_propagator = ShardingPropagator() + self._random_ops = { + aten.native_dropout.default, + aten.normal_.default, + aten.rand_like.default, + aten.randn_like.default, + aten.randint_like.default, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + aten.uniform_.default, + aten.bernoulli.default, + aten.bernoulli_.float, + } + self._custom_op_handlers = { + aten.linear.default: decompose_handler, + aten.is_same_size.default: is_same_size_handler, + aten.convolution.default: convolution_handler, + aten.convolution_backward.default: convolution_backward_handler, + aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler, + } + + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) + # as implicitly replicated or we throw error to user. + # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave + # it as False by default. + self._allow_implicit_replication = False + + def dispatch( + self, + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> object: + """ + Main dispatching logic + """ + # operators that does not need to go through sharding propagation + if op_call in self._custom_op_handlers: + return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] + + # extract local tensor and sharding infos to a OpInfo + op_info = self.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + self.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + logger.debug("output_sharding for %s: %s", op_call, output_sharding) + assert output_sharding is not None, "output sharding should not be None" + + mesh = op_info.mesh + if mesh.get_coordinate() is not None: + # computation that happens in the current rank of the mesh, normal case + if output_sharding.needs_redistribute: + # If sharding propagation decision needs redistribute, perform redistribute + # on args first, which could potentially modify args (i.e. allgather certain arg) + assert output_sharding.redistribute_schema is not None + self.redistribute_local_args( + op_info, output_sharding.redistribute_schema + ) + + local_tensor_args = ( + pytree.tree_unflatten( + cast(List[object], op_info.local_args), op_info.args_tree_spec + ) + if op_info.args_tree_spec + else op_info.local_args + ) + + # run local op computation with potentially modified args/kwargs + local_tensor_args = cast(Tuple[object, ...], local_tensor_args) + if op_call in self._random_ops: + if not random._rng_tracker and is_rng_supported_mesh(mesh): + # Default to `OffsetBasedRNGTracker` if the parallelism API + # did not already construct one + random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) + + first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( + torch.Tensor, local_tensor_args[0] + ) + rng_context = ( + random._rng_tracker._distribute_region(first_arg._spec) + if random._rng_tracker and not first_local_arg.is_meta + else contextlib.nullcontext() + ) + # For DTensor random operator, run it within a RNGTracker context to + # ensure the random number generator is properly distributed. + with rng_context: + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + else: + # normal case, run local sharded op computation + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + + else: + # For a non-participating device (happens on rank that does not belong to + # the device mesh), we do: + # 1. if the return type is scalar, set the local result to None. + # 2. if the return type is Tensor or List[Tensor], return empty + # tensor(s) with correct dtype. + spec = output_sharding.output_spec + ret_list = op_info.schema.op._schema.returns + + if spec is None: + # For a scalar return type, the non-participating device has None + # as its local result + local_results = None + else: + + def default_tensor(spec: DTensorSpec) -> torch.Tensor: + if spec.tensor_meta is not None: + shape = spec.tensor_meta.shape + dtype = spec.tensor_meta.dtype + if len(shape) == 0: + # scalar tensor + return torch.zeros((), dtype=dtype) + else: + # non-scalar tensor + return torch.tensor([], dtype=dtype) + else: + raise RuntimeError(f"{spec} has no tensor metadata.") + + if isinstance(spec, DTensorSpec): + # return a Tensor value + local_results = default_tensor(spec) + elif isinstance(spec, Sequence): + # return a List[Tensor] value + local_results = [ + default_tensor(s) if s is not None else None for s in spec + ] + assert isinstance(local_results, List) + if None in local_results: + ret_type = str(ret_list[0].type) + raise NotImplementedError( + f"return type {ret_type} in DTensor op is not supported" + ) + + if output_sharding.output_spec is None: + if op_call == aten.equal.default: + # For equal operator, The local results from all devices should be all-gathered + # and a reduce op (AND) will be performed on the list of results to ensure SPMD + # execution. We can extend this for more ops if necessary. + obj_list = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] + obj_list = list(filter(lambda x: x is not None, obj_list)) + # perform reduce on the collection with AND op + local_results = functools.reduce(operator.and_, obj_list, True) + + if _is_inplace_op(op_call): + # inplace op should return self instead of re-wrapping + if output_sharding.output_spec is not None: + return args[0] + else: + return None + elif _is_out_variant_op(op_call): + # out variant could possibly have multiple out args (i.e. lu_unpack.out) + output_specs = ( + (output_sharding.output_spec,) + if not isinstance(output_sharding.output_spec, tuple) + else output_sharding.output_spec + ) + out_dts = [] + spec_idx = 0 + for argument in op_call._schema.arguments: + if argument.is_out: + out_dt = cast(dtensor.DTensor, kwargs[argument.name]) + out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) + out_dts.append(out_dt) + spec_idx += 1 + + assert len(out_dts) >= 1, "out variant should have at least one out arg" + return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] + else: + return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + + @staticmethod + def redistribute_local_args( + op_info: OpInfo, + suggested_input_schema: OpSchema, + ) -> None: + # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it + if op_info.args_tree_spec is not None: + flatten_args_schema_to_reshard = tuple( + pytree.tree_leaves(suggested_input_schema.args_schema) + ) + else: + flatten_args_schema_to_reshard = suggested_input_schema.args_schema + + new_local_args: List[object] = [] + for i, arg_spec in enumerate(op_info.flat_args_schema): + reshard_arg_spec = flatten_args_schema_to_reshard[i] + if isinstance(arg_spec, DTensorSpec): + local_tensor = cast(torch.Tensor, op_info.local_args[i]) + if arg_spec != reshard_arg_spec: + resharded_local_tensor = redistribute_local_tensor( + local_tensor, arg_spec, reshard_arg_spec + ) + new_local_args.append(resharded_local_tensor) + else: + new_local_args.append(local_tensor) + else: + new_local_args.append(reshard_arg_spec) + + op_info.local_args = tuple(new_local_args) + + def unwrap_to_op_info( + self, + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> OpInfo: + # get runtime schema info to determine whether to use pytree to flatten inputs + runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( + op_call, None + ) + + if runtime_schema_info is not None and runtime_schema_info.needs_pytree: + # flatten args/kwargs when op says necessary + tree_args, args_spec = pytree.tree_flatten(args) + args_list: Sequence[object] = tree_args + else: + args_list, args_spec = args, None + + args_schema: List[object] = [] + kwargs_schema: Dict[str, object] = {} + local_args: List[object] = [] + local_kwargs: Dict[str, object] = {} + mesh: Optional[DeviceMesh] = None + + for arg in args_list: + if isinstance(arg, dtensor.DTensor): + local_args.append(arg._local_tensor) + if mesh is not None and mesh != arg.device_mesh: + # TODO: try replicate dtensor spec in missing dimension would work + # for most cases for foreach case except when the first DTensor in + # the list is one that also need to be replicated. We need to revisit + # how we want to handle this corner case. For now, this case would hit + # the cross mesh error even if implicit replication is turned on. + spec = self._try_replicate_dtensor_spec_in_missing_dim( + op_call, arg, mesh + ) + args_schema.append(spec) + else: + mesh = arg.device_mesh + args_schema.append(arg._spec) + elif isinstance(arg, torch.Tensor): + mesh = mesh or try_find_mesh_from_args(op_call, args_list) + args_schema.append( + self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh) + ) + local_args.append(arg) + else: + args_schema.append(arg) + local_args.append(arg) + + for k, v in kwargs.items(): + if isinstance(v, dtensor.DTensor): + local_kwargs[k] = v._local_tensor + if mesh is not None and mesh != v.device_mesh: + spec = self._try_replicate_dtensor_spec_in_missing_dim( + op_call, v, mesh + ) + kwargs_schema[k] = spec + else: + mesh = v.device_mesh + kwargs_schema[k] = v._spec + elif isinstance(v, torch.Tensor): + mesh = mesh or try_find_mesh_from_args(op_call, args_list) + kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( + op_call, v, mesh + ) + local_kwargs[k] = v + else: + kwargs_schema[k] = v + local_kwargs[k] = v + + assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!" + op_info = OpInfo( + mesh, + OpSchema( + op_call, + pytree.tree_unflatten(args_schema, args_spec) + if args_spec + else tuple(args_schema), + kwargs_schema, + schema_info=runtime_schema_info, + ), + args_schema, + tuple(local_args), + local_kwargs, + args_spec, + ) + return op_info + + @staticmethod + def wrap(res: object, spec: OutputSpecType) -> object: + if isinstance(res, torch.Tensor): + if spec is not None: + assert isinstance( + spec, DTensorSpec + ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." + return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) + else: + # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor + assert res.ndim == 0, "output tensor should be scalar!" + return res + elif isinstance(res, (list, tuple)): + assert spec is not None and isinstance( + spec, (list, tuple) + ), f"output spec does not match with output! Expected list/tuple, got {spec}." + res_list = [] + for e, s in zip(res, spec): + res_list.append(OpDispatcher.wrap(e, s)) + + return tuple(res_list) if isinstance(res, tuple) else res_list + else: + # if the res contains only non tensor values (i.e. int/float/none), we simply return it + # without rewrapping to DTensor. + return res + + def _try_replicate_spec_for_scalar_tensor( + self, + op_call: torch._ops.OpOverload, + tensor_arg: torch.Tensor, + mesh: "DeviceMesh", + ) -> DTensorSpec: + # util function to produce a replicate spec for a scalar tensor arg/kwarg + if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: + warnings.warn( + "Found a non-scalar tensor with numel=1 and ndim!=0, " + "we are implicitly creating a replicated DTensor for it. " + "However, please consider changing it to a scalar tensor " + "or explicitly create a DTensor under distributed enviroment." + ) + + if tensor_arg.numel() == 1 or self._allow_implicit_replication: + # scalar tensor can be safely treated as replicated + replication_spec = DTensorSpec( + mesh, + (Replicate(),) * mesh.ndim, + tensor_meta=TensorMeta( + shape=tensor_arg.shape, + stride=tensor_arg.stride(), + dtype=tensor_arg.dtype, + ), + ) + else: + raise RuntimeError( + f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" + " torch.Tensor to DTensor before calling distributed operators!" + ) + return replication_spec + + def _try_replicate_dtensor_spec_in_missing_dim( + self, + op_call: torch._ops.OpOverload, + dtensor_arg: "dtensor.DTensor", + mesh: "DeviceMesh", + ) -> DTensorSpec: + # util function to produce a new spec for a DTensor arg/kwarg + # that puts Replicate() placement in the missing dimension for foreach ops + from torch.distributed.device_mesh import _mesh_resources + + cur_mesh = dtensor_arg.device_mesh + root_mesh = _mesh_resources.get_root_mesh(cur_mesh) + if ( + self._allow_implicit_replication + and "foreach" in op_call.__name__ + and root_mesh == mesh + ): + placements = [Replicate() for _ in range(root_mesh.ndim)] + cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh) + placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload] + replicate_spec = DTensorSpec( + root_mesh, + tuple(placements), + tensor_meta=TensorMeta( + shape=dtensor_arg.shape, + stride=dtensor_arg.stride(), + dtype=dtensor_arg.dtype, + ), + ) + else: + raise NotImplementedError( + f"{op_call}: DTensor does not support cross-mesh operation yet! " + f"Got meshes: {mesh} {cur_mesh}" + ) + return replicate_spec diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..e80729c7b628692cb4f23d1d067c471d8be48938 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py @@ -0,0 +1,276 @@ +from dataclasses import dataclass +from typing import Any, cast, List, NamedTuple, Optional, Tuple + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +class TensorMeta(NamedTuple): + # simple named tuple to represent tensor metadata + # intentionally to stay simple only for sharding + # propagation purposes. + shape: torch.Size + stride: Tuple[int, ...] + dtype: torch.dtype + + +# used internally to propagate the placements +@dataclass +class DTensorSpec: + mesh: DeviceMesh + placements: Tuple[Placement, ...] + + # tensor meta will only be set during sharding propagation + tensor_meta: Optional[TensorMeta] = None + + def __post_init__(self) -> None: + if not isinstance(self.placements, tuple): + self.placements = tuple(self.placements) + self._hash: Optional[int] = None + + def __setattr__(self, attr: str, value: Any) -> None: + super().__setattr__(attr, value) + # Make sure to recompute the hash in case any of the hashed attributes + # change (though we do not expect `mesh` or `placements` to change) + if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): + self._hash = None + + def _hash_impl(self) -> int: + # hashing and equality check for DTensorSpec are used to cache the sharding + # propagation results. We only need to consider the mesh, placements, shape + # dtype and stride. + # Caveat: we need to keep this in mind and sync hash and eq if we add more + # fields to them. + if self.tensor_meta is not None: + return hash( + ( + self.mesh, + self.placements, + self.tensor_meta.shape, + self.tensor_meta.stride, + self.tensor_meta.dtype, + ) + ) + return hash((self.mesh, self.placements)) + + def __hash__(self) -> int: + # We lazily cache the spec to avoid recomputing the hash upon each + # use, where we make sure to update the hash when the `tensor_meta` + # changes by overriding `__setattr__`. This must be lazy so that Dynamo + # does not try to hash non-singleton `SymInt`s for the stride. + if self._hash is None: + self._hash = self._hash_impl() + return self._hash + + def __eq__(self, __o: object) -> bool: + if not ( + isinstance(__o, DTensorSpec) + and self.mesh == __o.mesh + and self.placements == __o.placements + ): + return False + if self.tensor_meta is None or __o.tensor_meta is None: + return self.tensor_meta == __o.tensor_meta + + return ( + self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr] + and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr] + and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr] + ) + + def __str__(self) -> str: + """ + human readable representation of the DTensorSpec + """ + if len(self.placements) == 1: + placement_str = str(self.placements[0]) + else: + placement_str = str(self.placements) + + if self.tensor_meta is not None: + tensor_shape = str(tuple(self.tensor_meta.shape)) + else: + tensor_shape = "unknown shape" + + return f"Spec({placement_str} on {tensor_shape})" + + @property + def shape(self) -> torch.Size: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.shape + + @property + def stride(self) -> Tuple[int, ...]: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.stride + + @property + def ndim(self) -> int: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return len(self.tensor_meta.shape) + + @property + def num_shards(self) -> int: + num_shards = 1 + for i, placement in enumerate(self.placements): + if placement.is_shard(): + num_shards *= self.mesh.size(i) + return num_shards + + @property + def device_mesh(self) -> DeviceMesh: + # simple aliasing for the mesh field, make some + # checks that mixes DTensor/DTensorSpec easier + return self.mesh + + @property + def dim_map(self) -> List[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. It simply return a list of ints + where dim_map[i] denotes the sharding mapping to the mesh + dimension, and len(dim_map) == dist_tensor.ndim + dim_map[i] = -1: means tensor dim i replicate on mesh + dim_map[i] = j: means tensor dim i shard on mesh dim j + + For example, we have a dist tensor that have the shape of + [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: + [Shard(1)], the dim_map of this placement would be: + [-1, 0, -1]. This representation is pretty helpful during + sharding propagation where we could know exactly each + tensor dimension is sharded or not. + + Note that if placements contains `_Partial`, we have to + explicitly deal with it, so that when we create a DTensorSpec + with dim_map, we could properly record the pending sums. + """ + # dims mapping of dist tensor sharding + # return size of tensor ndim, -1 represent replicate + # and int >=0 represent shard on that device mesh dim + r = [-1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + if r[shard_dim] > -1: + raise ValueError( + f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + " DTensor operator implementation does not support things like hybrid" + " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + ) + r[shard_dim] = i + return r + + @property + def num_shards_map(self) -> List[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. Unlike `dim_map`, `num_shards_map` + denotes how many shards each tensor dim has. Like `dim_map`: + len(num_shards_map) == dist_tensor.ndim + num_shards_map[i] = 1: means tensor dim i is not sharded + num_shards_map[i] = j: means tensor dim i has j shards in total + + For example, we have a dist tensor of shape [18, 20, 30], + a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements + ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor + would be: [4, 2, 1]. + """ + r = [1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + r[shard_dim] *= self.mesh.size(i) + + return r + + @property + def sums(self) -> List[int]: + """ + sums is a property we derive from `placements` of the + distributed tensor. It simply return a list of ints where + sums[i] denotes the pending sum (partial) on mesh dim i + """ + return [ + idx + for idx, placement in enumerate(self.placements) + if placement.is_partial() + ] + + @classmethod + def from_dim_map( + cls, + mesh: DeviceMesh, + dim_map: List[int], + sums: List[int], + tensor_meta: Optional[TensorMeta] = None, + ) -> "DTensorSpec": + """ + Construct a DTensorSpec from dim_map list and pending sum. + + Args: + mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec + dim_map (List[int]): a list of integer that represents sharding on each + tensor dimension, see `dim_map` property doc for details + sums (List[int]): a list of integer that represents the dist tensor have + pending sum on which device mesh dimension. + tensor meta (TensorMeta): DTensor metadata + + Return: + a class:`DTensorSpec` object + """ + # by default replicate on device mesh dims + placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] + + # find all mesh dims that need pending reductions + for s in sums: + placements[s] = Partial() + + for i, m in enumerate(dim_map): + if m >= 0: + placement = placements[m] + if placement.is_shard(): + placement = cast(Shard, placement) + raise RuntimeError( + f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + ) + elif placement.is_partial(): + raise RuntimeError( + f"DeviceMesh dimension {m} cannot be both shard and partial!" + ) + placements[m] = Shard(i) + + return cls(mesh, tuple(placements), tensor_meta=tensor_meta) + + def is_replicated(self) -> bool: + """ + return True if the current DTensorSpec replicates on all mesh dims (devices) + """ + return all(placement.is_replicate() for placement in self.placements) + + def is_sharded(self) -> bool: + """ + return True if the current DTensorSpec is sharded on any mesh dims (devices) + """ + return any(placement.is_shard() for placement in self.placements) + + def shallow_copy_with_tensor_meta( + self, tensor_meta: Optional[TensorMeta] + ) -> "DTensorSpec": + """ + Shallow copy the DTensorSpec with a new tensor_meta. + """ + assert tensor_meta is not None, "shallow copy with no tensor_meta!" + return DTensorSpec( + self.mesh, + self.placements, + tensor_meta=tensor_meta, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_op_schema.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_op_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..190886da21fd220e76adbd6e10a527dbed425f13 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_op_schema.py @@ -0,0 +1,457 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import cached_property +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch._ops import OpOverload +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Placement + + +try: + from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec +except ImportError: + from torch.utils._pytree import ( # type: ignore[no-redef, assignment] + tree_leaves, + tree_map_only, + TreeSpec, + ) + + +# Common type aliases +ArgsType = Tuple[object, ...] +KwargsType = Dict[str, object] + +PlacementList = List[Optional[Placement]] + +# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould +# be the same set of possibilities. +OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]] + + +def _rebuild_tensor_from_dtensor_meta(arg) -> object: + """ + This is used to propagate tensor metadata, must be under fake mode + """ + assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta." + return torch.empty_strided( + arg.tensor_meta.shape, + arg.tensor_meta.stride, + dtype=arg.tensor_meta.dtype, + ) + + +def _is_inplace_op(op: OpOverload): + # simple analysis of function schema to determine + # if this is an inplace variant, it might not + # be entirely correct, but it's good enough for now. + return op._schema.name[-1] == "_" + + +def _is_out_variant_op(op: OpOverload): + # simple analysis of function schema to determine + # if this is an out variant, it might not + # be entirely correct, but it's good enough for now. + return "out" in op._schema.overload_name + + +def _pretty_print_spec(spec: object) -> str: + if spec is None: + return "None" + elif isinstance(spec, DTensorSpec): + return "".join([str(p) for p in spec.placements]) + elif isinstance(spec, Sequence): + return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")" + else: + raise RuntimeError(f"Unknown spec type to print: spec={spec}") + + +@dataclass +class PlacementStrategy: + """ + A placement strategy describes acceptable sharding placements of the output + and the tensor arguments of an operation. + + note: when the op return value is a single DTensor object, output_specs is + DTensorSpec; when the return value is a tuple of Optional[DTensor], + output_specs is a tuple of Optional[DTensorSpec]. + """ + + output_specs: Union[DTensorSpec, Tuple[Optional[DTensorSpec], ...]] + input_specs: Optional[Sequence[DTensorSpec]] = None + + # redistribute costs for this op placement strategy + # we need a nested list to record the cost for each + # operand of this operator, and for each operand of + # this operator it might have multiple placement strategies + redistribute_cost: Optional[List[List[float]]] = None + + @cached_property + def output_spec(self) -> DTensorSpec: + """ + This function requires that the strategy have exactly one DTensorSpec as the + output spec. If the output_specs is a tuple, we throw an exception. + """ + if isinstance(self.output_specs, DTensorSpec): + return self.output_specs + else: + raise ValueError( + f"function output_spec expects a single DTensorSpec but got: {self.output_specs}" + ) + + def input_spec(self, index: int = 0) -> DTensorSpec: + assert self.input_specs is not None, "input_specs of PlacementStrategy is None!" + assert len(self.input_specs) > index, ( + f"Invalid index {index} for input_specs of length " + f"{len(self.input_specs)}: {self.input_specs}" + ) + return self.input_specs[index] + + def __str__(self) -> str: + if self.input_specs is not None: + input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> " + else: + input_specs_str = "" + output_spec_str = _pretty_print_spec(self.output_specs) + return f"{input_specs_str}{output_spec_str}" + + +class StrategyType: + """ + Base class type for op strategy, We have two StrategyType: + OpStrategy and TupleStrategy + """ + + +class OpStrategy(StrategyType): + """ + OpStrategy that consists of a list of placement strategies associated with the op + """ + + def __init__(self, strategies: List[PlacementStrategy]) -> None: + super().__init__() + self.strategies: List[PlacementStrategy] = strategies + + def __str__(self) -> str: + strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies]) + mesh_shape = self.mesh_shape + return f"[{strategy_list_str}] @ mesh: {mesh_shape}" + + def max_num_shards(self) -> int: + """ + Returns the max number of shards across all placement strategies + """ + return max(strategy.output_spec.num_shards for strategy in self.strategies) + + @property + def mesh_shape(self): + output_spec = self.strategies[0].output_specs + if isinstance(output_spec, DTensorSpec): + return output_spec.mesh.shape + else: + assert isinstance( + output_spec, tuple + ), "found no DTensorSpec in the OpStrategy!" + assert output_spec[0] is not None + return output_spec[0].mesh.shape + + @property + def ndim(self): + return self.strategies[0].output_spec.ndim + + @property + def shape(self): + return self.strategies[0].output_spec.shape + + +class TupleStrategy(StrategyType): + """ + TupleStrategy represents the output strategy of this op is a tuple + of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors + with possibly different placement strategies, we should return a TupleStrategy that + contains a tuple of OpStrategy, where each child represents the sharding strategy + of "each element" of the tuple/list of tensors the op returns. + + NOTE: if the output of the op is a List[Tensor] and they share the same placement + strategy, then we should return a single OpStrategy instead of a TupleStrategy + """ + + def __init__(self, childs: Sequence[StrategyType]) -> None: + super().__init__() + self.childs: Sequence[StrategyType] = childs + + def __str__(self) -> str: + child_strategies_str = ", ".join( + [f"{str(strat)}" for idx, strat in enumerate(self.childs)] + ) + return f"TupleStrategy({child_strategies_str})" + + +@dataclass +class RuntimeSchemaInfo: + """ + RuntimeSchemaInfo stores the operator schema related information for runtime (eager) + execution. This is mainly used for two ways: 1. to generate hash for args to determine + whether to re-run sharding prop or not 2. to determine if we need pytree + """ + + # This static_argnum records static arg "starting index" for ops that have non-tensor + # args/kwargs which would affect sharding propagation results. All args starting from + # this index would be hashed to our sharding cache. + # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc. + static_argnum: int = 100 + # This static_kwargkey records static kwarg names which would affect sharding prop + static_kwargkey: Optional[List[str]] = None + # each op can decide if it wants to use pytree flatten/unflatten during operator + # eager execution, by default we don't need to do flatten/unflatten, only if the + # op indicate it needs to, this is to accelerate eager performance. + needs_pytree: bool = False + + +@dataclass +class OpSchema: + """ + OpSchema is a data class that describes an operator input schemas, it includes + DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order + preserved). It is mainly used by the DTensor's dispatching logic to perform various + actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.) + + NOTE: this should be used as a read only data class + TODO: make this a frozen dataclass + + Args: + op: the operator overload we are intercepting + args_schema: contains args except that the DTensor args have been replaced + with its DTensorSpec or OpStrategy + kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced + with its DTensorSpec or OpStrategy + """ + + op: OpOverload + args_schema: ArgsType + kwargs_schema: KwargsType + + schema_info: Optional[RuntimeSchemaInfo] = None + + @property + def args_spec(self) -> Tuple[DTensorSpec, ...]: + """ + args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list + with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) + mainly used by sharding propagation to propagate the output spec + """ + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, DTensorSpec)) + + @property + def args_strategy(self) -> Tuple[OpStrategy, ...]: + # filter out non-relevant values from args schema to get a clean OpStrategy list + # separate with args_spec for the ease of type annotation + # TODO: see if we should merge this with args_spec + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, OpStrategy)) + + def __repr__(self) -> str: + args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) + return ( + f"OpSchema(op={self.op}," + f" args_schema=({args_schema})," + f" kwargs_schema={self.kwargs_schema})" + ) + + def __str__(self) -> str: + args_schema: List[str] = [] + mesh_shape = None + for arg in self.args_schema: + if isinstance(arg, DTensorSpec): + args_schema.append(str(arg)) + mesh_shape = arg.mesh.shape + elif isinstance(arg, OpStrategy): + assert len(arg.strategies) == 1 + args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs)) + mesh_shape = arg.mesh_shape + elif isinstance(arg, TupleStrategy): + first_op_strtgy = arg.childs[0] + assert isinstance(first_op_strtgy, OpStrategy) + mesh_shape = first_op_strtgy.mesh_shape + args_schema.append(str(arg)) + else: + args_schema.append(str(arg)) + return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})" + + def __post_init__(self) -> None: + has_symints = False + for a in self.args_schema: + if isinstance(a, DTensorSpec) and a.tensor_meta is not None: + if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape): + has_symints = True + break + self.has_symints = has_symints + + def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool: + arg = self.args_schema[arg_idx] + is_tensor = isinstance(arg, DTensorSpec) + if is_tensor: + return True + + if not isinstance(arg, list): + return False + + return all(isinstance(e, DTensorSpec) or e is None for e in arg) + + def return_type_tuple_tensor_like(self) -> bool: + # all dispatch ops could only return Tuple[Tensor] or have None/ints/floats + # in the tuple, but the first element must be a Tensor, so this check is enough + return_types = self.op._schema.returns + return len(return_types) > 1 and isinstance( + return_types[0].type, torch.TensorType + ) + + def return_type_tensor(self) -> bool: + return_types = self.op._schema.returns + # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like + # return types, so this check is enough for tensor like types + return isinstance(return_types[0].type, torch.TensorType) + + def __hash__(self) -> int: + # Only hash args and kwargs that op indicates to hash + if not self.schema_info: + static_argnum = len(self.args_schema) + static_kwargkey = None + else: + static_argnum = self.schema_info.static_argnum + static_kwargkey = self.schema_info.static_kwargkey + + args_to_hash = tuple( + tuple(e) if isinstance(e, list) else e + for i, e in enumerate(self.args_schema) + if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum + ) + if static_kwargkey is not None: + kwargs_to_hash = tuple( + self.kwargs_schema.get(k, None) for k in static_kwargkey + ) + return hash((self.op, args_to_hash, kwargs_to_hash)) + else: + return hash((self.op, args_to_hash)) + + def __eq__(self, other: object) -> bool: + # early return checks + if not isinstance(other, OpSchema): + return False + + if self.op != other.op: + return False + + if len(self.args_schema) != len(other.args_schema): + return False + + # compare each element and early return if any of them is different + if not self.schema_info: + static_argnum = len(self.args_schema) + static_kwargkey = None + else: + static_argnum = self.schema_info.static_argnum + static_kwargkey = self.schema_info.static_kwargkey + + for i, (self_arg, other_arg) in enumerate( + zip(self.args_schema, other.args_schema) + ): + if isinstance(self_arg, DTensorSpec) and self_arg != other_arg: + return False + elif i >= static_argnum and self_arg != other_arg: + return False + + # check kwarg equality when there's a static kwarg key + if static_kwargkey: + for key in static_kwargkey: + if self.kwargs_schema.get(key, None) != other.kwargs_schema.get( + key, None + ): + return False + + return True + + def gen_fake_args(self) -> ArgsType: + """ + gen_fake_args: generate fake args for the operator, this is mainly used + by sharding propagation rules to generate fake args for the operator + to run the local tensor operator and get the output spec. + """ + return tree_map_only( + DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema + ) + + def gen_fake_kwargs(self) -> KwargsType: + """ + gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used + by sharding propagation rules to generate fake kwargs for the operator + to run the local tensor operator and get the output spec. + """ + return tree_map_only( + DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema + ) + + def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None: + suggestion_args_spec = self.args_spec + new_arg_schema: List[object] = [] + idx_of_args_spec = 0 + if ( + origin_schema.schema_info is not None + and origin_schema.schema_info.needs_pytree + ): + args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema) + else: + args_schema = origin_schema.args_schema + for arg in args_schema: + if isinstance(arg, DTensorSpec): + new_arg_schema.append(suggestion_args_spec[idx_of_args_spec]) + idx_of_args_spec += 1 + else: + new_arg_schema.append(arg) + self.args_schema = tuple(new_arg_schema) + self.kwargs_schema = origin_schema.kwargs_schema + + +@dataclass +class OutputSharding: + """ + OutputSharding is a data class that is used by the sharding propagation, + it could set the output_spec upon successful propagation. If needs_redistribute + is set to True, a redistribute_schema would be returned together to indicate + the input arguments needs to be redistributed before the op execution. + + NOTE: the redistribute_schema generated by sharding propagation should be + exactly the same as the operator OpSchema, except the DTensorSpecs + """ + + output_spec: OutputSpecType + redistribute_schema: Optional[OpSchema] = None + needs_redistribute: bool = False + + +@dataclass +class OpInfo: + """ + All Runtime Op execution info are packed here + """ + + mesh: DeviceMesh + schema: OpSchema + flat_args_schema: List[object] + local_args: Sequence[object] + local_kwargs: Dict[str, object] + args_tree_spec: Optional[TreeSpec] = None + + # the output sharding info + output_sharding: Optional[OutputSharding] = None diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__init__.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dec4665b1c8b957daa4a5f5d15d988ce00cdc79d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from ._conv_ops import * # noqa: F403 +from ._embedding_ops import * # noqa: F403 +from ._experimental_ops import * # noqa: F403 +from ._math_ops import * # noqa: F403 +from ._matrix_ops import * # noqa: F403 +from ._pointwise_ops import * # noqa: F403 +from ._random_ops import * # noqa: F403 +from ._tensor_ops import * # noqa: F403 +from ._view_ops import * # noqa: F403 diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26622dd24b6b1ffebfc00fbb482fde502f31fe13 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8718dcfae89a2e98aa42e0ab7c51bd33f96c2ffc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d88988385df0e68e6c1aef300515d0cd9d07db55 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..405d9ab461ca6285c3b987fc0a2138c0556b00a5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..500f44f38755caaceae1a0a48f47254c9fb56581 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_experimental_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_experimental_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..689100dfab34f4435f46c9811df105478dc79d60 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_experimental_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee9a752f11e5ade0461e517251570542e4b3ebb8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29f46413cae416f7ba08d35c223f359cb49e1845 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50e94fe01929f24d86b6224cab17282c2d482ba8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da02d5b43073a70e41b812261170739bfe628c57 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e524cb4cedf87d43928414920f5b25151d4f69f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02e07554443e5e14ac66f79b2e9067faf0fbfbdf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_common_rules.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_common_rules.py new file mode 100644 index 0000000000000000000000000000000000000000..2a41252be40e76e4cd4896d8e4a2fb84d01d0655 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_common_rules.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import cast, Dict, List, Optional, Tuple + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + _is_inplace_op, + _is_out_variant_op, + OpSchema, + OutputSharding, +) +from torch.distributed.tensor._ops.utils import prod +from torch.distributed.tensor._utils import compute_local_shape + + +def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: + return string[:idx] + new_char + string[idx + 1 :] + + +def _gen_reshard_suggestions( + op_schema: OpSchema, + input_dims: List[str], + input_specs: Tuple[DTensorSpec, ...], + dim_to_sharding: Dict[str, int], + pending_sum: List[int], +) -> OutputSharding: + suggested_arg_specs: List[DTensorSpec] = [] + for input_dim, input_spec in zip(input_dims, input_specs): + dim_map = [dim_to_sharding[dim] for dim in input_dim] + suggested_arg_specs.append( + DTensorSpec.from_dim_map( + mesh=input_spec.mesh, + dim_map=dim_map, + sums=pending_sum, + tensor_meta=input_spec.tensor_meta, + ) + ) + suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {}) + suggested_schema._inplace_rewrap_schema_suggestion(op_schema) + return OutputSharding( + None, + redistribute_schema=suggested_schema, + ) + + +def einop_rule( + equation: str, + op_schema: OpSchema, + *, + linearity: bool = False, + enforce_sharding: Optional[Dict[str, int]] = None, +) -> OutputSharding: + """ + Propagate the sharding of inputs to output for ops whose data moves according to einsum notation. + + This is mostly borrowed from @zdevito's sharding simulator. Examples: + mk,kn->mn - einsum + ij,ij->ij - addition + ij,j->ij - broadcasted addition + ij->i - reduction + Other ops could use this propagation algorithm when applied, note + that einsum propagation only deal with list of specs (DTensor specs) + as it only works on list of tensors! + + linearity in einop_rule means that the calling op `f` follows this rule: + f(a + b) = f(a) + f(b) + + In this case we can propagate the partial sum, note that linearity in einop + only applies to partial sum, not other operations like min/max (which are + associative but not linear). + """ + # parse einop equation and extract arg specs + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + input_specs = op_schema.args_spec + # NOTE: only support single output unless needed in future + output_dim = output_dims[0] + + dim_to_sharding: Dict[str, int] = {} + dim_to_size: Dict[str, int] = {} + # record pending sum, key is mesh dimension, value is pending sum + # counter across input specs + pending_sums_counter: Dict[int, int] = {} + seen_shardings: Dict[int, str] = {} + needs_reshard = False + + def merge_sharding(dim: str, a: int, b: int) -> int: + # merge the sharding of inputs if it's able to merge, i.e. we can merge + # replicate and shard to shard, but this will trigger an reshard operation + if a != b: + if a == -1 or b == -1: + # reshard the replicate to match the sharded one + nonlocal needs_reshard + needs_reshard = True + return a if a != -1 else b + else: + # TODO: further merge the sharding properly (i.e. reshard one input to replicate) + raise RuntimeError( + f"{equation}: dim {dim} sharded two different ways: {a} and {b}" + ) + else: + return a + + for input_dim, input_spec in zip(input_dims, input_specs): + # deal with partial sums + input_sums = input_spec.sums + for sum_dim in input_sums: + if sum_dim not in pending_sums_counter: + seen_shardings[sum_dim] = "+" + # update pending sum counter for pending sum mesh + # dimension with the occurrence from each input + pending_sums_counter[sum_dim] = pending_sums_counter.get(sum_dim, 0) + 1 + + for idx, (dim, mesh_dim) in enumerate(zip(input_dim, input_spec.dim_map)): + if enforce_sharding and dim in enforce_sharding: + if enforce_sharding[dim] != mesh_dim: + needs_reshard = True + dim_to_sharding[dim] = enforce_sharding[dim] + dim_to_size[dim] = input_spec.shape[idx] + elif dim not in dim_to_sharding: + dim_to_sharding[dim] = mesh_dim + dim_to_size[dim] = input_spec.shape[idx] + else: + dim_to_sharding[dim] = merge_sharding( + dim, dim_to_sharding[dim], mesh_dim + ) + assert dim_to_size[dim] == input_spec.shape[idx] + + # after merging sharding, we check if there're multiple + # sharding on the same mesh dim. + merged_sharding_for_dim = dim_to_sharding[dim] + if merged_sharding_for_dim != -1: + if ( + merged_sharding_for_dim in seen_shardings + and dim != seen_shardings[merged_sharding_for_dim] + ): + needs_reshard = True + seen_shardings[merged_sharding_for_dim] += dim + else: + seen_shardings[merged_sharding_for_dim] = dim + + if pending_sums_counter and not linearity: + # return reshard suggestion with no pending sum, because we already properly + # merge the sharding, this reshard suggestion is legit to use + return _gen_reshard_suggestions( + op_schema, input_dims, input_specs, dim_to_sharding, [] + ) + else: + # It's a op that support linearity, but not all input arguments are partial + # we fail the sharding propagation with suggestion to make all inputs be + # partial on the corresponding mesh dim (all inputs should be partial for + # the mesh dims in order to execute locally and delay the sum reduction) + for value in pending_sums_counter.values(): + if value != len(input_specs): + needs_reshard = True + + for mesh_dim, dims in seen_shardings.items(): + if len(dims) > 1: + # we found different input dims are being sharded on the same mesh dim + # in order to perform local op computation, we need to reshard inputs + # base on some simple heuristics, now we simply pick the one with least comm + # volume. (i.e. the input with least size) + # TODO: consider a more advanced heuristic to pick the best sharding + costs = [] + for d in dims: + cost = 0 + for input_dim, input_spec in zip(input_dims, input_specs): + if ( + d in input_dim + and input_spec.dim_map[input_dim.index(d)] == mesh_dim + ): + assert input_spec.tensor_meta is not None + global_shape = input_spec.tensor_meta.shape + local_shape = compute_local_shape( + global_shape, input_spec.mesh, input_spec.placements + ) + cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) + costs.append(cost) + d_to_keep_sharding = dims[costs.index(max(costs))] + for d in dims: + # update dim_to_sharding to keep the sharding of the dim with + # highest comm and make the rest of the dims to replicate + if d != d_to_keep_sharding: + dim_to_sharding[d] = -1 + + pending_sums = list(pending_sums_counter.keys()) + if needs_reshard: + return _gen_reshard_suggestions( + op_schema, input_dims, input_specs, dim_to_sharding, pending_sums + ) + + # generate output pending sum if a dim is sharded, and it appears in input + # but not output + for dim, shard_on_mesh in dim_to_sharding.items(): + if dim not in output_dims[0] and shard_on_mesh != -1: + pending_sums.append(shard_on_mesh) + + # if no need to reshard, we directly generate the output sharding + output_dim_map = [] + output_shape = [] + for dim in output_dim: + if dim == "1": + # find output dim that is a singleton dimension, mark sharding and shape + output_dim_map.append(-1) + output_shape.append(1) + else: + output_dim_map.append(dim_to_sharding[dim]) + output_shape.append(dim_to_size[dim]) + + # XXX: since we still need to have intermediate shape calculation, we need + # to pass in the shape here. We should remove this once sharding decomp works + # for ops like addmm + assert input_specs[0].tensor_meta is not None + tensor_meta = TensorMeta( + torch.Size(output_shape), + input_specs[0].tensor_meta.stride, + input_specs[0].tensor_meta.dtype, + ) + return OutputSharding( + DTensorSpec.from_dim_map( + input_specs[0].mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, + ) + ) + + +def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputSharding: + """ + Propagate the sharding for pointwise operations. + + Examples: + ij,ij->ij - addition/mul + ij,j->ij - broadcasted addition + """ + alphabet = "abcdefghijklmnopqrstuvwxyz" + # find the max_dim first in case we need to broadcasting + input_specs = op_schema.args_spec + max_dim = max(input.ndim for input in input_specs) + dimchars = [] + singleton_counter: List[int] = [0] * max_dim + for input in input_specs: + start_dim = max_dim - input.ndim + p = alphabet[start_dim:max_dim] + # handle the "broadcasting to a common shape case" + # see https://pytorch.org/docs/stable/notes/broadcasting.html + # If any of the dimensions is singleton dimension (i.e. 1). + # we mark the dim char as a special "1" to distinguish with + # the non-singleton dimension, so that sharding propagation + # should just ignore the singleton dimension. + if len(input_specs) > 1: + for i in range(max_dim): + if i < start_dim: + # treat the leading miss dim chars as singleton + singleton_counter[i] += 1 + elif input.shape[i - start_dim] == 1: + # mark singleton dim char as a special "1" in einop rule + singleton_counter[i] += 1 + p = _replace_char_in_str(p, "1", (i - start_dim)) + + dimchars.append(p) + out_dimchars = alphabet[:max_dim] + # check if we replace the all inputs dim char with singleton dimension, + # if we replace all inputs, we also need to replace the output dimension. + for output_dim_idx in range(len(out_dimchars)): + out_dimchar = out_dimchars[output_dim_idx] + if singleton_counter[output_dim_idx] == len(input_specs): + out_dimchars = _replace_char_in_str(out_dimchars, "1", output_dim_idx) + + fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}" + + enforce_sharding: Dict[str, int] = {} + if _is_inplace_op(op_schema.op): + # inplace op should keep the input sharding it writes to + for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map): + enforce_sharding[out_dimchar] = mesh_dim + elif _is_out_variant_op(op_schema.op): + out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"]) + for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map): + enforce_sharding[out_dimchar] = mesh_dim + + return einop_rule( + fmt, + op_schema, + linearity=linearity, + enforce_sharding=enforce_sharding, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_conv_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_conv_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..db2a8136e14da0aab9969f79329fec49b42a813a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_conv_ops.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-decorators +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from typing import List + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema, OutputSharding +from torch.distributed.tensor._ops.utils import register_prop_rule + + +aten = torch.ops.aten + + +@register_prop_rule(aten.convolution.default) +def convolution_rules(op_schema: OpSchema) -> OutputSharding: + ( + input_spec, + weight_spec, + bias_spec, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) = op_schema.args_schema + + assert isinstance(input_spec, DTensorSpec) + assert isinstance(weight_spec, DTensorSpec) + assert isinstance(bias_spec, DTensorSpec) + assert input_spec.tensor_meta is not None + assert weight_spec.tensor_meta is not None + in_shape = input_spec.tensor_meta.shape + weight_shape = weight_spec.tensor_meta.shape + assert isinstance(stride, List) + assert isinstance(padding, List) + assert isinstance(dilation, List) + assert isinstance(weight_shape, torch.Size) + N, C_in, H_in, W_in = in_shape[0], in_shape[1], in_shape[2], in_shape[3] + C_out = weight_shape[0] + H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[ + 0 + ] + 1 + W_out = (W_in + 2 * padding[1] - dilation[1] * (weight_shape[3] - 1) - 1) // stride[ + 1 + ] + 1 + output_shape = [N, C_out, H_out, W_out] + output_stride = (C_out * H_out * W_out, H_out * W_out, W_out, 1) + output_dim_map = input_spec.dim_map + pending_sums = input_spec.sums + + tensor_meta = TensorMeta( + torch.Size(output_shape), + output_stride, + input_spec.tensor_meta.dtype, + ) + return OutputSharding( + DTensorSpec.from_dim_map( + input_spec.mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, + ) + ) + + +@register_prop_rule(aten.convolution_backward.default) +def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: + input_spec = op_schema.args_schema[0] + ( + grad_output_spec, + input_spec, + weight_spec, + bias_shape_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, + ) = op_schema.args_schema + + assert isinstance(grad_output_spec, DTensorSpec) + assert isinstance(input_spec, DTensorSpec) + assert isinstance(weight_spec, DTensorSpec) + assert isinstance(bias_shape_opt, List) + assert input_spec.tensor_meta is not None + weight_tensor_meta = weight_spec.tensor_meta + bias_tensor_meta = TensorMeta( + torch.Size(bias_shape_opt), + (1,), + input_spec.tensor_meta.dtype, + ) + + grad_input_spec = input_spec + grad_weight_spec = DTensorSpec.from_dim_map( + input_spec.mesh, + [-1, -1, -1, -1], + [0], + tensor_meta=weight_tensor_meta, + ) + grad_bias_spec = DTensorSpec.from_dim_map( + input_spec.mesh, + [-1], + [0], + tensor_meta=bias_tensor_meta, + ) + return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3227600b35d27011b0cce4291434a633b3bd40 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -0,0 +1,181 @@ +import itertools +from dataclasses import dataclass +from typing import List, Set, Tuple + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +@dataclass +class EinsumDims: + contracting_dims: List[str] + batch_dims: List[str] + lhs_out_only_dims: List[str] + rhs_out_only_dims: List[str] + + @classmethod + def parse_equation(cls, equation: str) -> Tuple[List[str], str]: + # parse einop equation and extract arg specs + """ + Parse the einsum equation str to input dim chars and output dim char + """ + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + + # NOTE: only support at most two inputs, and single output + # extend to support more inputs if needed in future + assert len(input_dims) <= 2, "Only support at most two inputs" + assert len(output_dims) == 1, "Only support single output" + output_dim = output_dims[0] + return input_dims, output_dim + + @classmethod + def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims": + """ + Parse the dims and extract the contracting, batch, and free dimensions + for the left and right hand sides. + """ + dim_char_set: Set[str] = set() + for input_dim in input_dims: + dim_char_set.update(input_dim) + + # get a determinisitc order of all dim chars + all_dim_chars = sorted(dim_char_set) + + # parse input and output dimensions + lhs_out_only_dims, rhs_out_only_dims = [], [] + batch_dims, contracting_dims = [], [] + + for dim_char in all_dim_chars: + if dim_char not in output_dim: + contracting_dims.append(dim_char) + else: + is_batch_dim = True + for input_dim in input_dims: + is_batch_dim = is_batch_dim and dim_char in input_dim + + if is_batch_dim: + batch_dims.append(dim_char) + else: + assert ( + len(input_dims) == 2 + ), "free dimension only supported for two inputs!" + lhs, rhs = input_dims + if dim_char in lhs: + lhs_out_only_dims.append(dim_char) + elif dim_char in rhs: + rhs_out_only_dims.append(dim_char) + else: + raise RuntimeError("Invalid dimension character") + + return cls( + contracting_dims=contracting_dims, + batch_dims=batch_dims, + lhs_out_only_dims=lhs_out_only_dims, + rhs_out_only_dims=rhs_out_only_dims, + ) + + +def gen_einsum_strategies( + equation: str, + mesh: DeviceMesh, + *, + linearity: bool = False, +) -> OpStrategy: + """ + Generate a strategy list for the ops that follow einsum style notation. + """ + # parse einop equation and extract dims + input_dims, output_dim = EinsumDims.parse_equation(equation) + edims = EinsumDims.parse_dims(input_dims, output_dim) + + all_mesh_dim_strategies = [] + + # generate strategies for each mesh dim + for mesh_dim in range(mesh.ndim): + mesh_dim_strategies = [] + + # placement list stores placements of [output, input1, input2, ...] + # first we always have replicate all for inputs and output + placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1) + mesh_dim_strategies.append(placement_list) + + if mesh.size(mesh_dim) <= 1: + # only replicate strategy for mesh dim with size 1 + # TODO: see if this is valid for the submesh case + continue + + # split batch dim + for batch_dim in edims.batch_dims: + output_batch_dim = output_dim.index(batch_dim) + placement_list = [Shard(output_batch_dim)] + for input_dim in input_dims: + input_batch_dim = input_dim.index(batch_dim) + placement_list.append(Shard(input_batch_dim)) + + mesh_dim_strategies.append(placement_list) + + # split contracting dim + for contracting_dim in edims.contracting_dims: + placement_list = [Partial()] + for input_dim in input_dims: + input_contracting_dim = input_dim.index(contracting_dim) + placement_list.append(Shard(input_contracting_dim)) + + mesh_dim_strategies.append(placement_list) + + # split lhs free dim + for lhs_dim in edims.lhs_out_only_dims: + lhs_free_dim = output_dim.index(lhs_dim) + # this means split the lhs input and output + # i.e. S(0), R -> S(0) + lhs_placement_list: List[Placement] = [ + Shard(lhs_free_dim), + Shard(lhs_free_dim), + Replicate(), + ] + mesh_dim_strategies.append(lhs_placement_list) + + # split rhs free dim + for rhs_dim in edims.rhs_out_only_dims: + rhs_free_dim = output_dim.index(rhs_dim) + rhs_placement_list: List[Placement] = [ + Shard(rhs_free_dim), + Replicate(), + Shard(rhs_free_dim), + ] + mesh_dim_strategies.append(rhs_placement_list) + + # linearity strategy + if linearity: + linearity_placement_list: List[Placement] = [Partial()] + for input_dim in input_dims: + linearity_placement_list.append(Partial()) + mesh_dim_strategies.append(linearity_placement_list) + + all_mesh_dim_strategies.append(mesh_dim_strategies) + + # generate strategies for entire mesh + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + # TODO: filter out invalid strategies, at this point we generate + # all possible strategies without considering the whether the tensor + # dim could be sharded or not, we would need to filter out invalid + # strategies base on the actual tensor shape + # (i.e. for Shard, tensor dim size must > mesh size) + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + spec_list.append(DTensorSpec(mesh, tuple(specs))) + strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:]) + all_strategies.append(strat) + + return OpStrategy(all_strategies) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ae333b800ffcb8dd3b3a827aedb7efcb427be61e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py @@ -0,0 +1,274 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from dataclasses import dataclass, field +from typing import cast, Optional + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + StrategyType, +) +from torch.distributed.tensor._ops.utils import ( + expand_to_full_mesh_op_strategy, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +@dataclass +class MaskBuffer: + data: Optional[torch.Tensor] = None + # refcount allows shared usage of the MaskBuffer, as long as all users have the same data + refcount: int = 0 + + def materialize_mask(self, mask): + if self.refcount == 0: + self.data = mask + else: + assert self.data is not None + if not torch.equal(self.data, mask): + raise RuntimeError( + "MaskBuffer has been materialized with conflicting data" + ) + self.refcount += 1 + + def release_mask(self): + if self.refcount == 0 or self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + self.refcount -= 1 + if self.refcount == 0: + self.data = None + + def apply_mask(self, tensor): + if self.refcount == 0 or self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + + # NOTE: _MaskPartial is being used by the embedding op and the gather op. + # For gather, the mask has the same dimension as the output tensor, whereas + # the output of the embedding op has an additional dimension compare to the input, + # hence the output masking logic below having two different cases. + if tensor.ndim == self.data.ndim: + tensor[self.data] = 0.0 + else: + tensor[self.data, :] = 0.0 + + +@dataclass(frozen=True) +class _MaskPartial(Partial): + """ + A partial mask placement devised for rowwise sharded embedding op, where we need + to mask and adjust the indices to the local embedding shard, embedding masking + is a special type of the Partial placement + + NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor + lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. + """ + + mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) + + # required fields for computing the local offset and deriving the mask + offset_shape: Optional[torch.Size] = None + offset_dim: int = 0 + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + assert ( + self.offset_shape is not None + ), "offset_shape needs to be set for _MaskPartial" + local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim( + self.offset_shape[self.offset_dim], + num_chunks, + mesh.get_local_rank(mesh_dim), + return_offset=True, + ) + # Build the input mask and save it for the current partial placement + # this is so that the output of embedding op can reuse the same partial + # placement saved mask to perform mask + reduction + mask = (tensor < local_offset_on_dim) | ( + tensor >= local_offset_on_dim + local_shard_size + ) + # mask the input tensor + masked_tensor = tensor.clone() - local_offset_on_dim + masked_tensor[mask] = 0 + # materialize the mask buffer to be used for reduction + self.mask_buffer.materialize_mask(mask) + return masked_tensor + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # perform sum reduction + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _MaskPartial): + return False + + # if either data is not None, we invalidate the sharding cache, as this indicates + # the current MaskPartial placement is still in use and should not be used for cache hit. + if self.mask_buffer.data is not None or other.mask_buffer.data is not None: + return False + + return ( + self.reduce_op == other.reduce_op + and self.offset_shape == other.offset_shape + and self.offset_dim == other.offset_dim + ) + + def __hash__(self) -> int: + return 1 + hash( + ( + self.reduce_op, + self.offset_shape, + self.offset_dim, + ) + ) + + def __repr__(self) -> str: + """ + machine readable representation of the MaskPartial placement + """ + return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" + + def __str__(self) -> str: + """ + human readable representation of the MaskPartial placement + """ + return "MaskP" + + +@register_op_strategy(aten.embedding.default) +def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + weight_strategy = cast(OpStrategy, op_schema.args_schema[0]) + indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) + + weight_shape = weight_strategy.shape + indices_shape = indices_strategy.shape + output_emd_dim = len(indices_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate + colwise_sharding: PlacementList = [Shard(output_emd_dim), Shard(1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial + embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0) + + # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates + # from the input indices and use it for output reduction + rowwise_sharding: PlacementList = [ + embedding_partial_placement, + Shard(0), + embedding_partial_placement, + ] + single_mesh_dim_strategies.append(rowwise_sharding) + + # batch dim sharding, weight replicated, input can shard on any dim, output follows input + for input_dim in range(len(indices_shape)): + batch_sharding: PlacementList = [ + Shard(input_dim), + Replicate(), + Shard(input_dim), + ] + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) + + +@register_op_strategy(aten.embedding_dense_backward.default) +def embedding_dense_backward_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0]) + indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) + + grad_out_shape = grad_out_strategy.shape + indices_shape = indices_strategy.shape + grad_out_ndim = len(grad_out_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding backward, grad_out shard on last dim, input replicate, + # weight grad shard colwise + colwise_sharding: PlacementList = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # batch dim sharding, weight replicated, grad_out/input have same sharding + # that can shard on any dim, weight grad partial + for input_dim in range(len(indices_shape)): + batch_sharding: PlacementList = [Partial(), Shard(input_dim), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + # grad_out partial, input replicate, weight grad keep partial + partial_sharding: PlacementList = [Partial(), Partial(), Replicate()] + single_mesh_dim_strategies.append(partial_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_experimental_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_experimental_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ab5cc3aeeb020dbee5b71ae76c9d977fcdb0a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_experimental_ops.py @@ -0,0 +1,28 @@ +# mypy: allow-untyped-decorators +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + StrategyType, +) +from torch.distributed.tensor._ops.utils import register_op_strategy +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate + + +aten = torch.ops.aten + + +@register_op_strategy(aten.slice_backward.default) +def slice_backward_rules(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """ + slice_backward is a new_zeros + slice_scatter, we only allow replication + on the input/output for now since new_zeros would produce replication + """ + replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + return OpStrategy([PlacementStrategy(replicate_spec)]) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_math_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_math_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4905c3389185954ffbe1caf67d2c8d1803a9b749 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_math_ops.py @@ -0,0 +1,1058 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import math +from dataclasses import dataclass +from enum import Enum +from typing import cast, List, Optional, Sequence, Tuple, Union + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + PlacementStrategy, + RuntimeSchemaInfo, + TupleStrategy, +) +from torch.distributed.tensor._ops.utils import ( + as_list, + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + is_tensor_evenly_shardable, + normalize_dim, + normalize_dims, + register_op_strategy, +) +from torch.distributed.tensor._utils import normalize_to_torch_size +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + +@dataclass(frozen=True) +class NormReduction: + norm_type: Union[int, float, str] + + +ReductionOpType = Union[NormReduction, str] + + +@dataclass(frozen=True) +class _NormPartial(Partial): + """ + This placement is used for partial vector norm. + + For p-norms (where p not inf or -inf), the p-norm over n elements computes + (sum_i x_i^p)^(1/p) + where the sum is from i=1 to n. The reduction op is the p-norm itself. + For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm: + Rank 0: [t1, t2] | Rank 1: [t3, t4] + After computing 2-norm per gradient (partial placement): + Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)] + Converting from partial to replicate wants to ultimately get: + Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)] + This can be achieved by computing 2-norm on each rank's result. This holds + similarly for inf and -inf norm. For 0-norm, the reduction op is sum. + """ + + norm_type: Union[int, float, str] = 2 + + def __post_init__(self): + """Set the appropriate reduce op based on the norm type.""" + # Use `object.__setattr__` to bypass frozen checks + if self.norm_type in (float("inf"), "inf"): + object.__setattr__(self, "reduce_op", "max") + elif self.norm_type in (float("-inf"), "-inf"): + object.__setattr__(self, "reduce_op", "min") + elif isinstance(self.norm_type, (int, float)): + object.__setattr__(self, "reduce_op", "sum") + else: + raise NotImplementedError(f"Unsupported norm type: {self.norm_type}") + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + For example, consider 4 ranks, a (3,) replicated tensor, and 2-norm: + Ranks 0 and 1: sqrt(t1^2 + t2^2 + t3^3) + To convert from replicated to partial, we want f(x) such that + sqrt(t1^2 + t2^2 + t3^3) = sqrt(4f(t1)^2 + 4f(t2)^2 + 4f(t3)^2) + = sqrt(4) sqrt(f(t1)^2 + f(t2)^2 + f(t3)^2). + One such f(x) is f(x) = x / sqrt(4). This generalizes to d ranks and + p-norm as f(x) = x / d^(1/p). + """ + if self.reduce_op in ("max", "min"): + return tensor + elif self.reduce_op == "sum": + if self.norm_type == 0: + raise NotImplementedError(f"Unsupported norm type:: {self.norm_type}") + elif self.norm_type == 1: + return tensor / mesh.size(mesh_dim) + assert isinstance(self.norm_type, (int, float)) + return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type) + raise NotImplementedError(self.reduce_op) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + assert isinstance(shard_spec, Shard), f"{shard_spec}" + tensor = self._pre_reduce_transform(tensor) + reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) + return self._post_reduce_transform(reduced_tensor) + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + tensor = self._pre_reduce_transform(tensor) + reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim) + return self._post_reduce_transform(reduced_tensor) + + def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: + if self.reduce_op == "sum": + assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" + if self.norm_type != 0 and self.norm_type != 1: + return tensor**self.norm_type + return tensor + + def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: + if self.reduce_op == "sum": + assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" + if self.norm_type != 0 and self.norm_type != 1: + return tensor ** (1.0 / self.norm_type) + return tensor + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _NormPartial): + return False + return self.norm_type == other.norm_type + + def __hash__(self) -> int: + return 1 + hash(self.norm_type) + + +def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]: + if dims_arg is None: + return None + dims = cast(List[int], as_list(dims_arg)) + dims = cast(List[int], normalize_dims(dims, ndim)) + empty_dims = [[0], [-1], []] + if ndim == 0 and dims_arg in empty_dims: + return None + return dims + + +def _infer_reduce_dims_map( + reduction_dims: List[int], input_ndim: int, keep_dim=False +) -> List[int]: + reduction_dims_map = [] + new_dim_count = 0 + for input_dim in range(input_ndim): + if input_dim in reduction_dims and not keep_dim: + # if input dim in reduction dims, mark it as -1 + reduction_dims_map.append(-1) + else: + # otherwise mark it as the new dim + reduction_dims_map.append(new_dim_count) + new_dim_count += 1 + + return reduction_dims_map + + +def _replicate_dims_start_at( + placements: Sequence[Placement], start_dim: int = 0 +) -> Tuple[Placement, ...]: + new_placements: List[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +# return new_placements which align with placements but skip the skipped_dim +def _skip_dim( + placements: Tuple[Placement, ...], skipped_dim: int +) -> Tuple[Placement, ...]: + new_placements: List[Placement] = [] + for p in placements: + if isinstance(p, Shard) and p.dim >= skipped_dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + return tuple(new_placements) + + +def replicate_reduction_dims( + placements: Tuple[Placement, ...], reduction_dims: List[int] +) -> Tuple[Placement, ...]: + # replicate the reduction dims if not reduction_linear + new_placements: List[Placement] = [] + + for p in placements: + if p.is_partial(): + new_placements.append(Replicate()) + elif isinstance(p, Shard) and p.dim in reduction_dims: + new_placements.append(Replicate()) + else: + new_placements.append(p) + + return tuple(new_placements) + + +def map_placements_after_reduction( + placements: Tuple[Placement, ...], + reduction_dims: List[int], + reduction_dims_map: List[int], + reduction_op: ReductionOpType, +) -> Tuple[Placement, ...]: + """ + Map each placement based on the output shape after reduction. + """ + new_placements: List[Placement] = [] + for placement in placements: + if isinstance(placement, (Replicate, Partial)): + new_placements.append(placement) + else: + assert isinstance(placement, Shard) + shard_dim = placement.dim + new_shard_dim = reduction_dims_map[shard_dim] + if new_shard_dim == -1 or shard_dim in reduction_dims: + # if new_shard_dim collapsed or its in the reduction dims + # (i.e. for the case where keepdims=True), we generate partial + new_placements.append(get_placement_from_reduction_op(reduction_op)) + else: + new_placements.append(Shard(new_shard_dim)) + return tuple(new_placements) + + +def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement: + if isinstance(reduction_op, NormReduction): + return _NormPartial(norm_type=reduction_op.norm_type) + return Partial(reduction_op) + + +def common_reduction_strategy( + mesh: DeviceMesh, + input_strategy: OpStrategy, + reduce_dims: List[int], + keep_dim: bool = False, + reduction_linear: bool = True, + reduction_op: ReductionOpType = "sum", +) -> OpStrategy: + """ + reduction_linear means that the reduction `f` follows this rule: + f([f(a), f(b)]) = f([a, b]) + + reduction linear should be super set of linearity. + """ + # by default follow reduction input strategy + reduction_strategy = OpStrategy([]) + + for strtg in input_strategy.strategies: + if not reduction_linear: + # input placements for this strategy should clear out pending sum and sharding + # on the reduction dimension + input_placements = replicate_reduction_dims( + strtg.output_spec.placements, reduce_dims + ) + else: + input_placements = strtg.output_spec.placements + + input_spec = DTensorSpec( + mesh=mesh, + placements=input_placements, + tensor_meta=strtg.output_spec.tensor_meta, + ) + + reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim) + out_placements = map_placements_after_reduction( + input_spec.placements, reduce_dims, reduce_dims_map, reduction_op + ) + redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)] + reduction_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec( + mesh=mesh, + placements=out_placements, + ), + input_specs=(input_spec,), + redistribute_cost=redistribute_cost, + ) + ) + + return reduction_strategy + + +LINEAR_REDUCTION_OP_MAP = { + aten.all.default: "sum", + aten.all.dim: "sum", + aten.sum.default: "sum", + aten.sum.dim_IntList: "sum", + aten.prod.default: "product", + aten.prod.dim_int: "product", + aten.prod.int_out: "product", + aten.mean.default: "avg", + aten.mean.dim: "avg", + aten.mean.out: "avg", + aten.max.default: "max", + aten.max.dim: "max", + aten.max.out: "max", + aten.min.default: "min", + aten.min.dim: "min", + aten.min.out: "min", + aten.any.default: "sum", + aten.any.dim: "sum", + aten.any.out: "sum", +} + + +@register_op_strategy( + list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1) +) +def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + dims = None + if len(op_schema.args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) + + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + + keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2]) + reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op] + return common_reduction_strategy( + mesh, + input_strategy, + reduce_dims, + keep_dim=keep_dim, + reduction_linear=True, + reduction_op=reduction_op, + ) + + +@register_op_strategy( + [aten.var.correction, aten.var.correction_out], + schema_info=RuntimeSchemaInfo(1, ["keepdim"]), +) +def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + dims = None + if len(op_schema.args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) + + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + + keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False)) + return common_reduction_strategy( + mesh, input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False + ) + + +@register_op_strategy( + [aten.linalg_vector_norm.default], schema_info=RuntimeSchemaInfo(1) +) +def vector_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + norm_type = args_schema[1] if len(args_schema) > 1 else 2 + assert isinstance(norm_type, (int, float, str)), f"{norm_type}" + dim = args_schema[2] if len(args_schema) > 2 else None + keepdim = args_schema[3] if len(args_schema) > 3 else False + dims = _infer_reduction_dims(dim, input_strategy.ndim) + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + return common_reduction_strategy( + mesh, + input_strategy, + reduce_dims, + keep_dim=cast(bool, keepdim), + reduction_linear=True, + reduction_op=NormReduction(norm_type), + ) + + +@register_op_strategy( + [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True) +) +def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrategy: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy) + norm_type = args_schema[1] if len(args_schema) > 1 else 2 + assert isinstance(norm_type, (int, float, str)), f"{norm_type}" + output_tuple_strategy_childs: List[OpStrategy] = [] + for op_strategy in input_tuple_strategy.childs: + assert isinstance(op_strategy, OpStrategy), f"{op_strategy}" + reduce_dims = list(range(op_strategy.ndim)) + output_strategy = common_reduction_strategy( + mesh, + op_strategy, + reduce_dims, + reduction_linear=True, + reduction_op=NormReduction(norm_type), + ) + output_tuple_strategy_childs.append(output_strategy) + return TupleStrategy(output_tuple_strategy_childs) + + +@register_op_strategy( + [ + aten._linalg_svd.default, + aten.linalg_qr.default, + # TODO: The diagonal ops can have an improved sharding strategy for + # shard placements that does not require redistributing to replicate. + aten.diagonal_copy.default, + aten.diag_embed.default, + aten.diag.default, + aten.diagonal.default, + aten.tril.default, + aten.triu.default, + aten._linalg_eigh.default, + ], + schema_info=RuntimeSchemaInfo(1), +) +def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + """ + Since we do not have a simple way to compute some linear algebra operations + like SVD or QR decomposition, always fall back to replicate. + """ + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + output_strategies: List[PlacementStrategy] = [] + for placement_strategy in input_strategy.strategies: + replicate_placements = tuple(Replicate() for _ in range(mesh.ndim)) + replicate_spec = DTensorSpec( + mesh=mesh, + placements=replicate_placements, + tensor_meta=placement_strategy.output_spec.tensor_meta, + ) + redistribute_cost = [ + generate_redistribute_costs(input_strategy, replicate_spec) + ] + replicate_strategy = PlacementStrategy( + output_specs=replicate_spec, + input_specs=(replicate_spec,), + redistribute_cost=redistribute_cost, + ) + output_strategies.append(replicate_strategy) + return OpStrategy(output_strategies) + + +@register_op_strategy( + [aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default], + schema_info=RuntimeSchemaInfo(1), +) +def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + input_strategy, softmax_dim, *_ = op_schema.args_schema + input_strategy = cast(OpStrategy, input_strategy) + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim) + + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # make sure input is replicated along the softmax dim + input_target_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [softmax_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + output_target_spec = input_target_spec + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_target_spec, + input_specs=[input_target_spec], + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [ + aten._log_softmax_backward_data.default, + aten._softmax_backward_data.default, + ], + schema_info=RuntimeSchemaInfo(2), +) +def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema + grad_out_strategy = cast(OpStrategy, grad_out_strategy) + out_strategy = cast(OpStrategy, out_strategy) + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.ndim) + + grad_in_strategy = OpStrategy([]) + for grad_out_placement_strat, out_placement_strat in zip( + grad_out_strategy.strategies, out_strategy.strategies + ): + # follow the sharding of the grad_out or out depending on which has more shards + grad_out_src_spec = grad_out_placement_strat.output_spec + out_src_spec = out_placement_strat.output_spec + src_spec = ( + grad_out_src_spec + if grad_out_src_spec.num_shards >= out_src_spec.num_shards + else out_src_spec + ) + + # make sure inputs are replicated along the softmax dim + tgt_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]), + ) + redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec) + redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec) + grad_in_strategy.strategies.append( + PlacementStrategy( + output_specs=tgt_spec, + redistribute_cost=[redist_grad_out_cost, redist_out_cost], + ) + ) + + return grad_in_strategy + + +@register_op_strategy( + [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default], + schema_info=RuntimeSchemaInfo(3), +) +def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + assert len(op_schema.args_schema) == 5 + ( + input_strategy, + target_strategy, + weight_strategy, + reduction, + _, + ) = op_schema.args_schema + input_strategy = cast(OpStrategy, input_strategy) + target_strategy = cast(OpStrategy, target_strategy) + reduction = cast(int, reduction) + + input_shape = input_strategy.shape + channel_dim = 1 if len(input_shape) >= 2 else 0 + + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + + # make sure input is replicated along the channel dim + input_src_spec = input_placement_strategy.output_spec + input_expected_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [channel_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_expected_spec) + ) + + # target doesn't have channel dim, and it follows input on other dims + target_src_spec = target_strategy.strategies[idx].output_spec + target_expected_spec = DTensorSpec( + mesh=mesh, + placements=_skip_dim(input_expected_spec.placements, channel_dim), + tensor_meta=target_src_spec.tensor_meta, + ) + op_args_target_specs.append(target_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(target_strategy, target_expected_spec) + ) + + # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] + # make sure it is replicated + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_expected_spec) + ) + + if reduction == Reduction.NONE.value: + output_expected_spec = target_expected_spec + total_weight_expected_spec = DTensorSpec( + mesh=mesh, placements=tuple([Replicate()] * mesh.ndim) + ) + else: + if reduction == Reduction.MEAN.value: + reduction_op = "avg" + if not is_tensor_evenly_shardable( + target_expected_spec.shape, target_expected_spec + ): + raise ValueError( + "The intermediate results of nll_loss cannot be evenly sharded, \ + resulting in biased mean result." + ) + else: # reduction == Reduction.SUM.value: + reduction_op = "sum" + reduce_dims = list(range(target_expected_spec.ndim)) + reduce_dims_map = _infer_reduce_dims_map( + reduce_dims, target_expected_spec.ndim, keep_dim=False + ) + out_placements = map_placements_after_reduction( + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + reduction_op, + ) + output_expected_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + ) + + # whether reduction is sum or mean, the total weight has to be summed up if not replicated + total_weight_placements = map_placements_after_reduction( + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + "sum", + ) + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=total_weight_placements, + ) + + output_strategy.strategies.append( + PlacementStrategy( + output_specs=(output_expected_spec, total_weight_expected_spec), + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default], + schema_info=RuntimeSchemaInfo(4), +) +def nll_loss_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + assert len(op_schema.args_schema) == 7 + ( + grad_out_strategy, + input_strategy, + target_strategy, + weight_strategy, + reduction, + _, + total_weight_strategy, + ) = op_schema.args_schema + grad_out_strategy = cast(OpStrategy, grad_out_strategy) + input_strategy = cast(OpStrategy, input_strategy) + target_strategy = cast(OpStrategy, target_strategy) + reduction = cast(int, reduction) + total_weight_strategy = cast(OpStrategy, total_weight_strategy) + + input_shape = input_strategy.shape + channel_dim = 1 if len(input_shape) >= 2 else 0 + + grad_in_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + + # make sure input is replicated along the channel dim + input_src_spec = input_placement_strategy.output_spec + input_expected_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [channel_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_expected_spec) + ) + + # target doesn't have channel dim, and it follows input on other dims + target_src_spec = target_strategy.strategies[idx].output_spec + target_expected_spec = DTensorSpec( + mesh=mesh, + placements=_skip_dim(input_expected_spec.placements, channel_dim), + tensor_meta=target_src_spec.tensor_meta, + ) + op_args_target_specs.append(target_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(target_strategy, target_expected_spec) + ) + + # grad_out follows target if there is no reduction; + # otherwise, it should be a replicated scalar. + grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec + if reduction == Reduction.NONE.value: + grad_out_expected_spec = target_expected_spec + else: + grad_out_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(grad_out_src_spec.placements), + tensor_meta=grad_out_src_spec.tensor_meta, + ) + op_args_target_specs.insert(0, grad_out_expected_spec) + redistribute_costs.insert( + 0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec) + ) + + # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] + # make sure it is replicated + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_expected_spec) + ) + + # total_weight should always be replicated + total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(total_weight_src_spec.placements), + tensor_meta=total_weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(total_weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs( + total_weight_strategy, total_weight_expected_spec + ) + ) + + grad_in_expected_spec = input_expected_spec + grad_in_strategy.strategies.append( + PlacementStrategy( + output_specs=grad_in_expected_spec, + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return grad_in_strategy + + +@register_op_strategy( + [aten.native_layer_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + # args must be: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + assert len(op_schema.args_schema) == 5 + ( + input_strategy, + normalized_shape, + weight_strategy, + bias_strategy, + _, + ) = op_schema.args_schema + + # the current layer norm implementation requires that all + # input DTensor's sharding must be in form of OpStrategy + assert isinstance(input_strategy, OpStrategy) + assert isinstance(normalized_shape, (int, Sequence, torch.Size)) + normalized_size = normalize_to_torch_size(normalized_shape) + + input_ndim = input_strategy.ndim + axis = input_ndim - len(normalized_size) + + # we use OpStrategy because the output (out, mean, rstd) + # should have the same placements + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # for the input tensor, we replicate it on the inner dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + + # for the weight tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + weight_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_target_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_target_spec) + ) + + if bias_strategy is not None: + assert isinstance(bias_strategy, OpStrategy) + bias_src_spec = bias_strategy.strategies[idx].output_spec + + # for the bias tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + bias_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(bias_src_spec.placements), + tensor_meta=bias_src_spec.tensor_meta, + ) + op_args_target_specs.append(bias_target_spec) + redistribute_costs.append( + generate_redistribute_costs(bias_strategy, bias_target_spec) + ) + + # the output spec is the same as input spec + output_target_spec = input_target_spec + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_target_spec, + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [aten.native_layer_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + # args must be: grad_out, input, normalized_shape, mean, rstd, + # weight, bias, output_mask. For None weight and bias, their + # corresponding objects will be None as well. + assert len(op_schema.args_schema) == 8 + ( + grad_out_strategy, + input_strategy, + normalized_shape, + mean_strategy, + rstd_strategy, + weight_strategy, + bias_strategy, + output_mask, + ) = op_schema.args_schema + + assert isinstance(grad_out_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(mean_strategy, OpStrategy) + assert isinstance(rstd_strategy, OpStrategy) + + assert isinstance(normalized_shape, (int, Sequence, torch.Size)) + normalized_size = normalize_to_torch_size(normalized_shape) + input_ndim = input_strategy.ndim + axis = input_ndim - len(normalized_size) + outer_dims = list(range(axis)) + + assert isinstance(output_mask, List) and len(output_mask) == 3 + + # output triple: (d_input, d_weight, d_bias) + out_tuple_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + # args for PlacementStrategy + output_specs_list: List[Optional[DTensorSpec]] = [] + input_specs_list: List[DTensorSpec] = [] + redistribute_costs = [] + + input_src_spec = input_placement_strategy.output_spec + # arg: grad_out + # TODO: change the strategy to the following rule. + # d_input is basically a product of element-wise mul of + # grad_out, rstd, and normalized input, among which rstd + # and normalized input (x_hat) should have the same sharding + # placements, and grad_out's sharding is determined by the + # pointwise result of x_hat and weight/bias. + # TODO: now grad_out spec follows input spec. we may need + # to change it to apply a pointwise rule over grad_out, + # input, and weight. + grad_out_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + input_specs_list.append(grad_out_target_spec) + redistribute_costs.append( + generate_redistribute_costs(grad_out_strategy, grad_out_target_spec) + ) + output_specs_list.append(grad_out_target_spec if output_mask[0] else None) + + # arg: input + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + input_specs_list.append(input_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + + # arg: mean, rstd + mean_src_spec = mean_strategy.strategies[idx].output_spec + input_specs_list.append(mean_src_spec) + redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) + rstd_src_spec = rstd_strategy.strategies[idx].output_spec + input_specs_list.append(rstd_src_spec) + redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) + + def _add_target_input_spec(strategy) -> DTensorSpec: + # shared logic for setting the weight and bias target input specs + assert isinstance(strategy, OpStrategy) + src_spec = strategy.strategies[idx].output_spec + # no need to redistribute since they should be replicated in forward pass + input_specs_list.append(src_spec) + redistribute_costs.append([0.0 for _ in strategy.strategies]) + return src_spec + + # arg: weight + # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) + if weight_strategy is not None: + weight_src_spec = _add_target_input_spec(weight_strategy) + # TODO: now d_weight spec follows input spec w/ a reduction. + # we may need to change to a pointwise rule over grad_out and + # input, then apply a reduction. + inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, input_src_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + weight_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=weight_src_spec.tensor_meta, + ) + output_specs_list.append(weight_out_spec if output_mask[1] else None) + else: + assert ( + output_mask[1] is False + ), "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." + output_specs_list.append(None) + + # arg: bias + # d_bias = sum(grad_out, outer_dim, keepdim=False) + if bias_strategy is not None: + bias_src_spec = _add_target_input_spec(bias_strategy) + # d_bias spec follows a reduction over grad_out + inp_placements = _replicate_dims_start_at( + grad_out_target_spec.placements, axis + ) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, grad_out_target_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + bias_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=bias_src_spec.tensor_meta, + ) + output_specs_list.append(bias_out_spec if output_mask[2] else None) + else: + assert ( + output_mask[2] is False + ), "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." + output_specs_list.append(None) + + out_tuple_strategy.strategies.append( + PlacementStrategy( + output_specs=tuple(output_specs_list), + input_specs=input_specs_list, + redistribute_cost=redistribute_costs, + ) + ) + + return out_tuple_strategy + + +@register_op_strategy( + [aten.topk.default], + schema_info=RuntimeSchemaInfo(2), +) +def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + k = cast(int, op_schema.args_schema[1]) + input_shape = input_strategy.shape + topk_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 + ) + topk_dim = normalize_dim(topk_dim, input_strategy.ndim) + + single_mesh_dim_strategies = [] + + # two outputs (values, indices), 1 input + # replicate always works + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # every dim except topk dim should work + for dim in range(input_strategy.ndim): + if dim != topk_dim: + dim_shardings: PlacementList = [Shard(dim)] * 3 + single_mesh_dim_strategies.append(dim_shardings) + # TODO: topk on sharded dim requries non-trival reduction, address it later + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=2 + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9a7a430a70eb2d2fb795ebc402effa832e5058 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py @@ -0,0 +1,500 @@ +# mypy: allow-untyped-decorators +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor + +from typing import List + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + PlacementStrategy, + RuntimeSchemaInfo, +) +from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies +from torch.distributed.tensor._ops.utils import ( + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + infer_broadcast_dims_map, + is_tensor_shardable, + map_placements_after_broadcast, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard + + +aten = torch.ops.aten + + +@register_op_strategy(aten.t.default) +def transpose_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + transpose_strategies = [] + for input_strategy in self_strategy.strategies: + input_spec = input_strategy.output_spec + # follow the input spec but transpose the Shard placements + output_placements = [ + Shard(1 - p.dim) if isinstance(p, Shard) else p + for p in input_spec.placements + ] + transpose_strategy = PlacementStrategy( + output_specs=DTensorSpec( + mesh=input_strategy.output_spec.mesh, + placements=tuple(output_placements), + ), + input_specs=(input_strategy.output_spec,), + ) + transpose_strategies.append(transpose_strategy) + + return OpStrategy(strategies=transpose_strategies) + + +def _mm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + self_strategy, mat2_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(mat2_strategy, OpStrategy) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + assert strtg.input_specs is not None + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec + ): + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def _addmm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(mat1_strategy, OpStrategy) + assert isinstance(mat2_strategy, OpStrategy) + self_shape = self_strategy.shape + mm_out_shape = torch.Size( + [ + mat2_strategy.shape[-1] if i == len(mat1_strategy.shape) - 1 else dim_size + for i, dim_size in enumerate(mat1_strategy.shape) + ] + ) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + # construct new strategy by consider the self arg + assert strtg.input_specs is not None + mat1_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + out_spec = strtg.output_spec + + # self arg's spec should follow the output of mm, but need + # to consider broadcast for the self arg + broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape) + self_placements = map_placements_after_broadcast( + out_spec.placements, mm_out_shape, broadcast_dims_map + ) + self_spec = DTensorSpec(mesh=mesh, placements=self_placements) + + if is_tensor_shardable(mat1_strategy.shape, mat1_spec) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec + ): + # update input specs with new self spec + strtg.input_specs = (self_spec, mat1_spec, mat2_spec) + + # associate costs + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat1_strategy, mat1_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +@register_op_strategy(aten.mm.default) +def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _mm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy(aten.addmm.default) +def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _addmm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy(aten.bmm.default) +def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema) + + +@register_op_strategy(aten.baddbmm.default) +def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema) + + +@register_op_strategy( + aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5) +) +def scaled_dot_product_flash_attention_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation + # as it involves: matmul, pointwise, reduction ops together. + return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] + q_input_strategy = op_schema.args_schema[0] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + Replicate(), + Replicate(), + Replicate(), + Replicate(), + ] + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + if return_debug_mask: + debug_attn_mask_sharding: Placement = Shard(1) # num head dim + else: + # empty debug mask, replicated + debug_attn_mask_sharding = Replicate() + + num_heads_dim_sharding: PlacementList = [ + output_sharding, + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Context Parallelism: shards on the sequence dim + single_mesh_dim_strategies.append( + [ + Shard(2), # output + Shard(2), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + Shard(2), # debugattn + Shard(2), # q + Shard(2), # k + Shard(2), # v + ] + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default) +def scaled_dot_product_flash_attention_backward_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + q_input_strategy = op_schema.args_schema[1] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + + tensor_input_indices = [ + i + for i, arg_spec in enumerate(op_schema.args_schema) + if isinstance(arg_spec, OpStrategy) + ] + num_tensor_inputs = len(tensor_input_indices) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda backward case, we have 3 tensor outputs and 6 to 10 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [Replicate()] * (3 + num_tensor_inputs) + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + grad_output_sharding = Shard(1) # num head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + grad_qkv_sharding = Shard(1) # num head dim + + num_heads_dim_sharding: PlacementList = [ + grad_qkv_sharding, + grad_qkv_sharding, + grad_qkv_sharding, + grad_output_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + output_sharding, + logsumexp_sharding, + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + num_heads_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Context Parallelism: shards on the sequence dim + seq_dim_sharding: PlacementList = [ + Shard(2), # grad_q + Shard(2), # grad_k + Shard(2), # grad_v + Shard(2), # grad_output + Shard(2), # q + Shard(2), # k + Shard(2), # v + Shard(2), # output + Shard(2), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(seq_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +@register_op_strategy(aten.constant_pad_nd.default) +def constant_pad_nd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: + # TODO(d4l3k); implement a more correct strategy for constant_pad_nd + return OpStrategy( + [ + PlacementStrategy( + output_specs=DTensorSpec(mesh, (Replicate(),)), + input_specs=( + DTensorSpec(mesh, (Replicate(),)), + DTensorSpec(mesh, (Replicate(),)), + ), + redistribute_cost=[[1]], + ) + ] + ) + + +@register_op_strategy( + aten._scaled_dot_product_efficient_attention.default, + schema_info=RuntimeSchemaInfo(4), +) +def scaled_dot_product_efficient_attention_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + q_input_strategy = op_schema.args_schema[0] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + has_attn_bias = op_schema.args_schema[3] is not None + compute_log_sumexp = op_schema.args_schema[4] + + single_mesh_dim_strategies: List[PlacementList] = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, + None, + Replicate(), + Replicate(), + Replicate(), + ] + if has_attn_bias: + all_replicate.append(Replicate()) # attn bias + + # Context Parallelism: shards on the sequence dim + single_mesh_dim_strategies.append( + [ + Shard(2), # output + Shard(2), # logsumexp + None, # philox_seed + None, # philox_offset + Shard(2), # q + Shard(2), # k + Shard(2), # v + ] + ) + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the heads dimension + qkv_sharding = Shard(1) + output_sharding = Shard(1) + if compute_log_sumexp: + logsumexp_sharding: Placement = Shard(1) + else: + # empty logsumexp, replicated + logsumexp_sharding = Replicate() + + num_heads_dim_sharding = [ + output_sharding, + logsumexp_sharding, + None, + None, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + if has_attn_bias: + num_heads_dim_sharding.append(Shard(1)) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, + ) + + +@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default) +def scaled_dot_product_efficient_attention_backward_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + q_input_strategy = op_schema.args_schema[1] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + qkv_shape = q_input_strategy.shape + has_attn_bias = op_schema.args_schema[4] is not None + + tensor_input_indices = [ + i + for i, arg_spec in enumerate(op_schema.args_schema) + if isinstance(arg_spec, OpStrategy) + ] + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda backward case, we have 4 tensor outputs and 8 or 9 tensor inputs + # NOTE: Output sharding of grad_bias on heads dim if attn_bias is present; + # otherwise grad_bias will be empty and its DTensorSpec will be removed. + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [Replicate()] * (12 + has_attn_bias) + + if not has_attn_bias: + all_replicate[3] = None # grad bias is None if attn_bias is not present + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the heads dimension + grad_output_sharding = Shard(1) + qkv_sharding = Shard(1) + output_sharding = Shard(1) + logsumexp_sharding = Shard(1) + grad_qkv_sharding = Shard(1) + grad_bias_sharding = Shard(1) if has_attn_bias else None + + num_heads_dim_sharding: PlacementList = [ + grad_qkv_sharding, + grad_qkv_sharding, + grad_qkv_sharding, + grad_bias_sharding, + grad_output_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + # the place for optional input attn_bias, + output_sharding, + logsumexp_sharding, + ] + # input sharding of attn_bias on heads dim if present + if has_attn_bias: + num_heads_dim_sharding.insert(8, Shard(1)) + # accept replicate on the rest scalar tensor inputs + # namely philox_seed and philox_offset + num_heads_dim_sharding.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Context Parallelism: shards on the sequence dim + seq_dim_sharding: PlacementList = [ + Shard(2), # grad_q + Shard(2), # grad_k + Shard(2), # grad_v + Shard(1) if has_attn_bias else None, # grad_bias + Shard(2), # grad_output + Shard(2), # q + Shard(2), # k + Shard(2), # v + Shard(2), # output + Shard(2), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + if has_attn_bias: + num_heads_dim_sharding.insert(8, Shard(1)) + seq_dim_sharding.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(seq_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bb40865ed9c06ab569060199086ce2cfdda23e51 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -0,0 +1,688 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import List, Sequence, Tuple + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + _is_inplace_op, + _is_out_variant_op, + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._ops.utils import ( + generate_redistribute_costs, + infer_broadcast_dims_map, + map_placements_after_broadcast, + normalize_dim, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten +# leave the remaining pointwise_ops list here for convenience, +# Below ops are some pointwise ops that are yet to be supported, +# they might not be a complete list. +# pointwise_ops = [ +# "fake_quantize_per_channel_affine", +# "fake_quantize_per_tensor_affine", +# "floor_divide", # floor_divide is deprecated +# "frexp", # multiple output pointwise op, need to add support +# "gradient", # need investigation on this op +# "imag", # complex data type only +# "quantized_batch_norm", +# "quantized_max_pool1d", +# "quantized_max_pool2d", +# "real", # complex data type only +# ] + + +linear_pointwise_ops = [ + aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.div_.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.to.dtype, + aten.add.Tensor, + aten.add_.Tensor, +] + + +pointwise_ops = [ + # please keep the entries below alphabetically sorted + aten.__ilshift__.Scalar, + aten.__ilshift__.Tensor, + aten.__irshift__.Scalar, + aten.__irshift__.Tensor, + aten.__lshift__.Scalar, + aten.__lshift__.Tensor, + aten.__rshift__.Scalar, + aten.__rshift__.Tensor, + aten._conj.default, + aten.abs.default, + aten.abs.out, + aten.abs_.default, + aten.acos.default, + aten.acos.out, + aten.acos_.default, + aten.acosh.default, + aten.acosh.out, + aten.acosh_.default, + aten.add.Scalar, + aten.add.out, + aten.add_.Scalar, + aten.addcdiv.default, + aten.addcdiv.out, + aten.addcdiv_.default, + aten.addcmul.default, + aten.addcmul.out, + aten.addcmul_.default, + aten.angle.default, + aten.angle.out, + aten.asin.default, + aten.asin.out, + aten.asin_.default, + aten.asinh.default, + aten.asinh.out, + aten.asinh_.default, + aten.atan.default, + aten.atan.out, + aten.atan2.default, + aten.atan2.out, + aten.atan2_.default, + aten.atan_.default, + aten.atanh.default, + aten.atanh.out, + aten.atanh_.default, + aten.bitwise_and.Scalar, + aten.bitwise_and.Scalar_Tensor, + aten.bitwise_and.Scalar_out, + aten.bitwise_and.Tensor, + aten.bitwise_and.Tensor_out, + aten.bitwise_and_.Scalar, + aten.bitwise_and_.Tensor, + aten.bitwise_left_shift.Scalar_Tensor, + aten.bitwise_left_shift.Tensor, + aten.bitwise_left_shift.Tensor_Scalar, + aten.bitwise_left_shift.Tensor_Scalar_out, + aten.bitwise_left_shift.Tensor_out, + aten.bitwise_left_shift_.Tensor, + aten.bitwise_left_shift_.Tensor_Scalar, + aten.bitwise_not.default, + aten.bitwise_not.out, + aten.bitwise_not_.default, + aten.bitwise_or.Scalar, + aten.bitwise_or.Scalar_Tensor, + aten.bitwise_or.Scalar_out, + aten.bitwise_or.Tensor, + aten.bitwise_or.Tensor_out, + aten.bitwise_or_.Scalar, + aten.bitwise_or_.Tensor, + aten.bitwise_right_shift.Scalar_Tensor, + aten.bitwise_right_shift.Tensor, + aten.bitwise_right_shift.Tensor_Scalar, + aten.bitwise_right_shift.Tensor_Scalar_out, + aten.bitwise_right_shift.Tensor_out, + aten.bitwise_right_shift_.Tensor, + aten.bitwise_right_shift_.Tensor_Scalar, + aten.bitwise_xor.Scalar, + aten.bitwise_xor.Scalar_Tensor, + aten.bitwise_xor.Scalar_out, + aten.bitwise_xor.Tensor, + aten.bitwise_xor.Tensor_out, + aten.bitwise_xor_.Scalar, + aten.bitwise_xor_.Tensor, + aten.ceil.default, + aten.ceil.out, + aten.ceil_.default, + aten.clamp.default, + aten.clamp.out, + aten.clamp_.default, + aten.clip.default, + aten.clip.out, + aten.clip_.default, + aten.conj_physical.default, + aten.conj_physical.out, + aten.conj_physical_.default, + aten.copysign.Scalar, + aten.copysign.Scalar_out, + aten.copysign.Tensor, + aten.copysign.out, + aten.copysign_.Scalar, + aten.copysign_.Tensor, + aten.cos.default, + aten.cos.out, + aten.cos_.default, + aten.cosh.default, + aten.cosh.out, + aten.cosh_.default, + aten.deg2rad.default, + aten.deg2rad.out, + aten.deg2rad_.default, + aten.digamma.default, + aten.digamma.out, + aten.digamma_.default, + aten.div.Tensor, + aten.div.Tensor_mode, + aten.div.out, + aten.div.out_mode, + aten.div_.Tensor, + aten.div_.Tensor_mode, + aten.eq.Tensor, + aten.eq.Tensor_out, + aten.eq.Scalar, + aten.eq.Scalar_out, + aten.erf.default, + aten.erf.out, + aten.erf_.default, + aten.erfc.default, + aten.erfc.out, + aten.erfc_.default, + aten.erfinv.default, + aten.erfinv.out, + aten.erfinv_.default, + aten.exp.default, + aten.exp.out, + aten.exp2.default, + aten.exp2.out, + aten.exp2_.default, + aten.exp_.default, + aten.expm1.default, + aten.expm1.out, + aten.expm1_.default, + aten.float_power.Scalar, + aten.float_power.Scalar_out, + aten.float_power.Tensor_Scalar, + aten.float_power.Tensor_Scalar_out, + aten.float_power.Tensor_Tensor, + aten.float_power.Tensor_Tensor_out, + aten.float_power_.Scalar, + aten.float_power_.Tensor, + aten.floor.default, + aten.floor.out, + aten.floor_.default, + aten.fmod.Scalar, + aten.fmod.Scalar_out, + aten.fmod.Tensor, + aten.fmod.Tensor_out, + aten.fmod_.Scalar, + aten.fmod_.Tensor, + aten.frac.default, + aten.frac.out, + aten.frac_.default, + aten.ge.Scalar, + aten.ge.Tensor, + aten.gelu.default, + aten.gt.Tensor, + aten.gt.Tensor_out, + aten.gt.Scalar, + aten.gt.Scalar_out, + aten.gt.Scalar, + aten.gt.Tensor, + aten.hypot.default, + aten.hypot.out, + aten.hypot_.default, + aten.i0.default, + aten.i0.out, + aten.i0_.default, + aten.igamma.default, + aten.igamma.out, + aten.igamma_.default, + aten.igammac.default, + aten.igammac.out, + aten.igammac_.default, + aten.isinf.default, + aten.isnan.default, + aten.isneginf.default, + aten.isneginf.out, + aten.isposinf.default, + aten.isposinf.out, + aten.ldexp.default, + aten.ldexp.out, + aten.ldexp_.default, + aten.lt.Tensor, + aten.lt.Tensor_out, + aten.lt.Scalar, + aten.lt.Scalar_out, + aten.le.Scalar, + aten.le.Tensor, + aten.lerp.Scalar, + aten.lerp.Scalar_out, + aten.lerp.Tensor, + aten.lerp.Tensor_out, + aten.lerp_.Scalar, + aten.lerp_.Tensor, + aten.lgamma.default, + aten.lgamma.out, + aten.lgamma_.default, + aten.log.default, + aten.log.out, + aten.log10.default, + aten.log10.out, + aten.log10_.default, + aten.log1p.default, + aten.log1p.out, + aten.log1p_.default, + aten.log2.default, + aten.log2.out, + aten.log2_.default, + aten.log_.default, + aten.logaddexp.default, + aten.logaddexp.out, + aten.logaddexp2.default, + aten.logaddexp2.out, + aten.logical_and.default, + aten.logical_and.out, + aten.logical_and_.default, + aten.logical_not.default, + aten.logical_not.out, + aten.logical_not_.default, + aten.logical_or.default, + aten.logical_or.out, + aten.logical_or_.default, + aten.logical_xor.default, + aten.logical_xor.out, + aten.logical_xor_.default, + aten.logit.default, + aten.logit.out, + aten.logit_.default, + aten.masked_fill.Scalar, + aten.maximum.out, + aten.mul.Scalar, + aten.mul.Tensor, + aten.mul.out, + aten.mul_.Scalar, + aten.mul_.Tensor, + aten.mvlgamma.default, + aten.mvlgamma.out, + aten.mvlgamma_.default, + aten.native_dropout_backward.default, + aten.native_dropout_backward.out, + aten.nan_to_num.default, + aten.nan_to_num.out, + aten.nan_to_num_.default, + aten.ne.Scalar, + aten.neg.default, + aten.neg.out, + aten.neg_.default, + aten.nextafter.default, + aten.nextafter.out, + aten.nextafter_.default, + aten.polygamma.default, + aten.polygamma.out, + aten.polygamma_.default, + aten.positive.default, + aten.pow.Scalar, + aten.pow.Scalar_out, + aten.pow.Tensor_Scalar, + aten.pow.Tensor_Scalar_out, + aten.pow.Tensor_Tensor, + aten.pow.Tensor_Tensor_out, + aten.pow_.Scalar, + aten.pow_.Tensor, + aten.reciprocal.default, + aten.reciprocal.out, + aten.reciprocal_.default, + aten.rad2deg.default, + aten.rad2deg.out, + aten.rad2deg_.default, + aten.relu.default, + aten.relu_.default, + aten.remainder.Scalar, + aten.remainder.Scalar_Tensor, + aten.remainder.Scalar_out, + aten.remainder.Tensor, + aten.remainder.Tensor_out, + aten.remainder_.Scalar, + aten.remainder_.Tensor, + aten.round.decimals, + aten.round.decimals_out, + aten.round.default, + aten.round.out, + aten.round_.decimals, + aten.round_.default, + aten.rsqrt.default, + aten.rsqrt.out, + aten.rsqrt_.default, + aten.rsub.Scalar, + aten.sgn.default, + aten.sgn.out, + aten.sgn_.default, + aten.sigmoid.default, + aten.sigmoid.out, + aten.sigmoid_.default, + aten.sign.default, + aten.sign.out, + aten.sign_.default, + aten.signbit.default, + aten.signbit.out, + aten.silu.default, + aten.silu.out, + aten.sin.default, + aten.sin.out, + aten.sin_.default, + aten.sinc.default, + aten.sinc.out, + aten.sinc_.default, + aten.sinh.default, + aten.sinh.out, + aten.sinh_.default, + aten.sqrt.default, + aten.sqrt.out, + aten.sqrt_.default, + aten.square.default, + aten.square.out, + aten.square_.default, + aten.sub.Scalar, + aten.sub.Tensor, + aten.sub.out, + aten.sub_.Scalar, + aten.sub_.Tensor, + aten.tan.default, + aten.tan.out, + aten.tan_.default, + aten.tanh.default, + aten.tanh.out, + aten.tanh_.default, + aten.true_divide.Tensor, + aten.trunc.default, + aten.trunc.out, + aten.trunc_.default, + aten.where.self, + aten.where.self_out, + aten.xlogy.OutScalar_Self, + aten.xlogy.OutScalar_Other, + aten.xlogy.OutTensor, + aten.xlogy.Scalar_Other, + aten.xlogy.Scalar_Self, + aten.xlogy.Tensor, + aten.xlogy_.Scalar_Other, + aten.xlogy_.Tensor, + # backward point-wise ops + # please keep the entries below alphabetically sorted + aten.gelu_backward.default, + aten.sigmoid_backward.default, + aten.silu_backward.default, + aten.tanh_backward.default, + aten.threshold_backward.default, +] + + +def pointwise_strategy( + mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False +) -> OpStrategy: + max_shards_strategy_index = -1 + max_shards = -1 + + if _is_inplace_op(op_schema.op): + # inplace op should follow the first arg strategy + followed_strategy = op_schema.args_schema[0] + elif _is_out_variant_op(op_schema.op): + # out variant op should follow the out kwarg strategy + followed_strategy = op_schema.kwargs_schema["out"] + else: + # normal pointwise op, we choose to follow the arg with + # the max shards in case operands needs reshard + for idx, arg_strategy in enumerate(op_schema.args_schema): + if not isinstance(arg_strategy, OpStrategy): + continue + + arg_max_shards = arg_strategy.max_num_shards() + if arg_max_shards > max_shards: + max_shards_strategy_index = idx + max_shards = arg_max_shards + + followed_strategy = op_schema.args_schema[max_shards_strategy_index] + + assert isinstance( + followed_strategy, OpStrategy + ), f"no strategy to follow for {op_schema}!" + return common_pointwise_strategy( + mesh, op_schema.args_schema, followed_strategy, linearity + ) + + +def common_pointwise_strategy( + mesh: DeviceMesh, + args_schema: Sequence[object], + followed_strategy: OpStrategy, + linearity: bool, +) -> OpStrategy: + # handle broadcasting + common_shape = torch.broadcast_shapes( + *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] + ) + pointwise_strategy = OpStrategy([]) + + for placement_strategy in followed_strategy.strategies: + spec_to_follow = placement_strategy.output_spec + out_placements: List[Placement] = [] + for placement in spec_to_follow.placements: + if isinstance(placement, Shard): + shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) + common_ndim = len(common_shape) + new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim + out_placements.append(Shard(new_shard_dim)) + elif isinstance(placement, Partial) and not linearity: + # clear the partial placemnet if op does not support linearity + # by default we just replicate the partial, need to see if this + # is optimal for all cases + out_placements.append(Replicate()) + else: + out_placements.append(placement) + + input_specs: List[DTensorSpec] = [] + redistribute_costs: List[List[float]] = [] + for input_arg in args_schema: + if isinstance(input_arg, OpStrategy): + # every arg follow the out_placements, but need to handle broadcasting + input_arg_spec = input_arg.strategies[0].output_spec + input_arg_dims_map = infer_broadcast_dims_map( + common_shape, input_arg_spec.shape + ) + input_target_placements = map_placements_after_broadcast( + tuple(out_placements), + common_shape, + input_arg_dims_map, + ) + input_arg_target_spec = DTensorSpec( + mesh=mesh, + placements=input_target_placements, + tensor_meta=input_arg_spec.tensor_meta, + ) + input_specs.append(input_arg_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_arg, input_arg_target_spec) + ) + + pointwise_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec( + mesh=mesh, + placements=tuple(out_placements), + ), + input_specs=input_specs, + redistribute_cost=redistribute_costs, + ) + ) + return pointwise_strategy + + +def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + """ + return pointwise_strategy(mesh, op_schema, linearity=True) + + +for op in linear_pointwise_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( + linear_pointwise_strategy + ) + +for op in pointwise_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( + pointwise_strategy + ) + + +# TODO: add all for_each ops +for_each_ops = [ + aten._foreach_abs.default, + aten._foreach_abs_.default, + aten._foreach_addcdiv_.Scalar, + aten._foreach_addcdiv_.ScalarList, + aten._foreach_addcdiv_.Tensor, + aten._foreach_addcmul.Scalar, + aten._foreach_addcmul_.Scalar, + aten._foreach_addcmul_.ScalarList, + aten._foreach_addcmul_.Tensor, + aten._foreach_clamp_max_.Scalar, + aten._foreach_clamp_min_.Scalar, + aten._foreach_div_.List, + aten._foreach_div_.Scalar, + aten._foreach_div_.ScalarList, + aten._foreach_div_.Tensor, + aten._foreach_div.List, + aten._foreach_div.Scalar, + aten._foreach_div.ScalarList, + aten._foreach_div.Tensor, + aten._foreach_lerp_.Scalar, + aten._foreach_maximum_.List, + aten._foreach_mul.Scalar, + aten._foreach_mul.ScalarList, + aten._foreach_mul.Tensor, + aten._foreach_mul.List, + aten._foreach_mul_.Scalar, + aten._foreach_mul_.ScalarList, + aten._foreach_mul_.Tensor, + aten._foreach_mul_.List, + aten._foreach_neg.default, + aten._foreach_neg_.default, + aten._foreach_reciprocal_.default, + aten._foreach_sub.Scalar, + aten._foreach_sub_.Scalar, + aten._foreach_sub.List, + aten._foreach_sub_.List, + aten._foreach_sub.ScalarList, + aten._foreach_sub_.ScalarList, + aten._foreach_sqrt.default, + aten._foreach_sqrt_.default, + aten._foreach_zero_.default, + aten._foreach_exp.default, + aten._foreach_exp_.default, + aten._foreach_cos.default, + aten._foreach_cos_.default, + aten._foreach_log.default, + aten._foreach_log_.default, + aten._amp_foreach_non_finite_check_and_unscale_.default, +] + +for_each_linearity_ops = [ + aten._foreach_add.Scalar, + aten._foreach_add_.Scalar, + aten._foreach_add_.ScalarList, + aten._foreach_add.List, + aten._foreach_add_.List, +] + + +def list_pointwise_strategy( + mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False +) -> StrategyType: + """ + Apply the pointwise strategy to the zipped arguments. For example, if we + run a foreach add of two lists l1 and l2, then we apply the pointwise + strategy on each pair (l1[i], l2[i]). If the first argument is a list but + the second (or later) one is a tensor, then we broadcast the tensor by + replicating it into a list with the length of the first argument. + + Args: + mesh (DeviceMesh): device mesh for pointwise ops + op_schema (OpSchema): schema of the operator to generate strategy for + linearity (bool): specify whether op(a) + op(b) = op(a + b) + + Returns: + OpStrategy: generated strategy + """ + + def args_tuple_strategies(args_schema: Tuple[object, ...]) -> List[TupleStrategy]: + first_arg = args_schema[0] + assert isinstance(first_arg, TupleStrategy) + strategy_len = len(first_arg.childs) + tuple_strategies: List[TupleStrategy] = [] + for arg_idx, arg in enumerate(args_schema): + if isinstance(arg, TupleStrategy): + # every tuple strategy should have the same length + assert len(arg.childs) == strategy_len + tuple_strategies.append(arg) + elif isinstance(arg, OpStrategy): + if arg_idx > 0: # implicitly broadcast + tuple_strategies.append( + TupleStrategy([arg for _ in range(strategy_len)]) + ) + else: + raise RuntimeError( + f"list op only supports tuple strategy! {op_schema}" + ) + return tuple_strategies + + args_strategies = args_tuple_strategies(op_schema.args_schema) + follow_strategy: TupleStrategy = args_strategies[0] + list_strategy: List[OpStrategy] = [] + for child_idx, child_strtgy in enumerate(follow_strategy.childs): + assert isinstance(child_strtgy, OpStrategy) + args_schema: List[StrategyType] = [ + arg_strategy.childs[child_idx] for arg_strategy in args_strategies + ] + pointwise_strategy: OpStrategy = common_pointwise_strategy( + mesh, args_schema, child_strtgy, linearity + ) + list_strategy.append(pointwise_strategy) + return TupleStrategy(list_strategy) + + +def list_linear_pointwise_strategy( + mesh: DeviceMesh, op_schema: OpSchema +) -> StrategyType: + """ + for each list op stratgy that supports linearity + """ + return list_pointwise_strategy(mesh, op_schema, linearity=True) + + +for op in for_each_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_pointwise_strategy + ) + +for op in for_each_linearity_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_linear_pointwise_strategy + ) + +fused_ops = [ + aten._fused_adam_.default, + aten._fused_adam.default, + aten._fused_adam.tensor_lr, + aten._fused_adam_.tensor_lr, + aten._fused_adamw_.default, + aten._fused_adamw.default, + aten._fused_adamw.tensor_lr, + aten._fused_adamw_.tensor_lr, +] + +for op in fused_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_pointwise_strategy + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_random_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..726b25e1eed023a2b25a53b36b4bb13d1c75a787 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_random_ops.py @@ -0,0 +1,38 @@ +# mypy: allow-untyped-decorators +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + StrategyType, +) +from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy + + +aten = torch.ops.aten + + +@register_op_strategy( + [ + aten.normal_.default, + aten.uniform_.default, + aten.native_dropout.default, + aten.bernoulli_.float, + aten.bernoulli.default, + ] +) +def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + random_strategy = OpStrategy([]) + for arg_strategy in self_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # TODO: figure out how inplace random op should behave when it's partial + raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") + random_strategy.strategies.append(PlacementStrategy(output_specs=arg_spec)) + + return random_strategy diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e9bcb3b0d122402530450b105ba637ab7fc469dc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py @@ -0,0 +1,792 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import cast, List, Optional, Sequence, Tuple + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + _is_inplace_op, + OpSchema, + OpStrategy, + OutputSharding, + PlacementList, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._ops._common_rules import pointwise_rule +from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +from torch.distributed.tensor._ops.utils import ( + expand_to_full_mesh_op_strategy, + is_tensor_dim_sharded, + is_tensor_evenly_shardable, + is_tensor_partial, + normalize_dim, + register_op_strategy, + register_prop_rule, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # Default strategy by default just propagate the first input strategy + select_strategy = op_schema.args_schema[0] + assert isinstance(select_strategy, OpStrategy) + default_strategy = [] + for strategy in select_strategy.strategies: + # we create new DTensorSpecs even for default strategy to assure that + # the tensor metas are distinct between the arguments and outputs + default_strategy.append( + PlacementStrategy( + output_specs=DTensorSpec( + mesh=strategy.output_spec.mesh, + placements=strategy.output_spec.placements, + ) + ) + ) + return OpStrategy(default_strategy) + + +register_op_strategy( + [ + aten.clone.default, + aten.contiguous.default, + aten.copy_.default, + aten.detach.default, + aten.fill_.Scalar, + aten.zero_.default, + ] +)(default_strategy) + +register_op_strategy( + aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) +)(default_strategy) + + +@register_op_strategy( + [ + aten.equal.default, + aten.is_same_size.default, + ] +) +def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # equal_strategy deals with ops that comparing two tensor, we need to make sure + # sharding layout the same with two operands, we choose to follow the arg with max + # num of shards, still keep is_same_size here for completeness as they share the + # same strategy in theory. + self_strategy, other_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(other_strategy, OpStrategy) + + select_strategy = ( + self_strategy + if self_strategy.max_num_shards() >= other_strategy.max_num_shards() + else other_strategy + ) + equal_strategy = OpStrategy([]) + + for arg_strategy in select_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # if the arg_spec have partial, reshard to replicate + # otherwise local shard tensor comparison would be invalid + output_spec = DTensorSpec( + mesh=arg_spec.mesh, + placements=tuple( + Replicate() if isinstance(p, Partial) else p + for p in arg_spec.placements + ), + ) + equal_strategy.strategies.append( + PlacementStrategy(output_specs=output_spec) + ) + else: + equal_strategy.strategies.append(PlacementStrategy(arg_spec)) + return equal_strategy + + +@register_op_strategy( + [ + aten.empty_like.default, + aten.ones_like.default, + aten.rand_like.default, + aten.randn_like.default, + aten.zeros_like.default, + ], + schema_info=RuntimeSchemaInfo(1, ["dtype"]), +) +@register_op_strategy( + [aten.full_like.default], + schema_info=RuntimeSchemaInfo(2, ["dtype"]), +) +@register_op_strategy( + [ + aten.randint_like.default, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + ], + schema_info=RuntimeSchemaInfo(3, ["dtype"]), +) +def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # create_like_strategy deals with ops that creating tensors with same + # shape as input, but with specific content that does not depend on + # the input, we can propagate sharding, but we have to make sure we + # move from partial to replicated. + select_strategy = op_schema.args_schema[0] + create_like_strategy = OpStrategy([]) + assert isinstance(select_strategy, OpStrategy) + for arg_strategy in select_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # if the arg_spec have partial, accept partial + # in the input_specs but output replicate for + # those corresponding mesh dims + output_spec = DTensorSpec( + mesh=arg_spec.mesh, + placements=tuple( + Replicate() if isinstance(p, Partial) else p + for p in arg_spec.placements + ), + ) + create_like_strategy.strategies.append( + PlacementStrategy(output_specs=output_spec, input_specs=(arg_spec,)) + ) + + else: + create_like_strategy.strategies.append(PlacementStrategy(arg_spec)) + + return create_like_strategy + + +@register_op_strategy( + [ + aten.new_empty.default, + aten.new_full.default, + aten.new_ones.default, + aten.new_zeros.default, + aten.new_empty_strided.default, + ], + schema_info=RuntimeSchemaInfo(1, ["dtype"]), +) +def new_factory_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # Currently there are two strategies: + # 1. let the output be replicated + # 2. let the output follow the input if input and output have the same shape + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + input_shape = input_strategy.shape + output_shape = op_schema.args_schema[1] + assert isinstance(output_shape, list) + + new_factory_strategy = OpStrategy([]) + for arg_strategy in input_strategy.strategies: + input_spec = arg_strategy.output_spec + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + new_factory_strategy.strategies.append( + PlacementStrategy( + output_specs=replica_spec, + input_specs=(input_spec,), + redistribute_cost=[[0.0] * mesh.ndim], + ) + ) + + if tuple(input_shape) == tuple(output_shape) and input_spec.is_sharded(): + # NOTE: for new_empty_strided, currently the non-replicate sharding + # is supported only when the shape is evenly shardable + if ( + op_schema.op == aten.new_empty_strided.default + and not is_tensor_evenly_shardable(input_shape, input_spec) + ): + continue + + new_factory_strategy.strategies.append( + PlacementStrategy( + output_specs=input_spec, + input_specs=(input_spec,), + # encouraging new tensor placement to be the same as input + redistribute_cost=[[-0.1] * mesh.ndim], + ) + ) + + return new_factory_strategy + + +@register_op_strategy(aten.bucketize.Tensor) +def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """Just propagate input sharding, but expect replicated for boundaries input.""" + input_strategy = op_schema.args_schema[0] + bucketize_strategy = OpStrategy([]) + assert isinstance(input_strategy, OpStrategy) + for arg_strategy in input_strategy.strategies: + arg_spec = DTensorSpec(mesh, arg_strategy.output_spec.placements) + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + bucketize_strategy.strategies.append( + PlacementStrategy( + output_specs=arg_spec, input_specs=(arg_spec, replica_spec) + ) + ) + + return bucketize_strategy + + +@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1)) +def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """Forward all shardings except the slice dimension.""" + defaults = (None, 0, None, None, 1) + input_strategy, dim, start, end, step = ( + op_schema.args_schema + defaults[len(op_schema.args_schema) :] + ) + assert isinstance(input_strategy, OpStrategy) + input_shape = input_strategy.shape + input_ndim = input_strategy.ndim + assert isinstance(dim, int) + if start is None: + start = 0 + if end is None or end > input_shape[dim]: + end = input_shape[dim] + assert isinstance(start, int) + assert isinstance(end, int) + assert isinstance(step, int) + + # normalize args + slice_dim = normalize_dim(dim, input_ndim) + start = normalize_dim(start, input_shape[dim]) + end = normalize_dim(end, input_shape[dim]) + + redundant_slice = start == 0 and end == input_shape[dim] and step == 1 + + slice_strategy = OpStrategy([]) + + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice: + # only add the strategy if the slice dim is not sharded + out_spec = DTensorSpec(mesh, arg_spec.placements) + slice_strategy.strategies.append(PlacementStrategy(output_specs=out_spec)) + if not slice_strategy.strategies: + # if all strategies are filtered out, unsharding all specs on slice dim + # of the input strategy, and use that as the op strategy + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + unshard_spec = DTensorSpec( + mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim) + ) + slice_strategy.strategies.append( + PlacementStrategy(output_specs=unshard_spec) + ) + return slice_strategy + + +def unshard_tensor_dim( + placements: Sequence[Placement], dim: int +) -> Tuple[Placement, ...]: + """Disallow the given tensor dimension to be sharded.""" + return tuple( + p if (not isinstance(p, Shard) or p.dim != dim) else Replicate() + for p in placements + ) + + +def replicate_tensor_dim( + placements: Sequence[Placement], dim: int +) -> Tuple[Placement, ...]: + """Force the given tensor dimension to be replicated.""" + # Not using p.is_shard() to avoid mypy complain about Placement not having + # attribute dim. + return tuple( + Replicate() if p.is_partial() or isinstance(p, Shard) and p.dim == dim else p + for p in placements + ) + + +@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2)) +def gen_slice_scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + # 1. number of dimensions in input and src need to match. + # 2. number of elements on all non-dim need to match between input and src. + # 3. numer of elements in src in dim need to match the slice size. + # Given the above: + # - We suggest for src to follow the sharding of input, except on the scatter dimension, + # where our best bet for now is to make them replicated as a fall-back. + # TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding. + + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + input_ndim = input_strategy.ndim + slice_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 + ) + slice_dim = normalize_dim(slice_dim, input_ndim) + + slice_scatter_strategy = OpStrategy([]) + # by default follow the input strategy for both input and src + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + if not ( + is_tensor_dim_sharded(arg_spec, dim=slice_dim) + or is_tensor_partial(arg_spec) + ): + # only add the strategy if the slice_scatter dim is not sharded or partial + slice_scatter_strategy.strategies.append( + PlacementStrategy(output_specs=arg_spec) + ) + + if not slice_scatter_strategy.strategies: + # if all strategies are filtered out, replicating all specs on slice_scatter dim + # of the input strategy, and use that as the op strategy + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + replicate_spec = DTensorSpec( + mesh, replicate_tensor_dim(arg_spec.placements, dim=slice_dim) + ) + slice_scatter_strategy.strategies.append( + PlacementStrategy(output_specs=replicate_spec) + ) + return slice_scatter_strategy + + +@register_op_strategy(aten._local_scalar_dense.default) +def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + """Only allow replication on the input/output.""" + replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + return OpStrategy([PlacementStrategy(replicate_spec)]) + + +@register_op_strategy( + [aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src], + schema_info=RuntimeSchemaInfo(1), +) +def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + if len(op_schema.args_strategy) < 3: + # scatter_.src/scatter.src with src be float number instead of tensor + all_replicate: PlacementList = [Replicate()] * 3 + else: + all_replicate = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + # TODO: see if we can support input sharding pattern + inplace_op = _is_inplace_op(op_schema.op) + + op_strategy = expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op + ) + return op_strategy + + +@register_op_strategy(aten.gather.default) +def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + dim = cast(int, op_schema.args_schema[1]) + index_strategy = cast(OpStrategy, op_schema.args_schema[2]) + + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) + input_sharding: PlacementList = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +def _derive_follow_placements_from_tuple_strategy( + tuple_strategy: TupleStrategy, +) -> Sequence[Placement]: + """ + derive the placements to follow from the tuple strategy, mainly used by + aten.stack, aten.cat, where each operand have the same shape, and correspondingly + expecting the same sharding + """ + + def merge_placement( + cur_placement: Placement, new_placement: Placement + ) -> Placement: + # semantic if we already have a follow placement, we + # check each placement for the current arg placement + # to see if we want to merge/adjust the placement to follow + # the priority: Partial -> Shard -> Replicate + if cur_placement == new_placement: + return cur_placement + + if cur_placement.is_partial(): + if new_placement.is_shard(): + # follow new placement + return new_placement + elif new_placement.is_partial(): + # different partial types, we can't merge and have to replicate all here + return Replicate() + else: + # follow partial + return cur_placement + elif cur_placement.is_shard(): + if new_placement.is_shard(): + # cur/new placement are different sharding (i.e. different shard dim) + # currently fallback to replicate all args + return Replicate() + else: + # for partial/replicate, follow the current shard placement + return cur_placement + else: + # current replicate, just follow new placement + return new_placement + + follow_placements: Optional[List[Placement]] = None + for arg_strategy in tuple_strategy.childs: + assert isinstance(arg_strategy, OpStrategy) + for placement_strategy in arg_strategy.strategies: + arg_placements = placement_strategy.output_spec.placements + if follow_placements is None: + follow_placements = list(arg_placements) + continue + mesh_ndim = len(follow_placements) + assert follow_placements is not None + for mesh_idx in range(mesh_ndim): + # merge placements with the priority + follow_placements[mesh_idx] = merge_placement( + follow_placements[mesh_idx], arg_placements[mesh_idx] + ) + assert follow_placements is not None, "follow placements should not be None!" + return follow_placements + + +def normalize_shard_for_stack( + placements: Sequence[Placement], insert_dim: int = 0 +) -> Sequence[Placement]: + # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to + # be normalized with the new Shard placement + normalized_placements: List[Placement] = [] + for placement in placements: + if isinstance(placement, Shard) and placement.dim >= insert_dim: + normalized_placements.append(Shard(placement.dim + 1)) + else: + normalized_placements.append(placement) + return normalized_placements + + +@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True)) +def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.ndim + dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) + + follow_placements = _derive_follow_placements_from_tuple_strategy( + input_tuple_strategy + ) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple( + DTensorSpec(mesh, tuple(follow_placements)) + for _ in range(len(input_tuple_strategy.childs)) + ) + + follow_placements = normalize_shard_for_stack(follow_placements, dim) + + op_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) + return op_strategy + + +@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True)) +def cat_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.ndim + dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) + + follow_placements = _derive_follow_placements_from_tuple_strategy( + input_tuple_strategy + ) + # for cat we unshard the cat dim if it is sharded + follow_placements = unshard_tensor_dim(follow_placements, dim) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple( + DTensorSpec(mesh, tuple(follow_placements)) + for _ in range(len(input_tuple_strategy.childs)) + ) + op_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) + return op_strategy + + +@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1)) +def prop_index_select(op_schema: OpSchema) -> OutputSharding: + values_spec, dim, indices_spec = op_schema.args_schema + + assert isinstance(values_spec, DTensorSpec) + assert isinstance(dim, int) + assert isinstance(indices_spec, DTensorSpec) + + all_indices_spec: List[Optional[DTensorSpec]] = [ + indices_spec if dim == i else None for i in range(values_spec.ndim) + ] + + result = prop_index( + OpSchema( + op=op_schema.op, + args_schema=(values_spec, all_indices_spec), + kwargs_schema=op_schema.kwargs_schema, + ) + ) + if result.redistribute_schema: + schema_suggestion = result.redistribute_schema + result.redistribute_schema = OpSchema( + op=op_schema.op, + args_schema=( + schema_suggestion.args_schema[0], + dim, + schema_suggestion.args_schema[1][dim], + ), + kwargs_schema=op_schema.kwargs_schema, + ) + return result + + +@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)) +def prop_index(op_schema: OpSchema) -> OutputSharding: + """ + Expect replicated on the first input; _mostly_ pointwise on the second input. + + TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first. + """ + # Current sharding constraints: + # For values: + # 1. We currently require that the dimension of values_spec be replicated or partial + # if they are being indexed on. + # 2. Other dimensions of values_spec can remain sharded if they are so. + # For indices: + # Indices can be either sharded or replicated. All index tensors need to be sharded + # in a compatible way, following the pointwise rule (including resolving Partial + # into either sharded or replicated) + + values_spec, multi_indices_spec = op_schema.args_schema + assert isinstance(values_spec, DTensorSpec) + assert isinstance(multi_indices_spec, list) + multi_indices_spec = cast(List[Optional[DTensorSpec]], multi_indices_spec) + valid_indices_spec: List[Tuple[int, DTensorSpec]] = [ + (i, a) for i, a in enumerate(multi_indices_spec) if a is not None + ] + + # 1. All indices have to be sharded equally. Moreover, indices can be broadcast. + # Here, we piggyback on the pointwise sharding rule for indices. + indices_out = pointwise_rule( + OpSchema( + op=op_schema.op, + args_schema=tuple(v[1] for v in valid_indices_spec), + kwargs_schema={}, + ) + ) + need_reshard_on_indices = indices_out.output_spec is None + + if not need_reshard_on_indices: + # this means that our inputs are already sharded properly and we will use that as our indices_spec + assert isinstance(indices_out.output_spec, DTensorSpec) + indices_spec: DTensorSpec = indices_out.output_spec + else: + assert indices_out.redistribute_schema is not None + valid_indices_suggestion = indices_out.redistribute_schema + for i, v in enumerate(valid_indices_suggestion.args_spec): + multi_indices_spec[valid_indices_spec[i][0]] = v + # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then + # use that to compute our ideal values_spec + indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec + assert isinstance(indices_output_spec, DTensorSpec) + indices_spec = indices_output_spec + + lookup_dims = {v[0] for v in valid_indices_spec} + + need_reshard_on_values = tuple( + (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard))) + for vp, ip in zip(values_spec.placements, indices_spec.placements) + ) + + if not need_reshard_on_indices and not any(need_reshard_on_values): + value_placements = values_spec.placements + + all_dims_consecutive = all( + b[0] - a[0] == 1 + for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1]) + ) + if all_dims_consecutive: + # if all index vectors are consecutives, insert at the dimension of the first index + insert_dim: int = valid_indices_spec[0][0] + else: + # else, insert on the first dimension + insert_dim = 0 + + def place(vp: Placement, ip: Placement) -> Placement: + if isinstance(vp, Shard): + return Shard( + vp.dim + if vp.dim < insert_dim + # accounts for the offset in output dimensions + else vp.dim + + indices_spec.ndim + - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec) + ) + if isinstance(ip, Shard): + return Shard(ip.dim + insert_dim) + # Partial or Replicated + return vp + + value_placements = tuple( + place(vp, ip) + for vp, ip in zip(values_spec.placements, indices_spec.placements) + ) + result = OutputSharding( + output_spec=DTensorSpec( + mesh=values_spec.mesh, + placements=value_placements, + ) + ) + return result + else: + result = OutputSharding( + output_spec=None, + redistribute_schema=OpSchema( + op=op_schema.op, + args_schema=( + DTensorSpec( + mesh=values_spec.mesh, + placements=tuple( + [ + Replicate() if need_reshard_on_values[i] else v + for i, v in enumerate(values_spec.placements) + ] + ), + tensor_meta=values_spec.tensor_meta, + ), + multi_indices_spec, + ), + kwargs_schema=op_schema.kwargs_schema, + ), + ) + return result + + +@register_prop_rule( + [ + aten.split.Tensor, + aten.split_with_sizes.default, + aten.split_with_sizes_copy.default, + ], + schema_info=RuntimeSchemaInfo(1), +) +def split_rule(op_schema: OpSchema) -> OutputSharding: + output_spec_list: List[DTensorSpec] = [] + input_spec = cast(DTensorSpec, op_schema.args_schema[0]) + ndim = input_spec.ndim + split_size_or_sections = op_schema.args_schema[1] + dim = cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 + dim = normalize_dim(dim, ndim) + + # TODO: tensor to split cannot have Partial + # in its placements for now. Will need to + # support in future. + if input_spec.sums: + raise NotImplementedError( + f"splitting distributed tensor with " + f"Partial placement is not implemented!\n" + f"DTensorSpec={input_spec}" + ) + + # TODO: just like slice op, split replicates before + # splitting on a sharded dimension + need_reshard = False + if is_tensor_dim_sharded(input_spec, dim=dim): + need_reshard = True + input_spec = DTensorSpec( + mesh=input_spec.mesh, + placements=unshard_tensor_dim(input_spec.placements, dim=dim), + tensor_meta=input_spec.tensor_meta, + ) + + if need_reshard: + return OutputSharding( + None, + redistribute_schema=OpSchema( + op=op_schema.op, + args_schema=(input_spec,) + op_schema.args_schema[1:], + kwargs_schema=op_schema.kwargs_schema, + ), + ) + + def size_split(N, i): + # Last chunk will be smaller if the tensor size N + # along the given dimension dim is not divisible by i. + assert i > 0 + return [i] * (N // i) + ([N % i] if N % i != 0 else []) + + output_size_list = ( + size_split(input_spec.shape[dim], split_size_or_sections) + if isinstance(split_size_or_sections, int) + else split_size_or_sections + ) + output_spec_list = [ + DTensorSpec( + mesh=input_spec.mesh, + placements=input_spec.placements, + ) + for _ in range(len(output_size_list)) + ] + return OutputSharding(output_spec_list) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_view_ops.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_view_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..451b92c80b24b8df0ed7719392302f836175e926 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_view_ops.py @@ -0,0 +1,666 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from dataclasses import dataclass +from typing import ( + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, +) +from torch.distributed.tensor._ops.utils import ( + generate_redistribute_costs, + normalize_dim, + normalize_dims, + prod, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard + + +aten = torch.ops.aten + +Shape = Tuple[int, ...] + + +@dataclass +class DimSpec: + """Specifies how an output dimension maps to an input dimension.""" + + def inputs(self) -> Iterable["DimSpec"]: + return () + + +# Rules that map each dimension of the output to dimensions of the input tensor +DimMap = Tuple[DimSpec, ...] + + +@dataclass +class Singleton(DimSpec): + """Output dimension is a singleton.""" + + +@dataclass +class InputDim(DimSpec): + """Output dimension maps directly to an input dimension.""" + + input_dim: int + + +@dataclass +class Broadcast(DimSpec): + """Output is the broadcast of a singleton input dimension.""" + + dim: DimSpec + dim_size: int + + @classmethod + def new(cls, dim: DimSpec, dim_size: int) -> DimSpec: + return Broadcast(dim, dim_size) + + def inputs(self) -> Iterable[DimSpec]: + return (self.dim,) + + +@dataclass +class NewDim(DimSpec): + """This is a new dimension created by the op.""" + + size: int + + @classmethod + def new(cls, size: int) -> DimSpec: + return Singleton() if size == 1 else NewDim(size) + + +@dataclass +class Repeat(DimSpec): + """Output dimension is the input dimension repeated n-times.""" + + input_dim: DimSpec + times: int + + @classmethod + def new(cls, dim: DimSpec, times: int) -> DimSpec: + if times == 1: + return dim + elif isinstance(dim, Singleton): + # repeating a singleton is the same as broadcasting it + return Broadcast(dim, times) + else: + return Repeat(dim, times) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +@dataclass +class Flatten(DimSpec): + """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output.""" + + input_dims: Sequence[DimSpec] + + @classmethod + def new(cls, dims: Sequence[DimSpec]) -> DimSpec: + if len(dims) == 0: + # flattening a scalar leads to a singleton + return Singleton() + elif len(dims) == 1: + # flattening a single dimension is no-op + return dims[0] + else: + return Flatten(dims) + + def inputs(self) -> Iterable[DimSpec]: + return self.input_dims + + +@dataclass +class Split(DimSpec): + """ + This dimension is a member of a decomposition of the input dim. + + Note that input_dim itself could be a Flattened set of input dims. + """ + + input_dim: DimSpec + group_shape: Shape + split_id: int + + @classmethod + def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec: + assert len(group_shape) > 0 + if len(group_shape) == 1: + # not really a group, just return the input dim back + assert idx == 0 + return dim + elif group_shape[idx] == 1: + return Singleton() + else: + # remove singletons from group + # group_mapping = [(new_index, (shape, old_index)) ...] + group_mapping = list( + enumerate((s, i) for i, s in enumerate(group_shape) if s != 1) + ) + new_group_shape = tuple(m[1][0] for m in group_mapping) + new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0] + return Split(dim, new_group_shape, new_idx) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +def dim_pad_left(ndim: int, min_dims: int) -> DimMap: + return (Singleton(),) * max(0, min_dims - ndim) + tuple( + InputDim(i) for i in range(ndim) + ) + + +def dim_atleast_3d(ndim: int) -> DimMap: + if ndim == 0: + return (Singleton(), Singleton(), Singleton()) + elif ndim == 1: + return (Singleton(), InputDim(0), Singleton()) + elif ndim == 2: + return (InputDim(0), InputDim(1), Singleton()) + else: + return tuple(InputDim(i) for i in range(ndim)) + + +def expand(input_shape: Shape, shape: Shape) -> DimMap: + """Implement broadcast on multiple dimensions.""" + assert len(shape) >= len(input_shape) + + # 1. create padded input dimensions + padded_input = dim_pad_left(len(input_shape), len(shape)) + # 2. check that input shapes are compatible + mapping = [] + for p, desired_s in zip(padded_input, shape): + if isinstance(p, Singleton): + actual_s = 1 + assert desired_s >= 0 + else: + assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}" + actual_s = input_shape[p.input_dim] + assert actual_s == 1 or desired_s == -1 or desired_s == actual_s + mapping.append( + p + if desired_s in (1, -1) or desired_s == actual_s + else Broadcast.new(p, desired_s) + ) + return tuple(mapping) + + +def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape: + if isinstance(sizes[0], int): + return cast(Shape, sizes) + elif len(sizes) == 1: + return sizes[0] + else: + raise RuntimeError("Size must be int... or tuple") + + +def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: + if ndim == 0: + return (Singleton(),) + elif ndim == 1: + return (InputDim(0),) + else: + # only flattening dims from start_dim to end_dim (inclusive) + # other dims are passed through + if end_dim < 0: + end_dim += ndim + results: List[DimSpec] = [InputDim(i) for i in range(start_dim)] + results.append( + Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1))) + ) + results.extend([InputDim(i) for i in range(end_dim + 1, ndim)]) + return tuple(results) + + +def dim_movedim( + ndim: int, + input: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], +) -> DimMap: + input = normalize_dims(input, ndim) + destination = normalize_dims(destination, ndim) + + assert len(input) == len(destination) + input_set = set(input) + assert len(input_set) == len(input), "Found repeated input dims" + assert len(set(destination)) == len(destination), "Found repeated output dims" + assert max(input) < ndim + assert max(destination) < ndim + + dest = [-1] * ndim + for i, d in zip(input, destination): + dest[d] = i + + unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set) + for i in range(ndim): + if dest[i] == -1: + dest[i] = next(unused_inputs_iter) + + return tuple(InputDim(i) for i in dest) + + +def dim_repeat(ndim: int, sizes: Shape) -> DimMap: + sizes = normalize_sizes(sizes) + assert ( + len(sizes) >= ndim + ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." + pad = len(sizes) - ndim + return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( + Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) + ) + + +def infer_size(total_size: int, sizes: Shape) -> Shape: + """ + One dimension input to view may be "-1". + + Infer the size of this dimension given the total_size. + """ + infers = [i for i, s in enumerate(sizes) if s == -1] + size = prod(sizes) + assert len(infers) <= 1, "can only infer one size" + if infers: + size = -size + missing_size = total_size // size + assert ( + total_size % size == 0 + ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." + return tuple(s if s != -1 else missing_size for s in sizes) + assert size == total_size, f"sizes do not match {total_size} vs {size}" + return sizes + + +def view_groups(from_size: Shape, to_size: Shape) -> DimMap: + """ + Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension. + + A view or reshape operation can be decomposed into a set of 3 types of smaller operations: + 1) Forward a dimension from input to output + 2) Flatten a set of dimensions into a single dimension + 3) Split one dimension into multiple dimensions + + view_groups identifies these operations and returns, for each output dimension, what + is operation was performed in the input dimension. For example: + + view_groups([2, 3, 4], [2, 12]) -> ( + InputDim(0), + Flatten((InputDim(1), InputDim(2))) + ) + + - ouptut dimension 0 maps to input dimension 0 + - output dimension 1 maps to a flattened input dimensions 1 and 2 + + + view_groups([2, 3], [3, 2]) -> ( + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), + ) + + - in the above, input is flattened into a single dimension and then split + into two separate dimensions with different sizes from the input. + """ + from_nelem = prod(from_size) + to_size = infer_size(from_nelem, normalize_sizes(to_size)) + + assert from_nelem == prod(to_size), "Total view shape does not add up" + + from_idx = 0 + to_idx = 0 + from_len = len(from_size) + to_len = len(to_size) + + result_pp = [] + + while from_idx < from_len or to_idx < to_len: + from_group_dim, to_group_shape = [], [] + + if from_idx >= from_len: + f = 1 + else: + f = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + + if to_idx >= to_len: + t = 1 + else: + t = to_size[to_idx] + to_group_shape.append(t) + to_idx += 1 + + # if any of the groups is singleton, great, we need to backtrack though + if f == 1 and t != 1: + # produces ([1], []) + to_idx -= 1 + to_group_shape = [] + elif f != 1 and t == 1: + # produces ([], [1]) + from_idx -= 1 + from_group_dim = [] + else: + # produces ([1], [1]), ([2], [2]), ([2,3], [6]) + while f != t: + if f < t: + nf = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + f *= nf + else: + nt = to_size[to_idx] + to_group_shape.append(nt) + to_idx += 1 + t *= nt + + if len(to_group_shape) > 0: + flattened = Flatten.new( + tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] >= 1) + ) + result_pp += [ + Split.new(flattened, tuple(to_group_shape), i) + for i in range(len(to_group_shape)) + ] + + return tuple(result_pp) + + +def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap: + if len(dims) < ndim: + dims = (1,) * (ndim - len(dims)) + dims + return dim_repeat(ndim, dims) + + +def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: + dim1 = normalize_dim(dim1, ndim) + dim2 = normalize_dim(dim2, ndim) + assert dim1 < ndim + assert dim2 < ndim + dimmap = [InputDim(i) for i in range(ndim)] + swapdim = dimmap[dim1] + dimmap[dim1] = dimmap[dim2] + dimmap[dim2] = swapdim + return tuple(dimmap) + + +def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: + # FIXME: this is wrong when dim=None and one of the dimensions + # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could + # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to + # removal of a dimension that is not actually a singleton. + return tuple( + InputDim(i) + for i, s in enumerate(shape) + if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape))) + ) + + +def dim_unsqueeze(ndim: int, dim: int) -> DimMap: + dims = tuple(InputDim(i) for i in range(ndim)) + if dim < 0: + dim += ndim + 1 + return dims[:dim] + (Singleton(),) + dims[dim:] + + +def dim_view_as_real(shape: Shape) -> DimMap: + ndim = len(shape) + results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)] + # each complex number is split into two real numbers, + # resulting in one more dimension of size 2 + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0)) + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1)) + return tuple(results) + + +def dim_reduction( + ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool +) -> DimMap: + """ + General fallback for reduction ops where Partial() does not apply. + + This will cause incoming tensor to be replicated on the reducing dimensions. + """ + if dim_or_dims is None: + dim_or_dims = tuple(range(ndim)) + if isinstance(dim_or_dims, int): + dim_or_dims = (dim_or_dims,) + dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims) + return tuple( + InputDim(i) if i not in dim_or_dims else Singleton() + for i in range(ndim) + if i not in dim_or_dims or keepdim + ) + + +dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { + torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), + torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), + torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), + torch.broadcast_to: lambda input, shape: expand(input.shape, shape), + Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + torch.flatten: lambda tensor: dim_flatten(tensor.ndim), + torch.movedim: lambda input, source, destination: dim_movedim( + input.ndim, source, destination + ), + torch.permute: lambda input, dims: tuple( + InputDim(i) for i in normalize_dims(dims, input.ndim) + ), + torch.ravel: lambda tensor: dim_flatten(tensor.ndim), + Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + torch.reshape: lambda input, shape: view_groups(input.shape, shape), + torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), + torch.tile: lambda input, dims: dim_tile(input.ndim, dims), + torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), + torch.view_as_real: lambda input: dim_view_as_real(input.shape), +} + + +def propagate_shape_and_sharding( + input_src_placements: Sequence[Placement], + local_in_shape: Shape, + rule: DimMap, + mesh_sizes: Shape, +) -> Tuple[Sequence[Placement], Sequence[Placement]]: + """ + Determine input target sharding and output sharding based on + given global tensor shape and input source sharding. + + Sharding propagation follows mapped dimensions: + - An output dimension that maps directly to an input dimension is sharded equally + - An output dimension that is a flattened set of input dimensions can only be + sharded if only the leftmost flattened dimension is sharded. + - An output dimension that is a split of the input dimension can only be sharded + if the leftmost split size is divisible by the mesh dimension + """ + assert len(input_src_placements) == len(mesh_sizes) + # for each input dim, for each mesh dim, provides a list of possible shardable dimensions + mesh_ndim = len(mesh_sizes) + shardable_dims: Dict[int, List[bool]] = {} + + # in case an input dimension disappears (e.g. collapsing, reduction) + # we cannot shard in that dimension (we need a replication fall-back rule) + seen_input_dims: Set[int] = set() + + def collect_used_inputs(cmd: DimSpec) -> None: + if isinstance(cmd, InputDim): + seen_input_dims.add(cmd.input_dim) + for inp in cmd.inputs(): + collect_used_inputs(inp) + + for cmd in rule: + collect_used_inputs(cmd) + for dim in range(len(local_in_shape)): + shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim + + def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: + if isinstance(cmd, InputDim): + return cmd + elif isinstance(cmd, Flatten): + for dim in cmd.input_dims[1:]: + if isinstance(dim, InputDim): + shardable_dims[dim.input_dim] = [False] * mesh_ndim + dim0 = cmd.input_dims[0] + return dim0 if isinstance(dim0, InputDim) else None + elif isinstance(cmd, Split): + in_dim = get_in_dim_to_shard(cmd.input_dim) + out_size = cmd.group_shape[cmd.split_id] + if cmd.split_id == 0 and in_dim is not None: + # we need to check that the input dimension is divisible + # by the size of the submesh we're sharding it on + # NOTE: it would be possible to shard the same input dimension + # on more than one mesh dimension. In that case, the dimension + # needs to be divisible by the product of mesh sizes. + # In order to keep the problem more tractable, we will not consider + # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) + # but we will allow it if that's the input and it's compatible + + # 1. is this dimension shardable on each individual mesh dim? + shardable_dims[in_dim.input_dim] = [ + out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes + ] + + # 2. here we special case things like [Shard(0), Shard(0)] + submesh_size = 1 + for size, shard in zip(mesh_sizes, input_src_placements): + if isinstance(shard, Shard) and shard.dim == in_dim: + submesh_size *= size + assert ( + out_size % submesh_size == 0 + ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + + # we will only shard our first component of the split + return in_dim if cmd.split_id == 0 else None + elif isinstance(cmd, Repeat): + in_dim = get_in_dim_to_shard(cmd.input_dim) + if in_dim is not None: + shardable_dims[in_dim.input_dim] = [False] * mesh_ndim + return None + else: + return None + + # for each output dim, find the corresponding input dim in terms of sharding prop + shard_dim_map = {} + for dim, cmd in enumerate(rule): + in_dim = get_in_dim_to_shard(cmd) + if in_dim is not None: + shard_dim_map[in_dim.input_dim] = dim + + input_tgt_placements = [ + Replicate() + if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] + else p + for mesh_dim, p in enumerate(input_src_placements) + ] + output_placements = [ + Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + for p in input_tgt_placements + ] + + return input_tgt_placements, output_placements + + +def register_op_strategy_map( + aten_op_overload: torch._ops.OpOverload, + local_op_name: Callable[..., torch.Tensor], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> None: + dim_map: Callable[..., DimMap] = dim_maps[local_op_name] + + @register_op_strategy(aten_op_overload, schema_info=schema_info) + def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + global_in_shape = input_strategy.shape + assert global_in_shape is not None, "Shape required." + + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + input_src_spec = input_placement_strategy.output_spec + + input_tgt_placements, output_placements = propagate_shape_and_sharding( + input_src_spec.placements, + tuple(global_in_shape), + rules, + mesh.shape, + ) + + # TODO: optimize this. we shouldn't simply blindly replicate + # unshardable dims ... + # FIXME: this can be wrong for situations where we have + # [Shard(0), Shard(0)] + input_tgt_spec = DTensorSpec( + placements=tuple(input_tgt_placements), + mesh=input_src_spec.mesh, + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs = [ + generate_redistribute_costs(input_strategy, input_tgt_spec) + ] + + output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) + output_strategy.strategies.append( + PlacementStrategy( + output_specs=output_spec, + input_specs=(input_tgt_spec,), + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +register_op_strategy_map(aten.squeeze.default, torch.squeeze) +register_op_strategy_map( + aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex) +register_op_strategy_map(aten.view_as_real.default, torch.view_as_real) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/utils.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..334bcc4a37fea343950b63e62a821e16e432acc5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/utils.py @@ -0,0 +1,280 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import functools +import itertools +import operator +from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union + +import torch +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + PlacementStrategy, + RuntimeSchemaInfo, +) +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +# convenient wrapper to register sharding propagation rules +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def register_prop_rule(op, schema_info=None): + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def wrapper(impl): + overloads = op if isinstance(op, list) else [op] + for overload in overloads: + DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( + overload, impl, schema_info + ) + return impl + + return wrapper + + +def register_op_strategy(op, schema_info=None): + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + + # For every ATen op that accepts any args in this list, + # the arg itself can impact the strides (and potentially the sharding strategy) + # of the output tensor. + # thus, we will detect ATen schemas with any of these args and ensure + # that they get specialized here. + arg_names_that_require_specializing_cache_strategy = [ + "memory_format", + ] + + def wrapper(impl): + if isinstance(op, list): + overloads = op + else: + overloads = [op] + + for overload in overloads: + curr_schema_info = None + if schema_info is None: + specialized_args = [ + a.name + for a in overload._schema.arguments + if a.name in arg_names_that_require_specializing_cache_strategy + ] + if any(specialized_args): + curr_schema_info = RuntimeSchemaInfo( + static_kwargkey=specialized_args + ) + else: + curr_schema_info = schema_info + DTensor._op_dispatcher.sharding_propagator.register_op_strategy( + overload, impl, curr_schema_info + ) + return impl + + return wrapper + + +def as_list( + x: Union[List[object], object] + # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. +) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type] + # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, + # which is an object but treated as a list by the tracer. Therefore, keep + # `immutable_list` intact here as well. + if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list): + return x + else: + return [x] + + +def normalize_dim(dim: int, ndim: int) -> int: + return dim if dim >= 0 else dim + ndim + + +def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]: + """Normalize a dim or a sequence of dims, so that they are all positive.""" + if isinstance(dims, int): + dims = (normalize_dim(dims, ndim),) + elif isinstance(dims, list): + dims = [normalize_dim(dim, ndim) for dim in dims] + elif isinstance(dims, tuple): + dims = tuple([normalize_dim(dim, ndim) for dim in dims]) + return dims + + +def prod(xs: Iterable[int]) -> int: + return functools.reduce(operator.mul, xs, 1) + + +def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: + """Check if the shape is shardable according to the spec.""" + # number of shards in each tensor dimension + shards_map = [1] * len(shape) + for i, placement in enumerate(spec.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + shards_map[shard_dim] *= spec.mesh.size(i) + + for i, dim_size in enumerate(shape): + # TODO: maybe we should determine is_shardable based on + # whether it's evenly sharded or not + if shards_map[i] > 1 and dim_size < shards_map[i]: + return False + + return True + + +def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: + """Check if the shape is evenly shardable according to the spec.""" + # number of shards in each tensor dimension + shards_map = [1] * len(shape) + for i, placement in enumerate(spec.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + shards_map[shard_dim] *= spec.mesh.size(i) + + for i, dim_size in enumerate(shape): + if shards_map[i] > 1 and (dim_size % shards_map[i] != 0): + return False + + return True + + +def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool: + """Return True if tensor dim is sharded.""" + return any(p.is_shard(dim) for p in spec.placements) + + +def is_tensor_partial(spec: DTensorSpec) -> bool: + """Return True if tensor is partial on the mesh.""" + return any(p.is_partial() for p in spec.placements) + + +def infer_broadcast_dims_map( + common_shape: torch.Size, input_shape: torch.Size +) -> List[int]: + # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim + # this is aligned with the broadcast semantics + common_ndim = len(common_shape) + input_ndim = len(input_shape) + broadcast_dims_map = [-1] * common_ndim + for idx in range(-1, -1 - input_ndim, -1): + if input_shape[idx] == common_shape[idx]: + broadcast_dims_map[common_ndim + idx] = input_ndim + idx + return broadcast_dims_map + + +def map_placements_after_broadcast( + placements: Tuple[Placement, ...], + shape: torch.Size, + broadcast_dims_map: List[int], +) -> Tuple[Placement, ...]: + """Map each placement based on the output shape after broadcast.""" + new_placements: List[Placement] = [] + for placement in placements: + if isinstance(placement, (Replicate, Partial)): + new_placements.append(placement) + else: + assert isinstance(placement, Shard) + shard_dim = normalize_dim(placement.dim, len(shape)) + new_shard_dim = broadcast_dims_map[shard_dim] + if new_shard_dim != -1: + # there's a map from the common shape shard dim to + # the input shape shard dim before broadcasting, + # use that instead + new_placements.append(Shard(new_shard_dim)) + else: + # there's no map between common shape shard dim and + # the input shape shard dim before broadcasting, + # in this case it means implicit broadcasting happen + # in this dim, so we can just mark it as replicate + # and implict broadcast will broadcast automatically + # to the sharded shape + new_placements.append(Replicate()) + + return tuple(new_placements) + + +def generate_redistribute_costs( + src_strategy: OpStrategy, dst_spec: DTensorSpec +) -> List[float]: + redistribute_costs: List[float] = [] + for strat in src_strategy.strategies: + redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec)) + + return redistribute_costs + + +def expand_to_full_mesh_op_strategy( + mesh: DeviceMesh, + op_schema: OpSchema, + single_mesh_dim_strategies: List[PlacementList], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list: List[Optional[DTensorSpec]] = [] + for specs in zip(*strategy_comb): + if specs[0] is not None: + spec_list.append(DTensorSpec(mesh, specs)) + else: + spec_list.append(None) + + input_specs: List[DTensorSpec] = [ + s for s in spec_list[input_index:] if isinstance(s, DTensorSpec) + ] + + input_args_strategy = op_schema.args_strategy + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + if input_index > 1: + output_specs = tuple(spec_list[:input_index]) + else: + if spec_list[0] is not None: + output_specs = spec_list[0] # type: ignore[assignment] + else: + raise RuntimeError("output spec is None") + strategy = PlacementStrategy( + output_specs=output_specs, + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_random.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_random.py new file mode 100644 index 0000000000000000000000000000000000000000..db4b2832548c07525593329893eb4ec900de36e9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_random.py @@ -0,0 +1,381 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import warnings +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.device_mesh import _get_device_handle, DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Shard + + +__all__ = [ + "is_rng_supported_mesh", + "manual_seed", + "OffsetBasedRNGTracker", + "TensorParallelRNGTracker", +] + +_rng_tracker: Optional["_RNGStateTracker"] = None + + +def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool: + """Checks if the current device of ``device_mesh`` supports DTensor's random APIs. + Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest + users call this API to test the availability before using our random APIs. + + Args: + device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the + random ops APIs are supported. + + Returns: + A bool value. True if ``device_mesh`` supports DTensor Random APIs; False otherwise. + + .. warning:: + Currently we only support correct RNG on cuda/cuda-like devices. + """ + device_handle = _get_device_handle(device_mesh.device_type) + if device_handle and hasattr(device_handle, "set_rng_state"): + return True + else: + # TODO: Logs way too much + warnings.warn( + f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh" + ) + return False + + +def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: + """Sets the seed for generating random numbers for the calling rank. + + Args: + seed (int): The desired seed. + device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. + + Returns: + None + + .. warning:: + When calling this function, :func:`manual_seed` must be called from all ranks of the + default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``, + with the same ``seed`` value. + If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it, + ``manual_seed`` will not set its GPU device's generator seed. + Current implementation only supports a GPU device mesh. + """ + device_handle = _get_device_handle(device_mesh.device_type) + if not device_handle: + raise NotImplementedError( + f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}" + ) + + # allgather the seed over the default PG + object_list = [seed] * dist.get_world_size() + dist.all_gather_object(object_list, seed) + for rank, object in enumerate(object_list): + if seed != int(object): + raise RuntimeError( + f"calling manual_seed function over {device_mesh} but received different seed values on ranks:", + f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!", + ) + # instantiate a RNG tracker if haven't. By default DTensor uses an + # OffsetBasedRNGTracker to perform random operators. + global _rng_tracker + if not _rng_tracker: + _rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type) + + # the current rank is in mesh + if device_mesh.get_coordinate() is not None: + if isinstance(_rng_tracker, TensorParallelRNGTracker): + _rng_tracker._manual_seed(device_mesh, seed) + elif isinstance(_rng_tracker, OffsetBasedRNGTracker): + _rng_tracker._manual_seed(seed) + else: + raise RuntimeError( + f"Unknown type of cuda RNG state tracker: _rng_tracker = {_rng_tracker}" + ) + + +class _RNGStateTracker: + """ + _RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object) + in a dict, mapping from a corresponding tag to each state tensor. It also provides + a set of convenient utility methods to help access/modify the state tensors. The most + important interface is _distribute_region which will be used when DTensor executes + a random op (an operator that calls RNG). + """ + + def __init__(self, device_type: str = "cuda"): + self._device_type = device_type + self._device_handle = _get_device_handle(device_type) + if not (self._device_handle and self._device_handle.is_available()): + raise RuntimeError( + f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device" + ) + + self._states: Dict[str, Tensor] = {} + self._devices = [self._device_handle.current_device()] + self._use_distribute_region = True + + @property + def rng_states(self) -> Dict[str, Tensor]: + return self._states + + @property + def distribute_region_enabled(self) -> bool: + return self._use_distribute_region + + @distribute_region_enabled.setter + def distribute_region_enabled(self, value) -> None: + self._use_distribute_region = value + + def rng_state_is_sync(self, name) -> bool: + return name in self.rng_states + + def get_seed(self, name: str) -> int: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + seed_tensor = (self.rng_states[name])[0:8].view(dtype=torch.int64) + return int(seed_tensor.item()) + + def set_seed(self, name: str, seed: int) -> None: + seed_tensor = torch.tensor([seed]).view(torch.uint8) + offset_tensor = torch.tensor([0]).view(torch.uint8) + self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) + + def _distribute_region(self, spec: DTensorSpec): + pass + + +class OffsetBasedRNGTracker(_RNGStateTracker): + """ + This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states + should be shared and synchronized among all ranks to respect the semantics of DTensor + random operators. + """ + + def __init__(self, device_type: str = "cuda"): + super().__init__(device_type) + # synchronize RNG state using rank 0's current one + rng_state = self._device_handle.get_rng_state().to(device_type) + dist.broadcast(rng_state, 0) + self.rng_states["parallel-rng"] = rng_state.to("cpu") + + def _manual_seed(self, parallel_seed: int) -> None: + self.set_seed("parallel-rng", parallel_seed) + + @contextlib.contextmanager + def _distribute_region(self, spec: DTensorSpec): + # check if the parallel rng state has been synchronized or not + if not self.rng_state_is_sync("parallel-rng"): + raise RuntimeError( + "OffsetBasedRNGTracker requires the random state to be synchronized " + "before entering into a distribute region!" + ) + + if self.distribute_region_enabled: + old_offset = self.get_offset("parallel-rng") + self._set_pre_op_offset(spec) + with torch.random.fork_rng(self._devices, device_type=self._device_type): + self._device_handle.set_rng_state(self.rng_states["parallel-rng"]) + try: + yield # execute the region code + finally: + # update offset to synchronize among ranks + self._set_post_op_offset(spec, old_offset) + else: + yield + + def get_offset(self, name: str) -> int: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + offset_tensor = (self.rng_states[name])[8:].view(dtype=torch.int64) + return int(offset_tensor.item()) + + def set_offset(self, name: str, offset: int) -> None: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + seed_tensor = (self.rng_states[name])[0:8] + offset_tensor = torch.tensor([offset]).view(torch.uint8) + self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) + + def _set_pre_op_offset(self, spec: DTensorSpec) -> None: + """Set the starting RNG offset for current device's local shard before actual + op execution. The pre_op_offset value should start from the current RNG offset + and increment by the size of local shard until it reaches the size of the whole + DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset + will be the same. + + Args: + spec (:class:`DTensorSpec`): the spec of the DTensor object on which + we prepare the offset for running random ops. + + Returns: + None + + .. warning:: + Note that, current implementation does not consider DTensor's continguity. + + Example: + take a DTensor of shape [8, 16] as an example. Assume that the DTensor + is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]), + and the mesh is: + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] + ``spec.mesh.get_coordinate()`` provides the coordinate of the current rank + in the mesh. For example, the coordinate of rank 5 is (1, 0, 1). + + Another concept to introduce besides rank coordinate is shard coordinate. + Each rank holds a local shard of the DTensor. In the example, the DTensor + is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and + rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each. + That being said, the local shard on rank 0 and rank 2 correspond to the same + shard of the DTensor. To denote each DTensor shard, we use a shard coordinate + (in the example, it will be a tuple (i, j) where shard (i, j) has the slice + DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2). + + Once we have rank coordinate and shard coordinate, we can calculate on each rank + what shard of the DTensor the rank holds, with the help of dim_map. The dim_map + of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord + (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]). + Following this calculation, + rank 0 and rank 2 holds the shard of coord (0, 0); + rank 1 and rank 3 holds the shard of coord (0, 1); + rank 4 and rank 6 holds the shard of coord (1, 0); + rank 5 and rank 7 holds the shard of coord (1, 1); + + The last value to calculate before obtaining the starting offset is the shard linear index. + The starting offset for each rank will be its shard_linear_index * local_tensor_numel. + """ + dtensor_shape = spec.shape + mesh = spec.mesh + dim_map = spec.dim_map + + # Compute shard coordinate: + # The coordinate on each tensor dim is a tuple (idx, range) + # If a DTensor is partitioned on its dim i into n shards, and the current rank + # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i + coordinate = mesh.get_coordinate() + assert coordinate is not None + shard_coord = [ + coordinate[mesh_dim] if mesh_dim >= 0 else 0 for mesh_dim in dim_map + ] + shard_size = [ + mesh.size(mesh_dim) if mesh_dim >= 0 else 1 for mesh_dim in dim_map + ] + + # compute shard linear index + shard_linear_idx = self._calc_shard_linear_idx(shard_coord, shard_size) + + # compute starting offset using the first shard's size + local_size_on_rank_0 = list(dtensor_shape) + for idx, placement in enumerate(spec.placements): + if isinstance(placement, Shard): + mesh_dim_size = mesh.size(idx) + shard_dim = placement.dim + local_size_on_rank_0[shard_dim] = placement._local_shard_size_on_dim( + dtensor_shape[shard_dim], + mesh_dim_size, + 0, + return_offset=False, + )[0] + + from torch.distributed.tensor._ops.utils import prod + + local_size = prod(local_size_on_rank_0) + + # get current RNG offset + current_offset = self.get_offset("parallel-rng") + + # pytorch: offset must be multiple of 4 + # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + self.set_offset("parallel-rng", current_offset + offset_incr) + + def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: + """Sets the RNG to a synchronized state after running the local random op. Every + rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is + the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor + random ops. + + Args: + spec (:class:`DTensorSpec`): the spec of the DTensor object on which + we post-process the offset for running random ops. + + Returns: + None + """ + dtensor_shape = spec.shape + + from torch.distributed.tensor._ops.utils import prod + + numel = prod(dtensor_shape) + # pytorch: offset must be multiple of 4 + # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp + numel = (numel + 3) // 4 * 4 + self.set_offset("parallel-rng", old_offset + numel) + + def _calc_shard_linear_idx( + self, shard_coord: List[int], shard_size: List[int] + ) -> int: + # compute shard linear index + shard_linear_idx = 0 + shard_coord_stride = 1 + for idx, size in zip(reversed(shard_coord), reversed(shard_size)): + shard_linear_idx += idx * shard_coord_stride + shard_coord_stride *= size + + return shard_linear_idx + + +class TensorParallelRNGTracker(_RNGStateTracker): + def __init__(self, device_type: str = "cuda"): + super().__init__(device_type) + # copy the default RNG state + self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state() + + def _manual_seed( + self, + tp_mesh: DeviceMesh, + base_seed: int = 1234, + ): + tensor_parallel_rank = tp_mesh.get_local_rank() + # this magic number 2718 comes from Megatron's code + # (https://github.com/NVIDIA/Megatron-LM/blob/060415572f4365a2e895f8036c4e37dad0efbdf5/megatron/core/tensor_parallel/random.py#L162-L163) + MegatronMagicNum = 2718 + tensor_parallel_seed = base_seed + MegatronMagicNum + tensor_parallel_rank + self.set_seed("tensor-parallel-rng", tensor_parallel_seed) + + @contextlib.contextmanager + def _distribute_region(self, spec: DTensorSpec): + # check if the tensor parallel rng state has been synchronized or not + if not self.rng_state_is_sync("tensor-parallel-rng"): + raise RuntimeError( + "TensorParallelRNGTracker requires the random state to be synchronized " + "before entering into a distribute region!" + ) + + if self.distribute_region_enabled: + with torch.random.fork_rng(self._devices, device_type=self._device_type): + self._device_handle.set_rng_state( + self.rng_states["tensor-parallel-rng"] + ) + try: + yield + finally: + self.rng_states[ + "tensor-parallel-rng" + ] = self._device_handle.get_rng_state() + else: + yield diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4e37e54744f4d5b909258b0351b42650c3a594 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py @@ -0,0 +1,351 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +from functools import lru_cache +from typing import cast, List, NamedTuple, Tuple + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +logger = logging.getLogger(__name__) + + +class _TransformInfo(NamedTuple): + mesh_dim: int + src_dst_placements: Tuple[Placement, Placement] + # logical_shape on this mesh dimension + logical_shape: List[int] + + +@lru_cache(maxsize=None) +def _gen_transform_infos( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, +) -> List[_TransformInfo]: + """ + Generate the transform infos from the source placements to the target placements. + + To transform from source to target placement it might have multiple steps, i.e. it + might decompose Si -> Sj into Si -> R -> Sj. + This would detect if there're mis-aligned/nested shardings between src/dst placements. + E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)), + in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in + the former is a nested-sharding of a tensor already already sharded dimension 0, whereras + the latter is the first sharding on tensor dimension 0. + """ + transform_infos: List[_TransformInfo] = [] + + device_mesh = src_spec.device_mesh + my_coordinate = device_mesh.get_coordinate() + assert my_coordinate is not None + + # logical shape records the logic tensor shape on the mesh dimension + # this is useful to ensure uneven sharding gets correct output shape + initial_logical_shape = list(src_spec.shape) + mesh_dims_to_logical_shape = [initial_logical_shape] + + if device_mesh.ndim == 1: + # if device_mesh is 1D, redistribute is a simple direct transformation + transform_infos.append( + _TransformInfo( + mesh_dim=0, + src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]), + logical_shape=initial_logical_shape, + ) + ) + return transform_infos + + # Handle multi-dim device mesh placement redistribution + # First, we need to build the logical shape for each mesh dim + # for correct allgathering uneven shards on each mesh dim (with dynamic padding) + for i, (src, dst) in enumerate(zip(src_spec.placements, dst_spec.placements)): + current_logical_shape = mesh_dims_to_logical_shape[i] + if isinstance(src, Shard): + if i < device_mesh.ndim - 1: + # calculate and save the logical shape for this sharding + mesh_dim_size = device_mesh.size(mesh_dim=i) + local_shard_size, _ = src._local_shard_size_on_dim( + current_logical_shape[src.dim], + mesh_dim_size, + my_coordinate[i], + ) + new_logical_shape = list(current_logical_shape) + new_logical_shape[src.dim] = local_shard_size + mesh_dims_to_logical_shape.append(new_logical_shape) + else: + mesh_dims_to_logical_shape.append(current_logical_shape) + + # Next, we need to derive the transform infos from src to dst placements, + # here we use a greedy search with step by step state transformations + current_placements = list(src_spec.placements) + target_placements = list(dst_spec.placements) + + if src_spec.num_shards > 1: + # If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec + # a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))). + # In those cases, we first traverse from inner placement to outer placement + # to detect misaligned shardings and properly replicate nested sharding first. + for mesh_dim in reversed(range(len(current_placements))): + current = current_placements[mesh_dim] + target = target_placements[mesh_dim] + # If target is not Shard, we can directly redistribute since we are traversing from innner + # to outer placements here + if isinstance(target, Shard): + # If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim + shard_dim = target.dim + current_mesh_sharding, target_mesh_sharding = [], [] + for i, (s, p) in enumerate(zip(current_placements, target_placements)): + if i >= mesh_dim: + break + if s.is_shard(shard_dim): + current_mesh_sharding.append(i) + if p.is_shard(shard_dim): + target_mesh_sharding.append(i) + + if current_mesh_sharding != target_mesh_sharding: + # if current/target_placements have misaligned sharding on the tensor dim BEFORE the current + # mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding + target = Replicate() + + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) + ) + current_placements[mesh_dim] = target + + # We always traverse from outer placement to inner placement to collect the remaining + # needed transform infos (i.e. the replication from nested sharding might need to further + # perform resharding to Shard again) + for mesh_dim, (current, target) in enumerate( + zip(current_placements, target_placements) + ): + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) + ) + current_placements[mesh_dim] = target + + return transform_infos + + +def redistribute_local_tensor( + local_tensor: torch.Tensor, + current_spec: DTensorSpec, + target_spec: DTensorSpec, + *, + async_op: bool = False, + is_backward: bool = False, +) -> torch.Tensor: + """ + This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to + the target DTensorSpec, which involves the necessary collective calls to transform + the local shard of the DTensor from its current spec to the target spec. + """ + + if current_spec.mesh != target_spec.mesh: + # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same + raise NotImplementedError("Cross device mesh comm not supported yet!") + + new_local_tensor = None + device_mesh = current_spec.mesh + + my_coordinate = device_mesh.get_coordinate() + + if my_coordinate is None: + # if rank is not part of mesh, we skip redistribute and simply return local_tensor, + # which should be an empty tensor + return local_tensor + + transform_infos = _gen_transform_infos(current_spec, target_spec) + + for transform_info in transform_infos: + i = transform_info.mesh_dim + current, target = transform_info.src_dst_placements + num_chunks = device_mesh.size(mesh_dim=i) + + if current == target: + # short cut, just use the original local tensor + new_local_tensor = local_tensor + continue + + logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i) + + if target.is_replicate(): + # Case 1: target is Replicate + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_value( + local_tensor, device_mesh, i + ) + elif current.is_shard(): + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + elif target.is_shard(): + # Case 2: target is Shard + target_placement = cast(Shard, target) + target_dim = target_placement.dim + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_shard_value( + local_tensor, device_mesh, i, target_placement + ) + elif current.is_replicate(): + # split the tensor and return the corresponding cloned local shard + new_local_tensor = target_placement._replicate_to_shard( + local_tensor, device_mesh, i, my_coordinate[i] + ) + else: + assert ( + current.is_shard() + ), f"Current placement should be shard but found {current}" + shard_spec = cast(Shard, current) + if shard_spec.dim != target_placement.dim: + new_local_tensor = shard_spec._to_new_shard_dim( + local_tensor, + device_mesh, + i, + transform_info.logical_shape, + target_placement.dim, + ) + elif target.is_partial(): + if current.is_replicate(): + partial_spec = cast(Partial, target) + # skip the replicate to partial transformation when we are in backward pass + # In this case we keep the grad as replicate, this is because we don't + # want to convert the replicated gradients back to partial, although + # that's logically conform with the same layout, converting the gradients + # back to partial is actually useless as you would have to do reduce later + # which would be more expensive than keeping it replicate! For this reason, + # we keep the replicate grad here. + new_local_tensor = ( + partial_spec._partition_value(local_tensor, device_mesh, i) + if not is_backward + else local_tensor + ) + elif current.is_shard(): + if not is_backward: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + # for backward shard -> partial, we just need to convert the shard to replicate + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + # partial -> partial no op, should never hit + new_local_tensor = local_tensor + + assert new_local_tensor is not None + local_tensor = new_local_tensor + + assert new_local_tensor is not None, "redistribute failed!" + + if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor): + new_local_tensor = new_local_tensor.wait() + + return new_local_tensor + + +class Redistribute(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: "dtensor.DTensor", + device_mesh: DeviceMesh, + placements: Tuple[Placement, ...], + async_op: bool = False, + ): + current_spec = input._spec + ctx.current_spec = current_spec + ctx.async_op = async_op + + if current_spec.placements != placements: + target_spec = DTensorSpec( + device_mesh, placements, tensor_meta=input._spec.tensor_meta + ) + + local_tensor = input._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, async_op=async_op + ) + else: + # use the same local tensor if placements are the same. + output = input._local_tensor + target_spec = current_spec + + return dtensor.DTensor( + output, + target_spec, + requires_grad=input.requires_grad, + ) + + @staticmethod + def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] + previous_spec = ctx.current_spec + current_spec = grad_output._spec + async_op = ctx.async_op + + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, + current_spec, + previous_spec, + async_op=async_op, + is_backward=True, + ) + # normalize the target placement to replicate if it is partial + normalized_placements: List[Placement] = [] + for previous_placement in previous_spec.placements: + if previous_placement.is_partial(): + # keep target placement to replicate instead of partial in this case + normalized_placements.append(Replicate()) + else: + normalized_placements.append(previous_placement) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=grad_output.dtype, + ), + ) + output_dtensor = dtensor.DTensor( + output, + spec, + requires_grad=grad_output.requires_grad, + ) + + return ( + output_dtensor, + None, + None, + None, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..c277655ff4bf0f0cec77d37956fc5f802125ae9b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py @@ -0,0 +1,497 @@ +# mypy: allow-untyped-defs +import threading +from functools import lru_cache +from itertools import chain +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch._ops import OpOverload +from torch._subclasses import FakeTensorMode +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpInfo, + OpSchema, + OpStrategy, + OutputSharding, + OutputSpecType, + PlacementStrategy, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._utils import ( + compute_local_shape, + compute_local_stride, + try_find_mesh_from_args, +) + + +aten = torch.ops.aten + + +def _length(obj) -> int: + if obj is None: + return 0 + if not isinstance(obj, Sequence): + return 1 + return len(obj) + + +class LocalLRUCache(threading.local): + def __init__(self, user_function: Callable) -> None: + self.cache = lru_cache(None)(user_function) + + def __call__(self, *args, **kwargs) -> object: + return self.cache(*args, **kwargs) + + def cache_info(self): + return self.cache.cache_info() + + +class ShardingPropagator: + def __init__(self) -> None: + self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} + self.op_strategy_funcs: Dict[ + OpOverload, + Callable[[DeviceMesh, OpSchema], StrategyType], + ] = {} + # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop + self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} + self.propagate_op_sharding = LocalLRUCache( + self.propagate_op_sharding_non_cached + ) + # op map to save indices of shape (and stride) args which may need to be modified in sharding prop + self.op_to_shape_and_stride_idx: Dict[ + OpOverload, Union[int, Tuple[int, int]] + ] = { + # new factory ops + aten.new_empty.default: 1, + aten.new_full.default: 1, + aten.new_ones.default: 1, + aten.new_zeros.default: 1, + aten.new_empty_strided.default: (1, 2), + # view ops + aten.expand.default: 1, + aten.reshape.default: 1, + aten.view.default: 1, + aten._unsafe_view.default: 1, + } + + def register_sharding_prop_rule( + self, + op_overload: OpOverload, + rule_func: Callable[[OpSchema], OutputSharding], + schema_info: Optional[RuntimeSchemaInfo] = None, + ): + """ + Register a sharding propagation rule for an operator. + """ + self.op_to_rules[op_overload] = rule_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + def register_op_strategy( + self, + op_overload: OpOverload, + strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType], + schema_info: Optional[RuntimeSchemaInfo] = None, + ): + """ + Register a sharding strategy generator for an operator. + """ + self.op_strategy_funcs[op_overload] = strategy_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + @lru_cache # noqa: B019 + def _propagate_tensor_meta( + self, op_schema: OpSchema + ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + """ + Propagate the tensor metadata, it could either return a TensorMeta + or a list/tuple of TensorMetas + """ + if op_schema.op == aten.equal.default: + # data dependent ops can't be used for fake propagation + return None + + # NOTE: We must call the tracing in fake tensor mode so that it + # avoids materializing memory + with FakeTensorMode(): + fake_args = op_schema.gen_fake_args() + fake_kwargs = op_schema.gen_fake_kwargs() + fake_out = op_schema.op(*fake_args, **fake_kwargs) + + if isinstance(fake_out, torch.Tensor): + return TensorMeta( + shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype + ) + + elif isinstance(fake_out, (tuple, list)): + tensor_meta_list: List[Optional[TensorMeta]] = [] + for fake_out_item in fake_out: + if isinstance(fake_out_item, torch.Tensor): + tensor_meta_list.append( + TensorMeta( + shape=fake_out_item.shape, + stride=fake_out_item.stride(), + dtype=fake_out_item.dtype, + ) + ) + else: + tensor_meta_list.append(None) + return ( + tuple(tensor_meta_list) + if isinstance(fake_out, tuple) + else tensor_meta_list + ) + else: + # if fake is not a tensor or tuple of tensor, return as none + return None + + def _wrap_output_spec_tensor_meta( + self, + op: OpOverload, + output_specs: OutputSpecType, + output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]], + ) -> None: + """ + Wrap the output_specs with the tensor metadata from the output. + """ + + if isinstance(output_specs, DTensorSpec): + if not isinstance(output_tensor_meta, TensorMeta): + # Either error due to ShardingPropagator or due to incorrect OutputSpec + if not isinstance(output_tensor_meta, (tuple, list)): + raise ValueError( + "ShardingPropagator error: output does not have an associated TensorMeta" + ) + raise ValueError( + f"For the op {op.name()}, `output_specs` has 1 output which does not equal the " + f"number of op outputs: {len(output_tensor_meta)}." + ) + output_specs.tensor_meta = output_tensor_meta + elif isinstance(output_specs, (tuple, list)): + if not isinstance(output_tensor_meta, (tuple, list)) or len( + output_specs + ) != len(output_tensor_meta): + raise ValueError( + f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the " + f"number of op outputs {_length(output_tensor_meta)}." + ) + for i, spec in enumerate(output_specs): + if isinstance(spec, DTensorSpec): + output_tensor_meta_i = output_tensor_meta[i] + if not isinstance(output_tensor_meta_i, TensorMeta): + raise ValueError( + f"ShardingPropagator error: output {i} does not have an associated TensorMeta" + ) + spec.tensor_meta = output_tensor_meta_i + + def propagate(self, op_info: OpInfo) -> None: + # We cannot use an lru cache if we know that inputs will have dynamic shapes, + # because SymInts are not hashable. + # This is generally ok because this only happens during tracing in torch.compile, + # and tracing does not need to be as fast as eagermode DTensor usages. + if op_info.schema.has_symints: + output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) + else: + output_sharding = cast( + OutputSharding, self.propagate_op_sharding(op_info.schema) + ) + op_info.output_sharding = output_sharding + + def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: + """ + Propagate the sharding for an operator given the op_schema. + """ + # special case op, we don't need to propagate for local + # scalar. TODO: figure out a better way to handle this + if op_schema.op is aten._local_scalar_dense.default: + return OutputSharding(None, op_schema) + + out_tensor_meta = self._propagate_tensor_meta(op_schema) + + def spec_to_strategy(spec: object) -> object: + if isinstance(spec, DTensorSpec): + return OpStrategy([PlacementStrategy(spec)]) + elif ( + isinstance(spec, (list, tuple)) + and len(spec) > 0 + and isinstance(spec[0], DTensorSpec) + ): + # tensor list create tuple strategy + tuple_strategy = [spec_to_strategy(s) for s in spec] + tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) + return TupleStrategy( + tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy + ) + else: + return spec + + if op_schema.op in self.op_strategy_funcs: + # generate op strategy for the op. + mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema) + # swap the args spec with args strategies + args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema] + + kwargs_op_strategy = { + k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items() + } + + # construct a new OpSchema on args for strategy based propagation + strategy_schema: OpSchema = OpSchema( + op=op_schema.op, + args_schema=tuple(args_op_strategy), + kwargs_schema=kwargs_op_strategy, + ) + + op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema) + + if isinstance(op_strategy, OpStrategy): + # single Op strategy + output_strategy = self._select_strategy(op_strategy) + + # check if we need to redistribute the input + needs_redistribute = False + expected_input_specs: List[DTensorSpec] = [] + + # in case where the op does not specify input_specs and output_specs + # is a DTensorSpec, we use output_specs as the spec for each DTensor + # input arg. + if output_strategy.input_specs is None: + assert isinstance(output_strategy.output_specs, DTensorSpec) + + for idx, input_spec in enumerate(op_schema.args_spec): + desired_spec = ( + output_strategy.output_spec + if output_strategy.input_specs is None + else output_strategy.input_specs[idx] + ) + expected_input_specs.append( + desired_spec.shallow_copy_with_tensor_meta( + input_spec.tensor_meta + ) + ) + if input_spec.placements != desired_spec.placements: + needs_redistribute = True + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(expected_input_specs), {} + ) + suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) + + # shape and stride args need to be modified for + # view ops and new factory ops, potentially + if op_schema.op in self.op_to_shape_and_stride_idx: + assert isinstance(output_strategy.output_spec, DTensorSpec) + # It happens when the output has the same shape as the input + # and the input placements are not all Replicate(). + if output_strategy.output_spec.is_sharded(): + schema = suggestion_schema or op_schema + assert isinstance(out_tensor_meta, TensorMeta) + suggestion_schema = self._adjust_shape_and_stride_args( + out_tensor_meta, schema, output_strategy.output_spec, mesh + ) + needs_redistribute = True + + # construct output spec for the op + if op_schema.return_type_tuple_tensor_like(): + # for ops that return multiple tensors and the output_specs is not + # a tuple, we use a tuple of that single output spec as the new + # output_specs + output_specs: OutputSpecType = output_strategy.output_specs + if isinstance(output_specs, DTensorSpec): + output_specs = tuple( + [ + # create a new DTensorSpec with the same placement as the + # output_specs in output_strategy + DTensorSpec( + mesh=output_specs.mesh, + placements=output_specs.placements, + tensor_meta=output_specs.tensor_meta, + ) + for _ in range(len(op_schema.op._schema.returns)) + ] + ) + elif op_schema.return_type_tensor(): + output_specs = output_strategy.output_specs + else: + output_specs = None + + output_sharding = OutputSharding( + output_specs, + suggestion_schema, + needs_redistribute=needs_redistribute, + ) + elif isinstance(op_strategy, TupleStrategy): + # tuple strategy output sharding processing + # runtime selected placement strategy for each TupleStrategy input arg + selected_strategies: List[PlacementStrategy] = [] + out_spec_list: List[DTensorSpec] = [] + for strategy in op_strategy.childs: + assert isinstance(strategy, OpStrategy) + selected_strategy = self._select_strategy(strategy) + selected_strategies.append(selected_strategy) + out_spec_list.append(selected_strategy.output_spec) + + needs_redistribute = False + suggestion_args: List[object] = [] + tensor_or_list_tensor_arg_idx = 0 + + for arg in op_schema.args_schema: + if ( + arg + and isinstance(arg, (list, tuple)) + and isinstance(arg[0], DTensorSpec) + ): + expected_input_spec_list: List[DTensorSpec] = [] + for idx, arg_spec in enumerate(arg): + expected_input_spec = selected_strategies[idx].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg_spec.tensor_meta + ) + ) + if arg_spec.placements != expected_input_spec.placements: + needs_redistribute = True + expected_input_spec_list.append(expected_input_spec) + suggestion_args.append( + tuple(expected_input_spec_list) + if isinstance(arg, tuple) + else expected_input_spec_list + ) + tensor_or_list_tensor_arg_idx += 1 + + elif isinstance(arg, DTensorSpec): + expected_input_spec = selected_strategies[0].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg.tensor_meta + ) + ) + if arg.placements != expected_input_spec.placements: + needs_redistribute = True + suggestion_args.append(expected_input_spec) + tensor_or_list_tensor_arg_idx += 1 + else: + suggestion_args.append(arg) + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema + ) + + output_sharding = OutputSharding( + tuple(out_spec_list) if out_tensor_meta is not None else None, + suggestion_schema, + needs_redistribute=needs_redistribute, + ) + else: + raise ValueError("Unsupported op strategy type") + + # associate the output sharding with the output tensor metadata + self._wrap_output_spec_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + return output_sharding + elif op_schema.op in self.op_to_rules: + # propagate the sharding with rule + sharding_prop_func = self.op_to_rules[op_schema.op] + + # step 1. there's sharding propagation rule, run + # sharding propagation to get the output sharding + try: + output_sharding = sharding_prop_func(op_schema) + except NotImplementedError as e: + raise e + except Exception as e: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}" + ) from e + + # step 2. if can't get output_spec from sharding + # propagation (i.e. no rules apply for input + # placements), we return the output sharding + # with schema suggestions, which can be used to + # decide how to do redistribute on inputs + if output_sharding.output_spec is None: + if output_sharding.redistribute_schema is None: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}!" + ) + else: + # we do auto redistribute on inputs if necessary + # run sharding propagation again with suggested schema + propagation_res = sharding_prop_func( + output_sharding.redistribute_schema + ) + # we set the output sharding with the new propagation result + # so that dispatching know both output_spec and redistribute_schema + # exist, which indicates a reshard is needed + output_sharding.output_spec = propagation_res.output_spec + output_sharding.needs_redistribute = True + + # associate the output sharding with the output tensor metadata + self._wrap_output_spec_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + + return output_sharding + else: + raise NotImplementedError( + f"Operator {op_schema.op} does not have a sharding strategy registered." + ) + + def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: + if len(strategy.strategies) == 1: + # short cut with only one possible strategy + return strategy.strategies[0] + + strategy_costs: List[float] = [] + for strtg in strategy.strategies: + assert ( + strtg.redistribute_cost is not None + ), "must set redistribute cost each strategy!" + redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) + strategy_costs.append(redistribute_cost) + + # for eager execution, we just select the one with the minimal redistribute cost + return strategy.strategies[strategy_costs.index(min(strategy_costs))] + + def _adjust_shape_and_stride_args( + self, + out_tensor_meta: TensorMeta, + schema: OpSchema, + spec: DTensorSpec, + mesh: DeviceMesh, + ) -> OpSchema: + shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] + if isinstance(shape_stride_idx, tuple): + shape_idx, stride_idx = shape_stride_idx + else: + shape_idx = shape_stride_idx + stride_idx = None + + expected_input_schema = list(schema.args_schema) + # adjust shape to be the same as that of the _local_tensor + # of the DTensor input arg at index 0, which is inferred + expected_input_schema[shape_idx] = compute_local_shape( + out_tensor_meta.shape, mesh, spec.placements + ) + + # adjust the stride arg for aten.new_empty_strided.default + if stride_idx: + expected_input_schema[stride_idx] = compute_local_stride( + out_tensor_meta.stride, mesh, spec.placements + ) + + return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_shards_wrapper.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_shards_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..de396473b77c91eef0b97e17b555f192c709e022 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_shards_wrapper.py @@ -0,0 +1,316 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Tuple + +import torch +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + MetadataIndex, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + TensorWriteData, + WriteItem, + WriteItemType, +) + + +aten = ( + torch.ops.aten +) # pyre-ignore[5]: Globally accessible variable `aten` has no type specified. + + +class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ + """ + A wrapper class to hold local shards of a DTensor. + This class is used largely for checkpointing purposes and implicity subtypes + the _Checkpointable protocol. + """ + + __slots__ = ["_local_shards", "_storage_meta"] + _local_shards: List[torch.Tensor] + _storage_meta: TensorStorageMetadata + + @staticmethod + def __new__( + cls, local_shards: List[torch.Tensor], local_offsets: List[Tuple[int, ...]] + ) -> "LocalShardsWrapper": + assert len(local_shards) > 0 + assert len(local_shards) == len(local_offsets) + assert all( + tensor.device == local_shards[0].device for tensor in local_shards[1:] + ) + + # we calculate the total tensor size by "concat" on second tensor dimension + cat_tensor_shape = list(local_shards[0].size()) + if len(local_shards) > 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[1] += shard.size()[1] + + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) + wrapper_shape = torch.Size(cat_tensor_shape) + chunks_meta = [ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=shard.size(), + ) + for shard, offset in zip(local_shards, local_offsets) + ] + + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + torch.Size(cat_tensor_shape), + ) + r._local_shards = local_shards + r._storage_meta = TensorStorageMetadata( + properties=wrapper_properties, + size=wrapper_shape, + chunks=chunks_meta, + ) + + return r + + # necessary for ops dispatching from this subclass to its local shards + @classmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + dispatcher = { + torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor, + torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor, + aten._to_copy.default: cls.handle_to_copy, + aten.view.default: cls.handle_view, + aten.equal.default: cls.handle_equal, + aten.detach.default: cls.handle_detach, + aten.clone.default: cls.handle_clone, + } + + if func in dispatcher: + return dispatcher[func]( + args, kwargs + ) # pyre-ignore [29] - `Variable[_VT]` is not a function. + else: + raise NotImplementedError( + f"{func} is not supported for LocalShardsWrapper!" + ) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_all_gather_into_tensor(args, kwargs): + dim = args[0].local_sizes()[0][1] + cat_tensor = torch.cat( + [t.view(-1) for t in args[0].local_shards()], dim=0 + ).view(-1, dim) + return torch.ops._c10d_functional.all_gather_into_tensor.default( + cat_tensor, *args[1:], **kwargs + ) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_wait_tensor(args, kwargs): + return torch.ops._c10d_functional.wait_tensor(args[0]) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_to_copy(args, kwargs): + res_shards_list = [ + aten._to_copy.default(shard, *args[1:], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_view(args, kwargs): + # TODO, do we need to change the shape of associated offsets? + res_shards_list = [ + aten.view.default(shard, args[1], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_equal(args, kwargs): + """ + LocalShardsWrapper equal impl also checks for equality of storage metadata + and the order of shards + """ + a, b = args[0], args[1] + if len(a.local_shards()) != len(b.local_shards()): + return False + if not all( + aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards()) + ): + return False + if not a.storage_metadata() == b.storage_metadata(): + return False + return True + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_detach(args, kwargs): + self_ls = args[0] + deatched_local_shards = [ + aten.detach.default(shard) for shard in self_ls.local_shards() + ] + self_ls._local_shards = deatched_local_shards + self_ls._storage_meta.properties.requires_grad = False + return self_ls + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_clone(args, kwargs): + self_ls = args[0] + desired_memory_format = kwargs.get("memory_format", None) + if desired_memory_format and desired_memory_format != torch.preserve_format: + raise NotImplementedError( + f"{desired_memory_format} is not supported for LocalShardsWrapper!" + ) + cloned_local_shards = [ + shard.clone(memory_format=desired_memory_format) + for shard in self_ls._local_shards + ] + return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + + @property + def device(self) -> torch._C.device: # type: ignore[override] + return self._local_shards[0].device + + @property + def is_meta(self) -> bool: # type: ignore[override] + return self._local_shards[0].is_meta + + # pyre-ignore[14] + def is_pinned(self) -> bool: # type: ignore[override] + return self._storage_meta.properties.pin_memory + + # pyre-ignore[14] + def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper": + self._storage_meta.properties.requires_grad = requires_grad + [shard.requires_grad_(requires_grad) for shard in self._local_shards] + return self + + def local_shards(self) -> List[torch.Tensor]: + """ + Returns a list of :class:`torch.Tensor' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + def local_sizes(self) -> List[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local sizes for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.sizes for chunk in self._storage_meta.chunks] + + def local_offsets(self) -> List[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local offsets for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.offsets for chunk in self._storage_meta.chunks] + + @property + def local_chunks(self) -> List[ChunkStorageMetadata]: + """ + Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the + metadata for each tensor shard + """ + return self._storage_meta.chunks + + def storage_metadata(self) -> TensorStorageMetadata: + """ + Returns a :class:`TensorStorageMetadata` object corresponding to the + metadata for the local tensor on current rank + """ + return self._storage_meta + + def __create_write_items__( + self, fqn: str, object: Any + ) -> List[WriteItem]: # pyre-ignore[2] + """ + For compatibility with DCP, we support creation of WriteItems + such that they can be saved properly. + """ + return [ + WriteItem( + index=MetadataIndex(fqn, chunks.offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata( + offsets=chunks.offsets, + sizes=chunks.sizes, + ), + properties=self._storage_meta.properties, + size=object.size(), + ), + ) + for tensor, chunks in zip(self.local_shards(), self.local_chunks) + ] + + def __create_chunk_list__(self) -> List[ChunkStorageMetadata]: + """ + For compatibility with DCP, we support creation of chunk lists + such that they can be saved properly. + """ + return self._storage_meta.chunks + + def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor: + """ + For compatibility with DCP, we support finding shard based on index + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + # Fast lookup path + if index.index is not None: + if ( + len(self._local_shards) > index.index + and self._storage_meta.chunks[index.index].offsets == index.offset + ): + return self._local_shards[index.index] + + if index.offset is not None: + for shard, chunk in zip(self._local_shards, self._storage_meta.chunks): + if chunk.offsets == index.offset: + return shard + + raise ValueError( + f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" + ) + + def _get_tensor_size_bytes(self) -> int: + object_size = 0 + for shard in self.local_shards(): + object_size += shard.nelement() * shard.element_size() + return object_size + + # pyre-fixme[3]: Return type must be annotated. + def __hash__(self): + return id(self) + + # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" + + def __str__(self) -> str: + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_tp_conv.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_tp_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ac11ef2162cbb7f80eecc3914d46ce3d3d5c1852 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_tp_conv.py @@ -0,0 +1,279 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from typing import cast, Dict, List, Tuple + +import torch +import torch.distributed as dist +import torch.distributed.tensor._api as dtensor + + +aten = torch.ops.aten + + +def _requires_data_exchange(padding): + # TODO: whether there requires data exchange is currently determined by padding + return padding[1] != 0 + + +def _is_supported(input_size, kernel_size, stride, padding, dilation): + if dilation[1] != 1: + raise RuntimeError("Dilation must be 1 for tensor parallel convolution.") + if padding[1] != 0: + if stride[1] != 1: + raise RuntimeError( + "Stride must be 1 when there is padding for tensor parallel convolution." + ) + if kernel_size[3] // 2 > input_size[3]: + raise RuntimeError( + "kernel_size[3] // 2 should be less than or equal to input_size[3] for tensor parallel convolution." + ) + else: + if not (input_size[3] % stride[1] == 0 and stride[1] == kernel_size[3]): + raise RuntimeError( + "It requires that input_size[3] is divisible by stride[1] and stride[1] equals kernel_size[3] " + "when there is padding for tensor parallel convolution." + ) + return True + + +def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size): + # dist comms and reconstruct local input tensor + send_to_right = in_tensor[:, :, :, -d1:].contiguous() + send_to_left = in_tensor[:, :, :, :d2].contiguous() + recv_from_right = torch.zeros_like(send_to_left) + recv_from_left = torch.zeros_like(send_to_right) + + send_op_right = dist.P2POp(dist.isend, send_to_right, right) + send_op_left = dist.P2POp(dist.isend, send_to_left, left) + recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) + recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) + + reqs = dist.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_left, recv_op_right] + ) + for req in reqs: + req.wait() + + if rank == 0: + in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1) + elif rank == size - 1: + in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1) + else: + in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1) + + return in_tensor + + +def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size): + # dist comms and aggregate gradients for edge pixels + send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous() + send_to_left = grad_in_tensor[:, :, :, :d1].contiguous() + recv_from_right = torch.zeros_like(send_to_left) + recv_from_left = torch.zeros_like(send_to_right) + + send_op_right = dist.P2POp(dist.isend, send_to_right, right) + send_op_left = dist.P2POp(dist.isend, send_to_left, left) + recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) + recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) + + reqs = dist.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_left, recv_op_right] + ) + for req in reqs: + req.wait() + + if rank == 0: + grad_in_tensor = grad_in_tensor[:, :, :, :-d2] + grad_in_tensor[:, :, :, -d1:] = torch.add( + grad_in_tensor[:, :, :, -d1:], recv_from_right + ) + elif rank == size - 1: + grad_in_tensor = grad_in_tensor[:, :, :, d1:] + grad_in_tensor[:, :, :, :d2] = torch.add( + grad_in_tensor[:, :, :, :d2], recv_from_left + ) + else: + grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2] + grad_in_tensor[:, :, :, -d1:] = torch.add( + grad_in_tensor[:, :, :, -d1:], recv_from_right + ) + grad_in_tensor[:, :, :, :d2] = torch.add( + grad_in_tensor[:, :, :, :d2], recv_from_left + ) + + +def tp_convolution( + op_call: torch._ops.OpOverload, + local_tensor_args: Tuple[object, ...], + local_tensor_kwargs: Dict[str, object], +) -> object: + assert op_call == aten.convolution.default + assert len(local_tensor_args) == 9 + + rank = dist.get_rank() + size = dist.get_world_size() + in_tensor = cast(torch.Tensor, local_tensor_args[0]) + weight = cast(torch.Tensor, local_tensor_args[1]) + stride, padding, dilation = local_tensor_args[3:6] + + assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) + assert isinstance(padding, List) + + if not _requires_data_exchange(padding): + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + return local_results + else: + # step 0 compute the overlap pixels of the input tensor + d = weight.shape[3] - 1 + d1 = d // 2 + d2 = d - d1 + assert d1 + d2 == d + right = (rank + 1) % size + left = (rank - 1 + size) % size + + # step1 reconstruct local input tensor + in_tensor = _ring_send_recv_construct( + in_tensor, d1, d2, left, right, rank, size + ) + + # step2 feed local input tensor to op_call + local_tensor_args_list = list(local_tensor_args) + local_tensor_args_list[0] = in_tensor + local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # step3 remove extra outputs from the results + padding_w = padding[1] + w = local_results.size(3) + if rank == 0: + local_results = local_results[:, :, :, : w - padding_w] + elif rank == size - 1: + local_results = local_results[:, :, :, padding_w:] + else: + local_results = local_results[:, :, :, padding_w : w - padding_w] + + return local_results + + +def tp_convolution_backward( + op_call: torch._ops.OpOverload, + local_tensor_args: Tuple[object, ...], + local_tensor_kwargs: Dict[str, object], +) -> object: + assert op_call == aten.convolution_backward.default + assert len(local_tensor_args) == 11 + + rank = dist.get_rank() + size = dist.get_world_size() + grad_out_tensor = cast(torch.Tensor, local_tensor_args[0]) + in_tensor = cast(torch.Tensor, local_tensor_args[1]) + weight = cast(torch.Tensor, local_tensor_args[2]) + stride, padding, dilation = local_tensor_args[4:7] + + assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) + assert isinstance(padding, List) + + if not _requires_data_exchange(padding): + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + return local_results + else: + # step 0 compute the overlap pixels of the input tensor + d = weight.shape[3] - 1 + d1 = d // 2 + d2 = d - d1 + assert d1 + d2 == d + right = (rank + 1) % size + left = (rank - 1 + size) % size + + # step1 reconstruct local input tensor + in_tensor = _ring_send_recv_construct( + in_tensor, d1, d2, left, right, rank, size + ) + + # step2 reconstruct local gradient output tensor + N, C_out, H_out, _ = grad_out_tensor.shape + padding_w = padding[1] + if rank == 0: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (0, padding_w), "constant", 0 + ) + elif rank == size - 1: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (padding_w, 0), "constant", 0 + ) + else: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (padding_w, padding_w), "constant", 0 + ) + + # step3 feed local input tensor to op_call + local_tensor_args_list = list(local_tensor_args) + local_tensor_args_list[0] = grad_out_tensor + local_tensor_args_list[1] = in_tensor + local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # step4 aggregate gradients for edge pixels + grad_in_tensor = local_results[0] + grad_in_tensor = _ring_send_recv_aggregate( + grad_in_tensor, d1, d2, left, right, rank, size + ) + + local_results = list(local_results) + local_results[0] = grad_in_tensor + local_results = cast(Tuple[object, ...], local_results) + + return local_results + + +def convolution_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + + # sharding propagation + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + + # local propagation + local_results = tp_convolution( + op_call, tuple(op_info.local_args), op_info.local_kwargs + ) + + return dtensor.DTensor._op_dispatcher.wrap( + local_results, output_sharding.output_spec + ) + + +def convolution_backward_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + # Redistribute grad_output tensor to the same placement as input tensor + args = list(args) + assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor) + args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements) + args = tuple(args) + + # extract local tensor and sharding infos to a OpInfo + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + + # sharding propagation + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + + # local propagation + local_results = tp_convolution_backward( + op_call, tuple(op_info.local_args), op_info.local_kwargs + ) + + return dtensor.DTensor._op_dispatcher.wrap( + local_results, output_sharding.output_spec + ) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_utils.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..55a892063db9260e0162f4f07ed9fb4bb684cc8c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/_utils.py @@ -0,0 +1,316 @@ +from typing import cast, List, Sequence, Tuple + +import torch +import torch.distributed.tensor._api as dtensor +from torch._prims_common import ShapeType +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) + + +# TODO: audit existing code base to see if we can safely remove this API. +def compute_local_shape( + global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[int, ...]: + """ + Compute the shape of a local shard of the given DTensor on its current + coordinate of the mesh. + """ + my_coordinate = mesh.get_coordinate() + + if my_coordinate is None: + # if rank not in the mesh, return empty shape + return (0,) + else: + local_shape = list(global_shape) # start with global shape + ndim = len(global_shape) + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + assert ( + shard_dim < ndim + ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" + local_shard_size, _ = placement._local_shard_size_on_dim( + local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] + ) + assert isinstance(local_shard_size, int) + local_shape[shard_dim] = local_shard_size + + return tuple(local_shape) + + +def compute_local_shape_and_global_offset( + global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Compute the local tensor shape and the global offsets into the original tensor + of a DTensor on its current global rank. This is useful for checkpointing purpose. + + Example (2 host with 4GPUs each): + # Below is a DeviceMesh with mesh_shape of (2, 4) + mesh = DeviceMesh(device_type="cuda", + mesh=[ + [0, 1, 2, 3], + [4, 5, 6, 7] + ], + ) + + Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh + with a placements of [Shard(0), Shard(0)]. + The local shape and global offset will be as follows: + rank0 -- local_shape:[1, 4], global_offset:[0, 0] + rank1 -- local_shape:[1, 4], global_offset:[1, 0] + rank2 -- local_shape:[1, 4], global_offset:[2, 0] + rank5 -- local_shape:[1, 4], global_offset:[5, 0] + rank3 -- local_shape:[1, 4], global_offset:[3, 0] + rank4 -- local_shape:[1, 4], global_offset:[4, 0] + rank6 -- local_shape:[1, 4], global_offset:[6, 0] + rank7 -- local_shape:[1, 4], global_offset:[7, 0] + + Let's say we distribute a global_tensor of shape (2) over the above DeviceMesh with + a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks. + The local shape and global offset will be as follows: + rank0 -- local_shape:[1,], global_offset:[0,] + rank1 -- local_shape:[1,], global_offset:[1,] + rank2 -- local_shape:[0,], global_offset:[2,] + rank5 -- local_shape:[0,], global_offset:[2,] + rank3 -- local_shape:[0,], global_offset:[2,] + rank4 -- local_shape:[0,], global_offset:[2,] + rank6 -- local_shape:[0,], global_offset:[2,] + rank7 -- local_shape:[0,], global_offset:[2,] + """ + my_coordinate = mesh.get_coordinate() + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((), ()) + else: + local_shape = list(global_shape) + global_offset = [0] * len(global_shape) + shard_idx_stride_by_mesh_dim = [ + [0] * mesh.ndim for _ in range(len(global_shape)) + ] # index by (shard_dim, mesh_dim) + num_shards_by_tensor_dim = [1] * len(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + local_offset = [0] * len(global_shape) + assert shard_dim < len( + local_shape + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + + local_shape[shard_dim] = shard_size + local_offset[shard_dim] = shard_offset + + # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], + # it means that this dimension has been already sharded in previous placement. + # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. + # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. + if global_offset[shard_dim] <= local_offset[shard_dim]: + global_offset[shard_dim] = local_offset[shard_dim] + else: + global_offset[shard_dim] += local_offset[shard_dim] + + num_shards_by_tensor_dim[shard_dim] *= mesh_dim_size + + # NOTE: the offset compute relies on the local shard index and it has no + # problem when strided sharding is not present. To correctly compute, we assume + # that the ``_StridedShard.split_factor`` field encodes how many partitions + # each local tensor will be further split into when sharding on higher mesh + # dimensions. However, this number is only correct if the DTensor is not + # sharded after the strided sharding completes. For example, + # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements + # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on + # device mesh dim-2, and last on mesh dim-1. We define the + # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding + # part because strided sharding happens on mesh dim-1 and it was caused by + # the fact that sharding on dim-2 occurred ahead. In this case, there's no + # further sharding after this strided sharding part and ``split_factor`` + # correctly encodes the number. Another example is + # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's + # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh + # dim-2. This violates our assumption that no further sharding shall occur + # after the strided sharding part and ``split_factor`` won't correctly + # encode the number of further split. So far, the only case where _StridedShard + # placement would appear is FSDP2 + TP on 2D mesh and the above case could only + # happen on mesh of 3 or more dimensions. + # TODO: change this function to correctly address this. + # TODO: this logic can be applied to contiguous sharding as well + strided_sharding = any(isinstance(p, _StridedShard) for p in placements) + if strided_sharding: + strided_part_seen = [False] * len(global_shape) + strided_part_end = [False] * len(global_shape) + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + + if strided_part_end[shard_dim]: + raise NotImplementedError( + f"Strided sharding does not allow Shard() to appear after " + f"the strided part has ended. {placement} at idx {idx} in " + f"{placements} violates this assumption." + ) + + if strided_part_seen[shard_dim]: + strided_part_end[shard_dim] = True + + if isinstance(placement, _StridedShard): + strided_part_seen[shard_dim] = True + shard_idx_stride_by_mesh_dim[shard_dim][ + idx + ] = num_shards_by_tensor_dim[shard_dim] // ( + placement.split_factor * mesh_dim_size + ) + else: + num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size + shard_idx_stride_by_mesh_dim[shard_dim][ + idx + ] = num_shards_by_tensor_dim[shard_dim] + + shard_idx = [ + sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)]) + for shard_dim, shard_idx_stride in enumerate( + shard_idx_stride_by_mesh_dim + ) + ] + + global_offset = [x * y for x, y in zip(local_shape, shard_idx)] + + return tuple(local_shape), tuple(global_offset) + + +def compute_global_tensor_info( + tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[List[int], List[int]]: + """ + Compute the global size and stride of a DTensor from the given local tensor. + The local size is multiplited by `world_size` per Sharding dim. + The local stride is multiplited by `world_size` per Sharding dim, as long as the + dimension is outside sharding dim. + + For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8). + If the DTensor placements are [Shard(2)] and world_size is 2; + then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8). + + Args: + tensor (:class:`torch.Tensor`): + Local tensor which DTensor will be constructed from. + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + placements (Sequence[:class:`Placement`]]): + The attribute of the DTensor that describes its layout + on the mesh topology. + + Return: + tensor_shape: A List of int which specifies the size of DTensor which build + on top of the local tensor. + tensor_stride: A List of int which specifies the stride of DTensor. + """ + tensor_shape = list(tensor.size()) + tensor_stride = list(tensor.stride()) + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if placement.is_shard(): + shard_placement = cast(Shard, placement) + if shard_placement.dim < 0: + raise AssertionError( + "Shard placements should have negative dims normalized in " + f"the user-facing APIs: {shard_placement}" + ) + shard_dim = shard_placement.dim + + assert ( + shard_dim < tensor.ndim + ), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." + + local_dim_size = tensor_shape[shard_dim] + tensor_shape[shard_dim] = local_dim_size * mesh_dim_size + + # recover tensor stride by modifying the stride that larger than + # the current stride on the shard_dim + for i in range(len(tensor_stride)): + if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]: + # rescale the stride by the shard size + tensor_stride[i] = tensor_stride[i] * mesh_dim_size + elif not isinstance(placement, (Replicate, Partial)): + raise RuntimeError(f"placement type {type(placement)} not supported!") + return tensor_shape, tensor_stride + + +def try_find_mesh_from_args( + op_call: torch._ops.OpOverload, args: Sequence[object] +) -> DeviceMesh: + """ + Find the device mesh object from args. + It returns None if no mesh is found. + NOTE: we can optimize this search if needed + """ + for arg in args: + if isinstance(arg, (dtensor.DTensor, DTensorSpec)): + return arg.device_mesh + elif ( + isinstance(arg, (list, tuple)) + and len(arg) > 0 + and isinstance(arg[0], (dtensor.DTensor, DTensorSpec)) + ): + return arg[0].device_mesh + + raise ValueError(f"Cannot find device mesh from args for op : {op_call}.") + + +def compute_local_stride( + global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[int, ...]: + """ + Compute the stride of a local tensor shard, given the global stride of the DTensor. + NOTE: Currently this function is assuming the DTensor is evenly shardable. + """ + stride_divisors = [1] * len(global_stride) + for mesh_idx, p in enumerate(placements): + if p.is_shard(): + i = cast(Shard, p).dim + # tensor dimension i is sharded on mesh dimension mesh_idx, + # so we need to divide all the strides larger than stride[i] + # (by the submesh size) + for j in range(len(global_stride)): + if global_stride[j] > global_stride[i]: + stride_divisors[j] *= mesh.size(mesh_idx) + return tuple( + global_stride[i] // stride_divisors[i] for i in range(len(global_stride)) + ) + + +def normalize_to_torch_size(size) -> torch.Size: # type: ignore[no-untyped-def] + """ + Unify variable types of size argument to torch.Size + Acceptable types include: + int, Sequence[int], Tuple[int], Tuple[Sequence[int]], + or torch.Size + """ + if isinstance(size, torch.Size): + return size + + if isinstance(size, int): + torch_size = [size] + elif len(size) == 1 and isinstance(size[0], Sequence): + torch_size = list(size[0]) + else: + torch_size = list(size) + return torch.Size(torch_size) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/device_mesh.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..ca59ded5eb52bc0a3878e76077ad2879df4bf499 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/device_mesh.py @@ -0,0 +1,9 @@ +from torch.distributed.device_mesh import ( # noqa: F401 + _get_device_handle, + _mesh_resources, + DeviceMesh, + init_device_mesh, +) + + +__all__ = ["init_device_mesh", "DeviceMesh"] diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/ddp.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..ccda5bc1b47fb4ac8721011b6d0c7486bfa14f9b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/ddp.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +from typing import Any, List, Optional, Set, Tuple + +import torch.nn as nn +from torch.distributed.tensor.parallel._data_parallel_utils import ( + _flatten_tensor, + _unflatten_tensor, +) + + +__all__ = [] # type: ignore[var-annotated] + + +def _get_submodule_n_params(module: nn.Module, path: str): + """ + Get submodule and the direct path of parameter from the module + """ + if "." in path: + path_list = path.split(".") + parent_module_path = ".".join(path_list[:-1]) + module = module.get_submodule(parent_module_path) + path = path_list[-1] + return module, path + + +def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]): + """ + Update parameters within the module + """ + for item in param_list: + parent_module, module_path, t = item + assert hasattr(parent_module, module_path) + delattr(parent_module, module_path) + setattr(parent_module, module_path, t) + + +def _reconstruct_dtensor(module: nn.Module, _input: Any): + """ + Recontruct DTensor parameters from local tensors + """ + param_list = [] + # TODO: To add perf optimizations to this iterations + for name, t in module.named_parameters(): + if hasattr(t, "_st_info"): + dtensor = _unflatten_tensor(t, t._st_info) + param_list.append((*_get_submodule_n_params(module, name), dtensor)) + _update_module_param(param_list) # type: ignore[arg-type] + + +def _localize_dtensor( + module: nn.Module, *_: Any, ignored_params: Optional[Set[nn.Parameter]] = None +): + """ + Convert DTensor parameters to local tensors + """ + if ignored_params is None: + ignored_params = set() + param_list = [] + for name, param in module.named_parameters(): + if param in ignored_params: + continue + t, sharding_info = _flatten_tensor(param) + if sharding_info is not None: + t = nn.Parameter(t) + t._st_info = sharding_info # type: ignore[attr-defined] + param_list.append((*_get_submodule_n_params(module, name), t)) + _update_module_param(param_list) # type: ignore[arg-type] + + +def _pre_dp_module_transform(module: nn.Module): + """ + Enable the composability between Tensor Parallelism (TP) and Data + Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which + are DTensors to local tensors before wrapping with data parallelism API. + We then register two hooks, one for converting local tensors back to DTensor + preforward and one to convert DTensors back to tensors after Forward. By + integrating this way, we avoid any special handling of DTensor parameters by DDP + and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP. + + For now, this API only works with ``DistributedDataParallel``. It will later support + other DP methods such as FSDP. + + Args: + module (:class:`nn.Module`): + Module which has been applied TP on. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform + >>> + >>> # Define the module. + >>> m = module(...) + >>> parallelize_module(m, PairwiseParallel()) + >>> m = pre_dp_module_transform(m) + >>> m = DDP(m) + >>> + """ + + _localize_dtensor(module, None, None) + # TODO: To add test cases and ensure that it works for nested modules + module.register_forward_pre_hook(_reconstruct_dtensor) + module.register_forward_hook(_localize_dtensor) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/input_reshard.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/input_reshard.py new file mode 100644 index 0000000000000000000000000000000000000000..630c287cae88f50b54dec92f135fcaa517156d82 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/input_reshard.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from functools import partial +from typing import Any, Optional, Tuple + +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard + + +__all__ = [ + "input_reshard", +] + + +def input_reshard( + module: torch.nn.Module, + tp_device_mesh: DeviceMesh, + input_reshard_dim: Optional[int] = None, +) -> torch.nn.Module: + """ + Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation. + + Register hooks to an nn.Module with input resharding so that we can shard + per the given `tp_device_mesh` and `input_reshard_dim` and restore the + input back when recomputing the activations in the backward. The reason + why we can do this is that for Tensor Parallel(TP), the input are same + across all TP ranks. + + Args: + module (:class:`nn.Module`): + Module to be registered with input resharding. + tp_device_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for Tensor Parallel. + input_reshard_dim (Optional[int]): + The dimension of where we perform the sharding + of input. If set None, there is no sharding of input. + Default: None + + Return: + A :class:`nn.Module` object registered with TP input resharding. + """ + cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None + + def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None: + saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks( + partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim), + partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim), + ) + saved_tensor_hooks.__enter__() + nonlocal cx + cx = saved_tensor_hooks # type: ignore[name-defined] + + def input_reshard_backward_hook( + _: torch.nn.Module, _i: Tuple[Any, ...], _o: Any + ) -> Any: + nonlocal cx + cx.__exit__() # type: ignore[name-defined, union-attr] + + if input_reshard_dim is None: + return module + module.register_forward_pre_hook(input_reshard_forward_pre_hook) + module.register_forward_hook(input_reshard_backward_hook) + return module + + +def _pack_hook_tp( + mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor +) -> Any: # noqa: D401 + """Hook function called after FWD to shard input.""" + if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): + return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) + elif ( + not isinstance(x, DTensor) + and isinstance(x, torch.Tensor) + and x.numel() >= mesh.size() + ): + return ( + DTensor.from_local(x, device_mesh=mesh) + .redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) + .to_local() + ) + else: + return x + + +def _unpack_hook_tp( + mesh: DeviceMesh, input_reshard_dim: int, x: Any +) -> torch.Tensor: # noqa: D401 + """Hook function called before activation recomputing in BWD to restore input.""" + if ( + isinstance(x, DTensor) + and len(x._spec.placements) == 1 + and x._spec.placements[0].is_shard() + ): + return x.redistribute(device_mesh=mesh, placements=[Replicate()]) + elif ( + not isinstance(x, DTensor) + and isinstance(x, torch.Tensor) + and x.numel() >= mesh.size() + ): + return ( + DTensor.from_local( + x, device_mesh=mesh, placements=[Shard(input_reshard_dim)] + ) + .redistribute(device_mesh=mesh, placements=[Replicate()]) + .to_local() + ) + else: + return x diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/loss.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..99f1e3ad6ef9adbfc8eaf316d9b082c78bbcbc6f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/loss.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +from typing import cast, Dict, Optional, Tuple + +import torch +import torch._prims_common as utils +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +from torch.distributed.tensor._ops._math_ops import ( + _skip_dim, + Reduction, + replicate_reduction_dims, +) +from torch.distributed.tensor.placement_types import Placement + + +aten = torch.ops.aten + + +__all__ = ["loss_parallel"] + + +@contextlib.contextmanager +def loss_parallel(): + """ + A context manager that enables loss parallelism, where efficient parallelized loss computation + can be performed when the input is sharded on the class dimension. Currently only the cross-entropy + loss is supported. + + Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or + :class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters. + The corresponding ``backward()`` call, if any, also needs to happen under this context manager. + + Args: + input (:class:`DTensor`): + Input logits. Assumed to be sharded on the class dimension. + target (Union[:class:`torch.Tensor`, :class:`DTensor`]): + Must be ground truth class indices (class probabilities currently not supported). + Assumed to be replicated across the ``DeviceMesh``. + weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional): + If given, assumed to be replicated across the ``DeviceMesh``. + label_smoothing: + Currently not supported. + + Returns: + A replicated :class:`DTensor`. + + Example: + A sharded DTensor is manually created here to showcase the usage. + In practice, it is usually the output of a TP module. + + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import loss_parallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> device_mesh = init_device_mesh("cuda", (8,)) + >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) + >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) + >>> target = torch.randint(16, (4,), device="cuda") + >>> with loss_parallel(): + >>> loss = F.cross_entropy(dist_input, target, reduction="mean") + >>> loss.backward() + >>> ... + """ + _enable_custom_loss_ops() + + yield + + _disable_custom_loss_ops() + + +# Currently only needs to support one dimensional DeviceMesh; in general return +# the mesh_dim with placements[mesh_dim].is_shard(dim) +def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int: + if not len(placements) == 1: + raise ValueError( + "Currently loss_parallel() only supports input on one-dimensional DeviceMesh." + ) + if not placements[0].is_shard(dim): + raise ValueError( + f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}." + ) + return 0 + + +def _cast_to_dtensor( + tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh +) -> DTensor: + if isinstance(tensor, DTensor): + if tensor.placements == placements: + return tensor + else: + raise RuntimeError(f"Expected {placements} but got {tensor.placements}.") + elif isinstance(tensor, torch.Tensor): + return DTensor.from_local( + tensor, device_mesh=mesh, placements=placements, run_check=False + ) + else: + raise TypeError(f"Unsupported type {type(tensor)}") + + +def _propagate_tensor_meta( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> TensorMeta: + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta( + op_info.schema + ) + if isinstance(tensor_meta, TensorMeta): + return tensor_meta + elif isinstance(tensor_meta, tuple): + return tensor_meta[0] + else: + raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.") + + +# NOTE: The implementation follows torch._decomp.decomposition._log_softmax, +# with all_reduce manually inserted to perform distributed computation. +def _log_softmax(x, dim, half_to_float, mesh, mesh_dim): + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if x.numel() == 0: + shifted = x + else: + x_max = torch.amax(x, dim, keepdim=True) + x_max = funcol.all_reduce( + x_max, reduceOp=c10d.ReduceOp.MAX.name, group=(mesh, mesh_dim) + ) + shifted = x - x_max + shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True) + shifted_sumexp = funcol.all_reduce( + shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=(mesh, mesh_dim) + ) + shifted_logsumexp = torch.log(shifted_sumexp) + result = shifted - shifted_logsumexp + if not half_to_float: + result = result.to(result_dtype) + return result + + +def _log_softmax_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + x = cast(DTensor, args[0]) + dim = cast(int, args[1]) + half_to_float = cast(bool, args[2]) + + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim) + + output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs) + + res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim) + + res_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=output_tensor_meta, + ) + + return DTensor( + res, + res_spec, + requires_grad=res.requires_grad, + ) + + +# NOTE: As explained below at _nll_loss_and_log_softmax_backward, the +# _log_softmax_backward_handler does not actually do any computation. +def _log_softmax_backward_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + grad_output = cast(DTensor, args[0]) + input_dtype = cast(torch.dtype, args[3]) + return grad_output.to(input_dtype) + + +# NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward, +# with customized communication inserted to perform distributed computation. +def _nll_loss_forward( + x: Tensor, + target: Tensor, + weight: Optional[Tensor], + local_weight: Optional[Tensor], + reduction: int, + ignore_index: int, + input_shape: torch.Size, + channel_dim: int, + mesh: DeviceMesh, + mesh_dim: int, +) -> Tuple[Tensor, Tensor]: + n_dims = x.dim() + channel_dim = 1 + if n_dims < 2: + channel_dim = 0 + + def _weight_view(weight: Tensor) -> Tensor: + if n_dims > 1: + shape = [ + 1, + ] * n_dims + shape[channel_dim] = weight.shape[0] + w = weight.view(shape) + else: + w = weight + return w + + if weight is not None: + w = _weight_view(weight) + assert local_weight is not None + local_w = _weight_view(local_weight) + x = x * local_w + safe_target = torch.where(target != ignore_index, target, 0) + safe_target_ = safe_target.unsqueeze(channel_dim) + + # The following code block is a distributed version of + # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) + partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) + safe_target_partial_ = partial_placement._partition_value( + safe_target_, mesh, mesh_dim + ) + result_partial = torch.gather(x, channel_dim, safe_target_partial_) + # an all_reduce happens here + result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim) + result = -result_reduced.squeeze(channel_dim) + + result = torch.where(target != ignore_index, result, 0) + + if reduction == Reduction.NONE.value and n_dims > 1: + total_weight = x.new_full((), 0.0) + return result, total_weight + + if weight is not None: + new_shape = list(x.shape) + new_shape[channel_dim] = -1 + w = w.expand(new_shape) + wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) + wsum = torch.where(target != ignore_index, wsum, 0) + total_weight = wsum.sum() + else: + total_weight = (target != ignore_index).sum().to(x) + + # NOTE: this is correct only on 1D DeviceMesh; o/w additional + # all-reduce on result and total_weight is needed + if reduction == Reduction.SUM.value: + result = result.sum() + elif reduction == Reduction.MEAN.value: + result = result.sum() / total_weight + + return result, total_weight + + +def _nll_loss_forward_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + x = cast(DTensor, args[0]) + target = args[1] + weight = args[2] + reduction = cast(int, args[3]) + ignore_index = cast(int, args[4]) + + channel_dim = 1 if x.dim() >= 2 else 0 + channel_dim_size = x.shape[channel_dim] + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) + + # Check user input: if target and weight are not DTensors, convert them to DTensors; + # if they are DTensors, check that they have the desired placements. + target_placements = _skip_dim( + replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim + ) + all_replicate_placements = (Replicate(),) * spec.mesh.ndim + target = _cast_to_dtensor(target, target_placements, spec.mesh) + local_weight = None + if weight is not None: + weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) + # For local computation, both (replicated) weight and (sharded) local_weight + # are needed in _nll_loss_forward(). local_weight is generated here using + # DTensor API, without incurring any communication. + sharded_placements = [ + Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim) + ] + local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor + assert local_weight.shape[0] == x._local_tensor.shape[channel_dim] + + if reduction == Reduction.NONE.value: + output_placements = target_placements + else: + output_placements = all_replicate_placements + + # tensor inputs to _propagate_tensor_meta need to be DTensors + args = list(args) + args[1], args[2] = target, weight + output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) + + result, total_weight = _nll_loss_forward( + x._local_tensor, + target._local_tensor, + weight._local_tensor if weight is not None else None, + local_weight, + reduction, + ignore_index, + x.shape, + channel_dim, + spec.mesh, + mesh_dim, + ) + out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta) + + return ( + DTensor( + result, + out_spec, + requires_grad=result.requires_grad, + ), + total_weight, + ) + + +# NOTE: The backward computation of cross_entropy goes through two steps: +# backward for nll_loss and then backward for log_softmax. In loss parallel, +# the two steps are fused into the following function (called by _nll_loss_backward_handler) +# to avoid communication when target contains class indices not class probabilities. +# Also note that the _log_softmax_backward_handler does not perform computation. +# The implementation resembles _nll_loss_backward and _log_softmax_backward_data +# from torch._decomp.decomposition. +def _nll_loss_and_log_softmax_backward( + grad_output: Tensor, + x: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, + input_shape: torch.Size, + channel_dim: int, + mesh: DeviceMesh, + mesh_dim: int, +) -> Tensor: + channel_dim = 0 if x.dim() < 2 else 1 + if reduction == Reduction.MEAN.value: + grad_output = grad_output / total_weight + + target = target.unsqueeze(channel_dim) + safe_target = torch.where(target != ignore_index, target, 0) + grad_input = torch.zeros_like(x) + + # The following code block is a distributed version of + # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) + partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) + safe_target = safe_target.squeeze(channel_dim).flatten() + masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim) + # only update grad_input to -1 if not masked + assert partial_placement.mask_buffer.data is not None + grad_update = partial_placement.mask_buffer.data.to(grad_input.dtype) - 1.0 + arange_1d = torch.arange( + masked_safe_target.shape[0], device=masked_safe_target.device + ) + # The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default; + # the last case is for aten.nll_loss2d_backward.default. + if x.dim() == 1: + grad_input[masked_safe_target] = grad_update + elif x.dim() == 2: + grad_input[arange_1d, masked_safe_target] = grad_update + else: + grad_input_t = grad_input.transpose(channel_dim, -1) + intermidate_shape = grad_input_t.shape + grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim]) + grad_input_2d[arange_1d, masked_safe_target] = grad_update + grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1) + + if grad_input.dim() > grad_output.dim() > 0: + grad_output = grad_output.unsqueeze(channel_dim) + + if weight is not None: + new_shape = [1 for _ in range(x.dim())] + new_shape[channel_dim] = weight.shape[0] + weight = weight.reshape(new_shape) + # In order for fused computation to work, the following line is rewritten. + # grad_output = grad_output * weight + new_shape = list(x.shape) + new_shape[channel_dim] = -1 + w = weight.expand(new_shape) + w_target = torch.gather(w, channel_dim, target) + grad_output = grad_output * w_target + + grad_output = torch.where(target != ignore_index, grad_output, 0) + + # NOTE: Instead of directly returning the grad_input as grad_output for log_softmax, + # here we perform backward computation for log_softmax altogether to avoid the + # otherwise extra all_gather communication. + # return grad_input * grad_output + return (grad_input + torch.exp(x)) * grad_output + + +def _nll_loss_backward_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + grad_output = cast(DTensor, args[0]) + x = cast(DTensor, args[1]) + target = args[2] + weight = args[3] + reduction = cast(int, args[4]) + ignore_index = cast(int, args[5]) + total_weight = cast(Tensor, args[6]) + + channel_dim = 1 if x.dim() >= 2 else 0 + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) + + # if target and weight are not DTensors, convert them to DTensors + target_placements = _skip_dim( + replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim + ) + all_replicate_placements = (Replicate(),) * spec.mesh.ndim + target = _cast_to_dtensor(target, target_placements, spec.mesh) + if weight is not None: + weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) + + # tensor inputs to _propagate_tensor_meta need to be DTensors + args = list(args) + args[2], args[3] = target, weight + args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh) + output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) + + result = _nll_loss_and_log_softmax_backward( + grad_output._local_tensor, + x._local_tensor, + target._local_tensor, + weight._local_tensor if weight is not None else None, + reduction, + ignore_index, + total_weight, + x.shape, + channel_dim, + spec.mesh, + mesh_dim, + ) + # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim + out_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=output_tensor_meta, + ) + + return DTensor( + result, + out_spec, + requires_grad=result.requires_grad, + ) + + +customized_loss_ops = { + aten._log_softmax.default: _log_softmax_handler, + aten._log_softmax_backward_data.default: _log_softmax_backward_handler, + aten.nll_loss_forward.default: _nll_loss_forward_handler, + aten.nll_loss2d_forward.default: _nll_loss_forward_handler, + aten.nll_loss_backward.default: _nll_loss_backward_handler, + aten.nll_loss2d_backward.default: _nll_loss_backward_handler, +} + + +def _enable_custom_loss_ops(): + DTensor._op_dispatcher._custom_op_handlers.update(customized_loss_ops) + + +def _disable_custom_loss_ops(): + for custom_op in customized_loss_ops: + DTensor._op_dispatcher._custom_op_handlers.pop(custom_op) diff --git a/.venv/lib/python3.11/site-packages/torch/distributed/tensor/placement_types.py b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/placement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9834ab8b81cddd7efba11a180465513db615d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/distributed/tensor/placement_types.py @@ -0,0 +1,652 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +from dataclasses import dataclass +from typing import cast, List, Optional, Tuple + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._collective_utils import ( + fill_empty_tensor_to_shards, + mesh_broadcast, + mesh_scatter, + pad_tensor, + shard_dim_alltoall, + unpad_tensor, +) + + +__all__ = ["Placement", "Shard", "Replicate", "Partial"] + + +class Placement: + """ + The base class for the Placement type, where it describes how a DTensor is placed onto the + ``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout. + It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``, + and ``Partial``. + + This class is not meant to be used directly, mainly served as a typing stub. + """ + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + is_shard_instance = isinstance(self, Shard) + if dim is not None and is_shard_instance: + return cast(Shard, self).dim == dim + else: + return is_shard_instance + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, Partial) + + +@dataclass(frozen=True) +class Shard(Placement): + """ + The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension + ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension only holds a shard/piece of the global Tensor. The + ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the + last few shards on the DeviceMesh dimension might be empty when the tensor dimension + is not evenly divisble on the DeviceMesh dimension. The ``Shard`` placement can be + used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) + + Args: + dim (int): The tensor dimension that describes the DTensor is sharded over its + corresponding DeviceMesh dimension. + + .. warning:: sharding on a tensor dimension where the tensor dimension size is not + evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. + """ + + dim: int + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + This function uses torch.chunk to split a tensor into num_chunks shards along + the Shard placement dimension, and return a list of shards with their pad sizes. + + Keyword args: + with_padding (bool, optional): when True, we pad the tensor on the last + few ranks before calling the collectives (i.e. scatter/all_gather, etc.). + This is because collectives usually require equal size tensor inputs + """ + assert ( + self.dim <= tensor.ndim + ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + + # chunk tensor over dimension `dim` into n slices + tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) + num_empty_tensors = num_chunks - len(tensor_list) + + # if no need to have padding or tensor dim size is evenly sharded already + # we can return early. + if not with_padding or tensor.size(self.dim) % num_chunks == 0: + if contiguous: + tensor_list = [t.contiguous() for t in tensor_list] + return ( + fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors), + [], + ) + + # compute the chunk size inline with ``torch.chunk`` to calculate padding + full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks + + # Compute chunk size for each chunk for ``self.dim`` + chunk_sizes = [ + tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0 + for idx in range(num_chunks) + ] + # Compute pad size on each chunk + pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] + + # Reuse tensor to fill empty chunk with empty tensor + tensor_list = fill_empty_tensor_to_shards( + tensor_list, self.dim, num_empty_tensors + ) + shard_list = [] + for shard, pad_size in zip(tensor_list, pad_sizes): + # Fill the empty tensor with zeroes with padding. + if with_padding and pad_size > 0: + shard = pad_tensor(shard, self.dim, pad_size) + shard = shard.contiguous() if contiguous else shard + shard_list.append(shard) + return shard_list, pad_sizes + + @staticmethod + def _local_shard_size_on_dim( + size_on_dim: int, + num_chunks: int, + rank: int, + return_offset: bool = False, + ) -> Tuple[int, int]: + """ + returns the local shard size and offset on a given tensor dim + """ + # Compute the chunk size inline with ``torch.chunk`` + if size_on_dim % num_chunks == 0: + full_chunk_size = size_on_dim // num_chunks + return full_chunk_size, full_chunk_size * rank if return_offset else -1 + + # uneven sharding case + full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks + shard_starting_idx = full_chunk_size * rank + + if size_on_dim < shard_starting_idx: + return 0, size_on_dim if return_offset else -1 + else: + local_shard_size = ( + min(size_on_dim, shard_starting_idx + full_chunk_size) + - shard_starting_idx + ) + return local_shard_size, shard_starting_idx if return_offset else -1 + + def _shard_tensor( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + shard and scatter a tensor on a mesh dimension (use coordinate + 0 on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + scatter_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + + mesh_dim_local_rank = my_coordinate[mesh_dim] + output = torch.empty_like(scatter_list[mesh_dim_local_rank]) + mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim) + + # Only unpad if the local_tensor was padded on the dimension. + if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0: + output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank]) + return output + + def _reduce_shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + reduce_op: str, + mesh_dim: int, + ) -> torch.Tensor: + """ + reduce and scatter a tensor on a mesh dimension + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return tensor + + is_padded = tensor.size(self.dim) % num_chunks != 0 + if is_padded: + scattered_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + tensor = torch.cat(scattered_list, dim=self.dim) + elif not tensor.is_contiguous(): + tensor = tensor.contiguous() + + output = funcol.reduce_scatter_tensor( + tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) + ) + + if is_padded: + output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined] + return output + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + ) -> torch.Tensor: + """ + This function all_gather all shards and return a tensor that + is replicated on the previously sharded mesh dimension + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + # check if it's uneven, so we need to pad input tensor before all_gather + local_shape = list(local_tensor.size()) + + logical_dim_size = current_logical_shape[self.dim] + is_padded = logical_dim_size % num_chunks != 0 + + if is_padded: + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + pad_size = full_chunk_size - local_shape[self.dim] + local_tensor = pad_tensor(local_tensor, self.dim, pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if is_padded: + unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] + result = unpad_tensor(result, self.dim, unpad_size) + return result + + def _replicate_to_shard( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_index: int, + ) -> torch.Tensor: + """ + transform from replicated tensor to a sharded tensor on + the current rank, which would perform a local chunk + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + shards, _ = self._split_tensor( + local_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + return shards[shard_index].clone() + + def _to_new_shard_dim( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + new_shard_dim: int, + ) -> torch.Tensor: + """ + transform from existing sharded tensor to a new sharded tensor on + that shard on a new dimension, which performs an alltoall + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return local_tensor + + num_chunks = mesh.size(mesh_dim=mesh_dim) + + old_dim_logical_size = current_logical_shape[self.dim] + new_dim_logical_size = current_logical_shape[new_shard_dim] + old_dim_padding = old_dim_logical_size % num_chunks != 0 + new_dim_padding = new_dim_logical_size % num_chunks != 0 + if old_dim_padding: + old_dim_full_chunk_size = ( + old_dim_logical_size + num_chunks - 1 + ) // num_chunks + old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim) + local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size) + if new_dim_padding: + new_dim_full_chunk_size = ( + new_dim_logical_size + num_chunks - 1 + ) // num_chunks + new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size( + new_shard_dim + ) + local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + new_tensor = shard_dim_alltoall( + local_tensor, self.dim, new_shard_dim, mesh, mesh_dim + ) + + if old_dim_padding: + old_dim_unpad_size = ( + old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim] # type: ignore[possibly-undefined] + ) + new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined] + + if new_dim_padding: + local_shard_size_on_new_dim = self._local_shard_size_on_dim( + new_dim_logical_size, num_chunks, my_coordinate[mesh_dim] + )[0] + new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] + new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] + + return new_tensor + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Shard): + return False + return self.dim == other.dim + + def __hash__(self) -> int: + return hash(self.dim) + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + def __str__(self) -> str: + """human readable representation of the Shard placement""" + return f"S({self.dim})" + + +# kw_only is only available in python >= 3.10 +kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {} + + +@dataclass(frozen=True, **kw_only_dataclass) +class _StridedShard(Shard): + """ + _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor + is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. + We call this right-to-left sharding which is the opposite of the default + left-to-right sharding. See the example below: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [Shard(0), Shard(0)] + + The default sharding behavior shards the tensor on "dp" mesh dimension first then + "tp" dimension. The sharding result will be: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 1 (row 2-3) + 2 | (1, 0) | 2 (row 4-5) + 3 | (1, 1) | 3 (row 6-7) + + While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on + "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the + result: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The consequence is, any attempt to redistribute this DTensor to a full replica will + produce a wrong result because the shard-to-replicate redistribution always happens + right-to-left, regardless it's left-to-right sharding or right-to-left. To address + this, we use _StridedShard placement to make this right-to-left sharding compatible + with our left-to-right convention on both tensor distribution and redistribution. + + Now with _StridedShard, the right-to-left sharding above can be represented as: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [_StridedShard(0, split_factor=2), Shard(0)] + + And a left-to-right processing of `placements` will produce the same result, which is + different from using the `Shard` placement: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The argument `split_factor` is the number of existing shards over the tensor sharding + dimension before processing the _StridedShard placement, as if the sharding happened + right-to-left. In the example above, the tensor should first be sharded on the "tp" + dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the + `split_factor` of the _StridedShard placement on "dp" dim is 2. + + TODO: strided sharding needs to work fine with uneven sharding. Now it forbids + resharding if the tensor is unevenly sharded. + TODO: we should remove _StridedShard placement once we can unify it with Shard + """ + + split_factor: int + + def __eq__(self, other: object) -> bool: + if isinstance(other, _StridedShard): + return self.dim == other.dim and self.split_factor == other.split_factor + elif isinstance(other, Shard): + # TODO: this is to avoid extra all-gather in dtensor op dispatch + # note that sharding prop would not produce _StridedShard and an + # placement inequality would introduce an all-gather for resharding + return self.dim == other.dim + return False + + def __hash__(self) -> int: + return hash((self.dim, self.split_factor)) + + def __repr__(self) -> str: + """ + machine readable representation of the _StridedShard placement + """ + return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" + + def __str__(self) -> str: + """human readable representation of the _StridedShard placement""" + return f"_S({self.dim}, {self.split_factor})" + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + TODO: currently _StridedShard does not support padding + """ + assert ( + self.dim <= tensor.ndim + ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + + total_split = num_chunks * self.split_factor + assert tensor.size(self.dim) % total_split == 0, ( + "_StridedShard currently only allows even sharding but got tensor size" + f" {tensor.size(self.dim)} on dim {self.dim} and total split" + f" {total_split}={num_chunks} * {self.split_factor}" + ) + + group_size = self.split_factor + total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim)) + tensor_list = [ + torch.cat( + [ + total_split_tensor_list[i + j * num_chunks] # stride is num_chunks + for j in range(group_size) + ], + dim=self.dim, + ) + for i in range(num_chunks) + ] + + if contiguous: + tensor_list = [t.contiguous() for t in tensor_list] + + return tensor_list, [] + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + ) -> torch.Tensor: + """ + Note: currently _StridedShard does not support padding + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + total_split = num_chunks * self.split_factor + # NOTE: we require Strided Sharding to be even for now + assert current_logical_shape[self.dim] % total_split == 0, ( + "_StridedShard requires even sharding but got tensor size " + f"{current_logical_shape[self.dim]} on dim {self.dim} and " + f"total split {total_split}=num_chunks {num_chunks} " + f"* split_factor {self.split_factor}" + ) + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if isinstance(result, funcol.AsyncCollectiveTensor): + result = result.wait() + + tensor_shard_list = torch.chunk(result, total_split, dim=self.dim) + # rearrange the order + new_tensor_shard_list = [] + for idx in range(len(tensor_shard_list)): + # the shard split of index `idx` is assigned a new index within + # _StridedShard._split_tensor: + # the original tensor was split into `total_split` chunks, + # all chunks with the same `idx % num_chunks` are merged into one + # new shard and placed on mesh's local rank `idx % num_chunks` + idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks + new_tensor_shard_list.append(tensor_shard_list[idx_after_split]) + + return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous() + + +@dataclass(frozen=True) +class Replicate(Placement): + """ + The ``Replicate()`` placement describes the DTensor replicating on a corresponding + ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a + replica of the global Tensor. The ``Replicate`` placement can be used by all + DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) + """ + + def __eq__(self, other: object) -> bool: + return isinstance(other, Replicate) + + def __hash__(self) -> int: + # every replicate placement is the same + return -1 + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + def __str__(self) -> str: + """ + human readable representation of the Replicate placement + """ + return "R" + + def _replicate_tensor( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + Replicate (broadcast) a torch.Tensor on a mesh dimension (use + the first coordinate on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + tensor = tensor.contiguous() + mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim) + return tensor + + +@dataclass(frozen=True) +class Partial(Placement): + """ + The ``Partial(reduce_op)`` placement describes the DTensor that is pending + reduction on a specified ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension holds the partial value of the global Tensor. User can + redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` + placement on the specified ``DeviceMesh`` dimension using ``redistribute``, + which would trigger necessary communication operations under the hood (i.e. + ``allreduce``, ``reduce_scatter``). + + Args: + reduce_op (str, optional): The reduction op to be used for the partial DTensor + to produce Replicated/Sharded DTensor. Only element-wise reduction operations + are supported, including: "sum", "avg", "product", "max", "min", default: "sum". + + .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, + and can only be used by the ``DTensor.from_local`` API. + """ + + reduce_op: str = "sum" + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #1: + # _reduce_value: reduce the value of the tensor on the mesh dimension + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # Partial placement contract #2: + # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #3: + # _partition_value: partition the value of a replicated tensor on the mesh dimension + + # _partition_value is the conjugate operation of _reduce_value + # - i.e. _partition_value on a sum reduce op is just a divison operation + # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation + # TODO: if the reduce_op is min/max, etc. the _partition_value should be a + # different operation + assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" + num_chunks = mesh.size(mesh_dim=mesh_dim) + return tensor / num_chunks + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Partial): + return False + return self.reduce_op == other.reduce_op + + def __hash__(self) -> int: + return 1 + hash(self.reduce_op) + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"Partial({self.reduce_op})" + + def __str__(self) -> str: + """ + human readable representation of the Partial placement + """ + return "P" + + +# We keep the old _Partial name for a while for BC reason +_Partial = Partial