cranky-coder08 commited on
Commit
f4cade0
·
verified ·
1 Parent(s): ad5f26a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll +3 -0
  3. phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe +3 -0
  4. phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe +3 -0
  5. phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe +3 -0
  6. phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc +3 -0
  7. phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc +3 -0
  8. phivenv/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc +0 -0
  9. phivenv/Lib/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-39.pyc +0 -0
  10. phivenv/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc +0 -0
  11. phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc +0 -0
  12. phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc +0 -0
  13. phivenv/Lib/site-packages/torch/distributed/__pycache__/_serialization.cpython-39.pyc +0 -0
  14. phivenv/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc +0 -0
  15. phivenv/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc +0 -0
  16. phivenv/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc +0 -0
  17. phivenv/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc +0 -0
  18. phivenv/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc +0 -0
  19. phivenv/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc +0 -0
  20. phivenv/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc +0 -0
  21. phivenv/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc +0 -0
  22. phivenv/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc +0 -0
  23. phivenv/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc +0 -0
  24. phivenv/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc +0 -0
  25. phivenv/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc +0 -0
  26. phivenv/Lib/site-packages/torch/distributed/_composable/__init__.py +3 -0
  27. phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc +0 -0
  28. phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc +0 -0
  29. phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc +0 -0
  30. phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc +0 -0
  31. phivenv/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py +132 -0
  32. phivenv/Lib/site-packages/torch/distributed/_composable/contract.py +248 -0
  33. phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py +3 -0
  34. phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc +0 -0
  35. phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc +0 -0
  36. phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py +8 -0
  37. phivenv/Lib/site-packages/torch/distributed/_composable/replicate.py +256 -0
  38. phivenv/Lib/site-packages/torch/distributed/_shard/__init__.py +1 -0
  39. phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc +0 -0
  40. phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc +0 -0
  41. phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc +0 -0
  42. phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc +0 -0
  43. phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc +0 -0
  44. phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc +0 -0
  45. phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc +0 -0
  46. phivenv/Lib/site-packages/torch/distributed/_shard/_utils.py +32 -0
  47. phivenv/Lib/site-packages/torch/distributed/_shard/api.py +306 -0
  48. phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py +19 -0
  49. phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc +0 -0
  50. phivenv/Lib/site-packages/torch/distributed/_shard/common_op_utils.py +65 -0
.gitattributes CHANGED
@@ -48,3 +48,9 @@ phivenv/Lib/site-packages/numpy/_core/__pycache__/fromnumeric.cpython-39.pyc fil
48
  phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
49
  phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
50
  phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
48
  phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
49
  phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
50
  phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll filter=lfs diff=lfs merge=lfs -text
51
+ phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
52
+ phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll filter=lfs diff=lfs merge=lfs -text
53
+ phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
54
+ phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
55
+ phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
56
+ phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44629a7d27806ea076daeae8e829b0cfbdec9e25099561a19af8e5910bd635c5
3
+ size 32816640
phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1618387a688f162408e7811350a72269076d52bf6d0f09860548d5b57d677ac
3
+ size 180736
phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a00a877acefcad45953343ad56a22152f7aaba5fcf2a10215d84169d47fbcd1d
3
+ size 105984
phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43f1ddcd5bbdcf161d6816b79b4889e7f75d2ce12ab4f7bcc77d16003a17cdaf
3
+ size 166400
phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8acb4dd7cd594effc85e8c2b9ac052d6f4fe88744cd4749a8e8b8b93ba88246
3
+ size 151716
phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a6e2c125af98ae3013115aad3c6156dd30340dd0c77863105db036c061ddc8e
3
+ size 176641
phivenv/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (4.49 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-39.pyc ADDED
Binary file (1.78 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc ADDED
Binary file (1.49 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc ADDED
Binary file (32.8 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc ADDED
Binary file (2.75 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/_serialization.cpython-39.pyc ADDED
Binary file (4.76 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc ADDED
Binary file (21.2 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc ADDED
Binary file (3.95 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc ADDED
Binary file (3.12 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc ADDED
Binary file (5.46 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc ADDED
Binary file (532 Bytes). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc ADDED
Binary file (31.7 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc ADDED
Binary file (7.81 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc ADDED
Binary file (341 Bytes). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc ADDED
Binary file (3.8 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc ADDED
Binary file (8.85 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc ADDED
Binary file (27.2 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_composable/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .checkpoint_activation import checkpoint
2
+ from .contract import _get_registry, contract
3
+ from .replicate import replicate
phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (316 Bytes). View file
 
phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc ADDED
Binary file (4.54 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc ADDED
Binary file (6.96 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc ADDED
Binary file (7.13 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from collections.abc import Generator
3
+ from contextlib import AbstractContextManager, contextmanager, nullcontext
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import (
9
+ _checkpoint_without_reentrant_generator,
10
+ _DEFAULT_DETERMINISM_MODE,
11
+ )
12
+
13
+ from .contract import _State, contract
14
+
15
+
16
+ @contextmanager
17
+ def _no_hook(module: nn.Module, user_ctx: Optional[AbstractContextManager] = None):
18
+ r"""
19
+ Disable hooks installed by checkpoint to avoid unintentional recursion
20
+ during backward recomputation.
21
+ """
22
+
23
+ with user_ctx if user_ctx else nullcontext():
24
+ orig_enable_hook = checkpoint.state(module).enable_hook
25
+ checkpoint.state(module).enable_hook = False
26
+ try:
27
+ yield
28
+ finally:
29
+ checkpoint.state(module).enable_hook = orig_enable_hook
30
+
31
+
32
+ class _CheckpointState(_State):
33
+ enable_hook: bool = False
34
+ _ac_generator: Optional[Generator[None, None, None]]
35
+
36
+
37
+ @contract(_CheckpointState)
38
+ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
39
+ r"""
40
+ This is a composable activation checkpointing API. Unlike functional
41
+ activation checkpointing APIs, this one does not require changing model
42
+ source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
43
+ this one does not modify model structure or fully-qualified names either.
44
+ Under the hood, it registers activation checkpointing logic as pre- and
45
+ post-forward hooks. Hence, this API can be easily applied to any model or
46
+ sub-modules in the model.
47
+
48
+ Args:
49
+ module (nn.Module): the target model or sub-module to apply activation
50
+ checkpointing.
51
+
52
+ Example::
53
+ >>> # xdoctest: +SKIP
54
+ >>> import torch.nn as nn
55
+ >>>
56
+ >>> class MyModel(nn.Module):
57
+ >>> def __init__(self) -> None:
58
+ >>> super().__init__()
59
+ >>> self.l1 = nn.Linear(10, 10)
60
+ >>> self.l2 = nn.Linear(10, 10)
61
+ >>>
62
+ >>> def forward(self, x):
63
+ >>> return self.l2(self.l1(x))
64
+ >>>
65
+ >>> model = MyModel()
66
+ >>> checkpoint(model.l1) # apply activation checkpointing only to l1
67
+ >>> model(torch.zeros(2, 10)).sum().backward()
68
+
69
+ """
70
+ torch._C._log_api_usage_once("torch.distributed.checkpoint")
71
+
72
+ use_reentrant = kwargs.pop("use_reentrant", False)
73
+ if use_reentrant:
74
+ raise NotImplementedError(
75
+ "use_reentrant=True is not supported in composable checkpoint. "
76
+ "Please use torch.utils.checkpoint.checkpoint instead."
77
+ )
78
+ preserve_rng_state = kwargs.pop("preserve_rng_state", True)
79
+ user_context_fns = kwargs.pop("context_fn", None)
80
+ determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
81
+ debug = kwargs.pop("debug", False)
82
+
83
+ if kwargs:
84
+ raise ValueError(
85
+ "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
86
+ )
87
+
88
+ def forward_pre_hook(
89
+ module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
90
+ ) -> None:
91
+ if checkpoint.state(module).enable_hook:
92
+
93
+ def context_fns():
94
+ if user_context_fns is not None:
95
+ ctx1, ctx2 = user_context_fns()
96
+ return ctx1, _no_hook(module, ctx2)
97
+ else:
98
+ return nullcontext(), _no_hook(module)
99
+
100
+ gen = _checkpoint_without_reentrant_generator(
101
+ module,
102
+ preserve_rng_state,
103
+ context_fns,
104
+ determinism_check,
105
+ debug,
106
+ *args,
107
+ **kwargs,
108
+ )
109
+ checkpoint.state(module)._ac_generator = gen
110
+ next(gen)
111
+
112
+ def forward_hook(module: nn.Module, inputs: tuple[Any, ...], output: Any) -> Any:
113
+ if checkpoint.state(module).enable_hook:
114
+ try:
115
+ gen = checkpoint.state(module)._ac_generator
116
+ assert gen is not None
117
+ next(gen)
118
+ except StopIteration:
119
+ pass
120
+ else:
121
+ raise RuntimeError(
122
+ "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
123
+ )
124
+
125
+ # Ensure that we no longer hold on to the generator. always_call=True helps ensure we
126
+ # clear this even in the case of exception in fwd pass.
127
+ checkpoint.state(module)._ac_generator = None
128
+
129
+ checkpoint.state(module).enable_hook = True
130
+ module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
131
+ module.register_forward_hook(forward_hook, prepend=True, always_call=True)
132
+ return module
phivenv/Lib/site-packages/torch/distributed/_composable/contract.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import uuid
3
+ from collections import OrderedDict
4
+ from functools import wraps
5
+ from typing import Callable, Generic, Optional, Protocol
6
+ from typing_extensions import Concatenate, ParamSpec, TypeVar
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.distributed._composable_state import _State
11
+ from torch.distributed.utils import _get_root_modules
12
+
13
+
14
+ _T = TypeVar("_T", covariant=True)
15
+ _P = ParamSpec("_P")
16
+
17
+
18
+ def generate_state_key(string="__composable_api_state_key"):
19
+ return f"{string}_{str(uuid.uuid4())}"
20
+
21
+
22
+ STATE_KEY = generate_state_key()
23
+ REGISTRY_KEY = generate_state_key()
24
+
25
+
26
+ # TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
27
+ # we can add args and kwargs here, and then we can detect whether fully_shard
28
+ # is combined with reentrant activation checkpointing and error out with a clear
29
+ # message.
30
+ class RegistryItem:
31
+ pass
32
+
33
+
34
+ _TState = TypeVar("_TState", bound="_State", covariant=True)
35
+ _M = TypeVar("_M", nn.Module, list[nn.Module])
36
+
37
+
38
+ class _ContractFn(Protocol, Generic[_P, _T, _TState]):
39
+ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
40
+
41
+ def state(self, module: nn.Module) -> _TState: ...
42
+
43
+
44
+ def contract(
45
+ state_cls: type[_TState] = _State, # type: ignore[assignment]
46
+ ) -> Callable[
47
+ [Callable[Concatenate[_M, _P], _M]],
48
+ _ContractFn[Concatenate[_M, _P], _M, _TState],
49
+ ]:
50
+ r"""
51
+ Decorate a function as a composable distributed API, where the first
52
+ argument of the function must be an :class:`nn.Module` instance or sequence
53
+ of :class:`nn.Module` instances.
54
+
55
+ The decorator verifies that the decorated function does not modify
56
+ fully-qualified names (FQNs) for parameters, buffers, or modules. The
57
+ decorated function can return different module instances than the input
58
+ modules; the FQN invariant will be enforced following the input order.
59
+
60
+ When a function ``func`` is decorated by ``@contract()``, a
61
+ ``.state(module: nn.Module)`` method will be installed to the decorated
62
+ function. Then you can retrieve and modify the state on a module by calling
63
+ ``func.state(module)``.
64
+
65
+ Example::
66
+ >>> # xdoctest: +SKIP
67
+ >>> import torch.nn as nn
68
+ >>>
69
+ >>> class MyModel(nn.Module):
70
+ >>> def __init__(self) -> None:
71
+ >>> super().__init__()
72
+ >>> self.l1 = nn.Linear(10, 10)
73
+ >>> self.l2 = nn.Linear(10, 10)
74
+ >>>
75
+ >>> def forward(self, x):
76
+ >>> return self.l2(self.l1(x))
77
+ >>>
78
+ >>> @contract()
79
+ >>> def my_feature(module: nn.Module) -> nn.Module:
80
+ >>> my_feature.state(module).some_state = "any value"
81
+ >>> return module
82
+ >>>
83
+ >>> model = MyModel()
84
+ >>> my_feature(model.l1)
85
+ >>> assert my_feature.state(model.l1).some_state == "any value"
86
+ >>> my_feature(model.l2)
87
+ >>> model(torch.randn(2, 10)).sum().backward()
88
+ """
89
+
90
+ # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
91
+ @wraps(state_cls) # type: ignore[arg-type]
92
+ def inner(
93
+ func: Callable[Concatenate[_M, _P], _M],
94
+ ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
95
+ @wraps(func)
96
+ def wrapper(
97
+ module: _M,
98
+ *args: _P.args,
99
+ **kwargs: _P.kwargs,
100
+ ) -> _M:
101
+ inp_module = module
102
+ modules: list[nn.Module]
103
+ if isinstance(module, nn.Module):
104
+ modules = [module]
105
+ else:
106
+ # If the user passes a sequence of modules, then we assume that
107
+ # we only need to insert the state object on the root modules
108
+ # (i.e. those without a parent) among the passed-in modules.
109
+ modules = _get_root_modules(list(module))
110
+ state = state_cls() # shared across all modules
111
+ registry_item = RegistryItem() # shared across all modules
112
+
113
+ # `func` is allowed to return different module instances than the
114
+ # input modules as long as FQNs are preserved following the input
115
+ # module order
116
+ all_orig_named_params: list[dict[str, nn.Parameter]] = []
117
+ all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
118
+ all_orig_named_modules: list[dict[str, nn.Module]] = []
119
+
120
+ for module in modules:
121
+ default_all_state: dict[Callable, _State] = OrderedDict()
122
+ default_registry: dict[str, RegistryItem] = OrderedDict()
123
+ all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
124
+ STATE_KEY, default_all_state
125
+ )
126
+ if not isinstance(all_state, dict):
127
+ raise AssertionError(
128
+ f"Distributed composable API states corrupted: {all_state}"
129
+ )
130
+ registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
131
+ REGISTRY_KEY, default_registry
132
+ )
133
+ if not isinstance(registry, dict):
134
+ raise AssertionError(
135
+ f"Distributed composable API registry corrupted: {registry}"
136
+ )
137
+ if func in all_state or func.__name__ in registry:
138
+ raise AssertionError(
139
+ "Each distinct composable distributed API can only be applied to a "
140
+ f"module once. {func.__name__} has already been applied to the "
141
+ f"following module:\n{module}"
142
+ )
143
+ all_state.setdefault(func, state)
144
+ registry.setdefault(func.__name__, registry_item)
145
+
146
+ all_orig_named_params.append(OrderedDict(module.named_parameters()))
147
+ all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
148
+ all_orig_named_modules.append(OrderedDict(module.named_modules()))
149
+
150
+ updated = func(inp_module, *args, **kwargs)
151
+ if updated is None:
152
+ updated = inp_module # type: ignore[assignment]
153
+ updated_modules: list[nn.Module]
154
+ if isinstance(updated, nn.Module):
155
+ updated_modules = [updated]
156
+ else:
157
+ updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type, call-overload]
158
+
159
+ all_new_named_params: list[dict[str, nn.Parameter]] = []
160
+ all_new_named_buffers: list[dict[str, torch.Tensor]] = []
161
+ all_new_named_modules: list[dict[str, nn.Module]] = []
162
+ for module in updated_modules:
163
+ all_new_named_params.append(OrderedDict(module.named_parameters()))
164
+ all_new_named_buffers.append(OrderedDict(module.named_buffers()))
165
+ all_new_named_modules.append(OrderedDict(module.named_modules()))
166
+
167
+ num_orig_modules = len(all_orig_named_modules)
168
+ num_new_modules = len(all_new_named_modules)
169
+ if num_orig_modules != num_new_modules:
170
+ raise AssertionError(
171
+ f"{func.__name__} should return the same number of modules as input modules"
172
+ f"Inputs: {num_orig_modules} modules\n"
173
+ f"Outputs: {num_new_modules} modules"
174
+ )
175
+
176
+ def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
177
+ if orig_fqns == new_fqns:
178
+ return
179
+
180
+ orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
181
+ orig_only = orig_fqn_set - new_fqn_set
182
+ new_only = new_fqn_set - orig_fqn_set
183
+ if len(orig_only) or len(new_only):
184
+ raise RuntimeError(
185
+ f"{check_key}"
186
+ "Composable distributed API implementations cannot modify FQNs.\n"
187
+ f"FQNs only in original: {orig_only}\n"
188
+ f"FQNs only in new: {new_only}"
189
+ )
190
+ else:
191
+ raise RuntimeError(
192
+ f"{check_key}"
193
+ "Composable distributed API implementations cannot modify "
194
+ "the order of FQNs.\n"
195
+ f"Original FQNs: {orig_only}\n"
196
+ f"New FQNs: {new_only}"
197
+ )
198
+
199
+ for orig_named_params, new_named_params in zip(
200
+ all_orig_named_params, all_new_named_params
201
+ ):
202
+ check_fqn(
203
+ list(orig_named_params.keys()),
204
+ list(new_named_params.keys()),
205
+ "Checking parameters: ",
206
+ )
207
+ for orig_named_buffers, new_named_buffers in zip(
208
+ all_orig_named_buffers, all_new_named_buffers
209
+ ):
210
+ check_fqn(
211
+ list(orig_named_buffers.keys()),
212
+ list(new_named_buffers.keys()),
213
+ "Checking buffers: ",
214
+ )
215
+ for orig_named_modules, new_named_modules in zip(
216
+ all_orig_named_modules, all_new_named_modules
217
+ ):
218
+ check_fqn(
219
+ list(orig_named_modules.keys()),
220
+ list(new_named_modules.keys()),
221
+ "Checking modules: ",
222
+ )
223
+
224
+ # TODO: verify that installed distributed paradigms are compatible with
225
+ # each other.
226
+
227
+ return updated
228
+
229
+ def get_state(module: nn.Module) -> _State:
230
+ return module.__dict__.setdefault( # type: ignore[call-overload]
231
+ STATE_KEY,
232
+ {}, # TODO(@yhcharles): this is a temporary fix, need a better way
233
+ ).get(func) # type: ignore[call-overload]
234
+
235
+ wrapper.state = get_state # type: ignore[attr-defined]
236
+
237
+ return wrapper # type: ignore[return-value]
238
+
239
+ return inner # type: ignore[return-value]
240
+
241
+
242
+ def _get_registry(module: nn.Module) -> Optional[dict[str, RegistryItem]]:
243
+ r"""
244
+ Get an ``OrderedDict`` of composable APIs that have been applied to the
245
+ ``module``, indexed by the API name. If no API has been applied, then this
246
+ returns ``None``.
247
+ """
248
+ return getattr(module, REGISTRY_KEY, None)
phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
2
+
3
+ from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method
phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (395 Bytes). View file
 
phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc ADDED
Binary file (323 Bytes). View file
 
phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # TODO: For backward compatibility, we are importing the public objects
2
+ # originally from this file.
3
+ from torch.distributed.fsdp import ( # noqa: F401
4
+ FSDPModule,
5
+ fully_shard,
6
+ register_fsdp_forward_method,
7
+ UnshardHandle,
8
+ )
phivenv/Lib/site-packages/torch/distributed/_composable/replicate.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import weakref
3
+ from collections.abc import Iterable
4
+ from typing import Any, NoReturn, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.distributed._composable_state import _State
9
+ from torch.nn.parallel import DistributedDataParallel
10
+
11
+ from .contract import _get_registry, contract
12
+
13
+
14
+ _ROOT_MODULE_PREFIX = ""
15
+
16
+
17
+ class _ReplicateState(_State):
18
+ _ddp_weakref: weakref.ref
19
+
20
+ def __init__(self) -> None:
21
+ super().__init__()
22
+ self.module: nn.Module = nn.ParameterList()
23
+ self.has_initialized: bool = False
24
+ self._param_list: nn.ParameterList = nn.ParameterList()
25
+ # TODO(@fegin): this variable is originally create for testing, we
26
+ # should remove this if possible.
27
+ self._orig_module = self.module
28
+ self._param_names: list[str] = []
29
+ self._no_sync: bool = False
30
+ self._init_args: Optional[tuple[Any, ...]] = None
31
+ self._init_kwargs: dict[str, Any] = {}
32
+ self._comm_hook_args: list[Any] = []
33
+
34
+ def _collect_params(
35
+ self,
36
+ module: nn.Module,
37
+ ignored_modules: set[nn.Module],
38
+ ignored_params: set[nn.Parameter],
39
+ prefix: str = _ROOT_MODULE_PREFIX,
40
+ ) -> None:
41
+ # skip if managed by fully_sharded API
42
+ if _is_fully_sharded(module):
43
+ return
44
+
45
+ # if a module is ignored, all descendants of the module are ignored.
46
+ if module in ignored_modules:
47
+ return
48
+
49
+ recurse_prefix = (
50
+ f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
51
+ )
52
+
53
+ for n, p in module.named_parameters(recurse=False):
54
+ if p not in ignored_params:
55
+ self._param_list.append(p)
56
+ self._param_names.append(f"{recurse_prefix}{n}")
57
+
58
+ for name, child_module in module.named_children():
59
+ self._collect_params(
60
+ child_module,
61
+ ignored_modules,
62
+ ignored_params,
63
+ prefix=f"{recurse_prefix}{name}",
64
+ )
65
+
66
+ def lazy_init(self) -> None:
67
+ @torch._disable_dynamo(recursive=True)
68
+ def _lazy_init():
69
+ assert self._init_args is not None
70
+ self.init(*self._init_args, **self._init_kwargs)
71
+ self.register_comm_hook()
72
+ self._init_args = ()
73
+ self._init_kwargs = {}
74
+
75
+ _lazy_init()
76
+
77
+ def init(
78
+ self,
79
+ module: nn.Module,
80
+ ignored_modules: set[nn.Module],
81
+ **kwargs,
82
+ ) -> None:
83
+ if self.has_initialized:
84
+ return
85
+
86
+ self.has_initialized = True
87
+ self.module = module
88
+ ignored_params = {p for m in ignored_modules for p in m.parameters()}
89
+ for submodule in module.modules():
90
+ if _is_fully_sharded(submodule):
91
+ ignored_params.update(submodule.parameters())
92
+ from torch.distributed.tensor.parallel.ddp import _localize_dtensor
93
+
94
+ _localize_dtensor(module, ignored_params=ignored_params)
95
+ self._collect_params(module, ignored_modules, ignored_params)
96
+
97
+ if "device_id" in kwargs:
98
+ # replicate() supports a small usability enhancement where
99
+ # user can pass in device_id as a Union[int, torch.device] even for
100
+ # CPU devices so users don't have to change code for CPU/GPU runs.
101
+ # We derive the right device_ids to feed into DDP to support this.
102
+ if kwargs["device_id"] is not None:
103
+ device_id = kwargs["device_id"]
104
+ # Convert to device_ids that DDP expects.
105
+ if isinstance(device_id, torch.device) and device_id.type == "cpu":
106
+ # CPU modules receive device_ids None
107
+ kwargs["device_ids"] = None
108
+ else:
109
+ # GPU modules expect device_ids=[cuda_device]
110
+ kwargs["device_ids"] = [device_id]
111
+ else:
112
+ kwargs["device_ids"] = None
113
+ kwargs.pop("device_id")
114
+
115
+ self._ddp = DistributedDataParallel(self._param_list, **kwargs)
116
+ # Weakref to the DDP instance is currently only used for testing.
117
+ replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
118
+
119
+ def register_comm_hook(self) -> None:
120
+ for comm_args, comm_kwargs in self._comm_hook_args:
121
+ self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
122
+ self._comm_hook_args.clear()
123
+
124
+ def record_init_args(self, *args, **kwargs) -> None:
125
+ self._init_args = args
126
+ self._init_kwargs = kwargs
127
+
128
+ def forward_pre_hook(
129
+ self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
130
+ ) -> Any:
131
+ if self._init_args or self._init_kwargs:
132
+ self.lazy_init()
133
+ self._ddp.require_backward_grad_sync = not self._no_sync
134
+ return self._ddp._pre_forward(*args, **kwargs)
135
+
136
+ def forward_post_hook(
137
+ self,
138
+ module: nn.Module,
139
+ input: tuple[torch.Tensor],
140
+ output: torch.Tensor,
141
+ ) -> torch.Tensor:
142
+ return self._ddp._post_forward(output)
143
+
144
+
145
+ def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
146
+ raise AssertionError(
147
+ "DDP does not support deepcopy. Please use state dict for serialization."
148
+ )
149
+
150
+
151
+ # Follow the same pattern as FSDP/fully_shard
152
+ class DDP:
153
+ def __new__(cls, *args, **kwargs):
154
+ """
155
+ Override ``__new__`` to remove the DDP class and directly construct
156
+ the original class for cases like indexing into a container module.
157
+ """
158
+ # Use index 2 since 0 is the dynamically constructed `DDP<...>` class
159
+ # and index 1 is the `DDP` class itself
160
+ orig_cls = cls.__mro__[2]
161
+ return orig_cls.__new__(orig_cls, *args, **kwargs)
162
+
163
+ def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
164
+ """
165
+ Sets if the module should sync gradients. This can be used to implement
166
+ gradient accumulation without communication.
167
+
168
+ Args:
169
+ requires_gradient_sync (bool): Whether to reduce gradients for the
170
+ module's parameters.
171
+ """
172
+ replicate.state(self)._no_sync = not requires_gradient_sync # type: ignore[arg-type]
173
+
174
+ def register_comm_hook(self, *args, **kwargs) -> None:
175
+ replicate.state(self)._comm_hook_args.append((args, kwargs)) # type: ignore[arg-type]
176
+
177
+
178
+ @contract(state_cls=_ReplicateState)
179
+ def replicate(
180
+ module: nn.Module,
181
+ ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
182
+ **kwargs,
183
+ ) -> nn.Module:
184
+ r"""Replicates a module
185
+
186
+ Args:
187
+ module (torch.nn.Module): module to replicate
188
+
189
+ Example::
190
+ >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
191
+ >>> module = nn.Linear(3, 3)
192
+ >>> replicate(module)
193
+ """
194
+ torch._C._log_api_usage_once("torch.distributed.replicate")
195
+
196
+ # TODO(fegin): using kwargs is not a good idea if we would like to make
197
+ # replicate a formal API to replace DDP.
198
+ if "device_id" in kwargs:
199
+ if not isinstance(kwargs["device_id"], (int, torch.device)):
200
+ raise RuntimeError(
201
+ "Expected device_id to be int or torch.device, "
202
+ f"but got {type(kwargs['device_id'])}"
203
+ )
204
+
205
+ if _is_fully_sharded(module):
206
+ raise RuntimeError(
207
+ "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
208
+ )
209
+
210
+ if ignored_modules is None:
211
+ ignored_modules = {}
212
+ else:
213
+ ignored_modules = set(ignored_modules)
214
+
215
+ state = replicate.state(module)
216
+ module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
217
+ device_mesh = kwargs.get("device_mesh", None)
218
+ if device_mesh is not None:
219
+ from torch.distributed.device_mesh import _mesh_resources
220
+
221
+ root_mesh = _mesh_resources.get_root_mesh(device_mesh)
222
+ # if a root mesh is not the same as device_mesh,
223
+ # meaning the device_mesh is sliced out from the root mesh.
224
+ if root_mesh != device_mesh:
225
+ # TODO: This is a temporary work around to enable DDP + TP.
226
+ # We should do the logic in DDP so that the 2D implementation is
227
+ # sound and the state_dict works out of the box.
228
+ #
229
+ # This won't conflict with what is done in DDP class as the module
230
+ # replicate is going to pass is NOT the original module.
231
+ from torch.distributed.tensor.parallel.ddp import (
232
+ _localize_dtensor,
233
+ _reconstruct_dtensor,
234
+ )
235
+
236
+ module.register_forward_pre_hook(_reconstruct_dtensor)
237
+ module.register_forward_hook(_localize_dtensor)
238
+
239
+ module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type]
240
+
241
+ state.record_init_args(module, ignored_modules, **kwargs)
242
+
243
+ # Place DDP leftmost for highest priority in the method resolution order
244
+ cls = module.__class__
245
+ dct = {"__deepcopy__": unimplemented_deepcopy}
246
+ new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
247
+ module.__class__ = new_cls
248
+ return module
249
+
250
+
251
+ def _is_fully_sharded(module: nn.Module) -> bool:
252
+ r"""Check if module is marked with fully_shard."""
253
+ registry = _get_registry(module)
254
+ if registry is None:
255
+ return False
256
+ return "fully_shard" in registry
phivenv/Lib/site-packages/torch/distributed/_shard/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (292 Bytes). View file
 
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc ADDED
Binary file (1.04 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc ADDED
Binary file (9.82 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc ADDED
Binary file (2.25 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc ADDED
Binary file (2.28 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc ADDED
Binary file (1.18 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc ADDED
Binary file (1.32 kB). View file
 
phivenv/Lib/site-packages/torch/distributed/_shard/_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch.distributed._shard.metadata import ShardMetadata
5
+
6
+
7
+ DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor."
8
+
9
+
10
+ def narrow_tensor_by_index(
11
+ tensor: torch.Tensor,
12
+ offsets: Sequence[int],
13
+ sizes: Sequence[int],
14
+ ) -> torch.Tensor:
15
+ """
16
+ Narrow the tensor according to ``offsets`` and ``sizes``.
17
+ """
18
+ narrowed_tensor = tensor
19
+ for idx, (offset, size) in enumerate(zip(offsets, sizes)):
20
+ if size < tensor.size(idx):
21
+ # Reshape to get shard for this rank and we don't want autograd
22
+ # recording here for the narrow op and 'local_shard' should be a
23
+ # leaf variable in the autograd graph.
24
+ narrowed_tensor = narrowed_tensor.narrow(idx, offset, size)
25
+ return narrowed_tensor
26
+
27
+
28
+ def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor:
29
+ """
30
+ Narrow the tensor according to the metadata
31
+ """
32
+ return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes)
phivenv/Lib/site-packages/torch/distributed/_shard/api.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from contextlib import contextmanager
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.nn as nn
8
+ from torch.distributed import distributed_c10d
9
+ from torch.distributed._shard.sharded_tensor import ShardedTensor
10
+
11
+ from .sharder import Sharder
12
+ from .sharding_plan import ShardingPlan
13
+ from .sharding_spec import ChunkShardingSpec, ShardingSpec
14
+
15
+
16
+ def _shard_tensor(
17
+ tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
18
+ ) -> ShardedTensor:
19
+ """
20
+ Given a :class:`torch.Tensor`, it shards that tensor according to the provided
21
+ ``sharding_spec``. ``src_rank`` denotes the source rank which would be
22
+ used as the ground truth of the data which would be scattered as shards
23
+ across the rest of the ranks.
24
+
25
+ Args:
26
+ tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
27
+ sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
28
+ describing how to shard the Tensor.
29
+
30
+ Keyword args:
31
+ src_rank (int, optional): The source rank which is used as the ground truth of
32
+ the data for the parameter that would be sharded and scattered
33
+ across the rest of the ranks.
34
+ Default: 0.
35
+ process_group (ProcessGroup, optional): The process group to work on. If None,
36
+ the default process group will be used.
37
+
38
+ Returns:
39
+ A :class:`ShardedTensor` sharded from the given tensor.
40
+
41
+ .. warning::
42
+ Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
43
+ currently supported as the ``sharding_spec``.
44
+ """
45
+ if not tensor.is_contiguous():
46
+ raise ValueError("input tensor is not a contiguous Tensor")
47
+
48
+ pg = (
49
+ process_group
50
+ if process_group is not None
51
+ else distributed_c10d._get_default_group()
52
+ )
53
+ world_size = dist.get_world_size(pg)
54
+ current_rank = dist.get_rank(pg)
55
+
56
+ # Validate src_rank and sharding_spec are same across all ranks.
57
+ gathered_list = [None] * world_size
58
+ dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
59
+
60
+ for idx, entry in enumerate(gathered_list):
61
+ if src_rank != entry[0]: # type: ignore[index]
62
+ raise ValueError(
63
+ f"src_rank={src_rank} on rank: {current_rank} does not " # type: ignore[index]
64
+ f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index]
65
+ )
66
+ if sharding_spec != entry[1]: # type: ignore[index]
67
+ raise ValueError(
68
+ f"sharding_spec={sharding_spec} on rank: {current_rank} does not " # type: ignore[index]
69
+ f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index]
70
+ )
71
+
72
+ st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg)
73
+
74
+ return st
75
+
76
+
77
+ def shard_parameter(
78
+ module: torch.nn.Module,
79
+ param_name: str,
80
+ sharding_spec: ShardingSpec,
81
+ src_rank=0,
82
+ process_group=None,
83
+ ):
84
+ """
85
+ Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
86
+ module, it shards that parameter according to the provided
87
+ ``sharding_spec``. ``src_rank`` denotes the source rank which would be
88
+ used as the ground truth of the data which would be scattered as shards
89
+ across the rest of the ranks.
90
+
91
+ This method replaces ``module.param_name`` with a
92
+ :class:`torch.distributed._sharded_tensor.ShardedTensor`
93
+
94
+ Args:
95
+ module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
96
+ param_name (str): Name of the parameter of ``module`` that needs to be sharded.
97
+ sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
98
+ describing how to shard the Tensor.
99
+
100
+ Keyword args:
101
+ src_rank (int, optional): The source rank which is used as the ground truth of
102
+ the data for the parameter that would be sharded and scattered
103
+ across the rest of the ranks.
104
+ Default: 0.
105
+ process_group (ProcessGroup, optional): The process group to work on. If None,
106
+ the default process group will be used.
107
+
108
+ .. warning::
109
+ Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
110
+ currently supported as the ``sharding_spec``.
111
+ """
112
+ # Perform some validation first.
113
+ if not hasattr(module, param_name):
114
+ raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`")
115
+
116
+ tensor = getattr(module, param_name)
117
+ if not isinstance(tensor, torch.Tensor):
118
+ raise ValueError(
119
+ f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}"
120
+ )
121
+
122
+ if not tensor.is_contiguous():
123
+ raise ValueError(f"param: {param_name} is not a contiguous Tensor")
124
+
125
+ st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
126
+
127
+ # Replace param with ShardedTensor.
128
+ module.register_parameter(param_name, nn.Parameter(st))
129
+
130
+
131
+ # Tracks the current process group in the load context manager.
132
+ _CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None
133
+
134
+
135
+ @contextmanager
136
+ def load_with_process_group(process_group):
137
+ """
138
+ Context manager to set the process group with which to load a ShardedTensor.
139
+ """
140
+ global _CURRENT_PROCESS_GROUP
141
+ if _CURRENT_PROCESS_GROUP is not None:
142
+ raise RuntimeError(
143
+ 'ProcessGroup already set by previous "load_with_process_group" '
144
+ "context manager"
145
+ )
146
+ _CURRENT_PROCESS_GROUP = process_group
147
+ try:
148
+ yield process_group
149
+ finally:
150
+ _CURRENT_PROCESS_GROUP = None
151
+
152
+
153
+ def _get_current_process_group():
154
+ """
155
+ Retrieves the current process group set by ``load_with_process_group``.
156
+ If not set, it just returns the default group.
157
+ """
158
+ global _CURRENT_PROCESS_GROUP
159
+ if _CURRENT_PROCESS_GROUP is None:
160
+ return distributed_c10d._get_default_group()
161
+ else:
162
+ return _CURRENT_PROCESS_GROUP
163
+
164
+
165
+ def _reshard_output(
166
+ module: torch.nn.Module, resharding_spec: ShardingSpec
167
+ ) -> torch.nn.Module:
168
+ """
169
+ Hook a module with output resharding in the forward pass according
170
+ to the given ``resharding_spec``.
171
+
172
+ Args:
173
+ module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
174
+ resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
175
+ The specification describing how the output of the module will be resharded.
176
+
177
+ Returns:
178
+ A :class:`torch.nn.Module` object with reshard API hooked.
179
+ """
180
+
181
+ def hook_func(_module, _input, output):
182
+ if isinstance(output, ShardedTensor):
183
+ return output.reshard(resharding_spec)
184
+ return output
185
+
186
+ module.register_forward_hook(hook_func)
187
+ return module
188
+
189
+
190
+ def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
191
+ """
192
+ Hook a module with local shards collection in the forward pass.
193
+
194
+ This API is typically used to convert a sharded representation back to data parallel
195
+ representation. In particular, it returns the local tensor for this Shard. If the
196
+ size along the sharding dimension for the local tensor is 1, this dimension is removed
197
+ from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
198
+ a local Tensor of size [16] across each rank and not [1, 16] across each rank.
199
+
200
+ Args:
201
+ module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the
202
+ local tensor value needs to be returned.
203
+
204
+ Returns:
205
+ A :class:`torch.nn.Module` object with collection API hooked.
206
+ """
207
+
208
+ def hook_func(_module, _input, output):
209
+ if isinstance(output, ShardedTensor):
210
+ local_tensor = output.local_tensor()
211
+ # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec
212
+ sharding_spec = output._sharding_spec
213
+ if (
214
+ isinstance(sharding_spec, ChunkShardingSpec)
215
+ and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type]
216
+ ):
217
+ local_tensor = local_tensor.squeeze(
218
+ output._sharding_spec.dim # type: ignore[attr-defined]
219
+ )
220
+ return local_tensor
221
+
222
+ module.register_forward_hook(hook_func)
223
+ return module
224
+
225
+
226
+ def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None):
227
+ """
228
+ Shards a given module according to the provided sharding `plan`. This method
229
+ first shards all the parameters according to the given sharding `plan`. Then if
230
+ `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it
231
+ will tag the output of modules according `output_plan`, convert the module's
232
+ output back to data parallel according to `return_local_tensor`.
233
+
234
+ Needs to be called on all ranks in an SPMD fashion.
235
+
236
+ Args:
237
+ module (:class:`torch.nn.Module`): The module to apply sharding to
238
+ plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`):
239
+ The ShardingPlan which specified param name to ShardingSpec to apply to
240
+ each parameter.
241
+
242
+ Keyword args:
243
+ src_rank (int, optional): The source rank which is used as the ground truth of
244
+ the data for the module that would be sharded and scattered across the rest
245
+ of the ranks.
246
+ Default: 0.
247
+ process_group (ProcessGroup, optional): The process group to work on. If None,
248
+ the default process group will be used.
249
+ """
250
+ # record Sharder paths for sanity check on the plan to ensure items in the plan
251
+ # does not conflict with the submodule tree that the Sharder is working with
252
+ sharder_paths = []
253
+ for name, spec in plan.plan.items():
254
+ if isinstance(spec, Sharder):
255
+ sharder_paths.append(name)
256
+
257
+ # shard the parameter according to the ShardingPlan
258
+ for name, spec in plan.plan.items():
259
+ if isinstance(spec, ShardingSpec):
260
+ # if found a sharding spec, try to shard the parameter
261
+ module_path, _, param_name = name.rpartition(".")
262
+
263
+ for sharder_path in sharder_paths:
264
+ if module_path.startswith(sharder_path):
265
+ raise RuntimeError(
266
+ f"ShardingPlan is in-valid, trying to shard a parameter: {name},"
267
+ f" but there's already a Sharder entry for module {sharder_path},"
268
+ f" parameter sharding should not conflict with the submodule tree"
269
+ f" that a Sharder is working with!"
270
+ )
271
+
272
+ mod = module.get_submodule(module_path)
273
+ shard_parameter(
274
+ mod, param_name, spec, src_rank=src_rank, process_group=process_group
275
+ )
276
+ elif isinstance(spec, Sharder):
277
+ parent_mod_path, _, _mod_name = name.rpartition(".")
278
+ if name == "":
279
+ raise KeyError("Module path must not be empty for custom sharder!")
280
+ mod = module.get_submodule(name)
281
+ parent_mod = module.get_submodule(parent_mod_path)
282
+ sharded_mod = spec.shard(mod)
283
+ # swap this submodule with the sharded module
284
+ parent_mod.mod_name = sharded_mod
285
+ else:
286
+ raise TypeError(
287
+ f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'"
288
+ )
289
+
290
+ # reshard output if there's an entry in `reshard_output` for this module
291
+ if plan.output_plan is not None:
292
+ for module_path, output_spec in plan.output_plan.items():
293
+ if isinstance(output_spec, ShardingSpec):
294
+ mod = module.get_submodule(module_path)
295
+ _reshard_output(mod, output_spec)
296
+ else:
297
+ raise TypeError(
298
+ f"Only `ShardingSpec` is supported as output_plan for '{module_path}'"
299
+ )
300
+ # convert the output back to data parallel for the modules appears in
301
+ # `return_local_tensor` of the plan, we will call `_collect_local_shard`
302
+ # to collect the local tensor for output of modules
303
+ if plan.return_local_tensor is not None:
304
+ for module_path in plan.return_local_tensor:
305
+ mod = module.get_submodule(module_path)
306
+ _collect_local_shard(mod)
phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Keep old package for BC purposes, this file should be removed once
2
+ # everything moves to the `torch.distributed.checkpoint` package.
3
+ import sys
4
+ import warnings
5
+
6
+ import torch
7
+ from torch.distributed.checkpoint import * # noqa: F403
8
+
9
+
10
+ with warnings.catch_warnings():
11
+ warnings.simplefilter("always")
12
+ warnings.warn(
13
+ "`torch.distributed._shard.checkpoint` will be deprecated, "
14
+ "use `torch.distributed.checkpoint` instead",
15
+ DeprecationWarning,
16
+ stacklevel=2,
17
+ )
18
+
19
+ sys.modules["torch.distributed._shard.checkpoint"] = torch.distributed.checkpoint
phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (628 Bytes). View file
 
phivenv/Lib/site-packages/torch/distributed/_shard/common_op_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch.utils import _pytree as pytree
6
+
7
+
8
+ def _basic_validation(op, args=(), kwargs=None):
9
+ """
10
+ Common validation across all ops go in here.
11
+ """
12
+ from torch.distributed._shard.sharded_tensor import ShardedTensor
13
+
14
+ if len(args) == 0 and (kwargs is None or len(kwargs) == 0):
15
+ raise ValueError(f" No input for '{op.__name__}'!")
16
+
17
+ # Validate types
18
+ has_distributed_tensor = False
19
+
20
+ def is_distributed_tensor(e):
21
+ nonlocal has_distributed_tensor
22
+ if isinstance(e, ShardedTensor):
23
+ has_distributed_tensor = True
24
+
25
+ pytree.tree_map_(is_distributed_tensor, args)
26
+ pytree.tree_map_(is_distributed_tensor, kwargs)
27
+
28
+ if not has_distributed_tensor:
29
+ raise TypeError(
30
+ f"torch function '{op.__name__}', with args: {args} and "
31
+ f"kwargs: {kwargs} are called without any distributed tensor!"
32
+ )
33
+
34
+ # Validate all distributed tensors use the same PG.
35
+ cur_pg: Optional[torch.distributed.ProcessGroup] = None
36
+
37
+ def validate_pg(e):
38
+ nonlocal cur_pg
39
+ if isinstance(e, ShardedTensor):
40
+ if cur_pg is not None and e._process_group is not cur_pg:
41
+ raise RuntimeError(
42
+ "All distributed tensors should use the "
43
+ "same ProcessGroup if used together in an op."
44
+ )
45
+ cur_pg = e._process_group
46
+
47
+ pytree.tree_map_(validate_pg, args)
48
+ pytree.tree_map_(validate_pg, kwargs)
49
+
50
+
51
+ def _register_default_op(op, decorator):
52
+ @decorator(op)
53
+ def tensor_default_op(types, args=(), kwargs=None, pg=None):
54
+ """
55
+ Handles ``__torch_function__`` dispatch for the default tensor ops that
56
+ behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
57
+ ``torch.Tensor.dtype``. We simply lower to the real op call with
58
+ DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
59
+ to avoid recursions.
60
+ """
61
+ if kwargs is None:
62
+ kwargs = {}
63
+
64
+ with torch._C.DisableTorchFunctionSubclass():
65
+ return op(*args, **kwargs)