Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll +3 -0
- phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe +3 -0
- phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe +3 -0
- phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe +3 -0
- phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc +3 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/_serialization.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/__init__.py +3 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py +132 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/contract.py +248 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py +3 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py +8 -0
- phivenv/Lib/site-packages/torch/distributed/_composable/replicate.py +256 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/__init__.py +1 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/_utils.py +32 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/api.py +306 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py +19 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc +0 -0
- phivenv/Lib/site-packages/torch/distributed/_shard/common_op_utils.py +65 -0
.gitattributes
CHANGED
|
@@ -48,3 +48,9 @@ phivenv/Lib/site-packages/numpy/_core/__pycache__/fromnumeric.cpython-39.pyc fil
|
|
| 48 |
phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 49 |
phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 50 |
phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 49 |
phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 50 |
phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:44629a7d27806ea076daeae8e829b0cfbdec9e25099561a19af8e5910bd635c5
|
| 3 |
+
size 32816640
|
phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1618387a688f162408e7811350a72269076d52bf6d0f09860548d5b57d677ac
|
| 3 |
+
size 180736
|
phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a00a877acefcad45953343ad56a22152f7aaba5fcf2a10215d84169d47fbcd1d
|
| 3 |
+
size 105984
|
phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:43f1ddcd5bbdcf161d6816b79b4889e7f75d2ce12ab4f7bcc77d16003a17cdaf
|
| 3 |
+
size 166400
|
phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8acb4dd7cd594effc85e8c2b9ac052d6f4fe88744cd4749a8e8b8b93ba88246
|
| 3 |
+
size 151716
|
phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a6e2c125af98ae3013115aad3c6156dd30340dd0c77863105db036c061ddc8e
|
| 3 |
+
size 176641
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (4.49 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-39.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc
ADDED
|
Binary file (32.8 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/_serialization.cpython-39.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc
ADDED
|
Binary file (3.95 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc
ADDED
|
Binary file (532 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc
ADDED
|
Binary file (31.7 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc
ADDED
|
Binary file (7.81 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc
ADDED
|
Binary file (341 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc
ADDED
|
Binary file (3.8 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc
ADDED
|
Binary file (8.85 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_composable/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .checkpoint_activation import checkpoint
|
| 2 |
+
from .contract import _get_registry, contract
|
| 3 |
+
from .replicate import replicate
|
phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (316 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc
ADDED
|
Binary file (6.96 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc
ADDED
|
Binary file (7.13 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from collections.abc import Generator
|
| 3 |
+
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.checkpoint import (
|
| 9 |
+
_checkpoint_without_reentrant_generator,
|
| 10 |
+
_DEFAULT_DETERMINISM_MODE,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from .contract import _State, contract
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@contextmanager
|
| 17 |
+
def _no_hook(module: nn.Module, user_ctx: Optional[AbstractContextManager] = None):
|
| 18 |
+
r"""
|
| 19 |
+
Disable hooks installed by checkpoint to avoid unintentional recursion
|
| 20 |
+
during backward recomputation.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
with user_ctx if user_ctx else nullcontext():
|
| 24 |
+
orig_enable_hook = checkpoint.state(module).enable_hook
|
| 25 |
+
checkpoint.state(module).enable_hook = False
|
| 26 |
+
try:
|
| 27 |
+
yield
|
| 28 |
+
finally:
|
| 29 |
+
checkpoint.state(module).enable_hook = orig_enable_hook
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class _CheckpointState(_State):
|
| 33 |
+
enable_hook: bool = False
|
| 34 |
+
_ac_generator: Optional[Generator[None, None, None]]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@contract(_CheckpointState)
|
| 38 |
+
def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
| 39 |
+
r"""
|
| 40 |
+
This is a composable activation checkpointing API. Unlike functional
|
| 41 |
+
activation checkpointing APIs, this one does not require changing model
|
| 42 |
+
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
|
| 43 |
+
this one does not modify model structure or fully-qualified names either.
|
| 44 |
+
Under the hood, it registers activation checkpointing logic as pre- and
|
| 45 |
+
post-forward hooks. Hence, this API can be easily applied to any model or
|
| 46 |
+
sub-modules in the model.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
module (nn.Module): the target model or sub-module to apply activation
|
| 50 |
+
checkpointing.
|
| 51 |
+
|
| 52 |
+
Example::
|
| 53 |
+
>>> # xdoctest: +SKIP
|
| 54 |
+
>>> import torch.nn as nn
|
| 55 |
+
>>>
|
| 56 |
+
>>> class MyModel(nn.Module):
|
| 57 |
+
>>> def __init__(self) -> None:
|
| 58 |
+
>>> super().__init__()
|
| 59 |
+
>>> self.l1 = nn.Linear(10, 10)
|
| 60 |
+
>>> self.l2 = nn.Linear(10, 10)
|
| 61 |
+
>>>
|
| 62 |
+
>>> def forward(self, x):
|
| 63 |
+
>>> return self.l2(self.l1(x))
|
| 64 |
+
>>>
|
| 65 |
+
>>> model = MyModel()
|
| 66 |
+
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
|
| 67 |
+
>>> model(torch.zeros(2, 10)).sum().backward()
|
| 68 |
+
|
| 69 |
+
"""
|
| 70 |
+
torch._C._log_api_usage_once("torch.distributed.checkpoint")
|
| 71 |
+
|
| 72 |
+
use_reentrant = kwargs.pop("use_reentrant", False)
|
| 73 |
+
if use_reentrant:
|
| 74 |
+
raise NotImplementedError(
|
| 75 |
+
"use_reentrant=True is not supported in composable checkpoint. "
|
| 76 |
+
"Please use torch.utils.checkpoint.checkpoint instead."
|
| 77 |
+
)
|
| 78 |
+
preserve_rng_state = kwargs.pop("preserve_rng_state", True)
|
| 79 |
+
user_context_fns = kwargs.pop("context_fn", None)
|
| 80 |
+
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
|
| 81 |
+
debug = kwargs.pop("debug", False)
|
| 82 |
+
|
| 83 |
+
if kwargs:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward_pre_hook(
|
| 89 |
+
module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
|
| 90 |
+
) -> None:
|
| 91 |
+
if checkpoint.state(module).enable_hook:
|
| 92 |
+
|
| 93 |
+
def context_fns():
|
| 94 |
+
if user_context_fns is not None:
|
| 95 |
+
ctx1, ctx2 = user_context_fns()
|
| 96 |
+
return ctx1, _no_hook(module, ctx2)
|
| 97 |
+
else:
|
| 98 |
+
return nullcontext(), _no_hook(module)
|
| 99 |
+
|
| 100 |
+
gen = _checkpoint_without_reentrant_generator(
|
| 101 |
+
module,
|
| 102 |
+
preserve_rng_state,
|
| 103 |
+
context_fns,
|
| 104 |
+
determinism_check,
|
| 105 |
+
debug,
|
| 106 |
+
*args,
|
| 107 |
+
**kwargs,
|
| 108 |
+
)
|
| 109 |
+
checkpoint.state(module)._ac_generator = gen
|
| 110 |
+
next(gen)
|
| 111 |
+
|
| 112 |
+
def forward_hook(module: nn.Module, inputs: tuple[Any, ...], output: Any) -> Any:
|
| 113 |
+
if checkpoint.state(module).enable_hook:
|
| 114 |
+
try:
|
| 115 |
+
gen = checkpoint.state(module)._ac_generator
|
| 116 |
+
assert gen is not None
|
| 117 |
+
next(gen)
|
| 118 |
+
except StopIteration:
|
| 119 |
+
pass
|
| 120 |
+
else:
|
| 121 |
+
raise RuntimeError(
|
| 122 |
+
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
|
| 126 |
+
# clear this even in the case of exception in fwd pass.
|
| 127 |
+
checkpoint.state(module)._ac_generator = None
|
| 128 |
+
|
| 129 |
+
checkpoint.state(module).enable_hook = True
|
| 130 |
+
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
| 131 |
+
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
|
| 132 |
+
return module
|
phivenv/Lib/site-packages/torch/distributed/_composable/contract.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import uuid
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from functools import wraps
|
| 5 |
+
from typing import Callable, Generic, Optional, Protocol
|
| 6 |
+
from typing_extensions import Concatenate, ParamSpec, TypeVar
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.distributed._composable_state import _State
|
| 11 |
+
from torch.distributed.utils import _get_root_modules
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_T = TypeVar("_T", covariant=True)
|
| 15 |
+
_P = ParamSpec("_P")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_state_key(string="__composable_api_state_key"):
|
| 19 |
+
return f"{string}_{str(uuid.uuid4())}"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
STATE_KEY = generate_state_key()
|
| 23 |
+
REGISTRY_KEY = generate_state_key()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
|
| 27 |
+
# we can add args and kwargs here, and then we can detect whether fully_shard
|
| 28 |
+
# is combined with reentrant activation checkpointing and error out with a clear
|
| 29 |
+
# message.
|
| 30 |
+
class RegistryItem:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
_TState = TypeVar("_TState", bound="_State", covariant=True)
|
| 35 |
+
_M = TypeVar("_M", nn.Module, list[nn.Module])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class _ContractFn(Protocol, Generic[_P, _T, _TState]):
|
| 39 |
+
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
|
| 40 |
+
|
| 41 |
+
def state(self, module: nn.Module) -> _TState: ...
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def contract(
|
| 45 |
+
state_cls: type[_TState] = _State, # type: ignore[assignment]
|
| 46 |
+
) -> Callable[
|
| 47 |
+
[Callable[Concatenate[_M, _P], _M]],
|
| 48 |
+
_ContractFn[Concatenate[_M, _P], _M, _TState],
|
| 49 |
+
]:
|
| 50 |
+
r"""
|
| 51 |
+
Decorate a function as a composable distributed API, where the first
|
| 52 |
+
argument of the function must be an :class:`nn.Module` instance or sequence
|
| 53 |
+
of :class:`nn.Module` instances.
|
| 54 |
+
|
| 55 |
+
The decorator verifies that the decorated function does not modify
|
| 56 |
+
fully-qualified names (FQNs) for parameters, buffers, or modules. The
|
| 57 |
+
decorated function can return different module instances than the input
|
| 58 |
+
modules; the FQN invariant will be enforced following the input order.
|
| 59 |
+
|
| 60 |
+
When a function ``func`` is decorated by ``@contract()``, a
|
| 61 |
+
``.state(module: nn.Module)`` method will be installed to the decorated
|
| 62 |
+
function. Then you can retrieve and modify the state on a module by calling
|
| 63 |
+
``func.state(module)``.
|
| 64 |
+
|
| 65 |
+
Example::
|
| 66 |
+
>>> # xdoctest: +SKIP
|
| 67 |
+
>>> import torch.nn as nn
|
| 68 |
+
>>>
|
| 69 |
+
>>> class MyModel(nn.Module):
|
| 70 |
+
>>> def __init__(self) -> None:
|
| 71 |
+
>>> super().__init__()
|
| 72 |
+
>>> self.l1 = nn.Linear(10, 10)
|
| 73 |
+
>>> self.l2 = nn.Linear(10, 10)
|
| 74 |
+
>>>
|
| 75 |
+
>>> def forward(self, x):
|
| 76 |
+
>>> return self.l2(self.l1(x))
|
| 77 |
+
>>>
|
| 78 |
+
>>> @contract()
|
| 79 |
+
>>> def my_feature(module: nn.Module) -> nn.Module:
|
| 80 |
+
>>> my_feature.state(module).some_state = "any value"
|
| 81 |
+
>>> return module
|
| 82 |
+
>>>
|
| 83 |
+
>>> model = MyModel()
|
| 84 |
+
>>> my_feature(model.l1)
|
| 85 |
+
>>> assert my_feature.state(model.l1).some_state == "any value"
|
| 86 |
+
>>> my_feature(model.l2)
|
| 87 |
+
>>> model(torch.randn(2, 10)).sum().backward()
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
|
| 91 |
+
@wraps(state_cls) # type: ignore[arg-type]
|
| 92 |
+
def inner(
|
| 93 |
+
func: Callable[Concatenate[_M, _P], _M],
|
| 94 |
+
) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
|
| 95 |
+
@wraps(func)
|
| 96 |
+
def wrapper(
|
| 97 |
+
module: _M,
|
| 98 |
+
*args: _P.args,
|
| 99 |
+
**kwargs: _P.kwargs,
|
| 100 |
+
) -> _M:
|
| 101 |
+
inp_module = module
|
| 102 |
+
modules: list[nn.Module]
|
| 103 |
+
if isinstance(module, nn.Module):
|
| 104 |
+
modules = [module]
|
| 105 |
+
else:
|
| 106 |
+
# If the user passes a sequence of modules, then we assume that
|
| 107 |
+
# we only need to insert the state object on the root modules
|
| 108 |
+
# (i.e. those without a parent) among the passed-in modules.
|
| 109 |
+
modules = _get_root_modules(list(module))
|
| 110 |
+
state = state_cls() # shared across all modules
|
| 111 |
+
registry_item = RegistryItem() # shared across all modules
|
| 112 |
+
|
| 113 |
+
# `func` is allowed to return different module instances than the
|
| 114 |
+
# input modules as long as FQNs are preserved following the input
|
| 115 |
+
# module order
|
| 116 |
+
all_orig_named_params: list[dict[str, nn.Parameter]] = []
|
| 117 |
+
all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
|
| 118 |
+
all_orig_named_modules: list[dict[str, nn.Module]] = []
|
| 119 |
+
|
| 120 |
+
for module in modules:
|
| 121 |
+
default_all_state: dict[Callable, _State] = OrderedDict()
|
| 122 |
+
default_registry: dict[str, RegistryItem] = OrderedDict()
|
| 123 |
+
all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
|
| 124 |
+
STATE_KEY, default_all_state
|
| 125 |
+
)
|
| 126 |
+
if not isinstance(all_state, dict):
|
| 127 |
+
raise AssertionError(
|
| 128 |
+
f"Distributed composable API states corrupted: {all_state}"
|
| 129 |
+
)
|
| 130 |
+
registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
|
| 131 |
+
REGISTRY_KEY, default_registry
|
| 132 |
+
)
|
| 133 |
+
if not isinstance(registry, dict):
|
| 134 |
+
raise AssertionError(
|
| 135 |
+
f"Distributed composable API registry corrupted: {registry}"
|
| 136 |
+
)
|
| 137 |
+
if func in all_state or func.__name__ in registry:
|
| 138 |
+
raise AssertionError(
|
| 139 |
+
"Each distinct composable distributed API can only be applied to a "
|
| 140 |
+
f"module once. {func.__name__} has already been applied to the "
|
| 141 |
+
f"following module:\n{module}"
|
| 142 |
+
)
|
| 143 |
+
all_state.setdefault(func, state)
|
| 144 |
+
registry.setdefault(func.__name__, registry_item)
|
| 145 |
+
|
| 146 |
+
all_orig_named_params.append(OrderedDict(module.named_parameters()))
|
| 147 |
+
all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
|
| 148 |
+
all_orig_named_modules.append(OrderedDict(module.named_modules()))
|
| 149 |
+
|
| 150 |
+
updated = func(inp_module, *args, **kwargs)
|
| 151 |
+
if updated is None:
|
| 152 |
+
updated = inp_module # type: ignore[assignment]
|
| 153 |
+
updated_modules: list[nn.Module]
|
| 154 |
+
if isinstance(updated, nn.Module):
|
| 155 |
+
updated_modules = [updated]
|
| 156 |
+
else:
|
| 157 |
+
updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type, call-overload]
|
| 158 |
+
|
| 159 |
+
all_new_named_params: list[dict[str, nn.Parameter]] = []
|
| 160 |
+
all_new_named_buffers: list[dict[str, torch.Tensor]] = []
|
| 161 |
+
all_new_named_modules: list[dict[str, nn.Module]] = []
|
| 162 |
+
for module in updated_modules:
|
| 163 |
+
all_new_named_params.append(OrderedDict(module.named_parameters()))
|
| 164 |
+
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
|
| 165 |
+
all_new_named_modules.append(OrderedDict(module.named_modules()))
|
| 166 |
+
|
| 167 |
+
num_orig_modules = len(all_orig_named_modules)
|
| 168 |
+
num_new_modules = len(all_new_named_modules)
|
| 169 |
+
if num_orig_modules != num_new_modules:
|
| 170 |
+
raise AssertionError(
|
| 171 |
+
f"{func.__name__} should return the same number of modules as input modules"
|
| 172 |
+
f"Inputs: {num_orig_modules} modules\n"
|
| 173 |
+
f"Outputs: {num_new_modules} modules"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
|
| 177 |
+
if orig_fqns == new_fqns:
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
|
| 181 |
+
orig_only = orig_fqn_set - new_fqn_set
|
| 182 |
+
new_only = new_fqn_set - orig_fqn_set
|
| 183 |
+
if len(orig_only) or len(new_only):
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"{check_key}"
|
| 186 |
+
"Composable distributed API implementations cannot modify FQNs.\n"
|
| 187 |
+
f"FQNs only in original: {orig_only}\n"
|
| 188 |
+
f"FQNs only in new: {new_only}"
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
raise RuntimeError(
|
| 192 |
+
f"{check_key}"
|
| 193 |
+
"Composable distributed API implementations cannot modify "
|
| 194 |
+
"the order of FQNs.\n"
|
| 195 |
+
f"Original FQNs: {orig_only}\n"
|
| 196 |
+
f"New FQNs: {new_only}"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
for orig_named_params, new_named_params in zip(
|
| 200 |
+
all_orig_named_params, all_new_named_params
|
| 201 |
+
):
|
| 202 |
+
check_fqn(
|
| 203 |
+
list(orig_named_params.keys()),
|
| 204 |
+
list(new_named_params.keys()),
|
| 205 |
+
"Checking parameters: ",
|
| 206 |
+
)
|
| 207 |
+
for orig_named_buffers, new_named_buffers in zip(
|
| 208 |
+
all_orig_named_buffers, all_new_named_buffers
|
| 209 |
+
):
|
| 210 |
+
check_fqn(
|
| 211 |
+
list(orig_named_buffers.keys()),
|
| 212 |
+
list(new_named_buffers.keys()),
|
| 213 |
+
"Checking buffers: ",
|
| 214 |
+
)
|
| 215 |
+
for orig_named_modules, new_named_modules in zip(
|
| 216 |
+
all_orig_named_modules, all_new_named_modules
|
| 217 |
+
):
|
| 218 |
+
check_fqn(
|
| 219 |
+
list(orig_named_modules.keys()),
|
| 220 |
+
list(new_named_modules.keys()),
|
| 221 |
+
"Checking modules: ",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# TODO: verify that installed distributed paradigms are compatible with
|
| 225 |
+
# each other.
|
| 226 |
+
|
| 227 |
+
return updated
|
| 228 |
+
|
| 229 |
+
def get_state(module: nn.Module) -> _State:
|
| 230 |
+
return module.__dict__.setdefault( # type: ignore[call-overload]
|
| 231 |
+
STATE_KEY,
|
| 232 |
+
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
|
| 233 |
+
).get(func) # type: ignore[call-overload]
|
| 234 |
+
|
| 235 |
+
wrapper.state = get_state # type: ignore[attr-defined]
|
| 236 |
+
|
| 237 |
+
return wrapper # type: ignore[return-value]
|
| 238 |
+
|
| 239 |
+
return inner # type: ignore[return-value]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _get_registry(module: nn.Module) -> Optional[dict[str, RegistryItem]]:
|
| 243 |
+
r"""
|
| 244 |
+
Get an ``OrderedDict`` of composable APIs that have been applied to the
|
| 245 |
+
``module``, indexed by the API name. If no API has been applied, then this
|
| 246 |
+
returns ``None``.
|
| 247 |
+
"""
|
| 248 |
+
return getattr(module, REGISTRY_KEY, None)
|
phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
|
| 2 |
+
|
| 3 |
+
from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method
|
phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (395 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc
ADDED
|
Binary file (323 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO: For backward compatibility, we are importing the public objects
|
| 2 |
+
# originally from this file.
|
| 3 |
+
from torch.distributed.fsdp import ( # noqa: F401
|
| 4 |
+
FSDPModule,
|
| 5 |
+
fully_shard,
|
| 6 |
+
register_fsdp_forward_method,
|
| 7 |
+
UnshardHandle,
|
| 8 |
+
)
|
phivenv/Lib/site-packages/torch/distributed/_composable/replicate.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import weakref
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
from typing import Any, NoReturn, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.distributed._composable_state import _State
|
| 9 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 10 |
+
|
| 11 |
+
from .contract import _get_registry, contract
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_ROOT_MODULE_PREFIX = ""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _ReplicateState(_State):
|
| 18 |
+
_ddp_weakref: weakref.ref
|
| 19 |
+
|
| 20 |
+
def __init__(self) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.module: nn.Module = nn.ParameterList()
|
| 23 |
+
self.has_initialized: bool = False
|
| 24 |
+
self._param_list: nn.ParameterList = nn.ParameterList()
|
| 25 |
+
# TODO(@fegin): this variable is originally create for testing, we
|
| 26 |
+
# should remove this if possible.
|
| 27 |
+
self._orig_module = self.module
|
| 28 |
+
self._param_names: list[str] = []
|
| 29 |
+
self._no_sync: bool = False
|
| 30 |
+
self._init_args: Optional[tuple[Any, ...]] = None
|
| 31 |
+
self._init_kwargs: dict[str, Any] = {}
|
| 32 |
+
self._comm_hook_args: list[Any] = []
|
| 33 |
+
|
| 34 |
+
def _collect_params(
|
| 35 |
+
self,
|
| 36 |
+
module: nn.Module,
|
| 37 |
+
ignored_modules: set[nn.Module],
|
| 38 |
+
ignored_params: set[nn.Parameter],
|
| 39 |
+
prefix: str = _ROOT_MODULE_PREFIX,
|
| 40 |
+
) -> None:
|
| 41 |
+
# skip if managed by fully_sharded API
|
| 42 |
+
if _is_fully_sharded(module):
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
# if a module is ignored, all descendants of the module are ignored.
|
| 46 |
+
if module in ignored_modules:
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
recurse_prefix = (
|
| 50 |
+
f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
for n, p in module.named_parameters(recurse=False):
|
| 54 |
+
if p not in ignored_params:
|
| 55 |
+
self._param_list.append(p)
|
| 56 |
+
self._param_names.append(f"{recurse_prefix}{n}")
|
| 57 |
+
|
| 58 |
+
for name, child_module in module.named_children():
|
| 59 |
+
self._collect_params(
|
| 60 |
+
child_module,
|
| 61 |
+
ignored_modules,
|
| 62 |
+
ignored_params,
|
| 63 |
+
prefix=f"{recurse_prefix}{name}",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def lazy_init(self) -> None:
|
| 67 |
+
@torch._disable_dynamo(recursive=True)
|
| 68 |
+
def _lazy_init():
|
| 69 |
+
assert self._init_args is not None
|
| 70 |
+
self.init(*self._init_args, **self._init_kwargs)
|
| 71 |
+
self.register_comm_hook()
|
| 72 |
+
self._init_args = ()
|
| 73 |
+
self._init_kwargs = {}
|
| 74 |
+
|
| 75 |
+
_lazy_init()
|
| 76 |
+
|
| 77 |
+
def init(
|
| 78 |
+
self,
|
| 79 |
+
module: nn.Module,
|
| 80 |
+
ignored_modules: set[nn.Module],
|
| 81 |
+
**kwargs,
|
| 82 |
+
) -> None:
|
| 83 |
+
if self.has_initialized:
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
self.has_initialized = True
|
| 87 |
+
self.module = module
|
| 88 |
+
ignored_params = {p for m in ignored_modules for p in m.parameters()}
|
| 89 |
+
for submodule in module.modules():
|
| 90 |
+
if _is_fully_sharded(submodule):
|
| 91 |
+
ignored_params.update(submodule.parameters())
|
| 92 |
+
from torch.distributed.tensor.parallel.ddp import _localize_dtensor
|
| 93 |
+
|
| 94 |
+
_localize_dtensor(module, ignored_params=ignored_params)
|
| 95 |
+
self._collect_params(module, ignored_modules, ignored_params)
|
| 96 |
+
|
| 97 |
+
if "device_id" in kwargs:
|
| 98 |
+
# replicate() supports a small usability enhancement where
|
| 99 |
+
# user can pass in device_id as a Union[int, torch.device] even for
|
| 100 |
+
# CPU devices so users don't have to change code for CPU/GPU runs.
|
| 101 |
+
# We derive the right device_ids to feed into DDP to support this.
|
| 102 |
+
if kwargs["device_id"] is not None:
|
| 103 |
+
device_id = kwargs["device_id"]
|
| 104 |
+
# Convert to device_ids that DDP expects.
|
| 105 |
+
if isinstance(device_id, torch.device) and device_id.type == "cpu":
|
| 106 |
+
# CPU modules receive device_ids None
|
| 107 |
+
kwargs["device_ids"] = None
|
| 108 |
+
else:
|
| 109 |
+
# GPU modules expect device_ids=[cuda_device]
|
| 110 |
+
kwargs["device_ids"] = [device_id]
|
| 111 |
+
else:
|
| 112 |
+
kwargs["device_ids"] = None
|
| 113 |
+
kwargs.pop("device_id")
|
| 114 |
+
|
| 115 |
+
self._ddp = DistributedDataParallel(self._param_list, **kwargs)
|
| 116 |
+
# Weakref to the DDP instance is currently only used for testing.
|
| 117 |
+
replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
|
| 118 |
+
|
| 119 |
+
def register_comm_hook(self) -> None:
|
| 120 |
+
for comm_args, comm_kwargs in self._comm_hook_args:
|
| 121 |
+
self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
|
| 122 |
+
self._comm_hook_args.clear()
|
| 123 |
+
|
| 124 |
+
def record_init_args(self, *args, **kwargs) -> None:
|
| 125 |
+
self._init_args = args
|
| 126 |
+
self._init_kwargs = kwargs
|
| 127 |
+
|
| 128 |
+
def forward_pre_hook(
|
| 129 |
+
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
|
| 130 |
+
) -> Any:
|
| 131 |
+
if self._init_args or self._init_kwargs:
|
| 132 |
+
self.lazy_init()
|
| 133 |
+
self._ddp.require_backward_grad_sync = not self._no_sync
|
| 134 |
+
return self._ddp._pre_forward(*args, **kwargs)
|
| 135 |
+
|
| 136 |
+
def forward_post_hook(
|
| 137 |
+
self,
|
| 138 |
+
module: nn.Module,
|
| 139 |
+
input: tuple[torch.Tensor],
|
| 140 |
+
output: torch.Tensor,
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
return self._ddp._post_forward(output)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
|
| 146 |
+
raise AssertionError(
|
| 147 |
+
"DDP does not support deepcopy. Please use state dict for serialization."
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Follow the same pattern as FSDP/fully_shard
|
| 152 |
+
class DDP:
|
| 153 |
+
def __new__(cls, *args, **kwargs):
|
| 154 |
+
"""
|
| 155 |
+
Override ``__new__`` to remove the DDP class and directly construct
|
| 156 |
+
the original class for cases like indexing into a container module.
|
| 157 |
+
"""
|
| 158 |
+
# Use index 2 since 0 is the dynamically constructed `DDP<...>` class
|
| 159 |
+
# and index 1 is the `DDP` class itself
|
| 160 |
+
orig_cls = cls.__mro__[2]
|
| 161 |
+
return orig_cls.__new__(orig_cls, *args, **kwargs)
|
| 162 |
+
|
| 163 |
+
def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
|
| 164 |
+
"""
|
| 165 |
+
Sets if the module should sync gradients. This can be used to implement
|
| 166 |
+
gradient accumulation without communication.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
requires_gradient_sync (bool): Whether to reduce gradients for the
|
| 170 |
+
module's parameters.
|
| 171 |
+
"""
|
| 172 |
+
replicate.state(self)._no_sync = not requires_gradient_sync # type: ignore[arg-type]
|
| 173 |
+
|
| 174 |
+
def register_comm_hook(self, *args, **kwargs) -> None:
|
| 175 |
+
replicate.state(self)._comm_hook_args.append((args, kwargs)) # type: ignore[arg-type]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@contract(state_cls=_ReplicateState)
|
| 179 |
+
def replicate(
|
| 180 |
+
module: nn.Module,
|
| 181 |
+
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
| 182 |
+
**kwargs,
|
| 183 |
+
) -> nn.Module:
|
| 184 |
+
r"""Replicates a module
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
module (torch.nn.Module): module to replicate
|
| 188 |
+
|
| 189 |
+
Example::
|
| 190 |
+
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
|
| 191 |
+
>>> module = nn.Linear(3, 3)
|
| 192 |
+
>>> replicate(module)
|
| 193 |
+
"""
|
| 194 |
+
torch._C._log_api_usage_once("torch.distributed.replicate")
|
| 195 |
+
|
| 196 |
+
# TODO(fegin): using kwargs is not a good idea if we would like to make
|
| 197 |
+
# replicate a formal API to replace DDP.
|
| 198 |
+
if "device_id" in kwargs:
|
| 199 |
+
if not isinstance(kwargs["device_id"], (int, torch.device)):
|
| 200 |
+
raise RuntimeError(
|
| 201 |
+
"Expected device_id to be int or torch.device, "
|
| 202 |
+
f"but got {type(kwargs['device_id'])}"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if _is_fully_sharded(module):
|
| 206 |
+
raise RuntimeError(
|
| 207 |
+
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
if ignored_modules is None:
|
| 211 |
+
ignored_modules = {}
|
| 212 |
+
else:
|
| 213 |
+
ignored_modules = set(ignored_modules)
|
| 214 |
+
|
| 215 |
+
state = replicate.state(module)
|
| 216 |
+
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
|
| 217 |
+
device_mesh = kwargs.get("device_mesh", None)
|
| 218 |
+
if device_mesh is not None:
|
| 219 |
+
from torch.distributed.device_mesh import _mesh_resources
|
| 220 |
+
|
| 221 |
+
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
| 222 |
+
# if a root mesh is not the same as device_mesh,
|
| 223 |
+
# meaning the device_mesh is sliced out from the root mesh.
|
| 224 |
+
if root_mesh != device_mesh:
|
| 225 |
+
# TODO: This is a temporary work around to enable DDP + TP.
|
| 226 |
+
# We should do the logic in DDP so that the 2D implementation is
|
| 227 |
+
# sound and the state_dict works out of the box.
|
| 228 |
+
#
|
| 229 |
+
# This won't conflict with what is done in DDP class as the module
|
| 230 |
+
# replicate is going to pass is NOT the original module.
|
| 231 |
+
from torch.distributed.tensor.parallel.ddp import (
|
| 232 |
+
_localize_dtensor,
|
| 233 |
+
_reconstruct_dtensor,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
module.register_forward_pre_hook(_reconstruct_dtensor)
|
| 237 |
+
module.register_forward_hook(_localize_dtensor)
|
| 238 |
+
|
| 239 |
+
module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type]
|
| 240 |
+
|
| 241 |
+
state.record_init_args(module, ignored_modules, **kwargs)
|
| 242 |
+
|
| 243 |
+
# Place DDP leftmost for highest priority in the method resolution order
|
| 244 |
+
cls = module.__class__
|
| 245 |
+
dct = {"__deepcopy__": unimplemented_deepcopy}
|
| 246 |
+
new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
|
| 247 |
+
module.__class__ = new_cls
|
| 248 |
+
return module
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _is_fully_sharded(module: nn.Module) -> bool:
|
| 252 |
+
r"""Check if module is marked with fully_shard."""
|
| 253 |
+
registry = _get_registry(module)
|
| 254 |
+
if registry is None:
|
| 255 |
+
return False
|
| 256 |
+
return "fully_shard" in registry
|
phivenv/Lib/site-packages/torch/distributed/_shard/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter
|
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (292 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc
ADDED
|
Binary file (9.82 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc
ADDED
|
Binary file (2.28 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_shard/_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Sequence
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.distributed._shard.metadata import ShardMetadata
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor."
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def narrow_tensor_by_index(
|
| 11 |
+
tensor: torch.Tensor,
|
| 12 |
+
offsets: Sequence[int],
|
| 13 |
+
sizes: Sequence[int],
|
| 14 |
+
) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Narrow the tensor according to ``offsets`` and ``sizes``.
|
| 17 |
+
"""
|
| 18 |
+
narrowed_tensor = tensor
|
| 19 |
+
for idx, (offset, size) in enumerate(zip(offsets, sizes)):
|
| 20 |
+
if size < tensor.size(idx):
|
| 21 |
+
# Reshape to get shard for this rank and we don't want autograd
|
| 22 |
+
# recording here for the narrow op and 'local_shard' should be a
|
| 23 |
+
# leaf variable in the autograd graph.
|
| 24 |
+
narrowed_tensor = narrowed_tensor.narrow(idx, offset, size)
|
| 25 |
+
return narrowed_tensor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor:
|
| 29 |
+
"""
|
| 30 |
+
Narrow the tensor according to the metadata
|
| 31 |
+
"""
|
| 32 |
+
return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes)
|
phivenv/Lib/site-packages/torch/distributed/_shard/api.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.distributed import distributed_c10d
|
| 9 |
+
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
| 10 |
+
|
| 11 |
+
from .sharder import Sharder
|
| 12 |
+
from .sharding_plan import ShardingPlan
|
| 13 |
+
from .sharding_spec import ChunkShardingSpec, ShardingSpec
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _shard_tensor(
|
| 17 |
+
tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
|
| 18 |
+
) -> ShardedTensor:
|
| 19 |
+
"""
|
| 20 |
+
Given a :class:`torch.Tensor`, it shards that tensor according to the provided
|
| 21 |
+
``sharding_spec``. ``src_rank`` denotes the source rank which would be
|
| 22 |
+
used as the ground truth of the data which would be scattered as shards
|
| 23 |
+
across the rest of the ranks.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
|
| 27 |
+
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
| 28 |
+
describing how to shard the Tensor.
|
| 29 |
+
|
| 30 |
+
Keyword args:
|
| 31 |
+
src_rank (int, optional): The source rank which is used as the ground truth of
|
| 32 |
+
the data for the parameter that would be sharded and scattered
|
| 33 |
+
across the rest of the ranks.
|
| 34 |
+
Default: 0.
|
| 35 |
+
process_group (ProcessGroup, optional): The process group to work on. If None,
|
| 36 |
+
the default process group will be used.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
A :class:`ShardedTensor` sharded from the given tensor.
|
| 40 |
+
|
| 41 |
+
.. warning::
|
| 42 |
+
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
|
| 43 |
+
currently supported as the ``sharding_spec``.
|
| 44 |
+
"""
|
| 45 |
+
if not tensor.is_contiguous():
|
| 46 |
+
raise ValueError("input tensor is not a contiguous Tensor")
|
| 47 |
+
|
| 48 |
+
pg = (
|
| 49 |
+
process_group
|
| 50 |
+
if process_group is not None
|
| 51 |
+
else distributed_c10d._get_default_group()
|
| 52 |
+
)
|
| 53 |
+
world_size = dist.get_world_size(pg)
|
| 54 |
+
current_rank = dist.get_rank(pg)
|
| 55 |
+
|
| 56 |
+
# Validate src_rank and sharding_spec are same across all ranks.
|
| 57 |
+
gathered_list = [None] * world_size
|
| 58 |
+
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
|
| 59 |
+
|
| 60 |
+
for idx, entry in enumerate(gathered_list):
|
| 61 |
+
if src_rank != entry[0]: # type: ignore[index]
|
| 62 |
+
raise ValueError(
|
| 63 |
+
f"src_rank={src_rank} on rank: {current_rank} does not " # type: ignore[index]
|
| 64 |
+
f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index]
|
| 65 |
+
)
|
| 66 |
+
if sharding_spec != entry[1]: # type: ignore[index]
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f"sharding_spec={sharding_spec} on rank: {current_rank} does not " # type: ignore[index]
|
| 69 |
+
f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg)
|
| 73 |
+
|
| 74 |
+
return st
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def shard_parameter(
|
| 78 |
+
module: torch.nn.Module,
|
| 79 |
+
param_name: str,
|
| 80 |
+
sharding_spec: ShardingSpec,
|
| 81 |
+
src_rank=0,
|
| 82 |
+
process_group=None,
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
|
| 86 |
+
module, it shards that parameter according to the provided
|
| 87 |
+
``sharding_spec``. ``src_rank`` denotes the source rank which would be
|
| 88 |
+
used as the ground truth of the data which would be scattered as shards
|
| 89 |
+
across the rest of the ranks.
|
| 90 |
+
|
| 91 |
+
This method replaces ``module.param_name`` with a
|
| 92 |
+
:class:`torch.distributed._sharded_tensor.ShardedTensor`
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
|
| 96 |
+
param_name (str): Name of the parameter of ``module`` that needs to be sharded.
|
| 97 |
+
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
| 98 |
+
describing how to shard the Tensor.
|
| 99 |
+
|
| 100 |
+
Keyword args:
|
| 101 |
+
src_rank (int, optional): The source rank which is used as the ground truth of
|
| 102 |
+
the data for the parameter that would be sharded and scattered
|
| 103 |
+
across the rest of the ranks.
|
| 104 |
+
Default: 0.
|
| 105 |
+
process_group (ProcessGroup, optional): The process group to work on. If None,
|
| 106 |
+
the default process group will be used.
|
| 107 |
+
|
| 108 |
+
.. warning::
|
| 109 |
+
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
|
| 110 |
+
currently supported as the ``sharding_spec``.
|
| 111 |
+
"""
|
| 112 |
+
# Perform some validation first.
|
| 113 |
+
if not hasattr(module, param_name):
|
| 114 |
+
raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`")
|
| 115 |
+
|
| 116 |
+
tensor = getattr(module, param_name)
|
| 117 |
+
if not isinstance(tensor, torch.Tensor):
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if not tensor.is_contiguous():
|
| 123 |
+
raise ValueError(f"param: {param_name} is not a contiguous Tensor")
|
| 124 |
+
|
| 125 |
+
st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
|
| 126 |
+
|
| 127 |
+
# Replace param with ShardedTensor.
|
| 128 |
+
module.register_parameter(param_name, nn.Parameter(st))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Tracks the current process group in the load context manager.
|
| 132 |
+
_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@contextmanager
|
| 136 |
+
def load_with_process_group(process_group):
|
| 137 |
+
"""
|
| 138 |
+
Context manager to set the process group with which to load a ShardedTensor.
|
| 139 |
+
"""
|
| 140 |
+
global _CURRENT_PROCESS_GROUP
|
| 141 |
+
if _CURRENT_PROCESS_GROUP is not None:
|
| 142 |
+
raise RuntimeError(
|
| 143 |
+
'ProcessGroup already set by previous "load_with_process_group" '
|
| 144 |
+
"context manager"
|
| 145 |
+
)
|
| 146 |
+
_CURRENT_PROCESS_GROUP = process_group
|
| 147 |
+
try:
|
| 148 |
+
yield process_group
|
| 149 |
+
finally:
|
| 150 |
+
_CURRENT_PROCESS_GROUP = None
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _get_current_process_group():
|
| 154 |
+
"""
|
| 155 |
+
Retrieves the current process group set by ``load_with_process_group``.
|
| 156 |
+
If not set, it just returns the default group.
|
| 157 |
+
"""
|
| 158 |
+
global _CURRENT_PROCESS_GROUP
|
| 159 |
+
if _CURRENT_PROCESS_GROUP is None:
|
| 160 |
+
return distributed_c10d._get_default_group()
|
| 161 |
+
else:
|
| 162 |
+
return _CURRENT_PROCESS_GROUP
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _reshard_output(
|
| 166 |
+
module: torch.nn.Module, resharding_spec: ShardingSpec
|
| 167 |
+
) -> torch.nn.Module:
|
| 168 |
+
"""
|
| 169 |
+
Hook a module with output resharding in the forward pass according
|
| 170 |
+
to the given ``resharding_spec``.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
|
| 174 |
+
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
|
| 175 |
+
The specification describing how the output of the module will be resharded.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
A :class:`torch.nn.Module` object with reshard API hooked.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def hook_func(_module, _input, output):
|
| 182 |
+
if isinstance(output, ShardedTensor):
|
| 183 |
+
return output.reshard(resharding_spec)
|
| 184 |
+
return output
|
| 185 |
+
|
| 186 |
+
module.register_forward_hook(hook_func)
|
| 187 |
+
return module
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
|
| 191 |
+
"""
|
| 192 |
+
Hook a module with local shards collection in the forward pass.
|
| 193 |
+
|
| 194 |
+
This API is typically used to convert a sharded representation back to data parallel
|
| 195 |
+
representation. In particular, it returns the local tensor for this Shard. If the
|
| 196 |
+
size along the sharding dimension for the local tensor is 1, this dimension is removed
|
| 197 |
+
from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
|
| 198 |
+
a local Tensor of size [16] across each rank and not [1, 16] across each rank.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the
|
| 202 |
+
local tensor value needs to be returned.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
A :class:`torch.nn.Module` object with collection API hooked.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def hook_func(_module, _input, output):
|
| 209 |
+
if isinstance(output, ShardedTensor):
|
| 210 |
+
local_tensor = output.local_tensor()
|
| 211 |
+
# Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec
|
| 212 |
+
sharding_spec = output._sharding_spec
|
| 213 |
+
if (
|
| 214 |
+
isinstance(sharding_spec, ChunkShardingSpec)
|
| 215 |
+
and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type]
|
| 216 |
+
):
|
| 217 |
+
local_tensor = local_tensor.squeeze(
|
| 218 |
+
output._sharding_spec.dim # type: ignore[attr-defined]
|
| 219 |
+
)
|
| 220 |
+
return local_tensor
|
| 221 |
+
|
| 222 |
+
module.register_forward_hook(hook_func)
|
| 223 |
+
return module
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None):
|
| 227 |
+
"""
|
| 228 |
+
Shards a given module according to the provided sharding `plan`. This method
|
| 229 |
+
first shards all the parameters according to the given sharding `plan`. Then if
|
| 230 |
+
`output_plan` and `return_local_tensor` are specified in the sharding `plan`, it
|
| 231 |
+
will tag the output of modules according `output_plan`, convert the module's
|
| 232 |
+
output back to data parallel according to `return_local_tensor`.
|
| 233 |
+
|
| 234 |
+
Needs to be called on all ranks in an SPMD fashion.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
module (:class:`torch.nn.Module`): The module to apply sharding to
|
| 238 |
+
plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`):
|
| 239 |
+
The ShardingPlan which specified param name to ShardingSpec to apply to
|
| 240 |
+
each parameter.
|
| 241 |
+
|
| 242 |
+
Keyword args:
|
| 243 |
+
src_rank (int, optional): The source rank which is used as the ground truth of
|
| 244 |
+
the data for the module that would be sharded and scattered across the rest
|
| 245 |
+
of the ranks.
|
| 246 |
+
Default: 0.
|
| 247 |
+
process_group (ProcessGroup, optional): The process group to work on. If None,
|
| 248 |
+
the default process group will be used.
|
| 249 |
+
"""
|
| 250 |
+
# record Sharder paths for sanity check on the plan to ensure items in the plan
|
| 251 |
+
# does not conflict with the submodule tree that the Sharder is working with
|
| 252 |
+
sharder_paths = []
|
| 253 |
+
for name, spec in plan.plan.items():
|
| 254 |
+
if isinstance(spec, Sharder):
|
| 255 |
+
sharder_paths.append(name)
|
| 256 |
+
|
| 257 |
+
# shard the parameter according to the ShardingPlan
|
| 258 |
+
for name, spec in plan.plan.items():
|
| 259 |
+
if isinstance(spec, ShardingSpec):
|
| 260 |
+
# if found a sharding spec, try to shard the parameter
|
| 261 |
+
module_path, _, param_name = name.rpartition(".")
|
| 262 |
+
|
| 263 |
+
for sharder_path in sharder_paths:
|
| 264 |
+
if module_path.startswith(sharder_path):
|
| 265 |
+
raise RuntimeError(
|
| 266 |
+
f"ShardingPlan is in-valid, trying to shard a parameter: {name},"
|
| 267 |
+
f" but there's already a Sharder entry for module {sharder_path},"
|
| 268 |
+
f" parameter sharding should not conflict with the submodule tree"
|
| 269 |
+
f" that a Sharder is working with!"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
mod = module.get_submodule(module_path)
|
| 273 |
+
shard_parameter(
|
| 274 |
+
mod, param_name, spec, src_rank=src_rank, process_group=process_group
|
| 275 |
+
)
|
| 276 |
+
elif isinstance(spec, Sharder):
|
| 277 |
+
parent_mod_path, _, _mod_name = name.rpartition(".")
|
| 278 |
+
if name == "":
|
| 279 |
+
raise KeyError("Module path must not be empty for custom sharder!")
|
| 280 |
+
mod = module.get_submodule(name)
|
| 281 |
+
parent_mod = module.get_submodule(parent_mod_path)
|
| 282 |
+
sharded_mod = spec.shard(mod)
|
| 283 |
+
# swap this submodule with the sharded module
|
| 284 |
+
parent_mod.mod_name = sharded_mod
|
| 285 |
+
else:
|
| 286 |
+
raise TypeError(
|
| 287 |
+
f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# reshard output if there's an entry in `reshard_output` for this module
|
| 291 |
+
if plan.output_plan is not None:
|
| 292 |
+
for module_path, output_spec in plan.output_plan.items():
|
| 293 |
+
if isinstance(output_spec, ShardingSpec):
|
| 294 |
+
mod = module.get_submodule(module_path)
|
| 295 |
+
_reshard_output(mod, output_spec)
|
| 296 |
+
else:
|
| 297 |
+
raise TypeError(
|
| 298 |
+
f"Only `ShardingSpec` is supported as output_plan for '{module_path}'"
|
| 299 |
+
)
|
| 300 |
+
# convert the output back to data parallel for the modules appears in
|
| 301 |
+
# `return_local_tensor` of the plan, we will call `_collect_local_shard`
|
| 302 |
+
# to collect the local tensor for output of modules
|
| 303 |
+
if plan.return_local_tensor is not None:
|
| 304 |
+
for module_path in plan.return_local_tensor:
|
| 305 |
+
mod = module.get_submodule(module_path)
|
| 306 |
+
_collect_local_shard(mod)
|
phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Keep old package for BC purposes, this file should be removed once
|
| 2 |
+
# everything moves to the `torch.distributed.checkpoint` package.
|
| 3 |
+
import sys
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.distributed.checkpoint import * # noqa: F403
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
with warnings.catch_warnings():
|
| 11 |
+
warnings.simplefilter("always")
|
| 12 |
+
warnings.warn(
|
| 13 |
+
"`torch.distributed._shard.checkpoint` will be deprecated, "
|
| 14 |
+
"use `torch.distributed.checkpoint` instead",
|
| 15 |
+
DeprecationWarning,
|
| 16 |
+
stacklevel=2,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
sys.modules["torch.distributed._shard.checkpoint"] = torch.distributed.checkpoint
|
phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (628 Bytes). View file
|
|
|
phivenv/Lib/site-packages/torch/distributed/_shard/common_op_utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils import _pytree as pytree
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _basic_validation(op, args=(), kwargs=None):
|
| 9 |
+
"""
|
| 10 |
+
Common validation across all ops go in here.
|
| 11 |
+
"""
|
| 12 |
+
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
| 13 |
+
|
| 14 |
+
if len(args) == 0 and (kwargs is None or len(kwargs) == 0):
|
| 15 |
+
raise ValueError(f" No input for '{op.__name__}'!")
|
| 16 |
+
|
| 17 |
+
# Validate types
|
| 18 |
+
has_distributed_tensor = False
|
| 19 |
+
|
| 20 |
+
def is_distributed_tensor(e):
|
| 21 |
+
nonlocal has_distributed_tensor
|
| 22 |
+
if isinstance(e, ShardedTensor):
|
| 23 |
+
has_distributed_tensor = True
|
| 24 |
+
|
| 25 |
+
pytree.tree_map_(is_distributed_tensor, args)
|
| 26 |
+
pytree.tree_map_(is_distributed_tensor, kwargs)
|
| 27 |
+
|
| 28 |
+
if not has_distributed_tensor:
|
| 29 |
+
raise TypeError(
|
| 30 |
+
f"torch function '{op.__name__}', with args: {args} and "
|
| 31 |
+
f"kwargs: {kwargs} are called without any distributed tensor!"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Validate all distributed tensors use the same PG.
|
| 35 |
+
cur_pg: Optional[torch.distributed.ProcessGroup] = None
|
| 36 |
+
|
| 37 |
+
def validate_pg(e):
|
| 38 |
+
nonlocal cur_pg
|
| 39 |
+
if isinstance(e, ShardedTensor):
|
| 40 |
+
if cur_pg is not None and e._process_group is not cur_pg:
|
| 41 |
+
raise RuntimeError(
|
| 42 |
+
"All distributed tensors should use the "
|
| 43 |
+
"same ProcessGroup if used together in an op."
|
| 44 |
+
)
|
| 45 |
+
cur_pg = e._process_group
|
| 46 |
+
|
| 47 |
+
pytree.tree_map_(validate_pg, args)
|
| 48 |
+
pytree.tree_map_(validate_pg, kwargs)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _register_default_op(op, decorator):
|
| 52 |
+
@decorator(op)
|
| 53 |
+
def tensor_default_op(types, args=(), kwargs=None, pg=None):
|
| 54 |
+
"""
|
| 55 |
+
Handles ``__torch_function__`` dispatch for the default tensor ops that
|
| 56 |
+
behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
|
| 57 |
+
``torch.Tensor.dtype``. We simply lower to the real op call with
|
| 58 |
+
DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
|
| 59 |
+
to avoid recursions.
|
| 60 |
+
"""
|
| 61 |
+
if kwargs is None:
|
| 62 |
+
kwargs = {}
|
| 63 |
+
|
| 64 |
+
with torch._C.DisableTorchFunctionSubclass():
|
| 65 |
+
return op(*args, **kwargs)
|