koichi12 commited on
Commit
da0ba90
·
verified ·
1 Parent(s): da4944c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/apiclient/__init__.py +27 -0
  3. .venv/lib/python3.11/site-packages/apiclient/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/certifi/__init__.py +4 -0
  5. .venv/lib/python3.11/site-packages/certifi/__main__.py +12 -0
  6. .venv/lib/python3.11/site-packages/certifi/__pycache__/__init__.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/certifi/__pycache__/__main__.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/certifi/__pycache__/core.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/certifi/cacert.pem +0 -0
  10. .venv/lib/python3.11/site-packages/certifi/core.py +114 -0
  11. .venv/lib/python3.11/site-packages/certifi/py.typed +0 -0
  12. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_auth.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_cache_assets.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_cache_manager.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_datetime.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_deprecation.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_experimental.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/endpoint_helpers.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/insecure_hashlib.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/logging.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/tqdm.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 +3 -0
  24. .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/INSTALLER +1 -0
  25. .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/LICENSE +201 -0
  26. .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/METADATA +51 -0
  27. .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/NOTICE +5 -0
  28. .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/RECORD +58 -0
  29. .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/WHEEL +5 -0
  30. .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/top_level.txt +1 -0
  31. .venv/lib/python3.11/site-packages/torch/_C.cpython-311-x86_64-linux-gnu.so +0 -0
  32. .venv/lib/python3.11/site-packages/torch/_VF.py +31 -0
  33. .venv/lib/python3.11/site-packages/torch/_VF.pyi +0 -0
  34. .venv/lib/python3.11/site-packages/torch/__config__.py +23 -0
  35. .venv/lib/python3.11/site-packages/torch/__future__.py +75 -0
  36. .venv/lib/python3.11/site-packages/torch/__init__.py +2665 -0
  37. .venv/lib/python3.11/site-packages/torch/_appdirs.py +667 -0
  38. .venv/lib/python3.11/site-packages/torch/_classes.py +56 -0
  39. .venv/lib/python3.11/site-packages/torch/_compile.py +38 -0
  40. .venv/lib/python3.11/site-packages/torch/_custom_ops.py +324 -0
  41. .venv/lib/python3.11/site-packages/torch/_deploy.py +104 -0
  42. .venv/lib/python3.11/site-packages/torch/_guards.py +925 -0
  43. .venv/lib/python3.11/site-packages/torch/_jit_internal.py +1547 -0
  44. .venv/lib/python3.11/site-packages/torch/_linalg_utils.py +150 -0
  45. .venv/lib/python3.11/site-packages/torch/_lobpcg.py +1157 -0
  46. .venv/lib/python3.11/site-packages/torch/_lowrank.py +294 -0
  47. .venv/lib/python3.11/site-packages/torch/_meta_registrations.py +0 -0
  48. .venv/lib/python3.11/site-packages/torch/_namedtensor_internals.py +159 -0
  49. .venv/lib/python3.11/site-packages/torch/_ops.py +1355 -0
  50. .venv/lib/python3.11/site-packages/torch/_python_dispatcher.py +182 -0
.gitattributes CHANGED
@@ -413,3 +413,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
413
  .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.4 filter=lfs diff=lfs merge=lfs -text
414
  .venv/lib/python3.11/site-packages/cv2/cv2.abi3.so filter=lfs diff=lfs merge=lfs -text
415
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_engines_runtime_compiled.so.9 filter=lfs diff=lfs merge=lfs -text
 
 
413
  .venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.4 filter=lfs diff=lfs merge=lfs -text
414
  .venv/lib/python3.11/site-packages/cv2/cv2.abi3.so filter=lfs diff=lfs merge=lfs -text
415
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_engines_runtime_compiled.so.9 filter=lfs diff=lfs merge=lfs -text
416
+ .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/apiclient/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retain apiclient as an alias for googleapiclient."""
2
+
3
+ from googleapiclient import channel, discovery, errors, http, mimeparse, model
4
+
5
+ try:
6
+ from googleapiclient import sample_tools
7
+ except ImportError:
8
+ # Silently ignore, because the vast majority of consumers won't use it and
9
+ # it has deep dependence on oauth2client, an optional dependency.
10
+ sample_tools = None
11
+ from googleapiclient import schema
12
+
13
+ _SUBMODULES = {
14
+ "channel": channel,
15
+ "discovery": discovery,
16
+ "errors": errors,
17
+ "http": http,
18
+ "mimeparse": mimeparse,
19
+ "model": model,
20
+ "sample_tools": sample_tools,
21
+ "schema": schema,
22
+ }
23
+
24
+ import sys
25
+
26
+ for module_name, module in _SUBMODULES.items():
27
+ sys.modules["apiclient.%s" % module_name] = module
.venv/lib/python3.11/site-packages/apiclient/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (954 Bytes). View file
 
.venv/lib/python3.11/site-packages/certifi/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .core import contents, where
2
+
3
+ __all__ = ["contents", "where"]
4
+ __version__ = "2024.12.14"
.venv/lib/python3.11/site-packages/certifi/__main__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from certifi import contents, where
4
+
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("-c", "--contents", action="store_true")
7
+ args = parser.parse_args()
8
+
9
+ if args.contents:
10
+ print(contents())
11
+ else:
12
+ print(where())
.venv/lib/python3.11/site-packages/certifi/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (322 Bytes). View file
 
.venv/lib/python3.11/site-packages/certifi/__pycache__/__main__.cpython-311.pyc ADDED
Binary file (711 Bytes). View file
 
.venv/lib/python3.11/site-packages/certifi/__pycache__/core.cpython-311.pyc ADDED
Binary file (3.75 kB). View file
 
.venv/lib/python3.11/site-packages/certifi/cacert.pem ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/certifi/core.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ certifi.py
3
+ ~~~~~~~~~~
4
+
5
+ This module returns the installation location of cacert.pem or its contents.
6
+ """
7
+ import sys
8
+ import atexit
9
+
10
+ def exit_cacert_ctx() -> None:
11
+ _CACERT_CTX.__exit__(None, None, None) # type: ignore[union-attr]
12
+
13
+
14
+ if sys.version_info >= (3, 11):
15
+
16
+ from importlib.resources import as_file, files
17
+
18
+ _CACERT_CTX = None
19
+ _CACERT_PATH = None
20
+
21
+ def where() -> str:
22
+ # This is slightly terrible, but we want to delay extracting the file
23
+ # in cases where we're inside of a zipimport situation until someone
24
+ # actually calls where(), but we don't want to re-extract the file
25
+ # on every call of where(), so we'll do it once then store it in a
26
+ # global variable.
27
+ global _CACERT_CTX
28
+ global _CACERT_PATH
29
+ if _CACERT_PATH is None:
30
+ # This is slightly janky, the importlib.resources API wants you to
31
+ # manage the cleanup of this file, so it doesn't actually return a
32
+ # path, it returns a context manager that will give you the path
33
+ # when you enter it and will do any cleanup when you leave it. In
34
+ # the common case of not needing a temporary file, it will just
35
+ # return the file system location and the __exit__() is a no-op.
36
+ #
37
+ # We also have to hold onto the actual context manager, because
38
+ # it will do the cleanup whenever it gets garbage collected, so
39
+ # we will also store that at the global level as well.
40
+ _CACERT_CTX = as_file(files("certifi").joinpath("cacert.pem"))
41
+ _CACERT_PATH = str(_CACERT_CTX.__enter__())
42
+ atexit.register(exit_cacert_ctx)
43
+
44
+ return _CACERT_PATH
45
+
46
+ def contents() -> str:
47
+ return files("certifi").joinpath("cacert.pem").read_text(encoding="ascii")
48
+
49
+ elif sys.version_info >= (3, 7):
50
+
51
+ from importlib.resources import path as get_path, read_text
52
+
53
+ _CACERT_CTX = None
54
+ _CACERT_PATH = None
55
+
56
+ def where() -> str:
57
+ # This is slightly terrible, but we want to delay extracting the
58
+ # file in cases where we're inside of a zipimport situation until
59
+ # someone actually calls where(), but we don't want to re-extract
60
+ # the file on every call of where(), so we'll do it once then store
61
+ # it in a global variable.
62
+ global _CACERT_CTX
63
+ global _CACERT_PATH
64
+ if _CACERT_PATH is None:
65
+ # This is slightly janky, the importlib.resources API wants you
66
+ # to manage the cleanup of this file, so it doesn't actually
67
+ # return a path, it returns a context manager that will give
68
+ # you the path when you enter it and will do any cleanup when
69
+ # you leave it. In the common case of not needing a temporary
70
+ # file, it will just return the file system location and the
71
+ # __exit__() is a no-op.
72
+ #
73
+ # We also have to hold onto the actual context manager, because
74
+ # it will do the cleanup whenever it gets garbage collected, so
75
+ # we will also store that at the global level as well.
76
+ _CACERT_CTX = get_path("certifi", "cacert.pem")
77
+ _CACERT_PATH = str(_CACERT_CTX.__enter__())
78
+ atexit.register(exit_cacert_ctx)
79
+
80
+ return _CACERT_PATH
81
+
82
+ def contents() -> str:
83
+ return read_text("certifi", "cacert.pem", encoding="ascii")
84
+
85
+ else:
86
+ import os
87
+ import types
88
+ from typing import Union
89
+
90
+ Package = Union[types.ModuleType, str]
91
+ Resource = Union[str, "os.PathLike"]
92
+
93
+ # This fallback will work for Python versions prior to 3.7 that lack the
94
+ # importlib.resources module but relies on the existing `where` function
95
+ # so won't address issues with environments like PyOxidizer that don't set
96
+ # __file__ on modules.
97
+ def read_text(
98
+ package: Package,
99
+ resource: Resource,
100
+ encoding: str = 'utf-8',
101
+ errors: str = 'strict'
102
+ ) -> str:
103
+ with open(where(), encoding=encoding) as data:
104
+ return data.read()
105
+
106
+ # If we don't have importlib.resources, then we will just do the old logic
107
+ # of assuming we're on the filesystem and munge the path directly.
108
+ def where() -> str:
109
+ f = os.path.dirname(__file__)
110
+
111
+ return os.path.join(f, "cacert.pem")
112
+
113
+ def contents() -> str:
114
+ return read_text("certifi", "cacert.pem", encoding="ascii")
.venv/lib/python3.11/site-packages/certifi/py.typed ADDED
File without changes
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.86 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_auth.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_cache_assets.cpython-311.pyc ADDED
Binary file (5.77 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_cache_manager.cpython-311.pyc ADDED
Binary file (40.5 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_datetime.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_deprecation.cpython-311.pyc ADDED
Binary file (7.45 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_experimental.cpython-311.pyc ADDED
Binary file (2.41 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/endpoint_helpers.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/insecure_hashlib.cpython-311.pyc ADDED
Binary file (623 Bytes). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/logging.cpython-311.pyc ADDED
Binary file (6.53 kB). View file
 
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/tqdm.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
.venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47662749a295f771b92abe8d99dcd5f151953d56069a19f43977b97868ec21eb
3
+ size 82303400
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/METADATA ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: prometheus_client
3
+ Version: 0.21.1
4
+ Summary: Python client for the Prometheus monitoring system.
5
+ Home-page: https://github.com/prometheus/client_python
6
+ Author: Brian Brazil
7
+ Author-email: brian.brazil@robustperception.io
8
+ License: Apache Software License 2.0
9
+ Keywords: prometheus monitoring instrumentation client
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Information Technology
13
+ Classifier: Intended Audience :: System Administrators
14
+ Classifier: Programming Language :: Python
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.8
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Programming Language :: Python :: Implementation :: CPython
22
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
23
+ Classifier: Topic :: System :: Monitoring
24
+ Classifier: License :: OSI Approved :: Apache Software License
25
+ Requires-Python: >=3.8
26
+ Description-Content-Type: text/markdown
27
+ License-File: LICENSE
28
+ License-File: NOTICE
29
+ Provides-Extra: twisted
30
+ Requires-Dist: twisted; extra == "twisted"
31
+
32
+ # Prometheus Python Client
33
+
34
+ The official Python client for [Prometheus](https://prometheus.io).
35
+
36
+ ## Installation
37
+
38
+ ```
39
+ pip install prometheus-client
40
+ ```
41
+
42
+ This package can be found on [PyPI](https://pypi.python.org/pypi/prometheus_client).
43
+
44
+ ## Documentation
45
+
46
+ Documentation is available on https://prometheus.github.io/client_python
47
+
48
+ ## Links
49
+
50
+ * [Releases](https://github.com/prometheus/client_python/releases): The releases page shows the history of the project and acts as a changelog.
51
+ * [PyPI](https://pypi.python.org/pypi/prometheus_client)
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/NOTICE ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Prometheus instrumentation library for Python applications
2
+ Copyright 2015 The Prometheus Authors
3
+
4
+ This product bundles decorator 4.0.10 which is available under a "2-clause BSD"
5
+ license. For details, see prometheus_client/decorator.py.
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/RECORD ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prometheus_client-0.21.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ prometheus_client-0.21.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
3
+ prometheus_client-0.21.1.dist-info/METADATA,sha256=r74KhsmW6__tSpz4xH6BX7qsbJfFWYfj24x1elsVtr8,1842
4
+ prometheus_client-0.21.1.dist-info/NOTICE,sha256=TvoYdK6qYPNl9Xl-YX8f-TPhXlCOr3UemEjtRBPXp64,236
5
+ prometheus_client-0.21.1.dist-info/RECORD,,
6
+ prometheus_client-0.21.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
7
+ prometheus_client-0.21.1.dist-info/top_level.txt,sha256=AxLEvHEMhTW-Kvb9Ly1DPI3aapigQ2aeg8TXMt9WMRo,18
8
+ prometheus_client/__init__.py,sha256=D-ptlQkWPXqZIJPi5TR0QNMdWr_Ejv-gMq6WAFik_9o,1815
9
+ prometheus_client/__pycache__/__init__.cpython-311.pyc,,
10
+ prometheus_client/__pycache__/asgi.cpython-311.pyc,,
11
+ prometheus_client/__pycache__/context_managers.cpython-311.pyc,,
12
+ prometheus_client/__pycache__/core.cpython-311.pyc,,
13
+ prometheus_client/__pycache__/decorator.cpython-311.pyc,,
14
+ prometheus_client/__pycache__/exposition.cpython-311.pyc,,
15
+ prometheus_client/__pycache__/gc_collector.cpython-311.pyc,,
16
+ prometheus_client/__pycache__/metrics.cpython-311.pyc,,
17
+ prometheus_client/__pycache__/metrics_core.cpython-311.pyc,,
18
+ prometheus_client/__pycache__/mmap_dict.cpython-311.pyc,,
19
+ prometheus_client/__pycache__/multiprocess.cpython-311.pyc,,
20
+ prometheus_client/__pycache__/parser.cpython-311.pyc,,
21
+ prometheus_client/__pycache__/platform_collector.cpython-311.pyc,,
22
+ prometheus_client/__pycache__/process_collector.cpython-311.pyc,,
23
+ prometheus_client/__pycache__/registry.cpython-311.pyc,,
24
+ prometheus_client/__pycache__/samples.cpython-311.pyc,,
25
+ prometheus_client/__pycache__/utils.cpython-311.pyc,,
26
+ prometheus_client/__pycache__/values.cpython-311.pyc,,
27
+ prometheus_client/asgi.py,sha256=ivn-eV7ZU0BEa4E9oWBFbBRUklHPw9f5lcdGsyFuCLo,1606
28
+ prometheus_client/bridge/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
+ prometheus_client/bridge/__pycache__/__init__.cpython-311.pyc,,
30
+ prometheus_client/bridge/__pycache__/graphite.cpython-311.pyc,,
31
+ prometheus_client/bridge/graphite.py,sha256=m5-7IyVyGL8C6S9yLxeupS1pfj8KFNPNlazddamQT8s,2897
32
+ prometheus_client/context_managers.py,sha256=E7uksn4D7yBoZWDgjI1VRpR3l2tKivs9DHZ5UAcmPwE,2343
33
+ prometheus_client/core.py,sha256=yyVvSxa8WQnBvAr4JhO3HqdTqClwhbzmVGvwRvWQMIo,860
34
+ prometheus_client/decorator.py,sha256=7MdUokWmzQ17foet2R5QcMubdZ1WDPGYo0_HqLxAw2k,15802
35
+ prometheus_client/exposition.py,sha256=nmushN6NIGo-nOBeaCXfg5bCeyvesVM_DXUWmRjFwr4,26176
36
+ prometheus_client/gc_collector.py,sha256=tBhXXktF9g9h7gvO-DmI2gxPol2_gXI1M6e9ZMazNfY,1514
37
+ prometheus_client/metrics.py,sha256=ypy4Vv0duzCgo4ZXHBNK45uU9hbe7iK-Fohv7EJ_I5A,28109
38
+ prometheus_client/metrics_core.py,sha256=Yz-yqS3pxNdpIRMShQv_IHaKlVS_Q53TaYcP9U8LDlE,15548
39
+ prometheus_client/mmap_dict.py,sha256=-t49kywZHFHk2D9IWtunqKFtr5eEgiN-RjFWg16JE-Q,5393
40
+ prometheus_client/multiprocess.py,sha256=VIvAR0vmjL0lknnTijKt9HS1DNz9rZrS09HqIIcaZLs,7539
41
+ prometheus_client/openmetrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
+ prometheus_client/openmetrics/__pycache__/__init__.cpython-311.pyc,,
43
+ prometheus_client/openmetrics/__pycache__/exposition.cpython-311.pyc,,
44
+ prometheus_client/openmetrics/__pycache__/parser.cpython-311.pyc,,
45
+ prometheus_client/openmetrics/exposition.py,sha256=Ef3GeveuojMzOrl-T7cG6Ml2TRN1xIYjpe_puReFrlo,2993
46
+ prometheus_client/openmetrics/parser.py,sha256=c6vQccyW93MXzc22QGdceETg0m_KMeMyEbKrfObG0R8,22125
47
+ prometheus_client/parser.py,sha256=zuVhB8clFPvQ9wOEj1XikN7NoJe8J3pZcQkNgEUkuXg,7434
48
+ prometheus_client/platform_collector.py,sha256=t_GD2oCLN3Pql4TltbNqTap8a4HOtbvBm0OU5_gPn38,1879
49
+ prometheus_client/process_collector.py,sha256=B8y36L1iq0c3KFlvdNj1F5JEQLTec116h6y3m9Jhk90,3864
50
+ prometheus_client/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
+ prometheus_client/registry.py,sha256=3R-yxiPitVs36cnIRnotqSJmOPwAQsLz-tl6kw3rcd4,6196
52
+ prometheus_client/samples.py,sha256=smIiOIsAwGXHgM_7xg9Zo5yTEM2gavYvVtgGTjdWMcA,1687
53
+ prometheus_client/twisted/__init__.py,sha256=0RxJjYSOC5p6o2cu6JbfUzc8ReHYQGNv9pKP-U4u7OE,72
54
+ prometheus_client/twisted/__pycache__/__init__.cpython-311.pyc,,
55
+ prometheus_client/twisted/__pycache__/_exposition.cpython-311.pyc,,
56
+ prometheus_client/twisted/_exposition.py,sha256=2TL2BH5sW0i6H7dHkot9aBH9Ld-I60ax55DuaIWnElo,250
57
+ prometheus_client/utils.py,sha256=zKJZaW_hyZgQSmkaD-rgT5l-YsT3--le0BRQ7v_x8eE,594
58
+ prometheus_client/values.py,sha256=hzThQQd0x4mIPR3ddezQpjUoDVdSBnwem4Z48woxpa8,5002
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.6.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ prometheus_client
.venv/lib/python3.11/site-packages/torch/_C.cpython-311-x86_64-linux-gnu.so ADDED
Binary file (37.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_VF.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This makes the functions in torch._C._VariableFunctions available as
3
+ torch._VF.<funcname>
4
+ without mypy being able to find them.
5
+
6
+ A subset of those functions are mapped to ATen functions in
7
+ torch/jit/_builtins.py
8
+
9
+ See https://github.com/pytorch/pytorch/issues/21478 for the reason for
10
+ introducing torch._VF
11
+
12
+ """
13
+
14
+ import sys
15
+ import types
16
+
17
+ import torch
18
+
19
+
20
+ class VFModule(types.ModuleType):
21
+ vf: types.ModuleType
22
+
23
+ def __init__(self, name: str):
24
+ super().__init__(name)
25
+ self.vf = torch._C._VariableFunctions
26
+
27
+ def __getattr__(self, name: str) -> object:
28
+ return getattr(self.vf, name)
29
+
30
+
31
+ sys.modules[__name__] = VFModule(__name__)
.venv/lib/python3.11/site-packages/torch/_VF.pyi ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/__config__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+
4
+
5
+ def show():
6
+ """
7
+ Return a human-readable string with descriptions of the
8
+ configuration of PyTorch.
9
+ """
10
+ return torch._C._show_config()
11
+
12
+
13
+ # TODO: In principle, we could provide more structured version/config
14
+ # information here. For now only CXX_FLAGS is exposed, as Timer
15
+ # uses them.
16
+ def _cxx_flags():
17
+ """Returns the CXX_FLAGS used when building PyTorch."""
18
+ return torch._C._cxx_flags()
19
+
20
+
21
+ def parallel_info():
22
+ r"""Returns detailed string with parallelization settings"""
23
+ return torch._C._parallel_info()
.venv/lib/python3.11/site-packages/torch/__future__.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _overwrite_module_params_on_conversion: bool = False
2
+ _swap_module_params_on_conversion: bool = False
3
+
4
+
5
+ def set_overwrite_module_params_on_conversion(value: bool) -> None:
6
+ """
7
+ Sets whether to assign new tensors to the parameters instead of changing the
8
+ existing parameters in-place when converting an ``nn.Module``.
9
+
10
+ When enabled, the following methods will assign new parameters to the module:
11
+
12
+ #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
13
+ #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
14
+ #. :meth:`nn.Module.to`
15
+ #. :meth:`nn.Module.to_empty`
16
+
17
+ Args:
18
+ value (bool): Whether to assign new tensors or not.
19
+
20
+ """
21
+ global _overwrite_module_params_on_conversion
22
+ _overwrite_module_params_on_conversion = value
23
+
24
+
25
+ def get_overwrite_module_params_on_conversion() -> bool:
26
+ """
27
+ Returns whether to assign new tensors to the parameters instead of changing the
28
+ existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``.
29
+
30
+ See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
31
+ """
32
+ return _overwrite_module_params_on_conversion
33
+
34
+
35
+ def set_swap_module_params_on_conversion(value: bool) -> None:
36
+ """
37
+ Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
38
+ change the existing parameters in-place when converting an ``nn.Module`` and instead
39
+ of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
40
+
41
+ .. note::
42
+ This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion`
43
+
44
+ When enabled, the following methods will swap the existing parameters in-place:
45
+
46
+ #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
47
+ #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
48
+ #. :meth:`nn.Module.to`
49
+ #. :meth:`nn.Module.to_empty`
50
+ #. :meth:`nn.Module.load_state_dict`
51
+
52
+ The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
53
+
54
+ #. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via
55
+ :meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
56
+ #. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
57
+ #. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`
58
+ with ``res``
59
+
60
+ Args:
61
+ value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not.
62
+
63
+ """
64
+ global _swap_module_params_on_conversion
65
+ _swap_module_params_on_conversion = value
66
+
67
+
68
+ def get_swap_module_params_on_conversion() -> bool:
69
+ """
70
+ Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
71
+ change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
72
+
73
+ See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information.
74
+ """
75
+ return _swap_module_params_on_conversion
.venv/lib/python3.11/site-packages/torch/__init__.py ADDED
@@ -0,0 +1,2665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The torch package contains data structures for multi-dimensional
3
+ tensors and defines mathematical operations over these tensors.
4
+ Additionally, it provides many utilities for efficient serialization of
5
+ Tensors and arbitrary types, and other useful utilities.
6
+
7
+ It has a CUDA counterpart, that enables you to run your tensor computations
8
+ on an NVIDIA GPU with compute capability >= 3.0.
9
+ """
10
+
11
+ # mypy: allow-untyped-defs
12
+
13
+ import builtins
14
+ import ctypes
15
+ import glob
16
+ import importlib
17
+ import inspect
18
+ import math
19
+ import os
20
+ import platform
21
+ import sys
22
+ import textwrap
23
+ import threading
24
+ from typing import (
25
+ Any as _Any,
26
+ Callable as _Callable,
27
+ Dict as _Dict,
28
+ Optional as _Optional,
29
+ overload as _overload,
30
+ Set as _Set,
31
+ Tuple as _Tuple,
32
+ Type as _Type,
33
+ TYPE_CHECKING,
34
+ TypeVar as _TypeVar,
35
+ Union as _Union,
36
+ )
37
+ from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from .types import IntLikeType
42
+
43
+
44
+ # multipy/deploy is setting this import before importing torch, this is the most
45
+ # reliable way we have to detect if we're running within deploy.
46
+ # https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137
47
+ def _running_with_deploy() -> builtins.bool:
48
+ return sys.modules.get("torch._meta_registrations", None) is object
49
+
50
+
51
+ from torch._utils import (
52
+ _functionalize_sync as _sync,
53
+ _import_dotted_name,
54
+ classproperty,
55
+ )
56
+ from torch._utils_internal import (
57
+ get_file_path,
58
+ prepare_multiprocessing_environment,
59
+ USE_GLOBAL_DEPS,
60
+ USE_RTLD_GLOBAL_WITH_LIBTORCH,
61
+ )
62
+
63
+
64
+ # TODO(torch_deploy) figure out how to freeze version.py in fbcode build
65
+ if _running_with_deploy():
66
+ __version__ = "torch-deploy-1.8"
67
+ else:
68
+ from torch.torch_version import __version__ as __version__
69
+
70
+ __all__ = [
71
+ "BoolStorage",
72
+ "BoolTensor",
73
+ "ByteStorage",
74
+ "ByteTensor",
75
+ "CharStorage",
76
+ "CharTensor",
77
+ "DoubleStorage",
78
+ "DoubleTensor",
79
+ "FloatStorage",
80
+ "FloatTensor",
81
+ "GradScaler",
82
+ "IntStorage",
83
+ "IntTensor",
84
+ "LongStorage",
85
+ "LongTensor",
86
+ "ShortStorage",
87
+ "ShortTensor",
88
+ "SymBool",
89
+ "SymFloat",
90
+ "SymInt",
91
+ "Tensor",
92
+ "TypedStorage",
93
+ "UntypedStorage",
94
+ "are_deterministic_algorithms_enabled",
95
+ "autocast",
96
+ "chunk",
97
+ "compile",
98
+ "cond",
99
+ "enable_grad",
100
+ "export",
101
+ "get_default_device",
102
+ "get_deterministic_debug_mode",
103
+ "get_device_module",
104
+ "get_float32_matmul_precision",
105
+ "get_rng_state",
106
+ "inference_mode",
107
+ "initial_seed",
108
+ "is_deterministic_algorithms_warn_only_enabled",
109
+ "is_storage",
110
+ "is_tensor",
111
+ "is_warn_always_enabled",
112
+ "load",
113
+ "lobpcg",
114
+ "manual_seed",
115
+ "matmul",
116
+ "no_grad",
117
+ "rand",
118
+ "randn",
119
+ "save",
120
+ "seed",
121
+ "set_default_device",
122
+ "set_default_tensor_type",
123
+ "set_deterministic_debug_mode",
124
+ "set_float32_matmul_precision",
125
+ "set_printoptions",
126
+ "set_rng_state",
127
+ "set_warn_always",
128
+ "split",
129
+ "stack",
130
+ "sym_float",
131
+ "sym_int",
132
+ "sym_ite",
133
+ "sym_max",
134
+ "sym_min",
135
+ "sym_not",
136
+ "typename",
137
+ "unravel_index",
138
+ "use_deterministic_algorithms",
139
+ "vmap",
140
+ ]
141
+
142
+ # Please keep this list sorted
143
+ assert __all__ == sorted(__all__)
144
+
145
+ ################################################################################
146
+ # Load the extension module
147
+ ################################################################################
148
+
149
+ if sys.platform == "win32":
150
+
151
+ def _load_dll_libraries() -> None:
152
+ import sysconfig
153
+
154
+ from torch.version import cuda as cuda_version
155
+
156
+ pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files")
157
+ py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin")
158
+ th_dll_path = os.path.join(os.path.dirname(__file__), "lib")
159
+ usebase_path = os.path.join(
160
+ sysconfig.get_config_var("userbase"), "Library", "bin"
161
+ )
162
+
163
+ # When users create a virtualenv that inherits the base environment,
164
+ # we will need to add the corresponding library directory into
165
+ # DLL search directories. Otherwise, it will rely on `PATH` which
166
+ # is dependent on user settings.
167
+ if sys.exec_prefix != sys.base_exec_prefix:
168
+ base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin")
169
+ else:
170
+ base_py_dll_path = ""
171
+
172
+ dll_paths = [
173
+ p
174
+ for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path)
175
+ if os.path.exists(p)
176
+ ]
177
+
178
+ if not builtins.any(
179
+ os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths
180
+ ):
181
+ nvtoolsext_dll_path = os.path.join(
182
+ os.getenv(
183
+ "NVTOOLSEXT_PATH",
184
+ os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"),
185
+ ),
186
+ "bin",
187
+ "x64",
188
+ )
189
+ else:
190
+ nvtoolsext_dll_path = ""
191
+
192
+ if cuda_version and builtins.all(
193
+ not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths
194
+ ):
195
+ cuda_version_1 = cuda_version.replace(".", "_")
196
+ cuda_path_var = "CUDA_PATH_V" + cuda_version_1
197
+ default_path = os.path.join(
198
+ pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}"
199
+ )
200
+ cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin")
201
+ else:
202
+ cuda_path = ""
203
+
204
+ dll_paths.extend(
205
+ p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p)
206
+ )
207
+
208
+ kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
209
+ with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
210
+ prev_error_mode = kernel32.SetErrorMode(0x0001)
211
+
212
+ kernel32.LoadLibraryW.restype = ctypes.c_void_p
213
+ if with_load_library_flags:
214
+ kernel32.LoadLibraryExW.restype = ctypes.c_void_p
215
+
216
+ for dll_path in dll_paths:
217
+ os.add_dll_directory(dll_path)
218
+
219
+ try:
220
+ ctypes.CDLL("vcruntime140.dll")
221
+ ctypes.CDLL("msvcp140.dll")
222
+ ctypes.CDLL("vcruntime140_1.dll")
223
+ except OSError:
224
+ print(
225
+ textwrap.dedent(
226
+ """
227
+ Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
228
+ It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe
229
+ """
230
+ ).strip()
231
+ )
232
+
233
+ dlls = glob.glob(os.path.join(th_dll_path, "*.dll"))
234
+ path_patched = False
235
+ for dll in dlls:
236
+ is_loaded = False
237
+ if with_load_library_flags:
238
+ res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
239
+ last_error = ctypes.get_last_error()
240
+ if res is None and last_error != 126:
241
+ err = ctypes.WinError(last_error)
242
+ err.strerror += (
243
+ f' Error loading "{dll}" or one of its dependencies.'
244
+ )
245
+ raise err
246
+ elif res is not None:
247
+ is_loaded = True
248
+ if not is_loaded:
249
+ if not path_patched:
250
+ os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
251
+ path_patched = True
252
+ res = kernel32.LoadLibraryW(dll)
253
+ if res is None:
254
+ err = ctypes.WinError(ctypes.get_last_error())
255
+ err.strerror += (
256
+ f' Error loading "{dll}" or one of its dependencies.'
257
+ )
258
+ raise err
259
+
260
+ kernel32.SetErrorMode(prev_error_mode)
261
+
262
+ _load_dll_libraries()
263
+ del _load_dll_libraries
264
+
265
+
266
+ def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None:
267
+ """Preloads cuda deps if they could not be found otherwise."""
268
+ # Should only be called on Linux if default path resolution have failed
269
+ assert platform.system() == "Linux", "Should only be called on Linux"
270
+
271
+ lib_path = None
272
+ for path in sys.path:
273
+ nvidia_path = os.path.join(path, "nvidia")
274
+ if not os.path.exists(nvidia_path):
275
+ continue
276
+ candidate_lib_paths = glob.glob(
277
+ os.path.join(nvidia_path, lib_folder, "lib", lib_name)
278
+ )
279
+ if candidate_lib_paths and not lib_path:
280
+ lib_path = candidate_lib_paths[0]
281
+ if lib_path:
282
+ break
283
+ if not lib_path:
284
+ raise ValueError(f"{lib_name} not found in the system path {sys.path}")
285
+ ctypes.CDLL(lib_path)
286
+
287
+
288
+ # See Note [Global dependencies]
289
+ def _load_global_deps() -> None:
290
+ if _running_with_deploy() or platform.system() == "Windows":
291
+ return
292
+
293
+ # Determine the file extension based on the platform
294
+ lib_ext = ".dylib" if platform.system() == "Darwin" else ".so"
295
+ lib_name = f"libtorch_global_deps{lib_ext}"
296
+ here = os.path.abspath(__file__)
297
+ global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name)
298
+
299
+ try:
300
+ ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
301
+ except OSError as err:
302
+ # Can only happen for wheel with cuda libs as PYPI deps
303
+ # As PyTorch is not purelib, but nvidia-*-cu12 is
304
+ cuda_libs: _Dict[str, str] = {
305
+ "cublas": "libcublas.so.*[0-9]",
306
+ "cudnn": "libcudnn.so.*[0-9]",
307
+ "cuda_nvrtc": "libnvrtc.so.*[0-9]",
308
+ "cuda_runtime": "libcudart.so.*[0-9]",
309
+ "cuda_cupti": "libcupti.so.*[0-9]",
310
+ "cufft": "libcufft.so.*[0-9]",
311
+ "curand": "libcurand.so.*[0-9]",
312
+ "nvjitlink": "libnvJitLink.so.*[0-9]",
313
+ "cusparse": "libcusparse.so.*[0-9]",
314
+ "cusolver": "libcusolver.so.*[0-9]",
315
+ "nccl": "libnccl.so.*[0-9]",
316
+ "nvtx": "libnvToolsExt.so.*[0-9]",
317
+ }
318
+ is_cuda_lib_err = [
319
+ lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0]
320
+ ]
321
+ if not is_cuda_lib_err:
322
+ raise err
323
+ for lib_folder, lib_name in cuda_libs.items():
324
+ _preload_cuda_deps(lib_folder, lib_name)
325
+ ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
326
+
327
+
328
+ if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
329
+ _running_with_deploy() or platform.system() != "Windows"
330
+ ):
331
+ # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
332
+ # few circumstances:
333
+ #
334
+ # 1. You're in a build environment (e.g., fbcode) where
335
+ # libtorch_global_deps is not available, but you still need
336
+ # to get mkl to link in with RTLD_GLOBAL or it will just
337
+ # not work.
338
+ #
339
+ # 2. You're trying to run PyTorch under UBSAN and you need
340
+ # to ensure that only one copy of libtorch is loaded, so
341
+ # vptr checks work properly
342
+ #
343
+ # If you're using this setting, you must verify that all the libraries
344
+ # you load consistently use the same libstdc++, or you may have
345
+ # mysterious segfaults.
346
+ #
347
+ old_flags = sys.getdlopenflags()
348
+ sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
349
+
350
+ from torch._C import * # noqa: F403
351
+
352
+ sys.setdlopenflags(old_flags)
353
+ del old_flags
354
+
355
+ else:
356
+ # Easy way. You want this most of the time, because it will prevent
357
+ # C++ symbols from libtorch clobbering C++ symbols from other
358
+ # libraries, leading to mysterious segfaults.
359
+ #
360
+ # If building in an environment where libtorch_global_deps isn't available
361
+ # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
362
+ # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
363
+ #
364
+ # See Note [Global dependencies]
365
+ if USE_GLOBAL_DEPS:
366
+ _load_global_deps()
367
+ from torch._C import * # noqa: F403
368
+
369
+
370
+ class SymInt:
371
+ """
372
+ Like an int (including magic methods), but redirects all operations on the
373
+ wrapped node. This is used in particular to symbolically record operations
374
+ in the symbolic shape workflow.
375
+ """
376
+
377
+ def __init__(self, node):
378
+ # This field MUST be named node; C++ binding code assumes that this
379
+ # class has a field named node that stores SymNode
380
+ self.node = node
381
+
382
+ def __bool__(self):
383
+ return builtins.bool(self != 0)
384
+
385
+ def __int__(self):
386
+ return self.node.int_()
387
+
388
+ def __index__(self):
389
+ return self.node.int_()
390
+
391
+ # Magic methods installed by torch.fx.experimental.sym_node
392
+
393
+ def __round__(self, ndigits=None):
394
+ return self
395
+
396
+ def __truediv__(self, other):
397
+ if isinstance(other, (builtins.float, SymFloat)):
398
+ return sym_float(self).__float_truediv__(other)
399
+ if not isinstance(other, (builtins.int, SymInt)):
400
+ return NotImplemented
401
+ return self.__int_truediv__(other)
402
+
403
+ def __rtruediv__(self, other):
404
+ if isinstance(other, (builtins.float, SymFloat)):
405
+ return sym_float(self).__rfloat_truediv__(other)
406
+ if not isinstance(other, (builtins.int, SymInt)):
407
+ return NotImplemented
408
+ return self.__rint_truediv__(other)
409
+
410
+ def __floordiv__(self, other):
411
+ if isinstance(other, (builtins.float, SymFloat)):
412
+ return sym_float(math.floor(sym_float(self) / other))
413
+ if not isinstance(other, (builtins.int, SymInt)):
414
+ return NotImplemented
415
+ return self.__int_floordiv__(other)
416
+
417
+ def __rfloordiv__(self, other):
418
+ if isinstance(other, (builtins.float, SymFloat)):
419
+ return sym_float(math.floor(other / sym_float(self)))
420
+ if not isinstance(other, (builtins.int, SymInt)):
421
+ return NotImplemented
422
+ return self.__rint_floordiv__(other)
423
+
424
+ # nb: complex is impossible to handle correctly lol, with
425
+ # negative base and integral float need to diverge semantics and
426
+ # just always return complex. Neener neener pretend this problem
427
+ # doesn't exist
428
+ def __pow__(self, other):
429
+ if isinstance(other, (builtins.float, SymFloat)):
430
+ return sym_float(self).__pow__(other)
431
+ if not isinstance(other, (builtins.int, SymInt)):
432
+ return NotImplemented
433
+ # Guards! This guard is necessary because we need to know it to
434
+ # determine the output type of this operation
435
+ if other >= 0:
436
+ return self.__pow_by_natural__(other)
437
+ else:
438
+ # Mercifully, when the exponent is negative, Python just promotes
439
+ # to doubles and does a float pow:
440
+ #
441
+ # if (Py_SIZE(b) < 0 && c == NULL) {
442
+ # /* if exponent is negative and there's no modulus:
443
+ # return a float. This works because we know
444
+ # that this calls float_pow() which converts its
445
+ # arguments to double. */
446
+ # Py_DECREF(a);
447
+ # Py_DECREF(b);
448
+ # return PyFloat_Type.tp_as_number->nb_power(v, w, x);
449
+ # }
450
+ return sym_float(self).__pow__(sym_float(other))
451
+
452
+ def __rpow__(self, other):
453
+ if isinstance(other, (builtins.float, SymFloat)):
454
+ return sym_float(self).__rpow__(other)
455
+ if not isinstance(other, (builtins.int, SymInt)):
456
+ return NotImplemented
457
+ if self >= 0: # self is exponent
458
+ return self.__rpow_by_natural__(other)
459
+ else:
460
+ return sym_float(self).__rpow__(sym_float(other))
461
+
462
+ def __eq__(self, other: object) -> builtins.bool:
463
+ raise TypeError("type stub not overridden")
464
+
465
+ def __lt__(self, other) -> builtins.bool:
466
+ raise TypeError("type stub not overridden")
467
+
468
+ def __gt__(self, other) -> builtins.bool:
469
+ raise TypeError("type stub not overridden")
470
+
471
+ def __le__(self, other) -> builtins.bool:
472
+ raise TypeError("type stub not overridden")
473
+
474
+ def __ge__(self, other) -> builtins.bool:
475
+ raise TypeError("type stub not overridden")
476
+
477
+ def __add__(self, other) -> "SymInt":
478
+ raise TypeError("type stub not overridden")
479
+
480
+ def __mod__(self, other: "IntLikeType") -> "SymInt":
481
+ raise TypeError("type stub not overridden")
482
+
483
+ def __mul__(self, other) -> "SymInt":
484
+ raise TypeError("type stub not overridden")
485
+
486
+ def __pow_by_natural__(self, other) -> "SymInt":
487
+ raise TypeError("type stub not overridden")
488
+
489
+ def __rpow_by_natural__(self, other) -> "SymInt":
490
+ raise TypeError("type stub not overridden")
491
+
492
+ def __int_truediv__(self, other) -> "SymFloat":
493
+ raise TypeError("type stub not overridden")
494
+
495
+ def __rint_truediv__(self, other) -> "SymFloat":
496
+ raise TypeError("type stub not overridden")
497
+
498
+ def __int_floordiv__(self, other) -> "SymFloat":
499
+ raise TypeError("type stub not overridden")
500
+
501
+ def __rint_floordiv__(self, other) -> "SymFloat":
502
+ raise TypeError("type stub not overridden")
503
+
504
+ def __sym_max__(self, other):
505
+ raise TypeError("type stub not overridden")
506
+
507
+ def __sym_min__(self, other):
508
+ raise TypeError("type stub not overridden")
509
+
510
+ def __sym_float__(self):
511
+ raise TypeError("type stub not overridden")
512
+
513
+ def __neg__(self):
514
+ raise TypeError("type stub not overridden")
515
+
516
+ def __sub__(self, other: "IntLikeType") -> "SymInt":
517
+ raise TypeError("type stub not overridden")
518
+
519
+ def __repr__(self):
520
+ return self.node._graph_repr()
521
+
522
+ def _sympy_(self):
523
+ return self.node.expr
524
+
525
+ def __hash__(self) -> builtins.int:
526
+ if self.node.is_nested_int():
527
+ return hash(self.node.nested_int())
528
+ else:
529
+ # We could support constant SymInts as well, but not doing it for now
530
+ raise TypeError("unhashable type: non-nested SymInt")
531
+ # TODO: Force specialization
532
+ # This can't be done because the TypeError here is load bearing
533
+ # for einops
534
+ # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
535
+ # return hash(builtins.int(self))
536
+
537
+ def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
538
+ """Represent this int as an exact integer ratio"""
539
+ return self, 1
540
+
541
+ def bit_length(self) -> builtins.int:
542
+ # TODO: A more relaxed guard is possible here, where you guard to
543
+ # allow all integer quantities which would result in the same bit
544
+ # length. We can also just make a dedicated Sympy function for
545
+ # computing this quantity and represent it symbolically.
546
+ return builtins.int(self).bit_length()
547
+
548
+ def conjugate(self) -> "SymInt":
549
+ return self
550
+
551
+
552
+ class SymFloat:
553
+ """
554
+ Like an float (including magic methods), but redirects all operations on the
555
+ wrapped node. This is used in particular to symbolically record operations
556
+ in the symbolic shape workflow.
557
+ """
558
+
559
+ def __init__(self, node):
560
+ # This field MUST be named node; C++ binding code assumes that this
561
+ # class has a field named node that stores SymNode
562
+ self.node = node
563
+
564
+ def __truediv__(self, other):
565
+ if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
566
+ return NotImplemented
567
+ return self.__float_truediv__(sym_float(other))
568
+
569
+ def __rtruediv__(self, other):
570
+ if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
571
+ return NotImplemented
572
+ return self.__rfloat_truediv__(sym_float(other))
573
+
574
+ def __floordiv__(self, other):
575
+ if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
576
+ return NotImplemented
577
+ return sym_float(math.floor(self / sym_float(other)))
578
+
579
+ def __rfloordiv__(self, other):
580
+ if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
581
+ return NotImplemented
582
+ return sym_float(math.floor(sym_float(other) / self))
583
+
584
+ def __bool__(self):
585
+ return self.node.bool_()
586
+
587
+ def __float__(self):
588
+ return self.node.guard_float("", 0)
589
+
590
+ # Symbolic power does NOT work with negative base, this is to avoid
591
+ # potential complex outputs
592
+ def __pow__(self, other):
593
+ if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
594
+ return NotImplemented
595
+ torch._check(self >= 0)
596
+ return self.__float_pow__(other)
597
+
598
+ def __rpow__(self, other):
599
+ if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
600
+ return NotImplemented
601
+ torch._check(other >= 0)
602
+ return self.__rfloat_pow__(other)
603
+
604
+ # Magic methods installed by torch.fx.experimental.sym_node
605
+
606
+ def __eq__(self, other: object) -> builtins.bool:
607
+ raise TypeError("type stub not overridden")
608
+
609
+ def __lt__(self, other) -> builtins.bool:
610
+ raise TypeError("type stub not overridden")
611
+
612
+ def __gt__(self, other) -> builtins.bool:
613
+ raise TypeError("type stub not overridden")
614
+
615
+ def __le__(self, other) -> builtins.bool:
616
+ raise TypeError("type stub not overridden")
617
+
618
+ def __ge__(self, other) -> builtins.bool:
619
+ raise TypeError("type stub not overridden")
620
+
621
+ def __float_pow__(self, other) -> "SymFloat":
622
+ raise TypeError("type stub not overridden")
623
+
624
+ def __rfloat_pow__(self, other) -> "SymFloat":
625
+ raise TypeError("type stub not overridden")
626
+
627
+ def __float_truediv__(self, other) -> "SymFloat":
628
+ raise TypeError("type stub not overridden")
629
+
630
+ def __rfloat_truediv__(self, other) -> "SymFloat":
631
+ raise TypeError("type stub not overridden")
632
+
633
+ def __trunc__(self):
634
+ raise TypeError("type stub not overridden")
635
+
636
+ def __sym_max__(self, other):
637
+ raise TypeError("type stub not overridden")
638
+
639
+ def __sym_min__(self, other):
640
+ raise TypeError("type stub not overridden")
641
+
642
+ def __sym_int__(self):
643
+ raise TypeError("type stub not overridden")
644
+
645
+ def is_integer(self):
646
+ """Return True if the float is an integer."""
647
+ raise TypeError("type stub not overridden")
648
+
649
+ def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
650
+ """Represent this float as an exact integer ratio"""
651
+ return builtins.float(self).as_integer_ratio()
652
+
653
+ def __repr__(self):
654
+ return self.node._graph_repr()
655
+
656
+ def _sympy_(self):
657
+ return self.node.expr
658
+
659
+ def __hash__(self):
660
+ return hash(builtins.float(self))
661
+
662
+
663
+ class SymBool:
664
+ """
665
+ Like an bool (including magic methods), but redirects all operations on the
666
+ wrapped node. This is used in particular to symbolically record operations
667
+ in the symbolic shape workflow.
668
+
669
+ Unlike regular bools, regular boolean operators will force extra guards instead
670
+ of symbolically evaluate. Use the bitwise operators instead to handle this.
671
+ """
672
+
673
+ def __init__(self, node):
674
+ # This field MUST be named node; C++ binding code assumes that this
675
+ # class has a field named node that stores SymNode
676
+ self.node = node
677
+
678
+ def __bool__(self):
679
+ return self.node.bool_()
680
+
681
+ def __int__(self):
682
+ return builtins.int(self.node.bool_())
683
+
684
+ # Magic methods installed by torch.fx.experimental.sym_node
685
+ def __and__(self, other) -> "SymBool":
686
+ raise TypeError("type stub not overridden")
687
+
688
+ def __or__(self, other) -> "SymBool":
689
+ raise TypeError("type stub not overridden")
690
+
691
+ # We very carefully define __sym_not__, and not a number of other
692
+ # plausible alternatives:
693
+ #
694
+ # - We do not override __not__ because this is not a real magic
695
+ # method; you cannot override the meaning of the not builtin in
696
+ # Python. We use the name 'sym_not' to clarify that in user code you
697
+ # cannot use the builtin not or operator.not_ or operator.__not__ and
698
+ # hit this magic method; you must use our custom sym_not operator.
699
+ #
700
+ # - We do not override the __invert__ method because SymBool is
701
+ # meant to be usable in situations where bool is expected. However,
702
+ # bitwise negation ~a does the wrong thing with booleans (because
703
+ # bool is a subclass of int, so ~1 = -2 which is not falseish.)
704
+ # This would be a giant footgun, so we get around it by defining
705
+ # our own operator. Note that bitwise and/or do the right thing,
706
+ # so we reuse the conventional operators there for readability.
707
+ #
708
+ def __sym_not__(self) -> "SymBool":
709
+ raise TypeError("type stub not overridden")
710
+
711
+ def __sym_ite__(self, then_val, else_val):
712
+ raise TypeError("type stub not overridden")
713
+
714
+ def __eq__(self, other) -> builtins.bool:
715
+ raise TypeError("type stub not overridden")
716
+
717
+ def __repr__(self):
718
+ return self.node._graph_repr()
719
+
720
+ def _sympy_(self):
721
+ return self.node.expr
722
+
723
+ def __hash__(self):
724
+ if self.node.is_constant():
725
+ return hash(self.node.bool_())
726
+ else:
727
+ # Force specialization
728
+ return hash(builtins.bool(self))
729
+
730
+
731
+ def sym_not(a):
732
+ r"""SymInt-aware utility for logical negation.
733
+
734
+ Args:
735
+ a (SymBool or bool): Object to negate
736
+ """
737
+ import sympy
738
+
739
+ if overrides.has_torch_function_unary(a):
740
+ return overrides.handle_torch_function(sym_not, (a,), a)
741
+ if hasattr(a, "__sym_not__"):
742
+ return a.__sym_not__()
743
+ if isinstance(a, sympy.Basic):
744
+ return ~a # type: ignore[operator]
745
+ return not a
746
+
747
+
748
+ def sym_float(a):
749
+ r"""SymInt-aware utility for float casting.
750
+
751
+ Args:
752
+ a (SymInt, SymFloat, or object): Object to cast
753
+ """
754
+ if overrides.has_torch_function_unary(a):
755
+ return overrides.handle_torch_function(sym_float, (a,), a)
756
+ if isinstance(a, SymFloat):
757
+ return a
758
+ elif hasattr(a, "__sym_float__"):
759
+ return a.__sym_float__()
760
+ return builtins.float(a) # type: ignore[operator]
761
+
762
+
763
+ def sym_int(a):
764
+ r"""SymInt-aware utility for int casting.
765
+
766
+ Args:
767
+ a (SymInt, SymFloat, or object): Object to cast
768
+ """
769
+ if overrides.has_torch_function_unary(a):
770
+ return overrides.handle_torch_function(sym_int, (a,), a)
771
+ if isinstance(a, SymInt):
772
+ return a
773
+ elif isinstance(a, SymFloat):
774
+ return math.trunc(a)
775
+ return builtins.int(a) # type: ignore[operator]
776
+
777
+
778
+ def sym_max(a, b):
779
+ """
780
+ SymInt-aware utility for max which avoids branching on a < b.
781
+ Unlike builtins.max(), this only works for int/float, and it always
782
+ promotes to float if any argument is float (unlike builtins.max, which
783
+ will faithfully preserve the type of the input argument).
784
+ """
785
+ if overrides.has_torch_function((a, b)):
786
+ return overrides.handle_torch_function(sym_max, (a, b), a, b)
787
+ if isinstance(a, (SymInt, SymFloat)):
788
+ return a.__sym_max__(b)
789
+ elif isinstance(b, (SymInt, SymFloat)):
790
+ # Due to promotion semantics, this is operator is commutative:
791
+ # max(1, 1.0) === max(1.0, 1) === 1.0
792
+ return b.__sym_max__(a)
793
+ # TODO: Probably can make bool work too, just lazy
794
+
795
+ all_types, float_types = __all_and_float_types()
796
+
797
+ assert isinstance(a, all_types), type(a)
798
+ assert isinstance(b, all_types), type(b)
799
+ if isinstance(a, float_types) or isinstance(b, float_types):
800
+ return builtins.float(builtins.max(a, b))
801
+ else:
802
+ return builtins.max(a, b)
803
+
804
+
805
+ def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]:
806
+ try:
807
+ import numpy as np
808
+
809
+ all_types: _Tuple[_Type, ...] = (
810
+ np.integer,
811
+ np.floating,
812
+ builtins.int,
813
+ builtins.float,
814
+ )
815
+ float_types: _Tuple[_Type, ...] = (np.floating, builtins.float)
816
+ except ModuleNotFoundError:
817
+ all_types = (builtins.int, builtins.float)
818
+ float_types = (builtins.float,)
819
+
820
+ return all_types, float_types
821
+
822
+
823
+ def sym_min(a, b):
824
+ """SymInt-aware utility for min()."""
825
+ if overrides.has_torch_function((a, b)):
826
+ return overrides.handle_torch_function(sym_min, (a, b), a, b)
827
+ if isinstance(a, (SymInt, SymFloat)):
828
+ return a.__sym_min__(b)
829
+ elif isinstance(b, (SymInt, SymFloat)):
830
+ return b.__sym_min__(a)
831
+
832
+ all_types, float_types = __all_and_float_types()
833
+
834
+ assert isinstance(a, all_types), type(a)
835
+ assert isinstance(b, all_types), type(b)
836
+ if isinstance(a, float_types) or isinstance(b, float_types):
837
+ return builtins.float(builtins.min(a, b))
838
+ else:
839
+ return builtins.min(a, b)
840
+
841
+
842
+ # Drop in replacement for math.sqrt, math.sin, math.cos etc
843
+ def _get_sym_math_fn(name):
844
+ def fn(a):
845
+ if overrides.has_torch_function_unary(a):
846
+ return overrides.handle_torch_function(fn, (a,), a)
847
+ if hasattr(a, f"__sym_{name}__"):
848
+ return getattr(a, f"__sym_{name}__")()
849
+ return getattr(math, name)(a)
850
+
851
+ return fn
852
+
853
+
854
+ __fn, __name, __sym_name = None, "", ""
855
+ for __name in (
856
+ "sqrt",
857
+ "cos",
858
+ "cosh",
859
+ "sin",
860
+ "sinh",
861
+ "tan",
862
+ "tanh",
863
+ "asin",
864
+ "acos",
865
+ "atan",
866
+ ):
867
+ __sym_name = f"_sym_{__name}"
868
+ __fn = _get_sym_math_fn(__name)
869
+ __fn.__qualname__ = __fn.__name__ = __sym_name
870
+ globals()[__sym_name] = __fn
871
+
872
+ del __fn, __name, __sym_name, _get_sym_math_fn
873
+
874
+ # Adding temporary shortcut
875
+ sym_sqrt = globals()["_sym_sqrt"]
876
+ __all__.append("sym_sqrt")
877
+
878
+
879
+ def sym_ite(b, t, f):
880
+ if overrides.has_torch_function((b, t, f)):
881
+ return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
882
+ assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
883
+ if isinstance(b, SymBool):
884
+ return b.__sym_ite__(t, f)
885
+ return t if b else f
886
+
887
+
888
+ # Check to see if we can load C extensions, and if not provide some guidance
889
+ # on what the problem might be.
890
+ try:
891
+ # _initExtension is chosen (arbitrarily) as a sentinel.
892
+ from torch._C import _initExtension
893
+ except ImportError:
894
+ import torch._C as _C_for_compiled_check
895
+
896
+ # The __file__ check only works for Python 3.7 and above.
897
+ if _C_for_compiled_check.__file__ is None:
898
+ raise ImportError(
899
+ textwrap.dedent(
900
+ """
901
+ Failed to load PyTorch C extensions:
902
+ It appears that PyTorch has loaded the `torch/_C` folder
903
+ of the PyTorch repository rather than the C extensions which
904
+ are expected in the `torch._C` namespace. This can occur when
905
+ using the `install` workflow. e.g.
906
+ $ python setup.py install && python -c "import torch"
907
+
908
+ This error can generally be solved using the `develop` workflow
909
+ $ python setup.py develop && python -c "import torch" # This should succeed
910
+ or by running Python from a different directory.
911
+ """
912
+ ).strip()
913
+ ) from None
914
+ raise # If __file__ is not None the cause is unknown, so just re-raise.
915
+
916
+ # The torch._C submodule is already loaded via `from torch._C import *` above
917
+ # Make an explicit reference to the _C submodule to appease linters
918
+ from torch import _C as _C
919
+
920
+
921
+ __name, __obj = "", None
922
+ for __name in dir(_C):
923
+ if __name[0] != "_" and not __name.endswith("Base"):
924
+ __all__.append(__name)
925
+ __obj = getattr(_C, __name)
926
+ if callable(__obj) or inspect.isclass(__obj):
927
+ if __obj.__module__ != __name__: # "torch"
928
+ # TODO: fix their module from C++ side
929
+ if __name not in {
930
+ "DisableTorchFunctionSubclass",
931
+ "DisableTorchFunction",
932
+ "Generator",
933
+ }:
934
+ __obj.__module__ = __name__ # "torch"
935
+ elif __name == "TensorBase":
936
+ # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
937
+ delattr(sys.modules[__name__], __name)
938
+
939
+ del __name, __obj
940
+
941
+ if not TYPE_CHECKING:
942
+ # issue 38137 and python issue 43367. Submodules of a C extension are
943
+ # non-standard, and attributes of those submodules cannot be pickled since
944
+ # pickle expect to be able to import them as "from _C.sub import attr"
945
+ # which fails with "_C is not a package
946
+ def _import_extension_to_sys_modules(module, memo=None):
947
+ if memo is None:
948
+ memo = set()
949
+ if module in memo:
950
+ return
951
+ memo.add(module)
952
+ module_name = module.__name__
953
+ for name in dir(module):
954
+ member = getattr(module, name)
955
+ member_name = getattr(member, "__name__", "")
956
+ if inspect.ismodule(member) and member_name.startswith(module_name):
957
+ sys.modules.setdefault(member_name, member)
958
+ # Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
959
+ _import_extension_to_sys_modules(member, memo)
960
+
961
+ _import_extension_to_sys_modules(_C)
962
+ del _import_extension_to_sys_modules
963
+
964
+ ################################################################################
965
+ # Define basic utilities
966
+ ################################################################################
967
+
968
+
969
+ def typename(obj: _Any, /) -> str:
970
+ """
971
+ String representation of the type of an object.
972
+
973
+ This function returns a fully qualified string representation of an object's type.
974
+ Args:
975
+ obj (object): The object whose type to represent
976
+ Returns:
977
+ str: the type of the object `o`
978
+ Example:
979
+ >>> x = torch.tensor([1, 2, 3])
980
+ >>> torch.typename(x)
981
+ 'torch.LongTensor'
982
+ >>> torch.typename(torch.nn.Parameter)
983
+ 'torch.nn.parameter.Parameter'
984
+ """
985
+ if isinstance(obj, torch.Tensor):
986
+ return obj.type()
987
+
988
+ module = getattr(obj, "__module__", "") or ""
989
+ qualname = ""
990
+
991
+ if hasattr(obj, "__qualname__"):
992
+ qualname = obj.__qualname__
993
+ elif hasattr(obj, "__name__"):
994
+ qualname = obj.__name__
995
+ else:
996
+ module = obj.__class__.__module__ or ""
997
+ qualname = obj.__class__.__qualname__
998
+
999
+ if module in {"", "builtins"}:
1000
+ return qualname
1001
+ return f"{module}.{qualname}"
1002
+
1003
+
1004
+ def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
1005
+ r"""Returns True if `obj` is a PyTorch tensor.
1006
+
1007
+ Note that this function is simply doing ``isinstance(obj, Tensor)``.
1008
+ Using that ``isinstance`` check is better for typechecking with mypy,
1009
+ and more explicit - so it's recommended to use that instead of
1010
+ ``is_tensor``.
1011
+
1012
+ Args:
1013
+ obj (object): Object to test
1014
+ Example::
1015
+
1016
+ >>> x = torch.tensor([1, 2, 3])
1017
+ >>> torch.is_tensor(x)
1018
+ True
1019
+
1020
+ """
1021
+ return isinstance(obj, torch.Tensor)
1022
+
1023
+
1024
+ def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
1025
+ r"""Returns True if `obj` is a PyTorch storage object.
1026
+
1027
+ Args:
1028
+ obj (Object): Object to test
1029
+ """
1030
+ return type(obj) in _storage_classes
1031
+
1032
+
1033
+ _GLOBAL_DEVICE_CONTEXT = threading.local()
1034
+
1035
+
1036
+ def get_default_device() -> "torch.device":
1037
+ r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
1038
+ global _GLOBAL_DEVICE_CONTEXT
1039
+
1040
+ if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
1041
+ device = _GLOBAL_DEVICE_CONTEXT.device_context.device
1042
+ if device.index is not None:
1043
+ return device
1044
+ else:
1045
+ # TODO: Call like get_device_index() method corresponding to
1046
+ # each device type
1047
+ return torch.tensor([]).device
1048
+ else:
1049
+ return torch.device("cpu")
1050
+
1051
+
1052
+ def set_default_device(
1053
+ device: _Optional[_Union["torch.device", str, builtins.int]],
1054
+ ) -> None:
1055
+ """Sets the default ``torch.Tensor`` to be allocated on ``device``. This
1056
+ does not affect factory function calls which are called with an explicit
1057
+ ``device`` argument. Factory calls will be performed as if they
1058
+ were passed ``device`` as an argument.
1059
+
1060
+ To only temporarily change the default device instead of setting it
1061
+ globally, use ``with torch.device(device):`` instead.
1062
+
1063
+ The default device is initially ``cpu``. If you set the default tensor
1064
+ device to another device (e.g., ``cuda``) without a device index, tensors
1065
+ will be allocated on whatever the current device for the device type,
1066
+ even after :func:`torch.cuda.set_device` is called.
1067
+
1068
+ .. warning::
1069
+
1070
+ This function imposes a slight performance cost on every Python
1071
+ call to the torch API (not just factory functions). If this
1072
+ is causing problems for you, please comment on
1073
+ https://github.com/pytorch/pytorch/issues/92701
1074
+
1075
+ .. note::
1076
+
1077
+ This doesn't affect functions that create tensors that share the same memory as the input, like:
1078
+ :func:`torch.from_numpy` and :func:`torch.frombuffer`
1079
+
1080
+ Args:
1081
+ device (device or string): the device to set as default
1082
+
1083
+ Example::
1084
+
1085
+ >>> # xdoctest: +SKIP("requires cuda, changes global state")
1086
+ >>> torch.get_default_device()
1087
+ device(type='cpu')
1088
+ >>> torch.set_default_device('cuda') # current device is 0
1089
+ >>> torch.get_default_device()
1090
+ device(type='cuda', index=0)
1091
+ >>> torch.set_default_device('cuda')
1092
+ >>> torch.cuda.set_device('cuda:1') # current device is 1
1093
+ >>> torch.get_default_device()
1094
+ device(type='cuda', index=1)
1095
+ >>> torch.set_default_device('cuda:1')
1096
+ >>> torch.get_default_device()
1097
+ device(type='cuda', index=1)
1098
+
1099
+ """
1100
+ global _GLOBAL_DEVICE_CONTEXT
1101
+ if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
1102
+ device_context = _GLOBAL_DEVICE_CONTEXT.device_context
1103
+ if device_context is not None:
1104
+ device_context.__exit__(None, None, None)
1105
+
1106
+ if device is None:
1107
+ device_context = None
1108
+ else:
1109
+ from torch.utils._device import DeviceContext
1110
+
1111
+ device_context = DeviceContext(device)
1112
+ device_context.__enter__()
1113
+ _GLOBAL_DEVICE_CONTEXT.device_context = device_context
1114
+
1115
+
1116
+ def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
1117
+ r"""
1118
+ .. warning::
1119
+
1120
+ This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and
1121
+ :func:`torch.set_default_device()` as alternatives.
1122
+
1123
+ Sets the default ``torch.Tensor`` type to floating point tensor type
1124
+ ``t``. This type will also be used as default floating point type for
1125
+ type inference in :func:`torch.tensor`.
1126
+
1127
+ The default floating point tensor type is initially ``torch.FloatTensor``.
1128
+
1129
+ Args:
1130
+ t (type or string): the floating point tensor type or its name
1131
+
1132
+ Example::
1133
+
1134
+ >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
1135
+ >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
1136
+ torch.float32
1137
+ >>> torch.set_default_tensor_type(torch.DoubleTensor)
1138
+ >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
1139
+ torch.float64
1140
+
1141
+ """
1142
+ if isinstance(t, str):
1143
+ t = _import_dotted_name(t)
1144
+ _C._set_default_tensor_type(t)
1145
+
1146
+
1147
+ def set_default_dtype(d: "torch.dtype", /) -> None:
1148
+ r"""
1149
+
1150
+ Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
1151
+ as inputs. Other dtypes will cause torch to raise an exception.
1152
+
1153
+ When PyTorch is initialized its default floating point dtype is torch.float32,
1154
+ and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
1155
+ type inference. The default floating point dtype is used to:
1156
+
1157
+ 1. Implicitly determine the default complex dtype. When the default floating type is float16,
1158
+ the default complex dtype is complex32. For float32, the default complex dtype is complex64.
1159
+ For float64, it is complex128. For bfloat16, an exception will be raised because
1160
+ there is no corresponding complex type for bfloat16.
1161
+ 2. Infer the dtype for tensors constructed using Python floats or complex Python
1162
+ numbers. See examples below.
1163
+ 3. Determine the result of type promotion between bool and integer tensors and
1164
+ Python floats and complex Python numbers.
1165
+
1166
+ Args:
1167
+ d (:class:`torch.dtype`): the floating point dtype to make the default.
1168
+
1169
+ Example:
1170
+ >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
1171
+ >>> # initial default for floating point is torch.float32
1172
+ >>> # Python floats are interpreted as float32
1173
+ >>> torch.tensor([1.2, 3]).dtype
1174
+ torch.float32
1175
+ >>> # initial default for floating point is torch.complex64
1176
+ >>> # Complex Python numbers are interpreted as complex64
1177
+ >>> torch.tensor([1.2, 3j]).dtype
1178
+ torch.complex64
1179
+
1180
+ >>> torch.set_default_dtype(torch.float64)
1181
+ >>> # Python floats are now interpreted as float64
1182
+ >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
1183
+ torch.float64
1184
+ >>> # Complex Python numbers are now interpreted as complex128
1185
+ >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
1186
+ torch.complex128
1187
+
1188
+ >>> torch.set_default_dtype(torch.float16)
1189
+ >>> # Python floats are now interpreted as float16
1190
+ >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
1191
+ torch.float16
1192
+ >>> # Complex Python numbers are now interpreted as complex128
1193
+ >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
1194
+ torch.complex32
1195
+
1196
+ """
1197
+ _C._set_default_dtype(d)
1198
+
1199
+
1200
+ def use_deterministic_algorithms(
1201
+ mode: builtins.bool,
1202
+ *,
1203
+ warn_only: builtins.bool = False,
1204
+ ) -> None:
1205
+ r"""Sets whether PyTorch operations must use "deterministic"
1206
+ algorithms. That is, algorithms which, given the same input, and when
1207
+ run on the same software and hardware, always produce the same output.
1208
+ When enabled, operations will use deterministic algorithms when available,
1209
+ and if only nondeterministic algorithms are available they will throw a
1210
+ :class:`RuntimeError` when called.
1211
+
1212
+ .. note:: This setting alone is not always enough to make an application
1213
+ reproducible. Refer to :ref:`reproducibility` for more information.
1214
+
1215
+ .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative
1216
+ interface for this feature.
1217
+
1218
+ The following normally-nondeterministic operations will act
1219
+ deterministically when ``mode=True``:
1220
+
1221
+ * :class:`torch.nn.Conv1d` when called on CUDA tensor
1222
+ * :class:`torch.nn.Conv2d` when called on CUDA tensor
1223
+ * :class:`torch.nn.Conv3d` when called on CUDA tensor
1224
+ * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
1225
+ * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
1226
+ * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
1227
+ * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
1228
+ * :func:`torch.bmm` when called on sparse-dense CUDA tensors
1229
+ * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
1230
+ and the index is a list of tensors
1231
+ * :func:`torch.Tensor.index_put` with ``accumulate=False``
1232
+ * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
1233
+ tensor
1234
+ * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
1235
+ tensor
1236
+ * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
1237
+ * :func:`torch.gather` when called on a CUDA tensor that requires grad
1238
+ * :func:`torch.index_add` when called on CUDA tensor
1239
+ * :func:`torch.index_select` when attempting to differentiate a CUDA tensor
1240
+ * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
1241
+ * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
1242
+ * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
1243
+ * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
1244
+
1245
+ The following normally-nondeterministic operations will throw a
1246
+ :class:`RuntimeError` when ``mode=True``:
1247
+
1248
+ * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor
1249
+ * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor
1250
+ * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor
1251
+ * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor
1252
+ * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor
1253
+ * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor
1254
+ * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor
1255
+ * :class:`torch.nn.MaxUnpool1d`
1256
+ * :class:`torch.nn.MaxUnpool2d`
1257
+ * :class:`torch.nn.MaxUnpool3d`
1258
+ * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor
1259
+ and one of the following modes is used:
1260
+
1261
+ - ``linear``
1262
+ - ``bilinear``
1263
+ - ``bicubic``
1264
+ - ``trilinear``
1265
+
1266
+ * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor
1267
+ * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
1268
+ * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
1269
+ * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
1270
+ * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
1271
+ * :class:`torch.nn.NLLLoss` when called on a CUDA tensor
1272
+ * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
1273
+ * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
1274
+ ``mode='max'``
1275
+ * :func:`torch.Tensor.put_` when ``accumulate=False``
1276
+ * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
1277
+ * :func:`torch.histc` when called on a CUDA tensor
1278
+ * :func:`torch.bincount` when called on a CUDA tensor and ``weights``
1279
+ tensor is given
1280
+ * :func:`torch.kthvalue` with called on a CUDA tensor
1281
+ * :func:`torch.median` with indices output when called on a CUDA tensor
1282
+ * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
1283
+ * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
1284
+ * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
1285
+ * :func:`torch.Tensor.resize_` when called with a quantized tensor
1286
+
1287
+ In addition, several operations fill uninitialized memory when this setting
1288
+ is turned on and when
1289
+ :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
1290
+ See the documentation for that attribute for more information.
1291
+
1292
+ A handful of CUDA operations are nondeterministic if the CUDA version is
1293
+ 10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
1294
+ or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more
1295
+ details: `<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_
1296
+ If one of these environment variable configurations is not set, a :class:`RuntimeError`
1297
+ will be raised from these operations when called with CUDA tensors:
1298
+
1299
+ * :func:`torch.mm`
1300
+ * :func:`torch.mv`
1301
+ * :func:`torch.bmm`
1302
+
1303
+ Note that deterministic operations tend to have worse performance than
1304
+ nondeterministic operations.
1305
+
1306
+ .. note::
1307
+
1308
+ This flag does not detect or prevent nondeterministic behavior caused
1309
+ by calling an inplace operation on a tensor with an internal memory
1310
+ overlap or by giving such a tensor as the :attr:`out` argument for an
1311
+ operation. In these cases, multiple writes of different data may target
1312
+ a single memory location, and the order of writes is not guaranteed.
1313
+
1314
+ Args:
1315
+ mode (:class:`bool`): If True, makes potentially nondeterministic
1316
+ operations switch to a deterministic algorithm or throw a runtime
1317
+ error. If False, allows nondeterministic operations.
1318
+
1319
+ Keyword args:
1320
+ warn_only (:class:`bool`, optional): If True, operations that do not
1321
+ have a deterministic implementation will throw a warning instead of
1322
+ an error. Default: ``False``
1323
+
1324
+ Example::
1325
+
1326
+ >>> # xdoctest: +SKIP
1327
+ >>> torch.use_deterministic_algorithms(True)
1328
+
1329
+ # Forward mode nondeterministic error
1330
+ >>> torch.randn(10, device='cuda').kthvalue(1)
1331
+ ...
1332
+ RuntimeError: kthvalue CUDA does not have a deterministic implementation...
1333
+
1334
+ # Backward mode nondeterministic error
1335
+ >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward()
1336
+ ...
1337
+ RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation...
1338
+ """
1339
+ _C._set_deterministic_algorithms(mode, warn_only=warn_only)
1340
+
1341
+
1342
+ def are_deterministic_algorithms_enabled() -> builtins.bool:
1343
+ r"""Returns True if the global deterministic flag is turned on. Refer to
1344
+ :func:`torch.use_deterministic_algorithms` documentation for more details.
1345
+ """
1346
+ return _C._get_deterministic_algorithms()
1347
+
1348
+
1349
+ def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool:
1350
+ r"""Returns True if the global deterministic flag is set to warn only.
1351
+ Refer to :func:`torch.use_deterministic_algorithms` documentation for more
1352
+ details.
1353
+ """
1354
+ return _C._get_deterministic_algorithms_warn_only()
1355
+
1356
+
1357
+ def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None:
1358
+ r"""Sets the debug mode for deterministic operations.
1359
+
1360
+ .. note:: This is an alternative interface for
1361
+ :func:`torch.use_deterministic_algorithms`. Refer to that function's
1362
+ documentation for details about affected operations.
1363
+
1364
+ Args:
1365
+ debug_mode(str or int): If "default" or 0, don't error or warn on
1366
+ nondeterministic operations. If "warn" or 1, warn on
1367
+ nondeterministic operations. If "error" or 2, error on
1368
+ nondeterministic operations.
1369
+ """
1370
+
1371
+ # NOTE: builtins.int is used here because int in this scope resolves
1372
+ # to torch.int
1373
+ if not isinstance(debug_mode, (builtins.int, str)):
1374
+ raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}")
1375
+
1376
+ if isinstance(debug_mode, str):
1377
+ if debug_mode == "default":
1378
+ debug_mode = 0
1379
+ elif debug_mode == "warn":
1380
+ debug_mode = 1
1381
+ elif debug_mode == "error":
1382
+ debug_mode = 2
1383
+ else:
1384
+ raise RuntimeError(
1385
+ "invalid value of debug_mode, expected one of `default`, "
1386
+ f"`warn`, `error`, but got {debug_mode}"
1387
+ )
1388
+
1389
+ if debug_mode == 0:
1390
+ _C._set_deterministic_algorithms(False)
1391
+ elif debug_mode == 1:
1392
+ _C._set_deterministic_algorithms(True, warn_only=True)
1393
+ elif debug_mode == 2:
1394
+ _C._set_deterministic_algorithms(True)
1395
+ else:
1396
+ raise RuntimeError(
1397
+ "invalid value of debug_mode, expected 0, 1, or 2, " f"but got {debug_mode}"
1398
+ )
1399
+
1400
+
1401
+ def get_deterministic_debug_mode() -> builtins.int:
1402
+ r"""Returns the current value of the debug mode for deterministic
1403
+ operations. Refer to :func:`torch.set_deterministic_debug_mode`
1404
+ documentation for more details.
1405
+ """
1406
+
1407
+ if _C._get_deterministic_algorithms():
1408
+ if _C._get_deterministic_algorithms_warn_only():
1409
+ return 1
1410
+ else:
1411
+ return 2
1412
+ else:
1413
+ return 0
1414
+
1415
+
1416
+ def get_float32_matmul_precision() -> str:
1417
+ r"""Returns the current value of float32 matrix multiplication precision. Refer to
1418
+ :func:`torch.set_float32_matmul_precision` documentation for more details.
1419
+ """
1420
+ return _C._get_float32_matmul_precision()
1421
+
1422
+
1423
+ def set_float32_matmul_precision(precision: str) -> None:
1424
+ r"""Sets the internal precision of float32 matrix multiplications.
1425
+
1426
+ Running float32 matrix multiplications in lower precision may significantly increase
1427
+ performance, and in some programs the loss of precision has a negligible impact.
1428
+
1429
+ Supports three settings:
1430
+
1431
+ * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
1432
+ bits with 23 bits explicitly stored) for internal computations.
1433
+ * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
1434
+ mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
1435
+ (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
1436
+ algorithms are available. Otherwise float32 matrix multiplications are computed
1437
+ as if the precision is "highest". See below for more information on the bfloat16
1438
+ approach.
1439
+ * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
1440
+ bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
1441
+ using that datatype internally is available. Otherwise float32
1442
+ matrix multiplications are computed as if the precision is "high".
1443
+
1444
+ When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
1445
+ that is more complicated than simply truncating to some smaller number mantissa bits
1446
+ (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete
1447
+ description of this algorithm. To briefly explain here, the first step is to realize
1448
+ that we can perfectly encode a single float32 number as the sum of three bfloat16
1449
+ numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
1450
+ same number of exponent bits). This means that the product of two float32 numbers can
1451
+ be exactly given by the sum of nine products of bfloat16 numbers. We can then trade
1452
+ accuracy for speed by dropping some of these products. The "high" precision algorithm
1453
+ specifically keeps only the three most significant products, which conveniently excludes
1454
+ all of the products involving the last 8 mantissa bits of either input. This means that
1455
+ we can represent our inputs as the sum of two bfloat16 numbers rather than three.
1456
+ Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than
1457
+ float32 ones, it's faster to do three multiplications and 2 additions with bfloat16
1458
+ precision than it is to do a single multiplication with float32 precision.
1459
+
1460
+ .. [Henry2019] http://arxiv.org/abs/1904.06376
1461
+
1462
+ .. note::
1463
+
1464
+ This does not change the output dtype of float32 matrix multiplications,
1465
+ it controls how the internal computation of the matrix multiplication is performed.
1466
+
1467
+ .. note::
1468
+
1469
+ This does not change the precision of convolution operations. Other flags,
1470
+ like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution
1471
+ operations.
1472
+
1473
+ .. note::
1474
+
1475
+ This flag currently only affects one native device type: CUDA.
1476
+ If "high" or "medium" are set then the TensorFloat32 datatype will be used
1477
+ when computing float32 matrix multiplications, equivalent to setting
1478
+ `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default)
1479
+ is set then the float32 datatype is used for internal computations, equivalent
1480
+ to setting `torch.backends.cuda.matmul.allow_tf32 = False`.
1481
+
1482
+ Args:
1483
+ precision(str): can be set to "highest" (default), "high", or "medium" (see above).
1484
+
1485
+ """
1486
+ _C._set_float32_matmul_precision(precision)
1487
+
1488
+
1489
+ def set_warn_always(b: builtins.bool, /) -> None:
1490
+ r"""When this flag is False (default) then some PyTorch warnings may only
1491
+ appear once per process. This helps avoid excessive warning information.
1492
+ Setting it to True causes these warnings to always appear, which may be
1493
+ helpful when debugging.
1494
+
1495
+ Args:
1496
+ b (:class:`bool`): If True, force warnings to always be emitted
1497
+ If False, set to the default behaviour
1498
+ """
1499
+ _C._set_warnAlways(b)
1500
+
1501
+
1502
+ def is_warn_always_enabled() -> builtins.bool:
1503
+ r"""Returns True if the global warn_always flag is turned on. Refer to
1504
+ :func:`torch.set_warn_always` documentation for more details.
1505
+ """
1506
+ return _C._get_warnAlways()
1507
+
1508
+
1509
+ ################################################################################
1510
+ # Define error checking functions
1511
+ ################################################################################
1512
+
1513
+ # These error checking functions must be kept consistent with their C++
1514
+ # equivalents. Their C++ equivalents are mentioned where applicable.
1515
+
1516
+
1517
+ def _check_with(
1518
+ error_type,
1519
+ cond: _Union[builtins.bool, SymBool],
1520
+ message: _Callable[[], str],
1521
+ ): # noqa: F811
1522
+ if not isinstance(cond, (builtins.bool, SymBool)):
1523
+ raise TypeError(f"cond must be a bool, but got {type(cond)}")
1524
+
1525
+ from torch.fx.experimental.symbolic_shapes import expect_true
1526
+
1527
+ if expect_true(cond):
1528
+ return
1529
+
1530
+ # error_type must be a subclass of Exception and not subclass of Warning
1531
+ assert issubclass(error_type, Exception) and not issubclass(error_type, Warning)
1532
+
1533
+ if message is None:
1534
+ message_evaluated = (
1535
+ "Expected cond to be True, but got False. (Could this error "
1536
+ "message be improved? If so, please report an enhancement request "
1537
+ "to PyTorch.)"
1538
+ )
1539
+
1540
+ else:
1541
+ if not callable(message):
1542
+ raise TypeError("message must be a callable")
1543
+
1544
+ message_evaluated = str(message())
1545
+
1546
+ raise error_type(message_evaluated)
1547
+
1548
+
1549
+ def _check(cond, message=None): # noqa: F811
1550
+ r"""Throws error containing an optional message if the specified condition
1551
+ is False.
1552
+
1553
+ Error type: ``RuntimeError``
1554
+
1555
+ C++ equivalent: ``TORCH_CHECK``
1556
+
1557
+ Args:
1558
+ cond (:class:`bool`): If False, throw error
1559
+
1560
+ message (Callable, optional): Callable that returns either a string or
1561
+ an object that has a ``__str__()`` method to be used as the error
1562
+ message. Default: ``None``
1563
+ """
1564
+ _check_with(RuntimeError, cond, message)
1565
+
1566
+
1567
+ def _check_is_size(i, message=None):
1568
+ """Checks that a given integer is a valid size (i.e., is non-negative).
1569
+ You should use this over _check(i >= 0) because we can use the semantic
1570
+ information (that i is a size) to make some further inferences in case
1571
+ i is an unbacked SymInt.
1572
+
1573
+ NB: Do NOT use this in contexts where a -1 size would be valid (indicating
1574
+ to infer the size from context, or if you should wrap-around or truncate).
1575
+ Only use this if the only valid value is an honest to goodness size.
1576
+ """
1577
+ # This is responsible for the expect_true
1578
+ _check(i >= 0, message)
1579
+ from torch.fx.experimental.symbolic_shapes import _advise_is_size
1580
+
1581
+ _advise_is_size(i)
1582
+
1583
+
1584
+ def _check_index(cond, message=None): # noqa: F811
1585
+ r"""Throws error containing an optional message if the specified condition
1586
+ is False.
1587
+
1588
+ Error type: ``IndexError``
1589
+
1590
+ C++ equivalent: ``TORCH_CHECK_INDEX``
1591
+
1592
+ Args:
1593
+ cond (:class:`bool`): If False, throw error
1594
+
1595
+ message (Callable, optional): Callable that returns either a string or
1596
+ an object that has a ``__str__()`` method to be used as the error
1597
+ message. Default: ``None``
1598
+ """
1599
+ _check_with(IndexError, cond, message)
1600
+
1601
+
1602
+ def _check_value(cond, message=None): # noqa: F811
1603
+ r"""Throws error containing an optional message if the specified condition
1604
+ is False.
1605
+
1606
+ Error type: ``ValueError``
1607
+
1608
+ C++ equivalent: ``TORCH_CHECK_VALUE``
1609
+
1610
+ Args:
1611
+ cond (:class:`bool`): If False, throw error
1612
+
1613
+ message (Callable, optional): Callable that returns either a string or
1614
+ an object that has a ``__str__()`` method to be used as the error
1615
+ message. Default: ``None``
1616
+ """
1617
+ _check_with(ValueError, cond, message)
1618
+
1619
+
1620
+ def _check_type(cond, message=None): # noqa: F811
1621
+ r"""Throws error containing an optional message if the specified condition
1622
+ is False.
1623
+
1624
+ Error type: ``TypeError``
1625
+
1626
+ C++ equivalent: ``TORCH_CHECK_TYPE``
1627
+
1628
+ Args:
1629
+ cond (:class:`bool`): If False, throw error
1630
+
1631
+ message (Callable, optional): Callable that returns either a string or
1632
+ an object that has a ``__str__()`` method to be used as the error
1633
+ message. Default: ``None``
1634
+ """
1635
+ _check_with(TypeError, cond, message)
1636
+
1637
+
1638
+ def _check_not_implemented(cond, message=None): # noqa: F811
1639
+ r"""Throws error containing an optional message if the specified condition
1640
+ is False.
1641
+
1642
+ Error type: ``NotImplementedError``
1643
+
1644
+ C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED``
1645
+
1646
+ Args:
1647
+ cond (:class:`bool`): If False, throw error
1648
+
1649
+ message (Callable, optional): Callable that returns either a string or
1650
+ an object that has a ``__str__()`` method to be used as the error
1651
+ message. Default: ``None``
1652
+ """
1653
+ _check_with(NotImplementedError, cond, message)
1654
+
1655
+
1656
+ def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
1657
+ if not is_tensor(cond):
1658
+ raise TypeError(f"cond must be a tensor, but got {type(cond)}")
1659
+
1660
+ if not cond.dtype == torch.bool:
1661
+ raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")
1662
+
1663
+ _check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type]
1664
+
1665
+
1666
+ # C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
1667
+ def _check_tensor_all(cond, message=None): # noqa: F811
1668
+ r"""Throws error containing an optional message if the specified condition
1669
+ is False.
1670
+
1671
+ Error type: ``RuntimeError``
1672
+
1673
+ C++ equivalent: ``TORCH_CHECK_TENSOR_ALL``
1674
+
1675
+ Args:
1676
+ cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any
1677
+ element is ``False``, throw error
1678
+
1679
+ message (Callable, optional): Callable that returns either a string or
1680
+ an object that has a ``__str__()`` method to be used as the error
1681
+ message. Default: ``None``
1682
+ """
1683
+ _check_tensor_all_with(RuntimeError, cond, message)
1684
+
1685
+
1686
+ ################################################################################
1687
+ # Define numeric constants
1688
+ ################################################################################
1689
+
1690
+ # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and
1691
+ # NumPy consistency (https://numpy.org/devdocs/reference/constants.html)
1692
+ from math import e, inf, nan, pi
1693
+
1694
+
1695
+ newaxis: None = None
1696
+
1697
+ __all__.extend(["e", "pi", "nan", "inf", "newaxis"])
1698
+
1699
+ ################################################################################
1700
+ # Define Storage and Tensor classes
1701
+ ################################################################################
1702
+
1703
+ from torch._tensor import Tensor # usort: skip
1704
+
1705
+ # needs to be after torch.Tensor is defined to avoid circular dependencies
1706
+ from torch import storage as storage # usort: skip
1707
+ from torch.storage import (
1708
+ _LegacyStorage,
1709
+ _StorageBase,
1710
+ _warn_typed_storage_removal,
1711
+ TypedStorage,
1712
+ UntypedStorage,
1713
+ )
1714
+
1715
+
1716
+ # NOTE: New <type>Storage classes should never be added. When adding a new
1717
+ # dtype, use torch.storage.TypedStorage directly.
1718
+ class ByteStorage(_LegacyStorage):
1719
+ @classproperty
1720
+ def dtype(self):
1721
+ _warn_typed_storage_removal(stacklevel=3)
1722
+ return self._dtype
1723
+
1724
+ @classproperty
1725
+ def _dtype(self):
1726
+ return torch.uint8
1727
+
1728
+
1729
+ class DoubleStorage(_LegacyStorage):
1730
+ @classproperty
1731
+ def dtype(self):
1732
+ _warn_typed_storage_removal(stacklevel=3)
1733
+ return self._dtype
1734
+
1735
+ @classproperty
1736
+ def _dtype(self):
1737
+ return torch.double
1738
+
1739
+
1740
+ class FloatStorage(_LegacyStorage):
1741
+ @classproperty
1742
+ def dtype(self):
1743
+ _warn_typed_storage_removal(stacklevel=3)
1744
+ return self._dtype
1745
+
1746
+ @classproperty
1747
+ def _dtype(self):
1748
+ return torch.float
1749
+
1750
+
1751
+ class HalfStorage(_LegacyStorage):
1752
+ @classproperty
1753
+ def dtype(self):
1754
+ _warn_typed_storage_removal(stacklevel=3)
1755
+ return self._dtype
1756
+
1757
+ @classproperty
1758
+ def _dtype(self):
1759
+ return torch.half
1760
+
1761
+
1762
+ class LongStorage(_LegacyStorage):
1763
+ @classproperty
1764
+ def dtype(self):
1765
+ _warn_typed_storage_removal(stacklevel=3)
1766
+ return self._dtype
1767
+
1768
+ @classproperty
1769
+ def _dtype(self):
1770
+ return torch.long
1771
+
1772
+
1773
+ class IntStorage(_LegacyStorage):
1774
+ @classproperty
1775
+ def dtype(self):
1776
+ _warn_typed_storage_removal(stacklevel=3)
1777
+ return self._dtype
1778
+
1779
+ @classproperty
1780
+ def _dtype(self):
1781
+ return torch.int
1782
+
1783
+
1784
+ class ShortStorage(_LegacyStorage):
1785
+ @classproperty
1786
+ def dtype(self):
1787
+ _warn_typed_storage_removal(stacklevel=3)
1788
+ return self._dtype
1789
+
1790
+ @classproperty
1791
+ def _dtype(self):
1792
+ return torch.short
1793
+
1794
+
1795
+ class CharStorage(_LegacyStorage):
1796
+ @classproperty
1797
+ def dtype(self):
1798
+ _warn_typed_storage_removal(stacklevel=3)
1799
+ return self._dtype
1800
+
1801
+ @classproperty
1802
+ def _dtype(self):
1803
+ return torch.int8
1804
+
1805
+
1806
+ class BoolStorage(_LegacyStorage):
1807
+ @classproperty
1808
+ def dtype(self):
1809
+ _warn_typed_storage_removal(stacklevel=3)
1810
+ return self._dtype
1811
+
1812
+ @classproperty
1813
+ def _dtype(self):
1814
+ return torch.bool
1815
+
1816
+
1817
+ class BFloat16Storage(_LegacyStorage):
1818
+ @classproperty
1819
+ def dtype(self):
1820
+ _warn_typed_storage_removal(stacklevel=3)
1821
+ return self._dtype
1822
+
1823
+ @classproperty
1824
+ def _dtype(self):
1825
+ return torch.bfloat16
1826
+
1827
+
1828
+ class ComplexDoubleStorage(_LegacyStorage):
1829
+ @classproperty
1830
+ def dtype(self):
1831
+ _warn_typed_storage_removal(stacklevel=3)
1832
+ return self._dtype
1833
+
1834
+ @classproperty
1835
+ def _dtype(self):
1836
+ return torch.cdouble
1837
+
1838
+
1839
+ class ComplexFloatStorage(_LegacyStorage):
1840
+ @classproperty
1841
+ def dtype(self):
1842
+ _warn_typed_storage_removal(stacklevel=3)
1843
+ return self._dtype
1844
+
1845
+ @classproperty
1846
+ def _dtype(self):
1847
+ return torch.cfloat
1848
+
1849
+
1850
+ class QUInt8Storage(_LegacyStorage):
1851
+ @classproperty
1852
+ def dtype(self):
1853
+ _warn_typed_storage_removal(stacklevel=3)
1854
+ return self._dtype
1855
+
1856
+ @classproperty
1857
+ def _dtype(self):
1858
+ return torch.quint8
1859
+
1860
+
1861
+ class QInt8Storage(_LegacyStorage):
1862
+ @classproperty
1863
+ def dtype(self):
1864
+ _warn_typed_storage_removal(stacklevel=3)
1865
+ return self._dtype
1866
+
1867
+ @classproperty
1868
+ def _dtype(self):
1869
+ return torch.qint8
1870
+
1871
+
1872
+ class QInt32Storage(_LegacyStorage):
1873
+ @classproperty
1874
+ def dtype(self):
1875
+ _warn_typed_storage_removal(stacklevel=3)
1876
+ return self._dtype
1877
+
1878
+ @classproperty
1879
+ def _dtype(self):
1880
+ return torch.qint32
1881
+
1882
+
1883
+ class QUInt4x2Storage(_LegacyStorage):
1884
+ @classproperty
1885
+ def dtype(self):
1886
+ _warn_typed_storage_removal(stacklevel=3)
1887
+ return self._dtype
1888
+
1889
+ @classproperty
1890
+ def _dtype(self):
1891
+ return torch.quint4x2
1892
+
1893
+
1894
+ class QUInt2x4Storage(_LegacyStorage):
1895
+ @classproperty
1896
+ def dtype(self):
1897
+ _warn_typed_storage_removal(stacklevel=3)
1898
+ return self._dtype
1899
+
1900
+ @classproperty
1901
+ def _dtype(self):
1902
+ return torch.quint2x4
1903
+
1904
+
1905
+ _storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
1906
+ UntypedStorage,
1907
+ DoubleStorage,
1908
+ FloatStorage,
1909
+ LongStorage,
1910
+ IntStorage,
1911
+ ShortStorage,
1912
+ CharStorage,
1913
+ ByteStorage,
1914
+ HalfStorage,
1915
+ BoolStorage,
1916
+ QUInt8Storage,
1917
+ QInt8Storage,
1918
+ QInt32Storage,
1919
+ BFloat16Storage,
1920
+ ComplexFloatStorage,
1921
+ ComplexDoubleStorage,
1922
+ QUInt4x2Storage,
1923
+ QUInt2x4Storage,
1924
+ TypedStorage,
1925
+ }
1926
+
1927
+ # The _tensor_classes set is initialized by the call to initialize_python_bindings.
1928
+ _tensor_classes: _Set[_Type["torch.Tensor"]] = set()
1929
+
1930
+ # If you edit these imports, please update torch/__init__.py.in as well
1931
+ from torch import amp as amp, random as random, serialization as serialization
1932
+ from torch._tensor_str import set_printoptions
1933
+ from torch.amp import autocast, GradScaler
1934
+ from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
1935
+ from torch.serialization import load, save
1936
+
1937
+
1938
+ ################################################################################
1939
+ # Initialize extension
1940
+ ################################################################################
1941
+
1942
+
1943
+ # Shared memory manager needs to know the exact location of manager executable
1944
+ def _manager_path():
1945
+ if _running_with_deploy() or platform.system() == "Windows":
1946
+ return b""
1947
+ path = get_file_path("torch", "bin", "torch_shm_manager")
1948
+ prepare_multiprocessing_environment(get_file_path("torch"))
1949
+ if not os.path.exists(path):
1950
+ raise RuntimeError("Unable to find torch_shm_manager at " + path)
1951
+ return path.encode("utf-8")
1952
+
1953
+
1954
+ _C._initExtension(_manager_path())
1955
+
1956
+ del _manager_path
1957
+
1958
+ # Appease the type checker: it can't deal with direct setting of globals().
1959
+ # Note that we will see "too many" functions when reexporting this way; there
1960
+ # is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions
1961
+ # so that this import is good enough
1962
+ if TYPE_CHECKING:
1963
+ # Some type signatures pulled in from _VariableFunctions here clash with
1964
+ # signatures already imported. For now these clashes are ignored; see
1965
+ # PR #43339 for details.
1966
+ from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
1967
+
1968
+ # Fixup segment_reduce visibility
1969
+ _segment_reduce = segment_reduce
1970
+ del segment_reduce # noqa: F821
1971
+
1972
+ # Ops not to be exposed in `torch` namespace,
1973
+ # mostly helper ops.
1974
+ PRIVATE_OPS = ("unique_dim",)
1975
+
1976
+ __name, __obj = "", None
1977
+ for __name in dir(_C._VariableFunctions):
1978
+ if __name.startswith("__") or __name in PRIVATE_OPS:
1979
+ continue
1980
+ __obj = getattr(_C._VariableFunctions, __name)
1981
+ __obj.__module__ = __name__ # "torch"
1982
+ # Hide some APIs that should not be public
1983
+ if __name == "segment_reduce":
1984
+ # TODO: Once the undocumented FC window is passed, remove the line bellow
1985
+ globals()[__name] = __obj
1986
+ __name = "_" + __name
1987
+ globals()[__name] = __obj
1988
+ if not __name.startswith("_"):
1989
+ __all__.append(__name)
1990
+
1991
+ del __name, __obj
1992
+
1993
+ ################################################################################
1994
+ # Add torch.dtype instances to the public API
1995
+ ################################################################################
1996
+
1997
+ import torch
1998
+
1999
+
2000
+ __all__.extend(
2001
+ name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype)
2002
+ )
2003
+
2004
+ ################################################################################
2005
+ # Import TorchDynamo's lazy APIs to avoid circular dependenices
2006
+ ################################################################################
2007
+
2008
+ # needs to be before from torch.functional import * to avoid circular dependencies
2009
+ from torch._compile import _disable_dynamo # usort: skip
2010
+
2011
+ ################################################################################
2012
+ # Import interface functions defined in Python
2013
+ ################################################################################
2014
+
2015
+ # needs to be after the above ATen bindings so we can overwrite from Python side
2016
+ from torch import _VF as _VF, functional as functional # usort: skip
2017
+ from torch.functional import * # usort: skip # noqa: F403
2018
+
2019
+ ################################################################################
2020
+ # Remove unnecessary members
2021
+ ################################################################################
2022
+
2023
+ del _StorageBase
2024
+ del _LegacyStorage
2025
+
2026
+ ################################################################################
2027
+ # Define _assert
2028
+ ################################################################################
2029
+
2030
+
2031
+ # needs to be before the submodule imports to avoid circular dependencies
2032
+ def _assert(condition, message):
2033
+ r"""A wrapper around Python's assert which is symbolically traceable."""
2034
+ if type(condition) is not torch.Tensor and overrides.has_torch_function(
2035
+ (condition,)
2036
+ ):
2037
+ return overrides.handle_torch_function(
2038
+ _assert, (condition,), condition, message
2039
+ )
2040
+ assert condition, message
2041
+
2042
+
2043
+ ################################################################################
2044
+ # Import most common subpackages
2045
+ ################################################################################
2046
+
2047
+ # Use the redundant form so that type checkers know that these are a part of
2048
+ # the public API. The "regular" import lines are there solely for the runtime
2049
+ # side effect of adding to the imported module's members for other users.
2050
+
2051
+ # needs to be before import torch.nn as nn to avoid circular dependencies
2052
+ from torch.autograd import ( # usort: skip
2053
+ enable_grad as enable_grad,
2054
+ inference_mode as inference_mode,
2055
+ no_grad as no_grad,
2056
+ set_grad_enabled as set_grad_enabled,
2057
+ )
2058
+
2059
+ from torch import (
2060
+ __config__ as __config__,
2061
+ __future__ as __future__,
2062
+ _awaits as _awaits,
2063
+ autograd as autograd,
2064
+ backends as backends,
2065
+ cpu as cpu,
2066
+ cuda as cuda,
2067
+ distributed as distributed,
2068
+ distributions as distributions,
2069
+ fft as fft,
2070
+ futures as futures,
2071
+ hub as hub,
2072
+ jit as jit,
2073
+ linalg as linalg,
2074
+ mps as mps,
2075
+ mtia as mtia,
2076
+ multiprocessing as multiprocessing,
2077
+ nested as nested,
2078
+ nn as nn,
2079
+ optim as optim,
2080
+ overrides as overrides,
2081
+ profiler as profiler,
2082
+ sparse as sparse,
2083
+ special as special,
2084
+ testing as testing,
2085
+ types as types,
2086
+ utils as utils,
2087
+ xpu as xpu,
2088
+ )
2089
+ from torch.signal import windows as windows
2090
+
2091
+
2092
+ # Quantized, sparse, AO, etc. should be last to get imported, as nothing
2093
+ # is expected to depend on them.
2094
+ from torch import ao as ao # usort: skip
2095
+
2096
+ # nn.quant* depends on ao -- so should be after those.
2097
+ import torch.nn.intrinsic
2098
+ import torch.nn.qat
2099
+ import torch.nn.quantizable
2100
+ import torch.nn.quantized
2101
+
2102
+
2103
+ _C._init_names(list(_storage_classes))
2104
+
2105
+ # attach docstrings to torch and tensor functions
2106
+ from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs
2107
+
2108
+
2109
+ del _torch_docs, _tensor_docs, _storage_docs, _size_docs
2110
+
2111
+
2112
+ def compiled_with_cxx11_abi() -> builtins.bool:
2113
+ r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
2114
+ return _C._GLIBCXX_USE_CXX11_ABI
2115
+
2116
+
2117
+ from torch import _library as _library, _ops as _ops
2118
+
2119
+
2120
+ # Import the ops and classes "namespace"
2121
+ from torch._ops import ops as ops # usort: skip
2122
+ from torch._classes import classes as classes # usort: skip
2123
+
2124
+ sys.modules.setdefault(f"{__name__}.ops", ops)
2125
+ sys.modules.setdefault(f"{__name__}.classes", classes)
2126
+
2127
+ # quantization depends on torch.fx and torch.ops
2128
+ # Import quantization
2129
+ from torch import quantization as quantization # usort: skip
2130
+
2131
+ # Import the quasi random sampler
2132
+ from torch import quasirandom as quasirandom # usort: skip
2133
+
2134
+ # If you are seeing this, it means that this call site was not checked if
2135
+ # the memory format could be preserved, and it was switched to old default
2136
+ # behaviour of contiguous
2137
+ legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
2138
+
2139
+ # Register fork handler to initialize OpenMP in child processes (see gh-28389)
2140
+ from torch.multiprocessing._atfork import register_after_fork
2141
+
2142
+
2143
+ register_after_fork(torch.get_num_threads)
2144
+ del register_after_fork
2145
+
2146
+ # Import tools that require fully imported torch (for applying
2147
+ # torch.jit.script as a decorator, for instance):
2148
+ from torch._lobpcg import lobpcg as lobpcg
2149
+
2150
+
2151
+ # These were previously defined in native_functions.yaml and appeared on the
2152
+ # `torch` namespace, but we moved them to c10 dispatch to facilitate custom
2153
+ # class usage. We add these lines here to preserve backward compatibility.
2154
+ quantized_lstm = ops.aten.quantized_lstm
2155
+ quantized_gru = ops.aten.quantized_gru
2156
+
2157
+ # Import experimental masked operations support. See
2158
+ # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
2159
+ # information.
2160
+ from torch import masked as masked
2161
+
2162
+ # Import removed ops with error message about removal
2163
+ from torch._linalg_utils import ( # type: ignore[misc]
2164
+ _symeig as symeig,
2165
+ eig,
2166
+ lstsq,
2167
+ matrix_rank,
2168
+ solve,
2169
+ )
2170
+ from torch.utils.dlpack import from_dlpack, to_dlpack
2171
+
2172
+
2173
+ class _TorchCompileInductorWrapper:
2174
+ compiler_name = "inductor"
2175
+
2176
+ def __init__(self, mode, options, dynamic):
2177
+ self.config: _Dict[str, _Any] = {}
2178
+ self.dynamic = dynamic
2179
+ self.apply_mode(mode)
2180
+ self.apply_options(options)
2181
+
2182
+ if self.config.get("triton.cudagraphs", False):
2183
+ os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
2184
+ # FIXME: CUDA Graph does not work well with CUPTI teardown.
2185
+ # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
2186
+ # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
2187
+ # Workaround: turn off CUPTI teardown when using CUDA Graphs.
2188
+ os.environ["TEARDOWN_CUPTI"] = "0"
2189
+
2190
+ def __eq__(self, other):
2191
+ return (
2192
+ isinstance(other, _TorchCompileInductorWrapper)
2193
+ and self.config == other.config
2194
+ and self.dynamic == other.dynamic
2195
+ )
2196
+
2197
+ def apply_mode(self, mode: _Optional[str]):
2198
+ if mode is None or mode == "default":
2199
+ pass
2200
+ elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}:
2201
+ from torch._inductor import list_mode_options
2202
+
2203
+ self.apply_options(list_mode_options(mode, self.dynamic))
2204
+ else:
2205
+ raise RuntimeError(
2206
+ f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs"
2207
+ )
2208
+
2209
+ def apply_options(self, options: _Optional[_Dict[str, _Any]]):
2210
+ if not options:
2211
+ return
2212
+
2213
+ from torch._inductor import config
2214
+
2215
+ current_config: _Dict[str, _Any] = config.shallow_copy_dict()
2216
+
2217
+ for key, val in options.items():
2218
+ attr_name = key.replace("-", "_")
2219
+ if attr_name not in current_config:
2220
+ raise RuntimeError(
2221
+ f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
2222
+ )
2223
+ if type(val) is not type(current_config[attr_name]):
2224
+ val_type_str = type(val).__name__
2225
+ expected_type_str = type(current_config[attr_name]).__name__
2226
+ raise RuntimeError(
2227
+ f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
2228
+ )
2229
+ self.config[attr_name] = val
2230
+
2231
+ def __call__(self, model_, inputs_):
2232
+ from torch._inductor.compile_fx import compile_fx
2233
+
2234
+ return compile_fx(model_, inputs_, config_patches=self.config)
2235
+
2236
+ def get_compiler_config(self):
2237
+ from torch._inductor.compile_fx import get_patched_config_dict
2238
+
2239
+ return get_patched_config_dict(config_patches=self.config)
2240
+
2241
+ def reset(self):
2242
+ from torch._inductor import config
2243
+
2244
+ if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
2245
+ if self.config.get("triton.cudagraphs", True):
2246
+ from torch._inductor.cudagraph_trees import reset_cudagraph_trees
2247
+
2248
+ reset_cudagraph_trees()
2249
+
2250
+
2251
+ class _TorchCompileWrapper:
2252
+ def __init__(self, backend, mode, options, dynamic):
2253
+ from torch._dynamo.backends.registry import lookup_backend
2254
+
2255
+ if isinstance(backend, str):
2256
+ self.compiler_name = backend
2257
+ elif hasattr(backend, "__name__"):
2258
+ self.compiler_name = backend.__name__
2259
+ else:
2260
+ self.compiler_name = str(backend)
2261
+ self.dynamic = dynamic
2262
+ self.compiler_fn = lookup_backend(backend)
2263
+ self.kwargs = {}
2264
+ # only pass the args if they non-empty
2265
+ if mode and mode != "default":
2266
+ self.kwargs["mode"] = mode
2267
+ if options:
2268
+ self.kwargs["options"] = options
2269
+
2270
+ def __eq__(self, other):
2271
+ return (
2272
+ isinstance(other, _TorchCompileWrapper)
2273
+ and self.compiler_fn == other.compiler_fn
2274
+ and self.kwargs == other.kwargs
2275
+ and self.dynamic == other.dynamic
2276
+ )
2277
+
2278
+ def __call__(self, model_, inputs_):
2279
+ return self.compiler_fn(model_, inputs_, **self.kwargs)
2280
+
2281
+ def reset(self):
2282
+ if hasattr(self.compiler_fn, "reset"):
2283
+ self.compiler_fn.reset()
2284
+
2285
+
2286
+ _InputT = _ParamSpec("_InputT")
2287
+ _RetT = _TypeVar("_RetT")
2288
+
2289
+
2290
+ @_overload
2291
+ def compile(
2292
+ model: _Callable[_InputT, _RetT],
2293
+ *,
2294
+ fullgraph: builtins.bool = False,
2295
+ dynamic: _Optional[builtins.bool] = None,
2296
+ backend: _Union[str, _Callable] = "inductor",
2297
+ mode: _Union[str, None] = None,
2298
+ options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2299
+ disable: builtins.bool = False,
2300
+ ) -> _Callable[_InputT, _RetT]: ...
2301
+
2302
+
2303
+ @_overload
2304
+ def compile(
2305
+ model: None = None,
2306
+ *,
2307
+ fullgraph: builtins.bool = False,
2308
+ dynamic: _Optional[builtins.bool] = None,
2309
+ backend: _Union[str, _Callable] = "inductor",
2310
+ mode: _Union[str, None] = None,
2311
+ options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2312
+ disable: builtins.bool = False,
2313
+ ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
2314
+
2315
+
2316
+ def compile(
2317
+ model: _Optional[_Callable] = None,
2318
+ *,
2319
+ fullgraph: builtins.bool = False,
2320
+ dynamic: _Optional[builtins.bool] = None,
2321
+ backend: _Union[str, _Callable] = "inductor",
2322
+ mode: _Union[str, None] = None,
2323
+ options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
2324
+ disable: builtins.bool = False,
2325
+ ) -> _Union[
2326
+ _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
2327
+ _Callable[_InputT, _RetT],
2328
+ ]:
2329
+ """
2330
+ Optimizes given model/function using TorchDynamo and specified backend.
2331
+ If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile`
2332
+ to compile the module inplace without changing its structure.
2333
+
2334
+ Concretely, for every frame executed within the compiled region, we will attempt
2335
+ to compile it and cache the compiled result on the code object for future
2336
+ use. A single frame may be compiled multiple times if previous compiled
2337
+ results are not applicable for subsequent calls (this is called a "guard
2338
+ failure), you can use TORCH_LOGS=guards to debug these situations.
2339
+ Multiple compiled results can be associated with a frame up to
2340
+ ``torch._dynamo.config.cache_size_limit``, which defaults to 8; at which
2341
+ point we will fall back to eager. Note that compile caches are per
2342
+ *code object*, not frame; if you dynamically create multiple copies of a
2343
+ function, they will all share the same code cache.
2344
+
2345
+ Args:
2346
+ model (Callable): Module/function to optimize
2347
+ fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions
2348
+ in the function that it will optimize. If True, then we require that the entire function be
2349
+ capturable into a single graph. If this is not possible (that is, if there are graph breaks),
2350
+ then this will raise an error.
2351
+ dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
2352
+ to generate a kernel that is as dynamic as possible to avoid recompilations when
2353
+ sizes change. This may not always work as some operations/optimizations will
2354
+ force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
2355
+ When this is False, we will NEVER generate dynamic kernels, we will always specialize.
2356
+ By default (None), we automatically detect if dynamism has occurred and compile a more
2357
+ dynamic kernel upon recompile.
2358
+ backend (str or Callable): backend to be used
2359
+
2360
+ - "inductor" is the default backend, which is a good balance between performance and overhead
2361
+
2362
+ - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`
2363
+
2364
+ - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)`
2365
+
2366
+ - To register an out-of-tree custom backend:
2367
+ https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends
2368
+ mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs"
2369
+
2370
+ - "default" is the default mode, which is a good balance between performance and overhead
2371
+
2372
+ - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs,
2373
+ useful for small batches. Reduction of overhead can come at the cost of more memory
2374
+ usage, as we will cache the workspace memory required for the invocation so that we
2375
+ do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed
2376
+ to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs.
2377
+ There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints
2378
+ to debug.
2379
+
2380
+ - "max-autotune" is a mode that leverages Triton or template based matrix multiplications
2381
+ on supported devices and Triton based convolutions on GPU.
2382
+ It enables CUDA graphs by default on GPU.
2383
+
2384
+ - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs
2385
+
2386
+ - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()`
2387
+
2388
+ options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are
2389
+
2390
+ - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set
2391
+
2392
+ - `max_autotune` which will profile to pick the best matmul configuration
2393
+
2394
+ - `fallback_random` which is useful when debugging accuracy issues
2395
+
2396
+ - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores
2397
+
2398
+ - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs
2399
+
2400
+ - `trace.enabled` which is the most useful debugging flag to turn on
2401
+
2402
+ - `trace.graph_diagram` which will show you a picture of your graph after fusion
2403
+
2404
+ - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()`
2405
+ disable (bool): Turn torch.compile() into a no-op for testing
2406
+
2407
+ Example::
2408
+
2409
+ @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
2410
+ def foo(x):
2411
+ return torch.sin(x) + torch.cos(x)
2412
+
2413
+ """
2414
+ _C._log_api_usage_once("torch.compile")
2415
+ if sys.version_info >= (3, 13):
2416
+ raise RuntimeError("Dynamo is not supported on Python 3.13+")
2417
+
2418
+ # Decorator mode
2419
+ if model is None:
2420
+
2421
+ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
2422
+ if model is None:
2423
+ raise RuntimeError("Model can't be None")
2424
+ return compile(
2425
+ model,
2426
+ fullgraph=fullgraph,
2427
+ dynamic=dynamic,
2428
+ backend=backend,
2429
+ mode=mode,
2430
+ options=options,
2431
+ disable=disable,
2432
+ )
2433
+
2434
+ return fn
2435
+
2436
+ if mode is not None and options is not None:
2437
+ raise RuntimeError(
2438
+ "Either mode or options can be specified, but both can't be specified at the same time."
2439
+ )
2440
+ if mode is None and options is None:
2441
+ mode = "default"
2442
+ if backend == "inductor":
2443
+ backend = _TorchCompileInductorWrapper(mode, options, dynamic)
2444
+ else:
2445
+ backend = _TorchCompileWrapper(backend, mode, options, dynamic)
2446
+
2447
+ return torch._dynamo.optimize(
2448
+ backend=backend,
2449
+ nopython=fullgraph,
2450
+ dynamic=dynamic,
2451
+ disable=disable,
2452
+ )(model) # type: ignore[return-value]
2453
+
2454
+
2455
+ def _register_device_module(device_type, module):
2456
+ r"""Register an external runtime module of the specific :attr:`device_type`
2457
+ supported by torch.
2458
+
2459
+ After the :attr:`module` is registered correctly, the user can refer
2460
+ the external runtime module as part of torch with attribute torch.xxx.
2461
+ """
2462
+ # Make sure the device_type represent a supported device type for torch.
2463
+ device_type = torch.device(device_type).type
2464
+ m = sys.modules[__name__]
2465
+ if hasattr(m, device_type):
2466
+ raise RuntimeError(
2467
+ f"The runtime module of '{device_type}' has already "
2468
+ f"been registered with '{getattr(m, device_type)}'"
2469
+ )
2470
+ setattr(m, device_type, module)
2471
+ torch_module_name = ".".join([__name__, device_type])
2472
+ sys.modules[torch_module_name] = module
2473
+
2474
+
2475
+ from torch import (
2476
+ export as export,
2477
+ func as func,
2478
+ library as library,
2479
+ return_types as return_types,
2480
+ )
2481
+ from torch._higher_order_ops import cond as cond, while_loop as while_loop
2482
+ from torch.func import vmap as vmap
2483
+
2484
+
2485
+ if not TYPE_CHECKING:
2486
+ from torch import _meta_registrations
2487
+
2488
+ # Enable CUDA Sanitizer
2489
+ if "TORCH_CUDA_SANITIZER" in os.environ:
2490
+ import torch.cuda._sanitizer as csan
2491
+
2492
+ csan.enable_cuda_sanitizer()
2493
+
2494
+ # Populate magic methods on SymInt and SymFloat
2495
+ import torch.fx.experimental.sym_node
2496
+
2497
+
2498
+ # Register MPS specific decomps
2499
+ torch.backends.mps._init()
2500
+
2501
+ if not _running_with_deploy():
2502
+ from torch import compiler as compiler
2503
+
2504
+ class _TritonLibrary:
2505
+ lib = torch.library.Library("triton", "DEF")
2506
+ ops_table: _Dict[_Tuple[str, str], _Callable] = {}
2507
+
2508
+ @classmethod
2509
+ def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
2510
+ if (op_key, dispatch_key) not in cls.ops_table:
2511
+ cls.lib.define(full_schema)
2512
+ cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
2513
+ cls.ops_table[(op_key, dispatch_key)] = op_impl
2514
+
2515
+ return cls.ops_table[(op_key, dispatch_key)]
2516
+
2517
+
2518
+ # Deprecated attributes
2519
+ _deprecated_attrs = {
2520
+ "has_mps": torch.backends.mps.is_built,
2521
+ "has_cuda": torch.backends.cuda.is_built,
2522
+ "has_cudnn": torch.backends.cudnn.is_available,
2523
+ "has_mkldnn": torch.backends.mkldnn.is_available,
2524
+ }
2525
+
2526
+ if TYPE_CHECKING:
2527
+ # Import the following modules during type checking to enable code intelligence features,
2528
+ # such as auto-completion in tools like pylance, even when these modules are not explicitly
2529
+ # imported in user code.
2530
+ from torch import (
2531
+ _dynamo as _dynamo,
2532
+ _inductor as _inductor,
2533
+ _subclasses as _subclasses,
2534
+ onnx as onnx,
2535
+ )
2536
+
2537
+ else:
2538
+ _lazy_modules = {
2539
+ "_dynamo",
2540
+ "_inductor",
2541
+ "_export",
2542
+ # ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
2543
+ "onnx",
2544
+ }
2545
+
2546
+ def __getattr__(name):
2547
+ # Deprecated attrs
2548
+ replacement = _deprecated_attrs.get(name)
2549
+ if replacement is not None:
2550
+ import warnings
2551
+
2552
+ warnings.warn(
2553
+ f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'",
2554
+ stacklevel=2,
2555
+ )
2556
+ return replacement()
2557
+
2558
+ # Lazy modules
2559
+ if name in _lazy_modules:
2560
+ return importlib.import_module(f".{name}", __name__)
2561
+
2562
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
2563
+
2564
+
2565
+ def get_device_module(device: _Optional[_Union[torch.device, str]] = None):
2566
+ """
2567
+ Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
2568
+ If no device is given, return the module for the current accelerator or CPU if none is present.
2569
+ """
2570
+ if isinstance(device, torch.device):
2571
+ device_module_name = device.type
2572
+ elif isinstance(device, str):
2573
+ device_module_name = torch.device(device).type
2574
+ elif device is None:
2575
+ # Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
2576
+ device_module_name = torch._C._get_accelerator().type
2577
+ else:
2578
+ raise RuntimeError(
2579
+ f"Invalid value of device '{device}', expect torch.device, str, or None"
2580
+ )
2581
+ device_module = getattr(torch, device_module_name, None)
2582
+ if device_module is None:
2583
+ raise RuntimeError(
2584
+ f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
2585
+ )
2586
+ return device_module
2587
+
2588
+
2589
+ def _constrain_as_size(
2590
+ symbol,
2591
+ min: _Optional[builtins.int] = None,
2592
+ max: _Optional[builtins.int] = None,
2593
+ ):
2594
+ """
2595
+ This indicates that a given int is size-like, and can be used in any context where a size is expected.
2596
+ You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist()
2597
+ which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve
2598
+ GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts.
2599
+
2600
+ This function has unusual semantics in some circumstances in framework
2601
+ code, we will treat this int as >= 2 (when we do a size-oblivious guard).
2602
+ This makes it easier to use the unbacked int in size contexts,
2603
+ as we will often attempt to guard on a size being zero/one
2604
+ (e.g., when computing the contiguity of a tensor, or testing if
2605
+ broadcasting can occur), which will not work on unbacked SymInts.
2606
+ However, if we conservatively assume that the size is not zero/one, we will
2607
+ end up with a graph that will still work even if the size is zero/one.
2608
+
2609
+ For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit
2610
+ ```
2611
+ """
2612
+ torch.sym_constrain_range_for_size(symbol, min=min, max=max)
2613
+
2614
+
2615
+ from torch import _logging
2616
+
2617
+
2618
+ _logging._init_logs()
2619
+
2620
+
2621
+ def _import_device_backends():
2622
+ """
2623
+ Leverage the Python plugin mechanism to load out-of-the-tree device extensions.
2624
+ See this RFC: https://github.com/pytorch/pytorch/issues/122468
2625
+ """
2626
+ from importlib.metadata import entry_points
2627
+
2628
+ group_name = "torch.backends"
2629
+ if sys.version_info < (3, 10):
2630
+ backend_extensions = entry_points().get(group_name, ())
2631
+ else:
2632
+ backend_extensions = entry_points(group=group_name)
2633
+
2634
+ for backend_extension in backend_extensions:
2635
+ try:
2636
+ # Load the extension
2637
+ entrypoint = backend_extension.load()
2638
+ # Call the entrypoint
2639
+ entrypoint()
2640
+ except Exception as err:
2641
+ raise RuntimeError(
2642
+ f"Failed to load the backend extension: {backend_extension.name}. "
2643
+ f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0."
2644
+ ) from err
2645
+
2646
+
2647
+ def _is_device_backend_autoload_enabled() -> builtins.bool:
2648
+ """
2649
+ Whether autoloading out-of-the-tree device extensions is enabled.
2650
+ The switch depends on the value of the environment variable
2651
+ `TORCH_DEVICE_BACKEND_AUTOLOAD`.
2652
+
2653
+ Returns:
2654
+ bool: Whether to enable autoloading the extensions. Enabled by default.
2655
+
2656
+ Examples:
2657
+ >>> torch._is_device_backend_autoload_enabled()
2658
+ True
2659
+ """
2660
+ # enabled by default
2661
+ return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1"
2662
+
2663
+
2664
+ if _is_device_backend_autoload_enabled():
2665
+ _import_device_backends()
.venv/lib/python3.11/site-packages/torch/_appdirs.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2005-2010 ActiveState Software Inc.
4
+ # Copyright (c) 2013 Eddy Petrișor
5
+
6
+ # flake8: noqa
7
+
8
+ """
9
+ This file is directly from
10
+ https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py
11
+
12
+ The license of https://github.com/ActiveState/appdirs copied below:
13
+
14
+
15
+ # This is the MIT license
16
+
17
+ Copyright (c) 2010 ActiveState Software Inc.
18
+
19
+ Permission is hereby granted, free of charge, to any person obtaining a
20
+ copy of this software and associated documentation files (the
21
+ "Software"), to deal in the Software without restriction, including
22
+ without limitation the rights to use, copy, modify, merge, publish,
23
+ distribute, sublicense, and/or sell copies of the Software, and to
24
+ permit persons to whom the Software is furnished to do so, subject to
25
+ the following conditions:
26
+
27
+ The above copyright notice and this permission notice shall be included
28
+ in all copies or substantial portions of the Software.
29
+
30
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
31
+ OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
32
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
33
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
34
+ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
35
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
36
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
37
+ """
38
+
39
+ """Utilities for determining application-specific dirs.
40
+
41
+ See <https://github.com/ActiveState/appdirs> for details and usage.
42
+ """
43
+ # Dev Notes:
44
+ # - MSDN on where to store app data files:
45
+ # http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
46
+ # - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
47
+ # - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
48
+
49
+ __version__ = "1.4.4"
50
+ __version_info__ = tuple(int(segment) for segment in __version__.split("."))
51
+
52
+
53
+ import os
54
+ import sys
55
+
56
+
57
+ unicode = str
58
+
59
+ if sys.platform.startswith("java"):
60
+ import platform
61
+
62
+ os_name = platform.java_ver()[3][0]
63
+ if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc.
64
+ system = "win32"
65
+ elif os_name.startswith("Mac"): # "Mac OS X", etc.
66
+ system = "darwin"
67
+ else: # "Linux", "SunOS", "FreeBSD", etc.
68
+ # Setting this to "linux2" is not ideal, but only Windows or Mac
69
+ # are actually checked for and the rest of the module expects
70
+ # *sys.platform* style strings.
71
+ system = "linux2"
72
+ else:
73
+ system = sys.platform
74
+
75
+
76
+ def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
77
+ r"""Return full path to the user-specific data dir for this application.
78
+
79
+ "appname" is the name of application.
80
+ If None, just the system directory is returned.
81
+ "appauthor" (only used on Windows) is the name of the
82
+ appauthor or distributing body for this application. Typically
83
+ it is the owning company name. This falls back to appname. You may
84
+ pass False to disable it.
85
+ "version" is an optional version path element to append to the
86
+ path. You might want to use this if you want multiple versions
87
+ of your app to be able to run independently. If used, this
88
+ would typically be "<major>.<minor>".
89
+ Only applied when appname is present.
90
+ "roaming" (boolean, default False) can be set True to use the Windows
91
+ roaming appdata directory. That means that for users on a Windows
92
+ network setup for roaming profiles, this user data will be
93
+ sync'd on login. See
94
+ <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
95
+ for a discussion of issues.
96
+
97
+ Typical user data directories are:
98
+ Mac OS X: ~/Library/Application Support/<AppName>
99
+ Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
100
+ Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
101
+ Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
102
+ Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
103
+ Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
104
+
105
+ For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
106
+ That means, by default "~/.local/share/<AppName>".
107
+ """
108
+ if system == "win32":
109
+ if appauthor is None:
110
+ appauthor = appname
111
+ const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
112
+ path = os.path.normpath(_get_win_folder(const))
113
+ if appname:
114
+ if appauthor is not False:
115
+ path = os.path.join(path, appauthor, appname)
116
+ else:
117
+ path = os.path.join(path, appname)
118
+ elif system == "darwin":
119
+ path = os.path.expanduser("~/Library/Application Support/")
120
+ if appname:
121
+ path = os.path.join(path, appname)
122
+ else:
123
+ path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
124
+ if appname:
125
+ path = os.path.join(path, appname)
126
+ if appname and version:
127
+ path = os.path.join(path, version)
128
+ return path
129
+
130
+
131
+ def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
132
+ r"""Return full path to the user-shared data dir for this application.
133
+
134
+ "appname" is the name of application.
135
+ If None, just the system directory is returned.
136
+ "appauthor" (only used on Windows) is the name of the
137
+ appauthor or distributing body for this application. Typically
138
+ it is the owning company name. This falls back to appname. You may
139
+ pass False to disable it.
140
+ "version" is an optional version path element to append to the
141
+ path. You might want to use this if you want multiple versions
142
+ of your app to be able to run independently. If used, this
143
+ would typically be "<major>.<minor>".
144
+ Only applied when appname is present.
145
+ "multipath" is an optional parameter only applicable to *nix
146
+ which indicates that the entire list of data dirs should be
147
+ returned. By default, the first item from XDG_DATA_DIRS is
148
+ returned, or '/usr/local/share/<AppName>',
149
+ if XDG_DATA_DIRS is not set
150
+
151
+ Typical site data directories are:
152
+ Mac OS X: /Library/Application Support/<AppName>
153
+ Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
154
+ Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
155
+ Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
156
+ Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
157
+
158
+ For Unix, this is using the $XDG_DATA_DIRS[0] default.
159
+
160
+ WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
161
+ """
162
+ if system == "win32":
163
+ if appauthor is None:
164
+ appauthor = appname
165
+ path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
166
+ if appname:
167
+ if appauthor is not False:
168
+ path = os.path.join(path, appauthor, appname)
169
+ else:
170
+ path = os.path.join(path, appname)
171
+ elif system == "darwin":
172
+ path = os.path.expanduser("/Library/Application Support")
173
+ if appname:
174
+ path = os.path.join(path, appname)
175
+ else:
176
+ # XDG default for $XDG_DATA_DIRS
177
+ # only first, if multipath is False
178
+ path = os.getenv(
179
+ "XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"])
180
+ )
181
+ pathlist = [
182
+ os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
183
+ ]
184
+ if appname:
185
+ if version:
186
+ appname = os.path.join(appname, version)
187
+ pathlist = [os.sep.join([x, appname]) for x in pathlist]
188
+
189
+ if multipath:
190
+ path = os.pathsep.join(pathlist)
191
+ else:
192
+ path = pathlist[0]
193
+ return path
194
+
195
+ if appname and version:
196
+ path = os.path.join(path, version)
197
+ return path
198
+
199
+
200
+ def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
201
+ r"""Return full path to the user-specific config dir for this application.
202
+
203
+ "appname" is the name of application.
204
+ If None, just the system directory is returned.
205
+ "appauthor" (only used on Windows) is the name of the
206
+ appauthor or distributing body for this application. Typically
207
+ it is the owning company name. This falls back to appname. You may
208
+ pass False to disable it.
209
+ "version" is an optional version path element to append to the
210
+ path. You might want to use this if you want multiple versions
211
+ of your app to be able to run independently. If used, this
212
+ would typically be "<major>.<minor>".
213
+ Only applied when appname is present.
214
+ "roaming" (boolean, default False) can be set True to use the Windows
215
+ roaming appdata directory. That means that for users on a Windows
216
+ network setup for roaming profiles, this user data will be
217
+ sync'd on login. See
218
+ <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
219
+ for a discussion of issues.
220
+
221
+ Typical user config directories are:
222
+ Mac OS X: ~/Library/Preferences/<AppName>
223
+ Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
224
+ Win *: same as user_data_dir
225
+
226
+ For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
227
+ That means, by default "~/.config/<AppName>".
228
+ """
229
+ if system == "win32":
230
+ path = user_data_dir(appname, appauthor, None, roaming)
231
+ elif system == "darwin":
232
+ path = os.path.expanduser("~/Library/Preferences/")
233
+ if appname:
234
+ path = os.path.join(path, appname)
235
+ else:
236
+ path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
237
+ if appname:
238
+ path = os.path.join(path, appname)
239
+ if appname and version:
240
+ path = os.path.join(path, version)
241
+ return path
242
+
243
+
244
+ def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
245
+ r"""Return full path to the user-shared data dir for this application.
246
+
247
+ "appname" is the name of application.
248
+ If None, just the system directory is returned.
249
+ "appauthor" (only used on Windows) is the name of the
250
+ appauthor or distributing body for this application. Typically
251
+ it is the owning company name. This falls back to appname. You may
252
+ pass False to disable it.
253
+ "version" is an optional version path element to append to the
254
+ path. You might want to use this if you want multiple versions
255
+ of your app to be able to run independently. If used, this
256
+ would typically be "<major>.<minor>".
257
+ Only applied when appname is present.
258
+ "multipath" is an optional parameter only applicable to *nix
259
+ which indicates that the entire list of config dirs should be
260
+ returned. By default, the first item from XDG_CONFIG_DIRS is
261
+ returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
262
+
263
+ Typical site config directories are:
264
+ Mac OS X: same as site_data_dir
265
+ Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
266
+ $XDG_CONFIG_DIRS
267
+ Win *: same as site_data_dir
268
+ Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
269
+
270
+ For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
271
+
272
+ WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
273
+ """
274
+ if system == "win32":
275
+ path = site_data_dir(appname, appauthor)
276
+ if appname and version:
277
+ path = os.path.join(path, version)
278
+ elif system == "darwin":
279
+ path = os.path.expanduser("/Library/Preferences")
280
+ if appname:
281
+ path = os.path.join(path, appname)
282
+ else:
283
+ # XDG default for $XDG_CONFIG_DIRS
284
+ # only first, if multipath is False
285
+ path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg")
286
+ pathlist = [
287
+ os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
288
+ ]
289
+ if appname:
290
+ if version:
291
+ appname = os.path.join(appname, version)
292
+ pathlist = [os.sep.join([x, appname]) for x in pathlist]
293
+
294
+ if multipath:
295
+ path = os.pathsep.join(pathlist)
296
+ else:
297
+ path = pathlist[0]
298
+ return path
299
+
300
+
301
+ def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
302
+ r"""Return full path to the user-specific cache dir for this application.
303
+
304
+ "appname" is the name of application.
305
+ If None, just the system directory is returned.
306
+ "appauthor" (only used on Windows) is the name of the
307
+ appauthor or distributing body for this application. Typically
308
+ it is the owning company name. This falls back to appname. You may
309
+ pass False to disable it.
310
+ "version" is an optional version path element to append to the
311
+ path. You might want to use this if you want multiple versions
312
+ of your app to be able to run independently. If used, this
313
+ would typically be "<major>.<minor>".
314
+ Only applied when appname is present.
315
+ "opinion" (boolean) can be False to disable the appending of
316
+ "Cache" to the base app data dir for Windows. See
317
+ discussion below.
318
+
319
+ Typical user cache directories are:
320
+ Mac OS X: ~/Library/Caches/<AppName>
321
+ Unix: ~/.cache/<AppName> (XDG default)
322
+ Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
323
+ Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
324
+
325
+ On Windows the only suggestion in the MSDN docs is that local settings go in
326
+ the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
327
+ app data dir (the default returned by `user_data_dir` above). Apps typically
328
+ put cache data somewhere *under* the given dir here. Some examples:
329
+ ...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
330
+ ...\Acme\SuperApp\Cache\1.0
331
+ OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
332
+ This can be disabled with the `opinion=False` option.
333
+ """
334
+ if system == "win32":
335
+ if appauthor is None:
336
+ appauthor = appname
337
+ path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
338
+ if appname:
339
+ if appauthor is not False:
340
+ path = os.path.join(path, appauthor, appname)
341
+ else:
342
+ path = os.path.join(path, appname)
343
+ if opinion:
344
+ path = os.path.join(path, "Cache")
345
+ elif system == "darwin":
346
+ path = os.path.expanduser("~/Library/Caches")
347
+ if appname:
348
+ path = os.path.join(path, appname)
349
+ else:
350
+ path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
351
+ if appname:
352
+ path = os.path.join(path, appname)
353
+ if appname and version:
354
+ path = os.path.join(path, version)
355
+ return path
356
+
357
+
358
+ def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
359
+ r"""Return full path to the user-specific state dir for this application.
360
+
361
+ "appname" is the name of application.
362
+ If None, just the system directory is returned.
363
+ "appauthor" (only used on Windows) is the name of the
364
+ appauthor or distributing body for this application. Typically
365
+ it is the owning company name. This falls back to appname. You may
366
+ pass False to disable it.
367
+ "version" is an optional version path element to append to the
368
+ path. You might want to use this if you want multiple versions
369
+ of your app to be able to run independently. If used, this
370
+ would typically be "<major>.<minor>".
371
+ Only applied when appname is present.
372
+ "roaming" (boolean, default False) can be set True to use the Windows
373
+ roaming appdata directory. That means that for users on a Windows
374
+ network setup for roaming profiles, this user data will be
375
+ sync'd on login. See
376
+ <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
377
+ for a discussion of issues.
378
+
379
+ Typical user state directories are:
380
+ Mac OS X: same as user_data_dir
381
+ Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
382
+ Win *: same as user_data_dir
383
+
384
+ For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
385
+ to extend the XDG spec and support $XDG_STATE_HOME.
386
+
387
+ That means, by default "~/.local/state/<AppName>".
388
+ """
389
+ if system in ["win32", "darwin"]:
390
+ path = user_data_dir(appname, appauthor, None, roaming)
391
+ else:
392
+ path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state"))
393
+ if appname:
394
+ path = os.path.join(path, appname)
395
+ if appname and version:
396
+ path = os.path.join(path, version)
397
+ return path
398
+
399
+
400
+ def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
401
+ r"""Return full path to the user-specific log dir for this application.
402
+
403
+ "appname" is the name of application.
404
+ If None, just the system directory is returned.
405
+ "appauthor" (only used on Windows) is the name of the
406
+ appauthor or distributing body for this application. Typically
407
+ it is the owning company name. This falls back to appname. You may
408
+ pass False to disable it.
409
+ "version" is an optional version path element to append to the
410
+ path. You might want to use this if you want multiple versions
411
+ of your app to be able to run independently. If used, this
412
+ would typically be "<major>.<minor>".
413
+ Only applied when appname is present.
414
+ "opinion" (boolean) can be False to disable the appending of
415
+ "Logs" to the base app data dir for Windows, and "log" to the
416
+ base cache dir for Unix. See discussion below.
417
+
418
+ Typical user log directories are:
419
+ Mac OS X: ~/Library/Logs/<AppName>
420
+ Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
421
+ Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
422
+ Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
423
+
424
+ On Windows the only suggestion in the MSDN docs is that local settings
425
+ go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
426
+ examples of what some windows apps use for a logs dir.)
427
+
428
+ OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
429
+ value for Windows and appends "log" to the user cache dir for Unix.
430
+ This can be disabled with the `opinion=False` option.
431
+ """
432
+ if system == "darwin":
433
+ path = os.path.join(os.path.expanduser("~/Library/Logs"), appname)
434
+ elif system == "win32":
435
+ path = user_data_dir(appname, appauthor, version)
436
+ version = False
437
+ if opinion:
438
+ path = os.path.join(path, "Logs")
439
+ else:
440
+ path = user_cache_dir(appname, appauthor, version)
441
+ version = False
442
+ if opinion:
443
+ path = os.path.join(path, "log")
444
+ if appname and version:
445
+ path = os.path.join(path, version)
446
+ return path
447
+
448
+
449
+ class AppDirs(object):
450
+ """Convenience wrapper for getting application dirs."""
451
+
452
+ def __init__(
453
+ self, appname=None, appauthor=None, version=None, roaming=False, multipath=False
454
+ ):
455
+ self.appname = appname
456
+ self.appauthor = appauthor
457
+ self.version = version
458
+ self.roaming = roaming
459
+ self.multipath = multipath
460
+
461
+ @property
462
+ def user_data_dir(self):
463
+ return user_data_dir(
464
+ self.appname, self.appauthor, version=self.version, roaming=self.roaming
465
+ )
466
+
467
+ @property
468
+ def site_data_dir(self):
469
+ return site_data_dir(
470
+ self.appname, self.appauthor, version=self.version, multipath=self.multipath
471
+ )
472
+
473
+ @property
474
+ def user_config_dir(self):
475
+ return user_config_dir(
476
+ self.appname, self.appauthor, version=self.version, roaming=self.roaming
477
+ )
478
+
479
+ @property
480
+ def site_config_dir(self):
481
+ return site_config_dir(
482
+ self.appname, self.appauthor, version=self.version, multipath=self.multipath
483
+ )
484
+
485
+ @property
486
+ def user_cache_dir(self):
487
+ return user_cache_dir(self.appname, self.appauthor, version=self.version)
488
+
489
+ @property
490
+ def user_state_dir(self):
491
+ return user_state_dir(self.appname, self.appauthor, version=self.version)
492
+
493
+ @property
494
+ def user_log_dir(self):
495
+ return user_log_dir(self.appname, self.appauthor, version=self.version)
496
+
497
+
498
+ # ---- internal support stuff
499
+
500
+
501
+ def _get_win_folder_from_registry(csidl_name):
502
+ """This is a fallback technique at best. I'm not sure if using the
503
+ registry for this guarantees us the correct answer for all CSIDL_*
504
+ names.
505
+ """
506
+ import winreg as _winreg
507
+
508
+ shell_folder_name = {
509
+ "CSIDL_APPDATA": "AppData",
510
+ "CSIDL_COMMON_APPDATA": "Common AppData",
511
+ "CSIDL_LOCAL_APPDATA": "Local AppData",
512
+ }[csidl_name]
513
+
514
+ key = _winreg.OpenKey(
515
+ _winreg.HKEY_CURRENT_USER,
516
+ r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
517
+ )
518
+ dir, type = _winreg.QueryValueEx(key, shell_folder_name)
519
+ return dir
520
+
521
+
522
+ def _get_win_folder_with_pywin32(csidl_name):
523
+ from win32com.shell import shell, shellcon
524
+
525
+ dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
526
+ # Try to make this a unicode path because SHGetFolderPath does
527
+ # not return unicode strings when there is unicode data in the
528
+ # path.
529
+ try:
530
+ dir = unicode(dir)
531
+
532
+ # Downgrade to short path name if have highbit chars. See
533
+ # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
534
+ has_high_char = False
535
+ for c in dir:
536
+ if ord(c) > 255:
537
+ has_high_char = True
538
+ break
539
+ if has_high_char:
540
+ try:
541
+ import win32api
542
+
543
+ dir = win32api.GetShortPathName(dir)
544
+ except ImportError:
545
+ pass
546
+ except UnicodeError:
547
+ pass
548
+ return dir
549
+
550
+
551
+ def _get_win_folder_with_ctypes(csidl_name):
552
+ import ctypes
553
+
554
+ csidl_const = {
555
+ "CSIDL_APPDATA": 26,
556
+ "CSIDL_COMMON_APPDATA": 35,
557
+ "CSIDL_LOCAL_APPDATA": 28,
558
+ }[csidl_name]
559
+
560
+ buf = ctypes.create_unicode_buffer(1024)
561
+ ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
562
+
563
+ # Downgrade to short path name if have highbit chars. See
564
+ # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
565
+ has_high_char = False
566
+ for c in buf:
567
+ if ord(c) > 255:
568
+ has_high_char = True
569
+ break
570
+ if has_high_char:
571
+ buf2 = ctypes.create_unicode_buffer(1024)
572
+ if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
573
+ buf = buf2
574
+
575
+ return buf.value
576
+
577
+
578
+ def _get_win_folder_with_jna(csidl_name):
579
+ import array
580
+
581
+ from com.sun import jna
582
+ from com.sun.jna.platform import win32
583
+
584
+ buf_size = win32.WinDef.MAX_PATH * 2
585
+ buf = array.zeros("c", buf_size)
586
+ shell = win32.Shell32.INSTANCE
587
+ shell.SHGetFolderPath(
588
+ None,
589
+ getattr(win32.ShlObj, csidl_name),
590
+ None,
591
+ win32.ShlObj.SHGFP_TYPE_CURRENT,
592
+ buf,
593
+ )
594
+ dir = jna.Native.toString(buf.tostring()).rstrip("\0")
595
+
596
+ # Downgrade to short path name if have highbit chars. See
597
+ # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
598
+ has_high_char = False
599
+ for c in dir:
600
+ if ord(c) > 255:
601
+ has_high_char = True
602
+ break
603
+ if has_high_char:
604
+ buf = array.zeros("c", buf_size)
605
+ kernel = win32.Kernel32.INSTANCE
606
+ if kernel.GetShortPathName(dir, buf, buf_size):
607
+ dir = jna.Native.toString(buf.tostring()).rstrip("\0")
608
+
609
+ return dir
610
+
611
+
612
+ if system == "win32":
613
+ try:
614
+ import win32com.shell
615
+
616
+ _get_win_folder = _get_win_folder_with_pywin32
617
+ except ImportError:
618
+ try:
619
+ from ctypes import windll
620
+
621
+ _get_win_folder = _get_win_folder_with_ctypes
622
+ except ImportError:
623
+ try:
624
+ import com.sun.jna
625
+
626
+ _get_win_folder = _get_win_folder_with_jna
627
+ except ImportError:
628
+ _get_win_folder = _get_win_folder_from_registry
629
+
630
+
631
+ # ---- self test code
632
+
633
+ if __name__ == "__main__":
634
+ appname = "MyApp"
635
+ appauthor = "MyCompany"
636
+
637
+ props = (
638
+ "user_data_dir",
639
+ "user_config_dir",
640
+ "user_cache_dir",
641
+ "user_state_dir",
642
+ "user_log_dir",
643
+ "site_data_dir",
644
+ "site_config_dir",
645
+ )
646
+
647
+ print(f"-- app dirs {__version__} --")
648
+
649
+ print("-- app dirs (with optional 'version')")
650
+ dirs = AppDirs(appname, appauthor, version="1.0")
651
+ for prop in props:
652
+ print(f"{prop}: {getattr(dirs, prop)}")
653
+
654
+ print("\n-- app dirs (without optional 'version')")
655
+ dirs = AppDirs(appname, appauthor)
656
+ for prop in props:
657
+ print(f"{prop}: {getattr(dirs, prop)}")
658
+
659
+ print("\n-- app dirs (without optional 'appauthor')")
660
+ dirs = AppDirs(appname)
661
+ for prop in props:
662
+ print(f"{prop}: {getattr(dirs, prop)}")
663
+
664
+ print("\n-- app dirs (with disabled 'appauthor')")
665
+ dirs = AppDirs(appname, appauthor=False)
666
+ for prop in props:
667
+ print(f"{prop}: {getattr(dirs, prop)}")
.venv/lib/python3.11/site-packages/torch/_classes.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import types
3
+
4
+ import torch._C
5
+
6
+
7
+ class _ClassNamespace(types.ModuleType):
8
+ def __init__(self, name):
9
+ super().__init__("torch.classes" + name)
10
+ self.name = name
11
+
12
+ def __getattr__(self, attr):
13
+ proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
14
+ if proxy is None:
15
+ raise RuntimeError(f"Class {self.name}.{attr} not registered!")
16
+ return proxy
17
+
18
+
19
+ class _Classes(types.ModuleType):
20
+ __file__ = "_classes.py"
21
+
22
+ def __init__(self) -> None:
23
+ super().__init__("torch.classes")
24
+
25
+ def __getattr__(self, name):
26
+ namespace = _ClassNamespace(name)
27
+ setattr(self, name, namespace)
28
+ return namespace
29
+
30
+ @property
31
+ def loaded_libraries(self):
32
+ return torch.ops.loaded_libraries
33
+
34
+ def load_library(self, path):
35
+ """
36
+ Loads a shared library from the given path into the current process.
37
+
38
+ The library being loaded may run global initialization code to register
39
+ custom classes with the PyTorch JIT runtime. This allows dynamically
40
+ loading custom classes. For this, you should compile your class
41
+ and the static registration code into a shared library object, and then
42
+ call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
43
+ shared object.
44
+
45
+ After the library is loaded, it is added to the
46
+ ``torch.classes.loaded_libraries`` attribute, a set that may be inspected
47
+ for the paths of all libraries loaded using this function.
48
+
49
+ Args:
50
+ path (str): A path to a shared library to load.
51
+ """
52
+ torch.ops.load_library(path)
53
+
54
+
55
+ # The classes "namespace"
56
+ classes = _Classes()
.venv/lib/python3.11/site-packages/torch/_compile.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """
3
+ APIs related to torch.compile which lazily import torch._dynamo to avoid
4
+ circular dependencies.
5
+ """
6
+
7
+ import functools
8
+
9
+
10
+ def _disable_dynamo(fn=None, recursive=True):
11
+ """
12
+ This API should be only used inside torch, external users should still use
13
+ torch._dynamo.disable. The main goal of this API is to avoid circular
14
+ imports issues that is common while using _dynamo.disable inside torch
15
+ itself.
16
+
17
+ This API avoids it by lazily importing torch._dynamo from the import time to
18
+ the invocation of the decorated function.
19
+ """
20
+ if fn is not None:
21
+
22
+ @functools.wraps(fn)
23
+ def inner(*args, **kwargs):
24
+ # cache this on the first invocation to avoid adding too much overhead.
25
+ disable_fn = getattr(fn, "__dynamo_disable", None)
26
+ if disable_fn is None:
27
+ import torch._dynamo
28
+
29
+ disable_fn = torch._dynamo.disable(fn, recursive)
30
+ fn.__dynamo_disable = disable_fn
31
+
32
+ return disable_fn(*args, **kwargs)
33
+
34
+ return inner
35
+ else:
36
+ # decorator usage like @_disable_dynamo(recursive=False). The resulting
37
+ # object expects the original decorated function as the arg.
38
+ return functools.partial(_disable_dynamo, recursive=recursive)
.venv/lib/python3.11/site-packages/torch/_custom_ops.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import inspect
3
+
4
+ from torch._custom_op.impl import (
5
+ _custom_op_with_schema,
6
+ _find_custom_op,
7
+ infer_schema,
8
+ parse_qualname,
9
+ validate_namespace,
10
+ )
11
+ from torch.library import get_ctx
12
+
13
+
14
+ __all__ = [
15
+ "custom_op",
16
+ "impl",
17
+ "impl_abstract",
18
+ "get_ctx",
19
+ "impl_save_for_backward",
20
+ "impl_backward",
21
+ ]
22
+
23
+
24
+ def custom_op(qualname, func_or_schema=None):
25
+ r"""Register a new custom operator
26
+
27
+ In PyTorch, defining an op (short for "operator") is a two step-process:
28
+ - we need to define the op (by providing an operator name and schema)
29
+ - we need to implement behavior for how the operator interacts with
30
+ various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
31
+
32
+ This entrypoint defines the custom operator (the first step)
33
+ you must then perform the second step by calling various
34
+ ``impl_*`` APIs.
35
+
36
+ This API may be used as a decorator (see examples).
37
+
38
+ For a detailed guide on custom ops, please see
39
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
40
+
41
+ Arguments:
42
+ qualname (str): Should be a string that looks like
43
+ "namespace::operator_name". Operators in PyTorch need a namespace to
44
+ avoid name collisions; a given operator may only be created once.
45
+ If you are writing a Python library, we recommend the namespace to
46
+ be the name of your top-level module.
47
+ func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
48
+ schema that tells PyTorch the types of the inputs/outputs.
49
+ If this is a Callable, we will automatically infer the schema from
50
+ the type annotations on the function (see examples). Otherwise,
51
+ if you don't want to use type annotations, you may provide us the
52
+ schema string.
53
+
54
+ Example::
55
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
56
+ >>> import torch
57
+ >>> import numpy as np
58
+ >>> from torch import Tensor
59
+ >>>
60
+ >>> # Step 1: define the custom op.
61
+ >>> # We need to provide the API a "prototype function"
62
+ >>> # (a function that returns NotImplementedError), from which
63
+ >>> # we will infer the types of the inputs and outputs.
64
+ >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
65
+ >>> def numpy_sin(x: Tensor) -> Tensor:
66
+ >>> raise NotImplementedError
67
+ >>>
68
+ >>> # The custom op is now accessible via the torch.ops module:
69
+ >>> torch.ops.mylibrary.numpy_sin
70
+ >>>
71
+ >>> # Step 2: Register an implementation for various PyTorch subsystems
72
+ >>>
73
+ >>> # Register an implementation for CPU tensors
74
+ >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
75
+ >>> def numpy_sin_impl_cpu(x):
76
+ >>> return torch.from_numpy(np.sin(x.numpy()))
77
+ >>>
78
+ >>> # Register an implementation for CUDA tensors
79
+ >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
80
+ >>> def numpy_sin_impl_cuda(x):
81
+ >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
82
+ >>>
83
+ >>> x = torch.randn(3)
84
+ >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
85
+ >>>
86
+ >>> x_cuda = x.cuda()
87
+ >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
88
+
89
+ """
90
+ ns, name = parse_qualname(qualname)
91
+ validate_namespace(ns)
92
+
93
+ def inner(func):
94
+ if not inspect.isfunction(func):
95
+ raise ValueError(
96
+ f"custom_op(...)(func): Expected `func` to be a Python "
97
+ f"function, got: {type(func)}"
98
+ )
99
+
100
+ if func.__name__ != name:
101
+ raise ValueError(
102
+ f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
103
+ f"to have name '{name}' but got '{func.__name__}'. "
104
+ f"Please either change the name of `func` or the qualname that "
105
+ f"is passed to `custom_op`"
106
+ )
107
+
108
+ schema = infer_schema(func, mutates_args=())
109
+ _custom_op_with_schema(qualname, schema)
110
+ return func
111
+
112
+ if func_or_schema is None:
113
+ return inner
114
+ if isinstance(func_or_schema, str):
115
+ _custom_op_with_schema(qualname, func_or_schema)
116
+ else:
117
+ return inner(func_or_schema)
118
+
119
+
120
+ def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
121
+ r"""Register an implementation for a device type for this custom op.
122
+
123
+ If the op is passed multiple Tensor inputs with different device
124
+ types, it will dispatch to the registered implementation for the highest
125
+ priority device type among those present.
126
+ The supported device types, in order of priority, are {'cuda', 'cpu'}.
127
+
128
+ This API may be used as a decorator (see examples).
129
+
130
+ For a detailed guide on custom ops, please see
131
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
132
+
133
+ Arguments:
134
+ device_types (str or Iterable[str]): the device type(s) to register the function for.
135
+
136
+ Example::
137
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
138
+ >>> import torch
139
+ >>> import numpy as np
140
+ >>> from torch import Tensor
141
+ >>>
142
+ >>> # Step 1: define the custom op.
143
+ >>> # We need to provide the API a "prototype function"
144
+ >>> # (a function that returns NotImplementedError), from which
145
+ >>> # we will infer the types of the inputs and outputs.
146
+ >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
147
+ >>> def numpy_cos(x: Tensor) -> Tensor:
148
+ >>> raise NotImplementedError
149
+ >>>
150
+ >>> # The custom op is now accessible via the torch.ops module:
151
+ >>> torch.ops.mylibrary.numpy_cos
152
+ >>>
153
+ >>> # Step 2: Register an implementation for various PyTorch subsystems
154
+ >>>
155
+ >>> # Register an implementation for CPU tensors
156
+ >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
157
+ >>> def numpy_cos_impl_cpu(x):
158
+ >>> return torch.from_numpy(np.cos(x.numpy()))
159
+ >>>
160
+ >>> # Register an implementation for CUDA tensors
161
+ >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
162
+ >>> def numpy_cos_impl_cuda(x):
163
+ >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
164
+ >>>
165
+ >>> x = torch.randn(3)
166
+ >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
167
+ >>>
168
+ >>> x_cuda = x.cuda()
169
+ >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
170
+
171
+ """
172
+
173
+ def inner(func):
174
+ custom_op = _find_custom_op(qualname, also_check_torch_library=True)
175
+ custom_op.impl(device_types, _stacklevel=3)(func)
176
+ return func
177
+
178
+ if func is None:
179
+ return inner
180
+ return inner(func)
181
+
182
+
183
+ def impl_abstract(qualname, *, func=None):
184
+ r"""Register an abstract implementation for this operator.
185
+
186
+ An "abstract implementation" specifies the behavior of this operator on
187
+ Tensors that carry no data. Given some input Tensors with certain properties
188
+ (sizes/strides/storage_offset/device), it specifies what the properties of
189
+ the output Tensors are.
190
+
191
+ The abstract implementation has the same signature as the operator.
192
+ It is run for both FakeTensors and meta tensors. To write an abstract
193
+ implementation, assume that all Tensor inputs to the operator are
194
+ regular CPU/CUDA/Meta tensors, but they do not have storage, and
195
+ you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
196
+ The abstract implementation must consist of only PyTorch operations
197
+ (and may not directly access the storage or data of any input or
198
+ intermediate Tensors).
199
+
200
+ This API may be used as a decorator (see examples).
201
+
202
+ For a detailed guide on custom ops, please see
203
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
204
+
205
+ Examples::
206
+ >>> import numpy as np
207
+ >>> from torch import Tensor
208
+ >>>
209
+ >>> # Example 1: an operator without data-dependent output shape
210
+ >>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
211
+ >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
212
+ >>> raise NotImplementedError
213
+ >>>
214
+ >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
215
+ >>> def custom_linear_abstract(x, weight):
216
+ >>> assert x.dim() == 2
217
+ >>> assert weight.dim() == 2
218
+ >>> assert bias.dim() == 1
219
+ >>> assert x.shape[1] == weight.shape[1]
220
+ >>> assert weight.shape[0] == bias.shape[0]
221
+ >>> assert x.device == weight.device
222
+ >>>
223
+ >>> return (x @ weight.t()) + bias
224
+ >>>
225
+ >>> # Example 2: an operator with data-dependent output shape
226
+ >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
227
+ >>> def custom_nonzero(x: Tensor) -> Tensor:
228
+ >>> ...
229
+ >>>
230
+ >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
231
+ >>> def custom_nonzero_abstract(x):
232
+ >>> # Number of nonzero-elements is data-dependent.
233
+ >>> # Since we cannot peek at the data in an abstract impl,
234
+ >>> # we use the ctx object to construct a new symint that
235
+ >>> # represents the data-dependent size.
236
+ >>> ctx = torch._custom_ops.get_ctx()
237
+ >>> nnz = ctx.create_unbacked_symint()
238
+ >>> shape = [x.dim(), nnz]
239
+ >>> result = x.new_empty(shape, dtype=torch.long)
240
+ >>> return result
241
+ >>>
242
+ >>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
243
+ >>> def custom_nonzero_impl(x):
244
+ >>> x_np = to_numpy(x)
245
+ >>> res = np.stack(np.nonzero(x_np), axis=1)
246
+ >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
247
+ >>> # constrain the range to at least 2
248
+ >>> if res.shape[0] <= 1:
249
+ >>> raise RuntimeError("not supported")
250
+ >>> return torch.tensor(res, device=x.device)
251
+
252
+ """
253
+ import torch.library
254
+
255
+ return torch.library.register_fake(qualname, func, _stacklevel=2)
256
+
257
+
258
+ def impl_save_for_backward(qualname, *, func=None):
259
+ r"""Register a function that tells us what to save for backward.
260
+
261
+ Please see :func:`impl_backward` for more details.
262
+ """
263
+
264
+ def inner(func):
265
+ custom_op = _find_custom_op(qualname, also_check_torch_library=True)
266
+ custom_op.impl_save_for_backward(_stacklevel=3)(func)
267
+ return func
268
+
269
+ if func is None:
270
+ return inner
271
+ return inner(func)
272
+
273
+
274
+ def impl_backward(qualname, output_differentiability=None, *, func=None):
275
+ r"""Registers a backward formula for an operator.
276
+
277
+ In order for an operator to work with autograd, you need to register
278
+ a backward formula. There are two pieces to this:
279
+ 1. You must give us a function to specify what to save for backward.
280
+ Call this the "save for backward" function.
281
+ 2. You must give us a function that computes gradients. Call this the
282
+ "backward" function.
283
+
284
+ Use `impl_save_for_backward` to define a "save for backward" function
285
+ that specifies what gets saved for backward. The function should accept
286
+ two arguments ``(inputs, output)`` and return the quantities to be saved
287
+ for backward.
288
+
289
+ During runtime, when you call the operator in a forwards pass, PyTorch
290
+ will invoke the "save for backward" function with the inputs and output
291
+ of the operator.
292
+
293
+ Use `impl_backward` to define the "backward" function. The backward
294
+ function must accept ``(ctx, saved, *grads)``:
295
+ - ``ctx`` is a context object where we may provide information
296
+ - ``saved`` is exactly what gets returned from the "save for backward"
297
+ function
298
+ - ``grads`` is one or more gradients. The number of gradients matches
299
+ the number of outputs of the operator.
300
+
301
+ The backward function must return a dict that maps the name of
302
+ an input to the operator to its corresponding gradient. All inputs that
303
+ were declared to be Tensors in the operator definition must be accounted
304
+ for in the dict. The gradient may be a Tensor or None.
305
+
306
+ For a detailed guide on custom ops, please see
307
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
308
+
309
+ """
310
+
311
+ def inner(func):
312
+ custom_op = _find_custom_op(qualname, also_check_torch_library=True)
313
+ custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
314
+ return func
315
+
316
+ if func is None:
317
+ return inner
318
+ return inner(func)
319
+
320
+
321
+ def _destroy(qualname):
322
+ """De-registers a custom op. For testing purposes only"""
323
+ custom_op = _find_custom_op(qualname)
324
+ custom_op._destroy()
.venv/lib/python3.11/site-packages/torch/_deploy.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import io
3
+
4
+ import torch
5
+ from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
6
+ from torch.package._package_pickler import create_pickler
7
+ from torch.package._package_unpickler import PackageUnpickler
8
+ from torch.serialization import _maybe_decode_ascii
9
+
10
+
11
+ def _save_storages(importer, obj):
12
+ serialized_storages = []
13
+ serialized_dtypes = []
14
+
15
+ importer = importer if isinstance(importer, torch.package.PackageImporter) else None
16
+ importers: Importer
17
+ if importer is not None:
18
+ importers = OrderedImporter(importer, sys_importer)
19
+ else:
20
+ importers = sys_importer
21
+
22
+ def persistent_id(obj):
23
+ if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
24
+ if isinstance(obj, torch.storage.TypedStorage):
25
+ # TODO: Once we decide to break serialization FC, we can
26
+ # remove this case
27
+ dtype = obj.dtype
28
+ else:
29
+ dtype = torch.uint8
30
+
31
+ serialized_storages.append(obj)
32
+ serialized_dtypes.append(dtype)
33
+ return ("storage", len(serialized_storages) - 1)
34
+
35
+ if hasattr(obj, "__reduce_deploy__"):
36
+ if _serialized_reduces.get(id(obj)) is None:
37
+ _serialized_reduces[id(obj)] = (
38
+ "reduce_deploy",
39
+ id(obj),
40
+ *obj.__reduce_deploy__(importers),
41
+ )
42
+ return _serialized_reduces[id(obj)]
43
+
44
+ return None
45
+
46
+ # Write the pickle data for `obj`
47
+ data_buf = io.BytesIO()
48
+ pickler = create_pickler(data_buf, importers)
49
+ pickler.persistent_id = persistent_id
50
+ pickler.dump(obj)
51
+ data_value = data_buf.getvalue()
52
+ return (
53
+ data_value,
54
+ serialized_storages,
55
+ serialized_dtypes,
56
+ importer.zip_reader if importer else None,
57
+ )
58
+
59
+
60
+ def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
61
+ def persistent_load(saved_id):
62
+ assert isinstance(saved_id, tuple)
63
+ typename = _maybe_decode_ascii(saved_id[0])
64
+ data = saved_id[1:]
65
+
66
+ if typename == "storage":
67
+ # TODO: Once we decide to break serialization FC, we can
68
+ # stop wrapping with TypedStorage
69
+ storage = serialized_storages[data[0]]
70
+ dtype = serialized_dtypes[data[0]]
71
+ return torch.storage.TypedStorage(
72
+ wrap_storage=storage.untyped(), dtype=dtype
73
+ )
74
+
75
+ if typename == "reduce_deploy":
76
+ reduce_id, func, args = data
77
+ if reduce_id not in _loaded_reduces:
78
+ _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
79
+ return _loaded_reduces[reduce_id]
80
+
81
+ return None
82
+
83
+ importer: Importer
84
+ if zip_reader is not None:
85
+ importer = OrderedImporter(_get_package(zip_reader), sys_importer)
86
+ else:
87
+ importer = sys_importer
88
+
89
+ unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
90
+ unpickler.persistent_load = persistent_load # type: ignore[method-assign]
91
+ result = _deploy_objects[id] = unpickler.load()
92
+ return result
93
+
94
+
95
+ def _get_package(zip_reader):
96
+ if zip_reader not in _raw_packages:
97
+ _raw_packages[zip_reader] = PackageImporter(zip_reader)
98
+ return _raw_packages[zip_reader]
99
+
100
+
101
+ _raw_packages: dict = {}
102
+ _deploy_objects: dict = {}
103
+ _serialized_reduces: dict = {}
104
+ _loaded_reduces: dict = {}
.venv/lib/python3.11/site-packages/torch/_guards.py ADDED
@@ -0,0 +1,925 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import contextlib
5
+ import dataclasses
6
+ import enum
7
+ import functools
8
+ import logging
9
+ import threading
10
+ import traceback
11
+ import unittest.mock
12
+ import weakref
13
+ from abc import abstractmethod
14
+ from contextlib import contextmanager
15
+ from typing import (
16
+ Any,
17
+ Callable,
18
+ Dict,
19
+ Generic,
20
+ List,
21
+ NamedTuple,
22
+ Optional,
23
+ Set,
24
+ Tuple,
25
+ TYPE_CHECKING,
26
+ TypeVar,
27
+ )
28
+
29
+ from torch._C._dynamo.eval_frame import set_context_frame # noqa: F401
30
+ from torch.utils import _pytree as pytree
31
+ from torch.utils._traceback import CapturedTraceback
32
+ from torch.utils.weak import WeakTensorKeyDictionary
33
+
34
+
35
+ log = logging.getLogger(__name__)
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ import sympy
40
+
41
+ # Import the following modules during type checking to enable code intelligence features,
42
+ # such as auto-completion in tools like pylance, even when these modules are not explicitly
43
+ # imported in user code.
44
+ import torch
45
+
46
+
47
+ """
48
+ torch._guards is the definitional source of truth for general purpose guard structures.
49
+
50
+ An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
51
+ and no guard installation notions here.
52
+ """
53
+
54
+
55
+ class CompileId(NamedTuple):
56
+ frame_id: int
57
+ # This id is per-frame, and counts how many times we've compiled this
58
+ # frame. This could have been a global id but having this be per-frame
59
+ # gives you a better intuitive sense for how many recompiles have occurred
60
+ # so far.
61
+ frame_compile_id: int
62
+ # TODO: consider also tracking the recompilation count
63
+
64
+ def __str__(self):
65
+ return f"{self.frame_id}/{self.frame_compile_id}"
66
+
67
+
68
+ class TraceId(NamedTuple):
69
+ compile_id: CompileId
70
+ # This starts off as 0, and every time we restart analysis it goes
71
+ # up by one
72
+ attempt: int
73
+
74
+ def __str__(self):
75
+ if self.attempt == 0:
76
+ return str(self.compile_id)
77
+ else:
78
+ return f"{self.compile_id}_{self.attempt}"
79
+
80
+
81
+ class GuardSource(enum.Enum):
82
+ LOCAL = 0
83
+ GLOBAL = 1
84
+ LOCAL_SPECIALIZED_NN_MODULE = 2
85
+ GLOBAL_SPECIALIZED_NN_MODULE = 3
86
+ CONSTANT = 4
87
+ RANDOM_VALUE = 5
88
+ SHAPE_ENV = 6
89
+ LOCAL_FSDP_MODULE = 7
90
+ GLOBAL_FSDP_MODULE = 8
91
+ BACKWARD_STATE = 9
92
+ EPHEMERAL = 10
93
+ SYNTHETIC_LOCAL = 11
94
+ LOCAL_UNSPECIALIZED_NN_MODULE = 12
95
+ GLOBAL_UNSPECIALIZED_NN_MODULE = 13
96
+ LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14
97
+ GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15
98
+
99
+ def is_fsdp_module(self) -> bool:
100
+ return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
101
+
102
+ def is_specialized_nn_module(self) -> bool:
103
+ return (
104
+ self
105
+ in (
106
+ GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
107
+ GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
108
+ )
109
+ # TODO (anijain2305) - Investigate why is_fsdp_module required.
110
+ or self.is_fsdp_module()
111
+ )
112
+
113
+ def is_unspecialized_nn_module(self) -> bool:
114
+ return self in (
115
+ GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
116
+ GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
117
+ GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
118
+ GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
119
+ )
120
+
121
+ def is_unspecialized_builtin_nn_module(self) -> bool:
122
+ return self in (
123
+ GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
124
+ GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
125
+ )
126
+
127
+ def is_local(self):
128
+ return self in (
129
+ GuardSource.LOCAL,
130
+ GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
131
+ GuardSource.LOCAL_FSDP_MODULE,
132
+ GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
133
+ GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
134
+ )
135
+
136
+
137
+ """
138
+ Base class for a "GuardBuilder" role.
139
+
140
+ The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
141
+ confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
142
+ to torchdynamo's GuardBuilder.
143
+
144
+ Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
145
+ on GuardSource's select function.
146
+
147
+ There is value in keeping this GuardBuilderBase empty to keep layering clean.
148
+ """
149
+
150
+
151
+ class GuardBuilderBase:
152
+ pass
153
+
154
+
155
+ class ShapeGuard(NamedTuple):
156
+ expr: sympy.Expr
157
+ stack: CapturedTraceback
158
+
159
+
160
+ @dataclasses.dataclass
161
+ class Guard:
162
+ # originating_source is the source that called the make_guard method to
163
+ # construct this guard object. The property name specifies what exactly it
164
+ # is the guard is guarding on. The meaning of the name is dependent on the
165
+ # create_fn; you must look at the use-site inside create_fn to know what
166
+ # name means.
167
+ #
168
+ # That being said, although you might think this is just a "name", name is
169
+ # usually an arbitrary Python expression that will be evaluated with all
170
+ # globals (and locals, if you create a LOCAL guard) to extract the Python
171
+ # object that we want to perform guard tests on. This evaluation
172
+ # typically happens in GuardBuilder.eval. In these cases, name is
173
+ # typically produced by originating_source.name() (not to be confused with
174
+ # GuardSource - the property source).
175
+ #
176
+ # Occasionally, name is not a valid Python expression; sometimes
177
+ # it is meaningless. Example create_fns that are like this include
178
+ # GRAD_MODE and SHAPE_ENV.
179
+ originating_source: Source
180
+ create_fn: Callable[[GuardBuilderBase, Guard], None]
181
+
182
+ # Export only. These values are written to at time of guard check_fn creation.
183
+ guard_types: Optional[List[str]] = None
184
+ code_list: Optional[List[str]] = None
185
+ obj_weakref: Optional[object] = None
186
+ guarded_class_weakref: Optional[type] = None
187
+
188
+ stack: Optional[CapturedTraceback] = None
189
+ user_stack: Optional[traceback.StackSummary] = None
190
+ _hash: Optional[int] = None
191
+
192
+ def __hash__(self):
193
+ if self._hash is None:
194
+ self._hash = hash((self.name, self.source, id(self.create_fn)))
195
+ return self._hash
196
+
197
+ def sort_key(self):
198
+ # Put the duplicate input guards at the end. The duplicate guards have
199
+ # two sources while guard.name only considers one source.
200
+ from torch._dynamo.guards import GuardBuilder
201
+
202
+ is_duplicate_input = (
203
+ isinstance(self.create_fn, functools.partial)
204
+ and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT
205
+ )
206
+ return (
207
+ is_duplicate_input,
208
+ self.source.value if self.source else -1,
209
+ len(self.name),
210
+ self.name,
211
+ self.inner_create_fn().__code__.co_firstlineno,
212
+ )
213
+
214
+ def __lt__(self, other):
215
+ return self.sort_key() < other.sort_key()
216
+
217
+ def inner_create_fn(self):
218
+ if isinstance(self.create_fn, functools.partial):
219
+ return self.create_fn.func
220
+ else:
221
+ return self.create_fn
222
+
223
+ @property
224
+ def name(self) -> str:
225
+ return self.originating_source.name()
226
+
227
+ @property
228
+ def source(self) -> GuardSource:
229
+ return self.originating_source.guard_source()
230
+
231
+ @staticmethod
232
+ def weakref_to_str(obj_weakref):
233
+ """
234
+ This is a workaround of a Python weakref bug.
235
+
236
+ `obj_weakref` is instance returned by `weakref.ref`,
237
+ `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
238
+
239
+ class MyConfig(dict):
240
+ def __getattr__(self, x):
241
+ return self[x]
242
+
243
+ obj = MyConfig(offset=5)
244
+ obj_weakref = weakref.ref(obj)
245
+ str(obj_weakref) # raise error: KeyError: '__name__'
246
+ """
247
+ if isinstance(obj_weakref, weakref.ReferenceType):
248
+ obj = obj_weakref()
249
+ if obj is not None:
250
+ return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
251
+ else:
252
+ return f"<weakref at {hex(id(obj_weakref))}; dead>"
253
+ else:
254
+ return str(obj_weakref)
255
+
256
+ def __repr__(self):
257
+ s = f"""
258
+ {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
259
+ {{
260
+ 'guard_types': {self.guard_types},
261
+ 'code': {self.code_list},
262
+ 'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
263
+ 'guarded_class': {self.guarded_class_weakref}
264
+ }}
265
+ """
266
+ return s
267
+
268
+ def __str__(self):
269
+ output = f"Name: {repr(self.name)}\n"
270
+ source = self.source.name.lower() if self.source else ""
271
+ output += f" Source: {source}\n"
272
+ output += f" Create Function: {self.inner_create_fn().__name__}\n"
273
+ output += f" Guard Types: {self.guard_types}\n"
274
+ output += f" Code List: {self.code_list}\n"
275
+ output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
276
+ output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
277
+ return output
278
+
279
+ def create(self, builder: GuardBuilderBase):
280
+ try:
281
+ return self.create_fn(builder, self)
282
+ except Exception:
283
+ log.exception("Error while creating guard:\n%s", str(self).rstrip())
284
+ if self.stack:
285
+ log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
286
+ raise
287
+
288
+ def is_specialized_nn_module(self):
289
+ return self.source.is_specialized_nn_module()
290
+
291
+ def is_fsdp_module(self):
292
+ return self.source.is_fsdp_module()
293
+
294
+ def is_local(self):
295
+ return self.source.is_local()
296
+
297
+ def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
298
+ if not self.guard_types:
299
+ self.guard_types = []
300
+
301
+ self.guard_types.append(guard_type)
302
+
303
+ assert self.guarded_class_weakref in (
304
+ guarded_class,
305
+ None,
306
+ ), "Guarded class id must be identical, or None"
307
+ self.guarded_class_weakref = guarded_class
308
+
309
+ if not self.code_list:
310
+ self.code_list = code_list
311
+ else:
312
+ self.code_list.extend(code_list)
313
+
314
+ # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
315
+ # multiple guards on the same object, the weakref can die between the
316
+ # invocation of set_export_info calls. So a dead weakref is also
317
+ # acceptable.
318
+ assert (
319
+ self.obj_weakref in (obj_weakref, None)
320
+ or callable(self.obj_weakref)
321
+ and self.obj_weakref() is None
322
+ ), "Guarded object must be identical, None or ephemeral (dead weakref)"
323
+ self.obj_weakref = obj_weakref
324
+
325
+
326
+ T = TypeVar("T")
327
+
328
+ """
329
+ Parent structure for guard env expressions.
330
+ A GuardEnvExpr can have any subtype.
331
+ Note: All subtypes must be handled exhaustively in
332
+ torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
333
+ """
334
+
335
+
336
+ @dataclasses.dataclass
337
+ class GuardEnvExpr:
338
+ pass
339
+
340
+
341
+ """
342
+ A class representing a pair of duplicate inputs.
343
+ input_pos_a and input_pos_b are input positions we have deduped.
344
+ """
345
+
346
+
347
+ @dataclasses.dataclass
348
+ class DuplicateInputs(GuardEnvExpr):
349
+ input_source_a: Source
350
+ input_source_b: Source
351
+
352
+ def __post_init__(self):
353
+ assert self.input_source_a != self.input_source_b
354
+
355
+
356
+ """
357
+ Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
358
+
359
+ copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
360
+ can also be taken in at restore_graphstate(T) calls.
361
+
362
+ When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
363
+ does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet.
364
+
365
+ In the future, it will have a closer coupling to a generic Checkpoint management system.
366
+ """
367
+
368
+
369
+ class Checkpointable(Generic[T]):
370
+ @abstractmethod
371
+ def copy_graphstate(self) -> T: ...
372
+
373
+ @abstractmethod
374
+ def restore_graphstate(self, state: T): ...
375
+
376
+
377
+ class GuardsCheckpointState:
378
+ """
379
+ The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
380
+ """
381
+
382
+ dynamo_guards: Set[Guard] = set()
383
+
384
+ def __init__(self, dynamo_guards):
385
+ self.dynamo_guards = dynamo_guards
386
+
387
+ def diff(self, other):
388
+ """
389
+ Produces a delta against another GuardsCheckpointState.
390
+
391
+ Returns None if no delta is found, otherwise, return a set() of mismatched
392
+ Guard type objects.
393
+ """
394
+ r = self.dynamo_guards.difference(other.dynamo_guards)
395
+ if len(r) == 0:
396
+ return None
397
+ return r
398
+
399
+ def __eq__(self, other):
400
+ return self.diff(other) is None
401
+
402
+
403
+ class ModuleContextCheckpointState:
404
+ nn_modules: Dict[str, torch.nn.Module] = {}
405
+
406
+ def __init__(self, nn_modules):
407
+ self.nn_modules = nn_modules
408
+
409
+ def diff(self, other):
410
+ """
411
+ Produces a delta against another ModuleContextCheckpointState.
412
+
413
+ Returns None if no delta is found, otherwise, return a set() of mismatched
414
+ module key names.
415
+ """
416
+ r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
417
+ if len(r) == 0:
418
+ return None
419
+ return r
420
+
421
+ def __eq__(self, other):
422
+ return self.diff(other) is None
423
+
424
+
425
+ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
426
+ def __init__(self) -> None:
427
+ self.nn_modules: Dict[str, Any] = {}
428
+
429
+ def copy_graphstate(self):
430
+ return ModuleContextCheckpointState(dict(self.nn_modules))
431
+
432
+ def restore_graphstate(self, state):
433
+ assert isinstance(state, ModuleContextCheckpointState)
434
+ self.nn_modules = state.nn_modules
435
+
436
+
437
+ class GlobalContextCheckpointState:
438
+ global_state: Dict[str, Tuple[Callable, ...]] = {}
439
+
440
+ def __init__(self, global_states):
441
+ self.global_state = global_states
442
+
443
+ def diff(self, other):
444
+ """
445
+ Produces a delta against another GlobalContextCheckpointState.
446
+
447
+ Returns None if no delta is found, otherwise, return a set() of mismatched
448
+ global key names.
449
+ """
450
+ r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
451
+ if len(r) == 0:
452
+ return None
453
+ return r
454
+
455
+ def __eq__(self, other):
456
+ return self.diff(other) is None
457
+
458
+
459
+ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
460
+ """
461
+ This keeps track of the global torch state during tracing of a function.
462
+ For example, torch.is_grad_enabled.
463
+ """
464
+
465
+ _supported_global_states = {
466
+ "grad_enabled",
467
+ "torch_function_enabled",
468
+ "autocast_enabled",
469
+ "autocast_cpu_enabled",
470
+ "autocast_gpu_dtype",
471
+ "autocast_cpu_dtype",
472
+ "autocast_cache_enabled",
473
+ }
474
+
475
+ def __init__(self) -> None:
476
+ self.global_state: Dict[str, Tuple[Callable, ...]] = {}
477
+
478
+ def copy_graphstate(self):
479
+ return GlobalContextCheckpointState(dict(self.global_state))
480
+
481
+ def restore_graphstate(self, state):
482
+ assert isinstance(state, GlobalContextCheckpointState)
483
+ self.global_state = state.global_state
484
+ assert (
485
+ len(self.global_state) == len(self._supported_global_states)
486
+ and set(self.global_state.keys()) == self._supported_global_states
487
+ ), "Global state mismatch"
488
+ for func, args in self.global_state.values():
489
+ func(args)
490
+
491
+
492
+ """
493
+ A GuardsContext is a checkpointable representation of all the guards in the current tracing
494
+ context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
495
+ directly outside of it. For passing around internal state representations of this object,
496
+ prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
497
+ """
498
+
499
+
500
+ # Like a Set[Guard] but will record the user stack on all guards at the
501
+ # time they were installed at their destination
502
+ class GuardsSet:
503
+ def __init__(self, inner=None):
504
+ if inner is None:
505
+ inner = set()
506
+ self.inner = inner
507
+
508
+ def __iter__(self):
509
+ return iter(self.inner)
510
+
511
+ def __len__(self):
512
+ return len(self.inner)
513
+
514
+ # Subtraction along with bool is typically used to determine the delta of
515
+ # added guards between checkpoints for higher order ops
516
+ def __sub__(self, other):
517
+ return GuardsSet(self.inner - other.inner)
518
+
519
+ def __bool__(self):
520
+ return bool(self.inner)
521
+
522
+ def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
523
+ if guard in self.inner:
524
+ return
525
+ if collect_debug_stack:
526
+ if guard.stack is None:
527
+ guard.stack = CapturedTraceback.extract(skip=1 + skip)
528
+ if guard.user_stack is None:
529
+ guard.user_stack = TracingContext.extract_stack()
530
+ self.inner.add(guard)
531
+
532
+ def update(self, *others: Set[Guard]):
533
+ for o in others:
534
+ for g in o:
535
+ self.add(g, skip=1)
536
+
537
+ def remove_guards_with_source(self, source):
538
+ """Delete all guards with a given source"""
539
+ self.inner = {g for g in self.inner if g.originating_source != source}
540
+
541
+
542
+ class GuardsContext(Checkpointable[GuardsCheckpointState]):
543
+ def __init__(self) -> None:
544
+ self.dynamo_guards: GuardsSet = GuardsSet()
545
+ self.aotautograd_guards: List[GuardEnvExpr] = []
546
+
547
+ def copy_graphstate(self):
548
+ return GuardsCheckpointState(set(self.dynamo_guards.inner))
549
+
550
+ def restore_graphstate(self, state):
551
+ # NB: "steals" the passed in state
552
+ assert isinstance(state, GuardsCheckpointState)
553
+ self.dynamo_guards = GuardsSet(state.dynamo_guards)
554
+
555
+
556
+ _TLS = threading.local()
557
+
558
+ """
559
+ TracingContext is the source of truth for all currently accumulated information
560
+ needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
561
+ are open to managing their own TracingContext with that in mind.
562
+
563
+ The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
564
+ having to plumb complex subsystems across multiple verticals.
565
+
566
+ Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
567
+ Accessing the current tracing context via
568
+ TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
569
+ to plumb objects back up to where frame interpretation happened.
570
+
571
+ Note that you can end up with multiple TracingContext for a single compilation
572
+ of a frame, as we reset the TracingContext whenever we restart analysis.
573
+ CompileContext is a more overarching context that encompasses multiple restarts.
574
+ """
575
+
576
+
577
+ class CompileContext:
578
+ @staticmethod
579
+ def get() -> CompileContext:
580
+ assert _TLS.compile_context is not None
581
+ return _TLS.compile_context
582
+
583
+ @staticmethod
584
+ def try_get() -> Optional[CompileContext]:
585
+ return getattr(_TLS, "compile_context", None)
586
+
587
+ def __init__(self, compile_id):
588
+ assert compile_id is None or isinstance(compile_id, CompileId)
589
+ self.compile_id: Optional[CompileId] = compile_id
590
+ self.attempt = 0
591
+
592
+ @staticmethod
593
+ def current_compile_id():
594
+ self = CompileContext.try_get()
595
+ if self is None:
596
+ return None
597
+ return self.compile_id
598
+
599
+ @staticmethod
600
+ def current_trace_id():
601
+ self = CompileContext.try_get()
602
+ if self is None:
603
+ return None
604
+ if self.compile_id is None:
605
+ return None
606
+ return TraceId(self.compile_id, self.attempt)
607
+
608
+
609
+ class TracingContext:
610
+ """
611
+ Provides the currently installed TracingContext, or None.
612
+
613
+ Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
614
+ will return None.
615
+ """
616
+
617
+ @staticmethod
618
+ def try_get() -> Optional[TracingContext]:
619
+ return getattr(_TLS, "tracing_context", None)
620
+
621
+ @staticmethod
622
+ def get() -> TracingContext:
623
+ if ctx := TracingContext.try_get():
624
+ return ctx
625
+ raise RuntimeError(
626
+ "TracingContext.get() must be called within an ongoing trace."
627
+ )
628
+
629
+ def __init__(self, fake_mode):
630
+ self.guards_context = GuardsContext()
631
+ self.module_context = ModuleContext()
632
+ self.global_context = GlobalContext()
633
+ self.fake_mode = fake_mode
634
+ self.frame_summary_stack = []
635
+ # This is morally part of frame_summary_stack, but it is kept separate
636
+ # for clarity. As we process a frame, this variable gets updated
637
+ # to keep track of what line we are in the function. We make a
638
+ # function call, this gets cleared and the frame location is pushed
639
+ # to frame_summary_stack (prepping this variable for the inner frame's
640
+ # progress)
641
+ self.loc_in_frame = None
642
+ # this is only set after aot_autograd
643
+ self.fw_metadata = None
644
+ # this is only set after aot_autograd
645
+ self.aot_graph_name = None
646
+ self.params_flat = None
647
+ # this is for extended return calling convention from backend
648
+ # compiler to aot_autograd
649
+ # Per output, what the compiler specified stride of the output is,
650
+ # or None if no stride is known. This is always the HINT, it
651
+ # is never a SymInt (it would be better if it was a SymInt, but
652
+ # I can't conveniently get this from Inductor atm. Also, be
653
+ # careful not to accidentally induce guards on the SymInt if
654
+ # you ever do change this in aot_autograd.py; you should check
655
+ # on permutations preferentially.)
656
+ self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
657
+ # When this is True, whenever we encounter an int in Dynamo tracing,
658
+ # we will (1) force unspec it and (2) force it as a size-like unbacked
659
+ # integer. This is currently used when processing certain lists of
660
+ # ints that are known to be size-like and may have 0/1 entries that we
661
+ # must not specialize on.
662
+ self.force_unspec_int_unbacked_size_like = False
663
+ # See note [Tensor Fakification and Symbol Caching]
664
+ self.tensor_to_context = WeakTensorKeyDictionary()
665
+
666
+ # If this true, Aot Autograd will return output Fake Tensors with appropiate
667
+ # meta on the first invocation
668
+ # see note: [Returning Fake Tensors on First AOT Autograd Call]
669
+ self.fakify_first_call = False
670
+
671
+ def clear(self):
672
+ # Look at the note in output_graph.py in function `save_global_state`
673
+ # for the context on clearing global context.
674
+ self.global_context.global_state = {}
675
+
676
+ @staticmethod
677
+ @contextmanager
678
+ def patch(**kwargs):
679
+ prior = {}
680
+ ctx = TracingContext.get()
681
+
682
+ for key in kwargs.keys():
683
+ # KeyError on invalid entry
684
+ prior[key] = getattr(ctx, key)
685
+ for key, val in kwargs.items():
686
+ setattr(ctx, key, val)
687
+ try:
688
+ yield
689
+ finally:
690
+ for key, val in prior.items():
691
+ setattr(ctx, key, val)
692
+
693
+ @staticmethod
694
+ def extract_stack():
695
+ self = TracingContext.try_get()
696
+ if self is None:
697
+ return traceback.StackSummary()
698
+ stack = self.frame_summary_stack
699
+ if self.loc_in_frame is not None:
700
+ stack = stack + [self.loc_in_frame]
701
+ return traceback.StackSummary.from_list(stack)
702
+
703
+ # Call this when you want to call into some code that isn't necessarily
704
+ # associated with the current frame state
705
+ @staticmethod
706
+ @contextlib.contextmanager
707
+ def clear_frame():
708
+ tc = TracingContext.get()
709
+ with unittest.mock.patch.object(
710
+ tc, "frame_summary_stack", []
711
+ ), unittest.mock.patch.object(tc, "loc_in_frame", None):
712
+ try:
713
+ yield
714
+ except Exception as e:
715
+ # Prevent real_stack from getting attached
716
+ #
717
+ # The invariant is that if an Exception as real_stack, we've
718
+ # appropriately attached a user stack and we no longer need to
719
+ # attach anything. Because we cannot conveniently interpose
720
+ # when an exception is thrown, we instead interpose everywhere
721
+ # we set what the user stack is set (using the context
722
+ # manager). However, our compiler stack does "tail calls"
723
+ # (when it calls into user compiler), at which point the
724
+ # parent exception frames would incorrectly attach an
725
+ # incorrect frame.
726
+ #
727
+ # However, if, somehow, someone raised an exception with this
728
+ # scope that had a stack (for example, because they are
729
+ # restoring the user stack state appropriately as they process
730
+ # node by node), we should respect it. Thus, we cannot
731
+ # unconditionally set None.
732
+ if not hasattr(e, "real_stack"):
733
+ e.real_stack = None # type: ignore[attr-defined]
734
+ raise
735
+
736
+ @staticmethod
737
+ @contextlib.contextmanager
738
+ def current_frame(frame_summary):
739
+ # frame_summary can be None to solely take advantage of real_stack
740
+ # attachment to thrown exceptions
741
+ tc = TracingContext.get()
742
+ if frame_summary is not None:
743
+ tc.frame_summary_stack.append(frame_summary)
744
+ old = tc.loc_in_frame
745
+ tc.loc_in_frame = None
746
+ try:
747
+ yield
748
+ except Exception as e:
749
+ if not hasattr(e, "real_stack"):
750
+ e.real_stack = tc.extract_stack() # type: ignore[attr-defined]
751
+ raise
752
+ finally:
753
+ if frame_summary is not None:
754
+ tc.frame_summary_stack.pop()
755
+ tc.loc_in_frame = old
756
+
757
+ @staticmethod
758
+ @contextlib.contextmanager
759
+ def report_output_strides():
760
+ tc = TracingContext.try_get()
761
+ if tc is None:
762
+ yield None
763
+ return
764
+ old_output_strides = tc.output_strides
765
+ tc.output_strides = []
766
+ try:
767
+ yield tc.output_strides
768
+ finally:
769
+ tc.output_strides = old_output_strides
770
+
771
+ @staticmethod
772
+ def set_current_loc(filename, lineno, frame_name):
773
+ TracingContext.get().loc_in_frame = traceback.FrameSummary(
774
+ filename, lineno, frame_name, lookup_line=False
775
+ )
776
+
777
+
778
+ @contextmanager
779
+ def compile_context(context: Optional[CompileContext]):
780
+ old_context = getattr(_TLS, "compile_context", None)
781
+ _TLS.compile_context = context
782
+ try:
783
+ yield context
784
+ finally:
785
+ if context is not None:
786
+ if context.compile_id is not None:
787
+ set_context_frame(
788
+ (
789
+ context.compile_id.frame_id,
790
+ context.compile_id.frame_compile_id,
791
+ context.attempt,
792
+ )
793
+ )
794
+ _TLS.compile_context = old_context
795
+
796
+
797
+ @contextmanager
798
+ def tracing(context: Optional[TracingContext]):
799
+ """
800
+ This function installs the passed in tracing context as a dynamic scoped
801
+ global variable.
802
+
803
+ Calls to TracingContext.get() while not under a `with tracing()` context
804
+ will return None.
805
+ """
806
+ old_context = getattr(_TLS, "tracing_context", None)
807
+ _TLS.tracing_context = context
808
+ try:
809
+ yield context
810
+ except Exception as e:
811
+ if not hasattr(e, "real_stack") and context is not None:
812
+ e.real_stack = context.extract_stack() # type: ignore[attr-defined]
813
+ raise
814
+ finally:
815
+ if (
816
+ context is not None
817
+ and context.fake_mode is not None
818
+ and context.fake_mode.shape_env is not None
819
+ ):
820
+ context.fake_mode.shape_env.cleanup()
821
+ _TLS.tracing_context = old_context
822
+
823
+
824
+ # Subclasses can be found in torch/_dynamo/source.py
825
+ # TODO(voz): Consider a toplevel torch/_source.py
826
+ @dataclasses.dataclass(frozen=True)
827
+ class Source:
828
+ def is_dict_key(self):
829
+ return False
830
+
831
+ def is_ephemeral(self):
832
+ return False
833
+
834
+ def reconstruct(self, codegen):
835
+ raise NotImplementedError
836
+
837
+ def guard_source(self) -> GuardSource:
838
+ raise NotImplementedError
839
+
840
+ def name(self) -> str:
841
+ raise NotImplementedError
842
+
843
+ def make_guard(self, fn) -> Guard:
844
+ if self.guard_source() is GuardSource.CONSTANT:
845
+ raise NotImplementedError
846
+ return Guard(self, fn)
847
+
848
+ def is_specialized_nn_module(self) -> bool:
849
+ return self.guard_source().is_specialized_nn_module()
850
+
851
+ def subguards_allowed(self):
852
+ """True if you can guard on attributes of this"""
853
+ return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
854
+
855
+
856
+ # Subclasses can be found in torch/_dynamo/source.py
857
+ @dataclasses.dataclass(frozen=True)
858
+ class ChainedSource(Source):
859
+ base: Source
860
+
861
+ def is_dict_key(self):
862
+ # Recurse until you either hit a ConstDictKey or a Source
863
+ return self.base.is_dict_key()
864
+
865
+ def is_ephemeral(self):
866
+ return self.base.is_ephemeral()
867
+
868
+
869
+ def detect_fake_mode(inputs: Any = None):
870
+ """
871
+ Attempts to "detect" what the current fake mode is. If there is one ambiently
872
+ available from TracingContext, we preferentially use that. Otherwise, we
873
+ heuristically detect the fake mode via the following sources, in order of
874
+ priority:
875
+
876
+ - Currently active fake mode on stack
877
+ - Fake mode associated with passed in tensors (inputs does not
878
+ have to be flattened)
879
+ """
880
+ from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
881
+
882
+ fake_modes = []
883
+
884
+ if context := TracingContext.try_get():
885
+ fake_mode = context.fake_mode
886
+ if fake_mode is not None:
887
+ fake_modes.append((fake_mode, "tracing context", 0))
888
+
889
+ from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
890
+
891
+ for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
892
+ if isinstance(m, FakeTensorMode):
893
+ fake_modes.append((m, "active fake mode", i))
894
+
895
+ flat_inputs = pytree.tree_leaves(inputs)
896
+ for i, flat_input in enumerate(flat_inputs):
897
+ if isinstance(flat_input, FakeTensor):
898
+ fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
899
+
900
+ if fake_modes:
901
+ fake_mode, desc1, i1 = fake_modes[0]
902
+ for m, desc2, i2 in fake_modes[1:]:
903
+ assert fake_mode is m, (
904
+ f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
905
+ f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
906
+ f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
907
+ )
908
+ return fake_mode
909
+ else:
910
+ return None
911
+
912
+
913
+ def active_fake_mode():
914
+ """
915
+ Inspects the dispatch mode stack for an active fake mode and returns it.
916
+ Returns None if no fake mode is active.
917
+ """
918
+ from torch._subclasses.fake_tensor import FakeTensorMode
919
+ from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
920
+
921
+ for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
922
+ if isinstance(m, FakeTensorMode):
923
+ return m
924
+
925
+ return None
.venv/lib/python3.11/site-packages/torch/_jit_internal.py ADDED
@@ -0,0 +1,1547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """
3
+ The weak_script annotation needs to be here instead of inside torch/jit/ so it
4
+ can be used in other places in torch/ (namely torch.nn) without running into
5
+ circular dependency problems
6
+ """
7
+
8
+ import ast
9
+ import builtins
10
+ import collections
11
+ import contextlib
12
+ import enum
13
+ import inspect
14
+ import io
15
+ import pickle
16
+ import sys
17
+ import textwrap
18
+ import threading
19
+ import types
20
+ import typing
21
+ import warnings
22
+ import weakref
23
+ from typing import (
24
+ Any,
25
+ Callable,
26
+ Dict,
27
+ Final,
28
+ ForwardRef,
29
+ get_args,
30
+ get_origin,
31
+ List,
32
+ Optional,
33
+ Tuple,
34
+ Type,
35
+ Union,
36
+ )
37
+
38
+ import torch
39
+
40
+ # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
41
+ # Explicitly ask to import `torch.distributed.__init__` first.
42
+ # Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
43
+ import torch.distributed.rpc
44
+ import torch.package._mangling as package_mangling
45
+ from torch._awaits import _Await
46
+ from torch._C import _Await as CAwait, Future as CFuture
47
+ from torch._sources import fake_range, get_source_lines_and_file, parse_def
48
+ from torch.futures import Future
49
+
50
+
51
+ IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
52
+ IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
53
+
54
+ BuiltinUnionType: Union[Type, Tuple[Type, ...]]
55
+ if sys.version_info >= (3, 10):
56
+ # NOTE: IS_PY310_PLUS doesn't work with mypy.
57
+ # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
58
+ BuiltinUnionType = types.UnionType
59
+ else:
60
+ BuiltinUnionType = () # trick: this makes isinstance short circuit.
61
+
62
+ LockType: Type
63
+ try:
64
+ import _thread
65
+
66
+ LockType = _thread.LockType
67
+ except ImportError:
68
+ import _dummy_thread # type: ignore[import-not-found]
69
+
70
+ LockType = _dummy_thread.LockType
71
+
72
+ # Wrapper functions that can call either of 2 functions depending on a boolean
73
+ # argument
74
+ boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
75
+ weakref.WeakKeyDictionary()
76
+ ) # noqa: T484
77
+
78
+
79
+ FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
80
+
81
+
82
+ def is_final(ann) -> bool:
83
+ return (
84
+ hasattr(ann, "__module__")
85
+ and ann.__module__ in {"typing", "typing_extensions"}
86
+ and (get_origin(ann) is Final or isinstance(ann, type(Final)))
87
+ )
88
+
89
+
90
+ # allows BroadcastingList instance to be subscriptable
91
+ class BroadcastingListCls:
92
+ def __getitem__(self, types):
93
+ return
94
+
95
+
96
+ # mypy doesn't support parameters on types, so we have to explicitly type each
97
+ # list size
98
+ BroadcastingList1 = BroadcastingListCls()
99
+ for i in range(2, 7):
100
+ globals()[f"BroadcastingList{i}"] = BroadcastingList1
101
+
102
+
103
+ def is_scripting() -> bool:
104
+ r"""
105
+ Function that returns True when in compilation and False otherwise. This
106
+ is useful especially with the @unused decorator to leave code in your
107
+ model that is not yet TorchScript compatible.
108
+ .. testcode::
109
+
110
+ import torch
111
+
112
+ @torch.jit.unused
113
+ def unsupported_linear_op(x):
114
+ return x
115
+
116
+ def linear(x):
117
+ if torch.jit.is_scripting():
118
+ return torch.linear(x)
119
+ else:
120
+ return unsupported_linear_op(x)
121
+ """
122
+ return False
123
+
124
+
125
+ # Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
126
+ def _qualified_name(obj, mangle_name=True) -> str:
127
+ # This special case allows us to override the qualified name on a type.
128
+ # It's currently used in conjunction with tracing, where we create a
129
+ # fake module to filter only supported attributes. However, since this
130
+ # new type is defined as a local class, we need a mechanism to override
131
+ # its qualname so it appears correctly in the TorchScript system. This,
132
+ # we set '_jit_override_qualname' with the original traced module's
133
+ # qualified name, which is picked up here
134
+ if hasattr(obj, "_jit_override_qualname"):
135
+ return obj._jit_override_qualname
136
+ # short-circuit in cases where the object already has a known qualified name
137
+ if isinstance(obj, torch._C.ScriptFunction):
138
+ return obj.qualified_name
139
+
140
+ if getattr(obj, "__name__", None):
141
+ name = obj.__name__
142
+ # Enum classes do not have `__name__` attr, instead they have `name`.
143
+ elif isinstance(obj, enum.Enum):
144
+ name = obj.name
145
+ else:
146
+ raise RuntimeError("Could not get name of python class object")
147
+
148
+ if name == "<lambda>":
149
+ name = "_lambda" # make name a valid identifier
150
+
151
+ module_name = obj.__module__
152
+
153
+ # If the module is actually a torchbind module, then we should short circuit
154
+ if module_name == "torch._classes":
155
+ return obj.qualified_name
156
+
157
+ # The Python docs are very clear that `__module__` can be None, but I can't
158
+ # figure out when it actually would be.
159
+ if module_name is None:
160
+ raise RuntimeError(
161
+ f"Could not get qualified name for class '{name}': "
162
+ "__module__ can't be None."
163
+ )
164
+
165
+ # if getattr(sys.modules[module_name], name) is not obj:
166
+ # raise RuntimeError(f"Could not get qualified name for class '{name}': "
167
+ # f"the attr {name} on module {module_name} is not the class")
168
+
169
+ # torch.package and TorchScript have separate mangling schemes to avoid
170
+ # name collisions from multiple packages. To avoid them interfering with
171
+ # each other, normalize the package manging here.
172
+ if package_mangling.is_mangled(module_name):
173
+ module_name = module_name.replace("<", "_")
174
+ module_name = module_name.replace(">", "_")
175
+
176
+ # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
177
+ # does not need mangle the python class name.
178
+ if mangle_name:
179
+ # __main__ is a builtin module, so rewrite it to "__torch__".
180
+ if module_name == "__main__":
181
+ module_name = "__torch__"
182
+ else:
183
+ # Everything else gets a "__torch__" prefix to avoid name collisions
184
+ # with the names of user values.
185
+ module_name = "__torch__." + module_name
186
+
187
+ if "." in name:
188
+ raise RuntimeError(
189
+ f"Could not get qualified name for class '{name}': "
190
+ f"'{name}' is not a valid identifier"
191
+ )
192
+
193
+ return module_name + "." + name
194
+
195
+
196
+ class SourceLoader:
197
+ def __init__(self):
198
+ self.content = {}
199
+
200
+ def cache(self, fn, source):
201
+ self.content[fn] = source
202
+
203
+ def get_source(self, fn):
204
+ return self.content.get(fn)
205
+
206
+
207
+ loader = SourceLoader()
208
+
209
+
210
+ def createResolutionCallbackFromEnv(lookup_base):
211
+ """
212
+ Creates a resolution callback that will look up qualified names in an
213
+ environment, starting with `lookup_base` for the base of any qualified
214
+ names, then proceeding down the lookup chain with the resolved object.
215
+
216
+ You should not use this directly, it should only be used from the other
217
+ createResolutionCallbackFrom* functions.
218
+ """
219
+
220
+ def lookupInModule(qualified_name, module):
221
+ if "." in qualified_name:
222
+ base, remaining_pieces = qualified_name.split(".", maxsplit=1)
223
+ module_value = getattr(module, base)
224
+ return lookupInModule(remaining_pieces, module_value)
225
+ else:
226
+ return getattr(module, qualified_name)
227
+
228
+ def parseNestedExpr(expr, module) -> Tuple[Any, int]:
229
+ i = 0
230
+ while i < len(expr) and expr[i] not in (",", "[", "]"):
231
+ i += 1
232
+
233
+ # Special case logic for the empty Tuple as a subscript (used
234
+ # in the type annotation `Tuple[()]`)
235
+ if expr[:i] == "()":
236
+ return (), i
237
+
238
+ base = lookupInModule(expr[:i].strip(), module)
239
+ assert base is not None, f"Unresolvable type {expr[:i]}"
240
+ if i == len(expr) or expr[i] != "[":
241
+ return base, i
242
+
243
+ assert expr[i] == "["
244
+ parts = []
245
+ while expr[i] != "]":
246
+ part_len = 0
247
+ i += 1
248
+ part, part_len = parseNestedExpr(expr[i:], module)
249
+ parts.append(part)
250
+ i += part_len
251
+ if len(parts) > 1:
252
+ return base[tuple(parts)], i + 1
253
+ else:
254
+ return base[parts[0]], i + 1
255
+
256
+ def parseExpr(expr, module):
257
+ try:
258
+ value, len_parsed = parseNestedExpr(expr, module)
259
+ assert len_parsed == len(
260
+ expr
261
+ ), "whole expression was not parsed, falling back to c++ parser"
262
+ return value
263
+ except Exception:
264
+ """
265
+ The python resolver fails in several cases in known unit tests, and is intended
266
+ to fall back gracefully to the c++ resolver in general. For example, python 2 style
267
+ annotations which are frequent in our unit tests often fail with types e.g. int not
268
+ resolvable from the calling frame.
269
+ """
270
+ return None
271
+
272
+ return lambda expr: parseExpr(expr, lookup_base)
273
+
274
+
275
+ def createResolutionCallbackFromFrame(frames_up: int = 0):
276
+ """
277
+ Creates a function which, given a string variable name,
278
+ returns the value of the variable in the scope of the caller of
279
+ the function which called createResolutionCallbackFromFrame (by default).
280
+
281
+ This is used to enable access in-scope Python variables inside
282
+ TorchScript fragments.
283
+
284
+ frames_up is number of additional frames to go up on the stack.
285
+ The default value is 0, which correspond to the frame of the caller
286
+ of createResolutionCallbackFromFrame. Also for example, if frames_up is set
287
+ to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
288
+ will be taken.
289
+
290
+ For example, the following program prints 2::
291
+
292
+ def bar():
293
+ cb = createResolutionCallbackFromFrame(1)
294
+ print(cb("foo"))
295
+
296
+
297
+ def baz():
298
+ foo = 2
299
+ bar()
300
+
301
+
302
+ baz()
303
+ """
304
+ frame = inspect.currentframe()
305
+ i = 0
306
+ while i < frames_up + 1:
307
+ assert frame is not None
308
+ frame = frame.f_back
309
+ i += 1
310
+
311
+ assert frame is not None
312
+ f_locals = frame.f_locals
313
+ f_globals = frame.f_globals
314
+
315
+ class env:
316
+ def __getattr__(self, key):
317
+ if key in f_locals:
318
+ return f_locals[key]
319
+ elif key in f_globals:
320
+ return f_globals[key]
321
+ elif key in dir(builtins):
322
+ return getattr(builtins, key)
323
+
324
+ return createResolutionCallbackFromEnv(env())
325
+
326
+
327
+ def get_closure(fn):
328
+ """
329
+ Get a dictionary of closed over variables from a function
330
+ """
331
+ captures = {}
332
+ captures.update(fn.__globals__)
333
+
334
+ for index, captured_name in enumerate(fn.__code__.co_freevars):
335
+ captures[captured_name] = fn.__closure__[index].cell_contents
336
+
337
+ return captures
338
+
339
+
340
+ # [local resolution in python]
341
+ # Depending on where a variable is defined, and where it is used, we may
342
+ # or may not be able to recover its value when recursively compiling a
343
+ # script function. Remember in the general case, a module or function is
344
+ # first defined and then later scripted. This means we do not have a
345
+ # chance to capture the active frames when the function is defined. Hence any
346
+ # name resolution has to happen later on the created closure. The way
347
+ # python captures type annotations restricts what we can recover. The
348
+ # follow example illustrates the different cases:
349
+ #
350
+ # class MyGlobalClass:
351
+ # ...
352
+ # def my_local_scope():
353
+ # @torch.jit.script
354
+ # class MyClass:
355
+ # ...
356
+ # @torch.jit.script
357
+ # class MyClassUsedAsVar:
358
+ # ...
359
+ # def eg(x: MyClass, y: MyGlobalClass):
360
+ # a_local_capture : Foo
361
+ # return MyClassUsedAsVar(x)
362
+ #
363
+ # MyGlobalClass is defined in the __globals__ dictionary of function
364
+ # 'eg', so it is always recoverable. my_local_scope introduces a new local
365
+ # variable scope in the function. Classes defined here are only visible as
366
+ # local variables. For the case of MyClassUsedAsVar, it is captured
367
+ # because it is used as a variable inside the body of the function, and we
368
+ # can resolve it using the captures returned from `get_closure`. However,
369
+ # the type annotations are not captured by the closure. In Python
370
+ # 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as
371
+ # annotations on `eg``, but starting in Python 4.0, they will represented as
372
+ # strings and no longer present. Furthermore, since the body of `eg` does
373
+ # not reference those names, they do not appear in the list of closed over
374
+ # variables. In Python 2.x, type annotations are in comments, leading to a
375
+ # similar situation where their definitions are not available. We anticipate
376
+ # that most users will not run into this issue because their modules and
377
+ # functions will be defined at a global scope like MyGlobalClass. In cases
378
+ # where they are not, it is possible to work around issues by declaring the
379
+ # values global in the function.
380
+ # In Python 3.9 declaring class as global will make it invisible to
381
+ # `inspect.getsource`, see https://bugs.python.org/issue42666 .
382
+ # This could be worked around by manualy adding it to `global()` dictionary.
383
+
384
+
385
+ def createResolutionCallbackFromClosure(fn):
386
+ """
387
+ Create a resolutionCallback by introspecting the function instead of
388
+ looking up the stack for the enclosing scope
389
+ """
390
+ closure = get_closure(fn)
391
+
392
+ class closure_lookup:
393
+ # This is a class since `closure` is a dict and it's easier in
394
+ # `env_helper` if everything just works with `getattr` calls
395
+ def __getattr__(self, key):
396
+ if key in closure:
397
+ return closure[key]
398
+ elif hasattr(typing, key):
399
+ return getattr(typing, key)
400
+ elif hasattr(builtins, key):
401
+ return getattr(builtins, key)
402
+ return None
403
+
404
+ return createResolutionCallbackFromEnv(closure_lookup())
405
+
406
+
407
+ def can_compile_class(cls) -> bool:
408
+ # If any of the functions on a type don't have a code object, this type can't
409
+ # be compiled and is probably a builtin / bound from C
410
+ if is_ignored_fn(cls):
411
+ return False
412
+
413
+ # Ignore the following list of built-in classes.
414
+ ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
415
+ if issubclass(cls, ignored_builtin_classes):
416
+ return False
417
+
418
+ names = cls.__dict__
419
+ fns = [
420
+ getattr(cls, name)
421
+ for name in names
422
+ if inspect.isroutine(getattr(cls, name, None))
423
+ ]
424
+ has_code = [hasattr(fn, "__code__") for fn in fns]
425
+ return all(has_code)
426
+
427
+
428
+ def get_callable_argument_names(fn) -> List[str]:
429
+ """
430
+ Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
431
+ Returns an empty list when other types of arguments are present.
432
+
433
+ This is used by `torch.jit.trace` to assign meaningful argument names to
434
+ traced functions and modules.
435
+
436
+ Args:
437
+ fn: A callable.
438
+ Returns:
439
+ Argument names: List[str]
440
+ """
441
+ # inspect.signature may fail, give up in that case.
442
+ try:
443
+ callable_signature = inspect.signature(fn)
444
+ except Exception:
445
+ return []
446
+
447
+ argument_names = []
448
+ for name, param in callable_signature.parameters.items():
449
+ # All four other types of arguments do not map to individual values
450
+ # with a keyword as name.
451
+ if not param.kind == param.POSITIONAL_OR_KEYWORD:
452
+ continue
453
+
454
+ argument_names.append(name)
455
+
456
+ return argument_names
457
+
458
+
459
+ def get_annotation_str(annotation):
460
+ """
461
+ Convert an AST node containing a type annotation to the string present in the source
462
+ that represents the same annotation.
463
+ """
464
+ if isinstance(annotation, ast.Name):
465
+ return annotation.id
466
+ elif isinstance(annotation, ast.Attribute):
467
+ return ".".join([get_annotation_str(annotation.value), annotation.attr])
468
+ elif isinstance(annotation, ast.Subscript):
469
+ # In Python3.9+ subscript indicies are not wrapped in ast.Index
470
+ subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined]
471
+ return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
472
+ elif isinstance(annotation, ast.Tuple):
473
+ return ",".join([get_annotation_str(elt) for elt in annotation.elts])
474
+ elif isinstance(annotation, ast.Constant):
475
+ return f"{annotation.value}"
476
+
477
+ # If an AST node is not handled here, it's probably handled in ScriptTypeParser.
478
+ return None
479
+
480
+
481
+ def get_type_hint_captures(fn):
482
+ """
483
+ Get a dictionary containing type resolution mappings necessary to resolve types
484
+ for the literal annotations on 'fn'. These are not considered to be closed-over by fn
485
+ and must be obtained separately (e.g. using this function).
486
+
487
+ Args:
488
+ fn: A callable.
489
+ Returns:
490
+ A Dict[str, Any] containing a mapping from the literal annotations used on
491
+ fn to the Python objects they refer to.
492
+ """
493
+ # First, try to get the source of the function. We'll need to parse it to find the actual string names
494
+ # that were used to annotate the types, since inspect.signature() will only return the class object that
495
+ # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
496
+ # This may happen in cases where the function is synthesized dynamically at runtime.
497
+ src = loader.get_source(fn)
498
+ if src is None:
499
+ try:
500
+ src = inspect.getsource(fn)
501
+ except OSError as e:
502
+ raise OSError(
503
+ f"Failed to get source for {fn} using inspect.getsource"
504
+ ) from e
505
+
506
+ # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
507
+ # types are strings. These are only understood by TorchScript in the context of a type annotation
508
+ # that refers to a class in its own definition, but trying to include a mapping for this in the result
509
+ # function would cause infinite recursion because the class is currently being compiled.
510
+ # In addition, there is logic in ScriptTypeParser to handle this.
511
+ signature = inspect.signature(fn)
512
+ name_to_type = {
513
+ name: parameter.annotation
514
+ for name, parameter in signature.parameters.items()
515
+ if parameter.annotation is not inspect.Parameter.empty
516
+ and not isinstance(parameter.annotation, str)
517
+ }
518
+
519
+ # Then, get the literal type annotations from the function declaration
520
+ # by source inspection. This accounts for the case in which aliases are used
521
+ # to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
522
+ # frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
523
+ a = ast.parse(textwrap.dedent(src))
524
+ if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
525
+ raise RuntimeError(f"Expected {fn} to be a function")
526
+ f = a.body[0]
527
+
528
+ # Prepare a dictionary of source annotation -> type, which will be the final result of this function,
529
+ # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping
530
+ # them to the type object corresponding to the annotation via name_to_type using the parameter name.
531
+ annotation_to_type = {}
532
+
533
+ for arg in f.args.args:
534
+ # Get the source type annotation string for this argument if possible.
535
+ arg_annotation_str = (
536
+ get_annotation_str(arg.annotation) if arg.annotation else None
537
+ )
538
+
539
+ # If the argument has no annotation or get_annotation_str cannot convert it to a string,
540
+ # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle
541
+ # this in the latter case.
542
+ if arg_annotation_str is None:
543
+ continue
544
+
545
+ # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not
546
+ # be present in name_to_type is that the annotation itself is a string and not a type object
547
+ # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this.
548
+ arg_name = arg.arg
549
+ if arg_name in name_to_type:
550
+ annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
551
+
552
+ # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations,
553
+ # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type
554
+ # of the annotation cannot be a string.
555
+ literal_return_annotation = get_annotation_str(f.returns)
556
+ valid_literal_annotation = literal_return_annotation is not None
557
+ return_annotation = signature.return_annotation
558
+ valid_return_annotation_type = (
559
+ return_annotation is not inspect.Parameter.empty
560
+ and not isinstance(return_annotation, str)
561
+ )
562
+ if valid_literal_annotation and valid_return_annotation_type:
563
+ annotation_to_type[literal_return_annotation] = return_annotation
564
+
565
+ return annotation_to_type
566
+
567
+
568
+ def createResolutionCallbackForClassMethods(cls):
569
+ """
570
+ This looks at all the methods defined in a class and pulls their closed-over
571
+ variables into a dictionary and uses that to resolve variables.
572
+ """
573
+ # cls is a type here, so `ismethod` is false since the methods on the type
574
+ # aren't bound to anything, so Python treats them as regular functions
575
+ fns = [
576
+ getattr(cls, name)
577
+ for name in cls.__dict__
578
+ if inspect.isroutine(getattr(cls, name))
579
+ ]
580
+ # Skip built-ins, as they do not have global scope nor type hints
581
+ # Needed to support `enum.Enum` derived classes in Python-3.11
582
+ # That adds `_new_member_` property which is an alias to `__new__`
583
+ fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
584
+ captures = {}
585
+
586
+ for fn in fns:
587
+ captures.update(get_closure(fn))
588
+ captures.update(get_type_hint_captures(fn))
589
+
590
+ def lookup_in_class(key):
591
+ if key in captures:
592
+ return captures[key]
593
+ else:
594
+ return getattr(builtins, key, None)
595
+
596
+ return lookup_in_class
597
+
598
+
599
+ def boolean_dispatch(
600
+ arg_name,
601
+ arg_index,
602
+ default,
603
+ if_true,
604
+ if_false,
605
+ module_name,
606
+ func_name,
607
+ ):
608
+ """
609
+ Dispatches to either of 2 script functions based on a boolean argument.
610
+ In TorchScript, the boolean argument must be constant so that the correct
611
+ function to use can be determined at compile time.
612
+ """
613
+
614
+ def fn(*args, **kwargs):
615
+ dispatch_flag = default
616
+ if arg_name in kwargs:
617
+ dispatch_flag = kwargs[arg_name]
618
+ elif arg_index < len(args):
619
+ dispatch_flag = args[arg_index]
620
+
621
+ if dispatch_flag:
622
+ return if_true(*args, **kwargs)
623
+ else:
624
+ return if_false(*args, **kwargs)
625
+
626
+ if if_true.__doc__ is None and if_false.__doc__ is not None:
627
+ doc = if_false.__doc__
628
+ if_true.__doc__ = doc
629
+ elif if_false.__doc__ is None and if_true.__doc__ is not None:
630
+ doc = if_true.__doc__
631
+ if_false.__doc__ = doc
632
+ elif if_false.__doc__ is None and if_true.__doc__ is None:
633
+ # neither function has a docstring
634
+ doc = None
635
+ else:
636
+ raise RuntimeError("only one function can have a docstring")
637
+ fn.__doc__ = doc
638
+
639
+ if module_name is not None:
640
+ fn.__module__ = module_name
641
+ if func_name is not None:
642
+ fn.__name__ = func_name
643
+
644
+ boolean_dispatched[fn] = {
645
+ "if_true": if_true,
646
+ "if_false": if_false,
647
+ "index": arg_index,
648
+ "default": default,
649
+ "arg_name": arg_name,
650
+ }
651
+ return fn
652
+
653
+
654
+ class FunctionModifiers:
655
+ """
656
+ Used to denote the behavior of a function in TorchScript. See export() and
657
+ ignore() for details.
658
+ """
659
+
660
+ UNUSED = "unused (ignored and replaced with raising of an exception)"
661
+ IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
662
+ EXPORT = "export (compile this function even if nothing calls it)"
663
+ DEFAULT = "default (compile if called from a exported function / forward)"
664
+ COPY_TO_SCRIPT_WRAPPER = (
665
+ "if this method is not scripted, copy the python method onto the scripted model"
666
+ )
667
+ _DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
668
+
669
+
670
+ def export(fn):
671
+ """
672
+ This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
673
+ :class:`ScriptModule` and should be compiled.
674
+
675
+ ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
676
+ Functions and methods called from ``forward`` are compiled as they are seen
677
+ by the compiler, so they do not need this decorator either.
678
+
679
+ Example (using ``@torch.jit.export`` on a method):
680
+
681
+ .. testcode::
682
+
683
+ import torch
684
+ import torch.nn as nn
685
+
686
+ class MyModule(nn.Module):
687
+ def implicitly_compiled_method(self, x):
688
+ return x + 99
689
+
690
+ # `forward` is implicitly decorated with `@torch.jit.export`,
691
+ # so adding it here would have no effect
692
+ def forward(self, x):
693
+ return x + 10
694
+
695
+ @torch.jit.export
696
+ def another_forward(self, x):
697
+ # When the compiler sees this call, it will compile
698
+ # `implicitly_compiled_method`
699
+ return self.implicitly_compiled_method(x)
700
+
701
+ def unused_method(self, x):
702
+ return x - 20
703
+
704
+ # `m` will contain compiled methods:
705
+ # `forward`
706
+ # `another_forward`
707
+ # `implicitly_compiled_method`
708
+ # `unused_method` will not be compiled since it was not called from
709
+ # any compiled methods and wasn't decorated with `@torch.jit.export`
710
+ m = torch.jit.script(MyModule())
711
+ """
712
+ fn._torchscript_modifier = FunctionModifiers.EXPORT
713
+ return fn
714
+
715
+
716
+ def unused(fn):
717
+ """
718
+ This decorator indicates to the compiler that a function or method should
719
+ be ignored and replaced with the raising of an exception. This allows you
720
+ to leave code in your model that is not yet TorchScript compatible and still
721
+ export your model.
722
+
723
+ Example (using ``@torch.jit.unused`` on a method)::
724
+
725
+ import torch
726
+ import torch.nn as nn
727
+
728
+
729
+ class MyModule(nn.Module):
730
+ def __init__(self, use_memory_efficient):
731
+ super().__init__()
732
+ self.use_memory_efficient = use_memory_efficient
733
+
734
+ @torch.jit.unused
735
+ def memory_efficient(self, x):
736
+ import pdb
737
+
738
+ pdb.set_trace()
739
+ return x + 10
740
+
741
+ def forward(self, x):
742
+ # Use not-yet-scriptable memory efficient mode
743
+ if self.use_memory_efficient:
744
+ return self.memory_efficient(x)
745
+ else:
746
+ return x + 10
747
+
748
+
749
+ m = torch.jit.script(MyModule(use_memory_efficient=False))
750
+ m.save("m.pt")
751
+
752
+ m = torch.jit.script(MyModule(use_memory_efficient=True))
753
+ # exception raised
754
+ m(torch.rand(100))
755
+ """
756
+ if isinstance(fn, property):
757
+ prop = fn
758
+ setattr( # noqa: B010
759
+ prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
760
+ )
761
+
762
+ if prop.fset:
763
+ setattr( # noqa: B010
764
+ prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
765
+ )
766
+
767
+ return prop
768
+
769
+ fn._torchscript_modifier = FunctionModifiers.UNUSED
770
+ return fn
771
+
772
+
773
+ # No op context manager from python side
774
+ class _IgnoreContextManager(contextlib.AbstractContextManager):
775
+ def __init__(self, **kwargs):
776
+ pass
777
+
778
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
779
+ pass
780
+
781
+
782
+ def ignore(drop=False, **kwargs):
783
+ """
784
+ This decorator indicates to the compiler that a function or method should
785
+ be ignored and left as a Python function. This allows you to leave code in
786
+ your model that is not yet TorchScript compatible. If called from TorchScript,
787
+ ignored functions will dispatch the call to the Python interpreter. Models with ignored
788
+ functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
789
+
790
+ Example (using ``@torch.jit.ignore`` on a method)::
791
+
792
+ import torch
793
+ import torch.nn as nn
794
+
795
+
796
+ class MyModule(nn.Module):
797
+ @torch.jit.ignore
798
+ def debugger(self, x):
799
+ import pdb
800
+
801
+ pdb.set_trace()
802
+
803
+ def forward(self, x):
804
+ x += 10
805
+ # The compiler would normally try to compile `debugger`,
806
+ # but since it is `@ignore`d, it will be left as a call
807
+ # to Python
808
+ self.debugger(x)
809
+ return x
810
+
811
+
812
+ m = torch.jit.script(MyModule())
813
+
814
+ # Error! The call `debugger` cannot be saved since it calls into Python
815
+ m.save("m.pt")
816
+
817
+ Example (using ``@torch.jit.ignore(drop=True)`` on a method):
818
+
819
+ .. testcode::
820
+
821
+ import torch
822
+ import torch.nn as nn
823
+
824
+ class MyModule(nn.Module):
825
+ @torch.jit.ignore(drop=True)
826
+ def training_method(self, x):
827
+ import pdb
828
+ pdb.set_trace()
829
+
830
+ def forward(self, x):
831
+ if self.training:
832
+ self.training_method(x)
833
+ return x
834
+
835
+ m = torch.jit.script(MyModule())
836
+
837
+ # This is OK since `training_method` is not saved, the call is replaced
838
+ # with a `raise`.
839
+ m.save("m.pt")
840
+
841
+ .. testcleanup::
842
+
843
+ import os
844
+ os.remove('m.pt')
845
+ """
846
+
847
+ if callable(drop):
848
+ # used without any args, so drop is actually a function
849
+ # @torch.jit.ignore
850
+ # def fn(...):
851
+ fn = drop
852
+ fn._torchscript_modifier = FunctionModifiers.IGNORE
853
+ return fn
854
+
855
+ if not isinstance(drop, bool):
856
+ raise RuntimeError(
857
+ "Argument to @torch.jit.ignore must be a bool or "
858
+ f"a function but got {drop}"
859
+ )
860
+
861
+ # for backwards compat
862
+ drop_on_export = kwargs.pop("drop_on_export", None)
863
+ if drop_on_export:
864
+ warnings.warn(
865
+ "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
866
+ "call on compilation. Use torch.jit.unused now. {}",
867
+ category=FutureWarning,
868
+ )
869
+
870
+ drop = drop_on_export
871
+ elif drop:
872
+ warnings.warn(
873
+ "ignore(True) has been deprecated. TorchScript will now drop the function "
874
+ "call on compilation. Use torch.jit.unused now. {}",
875
+ category=FutureWarning,
876
+ )
877
+
878
+ def decorator(fn):
879
+ if drop:
880
+ fn._torchscript_modifier = FunctionModifiers.UNUSED
881
+ else:
882
+ fn._torchscript_modifier = FunctionModifiers.IGNORE
883
+ return fn
884
+
885
+ return decorator
886
+
887
+
888
+ def _drop(fn):
889
+ fn._torchscript_modifier = FunctionModifiers._DROP
890
+ return fn
891
+
892
+
893
+ def _copy_to_script_wrapper(fn):
894
+ fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
895
+ return fn
896
+
897
+
898
+ def module_has_exports(mod):
899
+ for name in dir(mod):
900
+ if hasattr(mod, name):
901
+ item = getattr(mod, name)
902
+ if callable(item):
903
+ if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
904
+ return True
905
+ return False
906
+
907
+
908
+ # WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you
909
+ # rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to
910
+ # allow JIT'd code to still be covered.
911
+ def should_drop(fn) -> bool:
912
+ attr = get_torchscript_modifier(fn)
913
+ if attr is None:
914
+ return False
915
+ return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
916
+
917
+
918
+ def is_ignored_fn(fn) -> bool:
919
+ mod = get_torchscript_modifier(fn)
920
+ return (
921
+ mod is FunctionModifiers.UNUSED
922
+ or mod is FunctionModifiers.IGNORE
923
+ or mod is FunctionModifiers._DROP
924
+ )
925
+
926
+
927
+ def _is_drop_fn(fn) -> bool:
928
+ mod = get_torchscript_modifier(fn)
929
+ return mod is FunctionModifiers._DROP
930
+
931
+
932
+ def is_static_fn(cls, fn) -> bool:
933
+ return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
934
+
935
+
936
+ def get_static_fn(cls, fn):
937
+ return inspect.getattr_static(cls, fn).__func__
938
+
939
+
940
+ def get_torchscript_modifier(fn):
941
+ if not callable(fn):
942
+ return None
943
+ if hasattr(fn, "__func__"):
944
+ fn = fn.__func__
945
+ return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
946
+
947
+
948
+ def copy_torchscript_modifier(orig, new) -> None:
949
+ attr = get_torchscript_modifier(orig)
950
+ if attr is None:
951
+ return
952
+ new._torchscript_modifier = attr
953
+
954
+
955
+ # overloading registration
956
+ # overloads get registered in this file, and compiled in torch/jit/__init__.py
957
+ # so that they can be imported in nn/functional.py without an import cycle
958
+
959
+ # qualified_name => list[overload_functions]
960
+ _overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484
961
+
962
+
963
+ _OVERLOAD_EXAMPLE = """
964
+ Example usage of overload function:
965
+ @torch.jit._overload
966
+ def my_function(x: type0) -> type0: # decl 1
967
+ pass
968
+
969
+ @torch.jit._overload
970
+ def my_function(x: type1) -> type1: # decl 2
971
+ pass
972
+
973
+ def my_function(x): # implementation
974
+ if isinstance(x, type0):
975
+ return x
976
+ elif isinstance(x, type1):
977
+ return x
978
+ """
979
+
980
+
981
+ def get_overload_no_implementation_error_message(kind, obj):
982
+ sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
983
+ return (
984
+ f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
985
+ f"sure a definition is provided and defined after all overload declarations.\n"
986
+ f'File "{filename}", line {file_lineno}:\n'
987
+ + "".join(sourcelines)
988
+ + "\n"
989
+ + _OVERLOAD_EXAMPLE
990
+ )
991
+
992
+
993
+ def _check_overload_body(func):
994
+ try:
995
+ parsed_def = parse_def(func)
996
+ except OSError as e:
997
+ # Parsing the function definition can raise an OSError if source is unavailable.
998
+ # Since this is just an initial check, just raise a warning if this is the case.
999
+ warnings.warn(
1000
+ f"Unable to retrieve source for @torch.jit._overload function: {func}."
1001
+ )
1002
+ return
1003
+
1004
+ body = parsed_def.ast.body[0].body
1005
+
1006
+ def is_pass(x):
1007
+ return isinstance(x, ast.Pass)
1008
+
1009
+ def is_ellipsis(x):
1010
+ return (
1011
+ isinstance(x, ast.Expr)
1012
+ and isinstance(x.value, ast.Constant)
1013
+ and x.value.value is Ellipsis
1014
+ )
1015
+
1016
+ if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
1017
+ msg = (
1018
+ "Only `pass` statement or `...` can be the body of overload declaration:\n"
1019
+ )
1020
+ msg += "\n".join(parsed_def.source.split("\n")[:3])
1021
+ msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
1022
+ raise RuntimeError(msg)
1023
+
1024
+
1025
+ def _overload(func):
1026
+ _check_overload_body(func)
1027
+ qual_name = _qualified_name(func)
1028
+ global _overloaded_fns
1029
+ fn_overload_list = _overloaded_fns.get(qual_name)
1030
+ if fn_overload_list is None:
1031
+ fn_overload_list = []
1032
+ _overloaded_fns[qual_name] = fn_overload_list
1033
+ fn_overload_list.append(func)
1034
+ return func
1035
+
1036
+
1037
+ def _get_fn_overloads(qual_name):
1038
+ return _overloaded_fns.get(qual_name)
1039
+
1040
+
1041
+ def _clear_fn_overloads(qual_name) -> None:
1042
+ del _overloaded_fns[qual_name]
1043
+
1044
+
1045
+ def get_class_name_lineno(method) -> Tuple[str, int]:
1046
+ current_frame = inspect.currentframe()
1047
+
1048
+ # one for the get_class_name call, one for _overload_method call
1049
+ for i in range(2):
1050
+ assert (
1051
+ current_frame is not None
1052
+ ) # assert current frame is not an Optional[FrameType]
1053
+ current_frame = current_frame.f_back
1054
+
1055
+ assert current_frame is not None # same here
1056
+ class_name = current_frame.f_code.co_name
1057
+ line_no = current_frame.f_code.co_firstlineno
1058
+ return class_name, line_no
1059
+
1060
+
1061
+ # At the point the decorator is applied to class methods the method
1062
+ # has no reference to its owning class. _qualified_name would not include
1063
+ # the class it is defined in, so any methods with the same name in the same file
1064
+ # would have the same _qualified_name, even if they were defined in different
1065
+ # classes. This problem only exists in python 2.
1066
+ # We get around this problem by looking at the stack frame and identifying
1067
+ # the class name, and throwing an error whenever overloads are used
1068
+ # when modules of the same name are in the same file
1069
+
1070
+ # qualified_name => class name => list[overload_functions]
1071
+ _overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484
1072
+
1073
+
1074
+ # (qualified_name, class name) => class_fileno
1075
+ _overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {}
1076
+
1077
+
1078
+ def _overload_method(func):
1079
+ _check_overload_body(func)
1080
+ qual_name = _qualified_name(func)
1081
+ global _overloaded_methods
1082
+ class_name_map = _overloaded_methods.get(qual_name, None)
1083
+ if class_name_map is None:
1084
+ class_name_map = {}
1085
+ _overloaded_methods[qual_name] = class_name_map
1086
+
1087
+ class_name, line_no = get_class_name_lineno(func)
1088
+ method_overloads = class_name_map.get(class_name, None)
1089
+ if method_overloads is None:
1090
+ method_overloads = []
1091
+ class_name_map[class_name] = method_overloads
1092
+ _overloaded_method_class_fileno[(qual_name, class_name)] = line_no
1093
+ else:
1094
+ existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
1095
+ if existing_lineno != line_no:
1096
+ raise RuntimeError(
1097
+ "Cannot currently overload the same method name in two different"
1098
+ " classes with the same name in the same module"
1099
+ )
1100
+
1101
+ method_overloads.append(func)
1102
+ return func
1103
+
1104
+
1105
+ def _get_overloaded_methods(method, mod_class):
1106
+ # TODO: __name__ not set for submodules in recursive script
1107
+ if not hasattr(method, "__name__"):
1108
+ return None
1109
+ qual_name = _qualified_name(method)
1110
+ class_name_map = _overloaded_methods.get(qual_name, None)
1111
+ if class_name_map is None:
1112
+ return None
1113
+ overloads = class_name_map.get(mod_class.__name__, None)
1114
+ if overloads is None:
1115
+ return None
1116
+
1117
+ method_line_no = get_source_lines_and_file(method)[1]
1118
+ mod_class_fileno = get_source_lines_and_file(mod_class)[1]
1119
+ mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
1120
+ if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
1121
+ raise AssertionError(
1122
+ "Overloads are not useable when a module is redeclared within the same file: "
1123
+ + str(method)
1124
+ )
1125
+ return overloads
1126
+
1127
+
1128
+ def is_tuple(ann) -> bool:
1129
+ if ann is Tuple:
1130
+ raise_error_container_parameter_missing("Tuple")
1131
+
1132
+ # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
1133
+ if not hasattr(ann, "__module__"):
1134
+ return False
1135
+
1136
+ ann_origin = get_origin(ann)
1137
+ if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
1138
+ return True
1139
+ return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
1140
+
1141
+
1142
+ def is_list(ann) -> bool:
1143
+ if ann is List:
1144
+ raise_error_container_parameter_missing("List")
1145
+
1146
+ if not hasattr(ann, "__module__"):
1147
+ return False
1148
+
1149
+ ann_origin = get_origin(ann)
1150
+ if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
1151
+ return True
1152
+ return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
1153
+
1154
+
1155
+ def is_dict(ann) -> bool:
1156
+ if ann is Dict:
1157
+ raise_error_container_parameter_missing("Dict")
1158
+
1159
+ if not hasattr(ann, "__module__"):
1160
+ return False
1161
+
1162
+ ann_origin = get_origin(ann)
1163
+ if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
1164
+ return True
1165
+ return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
1166
+
1167
+
1168
+ def is_union(ann):
1169
+ if ann is Union:
1170
+ raise_error_container_parameter_missing("Union")
1171
+
1172
+ return isinstance(ann, BuiltinUnionType) or (
1173
+ hasattr(ann, "__module__")
1174
+ and ann.__module__ == "typing"
1175
+ and (get_origin(ann) is Union)
1176
+ )
1177
+
1178
+
1179
+ def is_optional(ann):
1180
+ if ann is Optional:
1181
+ raise_error_container_parameter_missing("Optional")
1182
+
1183
+ def is_optional_as_optional(ann):
1184
+ return (
1185
+ hasattr(ann, "__module__")
1186
+ and ann.__module__ == "typing"
1187
+ and (get_origin(ann) is Optional)
1188
+ )
1189
+
1190
+ def is_union_as_optional(ann):
1191
+ ann_args = get_args(ann)
1192
+ return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
1193
+
1194
+ return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
1195
+
1196
+
1197
+ def is_future(ann) -> bool:
1198
+ if ann is Future:
1199
+ raise RuntimeError(
1200
+ "Attempted to use Future without a "
1201
+ "contained type. Please add a contained type, e.g. "
1202
+ "Future[int]"
1203
+ )
1204
+ return get_origin(ann) is Future
1205
+
1206
+
1207
+ def is_await(ann) -> bool:
1208
+ if ann is _Await:
1209
+ return True
1210
+ return get_origin(ann) is _Await
1211
+
1212
+
1213
+ if torch.distributed.rpc.is_available():
1214
+ from torch._C._distributed_rpc import PyRRef
1215
+ from torch.distributed.rpc import RRef
1216
+
1217
+ def is_rref(ann) -> bool:
1218
+ if ann is RRef:
1219
+ raise RuntimeError(
1220
+ "Attempted to use RRef without a "
1221
+ "contained type. Please add a contained type, e.g. "
1222
+ "RRef[int]"
1223
+ )
1224
+ return get_origin(ann) is RRef
1225
+
1226
+ def is_rref_instance(obj) -> bool:
1227
+ return isinstance(obj, PyRRef)
1228
+
1229
+ else:
1230
+
1231
+ def is_rref_instance(obj) -> bool:
1232
+ # If the RPC module doesn't exist then RRefs don't exist either.
1233
+ return False
1234
+
1235
+
1236
+ def _try_get_dispatched_fn(fn):
1237
+ if not callable(fn):
1238
+ return None
1239
+ return boolean_dispatched.get(fn)
1240
+
1241
+
1242
+ def _get_named_tuple_properties(
1243
+ obj,
1244
+ loc: Optional[torch._C._jit_tree_views.SourceRange] = None,
1245
+ rcb=None,
1246
+ ):
1247
+ if loc is None:
1248
+ loc = fake_range()
1249
+
1250
+ assert issubclass(obj, tuple) and hasattr(obj, "_fields")
1251
+ if hasattr(obj, "_field_defaults"):
1252
+ defaults = [
1253
+ obj._field_defaults[field]
1254
+ for field in obj._fields
1255
+ if field in obj._field_defaults
1256
+ ]
1257
+ else:
1258
+ defaults = []
1259
+ # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
1260
+ # Also, annotations from base class are not inherited so they need to be queried explicitly
1261
+ if sys.version_info[:2] < (3, 10):
1262
+ obj_annotations = getattr(obj, "__annotations__", {})
1263
+ else:
1264
+ obj_annotations = inspect.get_annotations(obj)
1265
+ if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
1266
+ obj_annotations = inspect.get_annotations(obj.__base__)
1267
+
1268
+ annotations = []
1269
+ for field in obj._fields:
1270
+ if field in obj_annotations:
1271
+ field_type = obj_annotations[field]
1272
+ # [Note: ForwardRef annotations in NamedTuple attributes]
1273
+ # NamedTuple types are slightly different from normal types.
1274
+ #
1275
+ # Normally, annotations are evaluted like this (during jit.script):
1276
+ # 1. Load strings of python code into c++ and parse.
1277
+ # 2. Get annotations as strings
1278
+ # 3. Use the PythonResolver's resolution callback (rcb) to convert
1279
+ # the string into a python object
1280
+ # 4. We call into annotations.py:ann_to_type to convert python obj
1281
+ # from step 3 into a type that torchscript understands.
1282
+ #
1283
+ # NamedTuples are more complicated, because it has sub-types.
1284
+ # Normally, once we have the NamedTuple type object from #3,
1285
+ # we can just look at the annotation literal values and use
1286
+ # ann_to_type directly on them.
1287
+ #
1288
+ # But sometimes, users will annotate with string literals, e.g.
1289
+ # x: 'int'
1290
+ # This also happens with PEP563 (from __forward__ import annotations)
1291
+ #
1292
+ # These annotations appear in the annotation dict as ForwardRef('int').
1293
+ #
1294
+ # Then, we need to convert the string into a python object. This
1295
+ # requires having local context for custom objects or imported types.
1296
+ # rcb() is what gives us this. So, we plumb rcb through the stack so
1297
+ # it can be used in this context for the if block below.
1298
+ #
1299
+ # FAQ:
1300
+ # - Why do we need this special handling for NamedTuple but string
1301
+ # annotations work fine for normal types? Normally, we parse the
1302
+ # string directly and then call rcb() directly from C++.
1303
+ # - Why not use ForwardRef._evaluate? For that, we need globals()
1304
+ # and locals() for the local context where the NamedTuple was defined.
1305
+ # rcb is what lets us look up into these. So, basically rcb does the
1306
+ # hard work for us.
1307
+ if isinstance(field_type, ForwardRef) and rcb is not None:
1308
+ rcb_type = rcb(field_type.__forward_arg__)
1309
+ # rcb returns None if it can't find anything.
1310
+ if rcb_type is None:
1311
+ raise ValueError(
1312
+ f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
1313
+ f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
1314
+ f" Issue occurred at {loc.highlight()}"
1315
+ )
1316
+ field_type = rcb_type
1317
+ the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
1318
+ annotations.append(the_type)
1319
+ else:
1320
+ annotations.append(torch._C.TensorType.getInferred())
1321
+ return type(obj).__name__, obj._fields, annotations, defaults
1322
+
1323
+
1324
+ def _create_named_tuple(
1325
+ t,
1326
+ unqual_name: str,
1327
+ field_names: List[str],
1328
+ defaults: Tuple[Any, ...],
1329
+ ):
1330
+ TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc]
1331
+ return TupleType(*t)
1332
+
1333
+
1334
+ @contextlib.contextmanager
1335
+ def _disable_emit_hooks():
1336
+ hooks = torch._C._jit_get_emit_hooks()
1337
+ torch._C._jit_set_emit_hooks(None, None)
1338
+ try:
1339
+ yield
1340
+ finally:
1341
+ torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
1342
+
1343
+
1344
+ def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811
1345
+ def __enter__(self) -> None:
1346
+ self.hooks = torch._C._jit_get_emit_hooks()
1347
+ torch._C._jit_set_emit_hooks(None, None)
1348
+
1349
+ def __exit__(self, *args) -> None:
1350
+ torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
1351
+
1352
+
1353
+ def _is_exception(obj) -> bool:
1354
+ if not inspect.isclass(obj):
1355
+ return False
1356
+ return issubclass(obj, Exception)
1357
+
1358
+
1359
+ def raise_error_container_parameter_missing(target_type) -> None:
1360
+ if target_type == "Dict":
1361
+ raise RuntimeError(
1362
+ "Attempted to use Dict without "
1363
+ "contained types. Please add contained type, e.g. "
1364
+ "Dict[int, int]"
1365
+ )
1366
+ raise RuntimeError(
1367
+ f"Attempted to use {target_type} without a "
1368
+ "contained type. Please add a contained type, e.g. "
1369
+ f"{target_type}[int]"
1370
+ )
1371
+
1372
+
1373
+ def check_args_exist(target_type) -> None:
1374
+ if target_type is List or target_type is list:
1375
+ raise_error_container_parameter_missing("List")
1376
+ elif target_type is Tuple or target_type is tuple:
1377
+ raise_error_container_parameter_missing("Tuple")
1378
+ elif target_type is Dict or target_type is dict:
1379
+ raise_error_container_parameter_missing("Dict")
1380
+ elif target_type is None or target_type is Optional:
1381
+ raise_error_container_parameter_missing("Optional")
1382
+
1383
+
1384
+ def check_empty_containers(obj) -> None:
1385
+ if obj == [] or obj == {} or obj == ():
1386
+ warnings.warn(
1387
+ "The inner type of a container is lost when "
1388
+ "calling torch.jit.isinstance in eager mode. For "
1389
+ "example, List[int] would become list and "
1390
+ "therefore falsely return True for List[float] or"
1391
+ " List[str]."
1392
+ )
1393
+
1394
+
1395
+ # supports List/Dict/Tuple and Optional types
1396
+ # TODO support future
1397
+ def container_checker(obj, target_type) -> bool:
1398
+ origin_type = get_origin(target_type)
1399
+ check_args_exist(target_type)
1400
+ if origin_type is None:
1401
+ return False
1402
+ elif origin_type is list or origin_type is List:
1403
+ check_empty_containers(obj)
1404
+ if not isinstance(obj, list):
1405
+ return False
1406
+ arg_type = get_args(target_type)[0]
1407
+ arg_origin = get_origin(arg_type)
1408
+ for el in obj:
1409
+ # check if nested container, ex: List[List[str]]
1410
+ if arg_origin: # processes nested container, ex: List[List[str]]
1411
+ if not container_checker(el, arg_type):
1412
+ return False
1413
+ elif not isinstance(el, arg_type):
1414
+ return False
1415
+ return True
1416
+ elif origin_type is Dict or origin_type is dict:
1417
+ check_empty_containers(obj)
1418
+ if not isinstance(obj, dict):
1419
+ return False
1420
+ key_type = get_args(target_type)[0]
1421
+ val_type = get_args(target_type)[1]
1422
+ for key, val in obj.items():
1423
+ # check if keys are of right type
1424
+ if not isinstance(key, key_type):
1425
+ return False
1426
+ val_origin = get_origin(val_type)
1427
+ if val_origin:
1428
+ if not container_checker(val, val_type):
1429
+ return False
1430
+ elif not isinstance(val, val_type):
1431
+ return False
1432
+ return True
1433
+ elif origin_type is Tuple or origin_type is tuple:
1434
+ check_empty_containers(obj)
1435
+ if not isinstance(obj, tuple):
1436
+ return False
1437
+ arg_types = get_args(target_type)
1438
+ if len(obj) != len(arg_types):
1439
+ return False
1440
+ for el, el_type in zip(obj, arg_types):
1441
+ el_origin = get_origin(el_type)
1442
+ if el_origin:
1443
+ if not container_checker(el, el_type):
1444
+ return False
1445
+ elif not isinstance(el, el_type):
1446
+ return False
1447
+ return True
1448
+ elif origin_type is Union or issubclass(
1449
+ origin_type, BuiltinUnionType
1450
+ ): # also handles Optional
1451
+ if obj is None: # check before recursion because None is always fine
1452
+ return True
1453
+ inner_types = get_args(target_type)
1454
+ for t in inner_types:
1455
+ t_origin = get_origin(t)
1456
+ if t_origin:
1457
+ return container_checker(obj, t)
1458
+ elif isinstance(obj, t):
1459
+ return True
1460
+ return False
1461
+
1462
+
1463
+ def _isinstance(obj, target_type) -> bool:
1464
+ if isinstance(target_type, collections.abc.Container):
1465
+ if not isinstance(target_type, tuple):
1466
+ raise RuntimeError(
1467
+ "The second argument to "
1468
+ "`torch.jit.isinstance` must be a type "
1469
+ "or a tuple of types"
1470
+ )
1471
+ for t_type in target_type:
1472
+ if _isinstance(obj, t_type):
1473
+ return True
1474
+ return False
1475
+
1476
+ origin_type = get_origin(target_type)
1477
+ if origin_type:
1478
+ return container_checker(obj, target_type)
1479
+
1480
+ # Check to handle non-typed optional origin returns as none instead
1481
+ # of as optional in 3.7-3.8
1482
+ check_args_exist(target_type)
1483
+
1484
+ # handle non-containers
1485
+ return isinstance(obj, target_type)
1486
+
1487
+
1488
+ class _TensorExtractor(pickle.Pickler):
1489
+ def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
1490
+ super().__init__(*args, **kwargs)
1491
+ self.tensors = tensors
1492
+
1493
+ def persistent_id(self, obj):
1494
+ if isinstance(obj, torch.Tensor):
1495
+ self.tensors.append(obj)
1496
+ return ""
1497
+ # Since we just want to extract tensors, we don't mind if an object is
1498
+ # unpicklable if it doesn't contain tensors, as we can just ignore/skip
1499
+ # it. To play it safe, we only do so for common objects that we're sure
1500
+ # don't contain tensors. Feel free to add new types here. Note also that
1501
+ # even if a type isn't listed here this won't block users, since thet
1502
+ # can just add a __getstate__ or __reduce__ method to their class.
1503
+ if isinstance(obj, LockType):
1504
+ return ""
1505
+ # Futures and RRefs don't technically contain a value, they just offer
1506
+ # the means to access a value.
1507
+ if isinstance(obj, CFuture) or is_rref_instance(obj):
1508
+ return ""
1509
+ if isinstance(obj, CAwait):
1510
+ return ""
1511
+ if isinstance(obj, torch.cuda.Event):
1512
+ return ""
1513
+ if isinstance(obj, threading.Thread):
1514
+ return ""
1515
+ return None
1516
+
1517
+
1518
+ def _extract_tensors(obj):
1519
+ r"""
1520
+ This function is exclusively called from C++.
1521
+ See ``torch/csrc/jit/python/python_ivalue.h``.
1522
+
1523
+ It extracts the tensors contained in the given object, through pickling.
1524
+ """
1525
+ tensors: List[torch.Tensor] = []
1526
+ extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
1527
+ extractor.dump(obj)
1528
+ return tensors
1529
+
1530
+
1531
+ def _get_model_id(obj) -> Optional[str]:
1532
+ if isinstance(obj, torch.jit.ScriptModule):
1533
+ return str(obj._c._type())
1534
+ elif isinstance(obj, torch.jit.ScriptFunction):
1535
+ return obj.qualified_name
1536
+ else:
1537
+ return None
1538
+
1539
+
1540
+ # In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
1541
+ # that were previously dropped. To preserve the behavior, explicitly drop them there
1542
+
1543
+ if sys.version_info > (3, 10):
1544
+ _drop(enum.Enum.__new__)
1545
+ _drop(enum.Enum.__format__)
1546
+ _drop(enum.Enum.__repr__)
1547
+ _drop(enum.Enum.__str__)
.venv/lib/python3.11/site-packages/torch/_linalg_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """Various linear algebra utility methods for internal use."""
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+
10
+ def is_sparse(A):
11
+ """Check if tensor A is a sparse tensor"""
12
+ if isinstance(A, torch.Tensor):
13
+ return A.layout == torch.sparse_coo
14
+
15
+ error_str = "expected Tensor"
16
+ if not torch.jit.is_scripting():
17
+ error_str += f" but got {type(A)}"
18
+ raise TypeError(error_str)
19
+
20
+
21
+ def get_floating_dtype(A):
22
+ """Return the floating point dtype of tensor A.
23
+
24
+ Integer types map to float32.
25
+ """
26
+ dtype = A.dtype
27
+ if dtype in (torch.float16, torch.float32, torch.float64):
28
+ return dtype
29
+ return torch.float32
30
+
31
+
32
+ def matmul(A: Optional[Tensor], B: Tensor) -> Tensor:
33
+ """Multiply two matrices.
34
+
35
+ If A is None, return B. A can be sparse or dense. B is always
36
+ dense.
37
+ """
38
+ if A is None:
39
+ return B
40
+ if is_sparse(A):
41
+ return torch.sparse.mm(A, B)
42
+ return torch.matmul(A, B)
43
+
44
+
45
+ def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor:
46
+ """Return bilinear form of matrices: :math:`X^T A Y`."""
47
+ return matmul(X.mT, matmul(A, Y))
48
+
49
+
50
+ def qform(A: Optional[Tensor], S: Tensor):
51
+ """Return quadratic form :math:`S^T A S`."""
52
+ return bform(S, A, S)
53
+
54
+
55
+ def basis(A):
56
+ """Return orthogonal basis of A columns."""
57
+ return torch.linalg.qr(A).Q
58
+
59
+
60
+ def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]:
61
+ """Return eigenpairs of A with specified ordering."""
62
+ if largest is None:
63
+ largest = False
64
+ E, Z = torch.linalg.eigh(A, UPLO="U")
65
+ # assuming that E is ordered
66
+ if largest:
67
+ E = torch.flip(E, dims=(-1,))
68
+ Z = torch.flip(Z, dims=(-1,))
69
+ return E, Z
70
+
71
+
72
+ # These functions were deprecated and removed
73
+ # This nice error message can be removed in version 1.13+
74
+ def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor:
75
+ raise RuntimeError(
76
+ "This function was deprecated since version 1.9 and is now removed.\n"
77
+ "Please use the `torch.linalg.matrix_rank` function instead. "
78
+ "The parameter 'symmetric' was renamed in `torch.linalg.matrix_rank()` to 'hermitian'."
79
+ )
80
+
81
+
82
+ def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
83
+ raise RuntimeError(
84
+ "This function was deprecated since version 1.9 and is now removed. "
85
+ "`torch.solve` is deprecated in favor of `torch.linalg.solve`. "
86
+ "`torch.linalg.solve` has its arguments reversed and does not return the LU factorization.\n\n"
87
+ "To get the LU factorization see `torch.lu`, which can be used with `torch.lu_solve` or `torch.lu_unpack`.\n"
88
+ "X = torch.solve(B, A).solution "
89
+ "should be replaced with:\n"
90
+ "X = torch.linalg.solve(A, B)"
91
+ )
92
+
93
+
94
+ def lstsq(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
95
+ raise RuntimeError(
96
+ "This function was deprecated since version 1.9 and is now removed. "
97
+ "`torch.lstsq` is deprecated in favor of `torch.linalg.lstsq`.\n"
98
+ "`torch.linalg.lstsq` has reversed arguments and does not return the QR decomposition in "
99
+ "the returned tuple (although it returns other information about the problem).\n\n"
100
+ "To get the QR decomposition consider using `torch.linalg.qr`.\n\n"
101
+ "The returned solution in `torch.lstsq` stored the residuals of the solution in the "
102
+ "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, "
103
+ "the residuals are in the field 'residuals' of the returned named tuple.\n\n"
104
+ "The unpacking of the solution, as in\n"
105
+ "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n"
106
+ "should be replaced with:\n"
107
+ "X = torch.linalg.lstsq(A, B).solution"
108
+ )
109
+
110
+
111
+ def _symeig(
112
+ input,
113
+ eigenvectors=False,
114
+ upper=True,
115
+ *,
116
+ out=None,
117
+ ) -> Tuple[Tensor, Tensor]:
118
+ raise RuntimeError(
119
+ "This function was deprecated since version 1.9 and is now removed. "
120
+ "The default behavior has changed from using the upper triangular portion of the matrix by default "
121
+ "to using the lower triangular portion.\n\n"
122
+ "L, _ = torch.symeig(A, upper=upper) "
123
+ "should be replaced with:\n"
124
+ "L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n\n"
125
+ "and\n\n"
126
+ "L, V = torch.symeig(A, eigenvectors=True) "
127
+ "should be replaced with:\n"
128
+ "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')"
129
+ )
130
+
131
+
132
+ def eig(
133
+ self: Tensor,
134
+ eigenvectors: bool = False,
135
+ *,
136
+ e=None,
137
+ v=None,
138
+ ) -> Tuple[Tensor, Tensor]:
139
+ raise RuntimeError(
140
+ "This function was deprecated since version 1.9 and is now removed. "
141
+ "`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` rather than real tensors "
142
+ "mimicking complex tensors.\n\n"
143
+ "L, _ = torch.eig(A) "
144
+ "should be replaced with:\n"
145
+ "L_complex = torch.linalg.eigvals(A)\n\n"
146
+ "and\n\n"
147
+ "L, V = torch.eig(A, eigenvectors=True) "
148
+ "should be replaced with:\n"
149
+ "L_complex, V_complex = torch.linalg.eig(A)"
150
+ )
.venv/lib/python3.11/site-packages/torch/_lobpcg.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """Locally Optimal Block Preconditioned Conjugate Gradient methods."""
3
+ # Author: Pearu Peterson
4
+ # Created: February 2020
5
+
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import torch
9
+ from torch import _linalg_utils as _utils, Tensor
10
+ from torch.overrides import handle_torch_function, has_torch_function
11
+
12
+
13
+ __all__ = ["lobpcg"]
14
+
15
+
16
+ def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
17
+ # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
18
+ F = D.unsqueeze(-2) - D.unsqueeze(-1)
19
+ F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
20
+ F.pow_(-1)
21
+
22
+ # A.grad = U (D.grad + (U^T U.grad * F)) U^T
23
+ Ut = U.mT.contiguous()
24
+ res = torch.matmul(
25
+ U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
26
+ )
27
+
28
+ return res
29
+
30
+
31
+ def _polynomial_coefficients_given_roots(roots):
32
+ """
33
+ Given the `roots` of a polynomial, find the polynomial's coefficients.
34
+
35
+ If roots = (r_1, ..., r_n), then the method returns
36
+ coefficients (a_0, a_1, ..., a_n (== 1)) so that
37
+ p(x) = (x - r_1) * ... * (x - r_n)
38
+ = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
39
+
40
+ Note: for better performance requires writing a low-level kernel
41
+ """
42
+ poly_order = roots.shape[-1]
43
+ poly_coeffs_shape = list(roots.shape)
44
+ # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
45
+ # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
46
+ # but we insert one extra coefficient to enable better vectorization below
47
+ poly_coeffs_shape[-1] += 2
48
+ poly_coeffs = roots.new_zeros(poly_coeffs_shape)
49
+ poly_coeffs[..., 0] = 1
50
+ poly_coeffs[..., -1] = 1
51
+
52
+ # perform the Horner's rule
53
+ for i in range(1, poly_order + 1):
54
+ # note that it is computationally hard to compute backward for this method,
55
+ # because then given the coefficients it would require finding the roots and/or
56
+ # calculating the sensitivity based on the Vieta's theorem.
57
+ # So the code below tries to circumvent the explicit root finding by series
58
+ # of operations on memory copies imitating the Horner's method.
59
+ # The memory copies are required to construct nodes in the computational graph
60
+ # by exploting the explicit (not in-place, separate node for each step)
61
+ # recursion of the Horner's method.
62
+ # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
63
+ poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
64
+ out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
65
+ out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
66
+ -1, poly_order - i + 1, i + 1
67
+ )
68
+ poly_coeffs = poly_coeffs_new
69
+
70
+ return poly_coeffs.narrow(-1, 1, poly_order + 1)
71
+
72
+
73
+ def _polynomial_value(poly, x, zero_power, transition):
74
+ """
75
+ A generic method for computing poly(x) using the Horner's rule.
76
+
77
+ Args:
78
+ poly (Tensor): the (possibly batched) 1D Tensor representing
79
+ polynomial coefficients such that
80
+ poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
81
+ poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
82
+
83
+ x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
84
+
85
+ zero_power (Tensor): the representation of `x^0`. It is application-specific.
86
+
87
+ transition (Callable): the function that accepts some intermediate result `int_val`,
88
+ the `x` and a specific polynomial coefficient
89
+ `poly[..., k]` for some iteration `k`.
90
+ It basically performs one iteration of the Horner's rule
91
+ defined as `x * int_val + poly[..., k] * zero_power`.
92
+ Note that `zero_power` is not a parameter,
93
+ because the step `+ poly[..., k] * zero_power` depends on `x`,
94
+ whether it is a vector, a matrix, or something else, so this
95
+ functionality is delegated to the user.
96
+ """
97
+
98
+ res = zero_power.clone()
99
+ for k in range(poly.size(-1) - 2, -1, -1):
100
+ res = transition(res, x, poly[..., k])
101
+ return res
102
+
103
+
104
+ def _matrix_polynomial_value(poly, x, zero_power=None):
105
+ """
106
+ Evaluates `poly(x)` for the (batched) matrix input `x`.
107
+ Check out `_polynomial_value` function for more details.
108
+ """
109
+
110
+ # matrix-aware Horner's rule iteration
111
+ def transition(curr_poly_val, x, poly_coeff):
112
+ res = x.matmul(curr_poly_val)
113
+ res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
114
+ return res
115
+
116
+ if zero_power is None:
117
+ zero_power = torch.eye(
118
+ x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
119
+ ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
120
+
121
+ return _polynomial_value(poly, x, zero_power, transition)
122
+
123
+
124
+ def _vector_polynomial_value(poly, x, zero_power=None):
125
+ """
126
+ Evaluates `poly(x)` for the (batched) vector input `x`.
127
+ Check out `_polynomial_value` function for more details.
128
+ """
129
+
130
+ # vector-aware Horner's rule iteration
131
+ def transition(curr_poly_val, x, poly_coeff):
132
+ res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
133
+ return res
134
+
135
+ if zero_power is None:
136
+ zero_power = x.new_ones(1).expand(x.shape)
137
+
138
+ return _polynomial_value(poly, x, zero_power, transition)
139
+
140
+
141
+ def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
142
+ # compute a projection operator onto an orthogonal subspace spanned by the
143
+ # columns of U defined as (I - UU^T)
144
+ Ut = U.mT.contiguous()
145
+ proj_U_ortho = -U.matmul(Ut)
146
+ proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
147
+
148
+ # compute U_ortho, a basis for the orthogonal complement to the span(U),
149
+ # by projecting a random [..., m, m - k] matrix onto the subspace spanned
150
+ # by the columns of U.
151
+ #
152
+ # fix generator for determinism
153
+ gen = torch.Generator(A.device)
154
+
155
+ # orthogonal complement to the span(U)
156
+ U_ortho = proj_U_ortho.matmul(
157
+ torch.randn(
158
+ (*A.shape[:-1], A.size(-1) - D.size(-1)),
159
+ dtype=A.dtype,
160
+ device=A.device,
161
+ generator=gen,
162
+ )
163
+ )
164
+ U_ortho_t = U_ortho.mT.contiguous()
165
+
166
+ # compute the coefficients of the characteristic polynomial of the tensor D.
167
+ # Note that D is diagonal, so the diagonal elements are exactly the roots
168
+ # of the characteristic polynomial.
169
+ chr_poly_D = _polynomial_coefficients_given_roots(D)
170
+
171
+ # the code belows finds the explicit solution to the Sylvester equation
172
+ # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
173
+ # and incorporates it into the whole gradient stored in the `res` variable.
174
+ #
175
+ # Equivalent to the following naive implementation:
176
+ # res = A.new_zeros(A.shape)
177
+ # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
178
+ # for k in range(1, chr_poly_D.size(-1)):
179
+ # p_res.zero_()
180
+ # for i in range(0, k):
181
+ # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
182
+ # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
183
+ #
184
+ # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
185
+ # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
186
+ # and we need to compute g(U_grad, A, U, D)
187
+ #
188
+ # The naive implementation is based on the paper
189
+ # Hu, Qingxi, and Daizhan Cheng.
190
+ # "The polynomial solution to the Sylvester matrix equation."
191
+ # Applied mathematics letters 19.9 (2006): 859-864.
192
+ #
193
+ # We can modify the computation of `p_res` from above in a more efficient way
194
+ # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
195
+ # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
196
+ # + ...
197
+ # + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
198
+ # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
199
+ U_grad_projected = U_grad
200
+ series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
201
+ for k in range(1, chr_poly_D.size(-1)):
202
+ poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
203
+ series_acc += U_grad_projected * poly_D.unsqueeze(-2)
204
+ U_grad_projected = A.matmul(U_grad_projected)
205
+
206
+ # compute chr_poly_D(A) which essentially is:
207
+ #
208
+ # chr_poly_D_at_A = A.new_zeros(A.shape)
209
+ # for k in range(chr_poly_D.size(-1)):
210
+ # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
211
+ #
212
+ # Note, however, for better performance we use the Horner's rule
213
+ chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
214
+
215
+ # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
216
+ chr_poly_D_at_A_to_U_ortho = torch.matmul(
217
+ U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
218
+ )
219
+ # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
220
+ # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
221
+ # Cholesky decomposition requires the input to be positive-definite.
222
+ # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
223
+ # 1. `largest` == False, or
224
+ # 2. `largest` == True and `k` is even
225
+ # under the assumption that `A` has distinct eigenvalues.
226
+ #
227
+ # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
228
+ chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
229
+ chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
230
+ chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
231
+ )
232
+
233
+ # compute the gradient part in span(U)
234
+ res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
235
+
236
+ # incorporate the Sylvester equation solution into the full gradient
237
+ # it resides in span(U_ortho)
238
+ res -= U_ortho.matmul(
239
+ chr_poly_D_at_A_to_U_ortho_sign
240
+ * torch.cholesky_solve(
241
+ U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
242
+ )
243
+ ).matmul(Ut)
244
+
245
+ return res
246
+
247
+
248
+ def _symeig_backward(D_grad, U_grad, A, D, U, largest):
249
+ # if `U` is square, then the columns of `U` is a complete eigenspace
250
+ if U.size(-1) == U.size(-2):
251
+ return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
252
+ else:
253
+ return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
254
+
255
+
256
+ class LOBPCGAutogradFunction(torch.autograd.Function):
257
+ @staticmethod
258
+ def forward( # type: ignore[override]
259
+ ctx,
260
+ A: Tensor,
261
+ k: Optional[int] = None,
262
+ B: Optional[Tensor] = None,
263
+ X: Optional[Tensor] = None,
264
+ n: Optional[int] = None,
265
+ iK: Optional[Tensor] = None,
266
+ niter: Optional[int] = None,
267
+ tol: Optional[float] = None,
268
+ largest: Optional[bool] = None,
269
+ method: Optional[str] = None,
270
+ tracker: None = None,
271
+ ortho_iparams: Optional[Dict[str, int]] = None,
272
+ ortho_fparams: Optional[Dict[str, float]] = None,
273
+ ortho_bparams: Optional[Dict[str, bool]] = None,
274
+ ) -> Tuple[Tensor, Tensor]:
275
+ # makes sure that input is contiguous for efficiency.
276
+ # Note: autograd does not support dense gradients for sparse input yet.
277
+ A = A.contiguous() if (not A.is_sparse) else A
278
+ if B is not None:
279
+ B = B.contiguous() if (not B.is_sparse) else B
280
+
281
+ D, U = _lobpcg(
282
+ A,
283
+ k,
284
+ B,
285
+ X,
286
+ n,
287
+ iK,
288
+ niter,
289
+ tol,
290
+ largest,
291
+ method,
292
+ tracker,
293
+ ortho_iparams,
294
+ ortho_fparams,
295
+ ortho_bparams,
296
+ )
297
+
298
+ ctx.save_for_backward(A, B, D, U)
299
+ ctx.largest = largest
300
+
301
+ return D, U
302
+
303
+ @staticmethod
304
+ def backward(ctx, D_grad, U_grad):
305
+ A_grad = B_grad = None
306
+ grads = [None] * 14
307
+
308
+ A, B, D, U = ctx.saved_tensors
309
+ largest = ctx.largest
310
+
311
+ # lobpcg.backward has some limitations. Checks for unsupported input
312
+ if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
313
+ raise ValueError(
314
+ "lobpcg.backward does not support sparse input yet."
315
+ "Note that lobpcg.forward does though."
316
+ )
317
+ if (
318
+ A.dtype in (torch.complex64, torch.complex128)
319
+ or B is not None
320
+ and B.dtype in (torch.complex64, torch.complex128)
321
+ ):
322
+ raise ValueError(
323
+ "lobpcg.backward does not support complex input yet."
324
+ "Note that lobpcg.forward does though."
325
+ )
326
+ if B is not None:
327
+ raise ValueError(
328
+ "lobpcg.backward does not support backward with B != I yet."
329
+ )
330
+
331
+ if largest is None:
332
+ largest = True
333
+
334
+ # symeig backward
335
+ if B is None:
336
+ A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
337
+
338
+ # A has index 0
339
+ grads[0] = A_grad
340
+ # B has index 2
341
+ grads[2] = B_grad
342
+ return tuple(grads)
343
+
344
+
345
+ def lobpcg(
346
+ A: Tensor,
347
+ k: Optional[int] = None,
348
+ B: Optional[Tensor] = None,
349
+ X: Optional[Tensor] = None,
350
+ n: Optional[int] = None,
351
+ iK: Optional[Tensor] = None,
352
+ niter: Optional[int] = None,
353
+ tol: Optional[float] = None,
354
+ largest: Optional[bool] = None,
355
+ method: Optional[str] = None,
356
+ tracker: None = None,
357
+ ortho_iparams: Optional[Dict[str, int]] = None,
358
+ ortho_fparams: Optional[Dict[str, float]] = None,
359
+ ortho_bparams: Optional[Dict[str, bool]] = None,
360
+ ) -> Tuple[Tensor, Tensor]:
361
+ """Find the k largest (or smallest) eigenvalues and the corresponding
362
+ eigenvectors of a symmetric positive definite generalized
363
+ eigenvalue problem using matrix-free LOBPCG methods.
364
+
365
+ This function is a front-end to the following LOBPCG algorithms
366
+ selectable via `method` argument:
367
+
368
+ `method="basic"` - the LOBPCG method introduced by Andrew
369
+ Knyazev, see [Knyazev2001]. A less robust method, may fail when
370
+ Cholesky is applied to singular input.
371
+
372
+ `method="ortho"` - the LOBPCG method with orthogonal basis
373
+ selection [StathopoulosEtal2002]. A robust method.
374
+
375
+ Supported inputs are dense, sparse, and batches of dense matrices.
376
+
377
+ .. note:: In general, the basic method spends least time per
378
+ iteration. However, the robust methods converge much faster and
379
+ are more stable. So, the usage of the basic method is generally
380
+ not recommended but there exist cases where the usage of the
381
+ basic method may be preferred.
382
+
383
+ .. warning:: The backward method does not support sparse and complex inputs.
384
+ It works only when `B` is not provided (i.e. `B == None`).
385
+ We are actively working on extensions, and the details of
386
+ the algorithms are going to be published promptly.
387
+
388
+ .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
389
+ To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
390
+ in first-order optimization routines, prior to running `lobpcg`
391
+ we do the following symmetrization map: `A -> (A + A.t()) / 2`.
392
+ The map is performed only when the `A` requires gradients.
393
+
394
+ Args:
395
+
396
+ A (Tensor): the input tensor of size :math:`(*, m, m)`
397
+
398
+ B (Tensor, optional): the input tensor of size :math:`(*, m,
399
+ m)`. When not specified, `B` is interpreted as
400
+ identity matrix.
401
+
402
+ X (tensor, optional): the input tensor of size :math:`(*, m, n)`
403
+ where `k <= n <= m`. When specified, it is used as
404
+ initial approximation of eigenvectors. X must be a
405
+ dense tensor.
406
+
407
+ iK (tensor, optional): the input tensor of size :math:`(*, m,
408
+ m)`. When specified, it will be used as preconditioner.
409
+
410
+ k (integer, optional): the number of requested
411
+ eigenpairs. Default is the number of :math:`X`
412
+ columns (when specified) or `1`.
413
+
414
+ n (integer, optional): if :math:`X` is not specified then `n`
415
+ specifies the size of the generated random
416
+ approximation of eigenvectors. Default value for `n`
417
+ is `k`. If :math:`X` is specified, the value of `n`
418
+ (when specified) must be the number of :math:`X`
419
+ columns.
420
+
421
+ tol (float, optional): residual tolerance for stopping
422
+ criterion. Default is `feps ** 0.5` where `feps` is
423
+ smallest non-zero floating-point number of the given
424
+ input tensor `A` data type.
425
+
426
+ largest (bool, optional): when True, solve the eigenproblem for
427
+ the largest eigenvalues. Otherwise, solve the
428
+ eigenproblem for smallest eigenvalues. Default is
429
+ `True`.
430
+
431
+ method (str, optional): select LOBPCG method. See the
432
+ description of the function above. Default is
433
+ "ortho".
434
+
435
+ niter (int, optional): maximum number of iterations. When
436
+ reached, the iteration process is hard-stopped and
437
+ the current approximation of eigenpairs is returned.
438
+ For infinite iteration but until convergence criteria
439
+ is met, use `-1`.
440
+
441
+ tracker (callable, optional) : a function for tracing the
442
+ iteration process. When specified, it is called at
443
+ each iteration step with LOBPCG instance as an
444
+ argument. The LOBPCG instance holds the full state of
445
+ the iteration process in the following attributes:
446
+
447
+ `iparams`, `fparams`, `bparams` - dictionaries of
448
+ integer, float, and boolean valued input
449
+ parameters, respectively
450
+
451
+ `ivars`, `fvars`, `bvars`, `tvars` - dictionaries
452
+ of integer, float, boolean, and Tensor valued
453
+ iteration variables, respectively.
454
+
455
+ `A`, `B`, `iK` - input Tensor arguments.
456
+
457
+ `E`, `X`, `S`, `R` - iteration Tensor variables.
458
+
459
+ For instance:
460
+
461
+ `ivars["istep"]` - the current iteration step
462
+ `X` - the current approximation of eigenvectors
463
+ `E` - the current approximation of eigenvalues
464
+ `R` - the current residual
465
+ `ivars["converged_count"]` - the current number of converged eigenpairs
466
+ `tvars["rerr"]` - the current state of convergence criteria
467
+
468
+ Note that when `tracker` stores Tensor objects from
469
+ the LOBPCG instance, it must make copies of these.
470
+
471
+ If `tracker` sets `bvars["force_stop"] = True`, the
472
+ iteration process will be hard-stopped.
473
+
474
+ ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
475
+ various parameters to LOBPCG algorithm when using
476
+ `method="ortho"`.
477
+
478
+ Returns:
479
+
480
+ E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
481
+
482
+ X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
483
+
484
+ References:
485
+
486
+ [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
487
+ Preconditioned Eigensolver: Locally Optimal Block Preconditioned
488
+ Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
489
+ 517-541. (25 pages)
490
+ https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
491
+
492
+ [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
493
+ Wu. (2002) A Block Orthogonalization Procedure with Constant
494
+ Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
495
+ 2165-2182. (18 pages)
496
+ https://epubs.siam.org/doi/10.1137/S1064827500370883
497
+
498
+ [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
499
+ Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
500
+ SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
501
+ https://epubs.siam.org/doi/abs/10.1137/17M1129830
502
+
503
+ """
504
+
505
+ if not torch.jit.is_scripting():
506
+ tensor_ops = (A, B, X, iK)
507
+ if not set(map(type, tensor_ops)).issubset(
508
+ (torch.Tensor, type(None))
509
+ ) and has_torch_function(tensor_ops):
510
+ return handle_torch_function(
511
+ lobpcg,
512
+ tensor_ops,
513
+ A,
514
+ k=k,
515
+ B=B,
516
+ X=X,
517
+ n=n,
518
+ iK=iK,
519
+ niter=niter,
520
+ tol=tol,
521
+ largest=largest,
522
+ method=method,
523
+ tracker=tracker,
524
+ ortho_iparams=ortho_iparams,
525
+ ortho_fparams=ortho_fparams,
526
+ ortho_bparams=ortho_bparams,
527
+ )
528
+
529
+ if not torch._jit_internal.is_scripting():
530
+ if A.requires_grad or (B is not None and B.requires_grad):
531
+ # While it is expected that `A` is symmetric,
532
+ # the `A_grad` might be not. Therefore we perform the trick below,
533
+ # so that `A_grad` becomes symmetric.
534
+ # The symmetrization is important for first-order optimization methods,
535
+ # so that (A - alpha * A_grad) is still a symmetric matrix.
536
+ # Same holds for `B`.
537
+ A_sym = (A + A.mT) / 2
538
+ B_sym = (B + B.mT) / 2 if (B is not None) else None
539
+
540
+ return LOBPCGAutogradFunction.apply(
541
+ A_sym,
542
+ k,
543
+ B_sym,
544
+ X,
545
+ n,
546
+ iK,
547
+ niter,
548
+ tol,
549
+ largest,
550
+ method,
551
+ tracker,
552
+ ortho_iparams,
553
+ ortho_fparams,
554
+ ortho_bparams,
555
+ )
556
+ else:
557
+ if A.requires_grad or (B is not None and B.requires_grad):
558
+ raise RuntimeError(
559
+ "Script and require grads is not supported atm."
560
+ "If you just want to do the forward, use .detach()"
561
+ "on A and B before calling into lobpcg"
562
+ )
563
+
564
+ return _lobpcg(
565
+ A,
566
+ k,
567
+ B,
568
+ X,
569
+ n,
570
+ iK,
571
+ niter,
572
+ tol,
573
+ largest,
574
+ method,
575
+ tracker,
576
+ ortho_iparams,
577
+ ortho_fparams,
578
+ ortho_bparams,
579
+ )
580
+
581
+
582
+ def _lobpcg(
583
+ A: Tensor,
584
+ k: Optional[int] = None,
585
+ B: Optional[Tensor] = None,
586
+ X: Optional[Tensor] = None,
587
+ n: Optional[int] = None,
588
+ iK: Optional[Tensor] = None,
589
+ niter: Optional[int] = None,
590
+ tol: Optional[float] = None,
591
+ largest: Optional[bool] = None,
592
+ method: Optional[str] = None,
593
+ tracker: None = None,
594
+ ortho_iparams: Optional[Dict[str, int]] = None,
595
+ ortho_fparams: Optional[Dict[str, float]] = None,
596
+ ortho_bparams: Optional[Dict[str, bool]] = None,
597
+ ) -> Tuple[Tensor, Tensor]:
598
+ # A must be square:
599
+ assert A.shape[-2] == A.shape[-1], A.shape
600
+ if B is not None:
601
+ # A and B must have the same shapes:
602
+ assert A.shape == B.shape, (A.shape, B.shape)
603
+
604
+ dtype = _utils.get_floating_dtype(A)
605
+ device = A.device
606
+ if tol is None:
607
+ feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
608
+ tol = feps**0.5
609
+
610
+ m = A.shape[-1]
611
+ k = (1 if X is None else X.shape[-1]) if k is None else k
612
+ n = (k if n is None else n) if X is None else X.shape[-1]
613
+
614
+ if m < 3 * n:
615
+ raise ValueError(
616
+ f"LPBPCG algorithm is not applicable when the number of A rows (={m})"
617
+ f" is smaller than 3 x the number of requested eigenpairs (={n})"
618
+ )
619
+
620
+ method = "ortho" if method is None else method
621
+
622
+ iparams = {
623
+ "m": m,
624
+ "n": n,
625
+ "k": k,
626
+ "niter": 1000 if niter is None else niter,
627
+ }
628
+
629
+ fparams = {
630
+ "tol": tol,
631
+ }
632
+
633
+ bparams = {"largest": True if largest is None else largest}
634
+
635
+ if method == "ortho":
636
+ if ortho_iparams is not None:
637
+ iparams.update(ortho_iparams)
638
+ if ortho_fparams is not None:
639
+ fparams.update(ortho_fparams)
640
+ if ortho_bparams is not None:
641
+ bparams.update(ortho_bparams)
642
+ iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
643
+ iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
644
+ fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
645
+ fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
646
+ fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
647
+ bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
648
+
649
+ if not torch.jit.is_scripting():
650
+ LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[method-assign]
651
+
652
+ if len(A.shape) > 2:
653
+ N = int(torch.prod(torch.tensor(A.shape[:-2])))
654
+ bA = A.reshape((N,) + A.shape[-2:])
655
+ bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
656
+ bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
657
+ bE = torch.empty((N, k), dtype=dtype, device=device)
658
+ bXret = torch.empty((N, m, k), dtype=dtype, device=device)
659
+
660
+ for i in range(N):
661
+ A_ = bA[i]
662
+ B_ = bB[i] if bB is not None else None
663
+ X_ = (
664
+ torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
665
+ )
666
+ assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
667
+ iparams["batch_index"] = i
668
+ worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
669
+ worker.run()
670
+ bE[i] = worker.E[:k]
671
+ bXret[i] = worker.X[:, :k]
672
+
673
+ if not torch.jit.is_scripting():
674
+ LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign]
675
+
676
+ return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
677
+
678
+ X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
679
+ assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
680
+
681
+ worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
682
+
683
+ worker.run()
684
+
685
+ if not torch.jit.is_scripting():
686
+ LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign]
687
+
688
+ return worker.E[:k], worker.X[:, :k]
689
+
690
+
691
+ class LOBPCG:
692
+ """Worker class of LOBPCG methods."""
693
+
694
+ def __init__(
695
+ self,
696
+ A: Optional[Tensor],
697
+ B: Optional[Tensor],
698
+ X: Tensor,
699
+ iK: Optional[Tensor],
700
+ iparams: Dict[str, int],
701
+ fparams: Dict[str, float],
702
+ bparams: Dict[str, bool],
703
+ method: str,
704
+ tracker: None,
705
+ ) -> None:
706
+ # constant parameters
707
+ self.A = A
708
+ self.B = B
709
+ self.iK = iK
710
+ self.iparams = iparams
711
+ self.fparams = fparams
712
+ self.bparams = bparams
713
+ self.method = method
714
+ self.tracker = tracker
715
+ m = iparams["m"]
716
+ n = iparams["n"]
717
+
718
+ # variable parameters
719
+ self.X = X
720
+ self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
721
+ self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
722
+ self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
723
+ self.tvars: Dict[str, Tensor] = {}
724
+ self.ivars: Dict[str, int] = {"istep": 0}
725
+ self.fvars: Dict[str, float] = {"_": 0.0}
726
+ self.bvars: Dict[str, bool] = {"_": False}
727
+
728
+ def __str__(self):
729
+ lines = ["LOPBCG:"]
730
+ lines += [f" iparams={self.iparams}"]
731
+ lines += [f" fparams={self.fparams}"]
732
+ lines += [f" bparams={self.bparams}"]
733
+ lines += [f" ivars={self.ivars}"]
734
+ lines += [f" fvars={self.fvars}"]
735
+ lines += [f" bvars={self.bvars}"]
736
+ lines += [f" tvars={self.tvars}"]
737
+ lines += [f" A={self.A}"]
738
+ lines += [f" B={self.B}"]
739
+ lines += [f" iK={self.iK}"]
740
+ lines += [f" X={self.X}"]
741
+ lines += [f" E={self.E}"]
742
+ r = ""
743
+ for line in lines:
744
+ r += line + "\n"
745
+ return r
746
+
747
+ def update(self):
748
+ """Set and update iteration variables."""
749
+ if self.ivars["istep"] == 0:
750
+ X_norm = float(torch.norm(self.X))
751
+ iX_norm = X_norm**-1
752
+ A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
753
+ B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
754
+ self.fvars["X_norm"] = X_norm
755
+ self.fvars["A_norm"] = A_norm
756
+ self.fvars["B_norm"] = B_norm
757
+ self.ivars["iterations_left"] = self.iparams["niter"]
758
+ self.ivars["converged_count"] = 0
759
+ self.ivars["converged_end"] = 0
760
+
761
+ if self.method == "ortho":
762
+ self._update_ortho()
763
+ else:
764
+ self._update_basic()
765
+
766
+ self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
767
+ self.ivars["istep"] = self.ivars["istep"] + 1
768
+
769
+ def update_residual(self):
770
+ """Update residual R from A, B, X, E."""
771
+ mm = _utils.matmul
772
+ self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
773
+
774
+ def update_converged_count(self):
775
+ """Determine the number of converged eigenpairs using backward stable
776
+ convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
777
+
778
+ Users may redefine this method for custom convergence criteria.
779
+ """
780
+ # (...) -> int
781
+ prev_count = self.ivars["converged_count"]
782
+ tol = self.fparams["tol"]
783
+ A_norm = self.fvars["A_norm"]
784
+ B_norm = self.fvars["B_norm"]
785
+ E, X, R = self.E, self.X, self.R
786
+ rerr = (
787
+ torch.norm(R, 2, (0,))
788
+ * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1
789
+ )
790
+ converged = rerr.real < tol # this is a norm so imag is 0.0
791
+ count = 0
792
+ for b in converged:
793
+ if not b:
794
+ # ignore convergence of following pairs to ensure
795
+ # strict ordering of eigenpairs
796
+ break
797
+ count += 1
798
+ assert (
799
+ count >= prev_count
800
+ ), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
801
+ self.ivars["converged_count"] = count
802
+ self.tvars["rerr"] = rerr
803
+ return count
804
+
805
+ def stop_iteration(self):
806
+ """Return True to stop iterations.
807
+
808
+ Note that tracker (if defined) can force-stop iterations by
809
+ setting ``worker.bvars['force_stop'] = True``.
810
+ """
811
+ return (
812
+ self.bvars.get("force_stop", False)
813
+ or self.ivars["iterations_left"] == 0
814
+ or self.ivars["converged_count"] >= self.iparams["k"]
815
+ )
816
+
817
+ def run(self):
818
+ """Run LOBPCG iterations.
819
+
820
+ Use this method as a template for implementing LOBPCG
821
+ iteration scheme with custom tracker that is compatible with
822
+ TorchScript.
823
+ """
824
+ self.update()
825
+
826
+ if not torch.jit.is_scripting() and self.tracker is not None:
827
+ self.call_tracker()
828
+
829
+ while not self.stop_iteration():
830
+ self.update()
831
+
832
+ if not torch.jit.is_scripting() and self.tracker is not None:
833
+ self.call_tracker()
834
+
835
+ @torch.jit.unused
836
+ def call_tracker(self):
837
+ """Interface for tracking iteration process in Python mode.
838
+
839
+ Tracking the iteration process is disabled in TorchScript
840
+ mode. In fact, one should specify tracker=None when JIT
841
+ compiling functions using lobpcg.
842
+ """
843
+ # do nothing when in TorchScript mode
844
+
845
+ # Internal methods
846
+
847
+ def _update_basic(self):
848
+ """
849
+ Update or initialize iteration variables when `method == "basic"`.
850
+ """
851
+ mm = torch.matmul
852
+ ns = self.ivars["converged_end"]
853
+ nc = self.ivars["converged_count"]
854
+ n = self.iparams["n"]
855
+ largest = self.bparams["largest"]
856
+
857
+ if self.ivars["istep"] == 0:
858
+ Ri = self._get_rayleigh_ritz_transform(self.X)
859
+ M = _utils.qform(_utils.qform(self.A, self.X), Ri)
860
+ E, Z = _utils.symeig(M, largest)
861
+ self.X[:] = mm(self.X, mm(Ri, Z))
862
+ self.E[:] = E
863
+ np = 0
864
+ self.update_residual()
865
+ nc = self.update_converged_count()
866
+ self.S[..., :n] = self.X
867
+
868
+ W = _utils.matmul(self.iK, self.R)
869
+ self.ivars["converged_end"] = ns = n + np + W.shape[-1]
870
+ self.S[:, n + np : ns] = W
871
+ else:
872
+ S_ = self.S[:, nc:ns]
873
+ Ri = self._get_rayleigh_ritz_transform(S_)
874
+ M = _utils.qform(_utils.qform(self.A, S_), Ri)
875
+ E_, Z = _utils.symeig(M, largest)
876
+ self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
877
+ self.E[nc:] = E_[: n - nc]
878
+ P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
879
+ np = P.shape[-1]
880
+
881
+ self.update_residual()
882
+ nc = self.update_converged_count()
883
+ self.S[..., :n] = self.X
884
+ self.S[:, n : n + np] = P
885
+ W = _utils.matmul(self.iK, self.R[:, nc:])
886
+
887
+ self.ivars["converged_end"] = ns = n + np + W.shape[-1]
888
+ self.S[:, n + np : ns] = W
889
+
890
+ def _update_ortho(self):
891
+ """
892
+ Update or initialize iteration variables when `method == "ortho"`.
893
+ """
894
+ mm = torch.matmul
895
+ ns = self.ivars["converged_end"]
896
+ nc = self.ivars["converged_count"]
897
+ n = self.iparams["n"]
898
+ largest = self.bparams["largest"]
899
+
900
+ if self.ivars["istep"] == 0:
901
+ Ri = self._get_rayleigh_ritz_transform(self.X)
902
+ M = _utils.qform(_utils.qform(self.A, self.X), Ri)
903
+ E, Z = _utils.symeig(M, largest)
904
+ self.X = mm(self.X, mm(Ri, Z))
905
+ self.update_residual()
906
+ np = 0
907
+ nc = self.update_converged_count()
908
+ self.S[:, :n] = self.X
909
+ W = self._get_ortho(self.R, self.X)
910
+ ns = self.ivars["converged_end"] = n + np + W.shape[-1]
911
+ self.S[:, n + np : ns] = W
912
+
913
+ else:
914
+ S_ = self.S[:, nc:ns]
915
+ # Rayleigh-Ritz procedure
916
+ E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
917
+
918
+ # Update E, X, P
919
+ self.X[:, nc:] = mm(S_, Z[:, : n - nc])
920
+ self.E[nc:] = E_[: n - nc]
921
+ P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT)))
922
+ np = P.shape[-1]
923
+
924
+ # check convergence
925
+ self.update_residual()
926
+ nc = self.update_converged_count()
927
+
928
+ # update S
929
+ self.S[:, :n] = self.X
930
+ self.S[:, n : n + np] = P
931
+ W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
932
+ ns = self.ivars["converged_end"] = n + np + W.shape[-1]
933
+ self.S[:, n + np : ns] = W
934
+
935
+ def _get_rayleigh_ritz_transform(self, S):
936
+ """Return a transformation matrix that is used in Rayleigh-Ritz
937
+ procedure for reducing a general eigenvalue problem :math:`(S^TAS)
938
+ C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
939
+ S^TAS Ri) Z = Z E` where `C = Ri Z`.
940
+
941
+ .. note:: In the original Rayleight-Ritz procedure in
942
+ [DuerschEtal2018], the problem is formulated as follows::
943
+
944
+ SAS = S^T A S
945
+ SBS = S^T B S
946
+ D = (<diagonal matrix of SBS>) ** -1/2
947
+ R^T R = Cholesky(D SBS D)
948
+ Ri = D R^-1
949
+ solve symeig problem Ri^T SAS Ri Z = Theta Z
950
+ C = Ri Z
951
+
952
+ To reduce the number of matrix products (denoted by empty
953
+ space between matrices), here we introduce element-wise
954
+ products (denoted by symbol `*`) so that the Rayleight-Ritz
955
+ procedure becomes::
956
+
957
+ SAS = S^T A S
958
+ SBS = S^T B S
959
+ d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector
960
+ dd = d d^T # this is 2-d matrix
961
+ R^T R = Cholesky(dd * SBS)
962
+ Ri = R^-1 * d # broadcasting
963
+ solve symeig problem Ri^T SAS Ri Z = Theta Z
964
+ C = Ri Z
965
+
966
+ where `dd` is 2-d matrix that replaces matrix products `D M
967
+ D` with one element-wise product `M * dd`; and `d` replaces
968
+ matrix product `D M` with element-wise product `M *
969
+ d`. Also, creating the diagonal matrix `D` is avoided.
970
+
971
+ Args:
972
+ S (Tensor): the matrix basis for the search subspace, size is
973
+ :math:`(m, n)`.
974
+
975
+ Returns:
976
+ Ri (tensor): upper-triangular transformation matrix of size
977
+ :math:`(n, n)`.
978
+
979
+ """
980
+ B = self.B
981
+ mm = torch.matmul
982
+ SBS = _utils.qform(B, S)
983
+ d_row = SBS.diagonal(0, -2, -1) ** -0.5
984
+ d_col = d_row.reshape(d_row.shape[0], 1)
985
+ # TODO use torch.linalg.cholesky_solve once it is implemented
986
+ R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
987
+ return torch.linalg.solve_triangular(
988
+ R, d_row.diag_embed(), upper=True, left=False
989
+ )
990
+
991
+ def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor:
992
+ """Return B-orthonormal U.
993
+
994
+ .. note:: When `drop` is `False` then `svqb` is based on the
995
+ Algorithm 4 from [DuerschPhD2015] that is a slight
996
+ modification of the corresponding algorithm
997
+ introduced in [StathopolousWu2002].
998
+
999
+ Args:
1000
+
1001
+ U (Tensor) : initial approximation, size is (m, n)
1002
+ drop (bool) : when True, drop columns that
1003
+ contribution to the `span([U])` is small.
1004
+ tau (float) : positive tolerance
1005
+
1006
+ Returns:
1007
+
1008
+ U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
1009
+ is (m, n1), where `n1 = n` if `drop` is `False,
1010
+ otherwise `n1 <= n`.
1011
+
1012
+ """
1013
+ if torch.numel(U) == 0:
1014
+ return U
1015
+ UBU = _utils.qform(self.B, U)
1016
+ d = UBU.diagonal(0, -2, -1)
1017
+
1018
+ # Detect and drop exact zero columns from U. While the test
1019
+ # `abs(d) == 0` is unlikely to be True for random data, it is
1020
+ # possible to construct input data to lobpcg where it will be
1021
+ # True leading to a failure (notice the `d ** -0.5` operation
1022
+ # in the original algorithm). To prevent the failure, we drop
1023
+ # the exact zero columns here and then continue with the
1024
+ # original algorithm below.
1025
+ nz = torch.where(abs(d) != 0.0)
1026
+ assert len(nz) == 1, nz
1027
+ if len(nz[0]) < len(d):
1028
+ U = U[:, nz[0]]
1029
+ if torch.numel(U) == 0:
1030
+ return U
1031
+ UBU = _utils.qform(self.B, U)
1032
+ d = UBU.diagonal(0, -2, -1)
1033
+ nz = torch.where(abs(d) != 0.0)
1034
+ assert len(nz[0]) == len(d)
1035
+
1036
+ # The original algorithm 4 from [DuerschPhD2015].
1037
+ d_col = (d**-0.5).reshape(d.shape[0], 1)
1038
+ DUBUD = (UBU * d_col) * d_col.mT
1039
+ E, Z = _utils.symeig(DUBUD)
1040
+ t = tau * abs(E).max()
1041
+ if drop:
1042
+ keep = torch.where(E > t)
1043
+ assert len(keep) == 1, keep
1044
+ E = E[keep[0]]
1045
+ Z = Z[:, keep[0]]
1046
+ d_col = d_col[keep[0]]
1047
+ else:
1048
+ E[(torch.where(E < t))[0]] = t
1049
+
1050
+ return torch.matmul(U * d_col.mT, Z * E**-0.5)
1051
+
1052
+ def _get_ortho(self, U, V):
1053
+ """Return B-orthonormal U with columns are B-orthogonal to V.
1054
+
1055
+ .. note:: When `bparams["ortho_use_drop"] == False` then
1056
+ `_get_ortho` is based on the Algorithm 3 from
1057
+ [DuerschPhD2015] that is a slight modification of
1058
+ the corresponding algorithm introduced in
1059
+ [StathopolousWu2002]. Otherwise, the method
1060
+ implements Algorithm 6 from [DuerschPhD2015]
1061
+
1062
+ .. note:: If all U columns are B-collinear to V then the
1063
+ returned tensor U will be empty.
1064
+
1065
+ Args:
1066
+
1067
+ U (Tensor) : initial approximation, size is (m, n)
1068
+ V (Tensor) : B-orthogonal external basis, size is (m, k)
1069
+
1070
+ Returns:
1071
+
1072
+ U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
1073
+ such that :math:`V^T B U=0`, size is (m, n1),
1074
+ where `n1 = n` if `drop` is `False, otherwise
1075
+ `n1 <= n`.
1076
+ """
1077
+ mm = torch.matmul
1078
+ mm_B = _utils.matmul
1079
+ m = self.iparams["m"]
1080
+ tau_ortho = self.fparams["ortho_tol"]
1081
+ tau_drop = self.fparams["ortho_tol_drop"]
1082
+ tau_replace = self.fparams["ortho_tol_replace"]
1083
+ i_max = self.iparams["ortho_i_max"]
1084
+ j_max = self.iparams["ortho_j_max"]
1085
+ # when use_drop==True, enable dropping U columns that have
1086
+ # small contribution to the `span([U, V])`.
1087
+ use_drop = self.bparams["ortho_use_drop"]
1088
+
1089
+ # clean up variables from the previous call
1090
+ for vkey in list(self.fvars.keys()):
1091
+ if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
1092
+ self.fvars.pop(vkey)
1093
+ self.ivars.pop("ortho_i", 0)
1094
+ self.ivars.pop("ortho_j", 0)
1095
+
1096
+ BV_norm = torch.norm(mm_B(self.B, V))
1097
+ BU = mm_B(self.B, U)
1098
+ VBU = mm(V.mT, BU)
1099
+ i = j = 0
1100
+ stats = ""
1101
+ for i in range(i_max):
1102
+ U = U - mm(V, VBU)
1103
+ drop = False
1104
+ tau_svqb = tau_drop
1105
+ for j in range(j_max):
1106
+ if use_drop:
1107
+ U = self._get_svqb(U, drop, tau_svqb)
1108
+ drop = True
1109
+ tau_svqb = tau_replace
1110
+ else:
1111
+ U = self._get_svqb(U, False, tau_replace)
1112
+ if torch.numel(U) == 0:
1113
+ # all initial U columns are B-collinear to V
1114
+ self.ivars["ortho_i"] = i
1115
+ self.ivars["ortho_j"] = j
1116
+ return U
1117
+ BU = mm_B(self.B, U)
1118
+ UBU = mm(U.mT, BU)
1119
+ U_norm = torch.norm(U)
1120
+ BU_norm = torch.norm(BU)
1121
+ R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
1122
+ R_norm = torch.norm(R)
1123
+ # https://github.com/pytorch/pytorch/issues/33810 workaround:
1124
+ rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
1125
+ vkey = f"ortho_UBUmI_rerr[{i}, {j}]"
1126
+ self.fvars[vkey] = rerr
1127
+ if rerr < tau_ortho:
1128
+ break
1129
+ VBU = mm(V.mT, BU)
1130
+ VBU_norm = torch.norm(VBU)
1131
+ U_norm = torch.norm(U)
1132
+ rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
1133
+ vkey = f"ortho_VBU_rerr[{i}]"
1134
+ self.fvars[vkey] = rerr
1135
+ if rerr < tau_ortho:
1136
+ break
1137
+ if m < U.shape[-1] + V.shape[-1]:
1138
+ # TorchScript needs the class var to be assigned to a local to
1139
+ # do optional type refinement
1140
+ B = self.B
1141
+ assert B is not None
1142
+ raise ValueError(
1143
+ "Overdetermined shape of U:"
1144
+ f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold"
1145
+ )
1146
+ self.ivars["ortho_i"] = i
1147
+ self.ivars["ortho_j"] = j
1148
+ return U
1149
+
1150
+
1151
+ # Calling tracker is separated from LOBPCG definitions because
1152
+ # TorchScript does not support user-defined callback arguments:
1153
+ LOBPCG_call_tracker_orig = LOBPCG.call_tracker
1154
+
1155
+
1156
+ def LOBPCG_call_tracker(self):
1157
+ self.tracker(self)
.venv/lib/python3.11/site-packages/torch/_lowrank.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implement various linear algebra algorithms for low rank matrices."""
2
+
3
+ __all__ = ["svd_lowrank", "pca_lowrank"]
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from torch import _linalg_utils as _utils, Tensor
9
+ from torch.overrides import handle_torch_function, has_torch_function
10
+
11
+
12
+ def get_approximate_basis(
13
+ A: Tensor,
14
+ q: int,
15
+ niter: Optional[int] = 2,
16
+ M: Optional[Tensor] = None,
17
+ ) -> Tensor:
18
+ """Return tensor :math:`Q` with :math:`q` orthonormal columns such
19
+ that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
20
+ specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
21
+ approximates :math:`A - M`. without instantiating any tensors
22
+ of the size of :math:`A` or :math:`M`.
23
+
24
+ .. note:: The implementation is based on the Algorithm 4.4 from
25
+ Halko et al., 2009.
26
+
27
+ .. note:: For an adequate approximation of a k-rank matrix
28
+ :math:`A`, where k is not known in advance but could be
29
+ estimated, the number of :math:`Q` columns, q, can be
30
+ choosen according to the following criteria: in general,
31
+ :math:`k <= q <= min(2*k, m, n)`. For large low-rank
32
+ matrices, take :math:`q = k + 5..10`. If k is
33
+ relatively small compared to :math:`min(m, n)`, choosing
34
+ :math:`q = k + 0..2` may be sufficient.
35
+
36
+ .. note:: To obtain repeatable results, reset the seed for the
37
+ pseudorandom number generator
38
+
39
+ Args::
40
+ A (Tensor): the input tensor of size :math:`(*, m, n)`
41
+
42
+ q (int): the dimension of subspace spanned by :math:`Q`
43
+ columns.
44
+
45
+ niter (int, optional): the number of subspace iterations to
46
+ conduct; ``niter`` must be a
47
+ nonnegative integer. In most cases, the
48
+ default value 2 is more than enough.
49
+
50
+ M (Tensor, optional): the input tensor's mean of size
51
+ :math:`(*, m, n)`.
52
+
53
+ References::
54
+ - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
55
+ structure with randomness: probabilistic algorithms for
56
+ constructing approximate matrix decompositions,
57
+ arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
58
+ `arXiv <http://arxiv.org/abs/0909.4061>`_).
59
+ """
60
+
61
+ niter = 2 if niter is None else niter
62
+ dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype
63
+ matmul = _utils.matmul
64
+
65
+ R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device)
66
+
67
+ # The following code could be made faster using torch.geqrf + torch.ormqr
68
+ # but geqrf is not differentiable
69
+
70
+ X = matmul(A, R)
71
+ if M is not None:
72
+ X = X - matmul(M, R)
73
+ Q = torch.linalg.qr(X).Q
74
+ for i in range(niter):
75
+ X = matmul(A.mH, Q)
76
+ if M is not None:
77
+ X = X - matmul(M.mH, Q)
78
+ Q = torch.linalg.qr(X).Q
79
+ X = matmul(A, Q)
80
+ if M is not None:
81
+ X = X - matmul(M, Q)
82
+ Q = torch.linalg.qr(X).Q
83
+ return Q
84
+
85
+
86
+ def svd_lowrank(
87
+ A: Tensor,
88
+ q: Optional[int] = 6,
89
+ niter: Optional[int] = 2,
90
+ M: Optional[Tensor] = None,
91
+ ) -> Tuple[Tensor, Tensor, Tensor]:
92
+ r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
93
+ batches of matrices, or a sparse matrix :math:`A` such that
94
+ :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
95
+ SVD is computed for the matrix :math:`A - M`.
96
+
97
+ .. note:: The implementation is based on the Algorithm 5.1 from
98
+ Halko et al., 2009.
99
+
100
+ .. note:: For an adequate approximation of a k-rank matrix
101
+ :math:`A`, where k is not known in advance but could be
102
+ estimated, the number of :math:`Q` columns, q, can be
103
+ choosen according to the following criteria: in general,
104
+ :math:`k <= q <= min(2*k, m, n)`. For large low-rank
105
+ matrices, take :math:`q = k + 5..10`. If k is
106
+ relatively small compared to :math:`min(m, n)`, choosing
107
+ :math:`q = k + 0..2` may be sufficient.
108
+
109
+ .. note:: This is a randomized method. To obtain repeatable results,
110
+ set the seed for the pseudorandom number generator
111
+
112
+ .. note:: In general, use the full-rank SVD implementation
113
+ :func:`torch.linalg.svd` for dense matrices due to its 10x
114
+ higher performance characteristics. The low-rank SVD
115
+ will be useful for huge sparse matrices that
116
+ :func:`torch.linalg.svd` cannot handle.
117
+
118
+ Args::
119
+ A (Tensor): the input tensor of size :math:`(*, m, n)`
120
+
121
+ q (int, optional): a slightly overestimated rank of A.
122
+
123
+ niter (int, optional): the number of subspace iterations to
124
+ conduct; niter must be a nonnegative
125
+ integer, and defaults to 2
126
+
127
+ M (Tensor, optional): the input tensor's mean of size
128
+ :math:`(*, m, n)`, which will be broadcasted
129
+ to the size of A in this function.
130
+
131
+ References::
132
+ - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
133
+ structure with randomness: probabilistic algorithms for
134
+ constructing approximate matrix decompositions,
135
+ arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
136
+ `arXiv <https://arxiv.org/abs/0909.4061>`_).
137
+
138
+ """
139
+ if not torch.jit.is_scripting():
140
+ tensor_ops = (A, M)
141
+ if not set(map(type, tensor_ops)).issubset(
142
+ (torch.Tensor, type(None))
143
+ ) and has_torch_function(tensor_ops):
144
+ return handle_torch_function(
145
+ svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
146
+ )
147
+ return _svd_lowrank(A, q=q, niter=niter, M=M)
148
+
149
+
150
+ def _svd_lowrank(
151
+ A: Tensor,
152
+ q: Optional[int] = 6,
153
+ niter: Optional[int] = 2,
154
+ M: Optional[Tensor] = None,
155
+ ) -> Tuple[Tensor, Tensor, Tensor]:
156
+ # Algorithm 5.1 in Halko et al., 2009
157
+
158
+ q = 6 if q is None else q
159
+ m, n = A.shape[-2:]
160
+ matmul = _utils.matmul
161
+ if M is not None:
162
+ M = M.broadcast_to(A.size())
163
+
164
+ # Assume that A is tall
165
+ if m < n:
166
+ A = A.mH
167
+ if M is not None:
168
+ M = M.mH
169
+
170
+ Q = get_approximate_basis(A, q, niter=niter, M=M)
171
+ B = matmul(Q.mH, A)
172
+ if M is not None:
173
+ B = B - matmul(Q.mH, M)
174
+ U, S, Vh = torch.linalg.svd(B, full_matrices=False)
175
+ V = Vh.mH
176
+ U = Q.matmul(U)
177
+
178
+ if m < n:
179
+ U, V = V, U
180
+
181
+ return U, S, V
182
+
183
+
184
+ def pca_lowrank(
185
+ A: Tensor,
186
+ q: Optional[int] = None,
187
+ center: bool = True,
188
+ niter: int = 2,
189
+ ) -> Tuple[Tensor, Tensor, Tensor]:
190
+ r"""Performs linear Principal Component Analysis (PCA) on a low-rank
191
+ matrix, batches of such matrices, or sparse matrix.
192
+
193
+ This function returns a namedtuple ``(U, S, V)`` which is the
194
+ nearly optimal approximation of a singular value decomposition of
195
+ a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`
196
+
197
+ .. note:: The relation of ``(U, S, V)`` to PCA is as follows:
198
+
199
+ - :math:`A` is a data matrix with ``m`` samples and
200
+ ``n`` features
201
+
202
+ - the :math:`V` columns represent the principal directions
203
+
204
+ - :math:`S ** 2 / (m - 1)` contains the eigenvalues of
205
+ :math:`A^T A / (m - 1)` which is the covariance of
206
+ ``A`` when ``center=True`` is provided.
207
+
208
+ - ``matmul(A, V[:, :k])`` projects data to the first k
209
+ principal components
210
+
211
+ .. note:: Different from the standard SVD, the size of returned
212
+ matrices depend on the specified rank and q
213
+ values as follows:
214
+
215
+ - :math:`U` is m x q matrix
216
+
217
+ - :math:`S` is q-vector
218
+
219
+ - :math:`V` is n x q matrix
220
+
221
+ .. note:: To obtain repeatable results, reset the seed for the
222
+ pseudorandom number generator
223
+
224
+ Args:
225
+
226
+ A (Tensor): the input tensor of size :math:`(*, m, n)`
227
+
228
+ q (int, optional): a slightly overestimated rank of
229
+ :math:`A`. By default, ``q = min(6, m,
230
+ n)``.
231
+
232
+ center (bool, optional): if True, center the input tensor,
233
+ otherwise, assume that the input is
234
+ centered.
235
+
236
+ niter (int, optional): the number of subspace iterations to
237
+ conduct; niter must be a nonnegative
238
+ integer, and defaults to 2.
239
+
240
+ References::
241
+
242
+ - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
243
+ structure with randomness: probabilistic algorithms for
244
+ constructing approximate matrix decompositions,
245
+ arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
246
+ `arXiv <http://arxiv.org/abs/0909.4061>`_).
247
+
248
+ """
249
+
250
+ if not torch.jit.is_scripting():
251
+ if type(A) is not torch.Tensor and has_torch_function((A,)):
252
+ return handle_torch_function(
253
+ pca_lowrank, (A,), A, q=q, center=center, niter=niter
254
+ )
255
+
256
+ (m, n) = A.shape[-2:]
257
+
258
+ if q is None:
259
+ q = min(6, m, n)
260
+ elif not (q >= 0 and q <= min(m, n)):
261
+ raise ValueError(
262
+ f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}"
263
+ )
264
+ if not (niter >= 0):
265
+ raise ValueError(f"niter(={niter}) must be non-negative integer")
266
+
267
+ dtype = _utils.get_floating_dtype(A)
268
+
269
+ if not center:
270
+ return _svd_lowrank(A, q, niter=niter, M=None)
271
+
272
+ if _utils.is_sparse(A):
273
+ if len(A.shape) != 2:
274
+ raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
275
+ c = torch.sparse.sum(A, dim=(-2,)) / m
276
+ # reshape c
277
+ column_indices = c.indices()[0]
278
+ indices = torch.zeros(
279
+ 2,
280
+ len(column_indices),
281
+ dtype=column_indices.dtype,
282
+ device=column_indices.device,
283
+ )
284
+ indices[0] = column_indices
285
+ C_t = torch.sparse_coo_tensor(
286
+ indices, c.values(), (n, 1), dtype=dtype, device=A.device
287
+ )
288
+
289
+ ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
290
+ M = torch.sparse.mm(C_t, ones_m1_t).mT
291
+ return _svd_lowrank(A, q, niter=niter, M=M)
292
+ else:
293
+ C = A.mean(dim=(-2,), keepdim=True)
294
+ return _svd_lowrank(A - C, q, niter=niter, M=None)
.venv/lib/python3.11/site-packages/torch/_meta_registrations.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torch/_namedtensor_internals.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from collections import OrderedDict
3
+
4
+
5
+ """
6
+ This file contains helper functions that implement experimental functionality
7
+ for named tensors in python. All of these are experimental, unstable, and
8
+ subject to change or deletion.
9
+ """
10
+
11
+
12
+ def check_serializing_named_tensor(tensor):
13
+ if tensor.has_names():
14
+ raise RuntimeError(
15
+ "NYI: Named tensors don't support serialization. Please drop "
16
+ "names via `tensor = tensor.rename(None)` before serialization."
17
+ )
18
+
19
+
20
+ def build_dim_map(tensor):
21
+ """Returns a map of { dim: dim_name } where dim is a name if the dim is named
22
+ and the dim index otherwise."""
23
+ return OrderedDict(
24
+ [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)]
25
+ )
26
+
27
+
28
+ def unzip_namedshape(namedshape):
29
+ if isinstance(namedshape, OrderedDict):
30
+ namedshape = namedshape.items()
31
+ if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple):
32
+ raise RuntimeError(
33
+ f"Expected namedshape to be OrderedDict or iterable of tuples, got: {type(namedshape)}"
34
+ )
35
+ if len(namedshape) == 0:
36
+ raise RuntimeError("Expected namedshape to non-empty.")
37
+ return zip(*namedshape)
38
+
39
+
40
+ def namer_api_name(inplace):
41
+ if inplace:
42
+ return "rename_"
43
+ else:
44
+ return "rename"
45
+
46
+
47
+ def is_ellipsis(item):
48
+ return item == Ellipsis or item == "..."
49
+
50
+
51
+ def single_ellipsis_index(names, fn_name):
52
+ ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
53
+ if len(ellipsis_indices) >= 2:
54
+ raise RuntimeError(
55
+ f"{fn_name}: More than one Ellipsis ('...') found in names ("
56
+ f"{names}). This function supports up to one Ellipsis."
57
+ )
58
+ if len(ellipsis_indices) == 1:
59
+ return ellipsis_indices[0]
60
+ return None
61
+
62
+
63
+ def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names):
64
+ return names[numel_pre_glob : len(names) - numel_post_glob]
65
+
66
+
67
+ def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names):
68
+ globbed_names = expand_single_ellipsis(
69
+ ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names
70
+ )
71
+ return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :]
72
+
73
+
74
+ def resolve_ellipsis(names, tensor_names, fn_name):
75
+ """
76
+ Expands ... inside `names` to be equal to a list of names from `tensor_names`.
77
+ """
78
+ ellipsis_idx = single_ellipsis_index(names, fn_name)
79
+ if ellipsis_idx is None:
80
+ return names
81
+ return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)
82
+
83
+
84
+ def update_names_with_list(tensor, names, inplace):
85
+ # Special case for tensor.rename(None)
86
+ if len(names) == 1 and names[0] is None:
87
+ return tensor._update_names(None, inplace)
88
+
89
+ return tensor._update_names(
90
+ resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace
91
+ )
92
+
93
+
94
+ def update_names_with_mapping(tensor, rename_map, inplace):
95
+ dim_map = build_dim_map(tensor)
96
+ for old_dim in rename_map.keys():
97
+ new_dim = rename_map[old_dim]
98
+ if old_dim in dim_map.keys():
99
+ dim_map[old_dim] = new_dim
100
+ else:
101
+ raise RuntimeError(
102
+ f"{namer_api_name(inplace)}: Tried to rename dim '{old_dim}' to dim "
103
+ f"{new_dim} in Tensor[{tensor.names}] but dim '{old_dim}' does not exist"
104
+ )
105
+ return tensor._update_names(tuple(dim_map.values()), inplace)
106
+
107
+
108
+ def update_names(tensor, names, rename_map, inplace):
109
+ """There are two usages:
110
+
111
+ tensor.rename(*names) returns a view on tensor with named dims `names`.
112
+ `names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`,
113
+ then it is expanded greedily to be equal to the corresponding names from
114
+ `tensor.names`.
115
+
116
+ For example,
117
+ ```
118
+ >>> # xdoctest: +SKIP
119
+ >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
120
+ >>> x.rename('...', 'height', 'width').names
121
+ ('N', 'C', 'height', 'width')
122
+
123
+ >>> # xdoctest: +SKIP
124
+ >>> x.rename('batch', '...', 'width').names
125
+ ('batch', 'C', 'H', 'width')
126
+
127
+ ```
128
+
129
+ tensor.rename(**rename_map) returns a view on tensor that has rename dims
130
+ as specified in the mapping `rename_map`.
131
+
132
+ For example,
133
+ ```
134
+ >>> # xdoctest: +SKIP
135
+ >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
136
+ >>> x.rename(W='width', H='height').names
137
+ ('N', 'C', 'height', 'width')
138
+
139
+ ```
140
+
141
+ Finally, tensor.rename has an in-place version called tensor.rename_.
142
+ """
143
+ has_names = len(names) > 0
144
+ has_rename_pairs = bool(rename_map)
145
+ if has_names and has_rename_pairs:
146
+ raise RuntimeError(
147
+ f"{namer_api_name(inplace)}: This function takes either positional "
148
+ f"args or keyword args, but not both. Use tensor.{namer_api_name(inplace)}(*names) "
149
+ f"to name dims and tensor.{namer_api_name(inplace)}(**rename_map) to rename "
150
+ "dims."
151
+ )
152
+
153
+ # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor.
154
+ if not has_names and not has_rename_pairs:
155
+ return update_names_with_list(tensor, names, inplace)
156
+
157
+ if has_names:
158
+ return update_names_with_list(tensor, names, inplace)
159
+ return update_names_with_mapping(tensor, rename_map, inplace)
.venv/lib/python3.11/site-packages/torch/_ops.py ADDED
@@ -0,0 +1,1355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import abc
3
+ import contextlib
4
+ import ctypes
5
+ import importlib
6
+ import inspect
7
+ import sys
8
+ import types
9
+ from typing import Any, Callable, Dict, List, Set, Type, Union
10
+
11
+ import torch
12
+ import torch.utils._pytree as pytree
13
+ from torch import _utils_internal
14
+ from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
15
+ from torch._functorch.pyfunctorch import dispatch_functorch
16
+ from torch.utils._python_dispatch import TorchDispatchMode
17
+
18
+
19
+ # Query `hasattr` only once.
20
+ _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
21
+
22
+
23
+ @contextlib.contextmanager
24
+ def dl_open_guard():
25
+ """
26
+ Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
27
+ shared library to load custom operators.
28
+ """
29
+ if not _SET_GLOBAL_FLAGS:
30
+ yield
31
+ return
32
+ old_flags = sys.getdlopenflags()
33
+ sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
34
+ try:
35
+ yield
36
+ finally:
37
+ sys.setdlopenflags(old_flags)
38
+
39
+
40
+ class OperatorBase:
41
+ """
42
+ Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator
43
+ (which represents Python-only operators that are unrepresentable in TorchScript).
44
+ """
45
+
46
+ def __init__(self):
47
+ # The dispatch cache precomputes a mapping of dispatch key that the
48
+ # dispatcher wants to dispatch to, to an actual implementation of the
49
+ # dispatch key. Confusingly, the actual implementation could *also* be a
50
+ # dispatch key, but in this case, this refers to the C++ kernel that
51
+ # was registered to some dispatch key. Aliases are permitted in the
52
+ # latter but not the former; for example, you might lookup the
53
+ # entry for AutogradCPU, and this maps you to the Autograd key for
54
+ # the generic autograd kernel that works for all devices. Since this
55
+ # is the Python dispatcher, you can also put an arbitrary Python
56
+ # callable to call instead. This handler gets precisely the
57
+ # args/kwargs that the operator was __call__'ed with.
58
+ # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
59
+ # for use with OpOverload; cache lookup is done entirely from C++
60
+ # for speed.
61
+ # TODO: The cache is NOT currently used by HigherOrderOperator, but it should!
62
+ self._dispatch_cache: Dict[
63
+ DispatchKey, Union[DispatchKey, Callable[..., Any]]
64
+ ] = {}
65
+
66
+ # This table allows you to override the behavior of a particular
67
+ # dispatch key to call a custom Python function, rather than the
68
+ # ordinary C++ configured behavior. This is the raison d'etre of
69
+ # Python dispatcher: to let you program the dispatcher from Python
70
+ # in case you need something unusual, and don't want to clobber
71
+ # the existing registrations using the Python operator registration
72
+ # API.
73
+ self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {}
74
+
75
+ # This table allows you to override the behavior of a particular
76
+ # operator for a particular TorchDispatchMode. In practice,
77
+ # we are using this mostly for ProxyTensorMode. Modes can be
78
+ # thought of as an open world extension of dispatch keys, so it
79
+ # makes sense that you should be able to register them, the same
80
+ # way you can register dispatch keys.
81
+ self.python_key_table: Dict[
82
+ Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any]
83
+ ] = {}
84
+
85
+ # This table allows you to override the behavior of functorch
86
+ # transformations. NB: this currently only does something for
87
+ # HigherOrderOperator
88
+ self.functorch_table = {}
89
+
90
+ def __call__(self, *args, **kwargs):
91
+ raise NotImplementedError
92
+
93
+ def has_kernel_for_dispatch_key(self, k):
94
+ return k in self.py_kernels
95
+
96
+ def has_kernel_for_any_dispatch_key(self, ks):
97
+ for k in self.py_kernels:
98
+ if not torch._C._dispatch_is_alias_key(k) and ks.has(k):
99
+ return True
100
+ return False
101
+
102
+ def py_impl(self, k):
103
+ def inner(fn):
104
+ if inspect.isclass(k) and (
105
+ issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
106
+ ):
107
+ assert k not in self.python_key_table
108
+ # TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys?
109
+ self.python_key_table[k] = fn
110
+ self._dispatch_cache.clear()
111
+ return fn
112
+
113
+ if isinstance(k, torch._C._functorch.TransformType):
114
+ assert k not in self.functorch_table
115
+ self.functorch_table[k] = fn
116
+ return fn
117
+
118
+ assert isinstance(k, DispatchKey)
119
+ assert (
120
+ k != DispatchKey.Python
121
+ ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
122
+
123
+ if k in self.py_kernels:
124
+ raise RuntimeError(
125
+ f"Trying to override a python impl for {k} on operator {self.name()}"
126
+ )
127
+ self.py_kernels[k] = fn
128
+ self._dispatch_cache.clear()
129
+ return fn
130
+
131
+ return inner
132
+
133
+ # Registers an implementation to all **3** variants of functionalization that we have:
134
+ # - DispatchKey.Functionalize
135
+ # - functorch.TransformType.Functionalize
136
+ # - FunctionalTensorMode
137
+ # Example:
138
+ # @py_functionalize_impl
139
+ # def functionalize_rule(ctx, inner_f, *args):
140
+ # args_unwrapped = ctx.unwrap_tensors(args)
141
+ # with ctx.redispatch_to_next():
142
+ # out = ctx.functionalize(inner_f)(*args_unwrapped)
143
+ # return ctx.wrap_tensors(out)
144
+ def py_functionalize_impl(self, fn):
145
+ from torch._subclasses.functional_tensor import (
146
+ CppFunctionalizeAPI as _CppFunctionalizeAPI,
147
+ FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
148
+ PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
149
+ )
150
+
151
+ # Construct our three flavors of functionalization,
152
+ # each of which have slightly different wrap/unwrap/redispatch policies
153
+ def functionalize_dk_fn(*args, **kwargs):
154
+ return fn(_CppFunctionalizeAPI(), *args, **kwargs)
155
+
156
+ def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
157
+ return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
158
+
159
+ def functionalize_functorch_fn(interpreter, *args, **kwargs):
160
+ return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
161
+
162
+ self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
163
+ self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
164
+ functionalize_dispatch_mode_fn
165
+ )
166
+ self.py_impl(torch._C._functorch.TransformType.Functionalize)(
167
+ functionalize_functorch_fn
168
+ )
169
+
170
+ return fn
171
+
172
+ def name(self):
173
+ raise NotImplementedError
174
+
175
+
176
+ # Equivalent to computeDispatchTableEntryWithDebug
177
+ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
178
+ # 1. (Direct) operator registration
179
+ if op.has_kernel_for_dispatch_key(k):
180
+ return k
181
+ # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
182
+ cand = DispatchKey.CompositeExplicitAutogradNonFunctional
183
+ if (
184
+ k == DispatchKey.Undefined or is_included_in_alias(k, cand)
185
+ ) and op.has_kernel_for_dispatch_key(cand):
186
+ return cand
187
+ # 2.2 Use CompositeExplicitAutograd kernel if available
188
+ cand = DispatchKey.CompositeExplicitAutograd
189
+ if (
190
+ k == DispatchKey.Undefined or is_included_in_alias(k, cand)
191
+ ) and op.has_kernel_for_dispatch_key(cand):
192
+ return cand
193
+ has_backend_kernel = op.has_kernel_for_any_dispatch_key(
194
+ torch._C._dispatch_get_backend_keyset_from_autograd(k)
195
+ ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd)
196
+ # 2.3. Use CompositeImplicitAutograd kernel if available
197
+ cand = DispatchKey.CompositeImplicitAutogradNestedTensor
198
+ if (
199
+ (k != DispatchKey.Undefined and is_included_in_alias(k, cand))
200
+ and op.has_kernel_for_dispatch_key(cand)
201
+ and not has_backend_kernel
202
+ ):
203
+ return cand
204
+ cand = DispatchKey.CompositeImplicitAutograd
205
+ if (
206
+ k == DispatchKey.Undefined or is_included_in_alias(k, cand)
207
+ ) and op.has_kernel_for_dispatch_key(cand):
208
+ if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key(
209
+ torch._C._dispatch_autogradother_backends
210
+ ):
211
+ raise RuntimeError("ambiguous autogradother kernel")
212
+ elif not has_backend_kernel:
213
+ return cand
214
+ # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
215
+ cand = DispatchKey.Autograd
216
+ if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
217
+ return cand
218
+ # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
219
+ cand = DispatchKey.FuncTorchBatchedDecomposition
220
+ if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
221
+ return cand
222
+ # Backend fallback
223
+ if torch._C._dispatch_has_backend_fallback(k):
224
+ # The dispatch key itself will implicitly route to backend fallback.
225
+ # This is probably not great for the pure Python implementation.
226
+ return k
227
+ raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
228
+
229
+
230
+ _higher_order_ops: Dict[str, "HigherOrderOperator"] = {}
231
+
232
+ _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
233
+ DispatchKey.PythonDispatcher, # type: ignore[attr-defined]
234
+ DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined]
235
+ DispatchKey.ADInplaceOrView,
236
+ DispatchKey.BackendSelect,
237
+ DispatchKey.AutocastCPU, # type: ignore[attr-defined]
238
+ DispatchKey.AutocastCUDA, # type: ignore[attr-defined]
239
+ ]
240
+
241
+
242
+ class HigherOrderOperator(OperatorBase, abc.ABC):
243
+ # The HigherOrderOperator will appear as torch.ops.higher_order.{name}
244
+ #
245
+ # If you're creating a new HigherOrderOperator, please do not change the
246
+ # default. Adding operators to the global torch.ops namespace is a bad
247
+ # practice due to name collisions.
248
+ def __init__(self, name):
249
+ super().__init__()
250
+ if type(self) is HigherOrderOperator:
251
+ raise RuntimeError(
252
+ "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
253
+ )
254
+ self._name = name
255
+
256
+ # Make _OPNamespace not scream, this whole name based association needs a good hard look
257
+ self.__name__ = name
258
+ _higher_order_ops[name] = self
259
+ self._ns = "higher_order"
260
+ self.__module__ = "torch.ops.higher_order"
261
+
262
+ self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
263
+
264
+ for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS:
265
+ self.fallthrough(dispatch_key)
266
+
267
+ # [NOTE] We have to register pre-dispatch key implementation
268
+ # because sometimes HOP use aot-dispatch tracing to detect certaion
269
+ # mutations. This is problematic when we are functionalizing HOP
270
+ # during pre-dispatch because when the inner tracer starts, it will see
271
+ # that PreDispatch key is still active. In that case, we just redispatch
272
+ # it to next key. This is only safe to do when PreDispatch key stack has no
273
+ # active modes.
274
+
275
+ def py_impl(self, k):
276
+ if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
277
+ self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
278
+ return super().py_impl(k)
279
+
280
+ @property
281
+ def namespace(self):
282
+ return self._ns
283
+
284
+ def fallthrough(self, dispatch_key):
285
+ self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
286
+
287
+ # Use positional-only argument to avoid naming collide with custom ops arguments
288
+ # that are named "self".
289
+ def dispatch(self, /, dispatch_key, *args, **kwargs):
290
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
291
+
292
+ if dispatch_key in self._dispatch_cache:
293
+ kernel = self._dispatch_cache[dispatch_key]
294
+ assert not isinstance(kernel, DispatchKey)
295
+ return kernel(*args, **kwargs)
296
+
297
+ if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode:
298
+ return dispatch_functorch(self, args, kwargs)
299
+
300
+ if dispatch_key == DispatchKey.Python:
301
+ # Keep the following 1:1 with handle_torch_function_no_python_arg_parser
302
+ # in torch/csrc/utils/python_arg_parser.cpp
303
+
304
+ overloaded_args_list = []
305
+
306
+ def has_python_key(tensor):
307
+ return torch._C._dispatch_keys(tensor).has("Python")
308
+
309
+ def check_overloaded(arg):
310
+ if isinstance(arg, torch.Tensor) and has_python_key(arg):
311
+ overloaded_args_list.append(arg)
312
+
313
+ for arg in (*args, *kwargs.values()):
314
+ check_overloaded(arg)
315
+ if isinstance(arg, (list, tuple)):
316
+ for a in arg:
317
+ check_overloaded(a)
318
+
319
+ overloaded_args = tuple(overloaded_args_list)
320
+ overloaded_types = tuple(type(arg) for arg in overloaded_args)
321
+
322
+ # Step 1: dispatch on any user TorchDispatchModes
323
+ from torch.utils._python_dispatch import _pop_mode_temporarily
324
+
325
+ curr_mode = _get_current_dispatch_mode()
326
+ if curr_mode is not None:
327
+ if type(curr_mode) in self.python_key_table:
328
+ handler = self.python_key_table[type(curr_mode)]
329
+ with _pop_mode_temporarily() as mode:
330
+ # "natural" calling convention: (mode, *args, **kwargs)
331
+ # TODO(rzou): we should support torch_dispatch calling convention too.
332
+ result = handler(mode, *args, **kwargs)
333
+ else:
334
+ raise NotImplementedError(
335
+ f"There was no rule registered for HOP {self._name} and mode {curr_mode}. "
336
+ f"We recommend filing an issue."
337
+ )
338
+ if result is not NotImplemented:
339
+ return result
340
+
341
+ # Step 2: dispatch on any subclasses
342
+ for arg in overloaded_args:
343
+ subclass_type = type(arg)
344
+ if (
345
+ subclass_type.__torch_dispatch__
346
+ == torch._C._disabled_torch_dispatch_impl
347
+ ):
348
+ continue
349
+ if subclass_type in self.python_key_table:
350
+ handler = self.python_key_table[subclass_type]
351
+ # "natural" calling convention: (*args, **kwargs)
352
+ # TODO(rzou): we should support torch_dispatch calling convention too.
353
+ result = handler(*args, **kwargs)
354
+ else:
355
+ raise NotImplementedError(
356
+ f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. "
357
+ f"We recommend filing an issue."
358
+ )
359
+ if result is not NotImplemented:
360
+ return result
361
+
362
+ # All handlers returned NotImplemented
363
+ raise TypeError(
364
+ f"Multiple dispatch failed for {self._name}. There was no registered that "
365
+ f"did not return NotImplemented. Use HOP.py_impl to register some. "
366
+ f"Tried mode: {curr_mode}) and subclasses: "
367
+ f"{[type(a) for a in overloaded_args]}"
368
+ )
369
+
370
+ functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined]
371
+ if functionality_key == DispatchKey.PreDispatch:
372
+ from torch.utils._python_dispatch import _pop_mode_temporarily
373
+
374
+ # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
375
+ # calls inside of a mode.
376
+ if (
377
+ _len_torch_dispatch_stack_pre_dispatch() > 0
378
+ ) and not torch._C._dispatch_tls_is_dispatch_key_excluded(
379
+ DispatchKey.Python
380
+ ):
381
+ curr_mode = _get_current_dispatch_mode_pre_dispatch()
382
+ assert (
383
+ curr_mode is not None
384
+ ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
385
+ assert (
386
+ type(curr_mode) in self.python_key_table
387
+ ), f"Current active mode {curr_mode} not registered"
388
+ handler = self.python_key_table[type(curr_mode)]
389
+ with _pop_mode_temporarily(functionality_key) as mode:
390
+ return handler(mode, *args, **kwargs)
391
+
392
+ final_key = resolve_key(self, dispatch_key)
393
+
394
+ # This can current fail due to backend fallbacks. You just have to
395
+ # register them by hand for HigherOrderOperator.
396
+ if final_key not in self.py_kernels:
397
+ raise NotImplementedError(
398
+ f"could not find kernel for HigherOrderOperator {self._name} "
399
+ f"at dispatch key {final_key} (resolved from {dispatch_key})"
400
+ )
401
+
402
+ # [NOTE] We shouldn't cache PreDispatch kernel here because depending
403
+ # on what modes are active, predispatch behaviour is different.
404
+ # Also we do same thing for normal ops:
405
+ # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
406
+ if dispatch_key != DispatchKey.PreDispatch:
407
+ self._dispatch_cache[dispatch_key] = self.py_kernels[final_key]
408
+ kernel = self.py_kernels[final_key]
409
+ # It's illegal to register DispatchKey to py_kernels, since there's no
410
+ # C++ kernel to call into
411
+ assert not isinstance(kernel, DispatchKey)
412
+ return kernel(*args, **kwargs)
413
+
414
+ @abc.abstractmethod
415
+ def __call__(self, /, *args, **kwargs):
416
+ # Dynamo already traces the body of HigherOrderOp beforehand when it
417
+ # so no need to trace into it.
418
+ from torch._dynamo import disable
419
+
420
+ @disable
421
+ def wrapper():
422
+ flat_args = _to_flat_tuple(args, kwargs)
423
+ if torch.overrides.has_torch_function(flat_args):
424
+ return torch.overrides.handle_torch_function(
425
+ self, flat_args, *args, **kwargs
426
+ )
427
+
428
+ dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
429
+ return self.dispatch(
430
+ dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
431
+ )
432
+
433
+ return wrapper()
434
+
435
+ def __str__(self):
436
+ return f"{self.name()}"
437
+
438
+ def name(self):
439
+ return self._name
440
+
441
+
442
+ def _to_flat_tuple(args, kwargs):
443
+ return pytree.arg_tree_leaves(*args, **kwargs)
444
+
445
+
446
+ def _compute_keyset(args, kwargs, non_fallthrough_keys):
447
+ tensors = _get_tensors(args, kwargs)
448
+ return key_extractor(tensors, non_fallthrough_keys)
449
+
450
+
451
+ def _get_tensors(args, kwargs):
452
+ flat_all = _to_flat_tuple(args, kwargs)
453
+ tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
454
+ return tuple(tensor_args)
455
+
456
+
457
+ # Note - this should maintain identical impl to the C++ dispatcher key extraction logic
458
+ # at ATen/core/dispatch/DispatchKeyExtractor.h
459
+ def key_extractor(tensors, key_mask):
460
+ key_set = torch._C._dispatch_tls_local_include_set()
461
+ for tensor in tensors:
462
+ key_set = key_set | torch._C._dispatch_keys(tensor)
463
+ key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
464
+ key_set = key_set & key_mask
465
+ return key_set
466
+
467
+
468
+ # Mode stack for PreDispatchKey
469
+ # it should always have three keys with
470
+ # priority given to FunctionalTensorMode and
471
+ # then ProxyTorchDispatchMode. It means that
472
+ # slot 0 belongs to ProxyTorchDispatchMode and
473
+ # slot 1 belongs to FunctionalTensorMode.
474
+ #
475
+ # SchemaCheckMode is separate from the other 2,
476
+ # and is only valid when the stack is empty.
477
+ # SchemaCheckMode is for testing purposes, and
478
+ # is meant to run in eager mode on concrete inputs,
479
+ # checking for incorrect schemas in regards to
480
+ # aliasing or mutating ops.
481
+ class _ModeStackStateForPreDispatch:
482
+ def __init__(self):
483
+ self.__infra_modes = [None, None]
484
+ self._schema_check_mode = None
485
+
486
+ def set(self, index, mode):
487
+ assert index < len(self.__infra_modes)
488
+ self.__infra_modes[index] = mode
489
+
490
+ def get(self, index):
491
+ assert index < len(self.__infra_modes)
492
+ return self.__infra_modes[index]
493
+
494
+ def count(self):
495
+ return len([i for i in self.__infra_modes if i is not None]) + int(
496
+ self._schema_check_mode is not None
497
+ )
498
+
499
+
500
+ _mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch()
501
+
502
+
503
+ def unset_mode_pre_dispatch(mode_key, schema_check=False):
504
+ current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch()
505
+ assert mode_key is None or mode_key in (
506
+ torch._C._TorchDispatchModeKey.PROXY,
507
+ torch._C._TorchDispatchModeKey.FUNCTIONAL,
508
+ )
509
+ if schema_check:
510
+ assert mode_key is None
511
+
512
+ def _unset_mode():
513
+ if mode_key == torch._C._TorchDispatchModeKey.PROXY:
514
+ current_mode = current_mode_stack_pre_dispatch.get(0)
515
+ mode_stack_state_for_pre_dispatch().set(0, None)
516
+ return current_mode
517
+ elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL:
518
+ current_mode = current_mode_stack_pre_dispatch.get(1)
519
+ mode_stack_state_for_pre_dispatch().set(1, None)
520
+ return current_mode
521
+ else:
522
+ current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
523
+ mode_stack_state_for_pre_dispatch()._schema_check_mode = None
524
+ return current_mode
525
+
526
+ current_mode = _unset_mode()
527
+
528
+ new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
529
+ # When we are unsetting a mode, we need to check if there is
530
+ # active mode left on the PreDispatch key. If there is nothing
531
+ # active, we need to remove PreDispatch key from local dispatch include
532
+ # set.
533
+ if new_pre_dispatch_len == 0:
534
+ torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False)
535
+
536
+ return current_mode
537
+
538
+
539
+ def _set_mode_pre_dispatch(mode):
540
+ from torch._subclasses.functional_tensor import FunctionalTensorMode
541
+ from torch._subclasses.schema_check_mode import SchemaCheckMode
542
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
543
+
544
+ assert isinstance(
545
+ mode,
546
+ (
547
+ FunctionalTensorMode,
548
+ ProxyTorchDispatchMode,
549
+ SchemaCheckMode,
550
+ ),
551
+ )
552
+
553
+ previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch()
554
+ if isinstance(mode, SchemaCheckMode):
555
+ current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
556
+ if previous_mode_stack_len > 0:
557
+ raise AssertionError(
558
+ "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack"
559
+ )
560
+ mode_stack_state_for_pre_dispatch()._schema_check_mode = mode
561
+ elif isinstance(mode, FunctionalTensorMode):
562
+ current_mode = mode_stack_state_for_pre_dispatch().get(1)
563
+ assert current_mode is None
564
+ mode_stack_state_for_pre_dispatch().set(1, mode)
565
+ else:
566
+ current_mode = mode_stack_state_for_pre_dispatch().get(0)
567
+ assert current_mode is None
568
+ mode_stack_state_for_pre_dispatch().set(0, mode)
569
+
570
+ # When we are setting a mode, we need to check if there is
571
+ # active mode left on the PreDispatch key. If there was nothing
572
+ # active before setting this mode, it means that PreDispatch key
573
+ # was turned off. So we need to turn it on again.
574
+ if previous_mode_stack_len == 0:
575
+ torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True)
576
+
577
+
578
+ def _pop_mode_from_pre_dispatch():
579
+ mode_stack = mode_stack_state_for_pre_dispatch()
580
+ pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
581
+
582
+ if pre_dispatch_len == 0:
583
+ raise AssertionError("Trying to pop empty mode stack")
584
+
585
+ if mode_stack._schema_check_mode is not None:
586
+ return unset_mode_pre_dispatch(None, schema_check=True)
587
+ if mode_stack.get(1) is not None:
588
+ return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL)
589
+ if mode_stack.get(0) is not None:
590
+ return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
591
+
592
+
593
+ def _len_torch_dispatch_stack_pre_dispatch():
594
+ return mode_stack_state_for_pre_dispatch().count()
595
+
596
+
597
+ def _get_dispatch_mode_pre_dispatch(mode_key):
598
+ assert mode_key in (
599
+ torch._C._TorchDispatchModeKey.PROXY,
600
+ torch._C._TorchDispatchModeKey.FUNCTIONAL,
601
+ )
602
+ if mode_key == torch._C._TorchDispatchModeKey.PROXY:
603
+ return mode_stack_state_for_pre_dispatch().get(0)
604
+ else:
605
+ return mode_stack_state_for_pre_dispatch().get(1)
606
+
607
+
608
+ def _get_current_dispatch_mode_pre_dispatch():
609
+ if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None:
610
+ return mode_stack_state_for_pre_dispatch()._schema_check_mode
611
+ else:
612
+ stack_len = mode_stack_state_for_pre_dispatch().count()
613
+ if stack_len == 2:
614
+ return mode_stack_state_for_pre_dispatch().get(1)
615
+ if stack_len == 1:
616
+ return (
617
+ mode_stack_state_for_pre_dispatch().get(1)
618
+ if mode_stack_state_for_pre_dispatch().get(1) is not None
619
+ else mode_stack_state_for_pre_dispatch().get(0)
620
+ )
621
+ return None
622
+
623
+
624
+ def mode_stack_state_for_pre_dispatch():
625
+ global _mode_stack_state_for_pre_dispatch
626
+ return _mode_stack_state_for_pre_dispatch
627
+
628
+
629
+ cached_ops: Set["OpOverload"] = set()
630
+
631
+
632
+ def add_cached_op(op_overload):
633
+ global cached_ops
634
+ cached_ops.add(op_overload)
635
+
636
+
637
+ def reset_cached_ops():
638
+ global cached_ops
639
+ cached_ops.clear()
640
+
641
+
642
+ def get_cached_ops():
643
+ global cached_ops
644
+ return cached_ops
645
+
646
+
647
+ # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
648
+ # You can obtain an OpOverload object through attribute query on OpOverloadPacket.
649
+ class OpOverload(OperatorBase):
650
+ def __init__(self, overloadpacket, op, op_dk, schema, tags):
651
+ super().__init__()
652
+ self._op = op
653
+ self._op_dk = op_dk
654
+ self._schema = schema
655
+ self._overloadpacket = overloadpacket
656
+ self._tags = tags
657
+ self._overloadname = (
658
+ "default" if schema.overload_name == "" else schema.overload_name
659
+ )
660
+ self._name = self._schema.name
661
+ if schema.overload_name:
662
+ self._name += "." + schema.overload_name
663
+ self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}"
664
+ self.__module__ = overloadpacket.__module__
665
+ op.__module__ = overloadpacket.__module__
666
+ self.__qualname__ = self._name
667
+ self.__annotations__ = {}
668
+ # Only compute the OperatorHandle when we need it. Not all OpOverloads have
669
+ # OperatorHandles (the TorchScript ones don't...)
670
+ self._lazy_handle = None
671
+
672
+ # If the OpOverload was constructed from a Library.def in Python.
673
+ self._defined_in_python = self.__qualname__ in torch.library._defs
674
+
675
+ # Logic replicated from aten/src/ATen/native/MathBitsFallback.h
676
+ is_write = None
677
+ for a in self._schema.arguments:
678
+ if a.alias_info is None:
679
+ continue
680
+ if is_write is None:
681
+ is_write = a.alias_info.is_write
682
+ else:
683
+ # We will conservatively call mixed mutable/non-mutable
684
+ # aliased inputs as NOT a view
685
+ is_write = a.alias_info.is_write or is_write
686
+ self.is_view = is_write is not None and not is_write
687
+
688
+ @property
689
+ def _namespace(self):
690
+ return self._schema.name.split("::")[0]
691
+
692
+ @property
693
+ def _opname(self):
694
+ return self._schema.name.split("::")[1]
695
+
696
+ @property
697
+ def _handle(self):
698
+ if self._lazy_handle is None:
699
+ self._lazy_handle = torch._C._dispatch_find_schema_or_throw(
700
+ self._schema.name, self._schema.overload_name
701
+ )
702
+ return self._lazy_handle
703
+
704
+ # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
705
+ def __deepcopy__(self, memo=None):
706
+ return self
707
+
708
+ def __repr__(self):
709
+ return "<OpOverload(op='{}.{}', overload='{}')>".format(
710
+ *self._schema.name.split("::"), self._overloadname
711
+ )
712
+
713
+ # Use positional-only argument to avoid naming collision with aten ops arguments
714
+ # that are named "self". This way, all the aten ops can be called by kwargs.
715
+ def __call__(self, /, *args, **kwargs):
716
+ return self._op(*args, **kwargs)
717
+
718
+ # Use positional-only argument to avoid naming collision with aten ops arguments
719
+ # that are named "self". This way, all the aten ops can be called by kwargs.
720
+ def redispatch(self, /, keyset, *args, **kwargs):
721
+ return self._handle.redispatch_boxed(keyset, *args, **kwargs)
722
+
723
+ def __hash__(self):
724
+ return hash(self._op)
725
+
726
+ # `my_namespace.my_op_name.overload_name`
727
+ def __str__(self):
728
+ return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
729
+
730
+ def has_kernel_for_dispatch_key(self, k):
731
+ return super().has_kernel_for_dispatch_key(
732
+ k
733
+ ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k)
734
+
735
+ def has_kernel_for_any_dispatch_key(self, ks):
736
+ return torch._C._dispatch_has_kernel_for_any_dispatch_key(
737
+ self.name(), ks
738
+ ) or super().has_kernel_for_any_dispatch_key(ks)
739
+
740
+ @property
741
+ def namespace(self):
742
+ return self._schema.name.split("::")[0]
743
+
744
+ def _can_decompose(self):
745
+ dk = DispatchKey.CompositeImplicitAutograd
746
+ return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key(
747
+ self.name(), dk
748
+ )
749
+
750
+ def decompose(self, *args, **kwargs):
751
+ dk = DispatchKey.CompositeImplicitAutograd
752
+ if dk in self.py_kernels:
753
+ # NB: This branch is not too necessary anymore, because we can
754
+ # apply Python CompositeImplicitAutograd *before* tracing
755
+ # using Python dispatcher (also taking advantage of the autograd
756
+ # formula). But it's included for completeness
757
+ return self.py_kernels[dk](*args, **kwargs)
758
+ elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
759
+ return self._op_dk(dk, *args, **kwargs)
760
+ else:
761
+ return NotImplemented
762
+
763
+ # Remove a dispatch key from the dispatch cache. This will force it to get
764
+ # recomputed the next time. Does nothing
765
+ # WARNING: if you register a dispatch key to py_kernels of an OpOverload,
766
+ # calling _del_dispatch on that key is NOT sufficient to apply your change,
767
+ # because a single registration may affect MULTIPLE dispatch keys (e.g.,
768
+ # registering Autograd affects AutogradCPU). del_dispatch is to be used
769
+ # only if you are specifically modifying how get_dispatch handles a
770
+ # particular input 'key'.
771
+ def _uncache_dispatch(self, key):
772
+ self._dispatch_cache.pop(key, None)
773
+
774
+ # This implements the pre-computation logic for the Python dispatcher.
775
+ def _get_dispatch(self, key):
776
+ # This is only called upon a cache miss
777
+ assert key not in self._dispatch_cache, f"{self} {key}"
778
+
779
+ if key == DispatchKey.Python:
780
+ if not isinstance(self, TorchBindOpOverload) and not self.python_key_table:
781
+ self._dispatch_cache[key] = key
782
+ add_cached_op(self)
783
+ return key
784
+
785
+ def handler(*args, **kwargs):
786
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
787
+
788
+ # TODO: We also need to handle tensor subclasses here
789
+ # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
790
+ curr_mode = type(_get_current_dispatch_mode())
791
+ assert (
792
+ curr_mode is not None
793
+ ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
794
+
795
+ if curr_mode not in self.python_key_table:
796
+ if isinstance(self, TorchBindOpOverload):
797
+ with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
798
+ return torch._library.utils.handle_dispatch_mode(
799
+ mode, self, *args, **kwargs
800
+ )
801
+ else:
802
+ return self._op_dk(key, *args, **kwargs)
803
+
804
+ with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
805
+ return self.python_key_table[curr_mode](mode, *args, **kwargs)
806
+
807
+ self._dispatch_cache[key] = handler
808
+ add_cached_op(self)
809
+ return handler
810
+
811
+ functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined]
812
+ if functionality_key == DispatchKey.PreDispatch:
813
+ curr_stack_len = _len_torch_dispatch_stack_pre_dispatch()
814
+ # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
815
+ # calls inside of a mode.
816
+ if (
817
+ curr_stack_len > 0
818
+ and not torch._C._dispatch_tls_is_dispatch_key_excluded(
819
+ DispatchKey.Python
820
+ )
821
+ ):
822
+
823
+ def handler(*args, **kwargs):
824
+ @contextlib.contextmanager
825
+ def _temporarily_pop_modes_from_pre_dispatch():
826
+ top_mode = _pop_mode_from_pre_dispatch()
827
+ try:
828
+ yield top_mode
829
+ finally:
830
+ _set_mode_pre_dispatch(top_mode)
831
+
832
+ with _temporarily_pop_modes_from_pre_dispatch() as curr_mode:
833
+ return torch._library.utils.handle_dispatch_mode(
834
+ curr_mode, self, *args, **kwargs
835
+ )
836
+
837
+ # Note [Not Caching Per-Dispatch-Key Mode Handlers]
838
+ # Note that we're not caching this handler. There isn't really a point, since the slow bit
839
+ # is the handler itself (in python).
840
+ # Also, not caching means that we don't have to reset the cache when any existing
841
+ # modes go out of scope (which in of itself takes time to loop through all operators).
842
+ return handler
843
+
844
+ final_key = resolve_key(self, key)
845
+
846
+ # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
847
+ cache_result = key != DispatchKey.PreDispatch
848
+
849
+ # TODO: We could potentially have lots of debugging wrappers against
850
+ # dispatch keys; design some general registration mechanism instead of
851
+ # having if statement for each of them
852
+ if key == DispatchKey.Functionalize:
853
+ import torch._dispatch.python as pydispatch
854
+
855
+ if pydispatch.CROSSREF_FUNCTIONALIZE:
856
+ handler = pydispatch.make_crossref_functionalize(self, final_key)
857
+ if cache_result:
858
+ self._dispatch_cache[key] = handler
859
+ add_cached_op(self)
860
+ return handler
861
+
862
+ r = self.py_kernels.get(final_key, final_key)
863
+ if cache_result:
864
+ self._dispatch_cache[key] = r
865
+ add_cached_op(self)
866
+ return r
867
+
868
+ def name(self):
869
+ return self._name
870
+
871
+ @property
872
+ def overloadpacket(self):
873
+ return self._overloadpacket
874
+
875
+ @property
876
+ def op(self):
877
+ return self._op
878
+
879
+ @property
880
+ def tags(self):
881
+ return self._tags
882
+
883
+ # TODO: add more methods to expose information about input and output arguments
884
+
885
+
886
+ # TorchBindOpOverload are those custom ops which have at least one overload's
887
+ # schema consists of torch.ScriptObject (i.e. custom class) input.
888
+ # TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
889
+ # when its inputs contain FakeScriptObject in a similar way as higher order ops.
890
+ class TorchBindOpOverload(OpOverload):
891
+ def _fallthrough_keys(self) -> List[DispatchKey]:
892
+ # TODO: we should be calling the fallback for these, but a fallthrough is almost close
893
+ # enough to the fallback in most cases that we care about.
894
+ _DEFAULT_FALLTHROUGH_KEYS = [
895
+ DispatchKey.Autograd,
896
+ DispatchKey.AutogradCPU,
897
+ DispatchKey.AutogradCUDA,
898
+ DispatchKey.ADInplaceOrView,
899
+ DispatchKey.BackendSelect,
900
+ DispatchKey.PythonTLSSnapshot,
901
+ DispatchKey.PythonDispatcher,
902
+ ]
903
+
904
+ def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
905
+ if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
906
+ return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
907
+ self.name(), key
908
+ )
909
+
910
+ return (
911
+ key not in self.py_kernels
912
+ or self.py_kernels[key] is torch.library.fallthrough_kernel
913
+ )
914
+
915
+ return [
916
+ key
917
+ for key in _DEFAULT_FALLTHROUGH_KEYS
918
+ if _may_use_fallthrough_instead_of_fallback(key)
919
+ ]
920
+
921
+ @contextlib.contextmanager
922
+ def _register_as_effectful_op_temporarily(self):
923
+ from torch._higher_order_ops.effects import (
924
+ _EffectType,
925
+ _register_effectful_op,
926
+ SIDE_EFFECTS,
927
+ )
928
+
929
+ try:
930
+ if self not in SIDE_EFFECTS:
931
+ _register_effectful_op(self, _EffectType.ORDERED)
932
+ yield
933
+ finally:
934
+ if self in SIDE_EFFECTS:
935
+ del SIDE_EFFECTS[self]
936
+
937
+ # Use positional-only argument to avoid naming collision with aten ops arguments
938
+ # that are named "self". This way, all the aten ops can be called by kwargs.
939
+ def __call__(self, /, *args, **kwargs):
940
+ if _must_dispatch_in_python(args, kwargs):
941
+ # When any inputs are FakeScriptObject, we need to
942
+ # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
943
+ # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject.
944
+ #
945
+ # Note:
946
+ # 1. We only register the torchbind op temporarily as effectful op because we only want
947
+ # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior
948
+ # of the eagerly executing the op might change after tracing.
949
+ # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
950
+ # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
951
+ with self._register_as_effectful_op_temporarily():
952
+ return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
953
+ return self._op(*args, **kwargs)
954
+
955
+ def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
956
+ non_fallthrough_keys = torch._C._dispatch_keyset_full()
957
+ for key in fallthrough_keys:
958
+ non_fallthrough_keys = non_fallthrough_keys.remove(key)
959
+
960
+ dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
961
+ dispatch_key = dispatch_key_set.highestPriorityTypeId()
962
+
963
+ handler = (
964
+ self._get_dispatch(dispatch_key)
965
+ if dispatch_key not in self._dispatch_cache
966
+ else self._dispatch_cache[dispatch_key]
967
+ )
968
+
969
+ if isinstance(handler, DispatchKey):
970
+ # fallthrough keys can be registered at runtime via torch.library.impl
971
+ # so need to add it to fallthrough_keys and re-dispatch.
972
+ if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
973
+ self.name(), dispatch_key
974
+ ):
975
+ return self._dispatch_in_python(
976
+ args, kwargs, fallthrough_keys + [dispatch_key]
977
+ )
978
+
979
+ raise RuntimeError(
980
+ f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}."
981
+ f" but no python implementation is found."
982
+ f" Please file an issue on this when you encounter this error."
983
+ f" This error can happen when you export or compile the model."
984
+ f" It can still happpen even if a C++ implementation for {dispatch_key}. "
985
+ f" has been registered. That's because FakeScriptObject purely lives in python and cannot work "
986
+ f" with a C++ implementation."
987
+ )
988
+
989
+ assert isinstance(handler, Callable) # type: ignore[arg-type]
990
+ return handler(*args, **kwargs)
991
+
992
+
993
+ def _must_dispatch_in_python(args, kwargs):
994
+ return pytree.tree_any(
995
+ lambda obj: isinstance(
996
+ obj, torch._library.fake_class_registry.FakeScriptObject
997
+ ),
998
+ (args, kwargs),
999
+ )
1000
+
1001
+
1002
+ def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
1003
+ return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments)
1004
+
1005
+
1006
+ # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
1007
+ # You can obtain an OpOverload object through attribute query.
1008
+ class OpOverloadPacket:
1009
+ def __init__(self, qualified_op_name, op_name, op, overload_names):
1010
+ # These attributes are accessible on the object through the properties
1011
+ # defined below but are immutable
1012
+ self._qualified_op_name = qualified_op_name
1013
+ self.__name__ = op_name
1014
+ self._op = op
1015
+ self._overload_names = overload_names
1016
+ self._dir = []
1017
+ self._has_torchbind_op_overload = any(
1018
+ _has_script_object_arg(schema) for schema in self._schemas.values()
1019
+ )
1020
+
1021
+ # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
1022
+ def __deepcopy__(self, memo=None):
1023
+ return self
1024
+
1025
+ def __repr__(self):
1026
+ return "<OpOverloadPacket(op='{}.{}')>".format(
1027
+ *self._qualified_op_name.split("::")
1028
+ )
1029
+
1030
+ def __hash__(self):
1031
+ return hash(self._op)
1032
+
1033
+ def __str__(self):
1034
+ return "{}.{}".format(*self._qualified_op_name.split("::"))
1035
+
1036
+ @property
1037
+ def op(self):
1038
+ return self._op
1039
+
1040
+ @property
1041
+ def _schemas(self):
1042
+ return {
1043
+ overload_name: torch._C._get_schema(self._qualified_op_name, overload_name)
1044
+ for overload_name in self._overload_names
1045
+ }
1046
+
1047
+ def __getattr__(self, key):
1048
+ # It is not a valid op_name when __file__ is passed in
1049
+ if key == "__file__":
1050
+ return "torch.ops"
1051
+
1052
+ # ensure that query for dunder attributes that does not exist on
1053
+ # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
1054
+ # `_get_operation_overload` (which is an expensive operation).
1055
+ # This is done to prevent any potential slowdown. This list can be extended
1056
+ # if there exists other attributes like `__name__` that only exist on self._op and not on the
1057
+ # opoverloadpacket.
1058
+ # This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
1059
+ try:
1060
+ if key.startswith("__"):
1061
+ return getattr(self._op, key)
1062
+ except AttributeError:
1063
+ # for consistency because it seems weird to
1064
+ # throw an attribute error with a message containing
1065
+ # an object name different from the one the attribute
1066
+ # query was performed on.
1067
+ raise AttributeError(
1068
+ f"'{str(self)}' can't have an overload name beginning with '__' and the "
1069
+ f"underlying op {str(self._op)} has no attribute {key} either."
1070
+ ) from None
1071
+
1072
+ try:
1073
+ # This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
1074
+ use_key = "" if key == "default" else key
1075
+ # TODO: disallow access to overloads registered by JIT
1076
+ op_dk_tags = torch._C._get_operation_overload(
1077
+ self._qualified_op_name, use_key
1078
+ )
1079
+ if op_dk_tags is None:
1080
+ raise AttributeError(
1081
+ f"The underlying op of '{str(self)}' has no overload name '{key}'"
1082
+ )
1083
+
1084
+ op_, op_dk_, tags = op_dk_tags
1085
+ schema = torch._C._get_schema(self._qualified_op_name, use_key)
1086
+ overload = (
1087
+ OpOverload(self, op_, op_dk_, schema, tags)
1088
+ if not _has_script_object_arg(schema)
1089
+ else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
1090
+ )
1091
+ # cache the overload object
1092
+ setattr(self, key, overload)
1093
+ self._dir.append(key)
1094
+ return overload
1095
+ except RuntimeError:
1096
+ raise AttributeError(
1097
+ f"The underlying op of '{str(self)}' has no overload name '{key}'"
1098
+ ) from None
1099
+
1100
+ def __iter__(self):
1101
+ return iter(self._dir)
1102
+
1103
+ # Use positional-only argument to avoid naming collision with aten ops arguments
1104
+ # that are named "self". This way, all the aten ops can be called by kwargs.
1105
+ def __call__(self, /, *args, **kwargs):
1106
+ # overloading __call__ to ensure torch.ops.foo.bar()
1107
+ # is still callable from JIT
1108
+ # We save the function ptr as the `op` attribute on
1109
+ # OpOverloadPacket to access it here.
1110
+
1111
+ # Directly calling OverloadPacket goes into C++, which will check
1112
+ # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
1113
+ # intercept it here and call TorchBindOpverload instead.
1114
+ if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
1115
+ return _call_overload_packet_from_python(self, args, kwargs)
1116
+ return self._op(*args, **(kwargs or {}))
1117
+
1118
+ # TODO: use this to make a __dir__
1119
+ def overloads(self):
1120
+ return [n if n else "default" for n in self._overload_names]
1121
+
1122
+
1123
+ # Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
1124
+ # _jit_get_operations, which calls _get_operation_for_overload_or_packet.
1125
+ def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
1126
+ # Re-use the torch function handling logic in cpp
1127
+ torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
1128
+ op, *args, **kwargs
1129
+ )
1130
+
1131
+ if torch_function_called:
1132
+ return ret
1133
+
1134
+ # The following mirrors getOpWithStack.
1135
+ # In cpp, we do a schema matching for the arguments, and call ToIValue to
1136
+ # to check whether the arguments are valid. But need to do similar things here
1137
+ # and check the schema whether the FakeScriptObject is the corresponding fake class
1138
+ # of the actual class used in schema.
1139
+ exceptions = {}
1140
+ found_op = None
1141
+ for overload_name in op.overloads():
1142
+ op_overload = getattr(op, overload_name)
1143
+ try:
1144
+ _ = torch._C._check_schema_allow_fake_script_object(
1145
+ op_overload._schema, *args, **kwargs
1146
+ )
1147
+ found_op = op_overload
1148
+ break
1149
+ except RuntimeError as e:
1150
+ exceptions[overload_name] = e
1151
+
1152
+ if found_op:
1153
+ return found_op(*args, **kwargs)
1154
+
1155
+ err_msg = (
1156
+ f"Fail to match any TorchBindOverload of {op} with following exceptions:\n"
1157
+ )
1158
+ for i, (key, msg) in enumerate(exceptions.items()):
1159
+ err_msg += f"Overload name {key}:\n {msg}\n"
1160
+ raise RuntimeError(err_msg)
1161
+
1162
+
1163
+ # Resolution of torch.fn is different from torch.ops.aten.fn
1164
+ # torch.fn uses the Python argparser, matches with the
1165
+ # appropriate schema, and calls into the unboxed version of the method
1166
+ # torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
1167
+ # JIT creates a stack of all the overloads and then tries to match the
1168
+ # correct one at runtime and always calls into the boxed version of the method
1169
+ # Autograd codegen creates VariableType, TracerType,
1170
+ # inplace or view type and python bindings.
1171
+ # Aten codegen generates tensor methods for the tensor class.
1172
+
1173
+ # _OpNamespace is a subclass of ModuleType because the torch script
1174
+ # allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
1175
+ # to work from script, we need to ensure ops and foo are modules
1176
+
1177
+
1178
+ class _OpNamespace(types.ModuleType):
1179
+ """
1180
+ An op namespace to dynamically bind Operators into Python.
1181
+
1182
+ Say a user has created a custom Operator called "my_namespace::my_op". To
1183
+ call this op, the user will write torch.ops.my_namespace.my_op(...).
1184
+ At startup, this operation will not yet be bound into Python. Instead, the
1185
+ following sequence of magic tricks will occur:
1186
+ 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
1187
+ on the `torch.ops` object, which will create a new `_OpNamespace`
1188
+ object called `my_namespace` and set it as an attribute on the `ops`
1189
+ object.
1190
+ 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
1191
+ the `my_namespace` object, which will retrieve the operation via
1192
+ `torch.get_operation`, a function bound from C++, and then in a similar
1193
+ fashion bind this new object onto the `my_namespace` object.
1194
+ 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
1195
+ and subsequent accesses will incur no further lookup (the namespace and
1196
+ operation will already exist).
1197
+ """
1198
+
1199
+ def __init__(self, name):
1200
+ super().__init__("torch.ops." + name)
1201
+ self.name = name
1202
+ self._dir = []
1203
+
1204
+ def __iter__(self):
1205
+ return iter(self._dir)
1206
+
1207
+ def __getattr__(self, op_name):
1208
+ # It is not a valid op_name when __file__ is passed in
1209
+ if op_name == "__file__":
1210
+ return "torch.ops"
1211
+ elif op_name in ["__origin__", "__self__"]:
1212
+ raise AttributeError(
1213
+ f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'"
1214
+ )
1215
+
1216
+ # Get the op `my_namespace::my_op` if available. This will also check
1217
+ # for overloads and raise an exception if there are more than one.
1218
+ namespace_name = self.name
1219
+ qualified_op_name = f"{namespace_name}::{op_name}"
1220
+ module_name = self.__module__ + "." + namespace_name
1221
+
1222
+ try:
1223
+ op, overload_names = _get_packet(qualified_op_name, module_name)
1224
+ if op is None:
1225
+ raise AttributeError(
1226
+ f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
1227
+ )
1228
+ except RuntimeError as e:
1229
+ # Turn this into AttributeError so getattr(obj, key, default)
1230
+ # works (this is called by TorchScript with __origin__)
1231
+ raise AttributeError(
1232
+ f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
1233
+ ) from e
1234
+
1235
+ op.__module__ = module_name
1236
+ opoverloadpacket = OpOverloadPacket(
1237
+ qualified_op_name, op_name, op, overload_names
1238
+ )
1239
+ opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
1240
+ # cache the opoverloadpacket to ensure that each op corresponds to
1241
+ # a unique OpOverloadPacket object
1242
+ setattr(self, op_name, opoverloadpacket)
1243
+ self._dir.append(op_name)
1244
+ return opoverloadpacket
1245
+
1246
+
1247
+ def _get_packet(qualname, op_module):
1248
+ op, overload_names = torch._C._jit_get_operation(qualname)
1249
+ if op is not None:
1250
+ # let the script frontend know that op is identical to the builtin op
1251
+ # with qualified_op_name
1252
+ torch.jit._builtins._register_builtin(op, qualname)
1253
+ op.__module__ = op_module
1254
+ return op, overload_names
1255
+
1256
+
1257
+ def _refresh_packet(packet):
1258
+ op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__)
1259
+ assert op is not None
1260
+ packet._op = op
1261
+ packet._overload_names = overload_names
1262
+
1263
+
1264
+ class _PyOpNamespace(_OpNamespace):
1265
+ def __init__(self, name, ops):
1266
+ super().__init__(name)
1267
+ self._ops = ops
1268
+
1269
+ def __getattr__(self, name):
1270
+ # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
1271
+ op = self._ops.get(name, None)
1272
+ if op is None:
1273
+ raise AttributeError(
1274
+ f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
1275
+ )
1276
+ setattr(self, name, op)
1277
+ return op
1278
+
1279
+
1280
+ class _Ops(types.ModuleType):
1281
+ __file__ = "_ops.py"
1282
+
1283
+ def __init__(self):
1284
+ super().__init__("torch.ops")
1285
+ self.loaded_libraries = set()
1286
+ self._higher_order_op_namespace = _PyOpNamespace(
1287
+ "torch.ops.higher_order", _higher_order_ops
1288
+ )
1289
+ self._dir = []
1290
+
1291
+ def __getattr__(self, name):
1292
+ # Check if the name is a HigherOrderOperator
1293
+ if name == "higher_order":
1294
+ return self._higher_order_op_namespace
1295
+
1296
+ # Here we are creating `torch.ops.my_namespace`
1297
+ namespace = _OpNamespace(name)
1298
+ setattr(self, name, namespace)
1299
+ self._dir.append(name)
1300
+ return namespace
1301
+
1302
+ def __iter__(self):
1303
+ return iter(self._dir)
1304
+
1305
+ def import_module(self, module):
1306
+ """
1307
+ Imports a Python module that has torch.library registrations.
1308
+
1309
+ Generally, to extend PyTorch with custom operators, a user will
1310
+ create a Python module whose import triggers registration of
1311
+ the custom operators via a torch.ops.load_library call or a call
1312
+ to one or more torch.library.* APIs.
1313
+
1314
+ It is unexpected for Python modules to have side effects, so some
1315
+ linters and formatters will complain. Use this API to import Python
1316
+ modules that contain these torch.library side effects.
1317
+
1318
+ Args:
1319
+ module (str): The name of the Python module to import
1320
+
1321
+ """
1322
+ importlib.import_module(module)
1323
+
1324
+ def load_library(self, path):
1325
+ """
1326
+ Loads a shared library from the given path into the current process.
1327
+
1328
+ The library being loaded may run global initialization code to register
1329
+ custom operators with the PyTorch JIT runtime. This allows dynamically
1330
+ loading custom operators. For this, you should compile your operator
1331
+ and the static registration code into a shared library object, and then
1332
+ call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
1333
+ shared object.
1334
+
1335
+ After the library is loaded, it is added to the
1336
+ ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
1337
+ for the paths of all libraries loaded using this function.
1338
+
1339
+ Args:
1340
+ path (str): A path to a shared library to load.
1341
+ """
1342
+ if torch._running_with_deploy():
1343
+ return
1344
+
1345
+ path = _utils_internal.resolve_library_path(path)
1346
+ with dl_open_guard():
1347
+ # Import the shared library into the process, thus running its
1348
+ # static (global) initialization code in order to register custom
1349
+ # operators with the JIT.
1350
+ ctypes.CDLL(path)
1351
+ self.loaded_libraries.add(path)
1352
+
1353
+
1354
+ # The ops "namespace"
1355
+ ops: _Ops = _Ops()
.venv/lib/python3.11/site-packages/torch/_python_dispatcher.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import re
3
+
4
+ import torch._C as C
5
+
6
+
7
+ """
8
+ PythonDispatcher class is a thin python-binding to C++ dispatcher and it
9
+ is designed to show how dispatcher precompute works. In particular,
10
+ it shows for a certain op `foo`, what the computed dispatch table looks
11
+ like after user register their kernels to certains dispatch keys.
12
+
13
+ In the real C++ dispatcher we support many dispatch keys for different
14
+ functionalities. For simplicity PythonDispatcher only supports dispatch
15
+ keys for a single example of each use case. These use cases are listed below:
16
+
17
+ - CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
18
+ autograd kernel in pytorch core library.
19
+ E.g. CPU, CUDA
20
+ - FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
21
+ inference kernels, but they share the same autograd kernel specified in AutogradOther.
22
+ E.g. FPGA, SparseCsrCPU
23
+ - XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
24
+ kernel defined in pytorch core library. Backend owner is responsible for registering both
25
+ inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
26
+ E.g. XLA, XPU, MPS
27
+ - CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc.
28
+ Kernels registered to this key MUST work for inference for all backends.
29
+ - Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther.
30
+ Kernels registered to this key MUST work for autograd for all backends.
31
+ - CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd
32
+ Kernels registered to this key MUST work for both inference + autograd for all backends.
33
+
34
+ Note we only allow registrations to alias keys inside pytorch core library. E.g
35
+ you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd
36
+ kernel from torch-xla extension, instead you should upstream the kernel into
37
+ pytorch/pytorch repo so that it's available for all backends and continuously
38
+ tested even without the extension.
39
+
40
+ Usage:
41
+ dispatcher = PythonDispatcher()
42
+ dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"])
43
+ print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend.
44
+ # For more debugging information
45
+ # print(dispatcher.keys())
46
+ # print(dispatcher.registrations())
47
+ # print(dispatcher.rawRegistrations())
48
+ # print(dispatcher.rawDispatchTable())
49
+ PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table.
50
+ This file only provides the simplified API for developers, relevant test code is located in
51
+ test/test_dispatch.py
52
+ """
53
+
54
+
55
+ class PythonDispatcher:
56
+ namespace = "__test__"
57
+ name = "foo"
58
+ # fmt: off
59
+ runtime_keys = [
60
+ "CPU", "AutogradCPU",
61
+ "FPGA", "AutogradOther",
62
+ "XLA", "AutogradXLA",
63
+ "Lazy", "AutogradLazy",
64
+ ]
65
+ # fmt: on
66
+ alias_keys = [
67
+ "CompositeExplicitAutograd",
68
+ "Autograd",
69
+ "CompositeImplicitAutograd",
70
+ ]
71
+ supported_keys = runtime_keys + alias_keys
72
+
73
+ def __init__(self) -> None:
74
+ C._dispatch_check_invariants(self.name) # type: ignore[attr-defined]
75
+ self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
76
+ self.ref.def_("foo(Tensor x) -> Tensor")
77
+
78
+ """
79
+ Returns a list of dispatch keys supported by PythonDispatcher.
80
+ You can register kernels to these keys.
81
+ """
82
+
83
+ def keys(self):
84
+ return self.supported_keys
85
+
86
+ """
87
+ Register kernels to the target dispatchKeys.
88
+ dispatchKeys(list[str]): a list of dispatch keys that you want to register
89
+ your own kernel. Note that you don't need to write the kernel yourself in
90
+ this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is
91
+ automatically generated and registered.
92
+ """
93
+
94
+ def register(self, dispatchKeys):
95
+ # Overriden is not supported and triggers a warning in C++ dispatcher.
96
+ if len(set(dispatchKeys)) != len(dispatchKeys):
97
+ raise RuntimeError(
98
+ f"Overriden is not allowed but found duplicates in {dispatchKeys}."
99
+ )
100
+ # We currently forbid this in codegen instead of C++ dispatcher.
101
+ if (
102
+ "CompositeImplicitAutograd" in dispatchKeys
103
+ and "CompositeExplicitAutograd" in dispatchKeys
104
+ ):
105
+ raise RuntimeError(
106
+ "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed."
107
+ )
108
+ for key in dispatchKeys:
109
+ if key not in self.supported_keys:
110
+ raise RuntimeError(
111
+ f"{key} is not supported, please select a dispatch key in {self.supported_keys}."
112
+ )
113
+ self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key)
114
+
115
+ """
116
+ Helper function to format (key, kernel).
117
+ """
118
+
119
+ def _format_line(self, key, kernel):
120
+ return f"{key:<15} {kernel}\n"
121
+
122
+ """
123
+ Helper function to print a table header.
124
+ """
125
+
126
+ def _format_header(self, header):
127
+ s = f"""
128
+ {header}
129
+ """
130
+ s += self._format_line("key", "kernel")
131
+ s += "---------------------------\n"
132
+ return s
133
+
134
+ """
135
+ Returns raw output of all registration info for debugging only.
136
+ Use registrations() for a simplified version.
137
+ """
138
+
139
+ def rawRegistrations(self):
140
+ return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined]
141
+
142
+ """
143
+ Returns raw output of computed dispatch table for debugging only.
144
+ Use dispatchTable() for a simplified version.
145
+ """
146
+
147
+ def rawDispatchTable(self):
148
+ return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined]
149
+
150
+ """
151
+ Returns a table(str) including all the registrations from users.
152
+ Note this includes registrations to both runtime keys and alias keys.
153
+ """
154
+
155
+ def registrations(self):
156
+ output = self._format_header("Registered Kernels")
157
+ state = self.rawRegistrations()
158
+ state_entries = state.split("\n")
159
+ for line in state_entries:
160
+ first = line.split(":")[0]
161
+ if any(first.startswith(k) for k in self.supported_keys):
162
+ kernel = line.split("::")[0].split(" ")[1]
163
+ output += self._format_line(first, kernel)
164
+ return output
165
+
166
+ """
167
+ Returns the computed dispatch table(str). Note this only include
168
+ runtime keys, registrations to alias keys have been decoded to their
169
+ mapped runtime keys.
170
+ """
171
+
172
+ def dispatchTable(self):
173
+ output = self._format_header("Computed Dispatch Table")
174
+ table = self.rawDispatchTable()
175
+ table_entries = table.split("\n")
176
+ regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
177
+ for line in table_entries:
178
+ k = line.split(":")[0]
179
+ if k in self.runtime_keys:
180
+ entry = regex.sub("[", line)
181
+ output += self._format_line(k, entry.split(": ")[1])
182
+ return output