Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_IR.py +1243 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__init__.py +28 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_backward.py +370 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_debug.py +21 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_unflatten.py +27 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/_utils.py +99 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/microbatch.py +469 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/schedules.py +2162 -0
- .venv/lib/python3.11/site-packages/torch/distributed/pipelining/stage.py +1468 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/__init__.py +67 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_api.py +1231 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_collective_utils.py +373 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py +510 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py +276 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_op_schema.py +457 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__init__.py +10 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_experimental_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-311.pyc
ADDED
|
Binary file (8.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-311.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-311.pyc
ADDED
|
Binary file (3.22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-311.pyc
ADDED
|
Binary file (23.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-311.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-311.pyc
ADDED
|
Binary file (630 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-311.pyc
ADDED
|
Binary file (7.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-311.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-311.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-311.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-311.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_IR.py
ADDED
|
@@ -0,0 +1,1243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 3 |
+
import copy
|
| 4 |
+
import logging
|
| 5 |
+
import operator
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from inspect import Parameter, Signature, signature
|
| 9 |
+
from types import MethodType
|
| 10 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.fx as fx
|
| 14 |
+
from torch.distributed import ProcessGroup
|
| 15 |
+
from torch.export import ExportedProgram
|
| 16 |
+
from torch.export.unflatten import (
|
| 17 |
+
_assign_attr,
|
| 18 |
+
_AttrKind,
|
| 19 |
+
_sink_params,
|
| 20 |
+
InterpreterModule,
|
| 21 |
+
)
|
| 22 |
+
from torch.fx.node import map_aggregate
|
| 23 |
+
from torch.fx.passes.split_module import split_module
|
| 24 |
+
|
| 25 |
+
from ._backward import _null_coalesce_accumulate, stage_backward
|
| 26 |
+
from ._unflatten import _outline_submodules
|
| 27 |
+
from ._utils import PipeInfo
|
| 28 |
+
from .stage import _PipelineStage
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# TODO:
|
| 34 |
+
# 1. investigate gradient sync for shared parameters. how does DDP do it?
|
| 35 |
+
# 2. Add parameter movement to split_module
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _find_loss_from_output_and_spec(output_val, spec_val):
|
| 39 |
+
if spec_val is False:
|
| 40 |
+
return None
|
| 41 |
+
if spec_val is True:
|
| 42 |
+
if not isinstance(output_val, fx.Node):
|
| 43 |
+
raise RuntimeError(
|
| 44 |
+
f"Loss spec must specify a dynamic value but got {output_val}"
|
| 45 |
+
)
|
| 46 |
+
return output_val
|
| 47 |
+
|
| 48 |
+
if isinstance(spec_val, (tuple, list)):
|
| 49 |
+
if not isinstance(output_val, (tuple, list)):
|
| 50 |
+
raise RuntimeError(
|
| 51 |
+
f"Output value {output_val} must match type of loss specification "
|
| 52 |
+
f"{spec_val}"
|
| 53 |
+
)
|
| 54 |
+
if len(output_val) != len(spec_val):
|
| 55 |
+
raise RuntimeError(
|
| 56 |
+
f"Output value {output_val} must match length of loss specification "
|
| 57 |
+
f"{spec_val}"
|
| 58 |
+
)
|
| 59 |
+
for out, spec in zip(output_val, spec_val):
|
| 60 |
+
loss_val = _find_loss_from_output_and_spec(out, spec)
|
| 61 |
+
if loss_val is not None:
|
| 62 |
+
return loss_val
|
| 63 |
+
raise RuntimeError(f"Did not find loss value in specification {spec_val}")
|
| 64 |
+
|
| 65 |
+
if isinstance(spec_val, dict):
|
| 66 |
+
if not isinstance(output_val, dict):
|
| 67 |
+
raise RuntimeError(
|
| 68 |
+
f"Output value {output_val} must match type of loss specification "
|
| 69 |
+
f"{spec_val}"
|
| 70 |
+
)
|
| 71 |
+
if set(output_val.keys()) != set(spec_val.keys()):
|
| 72 |
+
raise RuntimeError(
|
| 73 |
+
f"Output value {output_val} must match keys of loss specification "
|
| 74 |
+
f"{spec_val}"
|
| 75 |
+
)
|
| 76 |
+
for k in spec_val:
|
| 77 |
+
loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k])
|
| 78 |
+
if loss_val is not None:
|
| 79 |
+
return loss_val
|
| 80 |
+
raise RuntimeError(f"Did not find loss value in specification {spec_val}")
|
| 81 |
+
|
| 82 |
+
raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec):
|
| 86 |
+
output_nodes = [n for n in g.nodes if n.op == "output"]
|
| 87 |
+
assert len(output_nodes) == 1
|
| 88 |
+
output_node = output_nodes[0]
|
| 89 |
+
output_val = output_node.args[0]
|
| 90 |
+
generated_spec: Any = None
|
| 91 |
+
|
| 92 |
+
if isinstance(mod, TrivialLossWrapper):
|
| 93 |
+
# TrivialLossWrapper is pre-defined by PiPPy.
|
| 94 |
+
# It has loss as the only output so we can safely assume the first output arg is the loss.
|
| 95 |
+
assert len(output_node.args) == 1
|
| 96 |
+
loss_node = output_val
|
| 97 |
+
generated_spec = TrivialLossWrapper.loss_spec
|
| 98 |
+
elif output_loss_value_spec is None:
|
| 99 |
+
# Use default spec, i.e. search for "loss" in output values
|
| 100 |
+
if isinstance(output_val, dict) and "loss" in output_val.keys():
|
| 101 |
+
loss_node = output_val["loss"]
|
| 102 |
+
generated_spec = {k: k == "loss" for k in output_val}
|
| 103 |
+
else:
|
| 104 |
+
loss_node = None
|
| 105 |
+
generated_spec = None
|
| 106 |
+
else:
|
| 107 |
+
loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec)
|
| 108 |
+
generated_spec = output_loss_value_spec
|
| 109 |
+
|
| 110 |
+
return loss_node, output_node, generated_spec
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _insert_stage_symbolic_backward(
|
| 114 |
+
g: fx.Graph,
|
| 115 |
+
loss_node: fx.Node,
|
| 116 |
+
output_node: fx.Node,
|
| 117 |
+
):
|
| 118 |
+
# Collect metadata about tuple output values. TODO: move this to split_module or FX IR
|
| 119 |
+
tuples: Dict[fx.Node, Tuple] = {}
|
| 120 |
+
for node in reversed(g.nodes):
|
| 121 |
+
if node.op == "call_function":
|
| 122 |
+
# In the forward pass, only emit placeholder, module calls, and
|
| 123 |
+
# getitem calls. If we have a target other than getitem in this
|
| 124 |
+
# (forward-only) code, there is a bug.
|
| 125 |
+
assert node.target == operator.getitem, (
|
| 126 |
+
"Found non-getitem call in forward pass. "
|
| 127 |
+
"Please report a bug to PiPPy"
|
| 128 |
+
)
|
| 129 |
+
assert (
|
| 130 |
+
len(node.args) == 2
|
| 131 |
+
), "Found malformed getitem call. Please report a bug to PiPPy"
|
| 132 |
+
indexed_value, node_idx = tuple(node.args)
|
| 133 |
+
|
| 134 |
+
# indexed_value is a collection that we are indexing into. It could
|
| 135 |
+
# exist in the tuples map if we've processed another `getitem`
|
| 136 |
+
# already.
|
| 137 |
+
existing_list_size = (
|
| 138 |
+
len(tuples[indexed_value]) if indexed_value in tuples else -1
|
| 139 |
+
)
|
| 140 |
+
new_list_size = max(node_idx + 1, existing_list_size)
|
| 141 |
+
|
| 142 |
+
reconstructed_list = [None for _ in range(new_list_size)]
|
| 143 |
+
|
| 144 |
+
# Copy over existing elements if present
|
| 145 |
+
if indexed_value in tuples:
|
| 146 |
+
for i, val in enumerate(tuples[indexed_value]):
|
| 147 |
+
reconstructed_list[i] = val
|
| 148 |
+
|
| 149 |
+
# Populate value represented by this node
|
| 150 |
+
reconstructed_list[node_idx] = node
|
| 151 |
+
|
| 152 |
+
tuples[indexed_value] = tuple(reconstructed_list)
|
| 153 |
+
|
| 154 |
+
# Keep track of nodes that dominate the loss node.
|
| 155 |
+
# We will only emit backward operations for nodes that can contribute
|
| 156 |
+
# to the specified loss value.
|
| 157 |
+
live_nodes = {loss_node: None}
|
| 158 |
+
val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
|
| 159 |
+
|
| 160 |
+
def assign_or_accumulate_grad(forward_node, grad_value):
|
| 161 |
+
if forward_node in val_to_grad and forward_node.op != "placeholder":
|
| 162 |
+
grad_value = g.call_function(
|
| 163 |
+
_null_coalesce_accumulate,
|
| 164 |
+
(val_to_grad[forward_node], grad_value),
|
| 165 |
+
)
|
| 166 |
+
val_to_grad[forward_node] = grad_value
|
| 167 |
+
|
| 168 |
+
with g.inserting_before(output_node):
|
| 169 |
+
for node in reversed(g.nodes):
|
| 170 |
+
if node not in live_nodes:
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
def add_to_live_nodes(n):
|
| 174 |
+
live_nodes.setdefault(n, None)
|
| 175 |
+
|
| 176 |
+
fx.node.map_arg(node.args, add_to_live_nodes)
|
| 177 |
+
fx.node.map_arg(node.kwargs, add_to_live_nodes)
|
| 178 |
+
if node.op == "call_module":
|
| 179 |
+
output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]]
|
| 180 |
+
if node in tuples:
|
| 181 |
+
stage_output = tuples[node]
|
| 182 |
+
output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
|
| 183 |
+
outputs_with_grads_idxs = [
|
| 184 |
+
i for i, n in enumerate(tuples[node]) if n in live_nodes
|
| 185 |
+
]
|
| 186 |
+
else:
|
| 187 |
+
stage_output = (node,)
|
| 188 |
+
output_grads = val_to_grad[node]
|
| 189 |
+
outputs_with_grads_idxs = [0]
|
| 190 |
+
|
| 191 |
+
output_grads = (
|
| 192 |
+
(output_grads,)
|
| 193 |
+
if not isinstance(output_grads, tuple)
|
| 194 |
+
else output_grads
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
grad_call = g.call_function(
|
| 198 |
+
stage_backward,
|
| 199 |
+
kwargs={
|
| 200 |
+
"stage_output": stage_output,
|
| 201 |
+
"output_grads": output_grads,
|
| 202 |
+
"input_values": list(node.all_input_nodes),
|
| 203 |
+
"outputs_with_grads_idxs": outputs_with_grads_idxs,
|
| 204 |
+
},
|
| 205 |
+
)
|
| 206 |
+
# Insert backward stage debug info
|
| 207 |
+
kwargs_copy = dict(grad_call.kwargs)
|
| 208 |
+
grad_call.kwargs = kwargs_copy
|
| 209 |
+
|
| 210 |
+
grad_call_proxy = fx.Proxy(grad_call)
|
| 211 |
+
grads = grad_call_proxy.node
|
| 212 |
+
|
| 213 |
+
input_nodes = list(node.all_input_nodes)
|
| 214 |
+
grads_proxy = fx.Proxy(grads)
|
| 215 |
+
for i, input_node in enumerate(input_nodes):
|
| 216 |
+
assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index]
|
| 217 |
+
|
| 218 |
+
return g
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class PipeSequential(torch.nn.Sequential):
|
| 222 |
+
@staticmethod
|
| 223 |
+
def from_sequential(sequential_instance: torch.nn.Sequential):
|
| 224 |
+
return PipeSequential(*[copy.copy(m) for m in sequential_instance])
|
| 225 |
+
|
| 226 |
+
def forward(self, input):
|
| 227 |
+
for i, module in enumerate(self):
|
| 228 |
+
input = module(input)
|
| 229 |
+
if i != len(self) - 1:
|
| 230 |
+
pipe_split()
|
| 231 |
+
return input
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class LossWrapper(torch.nn.Module):
|
| 235 |
+
"""
|
| 236 |
+
LossWrapper is a convenient abstract class that allows you to wrap up both
|
| 237 |
+
your model as well as its loss function and specify the connectivity between
|
| 238 |
+
the inputs, model, loss function, and output value. Example::
|
| 239 |
+
|
| 240 |
+
class MyModelWrapper(LossWrapper):
|
| 241 |
+
def forward(self, x, targets):
|
| 242 |
+
model_out = self.module(x)
|
| 243 |
+
loss_value = self.loss_fn(model_out, targets)
|
| 244 |
+
return loss_value
|
| 245 |
+
|
| 246 |
+
The above example defines a connectivity where we expect the forward/loss/backward
|
| 247 |
+
training procedure to take two arguments (x and targets), pass x into the module
|
| 248 |
+
to get the output of the feedforward computation, pass the model output and the
|
| 249 |
+
targets value into the loss function, and get and return the loss value, which will
|
| 250 |
+
be backpropagated by PiPPy. The above class would then be instantiated like::
|
| 251 |
+
|
| 252 |
+
model = ... # instantiate the model
|
| 253 |
+
loss_fn = torch.nn.MSELoss() # for the sake of demonstration
|
| 254 |
+
|
| 255 |
+
wrapper = MyModelWrapper(model, loss_fn)
|
| 256 |
+
pipe = Pipe.from_tracing(wrapper, ...)
|
| 257 |
+
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, module, loss_fn):
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.module = module
|
| 263 |
+
self.loss_fn = loss_fn
|
| 264 |
+
|
| 265 |
+
def forward(self, *args, **kwargs):
|
| 266 |
+
raise NotImplementedError(
|
| 267 |
+
"This instance of LossWrapper does not have an overridden"
|
| 268 |
+
"forward(). Please implement forward() to specify the arguments, "
|
| 269 |
+
"connection between the module and loss, and loss output "
|
| 270 |
+
"value."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class TrivialLossWrapper(LossWrapper):
|
| 275 |
+
def forward(self, x, targets):
|
| 276 |
+
model_out = self.module(x)
|
| 277 |
+
return self.loss_fn(model_out, targets)
|
| 278 |
+
|
| 279 |
+
loss_spec = True
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# Pipe model representation
|
| 283 |
+
#
|
| 284 |
+
# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies
|
| 285 |
+
# a single topological ordering of pipeline "stages" that, when run in series,
|
| 286 |
+
# constitutes all of the operations of the program. However, unlike `nn.Sequential`,
|
| 287 |
+
# Pipe allows non-local usages of values, so long as those uses still respect
|
| 288 |
+
# topological ordering. In particular:
|
| 289 |
+
#
|
| 290 |
+
# 1. Non-local activations. This type of usage can appear in, for example, skip
|
| 291 |
+
# connections. These values will be directly transmitted from the "def" stage
|
| 292 |
+
# to all stages that use them skipping intermediate stages. During autograd,
|
| 293 |
+
# gradients will be propagated back through this skip connection reverse
|
| 294 |
+
# to how activations propagated in the forward pass.
|
| 295 |
+
# 2. Non-local parameter/module invocations. This occurs when a parameter is used
|
| 296 |
+
# in a stage downstream of where it is resident. These values can be carried
|
| 297 |
+
# forward similarly to (1), but in addition one might want to replicate the
|
| 298 |
+
# value on multiple stages. Gradients for these shared parameters will be
|
| 299 |
+
# accumulated separately on each stage, but there will be an additional
|
| 300 |
+
# gradient accumulation before the optimizer step.
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# Register `_pipe_split()` as an ATen operator. This is required for Export to
|
| 304 |
+
# preserve this marker in the graph.
|
| 305 |
+
torch.library.define("pippy::_pipe_split", "() -> ()")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@torch.library.impl("pippy::_pipe_split", "BackendSelect")
|
| 309 |
+
def _pipe_split():
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef]
|
| 314 |
+
def _pipe_split(): # noqa: F811
|
| 315 |
+
return None
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# Add an alias for convenience
|
| 319 |
+
aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
|
| 320 |
+
|
| 321 |
+
# Ask Export to preserve the `_pipe_split` op.
|
| 322 |
+
# See examples in pytorch/torch/fx/node.py
|
| 323 |
+
fx.node._side_effectful_functions.add(aten_pipe_split_alias)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# User facing API
|
| 327 |
+
def pipe_split():
|
| 328 |
+
"""
|
| 329 |
+
pipe_split is a special operator that is used to mark the boundary between
|
| 330 |
+
stages in a module. It is used to split the module into stages. It is a
|
| 331 |
+
no-op if your annotated module is run eagerly.
|
| 332 |
+
|
| 333 |
+
Example:
|
| 334 |
+
>>> # xdoctest: +SKIP
|
| 335 |
+
>>> def forward(self, x):
|
| 336 |
+
>>> x = torch.mm(x, self.mm_param)
|
| 337 |
+
>>> x = torch.relu(x)
|
| 338 |
+
>>> pipe_split()
|
| 339 |
+
>>> x = self.lin(x)
|
| 340 |
+
>>> return x
|
| 341 |
+
|
| 342 |
+
The above example will be split into two stages.
|
| 343 |
+
"""
|
| 344 |
+
return torch.ops.pippy._pipe_split()
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class MultiUseParameterConfig(Enum):
|
| 348 |
+
TRANSMIT = 1
|
| 349 |
+
REPLICATE = 2
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]]
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class DetachExecutor(fx.Interpreter):
|
| 356 |
+
"""
|
| 357 |
+
Special interpreter to run the split_gm in testing that detaches all inputs to
|
| 358 |
+
a module invocation. This is needed so that the values at the boundary are
|
| 359 |
+
leaf modules in autograd execution.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
def __init__(self, module, garbage_collect_values=True):
|
| 363 |
+
garbage_collect_values = False
|
| 364 |
+
super().__init__(module, garbage_collect_values)
|
| 365 |
+
self.value_remap = {}
|
| 366 |
+
|
| 367 |
+
def run(self, *args, initial_env=None):
|
| 368 |
+
self.value_remap = {}
|
| 369 |
+
return super().run(*args, initial_env=initial_env)
|
| 370 |
+
|
| 371 |
+
def call_module(self, target, args, kwargs):
|
| 372 |
+
def detach_tensors(a):
|
| 373 |
+
if isinstance(a, torch.Tensor) and a.requires_grad:
|
| 374 |
+
if a not in self.value_remap:
|
| 375 |
+
new_val = a.detach().requires_grad_(True)
|
| 376 |
+
self.value_remap[a] = new_val
|
| 377 |
+
return self.value_remap[a]
|
| 378 |
+
else:
|
| 379 |
+
return a
|
| 380 |
+
|
| 381 |
+
"""
|
| 382 |
+
def dont_traverse_size(a):
|
| 383 |
+
return type(a) != torch.Size
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
args = map_aggregate(
|
| 387 |
+
args,
|
| 388 |
+
detach_tensors, # dont_traverse_size
|
| 389 |
+
)
|
| 390 |
+
kwargs = map_aggregate(
|
| 391 |
+
kwargs,
|
| 392 |
+
detach_tensors, # dont_traverse_size
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
return super().call_module(target, args, kwargs)
|
| 396 |
+
|
| 397 |
+
def call_function(self, target, args, kwargs):
|
| 398 |
+
# HACK to reroute saved input tensors to point to the detach()ed version
|
| 399 |
+
if target == stage_backward:
|
| 400 |
+
kwargs = dict(kwargs)
|
| 401 |
+
kwargs["input_values"] = [
|
| 402 |
+
self.value_remap.get(v, v) for v in kwargs["input_values"]
|
| 403 |
+
]
|
| 404 |
+
return super().call_function(target, args, kwargs)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class _NodeReference:
|
| 408 |
+
def __init__(self, name):
|
| 409 |
+
self.name = name
|
| 410 |
+
|
| 411 |
+
name: str
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class _LinearNodeList:
|
| 415 |
+
def __init__(self, node_list):
|
| 416 |
+
self.serialize_node_list = []
|
| 417 |
+
for node in node_list:
|
| 418 |
+
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
|
| 419 |
+
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
|
| 420 |
+
serialize_node = fx.Node(
|
| 421 |
+
graph=None, # type: ignore[arg-type]
|
| 422 |
+
name=node.name,
|
| 423 |
+
op=node.op,
|
| 424 |
+
target=node.target,
|
| 425 |
+
args=node_args, # type: ignore[arg-type]
|
| 426 |
+
kwargs=node_kwargs, # type: ignore[arg-type]
|
| 427 |
+
return_type=node.type,
|
| 428 |
+
)
|
| 429 |
+
serialize_node.meta = copy.copy(node.meta)
|
| 430 |
+
self.serialize_node_list.append(serialize_node)
|
| 431 |
+
|
| 432 |
+
def to_graph(self):
|
| 433 |
+
graph = fx.Graph()
|
| 434 |
+
|
| 435 |
+
ref_str_to_node: Dict[str, fx.Node] = {}
|
| 436 |
+
|
| 437 |
+
def ref_to_node(arg):
|
| 438 |
+
if isinstance(arg, _NodeReference):
|
| 439 |
+
return ref_str_to_node[arg.name]
|
| 440 |
+
else:
|
| 441 |
+
return arg
|
| 442 |
+
|
| 443 |
+
for node in self.serialize_node_list:
|
| 444 |
+
node_args = map_aggregate(node.args, ref_to_node)
|
| 445 |
+
node_kwargs = map_aggregate(node.kwargs, ref_to_node)
|
| 446 |
+
deser_node = graph.create_node(
|
| 447 |
+
op=node.op,
|
| 448 |
+
target=node.target,
|
| 449 |
+
args=node_args, # type: ignore[arg-type]
|
| 450 |
+
kwargs=node_kwargs, # type: ignore[arg-type]
|
| 451 |
+
name=node.name,
|
| 452 |
+
type_expr=node.type,
|
| 453 |
+
)
|
| 454 |
+
ref_str_to_node[node.name] = deser_node
|
| 455 |
+
|
| 456 |
+
return graph
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def _direct_serialization_deserialize(body, nodes):
|
| 460 |
+
"""
|
| 461 |
+
Custom `__reduce__` method for serialization.
|
| 462 |
+
DO AS I SAY -- NOT AS I DO. This violates the principle that
|
| 463 |
+
GraphModules serialize via code export & re-tracing. We allow
|
| 464 |
+
for this here because **PIPE STAGES SHOULD NOT BE PERSISTED
|
| 465 |
+
TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting
|
| 466 |
+
these instances to disk will expose internal implementation
|
| 467 |
+
details of `fx.Graph` and related data structures and is
|
| 468 |
+
NOT advised.
|
| 469 |
+
"""
|
| 470 |
+
|
| 471 |
+
class DummyModule(torch.nn.Module):
|
| 472 |
+
def __init__(self, body):
|
| 473 |
+
super().__init__()
|
| 474 |
+
self.__dict__.update(body)
|
| 475 |
+
|
| 476 |
+
dummy = DummyModule(body)
|
| 477 |
+
|
| 478 |
+
return fx.GraphModule(dummy, nodes.to_graph())
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def _direct_serialization_reduce(self):
|
| 482 |
+
serialization_dict = dict(self.__dict__)
|
| 483 |
+
serialization_dict.pop("_graph")
|
| 484 |
+
return (
|
| 485 |
+
_direct_serialization_deserialize,
|
| 486 |
+
(serialization_dict, _LinearNodeList(self.graph.nodes)),
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _modify_graph_op_device(
|
| 491 |
+
gm: torch.fx.GraphModule,
|
| 492 |
+
new_device: torch.device,
|
| 493 |
+
):
|
| 494 |
+
"""
|
| 495 |
+
Modify the device argument of all "call_function" nodes in the graph. This
|
| 496 |
+
is useful for moving the graph to a different device. In particular for
|
| 497 |
+
generator ops, like torch.ones.
|
| 498 |
+
"""
|
| 499 |
+
modified = False
|
| 500 |
+
for node in gm.graph.nodes:
|
| 501 |
+
if node.op == "call_function":
|
| 502 |
+
if "device" in node.kwargs and node.kwargs["device"] != new_device:
|
| 503 |
+
logger.debug(
|
| 504 |
+
f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
|
| 505 |
+
)
|
| 506 |
+
node.update_kwarg("device", new_device)
|
| 507 |
+
modified = True
|
| 508 |
+
elif node.op == "call_module":
|
| 509 |
+
# Recursively modify "device" in submodules
|
| 510 |
+
submod = gm.get_submodule(node.target)
|
| 511 |
+
if isinstance(submod, torch.fx.GraphModule):
|
| 512 |
+
_modify_graph_op_device(submod, new_device)
|
| 513 |
+
elif isinstance(submod, InterpreterModule):
|
| 514 |
+
# If unflattening has been performed, we need to access its graph module by `.graph_module`
|
| 515 |
+
_modify_graph_op_device(submod.graph_module, new_device)
|
| 516 |
+
else:
|
| 517 |
+
logger.warning(
|
| 518 |
+
f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if modified:
|
| 522 |
+
gm.recompile()
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class Pipe(torch.nn.Module):
|
| 526 |
+
def __init__(
|
| 527 |
+
self,
|
| 528 |
+
split_gm: fx.GraphModule,
|
| 529 |
+
num_stages: int,
|
| 530 |
+
has_loss_and_backward: bool,
|
| 531 |
+
loss_spec,
|
| 532 |
+
):
|
| 533 |
+
# TODO: is there a way not to hard wire init?
|
| 534 |
+
torch.nn.Module.__init__(self)
|
| 535 |
+
self.split_gm: fx.GraphModule = split_gm
|
| 536 |
+
self.executor: DetachExecutor = DetachExecutor(self.split_gm)
|
| 537 |
+
self.num_stages: int = num_stages
|
| 538 |
+
self.has_loss_and_backward = has_loss_and_backward
|
| 539 |
+
self.loss_spec = loss_spec
|
| 540 |
+
|
| 541 |
+
for node in split_gm.graph.nodes:
|
| 542 |
+
assert (
|
| 543 |
+
node.op in {"call_module", "placeholder", "output"}
|
| 544 |
+
or (node.op, node.target) == ("call_function", operator.getitem)
|
| 545 |
+
or (node.op, node.target) == ("call_method", "backward")
|
| 546 |
+
or (node.op, node.target) == ("call_function", stage_backward)
|
| 547 |
+
or (node.op, node.target)
|
| 548 |
+
== ("call_function", _null_coalesce_accumulate)
|
| 549 |
+
), node
|
| 550 |
+
|
| 551 |
+
# Detect replicated parameters so we know that we have to do an additional allreduce
|
| 552 |
+
# before applying the optimizer
|
| 553 |
+
#
|
| 554 |
+
# Note that this also handles the case where there were multiple calls to a single
|
| 555 |
+
# module from different stages, regardless of whether that module invocation
|
| 556 |
+
# was handled by the logic above.
|
| 557 |
+
|
| 558 |
+
# Map parameter value to a dictionary that maps the user pipeline module
|
| 559 |
+
# to the local qualname within that module
|
| 560 |
+
params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {}
|
| 561 |
+
|
| 562 |
+
for m_qualname, mod in self.split_gm.named_children():
|
| 563 |
+
for p_qualname, param in mod.named_parameters():
|
| 564 |
+
params_to_users.setdefault(param, {})
|
| 565 |
+
params_to_users[param][m_qualname] = p_qualname
|
| 566 |
+
|
| 567 |
+
self.replicated_params: List[Dict[str, str]] = [
|
| 568 |
+
use_mapping
|
| 569 |
+
for _, use_mapping in params_to_users.items()
|
| 570 |
+
if len(use_mapping) > 1
|
| 571 |
+
]
|
| 572 |
+
|
| 573 |
+
# We must break the aliasing relationship between the replicated parameters for correct
|
| 574 |
+
# numerics in reference runs. If we do not do this, the autograd tape in separate stages
|
| 575 |
+
# will have a reference to the same tensor value and will erroneously apply gradient
|
| 576 |
+
# updates multiple times. Therefore, for each replicated parameter set, we deepcopy the
|
| 577 |
+
# values so that we have separate instances.
|
| 578 |
+
for param_mapping in self.replicated_params:
|
| 579 |
+
for submod_name, param_qualname in param_mapping.items():
|
| 580 |
+
submod = getattr(self.split_gm, submod_name)
|
| 581 |
+
atoms = param_qualname.split(".")
|
| 582 |
+
for atom in atoms[:-1]:
|
| 583 |
+
submod = getattr(submod, atom)
|
| 584 |
+
setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
|
| 585 |
+
|
| 586 |
+
def throw(self, *args, **kwargs):
|
| 587 |
+
raise RuntimeError(
|
| 588 |
+
"To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
self.split_gm.forward = throw
|
| 592 |
+
|
| 593 |
+
# Make submodules use custom direct-serialized GraphModule
|
| 594 |
+
i = 0
|
| 595 |
+
while True:
|
| 596 |
+
try:
|
| 597 |
+
name = f"submod_{i}"
|
| 598 |
+
submod = getattr(self.split_gm, name)
|
| 599 |
+
submod.__class__.__reduce__ = _direct_serialization_reduce
|
| 600 |
+
i += 1
|
| 601 |
+
except AttributeError:
|
| 602 |
+
break
|
| 603 |
+
|
| 604 |
+
def forward(self, *args, **kwargs):
|
| 605 |
+
executor_args = args
|
| 606 |
+
if len(kwargs) > 0:
|
| 607 |
+
parameters = []
|
| 608 |
+
for node in self.split_gm.graph.nodes:
|
| 609 |
+
if node.op == "placeholder":
|
| 610 |
+
if node.args and len(node.args) > 0:
|
| 611 |
+
parameters.append(
|
| 612 |
+
Parameter(
|
| 613 |
+
node.target,
|
| 614 |
+
Parameter.POSITIONAL_OR_KEYWORD,
|
| 615 |
+
default=node.args[0],
|
| 616 |
+
)
|
| 617 |
+
)
|
| 618 |
+
else:
|
| 619 |
+
parameter_kind = Parameter.POSITIONAL_OR_KEYWORD
|
| 620 |
+
param_name = node.target
|
| 621 |
+
if node.target.startswith("**"):
|
| 622 |
+
parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment]
|
| 623 |
+
param_name = param_name[2:]
|
| 624 |
+
elif node.target.startswith("*"):
|
| 625 |
+
parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment]
|
| 626 |
+
param_name = param_name[1:]
|
| 627 |
+
parameters.append(Parameter(param_name, parameter_kind))
|
| 628 |
+
signature = Signature(parameters)
|
| 629 |
+
ba = signature.bind(*args, **kwargs)
|
| 630 |
+
ba.apply_defaults()
|
| 631 |
+
executor_args = ba.arguments.values() # type: ignore[assignment]
|
| 632 |
+
|
| 633 |
+
res = self.executor.run(*executor_args)
|
| 634 |
+
|
| 635 |
+
return res
|
| 636 |
+
|
| 637 |
+
def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
|
| 638 |
+
"""
|
| 639 |
+
Return a stage module corresponding to `stage_idx` of the `pipe`.
|
| 640 |
+
"""
|
| 641 |
+
if stage_idx < 0 or stage_idx >= self.num_stages:
|
| 642 |
+
raise ValueError(f"Invalid stage index {stage_idx}!")
|
| 643 |
+
return getattr(self.split_gm, f"submod_{stage_idx}")
|
| 644 |
+
|
| 645 |
+
@staticmethod
|
| 646 |
+
def _number_and_count_forward_stages(gm: fx.GraphModule):
|
| 647 |
+
num_stages = 0
|
| 648 |
+
found_idxs: Dict[int, None] = {}
|
| 649 |
+
for node in gm.graph.nodes:
|
| 650 |
+
if node.op == "call_module" and node.target.startswith("submod_"):
|
| 651 |
+
node.meta["stage_idx"] = int(node.target[len("submod_") :])
|
| 652 |
+
found_idxs.setdefault(node.meta["stage_idx"])
|
| 653 |
+
num_stages += 1
|
| 654 |
+
|
| 655 |
+
# this assert will fail if a split point is inserted before the first layer, which creates empty first submodule
|
| 656 |
+
# Update: the following assert may fail against some torch versions >=
|
| 657 |
+
# 2.2.0, as:
|
| 658 |
+
# submod_0, submod_1, submod_2, ...
|
| 659 |
+
# may be named as
|
| 660 |
+
# submod_0, submod_2, submod_4, ...
|
| 661 |
+
# TODO: investigate
|
| 662 |
+
# assert all(i in found_idxs for i in range(num_stages))
|
| 663 |
+
|
| 664 |
+
return num_stages
|
| 665 |
+
|
| 666 |
+
@staticmethod
|
| 667 |
+
def _from_traced(
|
| 668 |
+
mod: torch.nn.Module,
|
| 669 |
+
exported_program: ExportedProgram,
|
| 670 |
+
multi_use_param_spec: Optional[MultiUseParamSpec] = None,
|
| 671 |
+
output_loss_value_spec=None,
|
| 672 |
+
split_policy: Optional[
|
| 673 |
+
Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
|
| 674 |
+
] = None,
|
| 675 |
+
):
|
| 676 |
+
"""
|
| 677 |
+
Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
|
| 678 |
+
which value in the output of `forward` is the loss value on which PiPPy should apply
|
| 679 |
+
backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``,
|
| 680 |
+
you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns
|
| 681 |
+
a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify
|
| 682 |
+
``output_loss_value_spec={'loss': True, 'model_out': False}``
|
| 683 |
+
"""
|
| 684 |
+
|
| 685 |
+
traced = exported_program.module()
|
| 686 |
+
|
| 687 |
+
if split_policy is not None:
|
| 688 |
+
logger.info("Auto-splitting model")
|
| 689 |
+
traced = split_policy(traced) # type: ignore[arg-type]
|
| 690 |
+
|
| 691 |
+
logger.debug(traced.print_readable(print_output=False))
|
| 692 |
+
|
| 693 |
+
# Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
|
| 694 |
+
# parameters relies on the invariant that parameter accesses happen once. This is not necessarily
|
| 695 |
+
# the case (especially with custom tracers), so fix that up here.
|
| 696 |
+
get_attr_nodes: Dict[str, fx.Node] = {}
|
| 697 |
+
for node in traced.graph.nodes:
|
| 698 |
+
if node.op == "get_attr":
|
| 699 |
+
get_attr_nodes.setdefault(node.target, node)
|
| 700 |
+
|
| 701 |
+
if get_attr_nodes[node.target] != node:
|
| 702 |
+
node.replace_all_uses_with(get_attr_nodes[node.target])
|
| 703 |
+
traced.graph.erase_node(node)
|
| 704 |
+
|
| 705 |
+
# avoid looking at next node by keeping track of previous pipe_split
|
| 706 |
+
prev_pipe_split_idx = -1
|
| 707 |
+
pipe_split_nodes_to_erase = set()
|
| 708 |
+
for i, node in enumerate(traced.graph.nodes):
|
| 709 |
+
if (node.op, node.target) == ("call_function", pipe_split):
|
| 710 |
+
if prev_pipe_split_idx == i - 1:
|
| 711 |
+
pipe_split_nodes_to_erase.add(node)
|
| 712 |
+
prev_pipe_split_idx = i
|
| 713 |
+
|
| 714 |
+
for node in pipe_split_nodes_to_erase:
|
| 715 |
+
traced.graph.erase_node(node)
|
| 716 |
+
|
| 717 |
+
traced.recompile()
|
| 718 |
+
|
| 719 |
+
part_idx = 0
|
| 720 |
+
|
| 721 |
+
def split_callback(n: fx.Node):
|
| 722 |
+
nonlocal part_idx
|
| 723 |
+
if (n.op, n.target) == (
|
| 724 |
+
"call_function",
|
| 725 |
+
aten_pipe_split_alias,
|
| 726 |
+
):
|
| 727 |
+
logger.debug(f"Found pipe_split {part_idx}") # noqa: G004
|
| 728 |
+
part_idx += 1
|
| 729 |
+
return part_idx
|
| 730 |
+
|
| 731 |
+
# TODO: what does split do with module invocations? does it move the modules
|
| 732 |
+
# into the submodules?
|
| 733 |
+
split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
|
| 734 |
+
# a (custom) tracer can produce dead code like orphan get_attr nodes
|
| 735 |
+
split.graph.eliminate_dead_code()
|
| 736 |
+
|
| 737 |
+
# peephole to remove pipe_split
|
| 738 |
+
for submodule in split.modules():
|
| 739 |
+
if isinstance(submodule, fx.GraphModule):
|
| 740 |
+
for node in submodule.graph.nodes:
|
| 741 |
+
if (node.op, node.target) == (
|
| 742 |
+
"call_function",
|
| 743 |
+
aten_pipe_split_alias,
|
| 744 |
+
):
|
| 745 |
+
submodule.graph.erase_node(node)
|
| 746 |
+
submodule.recompile()
|
| 747 |
+
|
| 748 |
+
for name, submodule in split.named_children():
|
| 749 |
+
if isinstance(submodule, fx.GraphModule):
|
| 750 |
+
new_submod = _outline_submodules(submodule.graph)
|
| 751 |
+
# Replace old submod
|
| 752 |
+
split.register_module(name, new_submod)
|
| 753 |
+
|
| 754 |
+
# TODO: backport this into split_module
|
| 755 |
+
def delete_user_reference(node, user):
|
| 756 |
+
"""
|
| 757 |
+
Delete reference of `node` from `user`'s arg list.
|
| 758 |
+
Args:
|
| 759 |
+
- node: a `get_attr` node at root.
|
| 760 |
+
- user: a submodule node that uses `node`.
|
| 761 |
+
"""
|
| 762 |
+
assert len(user.kwargs) == 0
|
| 763 |
+
use_idxs = [i for i, arg in enumerate(user.args) if arg == node]
|
| 764 |
+
assert len(use_idxs) == 1
|
| 765 |
+
args_copy = list(user.args)
|
| 766 |
+
args_copy.pop(use_idxs[0])
|
| 767 |
+
user.args = tuple(args_copy)
|
| 768 |
+
logger.debug(
|
| 769 |
+
f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
# A list of param referrals for deferred deletion.
|
| 773 |
+
# To be accumulated in `move_param_to_callee`.
|
| 774 |
+
to_delete = []
|
| 775 |
+
|
| 776 |
+
def _recursive_getattr_with_parent(mod, fqn):
|
| 777 |
+
# Returns getattr call given a nested FQN, and the last parent
|
| 778 |
+
atoms = fqn.split(".")
|
| 779 |
+
for atom in atoms[:-1]:
|
| 780 |
+
if not hasattr(mod, atom):
|
| 781 |
+
return None, None
|
| 782 |
+
mod = getattr(mod, atom)
|
| 783 |
+
if not hasattr(mod, atoms[-1]):
|
| 784 |
+
return mod, None
|
| 785 |
+
attr = getattr(mod, atoms[-1])
|
| 786 |
+
return mod, attr
|
| 787 |
+
|
| 788 |
+
def move_param_to_callee(
|
| 789 |
+
root,
|
| 790 |
+
callee_name,
|
| 791 |
+
param_fqn,
|
| 792 |
+
):
|
| 793 |
+
"""
|
| 794 |
+
Move a parameter from the root module to a submodule.
|
| 795 |
+
Args:
|
| 796 |
+
root: The root module.
|
| 797 |
+
callee_name: The name of the submodule to move the parameter to.
|
| 798 |
+
param_fqn: The fully qualified name of the parameter to move.
|
| 799 |
+
"""
|
| 800 |
+
# `atoms` is a list of strings representing the path to the
|
| 801 |
+
# parameter in the original model
|
| 802 |
+
atoms = param_fqn.split(".")
|
| 803 |
+
mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn)
|
| 804 |
+
# Check whether the parameter is a buffer or a parameter
|
| 805 |
+
is_buffer = atoms[-1] in mod_itr._buffers
|
| 806 |
+
|
| 807 |
+
# Check whether the parameter is a tensor
|
| 808 |
+
assert isinstance(param_val, torch.Tensor), (
|
| 809 |
+
f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}."
|
| 810 |
+
+ (
|
| 811 |
+
f" It might happen if module '{param_fqn}' was passed to some 'leaf function'"
|
| 812 |
+
f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect "
|
| 813 |
+
f"usages of '{param_fqn}' in the traced graph."
|
| 814 |
+
if isinstance(param_val, torch.nn.Module)
|
| 815 |
+
else ""
|
| 816 |
+
)
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# Get submodule
|
| 820 |
+
callee = root.get_submodule(callee_name)
|
| 821 |
+
assert not hasattr(
|
| 822 |
+
callee, param_fqn
|
| 823 |
+
), f"Module {callee_name} already has a parameter named {param_fqn}"
|
| 824 |
+
|
| 825 |
+
# Assign the parameter to the submodule
|
| 826 |
+
if is_buffer:
|
| 827 |
+
_assign_attr(
|
| 828 |
+
param_val,
|
| 829 |
+
callee,
|
| 830 |
+
param_fqn,
|
| 831 |
+
attr_kind=_AttrKind.BUFFER,
|
| 832 |
+
persistent=True, # TODO: handle non-persistent buffer
|
| 833 |
+
)
|
| 834 |
+
else:
|
| 835 |
+
_assign_attr(
|
| 836 |
+
param_val,
|
| 837 |
+
callee,
|
| 838 |
+
param_fqn,
|
| 839 |
+
attr_kind=_AttrKind.PARAMETER,
|
| 840 |
+
)
|
| 841 |
+
logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004
|
| 842 |
+
|
| 843 |
+
# Next step is to replace placeholder of submodule with a get_attr.
|
| 844 |
+
# Those placeholders are created by `split_module` inside each
|
| 845 |
+
# submodule.
|
| 846 |
+
# Update: this step is now moved to `_sink_params` because
|
| 847 |
+
# `_sink_params` can do it recursively (i.e. for modules inside
|
| 848 |
+
# submodule)
|
| 849 |
+
|
| 850 |
+
to_delete.append((mod_itr, atoms[-1]))
|
| 851 |
+
|
| 852 |
+
# Get the list of all parameters in the root module
|
| 853 |
+
attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes))
|
| 854 |
+
for node in attr_nodes:
|
| 855 |
+
# Check whether the parameter is used in only one submodule
|
| 856 |
+
if len(node.users) > 1:
|
| 857 |
+
logger.info(
|
| 858 |
+
f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004
|
| 859 |
+
)
|
| 860 |
+
for user in node.users:
|
| 861 |
+
assert user.op == "call_module"
|
| 862 |
+
# Move parameter into submodule
|
| 863 |
+
move_param_to_callee(
|
| 864 |
+
split,
|
| 865 |
+
user.target,
|
| 866 |
+
node.target,
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# [aliasing] store tensor id -> list of FQNs, built from state dict
|
| 870 |
+
# Also assign non-persistent buffers
|
| 871 |
+
id_to_fqns: Dict[int, Set[str]] = defaultdict(set)
|
| 872 |
+
for fqn, tensor in mod.state_dict(keep_vars=True).items():
|
| 873 |
+
id_to_fqns[id(tensor)].add(fqn)
|
| 874 |
+
for fqn, tensor in mod.named_buffers():
|
| 875 |
+
id_to_fqns[id(tensor)].add(fqn)
|
| 876 |
+
|
| 877 |
+
# After moving the params to their corresponding hierarchies, we also
|
| 878 |
+
# need to move the `get_attr` nodes from the root of the graph to those
|
| 879 |
+
# hierarchies.
|
| 880 |
+
# [aliasing] use id -> fqn mapping to list out all valid FQNs
|
| 881 |
+
inputs_to_state: Dict[str, List[str]] = {}
|
| 882 |
+
for attr in attr_nodes:
|
| 883 |
+
_, tensor = _recursive_getattr_with_parent(mod, attr.target)
|
| 884 |
+
fqns = list(id_to_fqns[id(tensor)])
|
| 885 |
+
if fqns:
|
| 886 |
+
inputs_to_state[attr.name] = fqns
|
| 887 |
+
elif attr.target in exported_program.constants: # lifted constants
|
| 888 |
+
inputs_to_state[attr.name] = [attr.target]
|
| 889 |
+
|
| 890 |
+
# [aliasing] for each submodule split, assign attributes on FQNs that may be used.
|
| 891 |
+
# We determine this based on whether or not the FQN attribute parent exists.
|
| 892 |
+
# i.e. if the last submodule exists, assign the attribute.
|
| 893 |
+
added_attributes: Dict[str, List[str]] = defaultdict(list)
|
| 894 |
+
for fqn, tensor in mod.state_dict(keep_vars=True).items():
|
| 895 |
+
for name, submod in split.named_children():
|
| 896 |
+
if isinstance(submod, fx.GraphModule):
|
| 897 |
+
parent, child = _recursive_getattr_with_parent(submod, fqn)
|
| 898 |
+
if (
|
| 899 |
+
parent and child is None
|
| 900 |
+
): # parent exists, attribute doesn't -> assign
|
| 901 |
+
added_attributes[name].append(fqn)
|
| 902 |
+
setattr(parent, fqn.split(".")[-1], tensor)
|
| 903 |
+
|
| 904 |
+
# Deferral deletion: Remove the original attributes (to params) from the
|
| 905 |
+
# root GraphModule
|
| 906 |
+
for mod_itr, last_atom in to_delete:
|
| 907 |
+
try:
|
| 908 |
+
delattr(mod_itr, last_atom)
|
| 909 |
+
except AttributeError:
|
| 910 |
+
# This is expected if the parameter is used in multiple stages
|
| 911 |
+
pass
|
| 912 |
+
|
| 913 |
+
# This is done by (1) `_sink_params` at each submodule;
|
| 914 |
+
for name, submod in split.named_children():
|
| 915 |
+
if isinstance(submod, fx.GraphModule):
|
| 916 |
+
_sink_params(submod, inputs_to_state, [])
|
| 917 |
+
submod.graph.lint()
|
| 918 |
+
submod.recompile()
|
| 919 |
+
|
| 920 |
+
# [aliasing] This step is not super necessary, but helps reduce parameter usage/memory.
|
| 921 |
+
# After _sink_params() routine has run, clean up unused attributes that we previously added.
|
| 922 |
+
# Determine this based on the get_attr nodes - if not used, remove it.
|
| 923 |
+
for name, attributes in added_attributes.items():
|
| 924 |
+
submod = getattr(split, name)
|
| 925 |
+
unused_attributes = set(attributes)
|
| 926 |
+
# track used attributes in the submodule, running DFS on subgraph hierarchy
|
| 927 |
+
stack = [("", submod)] # (scope, submodule)
|
| 928 |
+
while stack:
|
| 929 |
+
scope, _mod = stack.pop()
|
| 930 |
+
if isinstance(_mod, (fx.GraphModule, InterpreterModule)):
|
| 931 |
+
for node in _mod.graph.nodes:
|
| 932 |
+
if node.op == "get_attr":
|
| 933 |
+
# get_attr might get access deeper level attribute
|
| 934 |
+
fqn = scope + "." + node.target if scope else node.target
|
| 935 |
+
if fqn in unused_attributes: # used, remove it
|
| 936 |
+
unused_attributes.remove(fqn)
|
| 937 |
+
for _name, _submod in _mod.named_children():
|
| 938 |
+
stack.append((scope + "." + _name if scope else _name, _submod))
|
| 939 |
+
# delete unused attributes
|
| 940 |
+
for attr in unused_attributes:
|
| 941 |
+
mod_itr, atoms = submod, attr.split(".")
|
| 942 |
+
for atom in atoms[:-1]:
|
| 943 |
+
mod_itr = getattr(mod_itr, atom)
|
| 944 |
+
delattr(mod_itr, atoms[-1])
|
| 945 |
+
|
| 946 |
+
for node in attr_nodes:
|
| 947 |
+
# And (2): remove `get_attr` node from submod's arg list
|
| 948 |
+
for user in copy.copy(node.users):
|
| 949 |
+
assert user.op == "call_module"
|
| 950 |
+
delete_user_reference(node, user)
|
| 951 |
+
# And (3): remove the `get_attr` node from the root graph.
|
| 952 |
+
split.graph.erase_node(node)
|
| 953 |
+
|
| 954 |
+
split.delete_all_unused_submodules()
|
| 955 |
+
split.graph.lint()
|
| 956 |
+
split.recompile()
|
| 957 |
+
|
| 958 |
+
num_stages = Pipe._number_and_count_forward_stages(split)
|
| 959 |
+
|
| 960 |
+
has_loss_and_backward = False
|
| 961 |
+
generated_loss_spec = output_loss_value_spec
|
| 962 |
+
|
| 963 |
+
if output_loss_value_spec is not None:
|
| 964 |
+
loss_node, output_node, generated_loss_spec = _find_loss_output(
|
| 965 |
+
mod, split.graph, output_loss_value_spec
|
| 966 |
+
)
|
| 967 |
+
if loss_node is not None:
|
| 968 |
+
_insert_stage_symbolic_backward(
|
| 969 |
+
split.graph,
|
| 970 |
+
loss_node,
|
| 971 |
+
output_node,
|
| 972 |
+
)
|
| 973 |
+
split.recompile()
|
| 974 |
+
has_loss_and_backward = True
|
| 975 |
+
logger.debug("Pipeline is in training mode, backward pass generated")
|
| 976 |
+
else:
|
| 977 |
+
raise RuntimeError(
|
| 978 |
+
f"Did not find any loss value according to {output_loss_value_spec=}"
|
| 979 |
+
)
|
| 980 |
+
else:
|
| 981 |
+
logger.debug("Pipeline is in inference mode, backward pass not generated")
|
| 982 |
+
|
| 983 |
+
logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
|
| 984 |
+
|
| 985 |
+
return Pipe(
|
| 986 |
+
split,
|
| 987 |
+
num_stages,
|
| 988 |
+
has_loss_and_backward,
|
| 989 |
+
generated_loss_spec,
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
def print_readable(self):
|
| 993 |
+
"""
|
| 994 |
+
Print the pipe in a human-readable format.
|
| 995 |
+
This will print both the root pipe and each stage module.
|
| 996 |
+
"""
|
| 997 |
+
self.split_gm.print_readable()
|
| 998 |
+
|
| 999 |
+
@staticmethod
|
| 1000 |
+
def _trace_with_export(
|
| 1001 |
+
mod: torch.nn.Module,
|
| 1002 |
+
example_args: Tuple[Any, ...],
|
| 1003 |
+
example_kwargs: Optional[Dict[str, Any]] = None,
|
| 1004 |
+
) -> ExportedProgram:
|
| 1005 |
+
logger.info("Tracing model ...")
|
| 1006 |
+
try:
|
| 1007 |
+
ep = torch.export.export(
|
| 1008 |
+
mod,
|
| 1009 |
+
example_args,
|
| 1010 |
+
example_kwargs,
|
| 1011 |
+
)
|
| 1012 |
+
except Exception as e:
|
| 1013 |
+
raise RuntimeError(
|
| 1014 |
+
"It seems that we cannot capture your model as a full graph. "
|
| 1015 |
+
"Typical reasons include graph breaks, data/shape-dependent "
|
| 1016 |
+
"control flow, or missing meta kernels for custom operators. "
|
| 1017 |
+
"You can use our manual pipeline interfaces, or try to fix the "
|
| 1018 |
+
"graph breaks, see https://pytorch.org/docs/stable/export.html"
|
| 1019 |
+
) from e
|
| 1020 |
+
|
| 1021 |
+
return ep
|
| 1022 |
+
|
| 1023 |
+
@staticmethod
|
| 1024 |
+
def from_tracing(
|
| 1025 |
+
mod: torch.nn.Module,
|
| 1026 |
+
example_args: Tuple[Any, ...],
|
| 1027 |
+
example_kwargs: Optional[Dict[str, Any]] = None,
|
| 1028 |
+
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
| 1029 |
+
):
|
| 1030 |
+
# If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
|
| 1031 |
+
# stages instead of TRANSMIT'ting it
|
| 1032 |
+
multi_use_param_spec = MultiUseParameterConfig.REPLICATE
|
| 1033 |
+
|
| 1034 |
+
# Figure out which output is loss from output_chunk_spec
|
| 1035 |
+
output_loss_value_spec: Any = None
|
| 1036 |
+
# Deprecated
|
| 1037 |
+
"""
|
| 1038 |
+
if output_chunk_spec is not None:
|
| 1039 |
+
output_loss_value_spec = map_aggregate(
|
| 1040 |
+
output_chunk_spec, lambda v: isinstance(v, _LossReducer)
|
| 1041 |
+
)
|
| 1042 |
+
"""
|
| 1043 |
+
|
| 1044 |
+
# Trace with export
|
| 1045 |
+
exported_program = Pipe._trace_with_export(
|
| 1046 |
+
mod,
|
| 1047 |
+
example_args,
|
| 1048 |
+
example_kwargs,
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
pipe = Pipe._from_traced(
|
| 1052 |
+
mod,
|
| 1053 |
+
exported_program,
|
| 1054 |
+
multi_use_param_spec,
|
| 1055 |
+
output_loss_value_spec=output_loss_value_spec,
|
| 1056 |
+
split_policy=split_policy,
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
# Users want the first pipeline stage to accept kwargs if the original
|
| 1060 |
+
# program does. This is controlled by the `_codegen` field of the graph,
|
| 1061 |
+
# so we make a copy here. Note: we only want the input spec and not the
|
| 1062 |
+
# output spec, because the output spec is for the last stage. Maybe a
|
| 1063 |
+
# TODO? Not sure yet.
|
| 1064 |
+
split = pipe.split_gm
|
| 1065 |
+
traced = exported_program.module()
|
| 1066 |
+
submod0 = next(iter(split.children()))
|
| 1067 |
+
submod0_sign = signature(submod0.forward)
|
| 1068 |
+
model_sign = signature(traced.forward)
|
| 1069 |
+
if len(model_sign.parameters) != len(submod0_sign.parameters):
|
| 1070 |
+
# We don't change the signature of the first stage if it takes
|
| 1071 |
+
# different number of args than original model
|
| 1072 |
+
logger.info(
|
| 1073 |
+
f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004
|
| 1074 |
+
f"first pipeline stage takes {len(submod0_sign.parameters)}. "
|
| 1075 |
+
"Please provide args to respective pipeline stages."
|
| 1076 |
+
)
|
| 1077 |
+
else:
|
| 1078 |
+
# Support kwargs for the first stage
|
| 1079 |
+
submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
|
| 1080 |
+
# `_replace` is actually not "private" or internal. based on this doc:
|
| 1081 |
+
# To prevent conflicts with field names, the method and attribute names
|
| 1082 |
+
# start with an underscore
|
| 1083 |
+
submod0.graph._codegen.pytree_info = (
|
| 1084 |
+
submod0.graph._codegen.pytree_info._replace(out_spec=None)
|
| 1085 |
+
)
|
| 1086 |
+
submod0.recompile()
|
| 1087 |
+
|
| 1088 |
+
return pipe
|
| 1089 |
+
|
| 1090 |
+
def __str__(self):
|
| 1091 |
+
return self.split_gm.__str__()
|
| 1092 |
+
|
| 1093 |
+
def __repr__(self):
|
| 1094 |
+
return self.split_gm.__repr__()
|
| 1095 |
+
|
| 1096 |
+
def info(self) -> PipeInfo:
|
| 1097 |
+
"""
|
| 1098 |
+
Get information about the pipe.
|
| 1099 |
+
|
| 1100 |
+
Returns
|
| 1101 |
+
-------
|
| 1102 |
+
PipeInfo
|
| 1103 |
+
A dataclass containing information about the pipe.
|
| 1104 |
+
"""
|
| 1105 |
+
return PipeInfo(
|
| 1106 |
+
graph=self.split_gm.graph,
|
| 1107 |
+
num_stages=self.num_stages,
|
| 1108 |
+
has_loss_and_backward=self.has_loss_and_backward,
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
def build_stage(
|
| 1112 |
+
self,
|
| 1113 |
+
stage_index: int,
|
| 1114 |
+
device: torch.device,
|
| 1115 |
+
group: Optional[ProcessGroup] = None,
|
| 1116 |
+
) -> _PipelineStage:
|
| 1117 |
+
"""
|
| 1118 |
+
Create a `PipelineStage` given a stage index and distributed group.
|
| 1119 |
+
The `PipelineStage` can run with `PipelineSchedule`s.
|
| 1120 |
+
"""
|
| 1121 |
+
# Find stage module
|
| 1122 |
+
stage_module = self.get_stage_module(stage_index)
|
| 1123 |
+
|
| 1124 |
+
# Move ops argument to device
|
| 1125 |
+
# Today PT2 tracer does not treat `x.device` as a symbolic device;
|
| 1126 |
+
# instead, the device of tracing time got burned into the generated
|
| 1127 |
+
# code. Here we provide a workaround for users to manually modify the
|
| 1128 |
+
# "device" kwarg of operations. Such operation may include:
|
| 1129 |
+
# `torch.ones`, `torch.zeros`, `torch.rand`, etc.
|
| 1130 |
+
if isinstance(stage_module, torch.fx.GraphModule):
|
| 1131 |
+
_modify_graph_op_device(stage_module, device)
|
| 1132 |
+
else:
|
| 1133 |
+
logger.warning(
|
| 1134 |
+
f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
# Detach pipe info
|
| 1138 |
+
# Note: be careful what's included in `pipe_info`. We don't want to keep
|
| 1139 |
+
# a reference to `Pipe` or `Pipe.split_gm` which stops python from
|
| 1140 |
+
# recycling them. When python recycles them, other stage modules (which
|
| 1141 |
+
# are irrelevant to current rank) can be automatically freed.
|
| 1142 |
+
pipe_info = self.info()
|
| 1143 |
+
return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
class SplitPoint(Enum):
|
| 1147 |
+
BEGINNING = 1
|
| 1148 |
+
END = 2
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
# For backward compatibility, we kept the PipeSplitWrapper class because `class
|
| 1152 |
+
# SplitPoint` used to be defined in this class.
|
| 1153 |
+
class PipeSplitWrapper:
|
| 1154 |
+
# Create a class alias for BC
|
| 1155 |
+
SplitPoint = SplitPoint
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
def _split_before_forward(self, *args, **kwargs):
|
| 1159 |
+
pipe_split()
|
| 1160 |
+
return self._orig_forward(*args, **kwargs)
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
def _split_after_forward(self, *args, **kwargs):
|
| 1164 |
+
try:
|
| 1165 |
+
return self._orig_forward(*args, **kwargs)
|
| 1166 |
+
finally:
|
| 1167 |
+
pipe_split()
|
| 1168 |
+
|
| 1169 |
+
|
| 1170 |
+
def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
|
| 1171 |
+
# TODO: make this implementation out-of-place?
|
| 1172 |
+
for qualname, split_type in spec.items():
|
| 1173 |
+
atoms = qualname.split(".")
|
| 1174 |
+
predecessor_module = mod
|
| 1175 |
+
for i, atom in enumerate(atoms[:-1]):
|
| 1176 |
+
try:
|
| 1177 |
+
predecessor_module = getattr(predecessor_module, atom)
|
| 1178 |
+
except AttributeError as e:
|
| 1179 |
+
raise AttributeError(
|
| 1180 |
+
f"Specified target {qualname} referenced "
|
| 1181 |
+
f'nonexistent module {".".join(atoms[: i + 1])}'
|
| 1182 |
+
) from e
|
| 1183 |
+
|
| 1184 |
+
mod_to_wrap = getattr(predecessor_module, atoms[-1])
|
| 1185 |
+
mod_to_wrap._orig_forward = mod_to_wrap.forward
|
| 1186 |
+
if split_type == SplitPoint.BEGINNING:
|
| 1187 |
+
mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap)
|
| 1188 |
+
elif split_type == SplitPoint.END:
|
| 1189 |
+
mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap)
|
| 1190 |
+
else:
|
| 1191 |
+
raise ValueError("Unknown split point type.")
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
def pipeline(
|
| 1195 |
+
module: torch.nn.Module,
|
| 1196 |
+
mb_args: Tuple[Any, ...],
|
| 1197 |
+
mb_kwargs: Optional[Dict[str, Any]] = None,
|
| 1198 |
+
split_spec: Optional[Dict[str, SplitPoint]] = None,
|
| 1199 |
+
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
| 1200 |
+
) -> Pipe:
|
| 1201 |
+
"""
|
| 1202 |
+
Split a module based on a specification.
|
| 1203 |
+
|
| 1204 |
+
See `Pipe` for more details.
|
| 1205 |
+
|
| 1206 |
+
Arguments
|
| 1207 |
+
---------
|
| 1208 |
+
module:
|
| 1209 |
+
The module to be splitted.
|
| 1210 |
+
mb_args:
|
| 1211 |
+
Example positional inputs, in micro-batch form.
|
| 1212 |
+
mb_kwargs:
|
| 1213 |
+
Example keyword inputs, in micro-batch form. (default: `None`)
|
| 1214 |
+
split_spec:
|
| 1215 |
+
A dictionary using submodule names as split marker. (default: `None`)
|
| 1216 |
+
split_policy:
|
| 1217 |
+
The policy to use for splitting the module. (default: `None`)
|
| 1218 |
+
|
| 1219 |
+
Returns
|
| 1220 |
+
-------
|
| 1221 |
+
A pipeline representation of class `Pipe`.
|
| 1222 |
+
"""
|
| 1223 |
+
if split_spec is not None and split_policy is not None:
|
| 1224 |
+
raise ValueError(
|
| 1225 |
+
"Cannot specify both `split_spec` and `split_policy`. Please use only one of them."
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
if split_spec is not None:
|
| 1229 |
+
# Annotate split points in the module based on user spec
|
| 1230 |
+
annotate_split_points(module, split_spec)
|
| 1231 |
+
return Pipe.from_tracing(
|
| 1232 |
+
mod=module,
|
| 1233 |
+
example_args=mb_args,
|
| 1234 |
+
example_kwargs=mb_kwargs,
|
| 1235 |
+
)
|
| 1236 |
+
else:
|
| 1237 |
+
# Use split policy
|
| 1238 |
+
return Pipe.from_tracing(
|
| 1239 |
+
mod=module,
|
| 1240 |
+
example_args=mb_args,
|
| 1241 |
+
example_kwargs=mb_kwargs,
|
| 1242 |
+
split_policy=split_policy,
|
| 1243 |
+
)
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
from ._IR import Pipe, pipe_split, pipeline, SplitPoint
|
| 3 |
+
from .schedules import (
|
| 4 |
+
_ScheduleForwardOnly,
|
| 5 |
+
Schedule1F1B,
|
| 6 |
+
ScheduleFlexibleInterleaved1F1B,
|
| 7 |
+
ScheduleGPipe,
|
| 8 |
+
ScheduleInterleaved1F1B,
|
| 9 |
+
ScheduleInterleavedZeroBubble,
|
| 10 |
+
ScheduleLoopedBFS,
|
| 11 |
+
)
|
| 12 |
+
from .stage import build_stage, PipelineStage
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"Pipe",
|
| 17 |
+
"pipe_split",
|
| 18 |
+
"SplitPoint",
|
| 19 |
+
"pipeline",
|
| 20 |
+
"PipelineStage",
|
| 21 |
+
"build_stage",
|
| 22 |
+
"Schedule1F1B",
|
| 23 |
+
"ScheduleFlexibleInterleaved1F1B",
|
| 24 |
+
"ScheduleGPipe",
|
| 25 |
+
"ScheduleInterleaved1F1B",
|
| 26 |
+
"ScheduleLoopedBFS",
|
| 27 |
+
"ScheduleInterleavedZeroBubble",
|
| 28 |
+
]
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-311.pyc
ADDED
|
Binary file (54.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (842 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-311.pyc
ADDED
|
Binary file (1.08 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-311.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (4.77 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-311.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-311.pyc
ADDED
|
Binary file (91.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-311.pyc
ADDED
|
Binary file (66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_backward.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 3 |
+
import collections
|
| 4 |
+
import logging
|
| 5 |
+
import weakref
|
| 6 |
+
from typing import Any, cast, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.autograd.graph import GradientEdge, Node
|
| 10 |
+
from torch.nn import Parameter
|
| 11 |
+
|
| 12 |
+
from ._debug import map_debug_info
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
|
| 19 |
+
"""
|
| 20 |
+
Get the grad function or grad accumulator for a tensor.
|
| 21 |
+
|
| 22 |
+
Accumulate grad nodes are lazily created, so we need to a
|
| 23 |
+
dummy view in order to trigger its creation.
|
| 24 |
+
"""
|
| 25 |
+
if t.requires_grad and t.grad_fn is None:
|
| 26 |
+
# if no grad function (leaf tensors) we use view
|
| 27 |
+
viewed_t = t.view_as(t)
|
| 28 |
+
grad_fn = viewed_t.grad_fn
|
| 29 |
+
if grad_fn is not None:
|
| 30 |
+
return grad_fn.next_functions[0][0]
|
| 31 |
+
else:
|
| 32 |
+
raise RuntimeError(
|
| 33 |
+
"Attempted to get grad_fn, but got None."
|
| 34 |
+
"Is this being created in a no-grad context?"
|
| 35 |
+
)
|
| 36 |
+
else:
|
| 37 |
+
return t.grad_fn
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def reverse_closure(
|
| 41 |
+
roots: List[Node], target_nodes: Set[Node]
|
| 42 |
+
) -> Tuple[Set[Node], Set[Node]]:
|
| 43 |
+
"""
|
| 44 |
+
This function returns the reverse closure of the given roots,
|
| 45 |
+
i.e. the set of nodes that can be reached from the roots by following the
|
| 46 |
+
reverse edges of the graph. The target_nodes are the nodes that we want to
|
| 47 |
+
include in the closure.
|
| 48 |
+
"""
|
| 49 |
+
# Recurse until we reach a target node
|
| 50 |
+
closure: Set[Node] = set()
|
| 51 |
+
visited_target_nodes = set()
|
| 52 |
+
q: Deque[Node] = collections.deque()
|
| 53 |
+
for node in roots:
|
| 54 |
+
if node is not None and node not in closure:
|
| 55 |
+
closure.add(node)
|
| 56 |
+
q.append(node)
|
| 57 |
+
while q:
|
| 58 |
+
node = q.popleft()
|
| 59 |
+
metadata = cast(Dict[str, List], node.metadata)
|
| 60 |
+
reverse_edges = metadata.get("reverse_edges", [])
|
| 61 |
+
for holder_ref, idx in reverse_edges:
|
| 62 |
+
ref = holder_ref()
|
| 63 |
+
if ref is None:
|
| 64 |
+
# this reverse graph is no longer alive
|
| 65 |
+
# raise RuntimeError("Reverse graph is no longer alive")
|
| 66 |
+
continue
|
| 67 |
+
fn = ref.node
|
| 68 |
+
if fn in closure or fn is None:
|
| 69 |
+
continue
|
| 70 |
+
if fn in target_nodes:
|
| 71 |
+
visited_target_nodes.add(fn)
|
| 72 |
+
continue
|
| 73 |
+
closure.add(fn)
|
| 74 |
+
q.append(fn)
|
| 75 |
+
return closure, visited_target_nodes
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Enable weak pointer
|
| 79 |
+
class Holder:
|
| 80 |
+
def __init__(self, node: Node):
|
| 81 |
+
self.node = node
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def construct_reverse_graph(roots: List[Node]) -> List[Holder]:
|
| 85 |
+
q: Deque[Node] = collections.deque()
|
| 86 |
+
root_seen: Set[Node] = set()
|
| 87 |
+
reverse_graph_refs: List[Holder] = []
|
| 88 |
+
for node in roots:
|
| 89 |
+
if node is not None and node not in root_seen:
|
| 90 |
+
q.append(node)
|
| 91 |
+
root_seen.add(node)
|
| 92 |
+
while q:
|
| 93 |
+
node = q.popleft()
|
| 94 |
+
for fn, idx in node.next_functions:
|
| 95 |
+
if fn is not None:
|
| 96 |
+
# Don't necessarily need to store on the graph
|
| 97 |
+
metadata = cast(Dict[str, List], fn.metadata)
|
| 98 |
+
reverse_edges = metadata.get("reverse_edges", [])
|
| 99 |
+
if len(reverse_edges) == 0:
|
| 100 |
+
q.append(fn)
|
| 101 |
+
holder = Holder(node)
|
| 102 |
+
holder_ref = weakref.ref(holder)
|
| 103 |
+
reverse_graph_refs.append(holder)
|
| 104 |
+
reverse_edges.append((holder_ref, idx))
|
| 105 |
+
metadata["reverse_edges"] = reverse_edges
|
| 106 |
+
return reverse_graph_refs
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, Any]]:
|
| 110 |
+
"""
|
| 111 |
+
Given a list of inputs and a list of parameters, return a list of parameter
|
| 112 |
+
groups, where each group contains the parameters and the intermediates that
|
| 113 |
+
are connected to the parameters.
|
| 114 |
+
|
| 115 |
+
The returned list of parameter groups is a list of dictionaries, where each
|
| 116 |
+
dictionary contains the following keys:
|
| 117 |
+
- "params": a set of parameters
|
| 118 |
+
- "intermediates": a set of intermediates
|
| 119 |
+
|
| 120 |
+
The returned list of parameter groups is a list of dictionaries,
|
| 121 |
+
"""
|
| 122 |
+
# reverse graph that starts with inputs, and goes up to the dOutput or the loss,
|
| 123 |
+
# but omits weights and any subgraphs connecting weights to this closure
|
| 124 |
+
inputs_closure, _ = reverse_closure(inputs, set())
|
| 125 |
+
param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates
|
| 126 |
+
for i, param in enumerate(params):
|
| 127 |
+
closure, intersected = reverse_closure([param], inputs_closure)
|
| 128 |
+
param_group: Dict[str, Set] = {
|
| 129 |
+
"params": {param},
|
| 130 |
+
"intermediates": intersected,
|
| 131 |
+
}
|
| 132 |
+
for input_node in intersected:
|
| 133 |
+
existing = param_groups.get(input_node, None)
|
| 134 |
+
if existing is not None:
|
| 135 |
+
existing["params"] = existing["params"].union(param_group["params"])
|
| 136 |
+
existing["intermediates"] = existing["intermediates"].union(
|
| 137 |
+
param_group["intermediates"]
|
| 138 |
+
)
|
| 139 |
+
param_group = existing
|
| 140 |
+
else:
|
| 141 |
+
param_groups[input_node] = param_group
|
| 142 |
+
|
| 143 |
+
# Sanity check: union of all param_groups params should be equal to all params
|
| 144 |
+
union_params: Set[Node] = set()
|
| 145 |
+
seen_ids: Set[int] = set()
|
| 146 |
+
unique_param_groups = []
|
| 147 |
+
for param_group in param_groups.values():
|
| 148 |
+
if id(param_group) not in seen_ids:
|
| 149 |
+
seen_ids.add(id(param_group))
|
| 150 |
+
unique_param_groups.append(param_group)
|
| 151 |
+
union_params = union_params.union(param_group["params"])
|
| 152 |
+
|
| 153 |
+
# The assert will only be true if the input tensor requires gradients,
|
| 154 |
+
# otherwise the autograd graph will miss the first layer of inputs
|
| 155 |
+
# assert union_params == set(params)
|
| 156 |
+
return unique_param_groups
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def stage_backward_input(
|
| 160 |
+
stage_outputs: List[torch.Tensor],
|
| 161 |
+
output_grads: Optional[List[torch.Tensor]],
|
| 162 |
+
input_values: List[torch.Tensor],
|
| 163 |
+
weights: Iterator[Parameter],
|
| 164 |
+
):
|
| 165 |
+
"""
|
| 166 |
+
compute the gradients for only the stage inputs with respect to the stage outputs
|
| 167 |
+
"""
|
| 168 |
+
stage_output_grad_fns: List[Node] = list(
|
| 169 |
+
filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs))
|
| 170 |
+
)
|
| 171 |
+
stage_input_grad_fns: List[Node] = list(
|
| 172 |
+
filter(None, map(_get_grad_fn_or_grad_acc, input_values))
|
| 173 |
+
)
|
| 174 |
+
weight_grad_fns: List[Node] = list(
|
| 175 |
+
filter(None, map(_get_grad_fn_or_grad_acc, weights))
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
reverse_graph_refs = construct_reverse_graph(stage_output_grad_fns)
|
| 179 |
+
param_groups = get_param_groups(stage_input_grad_fns, weight_grad_fns)
|
| 180 |
+
del reverse_graph_refs
|
| 181 |
+
|
| 182 |
+
for param_group in param_groups:
|
| 183 |
+
for i, intermediate in enumerate(param_group["intermediates"]):
|
| 184 |
+
|
| 185 |
+
def get_hook(param_group, i):
|
| 186 |
+
def hook(grad_inputs):
|
| 187 |
+
if param_group.get("grads", None) is None:
|
| 188 |
+
param_group["grads"] = [None] * len(
|
| 189 |
+
param_group["intermediates"]
|
| 190 |
+
)
|
| 191 |
+
param_group["grads"][i] = grad_inputs
|
| 192 |
+
|
| 193 |
+
return hook
|
| 194 |
+
|
| 195 |
+
# These are always "split" nodes that we need to recompute, so
|
| 196 |
+
# save their inputs.
|
| 197 |
+
intermediate.register_prehook(get_hook(param_group, i))
|
| 198 |
+
|
| 199 |
+
# Stage 0 inputs do not require grads? Should we skip in that case?
|
| 200 |
+
if all(tensor.requires_grad for tensor in input_values):
|
| 201 |
+
if output_grads is None:
|
| 202 |
+
# In case this is the loss and there are no output_grads, then we just use 1s
|
| 203 |
+
output_grads = [
|
| 204 |
+
torch.ones_like(stage_output) for stage_output in stage_outputs
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
dinputs = torch.autograd.grad(
|
| 208 |
+
stage_outputs,
|
| 209 |
+
inputs=input_values,
|
| 210 |
+
grad_outputs=output_grads,
|
| 211 |
+
retain_graph=True,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# update the gradients for inputs
|
| 215 |
+
for i, inp in enumerate(input_values):
|
| 216 |
+
if inp.grad is None:
|
| 217 |
+
inp.grad = dinputs[i]
|
| 218 |
+
else:
|
| 219 |
+
inp.grad += dinputs[i]
|
| 220 |
+
else:
|
| 221 |
+
dinputs = None
|
| 222 |
+
return dinputs, param_groups
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def stage_backward_weight(
|
| 226 |
+
weights: Iterator[Parameter], param_groups: List[Dict[str, Any]]
|
| 227 |
+
):
|
| 228 |
+
# map weights to param_group_weights
|
| 229 |
+
grad_acc_to_weight = {}
|
| 230 |
+
weight_grads = []
|
| 231 |
+
for index, weight in enumerate(weights):
|
| 232 |
+
grad_acc = _get_grad_fn_or_grad_acc(weight)
|
| 233 |
+
grad_acc_to_weight[grad_acc] = weight, index
|
| 234 |
+
weight_grads.append(weight.grad)
|
| 235 |
+
|
| 236 |
+
for param_group in param_groups:
|
| 237 |
+
# TODO: Handle case where intermediate can have multiple outputs
|
| 238 |
+
intermediate_edges = tuple(
|
| 239 |
+
GradientEdge(i, 0) for i in param_group["intermediates"]
|
| 240 |
+
)
|
| 241 |
+
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
|
| 242 |
+
|
| 243 |
+
assert all(len(g) == 1 for g in param_group["grads"])
|
| 244 |
+
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
|
| 245 |
+
# We do not need to retain_graph because... guarantee no overlap?
|
| 246 |
+
# print("trying to execute: ", intermediate_edges, weights_edges)
|
| 247 |
+
dweights = torch.autograd.grad(
|
| 248 |
+
intermediate_edges,
|
| 249 |
+
weights_edges,
|
| 250 |
+
grad_outputs=sum(param_group["grads"], tuple()),
|
| 251 |
+
)
|
| 252 |
+
for grad_acc, dw in zip(param_group["params"], dweights):
|
| 253 |
+
weight, index = grad_acc_to_weight[grad_acc]
|
| 254 |
+
if weight.grad is None:
|
| 255 |
+
weight.grad = dw
|
| 256 |
+
else:
|
| 257 |
+
weight.grad += dw
|
| 258 |
+
# return grads in the original order weights were provided in
|
| 259 |
+
return weight_grads
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def stage_backward(
|
| 263 |
+
stage_output,
|
| 264 |
+
output_grads,
|
| 265 |
+
input_values,
|
| 266 |
+
outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used
|
| 267 |
+
):
|
| 268 |
+
"""
|
| 269 |
+
This is a helper function to:
|
| 270 |
+
1. compute the gradients for the stage inputs, and
|
| 271 |
+
2. accumulate gradients for the stage module's parameters.
|
| 272 |
+
|
| 273 |
+
Given the input value(s) and the corresponding gradient for the output
|
| 274 |
+
value(s), compute and accumulate gradients for all parameter values (leaves
|
| 275 |
+
in the autograd trace) as well as return a list of the gradients for the
|
| 276 |
+
input values
|
| 277 |
+
"""
|
| 278 |
+
if outputs_with_grads_idxs is not None:
|
| 279 |
+
# Deprecated, not used in runtime calls, only exists in compiler
|
| 280 |
+
stage_output = [stage_output[i] for i in outputs_with_grads_idxs]
|
| 281 |
+
output_grads = [output_grads[i] for i in outputs_with_grads_idxs]
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
# stage_output may be a composite datatype like dict. Extract all individual
|
| 285 |
+
# tensor values here
|
| 286 |
+
stage_output_tensors = []
|
| 287 |
+
output_grad_tensors = []
|
| 288 |
+
|
| 289 |
+
def extract_tensors_with_grads(output_val, grad_val):
|
| 290 |
+
if isinstance(output_val, torch.Tensor):
|
| 291 |
+
if not output_val.requires_grad and output_val.grad_fn is None:
|
| 292 |
+
return
|
| 293 |
+
assert isinstance(
|
| 294 |
+
grad_val, (torch.Tensor, type(None))
|
| 295 |
+
), f"Expected Tensor or None gradient but got {type(grad_val)}"
|
| 296 |
+
stage_output_tensors.append(output_val)
|
| 297 |
+
output_grad_tensors.append(grad_val)
|
| 298 |
+
elif isinstance(output_val, (tuple, list)):
|
| 299 |
+
if grad_val is None:
|
| 300 |
+
return
|
| 301 |
+
assert isinstance(
|
| 302 |
+
grad_val, (tuple, list)
|
| 303 |
+
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
|
| 304 |
+
assert len(output_val) == len(grad_val)
|
| 305 |
+
for ov, gv in zip(output_val, grad_val):
|
| 306 |
+
extract_tensors_with_grads(ov, gv)
|
| 307 |
+
elif isinstance(output_val, dict):
|
| 308 |
+
if grad_val is None:
|
| 309 |
+
return
|
| 310 |
+
assert isinstance(grad_val, dict)
|
| 311 |
+
assert set(output_val.keys()) == set(grad_val.keys())
|
| 312 |
+
for k in output_val.keys():
|
| 313 |
+
extract_tensors_with_grads(output_val[k], grad_val[k])
|
| 314 |
+
else:
|
| 315 |
+
# Output is a non-tensor type; just ignore it
|
| 316 |
+
pass
|
| 317 |
+
|
| 318 |
+
extract_tensors_with_grads(stage_output, output_grads)
|
| 319 |
+
|
| 320 |
+
torch.autograd.backward(
|
| 321 |
+
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Extract gradients wrt the input values
|
| 325 |
+
grad_inputs = []
|
| 326 |
+
for val in input_values:
|
| 327 |
+
if isinstance(val, torch.Tensor):
|
| 328 |
+
grad_inputs.append(val.grad)
|
| 329 |
+
else:
|
| 330 |
+
grad_inputs.append(None)
|
| 331 |
+
|
| 332 |
+
# Alternative impl: `torch.autograd.grad`.
|
| 333 |
+
# Note that `torch.autograd.grad` will not accumulate gradients into the
|
| 334 |
+
# model's parameters.
|
| 335 |
+
"""
|
| 336 |
+
inputs_with_grad = []
|
| 337 |
+
for val in input_values:
|
| 338 |
+
if isinstance(val, torch.Tensor) and val.requires_grad:
|
| 339 |
+
inputs_with_grad.append(val)
|
| 340 |
+
|
| 341 |
+
grad_inputs = torch.autograd.grad(
|
| 342 |
+
stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type]
|
| 343 |
+
)
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
except Exception as e:
|
| 347 |
+
exc_msg = f"""
|
| 348 |
+
Failed to run stage backward:
|
| 349 |
+
Stage output: {map_debug_info(stage_output)}
|
| 350 |
+
Output gradient: {map_debug_info(output_grads)}
|
| 351 |
+
Input: {map_debug_info(input_values)}
|
| 352 |
+
"""
|
| 353 |
+
raise RuntimeError(exc_msg) from e
|
| 354 |
+
|
| 355 |
+
return grad_inputs
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# TODO: handling requires_grad=False dynamically. Can we analyze this during initial
|
| 359 |
+
# IR emission?
|
| 360 |
+
def _null_coalesce_accumulate(lhs, rhs):
|
| 361 |
+
"""
|
| 362 |
+
Coalesce two values, even if one of them is null, returning the non-null
|
| 363 |
+
value.
|
| 364 |
+
"""
|
| 365 |
+
if lhs is None:
|
| 366 |
+
return rhs
|
| 367 |
+
elif rhs is None:
|
| 368 |
+
return lhs
|
| 369 |
+
else:
|
| 370 |
+
return torch.add(lhs, rhs)
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_debug.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def friendly_debug_info(v):
|
| 7 |
+
"""
|
| 8 |
+
Helper function to print out debug info in a friendly way.
|
| 9 |
+
"""
|
| 10 |
+
if isinstance(v, torch.Tensor):
|
| 11 |
+
return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})"
|
| 12 |
+
else:
|
| 13 |
+
return str(v)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def map_debug_info(a):
|
| 17 |
+
"""
|
| 18 |
+
Helper function to apply `friendly_debug_info` to items in `a`.
|
| 19 |
+
`a` may be a list, tuple, or dict.
|
| 20 |
+
"""
|
| 21 |
+
return torch.fx.node.map_aggregate(a, friendly_debug_info)
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_unflatten.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.export.unflatten import _ModuleFrame
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _outline_submodules(orig_graph: torch.fx.Graph):
|
| 10 |
+
# Create an empty GraphModule to hold the outlined modules
|
| 11 |
+
new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
| 12 |
+
seen_nodes: Dict[str, torch.fx.Node] = {}
|
| 13 |
+
seen_modules: Dict[int, torch.nn.Module] = {}
|
| 14 |
+
_ModuleFrame(
|
| 15 |
+
orig_graph,
|
| 16 |
+
tuple(orig_graph.nodes),
|
| 17 |
+
seen_nodes,
|
| 18 |
+
seen_modules,
|
| 19 |
+
None,
|
| 20 |
+
[""],
|
| 21 |
+
"",
|
| 22 |
+
{},
|
| 23 |
+
module=new_module,
|
| 24 |
+
).run_outer()
|
| 25 |
+
new_module.graph.lint()
|
| 26 |
+
new_module.recompile()
|
| 27 |
+
return new_module
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/_utils.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 3 |
+
import logging
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import fx
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def flatten_args_detach(args):
|
| 15 |
+
"""
|
| 16 |
+
Flatten the args into a list form and detach the tensors from computational graph.
|
| 17 |
+
"""
|
| 18 |
+
flat_detached_args = []
|
| 19 |
+
|
| 20 |
+
def extract_tensor_args(a):
|
| 21 |
+
nonlocal flat_detached_args
|
| 22 |
+
if isinstance(a, torch.Tensor):
|
| 23 |
+
val = a.detach().requires_grad_(a.requires_grad)
|
| 24 |
+
flat_detached_args.append(val)
|
| 25 |
+
return val
|
| 26 |
+
else:
|
| 27 |
+
flat_detached_args.append(a)
|
| 28 |
+
return a
|
| 29 |
+
|
| 30 |
+
new_args = fx.node.map_aggregate(
|
| 31 |
+
args,
|
| 32 |
+
extract_tensor_args,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
return new_args, flat_detached_args
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def flatten_args(args):
|
| 39 |
+
"""
|
| 40 |
+
Flatten the args into a list form.
|
| 41 |
+
"""
|
| 42 |
+
flat_args = []
|
| 43 |
+
|
| 44 |
+
def extract_tensor_args(a):
|
| 45 |
+
nonlocal flat_args
|
| 46 |
+
flat_args.append(a)
|
| 47 |
+
return a
|
| 48 |
+
|
| 49 |
+
fx.node.map_aggregate(
|
| 50 |
+
args,
|
| 51 |
+
extract_tensor_args,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return flat_args
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PipeliningShapeError(RuntimeError):
|
| 58 |
+
"""Shape mismatch between configured and runtime values."""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def validate_tensor_metadata(desc, expected, given):
|
| 62 |
+
if not expected.shape == given.shape:
|
| 63 |
+
raise PipeliningShapeError(
|
| 64 |
+
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
|
| 65 |
+
)
|
| 66 |
+
if not expected.dtype == given.dtype:
|
| 67 |
+
raise PipeliningShapeError(
|
| 68 |
+
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
|
| 69 |
+
)
|
| 70 |
+
if not expected.stride() == given.stride():
|
| 71 |
+
raise PipeliningShapeError(
|
| 72 |
+
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def validate_tensors_metadata(
|
| 77 |
+
desc,
|
| 78 |
+
expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
| 79 |
+
actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
| 80 |
+
):
|
| 81 |
+
if len(expected_tensors) != len(actual_tensors):
|
| 82 |
+
raise PipeliningShapeError(
|
| 83 |
+
f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
|
| 84 |
+
)
|
| 85 |
+
for i in range(len(expected_tensors)):
|
| 86 |
+
validate_tensor_metadata(
|
| 87 |
+
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class PipeInfo:
|
| 93 |
+
"""
|
| 94 |
+
Captures information for a pipeline (`Pipe` object).
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
graph: fx.Graph
|
| 98 |
+
num_stages: int
|
| 99 |
+
has_loss_and_backward: bool
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/microbatch.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.fx.node import map_aggregate
|
| 8 |
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"TensorChunkSpec",
|
| 13 |
+
"split_args_kwargs_into_chunks",
|
| 14 |
+
"merge_chunks",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
_debug_mask_minibatches specifies to send masked versions of the mini-batch
|
| 21 |
+
through instead of micro-batch slices--this can be used for more stable
|
| 22 |
+
numerical testing (see [A Note About Correctness Testing])
|
| 23 |
+
"""
|
| 24 |
+
_debug_mask_minibatches = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class _CustomReducer:
|
| 28 |
+
"""
|
| 29 |
+
Custom reducer class that can be used to specify a custom operation that
|
| 30 |
+
reduces losses of multiple microbatches into one value.
|
| 31 |
+
|
| 32 |
+
Example:
|
| 33 |
+
>>> # xdoctest: +SKIP
|
| 34 |
+
>>> sum_reducer = _CustomReducer(
|
| 35 |
+
>>> torch.tensor(0.0),
|
| 36 |
+
>>> lambda a, b: a + b
|
| 37 |
+
>>> )
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, init_value, reduce_fn):
|
| 41 |
+
self.init_value = init_value
|
| 42 |
+
self.reduce_fn = reduce_fn
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class _LossReducer(_CustomReducer):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b)
|
| 50 |
+
|
| 51 |
+
# Default chunking dimension is 0. This is used for the case where the user did
|
| 52 |
+
# not specify a chunking dimension.
|
| 53 |
+
DEFAULT_CHUNK_DIM = 0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TensorChunkSpec:
|
| 57 |
+
"""
|
| 58 |
+
Class used to specify chunking of inputs
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, split_dim):
|
| 62 |
+
self.split_dim = split_dim
|
| 63 |
+
|
| 64 |
+
split_dim: int
|
| 65 |
+
|
| 66 |
+
def __repr__(self):
|
| 67 |
+
return (
|
| 68 |
+
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def __str__(self):
|
| 72 |
+
return f"TensorChunkSpec({self.split_dim})"
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def from_tuple(
|
| 76 |
+
chunk_dims: Tuple[int, ...],
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
|
| 80 |
+
dimensions (int's).
|
| 81 |
+
Example:
|
| 82 |
+
>>> # xdoctest: +SKIP
|
| 83 |
+
>>> # There are three positional arguments to the model, and
|
| 84 |
+
>>> # we are chunking them along dimension 0, 0 and 1, respectively
|
| 85 |
+
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
|
| 86 |
+
"""
|
| 87 |
+
args_chunk_spec = map_aggregate(
|
| 88 |
+
chunk_dims,
|
| 89 |
+
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
|
| 90 |
+
)
|
| 91 |
+
return args_chunk_spec
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def from_dict(
|
| 95 |
+
chunk_dims: Dict[str, int],
|
| 96 |
+
):
|
| 97 |
+
"""
|
| 98 |
+
A helper for creating a dictionary of `TensorChunkSpec` from a
|
| 99 |
+
dictionary of chunk dimensions (int's).
|
| 100 |
+
Example:
|
| 101 |
+
>>> # xdoctest: +SKIP
|
| 102 |
+
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
|
| 103 |
+
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
|
| 104 |
+
"""
|
| 105 |
+
kwargs_chunk_spec = map_aggregate(
|
| 106 |
+
chunk_dims,
|
| 107 |
+
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
|
| 108 |
+
)
|
| 109 |
+
return kwargs_chunk_spec
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Class used to specify replication of inputs
|
| 113 |
+
class _Replicate:
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _shard_dict_of_args(
|
| 118 |
+
args_dict,
|
| 119 |
+
args_chunk_spec,
|
| 120 |
+
num_chunks,
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Given a dictionary of args, and a dictionary of chunking specs, shard the
|
| 124 |
+
args according to the chunking specs.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
args_dict: Dictionary of args
|
| 128 |
+
args_chunk_spec: Dictionary of chunking specs
|
| 129 |
+
num_chunks: Number of chunks to shard the args into
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
args_split: List of sharded args
|
| 133 |
+
"""
|
| 134 |
+
# Stage 1+2: flatten and shard/replicate
|
| 135 |
+
|
| 136 |
+
# args_sharded_replicated : [num args, num flat values, num chunks]
|
| 137 |
+
args_sharded_replicated = {}
|
| 138 |
+
arg_specs = []
|
| 139 |
+
|
| 140 |
+
real_num_chunks = num_chunks
|
| 141 |
+
first_tensor = True
|
| 142 |
+
|
| 143 |
+
assert len(args_dict) == len(
|
| 144 |
+
args_chunk_spec
|
| 145 |
+
), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
|
| 146 |
+
|
| 147 |
+
for arg_key, arg in args_dict.items():
|
| 148 |
+
flat, spec = tree_flatten(arg)
|
| 149 |
+
arg_specs.append(spec)
|
| 150 |
+
|
| 151 |
+
chunk_spec = args_chunk_spec[arg_key]
|
| 152 |
+
assert chunk_spec is not None # Should have been set by caller
|
| 153 |
+
chunk_spec_flat, _ = tree_flatten(chunk_spec)
|
| 154 |
+
if len(flat) != len(chunk_spec_flat):
|
| 155 |
+
raise ValueError(
|
| 156 |
+
f"Argument value {arg} did not have the same number of "
|
| 157 |
+
f"values as as chunk spec {chunk_spec}"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
sharded_arg_flat = []
|
| 161 |
+
|
| 162 |
+
for v, chunk_v in zip(flat, chunk_spec_flat):
|
| 163 |
+
if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
|
| 164 |
+
sharded_arg_flat.append([v] * real_num_chunks)
|
| 165 |
+
elif isinstance(chunk_v, TensorChunkSpec):
|
| 166 |
+
# TODO: check type of v. If it's a tensor, use chunk (or debug mask).
|
| 167 |
+
# If it's a collection type, split it as you would expect. Otherwise,
|
| 168 |
+
# Throw an error
|
| 169 |
+
assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
|
| 170 |
+
|
| 171 |
+
v_split_dim_size = v.size(chunk_v.split_dim)
|
| 172 |
+
if v_split_dim_size < real_num_chunks:
|
| 173 |
+
if first_tensor:
|
| 174 |
+
# We can only adjust number of chunks when we hit this
|
| 175 |
+
# issue at the first tensor encountered
|
| 176 |
+
logger.warning(
|
| 177 |
+
f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004
|
| 178 |
+
f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
|
| 179 |
+
)
|
| 180 |
+
real_num_chunks = v_split_dim_size
|
| 181 |
+
else:
|
| 182 |
+
raise RuntimeError(
|
| 183 |
+
f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
|
| 184 |
+
f"smaller than the number of chunks {num_chunks}. "
|
| 185 |
+
"PiPPy cannot reduce the number of chunks because "
|
| 186 |
+
"other arguments have bigger chunk-dimension sizes. "
|
| 187 |
+
"Please adjust your num_chunks setting."
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
chunk_tensors = torch.tensor_split(
|
| 191 |
+
v, real_num_chunks, chunk_v.split_dim
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if _debug_mask_minibatches:
|
| 195 |
+
expanded_chunks = []
|
| 196 |
+
|
| 197 |
+
split_dim_idx = 0
|
| 198 |
+
for chunk_tensor in chunk_tensors:
|
| 199 |
+
new_val = torch.zeros_like(v)
|
| 200 |
+
upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
|
| 201 |
+
|
| 202 |
+
slice_indices = [slice(None, None, None)] * new_val.ndim
|
| 203 |
+
slice_indices[chunk_v.split_dim] = slice(
|
| 204 |
+
split_dim_idx, upper_idx
|
| 205 |
+
)
|
| 206 |
+
new_val[slice_indices] = chunk_tensor
|
| 207 |
+
|
| 208 |
+
expanded_chunks.append(new_val)
|
| 209 |
+
|
| 210 |
+
split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
|
| 211 |
+
|
| 212 |
+
sharded_arg_flat.append(expanded_chunks)
|
| 213 |
+
else:
|
| 214 |
+
sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type]
|
| 215 |
+
|
| 216 |
+
first_tensor = False
|
| 217 |
+
else:
|
| 218 |
+
raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
|
| 219 |
+
|
| 220 |
+
args_sharded_replicated[arg_key] = sharded_arg_flat
|
| 221 |
+
|
| 222 |
+
# chunks_flat : [num chunks, num args, num flat values]
|
| 223 |
+
chunks_flat = []
|
| 224 |
+
for chunk_idx in range(real_num_chunks):
|
| 225 |
+
chunk_args = {}
|
| 226 |
+
for key, arg in args_sharded_replicated.items():
|
| 227 |
+
arg_single_chunk = []
|
| 228 |
+
for v_flat in arg:
|
| 229 |
+
arg_single_chunk.append(v_flat[chunk_idx])
|
| 230 |
+
chunk_args[key] = arg_single_chunk
|
| 231 |
+
chunks_flat.append(chunk_args)
|
| 232 |
+
|
| 233 |
+
# args_split : [num chunks, num args]
|
| 234 |
+
args_split = []
|
| 235 |
+
|
| 236 |
+
for chunk in chunks_flat:
|
| 237 |
+
per_chunk_args = {}
|
| 238 |
+
assert len(arg_specs) == len(chunk)
|
| 239 |
+
for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
|
| 240 |
+
per_chunk_args[key] = tree_unflatten(arg, arg_spec)
|
| 241 |
+
args_split.append(per_chunk_args)
|
| 242 |
+
|
| 243 |
+
return args_split
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def split_args_kwargs_into_chunks(
|
| 247 |
+
args: Tuple[Any, ...],
|
| 248 |
+
kwargs: Optional[Dict[str, Any]],
|
| 249 |
+
chunks: int,
|
| 250 |
+
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
| 251 |
+
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
| 252 |
+
) -> Tuple[List[Tuple], List[Dict]]:
|
| 253 |
+
"""
|
| 254 |
+
Given a sequence of args and kwargs, split them into a number of chunks
|
| 255 |
+
according to their respective chunking specs.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
args: Tuple of args
|
| 259 |
+
kwargs: Dict of kwargs
|
| 260 |
+
chunks: Number of chunks to split the args and kwargs into
|
| 261 |
+
args_chunk_spec: chunking specs for args, in same shape as args
|
| 262 |
+
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
args_split: List of sharded args
|
| 266 |
+
kwargs_split: List of sharded kwargs
|
| 267 |
+
"""
|
| 268 |
+
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
|
| 269 |
+
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
|
| 270 |
+
# and `kwargs_chunk_spec` specifications. The steps are as follows:
|
| 271 |
+
#
|
| 272 |
+
# 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
|
| 273 |
+
# To use a running example: suppose our inputs look like
|
| 274 |
+
#
|
| 275 |
+
# args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
|
| 276 |
+
# (kwargs not shown but it's a similar process)
|
| 277 |
+
#
|
| 278 |
+
# Then for this step we would end up with
|
| 279 |
+
#
|
| 280 |
+
# args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
|
| 281 |
+
#
|
| 282 |
+
# 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
|
| 283 |
+
#
|
| 284 |
+
# args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
|
| 285 |
+
#
|
| 286 |
+
# 3. Rotate the nesting order such that chunks are the outer dimension
|
| 287 |
+
#
|
| 288 |
+
# args_chunks = [
|
| 289 |
+
# ([A, B, C_1], D),
|
| 290 |
+
# ([A, B, C_2], D),
|
| 291 |
+
# ]
|
| 292 |
+
#
|
| 293 |
+
# 4. Unflatten each chunk according to the spec
|
| 294 |
+
#
|
| 295 |
+
# args_chunks = [
|
| 296 |
+
# ([A, [B, C_1]], D),
|
| 297 |
+
# ([A, [B, C_2]], D),
|
| 298 |
+
# ]
|
| 299 |
+
|
| 300 |
+
# TODO: _debug_mask_minibatches
|
| 301 |
+
# Handle the case where kwargs is None
|
| 302 |
+
if kwargs is None:
|
| 303 |
+
kwargs = {}
|
| 304 |
+
|
| 305 |
+
# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
|
| 306 |
+
# their format and use default chunking along dim 0
|
| 307 |
+
if args_chunk_spec is None:
|
| 308 |
+
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
|
| 309 |
+
|
| 310 |
+
if kwargs_chunk_spec is None:
|
| 311 |
+
kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
|
| 312 |
+
|
| 313 |
+
args_split_dict = _shard_dict_of_args(
|
| 314 |
+
dict(enumerate(args)),
|
| 315 |
+
dict(enumerate(args_chunk_spec)),
|
| 316 |
+
chunks,
|
| 317 |
+
)
|
| 318 |
+
real_num_chunks = len(args_split_dict)
|
| 319 |
+
|
| 320 |
+
kwargs_split = _shard_dict_of_args(
|
| 321 |
+
kwargs,
|
| 322 |
+
kwargs_chunk_spec,
|
| 323 |
+
real_num_chunks,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if len(kwargs_split) < real_num_chunks:
|
| 327 |
+
# In case kwargs are sharded into less chunks
|
| 328 |
+
# e.g. when `args` has no tensor, just values
|
| 329 |
+
real_num_chunks = len(kwargs_split)
|
| 330 |
+
# Re-shard args
|
| 331 |
+
args_split_dict = _shard_dict_of_args(
|
| 332 |
+
dict(enumerate(args)),
|
| 333 |
+
dict(enumerate(args_chunk_spec)),
|
| 334 |
+
real_num_chunks,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if len(args_split_dict) != len(kwargs_split):
|
| 338 |
+
raise RuntimeError(
|
| 339 |
+
"args and kwargs are split into different number of chunks: "
|
| 340 |
+
f"{len(args_split_dict)}, {len(kwargs_split)}"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
args_split = []
|
| 344 |
+
for chunk_args in args_split_dict:
|
| 345 |
+
args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args))))
|
| 346 |
+
|
| 347 |
+
return args_split, kwargs_split
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def merge_chunks(
|
| 351 |
+
chunks: List[Any],
|
| 352 |
+
chunk_spec,
|
| 353 |
+
):
|
| 354 |
+
"""
|
| 355 |
+
Given a list of chunks, merge them into a single value according to
|
| 356 |
+
the chunk spec.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
chunks: list of chunks
|
| 360 |
+
chunk_spec: Chunking spec for the chunks
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
value: Merged value
|
| 364 |
+
"""
|
| 365 |
+
# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
|
| 366 |
+
# steps are similar to the steps in that function but in reverse. Given the
|
| 367 |
+
# input values:
|
| 368 |
+
#
|
| 369 |
+
# chunks = [
|
| 370 |
+
# ([A, [B, C_1]], D),
|
| 371 |
+
# ([A, [B, C_2]], D),
|
| 372 |
+
# ]
|
| 373 |
+
# args_spec = ([None, [None, TensorChunkSpec]], None)
|
| 374 |
+
#
|
| 375 |
+
# 1. Flatten the chunks according to the chunk_spec
|
| 376 |
+
#
|
| 377 |
+
# chunks_flat = [
|
| 378 |
+
# ([A, B, C_1], D),
|
| 379 |
+
# ([A, B, C_2], D),
|
| 380 |
+
# ]
|
| 381 |
+
#
|
| 382 |
+
# 2. Rotate the nesting order such that chunks are the inner dimension
|
| 383 |
+
#
|
| 384 |
+
# value_inner = ([A, B, [C_1, C_2]], D)
|
| 385 |
+
#
|
| 386 |
+
# 3. Concatenate sharded arguments
|
| 387 |
+
#
|
| 388 |
+
# value_combined = ([A, B, C], D)
|
| 389 |
+
#
|
| 390 |
+
# 4. Unflatten the combined args given the spec
|
| 391 |
+
#
|
| 392 |
+
# value = ([A, [B, C]], D)
|
| 393 |
+
|
| 394 |
+
# Preliminary: flatten the chunk spec
|
| 395 |
+
if chunk_spec is not None:
|
| 396 |
+
spec_flattened, flatten_spec = tree_flatten(chunk_spec)
|
| 397 |
+
else:
|
| 398 |
+
# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
|
| 399 |
+
# We obtain the output structure by flattening chunk 0 and generate the chunk_spec
|
| 400 |
+
chunk0_flat, flatten_spec = tree_flatten(chunks[0])
|
| 401 |
+
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
|
| 402 |
+
|
| 403 |
+
# Stage 1: flatten chunks
|
| 404 |
+
# chunks_flattened : [num chunks, num args]
|
| 405 |
+
chunks_flattened = []
|
| 406 |
+
|
| 407 |
+
for chunk in chunks:
|
| 408 |
+
chunk_flattened, _ = tree_flatten(chunk)
|
| 409 |
+
if len(chunk_flattened) != len(spec_flattened):
|
| 410 |
+
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
|
| 411 |
+
|
| 412 |
+
chunks_flattened.append(chunk_flattened)
|
| 413 |
+
|
| 414 |
+
# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
|
| 415 |
+
# concatenate sharded operands
|
| 416 |
+
# args_flattened : [num args]
|
| 417 |
+
args_flattened = []
|
| 418 |
+
for arg_idx, arg in enumerate(spec_flattened):
|
| 419 |
+
if isinstance(arg, TensorChunkSpec):
|
| 420 |
+
partial_values = [
|
| 421 |
+
chunks_flattened[chunk_idx][arg_idx]
|
| 422 |
+
for chunk_idx in range(len(chunks_flattened))
|
| 423 |
+
]
|
| 424 |
+
|
| 425 |
+
if _debug_mask_minibatches:
|
| 426 |
+
# Infer size of individual chunks by running `tensor_split` again
|
| 427 |
+
overall_shape = partial_values[0].shape
|
| 428 |
+
for val in partial_values[1:]:
|
| 429 |
+
assert val.shape == overall_shape
|
| 430 |
+
meta_chunks = torch.tensor_split(
|
| 431 |
+
torch.empty(*overall_shape, device="meta"),
|
| 432 |
+
sections=len(partial_values),
|
| 433 |
+
dim=arg.split_dim,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
values_to_cat = []
|
| 437 |
+
chunk_start_idx = 0
|
| 438 |
+
assert len(partial_values) == len(meta_chunks)
|
| 439 |
+
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
|
| 440 |
+
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
|
| 441 |
+
|
| 442 |
+
slice_indices = [slice(None, None, None)] * partial_value.ndim
|
| 443 |
+
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
|
| 444 |
+
sliced = partial_value[slice_indices]
|
| 445 |
+
values_to_cat.append(sliced)
|
| 446 |
+
|
| 447 |
+
chunk_start_idx = chunk_end_idx
|
| 448 |
+
|
| 449 |
+
else:
|
| 450 |
+
values_to_cat = partial_values
|
| 451 |
+
|
| 452 |
+
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
|
| 453 |
+
elif isinstance(arg, _CustomReducer):
|
| 454 |
+
reduced_val = arg.init_value
|
| 455 |
+
|
| 456 |
+
for chunk_idx in range(len(chunks_flattened)):
|
| 457 |
+
reduced_val = arg.reduce_fn(
|
| 458 |
+
reduced_val, chunks_flattened[chunk_idx][arg_idx]
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
args_flattened.append(reduced_val)
|
| 462 |
+
else:
|
| 463 |
+
value = chunks_flattened[0][arg_idx]
|
| 464 |
+
for chunk_idx in range(1, len(chunks_flattened)):
|
| 465 |
+
assert chunks_flattened[chunk_idx][arg_idx] == value
|
| 466 |
+
args_flattened.append(value)
|
| 467 |
+
|
| 468 |
+
# Stage 4: Unflatten combined args
|
| 469 |
+
return tree_unflatten(args_flattened, flatten_spec)
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/schedules.py
ADDED
|
@@ -0,0 +1,2162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 3 |
+
|
| 4 |
+
import csv
|
| 5 |
+
import itertools
|
| 6 |
+
import logging
|
| 7 |
+
import re
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import (
|
| 12 |
+
Any,
|
| 13 |
+
Callable,
|
| 14 |
+
Dict,
|
| 15 |
+
List,
|
| 16 |
+
NamedTuple,
|
| 17 |
+
Optional,
|
| 18 |
+
Set,
|
| 19 |
+
Tuple,
|
| 20 |
+
TYPE_CHECKING,
|
| 21 |
+
Union,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.distributed as dist
|
| 26 |
+
from torch.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle
|
| 27 |
+
from torch.profiler import record_function
|
| 28 |
+
|
| 29 |
+
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
|
| 30 |
+
from .stage import _PipelineStageBase
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if TYPE_CHECKING:
|
| 34 |
+
from torch.distributed import Work
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"get_schedule_class",
|
| 38 |
+
"PipelineScheduleSingle",
|
| 39 |
+
"PipelineScheduleMulti",
|
| 40 |
+
"Schedule1F1B",
|
| 41 |
+
"ScheduleFlexibleInterleaved1F1B",
|
| 42 |
+
"ScheduleGPipe",
|
| 43 |
+
"ScheduleInterleaved1F1B",
|
| 44 |
+
"ScheduleLoopedBFS",
|
| 45 |
+
"ScheduleInterleavedZeroBubble",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class _ComputationType(Enum):
|
| 52 |
+
# TODO(whc) rename to _ActType?
|
| 53 |
+
FORWARD = 1
|
| 54 |
+
BACKWARD = 2
|
| 55 |
+
WEIGHT = 3
|
| 56 |
+
UNSHARD = 4
|
| 57 |
+
RESHARD = 5
|
| 58 |
+
SEND_F = 6
|
| 59 |
+
RECV_F = 7
|
| 60 |
+
SEND_B = 8
|
| 61 |
+
RECV_B = 9
|
| 62 |
+
|
| 63 |
+
def __str__(self):
|
| 64 |
+
str_map = {
|
| 65 |
+
_ComputationType.FORWARD: "F",
|
| 66 |
+
_ComputationType.BACKWARD: "B",
|
| 67 |
+
_ComputationType.WEIGHT: "W",
|
| 68 |
+
_ComputationType.UNSHARD: "UNSHARD",
|
| 69 |
+
_ComputationType.RESHARD: "RESHARD",
|
| 70 |
+
_ComputationType.SEND_F: "SEND_F",
|
| 71 |
+
_ComputationType.RECV_F: "RECV_F",
|
| 72 |
+
_ComputationType.SEND_B: "SEND_B",
|
| 73 |
+
_ComputationType.RECV_B: "RECV_B",
|
| 74 |
+
}
|
| 75 |
+
return str_map[self]
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def from_str(action):
|
| 79 |
+
if action == "F":
|
| 80 |
+
return _ComputationType.FORWARD
|
| 81 |
+
elif action == "B":
|
| 82 |
+
return _ComputationType.BACKWARD
|
| 83 |
+
elif action == "W":
|
| 84 |
+
return _ComputationType.WEIGHT
|
| 85 |
+
elif action == "UNSHARD":
|
| 86 |
+
return _ComputationType.UNSHARD
|
| 87 |
+
elif action == "RESHARD":
|
| 88 |
+
return _ComputationType.RESHARD
|
| 89 |
+
elif action == "SEND_F":
|
| 90 |
+
return _ComputationType.SEND_F
|
| 91 |
+
elif action == "RECV_F":
|
| 92 |
+
return _ComputationType.RECV_F
|
| 93 |
+
elif action == "SEND_B":
|
| 94 |
+
return _ComputationType.SEND_B
|
| 95 |
+
elif action == "RECV_B":
|
| 96 |
+
return _ComputationType.RECV_B
|
| 97 |
+
else:
|
| 98 |
+
raise RuntimeError(f"Invalid computation type {action}")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
FORWARD = _ComputationType.FORWARD
|
| 102 |
+
BACKWARD = _ComputationType.BACKWARD
|
| 103 |
+
WEIGHT = _ComputationType.WEIGHT
|
| 104 |
+
UNSHARD = _ComputationType.UNSHARD
|
| 105 |
+
RESHARD = _ComputationType.RESHARD
|
| 106 |
+
SEND_F = _ComputationType.SEND_F
|
| 107 |
+
RECV_F = _ComputationType.RECV_F
|
| 108 |
+
SEND_B = _ComputationType.SEND_B
|
| 109 |
+
RECV_B = _ComputationType.RECV_B
|
| 110 |
+
|
| 111 |
+
# Convenience shorthand for compute actions only since they are used in 'simple schedule format'
|
| 112 |
+
F = FORWARD
|
| 113 |
+
B = BACKWARD
|
| 114 |
+
W = WEIGHT
|
| 115 |
+
|
| 116 |
+
# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
|
| 117 |
+
_action_regex = re.compile(
|
| 118 |
+
r"(\d+)([F,B,W]|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B{0,1})(\d*)"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class _Action(NamedTuple):
|
| 123 |
+
stage_index: int
|
| 124 |
+
computation_type: _ComputationType
|
| 125 |
+
microbatch_index: Optional[int] = None
|
| 126 |
+
|
| 127 |
+
def __repr__(self):
|
| 128 |
+
repr = str(self.stage_index)
|
| 129 |
+
repr += str(self.computation_type)
|
| 130 |
+
if self.microbatch_index is not None:
|
| 131 |
+
repr += str(self.microbatch_index)
|
| 132 |
+
return repr
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
def from_str(str):
|
| 136 |
+
"""
|
| 137 |
+
Reverse of __repr__
|
| 138 |
+
|
| 139 |
+
String should be formatted as [stage][action type][(microbatch)]
|
| 140 |
+
e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
|
| 141 |
+
"""
|
| 142 |
+
if match := _action_regex.match(str):
|
| 143 |
+
stage_index, computation_type, microbatch_index = match.groups()
|
| 144 |
+
return _Action(
|
| 145 |
+
int(stage_index),
|
| 146 |
+
_ComputationType.from_str(computation_type),
|
| 147 |
+
int(microbatch_index) if len(microbatch_index) else None,
|
| 148 |
+
)
|
| 149 |
+
elif str == "" or str.isspace():
|
| 150 |
+
return None
|
| 151 |
+
raise RuntimeError(
|
| 152 |
+
f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
|
| 159 |
+
and returns the formatted string
|
| 160 |
+
"""
|
| 161 |
+
# Calculate the maximum number of steps across all ranks
|
| 162 |
+
num_steps = max(len(actions) for actions in pipeline_order.values())
|
| 163 |
+
step_labels = [
|
| 164 |
+
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
|
| 165 |
+
]
|
| 166 |
+
# Sorting the dictionary by keys and retrieving values in that order
|
| 167 |
+
rank_actions = [
|
| 168 |
+
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
|
| 169 |
+
]
|
| 170 |
+
# Transpose the list of lists (rows to columns)
|
| 171 |
+
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
|
| 172 |
+
# Generate column labels for ranks
|
| 173 |
+
num_ranks = len(pipeline_order)
|
| 174 |
+
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
|
| 175 |
+
# Calculate the maximum length of each column, considering labels
|
| 176 |
+
max_lengths = [
|
| 177 |
+
max(len(str(item)) if item is not None else 0 for item in col)
|
| 178 |
+
for col in zip(step_labels, *transposed_actions)
|
| 179 |
+
]
|
| 180 |
+
# Format the header row with rank labels
|
| 181 |
+
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
|
| 182 |
+
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
|
| 183 |
+
)
|
| 184 |
+
# Format each row with its corresponding label
|
| 185 |
+
formatted_rows = [
|
| 186 |
+
f"{label}: "
|
| 187 |
+
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
|
| 188 |
+
for label, row in zip(step_labels, transposed_actions)
|
| 189 |
+
]
|
| 190 |
+
# Join the rows into a single string
|
| 191 |
+
formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
|
| 192 |
+
return formatted_table
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _validate_pipeline_order(
|
| 196 |
+
pipeline_order: Dict[int, List[Optional[_Action]]],
|
| 197 |
+
num_microbatches: int,
|
| 198 |
+
num_stages: int,
|
| 199 |
+
enable_zero_bubble: bool = False,
|
| 200 |
+
):
|
| 201 |
+
"""
|
| 202 |
+
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
|
| 203 |
+
Validating that the pipeline order follows the rules:
|
| 204 |
+
1. Forward action for a microbatch must be before the Backward action for that microbatch
|
| 205 |
+
2. Recv for a microbatch must be before the send for that microbatch
|
| 206 |
+
3. Microbatch index is handled in sequential order for each stage
|
| 207 |
+
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
|
| 208 |
+
5. Same microbatch cannot be handled in the same time step across ranks
|
| 209 |
+
"""
|
| 210 |
+
# microbatch_index: (current computation type, current stage)
|
| 211 |
+
microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
|
| 212 |
+
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
|
| 213 |
+
for timestep in range(max_timestep):
|
| 214 |
+
error_msg: List[str] = []
|
| 215 |
+
current_timestep_actions = []
|
| 216 |
+
for rank in range(len(pipeline_order)):
|
| 217 |
+
action = (
|
| 218 |
+
pipeline_order[rank][timestep]
|
| 219 |
+
if timestep < len(pipeline_order[rank])
|
| 220 |
+
else None
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if action is not None:
|
| 224 |
+
computation_type = action.computation_type
|
| 225 |
+
if computation_type != _ComputationType.WEIGHT:
|
| 226 |
+
current_timestep_actions.append(action)
|
| 227 |
+
|
| 228 |
+
# TODO: enable this
|
| 229 |
+
# if len(current_timestep_actions) == 0:
|
| 230 |
+
# error_msg.append(
|
| 231 |
+
# "All actions were None, there is an unnecessary gap in the schedule"
|
| 232 |
+
# )
|
| 233 |
+
|
| 234 |
+
# Ensure that no microbatch is operated on twice in current_timestep_actions
|
| 235 |
+
unique_microbatch_indices = {
|
| 236 |
+
action.microbatch_index for action in current_timestep_actions
|
| 237 |
+
}
|
| 238 |
+
if len(unique_microbatch_indices) != len(current_timestep_actions):
|
| 239 |
+
error_msg.append(
|
| 240 |
+
"Duplicate microbatch index found in current_timestep_actions"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
for action in current_timestep_actions:
|
| 244 |
+
stage_index = action.stage_index
|
| 245 |
+
computation_type = action.computation_type
|
| 246 |
+
mb_index = action.microbatch_index
|
| 247 |
+
assert (
|
| 248 |
+
mb_index is not None
|
| 249 |
+
), "All currently supported action types require valid microbatch_index"
|
| 250 |
+
if mb_index >= num_microbatches:
|
| 251 |
+
error_msg.append(f"Microbatch index {mb_index} out of range")
|
| 252 |
+
|
| 253 |
+
# first microbatch
|
| 254 |
+
if mb_index not in microbatch_process_info:
|
| 255 |
+
if computation_type != _ComputationType.FORWARD or stage_index != 0:
|
| 256 |
+
error_msg.append(f"Incorrect start for microbatch {mb_index}")
|
| 257 |
+
microbatch_process_info[mb_index] = (computation_type, stage_index)
|
| 258 |
+
else:
|
| 259 |
+
# if the microbatch is included, check that the current stage is right after prev
|
| 260 |
+
prev_computation, prev_stage = microbatch_process_info[mb_index]
|
| 261 |
+
|
| 262 |
+
if prev_computation == _ComputationType.FORWARD:
|
| 263 |
+
if prev_stage == num_stages - 1:
|
| 264 |
+
expected_stage = num_stages - 1
|
| 265 |
+
expected_computation = _ComputationType.BACKWARD
|
| 266 |
+
else:
|
| 267 |
+
expected_stage = prev_stage + 1
|
| 268 |
+
expected_computation = _ComputationType.FORWARD
|
| 269 |
+
elif prev_computation == _ComputationType.BACKWARD:
|
| 270 |
+
if prev_stage == 0:
|
| 271 |
+
error_msg.append(
|
| 272 |
+
f"[{mb_index=}] already finished backward computation"
|
| 273 |
+
)
|
| 274 |
+
break
|
| 275 |
+
else:
|
| 276 |
+
expected_stage = prev_stage - 1
|
| 277 |
+
expected_computation = _ComputationType.BACKWARD
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError(
|
| 280 |
+
f"Computation type {prev_computation} not supported"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if expected_computation is not None:
|
| 284 |
+
if expected_computation != computation_type:
|
| 285 |
+
error_msg.append(
|
| 286 |
+
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if expected_stage != stage_index:
|
| 290 |
+
error_msg.append(
|
| 291 |
+
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
microbatch_process_info[mb_index] = (
|
| 295 |
+
expected_computation,
|
| 296 |
+
expected_stage,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if not enable_zero_bubble:
|
| 300 |
+
if len(error_msg) != 0:
|
| 301 |
+
raise RuntimeError(
|
| 302 |
+
f"Error at timestep {timestep}: " + ",".join(error_msg)
|
| 303 |
+
)
|
| 304 |
+
return
|
| 305 |
+
|
| 306 |
+
for rank in range(len(pipeline_order)):
|
| 307 |
+
backward_steps: Set[Tuple[int, int]] = set()
|
| 308 |
+
weight_steps: Set[Tuple[int, int]] = set()
|
| 309 |
+
|
| 310 |
+
for action in pipeline_order[rank]:
|
| 311 |
+
if action is None:
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
stage_index = action.stage_index
|
| 315 |
+
computation_type = action.computation_type
|
| 316 |
+
mb_index = action.microbatch_index
|
| 317 |
+
if computation_type == _ComputationType.BACKWARD:
|
| 318 |
+
if mb_index is not None:
|
| 319 |
+
backward_steps.add((mb_index, stage_index))
|
| 320 |
+
elif computation_type == _ComputationType.WEIGHT:
|
| 321 |
+
if (mb_index, stage_index) not in backward_steps:
|
| 322 |
+
error_msg.append(
|
| 323 |
+
f"{mb_index=}, {stage_index=} Weight happened before bwd"
|
| 324 |
+
)
|
| 325 |
+
if (mb_index, stage_index) in weight_steps:
|
| 326 |
+
error_msg.append(
|
| 327 |
+
f"{mb_index=}, {stage_index=} Duplicated weight step"
|
| 328 |
+
)
|
| 329 |
+
if mb_index is not None:
|
| 330 |
+
weight_steps.add((mb_index, stage_index))
|
| 331 |
+
|
| 332 |
+
if len(backward_steps) != len(weight_steps):
|
| 333 |
+
error_msg.append("Length weight steps != Length bwd steps")
|
| 334 |
+
|
| 335 |
+
if len(error_msg) != 0:
|
| 336 |
+
raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class _PipelineSchedule(ABC):
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
n_microbatches: int,
|
| 343 |
+
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
|
| 344 |
+
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
| 345 |
+
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
| 346 |
+
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 347 |
+
):
|
| 348 |
+
# From arguments
|
| 349 |
+
self._n_microbatches = n_microbatches
|
| 350 |
+
self._loss_fn = loss_fn
|
| 351 |
+
# Chunking specification for positional inputs. (default: `None`)
|
| 352 |
+
self._args_chunk_spec = args_chunk_spec
|
| 353 |
+
# Chunking specification for keyword inputs. (default: `None`)
|
| 354 |
+
self._kwargs_chunk_spec = kwargs_chunk_spec
|
| 355 |
+
self._output_merge_spec = output_merge_spec
|
| 356 |
+
"""
|
| 357 |
+
# args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
|
| 358 |
+
# They are used to convert batch to microbatches in `step(x)`. See
|
| 359 |
+
# `TensorChunkSpec` for helper methods for creating them.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
# Derived
|
| 363 |
+
self._has_backward = self._loss_fn is not None
|
| 364 |
+
|
| 365 |
+
# Holds the losses for each microbatch.
|
| 366 |
+
self._internal_losses: List[torch.Tensor] = []
|
| 367 |
+
logger.info("Using %s", self.__class__.__name__)
|
| 368 |
+
|
| 369 |
+
def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
|
| 370 |
+
if stage.is_last and self._has_backward:
|
| 371 |
+
loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
|
| 372 |
+
self._internal_losses.append(loss)
|
| 373 |
+
|
| 374 |
+
def _maybe_get_loss(self, stage, mb_index):
|
| 375 |
+
valid_index = 0 <= mb_index < len(self._internal_losses)
|
| 376 |
+
if stage.is_last and self._has_backward and valid_index:
|
| 377 |
+
return self._internal_losses[mb_index]
|
| 378 |
+
elif len(self._internal_losses) != 0 and not valid_index:
|
| 379 |
+
raise RuntimeError(
|
| 380 |
+
f"Loss for microbatch {mb_index} is not available. "
|
| 381 |
+
f"Available losses for microbatches: {self._internal_losses}"
|
| 382 |
+
)
|
| 383 |
+
else:
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
def _update_losses(self, stages, losses):
|
| 387 |
+
"""
|
| 388 |
+
Update the losses to those in the internal state
|
| 389 |
+
"""
|
| 390 |
+
# if stages not a list turn into a list
|
| 391 |
+
if not isinstance(stages, list):
|
| 392 |
+
stages = [stages]
|
| 393 |
+
contains_last_stage = any(stage.is_last for stage in stages)
|
| 394 |
+
|
| 395 |
+
# Return losses if there is a container passed in
|
| 396 |
+
if contains_last_stage and losses is not None:
|
| 397 |
+
if len(self._internal_losses) != self._n_microbatches:
|
| 398 |
+
raise RuntimeError(
|
| 399 |
+
f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Clean external container first
|
| 403 |
+
losses.clear()
|
| 404 |
+
# Copy internal losses to external container
|
| 405 |
+
losses.extend(self._internal_losses)
|
| 406 |
+
|
| 407 |
+
self._internal_losses.clear()
|
| 408 |
+
|
| 409 |
+
@abstractmethod
|
| 410 |
+
def _step_microbatches(
|
| 411 |
+
self,
|
| 412 |
+
arg_mbs: Optional[List] = None,
|
| 413 |
+
kwarg_mbs: Optional[List] = None,
|
| 414 |
+
target_mbs: Optional[List] = None,
|
| 415 |
+
losses: Optional[List] = None,
|
| 416 |
+
):
|
| 417 |
+
"""
|
| 418 |
+
Run one iteration of the pipeline schedule with list of microbatches.
|
| 419 |
+
Will go through all the microbatches according to the schedule
|
| 420 |
+
implementation.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
microbatches: list of microbatch args.
|
| 424 |
+
"""
|
| 425 |
+
raise NotImplementedError
|
| 426 |
+
|
| 427 |
+
@abstractmethod
|
| 428 |
+
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
| 429 |
+
"""
|
| 430 |
+
Run one iteration of the pipeline schedule with *whole-batch* input.
|
| 431 |
+
Will chunk the input into microbatches automatically, and go through the
|
| 432 |
+
microbatches according to the schedule implementation.
|
| 433 |
+
|
| 434 |
+
args: positional arguments to the model (as in non-pipeline case).
|
| 435 |
+
kwargs: keyword arguments to the model (as in non-pipeline case).
|
| 436 |
+
target: target for the loss function.
|
| 437 |
+
losses: a list to store the losses for each microbatch.
|
| 438 |
+
"""
|
| 439 |
+
raise NotImplementedError
|
| 440 |
+
|
| 441 |
+
def _check_inputs(
|
| 442 |
+
self,
|
| 443 |
+
arg_mbs: Optional[List] = None,
|
| 444 |
+
kwarg_mbs: Optional[List] = None,
|
| 445 |
+
target_mbs: Optional[List] = None,
|
| 446 |
+
losses: Optional[List] = None,
|
| 447 |
+
):
|
| 448 |
+
"""
|
| 449 |
+
Pre-process/check inputs
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
def check_type_and_len(mbs, name: str):
|
| 453 |
+
if not isinstance(mbs, list):
|
| 454 |
+
raise TypeError(f"{name} must be a list but got a {type(mbs)}")
|
| 455 |
+
if len(mbs) != self._n_microbatches:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if arg_mbs is not None:
|
| 461 |
+
check_type_and_len(arg_mbs, "arg_mbs")
|
| 462 |
+
else:
|
| 463 |
+
arg_mbs = [()] * self._n_microbatches
|
| 464 |
+
|
| 465 |
+
if kwarg_mbs is not None:
|
| 466 |
+
check_type_and_len(kwarg_mbs, "kwarg_mbs")
|
| 467 |
+
else:
|
| 468 |
+
kwarg_mbs = [{}] * self._n_microbatches
|
| 469 |
+
|
| 470 |
+
if target_mbs is not None:
|
| 471 |
+
check_type_and_len(target_mbs, "target_mbs")
|
| 472 |
+
|
| 473 |
+
if losses is not None:
|
| 474 |
+
if not isinstance(losses, list):
|
| 475 |
+
raise TypeError(f"losses must be a list but got a {type(losses)}")
|
| 476 |
+
|
| 477 |
+
return arg_mbs, kwarg_mbs
|
| 478 |
+
|
| 479 |
+
def _compute_loss(self, output, target):
|
| 480 |
+
return self._loss_fn(output, target) # type: ignore[misc]
|
| 481 |
+
|
| 482 |
+
def _split_inputs(
|
| 483 |
+
self,
|
| 484 |
+
args: Tuple[Any, ...],
|
| 485 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 486 |
+
):
|
| 487 |
+
"""
|
| 488 |
+
Splits a full-batch input into chunks (i.e. microbatches) and returns
|
| 489 |
+
the chunks
|
| 490 |
+
"""
|
| 491 |
+
if args or kwargs:
|
| 492 |
+
args_split, kwargs_split = split_args_kwargs_into_chunks(
|
| 493 |
+
args,
|
| 494 |
+
kwargs,
|
| 495 |
+
self._n_microbatches,
|
| 496 |
+
self._args_chunk_spec,
|
| 497 |
+
self._kwargs_chunk_spec,
|
| 498 |
+
)
|
| 499 |
+
return args_split, kwargs_split
|
| 500 |
+
else:
|
| 501 |
+
# Empty inputs (e.g. when called on middle stages)
|
| 502 |
+
# Return a list of empty tuples/dicts with matching length as chunks
|
| 503 |
+
return [()] * self._n_microbatches, [{}] * self._n_microbatches
|
| 504 |
+
|
| 505 |
+
def _merge_outputs(self, output_chunks: List[Any]) -> Any:
|
| 506 |
+
"""
|
| 507 |
+
Merge output chunks back to a batch state.
|
| 508 |
+
If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
|
| 509 |
+
"""
|
| 510 |
+
return merge_chunks(
|
| 511 |
+
output_chunks,
|
| 512 |
+
self._output_merge_spec,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
|
| 517 |
+
"""
|
| 518 |
+
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
|
| 519 |
+
"""
|
| 520 |
+
if len(p2p_ops) == 0:
|
| 521 |
+
return None
|
| 522 |
+
desc_str = f"{desc}, " if desc else ""
|
| 523 |
+
logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
|
| 524 |
+
return dist.batch_isend_irecv(p2p_ops).pop()
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def _sorted_batch_p2p(
|
| 528 |
+
p2p_ops: List[dist.P2POp], desc: Optional[str] = None
|
| 529 |
+
) -> Dict[int, dist.Work]:
|
| 530 |
+
"""
|
| 531 |
+
Sorts the list of P2P ops by the peer rank, and then calls
|
| 532 |
+
batch_isend_irecv. Return a dictionary of works by peer rank. This function
|
| 533 |
+
helps us avoid hangs in case of skip connections.
|
| 534 |
+
"""
|
| 535 |
+
# Arrange p2p_ops by peer rank:
|
| 536 |
+
# int is the peer rank;
|
| 537 |
+
# List is the list of ops towards the peer
|
| 538 |
+
ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
|
| 539 |
+
work_by_peer: Dict[int, dist.Work] = {}
|
| 540 |
+
if len(p2p_ops) == 0:
|
| 541 |
+
return work_by_peer
|
| 542 |
+
|
| 543 |
+
# Classify the ops by peer rank
|
| 544 |
+
for op in p2p_ops:
|
| 545 |
+
ops_by_peer[op.peer].append(op)
|
| 546 |
+
|
| 547 |
+
# Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
|
| 548 |
+
for peer, ops in sorted(ops_by_peer.items()):
|
| 549 |
+
work_by_peer[peer] = _batch_p2p(ops, desc=desc)
|
| 550 |
+
|
| 551 |
+
return work_by_peer
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class PipelineScheduleSingle(_PipelineSchedule):
|
| 555 |
+
"""
|
| 556 |
+
Base class for single-stage schedules.
|
| 557 |
+
Implements the `step` method.
|
| 558 |
+
Derived classes should implement `_step_microbatches`.
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
def __init__(
|
| 562 |
+
self,
|
| 563 |
+
stage: _PipelineStageBase,
|
| 564 |
+
n_microbatches: int,
|
| 565 |
+
loss_fn: Optional[Callable] = None,
|
| 566 |
+
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
| 567 |
+
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
| 568 |
+
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 569 |
+
):
|
| 570 |
+
# Init parent
|
| 571 |
+
super().__init__(
|
| 572 |
+
n_microbatches=n_microbatches,
|
| 573 |
+
loss_fn=loss_fn,
|
| 574 |
+
args_chunk_spec=args_chunk_spec,
|
| 575 |
+
kwargs_chunk_spec=kwargs_chunk_spec,
|
| 576 |
+
output_merge_spec=output_merge_spec,
|
| 577 |
+
)
|
| 578 |
+
# Self attributes
|
| 579 |
+
self._stage = stage
|
| 580 |
+
self._num_stages = stage.num_stages
|
| 581 |
+
# Set the same has_backward flag for stage object
|
| 582 |
+
self._stage.has_backward = self._has_backward
|
| 583 |
+
|
| 584 |
+
# TODO: later replace this with lazy shape inference during forward
|
| 585 |
+
# Prepare forward send/recv infrastructure for stage
|
| 586 |
+
stage._prepare_forward_infra(n_microbatches)
|
| 587 |
+
if self._has_backward:
|
| 588 |
+
stage._prepare_backward_infra(n_microbatches)
|
| 589 |
+
|
| 590 |
+
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
| 591 |
+
"""
|
| 592 |
+
Run one iteration of the pipeline schedule with *whole-batch* input.
|
| 593 |
+
Will chunk the input into microbatches automatically, and go through the
|
| 594 |
+
microbatches according to the schedule implementation.
|
| 595 |
+
|
| 596 |
+
args: positional arguments to the model (as in non-pipeline case).
|
| 597 |
+
kwargs: keyword arguments to the model (as in non-pipeline case).
|
| 598 |
+
target: target for the loss function.
|
| 599 |
+
losses: a list to store the losses for each microbatch.
|
| 600 |
+
"""
|
| 601 |
+
|
| 602 |
+
# Clean per iteration
|
| 603 |
+
self._stage.clear_runtime_states()
|
| 604 |
+
|
| 605 |
+
# Split inputs into microbatches
|
| 606 |
+
args_split, kwargs_split = self._split_inputs(args, kwargs)
|
| 607 |
+
|
| 608 |
+
# Split target into microbatches
|
| 609 |
+
if target is not None:
|
| 610 |
+
targets_split = list(torch.tensor_split(target, self._n_microbatches))
|
| 611 |
+
else:
|
| 612 |
+
targets_split = None
|
| 613 |
+
|
| 614 |
+
# Run microbatches
|
| 615 |
+
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
|
| 616 |
+
|
| 617 |
+
# Return merged results per original format
|
| 618 |
+
if self._stage.is_last:
|
| 619 |
+
return self._merge_outputs(self._stage.output_chunks)
|
| 620 |
+
else:
|
| 621 |
+
return None
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
class _ScheduleForwardOnly(PipelineScheduleSingle):
|
| 625 |
+
"""
|
| 626 |
+
The forward-only schedule.
|
| 627 |
+
Will go through all the microbatches and perform only the forward pass
|
| 628 |
+
"""
|
| 629 |
+
|
| 630 |
+
def _step_microbatches(
|
| 631 |
+
self,
|
| 632 |
+
arg_mbs: Optional[List] = None,
|
| 633 |
+
kwarg_mbs: Optional[List] = None,
|
| 634 |
+
target_mbs: Optional[List] = None,
|
| 635 |
+
losses: Optional[List] = None,
|
| 636 |
+
):
|
| 637 |
+
"""
|
| 638 |
+
Run one iteration of the pipeline schedule
|
| 639 |
+
"""
|
| 640 |
+
if target_mbs is not None or losses is not None:
|
| 641 |
+
raise RuntimeError(
|
| 642 |
+
"Forward-only schedule does not support loss computation"
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
| 646 |
+
|
| 647 |
+
# Delay send waits
|
| 648 |
+
fwd_sends_to_wait: List[dist.Work] = []
|
| 649 |
+
|
| 650 |
+
# Run microbatches
|
| 651 |
+
for i in range(self._n_microbatches):
|
| 652 |
+
with record_function(f"Forward {i}"):
|
| 653 |
+
ops = self._stage.get_fwd_recv_ops(i)
|
| 654 |
+
works = _sorted_batch_p2p(ops, desc="fwd_recv")
|
| 655 |
+
for work in works.values():
|
| 656 |
+
work.wait()
|
| 657 |
+
|
| 658 |
+
self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
|
| 659 |
+
|
| 660 |
+
ops = self._stage.get_fwd_send_ops(i)
|
| 661 |
+
works = _sorted_batch_p2p(ops, desc="fwd_send")
|
| 662 |
+
fwd_sends_to_wait.extend(works.values())
|
| 663 |
+
|
| 664 |
+
logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
|
| 665 |
+
|
| 666 |
+
# Wait for all forward sends to finish
|
| 667 |
+
# This should not have performance impact because by the time the first
|
| 668 |
+
# backward arrives all the forward sends should have been finished.
|
| 669 |
+
for work in fwd_sends_to_wait:
|
| 670 |
+
work.wait()
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
class ScheduleGPipe(PipelineScheduleSingle):
|
| 674 |
+
"""
|
| 675 |
+
The GPipe schedule.
|
| 676 |
+
Will go through all the microbatches in a fill-drain manner.
|
| 677 |
+
"""
|
| 678 |
+
|
| 679 |
+
def _step_microbatches(
|
| 680 |
+
self,
|
| 681 |
+
arg_mbs: Optional[List] = None,
|
| 682 |
+
kwarg_mbs: Optional[List] = None,
|
| 683 |
+
target_mbs: Optional[List] = None,
|
| 684 |
+
losses: Optional[List] = None,
|
| 685 |
+
):
|
| 686 |
+
"""
|
| 687 |
+
Run one iteration of the pipeline schedule with list of microbatches.
|
| 688 |
+
Will go through all the microbatches according to the GPipe schedule.
|
| 689 |
+
|
| 690 |
+
Args:
|
| 691 |
+
microbatches: list of microbatch args.
|
| 692 |
+
"""
|
| 693 |
+
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
| 694 |
+
|
| 695 |
+
# Delay send waits
|
| 696 |
+
fwd_sends_to_wait: List[dist.Work] = []
|
| 697 |
+
|
| 698 |
+
# Run microbatches
|
| 699 |
+
for i in range(self._n_microbatches):
|
| 700 |
+
with record_function(f"Forward {i}"):
|
| 701 |
+
ops = self._stage.get_fwd_recv_ops(i)
|
| 702 |
+
works = _sorted_batch_p2p(ops, desc="fwd_recv")
|
| 703 |
+
for work in works.values():
|
| 704 |
+
work.wait()
|
| 705 |
+
|
| 706 |
+
output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
|
| 707 |
+
|
| 708 |
+
ops = self._stage.get_fwd_send_ops(i)
|
| 709 |
+
works = _sorted_batch_p2p(ops, desc="fwd_send")
|
| 710 |
+
fwd_sends_to_wait.extend(works.values())
|
| 711 |
+
|
| 712 |
+
logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
|
| 713 |
+
|
| 714 |
+
self._maybe_compute_loss(self._stage, output, target_mbs, i)
|
| 715 |
+
|
| 716 |
+
# Wait for all forward sends to finish
|
| 717 |
+
# This should not have performance impact because by the time the first
|
| 718 |
+
# backward arrives all the forward sends should have been finished.
|
| 719 |
+
for work in fwd_sends_to_wait:
|
| 720 |
+
work.wait()
|
| 721 |
+
|
| 722 |
+
# No loss function, no need to run backward
|
| 723 |
+
if not self._has_backward:
|
| 724 |
+
return
|
| 725 |
+
|
| 726 |
+
# Run backward
|
| 727 |
+
# Delay send waits
|
| 728 |
+
bwd_sends_to_wait: List[dist.Work] = []
|
| 729 |
+
for i in range(self._n_microbatches):
|
| 730 |
+
with record_function(f"Backward {i}"):
|
| 731 |
+
ops = self._stage.get_bwd_recv_ops(i)
|
| 732 |
+
works = _sorted_batch_p2p(ops, desc="bwd_recv")
|
| 733 |
+
for work in works.values():
|
| 734 |
+
work.wait()
|
| 735 |
+
|
| 736 |
+
loss = self._maybe_get_loss(self._stage, i)
|
| 737 |
+
self._stage.backward_one_chunk(i, loss=loss)
|
| 738 |
+
|
| 739 |
+
ops = self._stage.get_bwd_send_ops(i)
|
| 740 |
+
works = _sorted_batch_p2p(ops, desc="bwd_send")
|
| 741 |
+
bwd_sends_to_wait.extend(works.values())
|
| 742 |
+
|
| 743 |
+
logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
|
| 744 |
+
|
| 745 |
+
# Return losses if there is a container passed in
|
| 746 |
+
self._update_losses(self._stage, losses)
|
| 747 |
+
|
| 748 |
+
# Wait for all backward sends to finish
|
| 749 |
+
for work in bwd_sends_to_wait:
|
| 750 |
+
work.wait()
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
class Schedule1F1B(PipelineScheduleSingle):
|
| 754 |
+
"""
|
| 755 |
+
The 1F1B schedule.
|
| 756 |
+
Will perform one forward and one backward on the microbatches in steady state.
|
| 757 |
+
"""
|
| 758 |
+
|
| 759 |
+
def _step_microbatches(
|
| 760 |
+
self,
|
| 761 |
+
arg_mbs: Optional[List] = None,
|
| 762 |
+
kwarg_mbs: Optional[List] = None,
|
| 763 |
+
target_mbs: Optional[List] = None,
|
| 764 |
+
losses: Optional[List] = None,
|
| 765 |
+
):
|
| 766 |
+
"""
|
| 767 |
+
Run one iteration of the pipeline schedule with list of microbatches.
|
| 768 |
+
Will go through all the microbatches according to the 1F1B schedule.
|
| 769 |
+
|
| 770 |
+
Args:
|
| 771 |
+
microbatches: list of microbatch args.
|
| 772 |
+
"""
|
| 773 |
+
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
| 774 |
+
|
| 775 |
+
# Last stage has 1 warmup, second-to-last 2 warmups, ...
|
| 776 |
+
# first stage `num_stages` warmups
|
| 777 |
+
warmup_chunks = min(
|
| 778 |
+
self._n_microbatches,
|
| 779 |
+
self._num_stages - self._stage.stage_index,
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
# Chunk counters
|
| 783 |
+
fwd_mb_index = 0
|
| 784 |
+
bwd_mb_index = 0
|
| 785 |
+
weight_stage_mb_index = 0
|
| 786 |
+
|
| 787 |
+
# Warmup phase
|
| 788 |
+
send_work = None
|
| 789 |
+
fwd_sends = []
|
| 790 |
+
for _ in range(warmup_chunks):
|
| 791 |
+
# Receive activations
|
| 792 |
+
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
|
| 793 |
+
if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
|
| 794 |
+
recv_work.wait()
|
| 795 |
+
|
| 796 |
+
# Compute
|
| 797 |
+
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
| 798 |
+
|
| 799 |
+
# Clear previous chunk's forward sends (hopefully they have well
|
| 800 |
+
# finished, otherwise, we are heavily communication bound, in which
|
| 801 |
+
# case it doesn't create a lot of benefit to compute next chunk
|
| 802 |
+
# eagerly either)
|
| 803 |
+
if send_work:
|
| 804 |
+
send_work.wait()
|
| 805 |
+
|
| 806 |
+
# Send activations
|
| 807 |
+
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
|
| 808 |
+
if fwd_mb_index != warmup_chunks - 1:
|
| 809 |
+
# Safe to fire
|
| 810 |
+
send_work = _batch_p2p(fwd_sends, desc="fwd_send")
|
| 811 |
+
# otherwise:
|
| 812 |
+
# The last foward send is left for fuse with first 1B in 1B1F below
|
| 813 |
+
|
| 814 |
+
# Compute loss
|
| 815 |
+
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
| 816 |
+
fwd_mb_index += 1
|
| 817 |
+
|
| 818 |
+
# Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
|
| 819 |
+
|
| 820 |
+
# 1B1F phase
|
| 821 |
+
while True: # Don't worry, we have a break inside
|
| 822 |
+
# We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
|
| 823 |
+
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
|
| 824 |
+
|
| 825 |
+
# Now, we need to fire the fwd_sends and bwd_recvs together
|
| 826 |
+
if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"):
|
| 827 |
+
fuse_work.wait()
|
| 828 |
+
|
| 829 |
+
# Backward one chunk
|
| 830 |
+
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
|
| 831 |
+
self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
|
| 832 |
+
|
| 833 |
+
# Get the bwd send ops, but don't fire, to be fused with the 1F below
|
| 834 |
+
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
|
| 835 |
+
bwd_mb_index += 1
|
| 836 |
+
|
| 837 |
+
if fwd_mb_index == self._n_microbatches:
|
| 838 |
+
# We are done with 1B1F, so break with some left-over bwd_sends
|
| 839 |
+
break
|
| 840 |
+
|
| 841 |
+
# We prepare 1F of the `1B1F`
|
| 842 |
+
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
|
| 843 |
+
|
| 844 |
+
# Fuse it with bwd_sends above
|
| 845 |
+
if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"):
|
| 846 |
+
fuse_work.wait()
|
| 847 |
+
|
| 848 |
+
# Now do the fwd
|
| 849 |
+
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
| 850 |
+
|
| 851 |
+
# Compute loss
|
| 852 |
+
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
| 853 |
+
|
| 854 |
+
# Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
|
| 855 |
+
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
|
| 856 |
+
fwd_mb_index += 1
|
| 857 |
+
|
| 858 |
+
# Remember we still have some bwd_sends left over after the break? Now it is time to fire it
|
| 859 |
+
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
|
| 860 |
+
|
| 861 |
+
# Cooldown
|
| 862 |
+
while bwd_mb_index < self._n_microbatches:
|
| 863 |
+
# prepare bwd recv ops
|
| 864 |
+
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
|
| 865 |
+
if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
|
| 866 |
+
recv_work.wait()
|
| 867 |
+
|
| 868 |
+
# Backward one chunk
|
| 869 |
+
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
|
| 870 |
+
self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
|
| 871 |
+
|
| 872 |
+
# Clear previous chunk's backward sends (hopefully they have well finished)
|
| 873 |
+
if send_work:
|
| 874 |
+
send_work.wait()
|
| 875 |
+
|
| 876 |
+
# Get the bwd send ops, fire it
|
| 877 |
+
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
|
| 878 |
+
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
|
| 879 |
+
bwd_mb_index += 1
|
| 880 |
+
|
| 881 |
+
# Wait for the last backward send to finish
|
| 882 |
+
if send_work:
|
| 883 |
+
send_work.wait()
|
| 884 |
+
|
| 885 |
+
# Return losses if there is a container passed in
|
| 886 |
+
self._update_losses(self._stage, losses)
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
def _add_unshard_reshard(
|
| 890 |
+
compute_actions: List[Optional[_Action]],
|
| 891 |
+
max_active_stages: int = 3,
|
| 892 |
+
) -> List[_Action]:
|
| 893 |
+
"""Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP.
|
| 894 |
+
|
| 895 |
+
UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
|
| 896 |
+
RESHARD does the opposite, releasing memory (but doing no commmunication)
|
| 897 |
+
|
| 898 |
+
We abandon the "timestep lock" during lowering
|
| 899 |
+
|
| 900 |
+
max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
|
| 901 |
+
3 stages is probably the thing we want?
|
| 902 |
+
(to account for having one f and one b active, and something else prefetching?)
|
| 903 |
+
"""
|
| 904 |
+
|
| 905 |
+
def next_stage_indices(
|
| 906 |
+
count: int, next_actions: List[Optional[_Action]]
|
| 907 |
+
) -> List[int]:
|
| 908 |
+
"""Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
|
| 909 |
+
seen: Set[int] = set()
|
| 910 |
+
ret: List[int] = []
|
| 911 |
+
|
| 912 |
+
for a in next_actions:
|
| 913 |
+
if a is not None and a.stage_index not in seen:
|
| 914 |
+
seen.add(a.stage_index)
|
| 915 |
+
ret.append(a.stage_index)
|
| 916 |
+
if len(ret) == count:
|
| 917 |
+
break
|
| 918 |
+
return ret
|
| 919 |
+
|
| 920 |
+
active_stages: Set[int] = set()
|
| 921 |
+
fsdp_aware_actions: List[_Action] = []
|
| 922 |
+
|
| 923 |
+
def _unshard(stage_index: int):
|
| 924 |
+
active_stages.add(stage_index)
|
| 925 |
+
fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
|
| 926 |
+
|
| 927 |
+
def _reshard(stage_index: int):
|
| 928 |
+
active_stages.remove(stage_index)
|
| 929 |
+
fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
|
| 930 |
+
|
| 931 |
+
for i, action in enumerate(compute_actions):
|
| 932 |
+
if action is None:
|
| 933 |
+
continue
|
| 934 |
+
|
| 935 |
+
# We prefetch the next N stages we'll see, dropping existing stages to make room
|
| 936 |
+
next_n = next_stage_indices(max_active_stages, compute_actions[i:])
|
| 937 |
+
# Fetch needs to be ordered correctly, so don't use a set
|
| 938 |
+
fetch = list(filter(lambda s: s not in active_stages, next_n))
|
| 939 |
+
# Unclear what the best policy is for eviction, but we can maintain order so we do
|
| 940 |
+
evict = list(filter(lambda s: s not in next_n, active_stages))
|
| 941 |
+
|
| 942 |
+
# logger.debug(
|
| 943 |
+
# "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
|
| 944 |
+
# i,
|
| 945 |
+
# active_stages,
|
| 946 |
+
# fetch,
|
| 947 |
+
# evict,
|
| 948 |
+
# )
|
| 949 |
+
|
| 950 |
+
for stage in evict:
|
| 951 |
+
_reshard(stage)
|
| 952 |
+
for stage in fetch:
|
| 953 |
+
_unshard(stage)
|
| 954 |
+
fsdp_aware_actions.append(action)
|
| 955 |
+
|
| 956 |
+
return fsdp_aware_actions
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
def _add_send_recv(
|
| 960 |
+
compute_actions: Dict[int, List[_Action]],
|
| 961 |
+
stage_to_rank: Callable[[int], int],
|
| 962 |
+
num_stages: int,
|
| 963 |
+
) -> Dict[int, List[_Action]]:
|
| 964 |
+
comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions}
|
| 965 |
+
|
| 966 |
+
def _has_comms(action: _Action) -> bool:
|
| 967 |
+
if action.computation_type == F:
|
| 968 |
+
return action.stage_index != num_stages - 1
|
| 969 |
+
elif action.computation_type == B:
|
| 970 |
+
return action.stage_index != 0
|
| 971 |
+
return False
|
| 972 |
+
|
| 973 |
+
def _get_comms(action: _Action) -> Tuple[_Action, _Action]:
|
| 974 |
+
assert _has_comms(action), f"{action} is not a valid comm action"
|
| 975 |
+
stage_idx = action.stage_index
|
| 976 |
+
ctype = action.computation_type
|
| 977 |
+
mb_idx = action.microbatch_index
|
| 978 |
+
send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
|
| 979 |
+
recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
|
| 980 |
+
recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
|
| 981 |
+
return send, recv
|
| 982 |
+
|
| 983 |
+
def _ready_to_schedule(
|
| 984 |
+
action: Optional[_Action], prev_actions: List[_Action]
|
| 985 |
+
) -> bool:
|
| 986 |
+
"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
|
| 987 |
+
This helps ensure a sane (non-hanging) ordering of sends and recvs.
|
| 988 |
+
But it also means we might not be able to schedule our next compute action yet.
|
| 989 |
+
"""
|
| 990 |
+
if action is None:
|
| 991 |
+
return True
|
| 992 |
+
elif action.computation_type == F and not action.stage_index == 0:
|
| 993 |
+
expected_recv = _Action(
|
| 994 |
+
action.stage_index,
|
| 995 |
+
RECV_F if action.computation_type == F else RECV_B,
|
| 996 |
+
action.microbatch_index,
|
| 997 |
+
)
|
| 998 |
+
return expected_recv in prev_actions
|
| 999 |
+
elif action.computation_type == B and not action.stage_index == num_stages - 1:
|
| 1000 |
+
expected_recv = _Action(
|
| 1001 |
+
action.stage_index,
|
| 1002 |
+
RECV_F if action.computation_type == F else RECV_B,
|
| 1003 |
+
action.microbatch_index,
|
| 1004 |
+
)
|
| 1005 |
+
return expected_recv in prev_actions
|
| 1006 |
+
else:
|
| 1007 |
+
return True
|
| 1008 |
+
|
| 1009 |
+
while compute_actions:
|
| 1010 |
+
progress = False
|
| 1011 |
+
# go in order of ranks even if dict keys aren't ordered
|
| 1012 |
+
for rank in range(len(compute_actions)):
|
| 1013 |
+
assert len(compute_actions[rank]) > 0
|
| 1014 |
+
action = compute_actions[rank][0]
|
| 1015 |
+
|
| 1016 |
+
if not _ready_to_schedule(action, comm_actions[rank]):
|
| 1017 |
+
continue
|
| 1018 |
+
|
| 1019 |
+
if action is not None:
|
| 1020 |
+
comm_actions[rank].append(action)
|
| 1021 |
+
if _has_comms(action):
|
| 1022 |
+
send, recv = _get_comms(action)
|
| 1023 |
+
# TODO we can avoid send/recv if the 2 stages are on the same rank.
|
| 1024 |
+
# should we avoid that in the runtime or here?
|
| 1025 |
+
comm_actions[rank].append(send)
|
| 1026 |
+
comm_actions[stage_to_rank(recv.stage_index)].append(recv)
|
| 1027 |
+
|
| 1028 |
+
compute_actions[rank].pop(0)
|
| 1029 |
+
if len(compute_actions[rank]) == 0:
|
| 1030 |
+
del compute_actions[rank]
|
| 1031 |
+
progress = True
|
| 1032 |
+
assert progress, "Malformed compute schedule, can't schedule sends/recvs"
|
| 1033 |
+
return comm_actions
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
class PipelineScheduleMulti(_PipelineSchedule):
|
| 1037 |
+
"""
|
| 1038 |
+
Base class for multi-stage schedules.
|
| 1039 |
+
Implements the `step` method.
|
| 1040 |
+
"""
|
| 1041 |
+
|
| 1042 |
+
def __init__(
|
| 1043 |
+
self,
|
| 1044 |
+
stages: List[_PipelineStageBase],
|
| 1045 |
+
n_microbatches: int,
|
| 1046 |
+
loss_fn: Optional[Callable] = None,
|
| 1047 |
+
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
| 1048 |
+
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
| 1049 |
+
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 1050 |
+
stage_index_to_group_rank: Optional[Dict[int, int]] = None,
|
| 1051 |
+
use_full_backward: bool = True,
|
| 1052 |
+
):
|
| 1053 |
+
if len(stages) <= 1:
|
| 1054 |
+
raise ValueError(
|
| 1055 |
+
f"Multi-stage schedule expects at least two stages but got {len(stages)}"
|
| 1056 |
+
)
|
| 1057 |
+
# Init parent
|
| 1058 |
+
super().__init__(
|
| 1059 |
+
n_microbatches=n_microbatches,
|
| 1060 |
+
loss_fn=loss_fn,
|
| 1061 |
+
args_chunk_spec=args_chunk_spec,
|
| 1062 |
+
kwargs_chunk_spec=kwargs_chunk_spec,
|
| 1063 |
+
output_merge_spec=output_merge_spec,
|
| 1064 |
+
)
|
| 1065 |
+
# Self attributes
|
| 1066 |
+
self._stages = stages
|
| 1067 |
+
self._num_stages = stages[0].num_stages
|
| 1068 |
+
self.pp_group_size = stages[0].group_size
|
| 1069 |
+
self.rank = stages[0].group_rank
|
| 1070 |
+
# Set the pipeline stage states
|
| 1071 |
+
if stage_index_to_group_rank is not None:
|
| 1072 |
+
for stage in self._stages:
|
| 1073 |
+
stage.stage_index_to_group_rank = stage_index_to_group_rank
|
| 1074 |
+
self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank
|
| 1075 |
+
|
| 1076 |
+
# Set the same has_backward flag for stage object
|
| 1077 |
+
for stage in self._stages:
|
| 1078 |
+
stage.has_backward = self._has_backward
|
| 1079 |
+
|
| 1080 |
+
self._should_compute_loss = (
|
| 1081 |
+
lambda stage: stage.is_last and self._loss_fn is not None
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
# This will be set during init of derived schedules
|
| 1085 |
+
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
| 1086 |
+
self.use_full_backward = use_full_backward
|
| 1087 |
+
|
| 1088 |
+
# TODO: later replace this with lazy shape inference during forward
|
| 1089 |
+
# Prepare forward send/recv infrastructure for stage
|
| 1090 |
+
for stage in self._stages:
|
| 1091 |
+
stage._prepare_forward_infra(n_microbatches)
|
| 1092 |
+
if self._has_backward:
|
| 1093 |
+
stage._prepare_backward_infra(n_microbatches)
|
| 1094 |
+
|
| 1095 |
+
def _dump_csv(self, filename):
|
| 1096 |
+
"""Dump a CSV representation of the schedule into a file with the provided filename."""
|
| 1097 |
+
with open(filename, "w", newline="") as csvfile:
|
| 1098 |
+
writer = csv.writer(csvfile)
|
| 1099 |
+
for rank in self.pipeline_order:
|
| 1100 |
+
writer.writerow(self.pipeline_order[rank])
|
| 1101 |
+
|
| 1102 |
+
def _validate_schedule(self):
|
| 1103 |
+
# TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554
|
| 1104 |
+
def _validate_rank_actions(
|
| 1105 |
+
actions: Dict[int, List[_Action | None]],
|
| 1106 |
+
num_stages: int,
|
| 1107 |
+
num_microbatches: int,
|
| 1108 |
+
):
|
| 1109 |
+
# We will count all the actions per stage and ensure they happen in a valid order
|
| 1110 |
+
# (e.g. F before B before W for a given microbatch)
|
| 1111 |
+
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
|
| 1112 |
+
stage_id: {
|
| 1113 |
+
F: set(),
|
| 1114 |
+
B: set(),
|
| 1115 |
+
W: set(),
|
| 1116 |
+
}
|
| 1117 |
+
for stage_id in range(num_stages)
|
| 1118 |
+
}
|
| 1119 |
+
for rank in actions:
|
| 1120 |
+
for action in actions[rank]:
|
| 1121 |
+
if action is None:
|
| 1122 |
+
continue
|
| 1123 |
+
assert isinstance(
|
| 1124 |
+
action, _Action
|
| 1125 |
+
), f"Got an invalid action: {action}, expected instance of _Action"
|
| 1126 |
+
s_id = action.stage_index
|
| 1127 |
+
ctype = action.computation_type
|
| 1128 |
+
mb_id = action.microbatch_index
|
| 1129 |
+
if ctype == F:
|
| 1130 |
+
stage_actions[s_id][F].add(mb_id)
|
| 1131 |
+
elif ctype == B:
|
| 1132 |
+
assert (
|
| 1133 |
+
mb_id in stage_actions[s_id][F]
|
| 1134 |
+
), f"Running Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
| 1135 |
+
stage_actions[s_id][B].add(mb_id)
|
| 1136 |
+
elif ctype == W:
|
| 1137 |
+
assert (
|
| 1138 |
+
not self.use_full_backward
|
| 1139 |
+
), "Schedule contains 'W' actions, but is configured to use full backward"
|
| 1140 |
+
assert (
|
| 1141 |
+
mb_id in stage_actions[s_id][B]
|
| 1142 |
+
), f"Running Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
|
| 1143 |
+
stage_actions[s_id][W].add(mb_id)
|
| 1144 |
+
|
| 1145 |
+
for s_id in stage_actions:
|
| 1146 |
+
for ctype in (F, B, W):
|
| 1147 |
+
stage_mb = len(stage_actions[s_id][ctype])
|
| 1148 |
+
assert (
|
| 1149 |
+
stage_mb == num_microbatches
|
| 1150 |
+
), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}"
|
| 1151 |
+
|
| 1152 |
+
assert (
|
| 1153 |
+
len(self.pipeline_order) == self.pp_group_size
|
| 1154 |
+
), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}"
|
| 1155 |
+
for rank in range(self.pp_group_size):
|
| 1156 |
+
assert (
|
| 1157 |
+
rank in self.pipeline_order
|
| 1158 |
+
), f"Schedule is missing actions for rank {rank}"
|
| 1159 |
+
_validate_rank_actions(
|
| 1160 |
+
self.pipeline_order,
|
| 1161 |
+
self._num_stages,
|
| 1162 |
+
self._n_microbatches,
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
def _load_csv(self, filename, format="compute_only"):
|
| 1166 |
+
"""Load a CSV representation of the schedule from a file with the provided filename.
|
| 1167 |
+
This API will most likely get renamed/refactored so is marked as internal for now.
|
| 1168 |
+
|
| 1169 |
+
format must be "compute_only" for PipelineScheduleMulti
|
| 1170 |
+
"""
|
| 1171 |
+
assert format == "compute_only"
|
| 1172 |
+
with open(filename, newline="") as csvfile:
|
| 1173 |
+
reader = csv.reader(csvfile)
|
| 1174 |
+
for rank, row in enumerate(reader):
|
| 1175 |
+
self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
|
| 1176 |
+
self._validate_schedule()
|
| 1177 |
+
|
| 1178 |
+
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
| 1179 |
+
"""
|
| 1180 |
+
Run one iteration of the pipeline schedule with *whole-batch* input.
|
| 1181 |
+
Will chunk the input into microbatches automatically, and go through the
|
| 1182 |
+
microbatches according to the schedule implementation.
|
| 1183 |
+
|
| 1184 |
+
args: positional arguments to the model (as in non-pipeline case).
|
| 1185 |
+
kwargs: keyword arguments to the model (as in non-pipeline case).
|
| 1186 |
+
target: target for the loss function.
|
| 1187 |
+
losses: a list to store the losses for each microbatch.
|
| 1188 |
+
"""
|
| 1189 |
+
|
| 1190 |
+
# Clean per iteration
|
| 1191 |
+
for stage in self._stages:
|
| 1192 |
+
stage.clear_runtime_states()
|
| 1193 |
+
|
| 1194 |
+
# Split inputs into microbatches
|
| 1195 |
+
args_split, kwargs_split = self._split_inputs(args, kwargs)
|
| 1196 |
+
|
| 1197 |
+
# Split target into microbatches
|
| 1198 |
+
if target is not None:
|
| 1199 |
+
targets_split = list(torch.tensor_split(target, self._n_microbatches))
|
| 1200 |
+
else:
|
| 1201 |
+
targets_split = None
|
| 1202 |
+
|
| 1203 |
+
# Run microbatches
|
| 1204 |
+
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
|
| 1205 |
+
|
| 1206 |
+
# Return merged results per original format
|
| 1207 |
+
for stage in self._stages:
|
| 1208 |
+
if stage.is_last:
|
| 1209 |
+
return self._merge_outputs(stage.output_chunks)
|
| 1210 |
+
# Does not contain the last stage
|
| 1211 |
+
return None
|
| 1212 |
+
|
| 1213 |
+
def _step_microbatches(
|
| 1214 |
+
self,
|
| 1215 |
+
arg_mbs: Optional[List] = None,
|
| 1216 |
+
kwarg_mbs: Optional[List] = None,
|
| 1217 |
+
target_mbs: Optional[List] = None,
|
| 1218 |
+
losses: Optional[List] = None,
|
| 1219 |
+
):
|
| 1220 |
+
"""
|
| 1221 |
+
Operate on the microbatches for looped schedules (multiple stages on each rank).
|
| 1222 |
+
|
| 1223 |
+
TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
|
| 1224 |
+
not support models with skip connections.
|
| 1225 |
+
"""
|
| 1226 |
+
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
| 1227 |
+
|
| 1228 |
+
# Based on the plan in Step 1 created in __init__:
|
| 1229 |
+
# 2. Perform communication based on the pipeline_order
|
| 1230 |
+
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
|
| 1231 |
+
stage.stage_index: stage for stage in self._stages
|
| 1232 |
+
}
|
| 1233 |
+
|
| 1234 |
+
# determine prev_rank and next_rank based on which ranks are next to
|
| 1235 |
+
# the stages in the pipeline_order
|
| 1236 |
+
all_prev_ranks: Set[int] = set()
|
| 1237 |
+
all_next_ranks: Set[int] = set()
|
| 1238 |
+
for stage_index in stage_index_to_stage.keys():
|
| 1239 |
+
# TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
|
| 1240 |
+
if stage_index > 0:
|
| 1241 |
+
all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
|
| 1242 |
+
if stage_index < self._num_stages - 1:
|
| 1243 |
+
all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
|
| 1244 |
+
|
| 1245 |
+
for time_step, action in enumerate(self.pipeline_order[self.rank]):
|
| 1246 |
+
try:
|
| 1247 |
+
ops: List[dist.P2POp] = []
|
| 1248 |
+
if action is not None:
|
| 1249 |
+
computation_type = action.computation_type
|
| 1250 |
+
mb_index = action.microbatch_index
|
| 1251 |
+
stage_index = action.stage_index
|
| 1252 |
+
assert (
|
| 1253 |
+
mb_index is not None
|
| 1254 |
+
), "All currently supported action types require valid microbatch_index"
|
| 1255 |
+
if computation_type == _ComputationType.FORWARD:
|
| 1256 |
+
# perform forward computation
|
| 1257 |
+
stage = stage_index_to_stage[stage_index]
|
| 1258 |
+
output = stage.forward_one_chunk(
|
| 1259 |
+
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
|
| 1260 |
+
)
|
| 1261 |
+
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
|
| 1262 |
+
ops.extend(stage.get_fwd_send_ops(mb_index))
|
| 1263 |
+
elif computation_type == _ComputationType.BACKWARD:
|
| 1264 |
+
# perform backward computation
|
| 1265 |
+
stage = stage_index_to_stage[stage_index]
|
| 1266 |
+
loss = self._maybe_get_loss(stage, mb_index)
|
| 1267 |
+
stage.backward_one_chunk(
|
| 1268 |
+
mb_index, loss=loss, full_backward=self.use_full_backward
|
| 1269 |
+
)
|
| 1270 |
+
ops.extend(stage.get_bwd_send_ops(mb_index))
|
| 1271 |
+
elif computation_type == _ComputationType.WEIGHT:
|
| 1272 |
+
# perform weight update
|
| 1273 |
+
if self.use_full_backward:
|
| 1274 |
+
raise ValueError(
|
| 1275 |
+
f"We detected a weight update in the pipeline schedule, but \
|
| 1276 |
+
{self.use_full_backward=}"
|
| 1277 |
+
)
|
| 1278 |
+
stage = stage_index_to_stage[stage_index]
|
| 1279 |
+
stage.backward_weight_one_chunk(mb_index)
|
| 1280 |
+
else:
|
| 1281 |
+
raise ValueError(f"Unknown computation type {computation_type}")
|
| 1282 |
+
|
| 1283 |
+
# Look at the neighboring ranks for this current timestep and determine whether
|
| 1284 |
+
# this current rank needs to do any recv communication
|
| 1285 |
+
for prev_rank in all_prev_ranks:
|
| 1286 |
+
prev_rank_ops = self.pipeline_order[prev_rank]
|
| 1287 |
+
prev_rank_action = None
|
| 1288 |
+
if time_step < len(prev_rank_ops):
|
| 1289 |
+
prev_rank_action = prev_rank_ops[time_step]
|
| 1290 |
+
if prev_rank_action is not None:
|
| 1291 |
+
computation_type = prev_rank_action.computation_type
|
| 1292 |
+
mb_index = prev_rank_action.microbatch_index
|
| 1293 |
+
stage_index = prev_rank_action.stage_index
|
| 1294 |
+
assert (
|
| 1295 |
+
mb_index is not None
|
| 1296 |
+
), "All currently supported action types require valid microbatch_index"
|
| 1297 |
+
# Only handle sends for the forward from a previous rank
|
| 1298 |
+
if computation_type == _ComputationType.FORWARD:
|
| 1299 |
+
# If not the last stage, then receive fwd activations
|
| 1300 |
+
if stage_index + 1 in stage_index_to_stage:
|
| 1301 |
+
# TODO: We are assuming that stage will always receive from stage-1
|
| 1302 |
+
# however that is not necessarily true of get_fwd_recv_ops
|
| 1303 |
+
stage = stage_index_to_stage[stage_index + 1]
|
| 1304 |
+
ops.extend(stage.get_fwd_recv_ops(mb_index))
|
| 1305 |
+
elif (
|
| 1306 |
+
computation_type == _ComputationType.BACKWARD
|
| 1307 |
+
or computation_type == _ComputationType.WEIGHT
|
| 1308 |
+
):
|
| 1309 |
+
# Previous rank doing backward or weight update has no influence for the current rank forward recv
|
| 1310 |
+
pass
|
| 1311 |
+
else:
|
| 1312 |
+
raise ValueError(
|
| 1313 |
+
f"Unknown computation type {computation_type}"
|
| 1314 |
+
)
|
| 1315 |
+
for next_rank in all_next_ranks:
|
| 1316 |
+
next_rank_ops = self.pipeline_order[next_rank]
|
| 1317 |
+
next_rank_action = None
|
| 1318 |
+
if time_step < len(next_rank_ops):
|
| 1319 |
+
next_rank_action = next_rank_ops[time_step]
|
| 1320 |
+
if next_rank_action is not None:
|
| 1321 |
+
computation_type = next_rank_action.computation_type
|
| 1322 |
+
mb_index = next_rank_action.microbatch_index
|
| 1323 |
+
stage_index = next_rank_action.stage_index
|
| 1324 |
+
assert (
|
| 1325 |
+
mb_index is not None
|
| 1326 |
+
), "All currently supported action types require valid microbatch_index"
|
| 1327 |
+
# Only handle receives for the backwards from a next rank
|
| 1328 |
+
if (
|
| 1329 |
+
computation_type == _ComputationType.FORWARD
|
| 1330 |
+
or computation_type == _ComputationType.WEIGHT
|
| 1331 |
+
):
|
| 1332 |
+
# Next rank doing forward or weight update has no influence for the current rank backward recv
|
| 1333 |
+
pass
|
| 1334 |
+
elif computation_type == _ComputationType.BACKWARD:
|
| 1335 |
+
# If not the first stage, then receive bwd gradients
|
| 1336 |
+
if stage_index - 1 in stage_index_to_stage:
|
| 1337 |
+
# TODO: We are assuming that stage will always receive from stage+1
|
| 1338 |
+
# however that is not necessarily true of get_bwd_recv_ops
|
| 1339 |
+
stage = stage_index_to_stage[stage_index - 1]
|
| 1340 |
+
ops.extend(stage.get_bwd_recv_ops(mb_index))
|
| 1341 |
+
else:
|
| 1342 |
+
raise ValueError(
|
| 1343 |
+
f"Unknown computation type {computation_type}"
|
| 1344 |
+
)
|
| 1345 |
+
|
| 1346 |
+
# do the communication
|
| 1347 |
+
if ops:
|
| 1348 |
+
_batch_p2p(ops).wait()
|
| 1349 |
+
except Exception as e:
|
| 1350 |
+
logger.error(
|
| 1351 |
+
"[Rank %s] pipeline schedule %s caught the following exception \
|
| 1352 |
+
at time_step %s when running action %s",
|
| 1353 |
+
self.rank,
|
| 1354 |
+
self.__class__.__name__,
|
| 1355 |
+
time_step,
|
| 1356 |
+
action,
|
| 1357 |
+
)
|
| 1358 |
+
logger.error("%s", _format_pipeline_order(self.pipeline_order))
|
| 1359 |
+
raise e
|
| 1360 |
+
# Return losses if there is a container passed in
|
| 1361 |
+
self._update_losses(self._stages, losses)
|
| 1362 |
+
|
| 1363 |
+
|
| 1364 |
+
class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
| 1365 |
+
"""
|
| 1366 |
+
Provides a simple runtime that requires a 'schedule IR' including specified communication operations.
|
| 1367 |
+
|
| 1368 |
+
Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
|
| 1369 |
+
subclassed and the subclass can be responsible for creating a schedule IR.
|
| 1370 |
+
"""
|
| 1371 |
+
|
| 1372 |
+
def _load_actions(
|
| 1373 |
+
self,
|
| 1374 |
+
actions: Dict[int, List[Optional[_Action]]],
|
| 1375 |
+
format: str = "compute_only",
|
| 1376 |
+
):
|
| 1377 |
+
"""
|
| 1378 |
+
Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
|
| 1379 |
+
communication actions. Stores the schedule in self, and must be called before running step_mo()
|
| 1380 |
+
"""
|
| 1381 |
+
assert (
|
| 1382 |
+
self.stage_index_to_group_rank is not None
|
| 1383 |
+
), "stage_index_to_group_rank is required for PipelineScheduleRuntime"
|
| 1384 |
+
self.pipeline_order_with_comms: Dict[int, List[_Action]] = {}
|
| 1385 |
+
if format == "compute_comms":
|
| 1386 |
+
for rank in actions:
|
| 1387 |
+
self.pipeline_order_with_comms[rank] = []
|
| 1388 |
+
for action in actions[rank]:
|
| 1389 |
+
assert action is not None
|
| 1390 |
+
self.pipeline_order_with_comms[rank].append(action)
|
| 1391 |
+
# TODO what level of validation should we offer for compute+comms schedule?
|
| 1392 |
+
elif format == "compute_only":
|
| 1393 |
+
# Perform schedule lowering
|
| 1394 |
+
for rank in actions:
|
| 1395 |
+
self.pipeline_order_with_comms[rank] = _add_unshard_reshard(
|
| 1396 |
+
actions[rank]
|
| 1397 |
+
)
|
| 1398 |
+
|
| 1399 |
+
self.pipeline_order_with_comms = _add_send_recv(
|
| 1400 |
+
self.pipeline_order_with_comms,
|
| 1401 |
+
stage_to_rank=lambda s: self.stage_index_to_group_rank[s],
|
| 1402 |
+
num_stages=self._num_stages,
|
| 1403 |
+
)
|
| 1404 |
+
else:
|
| 1405 |
+
raise NotImplementedError(f"{format=} is not implemented")
|
| 1406 |
+
|
| 1407 |
+
def _load_csv(self, filename: str, format: str = "compute_only"):
|
| 1408 |
+
"""Loads a csv in simple format and then lowers it to include comunication actions
|
| 1409 |
+
|
| 1410 |
+
format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes
|
| 1411 |
+
will automatically be run to generate a compute_comms schedule.
|
| 1412 |
+
"""
|
| 1413 |
+
if format == "compute_only":
|
| 1414 |
+
# this will populate self.pipeline_order
|
| 1415 |
+
super()._load_csv(filename)
|
| 1416 |
+
# this will populate self.pipeline_order_with_comms
|
| 1417 |
+
self._load_actions(self.pipeline_order)
|
| 1418 |
+
elif format == "compute_comms":
|
| 1419 |
+
actions = {}
|
| 1420 |
+
with open(filename, newline="") as csvfile:
|
| 1421 |
+
reader = csv.reader(csvfile)
|
| 1422 |
+
for rank, row in enumerate(reader):
|
| 1423 |
+
actions[rank] = [_Action.from_str(s) for s in row]
|
| 1424 |
+
self._load_actions(actions, format=format)
|
| 1425 |
+
else:
|
| 1426 |
+
raise NotImplementedError(f"{format=} is not implemented")
|
| 1427 |
+
|
| 1428 |
+
def _dump_csv(self, filename: str):
|
| 1429 |
+
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
|
| 1430 |
+
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
|
| 1431 |
+
# that it does not exist if it was created from a compute_comms schedule.
|
| 1432 |
+
assert (
|
| 1433 |
+
self.pipeline_order_with_comms is not None
|
| 1434 |
+
), "Must initialize compute_comms schedule before dump_csv"
|
| 1435 |
+
with open(filename, "w", newline="") as csvfile:
|
| 1436 |
+
writer = csv.writer(csvfile)
|
| 1437 |
+
for rank in self.pipeline_order_with_comms:
|
| 1438 |
+
writer.writerow(self.pipeline_order_with_comms[rank])
|
| 1439 |
+
|
| 1440 |
+
def _step_microbatches(
|
| 1441 |
+
self,
|
| 1442 |
+
arg_mbs: Optional[List] = None,
|
| 1443 |
+
kwarg_mbs: Optional[List] = None,
|
| 1444 |
+
target_mbs: Optional[List] = None,
|
| 1445 |
+
losses: Optional[List] = None,
|
| 1446 |
+
):
|
| 1447 |
+
"""
|
| 1448 |
+
Operate on the microbatches for looped schedules (multiple stages on each rank).
|
| 1449 |
+
|
| 1450 |
+
TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
|
| 1451 |
+
not support models with skip connections.
|
| 1452 |
+
"""
|
| 1453 |
+
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
| 1454 |
+
|
| 1455 |
+
# Based on the plan in Step 1 created in __init__:
|
| 1456 |
+
# 2. Perform communication based on the pipeline_order
|
| 1457 |
+
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
|
| 1458 |
+
stage.stage_index: stage for stage in self._stages
|
| 1459 |
+
}
|
| 1460 |
+
|
| 1461 |
+
assert (
|
| 1462 |
+
self.pipeline_order_with_comms is not None
|
| 1463 |
+
), "Must call _load_actions() before calling _step_microbatches()"
|
| 1464 |
+
|
| 1465 |
+
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
|
| 1466 |
+
bwd_recv_ops: Dict[Tuple[int, int], Work] = {}
|
| 1467 |
+
fwd_recv_ops: Dict[Tuple[int, int], Work] = {}
|
| 1468 |
+
|
| 1469 |
+
# send ops should be waited on before step() exists, mainly for hygeine
|
| 1470 |
+
send_ops: List[Work] = []
|
| 1471 |
+
|
| 1472 |
+
# we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
|
| 1473 |
+
unshard_ops: Dict[int, UnshardHandle] = {}
|
| 1474 |
+
unsharded_stages = set()
|
| 1475 |
+
|
| 1476 |
+
def _assert_unsharded(stage_idx: int):
|
| 1477 |
+
"""If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""
|
| 1478 |
+
if stage_idx in unshard_ops:
|
| 1479 |
+
unshard_ops[stage_idx].wait()
|
| 1480 |
+
del unshard_ops[stage_idx]
|
| 1481 |
+
unsharded_stages.add(stage_idx)
|
| 1482 |
+
assert (
|
| 1483 |
+
stage_idx in unsharded_stages
|
| 1484 |
+
), f"Attempted to compute on sharded {stage_idx=}"
|
| 1485 |
+
|
| 1486 |
+
for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
|
| 1487 |
+
try:
|
| 1488 |
+
comp_type = action.computation_type
|
| 1489 |
+
mb_index: int = (
|
| 1490 |
+
action.microbatch_index
|
| 1491 |
+
if action.microbatch_index is not None
|
| 1492 |
+
else -1
|
| 1493 |
+
)
|
| 1494 |
+
assert mb_index >= 0 or comp_type in (
|
| 1495 |
+
UNSHARD,
|
| 1496 |
+
RESHARD,
|
| 1497 |
+
), f"{action=} missing mb_index"
|
| 1498 |
+
stage_idx = action.stage_index
|
| 1499 |
+
stage = stage_index_to_stage[stage_idx]
|
| 1500 |
+
stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
|
| 1501 |
+
|
| 1502 |
+
logger.debug(
|
| 1503 |
+
"_PipelineScheduleRuntime running time_step %d, action %s",
|
| 1504 |
+
time_step,
|
| 1505 |
+
action,
|
| 1506 |
+
)
|
| 1507 |
+
|
| 1508 |
+
# TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
|
| 1509 |
+
# since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be
|
| 1510 |
+
# safe to use instead.
|
| 1511 |
+
# However, I was wondering if I should avoid calling batched operators at all in the case that there is
|
| 1512 |
+
# only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them.
|
| 1513 |
+
if comp_type == SEND_F:
|
| 1514 |
+
send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))
|
| 1515 |
+
elif comp_type == SEND_B:
|
| 1516 |
+
send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))
|
| 1517 |
+
elif comp_type == RECV_F:
|
| 1518 |
+
assert (
|
| 1519 |
+
stage_idx,
|
| 1520 |
+
mb_index,
|
| 1521 |
+
) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward"
|
| 1522 |
+
fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
| 1523 |
+
stage.get_fwd_recv_ops(mb_index)
|
| 1524 |
+
)
|
| 1525 |
+
elif comp_type == RECV_B:
|
| 1526 |
+
assert (
|
| 1527 |
+
stage_idx,
|
| 1528 |
+
mb_index,
|
| 1529 |
+
) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward"
|
| 1530 |
+
bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
| 1531 |
+
stage.get_bwd_recv_ops(mb_index)
|
| 1532 |
+
)
|
| 1533 |
+
elif comp_type == UNSHARD:
|
| 1534 |
+
if stage_uses_fsdp:
|
| 1535 |
+
assert (
|
| 1536 |
+
stage_idx not in unsharded_stages
|
| 1537 |
+
and stage_idx not in unshard_ops
|
| 1538 |
+
), f"Unsharding the same {stage_idx=} twice"
|
| 1539 |
+
unshard_ops[stage_idx] = stage.submod.unshard(async_op=True)
|
| 1540 |
+
elif comp_type == RESHARD:
|
| 1541 |
+
if stage_uses_fsdp:
|
| 1542 |
+
assert (
|
| 1543 |
+
stage_idx in unsharded_stages
|
| 1544 |
+
), f"Resharding {stage_idx=} without unsharding"
|
| 1545 |
+
assert (
|
| 1546 |
+
stage_idx not in unshard_ops
|
| 1547 |
+
), f"Resharding {stage_idx=} before finishing unshard"
|
| 1548 |
+
stage.submod.reshard()
|
| 1549 |
+
elif comp_type == FORWARD:
|
| 1550 |
+
if stage_uses_fsdp:
|
| 1551 |
+
_assert_unsharded(stage_idx)
|
| 1552 |
+
|
| 1553 |
+
if not stage.is_first:
|
| 1554 |
+
assert (
|
| 1555 |
+
stage_idx,
|
| 1556 |
+
mb_index,
|
| 1557 |
+
) in fwd_recv_ops, f"Computing {action=} before receiving input"
|
| 1558 |
+
fwd_recv_ops.pop((stage_idx, mb_index)).wait()
|
| 1559 |
+
output = stage.forward_one_chunk(
|
| 1560 |
+
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
|
| 1561 |
+
)
|
| 1562 |
+
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
|
| 1563 |
+
elif comp_type == BACKWARD:
|
| 1564 |
+
if stage_uses_fsdp:
|
| 1565 |
+
_assert_unsharded(stage_idx)
|
| 1566 |
+
|
| 1567 |
+
if not stage.is_last:
|
| 1568 |
+
assert (
|
| 1569 |
+
stage_idx,
|
| 1570 |
+
mb_index,
|
| 1571 |
+
) in bwd_recv_ops, (
|
| 1572 |
+
f"Attempted to run compute {action=} before receiving input"
|
| 1573 |
+
)
|
| 1574 |
+
bwd_recv_ops.pop((stage_idx, mb_index)).wait()
|
| 1575 |
+
loss = self._maybe_get_loss(stage, mb_index)
|
| 1576 |
+
stage.backward_one_chunk(
|
| 1577 |
+
mb_index, loss=loss, full_backward=self.use_full_backward
|
| 1578 |
+
)
|
| 1579 |
+
elif comp_type == WEIGHT:
|
| 1580 |
+
if stage_uses_fsdp:
|
| 1581 |
+
_assert_unsharded(stage_idx)
|
| 1582 |
+
|
| 1583 |
+
if self.use_full_backward:
|
| 1584 |
+
raise ValueError(
|
| 1585 |
+
f"We detected a weight update in the pipeline schedule, but \
|
| 1586 |
+
{self.use_full_backward=}"
|
| 1587 |
+
)
|
| 1588 |
+
stage.backward_weight_one_chunk(mb_index)
|
| 1589 |
+
else:
|
| 1590 |
+
raise ValueError(f"{action=} is unknown or unsupported")
|
| 1591 |
+
except Exception as e:
|
| 1592 |
+
logger.error(
|
| 1593 |
+
"_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:",
|
| 1594 |
+
time_step,
|
| 1595 |
+
action,
|
| 1596 |
+
)
|
| 1597 |
+
# TODO(whc) what is the best practice for printing a multiline log?
|
| 1598 |
+
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
|
| 1599 |
+
print(_format_pipeline_order(self.pipeline_order_with_comms)) # type: ignore[arg-type]
|
| 1600 |
+
raise e
|
| 1601 |
+
|
| 1602 |
+
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
|
| 1603 |
+
while len(send_ops):
|
| 1604 |
+
send_ops.pop().wait()
|
| 1605 |
+
|
| 1606 |
+
assert len(unshard_ops) == 0, "Unused unshard operations"
|
| 1607 |
+
|
| 1608 |
+
# Return losses if there is a container passed in
|
| 1609 |
+
self._update_losses(self._stages, losses)
|
| 1610 |
+
|
| 1611 |
+
|
| 1612 |
+
class ScheduleLoopedBFS(PipelineScheduleMulti):
|
| 1613 |
+
"""
|
| 1614 |
+
Breadth-First Pipeline Parallelism.
|
| 1615 |
+
See https://arxiv.org/abs/2211.05953 for details.
|
| 1616 |
+
Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
|
| 1617 |
+
What is different is that when microbatches are ready for multiple local
|
| 1618 |
+
stages, Loops BFS will prioritizes the earlier stage, running all available
|
| 1619 |
+
microbatches at once.
|
| 1620 |
+
"""
|
| 1621 |
+
|
| 1622 |
+
def __init__(
|
| 1623 |
+
self,
|
| 1624 |
+
stages: List[_PipelineStageBase],
|
| 1625 |
+
n_microbatches: int,
|
| 1626 |
+
loss_fn: Optional[Callable] = None,
|
| 1627 |
+
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 1628 |
+
):
|
| 1629 |
+
super().__init__(
|
| 1630 |
+
stages=stages,
|
| 1631 |
+
n_microbatches=n_microbatches,
|
| 1632 |
+
loss_fn=loss_fn,
|
| 1633 |
+
output_merge_spec=output_merge_spec,
|
| 1634 |
+
)
|
| 1635 |
+
|
| 1636 |
+
# 1. Create the pipeline_order (all ranks do this calculation)
|
| 1637 |
+
# This will be used to keep track of the current state of the entire pipeline
|
| 1638 |
+
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
| 1639 |
+
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
| 1640 |
+
# ========================================================================
|
| 1641 |
+
for rank in range(self.pp_group_size):
|
| 1642 |
+
rank_ops = self._calculate_single_rank_operations(rank)
|
| 1643 |
+
self.pipeline_order[rank] = rank_ops
|
| 1644 |
+
|
| 1645 |
+
def _calculate_single_rank_operations(self, rank):
|
| 1646 |
+
n_local_stages = len(self._stages)
|
| 1647 |
+
stage_indices = range(
|
| 1648 |
+
rank, self.pp_group_size * n_local_stages, self.pp_group_size
|
| 1649 |
+
)
|
| 1650 |
+
|
| 1651 |
+
# Store the list of operations used for that rank
|
| 1652 |
+
rank_ops: List[Optional[_Action]] = []
|
| 1653 |
+
# Pre-padding, rank starts with no-ops based on the warmup.
|
| 1654 |
+
for _ in range(rank):
|
| 1655 |
+
rank_ops.append(None)
|
| 1656 |
+
|
| 1657 |
+
for stage_index in stage_indices:
|
| 1658 |
+
for mb_index in range(self._n_microbatches):
|
| 1659 |
+
rank_ops.append(
|
| 1660 |
+
_Action(stage_index, _ComputationType.FORWARD, mb_index)
|
| 1661 |
+
)
|
| 1662 |
+
|
| 1663 |
+
# wait for the first backward to trickle up
|
| 1664 |
+
# which is 2 for every hop away
|
| 1665 |
+
post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
|
| 1666 |
+
rank_ops.extend([None] * post_warmup_ops)
|
| 1667 |
+
|
| 1668 |
+
for stage_index in reversed(stage_indices):
|
| 1669 |
+
for mb_index in reversed(range(self._n_microbatches)):
|
| 1670 |
+
rank_ops.append(
|
| 1671 |
+
_Action(stage_index, _ComputationType.BACKWARD, mb_index)
|
| 1672 |
+
)
|
| 1673 |
+
return rank_ops
|
| 1674 |
+
|
| 1675 |
+
|
| 1676 |
+
def _get_1f1b_rank_ops(
|
| 1677 |
+
n_local_stages,
|
| 1678 |
+
pp_group_size,
|
| 1679 |
+
warmup_ops,
|
| 1680 |
+
fwd_bwd_ops,
|
| 1681 |
+
cooldown_ops,
|
| 1682 |
+
rank,
|
| 1683 |
+
forward_stage_index,
|
| 1684 |
+
backward_stage_index,
|
| 1685 |
+
num_1f1b_microbatches=0,
|
| 1686 |
+
enable_zero_bubble=False,
|
| 1687 |
+
):
|
| 1688 |
+
# All stages start with handling microbatch 0
|
| 1689 |
+
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
| 1690 |
+
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
| 1691 |
+
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
|
| 1692 |
+
|
| 1693 |
+
# Store the list of operations used for that rank
|
| 1694 |
+
rank_ops: List[Optional[_Action]] = []
|
| 1695 |
+
# Pre-padding, rank starts with no-ops based on the warmup.
|
| 1696 |
+
for _ in range(rank):
|
| 1697 |
+
rank_ops.append(None)
|
| 1698 |
+
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
|
| 1699 |
+
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
|
| 1700 |
+
# Formula:
|
| 1701 |
+
# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
|
| 1702 |
+
# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
|
| 1703 |
+
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
|
| 1704 |
+
# warmup_ops = calculated above
|
| 1705 |
+
post_warmup_ops = (
|
| 1706 |
+
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
|
| 1707 |
+
) - (warmup_ops + rank)
|
| 1708 |
+
|
| 1709 |
+
if enable_zero_bubble:
|
| 1710 |
+
post_warmup_ops = pp_group_size - rank - 1
|
| 1711 |
+
|
| 1712 |
+
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
| 1713 |
+
|
| 1714 |
+
backward_op_ids = []
|
| 1715 |
+
weight_op_count = 0
|
| 1716 |
+
|
| 1717 |
+
for op in range(total_ops):
|
| 1718 |
+
# Warmup phase
|
| 1719 |
+
if op < warmup_ops:
|
| 1720 |
+
fwd_stage_index = forward_stage_index(op)
|
| 1721 |
+
# This will assign the current microbatch index and update it as well
|
| 1722 |
+
fwd_stage_mb_index[fwd_stage_index] = (
|
| 1723 |
+
mb_index := fwd_stage_mb_index[fwd_stage_index]
|
| 1724 |
+
) + 1
|
| 1725 |
+
rank_ops.append(
|
| 1726 |
+
_Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
|
| 1727 |
+
)
|
| 1728 |
+
if op == warmup_ops - 1:
|
| 1729 |
+
# This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
|
| 1730 |
+
rank_ops.extend([None] * post_warmup_ops)
|
| 1731 |
+
# 1F1B Phase (forward and backward)
|
| 1732 |
+
elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
|
| 1733 |
+
fwd_stage_index = forward_stage_index(op)
|
| 1734 |
+
fwd_stage_mb_index[fwd_stage_index] = (
|
| 1735 |
+
fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
|
| 1736 |
+
) + 1
|
| 1737 |
+
rank_ops.append(
|
| 1738 |
+
_Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
|
| 1739 |
+
)
|
| 1740 |
+
bwd_stage_index = backward_stage_index(op)
|
| 1741 |
+
bwd_stage_mb_index[bwd_stage_index] = (
|
| 1742 |
+
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
| 1743 |
+
) + 1
|
| 1744 |
+
rank_ops.append(
|
| 1745 |
+
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
|
| 1746 |
+
)
|
| 1747 |
+
backward_op_ids.append(op)
|
| 1748 |
+
|
| 1749 |
+
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
|
| 1750 |
+
weight_stage_index = backward_stage_index(
|
| 1751 |
+
backward_op_ids[weight_op_count]
|
| 1752 |
+
)
|
| 1753 |
+
weight_stage_mb_index[weight_stage_index] = (
|
| 1754 |
+
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
| 1755 |
+
) + 1
|
| 1756 |
+
rank_ops.append(
|
| 1757 |
+
_Action(
|
| 1758 |
+
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
|
| 1759 |
+
)
|
| 1760 |
+
)
|
| 1761 |
+
weight_op_count += 1
|
| 1762 |
+
# Cooldown phase
|
| 1763 |
+
else:
|
| 1764 |
+
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
|
| 1765 |
+
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
|
| 1766 |
+
if not enable_zero_bubble:
|
| 1767 |
+
rank_ops.append(None)
|
| 1768 |
+
|
| 1769 |
+
bwd_stage_index = backward_stage_index(op)
|
| 1770 |
+
bwd_stage_mb_index[bwd_stage_index] = (
|
| 1771 |
+
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
| 1772 |
+
) + 1
|
| 1773 |
+
rank_ops.append(
|
| 1774 |
+
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
|
| 1775 |
+
)
|
| 1776 |
+
backward_op_ids.append(op)
|
| 1777 |
+
|
| 1778 |
+
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
|
| 1779 |
+
weight_stage_index = backward_stage_index(
|
| 1780 |
+
backward_op_ids[weight_op_count]
|
| 1781 |
+
)
|
| 1782 |
+
weight_stage_mb_index[weight_stage_index] = (
|
| 1783 |
+
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
| 1784 |
+
) + 1
|
| 1785 |
+
rank_ops.append(
|
| 1786 |
+
_Action(
|
| 1787 |
+
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
|
| 1788 |
+
)
|
| 1789 |
+
)
|
| 1790 |
+
weight_op_count += 1
|
| 1791 |
+
|
| 1792 |
+
while enable_zero_bubble and weight_op_count < len(backward_op_ids):
|
| 1793 |
+
weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
|
| 1794 |
+
weight_stage_mb_index[weight_stage_index] = (
|
| 1795 |
+
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
| 1796 |
+
) + 1
|
| 1797 |
+
rank_ops.append(
|
| 1798 |
+
_Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
|
| 1799 |
+
)
|
| 1800 |
+
weight_op_count += 1
|
| 1801 |
+
|
| 1802 |
+
return rank_ops
|
| 1803 |
+
|
| 1804 |
+
|
| 1805 |
+
class ScheduleInterleaved1F1B(PipelineScheduleMulti):
|
| 1806 |
+
"""
|
| 1807 |
+
The Interleaved 1F1B schedule.
|
| 1808 |
+
See https://arxiv.org/pdf/2104.04473 for details.
|
| 1809 |
+
Will perform one forward and one backward on the microbatches in steady
|
| 1810 |
+
state and supports multiple stages per rank. When microbatches are ready for
|
| 1811 |
+
multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
|
| 1812 |
+
(also called "depth first").
|
| 1813 |
+
"""
|
| 1814 |
+
|
| 1815 |
+
def __init__(
|
| 1816 |
+
self,
|
| 1817 |
+
stages: List[_PipelineStageBase],
|
| 1818 |
+
n_microbatches: int,
|
| 1819 |
+
loss_fn: Optional[Callable] = None,
|
| 1820 |
+
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
| 1821 |
+
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
| 1822 |
+
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 1823 |
+
):
|
| 1824 |
+
self.pp_group_size = stages[0].group_size
|
| 1825 |
+
# TODO: is this limitation a must?
|
| 1826 |
+
if n_microbatches % self.pp_group_size != 0:
|
| 1827 |
+
raise ValueError(
|
| 1828 |
+
f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \
|
| 1829 |
+
to be a multiple of the number of pipeline ranks ({self.pp_group_size})."
|
| 1830 |
+
)
|
| 1831 |
+
|
| 1832 |
+
super().__init__(
|
| 1833 |
+
stages=stages,
|
| 1834 |
+
n_microbatches=n_microbatches,
|
| 1835 |
+
loss_fn=loss_fn,
|
| 1836 |
+
args_chunk_spec=args_chunk_spec,
|
| 1837 |
+
kwargs_chunk_spec=kwargs_chunk_spec,
|
| 1838 |
+
output_merge_spec=output_merge_spec,
|
| 1839 |
+
)
|
| 1840 |
+
|
| 1841 |
+
self.n_local_stages = len(stages)
|
| 1842 |
+
self.rank = stages[0].group_rank
|
| 1843 |
+
self.group = stages[0].group
|
| 1844 |
+
|
| 1845 |
+
# 1. Create the pipeline_order (all ranks do this calculation)
|
| 1846 |
+
# This will be used to keep track of the current state of the entire pipeline
|
| 1847 |
+
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
| 1848 |
+
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
| 1849 |
+
|
| 1850 |
+
for rank in range(self.pp_group_size):
|
| 1851 |
+
rank_ops = self._calculate_single_rank_operations(rank)
|
| 1852 |
+
self.pipeline_order[rank] = rank_ops
|
| 1853 |
+
|
| 1854 |
+
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
| 1855 |
+
def get_rank_warmup_ops(rank):
|
| 1856 |
+
# Warms up operations for last stage
|
| 1857 |
+
warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size
|
| 1858 |
+
# Increment warmup operations by 2 for each hop away from the last stage
|
| 1859 |
+
warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank)
|
| 1860 |
+
# We cannot have more warmup operations than there are number of microbatches, so cap it there
|
| 1861 |
+
return min(warmup_ops, self._n_microbatches * self.n_local_stages)
|
| 1862 |
+
|
| 1863 |
+
warmup_ops = get_rank_warmup_ops(rank)
|
| 1864 |
+
microbatch_ops = self.n_local_stages * self._n_microbatches
|
| 1865 |
+
# fwd_bwd_ops should encompass the remaining forwards
|
| 1866 |
+
fwd_bwd_ops = microbatch_ops - warmup_ops
|
| 1867 |
+
# cooldown_ops should encompass the remaining backwards
|
| 1868 |
+
cooldown_ops = microbatch_ops - fwd_bwd_ops
|
| 1869 |
+
# total ops encompass both forward and backward ops
|
| 1870 |
+
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
| 1871 |
+
# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
|
| 1872 |
+
|
| 1873 |
+
logger.debug(
|
| 1874 |
+
"rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
|
| 1875 |
+
rank,
|
| 1876 |
+
warmup_ops,
|
| 1877 |
+
fwd_bwd_ops,
|
| 1878 |
+
cooldown_ops,
|
| 1879 |
+
total_ops,
|
| 1880 |
+
)
|
| 1881 |
+
|
| 1882 |
+
# Calculates the stage index based on step and pp_group_size
|
| 1883 |
+
def forward_stage_index(step):
|
| 1884 |
+
# Get the local index from 0 to n_local_stages-1
|
| 1885 |
+
local_index = (step // self.pp_group_size) % self.n_local_stages
|
| 1886 |
+
return (local_index * self.pp_group_size) + rank
|
| 1887 |
+
|
| 1888 |
+
def backward_stage_index(step):
|
| 1889 |
+
local_index = (
|
| 1890 |
+
self.n_local_stages
|
| 1891 |
+
- 1
|
| 1892 |
+
- ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages
|
| 1893 |
+
)
|
| 1894 |
+
return (local_index * self.pp_group_size) + rank
|
| 1895 |
+
|
| 1896 |
+
return _get_1f1b_rank_ops(
|
| 1897 |
+
self.n_local_stages,
|
| 1898 |
+
self.pp_group_size,
|
| 1899 |
+
warmup_ops,
|
| 1900 |
+
fwd_bwd_ops,
|
| 1901 |
+
cooldown_ops,
|
| 1902 |
+
rank,
|
| 1903 |
+
forward_stage_index,
|
| 1904 |
+
backward_stage_index,
|
| 1905 |
+
)
|
| 1906 |
+
|
| 1907 |
+
|
| 1908 |
+
class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
| 1909 |
+
"""
|
| 1910 |
+
The Flexible Interleaved 1F1B schedule.
|
| 1911 |
+
|
| 1912 |
+
This schedule is mostly similar to the interleaved 1F1B schedule.
|
| 1913 |
+
It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
|
| 1914 |
+
Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
|
| 1915 |
+
it works as long as n_microbatches % num_rounds is 0. As a few examples, support
|
| 1916 |
+
|
| 1917 |
+
1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
|
| 1918 |
+
2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
|
| 1919 |
+
|
| 1920 |
+
When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5
|
| 1921 |
+
"""
|
| 1922 |
+
|
| 1923 |
+
def __init__(
|
| 1924 |
+
self,
|
| 1925 |
+
stages: List[_PipelineStageBase],
|
| 1926 |
+
n_microbatches: int,
|
| 1927 |
+
loss_fn: Optional[Callable] = None,
|
| 1928 |
+
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
| 1929 |
+
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
| 1930 |
+
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 1931 |
+
enable_zero_bubble: bool = False,
|
| 1932 |
+
):
|
| 1933 |
+
self.pp_group_size = stages[0].group_size
|
| 1934 |
+
super().__init__(
|
| 1935 |
+
stages=stages,
|
| 1936 |
+
n_microbatches=n_microbatches,
|
| 1937 |
+
loss_fn=loss_fn,
|
| 1938 |
+
args_chunk_spec=args_chunk_spec,
|
| 1939 |
+
kwargs_chunk_spec=kwargs_chunk_spec,
|
| 1940 |
+
output_merge_spec=output_merge_spec,
|
| 1941 |
+
use_full_backward=not enable_zero_bubble,
|
| 1942 |
+
)
|
| 1943 |
+
self.n_local_stages = len(stages)
|
| 1944 |
+
self.rank = stages[0].group_rank
|
| 1945 |
+
self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
|
| 1946 |
+
self.microbatches_per_round = n_microbatches // self.number_of_rounds
|
| 1947 |
+
self.enable_zero_bubble = enable_zero_bubble
|
| 1948 |
+
if n_microbatches % self.number_of_rounds != 0:
|
| 1949 |
+
raise ValueError(
|
| 1950 |
+
"Flexible Interleaved 1F1B requires the number of microbatches to be a "
|
| 1951 |
+
f"multiple of the number of rounds ({self.number_of_rounds}), "
|
| 1952 |
+
f"but got {n_microbatches}."
|
| 1953 |
+
)
|
| 1954 |
+
# 1. Create the pipeline_order (all ranks do this calculation)
|
| 1955 |
+
# This will be used to keep track of the current state of the entire pipeline
|
| 1956 |
+
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
| 1957 |
+
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
| 1958 |
+
for rank in range(self.pp_group_size):
|
| 1959 |
+
rank_ops = self._calculate_single_rank_operations(rank)
|
| 1960 |
+
self.pipeline_order[rank] = rank_ops
|
| 1961 |
+
|
| 1962 |
+
# This function add bubbles to the generated schedule based on dependencies of actions
|
| 1963 |
+
# Note that the ZB1P schedule will not require bubbles to be manually added and it is
|
| 1964 |
+
# only useful when n_microbatches <= microbatches_per_round
|
| 1965 |
+
self.pipeline_order = self._add_bubbles_to_actions(
|
| 1966 |
+
self.n_local_stages * self.pp_group_size,
|
| 1967 |
+
)
|
| 1968 |
+
|
| 1969 |
+
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
| 1970 |
+
def get_rank_warmup_ops(rank):
|
| 1971 |
+
# Warms up operations for last stage
|
| 1972 |
+
warmups_ops_last_stage = (
|
| 1973 |
+
self.n_local_stages - 1
|
| 1974 |
+
) * self.microbatches_per_round
|
| 1975 |
+
# Increment warmup operations by 2 for each hop away from the last stage
|
| 1976 |
+
multiply_factor = 1 if self.enable_zero_bubble else 2
|
| 1977 |
+
warmup_ops = warmups_ops_last_stage + multiply_factor * (
|
| 1978 |
+
(self.pp_group_size - 1) - rank
|
| 1979 |
+
)
|
| 1980 |
+
|
| 1981 |
+
# We cannot have more warmup operations than there are number of microbatches, so cap it there
|
| 1982 |
+
return min(warmup_ops, self._n_microbatches * self.n_local_stages)
|
| 1983 |
+
|
| 1984 |
+
warmup_ops = get_rank_warmup_ops(rank)
|
| 1985 |
+
microbatch_ops = self.n_local_stages * self._n_microbatches
|
| 1986 |
+
# fwd_bwd_ops should encompass the remaining forwards
|
| 1987 |
+
fwd_bwd_ops = microbatch_ops - warmup_ops
|
| 1988 |
+
# cooldown_ops should encompass the remaining backwards
|
| 1989 |
+
cooldown_ops = microbatch_ops - fwd_bwd_ops
|
| 1990 |
+
# total ops encompass both forward and backward ops
|
| 1991 |
+
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
| 1992 |
+
# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
|
| 1993 |
+
logger.debug(
|
| 1994 |
+
"rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
|
| 1995 |
+
rank,
|
| 1996 |
+
warmup_ops,
|
| 1997 |
+
fwd_bwd_ops,
|
| 1998 |
+
cooldown_ops,
|
| 1999 |
+
total_ops,
|
| 2000 |
+
)
|
| 2001 |
+
|
| 2002 |
+
# Calculates the stage index based on step and pp_group_size
|
| 2003 |
+
|
| 2004 |
+
def forward_stage_index(step):
|
| 2005 |
+
# Get the local index from 0 to n_local_stages-1
|
| 2006 |
+
local_index = (step // self.microbatches_per_round) % self.n_local_stages
|
| 2007 |
+
return (local_index * self.pp_group_size) + rank
|
| 2008 |
+
|
| 2009 |
+
def backward_stage_index(step):
|
| 2010 |
+
local_index = (
|
| 2011 |
+
self.n_local_stages
|
| 2012 |
+
- 1
|
| 2013 |
+
- ((step - warmup_ops) // self.microbatches_per_round)
|
| 2014 |
+
% self.n_local_stages
|
| 2015 |
+
)
|
| 2016 |
+
return (local_index * self.pp_group_size) + rank
|
| 2017 |
+
|
| 2018 |
+
if self.enable_zero_bubble:
|
| 2019 |
+
num_1f1b_microbatches = rank
|
| 2020 |
+
|
| 2021 |
+
return _get_1f1b_rank_ops(
|
| 2022 |
+
self.n_local_stages,
|
| 2023 |
+
self.pp_group_size,
|
| 2024 |
+
warmup_ops,
|
| 2025 |
+
fwd_bwd_ops,
|
| 2026 |
+
cooldown_ops,
|
| 2027 |
+
rank,
|
| 2028 |
+
forward_stage_index,
|
| 2029 |
+
backward_stage_index,
|
| 2030 |
+
num_1f1b_microbatches,
|
| 2031 |
+
enable_zero_bubble=True,
|
| 2032 |
+
)
|
| 2033 |
+
|
| 2034 |
+
return _get_1f1b_rank_ops(
|
| 2035 |
+
self.n_local_stages,
|
| 2036 |
+
self.pp_group_size,
|
| 2037 |
+
warmup_ops,
|
| 2038 |
+
fwd_bwd_ops,
|
| 2039 |
+
cooldown_ops,
|
| 2040 |
+
rank,
|
| 2041 |
+
forward_stage_index,
|
| 2042 |
+
backward_stage_index,
|
| 2043 |
+
)
|
| 2044 |
+
|
| 2045 |
+
def _add_bubbles_to_actions(self, num_stages_global):
|
| 2046 |
+
actions = self.pipeline_order
|
| 2047 |
+
if not self.enable_zero_bubble:
|
| 2048 |
+
return actions
|
| 2049 |
+
|
| 2050 |
+
def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
|
| 2051 |
+
if op == _ComputationType.FORWARD:
|
| 2052 |
+
if stage != 0 and (stage - 1, op, microbatch) not in seen_ops:
|
| 2053 |
+
return True
|
| 2054 |
+
elif op == _ComputationType.BACKWARD:
|
| 2055 |
+
if stage == num_stages_global - 1:
|
| 2056 |
+
return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops
|
| 2057 |
+
return (stage + 1, op, microbatch) not in seen_ops
|
| 2058 |
+
return False
|
| 2059 |
+
|
| 2060 |
+
seen_ops: Set[Tuple[int, _ComputationType, int]] = set()
|
| 2061 |
+
result: Dict[int, List[Optional[_Action]]] = {}
|
| 2062 |
+
next_pointer: Dict[int, int] = {}
|
| 2063 |
+
bubbles_added: Dict[int, int] = {}
|
| 2064 |
+
total_bubbles_added = 0
|
| 2065 |
+
|
| 2066 |
+
for rank in range(self.pp_group_size):
|
| 2067 |
+
result[rank] = []
|
| 2068 |
+
next_pointer[rank] = 0
|
| 2069 |
+
bubbles_added[rank] = 0
|
| 2070 |
+
|
| 2071 |
+
while True:
|
| 2072 |
+
should_stop = True
|
| 2073 |
+
|
| 2074 |
+
temp_seen_ops: Set[Tuple[int, _ComputationType, int]] = set()
|
| 2075 |
+
|
| 2076 |
+
for rank in range(self.pp_group_size):
|
| 2077 |
+
timestamp = next_pointer[rank]
|
| 2078 |
+
if timestamp >= len(actions[rank]):
|
| 2079 |
+
continue
|
| 2080 |
+
|
| 2081 |
+
should_stop = False
|
| 2082 |
+
|
| 2083 |
+
if actions[rank][timestamp] is not None:
|
| 2084 |
+
temp_action = actions[rank][timestamp]
|
| 2085 |
+
assert temp_action is not None
|
| 2086 |
+
stage_index, op, microbatch = temp_action
|
| 2087 |
+
if not need_bubble(
|
| 2088 |
+
stage_index, op, microbatch, num_stages_global, seen_ops
|
| 2089 |
+
):
|
| 2090 |
+
result[rank].append(actions[rank][timestamp])
|
| 2091 |
+
if microbatch is not None:
|
| 2092 |
+
temp_seen_ops.add((stage_index, op, microbatch))
|
| 2093 |
+
next_pointer[rank] += 1
|
| 2094 |
+
else:
|
| 2095 |
+
result[rank].append(None)
|
| 2096 |
+
bubbles_added[rank] += 1
|
| 2097 |
+
else:
|
| 2098 |
+
next_pointer[rank] += 1
|
| 2099 |
+
result[rank].append(None)
|
| 2100 |
+
|
| 2101 |
+
seen_ops.update(temp_seen_ops)
|
| 2102 |
+
if should_stop:
|
| 2103 |
+
break
|
| 2104 |
+
|
| 2105 |
+
if total_bubbles_added > 0:
|
| 2106 |
+
logger.warning(
|
| 2107 |
+
"Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",
|
| 2108 |
+
total_bubbles_added,
|
| 2109 |
+
bubbles_added,
|
| 2110 |
+
)
|
| 2111 |
+
return result
|
| 2112 |
+
|
| 2113 |
+
|
| 2114 |
+
class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B):
|
| 2115 |
+
"""
|
| 2116 |
+
The Interleaved Zero Bubble schedule.
|
| 2117 |
+
See https://arxiv.org/pdf/2401.10241 for details.
|
| 2118 |
+
Will perform one forward and one backward on inputs for the microbatches in steady
|
| 2119 |
+
state and supports multiple stages per rank. Uses the backward for weights to fill in
|
| 2120 |
+
the pipeline bubble.
|
| 2121 |
+
"""
|
| 2122 |
+
|
| 2123 |
+
def __init__(
|
| 2124 |
+
self,
|
| 2125 |
+
stages: List[_PipelineStageBase],
|
| 2126 |
+
n_microbatches: int,
|
| 2127 |
+
loss_fn: Optional[Callable] = None,
|
| 2128 |
+
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
| 2129 |
+
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
| 2130 |
+
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
| 2131 |
+
):
|
| 2132 |
+
super().__init__(
|
| 2133 |
+
stages=stages,
|
| 2134 |
+
n_microbatches=n_microbatches,
|
| 2135 |
+
loss_fn=loss_fn,
|
| 2136 |
+
args_chunk_spec=args_chunk_spec,
|
| 2137 |
+
kwargs_chunk_spec=kwargs_chunk_spec,
|
| 2138 |
+
output_merge_spec=output_merge_spec,
|
| 2139 |
+
enable_zero_bubble=True,
|
| 2140 |
+
)
|
| 2141 |
+
|
| 2142 |
+
|
| 2143 |
+
def get_schedule_class(schedule_name: str):
|
| 2144 |
+
"""
|
| 2145 |
+
Maps a schedule name to its corresponding class object.
|
| 2146 |
+
|
| 2147 |
+
Args:
|
| 2148 |
+
schedule_name (str): The name of the schedule.
|
| 2149 |
+
"""
|
| 2150 |
+
schedule_map = {
|
| 2151 |
+
"1F1B": Schedule1F1B,
|
| 2152 |
+
"Interleaved1F1B": ScheduleInterleaved1F1B,
|
| 2153 |
+
"GPipe": ScheduleGPipe,
|
| 2154 |
+
"FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B,
|
| 2155 |
+
"LoopedBFS": ScheduleLoopedBFS,
|
| 2156 |
+
"InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
|
| 2157 |
+
"PipelineScheduleSingle": PipelineScheduleSingle,
|
| 2158 |
+
"PipelineScheduleMulti": PipelineScheduleMulti,
|
| 2159 |
+
}
|
| 2160 |
+
if schedule_name not in schedule_map:
|
| 2161 |
+
raise ValueError(f"Unknown schedule name: {schedule_name}")
|
| 2162 |
+
return schedule_map[schedule_name]
|
.venv/lib/python3.11/site-packages/torch/distributed/pipelining/stage.py
ADDED
|
@@ -0,0 +1,1468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 3 |
+
import logging
|
| 4 |
+
import operator
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
import torch.fx as fx
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 13 |
+
from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard
|
| 14 |
+
from torch.fx.node import map_aggregate
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 16 |
+
|
| 17 |
+
from ._backward import stage_backward, stage_backward_input, stage_backward_weight
|
| 18 |
+
from ._debug import map_debug_info
|
| 19 |
+
from ._utils import flatten_args, PipeInfo, validate_tensors_metadata
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"PipelineStage",
|
| 24 |
+
"build_stage",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class _RootArgPlaceholder:
|
| 31 |
+
"""
|
| 32 |
+
Placeholder for model-level inputs.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, tensor):
|
| 36 |
+
self.meta = tensor.to("meta")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class _RecvInfo:
|
| 40 |
+
"""
|
| 41 |
+
Represents a stage input.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
input_name: str,
|
| 47 |
+
source: int,
|
| 48 |
+
buffer: torch.Tensor,
|
| 49 |
+
):
|
| 50 |
+
# Name of this input
|
| 51 |
+
self.input_name = input_name
|
| 52 |
+
# Stage index of the source of this input
|
| 53 |
+
self.source = source
|
| 54 |
+
# Buffer to receive the input into.
|
| 55 |
+
self.buffer = buffer
|
| 56 |
+
|
| 57 |
+
def __repr__(self):
|
| 58 |
+
return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# An input can be either a received activation or a model input
|
| 62 |
+
InputInfo = Union[_RecvInfo, _RootArgPlaceholder]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _make_tensor_from_meta(
|
| 66 |
+
example: Union[torch.Tensor, FakeTensor],
|
| 67 |
+
device: torch.device,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Create a real tensor from a tensor.
|
| 71 |
+
"""
|
| 72 |
+
return torch.empty(
|
| 73 |
+
example.size(),
|
| 74 |
+
dtype=example.dtype,
|
| 75 |
+
layout=example.layout,
|
| 76 |
+
device=device,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class _PipelineStageBase(ABC):
|
| 81 |
+
"""
|
| 82 |
+
Base class for pipeline stages.
|
| 83 |
+
Defines or implements common methods used by the `_PipelineStage` used by
|
| 84 |
+
the tracing frontend and `PipelineStage` used by manual frontend.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
submodule: torch.nn.Module,
|
| 90 |
+
stage_index: int,
|
| 91 |
+
num_stages: int,
|
| 92 |
+
device: torch.device,
|
| 93 |
+
group: Optional[dist.ProcessGroup] = None,
|
| 94 |
+
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
submodule (torch.nn.Module): The module to be executed in this stage.
|
| 99 |
+
stage_index (int): The index of this stage.
|
| 100 |
+
num_stages (int): The total number of stages in this pipeline.
|
| 101 |
+
device (torch.device): The device to run this stage on.
|
| 102 |
+
group (Optional[dist.ProcessGroup]): The process group to use for communication.
|
| 103 |
+
If `None`, the default process group will be used.
|
| 104 |
+
Default: `None`.
|
| 105 |
+
dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_runner is a builder function
|
| 106 |
+
that will build a new dw_runner function that will run parts of module backward that were intentionally
|
| 107 |
+
skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs
|
| 108 |
+
model backwards, and stage should save the latest dw_runner to run during weight pass.
|
| 109 |
+
If not provided, a dw_runner will be generated automatically by traversing the autograd graph.
|
| 110 |
+
When used with schedules that only have F and B steps, the fresh dw_runner function will be called as
|
| 111 |
+
part of B.
|
| 112 |
+
When used with F,B,W schedules, the dw_runner function implements 'W'.
|
| 113 |
+
"""
|
| 114 |
+
super().__init__()
|
| 115 |
+
if stage_index >= num_stages:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"Stage index {stage_index} is out of range of {num_stages}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.submod = submodule
|
| 121 |
+
self.stage_index = stage_index
|
| 122 |
+
self.num_stages = num_stages
|
| 123 |
+
self.device = device
|
| 124 |
+
self.group = group
|
| 125 |
+
|
| 126 |
+
self.dw_builder = dw_builder
|
| 127 |
+
|
| 128 |
+
# backward state
|
| 129 |
+
self.backward_state: Dict[int, Tuple[Any, ...]] = {}
|
| 130 |
+
|
| 131 |
+
# store dw_runner per microbatch_id
|
| 132 |
+
self.dw_runner: Dict[int, Callable[..., None]] = {}
|
| 133 |
+
|
| 134 |
+
# `group_rank` is rank in process group `group`.
|
| 135 |
+
self.group_rank = dist.get_rank(self.group)
|
| 136 |
+
self.group_size = dist.get_world_size(self.group)
|
| 137 |
+
if self.group_size > self.num_stages:
|
| 138 |
+
raise RuntimeError(
|
| 139 |
+
f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Run time states
|
| 143 |
+
self._outputs_meta: Optional[Tuple[torch.Tensor, ...]] = None
|
| 144 |
+
# map microbatch ID to list of forward tensor args
|
| 145 |
+
self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {}
|
| 146 |
+
# Caching chunk outputs for final output merge or reduction
|
| 147 |
+
self.output_chunks: List[Any] = []
|
| 148 |
+
|
| 149 |
+
# Initialize has_backward to false; this will be set to true if loss
|
| 150 |
+
# function is passed to pipeline schedule
|
| 151 |
+
self.has_backward = False
|
| 152 |
+
# Log prefix
|
| 153 |
+
self.log_prefix = f"[Stage {self.stage_index}]"
|
| 154 |
+
|
| 155 |
+
# Forward infra
|
| 156 |
+
self.args_recv_info: Dict[int, Tuple[InputInfo, ...]] = {}
|
| 157 |
+
self.set_requires_grad: Dict[int, bool] = {}
|
| 158 |
+
self.act_send_info: Dict[int, List] = {}
|
| 159 |
+
|
| 160 |
+
# Backward infra will created lazily
|
| 161 |
+
self.grad_recv_info: Dict = {}
|
| 162 |
+
self.grad_send_info: Optional[List] = None
|
| 163 |
+
|
| 164 |
+
# Number of backward chunks seen. This is used to determine when to do
|
| 165 |
+
# grad reduction in DDP or FSDP.
|
| 166 |
+
self._seen_bwd_chunks = 0
|
| 167 |
+
|
| 168 |
+
# To be populated later by the Schedule
|
| 169 |
+
self.chunks: Optional[int] = None
|
| 170 |
+
self.stage_index_to_group_rank: Dict[int, int] = {
|
| 171 |
+
i: i % self.group_size for i in range(self.num_stages)
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def has_backward(self) -> bool:
|
| 176 |
+
"""
|
| 177 |
+
Returns true if this stage has a backward pass.
|
| 178 |
+
"""
|
| 179 |
+
return self._has_backward
|
| 180 |
+
|
| 181 |
+
@has_backward.setter
|
| 182 |
+
def has_backward(self, has_backward: bool):
|
| 183 |
+
self._has_backward = has_backward
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def is_first(self):
|
| 187 |
+
"""
|
| 188 |
+
Returns true if this stage is the first stage in the pipeline.
|
| 189 |
+
"""
|
| 190 |
+
return self.stage_index == 0
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def is_last(self):
|
| 194 |
+
"""
|
| 195 |
+
Returns true if this stage is the last stage in the pipeline.
|
| 196 |
+
"""
|
| 197 |
+
return self.stage_index == self.num_stages - 1
|
| 198 |
+
|
| 199 |
+
def _check_chunk_id(self, chunk_id: int):
|
| 200 |
+
if self.chunks is None:
|
| 201 |
+
raise RuntimeError(
|
| 202 |
+
"Attempted to access chunk_id before chunks have been configured."
|
| 203 |
+
)
|
| 204 |
+
if chunk_id >= self.chunks:
|
| 205 |
+
raise RuntimeError(
|
| 206 |
+
f"Chunk id {chunk_id} is out of range [0, {self.chunks})"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def _configure_outputs_meta(self, outputs_meta: Tuple[torch.Tensor, ...]):
|
| 210 |
+
"""
|
| 211 |
+
Track the output shapes/dtype of this stage since they determine the send operation(s) which must match
|
| 212 |
+
recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial
|
| 213 |
+
configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
|
| 214 |
+
which could show up as hangs, silent corruption, or other errors.
|
| 215 |
+
"""
|
| 216 |
+
assert (
|
| 217 |
+
self._outputs_meta is None
|
| 218 |
+
), "Attempting to reconfigure output_meta, which is not supported"
|
| 219 |
+
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
|
| 220 |
+
|
| 221 |
+
def get_outputs_meta(self) -> Tuple[torch.Tensor, ...]:
|
| 222 |
+
"""Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
|
| 223 |
+
assert (
|
| 224 |
+
self._outputs_meta is not None
|
| 225 |
+
), "Attempted to get_outputs_meta() without configuring output meta"
|
| 226 |
+
return self._outputs_meta
|
| 227 |
+
|
| 228 |
+
def _create_grad_send_info(
|
| 229 |
+
self,
|
| 230 |
+
args_recv_info: Tuple,
|
| 231 |
+
) -> List[Optional[int]]:
|
| 232 |
+
"""
|
| 233 |
+
Create a list of stage indices to send gradients to.
|
| 234 |
+
"""
|
| 235 |
+
grad_send_info: List[Optional[int]] = []
|
| 236 |
+
|
| 237 |
+
def map_recv_to_send(a):
|
| 238 |
+
# Note: we send gradients back to previous stage as long as in
|
| 239 |
+
# forward it is a received input, regardless of whether it requires
|
| 240 |
+
# grad. It is up to the previous stage to disgard this gradient.
|
| 241 |
+
if isinstance(a, _RecvInfo):
|
| 242 |
+
grad_send_info.append(a.source)
|
| 243 |
+
return a.source
|
| 244 |
+
else:
|
| 245 |
+
grad_send_info.append(None)
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
map_aggregate(args_recv_info, map_recv_to_send)
|
| 249 |
+
|
| 250 |
+
logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info)
|
| 251 |
+
return grad_send_info
|
| 252 |
+
|
| 253 |
+
@abstractmethod
|
| 254 |
+
def _prepare_forward_infra(self, num_microbatches: int):
|
| 255 |
+
raise NotImplementedError
|
| 256 |
+
|
| 257 |
+
def _prepare_backward_infra(self, num_microbatches: int):
|
| 258 |
+
# TODO: this is needed for backward_maybe_with_nosync
|
| 259 |
+
self.chunks = num_microbatches
|
| 260 |
+
|
| 261 |
+
for mb_index in range(num_microbatches):
|
| 262 |
+
# `grad_recv_info` is a mirror of `act_send_info`
|
| 263 |
+
self.grad_recv_info[mb_index] = self._create_grad_recv_info(
|
| 264 |
+
self.act_send_info
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
@abstractmethod
|
| 268 |
+
def _create_grad_recv_info(
|
| 269 |
+
self,
|
| 270 |
+
act_send_info: Dict,
|
| 271 |
+
) -> Tuple[_RecvInfo, ...]:
|
| 272 |
+
raise NotImplementedError
|
| 273 |
+
|
| 274 |
+
def _get_recv_ops(
|
| 275 |
+
self,
|
| 276 |
+
recv_infos: Tuple[InputInfo, ...],
|
| 277 |
+
) -> List[dist.P2POp]:
|
| 278 |
+
"""
|
| 279 |
+
Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`.
|
| 280 |
+
Returns a list of ops that correspond to the recv infos.
|
| 281 |
+
"""
|
| 282 |
+
ops: List[dist.P2POp] = []
|
| 283 |
+
for info in recv_infos:
|
| 284 |
+
if not isinstance(info, _RecvInfo):
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
peer_rank = self.stage_index_to_group_rank[info.source]
|
| 288 |
+
peer_global_rank = (
|
| 289 |
+
peer_rank
|
| 290 |
+
if self.group is None
|
| 291 |
+
else dist.get_global_rank(self.group, peer_rank)
|
| 292 |
+
) # TODO
|
| 293 |
+
ops.append(
|
| 294 |
+
dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group)
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
return ops
|
| 298 |
+
|
| 299 |
+
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
|
| 300 |
+
"""
|
| 301 |
+
Returns a list of ops that are needed to receive the input arguments
|
| 302 |
+
for this stage.
|
| 303 |
+
"""
|
| 304 |
+
recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id]
|
| 305 |
+
|
| 306 |
+
# In case there is backward pass, set requires_grad for receive buffers
|
| 307 |
+
# before first forward
|
| 308 |
+
if self.has_backward and not self.set_requires_grad[fwd_chunk_id]:
|
| 309 |
+
for a in recv_infos:
|
| 310 |
+
if isinstance(a, _RecvInfo):
|
| 311 |
+
a.buffer.requires_grad_(True)
|
| 312 |
+
|
| 313 |
+
return self._get_recv_ops(recv_infos)
|
| 314 |
+
|
| 315 |
+
def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
|
| 316 |
+
"""
|
| 317 |
+
Returns a list of ops that are needed to receive the gradients
|
| 318 |
+
for this stage.
|
| 319 |
+
"""
|
| 320 |
+
if not self.has_backward or self.is_last:
|
| 321 |
+
return []
|
| 322 |
+
|
| 323 |
+
recv_infos = self.grad_recv_info[bwd_chunk_id]
|
| 324 |
+
return self._get_recv_ops(recv_infos)
|
| 325 |
+
|
| 326 |
+
def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
|
| 327 |
+
"""
|
| 328 |
+
Get the activation send ops for current stage's forward.
|
| 329 |
+
"""
|
| 330 |
+
output = self.output_chunks[fwd_chunk_id]
|
| 331 |
+
# Unify output form to tuple for easy correspondance with
|
| 332 |
+
# `act_send_info`
|
| 333 |
+
output_tuple = output if type(output) is tuple else (output,)
|
| 334 |
+
|
| 335 |
+
ops: List[dist.P2POp] = []
|
| 336 |
+
|
| 337 |
+
for idx, out in enumerate(output_tuple):
|
| 338 |
+
dst_stages = self.act_send_info[idx]
|
| 339 |
+
for dst in dst_stages:
|
| 340 |
+
if dst is None:
|
| 341 |
+
continue
|
| 342 |
+
logger.debug(
|
| 343 |
+
"%s Sending tensor to Stage %s: %s",
|
| 344 |
+
self.log_prefix,
|
| 345 |
+
dst,
|
| 346 |
+
out.size(),
|
| 347 |
+
)
|
| 348 |
+
peer_rank = self.stage_index_to_group_rank[dst]
|
| 349 |
+
peer_global_rank = (
|
| 350 |
+
peer_rank
|
| 351 |
+
if self.group is None
|
| 352 |
+
else dist.get_global_rank(self.group, peer_rank)
|
| 353 |
+
) # TODO
|
| 354 |
+
ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group))
|
| 355 |
+
|
| 356 |
+
return ops
|
| 357 |
+
|
| 358 |
+
def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
|
| 359 |
+
"""
|
| 360 |
+
Get the gradient send ops for current stage's backward.
|
| 361 |
+
"""
|
| 362 |
+
self._check_chunk_id(bwd_chunk_id)
|
| 363 |
+
|
| 364 |
+
if not self.has_backward or self.is_first:
|
| 365 |
+
return []
|
| 366 |
+
|
| 367 |
+
# Create bwd send infra lazily
|
| 368 |
+
if self.grad_send_info is None:
|
| 369 |
+
# Send info for input grads during backward:
|
| 370 |
+
# List of destinations corresponding to input grads
|
| 371 |
+
# Can be None if an input has no grad
|
| 372 |
+
# `grad_send_info` is a mirror of `args_recv_info`
|
| 373 |
+
self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0])
|
| 374 |
+
|
| 375 |
+
ops: List[dist.P2POp] = []
|
| 376 |
+
for grad, grad_recv_stage in zip(self.grads_input, self.grad_send_info):
|
| 377 |
+
if isinstance(grad, torch.Tensor) and grad_recv_stage is not None:
|
| 378 |
+
logger.debug(
|
| 379 |
+
"%s Sending gradient to Stage %s: %s",
|
| 380 |
+
self.log_prefix,
|
| 381 |
+
grad_recv_stage,
|
| 382 |
+
grad.size(),
|
| 383 |
+
)
|
| 384 |
+
peer_rank = self.stage_index_to_group_rank[grad_recv_stage]
|
| 385 |
+
peer_global_rank = (
|
| 386 |
+
peer_rank
|
| 387 |
+
if self.group is None
|
| 388 |
+
else dist.get_global_rank(self.group, peer_rank)
|
| 389 |
+
) # TODO
|
| 390 |
+
ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group))
|
| 391 |
+
else:
|
| 392 |
+
if not (grad is None and grad_recv_stage is None):
|
| 393 |
+
raise RuntimeError(
|
| 394 |
+
f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} "
|
| 395 |
+
f"and is expecting to send gradients to stage {grad_recv_stage}"
|
| 396 |
+
)
|
| 397 |
+
return ops
|
| 398 |
+
|
| 399 |
+
def clear_runtime_states(self) -> None:
|
| 400 |
+
"""
|
| 401 |
+
Clear runtime states of the stage.
|
| 402 |
+
"""
|
| 403 |
+
# map microbatch ID to list of forward tensor args
|
| 404 |
+
self.fwd_cache.clear()
|
| 405 |
+
# Caching chunk outputs for final output merge or reduction
|
| 406 |
+
self.output_chunks.clear()
|
| 407 |
+
# Reset bwd chunk counter
|
| 408 |
+
self._seen_bwd_chunks = 0
|
| 409 |
+
|
| 410 |
+
# Clear grad of input buffers in between schedule steps. This is because
|
| 411 |
+
# `torch.autograd.backward()` will accumulate gradients into leaf
|
| 412 |
+
# tensors by default. For gradients to pass back to previous stages, we
|
| 413 |
+
# don't want such accumulation.
|
| 414 |
+
for recv_tuple in self.args_recv_info.values(): # iterate over all chunks
|
| 415 |
+
for a in recv_tuple: # iterate over all input args
|
| 416 |
+
if isinstance(a, _RecvInfo):
|
| 417 |
+
# Set to None is the newer and recommended way to clear grads, compared to `zero_()`.
|
| 418 |
+
# See https://github.com/pytorch/pytorch/pull/92731
|
| 419 |
+
a.buffer.grad = None
|
| 420 |
+
|
| 421 |
+
def _map_tensor_from_recv_info(
|
| 422 |
+
self,
|
| 423 |
+
recv_infos: Tuple[InputInfo, ...],
|
| 424 |
+
):
|
| 425 |
+
"""
|
| 426 |
+
Map tensors from recv infos to a list.
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
def get_recv_tensor(info):
|
| 430 |
+
if isinstance(info, _RecvInfo):
|
| 431 |
+
return info.buffer
|
| 432 |
+
else:
|
| 433 |
+
raise AssertionError(f"Expected _RecvInfo but got {type(info)}")
|
| 434 |
+
|
| 435 |
+
tensors = map_aggregate(
|
| 436 |
+
recv_infos,
|
| 437 |
+
get_recv_tensor,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
return tensors
|
| 441 |
+
|
| 442 |
+
def _retrieve_recv_activations(self, fwd_chunk_id: int):
|
| 443 |
+
"""
|
| 444 |
+
Retrieve the activations received for the current stage during forward.
|
| 445 |
+
"""
|
| 446 |
+
recv_infos = self.args_recv_info[fwd_chunk_id]
|
| 447 |
+
activations = self._map_tensor_from_recv_info(recv_infos)
|
| 448 |
+
return activations
|
| 449 |
+
|
| 450 |
+
def _retrieve_recv_grads(
|
| 451 |
+
self,
|
| 452 |
+
bwd_chunk_id: int,
|
| 453 |
+
):
|
| 454 |
+
"""
|
| 455 |
+
Retrieve the gradients received for the current stage during backward.
|
| 456 |
+
"""
|
| 457 |
+
recv_infos = self.grad_recv_info[bwd_chunk_id]
|
| 458 |
+
grads = self._map_tensor_from_recv_info(recv_infos)
|
| 459 |
+
return grads
|
| 460 |
+
|
| 461 |
+
def forward_maybe_with_nosync(self, *args, **kwargs):
|
| 462 |
+
# If submod is wrapped with DDP, we use the `no_sync` context manager to
|
| 463 |
+
# avoid gradient all-reduce per microbatch
|
| 464 |
+
if isinstance(self.submod, DistributedDataParallel):
|
| 465 |
+
with self.submod.no_sync(): # type: ignore[operator]
|
| 466 |
+
out_val = self.submod(*args, **kwargs)
|
| 467 |
+
else:
|
| 468 |
+
out_val = self.submod(*args, **kwargs)
|
| 469 |
+
return out_val
|
| 470 |
+
|
| 471 |
+
def backward_maybe_with_nosync(self, backward_type, bwd_kwargs: Dict):
|
| 472 |
+
"""
|
| 473 |
+
Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the
|
| 474 |
+
other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but
|
| 475 |
+
there are additional state-variables and performance considerations depending on the data parallelism used.
|
| 476 |
+
This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries.
|
| 477 |
+
"""
|
| 478 |
+
full_backward = bwd_kwargs["full_backward"]
|
| 479 |
+
if full_backward:
|
| 480 |
+
last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator]
|
| 481 |
+
else:
|
| 482 |
+
# For backwards are split into weight and input, we will see twice as many bwd_chunks
|
| 483 |
+
last_backward = self._seen_bwd_chunks == 2 * self.chunks - 1 # type: ignore[operator]
|
| 484 |
+
|
| 485 |
+
def perform_backward(backward_type):
|
| 486 |
+
if backward_type == "full":
|
| 487 |
+
return lambda: stage_backward(
|
| 488 |
+
bwd_kwargs["stage_output"],
|
| 489 |
+
bwd_kwargs["output_grads"],
|
| 490 |
+
bwd_kwargs["input_values"],
|
| 491 |
+
)
|
| 492 |
+
elif backward_type == "input":
|
| 493 |
+
return lambda: stage_backward_input(
|
| 494 |
+
bwd_kwargs["stage_output"],
|
| 495 |
+
bwd_kwargs["output_grads"],
|
| 496 |
+
bwd_kwargs["input_values"],
|
| 497 |
+
self.submod.parameters(),
|
| 498 |
+
)
|
| 499 |
+
elif backward_type == "weight":
|
| 500 |
+
return lambda: stage_backward_weight(
|
| 501 |
+
self.submod.parameters(), bwd_kwargs["param_groups"]
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
raise RuntimeError(f"Unknown backward type: {backward_type}")
|
| 505 |
+
|
| 506 |
+
# If submod is wrapped by DDP
|
| 507 |
+
if isinstance(self.submod, DistributedDataParallel):
|
| 508 |
+
if last_backward:
|
| 509 |
+
# Last chunk, prepare for gradient reduction
|
| 510 |
+
# HACK: reaching into DDP implementation details here. Is there a better way?
|
| 511 |
+
self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator]
|
| 512 |
+
list(
|
| 513 |
+
torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined]
|
| 514 |
+
bwd_kwargs["stage_output"]
|
| 515 |
+
)
|
| 516 |
+
)
|
| 517 |
+
)
|
| 518 |
+
result = perform_backward(backward_type)()
|
| 519 |
+
else:
|
| 520 |
+
with self.submod.no_sync(): # type: ignore[operator]
|
| 521 |
+
result = perform_backward(backward_type)()
|
| 522 |
+
# If submod is a FSDP module
|
| 523 |
+
elif isinstance(self.submod, FSDPModule):
|
| 524 |
+
self.submod.set_is_last_backward(False)
|
| 525 |
+
self.submod.set_reshard_after_backward(False)
|
| 526 |
+
self.submod.set_requires_gradient_sync(False)
|
| 527 |
+
result = perform_backward(backward_type)()
|
| 528 |
+
if last_backward:
|
| 529 |
+
# Manually call post backward for FSDP
|
| 530 |
+
def run_post_backward(fsdp_module: FSDPModule) -> None:
|
| 531 |
+
fsdp_module.set_is_last_backward(True)
|
| 532 |
+
fsdp_module.set_reshard_after_backward(True)
|
| 533 |
+
fsdp_module.set_requires_gradient_sync(True)
|
| 534 |
+
fsdp_state = fully_shard.state(fsdp_module)
|
| 535 |
+
for state in fsdp_state._state_ctx.all_states:
|
| 536 |
+
if state._fsdp_param_group:
|
| 537 |
+
state._fsdp_param_group.post_backward()
|
| 538 |
+
|
| 539 |
+
run_post_backward(self.submod)
|
| 540 |
+
else:
|
| 541 |
+
# Non-DP submodule, regular backward
|
| 542 |
+
result = perform_backward(backward_type)()
|
| 543 |
+
|
| 544 |
+
self._seen_bwd_chunks += 1
|
| 545 |
+
|
| 546 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 547 |
+
# for stage_backward_input()
|
| 548 |
+
grads, param_groups = result
|
| 549 |
+
else:
|
| 550 |
+
grads, param_groups = result, None
|
| 551 |
+
|
| 552 |
+
return grads, param_groups
|
| 553 |
+
|
| 554 |
+
def forward_one_chunk(
|
| 555 |
+
self,
|
| 556 |
+
fwd_chunk_id: int,
|
| 557 |
+
args: Tuple[Any, ...],
|
| 558 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 559 |
+
):
|
| 560 |
+
"""
|
| 561 |
+
Perform forward pass on the stage with one microbatch.
|
| 562 |
+
`args` and `kwargs` are the inputs from *external* to this stage. They
|
| 563 |
+
applies only to the first stage in most cases.
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
if self.is_first:
|
| 567 |
+
# First stage doesn't need to receive anything
|
| 568 |
+
composite_args = args
|
| 569 |
+
composite_kwargs = kwargs or {}
|
| 570 |
+
else:
|
| 571 |
+
# Receive activations for this chunk
|
| 572 |
+
# Activations only come in args form
|
| 573 |
+
composite_args = self._retrieve_recv_activations(fwd_chunk_id)
|
| 574 |
+
composite_kwargs = {}
|
| 575 |
+
|
| 576 |
+
self._validate_fwd_input(args, kwargs)
|
| 577 |
+
|
| 578 |
+
# Compute forward
|
| 579 |
+
try:
|
| 580 |
+
output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
|
| 581 |
+
|
| 582 |
+
except Exception as e:
|
| 583 |
+
exc_msg = f"""
|
| 584 |
+
{self.log_prefix} failed to run forward:
|
| 585 |
+
args: {map_debug_info(composite_args)}
|
| 586 |
+
kwargs: {map_debug_info(composite_kwargs)}
|
| 587 |
+
"""
|
| 588 |
+
raise RuntimeError(exc_msg) from e
|
| 589 |
+
|
| 590 |
+
if type(output) is list:
|
| 591 |
+
# HACK: this is a hacky workaround for the fact that export creates
|
| 592 |
+
# output in list format
|
| 593 |
+
output = tuple(output)
|
| 594 |
+
|
| 595 |
+
# Unify output form to tuple for easy correspondance with
|
| 596 |
+
# `act_send_info`
|
| 597 |
+
output_tuple = output if type(output) is tuple else (output,)
|
| 598 |
+
# Prepare for final output merge or reduction
|
| 599 |
+
self.output_chunks.append(output)
|
| 600 |
+
|
| 601 |
+
# Save activations and inputs for backward
|
| 602 |
+
flat_args = flatten_args(composite_args)
|
| 603 |
+
flat_kwargs = flatten_args(composite_kwargs)
|
| 604 |
+
flatten_input_tensors = flat_args + flat_kwargs
|
| 605 |
+
self.fwd_cache[fwd_chunk_id] = (
|
| 606 |
+
output_tuple, # stage_output
|
| 607 |
+
flatten_input_tensors, # input_values
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
logger.debug(
|
| 611 |
+
"%s Forwarded chunk %s, outputs: %s",
|
| 612 |
+
self.log_prefix,
|
| 613 |
+
fwd_chunk_id,
|
| 614 |
+
map_debug_info(output),
|
| 615 |
+
)
|
| 616 |
+
self._validate_fwd_outputs(output_tuple)
|
| 617 |
+
return output
|
| 618 |
+
|
| 619 |
+
def backward_one_chunk(
|
| 620 |
+
self, bwd_chunk_id: int, loss=None, full_backward: bool = True
|
| 621 |
+
):
|
| 622 |
+
"""
|
| 623 |
+
Perform backward pass on the module.
|
| 624 |
+
This should only be called once per microbatch.
|
| 625 |
+
|
| 626 |
+
If full_backward is True (the default), the full backward pass including weight and input gradients will be run,
|
| 627 |
+
and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id.
|
| 628 |
+
|
| 629 |
+
If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time,
|
| 630 |
+
and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward.
|
| 631 |
+
"""
|
| 632 |
+
self._check_chunk_id(bwd_chunk_id)
|
| 633 |
+
|
| 634 |
+
(
|
| 635 |
+
stage_output,
|
| 636 |
+
input_values,
|
| 637 |
+
) = self.fwd_cache.pop(bwd_chunk_id)
|
| 638 |
+
|
| 639 |
+
# Compute backward
|
| 640 |
+
if self.is_last:
|
| 641 |
+
# Last stage computes gradients from loss and has no gradients from
|
| 642 |
+
# next stage
|
| 643 |
+
bwd_kwargs = {
|
| 644 |
+
"stage_output": loss,
|
| 645 |
+
"output_grads": None,
|
| 646 |
+
"input_values": input_values,
|
| 647 |
+
}
|
| 648 |
+
else:
|
| 649 |
+
# Otherwise, receive gradients from next stage
|
| 650 |
+
grads_output = self._retrieve_recv_grads(bwd_chunk_id)
|
| 651 |
+
# If an input to the pipeline requires gradient,
|
| 652 |
+
# `torch.autograd.backward` will accumulate the gradient into the
|
| 653 |
+
# `.grad` field of such input
|
| 654 |
+
bwd_kwargs = {
|
| 655 |
+
"stage_output": stage_output,
|
| 656 |
+
"output_grads": grads_output,
|
| 657 |
+
"input_values": input_values,
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
# Save full_backward
|
| 661 |
+
bwd_kwargs["full_backward"] = full_backward
|
| 662 |
+
|
| 663 |
+
# Custom backward function
|
| 664 |
+
if self.dw_builder:
|
| 665 |
+
# TODO: We may want to change our semantics so we are allowed to ignore
|
| 666 |
+
# the 'dw_builder' and call full_backward directly when it is a full_backward op.
|
| 667 |
+
self.grads_input, _ = self.backward_maybe_with_nosync("full", bwd_kwargs)
|
| 668 |
+
if full_backward:
|
| 669 |
+
self.dw_builder()()
|
| 670 |
+
else:
|
| 671 |
+
self.dw_runner[bwd_chunk_id] = self.dw_builder()
|
| 672 |
+
else:
|
| 673 |
+
if full_backward:
|
| 674 |
+
self.grads_input, _ = self.backward_maybe_with_nosync(
|
| 675 |
+
"full", bwd_kwargs
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
# perform the partial backwards for the inputs with a custom backward function
|
| 679 |
+
# when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors
|
| 680 |
+
if isinstance(bwd_kwargs["stage_output"], torch.Tensor):
|
| 681 |
+
bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],)
|
| 682 |
+
|
| 683 |
+
grads_input, param_groups = self.backward_maybe_with_nosync(
|
| 684 |
+
"input", bwd_kwargs
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# TODO: we dont need to save this, add to dw_runner?
|
| 688 |
+
self.backward_state[bwd_chunk_id] = (
|
| 689 |
+
input_values,
|
| 690 |
+
param_groups,
|
| 691 |
+
bwd_kwargs["stage_output"],
|
| 692 |
+
bwd_kwargs["output_grads"],
|
| 693 |
+
)
|
| 694 |
+
self.grads_input = grads_input
|
| 695 |
+
# Save a placeholder for the dw_runner
|
| 696 |
+
self.dw_runner[bwd_chunk_id] = lambda: None
|
| 697 |
+
logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id)
|
| 698 |
+
|
| 699 |
+
def backward_weight_one_chunk(self, bwd_chunk_id: int):
|
| 700 |
+
assert bwd_chunk_id in self.dw_runner, (
|
| 701 |
+
f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}"
|
| 702 |
+
" without first calling `backward_one_chunk(full_backward=False)`"
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
if self.dw_builder is not None:
|
| 706 |
+
self.dw_runner.pop(bwd_chunk_id)()
|
| 707 |
+
else:
|
| 708 |
+
(
|
| 709 |
+
input_values,
|
| 710 |
+
param_groups,
|
| 711 |
+
stage_output,
|
| 712 |
+
output_grads,
|
| 713 |
+
) = self.backward_state.pop(bwd_chunk_id)
|
| 714 |
+
|
| 715 |
+
if self.stage_index != 0:
|
| 716 |
+
bwd_kwargs = {
|
| 717 |
+
"stage_output": stage_output,
|
| 718 |
+
"param_groups": param_groups,
|
| 719 |
+
"full_backward": False,
|
| 720 |
+
}
|
| 721 |
+
weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs)
|
| 722 |
+
else:
|
| 723 |
+
# TODO: figure out a better way to do this:
|
| 724 |
+
# if inputs does not require gradient,
|
| 725 |
+
# then the parameter group will not be fully captured during stage_backward_input
|
| 726 |
+
# in this case, we need call grad directly on the parameters
|
| 727 |
+
# To solve: make input fn do the intersect compute and then finish it off during W
|
| 728 |
+
bwd_kwargs = {
|
| 729 |
+
"stage_output": stage_output,
|
| 730 |
+
"output_grads": output_grads,
|
| 731 |
+
"input_values": input_values,
|
| 732 |
+
"full_backward": False,
|
| 733 |
+
}
|
| 734 |
+
self.backward_maybe_with_nosync("full", bwd_kwargs)
|
| 735 |
+
|
| 736 |
+
def _validate_fwd_input(self, args, kwargs):
|
| 737 |
+
"""Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage."""
|
| 738 |
+
|
| 739 |
+
if self.is_first:
|
| 740 |
+
# TODO why is there a separate recv_info for each pipeline chunk?
|
| 741 |
+
# kwen2501: to avoid passing a `fwd_chunk_id` to this function, we
|
| 742 |
+
# check all chunks against args_recv_info[0]
|
| 743 |
+
expected_args = self.args_recv_info[0]
|
| 744 |
+
else:
|
| 745 |
+
# We don't check inputs for non-0 stages assuming they don't accept
|
| 746 |
+
# user inputs in canonical pipeline scenarios
|
| 747 |
+
return
|
| 748 |
+
|
| 749 |
+
if len(kwargs):
|
| 750 |
+
# TODO- need a mapping of kwarg to position in self.args_recv_info
|
| 751 |
+
# without it, we just validate shapes for args and ignore kwargs
|
| 752 |
+
expected_args = expected_args[: len(expected_args) - len(kwargs)]
|
| 753 |
+
|
| 754 |
+
# TODO- need a mapping of kwarg to position in self.args_recv_info
|
| 755 |
+
# maybe it's impossible to tell whether the len mismatches because
|
| 756 |
+
# (a) the user passed an extra arg or missed an arg
|
| 757 |
+
# (b) the user did not pass a kwarg, which has a default value baked into expected_args
|
| 758 |
+
expected_tensors_meta = [
|
| 759 |
+
e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer
|
| 760 |
+
for e in expected_args
|
| 761 |
+
]
|
| 762 |
+
validate_tensors_metadata(
|
| 763 |
+
f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
def _validate_fwd_outputs(self, outputs: Tuple[torch.Tensor, ...]):
|
| 767 |
+
"""Raises a RuntimeError if this stage produces an output of unexpected shape/dtype.
|
| 768 |
+
Most likely, this could be cause either by incorrect user specification of output shapes, or becuase
|
| 769 |
+
shape inference was done on the original model but then at runtime the model is wrapped with something like
|
| 770 |
+
mixed precision which changes output dtype.
|
| 771 |
+
"""
|
| 772 |
+
expected_tensors_meta = self.get_outputs_meta()
|
| 773 |
+
validate_tensors_metadata(
|
| 774 |
+
f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
class _PipelineStage(_PipelineStageBase):
|
| 779 |
+
def __init__(
|
| 780 |
+
self,
|
| 781 |
+
stage_module: torch.nn.Module,
|
| 782 |
+
stage_index: int,
|
| 783 |
+
pipe_info: PipeInfo,
|
| 784 |
+
device: torch.device,
|
| 785 |
+
group: Optional[dist.ProcessGroup] = None,
|
| 786 |
+
):
|
| 787 |
+
"""
|
| 788 |
+
Create a pipeline stage given a stage_module to be wrapped by this stage
|
| 789 |
+
and a `pipe_info` describing the stage relationship of the pipeline.
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
stage_module (torch.nn.Module): the module to be wrapped by this stage
|
| 793 |
+
stage_index (int): the index of this stage in the pipeline
|
| 794 |
+
pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()`
|
| 795 |
+
device (torch.device): the device to be used by this stage
|
| 796 |
+
group (Optional[dist.ProcessGroup]): the process group to be used by this stage
|
| 797 |
+
"""
|
| 798 |
+
_PipelineStageBase.__init__(
|
| 799 |
+
self,
|
| 800 |
+
stage_module,
|
| 801 |
+
stage_index,
|
| 802 |
+
pipe_info.num_stages,
|
| 803 |
+
device,
|
| 804 |
+
group,
|
| 805 |
+
)
|
| 806 |
+
self.pipe_info = pipe_info
|
| 807 |
+
|
| 808 |
+
# Find stage nodes in graph
|
| 809 |
+
submod_nodes = [
|
| 810 |
+
node for node in pipe_info.graph.nodes if node.op == "call_module"
|
| 811 |
+
]
|
| 812 |
+
if len(submod_nodes) != self.num_stages:
|
| 813 |
+
raise AssertionError(
|
| 814 |
+
f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}"
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
# Find my stage node in graph
|
| 818 |
+
self.node = submod_nodes[self.stage_index]
|
| 819 |
+
self.name = self.node.name
|
| 820 |
+
logger.info(
|
| 821 |
+
"[%s] Creating PipelineStage %s for %s",
|
| 822 |
+
self.group_rank,
|
| 823 |
+
stage_index,
|
| 824 |
+
self.name,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
# Create mapping from stage name to stage index
|
| 828 |
+
self.submod_to_stage_index: Dict[str, int] = {}
|
| 829 |
+
for i, node in enumerate(submod_nodes):
|
| 830 |
+
self.submod_to_stage_index.setdefault(node.name, i)
|
| 831 |
+
|
| 832 |
+
# Cast submodule to device
|
| 833 |
+
self._move_submod_to_device()
|
| 834 |
+
|
| 835 |
+
def _move_submod_to_device(self):
|
| 836 |
+
# Move submodule to indicated device if possible
|
| 837 |
+
# Note: we cannot move meta module to real devices because meta tensors
|
| 838 |
+
# do not support to() method. One needs to do an in-place tensor swap in
|
| 839 |
+
# that case.
|
| 840 |
+
has_meta_param = any(
|
| 841 |
+
isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters()
|
| 842 |
+
)
|
| 843 |
+
if has_meta_param:
|
| 844 |
+
logger.debug("%s Found meta parameters!", self.log_prefix)
|
| 845 |
+
else:
|
| 846 |
+
self.submod.to(self.device)
|
| 847 |
+
|
| 848 |
+
def _prepare_forward_infra(self, num_microbatches: int):
|
| 849 |
+
"""
|
| 850 |
+
Create send/recv infrastructures for activations (during forward)
|
| 851 |
+
"""
|
| 852 |
+
# Flag per chunk to keep track of whether we have set `requires_grad`
|
| 853 |
+
# for receive buffers. Format: {chunk : Boolean}
|
| 854 |
+
for chunk in range(num_microbatches):
|
| 855 |
+
self.args_recv_info[chunk] = self._create_act_recv_info()
|
| 856 |
+
self.set_requires_grad[chunk] = False
|
| 857 |
+
|
| 858 |
+
# Send info during forward for each activation
|
| 859 |
+
self.act_send_info = self._create_act_send_info()
|
| 860 |
+
|
| 861 |
+
def get_stage_index_of_submod(
|
| 862 |
+
self,
|
| 863 |
+
submod_name: str,
|
| 864 |
+
):
|
| 865 |
+
"""
|
| 866 |
+
Given a submodule name, return the stage index of the submodule.
|
| 867 |
+
"""
|
| 868 |
+
if submod_name not in self.submod_to_stage_index:
|
| 869 |
+
raise AssertionError(f"Stage id of {submod_name} not found")
|
| 870 |
+
|
| 871 |
+
return self.submod_to_stage_index[submod_name]
|
| 872 |
+
|
| 873 |
+
def _create_act_recv_info(
|
| 874 |
+
self,
|
| 875 |
+
):
|
| 876 |
+
"""
|
| 877 |
+
Create a tuple of `_RecvInfo` for inputs to the stage.
|
| 878 |
+
"""
|
| 879 |
+
|
| 880 |
+
def create_recv_tensor(placeholder, arg_node):
|
| 881 |
+
"""
|
| 882 |
+
Create a receive buffer for a placeholder.
|
| 883 |
+
"""
|
| 884 |
+
example_value = placeholder.meta["val"]
|
| 885 |
+
if arg_node.op == "placeholder":
|
| 886 |
+
# This is a root level placeholder, thus an input argument to the entire model.
|
| 887 |
+
# We are likely at stage 0, hence no need to create a receive buffer.
|
| 888 |
+
return _RootArgPlaceholder(example_value)
|
| 889 |
+
|
| 890 |
+
# Figure out the source stage of this input
|
| 891 |
+
while arg_node.target is operator.getitem:
|
| 892 |
+
# If the input is a getitem, we need to go deeper
|
| 893 |
+
arg_node = arg_node.args[0]
|
| 894 |
+
|
| 895 |
+
assert (
|
| 896 |
+
arg_node.op == "call_module"
|
| 897 |
+
), f"Expecting call_module, got {arg_node.op}"
|
| 898 |
+
src_stage = self.get_stage_index_of_submod(arg_node.name)
|
| 899 |
+
|
| 900 |
+
# Create a receive buffer for this placeholder
|
| 901 |
+
logger.debug(
|
| 902 |
+
"%s Creating recv buffer for input '%s' : %s, %s",
|
| 903 |
+
self.log_prefix,
|
| 904 |
+
placeholder.name,
|
| 905 |
+
example_value.shape,
|
| 906 |
+
example_value.dtype,
|
| 907 |
+
)
|
| 908 |
+
buffer = _make_tensor_from_meta(example_value, self.device)
|
| 909 |
+
|
| 910 |
+
return _RecvInfo(
|
| 911 |
+
arg_node.name,
|
| 912 |
+
src_stage,
|
| 913 |
+
buffer,
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
args_recv_info: List[InputInfo] = []
|
| 917 |
+
# Filter out placeholder nodes from `self.submod` (a GraphModule)
|
| 918 |
+
placeholders = filter(
|
| 919 |
+
lambda node: node.op == "placeholder", self.submod.graph.nodes
|
| 920 |
+
)
|
| 921 |
+
# `placeholders` are nodes internal to submod.
|
| 922 |
+
# `self.node.args` are dependency nodes in the outer graph.
|
| 923 |
+
# The two are 1:1.
|
| 924 |
+
for placeholder, arg_node in zip(placeholders, self.node.args):
|
| 925 |
+
# Create a receive buffer for this placeholder
|
| 926 |
+
recv_info = create_recv_tensor(placeholder, arg_node)
|
| 927 |
+
args_recv_info.append(recv_info)
|
| 928 |
+
|
| 929 |
+
logger.debug(
|
| 930 |
+
"%s Activation recv / args info: %s", self.log_prefix, args_recv_info
|
| 931 |
+
)
|
| 932 |
+
# `args` is a Tuple, hence we will return a Tuple[InputInfo]
|
| 933 |
+
return tuple(args_recv_info)
|
| 934 |
+
|
| 935 |
+
def find_dst_rank(
|
| 936 |
+
self,
|
| 937 |
+
user: fx.Node,
|
| 938 |
+
) -> Optional[int]:
|
| 939 |
+
"""
|
| 940 |
+
Find the destination rank of a `user` node.
|
| 941 |
+
If the `user` is not a submod, `None` may be returned.
|
| 942 |
+
"""
|
| 943 |
+
if user.op == "call_module":
|
| 944 |
+
# User is a stage (`call_module`)
|
| 945 |
+
return self.get_stage_index_of_submod(user.name)
|
| 946 |
+
else:
|
| 947 |
+
# - If user.op == "output":
|
| 948 |
+
# No need to send back to rank 0
|
| 949 |
+
# - If user.target is stage_backward:
|
| 950 |
+
# No need to send assuming submod output is stored locally or
|
| 951 |
+
# should be re-calucated in case of activation checkpointing
|
| 952 |
+
return None
|
| 953 |
+
|
| 954 |
+
def _create_act_send_info(self):
|
| 955 |
+
"""
|
| 956 |
+
Create a dict of send info for activations.
|
| 957 |
+
The dict is of the form:
|
| 958 |
+
{
|
| 959 |
+
output_index: [dst_rank_0, dst_rank_1, ...],
|
| 960 |
+
...
|
| 961 |
+
}
|
| 962 |
+
where the list of `dst_rank`s covers the case where an output value may
|
| 963 |
+
be consumed by multiple stages.
|
| 964 |
+
"""
|
| 965 |
+
# Output index: List of receiver ranks
|
| 966 |
+
act_send_info: Dict[int, List] = {}
|
| 967 |
+
out_idx = 0
|
| 968 |
+
|
| 969 |
+
for user in self.node.users:
|
| 970 |
+
if user.target is operator.getitem:
|
| 971 |
+
# Recursively find the real destination
|
| 972 |
+
gi_dsts = act_send_info.setdefault(out_idx, [])
|
| 973 |
+
for gi_user in user.users:
|
| 974 |
+
dst_rank = self.find_dst_rank(gi_user)
|
| 975 |
+
if dst_rank is not None:
|
| 976 |
+
gi_dsts.append(dst_rank)
|
| 977 |
+
# Next `getitem` will point to the next output index
|
| 978 |
+
out_idx += 1
|
| 979 |
+
else:
|
| 980 |
+
# In case of single output value, `out_idx` will not increase
|
| 981 |
+
dsts = act_send_info.setdefault(out_idx, [])
|
| 982 |
+
dst_rank = self.find_dst_rank(user)
|
| 983 |
+
if dst_rank is not None:
|
| 984 |
+
dsts.append(dst_rank)
|
| 985 |
+
|
| 986 |
+
output_node = self._get_output_node()
|
| 987 |
+
output_vals: Tuple[torch.Tensor] = tuple(
|
| 988 |
+
v.meta["val"] for v in flatten_args(output_node.args)
|
| 989 |
+
)
|
| 990 |
+
self._configure_outputs_meta(output_vals)
|
| 991 |
+
|
| 992 |
+
logger.debug("%s Send info: %s", self.log_prefix, act_send_info)
|
| 993 |
+
return act_send_info
|
| 994 |
+
|
| 995 |
+
def _get_output_node(self):
|
| 996 |
+
output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"]
|
| 997 |
+
assert len(output_nodes) == 1
|
| 998 |
+
output_node = output_nodes[0]
|
| 999 |
+
return output_node
|
| 1000 |
+
|
| 1001 |
+
def _create_grad_recv_info(
|
| 1002 |
+
self,
|
| 1003 |
+
act_send_info: Dict,
|
| 1004 |
+
) -> Tuple[_RecvInfo, ...]:
|
| 1005 |
+
"""
|
| 1006 |
+
Create a tuple of `_RecvInfo` for gradients.
|
| 1007 |
+
"""
|
| 1008 |
+
# Dict[output_index, _RecvInfo]
|
| 1009 |
+
grad_recv_info: Dict[int, _RecvInfo] = {}
|
| 1010 |
+
output_node = self._get_output_node()
|
| 1011 |
+
|
| 1012 |
+
# The output node may take multiple args, meaning the submod having multiple output values.
|
| 1013 |
+
output_vals = flatten_args(output_node.args)
|
| 1014 |
+
|
| 1015 |
+
for out_idx, dst_list in act_send_info.items():
|
| 1016 |
+
if not dst_list:
|
| 1017 |
+
# No actual receiver for activation so no grad coming back
|
| 1018 |
+
continue
|
| 1019 |
+
|
| 1020 |
+
output = output_vals[out_idx]
|
| 1021 |
+
example_value = output.meta["val"]
|
| 1022 |
+
logger.debug(
|
| 1023 |
+
f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004
|
| 1024 |
+
f": {example_value.shape}, {example_value.dtype}"
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
# TODO: otherwise needs grad accumulation
|
| 1028 |
+
assert len(dst_list) == 1, "Backward of skip connections not supported yet"
|
| 1029 |
+
grad_src = dst_list[0]
|
| 1030 |
+
grad_recv_info[out_idx] = _RecvInfo(
|
| 1031 |
+
f"{grad_src}", # noqa: G004
|
| 1032 |
+
grad_src,
|
| 1033 |
+
_make_tensor_from_meta(example_value, self.device),
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# Convert to tuple for convenience in get_ops and retrieve tensor
|
| 1037 |
+
grad_recv_info_tuple = tuple(grad_recv_info.values())
|
| 1038 |
+
logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple)
|
| 1039 |
+
return grad_recv_info_tuple
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
# A helper function to create a pipeline stage based on traced pipeline information
|
| 1043 |
+
def build_stage(
|
| 1044 |
+
stage_module: torch.nn.Module,
|
| 1045 |
+
stage_index: int,
|
| 1046 |
+
pipe_info: PipeInfo,
|
| 1047 |
+
device: torch.device,
|
| 1048 |
+
group: Optional[dist.ProcessGroup] = None,
|
| 1049 |
+
) -> _PipelineStage:
|
| 1050 |
+
"""
|
| 1051 |
+
Create a pipeline stage given a stage_module to be wrapped by this stage
|
| 1052 |
+
and pipeline information.
|
| 1053 |
+
|
| 1054 |
+
Args:
|
| 1055 |
+
stage_module (torch.nn.Module): the module to be wrapped by this stage
|
| 1056 |
+
stage_index (int): the index of this stage in the pipeline
|
| 1057 |
+
pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()`
|
| 1058 |
+
device (torch.device): the device to be used by this stage
|
| 1059 |
+
group (Optional[dist.ProcessGroup]): the process group to be used by this stage
|
| 1060 |
+
|
| 1061 |
+
Returns:
|
| 1062 |
+
_PipelineStage: a pipeline stage that can run with `PipelineSchedules`.
|
| 1063 |
+
"""
|
| 1064 |
+
return _PipelineStage(
|
| 1065 |
+
stage_module,
|
| 1066 |
+
stage_index,
|
| 1067 |
+
pipe_info,
|
| 1068 |
+
device,
|
| 1069 |
+
group,
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
|
| 1073 |
+
# Manual PipelineStage functions and definition
|
| 1074 |
+
|
| 1075 |
+
METADATA_TENSOR_LEN = 100
|
| 1076 |
+
PLACEHOLDER_VAL = -1
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
def _create_empty_tensors(
|
| 1080 |
+
tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device
|
| 1081 |
+
) -> List[torch.Tensor]:
|
| 1082 |
+
"""
|
| 1083 |
+
Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s),
|
| 1084 |
+
and places them on the specified device.
|
| 1085 |
+
Args:
|
| 1086 |
+
tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s).
|
| 1087 |
+
device (torch.device): The device where the new tensors will be placed.
|
| 1088 |
+
Returns:
|
| 1089 |
+
List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s).
|
| 1090 |
+
"""
|
| 1091 |
+
if isinstance(tensor, torch.Tensor):
|
| 1092 |
+
return [torch.empty_like(tensor, device=device)]
|
| 1093 |
+
elif isinstance(tensor, (list, tuple)):
|
| 1094 |
+
return [torch.empty_like(t, device=device) for t in tensor]
|
| 1095 |
+
raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors")
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
def _create_metadata_tensor(
|
| 1099 |
+
tensors: Optional[List[torch.Tensor]] = None,
|
| 1100 |
+
device: Optional[torch.device] = torch.device("cpu"),
|
| 1101 |
+
) -> torch.Tensor:
|
| 1102 |
+
"""
|
| 1103 |
+
Create a metadata tensor that can be sent over the wire.
|
| 1104 |
+
This tensor contains the number of dimensions and the shape of each tensor being sent.
|
| 1105 |
+
|
| 1106 |
+
The data is of format [num_dims, dim1, dim2, ...].
|
| 1107 |
+
If the tensor is None, a tensor of only placeholder values will be returned.
|
| 1108 |
+
|
| 1109 |
+
Inputs:
|
| 1110 |
+
tensors: A list of tensors, the tensors will converted into its shape dimensions and
|
| 1111 |
+
these dimensions will be concatenated.
|
| 1112 |
+
device: The device where the metadata tensor will be created.
|
| 1113 |
+
If the tensor is None, then this tensor will contain PLACEHOLDER_VALs.
|
| 1114 |
+
|
| 1115 |
+
"""
|
| 1116 |
+
metadata_tensor = torch.full(
|
| 1117 |
+
(METADATA_TENSOR_LEN,),
|
| 1118 |
+
PLACEHOLDER_VAL,
|
| 1119 |
+
dtype=torch.int32,
|
| 1120 |
+
device=device,
|
| 1121 |
+
)
|
| 1122 |
+
if tensors:
|
| 1123 |
+
# Create a list of tensors containing the number of dimensions and the shape of each tensor
|
| 1124 |
+
data = [
|
| 1125 |
+
# data is of format [num_dims, dim1, dim2, ...]
|
| 1126 |
+
torch.tensor(
|
| 1127 |
+
[len(tensor.shape)] + list(tensor.shape),
|
| 1128 |
+
dtype=torch.int32,
|
| 1129 |
+
device=device,
|
| 1130 |
+
)
|
| 1131 |
+
for tensor in tensors
|
| 1132 |
+
]
|
| 1133 |
+
# Concatenate the data into a single tensor
|
| 1134 |
+
data_tensor = torch.cat(data)
|
| 1135 |
+
dt_shape = data_tensor.shape[0]
|
| 1136 |
+
if dt_shape > METADATA_TENSOR_LEN:
|
| 1137 |
+
raise ValueError(
|
| 1138 |
+
f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})."
|
| 1139 |
+
)
|
| 1140 |
+
metadata_tensor[:dt_shape] = data_tensor
|
| 1141 |
+
return metadata_tensor
|
| 1142 |
+
|
| 1143 |
+
|
| 1144 |
+
def _extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]:
|
| 1145 |
+
"""
|
| 1146 |
+
Extract the number of dimensions and the shape of each tensor from a metadata tensor.
|
| 1147 |
+
"""
|
| 1148 |
+
metadata: List[torch.Size] = []
|
| 1149 |
+
i = 0
|
| 1150 |
+
while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL:
|
| 1151 |
+
num_dims = int(tensor[i].item())
|
| 1152 |
+
shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist())
|
| 1153 |
+
metadata.append(shape)
|
| 1154 |
+
i += num_dims + 1
|
| 1155 |
+
return metadata
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
def _get_stage_shapes(
|
| 1159 |
+
stage_modules: List[nn.Module],
|
| 1160 |
+
stage_ids: List[int],
|
| 1161 |
+
num_stages: int,
|
| 1162 |
+
rank: int,
|
| 1163 |
+
world_size: int,
|
| 1164 |
+
device: torch.device,
|
| 1165 |
+
microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
| 1166 |
+
):
|
| 1167 |
+
"""
|
| 1168 |
+
Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of
|
| 1169 |
+
virtual pipelining) and returns the shape of the inputs and outputs of the module.
|
| 1170 |
+
Only the first stage must pass in a microbatch.
|
| 1171 |
+
|
| 1172 |
+
Each rank must call _get_stage_shapes or the program will hang.
|
| 1173 |
+
|
| 1174 |
+
Args:
|
| 1175 |
+
stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any
|
| 1176 |
+
non-interleaved schedules and >1 for any interleaved schedules.
|
| 1177 |
+
stage_ids: The id of the stages assigned to this rank.
|
| 1178 |
+
num_stages: Total number of stages.
|
| 1179 |
+
rank: Rank of the current process.
|
| 1180 |
+
world_size: Number of processes participating in the pipeline.
|
| 1181 |
+
device: Device where the tensors are allocated.
|
| 1182 |
+
|
| 1183 |
+
Returns a dictionary containing the following keys:
|
| 1184 |
+
"inputs": Shape of the inputs to the module
|
| 1185 |
+
"outputs": Shape of the outputs of the module
|
| 1186 |
+
"""
|
| 1187 |
+
|
| 1188 |
+
stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {}
|
| 1189 |
+
for stage_id, model in zip(stage_ids, stage_modules):
|
| 1190 |
+
input_shape_metadata_tensor = _create_metadata_tensor(device=device)
|
| 1191 |
+
# TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1
|
| 1192 |
+
prev_rank = (rank - 1) % world_size
|
| 1193 |
+
next_rank = (rank + 1) % world_size
|
| 1194 |
+
shapes = {}
|
| 1195 |
+
|
| 1196 |
+
# first stage doesn't receive anything and uses a microbatch
|
| 1197 |
+
if stage_id == 0:
|
| 1198 |
+
if microbatch is None:
|
| 1199 |
+
raise RuntimeError("Microbatch is required for first stage")
|
| 1200 |
+
example_fwd_inputs = microbatch
|
| 1201 |
+
if isinstance(example_fwd_inputs, torch.Tensor):
|
| 1202 |
+
example_fwd_inputs = [example_fwd_inputs]
|
| 1203 |
+
else:
|
| 1204 |
+
# other stages must receive shape information
|
| 1205 |
+
# TODO: send/recv should take a group, rather than use the default group
|
| 1206 |
+
dist.recv(input_shape_metadata_tensor, prev_rank)
|
| 1207 |
+
metadata = _extract_metadata_from_tensor(input_shape_metadata_tensor)
|
| 1208 |
+
example_fwd_inputs = [
|
| 1209 |
+
torch.empty(shape_list, device=device) for shape_list in metadata
|
| 1210 |
+
]
|
| 1211 |
+
shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs]
|
| 1212 |
+
|
| 1213 |
+
# perform forward
|
| 1214 |
+
# TODO: if forward fails raise a more descriptive error explaining which stage failed
|
| 1215 |
+
fwd_outputs = model(*example_fwd_inputs)
|
| 1216 |
+
fwd_outputs = _create_empty_tensors(fwd_outputs, device)
|
| 1217 |
+
shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs]
|
| 1218 |
+
|
| 1219 |
+
# send shape dims
|
| 1220 |
+
if stage_id != num_stages - 1:
|
| 1221 |
+
output_shape_metadata_tensor = _create_metadata_tensor(
|
| 1222 |
+
fwd_outputs, device=device
|
| 1223 |
+
)
|
| 1224 |
+
dist.send(output_shape_metadata_tensor, next_rank)
|
| 1225 |
+
stage_id_to_shapes[stage_id] = shapes
|
| 1226 |
+
logger.info(stage_id_to_shapes)
|
| 1227 |
+
return stage_id_to_shapes
|
| 1228 |
+
|
| 1229 |
+
|
| 1230 |
+
class PipelineStage(_PipelineStageBase):
|
| 1231 |
+
"""
|
| 1232 |
+
A class representing a pipeline stage in a pipeline parallelism setup.
|
| 1233 |
+
This class is created manually by providing a example input (and optionally output)
|
| 1234 |
+
as opposed to the PipelineStage class that is outputed from pipeline().
|
| 1235 |
+
This class extends the `_PipelineStageBase` class and can similarly be used
|
| 1236 |
+
in `PipelineScheule`.
|
| 1237 |
+
|
| 1238 |
+
Args:
|
| 1239 |
+
submodule (nn.Module): The PyTorch module wrapped by this stage.
|
| 1240 |
+
stage_index (int): The ID of this stage.
|
| 1241 |
+
num_stages (int): The total number of stages.
|
| 1242 |
+
device (torch.device): The device where this stage is located.
|
| 1243 |
+
input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule.
|
| 1244 |
+
output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule.
|
| 1245 |
+
group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group.
|
| 1246 |
+
dw_builder: TODO clean up comments
|
| 1247 |
+
"""
|
| 1248 |
+
|
| 1249 |
+
def __init__(
|
| 1250 |
+
self,
|
| 1251 |
+
submodule: nn.Module,
|
| 1252 |
+
stage_index: int,
|
| 1253 |
+
num_stages: int,
|
| 1254 |
+
device: torch.device,
|
| 1255 |
+
input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
|
| 1256 |
+
output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None,
|
| 1257 |
+
group: Optional[dist.ProcessGroup] = None,
|
| 1258 |
+
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
|
| 1259 |
+
):
|
| 1260 |
+
super().__init__(submodule, stage_index, num_stages, device, group, dw_builder)
|
| 1261 |
+
self.submod.to(self.device)
|
| 1262 |
+
# When we materialize the model partition on cuda, we call reset_parameters() if it is available
|
| 1263 |
+
self.inputs: List[torch.Tensor] = []
|
| 1264 |
+
self.outputs: List[torch.Tensor] = []
|
| 1265 |
+
|
| 1266 |
+
self.inputs = _create_empty_tensors(input_args, device)
|
| 1267 |
+
|
| 1268 |
+
if output_args is None:
|
| 1269 |
+
logger.info("output_args not provided, performing forward using input_args")
|
| 1270 |
+
self.outputs = self.submod(*self.inputs)
|
| 1271 |
+
# create buffers for the output so that the data is in the correct
|
| 1272 |
+
# shape in order to use in p2p op (send)
|
| 1273 |
+
self.outputs = _create_empty_tensors(self.outputs, device)
|
| 1274 |
+
else:
|
| 1275 |
+
self.outputs = _create_empty_tensors(output_args, device)
|
| 1276 |
+
|
| 1277 |
+
self._configure_outputs_meta(tuple(self.outputs))
|
| 1278 |
+
|
| 1279 |
+
# these are the buffers used in backwards send/recv, they are allocated later
|
| 1280 |
+
self.outputs_grad: List[torch.Tensor] = []
|
| 1281 |
+
|
| 1282 |
+
def stage_global_rank(peer_rank):
|
| 1283 |
+
return (
|
| 1284 |
+
peer_rank
|
| 1285 |
+
if self.group is None
|
| 1286 |
+
else dist.get_global_rank(self.group, peer_rank)
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size)
|
| 1290 |
+
self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size)
|
| 1291 |
+
|
| 1292 |
+
logger.debug(
|
| 1293 |
+
f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004
|
| 1294 |
+
f"{self.is_last=}, {self.num_stages=}, "
|
| 1295 |
+
f"inputs: {[inp.shape for inp in self.inputs]}, "
|
| 1296 |
+
f"output: {[output.shape for output in self.outputs]}"
|
| 1297 |
+
)
|
| 1298 |
+
|
| 1299 |
+
def _prepare_forward_infra(self, num_microbatches: int) -> None:
|
| 1300 |
+
# Receive info during forward
|
| 1301 |
+
# TODO: create args_recv_info lazily? (same needed for PipelineStage)
|
| 1302 |
+
for chunk_id in range(num_microbatches):
|
| 1303 |
+
self.set_requires_grad[chunk_id] = False
|
| 1304 |
+
if not self.is_first:
|
| 1305 |
+
# We assume that we always receive from stage - 1
|
| 1306 |
+
recv_infos = tuple(
|
| 1307 |
+
[
|
| 1308 |
+
_RecvInfo(
|
| 1309 |
+
f"recv_for_{self.stage_index}_from_{self.stage_index - 1}",
|
| 1310 |
+
self.stage_index - 1,
|
| 1311 |
+
_make_tensor_from_meta(inp, self.device),
|
| 1312 |
+
)
|
| 1313 |
+
for inp in self.inputs
|
| 1314 |
+
]
|
| 1315 |
+
)
|
| 1316 |
+
|
| 1317 |
+
self.args_recv_info[chunk_id] = recv_infos
|
| 1318 |
+
else:
|
| 1319 |
+
self.args_recv_info[chunk_id] = tuple(
|
| 1320 |
+
[_RootArgPlaceholder(i) for i in self.inputs]
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
# Send info during forward for each activation
|
| 1324 |
+
# only need the rank that is being sent to
|
| 1325 |
+
self.act_send_info: Dict[int, List] = {}
|
| 1326 |
+
for idx in range(len(self.outputs)):
|
| 1327 |
+
# We assume we always send to stage + 1
|
| 1328 |
+
if not self.is_last:
|
| 1329 |
+
self.act_send_info[idx] = [self.stage_index + 1]
|
| 1330 |
+
else:
|
| 1331 |
+
self.act_send_info[idx] = []
|
| 1332 |
+
|
| 1333 |
+
def _create_grad_recv_info(
|
| 1334 |
+
self,
|
| 1335 |
+
act_send_info: Dict,
|
| 1336 |
+
) -> Tuple[_RecvInfo, ...]:
|
| 1337 |
+
grad_recv_info: Tuple[_RecvInfo, ...] = ()
|
| 1338 |
+
if not self.is_last:
|
| 1339 |
+
# Receiving gradients from multiple sources is not supported
|
| 1340 |
+
# hence we only take the first destination
|
| 1341 |
+
grad_recv_info = tuple(
|
| 1342 |
+
[
|
| 1343 |
+
_RecvInfo(
|
| 1344 |
+
f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}",
|
| 1345 |
+
dst_list[0],
|
| 1346 |
+
_make_tensor_from_meta(self.outputs[idx], self.device),
|
| 1347 |
+
)
|
| 1348 |
+
for idx, dst_list in act_send_info.items()
|
| 1349 |
+
]
|
| 1350 |
+
)
|
| 1351 |
+
return grad_recv_info
|
| 1352 |
+
|
| 1353 |
+
def _init_p2p_neighbors(self):
|
| 1354 |
+
"""
|
| 1355 |
+
Set up p2p communitors between previous and next stages
|
| 1356 |
+
by sending a dummy tensor.
|
| 1357 |
+
|
| 1358 |
+
If this is used, must be called for all pipeline stages.
|
| 1359 |
+
"""
|
| 1360 |
+
ops = []
|
| 1361 |
+
recv_tensor = torch.zeros(1, device="cuda")
|
| 1362 |
+
send_tensor = torch.ones(1, device="cuda")
|
| 1363 |
+
# forward
|
| 1364 |
+
if not self.is_first:
|
| 1365 |
+
ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_stage, self.group))
|
| 1366 |
+
if not self.is_last:
|
| 1367 |
+
ops.append(dist.P2POp(dist.isend, send_tensor, self.next_stage, self.group))
|
| 1368 |
+
|
| 1369 |
+
# backward
|
| 1370 |
+
if not self.is_first:
|
| 1371 |
+
ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_stage, self.group))
|
| 1372 |
+
if not self.is_last:
|
| 1373 |
+
ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_stage, self.group))
|
| 1374 |
+
|
| 1375 |
+
return True
|
| 1376 |
+
|
| 1377 |
+
|
| 1378 |
+
def _validate_stage_shapes(pipeline_stages: List[PipelineStage]):
|
| 1379 |
+
"""
|
| 1380 |
+
Check that the buffer shapes match between stages was expected by performing an all_gather between
|
| 1381 |
+
all stages.
|
| 1382 |
+
"""
|
| 1383 |
+
if len(pipeline_stages) == 0:
|
| 1384 |
+
raise ValueError("No pipeline stages provided.")
|
| 1385 |
+
|
| 1386 |
+
virtual_pipeline_size = len(pipeline_stages)
|
| 1387 |
+
all_inputs = []
|
| 1388 |
+
all_outputs = []
|
| 1389 |
+
world_size = pipeline_stages[0].group_size
|
| 1390 |
+
num_stages = pipeline_stages[0].num_stages
|
| 1391 |
+
|
| 1392 |
+
# perform all gathers between all stages
|
| 1393 |
+
for virtual_id, stage in enumerate(pipeline_stages):
|
| 1394 |
+
world_size = stage.group_size
|
| 1395 |
+
stage_id: int = stage.stage_index
|
| 1396 |
+
rank = stage.group_rank
|
| 1397 |
+
# check that world_size and num_stages are consistent across all stages
|
| 1398 |
+
if stage.group_size != world_size:
|
| 1399 |
+
raise ValueError(
|
| 1400 |
+
f"Stage id {stage_id} has world size ({stage.group_size}) \
|
| 1401 |
+
which does not match world size ({world_size}) of other stages."
|
| 1402 |
+
)
|
| 1403 |
+
if stage.num_stages != num_stages:
|
| 1404 |
+
raise ValueError(
|
| 1405 |
+
f"Stage id {stage_id} has num stages ({stage.num_stages}) \
|
| 1406 |
+
which does not match num stages ({num_stages}) of other stages."
|
| 1407 |
+
)
|
| 1408 |
+
|
| 1409 |
+
pg_rank = dist.get_rank(stage.group)
|
| 1410 |
+
if rank != pg_rank:
|
| 1411 |
+
raise ValueError(
|
| 1412 |
+
f"Rank {rank} is not equal to process group rank {pg_rank}"
|
| 1413 |
+
)
|
| 1414 |
+
|
| 1415 |
+
if (num_stages := stage.num_stages) % world_size != 0:
|
| 1416 |
+
raise ValueError(
|
| 1417 |
+
f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})"
|
| 1418 |
+
)
|
| 1419 |
+
|
| 1420 |
+
# all gather each ranks inputs
|
| 1421 |
+
tensor_list = [
|
| 1422 |
+
_create_metadata_tensor(device=stage.device)
|
| 1423 |
+
for _ in range(stage.group_size)
|
| 1424 |
+
]
|
| 1425 |
+
expected_inputs = stage.inputs
|
| 1426 |
+
stage_input = _create_metadata_tensor(expected_inputs, device=stage.device)
|
| 1427 |
+
dist.all_gather(tensor_list, stage_input)
|
| 1428 |
+
stage_input_shapes = [
|
| 1429 |
+
_extract_metadata_from_tensor(tensor) for tensor in tensor_list
|
| 1430 |
+
]
|
| 1431 |
+
|
| 1432 |
+
# all gather each ranks outputs
|
| 1433 |
+
tensor_list = [
|
| 1434 |
+
_create_metadata_tensor(device=stage.device)
|
| 1435 |
+
for _ in range(stage.group_size)
|
| 1436 |
+
]
|
| 1437 |
+
expected_outputs = stage.outputs
|
| 1438 |
+
stage_output = _create_metadata_tensor(expected_outputs, device=stage.device)
|
| 1439 |
+
dist.all_gather(tensor_list, stage_output)
|
| 1440 |
+
stage_output_shapes = [
|
| 1441 |
+
_extract_metadata_from_tensor(tensor) for tensor in tensor_list
|
| 1442 |
+
]
|
| 1443 |
+
|
| 1444 |
+
logger.debug(
|
| 1445 |
+
f"Rank: {pg_rank}" # noqa: G004
|
| 1446 |
+
f"Stage id: {stage_id}"
|
| 1447 |
+
f"Stage num stages: {stage.num_stages}"
|
| 1448 |
+
f"Stage rank: {rank}"
|
| 1449 |
+
f"Stage world size: {world_size}"
|
| 1450 |
+
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}" # noqa: G003
|
| 1451 |
+
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}" # noqa: G003
|
| 1452 |
+
)
|
| 1453 |
+
|
| 1454 |
+
all_inputs.extend(stage_input_shapes)
|
| 1455 |
+
all_outputs.extend(stage_output_shapes)
|
| 1456 |
+
|
| 1457 |
+
# log only rank 0's view, they will all be equivalent
|
| 1458 |
+
if pg_rank == 0:
|
| 1459 |
+
logger.info(
|
| 1460 |
+
"all stage inputs: %s \n all stage outputs: %s", all_inputs, all_outputs
|
| 1461 |
+
)
|
| 1462 |
+
|
| 1463 |
+
# Check if the output for stage 0 matches the input at stage 1, and so forth
|
| 1464 |
+
for i in range(virtual_pipeline_size * world_size - 1):
|
| 1465 |
+
if (out := all_outputs[i]) != (inp := all_inputs[i + 1]):
|
| 1466 |
+
raise ValueError(
|
| 1467 |
+
f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}."
|
| 1468 |
+
)
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/__init__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed.tensor._ops # force import all built-in dtensor ops
|
| 5 |
+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401
|
| 6 |
+
from torch.distributed.tensor._api import (
|
| 7 |
+
distribute_module,
|
| 8 |
+
distribute_tensor,
|
| 9 |
+
DTensor,
|
| 10 |
+
empty,
|
| 11 |
+
full,
|
| 12 |
+
ones,
|
| 13 |
+
rand,
|
| 14 |
+
randn,
|
| 15 |
+
zeros,
|
| 16 |
+
)
|
| 17 |
+
from torch.distributed.tensor.placement_types import (
|
| 18 |
+
Partial,
|
| 19 |
+
Placement,
|
| 20 |
+
Replicate,
|
| 21 |
+
Shard,
|
| 22 |
+
)
|
| 23 |
+
from torch.optim.optimizer import (
|
| 24 |
+
_foreach_supported_types as _optim_foreach_supported_types,
|
| 25 |
+
)
|
| 26 |
+
from torch.utils._foreach_utils import (
|
| 27 |
+
_foreach_supported_types as _util_foreach_supported_types,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# All public APIs from dtensor package
|
| 32 |
+
__all__ = [
|
| 33 |
+
"DTensor",
|
| 34 |
+
"distribute_tensor",
|
| 35 |
+
"distribute_module",
|
| 36 |
+
"Shard",
|
| 37 |
+
"Replicate",
|
| 38 |
+
"Partial",
|
| 39 |
+
"Placement",
|
| 40 |
+
"ones",
|
| 41 |
+
"empty",
|
| 42 |
+
"full",
|
| 43 |
+
"rand",
|
| 44 |
+
"randn",
|
| 45 |
+
"zeros",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Append DTensor to the list of supported types for foreach implementation for optimizer
|
| 50 |
+
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
|
| 51 |
+
if DTensor not in _optim_foreach_supported_types:
|
| 52 |
+
_optim_foreach_supported_types.append(DTensor)
|
| 53 |
+
|
| 54 |
+
if DTensor not in _util_foreach_supported_types:
|
| 55 |
+
_util_foreach_supported_types.append(DTensor)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Set namespace for exposed private names
|
| 59 |
+
DTensor.__module__ = "torch.distributed.tensor"
|
| 60 |
+
distribute_tensor.__module__ = "torch.distributed.tensor"
|
| 61 |
+
distribute_module.__module__ = "torch.distributed.tensor"
|
| 62 |
+
ones.__module__ = "torch.distributed.tensor"
|
| 63 |
+
empty.__module__ = "torch.distributed.tensor"
|
| 64 |
+
full.__module__ = "torch.distributed.tensor"
|
| 65 |
+
rand.__module__ = "torch.distributed.tensor"
|
| 66 |
+
randn.__module__ = "torch.distributed.tensor"
|
| 67 |
+
zeros.__module__ = "torch.distributed.tensor"
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_api.py
ADDED
|
@@ -0,0 +1,1231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 4 |
+
import inspect
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import Any, Callable, cast, Optional, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed.tensor._dispatch as op_dispatch
|
| 10 |
+
import torch.distributed.tensor._random as random
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
| 13 |
+
from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast
|
| 14 |
+
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
| 15 |
+
from torch.distributed.tensor._random import (
|
| 16 |
+
is_rng_supported_mesh,
|
| 17 |
+
OffsetBasedRNGTracker,
|
| 18 |
+
)
|
| 19 |
+
from torch.distributed.tensor._redistribute import (
|
| 20 |
+
Redistribute,
|
| 21 |
+
redistribute_local_tensor,
|
| 22 |
+
)
|
| 23 |
+
from torch.distributed.tensor._utils import (
|
| 24 |
+
compute_global_tensor_info,
|
| 25 |
+
compute_local_shape,
|
| 26 |
+
normalize_to_torch_size,
|
| 27 |
+
)
|
| 28 |
+
from torch.distributed.tensor.placement_types import (
|
| 29 |
+
Partial,
|
| 30 |
+
Placement,
|
| 31 |
+
Replicate,
|
| 32 |
+
Shard,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"DTensor",
|
| 38 |
+
"distribute_tensor",
|
| 39 |
+
"distribute_module",
|
| 40 |
+
"ones",
|
| 41 |
+
"empty",
|
| 42 |
+
"full",
|
| 43 |
+
"rand",
|
| 44 |
+
"randn",
|
| 45 |
+
"zeros",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
aten = torch.ops.aten
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# NOTE [Autograd interaction between torch.Tensor]
|
| 52 |
+
#
|
| 53 |
+
# The autograd functions defined below are being used by the public
|
| 54 |
+
# facing APIs (i.e. from_local, to_local) to ensure DTensor to work
|
| 55 |
+
# together with torch.Tensor within the autograd engine. This
|
| 56 |
+
# allows DTensor to only exist on part of the module hierarchy.
|
| 57 |
+
#
|
| 58 |
+
# As an example, we have the a module that consists of submodules
|
| 59 |
+
# A, B, and C, the execution flow would be like:
|
| 60 |
+
# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
|
| 61 |
+
#
|
| 62 |
+
# Suppose I only want to make Module B be a sharded module with
|
| 63 |
+
# DTensor params, the following forward/backward should work:
|
| 64 |
+
#
|
| 65 |
+
# input(torch.Tensor) -> Module A
|
| 66 |
+
# -> DTensor input (from_local) -> Sharded Module B -> DTensor output
|
| 67 |
+
# -> torch.Tensor output (to_local) -> Module C
|
| 68 |
+
#
|
| 69 |
+
# So from_local/to_local must be Autograd functions.
|
| 70 |
+
#
|
| 71 |
+
class _ToTorchTensor(torch.autograd.Function):
|
| 72 |
+
@staticmethod
|
| 73 |
+
def forward( # type: ignore[override]
|
| 74 |
+
ctx,
|
| 75 |
+
input: "DTensor",
|
| 76 |
+
grad_placements: Optional[Sequence[Placement]],
|
| 77 |
+
):
|
| 78 |
+
ctx.dtensor_spec = input._spec
|
| 79 |
+
ctx.grad_placements = grad_placements
|
| 80 |
+
local_tensor = input._local_tensor
|
| 81 |
+
|
| 82 |
+
# We need to return a fresh Tensor object there as autograd metadata
|
| 83 |
+
# will be inplaced into it. So we don't want to pollute the Tensor
|
| 84 |
+
# object stored in the _local_tensor of this DTensor.
|
| 85 |
+
return local_tensor.view_as(local_tensor)
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
|
| 89 |
+
dtensor_spec = ctx.dtensor_spec
|
| 90 |
+
mesh = dtensor_spec.mesh
|
| 91 |
+
grad_placements = ctx.grad_placements
|
| 92 |
+
dtensor_meta = dtensor_spec.tensor_meta
|
| 93 |
+
|
| 94 |
+
_, tensor_stride = compute_global_tensor_info(
|
| 95 |
+
grad_output, mesh, dtensor_spec.placements
|
| 96 |
+
)
|
| 97 |
+
tensor_stride = tuple(tensor_stride)
|
| 98 |
+
grad_placements = grad_placements or dtensor_spec.placements
|
| 99 |
+
grad_spec = DTensorSpec(
|
| 100 |
+
mesh,
|
| 101 |
+
grad_placements,
|
| 102 |
+
tensor_meta=TensorMeta(
|
| 103 |
+
shape=dtensor_meta.shape,
|
| 104 |
+
stride=tensor_stride,
|
| 105 |
+
dtype=dtensor_meta.dtype,
|
| 106 |
+
),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return (
|
| 110 |
+
DTensor(
|
| 111 |
+
grad_output,
|
| 112 |
+
grad_spec,
|
| 113 |
+
requires_grad=grad_output.requires_grad,
|
| 114 |
+
),
|
| 115 |
+
None,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class _FromTorchTensor(torch.autograd.Function):
|
| 120 |
+
@staticmethod
|
| 121 |
+
def forward( # type: ignore[override]
|
| 122 |
+
ctx, # pyre-ignore[2]: Parameter must be annotated.
|
| 123 |
+
input: torch.Tensor,
|
| 124 |
+
device_mesh: DeviceMesh,
|
| 125 |
+
placements: Tuple[Placement, ...],
|
| 126 |
+
run_check: bool,
|
| 127 |
+
shape: Optional[torch.Size] = None,
|
| 128 |
+
stride: Optional[Tuple[int, ...]] = None,
|
| 129 |
+
) -> "DTensor":
|
| 130 |
+
ctx.previous_placement = placements
|
| 131 |
+
ctx.previous_device_mesh = device_mesh
|
| 132 |
+
|
| 133 |
+
if shape and stride:
|
| 134 |
+
tensor_shape, tensor_stride = shape, stride
|
| 135 |
+
elif not shape and not stride:
|
| 136 |
+
# if it's not by default run_check, we assume user is certain that each
|
| 137 |
+
# rank has the same tensor shape, and we just use that to calculate the
|
| 138 |
+
# global shape
|
| 139 |
+
global_shape, global_stride = compute_global_tensor_info(
|
| 140 |
+
input, device_mesh, placements
|
| 141 |
+
)
|
| 142 |
+
tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride)
|
| 143 |
+
else:
|
| 144 |
+
raise RuntimeError(
|
| 145 |
+
f"Found shape:{shape}, stride:{stride}.",
|
| 146 |
+
"Please pass both shape and stride at the same time.",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if device_mesh.get_coordinate() is None:
|
| 150 |
+
# if the global rank is not participating in the device mesh, we
|
| 151 |
+
# simply set the local tensor to an empty tensor
|
| 152 |
+
input = input.new_empty(0, requires_grad=input.requires_grad)
|
| 153 |
+
elif run_check:
|
| 154 |
+
# TODO: support uneven sharding when global shape/stride not passed, by
|
| 155 |
+
# building the global TensorMeta during check_tensor_meta
|
| 156 |
+
check_shape_stride = not shape and not stride
|
| 157 |
+
check_tensor_meta(input, check_shape_stride=check_shape_stride)
|
| 158 |
+
# TODO: See if we need to make this run_check logic
|
| 159 |
+
# have a corresponding backward.
|
| 160 |
+
for idx, placement in enumerate(placements):
|
| 161 |
+
if placement.is_replicate():
|
| 162 |
+
# broadcast rank 0 tensor to all ranks
|
| 163 |
+
# only broadcast if run_check is True
|
| 164 |
+
input = input.contiguous()
|
| 165 |
+
mesh_broadcast(input, device_mesh, mesh_dim=idx)
|
| 166 |
+
|
| 167 |
+
dist_spec = DTensorSpec(
|
| 168 |
+
device_mesh,
|
| 169 |
+
placements,
|
| 170 |
+
tensor_meta=TensorMeta(
|
| 171 |
+
tensor_shape,
|
| 172 |
+
tensor_stride,
|
| 173 |
+
input.dtype,
|
| 174 |
+
),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# We want a fresh Tensor object that shares memory with the input tensor
|
| 178 |
+
dist_tensor = DTensor(
|
| 179 |
+
input.view_as(input),
|
| 180 |
+
dist_spec,
|
| 181 |
+
# requires_grad of the dist tensor depends on if input
|
| 182 |
+
# requires_grad or not
|
| 183 |
+
requires_grad=input.requires_grad,
|
| 184 |
+
)
|
| 185 |
+
return dist_tensor
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def backward(ctx, grad_output: "DTensor"): # type: ignore[override]
|
| 189 |
+
previous_placement = ctx.previous_placement
|
| 190 |
+
previous_device_mesh = ctx.previous_device_mesh
|
| 191 |
+
|
| 192 |
+
# reshard to the placement when creating DistributedTensor
|
| 193 |
+
# so that the gradient layout matches, and we could return
|
| 194 |
+
# local gradients directly
|
| 195 |
+
if grad_output.placements != previous_placement:
|
| 196 |
+
current_spec = grad_output._spec
|
| 197 |
+
target_spec = DTensorSpec(
|
| 198 |
+
previous_device_mesh,
|
| 199 |
+
previous_placement,
|
| 200 |
+
tensor_meta=grad_output._spec.tensor_meta,
|
| 201 |
+
)
|
| 202 |
+
local_tensor = grad_output._local_tensor
|
| 203 |
+
output = redistribute_local_tensor(
|
| 204 |
+
local_tensor, current_spec, target_spec, is_backward=True
|
| 205 |
+
)
|
| 206 |
+
# TODO: return the redistributed local tensor directly without
|
| 207 |
+
# differentiable backward. see if this make sense for all cases.
|
| 208 |
+
return output, None, None, None, None, None
|
| 209 |
+
|
| 210 |
+
# TODO: backward is also differentiable now, add a test
|
| 211 |
+
# to test higher level gradients.
|
| 212 |
+
return grad_output.to_local(), None, None, None, None, None
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class DTensor(torch.Tensor):
|
| 216 |
+
"""
|
| 217 |
+
``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like
|
| 218 |
+
abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding
|
| 219 |
+
layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`:
|
| 220 |
+
|
| 221 |
+
* :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension
|
| 222 |
+
* :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension
|
| 223 |
+
* :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension
|
| 224 |
+
|
| 225 |
+
When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue
|
| 226 |
+
communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the
|
| 227 |
+
placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs.
|
| 228 |
+
|
| 229 |
+
To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor``
|
| 230 |
+
requires every Tensor argument of the operator be DTensor.
|
| 231 |
+
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
_local_tensor: torch.Tensor
|
| 235 |
+
_spec: DTensorSpec
|
| 236 |
+
__slots__ = ["_local_tensor", "_spec"]
|
| 237 |
+
|
| 238 |
+
# _op_dispatcher instance as a class attribute to handle runtime dispatching logic
|
| 239 |
+
_op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()
|
| 240 |
+
|
| 241 |
+
@staticmethod
|
| 242 |
+
@torch._disable_dynamo
|
| 243 |
+
def __new__(
|
| 244 |
+
cls,
|
| 245 |
+
local_tensor: torch.Tensor,
|
| 246 |
+
spec: DTensorSpec,
|
| 247 |
+
*,
|
| 248 |
+
requires_grad: bool,
|
| 249 |
+
) -> "DTensor":
|
| 250 |
+
"""
|
| 251 |
+
Construct a DTensor from a local tensor, device mesh, and placement and
|
| 252 |
+
other tensor properties (i.e. shape, requires_grad, strides, etc).
|
| 253 |
+
|
| 254 |
+
.. note:: This is not a public API and it's only supposed to be used by the
|
| 255 |
+
operator implementations and internals. If you want to construct a
|
| 256 |
+
DTensor from a local tensor, consider using ``DTensor.from_local``, if
|
| 257 |
+
you want to construct a DTensor from a "global" tensor (where you
|
| 258 |
+
already have tensor initialized and want to shard this tensor),
|
| 259 |
+
consider using ``distribute_tensor``.
|
| 260 |
+
"""
|
| 261 |
+
if local_tensor.requires_grad and not requires_grad:
|
| 262 |
+
warnings.warn(
|
| 263 |
+
"To construct DTensor from torch.Tensor, it's recommended to "
|
| 264 |
+
"use local_tensor.detach() and make requires_grad consistent."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# new method instruct wrapper tensor from local_tensor and add
|
| 268 |
+
# placement spec, it does not do actual distribution
|
| 269 |
+
assert spec.tensor_meta is not None, "TensorMeta should not be None!"
|
| 270 |
+
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
| 271 |
+
cls,
|
| 272 |
+
spec.tensor_meta.shape,
|
| 273 |
+
strides=spec.tensor_meta.stride,
|
| 274 |
+
dtype=local_tensor.dtype,
|
| 275 |
+
device=local_tensor.device,
|
| 276 |
+
layout=local_tensor.layout,
|
| 277 |
+
requires_grad=requires_grad,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
r._spec = spec
|
| 281 |
+
r._local_tensor = local_tensor
|
| 282 |
+
return r
|
| 283 |
+
|
| 284 |
+
# pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
|
| 285 |
+
# pyre-fixme[3]: Return type must be annotated.
|
| 286 |
+
def __repr__(self):
|
| 287 |
+
# TODO: consider all_gather the local tensors for better debugging
|
| 288 |
+
return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
|
| 289 |
+
|
| 290 |
+
def __tensor_flatten__(self):
|
| 291 |
+
"""
|
| 292 |
+
protocol to inform how to flatten a DTensor to local tensor
|
| 293 |
+
for PT2 tracing
|
| 294 |
+
"""
|
| 295 |
+
return ["_local_tensor"], (self._spec, self.requires_grad)
|
| 296 |
+
|
| 297 |
+
@staticmethod
|
| 298 |
+
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
|
| 299 |
+
assert (
|
| 300 |
+
flatten_spec is not None
|
| 301 |
+
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
|
| 302 |
+
local_tensor = inner_tensors["_local_tensor"]
|
| 303 |
+
spec, requires_grad = flatten_spec
|
| 304 |
+
unflatten_tensor_meta = TensorMeta(
|
| 305 |
+
shape=outer_size,
|
| 306 |
+
stride=outer_stride,
|
| 307 |
+
dtype=spec.tensor_meta.dtype,
|
| 308 |
+
)
|
| 309 |
+
unflatten_spec = DTensorSpec(
|
| 310 |
+
spec.mesh,
|
| 311 |
+
spec.placements,
|
| 312 |
+
tensor_meta=unflatten_tensor_meta,
|
| 313 |
+
)
|
| 314 |
+
return DTensor(
|
| 315 |
+
local_tensor,
|
| 316 |
+
unflatten_spec,
|
| 317 |
+
requires_grad=requires_grad,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def __coerce_tangent_metadata__(self):
|
| 321 |
+
if not any(isinstance(p, Partial) for p in self.placements):
|
| 322 |
+
return self
|
| 323 |
+
placements = [
|
| 324 |
+
Replicate() if isinstance(p, Partial) else p for p in self.placements
|
| 325 |
+
]
|
| 326 |
+
return self.redistribute(device_mesh=self.device_mesh, placements=placements)
|
| 327 |
+
|
| 328 |
+
def __coerce_same_metadata_as_tangent__(self, flatten_spec):
|
| 329 |
+
(spec, _) = flatten_spec # Result of tensor_flatten()
|
| 330 |
+
return self.redistribute(
|
| 331 |
+
device_mesh=self.device_mesh,
|
| 332 |
+
placements=spec.placements,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
@classmethod
|
| 336 |
+
@torch._disable_dynamo
|
| 337 |
+
# pyre-fixme[3]: Return type must be annotated.
|
| 338 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
| 339 |
+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
| 340 |
+
return DTensor._op_dispatcher.dispatch(
|
| 341 |
+
func,
|
| 342 |
+
args,
|
| 343 |
+
kwargs or {},
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
@staticmethod
|
| 347 |
+
def from_local(
|
| 348 |
+
local_tensor: torch.Tensor,
|
| 349 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 350 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 351 |
+
*,
|
| 352 |
+
run_check: bool = False,
|
| 353 |
+
shape: Optional[torch.Size] = None,
|
| 354 |
+
stride: Optional[Tuple[int, ...]] = None,
|
| 355 |
+
) -> "DTensor":
|
| 356 |
+
"""
|
| 357 |
+
Create a :class:`DTensor` from a local torch.Tensor on each rank
|
| 358 |
+
according to the ``device_mesh`` and ``placements`` specified.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
local_tensor (torch.Tensor): local torch.Tensor on each rank.
|
| 362 |
+
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
|
| 363 |
+
tensor, if not specified, must be called under a DeviceMesh
|
| 364 |
+
context manager, default: None
|
| 365 |
+
placements (List[:class:`Placement`], optional): the placements that
|
| 366 |
+
describes how to place the local torch.Tensor on DeviceMesh, must
|
| 367 |
+
have the same number of elements as ``device_mesh.ndim``.
|
| 368 |
+
|
| 369 |
+
Keyword args:
|
| 370 |
+
run_check (bool, optional): at a cost of extra communications, perform
|
| 371 |
+
sanity check across ranks to check each local tensor's meta information
|
| 372 |
+
to ensure correctness. If have :class:`Replicate` in ``placements``, the
|
| 373 |
+
data on first rank of the device mesh dimension will be broadcasted
|
| 374 |
+
to other ranks. default: False
|
| 375 |
+
shape (torch.Size, optional): A List of int which specifies the size of
|
| 376 |
+
DTensor which build on top of `local_tensor`. Note this needs to be
|
| 377 |
+
provided if the shape of ``local_tensor`` are different across the ranks.
|
| 378 |
+
If not provided, ``shape`` will be computed assuming the given distributed
|
| 379 |
+
tensor is evenly sharded across ranks. default: None
|
| 380 |
+
stride (tuple, optional): A List of int which specifies the stride of DTensor.
|
| 381 |
+
If not provided, ``stride`` will be computed assuming the given distributed
|
| 382 |
+
tensor is evenly sharded across ranks. default: None
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
A :class:`DTensor` object
|
| 386 |
+
|
| 387 |
+
.. note:: When ``run_check=False``, it is the user's responsibility to ensure the
|
| 388 |
+
local tensor passed in is correct across ranks (i.e. the tensor is sharded for
|
| 389 |
+
the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement).
|
| 390 |
+
If not, the behavior of the created DTensor is undefined.
|
| 391 |
+
|
| 392 |
+
.. note:: ``from_local`` is differentiable, the `requires_grad` of the created
|
| 393 |
+
`DTensor` object will depend on if `local_tensor` requires_grad or not.
|
| 394 |
+
"""
|
| 395 |
+
# if same shape/dtype, no need to run_check, if not, must allgather
|
| 396 |
+
# the metadatas to check the size/dtype across ranks
|
| 397 |
+
# There should be no data communication unless there's replication
|
| 398 |
+
# strategy, where we broadcast the replication from the first rank
|
| 399 |
+
# in the mesh dimension
|
| 400 |
+
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
|
| 401 |
+
device_type = device_mesh.device_type
|
| 402 |
+
|
| 403 |
+
# convert the local tensor to desired device base on device mesh's device_type
|
| 404 |
+
if device_type != local_tensor.device.type and not local_tensor.is_meta:
|
| 405 |
+
local_tensor = local_tensor.to(device_type)
|
| 406 |
+
|
| 407 |
+
# set default placements to replicated if not specified
|
| 408 |
+
if placements is None:
|
| 409 |
+
placements = [Replicate() for _ in range(device_mesh.ndim)]
|
| 410 |
+
else:
|
| 411 |
+
placements = list(placements)
|
| 412 |
+
for idx, placement in enumerate(placements):
|
| 413 |
+
# normalize shard dim to be positive
|
| 414 |
+
if placement.is_shard():
|
| 415 |
+
placement = cast(Shard, placement)
|
| 416 |
+
if placement.dim < 0:
|
| 417 |
+
placements[idx] = Shard(placement.dim + local_tensor.ndim)
|
| 418 |
+
|
| 419 |
+
# `from_local` is differentiable, and the gradient of the dist tensor this function
|
| 420 |
+
# created should flow back the gradients to the local_tensor, so we call an autograd
|
| 421 |
+
# function to construct the dist tensor instead.
|
| 422 |
+
return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
|
| 423 |
+
local_tensor,
|
| 424 |
+
device_mesh,
|
| 425 |
+
tuple(placements),
|
| 426 |
+
run_check,
|
| 427 |
+
shape,
|
| 428 |
+
stride,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
def to_local(
|
| 432 |
+
self, *, grad_placements: Optional[Sequence[Placement]] = None
|
| 433 |
+
) -> torch.Tensor:
|
| 434 |
+
"""
|
| 435 |
+
Get the local tensor of this DTensor on its current rank. For sharding it returns
|
| 436 |
+
a local shard of the logical tensor view, for replication it returns the replica on
|
| 437 |
+
its current rank.
|
| 438 |
+
|
| 439 |
+
Keyword args:
|
| 440 |
+
grad_placements (List[:class:`Placement`], optional): the placements describes
|
| 441 |
+
the future layout of any gradient layout of the Tensor returned from this
|
| 442 |
+
function.
|
| 443 |
+
`to_local` converts DTensor to local tensor and the returned local tensor
|
| 444 |
+
might not be used as the original DTensor layout later in the code. This
|
| 445 |
+
argument is the hint that user can give to autograd in case the gradient
|
| 446 |
+
layout of the returned tensor does not match the original DTensor layout.
|
| 447 |
+
If not specified, we will assume the gradient layout remains the same
|
| 448 |
+
as the original DTensor and use that for gradient computation.
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the
|
| 452 |
+
local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned,
|
| 453 |
+
it means the local tensor is not ready yet (i.e. communication is not finished). In this
|
| 454 |
+
case, user needs to call ``wait`` to wait the local tensor to be ready.
|
| 455 |
+
|
| 456 |
+
.. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned
|
| 457 |
+
will depend on if the `DTensor` requires_grad or not.
|
| 458 |
+
"""
|
| 459 |
+
if not torch.is_grad_enabled():
|
| 460 |
+
return self._local_tensor
|
| 461 |
+
|
| 462 |
+
if grad_placements is not None and not isinstance(grad_placements, tuple):
|
| 463 |
+
grad_placements = tuple(grad_placements)
|
| 464 |
+
return _ToTorchTensor.apply(
|
| 465 |
+
self, grad_placements
|
| 466 |
+
) # pyre-ignore[16]: autograd func
|
| 467 |
+
|
| 468 |
+
def redistribute(
|
| 469 |
+
self,
|
| 470 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 471 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 472 |
+
*,
|
| 473 |
+
async_op: bool = False,
|
| 474 |
+
) -> "DTensor":
|
| 475 |
+
"""
|
| 476 |
+
``redistribute`` performs necessary collective operations that redistribute the current
|
| 477 |
+
DTensor from its current placements to a new placements, or from is current DeviceMesh
|
| 478 |
+
to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
|
| 479 |
+
specifying a Replicate placement for each dimension of the DeviceMesh.
|
| 480 |
+
|
| 481 |
+
When redistributing from current to the new placements on one device mesh dimension, we
|
| 482 |
+
will perform the following operations including communication collective or local operation:
|
| 483 |
+
|
| 484 |
+
1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather``
|
| 485 |
+
2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all``
|
| 486 |
+
3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``)
|
| 487 |
+
4. ``Partial()`` -> ``Replicate()``: ``all_reduce``
|
| 488 |
+
5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter``
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
``redistribute`` would correctly figure out the necessary redistribute steps for DTensors
|
| 492 |
+
that are created either on 1-D or N-D DeviceMesh.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
|
| 496 |
+
DTensor. If not specified, it would use the current DTensor's DeviceMesh.
|
| 497 |
+
default: None
|
| 498 |
+
placements (List[:class:`Placement`], optional): the new placements that
|
| 499 |
+
describes how to place the DTensor into the DeviceMesh, must
|
| 500 |
+
have the same number of elements as ``device_mesh.ndim``.
|
| 501 |
+
default: replicate on all mesh dimensions
|
| 502 |
+
|
| 503 |
+
Keyword args:
|
| 504 |
+
async_op (bool, optional): whether to perform the DTensor redistribute operation
|
| 505 |
+
asynchronously or not. Default: False
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
A :class:`DTensor` object
|
| 509 |
+
|
| 510 |
+
.. note:: ``redistribute`` is differentiable, which means user do not need to worry about
|
| 511 |
+
the backward formula of the redistribute operation.
|
| 512 |
+
|
| 513 |
+
.. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh,
|
| 514 |
+
Please file an issue if you need to redistribute DTensor to different DeviceMesh.
|
| 515 |
+
"""
|
| 516 |
+
# NOTE: This redistribute API currently only supports out
|
| 517 |
+
# of place redistribution, i.e. it always create a new
|
| 518 |
+
# DTensor object and leave the original one unchanged.
|
| 519 |
+
|
| 520 |
+
# if device_mesh is not specified, use the current device_mesh
|
| 521 |
+
device_mesh = device_mesh or self.device_mesh
|
| 522 |
+
# raise error if new placements not specified
|
| 523 |
+
if placements is None:
|
| 524 |
+
raise RuntimeError("placements is needed for redistribute!")
|
| 525 |
+
|
| 526 |
+
placements = list(placements)
|
| 527 |
+
for i, placement in enumerate(placements):
|
| 528 |
+
if placement.is_partial():
|
| 529 |
+
raise RuntimeError(
|
| 530 |
+
"Can not redistribute to Partial, redistributing to Partial is for internal use only!"
|
| 531 |
+
)
|
| 532 |
+
elif isinstance(placement, Shard) and placement.dim < 0:
|
| 533 |
+
# normalize shard dim to be positive
|
| 534 |
+
placements[i] = Shard(placement.dim + self.ndim)
|
| 535 |
+
placements = tuple(placements)
|
| 536 |
+
|
| 537 |
+
# pyre-fixme[16]: `Redistribute` has no attribute `apply`.
|
| 538 |
+
return Redistribute.apply(self, device_mesh, placements, async_op)
|
| 539 |
+
|
| 540 |
+
def full_tensor(
|
| 541 |
+
self, *, grad_placements: Optional[Sequence[Placement]] = None
|
| 542 |
+
) -> torch.Tensor:
|
| 543 |
+
"""
|
| 544 |
+
Return the full tensor of this DTensor. It will perform necessary collectives
|
| 545 |
+
to gather the local tensors from other ranks in its DeviceMesh and concatenate
|
| 546 |
+
them together. It's a syntatic sugar of the following code:
|
| 547 |
+
|
| 548 |
+
``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()``
|
| 549 |
+
|
| 550 |
+
Keyword args:
|
| 551 |
+
grad_placements (List[:class:`Placement`], optional): the placements describes
|
| 552 |
+
the future layout of any gradient layout of the full Tensor returned from this
|
| 553 |
+
function.
|
| 554 |
+
`full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor
|
| 555 |
+
might not be used as the original replicated DTensor layout later in the code. This
|
| 556 |
+
argument is the hint that user can give to autograd in case the gradient
|
| 557 |
+
layout of the returned tensor does not match the original replicated DTensor layout.
|
| 558 |
+
If not specified, we will assume the gradient layout of the full tensor be replicated.
|
| 559 |
+
|
| 560 |
+
Returns:
|
| 561 |
+
A :class:`torch.Tensor` object that represents the full tensor of this DTensor.
|
| 562 |
+
|
| 563 |
+
.. note:: ``full_tensor`` is differentiable.
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
redist_res = self.redistribute(
|
| 567 |
+
placements=[Replicate()] * self.device_mesh.ndim, async_op=False
|
| 568 |
+
)
|
| 569 |
+
return _ToTorchTensor.apply(redist_res, grad_placements)
|
| 570 |
+
|
| 571 |
+
@property
|
| 572 |
+
def device_mesh(self) -> DeviceMesh:
|
| 573 |
+
"""
|
| 574 |
+
The :class:`DeviceMesh` attribute that associates with this DTensor object.
|
| 575 |
+
|
| 576 |
+
.. note:: ``device_mesh`` is a read-only property, it can not be set.
|
| 577 |
+
"""
|
| 578 |
+
return self._spec.mesh
|
| 579 |
+
|
| 580 |
+
@property
|
| 581 |
+
def placements(self) -> Tuple[Placement, ...]:
|
| 582 |
+
"""
|
| 583 |
+
The placements attribute of this DTensor that describes the layout of this
|
| 584 |
+
DTensor on the its DeviceMesh.
|
| 585 |
+
|
| 586 |
+
.. note:: ``placements`` is a read-only property, it can not be set.
|
| 587 |
+
"""
|
| 588 |
+
return self._spec.placements
|
| 589 |
+
|
| 590 |
+
def __create_write_items__(self, fqn: str, object: Any):
|
| 591 |
+
from torch.distributed.checkpoint.planner_helpers import (
|
| 592 |
+
_create_write_items_for_dtensor,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
if hasattr(self._local_tensor, "__create_write_items__"):
|
| 596 |
+
return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined]
|
| 597 |
+
elif isinstance(self._local_tensor, torch.Tensor):
|
| 598 |
+
return [_create_write_items_for_dtensor(fqn, object)]
|
| 599 |
+
else:
|
| 600 |
+
raise RuntimeError("Unsupported tensor type!")
|
| 601 |
+
|
| 602 |
+
def __create_chunk_list__(self):
|
| 603 |
+
from torch.distributed.checkpoint.planner_helpers import (
|
| 604 |
+
_create_chunk_from_dtensor,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
if hasattr(self._local_tensor, "__create_chunk_list__"):
|
| 608 |
+
return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined]
|
| 609 |
+
elif isinstance(self._local_tensor, torch.Tensor):
|
| 610 |
+
return [_create_chunk_from_dtensor(self)]
|
| 611 |
+
else:
|
| 612 |
+
raise RuntimeError("Unsupported tensor type!")
|
| 613 |
+
|
| 614 |
+
def __get_tensor_shard__(self, index):
|
| 615 |
+
if hasattr(self._local_tensor, "__get_tensor_shard__"):
|
| 616 |
+
return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined]
|
| 617 |
+
elif isinstance(self._local_tensor, torch.Tensor):
|
| 618 |
+
return self.to_local()
|
| 619 |
+
else:
|
| 620 |
+
raise RuntimeError("Unsupported tensor type!")
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def distribute_tensor(
|
| 624 |
+
tensor: torch.Tensor,
|
| 625 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 626 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 627 |
+
) -> DTensor:
|
| 628 |
+
"""
|
| 629 |
+
Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according
|
| 630 |
+
to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the
|
| 631 |
+
same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use
|
| 632 |
+
the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve
|
| 633 |
+
the single-device semantic. If you want to construct a DTensor in the middle of the Autograd
|
| 634 |
+
computation, please use :meth:`DTensor.from_local` instead.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
|
| 638 |
+
want to shard a tensor on a dimension that is not evenly divisible by
|
| 639 |
+
the number of devices in that mesh dimension, we use ``torch.chunk``
|
| 640 |
+
semantic to shard the tensor and scatter the shards. The uneven sharding
|
| 641 |
+
behavior is experimental and subject to change.
|
| 642 |
+
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
|
| 643 |
+
tensor, if not specified, must be called under a DeviceMesh context
|
| 644 |
+
manager, default: None
|
| 645 |
+
placements (List[:class:`Placement`], optional): the placements that
|
| 646 |
+
describes how to place the tensor on DeviceMesh, must have the same
|
| 647 |
+
number of elements as ``device_mesh.ndim``. If not specified, we will
|
| 648 |
+
by default replicate the tensor across the ``device_mesh`` from the
|
| 649 |
+
first rank of each dimension of the `device_mesh`.
|
| 650 |
+
|
| 651 |
+
Returns:
|
| 652 |
+
A :class:`DTensor` or ``XLAShardedTensor`` object.
|
| 653 |
+
|
| 654 |
+
.. note::
|
| 655 |
+
When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor``
|
| 656 |
+
return `XLAShardedTensor` instead. see `this issue <https://github.com/pytorch/pytorch/issues/92909>`__
|
| 657 |
+
for more details. The XLA integration is experimental and subject to change.
|
| 658 |
+
"""
|
| 659 |
+
|
| 660 |
+
torch._C._log_api_usage_once("torch.dtensor.distribute_tensor")
|
| 661 |
+
|
| 662 |
+
# get default device mesh if there's nothing specified
|
| 663 |
+
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
|
| 664 |
+
device_type = device_mesh.device_type
|
| 665 |
+
if device_type == "xla":
|
| 666 |
+
try:
|
| 667 |
+
# call PyTorch/XLA SPMD for `xla` backend type device mesh.
|
| 668 |
+
# This returns XLAShardedTensor
|
| 669 |
+
from torch_xla.distributed.spmd import ( # type:ignore[import]
|
| 670 |
+
xla_distribute_tensor,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
return xla_distribute_tensor(
|
| 674 |
+
tensor, device_mesh, placements
|
| 675 |
+
) # type:ignore[return-value]
|
| 676 |
+
except ImportError as e:
|
| 677 |
+
msg = "To use DTensor API with xla, you must install the torch_xla package!"
|
| 678 |
+
raise ImportError(msg) from e
|
| 679 |
+
|
| 680 |
+
# instantiate a RNG tracker if haven't. By default DTensor uses an
|
| 681 |
+
# OffsetBasedRNGTracker to perform random operators.
|
| 682 |
+
# TODO: the value assignment to global variable is not the ideal solution
|
| 683 |
+
# we can replace it in future.
|
| 684 |
+
if not random._rng_tracker and is_rng_supported_mesh(device_mesh):
|
| 685 |
+
random._rng_tracker = OffsetBasedRNGTracker(device_type)
|
| 686 |
+
|
| 687 |
+
if not tensor.is_leaf:
|
| 688 |
+
raise RuntimeError(
|
| 689 |
+
"`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!"
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
# convert tensor to the corresponding device type if it's not in that device type
|
| 693 |
+
if device_type != tensor.device.type and not tensor.is_meta:
|
| 694 |
+
tensor = tensor.to(device_type)
|
| 695 |
+
|
| 696 |
+
# set default placements to replicated if not specified
|
| 697 |
+
if placements is None:
|
| 698 |
+
placements = [Replicate() for _ in range(device_mesh.ndim)]
|
| 699 |
+
|
| 700 |
+
if len(placements) != device_mesh.ndim:
|
| 701 |
+
raise ValueError(
|
| 702 |
+
f"`placements` must have the same length as `device_mesh.ndim`! "
|
| 703 |
+
f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
|
| 704 |
+
)
|
| 705 |
+
if isinstance(tensor, DTensor):
|
| 706 |
+
# if the tensor is already a DTensor, we need to check:
|
| 707 |
+
# 1. if the we can further shard this DTensor if the two device mesh belong to
|
| 708 |
+
# the same parenet mesh and further sharding is possible.
|
| 709 |
+
# 2. check if device mesh and placements are the same
|
| 710 |
+
if tensor.device_mesh != device_mesh:
|
| 711 |
+
raise ValueError(
|
| 712 |
+
f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
|
| 713 |
+
f"to a different device mesh {device_mesh}."
|
| 714 |
+
)
|
| 715 |
+
if tensor.placements != tuple(placements):
|
| 716 |
+
raise ValueError(
|
| 717 |
+
f"Cannot distribute a DTensor with placements {tensor.placements} "
|
| 718 |
+
f"to a different placements {placements}. do you want to call "
|
| 719 |
+
f"`redistribute` instead?"
|
| 720 |
+
)
|
| 721 |
+
return tensor
|
| 722 |
+
|
| 723 |
+
local_tensor = tensor.detach()
|
| 724 |
+
|
| 725 |
+
# TODO(xilun): address sharding order
|
| 726 |
+
# distribute the tensor according to the placements.
|
| 727 |
+
placements = list(placements)
|
| 728 |
+
for idx, placement in enumerate(placements):
|
| 729 |
+
if placement.is_shard():
|
| 730 |
+
placement = cast(Shard, placement)
|
| 731 |
+
if placement.dim < 0:
|
| 732 |
+
# normalize shard placement dim
|
| 733 |
+
placement = Shard(placement.dim + tensor.ndim)
|
| 734 |
+
placements[idx] = placement
|
| 735 |
+
local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
|
| 736 |
+
elif placement.is_replicate():
|
| 737 |
+
placement = cast(Replicate, placement)
|
| 738 |
+
local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx)
|
| 739 |
+
else:
|
| 740 |
+
raise RuntimeError(
|
| 741 |
+
f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
|
| 742 |
+
)
|
| 743 |
+
placements = tuple(placements)
|
| 744 |
+
|
| 745 |
+
assert local_tensor is not None, "distributing a tensor should not be None"
|
| 746 |
+
# detach the local tensor passed to DTensor since after the construction
|
| 747 |
+
# of DTensor, autograd would work on top of DTensor instead of local tensor
|
| 748 |
+
spec = DTensorSpec(
|
| 749 |
+
mesh=device_mesh,
|
| 750 |
+
placements=placements,
|
| 751 |
+
tensor_meta=TensorMeta(
|
| 752 |
+
shape=tensor.size(),
|
| 753 |
+
stride=tensor.stride(),
|
| 754 |
+
dtype=tensor.dtype,
|
| 755 |
+
),
|
| 756 |
+
)
|
| 757 |
+
return DTensor(
|
| 758 |
+
local_tensor.requires_grad_(tensor.requires_grad),
|
| 759 |
+
spec,
|
| 760 |
+
requires_grad=tensor.requires_grad,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
def distribute_module(
|
| 765 |
+
module: nn.Module,
|
| 766 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 767 |
+
partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
|
| 768 |
+
input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
|
| 769 |
+
output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
|
| 770 |
+
) -> nn.Module:
|
| 771 |
+
"""
|
| 772 |
+
This function expose three functions to control the parameters/inputs/outputs of the module:
|
| 773 |
+
|
| 774 |
+
1. To perform sharding on the module before runtime execution by specifying the
|
| 775 |
+
``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
|
| 776 |
+
parameters according to the `partition_fn` specified).
|
| 777 |
+
2. To control the inputs or outputs of the module during runtime execution by
|
| 778 |
+
specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
|
| 779 |
+
:class:`DTensor`, convert the output back to ``torch.Tensor``)
|
| 780 |
+
|
| 781 |
+
Args:
|
| 782 |
+
module (:class:`nn.Module`): user module to be partitioned.
|
| 783 |
+
device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
|
| 784 |
+
partition_fn (Callable): the function to partition parameters (i.e. shard certain
|
| 785 |
+
parameters across the ``device_mesh``). If ``partition_fn`` is not specified,
|
| 786 |
+
by default we replicate all module parameters of ``module`` across the mesh.
|
| 787 |
+
input_fn (Callable): specify the input distribution, i.e. could control how the
|
| 788 |
+
input of the module is sharded. ``input_fn`` will be installed as a module
|
| 789 |
+
``forward_pre_hook`` (pre forward hook).
|
| 790 |
+
output_fn (Callable): specify the output distribution, i.e. could control how the
|
| 791 |
+
output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be
|
| 792 |
+
installed as a module ``forward_hook`` (post forward hook).
|
| 793 |
+
|
| 794 |
+
Returns:
|
| 795 |
+
A module that contains parameters/buffers that are all ``DTensor`` s.
|
| 796 |
+
|
| 797 |
+
.. note::
|
| 798 |
+
When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module``
|
| 799 |
+
return nn.Module with PyTorch/XLA SPMD annotated parameters. See
|
| 800 |
+
`this issue <https://github.com/pytorch/pytorch/issues/92909>`__
|
| 801 |
+
for more details. The XLA integration is experimental and subject to change.
|
| 802 |
+
|
| 803 |
+
"""
|
| 804 |
+
|
| 805 |
+
torch._C._log_api_usage_once("torch.dtensor.distribute_module")
|
| 806 |
+
|
| 807 |
+
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
|
| 808 |
+
device_type = device_mesh.device_type
|
| 809 |
+
if device_type == "xla":
|
| 810 |
+
try:
|
| 811 |
+
# This function annotates all module parameters for auto-partitioning with
|
| 812 |
+
# PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters
|
| 813 |
+
# according to the `partition_fn` specified.
|
| 814 |
+
from torch_xla.distributed.spmd import ( # type:ignore[import]
|
| 815 |
+
xla_distribute_module,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
return xla_distribute_module(
|
| 819 |
+
module, device_mesh, partition_fn, input_fn, output_fn
|
| 820 |
+
) # type:ignore[return-value]
|
| 821 |
+
except ImportError as e:
|
| 822 |
+
msg = "To use DTensor API with xla, you must install the torch_xla package!"
|
| 823 |
+
raise ImportError(msg) from e
|
| 824 |
+
|
| 825 |
+
def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
|
| 826 |
+
# This function loop over the immediate module parameters and
|
| 827 |
+
# buffers, replicate all non DTensor params/buffers to DTensor
|
| 828 |
+
# parameters/buffers, if they have not been partitioned in the
|
| 829 |
+
# partition_fn, we can't easily use `module._apply` here
|
| 830 |
+
# because we don't know what happened inside partition_fn as
|
| 831 |
+
# user could do anything, i.e. install hooks, and we want to
|
| 832 |
+
# preserve those.
|
| 833 |
+
full_replicate = [Replicate()] * mesh.ndim
|
| 834 |
+
for key, param in m._parameters.items():
|
| 835 |
+
if param is not None and not isinstance(param, DTensor):
|
| 836 |
+
m.register_parameter(
|
| 837 |
+
key,
|
| 838 |
+
nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)),
|
| 839 |
+
)
|
| 840 |
+
for key, buffer in m._buffers.items():
|
| 841 |
+
if buffer is not None and not isinstance(buffer, DTensor):
|
| 842 |
+
m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate)
|
| 843 |
+
|
| 844 |
+
if partition_fn is None:
|
| 845 |
+
# if partition_fn not specified, we by default replicate
|
| 846 |
+
# all module params/buffers
|
| 847 |
+
for name, submod in module.named_modules():
|
| 848 |
+
replicate_module_params_buffers(submod, device_mesh)
|
| 849 |
+
else:
|
| 850 |
+
# apply partition_fun to submodules
|
| 851 |
+
for name, submod in module.named_modules():
|
| 852 |
+
partition_fn(name, submod, device_mesh)
|
| 853 |
+
replicate_module_params_buffers(submod, device_mesh)
|
| 854 |
+
|
| 855 |
+
# register input_fn as module forward pre hook
|
| 856 |
+
if input_fn is not None:
|
| 857 |
+
# check the input_fn signature
|
| 858 |
+
num_args = len(inspect.signature(input_fn).parameters)
|
| 859 |
+
if num_args == 2:
|
| 860 |
+
# input_fn only takes in inputs and device mesh
|
| 861 |
+
warnings.warn(
|
| 862 |
+
"Deprecating input_fn that takes two arguments (inputs, device_mesh), "
|
| 863 |
+
"please use input_fn that takes in (module, inputs, device_mesh) instead!",
|
| 864 |
+
FutureWarning,
|
| 865 |
+
stacklevel=2,
|
| 866 |
+
)
|
| 867 |
+
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
|
| 868 |
+
elif num_args == 3:
|
| 869 |
+
# input_fn takes in module, inputs, device mesh
|
| 870 |
+
module.register_forward_pre_hook(
|
| 871 |
+
lambda mod, inputs: input_fn(mod, inputs, device_mesh)
|
| 872 |
+
)
|
| 873 |
+
else:
|
| 874 |
+
raise ValueError(
|
| 875 |
+
f"input_fn should take in 3 arguments, but got {num_args} arguments!"
|
| 876 |
+
)
|
| 877 |
+
# register output_fn as module forward hook
|
| 878 |
+
if output_fn is not None:
|
| 879 |
+
num_args = len(inspect.signature(output_fn).parameters)
|
| 880 |
+
if num_args == 2:
|
| 881 |
+
# output_fn only takes in outputs and device mesh
|
| 882 |
+
warnings.warn(
|
| 883 |
+
"Deprecating output_fn that takes two arguments (inputs, device_mesh), "
|
| 884 |
+
"please use output_fn that takes in (module, inputs, device_mesh) instead!",
|
| 885 |
+
FutureWarning,
|
| 886 |
+
stacklevel=2,
|
| 887 |
+
)
|
| 888 |
+
module.register_forward_hook(
|
| 889 |
+
lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg]
|
| 890 |
+
)
|
| 891 |
+
elif num_args == 3:
|
| 892 |
+
module.register_forward_hook(
|
| 893 |
+
lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
|
| 894 |
+
)
|
| 895 |
+
else:
|
| 896 |
+
raise ValueError(
|
| 897 |
+
f"output_fn should take in 3 arguments, but got {num_args} arguments!"
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
return module
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
# Below are tensor factory function APIs, which are used to create a DTensor directly. We need
|
| 904 |
+
# to make separate factory function APIs because tensor subclass could not override the tensor
|
| 905 |
+
# factory methods, and we need user to call the factory functions with user intended device_mesh
|
| 906 |
+
# and placements to create a proper DTensor.
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def _dtensor_init_helper( # type: ignore[no-untyped-def]
|
| 910 |
+
init_op,
|
| 911 |
+
size: torch.Size,
|
| 912 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 913 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 914 |
+
**kwargs,
|
| 915 |
+
) -> DTensor:
|
| 916 |
+
# from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
|
| 917 |
+
|
| 918 |
+
# if device_mesh is None, use the one from mesh resources
|
| 919 |
+
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
|
| 920 |
+
kwargs["device"] = device_mesh.device_type
|
| 921 |
+
|
| 922 |
+
# set default placements to replicated if not specified
|
| 923 |
+
placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
|
| 924 |
+
|
| 925 |
+
# check device_mesh againts placements
|
| 926 |
+
assert device_mesh.ndim == len(
|
| 927 |
+
placements
|
| 928 |
+
), "mesh dimension does not match the length of placements"
|
| 929 |
+
|
| 930 |
+
assert kwargs["layout"] == torch.strided, "layout value not supported!"
|
| 931 |
+
torch_stride = torch._prims_common.make_contiguous_strides_for(size)
|
| 932 |
+
|
| 933 |
+
# get local tensor shape
|
| 934 |
+
local_shape = compute_local_shape(size, device_mesh, placements)
|
| 935 |
+
# initialize the local tensor
|
| 936 |
+
if init_op == torch.full:
|
| 937 |
+
fill_value = kwargs.pop("fill_value", 0)
|
| 938 |
+
local_tensor = init_op(local_shape, fill_value, **kwargs)
|
| 939 |
+
elif init_op == torch.rand or init_op == torch.randn:
|
| 940 |
+
# this tensor meta is not used except `shape`
|
| 941 |
+
dtype = kwargs.get("dtype", torch.get_default_dtype())
|
| 942 |
+
|
| 943 |
+
tensor_meta = TensorMeta(size, (0,), dtype)
|
| 944 |
+
spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta)
|
| 945 |
+
|
| 946 |
+
if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
|
| 947 |
+
random._rng_tracker = random.OffsetBasedRNGTracker()
|
| 948 |
+
|
| 949 |
+
assert random._rng_tracker is not None
|
| 950 |
+
with random._rng_tracker._distribute_region(spec):
|
| 951 |
+
local_tensor = init_op(local_shape, **kwargs)
|
| 952 |
+
else:
|
| 953 |
+
local_tensor = init_op(local_shape, **kwargs)
|
| 954 |
+
|
| 955 |
+
spec = DTensorSpec(
|
| 956 |
+
device_mesh,
|
| 957 |
+
tuple(placements),
|
| 958 |
+
tensor_meta=TensorMeta(
|
| 959 |
+
size,
|
| 960 |
+
torch_stride,
|
| 961 |
+
local_tensor.dtype,
|
| 962 |
+
),
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
return DTensor(
|
| 966 |
+
local_tensor,
|
| 967 |
+
spec,
|
| 968 |
+
requires_grad=kwargs["requires_grad"],
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
def ones( # type: ignore[no-untyped-def]
|
| 973 |
+
*size,
|
| 974 |
+
dtype: Optional[torch.dtype] = None,
|
| 975 |
+
layout: torch.layout = torch.strided,
|
| 976 |
+
requires_grad: bool = False,
|
| 977 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 978 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 979 |
+
) -> DTensor:
|
| 980 |
+
"""
|
| 981 |
+
Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
|
| 982 |
+
by the variable argument ``size``.
|
| 983 |
+
|
| 984 |
+
Args:
|
| 985 |
+
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
|
| 986 |
+
Can be a variable number of arguments or a collection like a list or tuple.
|
| 987 |
+
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
|
| 988 |
+
|
| 989 |
+
Keyword args:
|
| 990 |
+
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
|
| 991 |
+
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
|
| 992 |
+
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
|
| 993 |
+
Default: ``torch.strided``.
|
| 994 |
+
requires_grad (bool, optional): If autograd should record operations on the
|
| 995 |
+
returned :class:`DTensor`. Default: ``False``.
|
| 996 |
+
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
|
| 997 |
+
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
|
| 998 |
+
|
| 999 |
+
Returns:
|
| 1000 |
+
A :class:`DTensor` object on each rank
|
| 1001 |
+
"""
|
| 1002 |
+
torch_size = normalize_to_torch_size(size)
|
| 1003 |
+
|
| 1004 |
+
return _dtensor_init_helper(
|
| 1005 |
+
torch.ones,
|
| 1006 |
+
torch_size,
|
| 1007 |
+
dtype=dtype,
|
| 1008 |
+
layout=layout,
|
| 1009 |
+
requires_grad=requires_grad,
|
| 1010 |
+
device_mesh=device_mesh,
|
| 1011 |
+
placements=placements,
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
def empty( # type: ignore[no-untyped-def]
|
| 1016 |
+
*size,
|
| 1017 |
+
dtype: Optional[torch.dtype] = None,
|
| 1018 |
+
layout: torch.layout = torch.strided,
|
| 1019 |
+
requires_grad: bool = False,
|
| 1020 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 1021 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 1022 |
+
) -> DTensor:
|
| 1023 |
+
"""
|
| 1024 |
+
Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
|
| 1025 |
+
is defined by the variable argument ``size``.
|
| 1026 |
+
|
| 1027 |
+
Args:
|
| 1028 |
+
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
|
| 1029 |
+
Can be a variable number of arguments or a collection like a list or tuple.
|
| 1030 |
+
E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
|
| 1031 |
+
|
| 1032 |
+
Keyword args:
|
| 1033 |
+
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
|
| 1034 |
+
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\
|
| 1035 |
+
layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
|
| 1036 |
+
Default: ``torch.strided``.
|
| 1037 |
+
requires_grad (bool, optional): If autograd should record operations on the
|
| 1038 |
+
returned :class:`DTensor`. Default: ``False``.
|
| 1039 |
+
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
|
| 1040 |
+
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
|
| 1041 |
+
|
| 1042 |
+
Returns:
|
| 1043 |
+
A :class:`DTensor` object on each rank
|
| 1044 |
+
"""
|
| 1045 |
+
torch_size = normalize_to_torch_size(size)
|
| 1046 |
+
|
| 1047 |
+
return _dtensor_init_helper(
|
| 1048 |
+
torch.empty,
|
| 1049 |
+
torch_size,
|
| 1050 |
+
dtype=dtype,
|
| 1051 |
+
layout=layout,
|
| 1052 |
+
requires_grad=requires_grad,
|
| 1053 |
+
device_mesh=device_mesh,
|
| 1054 |
+
placements=placements,
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
def full( # type: ignore[no-untyped-def]
|
| 1059 |
+
size,
|
| 1060 |
+
fill_value,
|
| 1061 |
+
*,
|
| 1062 |
+
dtype: Optional[torch.dtype] = None,
|
| 1063 |
+
layout: torch.layout = torch.strided,
|
| 1064 |
+
requires_grad: bool = False,
|
| 1065 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 1066 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 1067 |
+
) -> DTensor:
|
| 1068 |
+
"""
|
| 1069 |
+
Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and
|
| 1070 |
+
``placements``, with the shape defined by the argument ``size``.
|
| 1071 |
+
|
| 1072 |
+
Args:
|
| 1073 |
+
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
|
| 1074 |
+
Can be a variable number of arguments or a collection like a list or tuple.
|
| 1075 |
+
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
|
| 1076 |
+
fill_value(Scalar): the value to fill the output tensor with.
|
| 1077 |
+
|
| 1078 |
+
Keyword args:
|
| 1079 |
+
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
|
| 1080 |
+
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
|
| 1081 |
+
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
|
| 1082 |
+
Default: ``torch.strided``.
|
| 1083 |
+
requires_grad (bool, optional): If autograd should record operations on the
|
| 1084 |
+
returned :class:`DTensor`. Default: ``False``.
|
| 1085 |
+
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
|
| 1086 |
+
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
|
| 1087 |
+
|
| 1088 |
+
Returns:
|
| 1089 |
+
A :class:`DTensor` object on each rank
|
| 1090 |
+
"""
|
| 1091 |
+
torch_size = normalize_to_torch_size(size)
|
| 1092 |
+
|
| 1093 |
+
return _dtensor_init_helper(
|
| 1094 |
+
torch.full,
|
| 1095 |
+
torch_size,
|
| 1096 |
+
fill_value=fill_value,
|
| 1097 |
+
dtype=dtype,
|
| 1098 |
+
layout=layout,
|
| 1099 |
+
requires_grad=requires_grad,
|
| 1100 |
+
device_mesh=device_mesh,
|
| 1101 |
+
placements=placements,
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
def rand( # type: ignore[no-untyped-def]
|
| 1106 |
+
*size,
|
| 1107 |
+
requires_grad: bool = False,
|
| 1108 |
+
dtype: Optional[torch.dtype] = None,
|
| 1109 |
+
layout: torch.layout = torch.strided,
|
| 1110 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 1111 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 1112 |
+
) -> DTensor:
|
| 1113 |
+
"""
|
| 1114 |
+
Returns a :class:`DTensor` filled with random numbers from a uniform distribution
|
| 1115 |
+
on the interval ``[0, 1)``. The shape of the tensor is defined by the variable
|
| 1116 |
+
argument ``size``.
|
| 1117 |
+
|
| 1118 |
+
Args:
|
| 1119 |
+
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
|
| 1120 |
+
Can be a variable number of arguments or a collection like a list or tuple.
|
| 1121 |
+
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
|
| 1122 |
+
|
| 1123 |
+
Keyword args:
|
| 1124 |
+
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
|
| 1125 |
+
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
|
| 1126 |
+
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
|
| 1127 |
+
Default: ``torch.strided``.
|
| 1128 |
+
requires_grad (bool, optional): If autograd should record operations on the
|
| 1129 |
+
returned :class:`DTensor`. Default: ``False``.
|
| 1130 |
+
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
|
| 1131 |
+
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
|
| 1132 |
+
|
| 1133 |
+
Returns:
|
| 1134 |
+
A :class:`DTensor` object on each rank
|
| 1135 |
+
"""
|
| 1136 |
+
torch_size = normalize_to_torch_size(size)
|
| 1137 |
+
|
| 1138 |
+
return _dtensor_init_helper(
|
| 1139 |
+
torch.rand,
|
| 1140 |
+
torch_size,
|
| 1141 |
+
dtype=dtype,
|
| 1142 |
+
layout=layout,
|
| 1143 |
+
requires_grad=requires_grad,
|
| 1144 |
+
device_mesh=device_mesh,
|
| 1145 |
+
placements=placements,
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
def randn( # type: ignore[no-untyped-def]
|
| 1150 |
+
*size,
|
| 1151 |
+
requires_grad: bool = False,
|
| 1152 |
+
dtype: Optional[torch.dtype] = None,
|
| 1153 |
+
layout: torch.layout = torch.strided,
|
| 1154 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 1155 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 1156 |
+
) -> DTensor:
|
| 1157 |
+
"""
|
| 1158 |
+
Returns a :class:`DTensor` filled with random numbers from a normal distribution
|
| 1159 |
+
with mean 0 and variance 1. The shape of the tensor is defined by the variable
|
| 1160 |
+
argument ``size``.
|
| 1161 |
+
|
| 1162 |
+
Args:
|
| 1163 |
+
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
|
| 1164 |
+
Can be a variable number of arguments or a collection like a list or tuple.
|
| 1165 |
+
E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
|
| 1166 |
+
|
| 1167 |
+
Keyword args:
|
| 1168 |
+
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
|
| 1169 |
+
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
|
| 1170 |
+
layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
|
| 1171 |
+
Default: ``torch.strided``.
|
| 1172 |
+
requires_grad (bool, optional): If autograd should record operations on the
|
| 1173 |
+
returned :class:`DTensor`. Default: ``False``.
|
| 1174 |
+
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
|
| 1175 |
+
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
|
| 1176 |
+
|
| 1177 |
+
Returns:
|
| 1178 |
+
A :class:`DTensor` object on each rank
|
| 1179 |
+
"""
|
| 1180 |
+
torch_size = normalize_to_torch_size(size)
|
| 1181 |
+
|
| 1182 |
+
return _dtensor_init_helper(
|
| 1183 |
+
torch.randn,
|
| 1184 |
+
torch_size,
|
| 1185 |
+
dtype=dtype,
|
| 1186 |
+
layout=layout,
|
| 1187 |
+
requires_grad=requires_grad,
|
| 1188 |
+
device_mesh=device_mesh,
|
| 1189 |
+
placements=placements,
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
def zeros( # type: ignore[no-untyped-def]
|
| 1194 |
+
*size,
|
| 1195 |
+
requires_grad: bool = False,
|
| 1196 |
+
dtype: Optional[torch.dtype] = None,
|
| 1197 |
+
layout: torch.layout = torch.strided,
|
| 1198 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 1199 |
+
placements: Optional[Sequence[Placement]] = None,
|
| 1200 |
+
) -> DTensor:
|
| 1201 |
+
"""
|
| 1202 |
+
Returns a :class:`DTensor` filled with the scalar value 0.
|
| 1203 |
+
|
| 1204 |
+
Args:
|
| 1205 |
+
size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
|
| 1206 |
+
Can be a variable number of arguments or a collection like a list or tuple.
|
| 1207 |
+
E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
|
| 1208 |
+
Keyword args:
|
| 1209 |
+
requires_grad (bool, optional): If autograd should record operations on the
|
| 1210 |
+
returned :class:`DTensor`. Default: ``False``.
|
| 1211 |
+
dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
|
| 1212 |
+
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
|
| 1213 |
+
layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
|
| 1214 |
+
Default: ``torch.strided``.
|
| 1215 |
+
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
|
| 1216 |
+
placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
|
| 1217 |
+
|
| 1218 |
+
Returns:
|
| 1219 |
+
A :class:`DTensor` object on each rank
|
| 1220 |
+
"""
|
| 1221 |
+
torch_size = normalize_to_torch_size(size)
|
| 1222 |
+
|
| 1223 |
+
return _dtensor_init_helper(
|
| 1224 |
+
torch.zeros,
|
| 1225 |
+
torch_size,
|
| 1226 |
+
dtype=dtype,
|
| 1227 |
+
layout=layout,
|
| 1228 |
+
requires_grad=requires_grad,
|
| 1229 |
+
device_mesh=device_mesh,
|
| 1230 |
+
placements=placements,
|
| 1231 |
+
)
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_collective_utils.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed._functional_collectives as funcol
|
| 10 |
+
import torch.distributed.tensor._dtensor_spec as dtensor_spec
|
| 11 |
+
from torch._C._distributed_c10d import _resolve_process_group
|
| 12 |
+
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
| 13 |
+
from torch.distributed.distributed_c10d import (
|
| 14 |
+
_get_group_size_by_name,
|
| 15 |
+
broadcast,
|
| 16 |
+
get_global_rank,
|
| 17 |
+
get_group_rank,
|
| 18 |
+
get_rank,
|
| 19 |
+
GroupMember,
|
| 20 |
+
ProcessGroup,
|
| 21 |
+
scatter,
|
| 22 |
+
Work,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if not torch._running_with_deploy():
|
| 30 |
+
|
| 31 |
+
@torch.library.register_fake("_dtensor::shard_dim_alltoall")
|
| 32 |
+
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
|
| 33 |
+
group_size = _get_group_size_by_name(group_name)
|
| 34 |
+
stacked_list = [torch.empty_like(input) for _ in range(group_size)]
|
| 35 |
+
group = _resolve_process_group(group_name)
|
| 36 |
+
group_rank = get_group_rank(group, get_rank())
|
| 37 |
+
|
| 38 |
+
return torch.cat(stacked_list, dim=gather_dim).chunk(group_size, dim=shard_dim)[
|
| 39 |
+
group_rank
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
else:
|
| 43 |
+
import warnings
|
| 44 |
+
|
| 45 |
+
warnings.warn(
|
| 46 |
+
"PyTorch Distributed functional collectives do not work with torch::deploy."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
|
| 51 |
+
if mesh.device_type == "cpu":
|
| 52 |
+
# Gloo does not support alltoall, so falling back to allgather + chunk
|
| 53 |
+
|
| 54 |
+
# TODO: This logs way too much
|
| 55 |
+
logger.warning(
|
| 56 |
+
"CPU process group does not support alltoall yet, falling back with allgather + chunk!"
|
| 57 |
+
)
|
| 58 |
+
out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim))
|
| 59 |
+
if isinstance(out, funcol.AsyncCollectiveTensor):
|
| 60 |
+
# stick to the same behavior for the alltoall case, remove this once we enable alltoall async
|
| 61 |
+
out = out.wait()
|
| 62 |
+
out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[
|
| 63 |
+
mesh.get_local_rank(mesh_dim)
|
| 64 |
+
]
|
| 65 |
+
return out.contiguous() if not out.is_contiguous() else out
|
| 66 |
+
|
| 67 |
+
group_name = funcol._resolve_group_name((mesh, mesh_dim))
|
| 68 |
+
# TODO: enable async op for shard_dim_alltoall
|
| 69 |
+
return torch.ops._dtensor.shard_dim_alltoall(
|
| 70 |
+
input, gather_dim, shard_dim, group_name
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def mesh_scatter(
|
| 75 |
+
output: torch.Tensor,
|
| 76 |
+
scatter_list: List[torch.Tensor],
|
| 77 |
+
mesh: DeviceMesh,
|
| 78 |
+
mesh_dim: int = 0,
|
| 79 |
+
async_op: bool = False,
|
| 80 |
+
) -> Optional[Work]:
|
| 81 |
+
"""
|
| 82 |
+
scatter a list of tensors to a device mesh dimension. We by default
|
| 83 |
+
use the first rank of the mesh dimension as the source of truth, i.e
|
| 84 |
+
for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
|
| 85 |
+
scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
|
| 86 |
+
2 to rank 2/3.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
output (torch.Tensor): the tensor to receive the scattered list.
|
| 90 |
+
scatter_list (List[torch.Tensor]): the tensor list to be scattered.
|
| 91 |
+
mesh_dim (int, optional): indicate which mesh dimension we want
|
| 92 |
+
to scatter on, we by default choose the first rank on the
|
| 93 |
+
mesh dimension as source of truth.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
A :class:`Work` object
|
| 97 |
+
"""
|
| 98 |
+
# TODO: Ideally we should use the meta tensor way
|
| 99 |
+
# (to register a meta kernel for the collective op)
|
| 100 |
+
# so that it would avoid the communication. Need to
|
| 101 |
+
# remove the check below once that is done.
|
| 102 |
+
if output.is_meta:
|
| 103 |
+
return None
|
| 104 |
+
dim_group = mesh.get_group(mesh_dim)
|
| 105 |
+
assert isinstance(dim_group, ProcessGroup)
|
| 106 |
+
# src need to be global rank
|
| 107 |
+
src_for_dim = 0
|
| 108 |
+
|
| 109 |
+
if dim_group is not GroupMember.WORLD:
|
| 110 |
+
src_for_dim = get_global_rank(dim_group, 0)
|
| 111 |
+
|
| 112 |
+
if src_for_dim == get_rank():
|
| 113 |
+
fut = scatter(
|
| 114 |
+
output,
|
| 115 |
+
scatter_list=scatter_list,
|
| 116 |
+
src=src_for_dim,
|
| 117 |
+
group=dim_group,
|
| 118 |
+
async_op=async_op,
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
fut = scatter(
|
| 122 |
+
output,
|
| 123 |
+
scatter_list=None,
|
| 124 |
+
src=src_for_dim,
|
| 125 |
+
group=dim_group,
|
| 126 |
+
async_op=async_op,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return fut
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def mesh_broadcast(
|
| 133 |
+
tensor: torch.Tensor,
|
| 134 |
+
mesh: DeviceMesh,
|
| 135 |
+
mesh_dim: int = 0,
|
| 136 |
+
async_op: bool = False,
|
| 137 |
+
) -> Optional[Work]:
|
| 138 |
+
"""
|
| 139 |
+
broadcast the tensor to a device mesh dimension. We by default
|
| 140 |
+
use the first rank of the mesh dimension as the source of truth, i.e
|
| 141 |
+
for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
|
| 142 |
+
broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
|
| 143 |
+
to rank 2/3.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
tensor (torch.Tensor): tensor to broadcast.
|
| 147 |
+
mesh_dim (int, optional): indicate which mesh dimension we want
|
| 148 |
+
to scatter on, we by default choose the first rank on the
|
| 149 |
+
mesh dimension as source of truth.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
A :class:`Work` object
|
| 153 |
+
"""
|
| 154 |
+
# TODO: Ideally we should use the meta tensor way
|
| 155 |
+
# (to register a meta kernel for the collective op)
|
| 156 |
+
# so that it would avoid the communication. Need to
|
| 157 |
+
# remove the check below once that is done.
|
| 158 |
+
if tensor.is_meta:
|
| 159 |
+
return None
|
| 160 |
+
dim_group = mesh.get_group(mesh_dim)
|
| 161 |
+
assert isinstance(dim_group, ProcessGroup)
|
| 162 |
+
# src need to be global rank
|
| 163 |
+
src_for_dim = 0
|
| 164 |
+
if dim_group is not GroupMember.WORLD:
|
| 165 |
+
src_for_dim = get_global_rank(dim_group, 0)
|
| 166 |
+
|
| 167 |
+
return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
|
| 171 |
+
if pad_size == 0:
|
| 172 |
+
return tensor
|
| 173 |
+
pad = [0, 0] * (tensor.ndim - pad_dim)
|
| 174 |
+
pad[-1] = pad_size
|
| 175 |
+
return torch.nn.functional.pad(tensor, pad)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
|
| 179 |
+
if pad_size == 0:
|
| 180 |
+
return tensor
|
| 181 |
+
return tensor.narrow(
|
| 182 |
+
pad_dim,
|
| 183 |
+
start=0,
|
| 184 |
+
length=tensor.size(pad_dim) - pad_size,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def fill_empty_tensor_to_shards(
|
| 189 |
+
shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int
|
| 190 |
+
) -> List[torch.Tensor]:
|
| 191 |
+
if num_empty_tensors == 0:
|
| 192 |
+
return shards
|
| 193 |
+
tensor_size = list(shards[0].size())
|
| 194 |
+
tensor_size = [
|
| 195 |
+
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
|
| 196 |
+
]
|
| 197 |
+
tensor = shards[0].new_zeros(tensor_size)
|
| 198 |
+
for _ in range(num_empty_tensors):
|
| 199 |
+
shards.append(tensor)
|
| 200 |
+
return shards
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def check_tensor_meta(
|
| 204 |
+
local_tensor, check_shape_stride=False
|
| 205 |
+
) -> Optional["dtensor_spec.TensorMeta"]:
|
| 206 |
+
local_metadata = {
|
| 207 |
+
"dtype": local_tensor.dtype,
|
| 208 |
+
"requires_grad": local_tensor.requires_grad,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
if check_shape_stride:
|
| 212 |
+
local_metadata.update(
|
| 213 |
+
{"shape": local_tensor.shape, "stride": local_tensor.stride()}
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
gathered_metadata = [None for _ in range(torch.distributed.get_world_size())]
|
| 217 |
+
torch.distributed.all_gather_object(gathered_metadata, local_metadata)
|
| 218 |
+
|
| 219 |
+
# Check if metadata is consistent across ranks
|
| 220 |
+
if not all(meta == local_metadata for meta in gathered_metadata):
|
| 221 |
+
raise ValueError(
|
| 222 |
+
"Inconsistent tensor metadata (including shape and stride) across ranks."
|
| 223 |
+
)
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int:
|
| 228 |
+
assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
|
| 229 |
+
return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@dataclass
|
| 233 |
+
class MeshTopoInfo:
|
| 234 |
+
"""
|
| 235 |
+
Mesh information for collective cost estimation
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
mesh: DeviceMesh
|
| 239 |
+
mesh_dim_devices: List[int]
|
| 240 |
+
mesh_dim_bandwidth: List[float]
|
| 241 |
+
mesh_dim_latency: List[float]
|
| 242 |
+
|
| 243 |
+
@staticmethod
|
| 244 |
+
@lru_cache(None)
|
| 245 |
+
def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
|
| 246 |
+
# Generate mesh topology info for intra-host/inter-host communication pattern
|
| 247 |
+
# Note that we made bunch of assumptions for simplicity:
|
| 248 |
+
# 1. we assume the mesh is homogeneous, and it's gpu/nccl model
|
| 249 |
+
# 2. we assume gpu arch is Ampere or Hopper
|
| 250 |
+
# 3. we assume collectives are all ring base algo for now
|
| 251 |
+
num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type)
|
| 252 |
+
# the base bw number (intra-node), GB/s
|
| 253 |
+
base_bw = 87.7
|
| 254 |
+
mesh_dim_bandwidth = [base_bw] * mesh.ndim
|
| 255 |
+
# the latency in terms of us (intra-node, nv-link)
|
| 256 |
+
mesh_dim_latency = [0.6] * mesh.ndim
|
| 257 |
+
mesh_dim_devices = [1] * mesh.ndim
|
| 258 |
+
|
| 259 |
+
total_num_devices = 1
|
| 260 |
+
for mesh_dim in reversed(range(mesh.ndim)):
|
| 261 |
+
num_devices = mesh.size(mesh_dim)
|
| 262 |
+
mesh_dim_devices[mesh_dim] = num_devices
|
| 263 |
+
total_num_devices *= num_devices
|
| 264 |
+
if total_num_devices > num_devices_per_host:
|
| 265 |
+
# magic number for inter-host communication bandwidth/latency factor
|
| 266 |
+
# This number assumes latest GPU arch, i.e. Ampere or Hopper
|
| 267 |
+
# TODO: see if we need to tweak this or offer a way for user
|
| 268 |
+
# to specify the bandwidths/latency
|
| 269 |
+
mesh_dim_bandwidth[mesh_dim] *= 0.22
|
| 270 |
+
# set to ethernet latency for inter-host
|
| 271 |
+
mesh_dim_latency[mesh_dim] = 2.7
|
| 272 |
+
|
| 273 |
+
return MeshTopoInfo(
|
| 274 |
+
mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
|
| 279 |
+
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
|
| 280 |
+
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
|
| 281 |
+
num_hops = num_devices_on_mesh_dim - 1
|
| 282 |
+
# base latency + comm latency
|
| 283 |
+
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us
|
| 284 |
+
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s
|
| 285 |
+
return latency + bw * 1e6 # rescale to us
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
|
| 289 |
+
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
|
| 290 |
+
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
|
| 291 |
+
# allreduce have almost 2x comm bytes compare to allgather/reduce_scatter
|
| 292 |
+
num_hops = 2 * num_devices_on_mesh_dim - 1
|
| 293 |
+
|
| 294 |
+
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
|
| 295 |
+
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
|
| 296 |
+
return latency + bw * 1e6
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def reduce_scatter_cost(
|
| 300 |
+
bytes_gb: float,
|
| 301 |
+
mesh_topo: MeshTopoInfo,
|
| 302 |
+
mesh_dim: int,
|
| 303 |
+
) -> float:
|
| 304 |
+
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
|
| 305 |
+
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
|
| 306 |
+
num_hops = num_devices_on_mesh_dim - 1
|
| 307 |
+
# base latency + comm latency
|
| 308 |
+
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
|
| 309 |
+
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
|
| 310 |
+
return latency + bw * 1e6
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def redistribute_cost(
|
| 314 |
+
current_spec: "dtensor_spec.DTensorSpec",
|
| 315 |
+
target_spec: "dtensor_spec.DTensorSpec",
|
| 316 |
+
) -> float:
|
| 317 |
+
"""
|
| 318 |
+
This function returns the cost of redistribute from current to target DTensorSpec.
|
| 319 |
+
|
| 320 |
+
NOTE:
|
| 321 |
+
1. Only consider communication cost here, since computation costs for redistribute
|
| 322 |
+
are quite trival (i.e. we only need to narrow or simple division)
|
| 323 |
+
2. Only consider redistribute cost on same mesh, cross mesh communication cost is
|
| 324 |
+
not quite needed for operator strategy estimation/selection.
|
| 325 |
+
"""
|
| 326 |
+
if current_spec.mesh != target_spec.mesh:
|
| 327 |
+
# make infinite cost if meshes are not same
|
| 328 |
+
# TODO: see if we want to support this once there's cross mesh communication
|
| 329 |
+
return float("inf")
|
| 330 |
+
|
| 331 |
+
if current_spec.is_replicated():
|
| 332 |
+
# short-cut:
|
| 333 |
+
# comm cost is 0 if current spec is already full replication
|
| 334 |
+
return 0.0
|
| 335 |
+
|
| 336 |
+
mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
|
| 337 |
+
cost = 0.0
|
| 338 |
+
comm_bytes_gb = (
|
| 339 |
+
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
|
| 340 |
+
)
|
| 341 |
+
# Transformation that considered for redistribute cost:
|
| 342 |
+
# 1. allgather 2. alltoall
|
| 343 |
+
# 3. allreduce 4. reduce_scatter
|
| 344 |
+
for i, (current, target) in enumerate(
|
| 345 |
+
zip(current_spec.placements, target_spec.placements)
|
| 346 |
+
):
|
| 347 |
+
if current == target:
|
| 348 |
+
continue
|
| 349 |
+
|
| 350 |
+
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
|
| 351 |
+
if current.is_shard() and target.is_replicate():
|
| 352 |
+
# allgather gives larger comm bytes
|
| 353 |
+
comm_bytes_gb *= num_devices_on_mesh_dim
|
| 354 |
+
# add up allgather comm cost
|
| 355 |
+
cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
|
| 356 |
+
elif current.is_shard() and target.is_shard():
|
| 357 |
+
# should be alltoall comm, since we haven't implement it yet, add penalty
|
| 358 |
+
# to favor allgather instead
|
| 359 |
+
cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0
|
| 360 |
+
elif current.is_partial() and target.is_replicate():
|
| 361 |
+
# add up allreduce comm cost
|
| 362 |
+
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
|
| 363 |
+
elif current.is_partial() and target.is_shard():
|
| 364 |
+
# add up reduce_scatter comm cost
|
| 365 |
+
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
|
| 366 |
+
# after reduce_scatter the comm bytes for further collectives halved.
|
| 367 |
+
comm_bytes_gb /= num_devices_on_mesh_dim
|
| 368 |
+
elif current.is_shard() and target.is_partial():
|
| 369 |
+
# ban shard -> partial as it does not make sense to perform
|
| 370 |
+
# this redistribute
|
| 371 |
+
return float("inf")
|
| 372 |
+
|
| 373 |
+
return cost
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
import contextlib
|
| 3 |
+
import functools
|
| 4 |
+
import logging
|
| 5 |
+
import operator
|
| 6 |
+
import warnings
|
| 7 |
+
from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import torch.distributed.tensor._api as dtensor
|
| 12 |
+
import torch.distributed.tensor._random as random
|
| 13 |
+
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
| 14 |
+
from torch.distributed.tensor._op_schema import (
|
| 15 |
+
_is_inplace_op,
|
| 16 |
+
_is_out_variant_op,
|
| 17 |
+
OpInfo,
|
| 18 |
+
OpSchema,
|
| 19 |
+
OutputSpecType,
|
| 20 |
+
)
|
| 21 |
+
from torch.distributed.tensor._random import is_rng_supported_mesh
|
| 22 |
+
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
| 23 |
+
from torch.distributed.tensor._sharding_prop import ShardingPropagator
|
| 24 |
+
from torch.distributed.tensor._tp_conv import (
|
| 25 |
+
convolution_backward_handler,
|
| 26 |
+
convolution_handler,
|
| 27 |
+
)
|
| 28 |
+
from torch.distributed.tensor._utils import try_find_mesh_from_args
|
| 29 |
+
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from torch.utils import _cxx_pytree as pytree
|
| 37 |
+
except ImportError:
|
| 38 |
+
from torch.utils import _pytree as pytree # type: ignore[no-redef]
|
| 39 |
+
|
| 40 |
+
aten = torch.ops.aten
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def decompose_handler(
|
| 45 |
+
op_call: torch._ops.OpOverload,
|
| 46 |
+
args: Tuple[object, ...],
|
| 47 |
+
kwargs: Dict[str, object],
|
| 48 |
+
) -> object:
|
| 49 |
+
"""
|
| 50 |
+
Decomposes a op to core ATen op, this handler is mostly here
|
| 51 |
+
for inference mode usage where the ops are not core aten ops.
|
| 52 |
+
"""
|
| 53 |
+
r = op_call.decompose(*args, **kwargs)
|
| 54 |
+
if r is not NotImplemented:
|
| 55 |
+
return r
|
| 56 |
+
else:
|
| 57 |
+
raise RuntimeError("Decomposition failed")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def is_same_size_handler(
|
| 61 |
+
op_call: torch._ops.OpOverload,
|
| 62 |
+
args: Tuple[object, ...],
|
| 63 |
+
kwargs: Dict[str, object],
|
| 64 |
+
) -> bool:
|
| 65 |
+
lhs = cast(torch.Tensor, args[0])
|
| 66 |
+
rhs = cast(torch.Tensor, args[1])
|
| 67 |
+
return lhs.shape == rhs.shape
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def found_inf_reduce_handler(
|
| 71 |
+
op_call: torch._ops.OpOverload,
|
| 72 |
+
args: Tuple[object, ...],
|
| 73 |
+
kwargs: Dict[str, object],
|
| 74 |
+
) -> None:
|
| 75 |
+
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
| 76 |
+
local_tensor_args = pytree.tree_unflatten(
|
| 77 |
+
cast(List[object], op_info.local_args), op_info.args_tree_spec
|
| 78 |
+
)
|
| 79 |
+
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
|
| 80 |
+
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
|
| 81 |
+
|
| 82 |
+
grad_dtensor = cast(list[dtensor.DTensor], args[0])[0]
|
| 83 |
+
grad_placements = grad_dtensor.placements
|
| 84 |
+
mesh = grad_dtensor.device_mesh
|
| 85 |
+
|
| 86 |
+
found_inf_placements: list[Placement] = []
|
| 87 |
+
for placement in grad_placements:
|
| 88 |
+
if isinstance(placement, Replicate):
|
| 89 |
+
found_inf_placements.append(placement)
|
| 90 |
+
else:
|
| 91 |
+
found_inf_placements.append(Partial("max"))
|
| 92 |
+
|
| 93 |
+
target_tensor = cast(torch.Tensor, args[1])
|
| 94 |
+
spec = DTensorSpec(
|
| 95 |
+
mesh=mesh,
|
| 96 |
+
placements=tuple(found_inf_placements),
|
| 97 |
+
tensor_meta=TensorMeta(
|
| 98 |
+
shape=target_tensor.size(),
|
| 99 |
+
stride=target_tensor.stride(),
|
| 100 |
+
dtype=target_tensor.dtype,
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
found_inf_dtensor = dtensor.DTensor(
|
| 104 |
+
local_tensor=target_tensor, spec=spec, requires_grad=False
|
| 105 |
+
)
|
| 106 |
+
found_inf = found_inf_dtensor.full_tensor()
|
| 107 |
+
target_tensor.copy_(found_inf)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class OpDispatcher:
|
| 111 |
+
"""
|
| 112 |
+
Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding
|
| 113 |
+
propagation, redistribute local args, local compute, and post-processing (re-wrapping). It
|
| 114 |
+
also handles any op specific logic if necessary.
|
| 115 |
+
|
| 116 |
+
NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher
|
| 117 |
+
is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster
|
| 118 |
+
pytree if needed, and leveraging various caching mechanisms implemented in the sharding
|
| 119 |
+
propagation and redistribute modules. The CPU overhead is critical to eager mode performance,
|
| 120 |
+
one need to carefully measure the CPU overhead when making significant changes to the
|
| 121 |
+
OpDispatcher and ShardingPropagator.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self) -> None:
|
| 125 |
+
self.sharding_propagator = ShardingPropagator()
|
| 126 |
+
self._random_ops = {
|
| 127 |
+
aten.native_dropout.default,
|
| 128 |
+
aten.normal_.default,
|
| 129 |
+
aten.rand_like.default,
|
| 130 |
+
aten.randn_like.default,
|
| 131 |
+
aten.randint_like.default,
|
| 132 |
+
aten.randint_like.low_dtype,
|
| 133 |
+
aten.randint_like.low_dtype_out,
|
| 134 |
+
aten.uniform_.default,
|
| 135 |
+
aten.bernoulli.default,
|
| 136 |
+
aten.bernoulli_.float,
|
| 137 |
+
}
|
| 138 |
+
self._custom_op_handlers = {
|
| 139 |
+
aten.linear.default: decompose_handler,
|
| 140 |
+
aten.is_same_size.default: is_same_size_handler,
|
| 141 |
+
aten.convolution.default: convolution_handler,
|
| 142 |
+
aten.convolution_backward.default: convolution_backward_handler,
|
| 143 |
+
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
|
| 147 |
+
# as implicitly replicated or we throw error to user.
|
| 148 |
+
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
|
| 149 |
+
# it as False by default.
|
| 150 |
+
self._allow_implicit_replication = False
|
| 151 |
+
|
| 152 |
+
def dispatch(
|
| 153 |
+
self,
|
| 154 |
+
op_call: torch._ops.OpOverload,
|
| 155 |
+
args: Tuple[object, ...],
|
| 156 |
+
kwargs: Dict[str, object],
|
| 157 |
+
) -> object:
|
| 158 |
+
"""
|
| 159 |
+
Main dispatching logic
|
| 160 |
+
"""
|
| 161 |
+
# operators that does not need to go through sharding propagation
|
| 162 |
+
if op_call in self._custom_op_handlers:
|
| 163 |
+
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
|
| 164 |
+
|
| 165 |
+
# extract local tensor and sharding infos to a OpInfo
|
| 166 |
+
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
|
| 167 |
+
logger.debug("Dispatching op_call: %s", op_info.schema)
|
| 168 |
+
|
| 169 |
+
self.sharding_propagator.propagate(op_info)
|
| 170 |
+
output_sharding = op_info.output_sharding
|
| 171 |
+
logger.debug("output_sharding for %s: %s", op_call, output_sharding)
|
| 172 |
+
assert output_sharding is not None, "output sharding should not be None"
|
| 173 |
+
|
| 174 |
+
mesh = op_info.mesh
|
| 175 |
+
if mesh.get_coordinate() is not None:
|
| 176 |
+
# computation that happens in the current rank of the mesh, normal case
|
| 177 |
+
if output_sharding.needs_redistribute:
|
| 178 |
+
# If sharding propagation decision needs redistribute, perform redistribute
|
| 179 |
+
# on args first, which could potentially modify args (i.e. allgather certain arg)
|
| 180 |
+
assert output_sharding.redistribute_schema is not None
|
| 181 |
+
self.redistribute_local_args(
|
| 182 |
+
op_info, output_sharding.redistribute_schema
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
local_tensor_args = (
|
| 186 |
+
pytree.tree_unflatten(
|
| 187 |
+
cast(List[object], op_info.local_args), op_info.args_tree_spec
|
| 188 |
+
)
|
| 189 |
+
if op_info.args_tree_spec
|
| 190 |
+
else op_info.local_args
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# run local op computation with potentially modified args/kwargs
|
| 194 |
+
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
|
| 195 |
+
if op_call in self._random_ops:
|
| 196 |
+
if not random._rng_tracker and is_rng_supported_mesh(mesh):
|
| 197 |
+
# Default to `OffsetBasedRNGTracker` if the parallelism API
|
| 198 |
+
# did not already construct one
|
| 199 |
+
random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
|
| 200 |
+
|
| 201 |
+
first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
|
| 202 |
+
torch.Tensor, local_tensor_args[0]
|
| 203 |
+
)
|
| 204 |
+
rng_context = (
|
| 205 |
+
random._rng_tracker._distribute_region(first_arg._spec)
|
| 206 |
+
if random._rng_tracker and not first_local_arg.is_meta
|
| 207 |
+
else contextlib.nullcontext()
|
| 208 |
+
)
|
| 209 |
+
# For DTensor random operator, run it within a RNGTracker context to
|
| 210 |
+
# ensure the random number generator is properly distributed.
|
| 211 |
+
with rng_context:
|
| 212 |
+
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
|
| 213 |
+
else:
|
| 214 |
+
# normal case, run local sharded op computation
|
| 215 |
+
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
|
| 216 |
+
|
| 217 |
+
else:
|
| 218 |
+
# For a non-participating device (happens on rank that does not belong to
|
| 219 |
+
# the device mesh), we do:
|
| 220 |
+
# 1. if the return type is scalar, set the local result to None.
|
| 221 |
+
# 2. if the return type is Tensor or List[Tensor], return empty
|
| 222 |
+
# tensor(s) with correct dtype.
|
| 223 |
+
spec = output_sharding.output_spec
|
| 224 |
+
ret_list = op_info.schema.op._schema.returns
|
| 225 |
+
|
| 226 |
+
if spec is None:
|
| 227 |
+
# For a scalar return type, the non-participating device has None
|
| 228 |
+
# as its local result
|
| 229 |
+
local_results = None
|
| 230 |
+
else:
|
| 231 |
+
|
| 232 |
+
def default_tensor(spec: DTensorSpec) -> torch.Tensor:
|
| 233 |
+
if spec.tensor_meta is not None:
|
| 234 |
+
shape = spec.tensor_meta.shape
|
| 235 |
+
dtype = spec.tensor_meta.dtype
|
| 236 |
+
if len(shape) == 0:
|
| 237 |
+
# scalar tensor
|
| 238 |
+
return torch.zeros((), dtype=dtype)
|
| 239 |
+
else:
|
| 240 |
+
# non-scalar tensor
|
| 241 |
+
return torch.tensor([], dtype=dtype)
|
| 242 |
+
else:
|
| 243 |
+
raise RuntimeError(f"{spec} has no tensor metadata.")
|
| 244 |
+
|
| 245 |
+
if isinstance(spec, DTensorSpec):
|
| 246 |
+
# return a Tensor value
|
| 247 |
+
local_results = default_tensor(spec)
|
| 248 |
+
elif isinstance(spec, Sequence):
|
| 249 |
+
# return a List[Tensor] value
|
| 250 |
+
local_results = [
|
| 251 |
+
default_tensor(s) if s is not None else None for s in spec
|
| 252 |
+
]
|
| 253 |
+
assert isinstance(local_results, List)
|
| 254 |
+
if None in local_results:
|
| 255 |
+
ret_type = str(ret_list[0].type)
|
| 256 |
+
raise NotImplementedError(
|
| 257 |
+
f"return type {ret_type} in DTensor op is not supported"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
if output_sharding.output_spec is None:
|
| 261 |
+
if op_call == aten.equal.default:
|
| 262 |
+
# For equal operator, The local results from all devices should be all-gathered
|
| 263 |
+
# and a reduce op (AND) will be performed on the list of results to ensure SPMD
|
| 264 |
+
# execution. We can extend this for more ops if necessary.
|
| 265 |
+
obj_list = [None for _ in range(dist.get_world_size())]
|
| 266 |
+
dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined]
|
| 267 |
+
obj_list = list(filter(lambda x: x is not None, obj_list))
|
| 268 |
+
# perform reduce on the collection with AND op
|
| 269 |
+
local_results = functools.reduce(operator.and_, obj_list, True)
|
| 270 |
+
|
| 271 |
+
if _is_inplace_op(op_call):
|
| 272 |
+
# inplace op should return self instead of re-wrapping
|
| 273 |
+
if output_sharding.output_spec is not None:
|
| 274 |
+
return args[0]
|
| 275 |
+
else:
|
| 276 |
+
return None
|
| 277 |
+
elif _is_out_variant_op(op_call):
|
| 278 |
+
# out variant could possibly have multiple out args (i.e. lu_unpack.out)
|
| 279 |
+
output_specs = (
|
| 280 |
+
(output_sharding.output_spec,)
|
| 281 |
+
if not isinstance(output_sharding.output_spec, tuple)
|
| 282 |
+
else output_sharding.output_spec
|
| 283 |
+
)
|
| 284 |
+
out_dts = []
|
| 285 |
+
spec_idx = 0
|
| 286 |
+
for argument in op_call._schema.arguments:
|
| 287 |
+
if argument.is_out:
|
| 288 |
+
out_dt = cast(dtensor.DTensor, kwargs[argument.name])
|
| 289 |
+
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
|
| 290 |
+
out_dts.append(out_dt)
|
| 291 |
+
spec_idx += 1
|
| 292 |
+
|
| 293 |
+
assert len(out_dts) >= 1, "out variant should have at least one out arg"
|
| 294 |
+
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
|
| 295 |
+
else:
|
| 296 |
+
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def redistribute_local_args(
|
| 300 |
+
op_info: OpInfo,
|
| 301 |
+
suggested_input_schema: OpSchema,
|
| 302 |
+
) -> None:
|
| 303 |
+
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
|
| 304 |
+
if op_info.args_tree_spec is not None:
|
| 305 |
+
flatten_args_schema_to_reshard = tuple(
|
| 306 |
+
pytree.tree_leaves(suggested_input_schema.args_schema)
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
flatten_args_schema_to_reshard = suggested_input_schema.args_schema
|
| 310 |
+
|
| 311 |
+
new_local_args: List[object] = []
|
| 312 |
+
for i, arg_spec in enumerate(op_info.flat_args_schema):
|
| 313 |
+
reshard_arg_spec = flatten_args_schema_to_reshard[i]
|
| 314 |
+
if isinstance(arg_spec, DTensorSpec):
|
| 315 |
+
local_tensor = cast(torch.Tensor, op_info.local_args[i])
|
| 316 |
+
if arg_spec != reshard_arg_spec:
|
| 317 |
+
resharded_local_tensor = redistribute_local_tensor(
|
| 318 |
+
local_tensor, arg_spec, reshard_arg_spec
|
| 319 |
+
)
|
| 320 |
+
new_local_args.append(resharded_local_tensor)
|
| 321 |
+
else:
|
| 322 |
+
new_local_args.append(local_tensor)
|
| 323 |
+
else:
|
| 324 |
+
new_local_args.append(reshard_arg_spec)
|
| 325 |
+
|
| 326 |
+
op_info.local_args = tuple(new_local_args)
|
| 327 |
+
|
| 328 |
+
def unwrap_to_op_info(
|
| 329 |
+
self,
|
| 330 |
+
op_call: torch._ops.OpOverload,
|
| 331 |
+
args: Tuple[object, ...],
|
| 332 |
+
kwargs: Dict[str, object],
|
| 333 |
+
) -> OpInfo:
|
| 334 |
+
# get runtime schema info to determine whether to use pytree to flatten inputs
|
| 335 |
+
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
|
| 336 |
+
op_call, None
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
|
| 340 |
+
# flatten args/kwargs when op says necessary
|
| 341 |
+
tree_args, args_spec = pytree.tree_flatten(args)
|
| 342 |
+
args_list: Sequence[object] = tree_args
|
| 343 |
+
else:
|
| 344 |
+
args_list, args_spec = args, None
|
| 345 |
+
|
| 346 |
+
args_schema: List[object] = []
|
| 347 |
+
kwargs_schema: Dict[str, object] = {}
|
| 348 |
+
local_args: List[object] = []
|
| 349 |
+
local_kwargs: Dict[str, object] = {}
|
| 350 |
+
mesh: Optional[DeviceMesh] = None
|
| 351 |
+
|
| 352 |
+
for arg in args_list:
|
| 353 |
+
if isinstance(arg, dtensor.DTensor):
|
| 354 |
+
local_args.append(arg._local_tensor)
|
| 355 |
+
if mesh is not None and mesh != arg.device_mesh:
|
| 356 |
+
# TODO: try replicate dtensor spec in missing dimension would work
|
| 357 |
+
# for most cases for foreach case except when the first DTensor in
|
| 358 |
+
# the list is one that also need to be replicated. We need to revisit
|
| 359 |
+
# how we want to handle this corner case. For now, this case would hit
|
| 360 |
+
# the cross mesh error even if implicit replication is turned on.
|
| 361 |
+
spec = self._try_replicate_dtensor_spec_in_missing_dim(
|
| 362 |
+
op_call, arg, mesh
|
| 363 |
+
)
|
| 364 |
+
args_schema.append(spec)
|
| 365 |
+
else:
|
| 366 |
+
mesh = arg.device_mesh
|
| 367 |
+
args_schema.append(arg._spec)
|
| 368 |
+
elif isinstance(arg, torch.Tensor):
|
| 369 |
+
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
|
| 370 |
+
args_schema.append(
|
| 371 |
+
self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
|
| 372 |
+
)
|
| 373 |
+
local_args.append(arg)
|
| 374 |
+
else:
|
| 375 |
+
args_schema.append(arg)
|
| 376 |
+
local_args.append(arg)
|
| 377 |
+
|
| 378 |
+
for k, v in kwargs.items():
|
| 379 |
+
if isinstance(v, dtensor.DTensor):
|
| 380 |
+
local_kwargs[k] = v._local_tensor
|
| 381 |
+
if mesh is not None and mesh != v.device_mesh:
|
| 382 |
+
spec = self._try_replicate_dtensor_spec_in_missing_dim(
|
| 383 |
+
op_call, v, mesh
|
| 384 |
+
)
|
| 385 |
+
kwargs_schema[k] = spec
|
| 386 |
+
else:
|
| 387 |
+
mesh = v.device_mesh
|
| 388 |
+
kwargs_schema[k] = v._spec
|
| 389 |
+
elif isinstance(v, torch.Tensor):
|
| 390 |
+
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
|
| 391 |
+
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
|
| 392 |
+
op_call, v, mesh
|
| 393 |
+
)
|
| 394 |
+
local_kwargs[k] = v
|
| 395 |
+
else:
|
| 396 |
+
kwargs_schema[k] = v
|
| 397 |
+
local_kwargs[k] = v
|
| 398 |
+
|
| 399 |
+
assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!"
|
| 400 |
+
op_info = OpInfo(
|
| 401 |
+
mesh,
|
| 402 |
+
OpSchema(
|
| 403 |
+
op_call,
|
| 404 |
+
pytree.tree_unflatten(args_schema, args_spec)
|
| 405 |
+
if args_spec
|
| 406 |
+
else tuple(args_schema),
|
| 407 |
+
kwargs_schema,
|
| 408 |
+
schema_info=runtime_schema_info,
|
| 409 |
+
),
|
| 410 |
+
args_schema,
|
| 411 |
+
tuple(local_args),
|
| 412 |
+
local_kwargs,
|
| 413 |
+
args_spec,
|
| 414 |
+
)
|
| 415 |
+
return op_info
|
| 416 |
+
|
| 417 |
+
@staticmethod
|
| 418 |
+
def wrap(res: object, spec: OutputSpecType) -> object:
|
| 419 |
+
if isinstance(res, torch.Tensor):
|
| 420 |
+
if spec is not None:
|
| 421 |
+
assert isinstance(
|
| 422 |
+
spec, DTensorSpec
|
| 423 |
+
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
| 424 |
+
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
|
| 425 |
+
else:
|
| 426 |
+
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
|
| 427 |
+
assert res.ndim == 0, "output tensor should be scalar!"
|
| 428 |
+
return res
|
| 429 |
+
elif isinstance(res, (list, tuple)):
|
| 430 |
+
assert spec is not None and isinstance(
|
| 431 |
+
spec, (list, tuple)
|
| 432 |
+
), f"output spec does not match with output! Expected list/tuple, got {spec}."
|
| 433 |
+
res_list = []
|
| 434 |
+
for e, s in zip(res, spec):
|
| 435 |
+
res_list.append(OpDispatcher.wrap(e, s))
|
| 436 |
+
|
| 437 |
+
return tuple(res_list) if isinstance(res, tuple) else res_list
|
| 438 |
+
else:
|
| 439 |
+
# if the res contains only non tensor values (i.e. int/float/none), we simply return it
|
| 440 |
+
# without rewrapping to DTensor.
|
| 441 |
+
return res
|
| 442 |
+
|
| 443 |
+
def _try_replicate_spec_for_scalar_tensor(
|
| 444 |
+
self,
|
| 445 |
+
op_call: torch._ops.OpOverload,
|
| 446 |
+
tensor_arg: torch.Tensor,
|
| 447 |
+
mesh: "DeviceMesh",
|
| 448 |
+
) -> DTensorSpec:
|
| 449 |
+
# util function to produce a replicate spec for a scalar tensor arg/kwarg
|
| 450 |
+
if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
|
| 451 |
+
warnings.warn(
|
| 452 |
+
"Found a non-scalar tensor with numel=1 and ndim!=0, "
|
| 453 |
+
"we are implicitly creating a replicated DTensor for it. "
|
| 454 |
+
"However, please consider changing it to a scalar tensor "
|
| 455 |
+
"or explicitly create a DTensor under distributed enviroment."
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
if tensor_arg.numel() == 1 or self._allow_implicit_replication:
|
| 459 |
+
# scalar tensor can be safely treated as replicated
|
| 460 |
+
replication_spec = DTensorSpec(
|
| 461 |
+
mesh,
|
| 462 |
+
(Replicate(),) * mesh.ndim,
|
| 463 |
+
tensor_meta=TensorMeta(
|
| 464 |
+
shape=tensor_arg.shape,
|
| 465 |
+
stride=tensor_arg.stride(),
|
| 466 |
+
dtype=tensor_arg.dtype,
|
| 467 |
+
),
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
raise RuntimeError(
|
| 471 |
+
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
|
| 472 |
+
" torch.Tensor to DTensor before calling distributed operators!"
|
| 473 |
+
)
|
| 474 |
+
return replication_spec
|
| 475 |
+
|
| 476 |
+
def _try_replicate_dtensor_spec_in_missing_dim(
|
| 477 |
+
self,
|
| 478 |
+
op_call: torch._ops.OpOverload,
|
| 479 |
+
dtensor_arg: "dtensor.DTensor",
|
| 480 |
+
mesh: "DeviceMesh",
|
| 481 |
+
) -> DTensorSpec:
|
| 482 |
+
# util function to produce a new spec for a DTensor arg/kwarg
|
| 483 |
+
# that puts Replicate() placement in the missing dimension for foreach ops
|
| 484 |
+
from torch.distributed.device_mesh import _mesh_resources
|
| 485 |
+
|
| 486 |
+
cur_mesh = dtensor_arg.device_mesh
|
| 487 |
+
root_mesh = _mesh_resources.get_root_mesh(cur_mesh)
|
| 488 |
+
if (
|
| 489 |
+
self._allow_implicit_replication
|
| 490 |
+
and "foreach" in op_call.__name__
|
| 491 |
+
and root_mesh == mesh
|
| 492 |
+
):
|
| 493 |
+
placements = [Replicate() for _ in range(root_mesh.ndim)]
|
| 494 |
+
cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh)
|
| 495 |
+
placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload]
|
| 496 |
+
replicate_spec = DTensorSpec(
|
| 497 |
+
root_mesh,
|
| 498 |
+
tuple(placements),
|
| 499 |
+
tensor_meta=TensorMeta(
|
| 500 |
+
shape=dtensor_arg.shape,
|
| 501 |
+
stride=dtensor_arg.stride(),
|
| 502 |
+
dtype=dtensor_arg.dtype,
|
| 503 |
+
),
|
| 504 |
+
)
|
| 505 |
+
else:
|
| 506 |
+
raise NotImplementedError(
|
| 507 |
+
f"{op_call}: DTensor does not support cross-mesh operation yet! "
|
| 508 |
+
f"Got meshes: {mesh} {cur_mesh}"
|
| 509 |
+
)
|
| 510 |
+
return replicate_spec
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_dtensor_spec.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, cast, List, NamedTuple, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 6 |
+
from torch.distributed.tensor.placement_types import (
|
| 7 |
+
Partial,
|
| 8 |
+
Placement,
|
| 9 |
+
Replicate,
|
| 10 |
+
Shard,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TensorMeta(NamedTuple):
|
| 15 |
+
# simple named tuple to represent tensor metadata
|
| 16 |
+
# intentionally to stay simple only for sharding
|
| 17 |
+
# propagation purposes.
|
| 18 |
+
shape: torch.Size
|
| 19 |
+
stride: Tuple[int, ...]
|
| 20 |
+
dtype: torch.dtype
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# used internally to propagate the placements
|
| 24 |
+
@dataclass
|
| 25 |
+
class DTensorSpec:
|
| 26 |
+
mesh: DeviceMesh
|
| 27 |
+
placements: Tuple[Placement, ...]
|
| 28 |
+
|
| 29 |
+
# tensor meta will only be set during sharding propagation
|
| 30 |
+
tensor_meta: Optional[TensorMeta] = None
|
| 31 |
+
|
| 32 |
+
def __post_init__(self) -> None:
|
| 33 |
+
if not isinstance(self.placements, tuple):
|
| 34 |
+
self.placements = tuple(self.placements)
|
| 35 |
+
self._hash: Optional[int] = None
|
| 36 |
+
|
| 37 |
+
def __setattr__(self, attr: str, value: Any) -> None:
|
| 38 |
+
super().__setattr__(attr, value)
|
| 39 |
+
# Make sure to recompute the hash in case any of the hashed attributes
|
| 40 |
+
# change (though we do not expect `mesh` or `placements` to change)
|
| 41 |
+
if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
|
| 42 |
+
self._hash = None
|
| 43 |
+
|
| 44 |
+
def _hash_impl(self) -> int:
|
| 45 |
+
# hashing and equality check for DTensorSpec are used to cache the sharding
|
| 46 |
+
# propagation results. We only need to consider the mesh, placements, shape
|
| 47 |
+
# dtype and stride.
|
| 48 |
+
# Caveat: we need to keep this in mind and sync hash and eq if we add more
|
| 49 |
+
# fields to them.
|
| 50 |
+
if self.tensor_meta is not None:
|
| 51 |
+
return hash(
|
| 52 |
+
(
|
| 53 |
+
self.mesh,
|
| 54 |
+
self.placements,
|
| 55 |
+
self.tensor_meta.shape,
|
| 56 |
+
self.tensor_meta.stride,
|
| 57 |
+
self.tensor_meta.dtype,
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
return hash((self.mesh, self.placements))
|
| 61 |
+
|
| 62 |
+
def __hash__(self) -> int:
|
| 63 |
+
# We lazily cache the spec to avoid recomputing the hash upon each
|
| 64 |
+
# use, where we make sure to update the hash when the `tensor_meta`
|
| 65 |
+
# changes by overriding `__setattr__`. This must be lazy so that Dynamo
|
| 66 |
+
# does not try to hash non-singleton `SymInt`s for the stride.
|
| 67 |
+
if self._hash is None:
|
| 68 |
+
self._hash = self._hash_impl()
|
| 69 |
+
return self._hash
|
| 70 |
+
|
| 71 |
+
def __eq__(self, __o: object) -> bool:
|
| 72 |
+
if not (
|
| 73 |
+
isinstance(__o, DTensorSpec)
|
| 74 |
+
and self.mesh == __o.mesh
|
| 75 |
+
and self.placements == __o.placements
|
| 76 |
+
):
|
| 77 |
+
return False
|
| 78 |
+
if self.tensor_meta is None or __o.tensor_meta is None:
|
| 79 |
+
return self.tensor_meta == __o.tensor_meta
|
| 80 |
+
|
| 81 |
+
return (
|
| 82 |
+
self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr]
|
| 83 |
+
and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr]
|
| 84 |
+
and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def __str__(self) -> str:
|
| 88 |
+
"""
|
| 89 |
+
human readable representation of the DTensorSpec
|
| 90 |
+
"""
|
| 91 |
+
if len(self.placements) == 1:
|
| 92 |
+
placement_str = str(self.placements[0])
|
| 93 |
+
else:
|
| 94 |
+
placement_str = str(self.placements)
|
| 95 |
+
|
| 96 |
+
if self.tensor_meta is not None:
|
| 97 |
+
tensor_shape = str(tuple(self.tensor_meta.shape))
|
| 98 |
+
else:
|
| 99 |
+
tensor_shape = "unknown shape"
|
| 100 |
+
|
| 101 |
+
return f"Spec({placement_str} on {tensor_shape})"
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def shape(self) -> torch.Size:
|
| 105 |
+
if self.tensor_meta is None:
|
| 106 |
+
raise ValueError("tensor_meta is not set")
|
| 107 |
+
return self.tensor_meta.shape
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def stride(self) -> Tuple[int, ...]:
|
| 111 |
+
if self.tensor_meta is None:
|
| 112 |
+
raise ValueError("tensor_meta is not set")
|
| 113 |
+
return self.tensor_meta.stride
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def ndim(self) -> int:
|
| 117 |
+
if self.tensor_meta is None:
|
| 118 |
+
raise ValueError("tensor_meta is not set")
|
| 119 |
+
return len(self.tensor_meta.shape)
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def num_shards(self) -> int:
|
| 123 |
+
num_shards = 1
|
| 124 |
+
for i, placement in enumerate(self.placements):
|
| 125 |
+
if placement.is_shard():
|
| 126 |
+
num_shards *= self.mesh.size(i)
|
| 127 |
+
return num_shards
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def device_mesh(self) -> DeviceMesh:
|
| 131 |
+
# simple aliasing for the mesh field, make some
|
| 132 |
+
# checks that mixes DTensor/DTensorSpec easier
|
| 133 |
+
return self.mesh
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def dim_map(self) -> List[int]:
|
| 137 |
+
"""
|
| 138 |
+
dim_map is a property we derive from `placements` of
|
| 139 |
+
the distributed tensor. It simply return a list of ints
|
| 140 |
+
where dim_map[i] denotes the sharding mapping to the mesh
|
| 141 |
+
dimension, and len(dim_map) == dist_tensor.ndim
|
| 142 |
+
dim_map[i] = -1: means tensor dim i replicate on mesh
|
| 143 |
+
dim_map[i] = j: means tensor dim i shard on mesh dim j
|
| 144 |
+
|
| 145 |
+
For example, we have a dist tensor that have the shape of
|
| 146 |
+
[18, 20, 30], and device_mesh([0, 1, 2, 3]), placements:
|
| 147 |
+
[Shard(1)], the dim_map of this placement would be:
|
| 148 |
+
[-1, 0, -1]. This representation is pretty helpful during
|
| 149 |
+
sharding propagation where we could know exactly each
|
| 150 |
+
tensor dimension is sharded or not.
|
| 151 |
+
|
| 152 |
+
Note that if placements contains `_Partial`, we have to
|
| 153 |
+
explicitly deal with it, so that when we create a DTensorSpec
|
| 154 |
+
with dim_map, we could properly record the pending sums.
|
| 155 |
+
"""
|
| 156 |
+
# dims mapping of dist tensor sharding
|
| 157 |
+
# return size of tensor ndim, -1 represent replicate
|
| 158 |
+
# and int >=0 represent shard on that device mesh dim
|
| 159 |
+
r = [-1] * self.ndim
|
| 160 |
+
for i, placement in enumerate(self.placements):
|
| 161 |
+
if placement.is_shard():
|
| 162 |
+
shard_dim = cast(Shard, placement).dim
|
| 163 |
+
if r[shard_dim] > -1:
|
| 164 |
+
raise ValueError(
|
| 165 |
+
f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]},"
|
| 166 |
+
" DTensor operator implementation does not support things like hybrid"
|
| 167 |
+
" sharding strategies yet (i.e. [Shard(0), Shard(0)])"
|
| 168 |
+
)
|
| 169 |
+
r[shard_dim] = i
|
| 170 |
+
return r
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def num_shards_map(self) -> List[int]:
|
| 174 |
+
"""
|
| 175 |
+
dim_map is a property we derive from `placements` of
|
| 176 |
+
the distributed tensor. Unlike `dim_map`, `num_shards_map`
|
| 177 |
+
denotes how many shards each tensor dim has. Like `dim_map`:
|
| 178 |
+
len(num_shards_map) == dist_tensor.ndim
|
| 179 |
+
num_shards_map[i] = 1: means tensor dim i is not sharded
|
| 180 |
+
num_shards_map[i] = j: means tensor dim i has j shards in total
|
| 181 |
+
|
| 182 |
+
For example, we have a dist tensor of shape [18, 20, 30],
|
| 183 |
+
a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements
|
| 184 |
+
([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor
|
| 185 |
+
would be: [4, 2, 1].
|
| 186 |
+
"""
|
| 187 |
+
r = [1] * self.ndim
|
| 188 |
+
for i, placement in enumerate(self.placements):
|
| 189 |
+
if placement.is_shard():
|
| 190 |
+
shard_dim = cast(Shard, placement).dim
|
| 191 |
+
r[shard_dim] *= self.mesh.size(i)
|
| 192 |
+
|
| 193 |
+
return r
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def sums(self) -> List[int]:
|
| 197 |
+
"""
|
| 198 |
+
sums is a property we derive from `placements` of the
|
| 199 |
+
distributed tensor. It simply return a list of ints where
|
| 200 |
+
sums[i] denotes the pending sum (partial) on mesh dim i
|
| 201 |
+
"""
|
| 202 |
+
return [
|
| 203 |
+
idx
|
| 204 |
+
for idx, placement in enumerate(self.placements)
|
| 205 |
+
if placement.is_partial()
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
@classmethod
|
| 209 |
+
def from_dim_map(
|
| 210 |
+
cls,
|
| 211 |
+
mesh: DeviceMesh,
|
| 212 |
+
dim_map: List[int],
|
| 213 |
+
sums: List[int],
|
| 214 |
+
tensor_meta: Optional[TensorMeta] = None,
|
| 215 |
+
) -> "DTensorSpec":
|
| 216 |
+
"""
|
| 217 |
+
Construct a DTensorSpec from dim_map list and pending sum.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec
|
| 221 |
+
dim_map (List[int]): a list of integer that represents sharding on each
|
| 222 |
+
tensor dimension, see `dim_map` property doc for details
|
| 223 |
+
sums (List[int]): a list of integer that represents the dist tensor have
|
| 224 |
+
pending sum on which device mesh dimension.
|
| 225 |
+
tensor meta (TensorMeta): DTensor metadata
|
| 226 |
+
|
| 227 |
+
Return:
|
| 228 |
+
a class:`DTensorSpec` object
|
| 229 |
+
"""
|
| 230 |
+
# by default replicate on device mesh dims
|
| 231 |
+
placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)]
|
| 232 |
+
|
| 233 |
+
# find all mesh dims that need pending reductions
|
| 234 |
+
for s in sums:
|
| 235 |
+
placements[s] = Partial()
|
| 236 |
+
|
| 237 |
+
for i, m in enumerate(dim_map):
|
| 238 |
+
if m >= 0:
|
| 239 |
+
placement = placements[m]
|
| 240 |
+
if placement.is_shard():
|
| 241 |
+
placement = cast(Shard, placement)
|
| 242 |
+
raise RuntimeError(
|
| 243 |
+
f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}"
|
| 244 |
+
)
|
| 245 |
+
elif placement.is_partial():
|
| 246 |
+
raise RuntimeError(
|
| 247 |
+
f"DeviceMesh dimension {m} cannot be both shard and partial!"
|
| 248 |
+
)
|
| 249 |
+
placements[m] = Shard(i)
|
| 250 |
+
|
| 251 |
+
return cls(mesh, tuple(placements), tensor_meta=tensor_meta)
|
| 252 |
+
|
| 253 |
+
def is_replicated(self) -> bool:
|
| 254 |
+
"""
|
| 255 |
+
return True if the current DTensorSpec replicates on all mesh dims (devices)
|
| 256 |
+
"""
|
| 257 |
+
return all(placement.is_replicate() for placement in self.placements)
|
| 258 |
+
|
| 259 |
+
def is_sharded(self) -> bool:
|
| 260 |
+
"""
|
| 261 |
+
return True if the current DTensorSpec is sharded on any mesh dims (devices)
|
| 262 |
+
"""
|
| 263 |
+
return any(placement.is_shard() for placement in self.placements)
|
| 264 |
+
|
| 265 |
+
def shallow_copy_with_tensor_meta(
|
| 266 |
+
self, tensor_meta: Optional[TensorMeta]
|
| 267 |
+
) -> "DTensorSpec":
|
| 268 |
+
"""
|
| 269 |
+
Shallow copy the DTensorSpec with a new tensor_meta.
|
| 270 |
+
"""
|
| 271 |
+
assert tensor_meta is not None, "shallow copy with no tensor_meta!"
|
| 272 |
+
return DTensorSpec(
|
| 273 |
+
self.mesh,
|
| 274 |
+
self.placements,
|
| 275 |
+
tensor_meta=tensor_meta,
|
| 276 |
+
)
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_op_schema.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from functools import cached_property
|
| 4 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._ops import OpOverload
|
| 8 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 9 |
+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
| 10 |
+
from torch.distributed.tensor.placement_types import Placement
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec
|
| 15 |
+
except ImportError:
|
| 16 |
+
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
|
| 17 |
+
tree_leaves,
|
| 18 |
+
tree_map_only,
|
| 19 |
+
TreeSpec,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Common type aliases
|
| 24 |
+
ArgsType = Tuple[object, ...]
|
| 25 |
+
KwargsType = Dict[str, object]
|
| 26 |
+
|
| 27 |
+
PlacementList = List[Optional[Placement]]
|
| 28 |
+
|
| 29 |
+
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
|
| 30 |
+
# be the same set of possibilities.
|
| 31 |
+
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _rebuild_tensor_from_dtensor_meta(arg) -> object:
|
| 35 |
+
"""
|
| 36 |
+
This is used to propagate tensor metadata, must be under fake mode
|
| 37 |
+
"""
|
| 38 |
+
assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
|
| 39 |
+
return torch.empty_strided(
|
| 40 |
+
arg.tensor_meta.shape,
|
| 41 |
+
arg.tensor_meta.stride,
|
| 42 |
+
dtype=arg.tensor_meta.dtype,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _is_inplace_op(op: OpOverload):
|
| 47 |
+
# simple analysis of function schema to determine
|
| 48 |
+
# if this is an inplace variant, it might not
|
| 49 |
+
# be entirely correct, but it's good enough for now.
|
| 50 |
+
return op._schema.name[-1] == "_"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _is_out_variant_op(op: OpOverload):
|
| 54 |
+
# simple analysis of function schema to determine
|
| 55 |
+
# if this is an out variant, it might not
|
| 56 |
+
# be entirely correct, but it's good enough for now.
|
| 57 |
+
return "out" in op._schema.overload_name
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _pretty_print_spec(spec: object) -> str:
|
| 61 |
+
if spec is None:
|
| 62 |
+
return "None"
|
| 63 |
+
elif isinstance(spec, DTensorSpec):
|
| 64 |
+
return "".join([str(p) for p in spec.placements])
|
| 65 |
+
elif isinstance(spec, Sequence):
|
| 66 |
+
return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")"
|
| 67 |
+
else:
|
| 68 |
+
raise RuntimeError(f"Unknown spec type to print: spec={spec}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class PlacementStrategy:
|
| 73 |
+
"""
|
| 74 |
+
A placement strategy describes acceptable sharding placements of the output
|
| 75 |
+
and the tensor arguments of an operation.
|
| 76 |
+
|
| 77 |
+
note: when the op return value is a single DTensor object, output_specs is
|
| 78 |
+
DTensorSpec; when the return value is a tuple of Optional[DTensor],
|
| 79 |
+
output_specs is a tuple of Optional[DTensorSpec].
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
output_specs: Union[DTensorSpec, Tuple[Optional[DTensorSpec], ...]]
|
| 83 |
+
input_specs: Optional[Sequence[DTensorSpec]] = None
|
| 84 |
+
|
| 85 |
+
# redistribute costs for this op placement strategy
|
| 86 |
+
# we need a nested list to record the cost for each
|
| 87 |
+
# operand of this operator, and for each operand of
|
| 88 |
+
# this operator it might have multiple placement strategies
|
| 89 |
+
redistribute_cost: Optional[List[List[float]]] = None
|
| 90 |
+
|
| 91 |
+
@cached_property
|
| 92 |
+
def output_spec(self) -> DTensorSpec:
|
| 93 |
+
"""
|
| 94 |
+
This function requires that the strategy have exactly one DTensorSpec as the
|
| 95 |
+
output spec. If the output_specs is a tuple, we throw an exception.
|
| 96 |
+
"""
|
| 97 |
+
if isinstance(self.output_specs, DTensorSpec):
|
| 98 |
+
return self.output_specs
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"function output_spec expects a single DTensorSpec but got: {self.output_specs}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def input_spec(self, index: int = 0) -> DTensorSpec:
|
| 105 |
+
assert self.input_specs is not None, "input_specs of PlacementStrategy is None!"
|
| 106 |
+
assert len(self.input_specs) > index, (
|
| 107 |
+
f"Invalid index {index} for input_specs of length "
|
| 108 |
+
f"{len(self.input_specs)}: {self.input_specs}"
|
| 109 |
+
)
|
| 110 |
+
return self.input_specs[index]
|
| 111 |
+
|
| 112 |
+
def __str__(self) -> str:
|
| 113 |
+
if self.input_specs is not None:
|
| 114 |
+
input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> "
|
| 115 |
+
else:
|
| 116 |
+
input_specs_str = ""
|
| 117 |
+
output_spec_str = _pretty_print_spec(self.output_specs)
|
| 118 |
+
return f"{input_specs_str}{output_spec_str}"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class StrategyType:
|
| 122 |
+
"""
|
| 123 |
+
Base class type for op strategy, We have two StrategyType:
|
| 124 |
+
OpStrategy and TupleStrategy
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class OpStrategy(StrategyType):
|
| 129 |
+
"""
|
| 130 |
+
OpStrategy that consists of a list of placement strategies associated with the op
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def __init__(self, strategies: List[PlacementStrategy]) -> None:
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.strategies: List[PlacementStrategy] = strategies
|
| 136 |
+
|
| 137 |
+
def __str__(self) -> str:
|
| 138 |
+
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
|
| 139 |
+
mesh_shape = self.mesh_shape
|
| 140 |
+
return f"[{strategy_list_str}] @ mesh: {mesh_shape}"
|
| 141 |
+
|
| 142 |
+
def max_num_shards(self) -> int:
|
| 143 |
+
"""
|
| 144 |
+
Returns the max number of shards across all placement strategies
|
| 145 |
+
"""
|
| 146 |
+
return max(strategy.output_spec.num_shards for strategy in self.strategies)
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def mesh_shape(self):
|
| 150 |
+
output_spec = self.strategies[0].output_specs
|
| 151 |
+
if isinstance(output_spec, DTensorSpec):
|
| 152 |
+
return output_spec.mesh.shape
|
| 153 |
+
else:
|
| 154 |
+
assert isinstance(
|
| 155 |
+
output_spec, tuple
|
| 156 |
+
), "found no DTensorSpec in the OpStrategy!"
|
| 157 |
+
assert output_spec[0] is not None
|
| 158 |
+
return output_spec[0].mesh.shape
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def ndim(self):
|
| 162 |
+
return self.strategies[0].output_spec.ndim
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def shape(self):
|
| 166 |
+
return self.strategies[0].output_spec.shape
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class TupleStrategy(StrategyType):
|
| 170 |
+
"""
|
| 171 |
+
TupleStrategy represents the output strategy of this op is a tuple
|
| 172 |
+
of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors
|
| 173 |
+
with possibly different placement strategies, we should return a TupleStrategy that
|
| 174 |
+
contains a tuple of OpStrategy, where each child represents the sharding strategy
|
| 175 |
+
of "each element" of the tuple/list of tensors the op returns.
|
| 176 |
+
|
| 177 |
+
NOTE: if the output of the op is a List[Tensor] and they share the same placement
|
| 178 |
+
strategy, then we should return a single OpStrategy instead of a TupleStrategy
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, childs: Sequence[StrategyType]) -> None:
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.childs: Sequence[StrategyType] = childs
|
| 184 |
+
|
| 185 |
+
def __str__(self) -> str:
|
| 186 |
+
child_strategies_str = ", ".join(
|
| 187 |
+
[f"{str(strat)}" for idx, strat in enumerate(self.childs)]
|
| 188 |
+
)
|
| 189 |
+
return f"TupleStrategy({child_strategies_str})"
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@dataclass
|
| 193 |
+
class RuntimeSchemaInfo:
|
| 194 |
+
"""
|
| 195 |
+
RuntimeSchemaInfo stores the operator schema related information for runtime (eager)
|
| 196 |
+
execution. This is mainly used for two ways: 1. to generate hash for args to determine
|
| 197 |
+
whether to re-run sharding prop or not 2. to determine if we need pytree
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
# This static_argnum records static arg "starting index" for ops that have non-tensor
|
| 201 |
+
# args/kwargs which would affect sharding propagation results. All args starting from
|
| 202 |
+
# this index would be hashed to our sharding cache.
|
| 203 |
+
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
|
| 204 |
+
static_argnum: int = 100
|
| 205 |
+
# This static_kwargkey records static kwarg names which would affect sharding prop
|
| 206 |
+
static_kwargkey: Optional[List[str]] = None
|
| 207 |
+
# each op can decide if it wants to use pytree flatten/unflatten during operator
|
| 208 |
+
# eager execution, by default we don't need to do flatten/unflatten, only if the
|
| 209 |
+
# op indicate it needs to, this is to accelerate eager performance.
|
| 210 |
+
needs_pytree: bool = False
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@dataclass
|
| 214 |
+
class OpSchema:
|
| 215 |
+
"""
|
| 216 |
+
OpSchema is a data class that describes an operator input schemas, it includes
|
| 217 |
+
DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order
|
| 218 |
+
preserved). It is mainly used by the DTensor's dispatching logic to perform various
|
| 219 |
+
actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
|
| 220 |
+
|
| 221 |
+
NOTE: this should be used as a read only data class
|
| 222 |
+
TODO: make this a frozen dataclass
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
op: the operator overload we are intercepting
|
| 226 |
+
args_schema: contains args except that the DTensor args have been replaced
|
| 227 |
+
with its DTensorSpec or OpStrategy
|
| 228 |
+
kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
|
| 229 |
+
with its DTensorSpec or OpStrategy
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
op: OpOverload
|
| 233 |
+
args_schema: ArgsType
|
| 234 |
+
kwargs_schema: KwargsType
|
| 235 |
+
|
| 236 |
+
schema_info: Optional[RuntimeSchemaInfo] = None
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def args_spec(self) -> Tuple[DTensorSpec, ...]:
|
| 240 |
+
"""
|
| 241 |
+
args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
|
| 242 |
+
with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
|
| 243 |
+
mainly used by sharding propagation to propagate the output spec
|
| 244 |
+
"""
|
| 245 |
+
args = (
|
| 246 |
+
tree_leaves(self.args_schema)
|
| 247 |
+
if self.schema_info is not None and self.schema_info.needs_pytree
|
| 248 |
+
else self.args_schema
|
| 249 |
+
)
|
| 250 |
+
return tuple(item for item in args if isinstance(item, DTensorSpec))
|
| 251 |
+
|
| 252 |
+
@property
|
| 253 |
+
def args_strategy(self) -> Tuple[OpStrategy, ...]:
|
| 254 |
+
# filter out non-relevant values from args schema to get a clean OpStrategy list
|
| 255 |
+
# separate with args_spec for the ease of type annotation
|
| 256 |
+
# TODO: see if we should merge this with args_spec
|
| 257 |
+
args = (
|
| 258 |
+
tree_leaves(self.args_schema)
|
| 259 |
+
if self.schema_info is not None and self.schema_info.needs_pytree
|
| 260 |
+
else self.args_schema
|
| 261 |
+
)
|
| 262 |
+
return tuple(item for item in args if isinstance(item, OpStrategy))
|
| 263 |
+
|
| 264 |
+
def __repr__(self) -> str:
|
| 265 |
+
args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema])
|
| 266 |
+
return (
|
| 267 |
+
f"OpSchema(op={self.op},"
|
| 268 |
+
f" args_schema=({args_schema}),"
|
| 269 |
+
f" kwargs_schema={self.kwargs_schema})"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def __str__(self) -> str:
|
| 273 |
+
args_schema: List[str] = []
|
| 274 |
+
mesh_shape = None
|
| 275 |
+
for arg in self.args_schema:
|
| 276 |
+
if isinstance(arg, DTensorSpec):
|
| 277 |
+
args_schema.append(str(arg))
|
| 278 |
+
mesh_shape = arg.mesh.shape
|
| 279 |
+
elif isinstance(arg, OpStrategy):
|
| 280 |
+
assert len(arg.strategies) == 1
|
| 281 |
+
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
|
| 282 |
+
mesh_shape = arg.mesh_shape
|
| 283 |
+
elif isinstance(arg, TupleStrategy):
|
| 284 |
+
first_op_strtgy = arg.childs[0]
|
| 285 |
+
assert isinstance(first_op_strtgy, OpStrategy)
|
| 286 |
+
mesh_shape = first_op_strtgy.mesh_shape
|
| 287 |
+
args_schema.append(str(arg))
|
| 288 |
+
else:
|
| 289 |
+
args_schema.append(str(arg))
|
| 290 |
+
return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
|
| 291 |
+
|
| 292 |
+
def __post_init__(self) -> None:
|
| 293 |
+
has_symints = False
|
| 294 |
+
for a in self.args_schema:
|
| 295 |
+
if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
|
| 296 |
+
if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
|
| 297 |
+
has_symints = True
|
| 298 |
+
break
|
| 299 |
+
self.has_symints = has_symints
|
| 300 |
+
|
| 301 |
+
def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool:
|
| 302 |
+
arg = self.args_schema[arg_idx]
|
| 303 |
+
is_tensor = isinstance(arg, DTensorSpec)
|
| 304 |
+
if is_tensor:
|
| 305 |
+
return True
|
| 306 |
+
|
| 307 |
+
if not isinstance(arg, list):
|
| 308 |
+
return False
|
| 309 |
+
|
| 310 |
+
return all(isinstance(e, DTensorSpec) or e is None for e in arg)
|
| 311 |
+
|
| 312 |
+
def return_type_tuple_tensor_like(self) -> bool:
|
| 313 |
+
# all dispatch ops could only return Tuple[Tensor] or have None/ints/floats
|
| 314 |
+
# in the tuple, but the first element must be a Tensor, so this check is enough
|
| 315 |
+
return_types = self.op._schema.returns
|
| 316 |
+
return len(return_types) > 1 and isinstance(
|
| 317 |
+
return_types[0].type, torch.TensorType
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def return_type_tensor(self) -> bool:
|
| 321 |
+
return_types = self.op._schema.returns
|
| 322 |
+
# all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
|
| 323 |
+
# return types, so this check is enough for tensor like types
|
| 324 |
+
return isinstance(return_types[0].type, torch.TensorType)
|
| 325 |
+
|
| 326 |
+
def __hash__(self) -> int:
|
| 327 |
+
# Only hash args and kwargs that op indicates to hash
|
| 328 |
+
if not self.schema_info:
|
| 329 |
+
static_argnum = len(self.args_schema)
|
| 330 |
+
static_kwargkey = None
|
| 331 |
+
else:
|
| 332 |
+
static_argnum = self.schema_info.static_argnum
|
| 333 |
+
static_kwargkey = self.schema_info.static_kwargkey
|
| 334 |
+
|
| 335 |
+
args_to_hash = tuple(
|
| 336 |
+
tuple(e) if isinstance(e, list) else e
|
| 337 |
+
for i, e in enumerate(self.args_schema)
|
| 338 |
+
if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum
|
| 339 |
+
)
|
| 340 |
+
if static_kwargkey is not None:
|
| 341 |
+
kwargs_to_hash = tuple(
|
| 342 |
+
self.kwargs_schema.get(k, None) for k in static_kwargkey
|
| 343 |
+
)
|
| 344 |
+
return hash((self.op, args_to_hash, kwargs_to_hash))
|
| 345 |
+
else:
|
| 346 |
+
return hash((self.op, args_to_hash))
|
| 347 |
+
|
| 348 |
+
def __eq__(self, other: object) -> bool:
|
| 349 |
+
# early return checks
|
| 350 |
+
if not isinstance(other, OpSchema):
|
| 351 |
+
return False
|
| 352 |
+
|
| 353 |
+
if self.op != other.op:
|
| 354 |
+
return False
|
| 355 |
+
|
| 356 |
+
if len(self.args_schema) != len(other.args_schema):
|
| 357 |
+
return False
|
| 358 |
+
|
| 359 |
+
# compare each element and early return if any of them is different
|
| 360 |
+
if not self.schema_info:
|
| 361 |
+
static_argnum = len(self.args_schema)
|
| 362 |
+
static_kwargkey = None
|
| 363 |
+
else:
|
| 364 |
+
static_argnum = self.schema_info.static_argnum
|
| 365 |
+
static_kwargkey = self.schema_info.static_kwargkey
|
| 366 |
+
|
| 367 |
+
for i, (self_arg, other_arg) in enumerate(
|
| 368 |
+
zip(self.args_schema, other.args_schema)
|
| 369 |
+
):
|
| 370 |
+
if isinstance(self_arg, DTensorSpec) and self_arg != other_arg:
|
| 371 |
+
return False
|
| 372 |
+
elif i >= static_argnum and self_arg != other_arg:
|
| 373 |
+
return False
|
| 374 |
+
|
| 375 |
+
# check kwarg equality when there's a static kwarg key
|
| 376 |
+
if static_kwargkey:
|
| 377 |
+
for key in static_kwargkey:
|
| 378 |
+
if self.kwargs_schema.get(key, None) != other.kwargs_schema.get(
|
| 379 |
+
key, None
|
| 380 |
+
):
|
| 381 |
+
return False
|
| 382 |
+
|
| 383 |
+
return True
|
| 384 |
+
|
| 385 |
+
def gen_fake_args(self) -> ArgsType:
|
| 386 |
+
"""
|
| 387 |
+
gen_fake_args: generate fake args for the operator, this is mainly used
|
| 388 |
+
by sharding propagation rules to generate fake args for the operator
|
| 389 |
+
to run the local tensor operator and get the output spec.
|
| 390 |
+
"""
|
| 391 |
+
return tree_map_only(
|
| 392 |
+
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
def gen_fake_kwargs(self) -> KwargsType:
|
| 396 |
+
"""
|
| 397 |
+
gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
|
| 398 |
+
by sharding propagation rules to generate fake kwargs for the operator
|
| 399 |
+
to run the local tensor operator and get the output spec.
|
| 400 |
+
"""
|
| 401 |
+
return tree_map_only(
|
| 402 |
+
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
|
| 406 |
+
suggestion_args_spec = self.args_spec
|
| 407 |
+
new_arg_schema: List[object] = []
|
| 408 |
+
idx_of_args_spec = 0
|
| 409 |
+
if (
|
| 410 |
+
origin_schema.schema_info is not None
|
| 411 |
+
and origin_schema.schema_info.needs_pytree
|
| 412 |
+
):
|
| 413 |
+
args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema)
|
| 414 |
+
else:
|
| 415 |
+
args_schema = origin_schema.args_schema
|
| 416 |
+
for arg in args_schema:
|
| 417 |
+
if isinstance(arg, DTensorSpec):
|
| 418 |
+
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
|
| 419 |
+
idx_of_args_spec += 1
|
| 420 |
+
else:
|
| 421 |
+
new_arg_schema.append(arg)
|
| 422 |
+
self.args_schema = tuple(new_arg_schema)
|
| 423 |
+
self.kwargs_schema = origin_schema.kwargs_schema
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
@dataclass
|
| 427 |
+
class OutputSharding:
|
| 428 |
+
"""
|
| 429 |
+
OutputSharding is a data class that is used by the sharding propagation,
|
| 430 |
+
it could set the output_spec upon successful propagation. If needs_redistribute
|
| 431 |
+
is set to True, a redistribute_schema would be returned together to indicate
|
| 432 |
+
the input arguments needs to be redistributed before the op execution.
|
| 433 |
+
|
| 434 |
+
NOTE: the redistribute_schema generated by sharding propagation should be
|
| 435 |
+
exactly the same as the operator OpSchema, except the DTensorSpecs
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
output_spec: OutputSpecType
|
| 439 |
+
redistribute_schema: Optional[OpSchema] = None
|
| 440 |
+
needs_redistribute: bool = False
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
@dataclass
|
| 444 |
+
class OpInfo:
|
| 445 |
+
"""
|
| 446 |
+
All Runtime Op execution info are packed here
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
mesh: DeviceMesh
|
| 450 |
+
schema: OpSchema
|
| 451 |
+
flat_args_schema: List[object]
|
| 452 |
+
local_args: Sequence[object]
|
| 453 |
+
local_kwargs: Dict[str, object]
|
| 454 |
+
args_tree_spec: Optional[TreeSpec] = None
|
| 455 |
+
|
| 456 |
+
# the output sharding info
|
| 457 |
+
output_sharding: Optional[OutputSharding] = None
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
from ._conv_ops import * # noqa: F403
|
| 3 |
+
from ._embedding_ops import * # noqa: F403
|
| 4 |
+
from ._experimental_ops import * # noqa: F403
|
| 5 |
+
from ._math_ops import * # noqa: F403
|
| 6 |
+
from ._matrix_ops import * # noqa: F403
|
| 7 |
+
from ._pointwise_ops import * # noqa: F403
|
| 8 |
+
from ._random_ops import * # noqa: F403
|
| 9 |
+
from ._tensor_ops import * # noqa: F403
|
| 10 |
+
from ._view_ops import * # noqa: F403
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (508 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-311.pyc
ADDED
|
Binary file (4.33 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-311.pyc
ADDED
|
Binary file (7.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_experimental_ops.cpython-311.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-311.pyc
ADDED
|
Binary file (42.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-311.pyc
ADDED
|
Binary file (30.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-311.pyc
ADDED
|
Binary file (1.84 kB). View file
|
|
|