koichi12 commited on
Commit
620d421
·
verified ·
1 Parent(s): f610d77

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_IR.py +1243 -0
  17. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__init__.py +28 -0
  18. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_backward.py +370 -0
  28. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_debug.py +21 -0
  29. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_unflatten.py +27 -0
  30. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_utils.py +99 -0
  31. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/microbatch.py +469 -0
  32. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/schedules.py +2162 -0
  33. .venv/lib/python3.11/site-packages/torch/distributed/pipelining/stage.py +1468 -0
  34. .venv/lib/python3.11/site-packages/torch/distributed/tensor/__init__.py +67 -0
  35. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_api.py +1231 -0
  36. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_collective_utils.py +373 -0
  37. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py +510 -0
  38. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py +276 -0
  39. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_op_schema.py +457 -0
  40. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__init__.py +10 -0
  41. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_experimental_ops.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.24 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc ADDED
Binary file (3.53 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-311.pyc ADDED
Binary file (8.21 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-311.pyc ADDED
Binary file (3.22 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-311.pyc ADDED
Binary file (23.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-311.pyc ADDED
Binary file (630 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-311.pyc ADDED
Binary file (7.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-311.pyc ADDED
Binary file (18.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-311.pyc ADDED
Binary file (22.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_IR.py ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+ import copy
4
+ import logging
5
+ import operator
6
+ from collections import defaultdict
7
+ from enum import Enum
8
+ from inspect import Parameter, Signature, signature
9
+ from types import MethodType
10
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
11
+
12
+ import torch
13
+ import torch.fx as fx
14
+ from torch.distributed import ProcessGroup
15
+ from torch.export import ExportedProgram
16
+ from torch.export.unflatten import (
17
+ _assign_attr,
18
+ _AttrKind,
19
+ _sink_params,
20
+ InterpreterModule,
21
+ )
22
+ from torch.fx.node import map_aggregate
23
+ from torch.fx.passes.split_module import split_module
24
+
25
+ from ._backward import _null_coalesce_accumulate, stage_backward
26
+ from ._unflatten import _outline_submodules
27
+ from ._utils import PipeInfo
28
+ from .stage import _PipelineStage
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # TODO:
34
+ # 1. investigate gradient sync for shared parameters. how does DDP do it?
35
+ # 2. Add parameter movement to split_module
36
+
37
+
38
+ def _find_loss_from_output_and_spec(output_val, spec_val):
39
+ if spec_val is False:
40
+ return None
41
+ if spec_val is True:
42
+ if not isinstance(output_val, fx.Node):
43
+ raise RuntimeError(
44
+ f"Loss spec must specify a dynamic value but got {output_val}"
45
+ )
46
+ return output_val
47
+
48
+ if isinstance(spec_val, (tuple, list)):
49
+ if not isinstance(output_val, (tuple, list)):
50
+ raise RuntimeError(
51
+ f"Output value {output_val} must match type of loss specification "
52
+ f"{spec_val}"
53
+ )
54
+ if len(output_val) != len(spec_val):
55
+ raise RuntimeError(
56
+ f"Output value {output_val} must match length of loss specification "
57
+ f"{spec_val}"
58
+ )
59
+ for out, spec in zip(output_val, spec_val):
60
+ loss_val = _find_loss_from_output_and_spec(out, spec)
61
+ if loss_val is not None:
62
+ return loss_val
63
+ raise RuntimeError(f"Did not find loss value in specification {spec_val}")
64
+
65
+ if isinstance(spec_val, dict):
66
+ if not isinstance(output_val, dict):
67
+ raise RuntimeError(
68
+ f"Output value {output_val} must match type of loss specification "
69
+ f"{spec_val}"
70
+ )
71
+ if set(output_val.keys()) != set(spec_val.keys()):
72
+ raise RuntimeError(
73
+ f"Output value {output_val} must match keys of loss specification "
74
+ f"{spec_val}"
75
+ )
76
+ for k in spec_val:
77
+ loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k])
78
+ if loss_val is not None:
79
+ return loss_val
80
+ raise RuntimeError(f"Did not find loss value in specification {spec_val}")
81
+
82
+ raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification")
83
+
84
+
85
+ def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec):
86
+ output_nodes = [n for n in g.nodes if n.op == "output"]
87
+ assert len(output_nodes) == 1
88
+ output_node = output_nodes[0]
89
+ output_val = output_node.args[0]
90
+ generated_spec: Any = None
91
+
92
+ if isinstance(mod, TrivialLossWrapper):
93
+ # TrivialLossWrapper is pre-defined by PiPPy.
94
+ # It has loss as the only output so we can safely assume the first output arg is the loss.
95
+ assert len(output_node.args) == 1
96
+ loss_node = output_val
97
+ generated_spec = TrivialLossWrapper.loss_spec
98
+ elif output_loss_value_spec is None:
99
+ # Use default spec, i.e. search for "loss" in output values
100
+ if isinstance(output_val, dict) and "loss" in output_val.keys():
101
+ loss_node = output_val["loss"]
102
+ generated_spec = {k: k == "loss" for k in output_val}
103
+ else:
104
+ loss_node = None
105
+ generated_spec = None
106
+ else:
107
+ loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec)
108
+ generated_spec = output_loss_value_spec
109
+
110
+ return loss_node, output_node, generated_spec
111
+
112
+
113
+ def _insert_stage_symbolic_backward(
114
+ g: fx.Graph,
115
+ loss_node: fx.Node,
116
+ output_node: fx.Node,
117
+ ):
118
+ # Collect metadata about tuple output values. TODO: move this to split_module or FX IR
119
+ tuples: Dict[fx.Node, Tuple] = {}
120
+ for node in reversed(g.nodes):
121
+ if node.op == "call_function":
122
+ # In the forward pass, only emit placeholder, module calls, and
123
+ # getitem calls. If we have a target other than getitem in this
124
+ # (forward-only) code, there is a bug.
125
+ assert node.target == operator.getitem, (
126
+ "Found non-getitem call in forward pass. "
127
+ "Please report a bug to PiPPy"
128
+ )
129
+ assert (
130
+ len(node.args) == 2
131
+ ), "Found malformed getitem call. Please report a bug to PiPPy"
132
+ indexed_value, node_idx = tuple(node.args)
133
+
134
+ # indexed_value is a collection that we are indexing into. It could
135
+ # exist in the tuples map if we've processed another `getitem`
136
+ # already.
137
+ existing_list_size = (
138
+ len(tuples[indexed_value]) if indexed_value in tuples else -1
139
+ )
140
+ new_list_size = max(node_idx + 1, existing_list_size)
141
+
142
+ reconstructed_list = [None for _ in range(new_list_size)]
143
+
144
+ # Copy over existing elements if present
145
+ if indexed_value in tuples:
146
+ for i, val in enumerate(tuples[indexed_value]):
147
+ reconstructed_list[i] = val
148
+
149
+ # Populate value represented by this node
150
+ reconstructed_list[node_idx] = node
151
+
152
+ tuples[indexed_value] = tuple(reconstructed_list)
153
+
154
+ # Keep track of nodes that dominate the loss node.
155
+ # We will only emit backward operations for nodes that can contribute
156
+ # to the specified loss value.
157
+ live_nodes = {loss_node: None}
158
+ val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
159
+
160
+ def assign_or_accumulate_grad(forward_node, grad_value):
161
+ if forward_node in val_to_grad and forward_node.op != "placeholder":
162
+ grad_value = g.call_function(
163
+ _null_coalesce_accumulate,
164
+ (val_to_grad[forward_node], grad_value),
165
+ )
166
+ val_to_grad[forward_node] = grad_value
167
+
168
+ with g.inserting_before(output_node):
169
+ for node in reversed(g.nodes):
170
+ if node not in live_nodes:
171
+ continue
172
+
173
+ def add_to_live_nodes(n):
174
+ live_nodes.setdefault(n, None)
175
+
176
+ fx.node.map_arg(node.args, add_to_live_nodes)
177
+ fx.node.map_arg(node.kwargs, add_to_live_nodes)
178
+ if node.op == "call_module":
179
+ output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]]
180
+ if node in tuples:
181
+ stage_output = tuples[node]
182
+ output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
183
+ outputs_with_grads_idxs = [
184
+ i for i, n in enumerate(tuples[node]) if n in live_nodes
185
+ ]
186
+ else:
187
+ stage_output = (node,)
188
+ output_grads = val_to_grad[node]
189
+ outputs_with_grads_idxs = [0]
190
+
191
+ output_grads = (
192
+ (output_grads,)
193
+ if not isinstance(output_grads, tuple)
194
+ else output_grads
195
+ )
196
+
197
+ grad_call = g.call_function(
198
+ stage_backward,
199
+ kwargs={
200
+ "stage_output": stage_output,
201
+ "output_grads": output_grads,
202
+ "input_values": list(node.all_input_nodes),
203
+ "outputs_with_grads_idxs": outputs_with_grads_idxs,
204
+ },
205
+ )
206
+ # Insert backward stage debug info
207
+ kwargs_copy = dict(grad_call.kwargs)
208
+ grad_call.kwargs = kwargs_copy
209
+
210
+ grad_call_proxy = fx.Proxy(grad_call)
211
+ grads = grad_call_proxy.node
212
+
213
+ input_nodes = list(node.all_input_nodes)
214
+ grads_proxy = fx.Proxy(grads)
215
+ for i, input_node in enumerate(input_nodes):
216
+ assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index]
217
+
218
+ return g
219
+
220
+
221
+ class PipeSequential(torch.nn.Sequential):
222
+ @staticmethod
223
+ def from_sequential(sequential_instance: torch.nn.Sequential):
224
+ return PipeSequential(*[copy.copy(m) for m in sequential_instance])
225
+
226
+ def forward(self, input):
227
+ for i, module in enumerate(self):
228
+ input = module(input)
229
+ if i != len(self) - 1:
230
+ pipe_split()
231
+ return input
232
+
233
+
234
+ class LossWrapper(torch.nn.Module):
235
+ """
236
+ LossWrapper is a convenient abstract class that allows you to wrap up both
237
+ your model as well as its loss function and specify the connectivity between
238
+ the inputs, model, loss function, and output value. Example::
239
+
240
+ class MyModelWrapper(LossWrapper):
241
+ def forward(self, x, targets):
242
+ model_out = self.module(x)
243
+ loss_value = self.loss_fn(model_out, targets)
244
+ return loss_value
245
+
246
+ The above example defines a connectivity where we expect the forward/loss/backward
247
+ training procedure to take two arguments (x and targets), pass x into the module
248
+ to get the output of the feedforward computation, pass the model output and the
249
+ targets value into the loss function, and get and return the loss value, which will
250
+ be backpropagated by PiPPy. The above class would then be instantiated like::
251
+
252
+ model = ... # instantiate the model
253
+ loss_fn = torch.nn.MSELoss() # for the sake of demonstration
254
+
255
+ wrapper = MyModelWrapper(model, loss_fn)
256
+ pipe = Pipe.from_tracing(wrapper, ...)
257
+
258
+ """
259
+
260
+ def __init__(self, module, loss_fn):
261
+ super().__init__()
262
+ self.module = module
263
+ self.loss_fn = loss_fn
264
+
265
+ def forward(self, *args, **kwargs):
266
+ raise NotImplementedError(
267
+ "This instance of LossWrapper does not have an overridden"
268
+ "forward(). Please implement forward() to specify the arguments, "
269
+ "connection between the module and loss, and loss output "
270
+ "value."
271
+ )
272
+
273
+
274
+ class TrivialLossWrapper(LossWrapper):
275
+ def forward(self, x, targets):
276
+ model_out = self.module(x)
277
+ return self.loss_fn(model_out, targets)
278
+
279
+ loss_spec = True
280
+
281
+
282
+ # Pipe model representation
283
+ #
284
+ # Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies
285
+ # a single topological ordering of pipeline "stages" that, when run in series,
286
+ # constitutes all of the operations of the program. However, unlike `nn.Sequential`,
287
+ # Pipe allows non-local usages of values, so long as those uses still respect
288
+ # topological ordering. In particular:
289
+ #
290
+ # 1. Non-local activations. This type of usage can appear in, for example, skip
291
+ # connections. These values will be directly transmitted from the "def" stage
292
+ # to all stages that use them skipping intermediate stages. During autograd,
293
+ # gradients will be propagated back through this skip connection reverse
294
+ # to how activations propagated in the forward pass.
295
+ # 2. Non-local parameter/module invocations. This occurs when a parameter is used
296
+ # in a stage downstream of where it is resident. These values can be carried
297
+ # forward similarly to (1), but in addition one might want to replicate the
298
+ # value on multiple stages. Gradients for these shared parameters will be
299
+ # accumulated separately on each stage, but there will be an additional
300
+ # gradient accumulation before the optimizer step.
301
+
302
+
303
+ # Register `_pipe_split()` as an ATen operator. This is required for Export to
304
+ # preserve this marker in the graph.
305
+ torch.library.define("pippy::_pipe_split", "() -> ()")
306
+
307
+
308
+ @torch.library.impl("pippy::_pipe_split", "BackendSelect")
309
+ def _pipe_split():
310
+ return None
311
+
312
+
313
+ @torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef]
314
+ def _pipe_split(): # noqa: F811
315
+ return None
316
+
317
+
318
+ # Add an alias for convenience
319
+ aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
320
+
321
+ # Ask Export to preserve the `_pipe_split` op.
322
+ # See examples in pytorch/torch/fx/node.py
323
+ fx.node._side_effectful_functions.add(aten_pipe_split_alias)
324
+
325
+
326
+ # User facing API
327
+ def pipe_split():
328
+ """
329
+ pipe_split is a special operator that is used to mark the boundary between
330
+ stages in a module. It is used to split the module into stages. It is a
331
+ no-op if your annotated module is run eagerly.
332
+
333
+ Example:
334
+ >>> # xdoctest: +SKIP
335
+ >>> def forward(self, x):
336
+ >>> x = torch.mm(x, self.mm_param)
337
+ >>> x = torch.relu(x)
338
+ >>> pipe_split()
339
+ >>> x = self.lin(x)
340
+ >>> return x
341
+
342
+ The above example will be split into two stages.
343
+ """
344
+ return torch.ops.pippy._pipe_split()
345
+
346
+
347
+ class MultiUseParameterConfig(Enum):
348
+ TRANSMIT = 1
349
+ REPLICATE = 2
350
+
351
+
352
+ MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]]
353
+
354
+
355
+ class DetachExecutor(fx.Interpreter):
356
+ """
357
+ Special interpreter to run the split_gm in testing that detaches all inputs to
358
+ a module invocation. This is needed so that the values at the boundary are
359
+ leaf modules in autograd execution.
360
+ """
361
+
362
+ def __init__(self, module, garbage_collect_values=True):
363
+ garbage_collect_values = False
364
+ super().__init__(module, garbage_collect_values)
365
+ self.value_remap = {}
366
+
367
+ def run(self, *args, initial_env=None):
368
+ self.value_remap = {}
369
+ return super().run(*args, initial_env=initial_env)
370
+
371
+ def call_module(self, target, args, kwargs):
372
+ def detach_tensors(a):
373
+ if isinstance(a, torch.Tensor) and a.requires_grad:
374
+ if a not in self.value_remap:
375
+ new_val = a.detach().requires_grad_(True)
376
+ self.value_remap[a] = new_val
377
+ return self.value_remap[a]
378
+ else:
379
+ return a
380
+
381
+ """
382
+ def dont_traverse_size(a):
383
+ return type(a) != torch.Size
384
+ """
385
+
386
+ args = map_aggregate(
387
+ args,
388
+ detach_tensors, # dont_traverse_size
389
+ )
390
+ kwargs = map_aggregate(
391
+ kwargs,
392
+ detach_tensors, # dont_traverse_size
393
+ )
394
+
395
+ return super().call_module(target, args, kwargs)
396
+
397
+ def call_function(self, target, args, kwargs):
398
+ # HACK to reroute saved input tensors to point to the detach()ed version
399
+ if target == stage_backward:
400
+ kwargs = dict(kwargs)
401
+ kwargs["input_values"] = [
402
+ self.value_remap.get(v, v) for v in kwargs["input_values"]
403
+ ]
404
+ return super().call_function(target, args, kwargs)
405
+
406
+
407
+ class _NodeReference:
408
+ def __init__(self, name):
409
+ self.name = name
410
+
411
+ name: str
412
+
413
+
414
+ class _LinearNodeList:
415
+ def __init__(self, node_list):
416
+ self.serialize_node_list = []
417
+ for node in node_list:
418
+ node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
419
+ node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
420
+ serialize_node = fx.Node(
421
+ graph=None, # type: ignore[arg-type]
422
+ name=node.name,
423
+ op=node.op,
424
+ target=node.target,
425
+ args=node_args, # type: ignore[arg-type]
426
+ kwargs=node_kwargs, # type: ignore[arg-type]
427
+ return_type=node.type,
428
+ )
429
+ serialize_node.meta = copy.copy(node.meta)
430
+ self.serialize_node_list.append(serialize_node)
431
+
432
+ def to_graph(self):
433
+ graph = fx.Graph()
434
+
435
+ ref_str_to_node: Dict[str, fx.Node] = {}
436
+
437
+ def ref_to_node(arg):
438
+ if isinstance(arg, _NodeReference):
439
+ return ref_str_to_node[arg.name]
440
+ else:
441
+ return arg
442
+
443
+ for node in self.serialize_node_list:
444
+ node_args = map_aggregate(node.args, ref_to_node)
445
+ node_kwargs = map_aggregate(node.kwargs, ref_to_node)
446
+ deser_node = graph.create_node(
447
+ op=node.op,
448
+ target=node.target,
449
+ args=node_args, # type: ignore[arg-type]
450
+ kwargs=node_kwargs, # type: ignore[arg-type]
451
+ name=node.name,
452
+ type_expr=node.type,
453
+ )
454
+ ref_str_to_node[node.name] = deser_node
455
+
456
+ return graph
457
+
458
+
459
+ def _direct_serialization_deserialize(body, nodes):
460
+ """
461
+ Custom `__reduce__` method for serialization.
462
+ DO AS I SAY -- NOT AS I DO. This violates the principle that
463
+ GraphModules serialize via code export & re-tracing. We allow
464
+ for this here because **PIPE STAGES SHOULD NOT BE PERSISTED
465
+ TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting
466
+ these instances to disk will expose internal implementation
467
+ details of `fx.Graph` and related data structures and is
468
+ NOT advised.
469
+ """
470
+
471
+ class DummyModule(torch.nn.Module):
472
+ def __init__(self, body):
473
+ super().__init__()
474
+ self.__dict__.update(body)
475
+
476
+ dummy = DummyModule(body)
477
+
478
+ return fx.GraphModule(dummy, nodes.to_graph())
479
+
480
+
481
+ def _direct_serialization_reduce(self):
482
+ serialization_dict = dict(self.__dict__)
483
+ serialization_dict.pop("_graph")
484
+ return (
485
+ _direct_serialization_deserialize,
486
+ (serialization_dict, _LinearNodeList(self.graph.nodes)),
487
+ )
488
+
489
+
490
+ def _modify_graph_op_device(
491
+ gm: torch.fx.GraphModule,
492
+ new_device: torch.device,
493
+ ):
494
+ """
495
+ Modify the device argument of all "call_function" nodes in the graph. This
496
+ is useful for moving the graph to a different device. In particular for
497
+ generator ops, like torch.ones.
498
+ """
499
+ modified = False
500
+ for node in gm.graph.nodes:
501
+ if node.op == "call_function":
502
+ if "device" in node.kwargs and node.kwargs["device"] != new_device:
503
+ logger.debug(
504
+ f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
505
+ )
506
+ node.update_kwarg("device", new_device)
507
+ modified = True
508
+ elif node.op == "call_module":
509
+ # Recursively modify "device" in submodules
510
+ submod = gm.get_submodule(node.target)
511
+ if isinstance(submod, torch.fx.GraphModule):
512
+ _modify_graph_op_device(submod, new_device)
513
+ elif isinstance(submod, InterpreterModule):
514
+ # If unflattening has been performed, we need to access its graph module by `.graph_module`
515
+ _modify_graph_op_device(submod.graph_module, new_device)
516
+ else:
517
+ logger.warning(
518
+ f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
519
+ )
520
+
521
+ if modified:
522
+ gm.recompile()
523
+
524
+
525
+ class Pipe(torch.nn.Module):
526
+ def __init__(
527
+ self,
528
+ split_gm: fx.GraphModule,
529
+ num_stages: int,
530
+ has_loss_and_backward: bool,
531
+ loss_spec,
532
+ ):
533
+ # TODO: is there a way not to hard wire init?
534
+ torch.nn.Module.__init__(self)
535
+ self.split_gm: fx.GraphModule = split_gm
536
+ self.executor: DetachExecutor = DetachExecutor(self.split_gm)
537
+ self.num_stages: int = num_stages
538
+ self.has_loss_and_backward = has_loss_and_backward
539
+ self.loss_spec = loss_spec
540
+
541
+ for node in split_gm.graph.nodes:
542
+ assert (
543
+ node.op in {"call_module", "placeholder", "output"}
544
+ or (node.op, node.target) == ("call_function", operator.getitem)
545
+ or (node.op, node.target) == ("call_method", "backward")
546
+ or (node.op, node.target) == ("call_function", stage_backward)
547
+ or (node.op, node.target)
548
+ == ("call_function", _null_coalesce_accumulate)
549
+ ), node
550
+
551
+ # Detect replicated parameters so we know that we have to do an additional allreduce
552
+ # before applying the optimizer
553
+ #
554
+ # Note that this also handles the case where there were multiple calls to a single
555
+ # module from different stages, regardless of whether that module invocation
556
+ # was handled by the logic above.
557
+
558
+ # Map parameter value to a dictionary that maps the user pipeline module
559
+ # to the local qualname within that module
560
+ params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {}
561
+
562
+ for m_qualname, mod in self.split_gm.named_children():
563
+ for p_qualname, param in mod.named_parameters():
564
+ params_to_users.setdefault(param, {})
565
+ params_to_users[param][m_qualname] = p_qualname
566
+
567
+ self.replicated_params: List[Dict[str, str]] = [
568
+ use_mapping
569
+ for _, use_mapping in params_to_users.items()
570
+ if len(use_mapping) > 1
571
+ ]
572
+
573
+ # We must break the aliasing relationship between the replicated parameters for correct
574
+ # numerics in reference runs. If we do not do this, the autograd tape in separate stages
575
+ # will have a reference to the same tensor value and will erroneously apply gradient
576
+ # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the
577
+ # values so that we have separate instances.
578
+ for param_mapping in self.replicated_params:
579
+ for submod_name, param_qualname in param_mapping.items():
580
+ submod = getattr(self.split_gm, submod_name)
581
+ atoms = param_qualname.split(".")
582
+ for atom in atoms[:-1]:
583
+ submod = getattr(submod, atom)
584
+ setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
585
+
586
+ def throw(self, *args, **kwargs):
587
+ raise RuntimeError(
588
+ "To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
589
+ )
590
+
591
+ self.split_gm.forward = throw
592
+
593
+ # Make submodules use custom direct-serialized GraphModule
594
+ i = 0
595
+ while True:
596
+ try:
597
+ name = f"submod_{i}"
598
+ submod = getattr(self.split_gm, name)
599
+ submod.__class__.__reduce__ = _direct_serialization_reduce
600
+ i += 1
601
+ except AttributeError:
602
+ break
603
+
604
+ def forward(self, *args, **kwargs):
605
+ executor_args = args
606
+ if len(kwargs) > 0:
607
+ parameters = []
608
+ for node in self.split_gm.graph.nodes:
609
+ if node.op == "placeholder":
610
+ if node.args and len(node.args) > 0:
611
+ parameters.append(
612
+ Parameter(
613
+ node.target,
614
+ Parameter.POSITIONAL_OR_KEYWORD,
615
+ default=node.args[0],
616
+ )
617
+ )
618
+ else:
619
+ parameter_kind = Parameter.POSITIONAL_OR_KEYWORD
620
+ param_name = node.target
621
+ if node.target.startswith("**"):
622
+ parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment]
623
+ param_name = param_name[2:]
624
+ elif node.target.startswith("*"):
625
+ parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment]
626
+ param_name = param_name[1:]
627
+ parameters.append(Parameter(param_name, parameter_kind))
628
+ signature = Signature(parameters)
629
+ ba = signature.bind(*args, **kwargs)
630
+ ba.apply_defaults()
631
+ executor_args = ba.arguments.values() # type: ignore[assignment]
632
+
633
+ res = self.executor.run(*executor_args)
634
+
635
+ return res
636
+
637
+ def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
638
+ """
639
+ Return a stage module corresponding to `stage_idx` of the `pipe`.
640
+ """
641
+ if stage_idx < 0 or stage_idx >= self.num_stages:
642
+ raise ValueError(f"Invalid stage index {stage_idx}!")
643
+ return getattr(self.split_gm, f"submod_{stage_idx}")
644
+
645
+ @staticmethod
646
+ def _number_and_count_forward_stages(gm: fx.GraphModule):
647
+ num_stages = 0
648
+ found_idxs: Dict[int, None] = {}
649
+ for node in gm.graph.nodes:
650
+ if node.op == "call_module" and node.target.startswith("submod_"):
651
+ node.meta["stage_idx"] = int(node.target[len("submod_") :])
652
+ found_idxs.setdefault(node.meta["stage_idx"])
653
+ num_stages += 1
654
+
655
+ # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule
656
+ # Update: the following assert may fail against some torch versions >=
657
+ # 2.2.0, as:
658
+ # submod_0, submod_1, submod_2, ...
659
+ # may be named as
660
+ # submod_0, submod_2, submod_4, ...
661
+ # TODO: investigate
662
+ # assert all(i in found_idxs for i in range(num_stages))
663
+
664
+ return num_stages
665
+
666
+ @staticmethod
667
+ def _from_traced(
668
+ mod: torch.nn.Module,
669
+ exported_program: ExportedProgram,
670
+ multi_use_param_spec: Optional[MultiUseParamSpec] = None,
671
+ output_loss_value_spec=None,
672
+ split_policy: Optional[
673
+ Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
674
+ ] = None,
675
+ ):
676
+ """
677
+ Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
678
+ which value in the output of `forward` is the loss value on which PiPPy should apply
679
+ backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``,
680
+ you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns
681
+ a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify
682
+ ``output_loss_value_spec={'loss': True, 'model_out': False}``
683
+ """
684
+
685
+ traced = exported_program.module()
686
+
687
+ if split_policy is not None:
688
+ logger.info("Auto-splitting model")
689
+ traced = split_policy(traced) # type: ignore[arg-type]
690
+
691
+ logger.debug(traced.print_readable(print_output=False))
692
+
693
+ # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
694
+ # parameters relies on the invariant that parameter accesses happen once. This is not necessarily
695
+ # the case (especially with custom tracers), so fix that up here.
696
+ get_attr_nodes: Dict[str, fx.Node] = {}
697
+ for node in traced.graph.nodes:
698
+ if node.op == "get_attr":
699
+ get_attr_nodes.setdefault(node.target, node)
700
+
701
+ if get_attr_nodes[node.target] != node:
702
+ node.replace_all_uses_with(get_attr_nodes[node.target])
703
+ traced.graph.erase_node(node)
704
+
705
+ # avoid looking at next node by keeping track of previous pipe_split
706
+ prev_pipe_split_idx = -1
707
+ pipe_split_nodes_to_erase = set()
708
+ for i, node in enumerate(traced.graph.nodes):
709
+ if (node.op, node.target) == ("call_function", pipe_split):
710
+ if prev_pipe_split_idx == i - 1:
711
+ pipe_split_nodes_to_erase.add(node)
712
+ prev_pipe_split_idx = i
713
+
714
+ for node in pipe_split_nodes_to_erase:
715
+ traced.graph.erase_node(node)
716
+
717
+ traced.recompile()
718
+
719
+ part_idx = 0
720
+
721
+ def split_callback(n: fx.Node):
722
+ nonlocal part_idx
723
+ if (n.op, n.target) == (
724
+ "call_function",
725
+ aten_pipe_split_alias,
726
+ ):
727
+ logger.debug(f"Found pipe_split {part_idx}") # noqa: G004
728
+ part_idx += 1
729
+ return part_idx
730
+
731
+ # TODO: what does split do with module invocations? does it move the modules
732
+ # into the submodules?
733
+ split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
734
+ # a (custom) tracer can produce dead code like orphan get_attr nodes
735
+ split.graph.eliminate_dead_code()
736
+
737
+ # peephole to remove pipe_split
738
+ for submodule in split.modules():
739
+ if isinstance(submodule, fx.GraphModule):
740
+ for node in submodule.graph.nodes:
741
+ if (node.op, node.target) == (
742
+ "call_function",
743
+ aten_pipe_split_alias,
744
+ ):
745
+ submodule.graph.erase_node(node)
746
+ submodule.recompile()
747
+
748
+ for name, submodule in split.named_children():
749
+ if isinstance(submodule, fx.GraphModule):
750
+ new_submod = _outline_submodules(submodule.graph)
751
+ # Replace old submod
752
+ split.register_module(name, new_submod)
753
+
754
+ # TODO: backport this into split_module
755
+ def delete_user_reference(node, user):
756
+ """
757
+ Delete reference of `node` from `user`'s arg list.
758
+ Args:
759
+ - node: a `get_attr` node at root.
760
+ - user: a submodule node that uses `node`.
761
+ """
762
+ assert len(user.kwargs) == 0
763
+ use_idxs = [i for i, arg in enumerate(user.args) if arg == node]
764
+ assert len(use_idxs) == 1
765
+ args_copy = list(user.args)
766
+ args_copy.pop(use_idxs[0])
767
+ user.args = tuple(args_copy)
768
+ logger.debug(
769
+ f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004
770
+ )
771
+
772
+ # A list of param referrals for deferred deletion.
773
+ # To be accumulated in `move_param_to_callee`.
774
+ to_delete = []
775
+
776
+ def _recursive_getattr_with_parent(mod, fqn):
777
+ # Returns getattr call given a nested FQN, and the last parent
778
+ atoms = fqn.split(".")
779
+ for atom in atoms[:-1]:
780
+ if not hasattr(mod, atom):
781
+ return None, None
782
+ mod = getattr(mod, atom)
783
+ if not hasattr(mod, atoms[-1]):
784
+ return mod, None
785
+ attr = getattr(mod, atoms[-1])
786
+ return mod, attr
787
+
788
+ def move_param_to_callee(
789
+ root,
790
+ callee_name,
791
+ param_fqn,
792
+ ):
793
+ """
794
+ Move a parameter from the root module to a submodule.
795
+ Args:
796
+ root: The root module.
797
+ callee_name: The name of the submodule to move the parameter to.
798
+ param_fqn: The fully qualified name of the parameter to move.
799
+ """
800
+ # `atoms` is a list of strings representing the path to the
801
+ # parameter in the original model
802
+ atoms = param_fqn.split(".")
803
+ mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn)
804
+ # Check whether the parameter is a buffer or a parameter
805
+ is_buffer = atoms[-1] in mod_itr._buffers
806
+
807
+ # Check whether the parameter is a tensor
808
+ assert isinstance(param_val, torch.Tensor), (
809
+ f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}."
810
+ + (
811
+ f" It might happen if module '{param_fqn}' was passed to some 'leaf function'"
812
+ f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect "
813
+ f"usages of '{param_fqn}' in the traced graph."
814
+ if isinstance(param_val, torch.nn.Module)
815
+ else ""
816
+ )
817
+ )
818
+
819
+ # Get submodule
820
+ callee = root.get_submodule(callee_name)
821
+ assert not hasattr(
822
+ callee, param_fqn
823
+ ), f"Module {callee_name} already has a parameter named {param_fqn}"
824
+
825
+ # Assign the parameter to the submodule
826
+ if is_buffer:
827
+ _assign_attr(
828
+ param_val,
829
+ callee,
830
+ param_fqn,
831
+ attr_kind=_AttrKind.BUFFER,
832
+ persistent=True, # TODO: handle non-persistent buffer
833
+ )
834
+ else:
835
+ _assign_attr(
836
+ param_val,
837
+ callee,
838
+ param_fqn,
839
+ attr_kind=_AttrKind.PARAMETER,
840
+ )
841
+ logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004
842
+
843
+ # Next step is to replace placeholder of submodule with a get_attr.
844
+ # Those placeholders are created by `split_module` inside each
845
+ # submodule.
846
+ # Update: this step is now moved to `_sink_params` because
847
+ # `_sink_params` can do it recursively (i.e. for modules inside
848
+ # submodule)
849
+
850
+ to_delete.append((mod_itr, atoms[-1]))
851
+
852
+ # Get the list of all parameters in the root module
853
+ attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes))
854
+ for node in attr_nodes:
855
+ # Check whether the parameter is used in only one submodule
856
+ if len(node.users) > 1:
857
+ logger.info(
858
+ f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004
859
+ )
860
+ for user in node.users:
861
+ assert user.op == "call_module"
862
+ # Move parameter into submodule
863
+ move_param_to_callee(
864
+ split,
865
+ user.target,
866
+ node.target,
867
+ )
868
+
869
+ # [aliasing] store tensor id -> list of FQNs, built from state dict
870
+ # Also assign non-persistent buffers
871
+ id_to_fqns: Dict[int, Set[str]] = defaultdict(set)
872
+ for fqn, tensor in mod.state_dict(keep_vars=True).items():
873
+ id_to_fqns[id(tensor)].add(fqn)
874
+ for fqn, tensor in mod.named_buffers():
875
+ id_to_fqns[id(tensor)].add(fqn)
876
+
877
+ # After moving the params to their corresponding hierarchies, we also
878
+ # need to move the `get_attr` nodes from the root of the graph to those
879
+ # hierarchies.
880
+ # [aliasing] use id -> fqn mapping to list out all valid FQNs
881
+ inputs_to_state: Dict[str, List[str]] = {}
882
+ for attr in attr_nodes:
883
+ _, tensor = _recursive_getattr_with_parent(mod, attr.target)
884
+ fqns = list(id_to_fqns[id(tensor)])
885
+ if fqns:
886
+ inputs_to_state[attr.name] = fqns
887
+ elif attr.target in exported_program.constants: # lifted constants
888
+ inputs_to_state[attr.name] = [attr.target]
889
+
890
+ # [aliasing] for each submodule split, assign attributes on FQNs that may be used.
891
+ # We determine this based on whether or not the FQN attribute parent exists.
892
+ # i.e. if the last submodule exists, assign the attribute.
893
+ added_attributes: Dict[str, List[str]] = defaultdict(list)
894
+ for fqn, tensor in mod.state_dict(keep_vars=True).items():
895
+ for name, submod in split.named_children():
896
+ if isinstance(submod, fx.GraphModule):
897
+ parent, child = _recursive_getattr_with_parent(submod, fqn)
898
+ if (
899
+ parent and child is None
900
+ ): # parent exists, attribute doesn't -> assign
901
+ added_attributes[name].append(fqn)
902
+ setattr(parent, fqn.split(".")[-1], tensor)
903
+
904
+ # Deferral deletion: Remove the original attributes (to params) from the
905
+ # root GraphModule
906
+ for mod_itr, last_atom in to_delete:
907
+ try:
908
+ delattr(mod_itr, last_atom)
909
+ except AttributeError:
910
+ # This is expected if the parameter is used in multiple stages
911
+ pass
912
+
913
+ # This is done by (1) `_sink_params` at each submodule;
914
+ for name, submod in split.named_children():
915
+ if isinstance(submod, fx.GraphModule):
916
+ _sink_params(submod, inputs_to_state, [])
917
+ submod.graph.lint()
918
+ submod.recompile()
919
+
920
+ # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory.
921
+ # After _sink_params() routine has run, clean up unused attributes that we previously added.
922
+ # Determine this based on the get_attr nodes - if not used, remove it.
923
+ for name, attributes in added_attributes.items():
924
+ submod = getattr(split, name)
925
+ unused_attributes = set(attributes)
926
+ # track used attributes in the submodule, running DFS on subgraph hierarchy
927
+ stack = [("", submod)] # (scope, submodule)
928
+ while stack:
929
+ scope, _mod = stack.pop()
930
+ if isinstance(_mod, (fx.GraphModule, InterpreterModule)):
931
+ for node in _mod.graph.nodes:
932
+ if node.op == "get_attr":
933
+ # get_attr might get access deeper level attribute
934
+ fqn = scope + "." + node.target if scope else node.target
935
+ if fqn in unused_attributes: # used, remove it
936
+ unused_attributes.remove(fqn)
937
+ for _name, _submod in _mod.named_children():
938
+ stack.append((scope + "." + _name if scope else _name, _submod))
939
+ # delete unused attributes
940
+ for attr in unused_attributes:
941
+ mod_itr, atoms = submod, attr.split(".")
942
+ for atom in atoms[:-1]:
943
+ mod_itr = getattr(mod_itr, atom)
944
+ delattr(mod_itr, atoms[-1])
945
+
946
+ for node in attr_nodes:
947
+ # And (2): remove `get_attr` node from submod's arg list
948
+ for user in copy.copy(node.users):
949
+ assert user.op == "call_module"
950
+ delete_user_reference(node, user)
951
+ # And (3): remove the `get_attr` node from the root graph.
952
+ split.graph.erase_node(node)
953
+
954
+ split.delete_all_unused_submodules()
955
+ split.graph.lint()
956
+ split.recompile()
957
+
958
+ num_stages = Pipe._number_and_count_forward_stages(split)
959
+
960
+ has_loss_and_backward = False
961
+ generated_loss_spec = output_loss_value_spec
962
+
963
+ if output_loss_value_spec is not None:
964
+ loss_node, output_node, generated_loss_spec = _find_loss_output(
965
+ mod, split.graph, output_loss_value_spec
966
+ )
967
+ if loss_node is not None:
968
+ _insert_stage_symbolic_backward(
969
+ split.graph,
970
+ loss_node,
971
+ output_node,
972
+ )
973
+ split.recompile()
974
+ has_loss_and_backward = True
975
+ logger.debug("Pipeline is in training mode, backward pass generated")
976
+ else:
977
+ raise RuntimeError(
978
+ f"Did not find any loss value according to {output_loss_value_spec=}"
979
+ )
980
+ else:
981
+ logger.debug("Pipeline is in inference mode, backward pass not generated")
982
+
983
+ logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
984
+
985
+ return Pipe(
986
+ split,
987
+ num_stages,
988
+ has_loss_and_backward,
989
+ generated_loss_spec,
990
+ )
991
+
992
+ def print_readable(self):
993
+ """
994
+ Print the pipe in a human-readable format.
995
+ This will print both the root pipe and each stage module.
996
+ """
997
+ self.split_gm.print_readable()
998
+
999
+ @staticmethod
1000
+ def _trace_with_export(
1001
+ mod: torch.nn.Module,
1002
+ example_args: Tuple[Any, ...],
1003
+ example_kwargs: Optional[Dict[str, Any]] = None,
1004
+ ) -> ExportedProgram:
1005
+ logger.info("Tracing model ...")
1006
+ try:
1007
+ ep = torch.export.export(
1008
+ mod,
1009
+ example_args,
1010
+ example_kwargs,
1011
+ )
1012
+ except Exception as e:
1013
+ raise RuntimeError(
1014
+ "It seems that we cannot capture your model as a full graph. "
1015
+ "Typical reasons include graph breaks, data/shape-dependent "
1016
+ "control flow, or missing meta kernels for custom operators. "
1017
+ "You can use our manual pipeline interfaces, or try to fix the "
1018
+ "graph breaks, see https://pytorch.org/docs/stable/export.html"
1019
+ ) from e
1020
+
1021
+ return ep
1022
+
1023
+ @staticmethod
1024
+ def from_tracing(
1025
+ mod: torch.nn.Module,
1026
+ example_args: Tuple[Any, ...],
1027
+ example_kwargs: Optional[Dict[str, Any]] = None,
1028
+ split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
1029
+ ):
1030
+ # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
1031
+ # stages instead of TRANSMIT'ting it
1032
+ multi_use_param_spec = MultiUseParameterConfig.REPLICATE
1033
+
1034
+ # Figure out which output is loss from output_chunk_spec
1035
+ output_loss_value_spec: Any = None
1036
+ # Deprecated
1037
+ """
1038
+ if output_chunk_spec is not None:
1039
+ output_loss_value_spec = map_aggregate(
1040
+ output_chunk_spec, lambda v: isinstance(v, _LossReducer)
1041
+ )
1042
+ """
1043
+
1044
+ # Trace with export
1045
+ exported_program = Pipe._trace_with_export(
1046
+ mod,
1047
+ example_args,
1048
+ example_kwargs,
1049
+ )
1050
+
1051
+ pipe = Pipe._from_traced(
1052
+ mod,
1053
+ exported_program,
1054
+ multi_use_param_spec,
1055
+ output_loss_value_spec=output_loss_value_spec,
1056
+ split_policy=split_policy,
1057
+ )
1058
+
1059
+ # Users want the first pipeline stage to accept kwargs if the original
1060
+ # program does. This is controlled by the `_codegen` field of the graph,
1061
+ # so we make a copy here. Note: we only want the input spec and not the
1062
+ # output spec, because the output spec is for the last stage. Maybe a
1063
+ # TODO? Not sure yet.
1064
+ split = pipe.split_gm
1065
+ traced = exported_program.module()
1066
+ submod0 = next(iter(split.children()))
1067
+ submod0_sign = signature(submod0.forward)
1068
+ model_sign = signature(traced.forward)
1069
+ if len(model_sign.parameters) != len(submod0_sign.parameters):
1070
+ # We don't change the signature of the first stage if it takes
1071
+ # different number of args than original model
1072
+ logger.info(
1073
+ f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004
1074
+ f"first pipeline stage takes {len(submod0_sign.parameters)}. "
1075
+ "Please provide args to respective pipeline stages."
1076
+ )
1077
+ else:
1078
+ # Support kwargs for the first stage
1079
+ submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
1080
+ # `_replace` is actually not "private" or internal. based on this doc:
1081
+ # To prevent conflicts with field names, the method and attribute names
1082
+ # start with an underscore
1083
+ submod0.graph._codegen.pytree_info = (
1084
+ submod0.graph._codegen.pytree_info._replace(out_spec=None)
1085
+ )
1086
+ submod0.recompile()
1087
+
1088
+ return pipe
1089
+
1090
+ def __str__(self):
1091
+ return self.split_gm.__str__()
1092
+
1093
+ def __repr__(self):
1094
+ return self.split_gm.__repr__()
1095
+
1096
+ def info(self) -> PipeInfo:
1097
+ """
1098
+ Get information about the pipe.
1099
+
1100
+ Returns
1101
+ -------
1102
+ PipeInfo
1103
+ A dataclass containing information about the pipe.
1104
+ """
1105
+ return PipeInfo(
1106
+ graph=self.split_gm.graph,
1107
+ num_stages=self.num_stages,
1108
+ has_loss_and_backward=self.has_loss_and_backward,
1109
+ )
1110
+
1111
+ def build_stage(
1112
+ self,
1113
+ stage_index: int,
1114
+ device: torch.device,
1115
+ group: Optional[ProcessGroup] = None,
1116
+ ) -> _PipelineStage:
1117
+ """
1118
+ Create a `PipelineStage` given a stage index and distributed group.
1119
+ The `PipelineStage` can run with `PipelineSchedule`s.
1120
+ """
1121
+ # Find stage module
1122
+ stage_module = self.get_stage_module(stage_index)
1123
+
1124
+ # Move ops argument to device
1125
+ # Today PT2 tracer does not treat `x.device` as a symbolic device;
1126
+ # instead, the device of tracing time got burned into the generated
1127
+ # code. Here we provide a workaround for users to manually modify the
1128
+ # "device" kwarg of operations. Such operation may include:
1129
+ # `torch.ones`, `torch.zeros`, `torch.rand`, etc.
1130
+ if isinstance(stage_module, torch.fx.GraphModule):
1131
+ _modify_graph_op_device(stage_module, device)
1132
+ else:
1133
+ logger.warning(
1134
+ f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004
1135
+ )
1136
+
1137
+ # Detach pipe info
1138
+ # Note: be careful what's included in `pipe_info`. We don't want to keep
1139
+ # a reference to `Pipe` or `Pipe.split_gm` which stops python from
1140
+ # recycling them. When python recycles them, other stage modules (which
1141
+ # are irrelevant to current rank) can be automatically freed.
1142
+ pipe_info = self.info()
1143
+ return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
1144
+
1145
+
1146
+ class SplitPoint(Enum):
1147
+ BEGINNING = 1
1148
+ END = 2
1149
+
1150
+
1151
+ # For backward compatibility, we kept the PipeSplitWrapper class because `class
1152
+ # SplitPoint` used to be defined in this class.
1153
+ class PipeSplitWrapper:
1154
+ # Create a class alias for BC
1155
+ SplitPoint = SplitPoint
1156
+
1157
+
1158
+ def _split_before_forward(self, *args, **kwargs):
1159
+ pipe_split()
1160
+ return self._orig_forward(*args, **kwargs)
1161
+
1162
+
1163
+ def _split_after_forward(self, *args, **kwargs):
1164
+ try:
1165
+ return self._orig_forward(*args, **kwargs)
1166
+ finally:
1167
+ pipe_split()
1168
+
1169
+
1170
+ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
1171
+ # TODO: make this implementation out-of-place?
1172
+ for qualname, split_type in spec.items():
1173
+ atoms = qualname.split(".")
1174
+ predecessor_module = mod
1175
+ for i, atom in enumerate(atoms[:-1]):
1176
+ try:
1177
+ predecessor_module = getattr(predecessor_module, atom)
1178
+ except AttributeError as e:
1179
+ raise AttributeError(
1180
+ f"Specified target {qualname} referenced "
1181
+ f'nonexistent module {".".join(atoms[: i + 1])}'
1182
+ ) from e
1183
+
1184
+ mod_to_wrap = getattr(predecessor_module, atoms[-1])
1185
+ mod_to_wrap._orig_forward = mod_to_wrap.forward
1186
+ if split_type == SplitPoint.BEGINNING:
1187
+ mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap)
1188
+ elif split_type == SplitPoint.END:
1189
+ mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap)
1190
+ else:
1191
+ raise ValueError("Unknown split point type.")
1192
+
1193
+
1194
+ def pipeline(
1195
+ module: torch.nn.Module,
1196
+ mb_args: Tuple[Any, ...],
1197
+ mb_kwargs: Optional[Dict[str, Any]] = None,
1198
+ split_spec: Optional[Dict[str, SplitPoint]] = None,
1199
+ split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
1200
+ ) -> Pipe:
1201
+ """
1202
+ Split a module based on a specification.
1203
+
1204
+ See `Pipe` for more details.
1205
+
1206
+ Arguments
1207
+ ---------
1208
+ module:
1209
+ The module to be splitted.
1210
+ mb_args:
1211
+ Example positional inputs, in micro-batch form.
1212
+ mb_kwargs:
1213
+ Example keyword inputs, in micro-batch form. (default: `None`)
1214
+ split_spec:
1215
+ A dictionary using submodule names as split marker. (default: `None`)
1216
+ split_policy:
1217
+ The policy to use for splitting the module. (default: `None`)
1218
+
1219
+ Returns
1220
+ -------
1221
+ A pipeline representation of class `Pipe`.
1222
+ """
1223
+ if split_spec is not None and split_policy is not None:
1224
+ raise ValueError(
1225
+ "Cannot specify both `split_spec` and `split_policy`. Please use only one of them."
1226
+ )
1227
+
1228
+ if split_spec is not None:
1229
+ # Annotate split points in the module based on user spec
1230
+ annotate_split_points(module, split_spec)
1231
+ return Pipe.from_tracing(
1232
+ mod=module,
1233
+ example_args=mb_args,
1234
+ example_kwargs=mb_kwargs,
1235
+ )
1236
+ else:
1237
+ # Use split policy
1238
+ return Pipe.from_tracing(
1239
+ mod=module,
1240
+ example_args=mb_args,
1241
+ example_kwargs=mb_kwargs,
1242
+ split_policy=split_policy,
1243
+ )
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ from ._IR import Pipe, pipe_split, pipeline, SplitPoint
3
+ from .schedules import (
4
+ _ScheduleForwardOnly,
5
+ Schedule1F1B,
6
+ ScheduleFlexibleInterleaved1F1B,
7
+ ScheduleGPipe,
8
+ ScheduleInterleaved1F1B,
9
+ ScheduleInterleavedZeroBubble,
10
+ ScheduleLoopedBFS,
11
+ )
12
+ from .stage import build_stage, PipelineStage
13
+
14
+
15
+ __all__ = [
16
+ "Pipe",
17
+ "pipe_split",
18
+ "SplitPoint",
19
+ "pipeline",
20
+ "PipelineStage",
21
+ "build_stage",
22
+ "Schedule1F1B",
23
+ "ScheduleFlexibleInterleaved1F1B",
24
+ "ScheduleGPipe",
25
+ "ScheduleInterleaved1F1B",
26
+ "ScheduleLoopedBFS",
27
+ "ScheduleInterleavedZeroBubble",
28
+ ]
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-311.pyc ADDED
Binary file (54.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (842 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-311.pyc ADDED
Binary file (1.08 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-311.pyc ADDED
Binary file (1.26 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (4.77 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-311.pyc ADDED
Binary file (16.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-311.pyc ADDED
Binary file (91.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-311.pyc ADDED
Binary file (66 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_backward.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+ import collections
4
+ import logging
5
+ import weakref
6
+ from typing import Any, cast, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union
7
+
8
+ import torch
9
+ from torch.autograd.graph import GradientEdge, Node
10
+ from torch.nn import Parameter
11
+
12
+ from ._debug import map_debug_info
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
19
+ """
20
+ Get the grad function or grad accumulator for a tensor.
21
+
22
+ Accumulate grad nodes are lazily created, so we need to a
23
+ dummy view in order to trigger its creation.
24
+ """
25
+ if t.requires_grad and t.grad_fn is None:
26
+ # if no grad function (leaf tensors) we use view
27
+ viewed_t = t.view_as(t)
28
+ grad_fn = viewed_t.grad_fn
29
+ if grad_fn is not None:
30
+ return grad_fn.next_functions[0][0]
31
+ else:
32
+ raise RuntimeError(
33
+ "Attempted to get grad_fn, but got None."
34
+ "Is this being created in a no-grad context?"
35
+ )
36
+ else:
37
+ return t.grad_fn
38
+
39
+
40
+ def reverse_closure(
41
+ roots: List[Node], target_nodes: Set[Node]
42
+ ) -> Tuple[Set[Node], Set[Node]]:
43
+ """
44
+ This function returns the reverse closure of the given roots,
45
+ i.e. the set of nodes that can be reached from the roots by following the
46
+ reverse edges of the graph. The target_nodes are the nodes that we want to
47
+ include in the closure.
48
+ """
49
+ # Recurse until we reach a target node
50
+ closure: Set[Node] = set()
51
+ visited_target_nodes = set()
52
+ q: Deque[Node] = collections.deque()
53
+ for node in roots:
54
+ if node is not None and node not in closure:
55
+ closure.add(node)
56
+ q.append(node)
57
+ while q:
58
+ node = q.popleft()
59
+ metadata = cast(Dict[str, List], node.metadata)
60
+ reverse_edges = metadata.get("reverse_edges", [])
61
+ for holder_ref, idx in reverse_edges:
62
+ ref = holder_ref()
63
+ if ref is None:
64
+ # this reverse graph is no longer alive
65
+ # raise RuntimeError("Reverse graph is no longer alive")
66
+ continue
67
+ fn = ref.node
68
+ if fn in closure or fn is None:
69
+ continue
70
+ if fn in target_nodes:
71
+ visited_target_nodes.add(fn)
72
+ continue
73
+ closure.add(fn)
74
+ q.append(fn)
75
+ return closure, visited_target_nodes
76
+
77
+
78
+ # Enable weak pointer
79
+ class Holder:
80
+ def __init__(self, node: Node):
81
+ self.node = node
82
+
83
+
84
+ def construct_reverse_graph(roots: List[Node]) -> List[Holder]:
85
+ q: Deque[Node] = collections.deque()
86
+ root_seen: Set[Node] = set()
87
+ reverse_graph_refs: List[Holder] = []
88
+ for node in roots:
89
+ if node is not None and node not in root_seen:
90
+ q.append(node)
91
+ root_seen.add(node)
92
+ while q:
93
+ node = q.popleft()
94
+ for fn, idx in node.next_functions:
95
+ if fn is not None:
96
+ # Don't necessarily need to store on the graph
97
+ metadata = cast(Dict[str, List], fn.metadata)
98
+ reverse_edges = metadata.get("reverse_edges", [])
99
+ if len(reverse_edges) == 0:
100
+ q.append(fn)
101
+ holder = Holder(node)
102
+ holder_ref = weakref.ref(holder)
103
+ reverse_graph_refs.append(holder)
104
+ reverse_edges.append((holder_ref, idx))
105
+ metadata["reverse_edges"] = reverse_edges
106
+ return reverse_graph_refs
107
+
108
+
109
+ def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, Any]]:
110
+ """
111
+ Given a list of inputs and a list of parameters, return a list of parameter
112
+ groups, where each group contains the parameters and the intermediates that
113
+ are connected to the parameters.
114
+
115
+ The returned list of parameter groups is a list of dictionaries, where each
116
+ dictionary contains the following keys:
117
+ - "params": a set of parameters
118
+ - "intermediates": a set of intermediates
119
+
120
+ The returned list of parameter groups is a list of dictionaries,
121
+ """
122
+ # reverse graph that starts with inputs, and goes up to the dOutput or the loss,
123
+ # but omits weights and any subgraphs connecting weights to this closure
124
+ inputs_closure, _ = reverse_closure(inputs, set())
125
+ param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates
126
+ for i, param in enumerate(params):
127
+ closure, intersected = reverse_closure([param], inputs_closure)
128
+ param_group: Dict[str, Set] = {
129
+ "params": {param},
130
+ "intermediates": intersected,
131
+ }
132
+ for input_node in intersected:
133
+ existing = param_groups.get(input_node, None)
134
+ if existing is not None:
135
+ existing["params"] = existing["params"].union(param_group["params"])
136
+ existing["intermediates"] = existing["intermediates"].union(
137
+ param_group["intermediates"]
138
+ )
139
+ param_group = existing
140
+ else:
141
+ param_groups[input_node] = param_group
142
+
143
+ # Sanity check: union of all param_groups params should be equal to all params
144
+ union_params: Set[Node] = set()
145
+ seen_ids: Set[int] = set()
146
+ unique_param_groups = []
147
+ for param_group in param_groups.values():
148
+ if id(param_group) not in seen_ids:
149
+ seen_ids.add(id(param_group))
150
+ unique_param_groups.append(param_group)
151
+ union_params = union_params.union(param_group["params"])
152
+
153
+ # The assert will only be true if the input tensor requires gradients,
154
+ # otherwise the autograd graph will miss the first layer of inputs
155
+ # assert union_params == set(params)
156
+ return unique_param_groups
157
+
158
+
159
+ def stage_backward_input(
160
+ stage_outputs: List[torch.Tensor],
161
+ output_grads: Optional[List[torch.Tensor]],
162
+ input_values: List[torch.Tensor],
163
+ weights: Iterator[Parameter],
164
+ ):
165
+ """
166
+ compute the gradients for only the stage inputs with respect to the stage outputs
167
+ """
168
+ stage_output_grad_fns: List[Node] = list(
169
+ filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs))
170
+ )
171
+ stage_input_grad_fns: List[Node] = list(
172
+ filter(None, map(_get_grad_fn_or_grad_acc, input_values))
173
+ )
174
+ weight_grad_fns: List[Node] = list(
175
+ filter(None, map(_get_grad_fn_or_grad_acc, weights))
176
+ )
177
+
178
+ reverse_graph_refs = construct_reverse_graph(stage_output_grad_fns)
179
+ param_groups = get_param_groups(stage_input_grad_fns, weight_grad_fns)
180
+ del reverse_graph_refs
181
+
182
+ for param_group in param_groups:
183
+ for i, intermediate in enumerate(param_group["intermediates"]):
184
+
185
+ def get_hook(param_group, i):
186
+ def hook(grad_inputs):
187
+ if param_group.get("grads", None) is None:
188
+ param_group["grads"] = [None] * len(
189
+ param_group["intermediates"]
190
+ )
191
+ param_group["grads"][i] = grad_inputs
192
+
193
+ return hook
194
+
195
+ # These are always "split" nodes that we need to recompute, so
196
+ # save their inputs.
197
+ intermediate.register_prehook(get_hook(param_group, i))
198
+
199
+ # Stage 0 inputs do not require grads? Should we skip in that case?
200
+ if all(tensor.requires_grad for tensor in input_values):
201
+ if output_grads is None:
202
+ # In case this is the loss and there are no output_grads, then we just use 1s
203
+ output_grads = [
204
+ torch.ones_like(stage_output) for stage_output in stage_outputs
205
+ ]
206
+
207
+ dinputs = torch.autograd.grad(
208
+ stage_outputs,
209
+ inputs=input_values,
210
+ grad_outputs=output_grads,
211
+ retain_graph=True,
212
+ )
213
+
214
+ # update the gradients for inputs
215
+ for i, inp in enumerate(input_values):
216
+ if inp.grad is None:
217
+ inp.grad = dinputs[i]
218
+ else:
219
+ inp.grad += dinputs[i]
220
+ else:
221
+ dinputs = None
222
+ return dinputs, param_groups
223
+
224
+
225
+ def stage_backward_weight(
226
+ weights: Iterator[Parameter], param_groups: List[Dict[str, Any]]
227
+ ):
228
+ # map weights to param_group_weights
229
+ grad_acc_to_weight = {}
230
+ weight_grads = []
231
+ for index, weight in enumerate(weights):
232
+ grad_acc = _get_grad_fn_or_grad_acc(weight)
233
+ grad_acc_to_weight[grad_acc] = weight, index
234
+ weight_grads.append(weight.grad)
235
+
236
+ for param_group in param_groups:
237
+ # TODO: Handle case where intermediate can have multiple outputs
238
+ intermediate_edges = tuple(
239
+ GradientEdge(i, 0) for i in param_group["intermediates"]
240
+ )
241
+ weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
242
+
243
+ assert all(len(g) == 1 for g in param_group["grads"])
244
+ # [NEW!] Able to pass a GradientEdge to autograd.grad as output
245
+ # We do not need to retain_graph because... guarantee no overlap?
246
+ # print("trying to execute: ", intermediate_edges, weights_edges)
247
+ dweights = torch.autograd.grad(
248
+ intermediate_edges,
249
+ weights_edges,
250
+ grad_outputs=sum(param_group["grads"], tuple()),
251
+ )
252
+ for grad_acc, dw in zip(param_group["params"], dweights):
253
+ weight, index = grad_acc_to_weight[grad_acc]
254
+ if weight.grad is None:
255
+ weight.grad = dw
256
+ else:
257
+ weight.grad += dw
258
+ # return grads in the original order weights were provided in
259
+ return weight_grads
260
+
261
+
262
+ def stage_backward(
263
+ stage_output,
264
+ output_grads,
265
+ input_values,
266
+ outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used
267
+ ):
268
+ """
269
+ This is a helper function to:
270
+ 1. compute the gradients for the stage inputs, and
271
+ 2. accumulate gradients for the stage module's parameters.
272
+
273
+ Given the input value(s) and the corresponding gradient for the output
274
+ value(s), compute and accumulate gradients for all parameter values (leaves
275
+ in the autograd trace) as well as return a list of the gradients for the
276
+ input values
277
+ """
278
+ if outputs_with_grads_idxs is not None:
279
+ # Deprecated, not used in runtime calls, only exists in compiler
280
+ stage_output = [stage_output[i] for i in outputs_with_grads_idxs]
281
+ output_grads = [output_grads[i] for i in outputs_with_grads_idxs]
282
+
283
+ try:
284
+ # stage_output may be a composite datatype like dict. Extract all individual
285
+ # tensor values here
286
+ stage_output_tensors = []
287
+ output_grad_tensors = []
288
+
289
+ def extract_tensors_with_grads(output_val, grad_val):
290
+ if isinstance(output_val, torch.Tensor):
291
+ if not output_val.requires_grad and output_val.grad_fn is None:
292
+ return
293
+ assert isinstance(
294
+ grad_val, (torch.Tensor, type(None))
295
+ ), f"Expected Tensor or None gradient but got {type(grad_val)}"
296
+ stage_output_tensors.append(output_val)
297
+ output_grad_tensors.append(grad_val)
298
+ elif isinstance(output_val, (tuple, list)):
299
+ if grad_val is None:
300
+ return
301
+ assert isinstance(
302
+ grad_val, (tuple, list)
303
+ ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
304
+ assert len(output_val) == len(grad_val)
305
+ for ov, gv in zip(output_val, grad_val):
306
+ extract_tensors_with_grads(ov, gv)
307
+ elif isinstance(output_val, dict):
308
+ if grad_val is None:
309
+ return
310
+ assert isinstance(grad_val, dict)
311
+ assert set(output_val.keys()) == set(grad_val.keys())
312
+ for k in output_val.keys():
313
+ extract_tensors_with_grads(output_val[k], grad_val[k])
314
+ else:
315
+ # Output is a non-tensor type; just ignore it
316
+ pass
317
+
318
+ extract_tensors_with_grads(stage_output, output_grads)
319
+
320
+ torch.autograd.backward(
321
+ stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
322
+ )
323
+
324
+ # Extract gradients wrt the input values
325
+ grad_inputs = []
326
+ for val in input_values:
327
+ if isinstance(val, torch.Tensor):
328
+ grad_inputs.append(val.grad)
329
+ else:
330
+ grad_inputs.append(None)
331
+
332
+ # Alternative impl: `torch.autograd.grad`.
333
+ # Note that `torch.autograd.grad` will not accumulate gradients into the
334
+ # model's parameters.
335
+ """
336
+ inputs_with_grad = []
337
+ for val in input_values:
338
+ if isinstance(val, torch.Tensor) and val.requires_grad:
339
+ inputs_with_grad.append(val)
340
+
341
+ grad_inputs = torch.autograd.grad(
342
+ stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type]
343
+ )
344
+ """
345
+
346
+ except Exception as e:
347
+ exc_msg = f"""
348
+ Failed to run stage backward:
349
+ Stage output: {map_debug_info(stage_output)}
350
+ Output gradient: {map_debug_info(output_grads)}
351
+ Input: {map_debug_info(input_values)}
352
+ """
353
+ raise RuntimeError(exc_msg) from e
354
+
355
+ return grad_inputs
356
+
357
+
358
+ # TODO: handling requires_grad=False dynamically. Can we analyze this during initial
359
+ # IR emission?
360
+ def _null_coalesce_accumulate(lhs, rhs):
361
+ """
362
+ Coalesce two values, even if one of them is null, returning the non-null
363
+ value.
364
+ """
365
+ if lhs is None:
366
+ return rhs
367
+ elif rhs is None:
368
+ return lhs
369
+ else:
370
+ return torch.add(lhs, rhs)
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_debug.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+ import torch
4
+
5
+
6
+ def friendly_debug_info(v):
7
+ """
8
+ Helper function to print out debug info in a friendly way.
9
+ """
10
+ if isinstance(v, torch.Tensor):
11
+ return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})"
12
+ else:
13
+ return str(v)
14
+
15
+
16
+ def map_debug_info(a):
17
+ """
18
+ Helper function to apply `friendly_debug_info` to items in `a`.
19
+ `a` may be a list, tuple, or dict.
20
+ """
21
+ return torch.fx.node.map_aggregate(a, friendly_debug_info)
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_unflatten.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+ from typing import Dict
4
+
5
+ import torch
6
+ from torch.export.unflatten import _ModuleFrame
7
+
8
+
9
+ def _outline_submodules(orig_graph: torch.fx.Graph):
10
+ # Create an empty GraphModule to hold the outlined modules
11
+ new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
12
+ seen_nodes: Dict[str, torch.fx.Node] = {}
13
+ seen_modules: Dict[int, torch.nn.Module] = {}
14
+ _ModuleFrame(
15
+ orig_graph,
16
+ tuple(orig_graph.nodes),
17
+ seen_nodes,
18
+ seen_modules,
19
+ None,
20
+ [""],
21
+ "",
22
+ {},
23
+ module=new_module,
24
+ ).run_outer()
25
+ new_module.graph.lint()
26
+ new_module.recompile()
27
+ return new_module
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import List, Tuple, Union
6
+
7
+ import torch
8
+ from torch import fx
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def flatten_args_detach(args):
15
+ """
16
+ Flatten the args into a list form and detach the tensors from computational graph.
17
+ """
18
+ flat_detached_args = []
19
+
20
+ def extract_tensor_args(a):
21
+ nonlocal flat_detached_args
22
+ if isinstance(a, torch.Tensor):
23
+ val = a.detach().requires_grad_(a.requires_grad)
24
+ flat_detached_args.append(val)
25
+ return val
26
+ else:
27
+ flat_detached_args.append(a)
28
+ return a
29
+
30
+ new_args = fx.node.map_aggregate(
31
+ args,
32
+ extract_tensor_args,
33
+ )
34
+
35
+ return new_args, flat_detached_args
36
+
37
+
38
+ def flatten_args(args):
39
+ """
40
+ Flatten the args into a list form.
41
+ """
42
+ flat_args = []
43
+
44
+ def extract_tensor_args(a):
45
+ nonlocal flat_args
46
+ flat_args.append(a)
47
+ return a
48
+
49
+ fx.node.map_aggregate(
50
+ args,
51
+ extract_tensor_args,
52
+ )
53
+
54
+ return flat_args
55
+
56
+
57
+ class PipeliningShapeError(RuntimeError):
58
+ """Shape mismatch between configured and runtime values."""
59
+
60
+
61
+ def validate_tensor_metadata(desc, expected, given):
62
+ if not expected.shape == given.shape:
63
+ raise PipeliningShapeError(
64
+ f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
65
+ )
66
+ if not expected.dtype == given.dtype:
67
+ raise PipeliningShapeError(
68
+ f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
69
+ )
70
+ if not expected.stride() == given.stride():
71
+ raise PipeliningShapeError(
72
+ f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
73
+ )
74
+
75
+
76
+ def validate_tensors_metadata(
77
+ desc,
78
+ expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
79
+ actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
80
+ ):
81
+ if len(expected_tensors) != len(actual_tensors):
82
+ raise PipeliningShapeError(
83
+ f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
84
+ )
85
+ for i in range(len(expected_tensors)):
86
+ validate_tensor_metadata(
87
+ f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
88
+ )
89
+
90
+
91
+ @dataclass
92
+ class PipeInfo:
93
+ """
94
+ Captures information for a pipeline (`Pipe` object).
95
+ """
96
+
97
+ graph: fx.Graph
98
+ num_stages: int
99
+ has_loss_and_backward: bool
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/microbatch.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+ import logging
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch.fx.node import map_aggregate
8
+ from torch.utils._pytree import tree_flatten, tree_unflatten
9
+
10
+
11
+ __all__ = [
12
+ "TensorChunkSpec",
13
+ "split_args_kwargs_into_chunks",
14
+ "merge_chunks",
15
+ ]
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ """
20
+ _debug_mask_minibatches specifies to send masked versions of the mini-batch
21
+ through instead of micro-batch slices--this can be used for more stable
22
+ numerical testing (see [A Note About Correctness Testing])
23
+ """
24
+ _debug_mask_minibatches = False
25
+
26
+
27
+ class _CustomReducer:
28
+ """
29
+ Custom reducer class that can be used to specify a custom operation that
30
+ reduces losses of multiple microbatches into one value.
31
+
32
+ Example:
33
+ >>> # xdoctest: +SKIP
34
+ >>> sum_reducer = _CustomReducer(
35
+ >>> torch.tensor(0.0),
36
+ >>> lambda a, b: a + b
37
+ >>> )
38
+ """
39
+
40
+ def __init__(self, init_value, reduce_fn):
41
+ self.init_value = init_value
42
+ self.reduce_fn = reduce_fn
43
+
44
+
45
+ class _LossReducer(_CustomReducer):
46
+ pass
47
+
48
+
49
+ sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b)
50
+
51
+ # Default chunking dimension is 0. This is used for the case where the user did
52
+ # not specify a chunking dimension.
53
+ DEFAULT_CHUNK_DIM = 0
54
+
55
+
56
+ class TensorChunkSpec:
57
+ """
58
+ Class used to specify chunking of inputs
59
+ """
60
+
61
+ def __init__(self, split_dim):
62
+ self.split_dim = split_dim
63
+
64
+ split_dim: int
65
+
66
+ def __repr__(self):
67
+ return (
68
+ f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
69
+ )
70
+
71
+ def __str__(self):
72
+ return f"TensorChunkSpec({self.split_dim})"
73
+
74
+ @staticmethod
75
+ def from_tuple(
76
+ chunk_dims: Tuple[int, ...],
77
+ ):
78
+ """
79
+ A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
80
+ dimensions (int's).
81
+ Example:
82
+ >>> # xdoctest: +SKIP
83
+ >>> # There are three positional arguments to the model, and
84
+ >>> # we are chunking them along dimension 0, 0 and 1, respectively
85
+ >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
86
+ """
87
+ args_chunk_spec = map_aggregate(
88
+ chunk_dims,
89
+ lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
90
+ )
91
+ return args_chunk_spec
92
+
93
+ @staticmethod
94
+ def from_dict(
95
+ chunk_dims: Dict[str, int],
96
+ ):
97
+ """
98
+ A helper for creating a dictionary of `TensorChunkSpec` from a
99
+ dictionary of chunk dimensions (int's).
100
+ Example:
101
+ >>> # xdoctest: +SKIP
102
+ >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
103
+ >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
104
+ """
105
+ kwargs_chunk_spec = map_aggregate(
106
+ chunk_dims,
107
+ lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
108
+ )
109
+ return kwargs_chunk_spec
110
+
111
+
112
+ # Class used to specify replication of inputs
113
+ class _Replicate:
114
+ pass
115
+
116
+
117
+ def _shard_dict_of_args(
118
+ args_dict,
119
+ args_chunk_spec,
120
+ num_chunks,
121
+ ):
122
+ """
123
+ Given a dictionary of args, and a dictionary of chunking specs, shard the
124
+ args according to the chunking specs.
125
+
126
+ Args:
127
+ args_dict: Dictionary of args
128
+ args_chunk_spec: Dictionary of chunking specs
129
+ num_chunks: Number of chunks to shard the args into
130
+
131
+ Returns:
132
+ args_split: List of sharded args
133
+ """
134
+ # Stage 1+2: flatten and shard/replicate
135
+
136
+ # args_sharded_replicated : [num args, num flat values, num chunks]
137
+ args_sharded_replicated = {}
138
+ arg_specs = []
139
+
140
+ real_num_chunks = num_chunks
141
+ first_tensor = True
142
+
143
+ assert len(args_dict) == len(
144
+ args_chunk_spec
145
+ ), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
146
+
147
+ for arg_key, arg in args_dict.items():
148
+ flat, spec = tree_flatten(arg)
149
+ arg_specs.append(spec)
150
+
151
+ chunk_spec = args_chunk_spec[arg_key]
152
+ assert chunk_spec is not None # Should have been set by caller
153
+ chunk_spec_flat, _ = tree_flatten(chunk_spec)
154
+ if len(flat) != len(chunk_spec_flat):
155
+ raise ValueError(
156
+ f"Argument value {arg} did not have the same number of "
157
+ f"values as as chunk spec {chunk_spec}"
158
+ )
159
+
160
+ sharded_arg_flat = []
161
+
162
+ for v, chunk_v in zip(flat, chunk_spec_flat):
163
+ if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
164
+ sharded_arg_flat.append([v] * real_num_chunks)
165
+ elif isinstance(chunk_v, TensorChunkSpec):
166
+ # TODO: check type of v. If it's a tensor, use chunk (or debug mask).
167
+ # If it's a collection type, split it as you would expect. Otherwise,
168
+ # Throw an error
169
+ assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
170
+
171
+ v_split_dim_size = v.size(chunk_v.split_dim)
172
+ if v_split_dim_size < real_num_chunks:
173
+ if first_tensor:
174
+ # We can only adjust number of chunks when we hit this
175
+ # issue at the first tensor encountered
176
+ logger.warning(
177
+ f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004
178
+ f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
179
+ )
180
+ real_num_chunks = v_split_dim_size
181
+ else:
182
+ raise RuntimeError(
183
+ f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
184
+ f"smaller than the number of chunks {num_chunks}. "
185
+ "PiPPy cannot reduce the number of chunks because "
186
+ "other arguments have bigger chunk-dimension sizes. "
187
+ "Please adjust your num_chunks setting."
188
+ )
189
+
190
+ chunk_tensors = torch.tensor_split(
191
+ v, real_num_chunks, chunk_v.split_dim
192
+ )
193
+
194
+ if _debug_mask_minibatches:
195
+ expanded_chunks = []
196
+
197
+ split_dim_idx = 0
198
+ for chunk_tensor in chunk_tensors:
199
+ new_val = torch.zeros_like(v)
200
+ upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
201
+
202
+ slice_indices = [slice(None, None, None)] * new_val.ndim
203
+ slice_indices[chunk_v.split_dim] = slice(
204
+ split_dim_idx, upper_idx
205
+ )
206
+ new_val[slice_indices] = chunk_tensor
207
+
208
+ expanded_chunks.append(new_val)
209
+
210
+ split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
211
+
212
+ sharded_arg_flat.append(expanded_chunks)
213
+ else:
214
+ sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type]
215
+
216
+ first_tensor = False
217
+ else:
218
+ raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
219
+
220
+ args_sharded_replicated[arg_key] = sharded_arg_flat
221
+
222
+ # chunks_flat : [num chunks, num args, num flat values]
223
+ chunks_flat = []
224
+ for chunk_idx in range(real_num_chunks):
225
+ chunk_args = {}
226
+ for key, arg in args_sharded_replicated.items():
227
+ arg_single_chunk = []
228
+ for v_flat in arg:
229
+ arg_single_chunk.append(v_flat[chunk_idx])
230
+ chunk_args[key] = arg_single_chunk
231
+ chunks_flat.append(chunk_args)
232
+
233
+ # args_split : [num chunks, num args]
234
+ args_split = []
235
+
236
+ for chunk in chunks_flat:
237
+ per_chunk_args = {}
238
+ assert len(arg_specs) == len(chunk)
239
+ for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
240
+ per_chunk_args[key] = tree_unflatten(arg, arg_spec)
241
+ args_split.append(per_chunk_args)
242
+
243
+ return args_split
244
+
245
+
246
+ def split_args_kwargs_into_chunks(
247
+ args: Tuple[Any, ...],
248
+ kwargs: Optional[Dict[str, Any]],
249
+ chunks: int,
250
+ args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
251
+ kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
252
+ ) -> Tuple[List[Tuple], List[Dict]]:
253
+ """
254
+ Given a sequence of args and kwargs, split them into a number of chunks
255
+ according to their respective chunking specs.
256
+
257
+ Args:
258
+ args: Tuple of args
259
+ kwargs: Dict of kwargs
260
+ chunks: Number of chunks to split the args and kwargs into
261
+ args_chunk_spec: chunking specs for args, in same shape as args
262
+ kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
263
+
264
+ Returns:
265
+ args_split: List of sharded args
266
+ kwargs_split: List of sharded kwargs
267
+ """
268
+ # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
269
+ # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
270
+ # and `kwargs_chunk_spec` specifications. The steps are as follows:
271
+ #
272
+ # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
273
+ # To use a running example: suppose our inputs look like
274
+ #
275
+ # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
276
+ # (kwargs not shown but it's a similar process)
277
+ #
278
+ # Then for this step we would end up with
279
+ #
280
+ # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
281
+ #
282
+ # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
283
+ #
284
+ # args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
285
+ #
286
+ # 3. Rotate the nesting order such that chunks are the outer dimension
287
+ #
288
+ # args_chunks = [
289
+ # ([A, B, C_1], D),
290
+ # ([A, B, C_2], D),
291
+ # ]
292
+ #
293
+ # 4. Unflatten each chunk according to the spec
294
+ #
295
+ # args_chunks = [
296
+ # ([A, [B, C_1]], D),
297
+ # ([A, [B, C_2]], D),
298
+ # ]
299
+
300
+ # TODO: _debug_mask_minibatches
301
+ # Handle the case where kwargs is None
302
+ if kwargs is None:
303
+ kwargs = {}
304
+
305
+ # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
306
+ # their format and use default chunking along dim 0
307
+ if args_chunk_spec is None:
308
+ args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
309
+
310
+ if kwargs_chunk_spec is None:
311
+ kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
312
+
313
+ args_split_dict = _shard_dict_of_args(
314
+ dict(enumerate(args)),
315
+ dict(enumerate(args_chunk_spec)),
316
+ chunks,
317
+ )
318
+ real_num_chunks = len(args_split_dict)
319
+
320
+ kwargs_split = _shard_dict_of_args(
321
+ kwargs,
322
+ kwargs_chunk_spec,
323
+ real_num_chunks,
324
+ )
325
+
326
+ if len(kwargs_split) < real_num_chunks:
327
+ # In case kwargs are sharded into less chunks
328
+ # e.g. when `args` has no tensor, just values
329
+ real_num_chunks = len(kwargs_split)
330
+ # Re-shard args
331
+ args_split_dict = _shard_dict_of_args(
332
+ dict(enumerate(args)),
333
+ dict(enumerate(args_chunk_spec)),
334
+ real_num_chunks,
335
+ )
336
+
337
+ if len(args_split_dict) != len(kwargs_split):
338
+ raise RuntimeError(
339
+ "args and kwargs are split into different number of chunks: "
340
+ f"{len(args_split_dict)}, {len(kwargs_split)}"
341
+ )
342
+
343
+ args_split = []
344
+ for chunk_args in args_split_dict:
345
+ args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args))))
346
+
347
+ return args_split, kwargs_split
348
+
349
+
350
+ def merge_chunks(
351
+ chunks: List[Any],
352
+ chunk_spec,
353
+ ):
354
+ """
355
+ Given a list of chunks, merge them into a single value according to
356
+ the chunk spec.
357
+
358
+ Args:
359
+ chunks: list of chunks
360
+ chunk_spec: Chunking spec for the chunks
361
+
362
+ Returns:
363
+ value: Merged value
364
+ """
365
+ # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
366
+ # steps are similar to the steps in that function but in reverse. Given the
367
+ # input values:
368
+ #
369
+ # chunks = [
370
+ # ([A, [B, C_1]], D),
371
+ # ([A, [B, C_2]], D),
372
+ # ]
373
+ # args_spec = ([None, [None, TensorChunkSpec]], None)
374
+ #
375
+ # 1. Flatten the chunks according to the chunk_spec
376
+ #
377
+ # chunks_flat = [
378
+ # ([A, B, C_1], D),
379
+ # ([A, B, C_2], D),
380
+ # ]
381
+ #
382
+ # 2. Rotate the nesting order such that chunks are the inner dimension
383
+ #
384
+ # value_inner = ([A, B, [C_1, C_2]], D)
385
+ #
386
+ # 3. Concatenate sharded arguments
387
+ #
388
+ # value_combined = ([A, B, C], D)
389
+ #
390
+ # 4. Unflatten the combined args given the spec
391
+ #
392
+ # value = ([A, [B, C]], D)
393
+
394
+ # Preliminary: flatten the chunk spec
395
+ if chunk_spec is not None:
396
+ spec_flattened, flatten_spec = tree_flatten(chunk_spec)
397
+ else:
398
+ # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
399
+ # We obtain the output structure by flattening chunk 0 and generate the chunk_spec
400
+ chunk0_flat, flatten_spec = tree_flatten(chunks[0])
401
+ spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
402
+
403
+ # Stage 1: flatten chunks
404
+ # chunks_flattened : [num chunks, num args]
405
+ chunks_flattened = []
406
+
407
+ for chunk in chunks:
408
+ chunk_flattened, _ = tree_flatten(chunk)
409
+ if len(chunk_flattened) != len(spec_flattened):
410
+ raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
411
+
412
+ chunks_flattened.append(chunk_flattened)
413
+
414
+ # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
415
+ # concatenate sharded operands
416
+ # args_flattened : [num args]
417
+ args_flattened = []
418
+ for arg_idx, arg in enumerate(spec_flattened):
419
+ if isinstance(arg, TensorChunkSpec):
420
+ partial_values = [
421
+ chunks_flattened[chunk_idx][arg_idx]
422
+ for chunk_idx in range(len(chunks_flattened))
423
+ ]
424
+
425
+ if _debug_mask_minibatches:
426
+ # Infer size of individual chunks by running `tensor_split` again
427
+ overall_shape = partial_values[0].shape
428
+ for val in partial_values[1:]:
429
+ assert val.shape == overall_shape
430
+ meta_chunks = torch.tensor_split(
431
+ torch.empty(*overall_shape, device="meta"),
432
+ sections=len(partial_values),
433
+ dim=arg.split_dim,
434
+ )
435
+
436
+ values_to_cat = []
437
+ chunk_start_idx = 0
438
+ assert len(partial_values) == len(meta_chunks)
439
+ for partial_value, meta_chunk in zip(partial_values, meta_chunks):
440
+ chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
441
+
442
+ slice_indices = [slice(None, None, None)] * partial_value.ndim
443
+ slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
444
+ sliced = partial_value[slice_indices]
445
+ values_to_cat.append(sliced)
446
+
447
+ chunk_start_idx = chunk_end_idx
448
+
449
+ else:
450
+ values_to_cat = partial_values
451
+
452
+ args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
453
+ elif isinstance(arg, _CustomReducer):
454
+ reduced_val = arg.init_value
455
+
456
+ for chunk_idx in range(len(chunks_flattened)):
457
+ reduced_val = arg.reduce_fn(
458
+ reduced_val, chunks_flattened[chunk_idx][arg_idx]
459
+ )
460
+
461
+ args_flattened.append(reduced_val)
462
+ else:
463
+ value = chunks_flattened[0][arg_idx]
464
+ for chunk_idx in range(1, len(chunks_flattened)):
465
+ assert chunks_flattened[chunk_idx][arg_idx] == value
466
+ args_flattened.append(value)
467
+
468
+ # Stage 4: Unflatten combined args
469
+ return tree_unflatten(args_flattened, flatten_spec)
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/schedules.py ADDED
@@ -0,0 +1,2162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+
4
+ import csv
5
+ import itertools
6
+ import logging
7
+ import re
8
+ from abc import ABC, abstractmethod
9
+ from collections import defaultdict
10
+ from enum import Enum
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ List,
16
+ NamedTuple,
17
+ Optional,
18
+ Set,
19
+ Tuple,
20
+ TYPE_CHECKING,
21
+ Union,
22
+ )
23
+
24
+ import torch
25
+ import torch.distributed as dist
26
+ from torch.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle
27
+ from torch.profiler import record_function
28
+
29
+ from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
30
+ from .stage import _PipelineStageBase
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from torch.distributed import Work
35
+
36
+ __all__ = [
37
+ "get_schedule_class",
38
+ "PipelineScheduleSingle",
39
+ "PipelineScheduleMulti",
40
+ "Schedule1F1B",
41
+ "ScheduleFlexibleInterleaved1F1B",
42
+ "ScheduleGPipe",
43
+ "ScheduleInterleaved1F1B",
44
+ "ScheduleLoopedBFS",
45
+ "ScheduleInterleavedZeroBubble",
46
+ ]
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class _ComputationType(Enum):
52
+ # TODO(whc) rename to _ActType?
53
+ FORWARD = 1
54
+ BACKWARD = 2
55
+ WEIGHT = 3
56
+ UNSHARD = 4
57
+ RESHARD = 5
58
+ SEND_F = 6
59
+ RECV_F = 7
60
+ SEND_B = 8
61
+ RECV_B = 9
62
+
63
+ def __str__(self):
64
+ str_map = {
65
+ _ComputationType.FORWARD: "F",
66
+ _ComputationType.BACKWARD: "B",
67
+ _ComputationType.WEIGHT: "W",
68
+ _ComputationType.UNSHARD: "UNSHARD",
69
+ _ComputationType.RESHARD: "RESHARD",
70
+ _ComputationType.SEND_F: "SEND_F",
71
+ _ComputationType.RECV_F: "RECV_F",
72
+ _ComputationType.SEND_B: "SEND_B",
73
+ _ComputationType.RECV_B: "RECV_B",
74
+ }
75
+ return str_map[self]
76
+
77
+ @staticmethod
78
+ def from_str(action):
79
+ if action == "F":
80
+ return _ComputationType.FORWARD
81
+ elif action == "B":
82
+ return _ComputationType.BACKWARD
83
+ elif action == "W":
84
+ return _ComputationType.WEIGHT
85
+ elif action == "UNSHARD":
86
+ return _ComputationType.UNSHARD
87
+ elif action == "RESHARD":
88
+ return _ComputationType.RESHARD
89
+ elif action == "SEND_F":
90
+ return _ComputationType.SEND_F
91
+ elif action == "RECV_F":
92
+ return _ComputationType.RECV_F
93
+ elif action == "SEND_B":
94
+ return _ComputationType.SEND_B
95
+ elif action == "RECV_B":
96
+ return _ComputationType.RECV_B
97
+ else:
98
+ raise RuntimeError(f"Invalid computation type {action}")
99
+
100
+
101
+ FORWARD = _ComputationType.FORWARD
102
+ BACKWARD = _ComputationType.BACKWARD
103
+ WEIGHT = _ComputationType.WEIGHT
104
+ UNSHARD = _ComputationType.UNSHARD
105
+ RESHARD = _ComputationType.RESHARD
106
+ SEND_F = _ComputationType.SEND_F
107
+ RECV_F = _ComputationType.RECV_F
108
+ SEND_B = _ComputationType.SEND_B
109
+ RECV_B = _ComputationType.RECV_B
110
+
111
+ # Convenience shorthand for compute actions only since they are used in 'simple schedule format'
112
+ F = FORWARD
113
+ B = BACKWARD
114
+ W = WEIGHT
115
+
116
+ # Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
117
+ _action_regex = re.compile(
118
+ r"(\d+)([F,B,W]|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B{0,1})(\d*)"
119
+ )
120
+
121
+
122
+ class _Action(NamedTuple):
123
+ stage_index: int
124
+ computation_type: _ComputationType
125
+ microbatch_index: Optional[int] = None
126
+
127
+ def __repr__(self):
128
+ repr = str(self.stage_index)
129
+ repr += str(self.computation_type)
130
+ if self.microbatch_index is not None:
131
+ repr += str(self.microbatch_index)
132
+ return repr
133
+
134
+ @staticmethod
135
+ def from_str(str):
136
+ """
137
+ Reverse of __repr__
138
+
139
+ String should be formatted as [stage][action type][(microbatch)]
140
+ e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
141
+ """
142
+ if match := _action_regex.match(str):
143
+ stage_index, computation_type, microbatch_index = match.groups()
144
+ return _Action(
145
+ int(stage_index),
146
+ _ComputationType.from_str(computation_type),
147
+ int(microbatch_index) if len(microbatch_index) else None,
148
+ )
149
+ elif str == "" or str.isspace():
150
+ return None
151
+ raise RuntimeError(
152
+ f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
153
+ )
154
+
155
+
156
+ def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
157
+ """
158
+ Formats the pipeline order in a timestep (row) x rank (column) grid of actions
159
+ and returns the formatted string
160
+ """
161
+ # Calculate the maximum number of steps across all ranks
162
+ num_steps = max(len(actions) for actions in pipeline_order.values())
163
+ step_labels = [
164
+ "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
165
+ ]
166
+ # Sorting the dictionary by keys and retrieving values in that order
167
+ rank_actions = [
168
+ pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
169
+ ]
170
+ # Transpose the list of lists (rows to columns)
171
+ transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
172
+ # Generate column labels for ranks
173
+ num_ranks = len(pipeline_order)
174
+ rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
175
+ # Calculate the maximum length of each column, considering labels
176
+ max_lengths = [
177
+ max(len(str(item)) if item is not None else 0 for item in col)
178
+ for col in zip(step_labels, *transposed_actions)
179
+ ]
180
+ # Format the header row with rank labels
181
+ header_row = " " * (len(step_labels[0]) + 2) + " ".join(
182
+ f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
183
+ )
184
+ # Format each row with its corresponding label
185
+ formatted_rows = [
186
+ f"{label}: "
187
+ + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
188
+ for label, row in zip(step_labels, transposed_actions)
189
+ ]
190
+ # Join the rows into a single string
191
+ formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
192
+ return formatted_table
193
+
194
+
195
+ def _validate_pipeline_order(
196
+ pipeline_order: Dict[int, List[Optional[_Action]]],
197
+ num_microbatches: int,
198
+ num_stages: int,
199
+ enable_zero_bubble: bool = False,
200
+ ):
201
+ """
202
+ pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
203
+ Validating that the pipeline order follows the rules:
204
+ 1. Forward action for a microbatch must be before the Backward action for that microbatch
205
+ 2. Recv for a microbatch must be before the send for that microbatch
206
+ 3. Microbatch index is handled in sequential order for each stage
207
+ 4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
208
+ 5. Same microbatch cannot be handled in the same time step across ranks
209
+ """
210
+ # microbatch_index: (current computation type, current stage)
211
+ microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
212
+ max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
213
+ for timestep in range(max_timestep):
214
+ error_msg: List[str] = []
215
+ current_timestep_actions = []
216
+ for rank in range(len(pipeline_order)):
217
+ action = (
218
+ pipeline_order[rank][timestep]
219
+ if timestep < len(pipeline_order[rank])
220
+ else None
221
+ )
222
+
223
+ if action is not None:
224
+ computation_type = action.computation_type
225
+ if computation_type != _ComputationType.WEIGHT:
226
+ current_timestep_actions.append(action)
227
+
228
+ # TODO: enable this
229
+ # if len(current_timestep_actions) == 0:
230
+ # error_msg.append(
231
+ # "All actions were None, there is an unnecessary gap in the schedule"
232
+ # )
233
+
234
+ # Ensure that no microbatch is operated on twice in current_timestep_actions
235
+ unique_microbatch_indices = {
236
+ action.microbatch_index for action in current_timestep_actions
237
+ }
238
+ if len(unique_microbatch_indices) != len(current_timestep_actions):
239
+ error_msg.append(
240
+ "Duplicate microbatch index found in current_timestep_actions"
241
+ )
242
+
243
+ for action in current_timestep_actions:
244
+ stage_index = action.stage_index
245
+ computation_type = action.computation_type
246
+ mb_index = action.microbatch_index
247
+ assert (
248
+ mb_index is not None
249
+ ), "All currently supported action types require valid microbatch_index"
250
+ if mb_index >= num_microbatches:
251
+ error_msg.append(f"Microbatch index {mb_index} out of range")
252
+
253
+ # first microbatch
254
+ if mb_index not in microbatch_process_info:
255
+ if computation_type != _ComputationType.FORWARD or stage_index != 0:
256
+ error_msg.append(f"Incorrect start for microbatch {mb_index}")
257
+ microbatch_process_info[mb_index] = (computation_type, stage_index)
258
+ else:
259
+ # if the microbatch is included, check that the current stage is right after prev
260
+ prev_computation, prev_stage = microbatch_process_info[mb_index]
261
+
262
+ if prev_computation == _ComputationType.FORWARD:
263
+ if prev_stage == num_stages - 1:
264
+ expected_stage = num_stages - 1
265
+ expected_computation = _ComputationType.BACKWARD
266
+ else:
267
+ expected_stage = prev_stage + 1
268
+ expected_computation = _ComputationType.FORWARD
269
+ elif prev_computation == _ComputationType.BACKWARD:
270
+ if prev_stage == 0:
271
+ error_msg.append(
272
+ f"[{mb_index=}] already finished backward computation"
273
+ )
274
+ break
275
+ else:
276
+ expected_stage = prev_stage - 1
277
+ expected_computation = _ComputationType.BACKWARD
278
+ else:
279
+ raise ValueError(
280
+ f"Computation type {prev_computation} not supported"
281
+ )
282
+
283
+ if expected_computation is not None:
284
+ if expected_computation != computation_type:
285
+ error_msg.append(
286
+ f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
287
+ )
288
+
289
+ if expected_stage != stage_index:
290
+ error_msg.append(
291
+ f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
292
+ )
293
+
294
+ microbatch_process_info[mb_index] = (
295
+ expected_computation,
296
+ expected_stage,
297
+ )
298
+
299
+ if not enable_zero_bubble:
300
+ if len(error_msg) != 0:
301
+ raise RuntimeError(
302
+ f"Error at timestep {timestep}: " + ",".join(error_msg)
303
+ )
304
+ return
305
+
306
+ for rank in range(len(pipeline_order)):
307
+ backward_steps: Set[Tuple[int, int]] = set()
308
+ weight_steps: Set[Tuple[int, int]] = set()
309
+
310
+ for action in pipeline_order[rank]:
311
+ if action is None:
312
+ continue
313
+
314
+ stage_index = action.stage_index
315
+ computation_type = action.computation_type
316
+ mb_index = action.microbatch_index
317
+ if computation_type == _ComputationType.BACKWARD:
318
+ if mb_index is not None:
319
+ backward_steps.add((mb_index, stage_index))
320
+ elif computation_type == _ComputationType.WEIGHT:
321
+ if (mb_index, stage_index) not in backward_steps:
322
+ error_msg.append(
323
+ f"{mb_index=}, {stage_index=} Weight happened before bwd"
324
+ )
325
+ if (mb_index, stage_index) in weight_steps:
326
+ error_msg.append(
327
+ f"{mb_index=}, {stage_index=} Duplicated weight step"
328
+ )
329
+ if mb_index is not None:
330
+ weight_steps.add((mb_index, stage_index))
331
+
332
+ if len(backward_steps) != len(weight_steps):
333
+ error_msg.append("Length weight steps != Length bwd steps")
334
+
335
+ if len(error_msg) != 0:
336
+ raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
337
+
338
+
339
+ class _PipelineSchedule(ABC):
340
+ def __init__(
341
+ self,
342
+ n_microbatches: int,
343
+ loss_fn: Optional[Callable[..., torch.Tensor]] = None,
344
+ args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
345
+ kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
346
+ output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
347
+ ):
348
+ # From arguments
349
+ self._n_microbatches = n_microbatches
350
+ self._loss_fn = loss_fn
351
+ # Chunking specification for positional inputs. (default: `None`)
352
+ self._args_chunk_spec = args_chunk_spec
353
+ # Chunking specification for keyword inputs. (default: `None`)
354
+ self._kwargs_chunk_spec = kwargs_chunk_spec
355
+ self._output_merge_spec = output_merge_spec
356
+ """
357
+ # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
358
+ # They are used to convert batch to microbatches in `step(x)`. See
359
+ # `TensorChunkSpec` for helper methods for creating them.
360
+ """
361
+
362
+ # Derived
363
+ self._has_backward = self._loss_fn is not None
364
+
365
+ # Holds the losses for each microbatch.
366
+ self._internal_losses: List[torch.Tensor] = []
367
+ logger.info("Using %s", self.__class__.__name__)
368
+
369
+ def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
370
+ if stage.is_last and self._has_backward:
371
+ loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
372
+ self._internal_losses.append(loss)
373
+
374
+ def _maybe_get_loss(self, stage, mb_index):
375
+ valid_index = 0 <= mb_index < len(self._internal_losses)
376
+ if stage.is_last and self._has_backward and valid_index:
377
+ return self._internal_losses[mb_index]
378
+ elif len(self._internal_losses) != 0 and not valid_index:
379
+ raise RuntimeError(
380
+ f"Loss for microbatch {mb_index} is not available. "
381
+ f"Available losses for microbatches: {self._internal_losses}"
382
+ )
383
+ else:
384
+ return None
385
+
386
+ def _update_losses(self, stages, losses):
387
+ """
388
+ Update the losses to those in the internal state
389
+ """
390
+ # if stages not a list turn into a list
391
+ if not isinstance(stages, list):
392
+ stages = [stages]
393
+ contains_last_stage = any(stage.is_last for stage in stages)
394
+
395
+ # Return losses if there is a container passed in
396
+ if contains_last_stage and losses is not None:
397
+ if len(self._internal_losses) != self._n_microbatches:
398
+ raise RuntimeError(
399
+ f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
400
+ )
401
+
402
+ # Clean external container first
403
+ losses.clear()
404
+ # Copy internal losses to external container
405
+ losses.extend(self._internal_losses)
406
+
407
+ self._internal_losses.clear()
408
+
409
+ @abstractmethod
410
+ def _step_microbatches(
411
+ self,
412
+ arg_mbs: Optional[List] = None,
413
+ kwarg_mbs: Optional[List] = None,
414
+ target_mbs: Optional[List] = None,
415
+ losses: Optional[List] = None,
416
+ ):
417
+ """
418
+ Run one iteration of the pipeline schedule with list of microbatches.
419
+ Will go through all the microbatches according to the schedule
420
+ implementation.
421
+
422
+ Args:
423
+ microbatches: list of microbatch args.
424
+ """
425
+ raise NotImplementedError
426
+
427
+ @abstractmethod
428
+ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
429
+ """
430
+ Run one iteration of the pipeline schedule with *whole-batch* input.
431
+ Will chunk the input into microbatches automatically, and go through the
432
+ microbatches according to the schedule implementation.
433
+
434
+ args: positional arguments to the model (as in non-pipeline case).
435
+ kwargs: keyword arguments to the model (as in non-pipeline case).
436
+ target: target for the loss function.
437
+ losses: a list to store the losses for each microbatch.
438
+ """
439
+ raise NotImplementedError
440
+
441
+ def _check_inputs(
442
+ self,
443
+ arg_mbs: Optional[List] = None,
444
+ kwarg_mbs: Optional[List] = None,
445
+ target_mbs: Optional[List] = None,
446
+ losses: Optional[List] = None,
447
+ ):
448
+ """
449
+ Pre-process/check inputs
450
+ """
451
+
452
+ def check_type_and_len(mbs, name: str):
453
+ if not isinstance(mbs, list):
454
+ raise TypeError(f"{name} must be a list but got a {type(mbs)}")
455
+ if len(mbs) != self._n_microbatches:
456
+ raise ValueError(
457
+ f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
458
+ )
459
+
460
+ if arg_mbs is not None:
461
+ check_type_and_len(arg_mbs, "arg_mbs")
462
+ else:
463
+ arg_mbs = [()] * self._n_microbatches
464
+
465
+ if kwarg_mbs is not None:
466
+ check_type_and_len(kwarg_mbs, "kwarg_mbs")
467
+ else:
468
+ kwarg_mbs = [{}] * self._n_microbatches
469
+
470
+ if target_mbs is not None:
471
+ check_type_and_len(target_mbs, "target_mbs")
472
+
473
+ if losses is not None:
474
+ if not isinstance(losses, list):
475
+ raise TypeError(f"losses must be a list but got a {type(losses)}")
476
+
477
+ return arg_mbs, kwarg_mbs
478
+
479
+ def _compute_loss(self, output, target):
480
+ return self._loss_fn(output, target) # type: ignore[misc]
481
+
482
+ def _split_inputs(
483
+ self,
484
+ args: Tuple[Any, ...],
485
+ kwargs: Optional[Dict[str, Any]] = None,
486
+ ):
487
+ """
488
+ Splits a full-batch input into chunks (i.e. microbatches) and returns
489
+ the chunks
490
+ """
491
+ if args or kwargs:
492
+ args_split, kwargs_split = split_args_kwargs_into_chunks(
493
+ args,
494
+ kwargs,
495
+ self._n_microbatches,
496
+ self._args_chunk_spec,
497
+ self._kwargs_chunk_spec,
498
+ )
499
+ return args_split, kwargs_split
500
+ else:
501
+ # Empty inputs (e.g. when called on middle stages)
502
+ # Return a list of empty tuples/dicts with matching length as chunks
503
+ return [()] * self._n_microbatches, [{}] * self._n_microbatches
504
+
505
+ def _merge_outputs(self, output_chunks: List[Any]) -> Any:
506
+ """
507
+ Merge output chunks back to a batch state.
508
+ If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
509
+ """
510
+ return merge_chunks(
511
+ output_chunks,
512
+ self._output_merge_spec,
513
+ )
514
+
515
+
516
+ def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
517
+ """
518
+ Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
519
+ """
520
+ if len(p2p_ops) == 0:
521
+ return None
522
+ desc_str = f"{desc}, " if desc else ""
523
+ logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
524
+ return dist.batch_isend_irecv(p2p_ops).pop()
525
+
526
+
527
+ def _sorted_batch_p2p(
528
+ p2p_ops: List[dist.P2POp], desc: Optional[str] = None
529
+ ) -> Dict[int, dist.Work]:
530
+ """
531
+ Sorts the list of P2P ops by the peer rank, and then calls
532
+ batch_isend_irecv. Return a dictionary of works by peer rank. This function
533
+ helps us avoid hangs in case of skip connections.
534
+ """
535
+ # Arrange p2p_ops by peer rank:
536
+ # int is the peer rank;
537
+ # List is the list of ops towards the peer
538
+ ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
539
+ work_by_peer: Dict[int, dist.Work] = {}
540
+ if len(p2p_ops) == 0:
541
+ return work_by_peer
542
+
543
+ # Classify the ops by peer rank
544
+ for op in p2p_ops:
545
+ ops_by_peer[op.peer].append(op)
546
+
547
+ # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
548
+ for peer, ops in sorted(ops_by_peer.items()):
549
+ work_by_peer[peer] = _batch_p2p(ops, desc=desc)
550
+
551
+ return work_by_peer
552
+
553
+
554
+ class PipelineScheduleSingle(_PipelineSchedule):
555
+ """
556
+ Base class for single-stage schedules.
557
+ Implements the `step` method.
558
+ Derived classes should implement `_step_microbatches`.
559
+ """
560
+
561
+ def __init__(
562
+ self,
563
+ stage: _PipelineStageBase,
564
+ n_microbatches: int,
565
+ loss_fn: Optional[Callable] = None,
566
+ args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
567
+ kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
568
+ output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
569
+ ):
570
+ # Init parent
571
+ super().__init__(
572
+ n_microbatches=n_microbatches,
573
+ loss_fn=loss_fn,
574
+ args_chunk_spec=args_chunk_spec,
575
+ kwargs_chunk_spec=kwargs_chunk_spec,
576
+ output_merge_spec=output_merge_spec,
577
+ )
578
+ # Self attributes
579
+ self._stage = stage
580
+ self._num_stages = stage.num_stages
581
+ # Set the same has_backward flag for stage object
582
+ self._stage.has_backward = self._has_backward
583
+
584
+ # TODO: later replace this with lazy shape inference during forward
585
+ # Prepare forward send/recv infrastructure for stage
586
+ stage._prepare_forward_infra(n_microbatches)
587
+ if self._has_backward:
588
+ stage._prepare_backward_infra(n_microbatches)
589
+
590
+ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
591
+ """
592
+ Run one iteration of the pipeline schedule with *whole-batch* input.
593
+ Will chunk the input into microbatches automatically, and go through the
594
+ microbatches according to the schedule implementation.
595
+
596
+ args: positional arguments to the model (as in non-pipeline case).
597
+ kwargs: keyword arguments to the model (as in non-pipeline case).
598
+ target: target for the loss function.
599
+ losses: a list to store the losses for each microbatch.
600
+ """
601
+
602
+ # Clean per iteration
603
+ self._stage.clear_runtime_states()
604
+
605
+ # Split inputs into microbatches
606
+ args_split, kwargs_split = self._split_inputs(args, kwargs)
607
+
608
+ # Split target into microbatches
609
+ if target is not None:
610
+ targets_split = list(torch.tensor_split(target, self._n_microbatches))
611
+ else:
612
+ targets_split = None
613
+
614
+ # Run microbatches
615
+ self._step_microbatches(args_split, kwargs_split, targets_split, losses)
616
+
617
+ # Return merged results per original format
618
+ if self._stage.is_last:
619
+ return self._merge_outputs(self._stage.output_chunks)
620
+ else:
621
+ return None
622
+
623
+
624
+ class _ScheduleForwardOnly(PipelineScheduleSingle):
625
+ """
626
+ The forward-only schedule.
627
+ Will go through all the microbatches and perform only the forward pass
628
+ """
629
+
630
+ def _step_microbatches(
631
+ self,
632
+ arg_mbs: Optional[List] = None,
633
+ kwarg_mbs: Optional[List] = None,
634
+ target_mbs: Optional[List] = None,
635
+ losses: Optional[List] = None,
636
+ ):
637
+ """
638
+ Run one iteration of the pipeline schedule
639
+ """
640
+ if target_mbs is not None or losses is not None:
641
+ raise RuntimeError(
642
+ "Forward-only schedule does not support loss computation"
643
+ )
644
+
645
+ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
646
+
647
+ # Delay send waits
648
+ fwd_sends_to_wait: List[dist.Work] = []
649
+
650
+ # Run microbatches
651
+ for i in range(self._n_microbatches):
652
+ with record_function(f"Forward {i}"):
653
+ ops = self._stage.get_fwd_recv_ops(i)
654
+ works = _sorted_batch_p2p(ops, desc="fwd_recv")
655
+ for work in works.values():
656
+ work.wait()
657
+
658
+ self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
659
+
660
+ ops = self._stage.get_fwd_send_ops(i)
661
+ works = _sorted_batch_p2p(ops, desc="fwd_send")
662
+ fwd_sends_to_wait.extend(works.values())
663
+
664
+ logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
665
+
666
+ # Wait for all forward sends to finish
667
+ # This should not have performance impact because by the time the first
668
+ # backward arrives all the forward sends should have been finished.
669
+ for work in fwd_sends_to_wait:
670
+ work.wait()
671
+
672
+
673
+ class ScheduleGPipe(PipelineScheduleSingle):
674
+ """
675
+ The GPipe schedule.
676
+ Will go through all the microbatches in a fill-drain manner.
677
+ """
678
+
679
+ def _step_microbatches(
680
+ self,
681
+ arg_mbs: Optional[List] = None,
682
+ kwarg_mbs: Optional[List] = None,
683
+ target_mbs: Optional[List] = None,
684
+ losses: Optional[List] = None,
685
+ ):
686
+ """
687
+ Run one iteration of the pipeline schedule with list of microbatches.
688
+ Will go through all the microbatches according to the GPipe schedule.
689
+
690
+ Args:
691
+ microbatches: list of microbatch args.
692
+ """
693
+ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
694
+
695
+ # Delay send waits
696
+ fwd_sends_to_wait: List[dist.Work] = []
697
+
698
+ # Run microbatches
699
+ for i in range(self._n_microbatches):
700
+ with record_function(f"Forward {i}"):
701
+ ops = self._stage.get_fwd_recv_ops(i)
702
+ works = _sorted_batch_p2p(ops, desc="fwd_recv")
703
+ for work in works.values():
704
+ work.wait()
705
+
706
+ output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
707
+
708
+ ops = self._stage.get_fwd_send_ops(i)
709
+ works = _sorted_batch_p2p(ops, desc="fwd_send")
710
+ fwd_sends_to_wait.extend(works.values())
711
+
712
+ logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
713
+
714
+ self._maybe_compute_loss(self._stage, output, target_mbs, i)
715
+
716
+ # Wait for all forward sends to finish
717
+ # This should not have performance impact because by the time the first
718
+ # backward arrives all the forward sends should have been finished.
719
+ for work in fwd_sends_to_wait:
720
+ work.wait()
721
+
722
+ # No loss function, no need to run backward
723
+ if not self._has_backward:
724
+ return
725
+
726
+ # Run backward
727
+ # Delay send waits
728
+ bwd_sends_to_wait: List[dist.Work] = []
729
+ for i in range(self._n_microbatches):
730
+ with record_function(f"Backward {i}"):
731
+ ops = self._stage.get_bwd_recv_ops(i)
732
+ works = _sorted_batch_p2p(ops, desc="bwd_recv")
733
+ for work in works.values():
734
+ work.wait()
735
+
736
+ loss = self._maybe_get_loss(self._stage, i)
737
+ self._stage.backward_one_chunk(i, loss=loss)
738
+
739
+ ops = self._stage.get_bwd_send_ops(i)
740
+ works = _sorted_batch_p2p(ops, desc="bwd_send")
741
+ bwd_sends_to_wait.extend(works.values())
742
+
743
+ logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
744
+
745
+ # Return losses if there is a container passed in
746
+ self._update_losses(self._stage, losses)
747
+
748
+ # Wait for all backward sends to finish
749
+ for work in bwd_sends_to_wait:
750
+ work.wait()
751
+
752
+
753
+ class Schedule1F1B(PipelineScheduleSingle):
754
+ """
755
+ The 1F1B schedule.
756
+ Will perform one forward and one backward on the microbatches in steady state.
757
+ """
758
+
759
+ def _step_microbatches(
760
+ self,
761
+ arg_mbs: Optional[List] = None,
762
+ kwarg_mbs: Optional[List] = None,
763
+ target_mbs: Optional[List] = None,
764
+ losses: Optional[List] = None,
765
+ ):
766
+ """
767
+ Run one iteration of the pipeline schedule with list of microbatches.
768
+ Will go through all the microbatches according to the 1F1B schedule.
769
+
770
+ Args:
771
+ microbatches: list of microbatch args.
772
+ """
773
+ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
774
+
775
+ # Last stage has 1 warmup, second-to-last 2 warmups, ...
776
+ # first stage `num_stages` warmups
777
+ warmup_chunks = min(
778
+ self._n_microbatches,
779
+ self._num_stages - self._stage.stage_index,
780
+ )
781
+
782
+ # Chunk counters
783
+ fwd_mb_index = 0
784
+ bwd_mb_index = 0
785
+ weight_stage_mb_index = 0
786
+
787
+ # Warmup phase
788
+ send_work = None
789
+ fwd_sends = []
790
+ for _ in range(warmup_chunks):
791
+ # Receive activations
792
+ fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
793
+ if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
794
+ recv_work.wait()
795
+
796
+ # Compute
797
+ output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
798
+
799
+ # Clear previous chunk's forward sends (hopefully they have well
800
+ # finished, otherwise, we are heavily communication bound, in which
801
+ # case it doesn't create a lot of benefit to compute next chunk
802
+ # eagerly either)
803
+ if send_work:
804
+ send_work.wait()
805
+
806
+ # Send activations
807
+ fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
808
+ if fwd_mb_index != warmup_chunks - 1:
809
+ # Safe to fire
810
+ send_work = _batch_p2p(fwd_sends, desc="fwd_send")
811
+ # otherwise:
812
+ # The last foward send is left for fuse with first 1B in 1B1F below
813
+
814
+ # Compute loss
815
+ self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
816
+ fwd_mb_index += 1
817
+
818
+ # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
819
+
820
+ # 1B1F phase
821
+ while True: # Don't worry, we have a break inside
822
+ # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
823
+ bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
824
+
825
+ # Now, we need to fire the fwd_sends and bwd_recvs together
826
+ if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"):
827
+ fuse_work.wait()
828
+
829
+ # Backward one chunk
830
+ loss = self._maybe_get_loss(self._stage, bwd_mb_index)
831
+ self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
832
+
833
+ # Get the bwd send ops, but don't fire, to be fused with the 1F below
834
+ bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
835
+ bwd_mb_index += 1
836
+
837
+ if fwd_mb_index == self._n_microbatches:
838
+ # We are done with 1B1F, so break with some left-over bwd_sends
839
+ break
840
+
841
+ # We prepare 1F of the `1B1F`
842
+ fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
843
+
844
+ # Fuse it with bwd_sends above
845
+ if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"):
846
+ fuse_work.wait()
847
+
848
+ # Now do the fwd
849
+ output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
850
+
851
+ # Compute loss
852
+ self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
853
+
854
+ # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
855
+ fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
856
+ fwd_mb_index += 1
857
+
858
+ # Remember we still have some bwd_sends left over after the break? Now it is time to fire it
859
+ send_work = _batch_p2p(bwd_sends, desc="bwd_send")
860
+
861
+ # Cooldown
862
+ while bwd_mb_index < self._n_microbatches:
863
+ # prepare bwd recv ops
864
+ bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
865
+ if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
866
+ recv_work.wait()
867
+
868
+ # Backward one chunk
869
+ loss = self._maybe_get_loss(self._stage, bwd_mb_index)
870
+ self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
871
+
872
+ # Clear previous chunk's backward sends (hopefully they have well finished)
873
+ if send_work:
874
+ send_work.wait()
875
+
876
+ # Get the bwd send ops, fire it
877
+ bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
878
+ send_work = _batch_p2p(bwd_sends, desc="bwd_send")
879
+ bwd_mb_index += 1
880
+
881
+ # Wait for the last backward send to finish
882
+ if send_work:
883
+ send_work.wait()
884
+
885
+ # Return losses if there is a container passed in
886
+ self._update_losses(self._stage, losses)
887
+
888
+
889
+ def _add_unshard_reshard(
890
+ compute_actions: List[Optional[_Action]],
891
+ max_active_stages: int = 3,
892
+ ) -> List[_Action]:
893
+ """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP.
894
+
895
+ UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
896
+ RESHARD does the opposite, releasing memory (but doing no commmunication)
897
+
898
+ We abandon the "timestep lock" during lowering
899
+
900
+ max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
901
+ 3 stages is probably the thing we want?
902
+ (to account for having one f and one b active, and something else prefetching?)
903
+ """
904
+
905
+ def next_stage_indices(
906
+ count: int, next_actions: List[Optional[_Action]]
907
+ ) -> List[int]:
908
+ """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
909
+ seen: Set[int] = set()
910
+ ret: List[int] = []
911
+
912
+ for a in next_actions:
913
+ if a is not None and a.stage_index not in seen:
914
+ seen.add(a.stage_index)
915
+ ret.append(a.stage_index)
916
+ if len(ret) == count:
917
+ break
918
+ return ret
919
+
920
+ active_stages: Set[int] = set()
921
+ fsdp_aware_actions: List[_Action] = []
922
+
923
+ def _unshard(stage_index: int):
924
+ active_stages.add(stage_index)
925
+ fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
926
+
927
+ def _reshard(stage_index: int):
928
+ active_stages.remove(stage_index)
929
+ fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
930
+
931
+ for i, action in enumerate(compute_actions):
932
+ if action is None:
933
+ continue
934
+
935
+ # We prefetch the next N stages we'll see, dropping existing stages to make room
936
+ next_n = next_stage_indices(max_active_stages, compute_actions[i:])
937
+ # Fetch needs to be ordered correctly, so don't use a set
938
+ fetch = list(filter(lambda s: s not in active_stages, next_n))
939
+ # Unclear what the best policy is for eviction, but we can maintain order so we do
940
+ evict = list(filter(lambda s: s not in next_n, active_stages))
941
+
942
+ # logger.debug(
943
+ # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
944
+ # i,
945
+ # active_stages,
946
+ # fetch,
947
+ # evict,
948
+ # )
949
+
950
+ for stage in evict:
951
+ _reshard(stage)
952
+ for stage in fetch:
953
+ _unshard(stage)
954
+ fsdp_aware_actions.append(action)
955
+
956
+ return fsdp_aware_actions
957
+
958
+
959
+ def _add_send_recv(
960
+ compute_actions: Dict[int, List[_Action]],
961
+ stage_to_rank: Callable[[int], int],
962
+ num_stages: int,
963
+ ) -> Dict[int, List[_Action]]:
964
+ comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions}
965
+
966
+ def _has_comms(action: _Action) -> bool:
967
+ if action.computation_type == F:
968
+ return action.stage_index != num_stages - 1
969
+ elif action.computation_type == B:
970
+ return action.stage_index != 0
971
+ return False
972
+
973
+ def _get_comms(action: _Action) -> Tuple[_Action, _Action]:
974
+ assert _has_comms(action), f"{action} is not a valid comm action"
975
+ stage_idx = action.stage_index
976
+ ctype = action.computation_type
977
+ mb_idx = action.microbatch_index
978
+ send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
979
+ recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
980
+ recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
981
+ return send, recv
982
+
983
+ def _ready_to_schedule(
984
+ action: Optional[_Action], prev_actions: List[_Action]
985
+ ) -> bool:
986
+ """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
987
+ This helps ensure a sane (non-hanging) ordering of sends and recvs.
988
+ But it also means we might not be able to schedule our next compute action yet.
989
+ """
990
+ if action is None:
991
+ return True
992
+ elif action.computation_type == F and not action.stage_index == 0:
993
+ expected_recv = _Action(
994
+ action.stage_index,
995
+ RECV_F if action.computation_type == F else RECV_B,
996
+ action.microbatch_index,
997
+ )
998
+ return expected_recv in prev_actions
999
+ elif action.computation_type == B and not action.stage_index == num_stages - 1:
1000
+ expected_recv = _Action(
1001
+ action.stage_index,
1002
+ RECV_F if action.computation_type == F else RECV_B,
1003
+ action.microbatch_index,
1004
+ )
1005
+ return expected_recv in prev_actions
1006
+ else:
1007
+ return True
1008
+
1009
+ while compute_actions:
1010
+ progress = False
1011
+ # go in order of ranks even if dict keys aren't ordered
1012
+ for rank in range(len(compute_actions)):
1013
+ assert len(compute_actions[rank]) > 0
1014
+ action = compute_actions[rank][0]
1015
+
1016
+ if not _ready_to_schedule(action, comm_actions[rank]):
1017
+ continue
1018
+
1019
+ if action is not None:
1020
+ comm_actions[rank].append(action)
1021
+ if _has_comms(action):
1022
+ send, recv = _get_comms(action)
1023
+ # TODO we can avoid send/recv if the 2 stages are on the same rank.
1024
+ # should we avoid that in the runtime or here?
1025
+ comm_actions[rank].append(send)
1026
+ comm_actions[stage_to_rank(recv.stage_index)].append(recv)
1027
+
1028
+ compute_actions[rank].pop(0)
1029
+ if len(compute_actions[rank]) == 0:
1030
+ del compute_actions[rank]
1031
+ progress = True
1032
+ assert progress, "Malformed compute schedule, can't schedule sends/recvs"
1033
+ return comm_actions
1034
+
1035
+
1036
+ class PipelineScheduleMulti(_PipelineSchedule):
1037
+ """
1038
+ Base class for multi-stage schedules.
1039
+ Implements the `step` method.
1040
+ """
1041
+
1042
+ def __init__(
1043
+ self,
1044
+ stages: List[_PipelineStageBase],
1045
+ n_microbatches: int,
1046
+ loss_fn: Optional[Callable] = None,
1047
+ args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
1048
+ kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
1049
+ output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1050
+ stage_index_to_group_rank: Optional[Dict[int, int]] = None,
1051
+ use_full_backward: bool = True,
1052
+ ):
1053
+ if len(stages) <= 1:
1054
+ raise ValueError(
1055
+ f"Multi-stage schedule expects at least two stages but got {len(stages)}"
1056
+ )
1057
+ # Init parent
1058
+ super().__init__(
1059
+ n_microbatches=n_microbatches,
1060
+ loss_fn=loss_fn,
1061
+ args_chunk_spec=args_chunk_spec,
1062
+ kwargs_chunk_spec=kwargs_chunk_spec,
1063
+ output_merge_spec=output_merge_spec,
1064
+ )
1065
+ # Self attributes
1066
+ self._stages = stages
1067
+ self._num_stages = stages[0].num_stages
1068
+ self.pp_group_size = stages[0].group_size
1069
+ self.rank = stages[0].group_rank
1070
+ # Set the pipeline stage states
1071
+ if stage_index_to_group_rank is not None:
1072
+ for stage in self._stages:
1073
+ stage.stage_index_to_group_rank = stage_index_to_group_rank
1074
+ self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank
1075
+
1076
+ # Set the same has_backward flag for stage object
1077
+ for stage in self._stages:
1078
+ stage.has_backward = self._has_backward
1079
+
1080
+ self._should_compute_loss = (
1081
+ lambda stage: stage.is_last and self._loss_fn is not None
1082
+ )
1083
+
1084
+ # This will be set during init of derived schedules
1085
+ self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1086
+ self.use_full_backward = use_full_backward
1087
+
1088
+ # TODO: later replace this with lazy shape inference during forward
1089
+ # Prepare forward send/recv infrastructure for stage
1090
+ for stage in self._stages:
1091
+ stage._prepare_forward_infra(n_microbatches)
1092
+ if self._has_backward:
1093
+ stage._prepare_backward_infra(n_microbatches)
1094
+
1095
+ def _dump_csv(self, filename):
1096
+ """Dump a CSV representation of the schedule into a file with the provided filename."""
1097
+ with open(filename, "w", newline="") as csvfile:
1098
+ writer = csv.writer(csvfile)
1099
+ for rank in self.pipeline_order:
1100
+ writer.writerow(self.pipeline_order[rank])
1101
+
1102
+ def _validate_schedule(self):
1103
+ # TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554
1104
+ def _validate_rank_actions(
1105
+ actions: Dict[int, List[_Action | None]],
1106
+ num_stages: int,
1107
+ num_microbatches: int,
1108
+ ):
1109
+ # We will count all the actions per stage and ensure they happen in a valid order
1110
+ # (e.g. F before B before W for a given microbatch)
1111
+ stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
1112
+ stage_id: {
1113
+ F: set(),
1114
+ B: set(),
1115
+ W: set(),
1116
+ }
1117
+ for stage_id in range(num_stages)
1118
+ }
1119
+ for rank in actions:
1120
+ for action in actions[rank]:
1121
+ if action is None:
1122
+ continue
1123
+ assert isinstance(
1124
+ action, _Action
1125
+ ), f"Got an invalid action: {action}, expected instance of _Action"
1126
+ s_id = action.stage_index
1127
+ ctype = action.computation_type
1128
+ mb_id = action.microbatch_index
1129
+ if ctype == F:
1130
+ stage_actions[s_id][F].add(mb_id)
1131
+ elif ctype == B:
1132
+ assert (
1133
+ mb_id in stage_actions[s_id][F]
1134
+ ), f"Running Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
1135
+ stage_actions[s_id][B].add(mb_id)
1136
+ elif ctype == W:
1137
+ assert (
1138
+ not self.use_full_backward
1139
+ ), "Schedule contains 'W' actions, but is configured to use full backward"
1140
+ assert (
1141
+ mb_id in stage_actions[s_id][B]
1142
+ ), f"Running Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
1143
+ stage_actions[s_id][W].add(mb_id)
1144
+
1145
+ for s_id in stage_actions:
1146
+ for ctype in (F, B, W):
1147
+ stage_mb = len(stage_actions[s_id][ctype])
1148
+ assert (
1149
+ stage_mb == num_microbatches
1150
+ ), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}"
1151
+
1152
+ assert (
1153
+ len(self.pipeline_order) == self.pp_group_size
1154
+ ), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}"
1155
+ for rank in range(self.pp_group_size):
1156
+ assert (
1157
+ rank in self.pipeline_order
1158
+ ), f"Schedule is missing actions for rank {rank}"
1159
+ _validate_rank_actions(
1160
+ self.pipeline_order,
1161
+ self._num_stages,
1162
+ self._n_microbatches,
1163
+ )
1164
+
1165
+ def _load_csv(self, filename, format="compute_only"):
1166
+ """Load a CSV representation of the schedule from a file with the provided filename.
1167
+ This API will most likely get renamed/refactored so is marked as internal for now.
1168
+
1169
+ format must be "compute_only" for PipelineScheduleMulti
1170
+ """
1171
+ assert format == "compute_only"
1172
+ with open(filename, newline="") as csvfile:
1173
+ reader = csv.reader(csvfile)
1174
+ for rank, row in enumerate(reader):
1175
+ self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
1176
+ self._validate_schedule()
1177
+
1178
+ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
1179
+ """
1180
+ Run one iteration of the pipeline schedule with *whole-batch* input.
1181
+ Will chunk the input into microbatches automatically, and go through the
1182
+ microbatches according to the schedule implementation.
1183
+
1184
+ args: positional arguments to the model (as in non-pipeline case).
1185
+ kwargs: keyword arguments to the model (as in non-pipeline case).
1186
+ target: target for the loss function.
1187
+ losses: a list to store the losses for each microbatch.
1188
+ """
1189
+
1190
+ # Clean per iteration
1191
+ for stage in self._stages:
1192
+ stage.clear_runtime_states()
1193
+
1194
+ # Split inputs into microbatches
1195
+ args_split, kwargs_split = self._split_inputs(args, kwargs)
1196
+
1197
+ # Split target into microbatches
1198
+ if target is not None:
1199
+ targets_split = list(torch.tensor_split(target, self._n_microbatches))
1200
+ else:
1201
+ targets_split = None
1202
+
1203
+ # Run microbatches
1204
+ self._step_microbatches(args_split, kwargs_split, targets_split, losses)
1205
+
1206
+ # Return merged results per original format
1207
+ for stage in self._stages:
1208
+ if stage.is_last:
1209
+ return self._merge_outputs(stage.output_chunks)
1210
+ # Does not contain the last stage
1211
+ return None
1212
+
1213
+ def _step_microbatches(
1214
+ self,
1215
+ arg_mbs: Optional[List] = None,
1216
+ kwarg_mbs: Optional[List] = None,
1217
+ target_mbs: Optional[List] = None,
1218
+ losses: Optional[List] = None,
1219
+ ):
1220
+ """
1221
+ Operate on the microbatches for looped schedules (multiple stages on each rank).
1222
+
1223
+ TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
1224
+ not support models with skip connections.
1225
+ """
1226
+ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
1227
+
1228
+ # Based on the plan in Step 1 created in __init__:
1229
+ # 2. Perform communication based on the pipeline_order
1230
+ stage_index_to_stage: Dict[int, _PipelineStageBase] = {
1231
+ stage.stage_index: stage for stage in self._stages
1232
+ }
1233
+
1234
+ # determine prev_rank and next_rank based on which ranks are next to
1235
+ # the stages in the pipeline_order
1236
+ all_prev_ranks: Set[int] = set()
1237
+ all_next_ranks: Set[int] = set()
1238
+ for stage_index in stage_index_to_stage.keys():
1239
+ # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
1240
+ if stage_index > 0:
1241
+ all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
1242
+ if stage_index < self._num_stages - 1:
1243
+ all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
1244
+
1245
+ for time_step, action in enumerate(self.pipeline_order[self.rank]):
1246
+ try:
1247
+ ops: List[dist.P2POp] = []
1248
+ if action is not None:
1249
+ computation_type = action.computation_type
1250
+ mb_index = action.microbatch_index
1251
+ stage_index = action.stage_index
1252
+ assert (
1253
+ mb_index is not None
1254
+ ), "All currently supported action types require valid microbatch_index"
1255
+ if computation_type == _ComputationType.FORWARD:
1256
+ # perform forward computation
1257
+ stage = stage_index_to_stage[stage_index]
1258
+ output = stage.forward_one_chunk(
1259
+ mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
1260
+ )
1261
+ self._maybe_compute_loss(stage, output, target_mbs, mb_index)
1262
+ ops.extend(stage.get_fwd_send_ops(mb_index))
1263
+ elif computation_type == _ComputationType.BACKWARD:
1264
+ # perform backward computation
1265
+ stage = stage_index_to_stage[stage_index]
1266
+ loss = self._maybe_get_loss(stage, mb_index)
1267
+ stage.backward_one_chunk(
1268
+ mb_index, loss=loss, full_backward=self.use_full_backward
1269
+ )
1270
+ ops.extend(stage.get_bwd_send_ops(mb_index))
1271
+ elif computation_type == _ComputationType.WEIGHT:
1272
+ # perform weight update
1273
+ if self.use_full_backward:
1274
+ raise ValueError(
1275
+ f"We detected a weight update in the pipeline schedule, but \
1276
+ {self.use_full_backward=}"
1277
+ )
1278
+ stage = stage_index_to_stage[stage_index]
1279
+ stage.backward_weight_one_chunk(mb_index)
1280
+ else:
1281
+ raise ValueError(f"Unknown computation type {computation_type}")
1282
+
1283
+ # Look at the neighboring ranks for this current timestep and determine whether
1284
+ # this current rank needs to do any recv communication
1285
+ for prev_rank in all_prev_ranks:
1286
+ prev_rank_ops = self.pipeline_order[prev_rank]
1287
+ prev_rank_action = None
1288
+ if time_step < len(prev_rank_ops):
1289
+ prev_rank_action = prev_rank_ops[time_step]
1290
+ if prev_rank_action is not None:
1291
+ computation_type = prev_rank_action.computation_type
1292
+ mb_index = prev_rank_action.microbatch_index
1293
+ stage_index = prev_rank_action.stage_index
1294
+ assert (
1295
+ mb_index is not None
1296
+ ), "All currently supported action types require valid microbatch_index"
1297
+ # Only handle sends for the forward from a previous rank
1298
+ if computation_type == _ComputationType.FORWARD:
1299
+ # If not the last stage, then receive fwd activations
1300
+ if stage_index + 1 in stage_index_to_stage:
1301
+ # TODO: We are assuming that stage will always receive from stage-1
1302
+ # however that is not necessarily true of get_fwd_recv_ops
1303
+ stage = stage_index_to_stage[stage_index + 1]
1304
+ ops.extend(stage.get_fwd_recv_ops(mb_index))
1305
+ elif (
1306
+ computation_type == _ComputationType.BACKWARD
1307
+ or computation_type == _ComputationType.WEIGHT
1308
+ ):
1309
+ # Previous rank doing backward or weight update has no influence for the current rank forward recv
1310
+ pass
1311
+ else:
1312
+ raise ValueError(
1313
+ f"Unknown computation type {computation_type}"
1314
+ )
1315
+ for next_rank in all_next_ranks:
1316
+ next_rank_ops = self.pipeline_order[next_rank]
1317
+ next_rank_action = None
1318
+ if time_step < len(next_rank_ops):
1319
+ next_rank_action = next_rank_ops[time_step]
1320
+ if next_rank_action is not None:
1321
+ computation_type = next_rank_action.computation_type
1322
+ mb_index = next_rank_action.microbatch_index
1323
+ stage_index = next_rank_action.stage_index
1324
+ assert (
1325
+ mb_index is not None
1326
+ ), "All currently supported action types require valid microbatch_index"
1327
+ # Only handle receives for the backwards from a next rank
1328
+ if (
1329
+ computation_type == _ComputationType.FORWARD
1330
+ or computation_type == _ComputationType.WEIGHT
1331
+ ):
1332
+ # Next rank doing forward or weight update has no influence for the current rank backward recv
1333
+ pass
1334
+ elif computation_type == _ComputationType.BACKWARD:
1335
+ # If not the first stage, then receive bwd gradients
1336
+ if stage_index - 1 in stage_index_to_stage:
1337
+ # TODO: We are assuming that stage will always receive from stage+1
1338
+ # however that is not necessarily true of get_bwd_recv_ops
1339
+ stage = stage_index_to_stage[stage_index - 1]
1340
+ ops.extend(stage.get_bwd_recv_ops(mb_index))
1341
+ else:
1342
+ raise ValueError(
1343
+ f"Unknown computation type {computation_type}"
1344
+ )
1345
+
1346
+ # do the communication
1347
+ if ops:
1348
+ _batch_p2p(ops).wait()
1349
+ except Exception as e:
1350
+ logger.error(
1351
+ "[Rank %s] pipeline schedule %s caught the following exception \
1352
+ at time_step %s when running action %s",
1353
+ self.rank,
1354
+ self.__class__.__name__,
1355
+ time_step,
1356
+ action,
1357
+ )
1358
+ logger.error("%s", _format_pipeline_order(self.pipeline_order))
1359
+ raise e
1360
+ # Return losses if there is a container passed in
1361
+ self._update_losses(self._stages, losses)
1362
+
1363
+
1364
+ class _PipelineScheduleRuntime(PipelineScheduleMulti):
1365
+ """
1366
+ Provides a simple runtime that requires a 'schedule IR' including specified communication operations.
1367
+
1368
+ Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
1369
+ subclassed and the subclass can be responsible for creating a schedule IR.
1370
+ """
1371
+
1372
+ def _load_actions(
1373
+ self,
1374
+ actions: Dict[int, List[Optional[_Action]]],
1375
+ format: str = "compute_only",
1376
+ ):
1377
+ """
1378
+ Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
1379
+ communication actions. Stores the schedule in self, and must be called before running step_mo()
1380
+ """
1381
+ assert (
1382
+ self.stage_index_to_group_rank is not None
1383
+ ), "stage_index_to_group_rank is required for PipelineScheduleRuntime"
1384
+ self.pipeline_order_with_comms: Dict[int, List[_Action]] = {}
1385
+ if format == "compute_comms":
1386
+ for rank in actions:
1387
+ self.pipeline_order_with_comms[rank] = []
1388
+ for action in actions[rank]:
1389
+ assert action is not None
1390
+ self.pipeline_order_with_comms[rank].append(action)
1391
+ # TODO what level of validation should we offer for compute+comms schedule?
1392
+ elif format == "compute_only":
1393
+ # Perform schedule lowering
1394
+ for rank in actions:
1395
+ self.pipeline_order_with_comms[rank] = _add_unshard_reshard(
1396
+ actions[rank]
1397
+ )
1398
+
1399
+ self.pipeline_order_with_comms = _add_send_recv(
1400
+ self.pipeline_order_with_comms,
1401
+ stage_to_rank=lambda s: self.stage_index_to_group_rank[s],
1402
+ num_stages=self._num_stages,
1403
+ )
1404
+ else:
1405
+ raise NotImplementedError(f"{format=} is not implemented")
1406
+
1407
+ def _load_csv(self, filename: str, format: str = "compute_only"):
1408
+ """Loads a csv in simple format and then lowers it to include comunication actions
1409
+
1410
+ format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes
1411
+ will automatically be run to generate a compute_comms schedule.
1412
+ """
1413
+ if format == "compute_only":
1414
+ # this will populate self.pipeline_order
1415
+ super()._load_csv(filename)
1416
+ # this will populate self.pipeline_order_with_comms
1417
+ self._load_actions(self.pipeline_order)
1418
+ elif format == "compute_comms":
1419
+ actions = {}
1420
+ with open(filename, newline="") as csvfile:
1421
+ reader = csv.reader(csvfile)
1422
+ for rank, row in enumerate(reader):
1423
+ actions[rank] = [_Action.from_str(s) for s in row]
1424
+ self._load_actions(actions, format=format)
1425
+ else:
1426
+ raise NotImplementedError(f"{format=} is not implemented")
1427
+
1428
+ def _dump_csv(self, filename: str):
1429
+ """Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
1430
+ # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
1431
+ # that it does not exist if it was created from a compute_comms schedule.
1432
+ assert (
1433
+ self.pipeline_order_with_comms is not None
1434
+ ), "Must initialize compute_comms schedule before dump_csv"
1435
+ with open(filename, "w", newline="") as csvfile:
1436
+ writer = csv.writer(csvfile)
1437
+ for rank in self.pipeline_order_with_comms:
1438
+ writer.writerow(self.pipeline_order_with_comms[rank])
1439
+
1440
+ def _step_microbatches(
1441
+ self,
1442
+ arg_mbs: Optional[List] = None,
1443
+ kwarg_mbs: Optional[List] = None,
1444
+ target_mbs: Optional[List] = None,
1445
+ losses: Optional[List] = None,
1446
+ ):
1447
+ """
1448
+ Operate on the microbatches for looped schedules (multiple stages on each rank).
1449
+
1450
+ TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
1451
+ not support models with skip connections.
1452
+ """
1453
+ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
1454
+
1455
+ # Based on the plan in Step 1 created in __init__:
1456
+ # 2. Perform communication based on the pipeline_order
1457
+ stage_index_to_stage: Dict[int, _PipelineStageBase] = {
1458
+ stage.stage_index: stage for stage in self._stages
1459
+ }
1460
+
1461
+ assert (
1462
+ self.pipeline_order_with_comms is not None
1463
+ ), "Must call _load_actions() before calling _step_microbatches()"
1464
+
1465
+ # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
1466
+ bwd_recv_ops: Dict[Tuple[int, int], Work] = {}
1467
+ fwd_recv_ops: Dict[Tuple[int, int], Work] = {}
1468
+
1469
+ # send ops should be waited on before step() exists, mainly for hygeine
1470
+ send_ops: List[Work] = []
1471
+
1472
+ # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
1473
+ unshard_ops: Dict[int, UnshardHandle] = {}
1474
+ unsharded_stages = set()
1475
+
1476
+ def _assert_unsharded(stage_idx: int):
1477
+ """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""
1478
+ if stage_idx in unshard_ops:
1479
+ unshard_ops[stage_idx].wait()
1480
+ del unshard_ops[stage_idx]
1481
+ unsharded_stages.add(stage_idx)
1482
+ assert (
1483
+ stage_idx in unsharded_stages
1484
+ ), f"Attempted to compute on sharded {stage_idx=}"
1485
+
1486
+ for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
1487
+ try:
1488
+ comp_type = action.computation_type
1489
+ mb_index: int = (
1490
+ action.microbatch_index
1491
+ if action.microbatch_index is not None
1492
+ else -1
1493
+ )
1494
+ assert mb_index >= 0 or comp_type in (
1495
+ UNSHARD,
1496
+ RESHARD,
1497
+ ), f"{action=} missing mb_index"
1498
+ stage_idx = action.stage_index
1499
+ stage = stage_index_to_stage[stage_idx]
1500
+ stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
1501
+
1502
+ logger.debug(
1503
+ "_PipelineScheduleRuntime running time_step %d, action %s",
1504
+ time_step,
1505
+ action,
1506
+ )
1507
+
1508
+ # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
1509
+ # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be
1510
+ # safe to use instead.
1511
+ # However, I was wondering if I should avoid calling batched operators at all in the case that there is
1512
+ # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them.
1513
+ if comp_type == SEND_F:
1514
+ send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))
1515
+ elif comp_type == SEND_B:
1516
+ send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))
1517
+ elif comp_type == RECV_F:
1518
+ assert (
1519
+ stage_idx,
1520
+ mb_index,
1521
+ ) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward"
1522
+ fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
1523
+ stage.get_fwd_recv_ops(mb_index)
1524
+ )
1525
+ elif comp_type == RECV_B:
1526
+ assert (
1527
+ stage_idx,
1528
+ mb_index,
1529
+ ) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward"
1530
+ bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
1531
+ stage.get_bwd_recv_ops(mb_index)
1532
+ )
1533
+ elif comp_type == UNSHARD:
1534
+ if stage_uses_fsdp:
1535
+ assert (
1536
+ stage_idx not in unsharded_stages
1537
+ and stage_idx not in unshard_ops
1538
+ ), f"Unsharding the same {stage_idx=} twice"
1539
+ unshard_ops[stage_idx] = stage.submod.unshard(async_op=True)
1540
+ elif comp_type == RESHARD:
1541
+ if stage_uses_fsdp:
1542
+ assert (
1543
+ stage_idx in unsharded_stages
1544
+ ), f"Resharding {stage_idx=} without unsharding"
1545
+ assert (
1546
+ stage_idx not in unshard_ops
1547
+ ), f"Resharding {stage_idx=} before finishing unshard"
1548
+ stage.submod.reshard()
1549
+ elif comp_type == FORWARD:
1550
+ if stage_uses_fsdp:
1551
+ _assert_unsharded(stage_idx)
1552
+
1553
+ if not stage.is_first:
1554
+ assert (
1555
+ stage_idx,
1556
+ mb_index,
1557
+ ) in fwd_recv_ops, f"Computing {action=} before receiving input"
1558
+ fwd_recv_ops.pop((stage_idx, mb_index)).wait()
1559
+ output = stage.forward_one_chunk(
1560
+ mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
1561
+ )
1562
+ self._maybe_compute_loss(stage, output, target_mbs, mb_index)
1563
+ elif comp_type == BACKWARD:
1564
+ if stage_uses_fsdp:
1565
+ _assert_unsharded(stage_idx)
1566
+
1567
+ if not stage.is_last:
1568
+ assert (
1569
+ stage_idx,
1570
+ mb_index,
1571
+ ) in bwd_recv_ops, (
1572
+ f"Attempted to run compute {action=} before receiving input"
1573
+ )
1574
+ bwd_recv_ops.pop((stage_idx, mb_index)).wait()
1575
+ loss = self._maybe_get_loss(stage, mb_index)
1576
+ stage.backward_one_chunk(
1577
+ mb_index, loss=loss, full_backward=self.use_full_backward
1578
+ )
1579
+ elif comp_type == WEIGHT:
1580
+ if stage_uses_fsdp:
1581
+ _assert_unsharded(stage_idx)
1582
+
1583
+ if self.use_full_backward:
1584
+ raise ValueError(
1585
+ f"We detected a weight update in the pipeline schedule, but \
1586
+ {self.use_full_backward=}"
1587
+ )
1588
+ stage.backward_weight_one_chunk(mb_index)
1589
+ else:
1590
+ raise ValueError(f"{action=} is unknown or unsupported")
1591
+ except Exception as e:
1592
+ logger.error(
1593
+ "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:",
1594
+ time_step,
1595
+ action,
1596
+ )
1597
+ # TODO(whc) what is the best practice for printing a multiline log?
1598
+ # logger will split it into multiple log lines, but this makes it hard to read (too wide)
1599
+ print(_format_pipeline_order(self.pipeline_order_with_comms)) # type: ignore[arg-type]
1600
+ raise e
1601
+
1602
+ # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
1603
+ while len(send_ops):
1604
+ send_ops.pop().wait()
1605
+
1606
+ assert len(unshard_ops) == 0, "Unused unshard operations"
1607
+
1608
+ # Return losses if there is a container passed in
1609
+ self._update_losses(self._stages, losses)
1610
+
1611
+
1612
+ class ScheduleLoopedBFS(PipelineScheduleMulti):
1613
+ """
1614
+ Breadth-First Pipeline Parallelism.
1615
+ See https://arxiv.org/abs/2211.05953 for details.
1616
+ Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
1617
+ What is different is that when microbatches are ready for multiple local
1618
+ stages, Loops BFS will prioritizes the earlier stage, running all available
1619
+ microbatches at once.
1620
+ """
1621
+
1622
+ def __init__(
1623
+ self,
1624
+ stages: List[_PipelineStageBase],
1625
+ n_microbatches: int,
1626
+ loss_fn: Optional[Callable] = None,
1627
+ output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1628
+ ):
1629
+ super().__init__(
1630
+ stages=stages,
1631
+ n_microbatches=n_microbatches,
1632
+ loss_fn=loss_fn,
1633
+ output_merge_spec=output_merge_spec,
1634
+ )
1635
+
1636
+ # 1. Create the pipeline_order (all ranks do this calculation)
1637
+ # This will be used to keep track of the current state of the entire pipeline
1638
+ # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
1639
+ self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1640
+ # ========================================================================
1641
+ for rank in range(self.pp_group_size):
1642
+ rank_ops = self._calculate_single_rank_operations(rank)
1643
+ self.pipeline_order[rank] = rank_ops
1644
+
1645
+ def _calculate_single_rank_operations(self, rank):
1646
+ n_local_stages = len(self._stages)
1647
+ stage_indices = range(
1648
+ rank, self.pp_group_size * n_local_stages, self.pp_group_size
1649
+ )
1650
+
1651
+ # Store the list of operations used for that rank
1652
+ rank_ops: List[Optional[_Action]] = []
1653
+ # Pre-padding, rank starts with no-ops based on the warmup.
1654
+ for _ in range(rank):
1655
+ rank_ops.append(None)
1656
+
1657
+ for stage_index in stage_indices:
1658
+ for mb_index in range(self._n_microbatches):
1659
+ rank_ops.append(
1660
+ _Action(stage_index, _ComputationType.FORWARD, mb_index)
1661
+ )
1662
+
1663
+ # wait for the first backward to trickle up
1664
+ # which is 2 for every hop away
1665
+ post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
1666
+ rank_ops.extend([None] * post_warmup_ops)
1667
+
1668
+ for stage_index in reversed(stage_indices):
1669
+ for mb_index in reversed(range(self._n_microbatches)):
1670
+ rank_ops.append(
1671
+ _Action(stage_index, _ComputationType.BACKWARD, mb_index)
1672
+ )
1673
+ return rank_ops
1674
+
1675
+
1676
+ def _get_1f1b_rank_ops(
1677
+ n_local_stages,
1678
+ pp_group_size,
1679
+ warmup_ops,
1680
+ fwd_bwd_ops,
1681
+ cooldown_ops,
1682
+ rank,
1683
+ forward_stage_index,
1684
+ backward_stage_index,
1685
+ num_1f1b_microbatches=0,
1686
+ enable_zero_bubble=False,
1687
+ ):
1688
+ # All stages start with handling microbatch 0
1689
+ fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
1690
+ bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
1691
+ weight_stage_mb_index: Dict[int, int] = defaultdict(int)
1692
+
1693
+ # Store the list of operations used for that rank
1694
+ rank_ops: List[Optional[_Action]] = []
1695
+ # Pre-padding, rank starts with no-ops based on the warmup.
1696
+ for _ in range(rank):
1697
+ rank_ops.append(None)
1698
+ # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
1699
+ # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
1700
+ # Formula:
1701
+ # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
1702
+ # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
1703
+ # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
1704
+ # warmup_ops = calculated above
1705
+ post_warmup_ops = (
1706
+ n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
1707
+ ) - (warmup_ops + rank)
1708
+
1709
+ if enable_zero_bubble:
1710
+ post_warmup_ops = pp_group_size - rank - 1
1711
+
1712
+ total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
1713
+
1714
+ backward_op_ids = []
1715
+ weight_op_count = 0
1716
+
1717
+ for op in range(total_ops):
1718
+ # Warmup phase
1719
+ if op < warmup_ops:
1720
+ fwd_stage_index = forward_stage_index(op)
1721
+ # This will assign the current microbatch index and update it as well
1722
+ fwd_stage_mb_index[fwd_stage_index] = (
1723
+ mb_index := fwd_stage_mb_index[fwd_stage_index]
1724
+ ) + 1
1725
+ rank_ops.append(
1726
+ _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
1727
+ )
1728
+ if op == warmup_ops - 1:
1729
+ # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
1730
+ rank_ops.extend([None] * post_warmup_ops)
1731
+ # 1F1B Phase (forward and backward)
1732
+ elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
1733
+ fwd_stage_index = forward_stage_index(op)
1734
+ fwd_stage_mb_index[fwd_stage_index] = (
1735
+ fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
1736
+ ) + 1
1737
+ rank_ops.append(
1738
+ _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
1739
+ )
1740
+ bwd_stage_index = backward_stage_index(op)
1741
+ bwd_stage_mb_index[bwd_stage_index] = (
1742
+ bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
1743
+ ) + 1
1744
+ rank_ops.append(
1745
+ _Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
1746
+ )
1747
+ backward_op_ids.append(op)
1748
+
1749
+ if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
1750
+ weight_stage_index = backward_stage_index(
1751
+ backward_op_ids[weight_op_count]
1752
+ )
1753
+ weight_stage_mb_index[weight_stage_index] = (
1754
+ weight_mb_index := weight_stage_mb_index[weight_stage_index]
1755
+ ) + 1
1756
+ rank_ops.append(
1757
+ _Action(
1758
+ weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
1759
+ )
1760
+ )
1761
+ weight_op_count += 1
1762
+ # Cooldown phase
1763
+ else:
1764
+ # During cooldown phase, we need steps to align with 1f1b happening in other ranks
1765
+ # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
1766
+ if not enable_zero_bubble:
1767
+ rank_ops.append(None)
1768
+
1769
+ bwd_stage_index = backward_stage_index(op)
1770
+ bwd_stage_mb_index[bwd_stage_index] = (
1771
+ bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
1772
+ ) + 1
1773
+ rank_ops.append(
1774
+ _Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
1775
+ )
1776
+ backward_op_ids.append(op)
1777
+
1778
+ if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
1779
+ weight_stage_index = backward_stage_index(
1780
+ backward_op_ids[weight_op_count]
1781
+ )
1782
+ weight_stage_mb_index[weight_stage_index] = (
1783
+ weight_mb_index := weight_stage_mb_index[weight_stage_index]
1784
+ ) + 1
1785
+ rank_ops.append(
1786
+ _Action(
1787
+ weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
1788
+ )
1789
+ )
1790
+ weight_op_count += 1
1791
+
1792
+ while enable_zero_bubble and weight_op_count < len(backward_op_ids):
1793
+ weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
1794
+ weight_stage_mb_index[weight_stage_index] = (
1795
+ weight_mb_index := weight_stage_mb_index[weight_stage_index]
1796
+ ) + 1
1797
+ rank_ops.append(
1798
+ _Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
1799
+ )
1800
+ weight_op_count += 1
1801
+
1802
+ return rank_ops
1803
+
1804
+
1805
+ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
1806
+ """
1807
+ The Interleaved 1F1B schedule.
1808
+ See https://arxiv.org/pdf/2104.04473 for details.
1809
+ Will perform one forward and one backward on the microbatches in steady
1810
+ state and supports multiple stages per rank. When microbatches are ready for
1811
+ multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
1812
+ (also called "depth first").
1813
+ """
1814
+
1815
+ def __init__(
1816
+ self,
1817
+ stages: List[_PipelineStageBase],
1818
+ n_microbatches: int,
1819
+ loss_fn: Optional[Callable] = None,
1820
+ args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
1821
+ kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
1822
+ output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1823
+ ):
1824
+ self.pp_group_size = stages[0].group_size
1825
+ # TODO: is this limitation a must?
1826
+ if n_microbatches % self.pp_group_size != 0:
1827
+ raise ValueError(
1828
+ f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \
1829
+ to be a multiple of the number of pipeline ranks ({self.pp_group_size})."
1830
+ )
1831
+
1832
+ super().__init__(
1833
+ stages=stages,
1834
+ n_microbatches=n_microbatches,
1835
+ loss_fn=loss_fn,
1836
+ args_chunk_spec=args_chunk_spec,
1837
+ kwargs_chunk_spec=kwargs_chunk_spec,
1838
+ output_merge_spec=output_merge_spec,
1839
+ )
1840
+
1841
+ self.n_local_stages = len(stages)
1842
+ self.rank = stages[0].group_rank
1843
+ self.group = stages[0].group
1844
+
1845
+ # 1. Create the pipeline_order (all ranks do this calculation)
1846
+ # This will be used to keep track of the current state of the entire pipeline
1847
+ # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
1848
+ self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1849
+
1850
+ for rank in range(self.pp_group_size):
1851
+ rank_ops = self._calculate_single_rank_operations(rank)
1852
+ self.pipeline_order[rank] = rank_ops
1853
+
1854
+ def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
1855
+ def get_rank_warmup_ops(rank):
1856
+ # Warms up operations for last stage
1857
+ warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size
1858
+ # Increment warmup operations by 2 for each hop away from the last stage
1859
+ warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank)
1860
+ # We cannot have more warmup operations than there are number of microbatches, so cap it there
1861
+ return min(warmup_ops, self._n_microbatches * self.n_local_stages)
1862
+
1863
+ warmup_ops = get_rank_warmup_ops(rank)
1864
+ microbatch_ops = self.n_local_stages * self._n_microbatches
1865
+ # fwd_bwd_ops should encompass the remaining forwards
1866
+ fwd_bwd_ops = microbatch_ops - warmup_ops
1867
+ # cooldown_ops should encompass the remaining backwards
1868
+ cooldown_ops = microbatch_ops - fwd_bwd_ops
1869
+ # total ops encompass both forward and backward ops
1870
+ total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
1871
+ # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
1872
+
1873
+ logger.debug(
1874
+ "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
1875
+ rank,
1876
+ warmup_ops,
1877
+ fwd_bwd_ops,
1878
+ cooldown_ops,
1879
+ total_ops,
1880
+ )
1881
+
1882
+ # Calculates the stage index based on step and pp_group_size
1883
+ def forward_stage_index(step):
1884
+ # Get the local index from 0 to n_local_stages-1
1885
+ local_index = (step // self.pp_group_size) % self.n_local_stages
1886
+ return (local_index * self.pp_group_size) + rank
1887
+
1888
+ def backward_stage_index(step):
1889
+ local_index = (
1890
+ self.n_local_stages
1891
+ - 1
1892
+ - ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages
1893
+ )
1894
+ return (local_index * self.pp_group_size) + rank
1895
+
1896
+ return _get_1f1b_rank_ops(
1897
+ self.n_local_stages,
1898
+ self.pp_group_size,
1899
+ warmup_ops,
1900
+ fwd_bwd_ops,
1901
+ cooldown_ops,
1902
+ rank,
1903
+ forward_stage_index,
1904
+ backward_stage_index,
1905
+ )
1906
+
1907
+
1908
+ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
1909
+ """
1910
+ The Flexible Interleaved 1F1B schedule.
1911
+
1912
+ This schedule is mostly similar to the interleaved 1F1B schedule.
1913
+ It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
1914
+ Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
1915
+ it works as long as n_microbatches % num_rounds is 0. As a few examples, support
1916
+
1917
+ 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
1918
+ 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
1919
+
1920
+ When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5
1921
+ """
1922
+
1923
+ def __init__(
1924
+ self,
1925
+ stages: List[_PipelineStageBase],
1926
+ n_microbatches: int,
1927
+ loss_fn: Optional[Callable] = None,
1928
+ args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
1929
+ kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
1930
+ output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1931
+ enable_zero_bubble: bool = False,
1932
+ ):
1933
+ self.pp_group_size = stages[0].group_size
1934
+ super().__init__(
1935
+ stages=stages,
1936
+ n_microbatches=n_microbatches,
1937
+ loss_fn=loss_fn,
1938
+ args_chunk_spec=args_chunk_spec,
1939
+ kwargs_chunk_spec=kwargs_chunk_spec,
1940
+ output_merge_spec=output_merge_spec,
1941
+ use_full_backward=not enable_zero_bubble,
1942
+ )
1943
+ self.n_local_stages = len(stages)
1944
+ self.rank = stages[0].group_rank
1945
+ self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
1946
+ self.microbatches_per_round = n_microbatches // self.number_of_rounds
1947
+ self.enable_zero_bubble = enable_zero_bubble
1948
+ if n_microbatches % self.number_of_rounds != 0:
1949
+ raise ValueError(
1950
+ "Flexible Interleaved 1F1B requires the number of microbatches to be a "
1951
+ f"multiple of the number of rounds ({self.number_of_rounds}), "
1952
+ f"but got {n_microbatches}."
1953
+ )
1954
+ # 1. Create the pipeline_order (all ranks do this calculation)
1955
+ # This will be used to keep track of the current state of the entire pipeline
1956
+ # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
1957
+ self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1958
+ for rank in range(self.pp_group_size):
1959
+ rank_ops = self._calculate_single_rank_operations(rank)
1960
+ self.pipeline_order[rank] = rank_ops
1961
+
1962
+ # This function add bubbles to the generated schedule based on dependencies of actions
1963
+ # Note that the ZB1P schedule will not require bubbles to be manually added and it is
1964
+ # only useful when n_microbatches <= microbatches_per_round
1965
+ self.pipeline_order = self._add_bubbles_to_actions(
1966
+ self.n_local_stages * self.pp_group_size,
1967
+ )
1968
+
1969
+ def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
1970
+ def get_rank_warmup_ops(rank):
1971
+ # Warms up operations for last stage
1972
+ warmups_ops_last_stage = (
1973
+ self.n_local_stages - 1
1974
+ ) * self.microbatches_per_round
1975
+ # Increment warmup operations by 2 for each hop away from the last stage
1976
+ multiply_factor = 1 if self.enable_zero_bubble else 2
1977
+ warmup_ops = warmups_ops_last_stage + multiply_factor * (
1978
+ (self.pp_group_size - 1) - rank
1979
+ )
1980
+
1981
+ # We cannot have more warmup operations than there are number of microbatches, so cap it there
1982
+ return min(warmup_ops, self._n_microbatches * self.n_local_stages)
1983
+
1984
+ warmup_ops = get_rank_warmup_ops(rank)
1985
+ microbatch_ops = self.n_local_stages * self._n_microbatches
1986
+ # fwd_bwd_ops should encompass the remaining forwards
1987
+ fwd_bwd_ops = microbatch_ops - warmup_ops
1988
+ # cooldown_ops should encompass the remaining backwards
1989
+ cooldown_ops = microbatch_ops - fwd_bwd_ops
1990
+ # total ops encompass both forward and backward ops
1991
+ total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
1992
+ # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
1993
+ logger.debug(
1994
+ "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
1995
+ rank,
1996
+ warmup_ops,
1997
+ fwd_bwd_ops,
1998
+ cooldown_ops,
1999
+ total_ops,
2000
+ )
2001
+
2002
+ # Calculates the stage index based on step and pp_group_size
2003
+
2004
+ def forward_stage_index(step):
2005
+ # Get the local index from 0 to n_local_stages-1
2006
+ local_index = (step // self.microbatches_per_round) % self.n_local_stages
2007
+ return (local_index * self.pp_group_size) + rank
2008
+
2009
+ def backward_stage_index(step):
2010
+ local_index = (
2011
+ self.n_local_stages
2012
+ - 1
2013
+ - ((step - warmup_ops) // self.microbatches_per_round)
2014
+ % self.n_local_stages
2015
+ )
2016
+ return (local_index * self.pp_group_size) + rank
2017
+
2018
+ if self.enable_zero_bubble:
2019
+ num_1f1b_microbatches = rank
2020
+
2021
+ return _get_1f1b_rank_ops(
2022
+ self.n_local_stages,
2023
+ self.pp_group_size,
2024
+ warmup_ops,
2025
+ fwd_bwd_ops,
2026
+ cooldown_ops,
2027
+ rank,
2028
+ forward_stage_index,
2029
+ backward_stage_index,
2030
+ num_1f1b_microbatches,
2031
+ enable_zero_bubble=True,
2032
+ )
2033
+
2034
+ return _get_1f1b_rank_ops(
2035
+ self.n_local_stages,
2036
+ self.pp_group_size,
2037
+ warmup_ops,
2038
+ fwd_bwd_ops,
2039
+ cooldown_ops,
2040
+ rank,
2041
+ forward_stage_index,
2042
+ backward_stage_index,
2043
+ )
2044
+
2045
+ def _add_bubbles_to_actions(self, num_stages_global):
2046
+ actions = self.pipeline_order
2047
+ if not self.enable_zero_bubble:
2048
+ return actions
2049
+
2050
+ def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
2051
+ if op == _ComputationType.FORWARD:
2052
+ if stage != 0 and (stage - 1, op, microbatch) not in seen_ops:
2053
+ return True
2054
+ elif op == _ComputationType.BACKWARD:
2055
+ if stage == num_stages_global - 1:
2056
+ return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops
2057
+ return (stage + 1, op, microbatch) not in seen_ops
2058
+ return False
2059
+
2060
+ seen_ops: Set[Tuple[int, _ComputationType, int]] = set()
2061
+ result: Dict[int, List[Optional[_Action]]] = {}
2062
+ next_pointer: Dict[int, int] = {}
2063
+ bubbles_added: Dict[int, int] = {}
2064
+ total_bubbles_added = 0
2065
+
2066
+ for rank in range(self.pp_group_size):
2067
+ result[rank] = []
2068
+ next_pointer[rank] = 0
2069
+ bubbles_added[rank] = 0
2070
+
2071
+ while True:
2072
+ should_stop = True
2073
+
2074
+ temp_seen_ops: Set[Tuple[int, _ComputationType, int]] = set()
2075
+
2076
+ for rank in range(self.pp_group_size):
2077
+ timestamp = next_pointer[rank]
2078
+ if timestamp >= len(actions[rank]):
2079
+ continue
2080
+
2081
+ should_stop = False
2082
+
2083
+ if actions[rank][timestamp] is not None:
2084
+ temp_action = actions[rank][timestamp]
2085
+ assert temp_action is not None
2086
+ stage_index, op, microbatch = temp_action
2087
+ if not need_bubble(
2088
+ stage_index, op, microbatch, num_stages_global, seen_ops
2089
+ ):
2090
+ result[rank].append(actions[rank][timestamp])
2091
+ if microbatch is not None:
2092
+ temp_seen_ops.add((stage_index, op, microbatch))
2093
+ next_pointer[rank] += 1
2094
+ else:
2095
+ result[rank].append(None)
2096
+ bubbles_added[rank] += 1
2097
+ else:
2098
+ next_pointer[rank] += 1
2099
+ result[rank].append(None)
2100
+
2101
+ seen_ops.update(temp_seen_ops)
2102
+ if should_stop:
2103
+ break
2104
+
2105
+ if total_bubbles_added > 0:
2106
+ logger.warning(
2107
+ "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",
2108
+ total_bubbles_added,
2109
+ bubbles_added,
2110
+ )
2111
+ return result
2112
+
2113
+
2114
+ class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B):
2115
+ """
2116
+ The Interleaved Zero Bubble schedule.
2117
+ See https://arxiv.org/pdf/2401.10241 for details.
2118
+ Will perform one forward and one backward on inputs for the microbatches in steady
2119
+ state and supports multiple stages per rank. Uses the backward for weights to fill in
2120
+ the pipeline bubble.
2121
+ """
2122
+
2123
+ def __init__(
2124
+ self,
2125
+ stages: List[_PipelineStageBase],
2126
+ n_microbatches: int,
2127
+ loss_fn: Optional[Callable] = None,
2128
+ args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
2129
+ kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
2130
+ output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
2131
+ ):
2132
+ super().__init__(
2133
+ stages=stages,
2134
+ n_microbatches=n_microbatches,
2135
+ loss_fn=loss_fn,
2136
+ args_chunk_spec=args_chunk_spec,
2137
+ kwargs_chunk_spec=kwargs_chunk_spec,
2138
+ output_merge_spec=output_merge_spec,
2139
+ enable_zero_bubble=True,
2140
+ )
2141
+
2142
+
2143
+ def get_schedule_class(schedule_name: str):
2144
+ """
2145
+ Maps a schedule name to its corresponding class object.
2146
+
2147
+ Args:
2148
+ schedule_name (str): The name of the schedule.
2149
+ """
2150
+ schedule_map = {
2151
+ "1F1B": Schedule1F1B,
2152
+ "Interleaved1F1B": ScheduleInterleaved1F1B,
2153
+ "GPipe": ScheduleGPipe,
2154
+ "FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B,
2155
+ "LoopedBFS": ScheduleLoopedBFS,
2156
+ "InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
2157
+ "PipelineScheduleSingle": PipelineScheduleSingle,
2158
+ "PipelineScheduleMulti": PipelineScheduleMulti,
2159
+ }
2160
+ if schedule_name not in schedule_map:
2161
+ raise ValueError(f"Unknown schedule name: {schedule_name}")
2162
+ return schedule_map[schedule_name]
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/stage.py ADDED
@@ -0,0 +1,1468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+ import logging
4
+ import operator
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.fx as fx
11
+ import torch.nn as nn
12
+ from torch._subclasses.fake_tensor import FakeTensor
13
+ from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard
14
+ from torch.fx.node import map_aggregate
15
+ from torch.nn.parallel import DistributedDataParallel
16
+
17
+ from ._backward import stage_backward, stage_backward_input, stage_backward_weight
18
+ from ._debug import map_debug_info
19
+ from ._utils import flatten_args, PipeInfo, validate_tensors_metadata
20
+
21
+
22
+ __all__ = [
23
+ "PipelineStage",
24
+ "build_stage",
25
+ ]
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class _RootArgPlaceholder:
31
+ """
32
+ Placeholder for model-level inputs.
33
+ """
34
+
35
+ def __init__(self, tensor):
36
+ self.meta = tensor.to("meta")
37
+
38
+
39
+ class _RecvInfo:
40
+ """
41
+ Represents a stage input.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ input_name: str,
47
+ source: int,
48
+ buffer: torch.Tensor,
49
+ ):
50
+ # Name of this input
51
+ self.input_name = input_name
52
+ # Stage index of the source of this input
53
+ self.source = source
54
+ # Buffer to receive the input into.
55
+ self.buffer = buffer
56
+
57
+ def __repr__(self):
58
+ return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})"
59
+
60
+
61
+ # An input can be either a received activation or a model input
62
+ InputInfo = Union[_RecvInfo, _RootArgPlaceholder]
63
+
64
+
65
+ def _make_tensor_from_meta(
66
+ example: Union[torch.Tensor, FakeTensor],
67
+ device: torch.device,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Create a real tensor from a tensor.
71
+ """
72
+ return torch.empty(
73
+ example.size(),
74
+ dtype=example.dtype,
75
+ layout=example.layout,
76
+ device=device,
77
+ )
78
+
79
+
80
+ class _PipelineStageBase(ABC):
81
+ """
82
+ Base class for pipeline stages.
83
+ Defines or implements common methods used by the `_PipelineStage` used by
84
+ the tracing frontend and `PipelineStage` used by manual frontend.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ submodule: torch.nn.Module,
90
+ stage_index: int,
91
+ num_stages: int,
92
+ device: torch.device,
93
+ group: Optional[dist.ProcessGroup] = None,
94
+ dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
95
+ ):
96
+ """
97
+ Args:
98
+ submodule (torch.nn.Module): The module to be executed in this stage.
99
+ stage_index (int): The index of this stage.
100
+ num_stages (int): The total number of stages in this pipeline.
101
+ device (torch.device): The device to run this stage on.
102
+ group (Optional[dist.ProcessGroup]): The process group to use for communication.
103
+ If `None`, the default process group will be used.
104
+ Default: `None`.
105
+ dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_runner is a builder function
106
+ that will build a new dw_runner function that will run parts of module backward that were intentionally
107
+ skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs
108
+ model backwards, and stage should save the latest dw_runner to run during weight pass.
109
+ If not provided, a dw_runner will be generated automatically by traversing the autograd graph.
110
+ When used with schedules that only have F and B steps, the fresh dw_runner function will be called as
111
+ part of B.
112
+ When used with F,B,W schedules, the dw_runner function implements 'W'.
113
+ """
114
+ super().__init__()
115
+ if stage_index >= num_stages:
116
+ raise ValueError(
117
+ f"Stage index {stage_index} is out of range of {num_stages}"
118
+ )
119
+
120
+ self.submod = submodule
121
+ self.stage_index = stage_index
122
+ self.num_stages = num_stages
123
+ self.device = device
124
+ self.group = group
125
+
126
+ self.dw_builder = dw_builder
127
+
128
+ # backward state
129
+ self.backward_state: Dict[int, Tuple[Any, ...]] = {}
130
+
131
+ # store dw_runner per microbatch_id
132
+ self.dw_runner: Dict[int, Callable[..., None]] = {}
133
+
134
+ # `group_rank` is rank in process group `group`.
135
+ self.group_rank = dist.get_rank(self.group)
136
+ self.group_size = dist.get_world_size(self.group)
137
+ if self.group_size > self.num_stages:
138
+ raise RuntimeError(
139
+ f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}"
140
+ )
141
+
142
+ # Run time states
143
+ self._outputs_meta: Optional[Tuple[torch.Tensor, ...]] = None
144
+ # map microbatch ID to list of forward tensor args
145
+ self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {}
146
+ # Caching chunk outputs for final output merge or reduction
147
+ self.output_chunks: List[Any] = []
148
+
149
+ # Initialize has_backward to false; this will be set to true if loss
150
+ # function is passed to pipeline schedule
151
+ self.has_backward = False
152
+ # Log prefix
153
+ self.log_prefix = f"[Stage {self.stage_index}]"
154
+
155
+ # Forward infra
156
+ self.args_recv_info: Dict[int, Tuple[InputInfo, ...]] = {}
157
+ self.set_requires_grad: Dict[int, bool] = {}
158
+ self.act_send_info: Dict[int, List] = {}
159
+
160
+ # Backward infra will created lazily
161
+ self.grad_recv_info: Dict = {}
162
+ self.grad_send_info: Optional[List] = None
163
+
164
+ # Number of backward chunks seen. This is used to determine when to do
165
+ # grad reduction in DDP or FSDP.
166
+ self._seen_bwd_chunks = 0
167
+
168
+ # To be populated later by the Schedule
169
+ self.chunks: Optional[int] = None
170
+ self.stage_index_to_group_rank: Dict[int, int] = {
171
+ i: i % self.group_size for i in range(self.num_stages)
172
+ }
173
+
174
+ @property
175
+ def has_backward(self) -> bool:
176
+ """
177
+ Returns true if this stage has a backward pass.
178
+ """
179
+ return self._has_backward
180
+
181
+ @has_backward.setter
182
+ def has_backward(self, has_backward: bool):
183
+ self._has_backward = has_backward
184
+
185
+ @property
186
+ def is_first(self):
187
+ """
188
+ Returns true if this stage is the first stage in the pipeline.
189
+ """
190
+ return self.stage_index == 0
191
+
192
+ @property
193
+ def is_last(self):
194
+ """
195
+ Returns true if this stage is the last stage in the pipeline.
196
+ """
197
+ return self.stage_index == self.num_stages - 1
198
+
199
+ def _check_chunk_id(self, chunk_id: int):
200
+ if self.chunks is None:
201
+ raise RuntimeError(
202
+ "Attempted to access chunk_id before chunks have been configured."
203
+ )
204
+ if chunk_id >= self.chunks:
205
+ raise RuntimeError(
206
+ f"Chunk id {chunk_id} is out of range [0, {self.chunks})"
207
+ )
208
+
209
+ def _configure_outputs_meta(self, outputs_meta: Tuple[torch.Tensor, ...]):
210
+ """
211
+ Track the output shapes/dtype of this stage since they determine the send operation(s) which must match
212
+ recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial
213
+ configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
214
+ which could show up as hangs, silent corruption, or other errors.
215
+ """
216
+ assert (
217
+ self._outputs_meta is None
218
+ ), "Attempting to reconfigure output_meta, which is not supported"
219
+ self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
220
+
221
+ def get_outputs_meta(self) -> Tuple[torch.Tensor, ...]:
222
+ """Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
223
+ assert (
224
+ self._outputs_meta is not None
225
+ ), "Attempted to get_outputs_meta() without configuring output meta"
226
+ return self._outputs_meta
227
+
228
+ def _create_grad_send_info(
229
+ self,
230
+ args_recv_info: Tuple,
231
+ ) -> List[Optional[int]]:
232
+ """
233
+ Create a list of stage indices to send gradients to.
234
+ """
235
+ grad_send_info: List[Optional[int]] = []
236
+
237
+ def map_recv_to_send(a):
238
+ # Note: we send gradients back to previous stage as long as in
239
+ # forward it is a received input, regardless of whether it requires
240
+ # grad. It is up to the previous stage to disgard this gradient.
241
+ if isinstance(a, _RecvInfo):
242
+ grad_send_info.append(a.source)
243
+ return a.source
244
+ else:
245
+ grad_send_info.append(None)
246
+ return None
247
+
248
+ map_aggregate(args_recv_info, map_recv_to_send)
249
+
250
+ logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info)
251
+ return grad_send_info
252
+
253
+ @abstractmethod
254
+ def _prepare_forward_infra(self, num_microbatches: int):
255
+ raise NotImplementedError
256
+
257
+ def _prepare_backward_infra(self, num_microbatches: int):
258
+ # TODO: this is needed for backward_maybe_with_nosync
259
+ self.chunks = num_microbatches
260
+
261
+ for mb_index in range(num_microbatches):
262
+ # `grad_recv_info` is a mirror of `act_send_info`
263
+ self.grad_recv_info[mb_index] = self._create_grad_recv_info(
264
+ self.act_send_info
265
+ )
266
+
267
+ @abstractmethod
268
+ def _create_grad_recv_info(
269
+ self,
270
+ act_send_info: Dict,
271
+ ) -> Tuple[_RecvInfo, ...]:
272
+ raise NotImplementedError
273
+
274
+ def _get_recv_ops(
275
+ self,
276
+ recv_infos: Tuple[InputInfo, ...],
277
+ ) -> List[dist.P2POp]:
278
+ """
279
+ Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`.
280
+ Returns a list of ops that correspond to the recv infos.
281
+ """
282
+ ops: List[dist.P2POp] = []
283
+ for info in recv_infos:
284
+ if not isinstance(info, _RecvInfo):
285
+ continue
286
+
287
+ peer_rank = self.stage_index_to_group_rank[info.source]
288
+ peer_global_rank = (
289
+ peer_rank
290
+ if self.group is None
291
+ else dist.get_global_rank(self.group, peer_rank)
292
+ ) # TODO
293
+ ops.append(
294
+ dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group)
295
+ )
296
+
297
+ return ops
298
+
299
+ def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
300
+ """
301
+ Returns a list of ops that are needed to receive the input arguments
302
+ for this stage.
303
+ """
304
+ recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id]
305
+
306
+ # In case there is backward pass, set requires_grad for receive buffers
307
+ # before first forward
308
+ if self.has_backward and not self.set_requires_grad[fwd_chunk_id]:
309
+ for a in recv_infos:
310
+ if isinstance(a, _RecvInfo):
311
+ a.buffer.requires_grad_(True)
312
+
313
+ return self._get_recv_ops(recv_infos)
314
+
315
+ def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
316
+ """
317
+ Returns a list of ops that are needed to receive the gradients
318
+ for this stage.
319
+ """
320
+ if not self.has_backward or self.is_last:
321
+ return []
322
+
323
+ recv_infos = self.grad_recv_info[bwd_chunk_id]
324
+ return self._get_recv_ops(recv_infos)
325
+
326
+ def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
327
+ """
328
+ Get the activation send ops for current stage's forward.
329
+ """
330
+ output = self.output_chunks[fwd_chunk_id]
331
+ # Unify output form to tuple for easy correspondance with
332
+ # `act_send_info`
333
+ output_tuple = output if type(output) is tuple else (output,)
334
+
335
+ ops: List[dist.P2POp] = []
336
+
337
+ for idx, out in enumerate(output_tuple):
338
+ dst_stages = self.act_send_info[idx]
339
+ for dst in dst_stages:
340
+ if dst is None:
341
+ continue
342
+ logger.debug(
343
+ "%s Sending tensor to Stage %s: %s",
344
+ self.log_prefix,
345
+ dst,
346
+ out.size(),
347
+ )
348
+ peer_rank = self.stage_index_to_group_rank[dst]
349
+ peer_global_rank = (
350
+ peer_rank
351
+ if self.group is None
352
+ else dist.get_global_rank(self.group, peer_rank)
353
+ ) # TODO
354
+ ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group))
355
+
356
+ return ops
357
+
358
+ def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
359
+ """
360
+ Get the gradient send ops for current stage's backward.
361
+ """
362
+ self._check_chunk_id(bwd_chunk_id)
363
+
364
+ if not self.has_backward or self.is_first:
365
+ return []
366
+
367
+ # Create bwd send infra lazily
368
+ if self.grad_send_info is None:
369
+ # Send info for input grads during backward:
370
+ # List of destinations corresponding to input grads
371
+ # Can be None if an input has no grad
372
+ # `grad_send_info` is a mirror of `args_recv_info`
373
+ self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0])
374
+
375
+ ops: List[dist.P2POp] = []
376
+ for grad, grad_recv_stage in zip(self.grads_input, self.grad_send_info):
377
+ if isinstance(grad, torch.Tensor) and grad_recv_stage is not None:
378
+ logger.debug(
379
+ "%s Sending gradient to Stage %s: %s",
380
+ self.log_prefix,
381
+ grad_recv_stage,
382
+ grad.size(),
383
+ )
384
+ peer_rank = self.stage_index_to_group_rank[grad_recv_stage]
385
+ peer_global_rank = (
386
+ peer_rank
387
+ if self.group is None
388
+ else dist.get_global_rank(self.group, peer_rank)
389
+ ) # TODO
390
+ ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group))
391
+ else:
392
+ if not (grad is None and grad_recv_stage is None):
393
+ raise RuntimeError(
394
+ f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} "
395
+ f"and is expecting to send gradients to stage {grad_recv_stage}"
396
+ )
397
+ return ops
398
+
399
+ def clear_runtime_states(self) -> None:
400
+ """
401
+ Clear runtime states of the stage.
402
+ """
403
+ # map microbatch ID to list of forward tensor args
404
+ self.fwd_cache.clear()
405
+ # Caching chunk outputs for final output merge or reduction
406
+ self.output_chunks.clear()
407
+ # Reset bwd chunk counter
408
+ self._seen_bwd_chunks = 0
409
+
410
+ # Clear grad of input buffers in between schedule steps. This is because
411
+ # `torch.autograd.backward()` will accumulate gradients into leaf
412
+ # tensors by default. For gradients to pass back to previous stages, we
413
+ # don't want such accumulation.
414
+ for recv_tuple in self.args_recv_info.values(): # iterate over all chunks
415
+ for a in recv_tuple: # iterate over all input args
416
+ if isinstance(a, _RecvInfo):
417
+ # Set to None is the newer and recommended way to clear grads, compared to `zero_()`.
418
+ # See https://github.com/pytorch/pytorch/pull/92731
419
+ a.buffer.grad = None
420
+
421
+ def _map_tensor_from_recv_info(
422
+ self,
423
+ recv_infos: Tuple[InputInfo, ...],
424
+ ):
425
+ """
426
+ Map tensors from recv infos to a list.
427
+ """
428
+
429
+ def get_recv_tensor(info):
430
+ if isinstance(info, _RecvInfo):
431
+ return info.buffer
432
+ else:
433
+ raise AssertionError(f"Expected _RecvInfo but got {type(info)}")
434
+
435
+ tensors = map_aggregate(
436
+ recv_infos,
437
+ get_recv_tensor,
438
+ )
439
+
440
+ return tensors
441
+
442
+ def _retrieve_recv_activations(self, fwd_chunk_id: int):
443
+ """
444
+ Retrieve the activations received for the current stage during forward.
445
+ """
446
+ recv_infos = self.args_recv_info[fwd_chunk_id]
447
+ activations = self._map_tensor_from_recv_info(recv_infos)
448
+ return activations
449
+
450
+ def _retrieve_recv_grads(
451
+ self,
452
+ bwd_chunk_id: int,
453
+ ):
454
+ """
455
+ Retrieve the gradients received for the current stage during backward.
456
+ """
457
+ recv_infos = self.grad_recv_info[bwd_chunk_id]
458
+ grads = self._map_tensor_from_recv_info(recv_infos)
459
+ return grads
460
+
461
+ def forward_maybe_with_nosync(self, *args, **kwargs):
462
+ # If submod is wrapped with DDP, we use the `no_sync` context manager to
463
+ # avoid gradient all-reduce per microbatch
464
+ if isinstance(self.submod, DistributedDataParallel):
465
+ with self.submod.no_sync(): # type: ignore[operator]
466
+ out_val = self.submod(*args, **kwargs)
467
+ else:
468
+ out_val = self.submod(*args, **kwargs)
469
+ return out_val
470
+
471
+ def backward_maybe_with_nosync(self, backward_type, bwd_kwargs: Dict):
472
+ """
473
+ Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the
474
+ other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but
475
+ there are additional state-variables and performance considerations depending on the data parallelism used.
476
+ This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries.
477
+ """
478
+ full_backward = bwd_kwargs["full_backward"]
479
+ if full_backward:
480
+ last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator]
481
+ else:
482
+ # For backwards are split into weight and input, we will see twice as many bwd_chunks
483
+ last_backward = self._seen_bwd_chunks == 2 * self.chunks - 1 # type: ignore[operator]
484
+
485
+ def perform_backward(backward_type):
486
+ if backward_type == "full":
487
+ return lambda: stage_backward(
488
+ bwd_kwargs["stage_output"],
489
+ bwd_kwargs["output_grads"],
490
+ bwd_kwargs["input_values"],
491
+ )
492
+ elif backward_type == "input":
493
+ return lambda: stage_backward_input(
494
+ bwd_kwargs["stage_output"],
495
+ bwd_kwargs["output_grads"],
496
+ bwd_kwargs["input_values"],
497
+ self.submod.parameters(),
498
+ )
499
+ elif backward_type == "weight":
500
+ return lambda: stage_backward_weight(
501
+ self.submod.parameters(), bwd_kwargs["param_groups"]
502
+ )
503
+ else:
504
+ raise RuntimeError(f"Unknown backward type: {backward_type}")
505
+
506
+ # If submod is wrapped by DDP
507
+ if isinstance(self.submod, DistributedDataParallel):
508
+ if last_backward:
509
+ # Last chunk, prepare for gradient reduction
510
+ # HACK: reaching into DDP implementation details here. Is there a better way?
511
+ self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator]
512
+ list(
513
+ torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined]
514
+ bwd_kwargs["stage_output"]
515
+ )
516
+ )
517
+ )
518
+ result = perform_backward(backward_type)()
519
+ else:
520
+ with self.submod.no_sync(): # type: ignore[operator]
521
+ result = perform_backward(backward_type)()
522
+ # If submod is a FSDP module
523
+ elif isinstance(self.submod, FSDPModule):
524
+ self.submod.set_is_last_backward(False)
525
+ self.submod.set_reshard_after_backward(False)
526
+ self.submod.set_requires_gradient_sync(False)
527
+ result = perform_backward(backward_type)()
528
+ if last_backward:
529
+ # Manually call post backward for FSDP
530
+ def run_post_backward(fsdp_module: FSDPModule) -> None:
531
+ fsdp_module.set_is_last_backward(True)
532
+ fsdp_module.set_reshard_after_backward(True)
533
+ fsdp_module.set_requires_gradient_sync(True)
534
+ fsdp_state = fully_shard.state(fsdp_module)
535
+ for state in fsdp_state._state_ctx.all_states:
536
+ if state._fsdp_param_group:
537
+ state._fsdp_param_group.post_backward()
538
+
539
+ run_post_backward(self.submod)
540
+ else:
541
+ # Non-DP submodule, regular backward
542
+ result = perform_backward(backward_type)()
543
+
544
+ self._seen_bwd_chunks += 1
545
+
546
+ if isinstance(result, tuple) and len(result) == 2:
547
+ # for stage_backward_input()
548
+ grads, param_groups = result
549
+ else:
550
+ grads, param_groups = result, None
551
+
552
+ return grads, param_groups
553
+
554
+ def forward_one_chunk(
555
+ self,
556
+ fwd_chunk_id: int,
557
+ args: Tuple[Any, ...],
558
+ kwargs: Optional[Dict[str, Any]] = None,
559
+ ):
560
+ """
561
+ Perform forward pass on the stage with one microbatch.
562
+ `args` and `kwargs` are the inputs from *external* to this stage. They
563
+ applies only to the first stage in most cases.
564
+ """
565
+
566
+ if self.is_first:
567
+ # First stage doesn't need to receive anything
568
+ composite_args = args
569
+ composite_kwargs = kwargs or {}
570
+ else:
571
+ # Receive activations for this chunk
572
+ # Activations only come in args form
573
+ composite_args = self._retrieve_recv_activations(fwd_chunk_id)
574
+ composite_kwargs = {}
575
+
576
+ self._validate_fwd_input(args, kwargs)
577
+
578
+ # Compute forward
579
+ try:
580
+ output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
581
+
582
+ except Exception as e:
583
+ exc_msg = f"""
584
+ {self.log_prefix} failed to run forward:
585
+ args: {map_debug_info(composite_args)}
586
+ kwargs: {map_debug_info(composite_kwargs)}
587
+ """
588
+ raise RuntimeError(exc_msg) from e
589
+
590
+ if type(output) is list:
591
+ # HACK: this is a hacky workaround for the fact that export creates
592
+ # output in list format
593
+ output = tuple(output)
594
+
595
+ # Unify output form to tuple for easy correspondance with
596
+ # `act_send_info`
597
+ output_tuple = output if type(output) is tuple else (output,)
598
+ # Prepare for final output merge or reduction
599
+ self.output_chunks.append(output)
600
+
601
+ # Save activations and inputs for backward
602
+ flat_args = flatten_args(composite_args)
603
+ flat_kwargs = flatten_args(composite_kwargs)
604
+ flatten_input_tensors = flat_args + flat_kwargs
605
+ self.fwd_cache[fwd_chunk_id] = (
606
+ output_tuple, # stage_output
607
+ flatten_input_tensors, # input_values
608
+ )
609
+
610
+ logger.debug(
611
+ "%s Forwarded chunk %s, outputs: %s",
612
+ self.log_prefix,
613
+ fwd_chunk_id,
614
+ map_debug_info(output),
615
+ )
616
+ self._validate_fwd_outputs(output_tuple)
617
+ return output
618
+
619
+ def backward_one_chunk(
620
+ self, bwd_chunk_id: int, loss=None, full_backward: bool = True
621
+ ):
622
+ """
623
+ Perform backward pass on the module.
624
+ This should only be called once per microbatch.
625
+
626
+ If full_backward is True (the default), the full backward pass including weight and input gradients will be run,
627
+ and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id.
628
+
629
+ If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time,
630
+ and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward.
631
+ """
632
+ self._check_chunk_id(bwd_chunk_id)
633
+
634
+ (
635
+ stage_output,
636
+ input_values,
637
+ ) = self.fwd_cache.pop(bwd_chunk_id)
638
+
639
+ # Compute backward
640
+ if self.is_last:
641
+ # Last stage computes gradients from loss and has no gradients from
642
+ # next stage
643
+ bwd_kwargs = {
644
+ "stage_output": loss,
645
+ "output_grads": None,
646
+ "input_values": input_values,
647
+ }
648
+ else:
649
+ # Otherwise, receive gradients from next stage
650
+ grads_output = self._retrieve_recv_grads(bwd_chunk_id)
651
+ # If an input to the pipeline requires gradient,
652
+ # `torch.autograd.backward` will accumulate the gradient into the
653
+ # `.grad` field of such input
654
+ bwd_kwargs = {
655
+ "stage_output": stage_output,
656
+ "output_grads": grads_output,
657
+ "input_values": input_values,
658
+ }
659
+
660
+ # Save full_backward
661
+ bwd_kwargs["full_backward"] = full_backward
662
+
663
+ # Custom backward function
664
+ if self.dw_builder:
665
+ # TODO: We may want to change our semantics so we are allowed to ignore
666
+ # the 'dw_builder' and call full_backward directly when it is a full_backward op.
667
+ self.grads_input, _ = self.backward_maybe_with_nosync("full", bwd_kwargs)
668
+ if full_backward:
669
+ self.dw_builder()()
670
+ else:
671
+ self.dw_runner[bwd_chunk_id] = self.dw_builder()
672
+ else:
673
+ if full_backward:
674
+ self.grads_input, _ = self.backward_maybe_with_nosync(
675
+ "full", bwd_kwargs
676
+ )
677
+ else:
678
+ # perform the partial backwards for the inputs with a custom backward function
679
+ # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors
680
+ if isinstance(bwd_kwargs["stage_output"], torch.Tensor):
681
+ bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],)
682
+
683
+ grads_input, param_groups = self.backward_maybe_with_nosync(
684
+ "input", bwd_kwargs
685
+ )
686
+
687
+ # TODO: we dont need to save this, add to dw_runner?
688
+ self.backward_state[bwd_chunk_id] = (
689
+ input_values,
690
+ param_groups,
691
+ bwd_kwargs["stage_output"],
692
+ bwd_kwargs["output_grads"],
693
+ )
694
+ self.grads_input = grads_input
695
+ # Save a placeholder for the dw_runner
696
+ self.dw_runner[bwd_chunk_id] = lambda: None
697
+ logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id)
698
+
699
+ def backward_weight_one_chunk(self, bwd_chunk_id: int):
700
+ assert bwd_chunk_id in self.dw_runner, (
701
+ f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}"
702
+ " without first calling `backward_one_chunk(full_backward=False)`"
703
+ )
704
+
705
+ if self.dw_builder is not None:
706
+ self.dw_runner.pop(bwd_chunk_id)()
707
+ else:
708
+ (
709
+ input_values,
710
+ param_groups,
711
+ stage_output,
712
+ output_grads,
713
+ ) = self.backward_state.pop(bwd_chunk_id)
714
+
715
+ if self.stage_index != 0:
716
+ bwd_kwargs = {
717
+ "stage_output": stage_output,
718
+ "param_groups": param_groups,
719
+ "full_backward": False,
720
+ }
721
+ weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs)
722
+ else:
723
+ # TODO: figure out a better way to do this:
724
+ # if inputs does not require gradient,
725
+ # then the parameter group will not be fully captured during stage_backward_input
726
+ # in this case, we need call grad directly on the parameters
727
+ # To solve: make input fn do the intersect compute and then finish it off during W
728
+ bwd_kwargs = {
729
+ "stage_output": stage_output,
730
+ "output_grads": output_grads,
731
+ "input_values": input_values,
732
+ "full_backward": False,
733
+ }
734
+ self.backward_maybe_with_nosync("full", bwd_kwargs)
735
+
736
+ def _validate_fwd_input(self, args, kwargs):
737
+ """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage."""
738
+
739
+ if self.is_first:
740
+ # TODO why is there a separate recv_info for each pipeline chunk?
741
+ # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we
742
+ # check all chunks against args_recv_info[0]
743
+ expected_args = self.args_recv_info[0]
744
+ else:
745
+ # We don't check inputs for non-0 stages assuming they don't accept
746
+ # user inputs in canonical pipeline scenarios
747
+ return
748
+
749
+ if len(kwargs):
750
+ # TODO- need a mapping of kwarg to position in self.args_recv_info
751
+ # without it, we just validate shapes for args and ignore kwargs
752
+ expected_args = expected_args[: len(expected_args) - len(kwargs)]
753
+
754
+ # TODO- need a mapping of kwarg to position in self.args_recv_info
755
+ # maybe it's impossible to tell whether the len mismatches because
756
+ # (a) the user passed an extra arg or missed an arg
757
+ # (b) the user did not pass a kwarg, which has a default value baked into expected_args
758
+ expected_tensors_meta = [
759
+ e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer
760
+ for e in expected_args
761
+ ]
762
+ validate_tensors_metadata(
763
+ f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args
764
+ )
765
+
766
+ def _validate_fwd_outputs(self, outputs: Tuple[torch.Tensor, ...]):
767
+ """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype.
768
+ Most likely, this could be cause either by incorrect user specification of output shapes, or becuase
769
+ shape inference was done on the original model but then at runtime the model is wrapped with something like
770
+ mixed precision which changes output dtype.
771
+ """
772
+ expected_tensors_meta = self.get_outputs_meta()
773
+ validate_tensors_metadata(
774
+ f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs
775
+ )
776
+
777
+
778
+ class _PipelineStage(_PipelineStageBase):
779
+ def __init__(
780
+ self,
781
+ stage_module: torch.nn.Module,
782
+ stage_index: int,
783
+ pipe_info: PipeInfo,
784
+ device: torch.device,
785
+ group: Optional[dist.ProcessGroup] = None,
786
+ ):
787
+ """
788
+ Create a pipeline stage given a stage_module to be wrapped by this stage
789
+ and a `pipe_info` describing the stage relationship of the pipeline.
790
+
791
+ Args:
792
+ stage_module (torch.nn.Module): the module to be wrapped by this stage
793
+ stage_index (int): the index of this stage in the pipeline
794
+ pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()`
795
+ device (torch.device): the device to be used by this stage
796
+ group (Optional[dist.ProcessGroup]): the process group to be used by this stage
797
+ """
798
+ _PipelineStageBase.__init__(
799
+ self,
800
+ stage_module,
801
+ stage_index,
802
+ pipe_info.num_stages,
803
+ device,
804
+ group,
805
+ )
806
+ self.pipe_info = pipe_info
807
+
808
+ # Find stage nodes in graph
809
+ submod_nodes = [
810
+ node for node in pipe_info.graph.nodes if node.op == "call_module"
811
+ ]
812
+ if len(submod_nodes) != self.num_stages:
813
+ raise AssertionError(
814
+ f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}"
815
+ )
816
+
817
+ # Find my stage node in graph
818
+ self.node = submod_nodes[self.stage_index]
819
+ self.name = self.node.name
820
+ logger.info(
821
+ "[%s] Creating PipelineStage %s for %s",
822
+ self.group_rank,
823
+ stage_index,
824
+ self.name,
825
+ )
826
+
827
+ # Create mapping from stage name to stage index
828
+ self.submod_to_stage_index: Dict[str, int] = {}
829
+ for i, node in enumerate(submod_nodes):
830
+ self.submod_to_stage_index.setdefault(node.name, i)
831
+
832
+ # Cast submodule to device
833
+ self._move_submod_to_device()
834
+
835
+ def _move_submod_to_device(self):
836
+ # Move submodule to indicated device if possible
837
+ # Note: we cannot move meta module to real devices because meta tensors
838
+ # do not support to() method. One needs to do an in-place tensor swap in
839
+ # that case.
840
+ has_meta_param = any(
841
+ isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters()
842
+ )
843
+ if has_meta_param:
844
+ logger.debug("%s Found meta parameters!", self.log_prefix)
845
+ else:
846
+ self.submod.to(self.device)
847
+
848
+ def _prepare_forward_infra(self, num_microbatches: int):
849
+ """
850
+ Create send/recv infrastructures for activations (during forward)
851
+ """
852
+ # Flag per chunk to keep track of whether we have set `requires_grad`
853
+ # for receive buffers. Format: {chunk : Boolean}
854
+ for chunk in range(num_microbatches):
855
+ self.args_recv_info[chunk] = self._create_act_recv_info()
856
+ self.set_requires_grad[chunk] = False
857
+
858
+ # Send info during forward for each activation
859
+ self.act_send_info = self._create_act_send_info()
860
+
861
+ def get_stage_index_of_submod(
862
+ self,
863
+ submod_name: str,
864
+ ):
865
+ """
866
+ Given a submodule name, return the stage index of the submodule.
867
+ """
868
+ if submod_name not in self.submod_to_stage_index:
869
+ raise AssertionError(f"Stage id of {submod_name} not found")
870
+
871
+ return self.submod_to_stage_index[submod_name]
872
+
873
+ def _create_act_recv_info(
874
+ self,
875
+ ):
876
+ """
877
+ Create a tuple of `_RecvInfo` for inputs to the stage.
878
+ """
879
+
880
+ def create_recv_tensor(placeholder, arg_node):
881
+ """
882
+ Create a receive buffer for a placeholder.
883
+ """
884
+ example_value = placeholder.meta["val"]
885
+ if arg_node.op == "placeholder":
886
+ # This is a root level placeholder, thus an input argument to the entire model.
887
+ # We are likely at stage 0, hence no need to create a receive buffer.
888
+ return _RootArgPlaceholder(example_value)
889
+
890
+ # Figure out the source stage of this input
891
+ while arg_node.target is operator.getitem:
892
+ # If the input is a getitem, we need to go deeper
893
+ arg_node = arg_node.args[0]
894
+
895
+ assert (
896
+ arg_node.op == "call_module"
897
+ ), f"Expecting call_module, got {arg_node.op}"
898
+ src_stage = self.get_stage_index_of_submod(arg_node.name)
899
+
900
+ # Create a receive buffer for this placeholder
901
+ logger.debug(
902
+ "%s Creating recv buffer for input '%s' : %s, %s",
903
+ self.log_prefix,
904
+ placeholder.name,
905
+ example_value.shape,
906
+ example_value.dtype,
907
+ )
908
+ buffer = _make_tensor_from_meta(example_value, self.device)
909
+
910
+ return _RecvInfo(
911
+ arg_node.name,
912
+ src_stage,
913
+ buffer,
914
+ )
915
+
916
+ args_recv_info: List[InputInfo] = []
917
+ # Filter out placeholder nodes from `self.submod` (a GraphModule)
918
+ placeholders = filter(
919
+ lambda node: node.op == "placeholder", self.submod.graph.nodes
920
+ )
921
+ # `placeholders` are nodes internal to submod.
922
+ # `self.node.args` are dependency nodes in the outer graph.
923
+ # The two are 1:1.
924
+ for placeholder, arg_node in zip(placeholders, self.node.args):
925
+ # Create a receive buffer for this placeholder
926
+ recv_info = create_recv_tensor(placeholder, arg_node)
927
+ args_recv_info.append(recv_info)
928
+
929
+ logger.debug(
930
+ "%s Activation recv / args info: %s", self.log_prefix, args_recv_info
931
+ )
932
+ # `args` is a Tuple, hence we will return a Tuple[InputInfo]
933
+ return tuple(args_recv_info)
934
+
935
+ def find_dst_rank(
936
+ self,
937
+ user: fx.Node,
938
+ ) -> Optional[int]:
939
+ """
940
+ Find the destination rank of a `user` node.
941
+ If the `user` is not a submod, `None` may be returned.
942
+ """
943
+ if user.op == "call_module":
944
+ # User is a stage (`call_module`)
945
+ return self.get_stage_index_of_submod(user.name)
946
+ else:
947
+ # - If user.op == "output":
948
+ # No need to send back to rank 0
949
+ # - If user.target is stage_backward:
950
+ # No need to send assuming submod output is stored locally or
951
+ # should be re-calucated in case of activation checkpointing
952
+ return None
953
+
954
+ def _create_act_send_info(self):
955
+ """
956
+ Create a dict of send info for activations.
957
+ The dict is of the form:
958
+ {
959
+ output_index: [dst_rank_0, dst_rank_1, ...],
960
+ ...
961
+ }
962
+ where the list of `dst_rank`s covers the case where an output value may
963
+ be consumed by multiple stages.
964
+ """
965
+ # Output index: List of receiver ranks
966
+ act_send_info: Dict[int, List] = {}
967
+ out_idx = 0
968
+
969
+ for user in self.node.users:
970
+ if user.target is operator.getitem:
971
+ # Recursively find the real destination
972
+ gi_dsts = act_send_info.setdefault(out_idx, [])
973
+ for gi_user in user.users:
974
+ dst_rank = self.find_dst_rank(gi_user)
975
+ if dst_rank is not None:
976
+ gi_dsts.append(dst_rank)
977
+ # Next `getitem` will point to the next output index
978
+ out_idx += 1
979
+ else:
980
+ # In case of single output value, `out_idx` will not increase
981
+ dsts = act_send_info.setdefault(out_idx, [])
982
+ dst_rank = self.find_dst_rank(user)
983
+ if dst_rank is not None:
984
+ dsts.append(dst_rank)
985
+
986
+ output_node = self._get_output_node()
987
+ output_vals: Tuple[torch.Tensor] = tuple(
988
+ v.meta["val"] for v in flatten_args(output_node.args)
989
+ )
990
+ self._configure_outputs_meta(output_vals)
991
+
992
+ logger.debug("%s Send info: %s", self.log_prefix, act_send_info)
993
+ return act_send_info
994
+
995
+ def _get_output_node(self):
996
+ output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"]
997
+ assert len(output_nodes) == 1
998
+ output_node = output_nodes[0]
999
+ return output_node
1000
+
1001
+ def _create_grad_recv_info(
1002
+ self,
1003
+ act_send_info: Dict,
1004
+ ) -> Tuple[_RecvInfo, ...]:
1005
+ """
1006
+ Create a tuple of `_RecvInfo` for gradients.
1007
+ """
1008
+ # Dict[output_index, _RecvInfo]
1009
+ grad_recv_info: Dict[int, _RecvInfo] = {}
1010
+ output_node = self._get_output_node()
1011
+
1012
+ # The output node may take multiple args, meaning the submod having multiple output values.
1013
+ output_vals = flatten_args(output_node.args)
1014
+
1015
+ for out_idx, dst_list in act_send_info.items():
1016
+ if not dst_list:
1017
+ # No actual receiver for activation so no grad coming back
1018
+ continue
1019
+
1020
+ output = output_vals[out_idx]
1021
+ example_value = output.meta["val"]
1022
+ logger.debug(
1023
+ f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004
1024
+ f": {example_value.shape}, {example_value.dtype}"
1025
+ )
1026
+
1027
+ # TODO: otherwise needs grad accumulation
1028
+ assert len(dst_list) == 1, "Backward of skip connections not supported yet"
1029
+ grad_src = dst_list[0]
1030
+ grad_recv_info[out_idx] = _RecvInfo(
1031
+ f"{grad_src}", # noqa: G004
1032
+ grad_src,
1033
+ _make_tensor_from_meta(example_value, self.device),
1034
+ )
1035
+
1036
+ # Convert to tuple for convenience in get_ops and retrieve tensor
1037
+ grad_recv_info_tuple = tuple(grad_recv_info.values())
1038
+ logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple)
1039
+ return grad_recv_info_tuple
1040
+
1041
+
1042
+ # A helper function to create a pipeline stage based on traced pipeline information
1043
+ def build_stage(
1044
+ stage_module: torch.nn.Module,
1045
+ stage_index: int,
1046
+ pipe_info: PipeInfo,
1047
+ device: torch.device,
1048
+ group: Optional[dist.ProcessGroup] = None,
1049
+ ) -> _PipelineStage:
1050
+ """
1051
+ Create a pipeline stage given a stage_module to be wrapped by this stage
1052
+ and pipeline information.
1053
+
1054
+ Args:
1055
+ stage_module (torch.nn.Module): the module to be wrapped by this stage
1056
+ stage_index (int): the index of this stage in the pipeline
1057
+ pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()`
1058
+ device (torch.device): the device to be used by this stage
1059
+ group (Optional[dist.ProcessGroup]): the process group to be used by this stage
1060
+
1061
+ Returns:
1062
+ _PipelineStage: a pipeline stage that can run with `PipelineSchedules`.
1063
+ """
1064
+ return _PipelineStage(
1065
+ stage_module,
1066
+ stage_index,
1067
+ pipe_info,
1068
+ device,
1069
+ group,
1070
+ )
1071
+
1072
+
1073
+ # Manual PipelineStage functions and definition
1074
+
1075
+ METADATA_TENSOR_LEN = 100
1076
+ PLACEHOLDER_VAL = -1
1077
+
1078
+
1079
+ def _create_empty_tensors(
1080
+ tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device
1081
+ ) -> List[torch.Tensor]:
1082
+ """
1083
+ Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s),
1084
+ and places them on the specified device.
1085
+ Args:
1086
+ tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s).
1087
+ device (torch.device): The device where the new tensors will be placed.
1088
+ Returns:
1089
+ List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s).
1090
+ """
1091
+ if isinstance(tensor, torch.Tensor):
1092
+ return [torch.empty_like(tensor, device=device)]
1093
+ elif isinstance(tensor, (list, tuple)):
1094
+ return [torch.empty_like(t, device=device) for t in tensor]
1095
+ raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors")
1096
+
1097
+
1098
+ def _create_metadata_tensor(
1099
+ tensors: Optional[List[torch.Tensor]] = None,
1100
+ device: Optional[torch.device] = torch.device("cpu"),
1101
+ ) -> torch.Tensor:
1102
+ """
1103
+ Create a metadata tensor that can be sent over the wire.
1104
+ This tensor contains the number of dimensions and the shape of each tensor being sent.
1105
+
1106
+ The data is of format [num_dims, dim1, dim2, ...].
1107
+ If the tensor is None, a tensor of only placeholder values will be returned.
1108
+
1109
+ Inputs:
1110
+ tensors: A list of tensors, the tensors will converted into its shape dimensions and
1111
+ these dimensions will be concatenated.
1112
+ device: The device where the metadata tensor will be created.
1113
+ If the tensor is None, then this tensor will contain PLACEHOLDER_VALs.
1114
+
1115
+ """
1116
+ metadata_tensor = torch.full(
1117
+ (METADATA_TENSOR_LEN,),
1118
+ PLACEHOLDER_VAL,
1119
+ dtype=torch.int32,
1120
+ device=device,
1121
+ )
1122
+ if tensors:
1123
+ # Create a list of tensors containing the number of dimensions and the shape of each tensor
1124
+ data = [
1125
+ # data is of format [num_dims, dim1, dim2, ...]
1126
+ torch.tensor(
1127
+ [len(tensor.shape)] + list(tensor.shape),
1128
+ dtype=torch.int32,
1129
+ device=device,
1130
+ )
1131
+ for tensor in tensors
1132
+ ]
1133
+ # Concatenate the data into a single tensor
1134
+ data_tensor = torch.cat(data)
1135
+ dt_shape = data_tensor.shape[0]
1136
+ if dt_shape > METADATA_TENSOR_LEN:
1137
+ raise ValueError(
1138
+ f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})."
1139
+ )
1140
+ metadata_tensor[:dt_shape] = data_tensor
1141
+ return metadata_tensor
1142
+
1143
+
1144
+ def _extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]:
1145
+ """
1146
+ Extract the number of dimensions and the shape of each tensor from a metadata tensor.
1147
+ """
1148
+ metadata: List[torch.Size] = []
1149
+ i = 0
1150
+ while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL:
1151
+ num_dims = int(tensor[i].item())
1152
+ shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist())
1153
+ metadata.append(shape)
1154
+ i += num_dims + 1
1155
+ return metadata
1156
+
1157
+
1158
+ def _get_stage_shapes(
1159
+ stage_modules: List[nn.Module],
1160
+ stage_ids: List[int],
1161
+ num_stages: int,
1162
+ rank: int,
1163
+ world_size: int,
1164
+ device: torch.device,
1165
+ microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
1166
+ ):
1167
+ """
1168
+ Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of
1169
+ virtual pipelining) and returns the shape of the inputs and outputs of the module.
1170
+ Only the first stage must pass in a microbatch.
1171
+
1172
+ Each rank must call _get_stage_shapes or the program will hang.
1173
+
1174
+ Args:
1175
+ stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any
1176
+ non-interleaved schedules and >1 for any interleaved schedules.
1177
+ stage_ids: The id of the stages assigned to this rank.
1178
+ num_stages: Total number of stages.
1179
+ rank: Rank of the current process.
1180
+ world_size: Number of processes participating in the pipeline.
1181
+ device: Device where the tensors are allocated.
1182
+
1183
+ Returns a dictionary containing the following keys:
1184
+ "inputs": Shape of the inputs to the module
1185
+ "outputs": Shape of the outputs of the module
1186
+ """
1187
+
1188
+ stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {}
1189
+ for stage_id, model in zip(stage_ids, stage_modules):
1190
+ input_shape_metadata_tensor = _create_metadata_tensor(device=device)
1191
+ # TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1
1192
+ prev_rank = (rank - 1) % world_size
1193
+ next_rank = (rank + 1) % world_size
1194
+ shapes = {}
1195
+
1196
+ # first stage doesn't receive anything and uses a microbatch
1197
+ if stage_id == 0:
1198
+ if microbatch is None:
1199
+ raise RuntimeError("Microbatch is required for first stage")
1200
+ example_fwd_inputs = microbatch
1201
+ if isinstance(example_fwd_inputs, torch.Tensor):
1202
+ example_fwd_inputs = [example_fwd_inputs]
1203
+ else:
1204
+ # other stages must receive shape information
1205
+ # TODO: send/recv should take a group, rather than use the default group
1206
+ dist.recv(input_shape_metadata_tensor, prev_rank)
1207
+ metadata = _extract_metadata_from_tensor(input_shape_metadata_tensor)
1208
+ example_fwd_inputs = [
1209
+ torch.empty(shape_list, device=device) for shape_list in metadata
1210
+ ]
1211
+ shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs]
1212
+
1213
+ # perform forward
1214
+ # TODO: if forward fails raise a more descriptive error explaining which stage failed
1215
+ fwd_outputs = model(*example_fwd_inputs)
1216
+ fwd_outputs = _create_empty_tensors(fwd_outputs, device)
1217
+ shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs]
1218
+
1219
+ # send shape dims
1220
+ if stage_id != num_stages - 1:
1221
+ output_shape_metadata_tensor = _create_metadata_tensor(
1222
+ fwd_outputs, device=device
1223
+ )
1224
+ dist.send(output_shape_metadata_tensor, next_rank)
1225
+ stage_id_to_shapes[stage_id] = shapes
1226
+ logger.info(stage_id_to_shapes)
1227
+ return stage_id_to_shapes
1228
+
1229
+
1230
+ class PipelineStage(_PipelineStageBase):
1231
+ """
1232
+ A class representing a pipeline stage in a pipeline parallelism setup.
1233
+ This class is created manually by providing a example input (and optionally output)
1234
+ as opposed to the PipelineStage class that is outputed from pipeline().
1235
+ This class extends the `_PipelineStageBase` class and can similarly be used
1236
+ in `PipelineScheule`.
1237
+
1238
+ Args:
1239
+ submodule (nn.Module): The PyTorch module wrapped by this stage.
1240
+ stage_index (int): The ID of this stage.
1241
+ num_stages (int): The total number of stages.
1242
+ device (torch.device): The device where this stage is located.
1243
+ input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule.
1244
+ output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule.
1245
+ group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group.
1246
+ dw_builder: TODO clean up comments
1247
+ """
1248
+
1249
+ def __init__(
1250
+ self,
1251
+ submodule: nn.Module,
1252
+ stage_index: int,
1253
+ num_stages: int,
1254
+ device: torch.device,
1255
+ input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
1256
+ output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None,
1257
+ group: Optional[dist.ProcessGroup] = None,
1258
+ dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
1259
+ ):
1260
+ super().__init__(submodule, stage_index, num_stages, device, group, dw_builder)
1261
+ self.submod.to(self.device)
1262
+ # When we materialize the model partition on cuda, we call reset_parameters() if it is available
1263
+ self.inputs: List[torch.Tensor] = []
1264
+ self.outputs: List[torch.Tensor] = []
1265
+
1266
+ self.inputs = _create_empty_tensors(input_args, device)
1267
+
1268
+ if output_args is None:
1269
+ logger.info("output_args not provided, performing forward using input_args")
1270
+ self.outputs = self.submod(*self.inputs)
1271
+ # create buffers for the output so that the data is in the correct
1272
+ # shape in order to use in p2p op (send)
1273
+ self.outputs = _create_empty_tensors(self.outputs, device)
1274
+ else:
1275
+ self.outputs = _create_empty_tensors(output_args, device)
1276
+
1277
+ self._configure_outputs_meta(tuple(self.outputs))
1278
+
1279
+ # these are the buffers used in backwards send/recv, they are allocated later
1280
+ self.outputs_grad: List[torch.Tensor] = []
1281
+
1282
+ def stage_global_rank(peer_rank):
1283
+ return (
1284
+ peer_rank
1285
+ if self.group is None
1286
+ else dist.get_global_rank(self.group, peer_rank)
1287
+ )
1288
+
1289
+ self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size)
1290
+ self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size)
1291
+
1292
+ logger.debug(
1293
+ f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004
1294
+ f"{self.is_last=}, {self.num_stages=}, "
1295
+ f"inputs: {[inp.shape for inp in self.inputs]}, "
1296
+ f"output: {[output.shape for output in self.outputs]}"
1297
+ )
1298
+
1299
+ def _prepare_forward_infra(self, num_microbatches: int) -> None:
1300
+ # Receive info during forward
1301
+ # TODO: create args_recv_info lazily? (same needed for PipelineStage)
1302
+ for chunk_id in range(num_microbatches):
1303
+ self.set_requires_grad[chunk_id] = False
1304
+ if not self.is_first:
1305
+ # We assume that we always receive from stage - 1
1306
+ recv_infos = tuple(
1307
+ [
1308
+ _RecvInfo(
1309
+ f"recv_for_{self.stage_index}_from_{self.stage_index - 1}",
1310
+ self.stage_index - 1,
1311
+ _make_tensor_from_meta(inp, self.device),
1312
+ )
1313
+ for inp in self.inputs
1314
+ ]
1315
+ )
1316
+
1317
+ self.args_recv_info[chunk_id] = recv_infos
1318
+ else:
1319
+ self.args_recv_info[chunk_id] = tuple(
1320
+ [_RootArgPlaceholder(i) for i in self.inputs]
1321
+ )
1322
+
1323
+ # Send info during forward for each activation
1324
+ # only need the rank that is being sent to
1325
+ self.act_send_info: Dict[int, List] = {}
1326
+ for idx in range(len(self.outputs)):
1327
+ # We assume we always send to stage + 1
1328
+ if not self.is_last:
1329
+ self.act_send_info[idx] = [self.stage_index + 1]
1330
+ else:
1331
+ self.act_send_info[idx] = []
1332
+
1333
+ def _create_grad_recv_info(
1334
+ self,
1335
+ act_send_info: Dict,
1336
+ ) -> Tuple[_RecvInfo, ...]:
1337
+ grad_recv_info: Tuple[_RecvInfo, ...] = ()
1338
+ if not self.is_last:
1339
+ # Receiving gradients from multiple sources is not supported
1340
+ # hence we only take the first destination
1341
+ grad_recv_info = tuple(
1342
+ [
1343
+ _RecvInfo(
1344
+ f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}",
1345
+ dst_list[0],
1346
+ _make_tensor_from_meta(self.outputs[idx], self.device),
1347
+ )
1348
+ for idx, dst_list in act_send_info.items()
1349
+ ]
1350
+ )
1351
+ return grad_recv_info
1352
+
1353
+ def _init_p2p_neighbors(self):
1354
+ """
1355
+ Set up p2p communitors between previous and next stages
1356
+ by sending a dummy tensor.
1357
+
1358
+ If this is used, must be called for all pipeline stages.
1359
+ """
1360
+ ops = []
1361
+ recv_tensor = torch.zeros(1, device="cuda")
1362
+ send_tensor = torch.ones(1, device="cuda")
1363
+ # forward
1364
+ if not self.is_first:
1365
+ ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_stage, self.group))
1366
+ if not self.is_last:
1367
+ ops.append(dist.P2POp(dist.isend, send_tensor, self.next_stage, self.group))
1368
+
1369
+ # backward
1370
+ if not self.is_first:
1371
+ ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_stage, self.group))
1372
+ if not self.is_last:
1373
+ ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_stage, self.group))
1374
+
1375
+ return True
1376
+
1377
+
1378
+ def _validate_stage_shapes(pipeline_stages: List[PipelineStage]):
1379
+ """
1380
+ Check that the buffer shapes match between stages was expected by performing an all_gather between
1381
+ all stages.
1382
+ """
1383
+ if len(pipeline_stages) == 0:
1384
+ raise ValueError("No pipeline stages provided.")
1385
+
1386
+ virtual_pipeline_size = len(pipeline_stages)
1387
+ all_inputs = []
1388
+ all_outputs = []
1389
+ world_size = pipeline_stages[0].group_size
1390
+ num_stages = pipeline_stages[0].num_stages
1391
+
1392
+ # perform all gathers between all stages
1393
+ for virtual_id, stage in enumerate(pipeline_stages):
1394
+ world_size = stage.group_size
1395
+ stage_id: int = stage.stage_index
1396
+ rank = stage.group_rank
1397
+ # check that world_size and num_stages are consistent across all stages
1398
+ if stage.group_size != world_size:
1399
+ raise ValueError(
1400
+ f"Stage id {stage_id} has world size ({stage.group_size}) \
1401
+ which does not match world size ({world_size}) of other stages."
1402
+ )
1403
+ if stage.num_stages != num_stages:
1404
+ raise ValueError(
1405
+ f"Stage id {stage_id} has num stages ({stage.num_stages}) \
1406
+ which does not match num stages ({num_stages}) of other stages."
1407
+ )
1408
+
1409
+ pg_rank = dist.get_rank(stage.group)
1410
+ if rank != pg_rank:
1411
+ raise ValueError(
1412
+ f"Rank {rank} is not equal to process group rank {pg_rank}"
1413
+ )
1414
+
1415
+ if (num_stages := stage.num_stages) % world_size != 0:
1416
+ raise ValueError(
1417
+ f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})"
1418
+ )
1419
+
1420
+ # all gather each ranks inputs
1421
+ tensor_list = [
1422
+ _create_metadata_tensor(device=stage.device)
1423
+ for _ in range(stage.group_size)
1424
+ ]
1425
+ expected_inputs = stage.inputs
1426
+ stage_input = _create_metadata_tensor(expected_inputs, device=stage.device)
1427
+ dist.all_gather(tensor_list, stage_input)
1428
+ stage_input_shapes = [
1429
+ _extract_metadata_from_tensor(tensor) for tensor in tensor_list
1430
+ ]
1431
+
1432
+ # all gather each ranks outputs
1433
+ tensor_list = [
1434
+ _create_metadata_tensor(device=stage.device)
1435
+ for _ in range(stage.group_size)
1436
+ ]
1437
+ expected_outputs = stage.outputs
1438
+ stage_output = _create_metadata_tensor(expected_outputs, device=stage.device)
1439
+ dist.all_gather(tensor_list, stage_output)
1440
+ stage_output_shapes = [
1441
+ _extract_metadata_from_tensor(tensor) for tensor in tensor_list
1442
+ ]
1443
+
1444
+ logger.debug(
1445
+ f"Rank: {pg_rank}" # noqa: G004
1446
+ f"Stage id: {stage_id}"
1447
+ f"Stage num stages: {stage.num_stages}"
1448
+ f"Stage rank: {rank}"
1449
+ f"Stage world size: {world_size}"
1450
+ f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}" # noqa: G003
1451
+ f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}" # noqa: G003
1452
+ )
1453
+
1454
+ all_inputs.extend(stage_input_shapes)
1455
+ all_outputs.extend(stage_output_shapes)
1456
+
1457
+ # log only rank 0's view, they will all be equivalent
1458
+ if pg_rank == 0:
1459
+ logger.info(
1460
+ "all stage inputs: %s \n all stage outputs: %s", all_inputs, all_outputs
1461
+ )
1462
+
1463
+ # Check if the output for stage 0 matches the input at stage 1, and so forth
1464
+ for i in range(virtual_pipeline_size * world_size - 1):
1465
+ if (out := all_outputs[i]) != (inp := all_inputs[i + 1]):
1466
+ raise ValueError(
1467
+ f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}."
1468
+ )
.venv/lib/python3.11/site-packages/torch/distributed/tensor/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+
3
+ import torch
4
+ import torch.distributed.tensor._ops # force import all built-in dtensor ops
5
+ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401
6
+ from torch.distributed.tensor._api import (
7
+ distribute_module,
8
+ distribute_tensor,
9
+ DTensor,
10
+ empty,
11
+ full,
12
+ ones,
13
+ rand,
14
+ randn,
15
+ zeros,
16
+ )
17
+ from torch.distributed.tensor.placement_types import (
18
+ Partial,
19
+ Placement,
20
+ Replicate,
21
+ Shard,
22
+ )
23
+ from torch.optim.optimizer import (
24
+ _foreach_supported_types as _optim_foreach_supported_types,
25
+ )
26
+ from torch.utils._foreach_utils import (
27
+ _foreach_supported_types as _util_foreach_supported_types,
28
+ )
29
+
30
+
31
+ # All public APIs from dtensor package
32
+ __all__ = [
33
+ "DTensor",
34
+ "distribute_tensor",
35
+ "distribute_module",
36
+ "Shard",
37
+ "Replicate",
38
+ "Partial",
39
+ "Placement",
40
+ "ones",
41
+ "empty",
42
+ "full",
43
+ "rand",
44
+ "randn",
45
+ "zeros",
46
+ ]
47
+
48
+
49
+ # Append DTensor to the list of supported types for foreach implementation for optimizer
50
+ # and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
51
+ if DTensor not in _optim_foreach_supported_types:
52
+ _optim_foreach_supported_types.append(DTensor)
53
+
54
+ if DTensor not in _util_foreach_supported_types:
55
+ _util_foreach_supported_types.append(DTensor)
56
+
57
+
58
+ # Set namespace for exposed private names
59
+ DTensor.__module__ = "torch.distributed.tensor"
60
+ distribute_tensor.__module__ = "torch.distributed.tensor"
61
+ distribute_module.__module__ = "torch.distributed.tensor"
62
+ ones.__module__ = "torch.distributed.tensor"
63
+ empty.__module__ = "torch.distributed.tensor"
64
+ full.__module__ = "torch.distributed.tensor"
65
+ rand.__module__ = "torch.distributed.tensor"
66
+ randn.__module__ = "torch.distributed.tensor"
67
+ zeros.__module__ = "torch.distributed.tensor"
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_api.py ADDED
@@ -0,0 +1,1231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates
4
+ import inspect
5
+ import warnings
6
+ from typing import Any, Callable, cast, Optional, Sequence, Tuple
7
+
8
+ import torch
9
+ import torch.distributed.tensor._dispatch as op_dispatch
10
+ import torch.distributed.tensor._random as random
11
+ import torch.nn as nn
12
+ from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
13
+ from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast
14
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
15
+ from torch.distributed.tensor._random import (
16
+ is_rng_supported_mesh,
17
+ OffsetBasedRNGTracker,
18
+ )
19
+ from torch.distributed.tensor._redistribute import (
20
+ Redistribute,
21
+ redistribute_local_tensor,
22
+ )
23
+ from torch.distributed.tensor._utils import (
24
+ compute_global_tensor_info,
25
+ compute_local_shape,
26
+ normalize_to_torch_size,
27
+ )
28
+ from torch.distributed.tensor.placement_types import (
29
+ Partial,
30
+ Placement,
31
+ Replicate,
32
+ Shard,
33
+ )
34
+
35
+
36
+ __all__ = [
37
+ "DTensor",
38
+ "distribute_tensor",
39
+ "distribute_module",
40
+ "ones",
41
+ "empty",
42
+ "full",
43
+ "rand",
44
+ "randn",
45
+ "zeros",
46
+ ]
47
+
48
+ aten = torch.ops.aten
49
+
50
+
51
+ # NOTE [Autograd interaction between torch.Tensor]
52
+ #
53
+ # The autograd functions defined below are being used by the public
54
+ # facing APIs (i.e. from_local, to_local) to ensure DTensor to work
55
+ # together with torch.Tensor within the autograd engine. This
56
+ # allows DTensor to only exist on part of the module hierarchy.
57
+ #
58
+ # As an example, we have the a module that consists of submodules
59
+ # A, B, and C, the execution flow would be like:
60
+ # input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
61
+ #
62
+ # Suppose I only want to make Module B be a sharded module with
63
+ # DTensor params, the following forward/backward should work:
64
+ #
65
+ # input(torch.Tensor) -> Module A
66
+ # -> DTensor input (from_local) -> Sharded Module B -> DTensor output
67
+ # -> torch.Tensor output (to_local) -> Module C
68
+ #
69
+ # So from_local/to_local must be Autograd functions.
70
+ #
71
+ class _ToTorchTensor(torch.autograd.Function):
72
+ @staticmethod
73
+ def forward( # type: ignore[override]
74
+ ctx,
75
+ input: "DTensor",
76
+ grad_placements: Optional[Sequence[Placement]],
77
+ ):
78
+ ctx.dtensor_spec = input._spec
79
+ ctx.grad_placements = grad_placements
80
+ local_tensor = input._local_tensor
81
+
82
+ # We need to return a fresh Tensor object there as autograd metadata
83
+ # will be inplaced into it. So we don't want to pollute the Tensor
84
+ # object stored in the _local_tensor of this DTensor.
85
+ return local_tensor.view_as(local_tensor)
86
+
87
+ @staticmethod
88
+ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
89
+ dtensor_spec = ctx.dtensor_spec
90
+ mesh = dtensor_spec.mesh
91
+ grad_placements = ctx.grad_placements
92
+ dtensor_meta = dtensor_spec.tensor_meta
93
+
94
+ _, tensor_stride = compute_global_tensor_info(
95
+ grad_output, mesh, dtensor_spec.placements
96
+ )
97
+ tensor_stride = tuple(tensor_stride)
98
+ grad_placements = grad_placements or dtensor_spec.placements
99
+ grad_spec = DTensorSpec(
100
+ mesh,
101
+ grad_placements,
102
+ tensor_meta=TensorMeta(
103
+ shape=dtensor_meta.shape,
104
+ stride=tensor_stride,
105
+ dtype=dtensor_meta.dtype,
106
+ ),
107
+ )
108
+
109
+ return (
110
+ DTensor(
111
+ grad_output,
112
+ grad_spec,
113
+ requires_grad=grad_output.requires_grad,
114
+ ),
115
+ None,
116
+ )
117
+
118
+
119
+ class _FromTorchTensor(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward( # type: ignore[override]
122
+ ctx, # pyre-ignore[2]: Parameter must be annotated.
123
+ input: torch.Tensor,
124
+ device_mesh: DeviceMesh,
125
+ placements: Tuple[Placement, ...],
126
+ run_check: bool,
127
+ shape: Optional[torch.Size] = None,
128
+ stride: Optional[Tuple[int, ...]] = None,
129
+ ) -> "DTensor":
130
+ ctx.previous_placement = placements
131
+ ctx.previous_device_mesh = device_mesh
132
+
133
+ if shape and stride:
134
+ tensor_shape, tensor_stride = shape, stride
135
+ elif not shape and not stride:
136
+ # if it's not by default run_check, we assume user is certain that each
137
+ # rank has the same tensor shape, and we just use that to calculate the
138
+ # global shape
139
+ global_shape, global_stride = compute_global_tensor_info(
140
+ input, device_mesh, placements
141
+ )
142
+ tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride)
143
+ else:
144
+ raise RuntimeError(
145
+ f"Found shape:{shape}, stride:{stride}.",
146
+ "Please pass both shape and stride at the same time.",
147
+ )
148
+
149
+ if device_mesh.get_coordinate() is None:
150
+ # if the global rank is not participating in the device mesh, we
151
+ # simply set the local tensor to an empty tensor
152
+ input = input.new_empty(0, requires_grad=input.requires_grad)
153
+ elif run_check:
154
+ # TODO: support uneven sharding when global shape/stride not passed, by
155
+ # building the global TensorMeta during check_tensor_meta
156
+ check_shape_stride = not shape and not stride
157
+ check_tensor_meta(input, check_shape_stride=check_shape_stride)
158
+ # TODO: See if we need to make this run_check logic
159
+ # have a corresponding backward.
160
+ for idx, placement in enumerate(placements):
161
+ if placement.is_replicate():
162
+ # broadcast rank 0 tensor to all ranks
163
+ # only broadcast if run_check is True
164
+ input = input.contiguous()
165
+ mesh_broadcast(input, device_mesh, mesh_dim=idx)
166
+
167
+ dist_spec = DTensorSpec(
168
+ device_mesh,
169
+ placements,
170
+ tensor_meta=TensorMeta(
171
+ tensor_shape,
172
+ tensor_stride,
173
+ input.dtype,
174
+ ),
175
+ )
176
+
177
+ # We want a fresh Tensor object that shares memory with the input tensor
178
+ dist_tensor = DTensor(
179
+ input.view_as(input),
180
+ dist_spec,
181
+ # requires_grad of the dist tensor depends on if input
182
+ # requires_grad or not
183
+ requires_grad=input.requires_grad,
184
+ )
185
+ return dist_tensor
186
+
187
+ @staticmethod
188
+ def backward(ctx, grad_output: "DTensor"): # type: ignore[override]
189
+ previous_placement = ctx.previous_placement
190
+ previous_device_mesh = ctx.previous_device_mesh
191
+
192
+ # reshard to the placement when creating DistributedTensor
193
+ # so that the gradient layout matches, and we could return
194
+ # local gradients directly
195
+ if grad_output.placements != previous_placement:
196
+ current_spec = grad_output._spec
197
+ target_spec = DTensorSpec(
198
+ previous_device_mesh,
199
+ previous_placement,
200
+ tensor_meta=grad_output._spec.tensor_meta,
201
+ )
202
+ local_tensor = grad_output._local_tensor
203
+ output = redistribute_local_tensor(
204
+ local_tensor, current_spec, target_spec, is_backward=True
205
+ )
206
+ # TODO: return the redistributed local tensor directly without
207
+ # differentiable backward. see if this make sense for all cases.
208
+ return output, None, None, None, None, None
209
+
210
+ # TODO: backward is also differentiable now, add a test
211
+ # to test higher level gradients.
212
+ return grad_output.to_local(), None, None, None, None, None
213
+
214
+
215
+ class DTensor(torch.Tensor):
216
+ """
217
+ ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like
218
+ abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding
219
+ layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`:
220
+
221
+ * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension
222
+ * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension
223
+ * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension
224
+
225
+ When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue
226
+ communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the
227
+ placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs.
228
+
229
+ To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor``
230
+ requires every Tensor argument of the operator be DTensor.
231
+
232
+ """
233
+
234
+ _local_tensor: torch.Tensor
235
+ _spec: DTensorSpec
236
+ __slots__ = ["_local_tensor", "_spec"]
237
+
238
+ # _op_dispatcher instance as a class attribute to handle runtime dispatching logic
239
+ _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()
240
+
241
+ @staticmethod
242
+ @torch._disable_dynamo
243
+ def __new__(
244
+ cls,
245
+ local_tensor: torch.Tensor,
246
+ spec: DTensorSpec,
247
+ *,
248
+ requires_grad: bool,
249
+ ) -> "DTensor":
250
+ """
251
+ Construct a DTensor from a local tensor, device mesh, and placement and
252
+ other tensor properties (i.e. shape, requires_grad, strides, etc).
253
+
254
+ .. note:: This is not a public API and it's only supposed to be used by the
255
+ operator implementations and internals. If you want to construct a
256
+ DTensor from a local tensor, consider using ``DTensor.from_local``, if
257
+ you want to construct a DTensor from a "global" tensor (where you
258
+ already have tensor initialized and want to shard this tensor),
259
+ consider using ``distribute_tensor``.
260
+ """
261
+ if local_tensor.requires_grad and not requires_grad:
262
+ warnings.warn(
263
+ "To construct DTensor from torch.Tensor, it's recommended to "
264
+ "use local_tensor.detach() and make requires_grad consistent."
265
+ )
266
+
267
+ # new method instruct wrapper tensor from local_tensor and add
268
+ # placement spec, it does not do actual distribution
269
+ assert spec.tensor_meta is not None, "TensorMeta should not be None!"
270
+ r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
271
+ cls,
272
+ spec.tensor_meta.shape,
273
+ strides=spec.tensor_meta.stride,
274
+ dtype=local_tensor.dtype,
275
+ device=local_tensor.device,
276
+ layout=local_tensor.layout,
277
+ requires_grad=requires_grad,
278
+ )
279
+
280
+ r._spec = spec
281
+ r._local_tensor = local_tensor
282
+ return r
283
+
284
+ # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
285
+ # pyre-fixme[3]: Return type must be annotated.
286
+ def __repr__(self):
287
+ # TODO: consider all_gather the local tensors for better debugging
288
+ return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
289
+
290
+ def __tensor_flatten__(self):
291
+ """
292
+ protocol to inform how to flatten a DTensor to local tensor
293
+ for PT2 tracing
294
+ """
295
+ return ["_local_tensor"], (self._spec, self.requires_grad)
296
+
297
+ @staticmethod
298
+ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
299
+ assert (
300
+ flatten_spec is not None
301
+ ), "Expecting spec to be not None from `__tensor_flatten__` return value!"
302
+ local_tensor = inner_tensors["_local_tensor"]
303
+ spec, requires_grad = flatten_spec
304
+ unflatten_tensor_meta = TensorMeta(
305
+ shape=outer_size,
306
+ stride=outer_stride,
307
+ dtype=spec.tensor_meta.dtype,
308
+ )
309
+ unflatten_spec = DTensorSpec(
310
+ spec.mesh,
311
+ spec.placements,
312
+ tensor_meta=unflatten_tensor_meta,
313
+ )
314
+ return DTensor(
315
+ local_tensor,
316
+ unflatten_spec,
317
+ requires_grad=requires_grad,
318
+ )
319
+
320
+ def __coerce_tangent_metadata__(self):
321
+ if not any(isinstance(p, Partial) for p in self.placements):
322
+ return self
323
+ placements = [
324
+ Replicate() if isinstance(p, Partial) else p for p in self.placements
325
+ ]
326
+ return self.redistribute(device_mesh=self.device_mesh, placements=placements)
327
+
328
+ def __coerce_same_metadata_as_tangent__(self, flatten_spec):
329
+ (spec, _) = flatten_spec # Result of tensor_flatten()
330
+ return self.redistribute(
331
+ device_mesh=self.device_mesh,
332
+ placements=spec.placements,
333
+ )
334
+
335
+ @classmethod
336
+ @torch._disable_dynamo
337
+ # pyre-fixme[3]: Return type must be annotated.
338
+ # pyre-fixme[2]: Parameter must be annotated.
339
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
340
+ return DTensor._op_dispatcher.dispatch(
341
+ func,
342
+ args,
343
+ kwargs or {},
344
+ )
345
+
346
+ @staticmethod
347
+ def from_local(
348
+ local_tensor: torch.Tensor,
349
+ device_mesh: Optional[DeviceMesh] = None,
350
+ placements: Optional[Sequence[Placement]] = None,
351
+ *,
352
+ run_check: bool = False,
353
+ shape: Optional[torch.Size] = None,
354
+ stride: Optional[Tuple[int, ...]] = None,
355
+ ) -> "DTensor":
356
+ """
357
+ Create a :class:`DTensor` from a local torch.Tensor on each rank
358
+ according to the ``device_mesh`` and ``placements`` specified.
359
+
360
+ Args:
361
+ local_tensor (torch.Tensor): local torch.Tensor on each rank.
362
+ device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
363
+ tensor, if not specified, must be called under a DeviceMesh
364
+ context manager, default: None
365
+ placements (List[:class:`Placement`], optional): the placements that
366
+ describes how to place the local torch.Tensor on DeviceMesh, must
367
+ have the same number of elements as ``device_mesh.ndim``.
368
+
369
+ Keyword args:
370
+ run_check (bool, optional): at a cost of extra communications, perform
371
+ sanity check across ranks to check each local tensor's meta information
372
+ to ensure correctness. If have :class:`Replicate` in ``placements``, the
373
+ data on first rank of the device mesh dimension will be broadcasted
374
+ to other ranks. default: False
375
+ shape (torch.Size, optional): A List of int which specifies the size of
376
+ DTensor which build on top of `local_tensor`. Note this needs to be
377
+ provided if the shape of ``local_tensor`` are different across the ranks.
378
+ If not provided, ``shape`` will be computed assuming the given distributed
379
+ tensor is evenly sharded across ranks. default: None
380
+ stride (tuple, optional): A List of int which specifies the stride of DTensor.
381
+ If not provided, ``stride`` will be computed assuming the given distributed
382
+ tensor is evenly sharded across ranks. default: None
383
+
384
+ Returns:
385
+ A :class:`DTensor` object
386
+
387
+ .. note:: When ``run_check=False``, it is the user's responsibility to ensure the
388
+ local tensor passed in is correct across ranks (i.e. the tensor is sharded for
389
+ the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement).
390
+ If not, the behavior of the created DTensor is undefined.
391
+
392
+ .. note:: ``from_local`` is differentiable, the `requires_grad` of the created
393
+ `DTensor` object will depend on if `local_tensor` requires_grad or not.
394
+ """
395
+ # if same shape/dtype, no need to run_check, if not, must allgather
396
+ # the metadatas to check the size/dtype across ranks
397
+ # There should be no data communication unless there's replication
398
+ # strategy, where we broadcast the replication from the first rank
399
+ # in the mesh dimension
400
+ device_mesh = device_mesh or _mesh_resources.get_current_mesh()
401
+ device_type = device_mesh.device_type
402
+
403
+ # convert the local tensor to desired device base on device mesh's device_type
404
+ if device_type != local_tensor.device.type and not local_tensor.is_meta:
405
+ local_tensor = local_tensor.to(device_type)
406
+
407
+ # set default placements to replicated if not specified
408
+ if placements is None:
409
+ placements = [Replicate() for _ in range(device_mesh.ndim)]
410
+ else:
411
+ placements = list(placements)
412
+ for idx, placement in enumerate(placements):
413
+ # normalize shard dim to be positive
414
+ if placement.is_shard():
415
+ placement = cast(Shard, placement)
416
+ if placement.dim < 0:
417
+ placements[idx] = Shard(placement.dim + local_tensor.ndim)
418
+
419
+ # `from_local` is differentiable, and the gradient of the dist tensor this function
420
+ # created should flow back the gradients to the local_tensor, so we call an autograd
421
+ # function to construct the dist tensor instead.
422
+ return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
423
+ local_tensor,
424
+ device_mesh,
425
+ tuple(placements),
426
+ run_check,
427
+ shape,
428
+ stride,
429
+ )
430
+
431
+ def to_local(
432
+ self, *, grad_placements: Optional[Sequence[Placement]] = None
433
+ ) -> torch.Tensor:
434
+ """
435
+ Get the local tensor of this DTensor on its current rank. For sharding it returns
436
+ a local shard of the logical tensor view, for replication it returns the replica on
437
+ its current rank.
438
+
439
+ Keyword args:
440
+ grad_placements (List[:class:`Placement`], optional): the placements describes
441
+ the future layout of any gradient layout of the Tensor returned from this
442
+ function.
443
+ `to_local` converts DTensor to local tensor and the returned local tensor
444
+ might not be used as the original DTensor layout later in the code. This
445
+ argument is the hint that user can give to autograd in case the gradient
446
+ layout of the returned tensor does not match the original DTensor layout.
447
+ If not specified, we will assume the gradient layout remains the same
448
+ as the original DTensor and use that for gradient computation.
449
+
450
+ Returns:
451
+ A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the
452
+ local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned,
453
+ it means the local tensor is not ready yet (i.e. communication is not finished). In this
454
+ case, user needs to call ``wait`` to wait the local tensor to be ready.
455
+
456
+ .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned
457
+ will depend on if the `DTensor` requires_grad or not.
458
+ """
459
+ if not torch.is_grad_enabled():
460
+ return self._local_tensor
461
+
462
+ if grad_placements is not None and not isinstance(grad_placements, tuple):
463
+ grad_placements = tuple(grad_placements)
464
+ return _ToTorchTensor.apply(
465
+ self, grad_placements
466
+ ) # pyre-ignore[16]: autograd func
467
+
468
+ def redistribute(
469
+ self,
470
+ device_mesh: Optional[DeviceMesh] = None,
471
+ placements: Optional[Sequence[Placement]] = None,
472
+ *,
473
+ async_op: bool = False,
474
+ ) -> "DTensor":
475
+ """
476
+ ``redistribute`` performs necessary collective operations that redistribute the current
477
+ DTensor from its current placements to a new placements, or from is current DeviceMesh
478
+ to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
479
+ specifying a Replicate placement for each dimension of the DeviceMesh.
480
+
481
+ When redistributing from current to the new placements on one device mesh dimension, we
482
+ will perform the following operations including communication collective or local operation:
483
+
484
+ 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather``
485
+ 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all``
486
+ 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``)
487
+ 4. ``Partial()`` -> ``Replicate()``: ``all_reduce``
488
+ 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter``
489
+
490
+
491
+ ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors
492
+ that are created either on 1-D or N-D DeviceMesh.
493
+
494
+ Args:
495
+ device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
496
+ DTensor. If not specified, it would use the current DTensor's DeviceMesh.
497
+ default: None
498
+ placements (List[:class:`Placement`], optional): the new placements that
499
+ describes how to place the DTensor into the DeviceMesh, must
500
+ have the same number of elements as ``device_mesh.ndim``.
501
+ default: replicate on all mesh dimensions
502
+
503
+ Keyword args:
504
+ async_op (bool, optional): whether to perform the DTensor redistribute operation
505
+ asynchronously or not. Default: False
506
+
507
+ Returns:
508
+ A :class:`DTensor` object
509
+
510
+ .. note:: ``redistribute`` is differentiable, which means user do not need to worry about
511
+ the backward formula of the redistribute operation.
512
+
513
+ .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh,
514
+ Please file an issue if you need to redistribute DTensor to different DeviceMesh.
515
+ """
516
+ # NOTE: This redistribute API currently only supports out
517
+ # of place redistribution, i.e. it always create a new
518
+ # DTensor object and leave the original one unchanged.
519
+
520
+ # if device_mesh is not specified, use the current device_mesh
521
+ device_mesh = device_mesh or self.device_mesh
522
+ # raise error if new placements not specified
523
+ if placements is None:
524
+ raise RuntimeError("placements is needed for redistribute!")
525
+
526
+ placements = list(placements)
527
+ for i, placement in enumerate(placements):
528
+ if placement.is_partial():
529
+ raise RuntimeError(
530
+ "Can not redistribute to Partial, redistributing to Partial is for internal use only!"
531
+ )
532
+ elif isinstance(placement, Shard) and placement.dim < 0:
533
+ # normalize shard dim to be positive
534
+ placements[i] = Shard(placement.dim + self.ndim)
535
+ placements = tuple(placements)
536
+
537
+ # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
538
+ return Redistribute.apply(self, device_mesh, placements, async_op)
539
+
540
+ def full_tensor(
541
+ self, *, grad_placements: Optional[Sequence[Placement]] = None
542
+ ) -> torch.Tensor:
543
+ """
544
+ Return the full tensor of this DTensor. It will perform necessary collectives
545
+ to gather the local tensors from other ranks in its DeviceMesh and concatenate
546
+ them together. It's a syntatic sugar of the following code:
547
+
548
+ ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()``
549
+
550
+ Keyword args:
551
+ grad_placements (List[:class:`Placement`], optional): the placements describes
552
+ the future layout of any gradient layout of the full Tensor returned from this
553
+ function.
554
+ `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor
555
+ might not be used as the original replicated DTensor layout later in the code. This
556
+ argument is the hint that user can give to autograd in case the gradient
557
+ layout of the returned tensor does not match the original replicated DTensor layout.
558
+ If not specified, we will assume the gradient layout of the full tensor be replicated.
559
+
560
+ Returns:
561
+ A :class:`torch.Tensor` object that represents the full tensor of this DTensor.
562
+
563
+ .. note:: ``full_tensor`` is differentiable.
564
+ """
565
+
566
+ redist_res = self.redistribute(
567
+ placements=[Replicate()] * self.device_mesh.ndim, async_op=False
568
+ )
569
+ return _ToTorchTensor.apply(redist_res, grad_placements)
570
+
571
+ @property
572
+ def device_mesh(self) -> DeviceMesh:
573
+ """
574
+ The :class:`DeviceMesh` attribute that associates with this DTensor object.
575
+
576
+ .. note:: ``device_mesh`` is a read-only property, it can not be set.
577
+ """
578
+ return self._spec.mesh
579
+
580
+ @property
581
+ def placements(self) -> Tuple[Placement, ...]:
582
+ """
583
+ The placements attribute of this DTensor that describes the layout of this
584
+ DTensor on the its DeviceMesh.
585
+
586
+ .. note:: ``placements`` is a read-only property, it can not be set.
587
+ """
588
+ return self._spec.placements
589
+
590
+ def __create_write_items__(self, fqn: str, object: Any):
591
+ from torch.distributed.checkpoint.planner_helpers import (
592
+ _create_write_items_for_dtensor,
593
+ )
594
+
595
+ if hasattr(self._local_tensor, "__create_write_items__"):
596
+ return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined]
597
+ elif isinstance(self._local_tensor, torch.Tensor):
598
+ return [_create_write_items_for_dtensor(fqn, object)]
599
+ else:
600
+ raise RuntimeError("Unsupported tensor type!")
601
+
602
+ def __create_chunk_list__(self):
603
+ from torch.distributed.checkpoint.planner_helpers import (
604
+ _create_chunk_from_dtensor,
605
+ )
606
+
607
+ if hasattr(self._local_tensor, "__create_chunk_list__"):
608
+ return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined]
609
+ elif isinstance(self._local_tensor, torch.Tensor):
610
+ return [_create_chunk_from_dtensor(self)]
611
+ else:
612
+ raise RuntimeError("Unsupported tensor type!")
613
+
614
+ def __get_tensor_shard__(self, index):
615
+ if hasattr(self._local_tensor, "__get_tensor_shard__"):
616
+ return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined]
617
+ elif isinstance(self._local_tensor, torch.Tensor):
618
+ return self.to_local()
619
+ else:
620
+ raise RuntimeError("Unsupported tensor type!")
621
+
622
+
623
+ def distribute_tensor(
624
+ tensor: torch.Tensor,
625
+ device_mesh: Optional[DeviceMesh] = None,
626
+ placements: Optional[Sequence[Placement]] = None,
627
+ ) -> DTensor:
628
+ """
629
+ Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according
630
+ to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the
631
+ same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use
632
+ the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve
633
+ the single-device semantic. If you want to construct a DTensor in the middle of the Autograd
634
+ computation, please use :meth:`DTensor.from_local` instead.
635
+
636
+ Args:
637
+ tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
638
+ want to shard a tensor on a dimension that is not evenly divisible by
639
+ the number of devices in that mesh dimension, we use ``torch.chunk``
640
+ semantic to shard the tensor and scatter the shards. The uneven sharding
641
+ behavior is experimental and subject to change.
642
+ device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
643
+ tensor, if not specified, must be called under a DeviceMesh context
644
+ manager, default: None
645
+ placements (List[:class:`Placement`], optional): the placements that
646
+ describes how to place the tensor on DeviceMesh, must have the same
647
+ number of elements as ``device_mesh.ndim``. If not specified, we will
648
+ by default replicate the tensor across the ``device_mesh`` from the
649
+ first rank of each dimension of the `device_mesh`.
650
+
651
+ Returns:
652
+ A :class:`DTensor` or ``XLAShardedTensor`` object.
653
+
654
+ .. note::
655
+ When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor``
656
+ return `XLAShardedTensor` instead. see `this issue <https://github.com/pytorch/pytorch/issues/92909>`__
657
+ for more details. The XLA integration is experimental and subject to change.
658
+ """
659
+
660
+ torch._C._log_api_usage_once("torch.dtensor.distribute_tensor")
661
+
662
+ # get default device mesh if there's nothing specified
663
+ device_mesh = device_mesh or _mesh_resources.get_current_mesh()
664
+ device_type = device_mesh.device_type
665
+ if device_type == "xla":
666
+ try:
667
+ # call PyTorch/XLA SPMD for `xla` backend type device mesh.
668
+ # This returns XLAShardedTensor
669
+ from torch_xla.distributed.spmd import ( # type:ignore[import]
670
+ xla_distribute_tensor,
671
+ )
672
+
673
+ return xla_distribute_tensor(
674
+ tensor, device_mesh, placements
675
+ ) # type:ignore[return-value]
676
+ except ImportError as e:
677
+ msg = "To use DTensor API with xla, you must install the torch_xla package!"
678
+ raise ImportError(msg) from e
679
+
680
+ # instantiate a RNG tracker if haven't. By default DTensor uses an
681
+ # OffsetBasedRNGTracker to perform random operators.
682
+ # TODO: the value assignment to global variable is not the ideal solution
683
+ # we can replace it in future.
684
+ if not random._rng_tracker and is_rng_supported_mesh(device_mesh):
685
+ random._rng_tracker = OffsetBasedRNGTracker(device_type)
686
+
687
+ if not tensor.is_leaf:
688
+ raise RuntimeError(
689
+ "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!"
690
+ )
691
+
692
+ # convert tensor to the corresponding device type if it's not in that device type
693
+ if device_type != tensor.device.type and not tensor.is_meta:
694
+ tensor = tensor.to(device_type)
695
+
696
+ # set default placements to replicated if not specified
697
+ if placements is None:
698
+ placements = [Replicate() for _ in range(device_mesh.ndim)]
699
+
700
+ if len(placements) != device_mesh.ndim:
701
+ raise ValueError(
702
+ f"`placements` must have the same length as `device_mesh.ndim`! "
703
+ f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
704
+ )
705
+ if isinstance(tensor, DTensor):
706
+ # if the tensor is already a DTensor, we need to check:
707
+ # 1. if the we can further shard this DTensor if the two device mesh belong to
708
+ # the same parenet mesh and further sharding is possible.
709
+ # 2. check if device mesh and placements are the same
710
+ if tensor.device_mesh != device_mesh:
711
+ raise ValueError(
712
+ f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
713
+ f"to a different device mesh {device_mesh}."
714
+ )
715
+ if tensor.placements != tuple(placements):
716
+ raise ValueError(
717
+ f"Cannot distribute a DTensor with placements {tensor.placements} "
718
+ f"to a different placements {placements}. do you want to call "
719
+ f"`redistribute` instead?"
720
+ )
721
+ return tensor
722
+
723
+ local_tensor = tensor.detach()
724
+
725
+ # TODO(xilun): address sharding order
726
+ # distribute the tensor according to the placements.
727
+ placements = list(placements)
728
+ for idx, placement in enumerate(placements):
729
+ if placement.is_shard():
730
+ placement = cast(Shard, placement)
731
+ if placement.dim < 0:
732
+ # normalize shard placement dim
733
+ placement = Shard(placement.dim + tensor.ndim)
734
+ placements[idx] = placement
735
+ local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
736
+ elif placement.is_replicate():
737
+ placement = cast(Replicate, placement)
738
+ local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx)
739
+ else:
740
+ raise RuntimeError(
741
+ f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
742
+ )
743
+ placements = tuple(placements)
744
+
745
+ assert local_tensor is not None, "distributing a tensor should not be None"
746
+ # detach the local tensor passed to DTensor since after the construction
747
+ # of DTensor, autograd would work on top of DTensor instead of local tensor
748
+ spec = DTensorSpec(
749
+ mesh=device_mesh,
750
+ placements=placements,
751
+ tensor_meta=TensorMeta(
752
+ shape=tensor.size(),
753
+ stride=tensor.stride(),
754
+ dtype=tensor.dtype,
755
+ ),
756
+ )
757
+ return DTensor(
758
+ local_tensor.requires_grad_(tensor.requires_grad),
759
+ spec,
760
+ requires_grad=tensor.requires_grad,
761
+ )
762
+
763
+
764
+ def distribute_module(
765
+ module: nn.Module,
766
+ device_mesh: Optional[DeviceMesh] = None,
767
+ partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
768
+ input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
769
+ output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
770
+ ) -> nn.Module:
771
+ """
772
+ This function expose three functions to control the parameters/inputs/outputs of the module:
773
+
774
+ 1. To perform sharding on the module before runtime execution by specifying the
775
+ ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
776
+ parameters according to the `partition_fn` specified).
777
+ 2. To control the inputs or outputs of the module during runtime execution by
778
+ specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
779
+ :class:`DTensor`, convert the output back to ``torch.Tensor``)
780
+
781
+ Args:
782
+ module (:class:`nn.Module`): user module to be partitioned.
783
+ device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
784
+ partition_fn (Callable): the function to partition parameters (i.e. shard certain
785
+ parameters across the ``device_mesh``). If ``partition_fn`` is not specified,
786
+ by default we replicate all module parameters of ``module`` across the mesh.
787
+ input_fn (Callable): specify the input distribution, i.e. could control how the
788
+ input of the module is sharded. ``input_fn`` will be installed as a module
789
+ ``forward_pre_hook`` (pre forward hook).
790
+ output_fn (Callable): specify the output distribution, i.e. could control how the
791
+ output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be
792
+ installed as a module ``forward_hook`` (post forward hook).
793
+
794
+ Returns:
795
+ A module that contains parameters/buffers that are all ``DTensor`` s.
796
+
797
+ .. note::
798
+ When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module``
799
+ return nn.Module with PyTorch/XLA SPMD annotated parameters. See
800
+ `this issue <https://github.com/pytorch/pytorch/issues/92909>`__
801
+ for more details. The XLA integration is experimental and subject to change.
802
+
803
+ """
804
+
805
+ torch._C._log_api_usage_once("torch.dtensor.distribute_module")
806
+
807
+ device_mesh = device_mesh or _mesh_resources.get_current_mesh()
808
+ device_type = device_mesh.device_type
809
+ if device_type == "xla":
810
+ try:
811
+ # This function annotates all module parameters for auto-partitioning with
812
+ # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters
813
+ # according to the `partition_fn` specified.
814
+ from torch_xla.distributed.spmd import ( # type:ignore[import]
815
+ xla_distribute_module,
816
+ )
817
+
818
+ return xla_distribute_module(
819
+ module, device_mesh, partition_fn, input_fn, output_fn
820
+ ) # type:ignore[return-value]
821
+ except ImportError as e:
822
+ msg = "To use DTensor API with xla, you must install the torch_xla package!"
823
+ raise ImportError(msg) from e
824
+
825
+ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
826
+ # This function loop over the immediate module parameters and
827
+ # buffers, replicate all non DTensor params/buffers to DTensor
828
+ # parameters/buffers, if they have not been partitioned in the
829
+ # partition_fn, we can't easily use `module._apply` here
830
+ # because we don't know what happened inside partition_fn as
831
+ # user could do anything, i.e. install hooks, and we want to
832
+ # preserve those.
833
+ full_replicate = [Replicate()] * mesh.ndim
834
+ for key, param in m._parameters.items():
835
+ if param is not None and not isinstance(param, DTensor):
836
+ m.register_parameter(
837
+ key,
838
+ nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)),
839
+ )
840
+ for key, buffer in m._buffers.items():
841
+ if buffer is not None and not isinstance(buffer, DTensor):
842
+ m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate)
843
+
844
+ if partition_fn is None:
845
+ # if partition_fn not specified, we by default replicate
846
+ # all module params/buffers
847
+ for name, submod in module.named_modules():
848
+ replicate_module_params_buffers(submod, device_mesh)
849
+ else:
850
+ # apply partition_fun to submodules
851
+ for name, submod in module.named_modules():
852
+ partition_fn(name, submod, device_mesh)
853
+ replicate_module_params_buffers(submod, device_mesh)
854
+
855
+ # register input_fn as module forward pre hook
856
+ if input_fn is not None:
857
+ # check the input_fn signature
858
+ num_args = len(inspect.signature(input_fn).parameters)
859
+ if num_args == 2:
860
+ # input_fn only takes in inputs and device mesh
861
+ warnings.warn(
862
+ "Deprecating input_fn that takes two arguments (inputs, device_mesh), "
863
+ "please use input_fn that takes in (module, inputs, device_mesh) instead!",
864
+ FutureWarning,
865
+ stacklevel=2,
866
+ )
867
+ module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
868
+ elif num_args == 3:
869
+ # input_fn takes in module, inputs, device mesh
870
+ module.register_forward_pre_hook(
871
+ lambda mod, inputs: input_fn(mod, inputs, device_mesh)
872
+ )
873
+ else:
874
+ raise ValueError(
875
+ f"input_fn should take in 3 arguments, but got {num_args} arguments!"
876
+ )
877
+ # register output_fn as module forward hook
878
+ if output_fn is not None:
879
+ num_args = len(inspect.signature(output_fn).parameters)
880
+ if num_args == 2:
881
+ # output_fn only takes in outputs and device mesh
882
+ warnings.warn(
883
+ "Deprecating output_fn that takes two arguments (inputs, device_mesh), "
884
+ "please use output_fn that takes in (module, inputs, device_mesh) instead!",
885
+ FutureWarning,
886
+ stacklevel=2,
887
+ )
888
+ module.register_forward_hook(
889
+ lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg]
890
+ )
891
+ elif num_args == 3:
892
+ module.register_forward_hook(
893
+ lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
894
+ )
895
+ else:
896
+ raise ValueError(
897
+ f"output_fn should take in 3 arguments, but got {num_args} arguments!"
898
+ )
899
+
900
+ return module
901
+
902
+
903
+ # Below are tensor factory function APIs, which are used to create a DTensor directly. We need
904
+ # to make separate factory function APIs because tensor subclass could not override the tensor
905
+ # factory methods, and we need user to call the factory functions with user intended device_mesh
906
+ # and placements to create a proper DTensor.
907
+
908
+
909
+ def _dtensor_init_helper( # type: ignore[no-untyped-def]
910
+ init_op,
911
+ size: torch.Size,
912
+ device_mesh: Optional[DeviceMesh] = None,
913
+ placements: Optional[Sequence[Placement]] = None,
914
+ **kwargs,
915
+ ) -> DTensor:
916
+ # from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
917
+
918
+ # if device_mesh is None, use the one from mesh resources
919
+ device_mesh = device_mesh or _mesh_resources.get_current_mesh()
920
+ kwargs["device"] = device_mesh.device_type
921
+
922
+ # set default placements to replicated if not specified
923
+ placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
924
+
925
+ # check device_mesh againts placements
926
+ assert device_mesh.ndim == len(
927
+ placements
928
+ ), "mesh dimension does not match the length of placements"
929
+
930
+ assert kwargs["layout"] == torch.strided, "layout value not supported!"
931
+ torch_stride = torch._prims_common.make_contiguous_strides_for(size)
932
+
933
+ # get local tensor shape
934
+ local_shape = compute_local_shape(size, device_mesh, placements)
935
+ # initialize the local tensor
936
+ if init_op == torch.full:
937
+ fill_value = kwargs.pop("fill_value", 0)
938
+ local_tensor = init_op(local_shape, fill_value, **kwargs)
939
+ elif init_op == torch.rand or init_op == torch.randn:
940
+ # this tensor meta is not used except `shape`
941
+ dtype = kwargs.get("dtype", torch.get_default_dtype())
942
+
943
+ tensor_meta = TensorMeta(size, (0,), dtype)
944
+ spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta)
945
+
946
+ if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
947
+ random._rng_tracker = random.OffsetBasedRNGTracker()
948
+
949
+ assert random._rng_tracker is not None
950
+ with random._rng_tracker._distribute_region(spec):
951
+ local_tensor = init_op(local_shape, **kwargs)
952
+ else:
953
+ local_tensor = init_op(local_shape, **kwargs)
954
+
955
+ spec = DTensorSpec(
956
+ device_mesh,
957
+ tuple(placements),
958
+ tensor_meta=TensorMeta(
959
+ size,
960
+ torch_stride,
961
+ local_tensor.dtype,
962
+ ),
963
+ )
964
+
965
+ return DTensor(
966
+ local_tensor,
967
+ spec,
968
+ requires_grad=kwargs["requires_grad"],
969
+ )
970
+
971
+
972
+ def ones( # type: ignore[no-untyped-def]
973
+ *size,
974
+ dtype: Optional[torch.dtype] = None,
975
+ layout: torch.layout = torch.strided,
976
+ requires_grad: bool = False,
977
+ device_mesh: Optional[DeviceMesh] = None,
978
+ placements: Optional[Sequence[Placement]] = None,
979
+ ) -> DTensor:
980
+ """
981
+ Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
982
+ by the variable argument ``size``.
983
+
984
+ Args:
985
+ size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
986
+ Can be a variable number of arguments or a collection like a list or tuple.
987
+ E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
988
+
989
+ Keyword args:
990
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
991
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
992
+ layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
993
+ Default: ``torch.strided``.
994
+ requires_grad (bool, optional): If autograd should record operations on the
995
+ returned :class:`DTensor`. Default: ``False``.
996
+ device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
997
+ placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
998
+
999
+ Returns:
1000
+ A :class:`DTensor` object on each rank
1001
+ """
1002
+ torch_size = normalize_to_torch_size(size)
1003
+
1004
+ return _dtensor_init_helper(
1005
+ torch.ones,
1006
+ torch_size,
1007
+ dtype=dtype,
1008
+ layout=layout,
1009
+ requires_grad=requires_grad,
1010
+ device_mesh=device_mesh,
1011
+ placements=placements,
1012
+ )
1013
+
1014
+
1015
+ def empty( # type: ignore[no-untyped-def]
1016
+ *size,
1017
+ dtype: Optional[torch.dtype] = None,
1018
+ layout: torch.layout = torch.strided,
1019
+ requires_grad: bool = False,
1020
+ device_mesh: Optional[DeviceMesh] = None,
1021
+ placements: Optional[Sequence[Placement]] = None,
1022
+ ) -> DTensor:
1023
+ """
1024
+ Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
1025
+ is defined by the variable argument ``size``.
1026
+
1027
+ Args:
1028
+ size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1029
+ Can be a variable number of arguments or a collection like a list or tuple.
1030
+ E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
1031
+
1032
+ Keyword args:
1033
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1034
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\
1035
+ layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
1036
+ Default: ``torch.strided``.
1037
+ requires_grad (bool, optional): If autograd should record operations on the
1038
+ returned :class:`DTensor`. Default: ``False``.
1039
+ device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
1040
+ placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1041
+
1042
+ Returns:
1043
+ A :class:`DTensor` object on each rank
1044
+ """
1045
+ torch_size = normalize_to_torch_size(size)
1046
+
1047
+ return _dtensor_init_helper(
1048
+ torch.empty,
1049
+ torch_size,
1050
+ dtype=dtype,
1051
+ layout=layout,
1052
+ requires_grad=requires_grad,
1053
+ device_mesh=device_mesh,
1054
+ placements=placements,
1055
+ )
1056
+
1057
+
1058
+ def full( # type: ignore[no-untyped-def]
1059
+ size,
1060
+ fill_value,
1061
+ *,
1062
+ dtype: Optional[torch.dtype] = None,
1063
+ layout: torch.layout = torch.strided,
1064
+ requires_grad: bool = False,
1065
+ device_mesh: Optional[DeviceMesh] = None,
1066
+ placements: Optional[Sequence[Placement]] = None,
1067
+ ) -> DTensor:
1068
+ """
1069
+ Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and
1070
+ ``placements``, with the shape defined by the argument ``size``.
1071
+
1072
+ Args:
1073
+ size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1074
+ Can be a variable number of arguments or a collection like a list or tuple.
1075
+ E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
1076
+ fill_value(Scalar): the value to fill the output tensor with.
1077
+
1078
+ Keyword args:
1079
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1080
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1081
+ layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
1082
+ Default: ``torch.strided``.
1083
+ requires_grad (bool, optional): If autograd should record operations on the
1084
+ returned :class:`DTensor`. Default: ``False``.
1085
+ device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
1086
+ placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1087
+
1088
+ Returns:
1089
+ A :class:`DTensor` object on each rank
1090
+ """
1091
+ torch_size = normalize_to_torch_size(size)
1092
+
1093
+ return _dtensor_init_helper(
1094
+ torch.full,
1095
+ torch_size,
1096
+ fill_value=fill_value,
1097
+ dtype=dtype,
1098
+ layout=layout,
1099
+ requires_grad=requires_grad,
1100
+ device_mesh=device_mesh,
1101
+ placements=placements,
1102
+ )
1103
+
1104
+
1105
+ def rand( # type: ignore[no-untyped-def]
1106
+ *size,
1107
+ requires_grad: bool = False,
1108
+ dtype: Optional[torch.dtype] = None,
1109
+ layout: torch.layout = torch.strided,
1110
+ device_mesh: Optional[DeviceMesh] = None,
1111
+ placements: Optional[Sequence[Placement]] = None,
1112
+ ) -> DTensor:
1113
+ """
1114
+ Returns a :class:`DTensor` filled with random numbers from a uniform distribution
1115
+ on the interval ``[0, 1)``. The shape of the tensor is defined by the variable
1116
+ argument ``size``.
1117
+
1118
+ Args:
1119
+ size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1120
+ Can be a variable number of arguments or a collection like a list or tuple.
1121
+ E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
1122
+
1123
+ Keyword args:
1124
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1125
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1126
+ layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
1127
+ Default: ``torch.strided``.
1128
+ requires_grad (bool, optional): If autograd should record operations on the
1129
+ returned :class:`DTensor`. Default: ``False``.
1130
+ device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
1131
+ placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1132
+
1133
+ Returns:
1134
+ A :class:`DTensor` object on each rank
1135
+ """
1136
+ torch_size = normalize_to_torch_size(size)
1137
+
1138
+ return _dtensor_init_helper(
1139
+ torch.rand,
1140
+ torch_size,
1141
+ dtype=dtype,
1142
+ layout=layout,
1143
+ requires_grad=requires_grad,
1144
+ device_mesh=device_mesh,
1145
+ placements=placements,
1146
+ )
1147
+
1148
+
1149
+ def randn( # type: ignore[no-untyped-def]
1150
+ *size,
1151
+ requires_grad: bool = False,
1152
+ dtype: Optional[torch.dtype] = None,
1153
+ layout: torch.layout = torch.strided,
1154
+ device_mesh: Optional[DeviceMesh] = None,
1155
+ placements: Optional[Sequence[Placement]] = None,
1156
+ ) -> DTensor:
1157
+ """
1158
+ Returns a :class:`DTensor` filled with random numbers from a normal distribution
1159
+ with mean 0 and variance 1. The shape of the tensor is defined by the variable
1160
+ argument ``size``.
1161
+
1162
+ Args:
1163
+ size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1164
+ Can be a variable number of arguments or a collection like a list or tuple.
1165
+ E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
1166
+
1167
+ Keyword args:
1168
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1169
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1170
+ layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
1171
+ Default: ``torch.strided``.
1172
+ requires_grad (bool, optional): If autograd should record operations on the
1173
+ returned :class:`DTensor`. Default: ``False``.
1174
+ device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
1175
+ placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1176
+
1177
+ Returns:
1178
+ A :class:`DTensor` object on each rank
1179
+ """
1180
+ torch_size = normalize_to_torch_size(size)
1181
+
1182
+ return _dtensor_init_helper(
1183
+ torch.randn,
1184
+ torch_size,
1185
+ dtype=dtype,
1186
+ layout=layout,
1187
+ requires_grad=requires_grad,
1188
+ device_mesh=device_mesh,
1189
+ placements=placements,
1190
+ )
1191
+
1192
+
1193
+ def zeros( # type: ignore[no-untyped-def]
1194
+ *size,
1195
+ requires_grad: bool = False,
1196
+ dtype: Optional[torch.dtype] = None,
1197
+ layout: torch.layout = torch.strided,
1198
+ device_mesh: Optional[DeviceMesh] = None,
1199
+ placements: Optional[Sequence[Placement]] = None,
1200
+ ) -> DTensor:
1201
+ """
1202
+ Returns a :class:`DTensor` filled with the scalar value 0.
1203
+
1204
+ Args:
1205
+ size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1206
+ Can be a variable number of arguments or a collection like a list or tuple.
1207
+ E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
1208
+ Keyword args:
1209
+ requires_grad (bool, optional): If autograd should record operations on the
1210
+ returned :class:`DTensor`. Default: ``False``.
1211
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1212
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1213
+ layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
1214
+ Default: ``torch.strided``.
1215
+ device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
1216
+ placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1217
+
1218
+ Returns:
1219
+ A :class:`DTensor` object on each rank
1220
+ """
1221
+ torch_size = normalize_to_torch_size(size)
1222
+
1223
+ return _dtensor_init_helper(
1224
+ torch.zeros,
1225
+ torch_size,
1226
+ dtype=dtype,
1227
+ layout=layout,
1228
+ requires_grad=requires_grad,
1229
+ device_mesh=device_mesh,
1230
+ placements=placements,
1231
+ )
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_collective_utils.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ import math
4
+ from dataclasses import dataclass
5
+ from functools import lru_cache
6
+ from typing import List, Optional
7
+
8
+ import torch
9
+ import torch.distributed._functional_collectives as funcol
10
+ import torch.distributed.tensor._dtensor_spec as dtensor_spec
11
+ from torch._C._distributed_c10d import _resolve_process_group
12
+ from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
13
+ from torch.distributed.distributed_c10d import (
14
+ _get_group_size_by_name,
15
+ broadcast,
16
+ get_global_rank,
17
+ get_group_rank,
18
+ get_rank,
19
+ GroupMember,
20
+ ProcessGroup,
21
+ scatter,
22
+ Work,
23
+ )
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ if not torch._running_with_deploy():
30
+
31
+ @torch.library.register_fake("_dtensor::shard_dim_alltoall")
32
+ def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
33
+ group_size = _get_group_size_by_name(group_name)
34
+ stacked_list = [torch.empty_like(input) for _ in range(group_size)]
35
+ group = _resolve_process_group(group_name)
36
+ group_rank = get_group_rank(group, get_rank())
37
+
38
+ return torch.cat(stacked_list, dim=gather_dim).chunk(group_size, dim=shard_dim)[
39
+ group_rank
40
+ ]
41
+
42
+ else:
43
+ import warnings
44
+
45
+ warnings.warn(
46
+ "PyTorch Distributed functional collectives do not work with torch::deploy."
47
+ )
48
+
49
+
50
+ def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
51
+ if mesh.device_type == "cpu":
52
+ # Gloo does not support alltoall, so falling back to allgather + chunk
53
+
54
+ # TODO: This logs way too much
55
+ logger.warning(
56
+ "CPU process group does not support alltoall yet, falling back with allgather + chunk!"
57
+ )
58
+ out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim))
59
+ if isinstance(out, funcol.AsyncCollectiveTensor):
60
+ # stick to the same behavior for the alltoall case, remove this once we enable alltoall async
61
+ out = out.wait()
62
+ out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[
63
+ mesh.get_local_rank(mesh_dim)
64
+ ]
65
+ return out.contiguous() if not out.is_contiguous() else out
66
+
67
+ group_name = funcol._resolve_group_name((mesh, mesh_dim))
68
+ # TODO: enable async op for shard_dim_alltoall
69
+ return torch.ops._dtensor.shard_dim_alltoall(
70
+ input, gather_dim, shard_dim, group_name
71
+ )
72
+
73
+
74
+ def mesh_scatter(
75
+ output: torch.Tensor,
76
+ scatter_list: List[torch.Tensor],
77
+ mesh: DeviceMesh,
78
+ mesh_dim: int = 0,
79
+ async_op: bool = False,
80
+ ) -> Optional[Work]:
81
+ """
82
+ scatter a list of tensors to a device mesh dimension. We by default
83
+ use the first rank of the mesh dimension as the source of truth, i.e
84
+ for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
85
+ scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
86
+ 2 to rank 2/3.
87
+
88
+ Args:
89
+ output (torch.Tensor): the tensor to receive the scattered list.
90
+ scatter_list (List[torch.Tensor]): the tensor list to be scattered.
91
+ mesh_dim (int, optional): indicate which mesh dimension we want
92
+ to scatter on, we by default choose the first rank on the
93
+ mesh dimension as source of truth.
94
+
95
+ Returns:
96
+ A :class:`Work` object
97
+ """
98
+ # TODO: Ideally we should use the meta tensor way
99
+ # (to register a meta kernel for the collective op)
100
+ # so that it would avoid the communication. Need to
101
+ # remove the check below once that is done.
102
+ if output.is_meta:
103
+ return None
104
+ dim_group = mesh.get_group(mesh_dim)
105
+ assert isinstance(dim_group, ProcessGroup)
106
+ # src need to be global rank
107
+ src_for_dim = 0
108
+
109
+ if dim_group is not GroupMember.WORLD:
110
+ src_for_dim = get_global_rank(dim_group, 0)
111
+
112
+ if src_for_dim == get_rank():
113
+ fut = scatter(
114
+ output,
115
+ scatter_list=scatter_list,
116
+ src=src_for_dim,
117
+ group=dim_group,
118
+ async_op=async_op,
119
+ )
120
+ else:
121
+ fut = scatter(
122
+ output,
123
+ scatter_list=None,
124
+ src=src_for_dim,
125
+ group=dim_group,
126
+ async_op=async_op,
127
+ )
128
+
129
+ return fut
130
+
131
+
132
+ def mesh_broadcast(
133
+ tensor: torch.Tensor,
134
+ mesh: DeviceMesh,
135
+ mesh_dim: int = 0,
136
+ async_op: bool = False,
137
+ ) -> Optional[Work]:
138
+ """
139
+ broadcast the tensor to a device mesh dimension. We by default
140
+ use the first rank of the mesh dimension as the source of truth, i.e
141
+ for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
142
+ broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
143
+ to rank 2/3.
144
+
145
+ Args:
146
+ tensor (torch.Tensor): tensor to broadcast.
147
+ mesh_dim (int, optional): indicate which mesh dimension we want
148
+ to scatter on, we by default choose the first rank on the
149
+ mesh dimension as source of truth.
150
+
151
+ Returns:
152
+ A :class:`Work` object
153
+ """
154
+ # TODO: Ideally we should use the meta tensor way
155
+ # (to register a meta kernel for the collective op)
156
+ # so that it would avoid the communication. Need to
157
+ # remove the check below once that is done.
158
+ if tensor.is_meta:
159
+ return None
160
+ dim_group = mesh.get_group(mesh_dim)
161
+ assert isinstance(dim_group, ProcessGroup)
162
+ # src need to be global rank
163
+ src_for_dim = 0
164
+ if dim_group is not GroupMember.WORLD:
165
+ src_for_dim = get_global_rank(dim_group, 0)
166
+
167
+ return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
168
+
169
+
170
+ def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
171
+ if pad_size == 0:
172
+ return tensor
173
+ pad = [0, 0] * (tensor.ndim - pad_dim)
174
+ pad[-1] = pad_size
175
+ return torch.nn.functional.pad(tensor, pad)
176
+
177
+
178
+ def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
179
+ if pad_size == 0:
180
+ return tensor
181
+ return tensor.narrow(
182
+ pad_dim,
183
+ start=0,
184
+ length=tensor.size(pad_dim) - pad_size,
185
+ )
186
+
187
+
188
+ def fill_empty_tensor_to_shards(
189
+ shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int
190
+ ) -> List[torch.Tensor]:
191
+ if num_empty_tensors == 0:
192
+ return shards
193
+ tensor_size = list(shards[0].size())
194
+ tensor_size = [
195
+ size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
196
+ ]
197
+ tensor = shards[0].new_zeros(tensor_size)
198
+ for _ in range(num_empty_tensors):
199
+ shards.append(tensor)
200
+ return shards
201
+
202
+
203
+ def check_tensor_meta(
204
+ local_tensor, check_shape_stride=False
205
+ ) -> Optional["dtensor_spec.TensorMeta"]:
206
+ local_metadata = {
207
+ "dtype": local_tensor.dtype,
208
+ "requires_grad": local_tensor.requires_grad,
209
+ }
210
+
211
+ if check_shape_stride:
212
+ local_metadata.update(
213
+ {"shape": local_tensor.shape, "stride": local_tensor.stride()}
214
+ )
215
+
216
+ gathered_metadata = [None for _ in range(torch.distributed.get_world_size())]
217
+ torch.distributed.all_gather_object(gathered_metadata, local_metadata)
218
+
219
+ # Check if metadata is consistent across ranks
220
+ if not all(meta == local_metadata for meta in gathered_metadata):
221
+ raise ValueError(
222
+ "Inconsistent tensor metadata (including shape and stride) across ranks."
223
+ )
224
+ return None
225
+
226
+
227
+ def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int:
228
+ assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
229
+ return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
230
+
231
+
232
+ @dataclass
233
+ class MeshTopoInfo:
234
+ """
235
+ Mesh information for collective cost estimation
236
+ """
237
+
238
+ mesh: DeviceMesh
239
+ mesh_dim_devices: List[int]
240
+ mesh_dim_bandwidth: List[float]
241
+ mesh_dim_latency: List[float]
242
+
243
+ @staticmethod
244
+ @lru_cache(None)
245
+ def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
246
+ # Generate mesh topology info for intra-host/inter-host communication pattern
247
+ # Note that we made bunch of assumptions for simplicity:
248
+ # 1. we assume the mesh is homogeneous, and it's gpu/nccl model
249
+ # 2. we assume gpu arch is Ampere or Hopper
250
+ # 3. we assume collectives are all ring base algo for now
251
+ num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type)
252
+ # the base bw number (intra-node), GB/s
253
+ base_bw = 87.7
254
+ mesh_dim_bandwidth = [base_bw] * mesh.ndim
255
+ # the latency in terms of us (intra-node, nv-link)
256
+ mesh_dim_latency = [0.6] * mesh.ndim
257
+ mesh_dim_devices = [1] * mesh.ndim
258
+
259
+ total_num_devices = 1
260
+ for mesh_dim in reversed(range(mesh.ndim)):
261
+ num_devices = mesh.size(mesh_dim)
262
+ mesh_dim_devices[mesh_dim] = num_devices
263
+ total_num_devices *= num_devices
264
+ if total_num_devices > num_devices_per_host:
265
+ # magic number for inter-host communication bandwidth/latency factor
266
+ # This number assumes latest GPU arch, i.e. Ampere or Hopper
267
+ # TODO: see if we need to tweak this or offer a way for user
268
+ # to specify the bandwidths/latency
269
+ mesh_dim_bandwidth[mesh_dim] *= 0.22
270
+ # set to ethernet latency for inter-host
271
+ mesh_dim_latency[mesh_dim] = 2.7
272
+
273
+ return MeshTopoInfo(
274
+ mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency
275
+ )
276
+
277
+
278
+ def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
279
+ num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
280
+ mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
281
+ num_hops = num_devices_on_mesh_dim - 1
282
+ # base latency + comm latency
283
+ latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us
284
+ bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s
285
+ return latency + bw * 1e6 # rescale to us
286
+
287
+
288
+ def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
289
+ num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
290
+ mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
291
+ # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter
292
+ num_hops = 2 * num_devices_on_mesh_dim - 1
293
+
294
+ latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
295
+ bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
296
+ return latency + bw * 1e6
297
+
298
+
299
+ def reduce_scatter_cost(
300
+ bytes_gb: float,
301
+ mesh_topo: MeshTopoInfo,
302
+ mesh_dim: int,
303
+ ) -> float:
304
+ num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
305
+ mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
306
+ num_hops = num_devices_on_mesh_dim - 1
307
+ # base latency + comm latency
308
+ latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
309
+ bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
310
+ return latency + bw * 1e6
311
+
312
+
313
+ def redistribute_cost(
314
+ current_spec: "dtensor_spec.DTensorSpec",
315
+ target_spec: "dtensor_spec.DTensorSpec",
316
+ ) -> float:
317
+ """
318
+ This function returns the cost of redistribute from current to target DTensorSpec.
319
+
320
+ NOTE:
321
+ 1. Only consider communication cost here, since computation costs for redistribute
322
+ are quite trival (i.e. we only need to narrow or simple division)
323
+ 2. Only consider redistribute cost on same mesh, cross mesh communication cost is
324
+ not quite needed for operator strategy estimation/selection.
325
+ """
326
+ if current_spec.mesh != target_spec.mesh:
327
+ # make infinite cost if meshes are not same
328
+ # TODO: see if we want to support this once there's cross mesh communication
329
+ return float("inf")
330
+
331
+ if current_spec.is_replicated():
332
+ # short-cut:
333
+ # comm cost is 0 if current spec is already full replication
334
+ return 0.0
335
+
336
+ mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
337
+ cost = 0.0
338
+ comm_bytes_gb = (
339
+ spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
340
+ )
341
+ # Transformation that considered for redistribute cost:
342
+ # 1. allgather 2. alltoall
343
+ # 3. allreduce 4. reduce_scatter
344
+ for i, (current, target) in enumerate(
345
+ zip(current_spec.placements, target_spec.placements)
346
+ ):
347
+ if current == target:
348
+ continue
349
+
350
+ num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
351
+ if current.is_shard() and target.is_replicate():
352
+ # allgather gives larger comm bytes
353
+ comm_bytes_gb *= num_devices_on_mesh_dim
354
+ # add up allgather comm cost
355
+ cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
356
+ elif current.is_shard() and target.is_shard():
357
+ # should be alltoall comm, since we haven't implement it yet, add penalty
358
+ # to favor allgather instead
359
+ cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0
360
+ elif current.is_partial() and target.is_replicate():
361
+ # add up allreduce comm cost
362
+ cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
363
+ elif current.is_partial() and target.is_shard():
364
+ # add up reduce_scatter comm cost
365
+ cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
366
+ # after reduce_scatter the comm bytes for further collectives halved.
367
+ comm_bytes_gb /= num_devices_on_mesh_dim
368
+ elif current.is_shard() and target.is_partial():
369
+ # ban shard -> partial as it does not make sense to perform
370
+ # this redistribute
371
+ return float("inf")
372
+
373
+ return cost
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ import contextlib
3
+ import functools
4
+ import logging
5
+ import operator
6
+ import warnings
7
+ from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.distributed.tensor._api as dtensor
12
+ import torch.distributed.tensor._random as random
13
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
14
+ from torch.distributed.tensor._op_schema import (
15
+ _is_inplace_op,
16
+ _is_out_variant_op,
17
+ OpInfo,
18
+ OpSchema,
19
+ OutputSpecType,
20
+ )
21
+ from torch.distributed.tensor._random import is_rng_supported_mesh
22
+ from torch.distributed.tensor._redistribute import redistribute_local_tensor
23
+ from torch.distributed.tensor._sharding_prop import ShardingPropagator
24
+ from torch.distributed.tensor._tp_conv import (
25
+ convolution_backward_handler,
26
+ convolution_handler,
27
+ )
28
+ from torch.distributed.tensor._utils import try_find_mesh_from_args
29
+ from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from torch.distributed.device_mesh import DeviceMesh
34
+
35
+ try:
36
+ from torch.utils import _cxx_pytree as pytree
37
+ except ImportError:
38
+ from torch.utils import _pytree as pytree # type: ignore[no-redef]
39
+
40
+ aten = torch.ops.aten
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def decompose_handler(
45
+ op_call: torch._ops.OpOverload,
46
+ args: Tuple[object, ...],
47
+ kwargs: Dict[str, object],
48
+ ) -> object:
49
+ """
50
+ Decomposes a op to core ATen op, this handler is mostly here
51
+ for inference mode usage where the ops are not core aten ops.
52
+ """
53
+ r = op_call.decompose(*args, **kwargs)
54
+ if r is not NotImplemented:
55
+ return r
56
+ else:
57
+ raise RuntimeError("Decomposition failed")
58
+
59
+
60
+ def is_same_size_handler(
61
+ op_call: torch._ops.OpOverload,
62
+ args: Tuple[object, ...],
63
+ kwargs: Dict[str, object],
64
+ ) -> bool:
65
+ lhs = cast(torch.Tensor, args[0])
66
+ rhs = cast(torch.Tensor, args[1])
67
+ return lhs.shape == rhs.shape
68
+
69
+
70
+ def found_inf_reduce_handler(
71
+ op_call: torch._ops.OpOverload,
72
+ args: Tuple[object, ...],
73
+ kwargs: Dict[str, object],
74
+ ) -> None:
75
+ op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
76
+ local_tensor_args = pytree.tree_unflatten(
77
+ cast(List[object], op_info.local_args), op_info.args_tree_spec
78
+ )
79
+ local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
80
+ local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
81
+
82
+ grad_dtensor = cast(list[dtensor.DTensor], args[0])[0]
83
+ grad_placements = grad_dtensor.placements
84
+ mesh = grad_dtensor.device_mesh
85
+
86
+ found_inf_placements: list[Placement] = []
87
+ for placement in grad_placements:
88
+ if isinstance(placement, Replicate):
89
+ found_inf_placements.append(placement)
90
+ else:
91
+ found_inf_placements.append(Partial("max"))
92
+
93
+ target_tensor = cast(torch.Tensor, args[1])
94
+ spec = DTensorSpec(
95
+ mesh=mesh,
96
+ placements=tuple(found_inf_placements),
97
+ tensor_meta=TensorMeta(
98
+ shape=target_tensor.size(),
99
+ stride=target_tensor.stride(),
100
+ dtype=target_tensor.dtype,
101
+ ),
102
+ )
103
+ found_inf_dtensor = dtensor.DTensor(
104
+ local_tensor=target_tensor, spec=spec, requires_grad=False
105
+ )
106
+ found_inf = found_inf_dtensor.full_tensor()
107
+ target_tensor.copy_(found_inf)
108
+
109
+
110
+ class OpDispatcher:
111
+ """
112
+ Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding
113
+ propagation, redistribute local args, local compute, and post-processing (re-wrapping). It
114
+ also handles any op specific logic if necessary.
115
+
116
+ NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher
117
+ is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster
118
+ pytree if needed, and leveraging various caching mechanisms implemented in the sharding
119
+ propagation and redistribute modules. The CPU overhead is critical to eager mode performance,
120
+ one need to carefully measure the CPU overhead when making significant changes to the
121
+ OpDispatcher and ShardingPropagator.
122
+ """
123
+
124
+ def __init__(self) -> None:
125
+ self.sharding_propagator = ShardingPropagator()
126
+ self._random_ops = {
127
+ aten.native_dropout.default,
128
+ aten.normal_.default,
129
+ aten.rand_like.default,
130
+ aten.randn_like.default,
131
+ aten.randint_like.default,
132
+ aten.randint_like.low_dtype,
133
+ aten.randint_like.low_dtype_out,
134
+ aten.uniform_.default,
135
+ aten.bernoulli.default,
136
+ aten.bernoulli_.float,
137
+ }
138
+ self._custom_op_handlers = {
139
+ aten.linear.default: decompose_handler,
140
+ aten.is_same_size.default: is_same_size_handler,
141
+ aten.convolution.default: convolution_handler,
142
+ aten.convolution_backward.default: convolution_backward_handler,
143
+ aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
144
+ }
145
+
146
+ # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
147
+ # as implicitly replicated or we throw error to user.
148
+ # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
149
+ # it as False by default.
150
+ self._allow_implicit_replication = False
151
+
152
+ def dispatch(
153
+ self,
154
+ op_call: torch._ops.OpOverload,
155
+ args: Tuple[object, ...],
156
+ kwargs: Dict[str, object],
157
+ ) -> object:
158
+ """
159
+ Main dispatching logic
160
+ """
161
+ # operators that does not need to go through sharding propagation
162
+ if op_call in self._custom_op_handlers:
163
+ return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
164
+
165
+ # extract local tensor and sharding infos to a OpInfo
166
+ op_info = self.unwrap_to_op_info(op_call, args, kwargs)
167
+ logger.debug("Dispatching op_call: %s", op_info.schema)
168
+
169
+ self.sharding_propagator.propagate(op_info)
170
+ output_sharding = op_info.output_sharding
171
+ logger.debug("output_sharding for %s: %s", op_call, output_sharding)
172
+ assert output_sharding is not None, "output sharding should not be None"
173
+
174
+ mesh = op_info.mesh
175
+ if mesh.get_coordinate() is not None:
176
+ # computation that happens in the current rank of the mesh, normal case
177
+ if output_sharding.needs_redistribute:
178
+ # If sharding propagation decision needs redistribute, perform redistribute
179
+ # on args first, which could potentially modify args (i.e. allgather certain arg)
180
+ assert output_sharding.redistribute_schema is not None
181
+ self.redistribute_local_args(
182
+ op_info, output_sharding.redistribute_schema
183
+ )
184
+
185
+ local_tensor_args = (
186
+ pytree.tree_unflatten(
187
+ cast(List[object], op_info.local_args), op_info.args_tree_spec
188
+ )
189
+ if op_info.args_tree_spec
190
+ else op_info.local_args
191
+ )
192
+
193
+ # run local op computation with potentially modified args/kwargs
194
+ local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
195
+ if op_call in self._random_ops:
196
+ if not random._rng_tracker and is_rng_supported_mesh(mesh):
197
+ # Default to `OffsetBasedRNGTracker` if the parallelism API
198
+ # did not already construct one
199
+ random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
200
+
201
+ first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
202
+ torch.Tensor, local_tensor_args[0]
203
+ )
204
+ rng_context = (
205
+ random._rng_tracker._distribute_region(first_arg._spec)
206
+ if random._rng_tracker and not first_local_arg.is_meta
207
+ else contextlib.nullcontext()
208
+ )
209
+ # For DTensor random operator, run it within a RNGTracker context to
210
+ # ensure the random number generator is properly distributed.
211
+ with rng_context:
212
+ local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
213
+ else:
214
+ # normal case, run local sharded op computation
215
+ local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
216
+
217
+ else:
218
+ # For a non-participating device (happens on rank that does not belong to
219
+ # the device mesh), we do:
220
+ # 1. if the return type is scalar, set the local result to None.
221
+ # 2. if the return type is Tensor or List[Tensor], return empty
222
+ # tensor(s) with correct dtype.
223
+ spec = output_sharding.output_spec
224
+ ret_list = op_info.schema.op._schema.returns
225
+
226
+ if spec is None:
227
+ # For a scalar return type, the non-participating device has None
228
+ # as its local result
229
+ local_results = None
230
+ else:
231
+
232
+ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
233
+ if spec.tensor_meta is not None:
234
+ shape = spec.tensor_meta.shape
235
+ dtype = spec.tensor_meta.dtype
236
+ if len(shape) == 0:
237
+ # scalar tensor
238
+ return torch.zeros((), dtype=dtype)
239
+ else:
240
+ # non-scalar tensor
241
+ return torch.tensor([], dtype=dtype)
242
+ else:
243
+ raise RuntimeError(f"{spec} has no tensor metadata.")
244
+
245
+ if isinstance(spec, DTensorSpec):
246
+ # return a Tensor value
247
+ local_results = default_tensor(spec)
248
+ elif isinstance(spec, Sequence):
249
+ # return a List[Tensor] value
250
+ local_results = [
251
+ default_tensor(s) if s is not None else None for s in spec
252
+ ]
253
+ assert isinstance(local_results, List)
254
+ if None in local_results:
255
+ ret_type = str(ret_list[0].type)
256
+ raise NotImplementedError(
257
+ f"return type {ret_type} in DTensor op is not supported"
258
+ )
259
+
260
+ if output_sharding.output_spec is None:
261
+ if op_call == aten.equal.default:
262
+ # For equal operator, The local results from all devices should be all-gathered
263
+ # and a reduce op (AND) will be performed on the list of results to ensure SPMD
264
+ # execution. We can extend this for more ops if necessary.
265
+ obj_list = [None for _ in range(dist.get_world_size())]
266
+ dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined]
267
+ obj_list = list(filter(lambda x: x is not None, obj_list))
268
+ # perform reduce on the collection with AND op
269
+ local_results = functools.reduce(operator.and_, obj_list, True)
270
+
271
+ if _is_inplace_op(op_call):
272
+ # inplace op should return self instead of re-wrapping
273
+ if output_sharding.output_spec is not None:
274
+ return args[0]
275
+ else:
276
+ return None
277
+ elif _is_out_variant_op(op_call):
278
+ # out variant could possibly have multiple out args (i.e. lu_unpack.out)
279
+ output_specs = (
280
+ (output_sharding.output_spec,)
281
+ if not isinstance(output_sharding.output_spec, tuple)
282
+ else output_sharding.output_spec
283
+ )
284
+ out_dts = []
285
+ spec_idx = 0
286
+ for argument in op_call._schema.arguments:
287
+ if argument.is_out:
288
+ out_dt = cast(dtensor.DTensor, kwargs[argument.name])
289
+ out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
290
+ out_dts.append(out_dt)
291
+ spec_idx += 1
292
+
293
+ assert len(out_dts) >= 1, "out variant should have at least one out arg"
294
+ return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
295
+ else:
296
+ return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
297
+
298
+ @staticmethod
299
+ def redistribute_local_args(
300
+ op_info: OpInfo,
301
+ suggested_input_schema: OpSchema,
302
+ ) -> None:
303
+ # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
304
+ if op_info.args_tree_spec is not None:
305
+ flatten_args_schema_to_reshard = tuple(
306
+ pytree.tree_leaves(suggested_input_schema.args_schema)
307
+ )
308
+ else:
309
+ flatten_args_schema_to_reshard = suggested_input_schema.args_schema
310
+
311
+ new_local_args: List[object] = []
312
+ for i, arg_spec in enumerate(op_info.flat_args_schema):
313
+ reshard_arg_spec = flatten_args_schema_to_reshard[i]
314
+ if isinstance(arg_spec, DTensorSpec):
315
+ local_tensor = cast(torch.Tensor, op_info.local_args[i])
316
+ if arg_spec != reshard_arg_spec:
317
+ resharded_local_tensor = redistribute_local_tensor(
318
+ local_tensor, arg_spec, reshard_arg_spec
319
+ )
320
+ new_local_args.append(resharded_local_tensor)
321
+ else:
322
+ new_local_args.append(local_tensor)
323
+ else:
324
+ new_local_args.append(reshard_arg_spec)
325
+
326
+ op_info.local_args = tuple(new_local_args)
327
+
328
+ def unwrap_to_op_info(
329
+ self,
330
+ op_call: torch._ops.OpOverload,
331
+ args: Tuple[object, ...],
332
+ kwargs: Dict[str, object],
333
+ ) -> OpInfo:
334
+ # get runtime schema info to determine whether to use pytree to flatten inputs
335
+ runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
336
+ op_call, None
337
+ )
338
+
339
+ if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
340
+ # flatten args/kwargs when op says necessary
341
+ tree_args, args_spec = pytree.tree_flatten(args)
342
+ args_list: Sequence[object] = tree_args
343
+ else:
344
+ args_list, args_spec = args, None
345
+
346
+ args_schema: List[object] = []
347
+ kwargs_schema: Dict[str, object] = {}
348
+ local_args: List[object] = []
349
+ local_kwargs: Dict[str, object] = {}
350
+ mesh: Optional[DeviceMesh] = None
351
+
352
+ for arg in args_list:
353
+ if isinstance(arg, dtensor.DTensor):
354
+ local_args.append(arg._local_tensor)
355
+ if mesh is not None and mesh != arg.device_mesh:
356
+ # TODO: try replicate dtensor spec in missing dimension would work
357
+ # for most cases for foreach case except when the first DTensor in
358
+ # the list is one that also need to be replicated. We need to revisit
359
+ # how we want to handle this corner case. For now, this case would hit
360
+ # the cross mesh error even if implicit replication is turned on.
361
+ spec = self._try_replicate_dtensor_spec_in_missing_dim(
362
+ op_call, arg, mesh
363
+ )
364
+ args_schema.append(spec)
365
+ else:
366
+ mesh = arg.device_mesh
367
+ args_schema.append(arg._spec)
368
+ elif isinstance(arg, torch.Tensor):
369
+ mesh = mesh or try_find_mesh_from_args(op_call, args_list)
370
+ args_schema.append(
371
+ self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
372
+ )
373
+ local_args.append(arg)
374
+ else:
375
+ args_schema.append(arg)
376
+ local_args.append(arg)
377
+
378
+ for k, v in kwargs.items():
379
+ if isinstance(v, dtensor.DTensor):
380
+ local_kwargs[k] = v._local_tensor
381
+ if mesh is not None and mesh != v.device_mesh:
382
+ spec = self._try_replicate_dtensor_spec_in_missing_dim(
383
+ op_call, v, mesh
384
+ )
385
+ kwargs_schema[k] = spec
386
+ else:
387
+ mesh = v.device_mesh
388
+ kwargs_schema[k] = v._spec
389
+ elif isinstance(v, torch.Tensor):
390
+ mesh = mesh or try_find_mesh_from_args(op_call, args_list)
391
+ kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
392
+ op_call, v, mesh
393
+ )
394
+ local_kwargs[k] = v
395
+ else:
396
+ kwargs_schema[k] = v
397
+ local_kwargs[k] = v
398
+
399
+ assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!"
400
+ op_info = OpInfo(
401
+ mesh,
402
+ OpSchema(
403
+ op_call,
404
+ pytree.tree_unflatten(args_schema, args_spec)
405
+ if args_spec
406
+ else tuple(args_schema),
407
+ kwargs_schema,
408
+ schema_info=runtime_schema_info,
409
+ ),
410
+ args_schema,
411
+ tuple(local_args),
412
+ local_kwargs,
413
+ args_spec,
414
+ )
415
+ return op_info
416
+
417
+ @staticmethod
418
+ def wrap(res: object, spec: OutputSpecType) -> object:
419
+ if isinstance(res, torch.Tensor):
420
+ if spec is not None:
421
+ assert isinstance(
422
+ spec, DTensorSpec
423
+ ), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
424
+ return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
425
+ else:
426
+ # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
427
+ assert res.ndim == 0, "output tensor should be scalar!"
428
+ return res
429
+ elif isinstance(res, (list, tuple)):
430
+ assert spec is not None and isinstance(
431
+ spec, (list, tuple)
432
+ ), f"output spec does not match with output! Expected list/tuple, got {spec}."
433
+ res_list = []
434
+ for e, s in zip(res, spec):
435
+ res_list.append(OpDispatcher.wrap(e, s))
436
+
437
+ return tuple(res_list) if isinstance(res, tuple) else res_list
438
+ else:
439
+ # if the res contains only non tensor values (i.e. int/float/none), we simply return it
440
+ # without rewrapping to DTensor.
441
+ return res
442
+
443
+ def _try_replicate_spec_for_scalar_tensor(
444
+ self,
445
+ op_call: torch._ops.OpOverload,
446
+ tensor_arg: torch.Tensor,
447
+ mesh: "DeviceMesh",
448
+ ) -> DTensorSpec:
449
+ # util function to produce a replicate spec for a scalar tensor arg/kwarg
450
+ if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
451
+ warnings.warn(
452
+ "Found a non-scalar tensor with numel=1 and ndim!=0, "
453
+ "we are implicitly creating a replicated DTensor for it. "
454
+ "However, please consider changing it to a scalar tensor "
455
+ "or explicitly create a DTensor under distributed enviroment."
456
+ )
457
+
458
+ if tensor_arg.numel() == 1 or self._allow_implicit_replication:
459
+ # scalar tensor can be safely treated as replicated
460
+ replication_spec = DTensorSpec(
461
+ mesh,
462
+ (Replicate(),) * mesh.ndim,
463
+ tensor_meta=TensorMeta(
464
+ shape=tensor_arg.shape,
465
+ stride=tensor_arg.stride(),
466
+ dtype=tensor_arg.dtype,
467
+ ),
468
+ )
469
+ else:
470
+ raise RuntimeError(
471
+ f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
472
+ " torch.Tensor to DTensor before calling distributed operators!"
473
+ )
474
+ return replication_spec
475
+
476
+ def _try_replicate_dtensor_spec_in_missing_dim(
477
+ self,
478
+ op_call: torch._ops.OpOverload,
479
+ dtensor_arg: "dtensor.DTensor",
480
+ mesh: "DeviceMesh",
481
+ ) -> DTensorSpec:
482
+ # util function to produce a new spec for a DTensor arg/kwarg
483
+ # that puts Replicate() placement in the missing dimension for foreach ops
484
+ from torch.distributed.device_mesh import _mesh_resources
485
+
486
+ cur_mesh = dtensor_arg.device_mesh
487
+ root_mesh = _mesh_resources.get_root_mesh(cur_mesh)
488
+ if (
489
+ self._allow_implicit_replication
490
+ and "foreach" in op_call.__name__
491
+ and root_mesh == mesh
492
+ ):
493
+ placements = [Replicate() for _ in range(root_mesh.ndim)]
494
+ cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh)
495
+ placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload]
496
+ replicate_spec = DTensorSpec(
497
+ root_mesh,
498
+ tuple(placements),
499
+ tensor_meta=TensorMeta(
500
+ shape=dtensor_arg.shape,
501
+ stride=dtensor_arg.stride(),
502
+ dtype=dtensor_arg.dtype,
503
+ ),
504
+ )
505
+ else:
506
+ raise NotImplementedError(
507
+ f"{op_call}: DTensor does not support cross-mesh operation yet! "
508
+ f"Got meshes: {mesh} {cur_mesh}"
509
+ )
510
+ return replicate_spec
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, cast, List, NamedTuple, Optional, Tuple
3
+
4
+ import torch
5
+ from torch.distributed.device_mesh import DeviceMesh
6
+ from torch.distributed.tensor.placement_types import (
7
+ Partial,
8
+ Placement,
9
+ Replicate,
10
+ Shard,
11
+ )
12
+
13
+
14
+ class TensorMeta(NamedTuple):
15
+ # simple named tuple to represent tensor metadata
16
+ # intentionally to stay simple only for sharding
17
+ # propagation purposes.
18
+ shape: torch.Size
19
+ stride: Tuple[int, ...]
20
+ dtype: torch.dtype
21
+
22
+
23
+ # used internally to propagate the placements
24
+ @dataclass
25
+ class DTensorSpec:
26
+ mesh: DeviceMesh
27
+ placements: Tuple[Placement, ...]
28
+
29
+ # tensor meta will only be set during sharding propagation
30
+ tensor_meta: Optional[TensorMeta] = None
31
+
32
+ def __post_init__(self) -> None:
33
+ if not isinstance(self.placements, tuple):
34
+ self.placements = tuple(self.placements)
35
+ self._hash: Optional[int] = None
36
+
37
+ def __setattr__(self, attr: str, value: Any) -> None:
38
+ super().__setattr__(attr, value)
39
+ # Make sure to recompute the hash in case any of the hashed attributes
40
+ # change (though we do not expect `mesh` or `placements` to change)
41
+ if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
42
+ self._hash = None
43
+
44
+ def _hash_impl(self) -> int:
45
+ # hashing and equality check for DTensorSpec are used to cache the sharding
46
+ # propagation results. We only need to consider the mesh, placements, shape
47
+ # dtype and stride.
48
+ # Caveat: we need to keep this in mind and sync hash and eq if we add more
49
+ # fields to them.
50
+ if self.tensor_meta is not None:
51
+ return hash(
52
+ (
53
+ self.mesh,
54
+ self.placements,
55
+ self.tensor_meta.shape,
56
+ self.tensor_meta.stride,
57
+ self.tensor_meta.dtype,
58
+ )
59
+ )
60
+ return hash((self.mesh, self.placements))
61
+
62
+ def __hash__(self) -> int:
63
+ # We lazily cache the spec to avoid recomputing the hash upon each
64
+ # use, where we make sure to update the hash when the `tensor_meta`
65
+ # changes by overriding `__setattr__`. This must be lazy so that Dynamo
66
+ # does not try to hash non-singleton `SymInt`s for the stride.
67
+ if self._hash is None:
68
+ self._hash = self._hash_impl()
69
+ return self._hash
70
+
71
+ def __eq__(self, __o: object) -> bool:
72
+ if not (
73
+ isinstance(__o, DTensorSpec)
74
+ and self.mesh == __o.mesh
75
+ and self.placements == __o.placements
76
+ ):
77
+ return False
78
+ if self.tensor_meta is None or __o.tensor_meta is None:
79
+ return self.tensor_meta == __o.tensor_meta
80
+
81
+ return (
82
+ self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr]
83
+ and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr]
84
+ and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr]
85
+ )
86
+
87
+ def __str__(self) -> str:
88
+ """
89
+ human readable representation of the DTensorSpec
90
+ """
91
+ if len(self.placements) == 1:
92
+ placement_str = str(self.placements[0])
93
+ else:
94
+ placement_str = str(self.placements)
95
+
96
+ if self.tensor_meta is not None:
97
+ tensor_shape = str(tuple(self.tensor_meta.shape))
98
+ else:
99
+ tensor_shape = "unknown shape"
100
+
101
+ return f"Spec({placement_str} on {tensor_shape})"
102
+
103
+ @property
104
+ def shape(self) -> torch.Size:
105
+ if self.tensor_meta is None:
106
+ raise ValueError("tensor_meta is not set")
107
+ return self.tensor_meta.shape
108
+
109
+ @property
110
+ def stride(self) -> Tuple[int, ...]:
111
+ if self.tensor_meta is None:
112
+ raise ValueError("tensor_meta is not set")
113
+ return self.tensor_meta.stride
114
+
115
+ @property
116
+ def ndim(self) -> int:
117
+ if self.tensor_meta is None:
118
+ raise ValueError("tensor_meta is not set")
119
+ return len(self.tensor_meta.shape)
120
+
121
+ @property
122
+ def num_shards(self) -> int:
123
+ num_shards = 1
124
+ for i, placement in enumerate(self.placements):
125
+ if placement.is_shard():
126
+ num_shards *= self.mesh.size(i)
127
+ return num_shards
128
+
129
+ @property
130
+ def device_mesh(self) -> DeviceMesh:
131
+ # simple aliasing for the mesh field, make some
132
+ # checks that mixes DTensor/DTensorSpec easier
133
+ return self.mesh
134
+
135
+ @property
136
+ def dim_map(self) -> List[int]:
137
+ """
138
+ dim_map is a property we derive from `placements` of
139
+ the distributed tensor. It simply return a list of ints
140
+ where dim_map[i] denotes the sharding mapping to the mesh
141
+ dimension, and len(dim_map) == dist_tensor.ndim
142
+ dim_map[i] = -1: means tensor dim i replicate on mesh
143
+ dim_map[i] = j: means tensor dim i shard on mesh dim j
144
+
145
+ For example, we have a dist tensor that have the shape of
146
+ [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements:
147
+ [Shard(1)], the dim_map of this placement would be:
148
+ [-1, 0, -1]. This representation is pretty helpful during
149
+ sharding propagation where we could know exactly each
150
+ tensor dimension is sharded or not.
151
+
152
+ Note that if placements contains `_Partial`, we have to
153
+ explicitly deal with it, so that when we create a DTensorSpec
154
+ with dim_map, we could properly record the pending sums.
155
+ """
156
+ # dims mapping of dist tensor sharding
157
+ # return size of tensor ndim, -1 represent replicate
158
+ # and int >=0 represent shard on that device mesh dim
159
+ r = [-1] * self.ndim
160
+ for i, placement in enumerate(self.placements):
161
+ if placement.is_shard():
162
+ shard_dim = cast(Shard, placement).dim
163
+ if r[shard_dim] > -1:
164
+ raise ValueError(
165
+ f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]},"
166
+ " DTensor operator implementation does not support things like hybrid"
167
+ " sharding strategies yet (i.e. [Shard(0), Shard(0)])"
168
+ )
169
+ r[shard_dim] = i
170
+ return r
171
+
172
+ @property
173
+ def num_shards_map(self) -> List[int]:
174
+ """
175
+ dim_map is a property we derive from `placements` of
176
+ the distributed tensor. Unlike `dim_map`, `num_shards_map`
177
+ denotes how many shards each tensor dim has. Like `dim_map`:
178
+ len(num_shards_map) == dist_tensor.ndim
179
+ num_shards_map[i] = 1: means tensor dim i is not sharded
180
+ num_shards_map[i] = j: means tensor dim i has j shards in total
181
+
182
+ For example, we have a dist tensor of shape [18, 20, 30],
183
+ a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements
184
+ ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor
185
+ would be: [4, 2, 1].
186
+ """
187
+ r = [1] * self.ndim
188
+ for i, placement in enumerate(self.placements):
189
+ if placement.is_shard():
190
+ shard_dim = cast(Shard, placement).dim
191
+ r[shard_dim] *= self.mesh.size(i)
192
+
193
+ return r
194
+
195
+ @property
196
+ def sums(self) -> List[int]:
197
+ """
198
+ sums is a property we derive from `placements` of the
199
+ distributed tensor. It simply return a list of ints where
200
+ sums[i] denotes the pending sum (partial) on mesh dim i
201
+ """
202
+ return [
203
+ idx
204
+ for idx, placement in enumerate(self.placements)
205
+ if placement.is_partial()
206
+ ]
207
+
208
+ @classmethod
209
+ def from_dim_map(
210
+ cls,
211
+ mesh: DeviceMesh,
212
+ dim_map: List[int],
213
+ sums: List[int],
214
+ tensor_meta: Optional[TensorMeta] = None,
215
+ ) -> "DTensorSpec":
216
+ """
217
+ Construct a DTensorSpec from dim_map list and pending sum.
218
+
219
+ Args:
220
+ mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec
221
+ dim_map (List[int]): a list of integer that represents sharding on each
222
+ tensor dimension, see `dim_map` property doc for details
223
+ sums (List[int]): a list of integer that represents the dist tensor have
224
+ pending sum on which device mesh dimension.
225
+ tensor meta (TensorMeta): DTensor metadata
226
+
227
+ Return:
228
+ a class:`DTensorSpec` object
229
+ """
230
+ # by default replicate on device mesh dims
231
+ placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)]
232
+
233
+ # find all mesh dims that need pending reductions
234
+ for s in sums:
235
+ placements[s] = Partial()
236
+
237
+ for i, m in enumerate(dim_map):
238
+ if m >= 0:
239
+ placement = placements[m]
240
+ if placement.is_shard():
241
+ placement = cast(Shard, placement)
242
+ raise RuntimeError(
243
+ f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}"
244
+ )
245
+ elif placement.is_partial():
246
+ raise RuntimeError(
247
+ f"DeviceMesh dimension {m} cannot be both shard and partial!"
248
+ )
249
+ placements[m] = Shard(i)
250
+
251
+ return cls(mesh, tuple(placements), tensor_meta=tensor_meta)
252
+
253
+ def is_replicated(self) -> bool:
254
+ """
255
+ return True if the current DTensorSpec replicates on all mesh dims (devices)
256
+ """
257
+ return all(placement.is_replicate() for placement in self.placements)
258
+
259
+ def is_sharded(self) -> bool:
260
+ """
261
+ return True if the current DTensorSpec is sharded on any mesh dims (devices)
262
+ """
263
+ return any(placement.is_shard() for placement in self.placements)
264
+
265
+ def shallow_copy_with_tensor_meta(
266
+ self, tensor_meta: Optional[TensorMeta]
267
+ ) -> "DTensorSpec":
268
+ """
269
+ Shallow copy the DTensorSpec with a new tensor_meta.
270
+ """
271
+ assert tensor_meta is not None, "shallow copy with no tensor_meta!"
272
+ return DTensorSpec(
273
+ self.mesh,
274
+ self.placements,
275
+ tensor_meta=tensor_meta,
276
+ )
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_op_schema.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from dataclasses import dataclass
3
+ from functools import cached_property
4
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
5
+
6
+ import torch
7
+ from torch._ops import OpOverload
8
+ from torch.distributed.device_mesh import DeviceMesh
9
+ from torch.distributed.tensor._dtensor_spec import DTensorSpec
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ try:
14
+ from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec
15
+ except ImportError:
16
+ from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
17
+ tree_leaves,
18
+ tree_map_only,
19
+ TreeSpec,
20
+ )
21
+
22
+
23
+ # Common type aliases
24
+ ArgsType = Tuple[object, ...]
25
+ KwargsType = Dict[str, object]
26
+
27
+ PlacementList = List[Optional[Placement]]
28
+
29
+ # ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
30
+ # be the same set of possibilities.
31
+ OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
32
+
33
+
34
+ def _rebuild_tensor_from_dtensor_meta(arg) -> object:
35
+ """
36
+ This is used to propagate tensor metadata, must be under fake mode
37
+ """
38
+ assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
39
+ return torch.empty_strided(
40
+ arg.tensor_meta.shape,
41
+ arg.tensor_meta.stride,
42
+ dtype=arg.tensor_meta.dtype,
43
+ )
44
+
45
+
46
+ def _is_inplace_op(op: OpOverload):
47
+ # simple analysis of function schema to determine
48
+ # if this is an inplace variant, it might not
49
+ # be entirely correct, but it's good enough for now.
50
+ return op._schema.name[-1] == "_"
51
+
52
+
53
+ def _is_out_variant_op(op: OpOverload):
54
+ # simple analysis of function schema to determine
55
+ # if this is an out variant, it might not
56
+ # be entirely correct, but it's good enough for now.
57
+ return "out" in op._schema.overload_name
58
+
59
+
60
+ def _pretty_print_spec(spec: object) -> str:
61
+ if spec is None:
62
+ return "None"
63
+ elif isinstance(spec, DTensorSpec):
64
+ return "".join([str(p) for p in spec.placements])
65
+ elif isinstance(spec, Sequence):
66
+ return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")"
67
+ else:
68
+ raise RuntimeError(f"Unknown spec type to print: spec={spec}")
69
+
70
+
71
+ @dataclass
72
+ class PlacementStrategy:
73
+ """
74
+ A placement strategy describes acceptable sharding placements of the output
75
+ and the tensor arguments of an operation.
76
+
77
+ note: when the op return value is a single DTensor object, output_specs is
78
+ DTensorSpec; when the return value is a tuple of Optional[DTensor],
79
+ output_specs is a tuple of Optional[DTensorSpec].
80
+ """
81
+
82
+ output_specs: Union[DTensorSpec, Tuple[Optional[DTensorSpec], ...]]
83
+ input_specs: Optional[Sequence[DTensorSpec]] = None
84
+
85
+ # redistribute costs for this op placement strategy
86
+ # we need a nested list to record the cost for each
87
+ # operand of this operator, and for each operand of
88
+ # this operator it might have multiple placement strategies
89
+ redistribute_cost: Optional[List[List[float]]] = None
90
+
91
+ @cached_property
92
+ def output_spec(self) -> DTensorSpec:
93
+ """
94
+ This function requires that the strategy have exactly one DTensorSpec as the
95
+ output spec. If the output_specs is a tuple, we throw an exception.
96
+ """
97
+ if isinstance(self.output_specs, DTensorSpec):
98
+ return self.output_specs
99
+ else:
100
+ raise ValueError(
101
+ f"function output_spec expects a single DTensorSpec but got: {self.output_specs}"
102
+ )
103
+
104
+ def input_spec(self, index: int = 0) -> DTensorSpec:
105
+ assert self.input_specs is not None, "input_specs of PlacementStrategy is None!"
106
+ assert len(self.input_specs) > index, (
107
+ f"Invalid index {index} for input_specs of length "
108
+ f"{len(self.input_specs)}: {self.input_specs}"
109
+ )
110
+ return self.input_specs[index]
111
+
112
+ def __str__(self) -> str:
113
+ if self.input_specs is not None:
114
+ input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> "
115
+ else:
116
+ input_specs_str = ""
117
+ output_spec_str = _pretty_print_spec(self.output_specs)
118
+ return f"{input_specs_str}{output_spec_str}"
119
+
120
+
121
+ class StrategyType:
122
+ """
123
+ Base class type for op strategy, We have two StrategyType:
124
+ OpStrategy and TupleStrategy
125
+ """
126
+
127
+
128
+ class OpStrategy(StrategyType):
129
+ """
130
+ OpStrategy that consists of a list of placement strategies associated with the op
131
+ """
132
+
133
+ def __init__(self, strategies: List[PlacementStrategy]) -> None:
134
+ super().__init__()
135
+ self.strategies: List[PlacementStrategy] = strategies
136
+
137
+ def __str__(self) -> str:
138
+ strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
139
+ mesh_shape = self.mesh_shape
140
+ return f"[{strategy_list_str}] @ mesh: {mesh_shape}"
141
+
142
+ def max_num_shards(self) -> int:
143
+ """
144
+ Returns the max number of shards across all placement strategies
145
+ """
146
+ return max(strategy.output_spec.num_shards for strategy in self.strategies)
147
+
148
+ @property
149
+ def mesh_shape(self):
150
+ output_spec = self.strategies[0].output_specs
151
+ if isinstance(output_spec, DTensorSpec):
152
+ return output_spec.mesh.shape
153
+ else:
154
+ assert isinstance(
155
+ output_spec, tuple
156
+ ), "found no DTensorSpec in the OpStrategy!"
157
+ assert output_spec[0] is not None
158
+ return output_spec[0].mesh.shape
159
+
160
+ @property
161
+ def ndim(self):
162
+ return self.strategies[0].output_spec.ndim
163
+
164
+ @property
165
+ def shape(self):
166
+ return self.strategies[0].output_spec.shape
167
+
168
+
169
+ class TupleStrategy(StrategyType):
170
+ """
171
+ TupleStrategy represents the output strategy of this op is a tuple
172
+ of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors
173
+ with possibly different placement strategies, we should return a TupleStrategy that
174
+ contains a tuple of OpStrategy, where each child represents the sharding strategy
175
+ of "each element" of the tuple/list of tensors the op returns.
176
+
177
+ NOTE: if the output of the op is a List[Tensor] and they share the same placement
178
+ strategy, then we should return a single OpStrategy instead of a TupleStrategy
179
+ """
180
+
181
+ def __init__(self, childs: Sequence[StrategyType]) -> None:
182
+ super().__init__()
183
+ self.childs: Sequence[StrategyType] = childs
184
+
185
+ def __str__(self) -> str:
186
+ child_strategies_str = ", ".join(
187
+ [f"{str(strat)}" for idx, strat in enumerate(self.childs)]
188
+ )
189
+ return f"TupleStrategy({child_strategies_str})"
190
+
191
+
192
+ @dataclass
193
+ class RuntimeSchemaInfo:
194
+ """
195
+ RuntimeSchemaInfo stores the operator schema related information for runtime (eager)
196
+ execution. This is mainly used for two ways: 1. to generate hash for args to determine
197
+ whether to re-run sharding prop or not 2. to determine if we need pytree
198
+ """
199
+
200
+ # This static_argnum records static arg "starting index" for ops that have non-tensor
201
+ # args/kwargs which would affect sharding propagation results. All args starting from
202
+ # this index would be hashed to our sharding cache.
203
+ # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
204
+ static_argnum: int = 100
205
+ # This static_kwargkey records static kwarg names which would affect sharding prop
206
+ static_kwargkey: Optional[List[str]] = None
207
+ # each op can decide if it wants to use pytree flatten/unflatten during operator
208
+ # eager execution, by default we don't need to do flatten/unflatten, only if the
209
+ # op indicate it needs to, this is to accelerate eager performance.
210
+ needs_pytree: bool = False
211
+
212
+
213
+ @dataclass
214
+ class OpSchema:
215
+ """
216
+ OpSchema is a data class that describes an operator input schemas, it includes
217
+ DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order
218
+ preserved). It is mainly used by the DTensor's dispatching logic to perform various
219
+ actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
220
+
221
+ NOTE: this should be used as a read only data class
222
+ TODO: make this a frozen dataclass
223
+
224
+ Args:
225
+ op: the operator overload we are intercepting
226
+ args_schema: contains args except that the DTensor args have been replaced
227
+ with its DTensorSpec or OpStrategy
228
+ kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
229
+ with its DTensorSpec or OpStrategy
230
+ """
231
+
232
+ op: OpOverload
233
+ args_schema: ArgsType
234
+ kwargs_schema: KwargsType
235
+
236
+ schema_info: Optional[RuntimeSchemaInfo] = None
237
+
238
+ @property
239
+ def args_spec(self) -> Tuple[DTensorSpec, ...]:
240
+ """
241
+ args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
242
+ with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
243
+ mainly used by sharding propagation to propagate the output spec
244
+ """
245
+ args = (
246
+ tree_leaves(self.args_schema)
247
+ if self.schema_info is not None and self.schema_info.needs_pytree
248
+ else self.args_schema
249
+ )
250
+ return tuple(item for item in args if isinstance(item, DTensorSpec))
251
+
252
+ @property
253
+ def args_strategy(self) -> Tuple[OpStrategy, ...]:
254
+ # filter out non-relevant values from args schema to get a clean OpStrategy list
255
+ # separate with args_spec for the ease of type annotation
256
+ # TODO: see if we should merge this with args_spec
257
+ args = (
258
+ tree_leaves(self.args_schema)
259
+ if self.schema_info is not None and self.schema_info.needs_pytree
260
+ else self.args_schema
261
+ )
262
+ return tuple(item for item in args if isinstance(item, OpStrategy))
263
+
264
+ def __repr__(self) -> str:
265
+ args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema])
266
+ return (
267
+ f"OpSchema(op={self.op},"
268
+ f" args_schema=({args_schema}),"
269
+ f" kwargs_schema={self.kwargs_schema})"
270
+ )
271
+
272
+ def __str__(self) -> str:
273
+ args_schema: List[str] = []
274
+ mesh_shape = None
275
+ for arg in self.args_schema:
276
+ if isinstance(arg, DTensorSpec):
277
+ args_schema.append(str(arg))
278
+ mesh_shape = arg.mesh.shape
279
+ elif isinstance(arg, OpStrategy):
280
+ assert len(arg.strategies) == 1
281
+ args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
282
+ mesh_shape = arg.mesh_shape
283
+ elif isinstance(arg, TupleStrategy):
284
+ first_op_strtgy = arg.childs[0]
285
+ assert isinstance(first_op_strtgy, OpStrategy)
286
+ mesh_shape = first_op_strtgy.mesh_shape
287
+ args_schema.append(str(arg))
288
+ else:
289
+ args_schema.append(str(arg))
290
+ return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
291
+
292
+ def __post_init__(self) -> None:
293
+ has_symints = False
294
+ for a in self.args_schema:
295
+ if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
296
+ if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
297
+ has_symints = True
298
+ break
299
+ self.has_symints = has_symints
300
+
301
+ def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool:
302
+ arg = self.args_schema[arg_idx]
303
+ is_tensor = isinstance(arg, DTensorSpec)
304
+ if is_tensor:
305
+ return True
306
+
307
+ if not isinstance(arg, list):
308
+ return False
309
+
310
+ return all(isinstance(e, DTensorSpec) or e is None for e in arg)
311
+
312
+ def return_type_tuple_tensor_like(self) -> bool:
313
+ # all dispatch ops could only return Tuple[Tensor] or have None/ints/floats
314
+ # in the tuple, but the first element must be a Tensor, so this check is enough
315
+ return_types = self.op._schema.returns
316
+ return len(return_types) > 1 and isinstance(
317
+ return_types[0].type, torch.TensorType
318
+ )
319
+
320
+ def return_type_tensor(self) -> bool:
321
+ return_types = self.op._schema.returns
322
+ # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
323
+ # return types, so this check is enough for tensor like types
324
+ return isinstance(return_types[0].type, torch.TensorType)
325
+
326
+ def __hash__(self) -> int:
327
+ # Only hash args and kwargs that op indicates to hash
328
+ if not self.schema_info:
329
+ static_argnum = len(self.args_schema)
330
+ static_kwargkey = None
331
+ else:
332
+ static_argnum = self.schema_info.static_argnum
333
+ static_kwargkey = self.schema_info.static_kwargkey
334
+
335
+ args_to_hash = tuple(
336
+ tuple(e) if isinstance(e, list) else e
337
+ for i, e in enumerate(self.args_schema)
338
+ if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum
339
+ )
340
+ if static_kwargkey is not None:
341
+ kwargs_to_hash = tuple(
342
+ self.kwargs_schema.get(k, None) for k in static_kwargkey
343
+ )
344
+ return hash((self.op, args_to_hash, kwargs_to_hash))
345
+ else:
346
+ return hash((self.op, args_to_hash))
347
+
348
+ def __eq__(self, other: object) -> bool:
349
+ # early return checks
350
+ if not isinstance(other, OpSchema):
351
+ return False
352
+
353
+ if self.op != other.op:
354
+ return False
355
+
356
+ if len(self.args_schema) != len(other.args_schema):
357
+ return False
358
+
359
+ # compare each element and early return if any of them is different
360
+ if not self.schema_info:
361
+ static_argnum = len(self.args_schema)
362
+ static_kwargkey = None
363
+ else:
364
+ static_argnum = self.schema_info.static_argnum
365
+ static_kwargkey = self.schema_info.static_kwargkey
366
+
367
+ for i, (self_arg, other_arg) in enumerate(
368
+ zip(self.args_schema, other.args_schema)
369
+ ):
370
+ if isinstance(self_arg, DTensorSpec) and self_arg != other_arg:
371
+ return False
372
+ elif i >= static_argnum and self_arg != other_arg:
373
+ return False
374
+
375
+ # check kwarg equality when there's a static kwarg key
376
+ if static_kwargkey:
377
+ for key in static_kwargkey:
378
+ if self.kwargs_schema.get(key, None) != other.kwargs_schema.get(
379
+ key, None
380
+ ):
381
+ return False
382
+
383
+ return True
384
+
385
+ def gen_fake_args(self) -> ArgsType:
386
+ """
387
+ gen_fake_args: generate fake args for the operator, this is mainly used
388
+ by sharding propagation rules to generate fake args for the operator
389
+ to run the local tensor operator and get the output spec.
390
+ """
391
+ return tree_map_only(
392
+ DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema
393
+ )
394
+
395
+ def gen_fake_kwargs(self) -> KwargsType:
396
+ """
397
+ gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
398
+ by sharding propagation rules to generate fake kwargs for the operator
399
+ to run the local tensor operator and get the output spec.
400
+ """
401
+ return tree_map_only(
402
+ DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
403
+ )
404
+
405
+ def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
406
+ suggestion_args_spec = self.args_spec
407
+ new_arg_schema: List[object] = []
408
+ idx_of_args_spec = 0
409
+ if (
410
+ origin_schema.schema_info is not None
411
+ and origin_schema.schema_info.needs_pytree
412
+ ):
413
+ args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema)
414
+ else:
415
+ args_schema = origin_schema.args_schema
416
+ for arg in args_schema:
417
+ if isinstance(arg, DTensorSpec):
418
+ new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
419
+ idx_of_args_spec += 1
420
+ else:
421
+ new_arg_schema.append(arg)
422
+ self.args_schema = tuple(new_arg_schema)
423
+ self.kwargs_schema = origin_schema.kwargs_schema
424
+
425
+
426
+ @dataclass
427
+ class OutputSharding:
428
+ """
429
+ OutputSharding is a data class that is used by the sharding propagation,
430
+ it could set the output_spec upon successful propagation. If needs_redistribute
431
+ is set to True, a redistribute_schema would be returned together to indicate
432
+ the input arguments needs to be redistributed before the op execution.
433
+
434
+ NOTE: the redistribute_schema generated by sharding propagation should be
435
+ exactly the same as the operator OpSchema, except the DTensorSpecs
436
+ """
437
+
438
+ output_spec: OutputSpecType
439
+ redistribute_schema: Optional[OpSchema] = None
440
+ needs_redistribute: bool = False
441
+
442
+
443
+ @dataclass
444
+ class OpInfo:
445
+ """
446
+ All Runtime Op execution info are packed here
447
+ """
448
+
449
+ mesh: DeviceMesh
450
+ schema: OpSchema
451
+ flat_args_schema: List[object]
452
+ local_args: Sequence[object]
453
+ local_kwargs: Dict[str, object]
454
+ args_tree_spec: Optional[TreeSpec] = None
455
+
456
+ # the output sharding info
457
+ output_sharding: Optional[OutputSharding] = None
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ from ._conv_ops import * # noqa: F403
3
+ from ._embedding_ops import * # noqa: F403
4
+ from ._experimental_ops import * # noqa: F403
5
+ from ._math_ops import * # noqa: F403
6
+ from ._matrix_ops import * # noqa: F403
7
+ from ._pointwise_ops import * # noqa: F403
8
+ from ._random_ops import * # noqa: F403
9
+ from ._tensor_ops import * # noqa: F403
10
+ from ._view_ops import * # noqa: F403
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (508 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-311.pyc ADDED
Binary file (4.33 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-311.pyc ADDED
Binary file (7.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_experimental_ops.cpython-311.pyc ADDED
Binary file (1.58 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-311.pyc ADDED
Binary file (42.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-311.pyc ADDED
Binary file (30.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-311.pyc ADDED
Binary file (1.84 kB). View file