Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/apiclient/__init__.py +27 -0
- .venv/lib/python3.11/site-packages/apiclient/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/certifi/__init__.py +4 -0
- .venv/lib/python3.11/site-packages/certifi/__main__.py +12 -0
- .venv/lib/python3.11/site-packages/certifi/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/certifi/__pycache__/__main__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/certifi/__pycache__/core.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/certifi/cacert.pem +0 -0
- .venv/lib/python3.11/site-packages/certifi/core.py +114 -0
- .venv/lib/python3.11/site-packages/certifi/py.typed +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_auth.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_cache_assets.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_cache_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_datetime.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_deprecation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_experimental.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/endpoint_helpers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/insecure_hashlib.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/logging.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/tqdm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 +3 -0
- .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/INSTALLER +1 -0
- .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/LICENSE +201 -0
- .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/METADATA +51 -0
- .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/NOTICE +5 -0
- .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/RECORD +58 -0
- .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/WHEEL +5 -0
- .venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/top_level.txt +1 -0
- .venv/lib/python3.11/site-packages/torch/_C.cpython-311-x86_64-linux-gnu.so +0 -0
- .venv/lib/python3.11/site-packages/torch/_VF.py +31 -0
- .venv/lib/python3.11/site-packages/torch/_VF.pyi +0 -0
- .venv/lib/python3.11/site-packages/torch/__config__.py +23 -0
- .venv/lib/python3.11/site-packages/torch/__future__.py +75 -0
- .venv/lib/python3.11/site-packages/torch/__init__.py +2665 -0
- .venv/lib/python3.11/site-packages/torch/_appdirs.py +667 -0
- .venv/lib/python3.11/site-packages/torch/_classes.py +56 -0
- .venv/lib/python3.11/site-packages/torch/_compile.py +38 -0
- .venv/lib/python3.11/site-packages/torch/_custom_ops.py +324 -0
- .venv/lib/python3.11/site-packages/torch/_deploy.py +104 -0
- .venv/lib/python3.11/site-packages/torch/_guards.py +925 -0
- .venv/lib/python3.11/site-packages/torch/_jit_internal.py +1547 -0
- .venv/lib/python3.11/site-packages/torch/_linalg_utils.py +150 -0
- .venv/lib/python3.11/site-packages/torch/_lobpcg.py +1157 -0
- .venv/lib/python3.11/site-packages/torch/_lowrank.py +294 -0
- .venv/lib/python3.11/site-packages/torch/_meta_registrations.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_namedtensor_internals.py +159 -0
- .venv/lib/python3.11/site-packages/torch/_ops.py +1355 -0
- .venv/lib/python3.11/site-packages/torch/_python_dispatcher.py +182 -0
.gitattributes
CHANGED
|
@@ -413,3 +413,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
|
|
| 413 |
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.4 filter=lfs diff=lfs merge=lfs -text
|
| 414 |
.venv/lib/python3.11/site-packages/cv2/cv2.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 415 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_engines_runtime_compiled.so.9 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 413 |
.venv/lib/python3.11/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.4 filter=lfs diff=lfs merge=lfs -text
|
| 414 |
.venv/lib/python3.11/site-packages/cv2/cv2.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 415 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_engines_runtime_compiled.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 416 |
+
.venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/apiclient/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Retain apiclient as an alias for googleapiclient."""
|
| 2 |
+
|
| 3 |
+
from googleapiclient import channel, discovery, errors, http, mimeparse, model
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from googleapiclient import sample_tools
|
| 7 |
+
except ImportError:
|
| 8 |
+
# Silently ignore, because the vast majority of consumers won't use it and
|
| 9 |
+
# it has deep dependence on oauth2client, an optional dependency.
|
| 10 |
+
sample_tools = None
|
| 11 |
+
from googleapiclient import schema
|
| 12 |
+
|
| 13 |
+
_SUBMODULES = {
|
| 14 |
+
"channel": channel,
|
| 15 |
+
"discovery": discovery,
|
| 16 |
+
"errors": errors,
|
| 17 |
+
"http": http,
|
| 18 |
+
"mimeparse": mimeparse,
|
| 19 |
+
"model": model,
|
| 20 |
+
"sample_tools": sample_tools,
|
| 21 |
+
"schema": schema,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
for module_name, module in _SUBMODULES.items():
|
| 27 |
+
sys.modules["apiclient.%s" % module_name] = module
|
.venv/lib/python3.11/site-packages/apiclient/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (954 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/certifi/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .core import contents, where
|
| 2 |
+
|
| 3 |
+
__all__ = ["contents", "where"]
|
| 4 |
+
__version__ = "2024.12.14"
|
.venv/lib/python3.11/site-packages/certifi/__main__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from certifi import contents, where
|
| 4 |
+
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
parser.add_argument("-c", "--contents", action="store_true")
|
| 7 |
+
args = parser.parse_args()
|
| 8 |
+
|
| 9 |
+
if args.contents:
|
| 10 |
+
print(contents())
|
| 11 |
+
else:
|
| 12 |
+
print(where())
|
.venv/lib/python3.11/site-packages/certifi/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (322 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/certifi/__pycache__/__main__.cpython-311.pyc
ADDED
|
Binary file (711 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/certifi/__pycache__/core.cpython-311.pyc
ADDED
|
Binary file (3.75 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/certifi/cacert.pem
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/certifi/core.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
certifi.py
|
| 3 |
+
~~~~~~~~~~
|
| 4 |
+
|
| 5 |
+
This module returns the installation location of cacert.pem or its contents.
|
| 6 |
+
"""
|
| 7 |
+
import sys
|
| 8 |
+
import atexit
|
| 9 |
+
|
| 10 |
+
def exit_cacert_ctx() -> None:
|
| 11 |
+
_CACERT_CTX.__exit__(None, None, None) # type: ignore[union-attr]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if sys.version_info >= (3, 11):
|
| 15 |
+
|
| 16 |
+
from importlib.resources import as_file, files
|
| 17 |
+
|
| 18 |
+
_CACERT_CTX = None
|
| 19 |
+
_CACERT_PATH = None
|
| 20 |
+
|
| 21 |
+
def where() -> str:
|
| 22 |
+
# This is slightly terrible, but we want to delay extracting the file
|
| 23 |
+
# in cases where we're inside of a zipimport situation until someone
|
| 24 |
+
# actually calls where(), but we don't want to re-extract the file
|
| 25 |
+
# on every call of where(), so we'll do it once then store it in a
|
| 26 |
+
# global variable.
|
| 27 |
+
global _CACERT_CTX
|
| 28 |
+
global _CACERT_PATH
|
| 29 |
+
if _CACERT_PATH is None:
|
| 30 |
+
# This is slightly janky, the importlib.resources API wants you to
|
| 31 |
+
# manage the cleanup of this file, so it doesn't actually return a
|
| 32 |
+
# path, it returns a context manager that will give you the path
|
| 33 |
+
# when you enter it and will do any cleanup when you leave it. In
|
| 34 |
+
# the common case of not needing a temporary file, it will just
|
| 35 |
+
# return the file system location and the __exit__() is a no-op.
|
| 36 |
+
#
|
| 37 |
+
# We also have to hold onto the actual context manager, because
|
| 38 |
+
# it will do the cleanup whenever it gets garbage collected, so
|
| 39 |
+
# we will also store that at the global level as well.
|
| 40 |
+
_CACERT_CTX = as_file(files("certifi").joinpath("cacert.pem"))
|
| 41 |
+
_CACERT_PATH = str(_CACERT_CTX.__enter__())
|
| 42 |
+
atexit.register(exit_cacert_ctx)
|
| 43 |
+
|
| 44 |
+
return _CACERT_PATH
|
| 45 |
+
|
| 46 |
+
def contents() -> str:
|
| 47 |
+
return files("certifi").joinpath("cacert.pem").read_text(encoding="ascii")
|
| 48 |
+
|
| 49 |
+
elif sys.version_info >= (3, 7):
|
| 50 |
+
|
| 51 |
+
from importlib.resources import path as get_path, read_text
|
| 52 |
+
|
| 53 |
+
_CACERT_CTX = None
|
| 54 |
+
_CACERT_PATH = None
|
| 55 |
+
|
| 56 |
+
def where() -> str:
|
| 57 |
+
# This is slightly terrible, but we want to delay extracting the
|
| 58 |
+
# file in cases where we're inside of a zipimport situation until
|
| 59 |
+
# someone actually calls where(), but we don't want to re-extract
|
| 60 |
+
# the file on every call of where(), so we'll do it once then store
|
| 61 |
+
# it in a global variable.
|
| 62 |
+
global _CACERT_CTX
|
| 63 |
+
global _CACERT_PATH
|
| 64 |
+
if _CACERT_PATH is None:
|
| 65 |
+
# This is slightly janky, the importlib.resources API wants you
|
| 66 |
+
# to manage the cleanup of this file, so it doesn't actually
|
| 67 |
+
# return a path, it returns a context manager that will give
|
| 68 |
+
# you the path when you enter it and will do any cleanup when
|
| 69 |
+
# you leave it. In the common case of not needing a temporary
|
| 70 |
+
# file, it will just return the file system location and the
|
| 71 |
+
# __exit__() is a no-op.
|
| 72 |
+
#
|
| 73 |
+
# We also have to hold onto the actual context manager, because
|
| 74 |
+
# it will do the cleanup whenever it gets garbage collected, so
|
| 75 |
+
# we will also store that at the global level as well.
|
| 76 |
+
_CACERT_CTX = get_path("certifi", "cacert.pem")
|
| 77 |
+
_CACERT_PATH = str(_CACERT_CTX.__enter__())
|
| 78 |
+
atexit.register(exit_cacert_ctx)
|
| 79 |
+
|
| 80 |
+
return _CACERT_PATH
|
| 81 |
+
|
| 82 |
+
def contents() -> str:
|
| 83 |
+
return read_text("certifi", "cacert.pem", encoding="ascii")
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
import os
|
| 87 |
+
import types
|
| 88 |
+
from typing import Union
|
| 89 |
+
|
| 90 |
+
Package = Union[types.ModuleType, str]
|
| 91 |
+
Resource = Union[str, "os.PathLike"]
|
| 92 |
+
|
| 93 |
+
# This fallback will work for Python versions prior to 3.7 that lack the
|
| 94 |
+
# importlib.resources module but relies on the existing `where` function
|
| 95 |
+
# so won't address issues with environments like PyOxidizer that don't set
|
| 96 |
+
# __file__ on modules.
|
| 97 |
+
def read_text(
|
| 98 |
+
package: Package,
|
| 99 |
+
resource: Resource,
|
| 100 |
+
encoding: str = 'utf-8',
|
| 101 |
+
errors: str = 'strict'
|
| 102 |
+
) -> str:
|
| 103 |
+
with open(where(), encoding=encoding) as data:
|
| 104 |
+
return data.read()
|
| 105 |
+
|
| 106 |
+
# If we don't have importlib.resources, then we will just do the old logic
|
| 107 |
+
# of assuming we're on the filesystem and munge the path directly.
|
| 108 |
+
def where() -> str:
|
| 109 |
+
f = os.path.dirname(__file__)
|
| 110 |
+
|
| 111 |
+
return os.path.join(f, "cacert.pem")
|
| 112 |
+
|
| 113 |
+
def contents() -> str:
|
| 114 |
+
return read_text("certifi", "cacert.pem", encoding="ascii")
|
.venv/lib/python3.11/site-packages/certifi/py.typed
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_auth.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_cache_assets.cpython-311.pyc
ADDED
|
Binary file (5.77 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_cache_manager.cpython-311.pyc
ADDED
|
Binary file (40.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_datetime.cpython-311.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_deprecation.cpython-311.pyc
ADDED
|
Binary file (7.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/_experimental.cpython-311.pyc
ADDED
|
Binary file (2.41 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/endpoint_helpers.cpython-311.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/insecure_hashlib.cpython-311.pyc
ADDED
|
Binary file (623 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/logging.cpython-311.pyc
ADDED
|
Binary file (6.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/huggingface_hub/utils/__pycache__/tqdm.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47662749a295f771b92abe8d99dcd5f151953d56069a19f43977b97868ec21eb
|
| 3 |
+
size 82303400
|
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/METADATA
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: prometheus_client
|
| 3 |
+
Version: 0.21.1
|
| 4 |
+
Summary: Python client for the Prometheus monitoring system.
|
| 5 |
+
Home-page: https://github.com/prometheus/client_python
|
| 6 |
+
Author: Brian Brazil
|
| 7 |
+
Author-email: brian.brazil@robustperception.io
|
| 8 |
+
License: Apache Software License 2.0
|
| 9 |
+
Keywords: prometheus monitoring instrumentation client
|
| 10 |
+
Classifier: Development Status :: 4 - Beta
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: Intended Audience :: Information Technology
|
| 13 |
+
Classifier: Intended Audience :: System Administrators
|
| 14 |
+
Classifier: Programming Language :: Python
|
| 15 |
+
Classifier: Programming Language :: Python :: 3
|
| 16 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 21 |
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
| 22 |
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
| 23 |
+
Classifier: Topic :: System :: Monitoring
|
| 24 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 25 |
+
Requires-Python: >=3.8
|
| 26 |
+
Description-Content-Type: text/markdown
|
| 27 |
+
License-File: LICENSE
|
| 28 |
+
License-File: NOTICE
|
| 29 |
+
Provides-Extra: twisted
|
| 30 |
+
Requires-Dist: twisted; extra == "twisted"
|
| 31 |
+
|
| 32 |
+
# Prometheus Python Client
|
| 33 |
+
|
| 34 |
+
The official Python client for [Prometheus](https://prometheus.io).
|
| 35 |
+
|
| 36 |
+
## Installation
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
pip install prometheus-client
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
This package can be found on [PyPI](https://pypi.python.org/pypi/prometheus_client).
|
| 43 |
+
|
| 44 |
+
## Documentation
|
| 45 |
+
|
| 46 |
+
Documentation is available on https://prometheus.github.io/client_python
|
| 47 |
+
|
| 48 |
+
## Links
|
| 49 |
+
|
| 50 |
+
* [Releases](https://github.com/prometheus/client_python/releases): The releases page shows the history of the project and acts as a changelog.
|
| 51 |
+
* [PyPI](https://pypi.python.org/pypi/prometheus_client)
|
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/NOTICE
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Prometheus instrumentation library for Python applications
|
| 2 |
+
Copyright 2015 The Prometheus Authors
|
| 3 |
+
|
| 4 |
+
This product bundles decorator 4.0.10 which is available under a "2-clause BSD"
|
| 5 |
+
license. For details, see prometheus_client/decorator.py.
|
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/RECORD
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
prometheus_client-0.21.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 2 |
+
prometheus_client-0.21.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
| 3 |
+
prometheus_client-0.21.1.dist-info/METADATA,sha256=r74KhsmW6__tSpz4xH6BX7qsbJfFWYfj24x1elsVtr8,1842
|
| 4 |
+
prometheus_client-0.21.1.dist-info/NOTICE,sha256=TvoYdK6qYPNl9Xl-YX8f-TPhXlCOr3UemEjtRBPXp64,236
|
| 5 |
+
prometheus_client-0.21.1.dist-info/RECORD,,
|
| 6 |
+
prometheus_client-0.21.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
| 7 |
+
prometheus_client-0.21.1.dist-info/top_level.txt,sha256=AxLEvHEMhTW-Kvb9Ly1DPI3aapigQ2aeg8TXMt9WMRo,18
|
| 8 |
+
prometheus_client/__init__.py,sha256=D-ptlQkWPXqZIJPi5TR0QNMdWr_Ejv-gMq6WAFik_9o,1815
|
| 9 |
+
prometheus_client/__pycache__/__init__.cpython-311.pyc,,
|
| 10 |
+
prometheus_client/__pycache__/asgi.cpython-311.pyc,,
|
| 11 |
+
prometheus_client/__pycache__/context_managers.cpython-311.pyc,,
|
| 12 |
+
prometheus_client/__pycache__/core.cpython-311.pyc,,
|
| 13 |
+
prometheus_client/__pycache__/decorator.cpython-311.pyc,,
|
| 14 |
+
prometheus_client/__pycache__/exposition.cpython-311.pyc,,
|
| 15 |
+
prometheus_client/__pycache__/gc_collector.cpython-311.pyc,,
|
| 16 |
+
prometheus_client/__pycache__/metrics.cpython-311.pyc,,
|
| 17 |
+
prometheus_client/__pycache__/metrics_core.cpython-311.pyc,,
|
| 18 |
+
prometheus_client/__pycache__/mmap_dict.cpython-311.pyc,,
|
| 19 |
+
prometheus_client/__pycache__/multiprocess.cpython-311.pyc,,
|
| 20 |
+
prometheus_client/__pycache__/parser.cpython-311.pyc,,
|
| 21 |
+
prometheus_client/__pycache__/platform_collector.cpython-311.pyc,,
|
| 22 |
+
prometheus_client/__pycache__/process_collector.cpython-311.pyc,,
|
| 23 |
+
prometheus_client/__pycache__/registry.cpython-311.pyc,,
|
| 24 |
+
prometheus_client/__pycache__/samples.cpython-311.pyc,,
|
| 25 |
+
prometheus_client/__pycache__/utils.cpython-311.pyc,,
|
| 26 |
+
prometheus_client/__pycache__/values.cpython-311.pyc,,
|
| 27 |
+
prometheus_client/asgi.py,sha256=ivn-eV7ZU0BEa4E9oWBFbBRUklHPw9f5lcdGsyFuCLo,1606
|
| 28 |
+
prometheus_client/bridge/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 29 |
+
prometheus_client/bridge/__pycache__/__init__.cpython-311.pyc,,
|
| 30 |
+
prometheus_client/bridge/__pycache__/graphite.cpython-311.pyc,,
|
| 31 |
+
prometheus_client/bridge/graphite.py,sha256=m5-7IyVyGL8C6S9yLxeupS1pfj8KFNPNlazddamQT8s,2897
|
| 32 |
+
prometheus_client/context_managers.py,sha256=E7uksn4D7yBoZWDgjI1VRpR3l2tKivs9DHZ5UAcmPwE,2343
|
| 33 |
+
prometheus_client/core.py,sha256=yyVvSxa8WQnBvAr4JhO3HqdTqClwhbzmVGvwRvWQMIo,860
|
| 34 |
+
prometheus_client/decorator.py,sha256=7MdUokWmzQ17foet2R5QcMubdZ1WDPGYo0_HqLxAw2k,15802
|
| 35 |
+
prometheus_client/exposition.py,sha256=nmushN6NIGo-nOBeaCXfg5bCeyvesVM_DXUWmRjFwr4,26176
|
| 36 |
+
prometheus_client/gc_collector.py,sha256=tBhXXktF9g9h7gvO-DmI2gxPol2_gXI1M6e9ZMazNfY,1514
|
| 37 |
+
prometheus_client/metrics.py,sha256=ypy4Vv0duzCgo4ZXHBNK45uU9hbe7iK-Fohv7EJ_I5A,28109
|
| 38 |
+
prometheus_client/metrics_core.py,sha256=Yz-yqS3pxNdpIRMShQv_IHaKlVS_Q53TaYcP9U8LDlE,15548
|
| 39 |
+
prometheus_client/mmap_dict.py,sha256=-t49kywZHFHk2D9IWtunqKFtr5eEgiN-RjFWg16JE-Q,5393
|
| 40 |
+
prometheus_client/multiprocess.py,sha256=VIvAR0vmjL0lknnTijKt9HS1DNz9rZrS09HqIIcaZLs,7539
|
| 41 |
+
prometheus_client/openmetrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 42 |
+
prometheus_client/openmetrics/__pycache__/__init__.cpython-311.pyc,,
|
| 43 |
+
prometheus_client/openmetrics/__pycache__/exposition.cpython-311.pyc,,
|
| 44 |
+
prometheus_client/openmetrics/__pycache__/parser.cpython-311.pyc,,
|
| 45 |
+
prometheus_client/openmetrics/exposition.py,sha256=Ef3GeveuojMzOrl-T7cG6Ml2TRN1xIYjpe_puReFrlo,2993
|
| 46 |
+
prometheus_client/openmetrics/parser.py,sha256=c6vQccyW93MXzc22QGdceETg0m_KMeMyEbKrfObG0R8,22125
|
| 47 |
+
prometheus_client/parser.py,sha256=zuVhB8clFPvQ9wOEj1XikN7NoJe8J3pZcQkNgEUkuXg,7434
|
| 48 |
+
prometheus_client/platform_collector.py,sha256=t_GD2oCLN3Pql4TltbNqTap8a4HOtbvBm0OU5_gPn38,1879
|
| 49 |
+
prometheus_client/process_collector.py,sha256=B8y36L1iq0c3KFlvdNj1F5JEQLTec116h6y3m9Jhk90,3864
|
| 50 |
+
prometheus_client/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 51 |
+
prometheus_client/registry.py,sha256=3R-yxiPitVs36cnIRnotqSJmOPwAQsLz-tl6kw3rcd4,6196
|
| 52 |
+
prometheus_client/samples.py,sha256=smIiOIsAwGXHgM_7xg9Zo5yTEM2gavYvVtgGTjdWMcA,1687
|
| 53 |
+
prometheus_client/twisted/__init__.py,sha256=0RxJjYSOC5p6o2cu6JbfUzc8ReHYQGNv9pKP-U4u7OE,72
|
| 54 |
+
prometheus_client/twisted/__pycache__/__init__.cpython-311.pyc,,
|
| 55 |
+
prometheus_client/twisted/__pycache__/_exposition.cpython-311.pyc,,
|
| 56 |
+
prometheus_client/twisted/_exposition.py,sha256=2TL2BH5sW0i6H7dHkot9aBH9Ld-I60ax55DuaIWnElo,250
|
| 57 |
+
prometheus_client/utils.py,sha256=zKJZaW_hyZgQSmkaD-rgT5l-YsT3--le0BRQ7v_x8eE,594
|
| 58 |
+
prometheus_client/values.py,sha256=hzThQQd0x4mIPR3ddezQpjUoDVdSBnwem4Z48woxpa8,5002
|
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: setuptools (75.6.0)
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py3-none-any
|
| 5 |
+
|
.venv/lib/python3.11/site-packages/prometheus_client-0.21.1.dist-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
prometheus_client
|
.venv/lib/python3.11/site-packages/torch/_C.cpython-311-x86_64-linux-gnu.so
ADDED
|
Binary file (37.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_VF.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This makes the functions in torch._C._VariableFunctions available as
|
| 3 |
+
torch._VF.<funcname>
|
| 4 |
+
without mypy being able to find them.
|
| 5 |
+
|
| 6 |
+
A subset of those functions are mapped to ATen functions in
|
| 7 |
+
torch/jit/_builtins.py
|
| 8 |
+
|
| 9 |
+
See https://github.com/pytorch/pytorch/issues/21478 for the reason for
|
| 10 |
+
introducing torch._VF
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
import types
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VFModule(types.ModuleType):
|
| 21 |
+
vf: types.ModuleType
|
| 22 |
+
|
| 23 |
+
def __init__(self, name: str):
|
| 24 |
+
super().__init__(name)
|
| 25 |
+
self.vf = torch._C._VariableFunctions
|
| 26 |
+
|
| 27 |
+
def __getattr__(self, name: str) -> object:
|
| 28 |
+
return getattr(self.vf, name)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
sys.modules[__name__] = VFModule(__name__)
|
.venv/lib/python3.11/site-packages/torch/_VF.pyi
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/__config__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def show():
|
| 6 |
+
"""
|
| 7 |
+
Return a human-readable string with descriptions of the
|
| 8 |
+
configuration of PyTorch.
|
| 9 |
+
"""
|
| 10 |
+
return torch._C._show_config()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# TODO: In principle, we could provide more structured version/config
|
| 14 |
+
# information here. For now only CXX_FLAGS is exposed, as Timer
|
| 15 |
+
# uses them.
|
| 16 |
+
def _cxx_flags():
|
| 17 |
+
"""Returns the CXX_FLAGS used when building PyTorch."""
|
| 18 |
+
return torch._C._cxx_flags()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parallel_info():
|
| 22 |
+
r"""Returns detailed string with parallelization settings"""
|
| 23 |
+
return torch._C._parallel_info()
|
.venv/lib/python3.11/site-packages/torch/__future__.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_overwrite_module_params_on_conversion: bool = False
|
| 2 |
+
_swap_module_params_on_conversion: bool = False
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def set_overwrite_module_params_on_conversion(value: bool) -> None:
|
| 6 |
+
"""
|
| 7 |
+
Sets whether to assign new tensors to the parameters instead of changing the
|
| 8 |
+
existing parameters in-place when converting an ``nn.Module``.
|
| 9 |
+
|
| 10 |
+
When enabled, the following methods will assign new parameters to the module:
|
| 11 |
+
|
| 12 |
+
#. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
|
| 13 |
+
#. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
|
| 14 |
+
#. :meth:`nn.Module.to`
|
| 15 |
+
#. :meth:`nn.Module.to_empty`
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
value (bool): Whether to assign new tensors or not.
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
global _overwrite_module_params_on_conversion
|
| 22 |
+
_overwrite_module_params_on_conversion = value
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_overwrite_module_params_on_conversion() -> bool:
|
| 26 |
+
"""
|
| 27 |
+
Returns whether to assign new tensors to the parameters instead of changing the
|
| 28 |
+
existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``.
|
| 29 |
+
|
| 30 |
+
See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
|
| 31 |
+
"""
|
| 32 |
+
return _overwrite_module_params_on_conversion
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def set_swap_module_params_on_conversion(value: bool) -> None:
|
| 36 |
+
"""
|
| 37 |
+
Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
|
| 38 |
+
change the existing parameters in-place when converting an ``nn.Module`` and instead
|
| 39 |
+
of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
|
| 40 |
+
|
| 41 |
+
.. note::
|
| 42 |
+
This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion`
|
| 43 |
+
|
| 44 |
+
When enabled, the following methods will swap the existing parameters in-place:
|
| 45 |
+
|
| 46 |
+
#. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
|
| 47 |
+
#. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
|
| 48 |
+
#. :meth:`nn.Module.to`
|
| 49 |
+
#. :meth:`nn.Module.to_empty`
|
| 50 |
+
#. :meth:`nn.Module.load_state_dict`
|
| 51 |
+
|
| 52 |
+
The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
|
| 53 |
+
|
| 54 |
+
#. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via
|
| 55 |
+
:meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
|
| 56 |
+
#. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
|
| 57 |
+
#. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`
|
| 58 |
+
with ``res``
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not.
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
global _swap_module_params_on_conversion
|
| 65 |
+
_swap_module_params_on_conversion = value
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_swap_module_params_on_conversion() -> bool:
|
| 69 |
+
"""
|
| 70 |
+
Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
|
| 71 |
+
change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
|
| 72 |
+
|
| 73 |
+
See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information.
|
| 74 |
+
"""
|
| 75 |
+
return _swap_module_params_on_conversion
|
.venv/lib/python3.11/site-packages/torch/__init__.py
ADDED
|
@@ -0,0 +1,2665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The torch package contains data structures for multi-dimensional
|
| 3 |
+
tensors and defines mathematical operations over these tensors.
|
| 4 |
+
Additionally, it provides many utilities for efficient serialization of
|
| 5 |
+
Tensors and arbitrary types, and other useful utilities.
|
| 6 |
+
|
| 7 |
+
It has a CUDA counterpart, that enables you to run your tensor computations
|
| 8 |
+
on an NVIDIA GPU with compute capability >= 3.0.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
# mypy: allow-untyped-defs
|
| 12 |
+
|
| 13 |
+
import builtins
|
| 14 |
+
import ctypes
|
| 15 |
+
import glob
|
| 16 |
+
import importlib
|
| 17 |
+
import inspect
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
import platform
|
| 21 |
+
import sys
|
| 22 |
+
import textwrap
|
| 23 |
+
import threading
|
| 24 |
+
from typing import (
|
| 25 |
+
Any as _Any,
|
| 26 |
+
Callable as _Callable,
|
| 27 |
+
Dict as _Dict,
|
| 28 |
+
Optional as _Optional,
|
| 29 |
+
overload as _overload,
|
| 30 |
+
Set as _Set,
|
| 31 |
+
Tuple as _Tuple,
|
| 32 |
+
Type as _Type,
|
| 33 |
+
TYPE_CHECKING,
|
| 34 |
+
TypeVar as _TypeVar,
|
| 35 |
+
Union as _Union,
|
| 36 |
+
)
|
| 37 |
+
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if TYPE_CHECKING:
|
| 41 |
+
from .types import IntLikeType
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# multipy/deploy is setting this import before importing torch, this is the most
|
| 45 |
+
# reliable way we have to detect if we're running within deploy.
|
| 46 |
+
# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137
|
| 47 |
+
def _running_with_deploy() -> builtins.bool:
|
| 48 |
+
return sys.modules.get("torch._meta_registrations", None) is object
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
from torch._utils import (
|
| 52 |
+
_functionalize_sync as _sync,
|
| 53 |
+
_import_dotted_name,
|
| 54 |
+
classproperty,
|
| 55 |
+
)
|
| 56 |
+
from torch._utils_internal import (
|
| 57 |
+
get_file_path,
|
| 58 |
+
prepare_multiprocessing_environment,
|
| 59 |
+
USE_GLOBAL_DEPS,
|
| 60 |
+
USE_RTLD_GLOBAL_WITH_LIBTORCH,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# TODO(torch_deploy) figure out how to freeze version.py in fbcode build
|
| 65 |
+
if _running_with_deploy():
|
| 66 |
+
__version__ = "torch-deploy-1.8"
|
| 67 |
+
else:
|
| 68 |
+
from torch.torch_version import __version__ as __version__
|
| 69 |
+
|
| 70 |
+
__all__ = [
|
| 71 |
+
"BoolStorage",
|
| 72 |
+
"BoolTensor",
|
| 73 |
+
"ByteStorage",
|
| 74 |
+
"ByteTensor",
|
| 75 |
+
"CharStorage",
|
| 76 |
+
"CharTensor",
|
| 77 |
+
"DoubleStorage",
|
| 78 |
+
"DoubleTensor",
|
| 79 |
+
"FloatStorage",
|
| 80 |
+
"FloatTensor",
|
| 81 |
+
"GradScaler",
|
| 82 |
+
"IntStorage",
|
| 83 |
+
"IntTensor",
|
| 84 |
+
"LongStorage",
|
| 85 |
+
"LongTensor",
|
| 86 |
+
"ShortStorage",
|
| 87 |
+
"ShortTensor",
|
| 88 |
+
"SymBool",
|
| 89 |
+
"SymFloat",
|
| 90 |
+
"SymInt",
|
| 91 |
+
"Tensor",
|
| 92 |
+
"TypedStorage",
|
| 93 |
+
"UntypedStorage",
|
| 94 |
+
"are_deterministic_algorithms_enabled",
|
| 95 |
+
"autocast",
|
| 96 |
+
"chunk",
|
| 97 |
+
"compile",
|
| 98 |
+
"cond",
|
| 99 |
+
"enable_grad",
|
| 100 |
+
"export",
|
| 101 |
+
"get_default_device",
|
| 102 |
+
"get_deterministic_debug_mode",
|
| 103 |
+
"get_device_module",
|
| 104 |
+
"get_float32_matmul_precision",
|
| 105 |
+
"get_rng_state",
|
| 106 |
+
"inference_mode",
|
| 107 |
+
"initial_seed",
|
| 108 |
+
"is_deterministic_algorithms_warn_only_enabled",
|
| 109 |
+
"is_storage",
|
| 110 |
+
"is_tensor",
|
| 111 |
+
"is_warn_always_enabled",
|
| 112 |
+
"load",
|
| 113 |
+
"lobpcg",
|
| 114 |
+
"manual_seed",
|
| 115 |
+
"matmul",
|
| 116 |
+
"no_grad",
|
| 117 |
+
"rand",
|
| 118 |
+
"randn",
|
| 119 |
+
"save",
|
| 120 |
+
"seed",
|
| 121 |
+
"set_default_device",
|
| 122 |
+
"set_default_tensor_type",
|
| 123 |
+
"set_deterministic_debug_mode",
|
| 124 |
+
"set_float32_matmul_precision",
|
| 125 |
+
"set_printoptions",
|
| 126 |
+
"set_rng_state",
|
| 127 |
+
"set_warn_always",
|
| 128 |
+
"split",
|
| 129 |
+
"stack",
|
| 130 |
+
"sym_float",
|
| 131 |
+
"sym_int",
|
| 132 |
+
"sym_ite",
|
| 133 |
+
"sym_max",
|
| 134 |
+
"sym_min",
|
| 135 |
+
"sym_not",
|
| 136 |
+
"typename",
|
| 137 |
+
"unravel_index",
|
| 138 |
+
"use_deterministic_algorithms",
|
| 139 |
+
"vmap",
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
# Please keep this list sorted
|
| 143 |
+
assert __all__ == sorted(__all__)
|
| 144 |
+
|
| 145 |
+
################################################################################
|
| 146 |
+
# Load the extension module
|
| 147 |
+
################################################################################
|
| 148 |
+
|
| 149 |
+
if sys.platform == "win32":
|
| 150 |
+
|
| 151 |
+
def _load_dll_libraries() -> None:
|
| 152 |
+
import sysconfig
|
| 153 |
+
|
| 154 |
+
from torch.version import cuda as cuda_version
|
| 155 |
+
|
| 156 |
+
pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files")
|
| 157 |
+
py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin")
|
| 158 |
+
th_dll_path = os.path.join(os.path.dirname(__file__), "lib")
|
| 159 |
+
usebase_path = os.path.join(
|
| 160 |
+
sysconfig.get_config_var("userbase"), "Library", "bin"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# When users create a virtualenv that inherits the base environment,
|
| 164 |
+
# we will need to add the corresponding library directory into
|
| 165 |
+
# DLL search directories. Otherwise, it will rely on `PATH` which
|
| 166 |
+
# is dependent on user settings.
|
| 167 |
+
if sys.exec_prefix != sys.base_exec_prefix:
|
| 168 |
+
base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin")
|
| 169 |
+
else:
|
| 170 |
+
base_py_dll_path = ""
|
| 171 |
+
|
| 172 |
+
dll_paths = [
|
| 173 |
+
p
|
| 174 |
+
for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path)
|
| 175 |
+
if os.path.exists(p)
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
if not builtins.any(
|
| 179 |
+
os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths
|
| 180 |
+
):
|
| 181 |
+
nvtoolsext_dll_path = os.path.join(
|
| 182 |
+
os.getenv(
|
| 183 |
+
"NVTOOLSEXT_PATH",
|
| 184 |
+
os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"),
|
| 185 |
+
),
|
| 186 |
+
"bin",
|
| 187 |
+
"x64",
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
nvtoolsext_dll_path = ""
|
| 191 |
+
|
| 192 |
+
if cuda_version and builtins.all(
|
| 193 |
+
not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths
|
| 194 |
+
):
|
| 195 |
+
cuda_version_1 = cuda_version.replace(".", "_")
|
| 196 |
+
cuda_path_var = "CUDA_PATH_V" + cuda_version_1
|
| 197 |
+
default_path = os.path.join(
|
| 198 |
+
pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}"
|
| 199 |
+
)
|
| 200 |
+
cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin")
|
| 201 |
+
else:
|
| 202 |
+
cuda_path = ""
|
| 203 |
+
|
| 204 |
+
dll_paths.extend(
|
| 205 |
+
p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
| 209 |
+
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
| 210 |
+
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
| 211 |
+
|
| 212 |
+
kernel32.LoadLibraryW.restype = ctypes.c_void_p
|
| 213 |
+
if with_load_library_flags:
|
| 214 |
+
kernel32.LoadLibraryExW.restype = ctypes.c_void_p
|
| 215 |
+
|
| 216 |
+
for dll_path in dll_paths:
|
| 217 |
+
os.add_dll_directory(dll_path)
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
ctypes.CDLL("vcruntime140.dll")
|
| 221 |
+
ctypes.CDLL("msvcp140.dll")
|
| 222 |
+
ctypes.CDLL("vcruntime140_1.dll")
|
| 223 |
+
except OSError:
|
| 224 |
+
print(
|
| 225 |
+
textwrap.dedent(
|
| 226 |
+
"""
|
| 227 |
+
Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
|
| 228 |
+
It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe
|
| 229 |
+
"""
|
| 230 |
+
).strip()
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
dlls = glob.glob(os.path.join(th_dll_path, "*.dll"))
|
| 234 |
+
path_patched = False
|
| 235 |
+
for dll in dlls:
|
| 236 |
+
is_loaded = False
|
| 237 |
+
if with_load_library_flags:
|
| 238 |
+
res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
|
| 239 |
+
last_error = ctypes.get_last_error()
|
| 240 |
+
if res is None and last_error != 126:
|
| 241 |
+
err = ctypes.WinError(last_error)
|
| 242 |
+
err.strerror += (
|
| 243 |
+
f' Error loading "{dll}" or one of its dependencies.'
|
| 244 |
+
)
|
| 245 |
+
raise err
|
| 246 |
+
elif res is not None:
|
| 247 |
+
is_loaded = True
|
| 248 |
+
if not is_loaded:
|
| 249 |
+
if not path_patched:
|
| 250 |
+
os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
|
| 251 |
+
path_patched = True
|
| 252 |
+
res = kernel32.LoadLibraryW(dll)
|
| 253 |
+
if res is None:
|
| 254 |
+
err = ctypes.WinError(ctypes.get_last_error())
|
| 255 |
+
err.strerror += (
|
| 256 |
+
f' Error loading "{dll}" or one of its dependencies.'
|
| 257 |
+
)
|
| 258 |
+
raise err
|
| 259 |
+
|
| 260 |
+
kernel32.SetErrorMode(prev_error_mode)
|
| 261 |
+
|
| 262 |
+
_load_dll_libraries()
|
| 263 |
+
del _load_dll_libraries
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None:
|
| 267 |
+
"""Preloads cuda deps if they could not be found otherwise."""
|
| 268 |
+
# Should only be called on Linux if default path resolution have failed
|
| 269 |
+
assert platform.system() == "Linux", "Should only be called on Linux"
|
| 270 |
+
|
| 271 |
+
lib_path = None
|
| 272 |
+
for path in sys.path:
|
| 273 |
+
nvidia_path = os.path.join(path, "nvidia")
|
| 274 |
+
if not os.path.exists(nvidia_path):
|
| 275 |
+
continue
|
| 276 |
+
candidate_lib_paths = glob.glob(
|
| 277 |
+
os.path.join(nvidia_path, lib_folder, "lib", lib_name)
|
| 278 |
+
)
|
| 279 |
+
if candidate_lib_paths and not lib_path:
|
| 280 |
+
lib_path = candidate_lib_paths[0]
|
| 281 |
+
if lib_path:
|
| 282 |
+
break
|
| 283 |
+
if not lib_path:
|
| 284 |
+
raise ValueError(f"{lib_name} not found in the system path {sys.path}")
|
| 285 |
+
ctypes.CDLL(lib_path)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# See Note [Global dependencies]
|
| 289 |
+
def _load_global_deps() -> None:
|
| 290 |
+
if _running_with_deploy() or platform.system() == "Windows":
|
| 291 |
+
return
|
| 292 |
+
|
| 293 |
+
# Determine the file extension based on the platform
|
| 294 |
+
lib_ext = ".dylib" if platform.system() == "Darwin" else ".so"
|
| 295 |
+
lib_name = f"libtorch_global_deps{lib_ext}"
|
| 296 |
+
here = os.path.abspath(__file__)
|
| 297 |
+
global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name)
|
| 298 |
+
|
| 299 |
+
try:
|
| 300 |
+
ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
|
| 301 |
+
except OSError as err:
|
| 302 |
+
# Can only happen for wheel with cuda libs as PYPI deps
|
| 303 |
+
# As PyTorch is not purelib, but nvidia-*-cu12 is
|
| 304 |
+
cuda_libs: _Dict[str, str] = {
|
| 305 |
+
"cublas": "libcublas.so.*[0-9]",
|
| 306 |
+
"cudnn": "libcudnn.so.*[0-9]",
|
| 307 |
+
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
|
| 308 |
+
"cuda_runtime": "libcudart.so.*[0-9]",
|
| 309 |
+
"cuda_cupti": "libcupti.so.*[0-9]",
|
| 310 |
+
"cufft": "libcufft.so.*[0-9]",
|
| 311 |
+
"curand": "libcurand.so.*[0-9]",
|
| 312 |
+
"nvjitlink": "libnvJitLink.so.*[0-9]",
|
| 313 |
+
"cusparse": "libcusparse.so.*[0-9]",
|
| 314 |
+
"cusolver": "libcusolver.so.*[0-9]",
|
| 315 |
+
"nccl": "libnccl.so.*[0-9]",
|
| 316 |
+
"nvtx": "libnvToolsExt.so.*[0-9]",
|
| 317 |
+
}
|
| 318 |
+
is_cuda_lib_err = [
|
| 319 |
+
lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0]
|
| 320 |
+
]
|
| 321 |
+
if not is_cuda_lib_err:
|
| 322 |
+
raise err
|
| 323 |
+
for lib_folder, lib_name in cuda_libs.items():
|
| 324 |
+
_preload_cuda_deps(lib_folder, lib_name)
|
| 325 |
+
ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
|
| 329 |
+
_running_with_deploy() or platform.system() != "Windows"
|
| 330 |
+
):
|
| 331 |
+
# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
|
| 332 |
+
# few circumstances:
|
| 333 |
+
#
|
| 334 |
+
# 1. You're in a build environment (e.g., fbcode) where
|
| 335 |
+
# libtorch_global_deps is not available, but you still need
|
| 336 |
+
# to get mkl to link in with RTLD_GLOBAL or it will just
|
| 337 |
+
# not work.
|
| 338 |
+
#
|
| 339 |
+
# 2. You're trying to run PyTorch under UBSAN and you need
|
| 340 |
+
# to ensure that only one copy of libtorch is loaded, so
|
| 341 |
+
# vptr checks work properly
|
| 342 |
+
#
|
| 343 |
+
# If you're using this setting, you must verify that all the libraries
|
| 344 |
+
# you load consistently use the same libstdc++, or you may have
|
| 345 |
+
# mysterious segfaults.
|
| 346 |
+
#
|
| 347 |
+
old_flags = sys.getdlopenflags()
|
| 348 |
+
sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
|
| 349 |
+
|
| 350 |
+
from torch._C import * # noqa: F403
|
| 351 |
+
|
| 352 |
+
sys.setdlopenflags(old_flags)
|
| 353 |
+
del old_flags
|
| 354 |
+
|
| 355 |
+
else:
|
| 356 |
+
# Easy way. You want this most of the time, because it will prevent
|
| 357 |
+
# C++ symbols from libtorch clobbering C++ symbols from other
|
| 358 |
+
# libraries, leading to mysterious segfaults.
|
| 359 |
+
#
|
| 360 |
+
# If building in an environment where libtorch_global_deps isn't available
|
| 361 |
+
# like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
|
| 362 |
+
# want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
|
| 363 |
+
#
|
| 364 |
+
# See Note [Global dependencies]
|
| 365 |
+
if USE_GLOBAL_DEPS:
|
| 366 |
+
_load_global_deps()
|
| 367 |
+
from torch._C import * # noqa: F403
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class SymInt:
|
| 371 |
+
"""
|
| 372 |
+
Like an int (including magic methods), but redirects all operations on the
|
| 373 |
+
wrapped node. This is used in particular to symbolically record operations
|
| 374 |
+
in the symbolic shape workflow.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
def __init__(self, node):
|
| 378 |
+
# This field MUST be named node; C++ binding code assumes that this
|
| 379 |
+
# class has a field named node that stores SymNode
|
| 380 |
+
self.node = node
|
| 381 |
+
|
| 382 |
+
def __bool__(self):
|
| 383 |
+
return builtins.bool(self != 0)
|
| 384 |
+
|
| 385 |
+
def __int__(self):
|
| 386 |
+
return self.node.int_()
|
| 387 |
+
|
| 388 |
+
def __index__(self):
|
| 389 |
+
return self.node.int_()
|
| 390 |
+
|
| 391 |
+
# Magic methods installed by torch.fx.experimental.sym_node
|
| 392 |
+
|
| 393 |
+
def __round__(self, ndigits=None):
|
| 394 |
+
return self
|
| 395 |
+
|
| 396 |
+
def __truediv__(self, other):
|
| 397 |
+
if isinstance(other, (builtins.float, SymFloat)):
|
| 398 |
+
return sym_float(self).__float_truediv__(other)
|
| 399 |
+
if not isinstance(other, (builtins.int, SymInt)):
|
| 400 |
+
return NotImplemented
|
| 401 |
+
return self.__int_truediv__(other)
|
| 402 |
+
|
| 403 |
+
def __rtruediv__(self, other):
|
| 404 |
+
if isinstance(other, (builtins.float, SymFloat)):
|
| 405 |
+
return sym_float(self).__rfloat_truediv__(other)
|
| 406 |
+
if not isinstance(other, (builtins.int, SymInt)):
|
| 407 |
+
return NotImplemented
|
| 408 |
+
return self.__rint_truediv__(other)
|
| 409 |
+
|
| 410 |
+
def __floordiv__(self, other):
|
| 411 |
+
if isinstance(other, (builtins.float, SymFloat)):
|
| 412 |
+
return sym_float(math.floor(sym_float(self) / other))
|
| 413 |
+
if not isinstance(other, (builtins.int, SymInt)):
|
| 414 |
+
return NotImplemented
|
| 415 |
+
return self.__int_floordiv__(other)
|
| 416 |
+
|
| 417 |
+
def __rfloordiv__(self, other):
|
| 418 |
+
if isinstance(other, (builtins.float, SymFloat)):
|
| 419 |
+
return sym_float(math.floor(other / sym_float(self)))
|
| 420 |
+
if not isinstance(other, (builtins.int, SymInt)):
|
| 421 |
+
return NotImplemented
|
| 422 |
+
return self.__rint_floordiv__(other)
|
| 423 |
+
|
| 424 |
+
# nb: complex is impossible to handle correctly lol, with
|
| 425 |
+
# negative base and integral float need to diverge semantics and
|
| 426 |
+
# just always return complex. Neener neener pretend this problem
|
| 427 |
+
# doesn't exist
|
| 428 |
+
def __pow__(self, other):
|
| 429 |
+
if isinstance(other, (builtins.float, SymFloat)):
|
| 430 |
+
return sym_float(self).__pow__(other)
|
| 431 |
+
if not isinstance(other, (builtins.int, SymInt)):
|
| 432 |
+
return NotImplemented
|
| 433 |
+
# Guards! This guard is necessary because we need to know it to
|
| 434 |
+
# determine the output type of this operation
|
| 435 |
+
if other >= 0:
|
| 436 |
+
return self.__pow_by_natural__(other)
|
| 437 |
+
else:
|
| 438 |
+
# Mercifully, when the exponent is negative, Python just promotes
|
| 439 |
+
# to doubles and does a float pow:
|
| 440 |
+
#
|
| 441 |
+
# if (Py_SIZE(b) < 0 && c == NULL) {
|
| 442 |
+
# /* if exponent is negative and there's no modulus:
|
| 443 |
+
# return a float. This works because we know
|
| 444 |
+
# that this calls float_pow() which converts its
|
| 445 |
+
# arguments to double. */
|
| 446 |
+
# Py_DECREF(a);
|
| 447 |
+
# Py_DECREF(b);
|
| 448 |
+
# return PyFloat_Type.tp_as_number->nb_power(v, w, x);
|
| 449 |
+
# }
|
| 450 |
+
return sym_float(self).__pow__(sym_float(other))
|
| 451 |
+
|
| 452 |
+
def __rpow__(self, other):
|
| 453 |
+
if isinstance(other, (builtins.float, SymFloat)):
|
| 454 |
+
return sym_float(self).__rpow__(other)
|
| 455 |
+
if not isinstance(other, (builtins.int, SymInt)):
|
| 456 |
+
return NotImplemented
|
| 457 |
+
if self >= 0: # self is exponent
|
| 458 |
+
return self.__rpow_by_natural__(other)
|
| 459 |
+
else:
|
| 460 |
+
return sym_float(self).__rpow__(sym_float(other))
|
| 461 |
+
|
| 462 |
+
def __eq__(self, other: object) -> builtins.bool:
|
| 463 |
+
raise TypeError("type stub not overridden")
|
| 464 |
+
|
| 465 |
+
def __lt__(self, other) -> builtins.bool:
|
| 466 |
+
raise TypeError("type stub not overridden")
|
| 467 |
+
|
| 468 |
+
def __gt__(self, other) -> builtins.bool:
|
| 469 |
+
raise TypeError("type stub not overridden")
|
| 470 |
+
|
| 471 |
+
def __le__(self, other) -> builtins.bool:
|
| 472 |
+
raise TypeError("type stub not overridden")
|
| 473 |
+
|
| 474 |
+
def __ge__(self, other) -> builtins.bool:
|
| 475 |
+
raise TypeError("type stub not overridden")
|
| 476 |
+
|
| 477 |
+
def __add__(self, other) -> "SymInt":
|
| 478 |
+
raise TypeError("type stub not overridden")
|
| 479 |
+
|
| 480 |
+
def __mod__(self, other: "IntLikeType") -> "SymInt":
|
| 481 |
+
raise TypeError("type stub not overridden")
|
| 482 |
+
|
| 483 |
+
def __mul__(self, other) -> "SymInt":
|
| 484 |
+
raise TypeError("type stub not overridden")
|
| 485 |
+
|
| 486 |
+
def __pow_by_natural__(self, other) -> "SymInt":
|
| 487 |
+
raise TypeError("type stub not overridden")
|
| 488 |
+
|
| 489 |
+
def __rpow_by_natural__(self, other) -> "SymInt":
|
| 490 |
+
raise TypeError("type stub not overridden")
|
| 491 |
+
|
| 492 |
+
def __int_truediv__(self, other) -> "SymFloat":
|
| 493 |
+
raise TypeError("type stub not overridden")
|
| 494 |
+
|
| 495 |
+
def __rint_truediv__(self, other) -> "SymFloat":
|
| 496 |
+
raise TypeError("type stub not overridden")
|
| 497 |
+
|
| 498 |
+
def __int_floordiv__(self, other) -> "SymFloat":
|
| 499 |
+
raise TypeError("type stub not overridden")
|
| 500 |
+
|
| 501 |
+
def __rint_floordiv__(self, other) -> "SymFloat":
|
| 502 |
+
raise TypeError("type stub not overridden")
|
| 503 |
+
|
| 504 |
+
def __sym_max__(self, other):
|
| 505 |
+
raise TypeError("type stub not overridden")
|
| 506 |
+
|
| 507 |
+
def __sym_min__(self, other):
|
| 508 |
+
raise TypeError("type stub not overridden")
|
| 509 |
+
|
| 510 |
+
def __sym_float__(self):
|
| 511 |
+
raise TypeError("type stub not overridden")
|
| 512 |
+
|
| 513 |
+
def __neg__(self):
|
| 514 |
+
raise TypeError("type stub not overridden")
|
| 515 |
+
|
| 516 |
+
def __sub__(self, other: "IntLikeType") -> "SymInt":
|
| 517 |
+
raise TypeError("type stub not overridden")
|
| 518 |
+
|
| 519 |
+
def __repr__(self):
|
| 520 |
+
return self.node._graph_repr()
|
| 521 |
+
|
| 522 |
+
def _sympy_(self):
|
| 523 |
+
return self.node.expr
|
| 524 |
+
|
| 525 |
+
def __hash__(self) -> builtins.int:
|
| 526 |
+
if self.node.is_nested_int():
|
| 527 |
+
return hash(self.node.nested_int())
|
| 528 |
+
else:
|
| 529 |
+
# We could support constant SymInts as well, but not doing it for now
|
| 530 |
+
raise TypeError("unhashable type: non-nested SymInt")
|
| 531 |
+
# TODO: Force specialization
|
| 532 |
+
# This can't be done because the TypeError here is load bearing
|
| 533 |
+
# for einops
|
| 534 |
+
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
|
| 535 |
+
# return hash(builtins.int(self))
|
| 536 |
+
|
| 537 |
+
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
|
| 538 |
+
"""Represent this int as an exact integer ratio"""
|
| 539 |
+
return self, 1
|
| 540 |
+
|
| 541 |
+
def bit_length(self) -> builtins.int:
|
| 542 |
+
# TODO: A more relaxed guard is possible here, where you guard to
|
| 543 |
+
# allow all integer quantities which would result in the same bit
|
| 544 |
+
# length. We can also just make a dedicated Sympy function for
|
| 545 |
+
# computing this quantity and represent it symbolically.
|
| 546 |
+
return builtins.int(self).bit_length()
|
| 547 |
+
|
| 548 |
+
def conjugate(self) -> "SymInt":
|
| 549 |
+
return self
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class SymFloat:
|
| 553 |
+
"""
|
| 554 |
+
Like an float (including magic methods), but redirects all operations on the
|
| 555 |
+
wrapped node. This is used in particular to symbolically record operations
|
| 556 |
+
in the symbolic shape workflow.
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
def __init__(self, node):
|
| 560 |
+
# This field MUST be named node; C++ binding code assumes that this
|
| 561 |
+
# class has a field named node that stores SymNode
|
| 562 |
+
self.node = node
|
| 563 |
+
|
| 564 |
+
def __truediv__(self, other):
|
| 565 |
+
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
| 566 |
+
return NotImplemented
|
| 567 |
+
return self.__float_truediv__(sym_float(other))
|
| 568 |
+
|
| 569 |
+
def __rtruediv__(self, other):
|
| 570 |
+
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
| 571 |
+
return NotImplemented
|
| 572 |
+
return self.__rfloat_truediv__(sym_float(other))
|
| 573 |
+
|
| 574 |
+
def __floordiv__(self, other):
|
| 575 |
+
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
| 576 |
+
return NotImplemented
|
| 577 |
+
return sym_float(math.floor(self / sym_float(other)))
|
| 578 |
+
|
| 579 |
+
def __rfloordiv__(self, other):
|
| 580 |
+
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
| 581 |
+
return NotImplemented
|
| 582 |
+
return sym_float(math.floor(sym_float(other) / self))
|
| 583 |
+
|
| 584 |
+
def __bool__(self):
|
| 585 |
+
return self.node.bool_()
|
| 586 |
+
|
| 587 |
+
def __float__(self):
|
| 588 |
+
return self.node.guard_float("", 0)
|
| 589 |
+
|
| 590 |
+
# Symbolic power does NOT work with negative base, this is to avoid
|
| 591 |
+
# potential complex outputs
|
| 592 |
+
def __pow__(self, other):
|
| 593 |
+
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
| 594 |
+
return NotImplemented
|
| 595 |
+
torch._check(self >= 0)
|
| 596 |
+
return self.__float_pow__(other)
|
| 597 |
+
|
| 598 |
+
def __rpow__(self, other):
|
| 599 |
+
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
|
| 600 |
+
return NotImplemented
|
| 601 |
+
torch._check(other >= 0)
|
| 602 |
+
return self.__rfloat_pow__(other)
|
| 603 |
+
|
| 604 |
+
# Magic methods installed by torch.fx.experimental.sym_node
|
| 605 |
+
|
| 606 |
+
def __eq__(self, other: object) -> builtins.bool:
|
| 607 |
+
raise TypeError("type stub not overridden")
|
| 608 |
+
|
| 609 |
+
def __lt__(self, other) -> builtins.bool:
|
| 610 |
+
raise TypeError("type stub not overridden")
|
| 611 |
+
|
| 612 |
+
def __gt__(self, other) -> builtins.bool:
|
| 613 |
+
raise TypeError("type stub not overridden")
|
| 614 |
+
|
| 615 |
+
def __le__(self, other) -> builtins.bool:
|
| 616 |
+
raise TypeError("type stub not overridden")
|
| 617 |
+
|
| 618 |
+
def __ge__(self, other) -> builtins.bool:
|
| 619 |
+
raise TypeError("type stub not overridden")
|
| 620 |
+
|
| 621 |
+
def __float_pow__(self, other) -> "SymFloat":
|
| 622 |
+
raise TypeError("type stub not overridden")
|
| 623 |
+
|
| 624 |
+
def __rfloat_pow__(self, other) -> "SymFloat":
|
| 625 |
+
raise TypeError("type stub not overridden")
|
| 626 |
+
|
| 627 |
+
def __float_truediv__(self, other) -> "SymFloat":
|
| 628 |
+
raise TypeError("type stub not overridden")
|
| 629 |
+
|
| 630 |
+
def __rfloat_truediv__(self, other) -> "SymFloat":
|
| 631 |
+
raise TypeError("type stub not overridden")
|
| 632 |
+
|
| 633 |
+
def __trunc__(self):
|
| 634 |
+
raise TypeError("type stub not overridden")
|
| 635 |
+
|
| 636 |
+
def __sym_max__(self, other):
|
| 637 |
+
raise TypeError("type stub not overridden")
|
| 638 |
+
|
| 639 |
+
def __sym_min__(self, other):
|
| 640 |
+
raise TypeError("type stub not overridden")
|
| 641 |
+
|
| 642 |
+
def __sym_int__(self):
|
| 643 |
+
raise TypeError("type stub not overridden")
|
| 644 |
+
|
| 645 |
+
def is_integer(self):
|
| 646 |
+
"""Return True if the float is an integer."""
|
| 647 |
+
raise TypeError("type stub not overridden")
|
| 648 |
+
|
| 649 |
+
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
|
| 650 |
+
"""Represent this float as an exact integer ratio"""
|
| 651 |
+
return builtins.float(self).as_integer_ratio()
|
| 652 |
+
|
| 653 |
+
def __repr__(self):
|
| 654 |
+
return self.node._graph_repr()
|
| 655 |
+
|
| 656 |
+
def _sympy_(self):
|
| 657 |
+
return self.node.expr
|
| 658 |
+
|
| 659 |
+
def __hash__(self):
|
| 660 |
+
return hash(builtins.float(self))
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
class SymBool:
|
| 664 |
+
"""
|
| 665 |
+
Like an bool (including magic methods), but redirects all operations on the
|
| 666 |
+
wrapped node. This is used in particular to symbolically record operations
|
| 667 |
+
in the symbolic shape workflow.
|
| 668 |
+
|
| 669 |
+
Unlike regular bools, regular boolean operators will force extra guards instead
|
| 670 |
+
of symbolically evaluate. Use the bitwise operators instead to handle this.
|
| 671 |
+
"""
|
| 672 |
+
|
| 673 |
+
def __init__(self, node):
|
| 674 |
+
# This field MUST be named node; C++ binding code assumes that this
|
| 675 |
+
# class has a field named node that stores SymNode
|
| 676 |
+
self.node = node
|
| 677 |
+
|
| 678 |
+
def __bool__(self):
|
| 679 |
+
return self.node.bool_()
|
| 680 |
+
|
| 681 |
+
def __int__(self):
|
| 682 |
+
return builtins.int(self.node.bool_())
|
| 683 |
+
|
| 684 |
+
# Magic methods installed by torch.fx.experimental.sym_node
|
| 685 |
+
def __and__(self, other) -> "SymBool":
|
| 686 |
+
raise TypeError("type stub not overridden")
|
| 687 |
+
|
| 688 |
+
def __or__(self, other) -> "SymBool":
|
| 689 |
+
raise TypeError("type stub not overridden")
|
| 690 |
+
|
| 691 |
+
# We very carefully define __sym_not__, and not a number of other
|
| 692 |
+
# plausible alternatives:
|
| 693 |
+
#
|
| 694 |
+
# - We do not override __not__ because this is not a real magic
|
| 695 |
+
# method; you cannot override the meaning of the not builtin in
|
| 696 |
+
# Python. We use the name 'sym_not' to clarify that in user code you
|
| 697 |
+
# cannot use the builtin not or operator.not_ or operator.__not__ and
|
| 698 |
+
# hit this magic method; you must use our custom sym_not operator.
|
| 699 |
+
#
|
| 700 |
+
# - We do not override the __invert__ method because SymBool is
|
| 701 |
+
# meant to be usable in situations where bool is expected. However,
|
| 702 |
+
# bitwise negation ~a does the wrong thing with booleans (because
|
| 703 |
+
# bool is a subclass of int, so ~1 = -2 which is not falseish.)
|
| 704 |
+
# This would be a giant footgun, so we get around it by defining
|
| 705 |
+
# our own operator. Note that bitwise and/or do the right thing,
|
| 706 |
+
# so we reuse the conventional operators there for readability.
|
| 707 |
+
#
|
| 708 |
+
def __sym_not__(self) -> "SymBool":
|
| 709 |
+
raise TypeError("type stub not overridden")
|
| 710 |
+
|
| 711 |
+
def __sym_ite__(self, then_val, else_val):
|
| 712 |
+
raise TypeError("type stub not overridden")
|
| 713 |
+
|
| 714 |
+
def __eq__(self, other) -> builtins.bool:
|
| 715 |
+
raise TypeError("type stub not overridden")
|
| 716 |
+
|
| 717 |
+
def __repr__(self):
|
| 718 |
+
return self.node._graph_repr()
|
| 719 |
+
|
| 720 |
+
def _sympy_(self):
|
| 721 |
+
return self.node.expr
|
| 722 |
+
|
| 723 |
+
def __hash__(self):
|
| 724 |
+
if self.node.is_constant():
|
| 725 |
+
return hash(self.node.bool_())
|
| 726 |
+
else:
|
| 727 |
+
# Force specialization
|
| 728 |
+
return hash(builtins.bool(self))
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def sym_not(a):
|
| 732 |
+
r"""SymInt-aware utility for logical negation.
|
| 733 |
+
|
| 734 |
+
Args:
|
| 735 |
+
a (SymBool or bool): Object to negate
|
| 736 |
+
"""
|
| 737 |
+
import sympy
|
| 738 |
+
|
| 739 |
+
if overrides.has_torch_function_unary(a):
|
| 740 |
+
return overrides.handle_torch_function(sym_not, (a,), a)
|
| 741 |
+
if hasattr(a, "__sym_not__"):
|
| 742 |
+
return a.__sym_not__()
|
| 743 |
+
if isinstance(a, sympy.Basic):
|
| 744 |
+
return ~a # type: ignore[operator]
|
| 745 |
+
return not a
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def sym_float(a):
|
| 749 |
+
r"""SymInt-aware utility for float casting.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
a (SymInt, SymFloat, or object): Object to cast
|
| 753 |
+
"""
|
| 754 |
+
if overrides.has_torch_function_unary(a):
|
| 755 |
+
return overrides.handle_torch_function(sym_float, (a,), a)
|
| 756 |
+
if isinstance(a, SymFloat):
|
| 757 |
+
return a
|
| 758 |
+
elif hasattr(a, "__sym_float__"):
|
| 759 |
+
return a.__sym_float__()
|
| 760 |
+
return builtins.float(a) # type: ignore[operator]
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
def sym_int(a):
|
| 764 |
+
r"""SymInt-aware utility for int casting.
|
| 765 |
+
|
| 766 |
+
Args:
|
| 767 |
+
a (SymInt, SymFloat, or object): Object to cast
|
| 768 |
+
"""
|
| 769 |
+
if overrides.has_torch_function_unary(a):
|
| 770 |
+
return overrides.handle_torch_function(sym_int, (a,), a)
|
| 771 |
+
if isinstance(a, SymInt):
|
| 772 |
+
return a
|
| 773 |
+
elif isinstance(a, SymFloat):
|
| 774 |
+
return math.trunc(a)
|
| 775 |
+
return builtins.int(a) # type: ignore[operator]
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def sym_max(a, b):
|
| 779 |
+
"""
|
| 780 |
+
SymInt-aware utility for max which avoids branching on a < b.
|
| 781 |
+
Unlike builtins.max(), this only works for int/float, and it always
|
| 782 |
+
promotes to float if any argument is float (unlike builtins.max, which
|
| 783 |
+
will faithfully preserve the type of the input argument).
|
| 784 |
+
"""
|
| 785 |
+
if overrides.has_torch_function((a, b)):
|
| 786 |
+
return overrides.handle_torch_function(sym_max, (a, b), a, b)
|
| 787 |
+
if isinstance(a, (SymInt, SymFloat)):
|
| 788 |
+
return a.__sym_max__(b)
|
| 789 |
+
elif isinstance(b, (SymInt, SymFloat)):
|
| 790 |
+
# Due to promotion semantics, this is operator is commutative:
|
| 791 |
+
# max(1, 1.0) === max(1.0, 1) === 1.0
|
| 792 |
+
return b.__sym_max__(a)
|
| 793 |
+
# TODO: Probably can make bool work too, just lazy
|
| 794 |
+
|
| 795 |
+
all_types, float_types = __all_and_float_types()
|
| 796 |
+
|
| 797 |
+
assert isinstance(a, all_types), type(a)
|
| 798 |
+
assert isinstance(b, all_types), type(b)
|
| 799 |
+
if isinstance(a, float_types) or isinstance(b, float_types):
|
| 800 |
+
return builtins.float(builtins.max(a, b))
|
| 801 |
+
else:
|
| 802 |
+
return builtins.max(a, b)
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]:
|
| 806 |
+
try:
|
| 807 |
+
import numpy as np
|
| 808 |
+
|
| 809 |
+
all_types: _Tuple[_Type, ...] = (
|
| 810 |
+
np.integer,
|
| 811 |
+
np.floating,
|
| 812 |
+
builtins.int,
|
| 813 |
+
builtins.float,
|
| 814 |
+
)
|
| 815 |
+
float_types: _Tuple[_Type, ...] = (np.floating, builtins.float)
|
| 816 |
+
except ModuleNotFoundError:
|
| 817 |
+
all_types = (builtins.int, builtins.float)
|
| 818 |
+
float_types = (builtins.float,)
|
| 819 |
+
|
| 820 |
+
return all_types, float_types
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
def sym_min(a, b):
|
| 824 |
+
"""SymInt-aware utility for min()."""
|
| 825 |
+
if overrides.has_torch_function((a, b)):
|
| 826 |
+
return overrides.handle_torch_function(sym_min, (a, b), a, b)
|
| 827 |
+
if isinstance(a, (SymInt, SymFloat)):
|
| 828 |
+
return a.__sym_min__(b)
|
| 829 |
+
elif isinstance(b, (SymInt, SymFloat)):
|
| 830 |
+
return b.__sym_min__(a)
|
| 831 |
+
|
| 832 |
+
all_types, float_types = __all_and_float_types()
|
| 833 |
+
|
| 834 |
+
assert isinstance(a, all_types), type(a)
|
| 835 |
+
assert isinstance(b, all_types), type(b)
|
| 836 |
+
if isinstance(a, float_types) or isinstance(b, float_types):
|
| 837 |
+
return builtins.float(builtins.min(a, b))
|
| 838 |
+
else:
|
| 839 |
+
return builtins.min(a, b)
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
# Drop in replacement for math.sqrt, math.sin, math.cos etc
|
| 843 |
+
def _get_sym_math_fn(name):
|
| 844 |
+
def fn(a):
|
| 845 |
+
if overrides.has_torch_function_unary(a):
|
| 846 |
+
return overrides.handle_torch_function(fn, (a,), a)
|
| 847 |
+
if hasattr(a, f"__sym_{name}__"):
|
| 848 |
+
return getattr(a, f"__sym_{name}__")()
|
| 849 |
+
return getattr(math, name)(a)
|
| 850 |
+
|
| 851 |
+
return fn
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
__fn, __name, __sym_name = None, "", ""
|
| 855 |
+
for __name in (
|
| 856 |
+
"sqrt",
|
| 857 |
+
"cos",
|
| 858 |
+
"cosh",
|
| 859 |
+
"sin",
|
| 860 |
+
"sinh",
|
| 861 |
+
"tan",
|
| 862 |
+
"tanh",
|
| 863 |
+
"asin",
|
| 864 |
+
"acos",
|
| 865 |
+
"atan",
|
| 866 |
+
):
|
| 867 |
+
__sym_name = f"_sym_{__name}"
|
| 868 |
+
__fn = _get_sym_math_fn(__name)
|
| 869 |
+
__fn.__qualname__ = __fn.__name__ = __sym_name
|
| 870 |
+
globals()[__sym_name] = __fn
|
| 871 |
+
|
| 872 |
+
del __fn, __name, __sym_name, _get_sym_math_fn
|
| 873 |
+
|
| 874 |
+
# Adding temporary shortcut
|
| 875 |
+
sym_sqrt = globals()["_sym_sqrt"]
|
| 876 |
+
__all__.append("sym_sqrt")
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
def sym_ite(b, t, f):
|
| 880 |
+
if overrides.has_torch_function((b, t, f)):
|
| 881 |
+
return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
|
| 882 |
+
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
|
| 883 |
+
if isinstance(b, SymBool):
|
| 884 |
+
return b.__sym_ite__(t, f)
|
| 885 |
+
return t if b else f
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
# Check to see if we can load C extensions, and if not provide some guidance
|
| 889 |
+
# on what the problem might be.
|
| 890 |
+
try:
|
| 891 |
+
# _initExtension is chosen (arbitrarily) as a sentinel.
|
| 892 |
+
from torch._C import _initExtension
|
| 893 |
+
except ImportError:
|
| 894 |
+
import torch._C as _C_for_compiled_check
|
| 895 |
+
|
| 896 |
+
# The __file__ check only works for Python 3.7 and above.
|
| 897 |
+
if _C_for_compiled_check.__file__ is None:
|
| 898 |
+
raise ImportError(
|
| 899 |
+
textwrap.dedent(
|
| 900 |
+
"""
|
| 901 |
+
Failed to load PyTorch C extensions:
|
| 902 |
+
It appears that PyTorch has loaded the `torch/_C` folder
|
| 903 |
+
of the PyTorch repository rather than the C extensions which
|
| 904 |
+
are expected in the `torch._C` namespace. This can occur when
|
| 905 |
+
using the `install` workflow. e.g.
|
| 906 |
+
$ python setup.py install && python -c "import torch"
|
| 907 |
+
|
| 908 |
+
This error can generally be solved using the `develop` workflow
|
| 909 |
+
$ python setup.py develop && python -c "import torch" # This should succeed
|
| 910 |
+
or by running Python from a different directory.
|
| 911 |
+
"""
|
| 912 |
+
).strip()
|
| 913 |
+
) from None
|
| 914 |
+
raise # If __file__ is not None the cause is unknown, so just re-raise.
|
| 915 |
+
|
| 916 |
+
# The torch._C submodule is already loaded via `from torch._C import *` above
|
| 917 |
+
# Make an explicit reference to the _C submodule to appease linters
|
| 918 |
+
from torch import _C as _C
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
__name, __obj = "", None
|
| 922 |
+
for __name in dir(_C):
|
| 923 |
+
if __name[0] != "_" and not __name.endswith("Base"):
|
| 924 |
+
__all__.append(__name)
|
| 925 |
+
__obj = getattr(_C, __name)
|
| 926 |
+
if callable(__obj) or inspect.isclass(__obj):
|
| 927 |
+
if __obj.__module__ != __name__: # "torch"
|
| 928 |
+
# TODO: fix their module from C++ side
|
| 929 |
+
if __name not in {
|
| 930 |
+
"DisableTorchFunctionSubclass",
|
| 931 |
+
"DisableTorchFunction",
|
| 932 |
+
"Generator",
|
| 933 |
+
}:
|
| 934 |
+
__obj.__module__ = __name__ # "torch"
|
| 935 |
+
elif __name == "TensorBase":
|
| 936 |
+
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
|
| 937 |
+
delattr(sys.modules[__name__], __name)
|
| 938 |
+
|
| 939 |
+
del __name, __obj
|
| 940 |
+
|
| 941 |
+
if not TYPE_CHECKING:
|
| 942 |
+
# issue 38137 and python issue 43367. Submodules of a C extension are
|
| 943 |
+
# non-standard, and attributes of those submodules cannot be pickled since
|
| 944 |
+
# pickle expect to be able to import them as "from _C.sub import attr"
|
| 945 |
+
# which fails with "_C is not a package
|
| 946 |
+
def _import_extension_to_sys_modules(module, memo=None):
|
| 947 |
+
if memo is None:
|
| 948 |
+
memo = set()
|
| 949 |
+
if module in memo:
|
| 950 |
+
return
|
| 951 |
+
memo.add(module)
|
| 952 |
+
module_name = module.__name__
|
| 953 |
+
for name in dir(module):
|
| 954 |
+
member = getattr(module, name)
|
| 955 |
+
member_name = getattr(member, "__name__", "")
|
| 956 |
+
if inspect.ismodule(member) and member_name.startswith(module_name):
|
| 957 |
+
sys.modules.setdefault(member_name, member)
|
| 958 |
+
# Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
|
| 959 |
+
_import_extension_to_sys_modules(member, memo)
|
| 960 |
+
|
| 961 |
+
_import_extension_to_sys_modules(_C)
|
| 962 |
+
del _import_extension_to_sys_modules
|
| 963 |
+
|
| 964 |
+
################################################################################
|
| 965 |
+
# Define basic utilities
|
| 966 |
+
################################################################################
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
def typename(obj: _Any, /) -> str:
|
| 970 |
+
"""
|
| 971 |
+
String representation of the type of an object.
|
| 972 |
+
|
| 973 |
+
This function returns a fully qualified string representation of an object's type.
|
| 974 |
+
Args:
|
| 975 |
+
obj (object): The object whose type to represent
|
| 976 |
+
Returns:
|
| 977 |
+
str: the type of the object `o`
|
| 978 |
+
Example:
|
| 979 |
+
>>> x = torch.tensor([1, 2, 3])
|
| 980 |
+
>>> torch.typename(x)
|
| 981 |
+
'torch.LongTensor'
|
| 982 |
+
>>> torch.typename(torch.nn.Parameter)
|
| 983 |
+
'torch.nn.parameter.Parameter'
|
| 984 |
+
"""
|
| 985 |
+
if isinstance(obj, torch.Tensor):
|
| 986 |
+
return obj.type()
|
| 987 |
+
|
| 988 |
+
module = getattr(obj, "__module__", "") or ""
|
| 989 |
+
qualname = ""
|
| 990 |
+
|
| 991 |
+
if hasattr(obj, "__qualname__"):
|
| 992 |
+
qualname = obj.__qualname__
|
| 993 |
+
elif hasattr(obj, "__name__"):
|
| 994 |
+
qualname = obj.__name__
|
| 995 |
+
else:
|
| 996 |
+
module = obj.__class__.__module__ or ""
|
| 997 |
+
qualname = obj.__class__.__qualname__
|
| 998 |
+
|
| 999 |
+
if module in {"", "builtins"}:
|
| 1000 |
+
return qualname
|
| 1001 |
+
return f"{module}.{qualname}"
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
|
| 1005 |
+
r"""Returns True if `obj` is a PyTorch tensor.
|
| 1006 |
+
|
| 1007 |
+
Note that this function is simply doing ``isinstance(obj, Tensor)``.
|
| 1008 |
+
Using that ``isinstance`` check is better for typechecking with mypy,
|
| 1009 |
+
and more explicit - so it's recommended to use that instead of
|
| 1010 |
+
``is_tensor``.
|
| 1011 |
+
|
| 1012 |
+
Args:
|
| 1013 |
+
obj (object): Object to test
|
| 1014 |
+
Example::
|
| 1015 |
+
|
| 1016 |
+
>>> x = torch.tensor([1, 2, 3])
|
| 1017 |
+
>>> torch.is_tensor(x)
|
| 1018 |
+
True
|
| 1019 |
+
|
| 1020 |
+
"""
|
| 1021 |
+
return isinstance(obj, torch.Tensor)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
|
| 1025 |
+
r"""Returns True if `obj` is a PyTorch storage object.
|
| 1026 |
+
|
| 1027 |
+
Args:
|
| 1028 |
+
obj (Object): Object to test
|
| 1029 |
+
"""
|
| 1030 |
+
return type(obj) in _storage_classes
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
_GLOBAL_DEVICE_CONTEXT = threading.local()
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
def get_default_device() -> "torch.device":
|
| 1037 |
+
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
|
| 1038 |
+
global _GLOBAL_DEVICE_CONTEXT
|
| 1039 |
+
|
| 1040 |
+
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
| 1041 |
+
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
|
| 1042 |
+
if device.index is not None:
|
| 1043 |
+
return device
|
| 1044 |
+
else:
|
| 1045 |
+
# TODO: Call like get_device_index() method corresponding to
|
| 1046 |
+
# each device type
|
| 1047 |
+
return torch.tensor([]).device
|
| 1048 |
+
else:
|
| 1049 |
+
return torch.device("cpu")
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def set_default_device(
|
| 1053 |
+
device: _Optional[_Union["torch.device", str, builtins.int]],
|
| 1054 |
+
) -> None:
|
| 1055 |
+
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
|
| 1056 |
+
does not affect factory function calls which are called with an explicit
|
| 1057 |
+
``device`` argument. Factory calls will be performed as if they
|
| 1058 |
+
were passed ``device`` as an argument.
|
| 1059 |
+
|
| 1060 |
+
To only temporarily change the default device instead of setting it
|
| 1061 |
+
globally, use ``with torch.device(device):`` instead.
|
| 1062 |
+
|
| 1063 |
+
The default device is initially ``cpu``. If you set the default tensor
|
| 1064 |
+
device to another device (e.g., ``cuda``) without a device index, tensors
|
| 1065 |
+
will be allocated on whatever the current device for the device type,
|
| 1066 |
+
even after :func:`torch.cuda.set_device` is called.
|
| 1067 |
+
|
| 1068 |
+
.. warning::
|
| 1069 |
+
|
| 1070 |
+
This function imposes a slight performance cost on every Python
|
| 1071 |
+
call to the torch API (not just factory functions). If this
|
| 1072 |
+
is causing problems for you, please comment on
|
| 1073 |
+
https://github.com/pytorch/pytorch/issues/92701
|
| 1074 |
+
|
| 1075 |
+
.. note::
|
| 1076 |
+
|
| 1077 |
+
This doesn't affect functions that create tensors that share the same memory as the input, like:
|
| 1078 |
+
:func:`torch.from_numpy` and :func:`torch.frombuffer`
|
| 1079 |
+
|
| 1080 |
+
Args:
|
| 1081 |
+
device (device or string): the device to set as default
|
| 1082 |
+
|
| 1083 |
+
Example::
|
| 1084 |
+
|
| 1085 |
+
>>> # xdoctest: +SKIP("requires cuda, changes global state")
|
| 1086 |
+
>>> torch.get_default_device()
|
| 1087 |
+
device(type='cpu')
|
| 1088 |
+
>>> torch.set_default_device('cuda') # current device is 0
|
| 1089 |
+
>>> torch.get_default_device()
|
| 1090 |
+
device(type='cuda', index=0)
|
| 1091 |
+
>>> torch.set_default_device('cuda')
|
| 1092 |
+
>>> torch.cuda.set_device('cuda:1') # current device is 1
|
| 1093 |
+
>>> torch.get_default_device()
|
| 1094 |
+
device(type='cuda', index=1)
|
| 1095 |
+
>>> torch.set_default_device('cuda:1')
|
| 1096 |
+
>>> torch.get_default_device()
|
| 1097 |
+
device(type='cuda', index=1)
|
| 1098 |
+
|
| 1099 |
+
"""
|
| 1100 |
+
global _GLOBAL_DEVICE_CONTEXT
|
| 1101 |
+
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
| 1102 |
+
device_context = _GLOBAL_DEVICE_CONTEXT.device_context
|
| 1103 |
+
if device_context is not None:
|
| 1104 |
+
device_context.__exit__(None, None, None)
|
| 1105 |
+
|
| 1106 |
+
if device is None:
|
| 1107 |
+
device_context = None
|
| 1108 |
+
else:
|
| 1109 |
+
from torch.utils._device import DeviceContext
|
| 1110 |
+
|
| 1111 |
+
device_context = DeviceContext(device)
|
| 1112 |
+
device_context.__enter__()
|
| 1113 |
+
_GLOBAL_DEVICE_CONTEXT.device_context = device_context
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
|
| 1117 |
+
r"""
|
| 1118 |
+
.. warning::
|
| 1119 |
+
|
| 1120 |
+
This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and
|
| 1121 |
+
:func:`torch.set_default_device()` as alternatives.
|
| 1122 |
+
|
| 1123 |
+
Sets the default ``torch.Tensor`` type to floating point tensor type
|
| 1124 |
+
``t``. This type will also be used as default floating point type for
|
| 1125 |
+
type inference in :func:`torch.tensor`.
|
| 1126 |
+
|
| 1127 |
+
The default floating point tensor type is initially ``torch.FloatTensor``.
|
| 1128 |
+
|
| 1129 |
+
Args:
|
| 1130 |
+
t (type or string): the floating point tensor type or its name
|
| 1131 |
+
|
| 1132 |
+
Example::
|
| 1133 |
+
|
| 1134 |
+
>>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
|
| 1135 |
+
>>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
|
| 1136 |
+
torch.float32
|
| 1137 |
+
>>> torch.set_default_tensor_type(torch.DoubleTensor)
|
| 1138 |
+
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
|
| 1139 |
+
torch.float64
|
| 1140 |
+
|
| 1141 |
+
"""
|
| 1142 |
+
if isinstance(t, str):
|
| 1143 |
+
t = _import_dotted_name(t)
|
| 1144 |
+
_C._set_default_tensor_type(t)
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
def set_default_dtype(d: "torch.dtype", /) -> None:
|
| 1148 |
+
r"""
|
| 1149 |
+
|
| 1150 |
+
Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
|
| 1151 |
+
as inputs. Other dtypes will cause torch to raise an exception.
|
| 1152 |
+
|
| 1153 |
+
When PyTorch is initialized its default floating point dtype is torch.float32,
|
| 1154 |
+
and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
|
| 1155 |
+
type inference. The default floating point dtype is used to:
|
| 1156 |
+
|
| 1157 |
+
1. Implicitly determine the default complex dtype. When the default floating type is float16,
|
| 1158 |
+
the default complex dtype is complex32. For float32, the default complex dtype is complex64.
|
| 1159 |
+
For float64, it is complex128. For bfloat16, an exception will be raised because
|
| 1160 |
+
there is no corresponding complex type for bfloat16.
|
| 1161 |
+
2. Infer the dtype for tensors constructed using Python floats or complex Python
|
| 1162 |
+
numbers. See examples below.
|
| 1163 |
+
3. Determine the result of type promotion between bool and integer tensors and
|
| 1164 |
+
Python floats and complex Python numbers.
|
| 1165 |
+
|
| 1166 |
+
Args:
|
| 1167 |
+
d (:class:`torch.dtype`): the floating point dtype to make the default.
|
| 1168 |
+
|
| 1169 |
+
Example:
|
| 1170 |
+
>>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
|
| 1171 |
+
>>> # initial default for floating point is torch.float32
|
| 1172 |
+
>>> # Python floats are interpreted as float32
|
| 1173 |
+
>>> torch.tensor([1.2, 3]).dtype
|
| 1174 |
+
torch.float32
|
| 1175 |
+
>>> # initial default for floating point is torch.complex64
|
| 1176 |
+
>>> # Complex Python numbers are interpreted as complex64
|
| 1177 |
+
>>> torch.tensor([1.2, 3j]).dtype
|
| 1178 |
+
torch.complex64
|
| 1179 |
+
|
| 1180 |
+
>>> torch.set_default_dtype(torch.float64)
|
| 1181 |
+
>>> # Python floats are now interpreted as float64
|
| 1182 |
+
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
|
| 1183 |
+
torch.float64
|
| 1184 |
+
>>> # Complex Python numbers are now interpreted as complex128
|
| 1185 |
+
>>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
|
| 1186 |
+
torch.complex128
|
| 1187 |
+
|
| 1188 |
+
>>> torch.set_default_dtype(torch.float16)
|
| 1189 |
+
>>> # Python floats are now interpreted as float16
|
| 1190 |
+
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
|
| 1191 |
+
torch.float16
|
| 1192 |
+
>>> # Complex Python numbers are now interpreted as complex128
|
| 1193 |
+
>>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
|
| 1194 |
+
torch.complex32
|
| 1195 |
+
|
| 1196 |
+
"""
|
| 1197 |
+
_C._set_default_dtype(d)
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
def use_deterministic_algorithms(
|
| 1201 |
+
mode: builtins.bool,
|
| 1202 |
+
*,
|
| 1203 |
+
warn_only: builtins.bool = False,
|
| 1204 |
+
) -> None:
|
| 1205 |
+
r"""Sets whether PyTorch operations must use "deterministic"
|
| 1206 |
+
algorithms. That is, algorithms which, given the same input, and when
|
| 1207 |
+
run on the same software and hardware, always produce the same output.
|
| 1208 |
+
When enabled, operations will use deterministic algorithms when available,
|
| 1209 |
+
and if only nondeterministic algorithms are available they will throw a
|
| 1210 |
+
:class:`RuntimeError` when called.
|
| 1211 |
+
|
| 1212 |
+
.. note:: This setting alone is not always enough to make an application
|
| 1213 |
+
reproducible. Refer to :ref:`reproducibility` for more information.
|
| 1214 |
+
|
| 1215 |
+
.. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative
|
| 1216 |
+
interface for this feature.
|
| 1217 |
+
|
| 1218 |
+
The following normally-nondeterministic operations will act
|
| 1219 |
+
deterministically when ``mode=True``:
|
| 1220 |
+
|
| 1221 |
+
* :class:`torch.nn.Conv1d` when called on CUDA tensor
|
| 1222 |
+
* :class:`torch.nn.Conv2d` when called on CUDA tensor
|
| 1223 |
+
* :class:`torch.nn.Conv3d` when called on CUDA tensor
|
| 1224 |
+
* :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
|
| 1225 |
+
* :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
|
| 1226 |
+
* :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
|
| 1227 |
+
* :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
|
| 1228 |
+
* :func:`torch.bmm` when called on sparse-dense CUDA tensors
|
| 1229 |
+
* :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
|
| 1230 |
+
and the index is a list of tensors
|
| 1231 |
+
* :func:`torch.Tensor.index_put` with ``accumulate=False``
|
| 1232 |
+
* :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
|
| 1233 |
+
tensor
|
| 1234 |
+
* :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
|
| 1235 |
+
tensor
|
| 1236 |
+
* :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
|
| 1237 |
+
* :func:`torch.gather` when called on a CUDA tensor that requires grad
|
| 1238 |
+
* :func:`torch.index_add` when called on CUDA tensor
|
| 1239 |
+
* :func:`torch.index_select` when attempting to differentiate a CUDA tensor
|
| 1240 |
+
* :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
|
| 1241 |
+
* :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
|
| 1242 |
+
* :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
|
| 1243 |
+
* :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
|
| 1244 |
+
|
| 1245 |
+
The following normally-nondeterministic operations will throw a
|
| 1246 |
+
:class:`RuntimeError` when ``mode=True``:
|
| 1247 |
+
|
| 1248 |
+
* :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor
|
| 1249 |
+
* :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor
|
| 1250 |
+
* :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor
|
| 1251 |
+
* :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor
|
| 1252 |
+
* :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor
|
| 1253 |
+
* :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor
|
| 1254 |
+
* :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor
|
| 1255 |
+
* :class:`torch.nn.MaxUnpool1d`
|
| 1256 |
+
* :class:`torch.nn.MaxUnpool2d`
|
| 1257 |
+
* :class:`torch.nn.MaxUnpool3d`
|
| 1258 |
+
* :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor
|
| 1259 |
+
and one of the following modes is used:
|
| 1260 |
+
|
| 1261 |
+
- ``linear``
|
| 1262 |
+
- ``bilinear``
|
| 1263 |
+
- ``bicubic``
|
| 1264 |
+
- ``trilinear``
|
| 1265 |
+
|
| 1266 |
+
* :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor
|
| 1267 |
+
* :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
|
| 1268 |
+
* :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
|
| 1269 |
+
* :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
|
| 1270 |
+
* :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
|
| 1271 |
+
* :class:`torch.nn.NLLLoss` when called on a CUDA tensor
|
| 1272 |
+
* :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
|
| 1273 |
+
* :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
|
| 1274 |
+
``mode='max'``
|
| 1275 |
+
* :func:`torch.Tensor.put_` when ``accumulate=False``
|
| 1276 |
+
* :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
|
| 1277 |
+
* :func:`torch.histc` when called on a CUDA tensor
|
| 1278 |
+
* :func:`torch.bincount` when called on a CUDA tensor and ``weights``
|
| 1279 |
+
tensor is given
|
| 1280 |
+
* :func:`torch.kthvalue` with called on a CUDA tensor
|
| 1281 |
+
* :func:`torch.median` with indices output when called on a CUDA tensor
|
| 1282 |
+
* :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
|
| 1283 |
+
* :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
|
| 1284 |
+
* :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
|
| 1285 |
+
* :func:`torch.Tensor.resize_` when called with a quantized tensor
|
| 1286 |
+
|
| 1287 |
+
In addition, several operations fill uninitialized memory when this setting
|
| 1288 |
+
is turned on and when
|
| 1289 |
+
:attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
|
| 1290 |
+
See the documentation for that attribute for more information.
|
| 1291 |
+
|
| 1292 |
+
A handful of CUDA operations are nondeterministic if the CUDA version is
|
| 1293 |
+
10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
|
| 1294 |
+
or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more
|
| 1295 |
+
details: `<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_
|
| 1296 |
+
If one of these environment variable configurations is not set, a :class:`RuntimeError`
|
| 1297 |
+
will be raised from these operations when called with CUDA tensors:
|
| 1298 |
+
|
| 1299 |
+
* :func:`torch.mm`
|
| 1300 |
+
* :func:`torch.mv`
|
| 1301 |
+
* :func:`torch.bmm`
|
| 1302 |
+
|
| 1303 |
+
Note that deterministic operations tend to have worse performance than
|
| 1304 |
+
nondeterministic operations.
|
| 1305 |
+
|
| 1306 |
+
.. note::
|
| 1307 |
+
|
| 1308 |
+
This flag does not detect or prevent nondeterministic behavior caused
|
| 1309 |
+
by calling an inplace operation on a tensor with an internal memory
|
| 1310 |
+
overlap or by giving such a tensor as the :attr:`out` argument for an
|
| 1311 |
+
operation. In these cases, multiple writes of different data may target
|
| 1312 |
+
a single memory location, and the order of writes is not guaranteed.
|
| 1313 |
+
|
| 1314 |
+
Args:
|
| 1315 |
+
mode (:class:`bool`): If True, makes potentially nondeterministic
|
| 1316 |
+
operations switch to a deterministic algorithm or throw a runtime
|
| 1317 |
+
error. If False, allows nondeterministic operations.
|
| 1318 |
+
|
| 1319 |
+
Keyword args:
|
| 1320 |
+
warn_only (:class:`bool`, optional): If True, operations that do not
|
| 1321 |
+
have a deterministic implementation will throw a warning instead of
|
| 1322 |
+
an error. Default: ``False``
|
| 1323 |
+
|
| 1324 |
+
Example::
|
| 1325 |
+
|
| 1326 |
+
>>> # xdoctest: +SKIP
|
| 1327 |
+
>>> torch.use_deterministic_algorithms(True)
|
| 1328 |
+
|
| 1329 |
+
# Forward mode nondeterministic error
|
| 1330 |
+
>>> torch.randn(10, device='cuda').kthvalue(1)
|
| 1331 |
+
...
|
| 1332 |
+
RuntimeError: kthvalue CUDA does not have a deterministic implementation...
|
| 1333 |
+
|
| 1334 |
+
# Backward mode nondeterministic error
|
| 1335 |
+
>>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward()
|
| 1336 |
+
...
|
| 1337 |
+
RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation...
|
| 1338 |
+
"""
|
| 1339 |
+
_C._set_deterministic_algorithms(mode, warn_only=warn_only)
|
| 1340 |
+
|
| 1341 |
+
|
| 1342 |
+
def are_deterministic_algorithms_enabled() -> builtins.bool:
|
| 1343 |
+
r"""Returns True if the global deterministic flag is turned on. Refer to
|
| 1344 |
+
:func:`torch.use_deterministic_algorithms` documentation for more details.
|
| 1345 |
+
"""
|
| 1346 |
+
return _C._get_deterministic_algorithms()
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool:
|
| 1350 |
+
r"""Returns True if the global deterministic flag is set to warn only.
|
| 1351 |
+
Refer to :func:`torch.use_deterministic_algorithms` documentation for more
|
| 1352 |
+
details.
|
| 1353 |
+
"""
|
| 1354 |
+
return _C._get_deterministic_algorithms_warn_only()
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None:
|
| 1358 |
+
r"""Sets the debug mode for deterministic operations.
|
| 1359 |
+
|
| 1360 |
+
.. note:: This is an alternative interface for
|
| 1361 |
+
:func:`torch.use_deterministic_algorithms`. Refer to that function's
|
| 1362 |
+
documentation for details about affected operations.
|
| 1363 |
+
|
| 1364 |
+
Args:
|
| 1365 |
+
debug_mode(str or int): If "default" or 0, don't error or warn on
|
| 1366 |
+
nondeterministic operations. If "warn" or 1, warn on
|
| 1367 |
+
nondeterministic operations. If "error" or 2, error on
|
| 1368 |
+
nondeterministic operations.
|
| 1369 |
+
"""
|
| 1370 |
+
|
| 1371 |
+
# NOTE: builtins.int is used here because int in this scope resolves
|
| 1372 |
+
# to torch.int
|
| 1373 |
+
if not isinstance(debug_mode, (builtins.int, str)):
|
| 1374 |
+
raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}")
|
| 1375 |
+
|
| 1376 |
+
if isinstance(debug_mode, str):
|
| 1377 |
+
if debug_mode == "default":
|
| 1378 |
+
debug_mode = 0
|
| 1379 |
+
elif debug_mode == "warn":
|
| 1380 |
+
debug_mode = 1
|
| 1381 |
+
elif debug_mode == "error":
|
| 1382 |
+
debug_mode = 2
|
| 1383 |
+
else:
|
| 1384 |
+
raise RuntimeError(
|
| 1385 |
+
"invalid value of debug_mode, expected one of `default`, "
|
| 1386 |
+
f"`warn`, `error`, but got {debug_mode}"
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
if debug_mode == 0:
|
| 1390 |
+
_C._set_deterministic_algorithms(False)
|
| 1391 |
+
elif debug_mode == 1:
|
| 1392 |
+
_C._set_deterministic_algorithms(True, warn_only=True)
|
| 1393 |
+
elif debug_mode == 2:
|
| 1394 |
+
_C._set_deterministic_algorithms(True)
|
| 1395 |
+
else:
|
| 1396 |
+
raise RuntimeError(
|
| 1397 |
+
"invalid value of debug_mode, expected 0, 1, or 2, " f"but got {debug_mode}"
|
| 1398 |
+
)
|
| 1399 |
+
|
| 1400 |
+
|
| 1401 |
+
def get_deterministic_debug_mode() -> builtins.int:
|
| 1402 |
+
r"""Returns the current value of the debug mode for deterministic
|
| 1403 |
+
operations. Refer to :func:`torch.set_deterministic_debug_mode`
|
| 1404 |
+
documentation for more details.
|
| 1405 |
+
"""
|
| 1406 |
+
|
| 1407 |
+
if _C._get_deterministic_algorithms():
|
| 1408 |
+
if _C._get_deterministic_algorithms_warn_only():
|
| 1409 |
+
return 1
|
| 1410 |
+
else:
|
| 1411 |
+
return 2
|
| 1412 |
+
else:
|
| 1413 |
+
return 0
|
| 1414 |
+
|
| 1415 |
+
|
| 1416 |
+
def get_float32_matmul_precision() -> str:
|
| 1417 |
+
r"""Returns the current value of float32 matrix multiplication precision. Refer to
|
| 1418 |
+
:func:`torch.set_float32_matmul_precision` documentation for more details.
|
| 1419 |
+
"""
|
| 1420 |
+
return _C._get_float32_matmul_precision()
|
| 1421 |
+
|
| 1422 |
+
|
| 1423 |
+
def set_float32_matmul_precision(precision: str) -> None:
|
| 1424 |
+
r"""Sets the internal precision of float32 matrix multiplications.
|
| 1425 |
+
|
| 1426 |
+
Running float32 matrix multiplications in lower precision may significantly increase
|
| 1427 |
+
performance, and in some programs the loss of precision has a negligible impact.
|
| 1428 |
+
|
| 1429 |
+
Supports three settings:
|
| 1430 |
+
|
| 1431 |
+
* "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
|
| 1432 |
+
bits with 23 bits explicitly stored) for internal computations.
|
| 1433 |
+
* "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
|
| 1434 |
+
mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
|
| 1435 |
+
(approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
|
| 1436 |
+
algorithms are available. Otherwise float32 matrix multiplications are computed
|
| 1437 |
+
as if the precision is "highest". See below for more information on the bfloat16
|
| 1438 |
+
approach.
|
| 1439 |
+
* "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
|
| 1440 |
+
bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
|
| 1441 |
+
using that datatype internally is available. Otherwise float32
|
| 1442 |
+
matrix multiplications are computed as if the precision is "high".
|
| 1443 |
+
|
| 1444 |
+
When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
|
| 1445 |
+
that is more complicated than simply truncating to some smaller number mantissa bits
|
| 1446 |
+
(e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete
|
| 1447 |
+
description of this algorithm. To briefly explain here, the first step is to realize
|
| 1448 |
+
that we can perfectly encode a single float32 number as the sum of three bfloat16
|
| 1449 |
+
numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
|
| 1450 |
+
same number of exponent bits). This means that the product of two float32 numbers can
|
| 1451 |
+
be exactly given by the sum of nine products of bfloat16 numbers. We can then trade
|
| 1452 |
+
accuracy for speed by dropping some of these products. The "high" precision algorithm
|
| 1453 |
+
specifically keeps only the three most significant products, which conveniently excludes
|
| 1454 |
+
all of the products involving the last 8 mantissa bits of either input. This means that
|
| 1455 |
+
we can represent our inputs as the sum of two bfloat16 numbers rather than three.
|
| 1456 |
+
Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than
|
| 1457 |
+
float32 ones, it's faster to do three multiplications and 2 additions with bfloat16
|
| 1458 |
+
precision than it is to do a single multiplication with float32 precision.
|
| 1459 |
+
|
| 1460 |
+
.. [Henry2019] http://arxiv.org/abs/1904.06376
|
| 1461 |
+
|
| 1462 |
+
.. note::
|
| 1463 |
+
|
| 1464 |
+
This does not change the output dtype of float32 matrix multiplications,
|
| 1465 |
+
it controls how the internal computation of the matrix multiplication is performed.
|
| 1466 |
+
|
| 1467 |
+
.. note::
|
| 1468 |
+
|
| 1469 |
+
This does not change the precision of convolution operations. Other flags,
|
| 1470 |
+
like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution
|
| 1471 |
+
operations.
|
| 1472 |
+
|
| 1473 |
+
.. note::
|
| 1474 |
+
|
| 1475 |
+
This flag currently only affects one native device type: CUDA.
|
| 1476 |
+
If "high" or "medium" are set then the TensorFloat32 datatype will be used
|
| 1477 |
+
when computing float32 matrix multiplications, equivalent to setting
|
| 1478 |
+
`torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default)
|
| 1479 |
+
is set then the float32 datatype is used for internal computations, equivalent
|
| 1480 |
+
to setting `torch.backends.cuda.matmul.allow_tf32 = False`.
|
| 1481 |
+
|
| 1482 |
+
Args:
|
| 1483 |
+
precision(str): can be set to "highest" (default), "high", or "medium" (see above).
|
| 1484 |
+
|
| 1485 |
+
"""
|
| 1486 |
+
_C._set_float32_matmul_precision(precision)
|
| 1487 |
+
|
| 1488 |
+
|
| 1489 |
+
def set_warn_always(b: builtins.bool, /) -> None:
|
| 1490 |
+
r"""When this flag is False (default) then some PyTorch warnings may only
|
| 1491 |
+
appear once per process. This helps avoid excessive warning information.
|
| 1492 |
+
Setting it to True causes these warnings to always appear, which may be
|
| 1493 |
+
helpful when debugging.
|
| 1494 |
+
|
| 1495 |
+
Args:
|
| 1496 |
+
b (:class:`bool`): If True, force warnings to always be emitted
|
| 1497 |
+
If False, set to the default behaviour
|
| 1498 |
+
"""
|
| 1499 |
+
_C._set_warnAlways(b)
|
| 1500 |
+
|
| 1501 |
+
|
| 1502 |
+
def is_warn_always_enabled() -> builtins.bool:
|
| 1503 |
+
r"""Returns True if the global warn_always flag is turned on. Refer to
|
| 1504 |
+
:func:`torch.set_warn_always` documentation for more details.
|
| 1505 |
+
"""
|
| 1506 |
+
return _C._get_warnAlways()
|
| 1507 |
+
|
| 1508 |
+
|
| 1509 |
+
################################################################################
|
| 1510 |
+
# Define error checking functions
|
| 1511 |
+
################################################################################
|
| 1512 |
+
|
| 1513 |
+
# These error checking functions must be kept consistent with their C++
|
| 1514 |
+
# equivalents. Their C++ equivalents are mentioned where applicable.
|
| 1515 |
+
|
| 1516 |
+
|
| 1517 |
+
def _check_with(
|
| 1518 |
+
error_type,
|
| 1519 |
+
cond: _Union[builtins.bool, SymBool],
|
| 1520 |
+
message: _Callable[[], str],
|
| 1521 |
+
): # noqa: F811
|
| 1522 |
+
if not isinstance(cond, (builtins.bool, SymBool)):
|
| 1523 |
+
raise TypeError(f"cond must be a bool, but got {type(cond)}")
|
| 1524 |
+
|
| 1525 |
+
from torch.fx.experimental.symbolic_shapes import expect_true
|
| 1526 |
+
|
| 1527 |
+
if expect_true(cond):
|
| 1528 |
+
return
|
| 1529 |
+
|
| 1530 |
+
# error_type must be a subclass of Exception and not subclass of Warning
|
| 1531 |
+
assert issubclass(error_type, Exception) and not issubclass(error_type, Warning)
|
| 1532 |
+
|
| 1533 |
+
if message is None:
|
| 1534 |
+
message_evaluated = (
|
| 1535 |
+
"Expected cond to be True, but got False. (Could this error "
|
| 1536 |
+
"message be improved? If so, please report an enhancement request "
|
| 1537 |
+
"to PyTorch.)"
|
| 1538 |
+
)
|
| 1539 |
+
|
| 1540 |
+
else:
|
| 1541 |
+
if not callable(message):
|
| 1542 |
+
raise TypeError("message must be a callable")
|
| 1543 |
+
|
| 1544 |
+
message_evaluated = str(message())
|
| 1545 |
+
|
| 1546 |
+
raise error_type(message_evaluated)
|
| 1547 |
+
|
| 1548 |
+
|
| 1549 |
+
def _check(cond, message=None): # noqa: F811
|
| 1550 |
+
r"""Throws error containing an optional message if the specified condition
|
| 1551 |
+
is False.
|
| 1552 |
+
|
| 1553 |
+
Error type: ``RuntimeError``
|
| 1554 |
+
|
| 1555 |
+
C++ equivalent: ``TORCH_CHECK``
|
| 1556 |
+
|
| 1557 |
+
Args:
|
| 1558 |
+
cond (:class:`bool`): If False, throw error
|
| 1559 |
+
|
| 1560 |
+
message (Callable, optional): Callable that returns either a string or
|
| 1561 |
+
an object that has a ``__str__()`` method to be used as the error
|
| 1562 |
+
message. Default: ``None``
|
| 1563 |
+
"""
|
| 1564 |
+
_check_with(RuntimeError, cond, message)
|
| 1565 |
+
|
| 1566 |
+
|
| 1567 |
+
def _check_is_size(i, message=None):
|
| 1568 |
+
"""Checks that a given integer is a valid size (i.e., is non-negative).
|
| 1569 |
+
You should use this over _check(i >= 0) because we can use the semantic
|
| 1570 |
+
information (that i is a size) to make some further inferences in case
|
| 1571 |
+
i is an unbacked SymInt.
|
| 1572 |
+
|
| 1573 |
+
NB: Do NOT use this in contexts where a -1 size would be valid (indicating
|
| 1574 |
+
to infer the size from context, or if you should wrap-around or truncate).
|
| 1575 |
+
Only use this if the only valid value is an honest to goodness size.
|
| 1576 |
+
"""
|
| 1577 |
+
# This is responsible for the expect_true
|
| 1578 |
+
_check(i >= 0, message)
|
| 1579 |
+
from torch.fx.experimental.symbolic_shapes import _advise_is_size
|
| 1580 |
+
|
| 1581 |
+
_advise_is_size(i)
|
| 1582 |
+
|
| 1583 |
+
|
| 1584 |
+
def _check_index(cond, message=None): # noqa: F811
|
| 1585 |
+
r"""Throws error containing an optional message if the specified condition
|
| 1586 |
+
is False.
|
| 1587 |
+
|
| 1588 |
+
Error type: ``IndexError``
|
| 1589 |
+
|
| 1590 |
+
C++ equivalent: ``TORCH_CHECK_INDEX``
|
| 1591 |
+
|
| 1592 |
+
Args:
|
| 1593 |
+
cond (:class:`bool`): If False, throw error
|
| 1594 |
+
|
| 1595 |
+
message (Callable, optional): Callable that returns either a string or
|
| 1596 |
+
an object that has a ``__str__()`` method to be used as the error
|
| 1597 |
+
message. Default: ``None``
|
| 1598 |
+
"""
|
| 1599 |
+
_check_with(IndexError, cond, message)
|
| 1600 |
+
|
| 1601 |
+
|
| 1602 |
+
def _check_value(cond, message=None): # noqa: F811
|
| 1603 |
+
r"""Throws error containing an optional message if the specified condition
|
| 1604 |
+
is False.
|
| 1605 |
+
|
| 1606 |
+
Error type: ``ValueError``
|
| 1607 |
+
|
| 1608 |
+
C++ equivalent: ``TORCH_CHECK_VALUE``
|
| 1609 |
+
|
| 1610 |
+
Args:
|
| 1611 |
+
cond (:class:`bool`): If False, throw error
|
| 1612 |
+
|
| 1613 |
+
message (Callable, optional): Callable that returns either a string or
|
| 1614 |
+
an object that has a ``__str__()`` method to be used as the error
|
| 1615 |
+
message. Default: ``None``
|
| 1616 |
+
"""
|
| 1617 |
+
_check_with(ValueError, cond, message)
|
| 1618 |
+
|
| 1619 |
+
|
| 1620 |
+
def _check_type(cond, message=None): # noqa: F811
|
| 1621 |
+
r"""Throws error containing an optional message if the specified condition
|
| 1622 |
+
is False.
|
| 1623 |
+
|
| 1624 |
+
Error type: ``TypeError``
|
| 1625 |
+
|
| 1626 |
+
C++ equivalent: ``TORCH_CHECK_TYPE``
|
| 1627 |
+
|
| 1628 |
+
Args:
|
| 1629 |
+
cond (:class:`bool`): If False, throw error
|
| 1630 |
+
|
| 1631 |
+
message (Callable, optional): Callable that returns either a string or
|
| 1632 |
+
an object that has a ``__str__()`` method to be used as the error
|
| 1633 |
+
message. Default: ``None``
|
| 1634 |
+
"""
|
| 1635 |
+
_check_with(TypeError, cond, message)
|
| 1636 |
+
|
| 1637 |
+
|
| 1638 |
+
def _check_not_implemented(cond, message=None): # noqa: F811
|
| 1639 |
+
r"""Throws error containing an optional message if the specified condition
|
| 1640 |
+
is False.
|
| 1641 |
+
|
| 1642 |
+
Error type: ``NotImplementedError``
|
| 1643 |
+
|
| 1644 |
+
C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED``
|
| 1645 |
+
|
| 1646 |
+
Args:
|
| 1647 |
+
cond (:class:`bool`): If False, throw error
|
| 1648 |
+
|
| 1649 |
+
message (Callable, optional): Callable that returns either a string or
|
| 1650 |
+
an object that has a ``__str__()`` method to be used as the error
|
| 1651 |
+
message. Default: ``None``
|
| 1652 |
+
"""
|
| 1653 |
+
_check_with(NotImplementedError, cond, message)
|
| 1654 |
+
|
| 1655 |
+
|
| 1656 |
+
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
|
| 1657 |
+
if not is_tensor(cond):
|
| 1658 |
+
raise TypeError(f"cond must be a tensor, but got {type(cond)}")
|
| 1659 |
+
|
| 1660 |
+
if not cond.dtype == torch.bool:
|
| 1661 |
+
raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")
|
| 1662 |
+
|
| 1663 |
+
_check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type]
|
| 1664 |
+
|
| 1665 |
+
|
| 1666 |
+
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
|
| 1667 |
+
def _check_tensor_all(cond, message=None): # noqa: F811
|
| 1668 |
+
r"""Throws error containing an optional message if the specified condition
|
| 1669 |
+
is False.
|
| 1670 |
+
|
| 1671 |
+
Error type: ``RuntimeError``
|
| 1672 |
+
|
| 1673 |
+
C++ equivalent: ``TORCH_CHECK_TENSOR_ALL``
|
| 1674 |
+
|
| 1675 |
+
Args:
|
| 1676 |
+
cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any
|
| 1677 |
+
element is ``False``, throw error
|
| 1678 |
+
|
| 1679 |
+
message (Callable, optional): Callable that returns either a string or
|
| 1680 |
+
an object that has a ``__str__()`` method to be used as the error
|
| 1681 |
+
message. Default: ``None``
|
| 1682 |
+
"""
|
| 1683 |
+
_check_tensor_all_with(RuntimeError, cond, message)
|
| 1684 |
+
|
| 1685 |
+
|
| 1686 |
+
################################################################################
|
| 1687 |
+
# Define numeric constants
|
| 1688 |
+
################################################################################
|
| 1689 |
+
|
| 1690 |
+
# For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and
|
| 1691 |
+
# NumPy consistency (https://numpy.org/devdocs/reference/constants.html)
|
| 1692 |
+
from math import e, inf, nan, pi
|
| 1693 |
+
|
| 1694 |
+
|
| 1695 |
+
newaxis: None = None
|
| 1696 |
+
|
| 1697 |
+
__all__.extend(["e", "pi", "nan", "inf", "newaxis"])
|
| 1698 |
+
|
| 1699 |
+
################################################################################
|
| 1700 |
+
# Define Storage and Tensor classes
|
| 1701 |
+
################################################################################
|
| 1702 |
+
|
| 1703 |
+
from torch._tensor import Tensor # usort: skip
|
| 1704 |
+
|
| 1705 |
+
# needs to be after torch.Tensor is defined to avoid circular dependencies
|
| 1706 |
+
from torch import storage as storage # usort: skip
|
| 1707 |
+
from torch.storage import (
|
| 1708 |
+
_LegacyStorage,
|
| 1709 |
+
_StorageBase,
|
| 1710 |
+
_warn_typed_storage_removal,
|
| 1711 |
+
TypedStorage,
|
| 1712 |
+
UntypedStorage,
|
| 1713 |
+
)
|
| 1714 |
+
|
| 1715 |
+
|
| 1716 |
+
# NOTE: New <type>Storage classes should never be added. When adding a new
|
| 1717 |
+
# dtype, use torch.storage.TypedStorage directly.
|
| 1718 |
+
class ByteStorage(_LegacyStorage):
|
| 1719 |
+
@classproperty
|
| 1720 |
+
def dtype(self):
|
| 1721 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1722 |
+
return self._dtype
|
| 1723 |
+
|
| 1724 |
+
@classproperty
|
| 1725 |
+
def _dtype(self):
|
| 1726 |
+
return torch.uint8
|
| 1727 |
+
|
| 1728 |
+
|
| 1729 |
+
class DoubleStorage(_LegacyStorage):
|
| 1730 |
+
@classproperty
|
| 1731 |
+
def dtype(self):
|
| 1732 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1733 |
+
return self._dtype
|
| 1734 |
+
|
| 1735 |
+
@classproperty
|
| 1736 |
+
def _dtype(self):
|
| 1737 |
+
return torch.double
|
| 1738 |
+
|
| 1739 |
+
|
| 1740 |
+
class FloatStorage(_LegacyStorage):
|
| 1741 |
+
@classproperty
|
| 1742 |
+
def dtype(self):
|
| 1743 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1744 |
+
return self._dtype
|
| 1745 |
+
|
| 1746 |
+
@classproperty
|
| 1747 |
+
def _dtype(self):
|
| 1748 |
+
return torch.float
|
| 1749 |
+
|
| 1750 |
+
|
| 1751 |
+
class HalfStorage(_LegacyStorage):
|
| 1752 |
+
@classproperty
|
| 1753 |
+
def dtype(self):
|
| 1754 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1755 |
+
return self._dtype
|
| 1756 |
+
|
| 1757 |
+
@classproperty
|
| 1758 |
+
def _dtype(self):
|
| 1759 |
+
return torch.half
|
| 1760 |
+
|
| 1761 |
+
|
| 1762 |
+
class LongStorage(_LegacyStorage):
|
| 1763 |
+
@classproperty
|
| 1764 |
+
def dtype(self):
|
| 1765 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1766 |
+
return self._dtype
|
| 1767 |
+
|
| 1768 |
+
@classproperty
|
| 1769 |
+
def _dtype(self):
|
| 1770 |
+
return torch.long
|
| 1771 |
+
|
| 1772 |
+
|
| 1773 |
+
class IntStorage(_LegacyStorage):
|
| 1774 |
+
@classproperty
|
| 1775 |
+
def dtype(self):
|
| 1776 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1777 |
+
return self._dtype
|
| 1778 |
+
|
| 1779 |
+
@classproperty
|
| 1780 |
+
def _dtype(self):
|
| 1781 |
+
return torch.int
|
| 1782 |
+
|
| 1783 |
+
|
| 1784 |
+
class ShortStorage(_LegacyStorage):
|
| 1785 |
+
@classproperty
|
| 1786 |
+
def dtype(self):
|
| 1787 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1788 |
+
return self._dtype
|
| 1789 |
+
|
| 1790 |
+
@classproperty
|
| 1791 |
+
def _dtype(self):
|
| 1792 |
+
return torch.short
|
| 1793 |
+
|
| 1794 |
+
|
| 1795 |
+
class CharStorage(_LegacyStorage):
|
| 1796 |
+
@classproperty
|
| 1797 |
+
def dtype(self):
|
| 1798 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1799 |
+
return self._dtype
|
| 1800 |
+
|
| 1801 |
+
@classproperty
|
| 1802 |
+
def _dtype(self):
|
| 1803 |
+
return torch.int8
|
| 1804 |
+
|
| 1805 |
+
|
| 1806 |
+
class BoolStorage(_LegacyStorage):
|
| 1807 |
+
@classproperty
|
| 1808 |
+
def dtype(self):
|
| 1809 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1810 |
+
return self._dtype
|
| 1811 |
+
|
| 1812 |
+
@classproperty
|
| 1813 |
+
def _dtype(self):
|
| 1814 |
+
return torch.bool
|
| 1815 |
+
|
| 1816 |
+
|
| 1817 |
+
class BFloat16Storage(_LegacyStorage):
|
| 1818 |
+
@classproperty
|
| 1819 |
+
def dtype(self):
|
| 1820 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1821 |
+
return self._dtype
|
| 1822 |
+
|
| 1823 |
+
@classproperty
|
| 1824 |
+
def _dtype(self):
|
| 1825 |
+
return torch.bfloat16
|
| 1826 |
+
|
| 1827 |
+
|
| 1828 |
+
class ComplexDoubleStorage(_LegacyStorage):
|
| 1829 |
+
@classproperty
|
| 1830 |
+
def dtype(self):
|
| 1831 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1832 |
+
return self._dtype
|
| 1833 |
+
|
| 1834 |
+
@classproperty
|
| 1835 |
+
def _dtype(self):
|
| 1836 |
+
return torch.cdouble
|
| 1837 |
+
|
| 1838 |
+
|
| 1839 |
+
class ComplexFloatStorage(_LegacyStorage):
|
| 1840 |
+
@classproperty
|
| 1841 |
+
def dtype(self):
|
| 1842 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1843 |
+
return self._dtype
|
| 1844 |
+
|
| 1845 |
+
@classproperty
|
| 1846 |
+
def _dtype(self):
|
| 1847 |
+
return torch.cfloat
|
| 1848 |
+
|
| 1849 |
+
|
| 1850 |
+
class QUInt8Storage(_LegacyStorage):
|
| 1851 |
+
@classproperty
|
| 1852 |
+
def dtype(self):
|
| 1853 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1854 |
+
return self._dtype
|
| 1855 |
+
|
| 1856 |
+
@classproperty
|
| 1857 |
+
def _dtype(self):
|
| 1858 |
+
return torch.quint8
|
| 1859 |
+
|
| 1860 |
+
|
| 1861 |
+
class QInt8Storage(_LegacyStorage):
|
| 1862 |
+
@classproperty
|
| 1863 |
+
def dtype(self):
|
| 1864 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1865 |
+
return self._dtype
|
| 1866 |
+
|
| 1867 |
+
@classproperty
|
| 1868 |
+
def _dtype(self):
|
| 1869 |
+
return torch.qint8
|
| 1870 |
+
|
| 1871 |
+
|
| 1872 |
+
class QInt32Storage(_LegacyStorage):
|
| 1873 |
+
@classproperty
|
| 1874 |
+
def dtype(self):
|
| 1875 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1876 |
+
return self._dtype
|
| 1877 |
+
|
| 1878 |
+
@classproperty
|
| 1879 |
+
def _dtype(self):
|
| 1880 |
+
return torch.qint32
|
| 1881 |
+
|
| 1882 |
+
|
| 1883 |
+
class QUInt4x2Storage(_LegacyStorage):
|
| 1884 |
+
@classproperty
|
| 1885 |
+
def dtype(self):
|
| 1886 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1887 |
+
return self._dtype
|
| 1888 |
+
|
| 1889 |
+
@classproperty
|
| 1890 |
+
def _dtype(self):
|
| 1891 |
+
return torch.quint4x2
|
| 1892 |
+
|
| 1893 |
+
|
| 1894 |
+
class QUInt2x4Storage(_LegacyStorage):
|
| 1895 |
+
@classproperty
|
| 1896 |
+
def dtype(self):
|
| 1897 |
+
_warn_typed_storage_removal(stacklevel=3)
|
| 1898 |
+
return self._dtype
|
| 1899 |
+
|
| 1900 |
+
@classproperty
|
| 1901 |
+
def _dtype(self):
|
| 1902 |
+
return torch.quint2x4
|
| 1903 |
+
|
| 1904 |
+
|
| 1905 |
+
_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
|
| 1906 |
+
UntypedStorage,
|
| 1907 |
+
DoubleStorage,
|
| 1908 |
+
FloatStorage,
|
| 1909 |
+
LongStorage,
|
| 1910 |
+
IntStorage,
|
| 1911 |
+
ShortStorage,
|
| 1912 |
+
CharStorage,
|
| 1913 |
+
ByteStorage,
|
| 1914 |
+
HalfStorage,
|
| 1915 |
+
BoolStorage,
|
| 1916 |
+
QUInt8Storage,
|
| 1917 |
+
QInt8Storage,
|
| 1918 |
+
QInt32Storage,
|
| 1919 |
+
BFloat16Storage,
|
| 1920 |
+
ComplexFloatStorage,
|
| 1921 |
+
ComplexDoubleStorage,
|
| 1922 |
+
QUInt4x2Storage,
|
| 1923 |
+
QUInt2x4Storage,
|
| 1924 |
+
TypedStorage,
|
| 1925 |
+
}
|
| 1926 |
+
|
| 1927 |
+
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
|
| 1928 |
+
_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
|
| 1929 |
+
|
| 1930 |
+
# If you edit these imports, please update torch/__init__.py.in as well
|
| 1931 |
+
from torch import amp as amp, random as random, serialization as serialization
|
| 1932 |
+
from torch._tensor_str import set_printoptions
|
| 1933 |
+
from torch.amp import autocast, GradScaler
|
| 1934 |
+
from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
|
| 1935 |
+
from torch.serialization import load, save
|
| 1936 |
+
|
| 1937 |
+
|
| 1938 |
+
################################################################################
|
| 1939 |
+
# Initialize extension
|
| 1940 |
+
################################################################################
|
| 1941 |
+
|
| 1942 |
+
|
| 1943 |
+
# Shared memory manager needs to know the exact location of manager executable
|
| 1944 |
+
def _manager_path():
|
| 1945 |
+
if _running_with_deploy() or platform.system() == "Windows":
|
| 1946 |
+
return b""
|
| 1947 |
+
path = get_file_path("torch", "bin", "torch_shm_manager")
|
| 1948 |
+
prepare_multiprocessing_environment(get_file_path("torch"))
|
| 1949 |
+
if not os.path.exists(path):
|
| 1950 |
+
raise RuntimeError("Unable to find torch_shm_manager at " + path)
|
| 1951 |
+
return path.encode("utf-8")
|
| 1952 |
+
|
| 1953 |
+
|
| 1954 |
+
_C._initExtension(_manager_path())
|
| 1955 |
+
|
| 1956 |
+
del _manager_path
|
| 1957 |
+
|
| 1958 |
+
# Appease the type checker: it can't deal with direct setting of globals().
|
| 1959 |
+
# Note that we will see "too many" functions when reexporting this way; there
|
| 1960 |
+
# is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions
|
| 1961 |
+
# so that this import is good enough
|
| 1962 |
+
if TYPE_CHECKING:
|
| 1963 |
+
# Some type signatures pulled in from _VariableFunctions here clash with
|
| 1964 |
+
# signatures already imported. For now these clashes are ignored; see
|
| 1965 |
+
# PR #43339 for details.
|
| 1966 |
+
from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
|
| 1967 |
+
|
| 1968 |
+
# Fixup segment_reduce visibility
|
| 1969 |
+
_segment_reduce = segment_reduce
|
| 1970 |
+
del segment_reduce # noqa: F821
|
| 1971 |
+
|
| 1972 |
+
# Ops not to be exposed in `torch` namespace,
|
| 1973 |
+
# mostly helper ops.
|
| 1974 |
+
PRIVATE_OPS = ("unique_dim",)
|
| 1975 |
+
|
| 1976 |
+
__name, __obj = "", None
|
| 1977 |
+
for __name in dir(_C._VariableFunctions):
|
| 1978 |
+
if __name.startswith("__") or __name in PRIVATE_OPS:
|
| 1979 |
+
continue
|
| 1980 |
+
__obj = getattr(_C._VariableFunctions, __name)
|
| 1981 |
+
__obj.__module__ = __name__ # "torch"
|
| 1982 |
+
# Hide some APIs that should not be public
|
| 1983 |
+
if __name == "segment_reduce":
|
| 1984 |
+
# TODO: Once the undocumented FC window is passed, remove the line bellow
|
| 1985 |
+
globals()[__name] = __obj
|
| 1986 |
+
__name = "_" + __name
|
| 1987 |
+
globals()[__name] = __obj
|
| 1988 |
+
if not __name.startswith("_"):
|
| 1989 |
+
__all__.append(__name)
|
| 1990 |
+
|
| 1991 |
+
del __name, __obj
|
| 1992 |
+
|
| 1993 |
+
################################################################################
|
| 1994 |
+
# Add torch.dtype instances to the public API
|
| 1995 |
+
################################################################################
|
| 1996 |
+
|
| 1997 |
+
import torch
|
| 1998 |
+
|
| 1999 |
+
|
| 2000 |
+
__all__.extend(
|
| 2001 |
+
name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype)
|
| 2002 |
+
)
|
| 2003 |
+
|
| 2004 |
+
################################################################################
|
| 2005 |
+
# Import TorchDynamo's lazy APIs to avoid circular dependenices
|
| 2006 |
+
################################################################################
|
| 2007 |
+
|
| 2008 |
+
# needs to be before from torch.functional import * to avoid circular dependencies
|
| 2009 |
+
from torch._compile import _disable_dynamo # usort: skip
|
| 2010 |
+
|
| 2011 |
+
################################################################################
|
| 2012 |
+
# Import interface functions defined in Python
|
| 2013 |
+
################################################################################
|
| 2014 |
+
|
| 2015 |
+
# needs to be after the above ATen bindings so we can overwrite from Python side
|
| 2016 |
+
from torch import _VF as _VF, functional as functional # usort: skip
|
| 2017 |
+
from torch.functional import * # usort: skip # noqa: F403
|
| 2018 |
+
|
| 2019 |
+
################################################################################
|
| 2020 |
+
# Remove unnecessary members
|
| 2021 |
+
################################################################################
|
| 2022 |
+
|
| 2023 |
+
del _StorageBase
|
| 2024 |
+
del _LegacyStorage
|
| 2025 |
+
|
| 2026 |
+
################################################################################
|
| 2027 |
+
# Define _assert
|
| 2028 |
+
################################################################################
|
| 2029 |
+
|
| 2030 |
+
|
| 2031 |
+
# needs to be before the submodule imports to avoid circular dependencies
|
| 2032 |
+
def _assert(condition, message):
|
| 2033 |
+
r"""A wrapper around Python's assert which is symbolically traceable."""
|
| 2034 |
+
if type(condition) is not torch.Tensor and overrides.has_torch_function(
|
| 2035 |
+
(condition,)
|
| 2036 |
+
):
|
| 2037 |
+
return overrides.handle_torch_function(
|
| 2038 |
+
_assert, (condition,), condition, message
|
| 2039 |
+
)
|
| 2040 |
+
assert condition, message
|
| 2041 |
+
|
| 2042 |
+
|
| 2043 |
+
################################################################################
|
| 2044 |
+
# Import most common subpackages
|
| 2045 |
+
################################################################################
|
| 2046 |
+
|
| 2047 |
+
# Use the redundant form so that type checkers know that these are a part of
|
| 2048 |
+
# the public API. The "regular" import lines are there solely for the runtime
|
| 2049 |
+
# side effect of adding to the imported module's members for other users.
|
| 2050 |
+
|
| 2051 |
+
# needs to be before import torch.nn as nn to avoid circular dependencies
|
| 2052 |
+
from torch.autograd import ( # usort: skip
|
| 2053 |
+
enable_grad as enable_grad,
|
| 2054 |
+
inference_mode as inference_mode,
|
| 2055 |
+
no_grad as no_grad,
|
| 2056 |
+
set_grad_enabled as set_grad_enabled,
|
| 2057 |
+
)
|
| 2058 |
+
|
| 2059 |
+
from torch import (
|
| 2060 |
+
__config__ as __config__,
|
| 2061 |
+
__future__ as __future__,
|
| 2062 |
+
_awaits as _awaits,
|
| 2063 |
+
autograd as autograd,
|
| 2064 |
+
backends as backends,
|
| 2065 |
+
cpu as cpu,
|
| 2066 |
+
cuda as cuda,
|
| 2067 |
+
distributed as distributed,
|
| 2068 |
+
distributions as distributions,
|
| 2069 |
+
fft as fft,
|
| 2070 |
+
futures as futures,
|
| 2071 |
+
hub as hub,
|
| 2072 |
+
jit as jit,
|
| 2073 |
+
linalg as linalg,
|
| 2074 |
+
mps as mps,
|
| 2075 |
+
mtia as mtia,
|
| 2076 |
+
multiprocessing as multiprocessing,
|
| 2077 |
+
nested as nested,
|
| 2078 |
+
nn as nn,
|
| 2079 |
+
optim as optim,
|
| 2080 |
+
overrides as overrides,
|
| 2081 |
+
profiler as profiler,
|
| 2082 |
+
sparse as sparse,
|
| 2083 |
+
special as special,
|
| 2084 |
+
testing as testing,
|
| 2085 |
+
types as types,
|
| 2086 |
+
utils as utils,
|
| 2087 |
+
xpu as xpu,
|
| 2088 |
+
)
|
| 2089 |
+
from torch.signal import windows as windows
|
| 2090 |
+
|
| 2091 |
+
|
| 2092 |
+
# Quantized, sparse, AO, etc. should be last to get imported, as nothing
|
| 2093 |
+
# is expected to depend on them.
|
| 2094 |
+
from torch import ao as ao # usort: skip
|
| 2095 |
+
|
| 2096 |
+
# nn.quant* depends on ao -- so should be after those.
|
| 2097 |
+
import torch.nn.intrinsic
|
| 2098 |
+
import torch.nn.qat
|
| 2099 |
+
import torch.nn.quantizable
|
| 2100 |
+
import torch.nn.quantized
|
| 2101 |
+
|
| 2102 |
+
|
| 2103 |
+
_C._init_names(list(_storage_classes))
|
| 2104 |
+
|
| 2105 |
+
# attach docstrings to torch and tensor functions
|
| 2106 |
+
from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs
|
| 2107 |
+
|
| 2108 |
+
|
| 2109 |
+
del _torch_docs, _tensor_docs, _storage_docs, _size_docs
|
| 2110 |
+
|
| 2111 |
+
|
| 2112 |
+
def compiled_with_cxx11_abi() -> builtins.bool:
|
| 2113 |
+
r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
|
| 2114 |
+
return _C._GLIBCXX_USE_CXX11_ABI
|
| 2115 |
+
|
| 2116 |
+
|
| 2117 |
+
from torch import _library as _library, _ops as _ops
|
| 2118 |
+
|
| 2119 |
+
|
| 2120 |
+
# Import the ops and classes "namespace"
|
| 2121 |
+
from torch._ops import ops as ops # usort: skip
|
| 2122 |
+
from torch._classes import classes as classes # usort: skip
|
| 2123 |
+
|
| 2124 |
+
sys.modules.setdefault(f"{__name__}.ops", ops)
|
| 2125 |
+
sys.modules.setdefault(f"{__name__}.classes", classes)
|
| 2126 |
+
|
| 2127 |
+
# quantization depends on torch.fx and torch.ops
|
| 2128 |
+
# Import quantization
|
| 2129 |
+
from torch import quantization as quantization # usort: skip
|
| 2130 |
+
|
| 2131 |
+
# Import the quasi random sampler
|
| 2132 |
+
from torch import quasirandom as quasirandom # usort: skip
|
| 2133 |
+
|
| 2134 |
+
# If you are seeing this, it means that this call site was not checked if
|
| 2135 |
+
# the memory format could be preserved, and it was switched to old default
|
| 2136 |
+
# behaviour of contiguous
|
| 2137 |
+
legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
|
| 2138 |
+
|
| 2139 |
+
# Register fork handler to initialize OpenMP in child processes (see gh-28389)
|
| 2140 |
+
from torch.multiprocessing._atfork import register_after_fork
|
| 2141 |
+
|
| 2142 |
+
|
| 2143 |
+
register_after_fork(torch.get_num_threads)
|
| 2144 |
+
del register_after_fork
|
| 2145 |
+
|
| 2146 |
+
# Import tools that require fully imported torch (for applying
|
| 2147 |
+
# torch.jit.script as a decorator, for instance):
|
| 2148 |
+
from torch._lobpcg import lobpcg as lobpcg
|
| 2149 |
+
|
| 2150 |
+
|
| 2151 |
+
# These were previously defined in native_functions.yaml and appeared on the
|
| 2152 |
+
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
|
| 2153 |
+
# class usage. We add these lines here to preserve backward compatibility.
|
| 2154 |
+
quantized_lstm = ops.aten.quantized_lstm
|
| 2155 |
+
quantized_gru = ops.aten.quantized_gru
|
| 2156 |
+
|
| 2157 |
+
# Import experimental masked operations support. See
|
| 2158 |
+
# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
|
| 2159 |
+
# information.
|
| 2160 |
+
from torch import masked as masked
|
| 2161 |
+
|
| 2162 |
+
# Import removed ops with error message about removal
|
| 2163 |
+
from torch._linalg_utils import ( # type: ignore[misc]
|
| 2164 |
+
_symeig as symeig,
|
| 2165 |
+
eig,
|
| 2166 |
+
lstsq,
|
| 2167 |
+
matrix_rank,
|
| 2168 |
+
solve,
|
| 2169 |
+
)
|
| 2170 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
| 2171 |
+
|
| 2172 |
+
|
| 2173 |
+
class _TorchCompileInductorWrapper:
|
| 2174 |
+
compiler_name = "inductor"
|
| 2175 |
+
|
| 2176 |
+
def __init__(self, mode, options, dynamic):
|
| 2177 |
+
self.config: _Dict[str, _Any] = {}
|
| 2178 |
+
self.dynamic = dynamic
|
| 2179 |
+
self.apply_mode(mode)
|
| 2180 |
+
self.apply_options(options)
|
| 2181 |
+
|
| 2182 |
+
if self.config.get("triton.cudagraphs", False):
|
| 2183 |
+
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
|
| 2184 |
+
# FIXME: CUDA Graph does not work well with CUPTI teardown.
|
| 2185 |
+
# 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
|
| 2186 |
+
# 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
|
| 2187 |
+
# Workaround: turn off CUPTI teardown when using CUDA Graphs.
|
| 2188 |
+
os.environ["TEARDOWN_CUPTI"] = "0"
|
| 2189 |
+
|
| 2190 |
+
def __eq__(self, other):
|
| 2191 |
+
return (
|
| 2192 |
+
isinstance(other, _TorchCompileInductorWrapper)
|
| 2193 |
+
and self.config == other.config
|
| 2194 |
+
and self.dynamic == other.dynamic
|
| 2195 |
+
)
|
| 2196 |
+
|
| 2197 |
+
def apply_mode(self, mode: _Optional[str]):
|
| 2198 |
+
if mode is None or mode == "default":
|
| 2199 |
+
pass
|
| 2200 |
+
elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}:
|
| 2201 |
+
from torch._inductor import list_mode_options
|
| 2202 |
+
|
| 2203 |
+
self.apply_options(list_mode_options(mode, self.dynamic))
|
| 2204 |
+
else:
|
| 2205 |
+
raise RuntimeError(
|
| 2206 |
+
f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs"
|
| 2207 |
+
)
|
| 2208 |
+
|
| 2209 |
+
def apply_options(self, options: _Optional[_Dict[str, _Any]]):
|
| 2210 |
+
if not options:
|
| 2211 |
+
return
|
| 2212 |
+
|
| 2213 |
+
from torch._inductor import config
|
| 2214 |
+
|
| 2215 |
+
current_config: _Dict[str, _Any] = config.shallow_copy_dict()
|
| 2216 |
+
|
| 2217 |
+
for key, val in options.items():
|
| 2218 |
+
attr_name = key.replace("-", "_")
|
| 2219 |
+
if attr_name not in current_config:
|
| 2220 |
+
raise RuntimeError(
|
| 2221 |
+
f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
|
| 2222 |
+
)
|
| 2223 |
+
if type(val) is not type(current_config[attr_name]):
|
| 2224 |
+
val_type_str = type(val).__name__
|
| 2225 |
+
expected_type_str = type(current_config[attr_name]).__name__
|
| 2226 |
+
raise RuntimeError(
|
| 2227 |
+
f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
|
| 2228 |
+
)
|
| 2229 |
+
self.config[attr_name] = val
|
| 2230 |
+
|
| 2231 |
+
def __call__(self, model_, inputs_):
|
| 2232 |
+
from torch._inductor.compile_fx import compile_fx
|
| 2233 |
+
|
| 2234 |
+
return compile_fx(model_, inputs_, config_patches=self.config)
|
| 2235 |
+
|
| 2236 |
+
def get_compiler_config(self):
|
| 2237 |
+
from torch._inductor.compile_fx import get_patched_config_dict
|
| 2238 |
+
|
| 2239 |
+
return get_patched_config_dict(config_patches=self.config)
|
| 2240 |
+
|
| 2241 |
+
def reset(self):
|
| 2242 |
+
from torch._inductor import config
|
| 2243 |
+
|
| 2244 |
+
if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
|
| 2245 |
+
if self.config.get("triton.cudagraphs", True):
|
| 2246 |
+
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
| 2247 |
+
|
| 2248 |
+
reset_cudagraph_trees()
|
| 2249 |
+
|
| 2250 |
+
|
| 2251 |
+
class _TorchCompileWrapper:
|
| 2252 |
+
def __init__(self, backend, mode, options, dynamic):
|
| 2253 |
+
from torch._dynamo.backends.registry import lookup_backend
|
| 2254 |
+
|
| 2255 |
+
if isinstance(backend, str):
|
| 2256 |
+
self.compiler_name = backend
|
| 2257 |
+
elif hasattr(backend, "__name__"):
|
| 2258 |
+
self.compiler_name = backend.__name__
|
| 2259 |
+
else:
|
| 2260 |
+
self.compiler_name = str(backend)
|
| 2261 |
+
self.dynamic = dynamic
|
| 2262 |
+
self.compiler_fn = lookup_backend(backend)
|
| 2263 |
+
self.kwargs = {}
|
| 2264 |
+
# only pass the args if they non-empty
|
| 2265 |
+
if mode and mode != "default":
|
| 2266 |
+
self.kwargs["mode"] = mode
|
| 2267 |
+
if options:
|
| 2268 |
+
self.kwargs["options"] = options
|
| 2269 |
+
|
| 2270 |
+
def __eq__(self, other):
|
| 2271 |
+
return (
|
| 2272 |
+
isinstance(other, _TorchCompileWrapper)
|
| 2273 |
+
and self.compiler_fn == other.compiler_fn
|
| 2274 |
+
and self.kwargs == other.kwargs
|
| 2275 |
+
and self.dynamic == other.dynamic
|
| 2276 |
+
)
|
| 2277 |
+
|
| 2278 |
+
def __call__(self, model_, inputs_):
|
| 2279 |
+
return self.compiler_fn(model_, inputs_, **self.kwargs)
|
| 2280 |
+
|
| 2281 |
+
def reset(self):
|
| 2282 |
+
if hasattr(self.compiler_fn, "reset"):
|
| 2283 |
+
self.compiler_fn.reset()
|
| 2284 |
+
|
| 2285 |
+
|
| 2286 |
+
_InputT = _ParamSpec("_InputT")
|
| 2287 |
+
_RetT = _TypeVar("_RetT")
|
| 2288 |
+
|
| 2289 |
+
|
| 2290 |
+
@_overload
|
| 2291 |
+
def compile(
|
| 2292 |
+
model: _Callable[_InputT, _RetT],
|
| 2293 |
+
*,
|
| 2294 |
+
fullgraph: builtins.bool = False,
|
| 2295 |
+
dynamic: _Optional[builtins.bool] = None,
|
| 2296 |
+
backend: _Union[str, _Callable] = "inductor",
|
| 2297 |
+
mode: _Union[str, None] = None,
|
| 2298 |
+
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
| 2299 |
+
disable: builtins.bool = False,
|
| 2300 |
+
) -> _Callable[_InputT, _RetT]: ...
|
| 2301 |
+
|
| 2302 |
+
|
| 2303 |
+
@_overload
|
| 2304 |
+
def compile(
|
| 2305 |
+
model: None = None,
|
| 2306 |
+
*,
|
| 2307 |
+
fullgraph: builtins.bool = False,
|
| 2308 |
+
dynamic: _Optional[builtins.bool] = None,
|
| 2309 |
+
backend: _Union[str, _Callable] = "inductor",
|
| 2310 |
+
mode: _Union[str, None] = None,
|
| 2311 |
+
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
| 2312 |
+
disable: builtins.bool = False,
|
| 2313 |
+
) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
|
| 2314 |
+
|
| 2315 |
+
|
| 2316 |
+
def compile(
|
| 2317 |
+
model: _Optional[_Callable] = None,
|
| 2318 |
+
*,
|
| 2319 |
+
fullgraph: builtins.bool = False,
|
| 2320 |
+
dynamic: _Optional[builtins.bool] = None,
|
| 2321 |
+
backend: _Union[str, _Callable] = "inductor",
|
| 2322 |
+
mode: _Union[str, None] = None,
|
| 2323 |
+
options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
|
| 2324 |
+
disable: builtins.bool = False,
|
| 2325 |
+
) -> _Union[
|
| 2326 |
+
_Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]],
|
| 2327 |
+
_Callable[_InputT, _RetT],
|
| 2328 |
+
]:
|
| 2329 |
+
"""
|
| 2330 |
+
Optimizes given model/function using TorchDynamo and specified backend.
|
| 2331 |
+
If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile`
|
| 2332 |
+
to compile the module inplace without changing its structure.
|
| 2333 |
+
|
| 2334 |
+
Concretely, for every frame executed within the compiled region, we will attempt
|
| 2335 |
+
to compile it and cache the compiled result on the code object for future
|
| 2336 |
+
use. A single frame may be compiled multiple times if previous compiled
|
| 2337 |
+
results are not applicable for subsequent calls (this is called a "guard
|
| 2338 |
+
failure), you can use TORCH_LOGS=guards to debug these situations.
|
| 2339 |
+
Multiple compiled results can be associated with a frame up to
|
| 2340 |
+
``torch._dynamo.config.cache_size_limit``, which defaults to 8; at which
|
| 2341 |
+
point we will fall back to eager. Note that compile caches are per
|
| 2342 |
+
*code object*, not frame; if you dynamically create multiple copies of a
|
| 2343 |
+
function, they will all share the same code cache.
|
| 2344 |
+
|
| 2345 |
+
Args:
|
| 2346 |
+
model (Callable): Module/function to optimize
|
| 2347 |
+
fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions
|
| 2348 |
+
in the function that it will optimize. If True, then we require that the entire function be
|
| 2349 |
+
capturable into a single graph. If this is not possible (that is, if there are graph breaks),
|
| 2350 |
+
then this will raise an error.
|
| 2351 |
+
dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
|
| 2352 |
+
to generate a kernel that is as dynamic as possible to avoid recompilations when
|
| 2353 |
+
sizes change. This may not always work as some operations/optimizations will
|
| 2354 |
+
force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
|
| 2355 |
+
When this is False, we will NEVER generate dynamic kernels, we will always specialize.
|
| 2356 |
+
By default (None), we automatically detect if dynamism has occurred and compile a more
|
| 2357 |
+
dynamic kernel upon recompile.
|
| 2358 |
+
backend (str or Callable): backend to be used
|
| 2359 |
+
|
| 2360 |
+
- "inductor" is the default backend, which is a good balance between performance and overhead
|
| 2361 |
+
|
| 2362 |
+
- Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`
|
| 2363 |
+
|
| 2364 |
+
- Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)`
|
| 2365 |
+
|
| 2366 |
+
- To register an out-of-tree custom backend:
|
| 2367 |
+
https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends
|
| 2368 |
+
mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs"
|
| 2369 |
+
|
| 2370 |
+
- "default" is the default mode, which is a good balance between performance and overhead
|
| 2371 |
+
|
| 2372 |
+
- "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs,
|
| 2373 |
+
useful for small batches. Reduction of overhead can come at the cost of more memory
|
| 2374 |
+
usage, as we will cache the workspace memory required for the invocation so that we
|
| 2375 |
+
do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed
|
| 2376 |
+
to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs.
|
| 2377 |
+
There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints
|
| 2378 |
+
to debug.
|
| 2379 |
+
|
| 2380 |
+
- "max-autotune" is a mode that leverages Triton or template based matrix multiplications
|
| 2381 |
+
on supported devices and Triton based convolutions on GPU.
|
| 2382 |
+
It enables CUDA graphs by default on GPU.
|
| 2383 |
+
|
| 2384 |
+
- "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs
|
| 2385 |
+
|
| 2386 |
+
- To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()`
|
| 2387 |
+
|
| 2388 |
+
options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are
|
| 2389 |
+
|
| 2390 |
+
- `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set
|
| 2391 |
+
|
| 2392 |
+
- `max_autotune` which will profile to pick the best matmul configuration
|
| 2393 |
+
|
| 2394 |
+
- `fallback_random` which is useful when debugging accuracy issues
|
| 2395 |
+
|
| 2396 |
+
- `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores
|
| 2397 |
+
|
| 2398 |
+
- `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs
|
| 2399 |
+
|
| 2400 |
+
- `trace.enabled` which is the most useful debugging flag to turn on
|
| 2401 |
+
|
| 2402 |
+
- `trace.graph_diagram` which will show you a picture of your graph after fusion
|
| 2403 |
+
|
| 2404 |
+
- For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()`
|
| 2405 |
+
disable (bool): Turn torch.compile() into a no-op for testing
|
| 2406 |
+
|
| 2407 |
+
Example::
|
| 2408 |
+
|
| 2409 |
+
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
|
| 2410 |
+
def foo(x):
|
| 2411 |
+
return torch.sin(x) + torch.cos(x)
|
| 2412 |
+
|
| 2413 |
+
"""
|
| 2414 |
+
_C._log_api_usage_once("torch.compile")
|
| 2415 |
+
if sys.version_info >= (3, 13):
|
| 2416 |
+
raise RuntimeError("Dynamo is not supported on Python 3.13+")
|
| 2417 |
+
|
| 2418 |
+
# Decorator mode
|
| 2419 |
+
if model is None:
|
| 2420 |
+
|
| 2421 |
+
def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
|
| 2422 |
+
if model is None:
|
| 2423 |
+
raise RuntimeError("Model can't be None")
|
| 2424 |
+
return compile(
|
| 2425 |
+
model,
|
| 2426 |
+
fullgraph=fullgraph,
|
| 2427 |
+
dynamic=dynamic,
|
| 2428 |
+
backend=backend,
|
| 2429 |
+
mode=mode,
|
| 2430 |
+
options=options,
|
| 2431 |
+
disable=disable,
|
| 2432 |
+
)
|
| 2433 |
+
|
| 2434 |
+
return fn
|
| 2435 |
+
|
| 2436 |
+
if mode is not None and options is not None:
|
| 2437 |
+
raise RuntimeError(
|
| 2438 |
+
"Either mode or options can be specified, but both can't be specified at the same time."
|
| 2439 |
+
)
|
| 2440 |
+
if mode is None and options is None:
|
| 2441 |
+
mode = "default"
|
| 2442 |
+
if backend == "inductor":
|
| 2443 |
+
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
| 2444 |
+
else:
|
| 2445 |
+
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
|
| 2446 |
+
|
| 2447 |
+
return torch._dynamo.optimize(
|
| 2448 |
+
backend=backend,
|
| 2449 |
+
nopython=fullgraph,
|
| 2450 |
+
dynamic=dynamic,
|
| 2451 |
+
disable=disable,
|
| 2452 |
+
)(model) # type: ignore[return-value]
|
| 2453 |
+
|
| 2454 |
+
|
| 2455 |
+
def _register_device_module(device_type, module):
|
| 2456 |
+
r"""Register an external runtime module of the specific :attr:`device_type`
|
| 2457 |
+
supported by torch.
|
| 2458 |
+
|
| 2459 |
+
After the :attr:`module` is registered correctly, the user can refer
|
| 2460 |
+
the external runtime module as part of torch with attribute torch.xxx.
|
| 2461 |
+
"""
|
| 2462 |
+
# Make sure the device_type represent a supported device type for torch.
|
| 2463 |
+
device_type = torch.device(device_type).type
|
| 2464 |
+
m = sys.modules[__name__]
|
| 2465 |
+
if hasattr(m, device_type):
|
| 2466 |
+
raise RuntimeError(
|
| 2467 |
+
f"The runtime module of '{device_type}' has already "
|
| 2468 |
+
f"been registered with '{getattr(m, device_type)}'"
|
| 2469 |
+
)
|
| 2470 |
+
setattr(m, device_type, module)
|
| 2471 |
+
torch_module_name = ".".join([__name__, device_type])
|
| 2472 |
+
sys.modules[torch_module_name] = module
|
| 2473 |
+
|
| 2474 |
+
|
| 2475 |
+
from torch import (
|
| 2476 |
+
export as export,
|
| 2477 |
+
func as func,
|
| 2478 |
+
library as library,
|
| 2479 |
+
return_types as return_types,
|
| 2480 |
+
)
|
| 2481 |
+
from torch._higher_order_ops import cond as cond, while_loop as while_loop
|
| 2482 |
+
from torch.func import vmap as vmap
|
| 2483 |
+
|
| 2484 |
+
|
| 2485 |
+
if not TYPE_CHECKING:
|
| 2486 |
+
from torch import _meta_registrations
|
| 2487 |
+
|
| 2488 |
+
# Enable CUDA Sanitizer
|
| 2489 |
+
if "TORCH_CUDA_SANITIZER" in os.environ:
|
| 2490 |
+
import torch.cuda._sanitizer as csan
|
| 2491 |
+
|
| 2492 |
+
csan.enable_cuda_sanitizer()
|
| 2493 |
+
|
| 2494 |
+
# Populate magic methods on SymInt and SymFloat
|
| 2495 |
+
import torch.fx.experimental.sym_node
|
| 2496 |
+
|
| 2497 |
+
|
| 2498 |
+
# Register MPS specific decomps
|
| 2499 |
+
torch.backends.mps._init()
|
| 2500 |
+
|
| 2501 |
+
if not _running_with_deploy():
|
| 2502 |
+
from torch import compiler as compiler
|
| 2503 |
+
|
| 2504 |
+
class _TritonLibrary:
|
| 2505 |
+
lib = torch.library.Library("triton", "DEF")
|
| 2506 |
+
ops_table: _Dict[_Tuple[str, str], _Callable] = {}
|
| 2507 |
+
|
| 2508 |
+
@classmethod
|
| 2509 |
+
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
|
| 2510 |
+
if (op_key, dispatch_key) not in cls.ops_table:
|
| 2511 |
+
cls.lib.define(full_schema)
|
| 2512 |
+
cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
|
| 2513 |
+
cls.ops_table[(op_key, dispatch_key)] = op_impl
|
| 2514 |
+
|
| 2515 |
+
return cls.ops_table[(op_key, dispatch_key)]
|
| 2516 |
+
|
| 2517 |
+
|
| 2518 |
+
# Deprecated attributes
|
| 2519 |
+
_deprecated_attrs = {
|
| 2520 |
+
"has_mps": torch.backends.mps.is_built,
|
| 2521 |
+
"has_cuda": torch.backends.cuda.is_built,
|
| 2522 |
+
"has_cudnn": torch.backends.cudnn.is_available,
|
| 2523 |
+
"has_mkldnn": torch.backends.mkldnn.is_available,
|
| 2524 |
+
}
|
| 2525 |
+
|
| 2526 |
+
if TYPE_CHECKING:
|
| 2527 |
+
# Import the following modules during type checking to enable code intelligence features,
|
| 2528 |
+
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
| 2529 |
+
# imported in user code.
|
| 2530 |
+
from torch import (
|
| 2531 |
+
_dynamo as _dynamo,
|
| 2532 |
+
_inductor as _inductor,
|
| 2533 |
+
_subclasses as _subclasses,
|
| 2534 |
+
onnx as onnx,
|
| 2535 |
+
)
|
| 2536 |
+
|
| 2537 |
+
else:
|
| 2538 |
+
_lazy_modules = {
|
| 2539 |
+
"_dynamo",
|
| 2540 |
+
"_inductor",
|
| 2541 |
+
"_export",
|
| 2542 |
+
# ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
|
| 2543 |
+
"onnx",
|
| 2544 |
+
}
|
| 2545 |
+
|
| 2546 |
+
def __getattr__(name):
|
| 2547 |
+
# Deprecated attrs
|
| 2548 |
+
replacement = _deprecated_attrs.get(name)
|
| 2549 |
+
if replacement is not None:
|
| 2550 |
+
import warnings
|
| 2551 |
+
|
| 2552 |
+
warnings.warn(
|
| 2553 |
+
f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'",
|
| 2554 |
+
stacklevel=2,
|
| 2555 |
+
)
|
| 2556 |
+
return replacement()
|
| 2557 |
+
|
| 2558 |
+
# Lazy modules
|
| 2559 |
+
if name in _lazy_modules:
|
| 2560 |
+
return importlib.import_module(f".{name}", __name__)
|
| 2561 |
+
|
| 2562 |
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
| 2563 |
+
|
| 2564 |
+
|
| 2565 |
+
def get_device_module(device: _Optional[_Union[torch.device, str]] = None):
|
| 2566 |
+
"""
|
| 2567 |
+
Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
|
| 2568 |
+
If no device is given, return the module for the current accelerator or CPU if none is present.
|
| 2569 |
+
"""
|
| 2570 |
+
if isinstance(device, torch.device):
|
| 2571 |
+
device_module_name = device.type
|
| 2572 |
+
elif isinstance(device, str):
|
| 2573 |
+
device_module_name = torch.device(device).type
|
| 2574 |
+
elif device is None:
|
| 2575 |
+
# Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
|
| 2576 |
+
device_module_name = torch._C._get_accelerator().type
|
| 2577 |
+
else:
|
| 2578 |
+
raise RuntimeError(
|
| 2579 |
+
f"Invalid value of device '{device}', expect torch.device, str, or None"
|
| 2580 |
+
)
|
| 2581 |
+
device_module = getattr(torch, device_module_name, None)
|
| 2582 |
+
if device_module is None:
|
| 2583 |
+
raise RuntimeError(
|
| 2584 |
+
f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
|
| 2585 |
+
)
|
| 2586 |
+
return device_module
|
| 2587 |
+
|
| 2588 |
+
|
| 2589 |
+
def _constrain_as_size(
|
| 2590 |
+
symbol,
|
| 2591 |
+
min: _Optional[builtins.int] = None,
|
| 2592 |
+
max: _Optional[builtins.int] = None,
|
| 2593 |
+
):
|
| 2594 |
+
"""
|
| 2595 |
+
This indicates that a given int is size-like, and can be used in any context where a size is expected.
|
| 2596 |
+
You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist()
|
| 2597 |
+
which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve
|
| 2598 |
+
GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts.
|
| 2599 |
+
|
| 2600 |
+
This function has unusual semantics in some circumstances in framework
|
| 2601 |
+
code, we will treat this int as >= 2 (when we do a size-oblivious guard).
|
| 2602 |
+
This makes it easier to use the unbacked int in size contexts,
|
| 2603 |
+
as we will often attempt to guard on a size being zero/one
|
| 2604 |
+
(e.g., when computing the contiguity of a tensor, or testing if
|
| 2605 |
+
broadcasting can occur), which will not work on unbacked SymInts.
|
| 2606 |
+
However, if we conservatively assume that the size is not zero/one, we will
|
| 2607 |
+
end up with a graph that will still work even if the size is zero/one.
|
| 2608 |
+
|
| 2609 |
+
For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit
|
| 2610 |
+
```
|
| 2611 |
+
"""
|
| 2612 |
+
torch.sym_constrain_range_for_size(symbol, min=min, max=max)
|
| 2613 |
+
|
| 2614 |
+
|
| 2615 |
+
from torch import _logging
|
| 2616 |
+
|
| 2617 |
+
|
| 2618 |
+
_logging._init_logs()
|
| 2619 |
+
|
| 2620 |
+
|
| 2621 |
+
def _import_device_backends():
|
| 2622 |
+
"""
|
| 2623 |
+
Leverage the Python plugin mechanism to load out-of-the-tree device extensions.
|
| 2624 |
+
See this RFC: https://github.com/pytorch/pytorch/issues/122468
|
| 2625 |
+
"""
|
| 2626 |
+
from importlib.metadata import entry_points
|
| 2627 |
+
|
| 2628 |
+
group_name = "torch.backends"
|
| 2629 |
+
if sys.version_info < (3, 10):
|
| 2630 |
+
backend_extensions = entry_points().get(group_name, ())
|
| 2631 |
+
else:
|
| 2632 |
+
backend_extensions = entry_points(group=group_name)
|
| 2633 |
+
|
| 2634 |
+
for backend_extension in backend_extensions:
|
| 2635 |
+
try:
|
| 2636 |
+
# Load the extension
|
| 2637 |
+
entrypoint = backend_extension.load()
|
| 2638 |
+
# Call the entrypoint
|
| 2639 |
+
entrypoint()
|
| 2640 |
+
except Exception as err:
|
| 2641 |
+
raise RuntimeError(
|
| 2642 |
+
f"Failed to load the backend extension: {backend_extension.name}. "
|
| 2643 |
+
f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0."
|
| 2644 |
+
) from err
|
| 2645 |
+
|
| 2646 |
+
|
| 2647 |
+
def _is_device_backend_autoload_enabled() -> builtins.bool:
|
| 2648 |
+
"""
|
| 2649 |
+
Whether autoloading out-of-the-tree device extensions is enabled.
|
| 2650 |
+
The switch depends on the value of the environment variable
|
| 2651 |
+
`TORCH_DEVICE_BACKEND_AUTOLOAD`.
|
| 2652 |
+
|
| 2653 |
+
Returns:
|
| 2654 |
+
bool: Whether to enable autoloading the extensions. Enabled by default.
|
| 2655 |
+
|
| 2656 |
+
Examples:
|
| 2657 |
+
>>> torch._is_device_backend_autoload_enabled()
|
| 2658 |
+
True
|
| 2659 |
+
"""
|
| 2660 |
+
# enabled by default
|
| 2661 |
+
return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1"
|
| 2662 |
+
|
| 2663 |
+
|
| 2664 |
+
if _is_device_backend_autoload_enabled():
|
| 2665 |
+
_import_device_backends()
|
.venv/lib/python3.11/site-packages/torch/_appdirs.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright (c) 2005-2010 ActiveState Software Inc.
|
| 4 |
+
# Copyright (c) 2013 Eddy Petrișor
|
| 5 |
+
|
| 6 |
+
# flake8: noqa
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
This file is directly from
|
| 10 |
+
https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py
|
| 11 |
+
|
| 12 |
+
The license of https://github.com/ActiveState/appdirs copied below:
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# This is the MIT license
|
| 16 |
+
|
| 17 |
+
Copyright (c) 2010 ActiveState Software Inc.
|
| 18 |
+
|
| 19 |
+
Permission is hereby granted, free of charge, to any person obtaining a
|
| 20 |
+
copy of this software and associated documentation files (the
|
| 21 |
+
"Software"), to deal in the Software without restriction, including
|
| 22 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
| 23 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
| 24 |
+
permit persons to whom the Software is furnished to do so, subject to
|
| 25 |
+
the following conditions:
|
| 26 |
+
|
| 27 |
+
The above copyright notice and this permission notice shall be included
|
| 28 |
+
in all copies or substantial portions of the Software.
|
| 29 |
+
|
| 30 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
| 31 |
+
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 32 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 33 |
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
| 34 |
+
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
| 35 |
+
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
| 36 |
+
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
"""Utilities for determining application-specific dirs.
|
| 40 |
+
|
| 41 |
+
See <https://github.com/ActiveState/appdirs> for details and usage.
|
| 42 |
+
"""
|
| 43 |
+
# Dev Notes:
|
| 44 |
+
# - MSDN on where to store app data files:
|
| 45 |
+
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
|
| 46 |
+
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
|
| 47 |
+
# - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
|
| 48 |
+
|
| 49 |
+
__version__ = "1.4.4"
|
| 50 |
+
__version_info__ = tuple(int(segment) for segment in __version__.split("."))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
import os
|
| 54 |
+
import sys
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
unicode = str
|
| 58 |
+
|
| 59 |
+
if sys.platform.startswith("java"):
|
| 60 |
+
import platform
|
| 61 |
+
|
| 62 |
+
os_name = platform.java_ver()[3][0]
|
| 63 |
+
if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc.
|
| 64 |
+
system = "win32"
|
| 65 |
+
elif os_name.startswith("Mac"): # "Mac OS X", etc.
|
| 66 |
+
system = "darwin"
|
| 67 |
+
else: # "Linux", "SunOS", "FreeBSD", etc.
|
| 68 |
+
# Setting this to "linux2" is not ideal, but only Windows or Mac
|
| 69 |
+
# are actually checked for and the rest of the module expects
|
| 70 |
+
# *sys.platform* style strings.
|
| 71 |
+
system = "linux2"
|
| 72 |
+
else:
|
| 73 |
+
system = sys.platform
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
|
| 77 |
+
r"""Return full path to the user-specific data dir for this application.
|
| 78 |
+
|
| 79 |
+
"appname" is the name of application.
|
| 80 |
+
If None, just the system directory is returned.
|
| 81 |
+
"appauthor" (only used on Windows) is the name of the
|
| 82 |
+
appauthor or distributing body for this application. Typically
|
| 83 |
+
it is the owning company name. This falls back to appname. You may
|
| 84 |
+
pass False to disable it.
|
| 85 |
+
"version" is an optional version path element to append to the
|
| 86 |
+
path. You might want to use this if you want multiple versions
|
| 87 |
+
of your app to be able to run independently. If used, this
|
| 88 |
+
would typically be "<major>.<minor>".
|
| 89 |
+
Only applied when appname is present.
|
| 90 |
+
"roaming" (boolean, default False) can be set True to use the Windows
|
| 91 |
+
roaming appdata directory. That means that for users on a Windows
|
| 92 |
+
network setup for roaming profiles, this user data will be
|
| 93 |
+
sync'd on login. See
|
| 94 |
+
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
| 95 |
+
for a discussion of issues.
|
| 96 |
+
|
| 97 |
+
Typical user data directories are:
|
| 98 |
+
Mac OS X: ~/Library/Application Support/<AppName>
|
| 99 |
+
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
|
| 100 |
+
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
|
| 101 |
+
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
|
| 102 |
+
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
|
| 103 |
+
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
|
| 104 |
+
|
| 105 |
+
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
|
| 106 |
+
That means, by default "~/.local/share/<AppName>".
|
| 107 |
+
"""
|
| 108 |
+
if system == "win32":
|
| 109 |
+
if appauthor is None:
|
| 110 |
+
appauthor = appname
|
| 111 |
+
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
|
| 112 |
+
path = os.path.normpath(_get_win_folder(const))
|
| 113 |
+
if appname:
|
| 114 |
+
if appauthor is not False:
|
| 115 |
+
path = os.path.join(path, appauthor, appname)
|
| 116 |
+
else:
|
| 117 |
+
path = os.path.join(path, appname)
|
| 118 |
+
elif system == "darwin":
|
| 119 |
+
path = os.path.expanduser("~/Library/Application Support/")
|
| 120 |
+
if appname:
|
| 121 |
+
path = os.path.join(path, appname)
|
| 122 |
+
else:
|
| 123 |
+
path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
|
| 124 |
+
if appname:
|
| 125 |
+
path = os.path.join(path, appname)
|
| 126 |
+
if appname and version:
|
| 127 |
+
path = os.path.join(path, version)
|
| 128 |
+
return path
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
|
| 132 |
+
r"""Return full path to the user-shared data dir for this application.
|
| 133 |
+
|
| 134 |
+
"appname" is the name of application.
|
| 135 |
+
If None, just the system directory is returned.
|
| 136 |
+
"appauthor" (only used on Windows) is the name of the
|
| 137 |
+
appauthor or distributing body for this application. Typically
|
| 138 |
+
it is the owning company name. This falls back to appname. You may
|
| 139 |
+
pass False to disable it.
|
| 140 |
+
"version" is an optional version path element to append to the
|
| 141 |
+
path. You might want to use this if you want multiple versions
|
| 142 |
+
of your app to be able to run independently. If used, this
|
| 143 |
+
would typically be "<major>.<minor>".
|
| 144 |
+
Only applied when appname is present.
|
| 145 |
+
"multipath" is an optional parameter only applicable to *nix
|
| 146 |
+
which indicates that the entire list of data dirs should be
|
| 147 |
+
returned. By default, the first item from XDG_DATA_DIRS is
|
| 148 |
+
returned, or '/usr/local/share/<AppName>',
|
| 149 |
+
if XDG_DATA_DIRS is not set
|
| 150 |
+
|
| 151 |
+
Typical site data directories are:
|
| 152 |
+
Mac OS X: /Library/Application Support/<AppName>
|
| 153 |
+
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
|
| 154 |
+
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
|
| 155 |
+
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
| 156 |
+
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
|
| 157 |
+
|
| 158 |
+
For Unix, this is using the $XDG_DATA_DIRS[0] default.
|
| 159 |
+
|
| 160 |
+
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
| 161 |
+
"""
|
| 162 |
+
if system == "win32":
|
| 163 |
+
if appauthor is None:
|
| 164 |
+
appauthor = appname
|
| 165 |
+
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
|
| 166 |
+
if appname:
|
| 167 |
+
if appauthor is not False:
|
| 168 |
+
path = os.path.join(path, appauthor, appname)
|
| 169 |
+
else:
|
| 170 |
+
path = os.path.join(path, appname)
|
| 171 |
+
elif system == "darwin":
|
| 172 |
+
path = os.path.expanduser("/Library/Application Support")
|
| 173 |
+
if appname:
|
| 174 |
+
path = os.path.join(path, appname)
|
| 175 |
+
else:
|
| 176 |
+
# XDG default for $XDG_DATA_DIRS
|
| 177 |
+
# only first, if multipath is False
|
| 178 |
+
path = os.getenv(
|
| 179 |
+
"XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"])
|
| 180 |
+
)
|
| 181 |
+
pathlist = [
|
| 182 |
+
os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
|
| 183 |
+
]
|
| 184 |
+
if appname:
|
| 185 |
+
if version:
|
| 186 |
+
appname = os.path.join(appname, version)
|
| 187 |
+
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
| 188 |
+
|
| 189 |
+
if multipath:
|
| 190 |
+
path = os.pathsep.join(pathlist)
|
| 191 |
+
else:
|
| 192 |
+
path = pathlist[0]
|
| 193 |
+
return path
|
| 194 |
+
|
| 195 |
+
if appname and version:
|
| 196 |
+
path = os.path.join(path, version)
|
| 197 |
+
return path
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
|
| 201 |
+
r"""Return full path to the user-specific config dir for this application.
|
| 202 |
+
|
| 203 |
+
"appname" is the name of application.
|
| 204 |
+
If None, just the system directory is returned.
|
| 205 |
+
"appauthor" (only used on Windows) is the name of the
|
| 206 |
+
appauthor or distributing body for this application. Typically
|
| 207 |
+
it is the owning company name. This falls back to appname. You may
|
| 208 |
+
pass False to disable it.
|
| 209 |
+
"version" is an optional version path element to append to the
|
| 210 |
+
path. You might want to use this if you want multiple versions
|
| 211 |
+
of your app to be able to run independently. If used, this
|
| 212 |
+
would typically be "<major>.<minor>".
|
| 213 |
+
Only applied when appname is present.
|
| 214 |
+
"roaming" (boolean, default False) can be set True to use the Windows
|
| 215 |
+
roaming appdata directory. That means that for users on a Windows
|
| 216 |
+
network setup for roaming profiles, this user data will be
|
| 217 |
+
sync'd on login. See
|
| 218 |
+
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
| 219 |
+
for a discussion of issues.
|
| 220 |
+
|
| 221 |
+
Typical user config directories are:
|
| 222 |
+
Mac OS X: ~/Library/Preferences/<AppName>
|
| 223 |
+
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
|
| 224 |
+
Win *: same as user_data_dir
|
| 225 |
+
|
| 226 |
+
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
|
| 227 |
+
That means, by default "~/.config/<AppName>".
|
| 228 |
+
"""
|
| 229 |
+
if system == "win32":
|
| 230 |
+
path = user_data_dir(appname, appauthor, None, roaming)
|
| 231 |
+
elif system == "darwin":
|
| 232 |
+
path = os.path.expanduser("~/Library/Preferences/")
|
| 233 |
+
if appname:
|
| 234 |
+
path = os.path.join(path, appname)
|
| 235 |
+
else:
|
| 236 |
+
path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
|
| 237 |
+
if appname:
|
| 238 |
+
path = os.path.join(path, appname)
|
| 239 |
+
if appname and version:
|
| 240 |
+
path = os.path.join(path, version)
|
| 241 |
+
return path
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
|
| 245 |
+
r"""Return full path to the user-shared data dir for this application.
|
| 246 |
+
|
| 247 |
+
"appname" is the name of application.
|
| 248 |
+
If None, just the system directory is returned.
|
| 249 |
+
"appauthor" (only used on Windows) is the name of the
|
| 250 |
+
appauthor or distributing body for this application. Typically
|
| 251 |
+
it is the owning company name. This falls back to appname. You may
|
| 252 |
+
pass False to disable it.
|
| 253 |
+
"version" is an optional version path element to append to the
|
| 254 |
+
path. You might want to use this if you want multiple versions
|
| 255 |
+
of your app to be able to run independently. If used, this
|
| 256 |
+
would typically be "<major>.<minor>".
|
| 257 |
+
Only applied when appname is present.
|
| 258 |
+
"multipath" is an optional parameter only applicable to *nix
|
| 259 |
+
which indicates that the entire list of config dirs should be
|
| 260 |
+
returned. By default, the first item from XDG_CONFIG_DIRS is
|
| 261 |
+
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
|
| 262 |
+
|
| 263 |
+
Typical site config directories are:
|
| 264 |
+
Mac OS X: same as site_data_dir
|
| 265 |
+
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
|
| 266 |
+
$XDG_CONFIG_DIRS
|
| 267 |
+
Win *: same as site_data_dir
|
| 268 |
+
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
| 269 |
+
|
| 270 |
+
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
|
| 271 |
+
|
| 272 |
+
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
| 273 |
+
"""
|
| 274 |
+
if system == "win32":
|
| 275 |
+
path = site_data_dir(appname, appauthor)
|
| 276 |
+
if appname and version:
|
| 277 |
+
path = os.path.join(path, version)
|
| 278 |
+
elif system == "darwin":
|
| 279 |
+
path = os.path.expanduser("/Library/Preferences")
|
| 280 |
+
if appname:
|
| 281 |
+
path = os.path.join(path, appname)
|
| 282 |
+
else:
|
| 283 |
+
# XDG default for $XDG_CONFIG_DIRS
|
| 284 |
+
# only first, if multipath is False
|
| 285 |
+
path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg")
|
| 286 |
+
pathlist = [
|
| 287 |
+
os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
|
| 288 |
+
]
|
| 289 |
+
if appname:
|
| 290 |
+
if version:
|
| 291 |
+
appname = os.path.join(appname, version)
|
| 292 |
+
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
| 293 |
+
|
| 294 |
+
if multipath:
|
| 295 |
+
path = os.pathsep.join(pathlist)
|
| 296 |
+
else:
|
| 297 |
+
path = pathlist[0]
|
| 298 |
+
return path
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
|
| 302 |
+
r"""Return full path to the user-specific cache dir for this application.
|
| 303 |
+
|
| 304 |
+
"appname" is the name of application.
|
| 305 |
+
If None, just the system directory is returned.
|
| 306 |
+
"appauthor" (only used on Windows) is the name of the
|
| 307 |
+
appauthor or distributing body for this application. Typically
|
| 308 |
+
it is the owning company name. This falls back to appname. You may
|
| 309 |
+
pass False to disable it.
|
| 310 |
+
"version" is an optional version path element to append to the
|
| 311 |
+
path. You might want to use this if you want multiple versions
|
| 312 |
+
of your app to be able to run independently. If used, this
|
| 313 |
+
would typically be "<major>.<minor>".
|
| 314 |
+
Only applied when appname is present.
|
| 315 |
+
"opinion" (boolean) can be False to disable the appending of
|
| 316 |
+
"Cache" to the base app data dir for Windows. See
|
| 317 |
+
discussion below.
|
| 318 |
+
|
| 319 |
+
Typical user cache directories are:
|
| 320 |
+
Mac OS X: ~/Library/Caches/<AppName>
|
| 321 |
+
Unix: ~/.cache/<AppName> (XDG default)
|
| 322 |
+
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
|
| 323 |
+
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
|
| 324 |
+
|
| 325 |
+
On Windows the only suggestion in the MSDN docs is that local settings go in
|
| 326 |
+
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
|
| 327 |
+
app data dir (the default returned by `user_data_dir` above). Apps typically
|
| 328 |
+
put cache data somewhere *under* the given dir here. Some examples:
|
| 329 |
+
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
|
| 330 |
+
...\Acme\SuperApp\Cache\1.0
|
| 331 |
+
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
|
| 332 |
+
This can be disabled with the `opinion=False` option.
|
| 333 |
+
"""
|
| 334 |
+
if system == "win32":
|
| 335 |
+
if appauthor is None:
|
| 336 |
+
appauthor = appname
|
| 337 |
+
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
|
| 338 |
+
if appname:
|
| 339 |
+
if appauthor is not False:
|
| 340 |
+
path = os.path.join(path, appauthor, appname)
|
| 341 |
+
else:
|
| 342 |
+
path = os.path.join(path, appname)
|
| 343 |
+
if opinion:
|
| 344 |
+
path = os.path.join(path, "Cache")
|
| 345 |
+
elif system == "darwin":
|
| 346 |
+
path = os.path.expanduser("~/Library/Caches")
|
| 347 |
+
if appname:
|
| 348 |
+
path = os.path.join(path, appname)
|
| 349 |
+
else:
|
| 350 |
+
path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
| 351 |
+
if appname:
|
| 352 |
+
path = os.path.join(path, appname)
|
| 353 |
+
if appname and version:
|
| 354 |
+
path = os.path.join(path, version)
|
| 355 |
+
return path
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
|
| 359 |
+
r"""Return full path to the user-specific state dir for this application.
|
| 360 |
+
|
| 361 |
+
"appname" is the name of application.
|
| 362 |
+
If None, just the system directory is returned.
|
| 363 |
+
"appauthor" (only used on Windows) is the name of the
|
| 364 |
+
appauthor or distributing body for this application. Typically
|
| 365 |
+
it is the owning company name. This falls back to appname. You may
|
| 366 |
+
pass False to disable it.
|
| 367 |
+
"version" is an optional version path element to append to the
|
| 368 |
+
path. You might want to use this if you want multiple versions
|
| 369 |
+
of your app to be able to run independently. If used, this
|
| 370 |
+
would typically be "<major>.<minor>".
|
| 371 |
+
Only applied when appname is present.
|
| 372 |
+
"roaming" (boolean, default False) can be set True to use the Windows
|
| 373 |
+
roaming appdata directory. That means that for users on a Windows
|
| 374 |
+
network setup for roaming profiles, this user data will be
|
| 375 |
+
sync'd on login. See
|
| 376 |
+
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
| 377 |
+
for a discussion of issues.
|
| 378 |
+
|
| 379 |
+
Typical user state directories are:
|
| 380 |
+
Mac OS X: same as user_data_dir
|
| 381 |
+
Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
|
| 382 |
+
Win *: same as user_data_dir
|
| 383 |
+
|
| 384 |
+
For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
|
| 385 |
+
to extend the XDG spec and support $XDG_STATE_HOME.
|
| 386 |
+
|
| 387 |
+
That means, by default "~/.local/state/<AppName>".
|
| 388 |
+
"""
|
| 389 |
+
if system in ["win32", "darwin"]:
|
| 390 |
+
path = user_data_dir(appname, appauthor, None, roaming)
|
| 391 |
+
else:
|
| 392 |
+
path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state"))
|
| 393 |
+
if appname:
|
| 394 |
+
path = os.path.join(path, appname)
|
| 395 |
+
if appname and version:
|
| 396 |
+
path = os.path.join(path, version)
|
| 397 |
+
return path
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
|
| 401 |
+
r"""Return full path to the user-specific log dir for this application.
|
| 402 |
+
|
| 403 |
+
"appname" is the name of application.
|
| 404 |
+
If None, just the system directory is returned.
|
| 405 |
+
"appauthor" (only used on Windows) is the name of the
|
| 406 |
+
appauthor or distributing body for this application. Typically
|
| 407 |
+
it is the owning company name. This falls back to appname. You may
|
| 408 |
+
pass False to disable it.
|
| 409 |
+
"version" is an optional version path element to append to the
|
| 410 |
+
path. You might want to use this if you want multiple versions
|
| 411 |
+
of your app to be able to run independently. If used, this
|
| 412 |
+
would typically be "<major>.<minor>".
|
| 413 |
+
Only applied when appname is present.
|
| 414 |
+
"opinion" (boolean) can be False to disable the appending of
|
| 415 |
+
"Logs" to the base app data dir for Windows, and "log" to the
|
| 416 |
+
base cache dir for Unix. See discussion below.
|
| 417 |
+
|
| 418 |
+
Typical user log directories are:
|
| 419 |
+
Mac OS X: ~/Library/Logs/<AppName>
|
| 420 |
+
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
|
| 421 |
+
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
|
| 422 |
+
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
|
| 423 |
+
|
| 424 |
+
On Windows the only suggestion in the MSDN docs is that local settings
|
| 425 |
+
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
|
| 426 |
+
examples of what some windows apps use for a logs dir.)
|
| 427 |
+
|
| 428 |
+
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
|
| 429 |
+
value for Windows and appends "log" to the user cache dir for Unix.
|
| 430 |
+
This can be disabled with the `opinion=False` option.
|
| 431 |
+
"""
|
| 432 |
+
if system == "darwin":
|
| 433 |
+
path = os.path.join(os.path.expanduser("~/Library/Logs"), appname)
|
| 434 |
+
elif system == "win32":
|
| 435 |
+
path = user_data_dir(appname, appauthor, version)
|
| 436 |
+
version = False
|
| 437 |
+
if opinion:
|
| 438 |
+
path = os.path.join(path, "Logs")
|
| 439 |
+
else:
|
| 440 |
+
path = user_cache_dir(appname, appauthor, version)
|
| 441 |
+
version = False
|
| 442 |
+
if opinion:
|
| 443 |
+
path = os.path.join(path, "log")
|
| 444 |
+
if appname and version:
|
| 445 |
+
path = os.path.join(path, version)
|
| 446 |
+
return path
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class AppDirs(object):
|
| 450 |
+
"""Convenience wrapper for getting application dirs."""
|
| 451 |
+
|
| 452 |
+
def __init__(
|
| 453 |
+
self, appname=None, appauthor=None, version=None, roaming=False, multipath=False
|
| 454 |
+
):
|
| 455 |
+
self.appname = appname
|
| 456 |
+
self.appauthor = appauthor
|
| 457 |
+
self.version = version
|
| 458 |
+
self.roaming = roaming
|
| 459 |
+
self.multipath = multipath
|
| 460 |
+
|
| 461 |
+
@property
|
| 462 |
+
def user_data_dir(self):
|
| 463 |
+
return user_data_dir(
|
| 464 |
+
self.appname, self.appauthor, version=self.version, roaming=self.roaming
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
@property
|
| 468 |
+
def site_data_dir(self):
|
| 469 |
+
return site_data_dir(
|
| 470 |
+
self.appname, self.appauthor, version=self.version, multipath=self.multipath
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
@property
|
| 474 |
+
def user_config_dir(self):
|
| 475 |
+
return user_config_dir(
|
| 476 |
+
self.appname, self.appauthor, version=self.version, roaming=self.roaming
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
@property
|
| 480 |
+
def site_config_dir(self):
|
| 481 |
+
return site_config_dir(
|
| 482 |
+
self.appname, self.appauthor, version=self.version, multipath=self.multipath
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
@property
|
| 486 |
+
def user_cache_dir(self):
|
| 487 |
+
return user_cache_dir(self.appname, self.appauthor, version=self.version)
|
| 488 |
+
|
| 489 |
+
@property
|
| 490 |
+
def user_state_dir(self):
|
| 491 |
+
return user_state_dir(self.appname, self.appauthor, version=self.version)
|
| 492 |
+
|
| 493 |
+
@property
|
| 494 |
+
def user_log_dir(self):
|
| 495 |
+
return user_log_dir(self.appname, self.appauthor, version=self.version)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# ---- internal support stuff
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def _get_win_folder_from_registry(csidl_name):
|
| 502 |
+
"""This is a fallback technique at best. I'm not sure if using the
|
| 503 |
+
registry for this guarantees us the correct answer for all CSIDL_*
|
| 504 |
+
names.
|
| 505 |
+
"""
|
| 506 |
+
import winreg as _winreg
|
| 507 |
+
|
| 508 |
+
shell_folder_name = {
|
| 509 |
+
"CSIDL_APPDATA": "AppData",
|
| 510 |
+
"CSIDL_COMMON_APPDATA": "Common AppData",
|
| 511 |
+
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
| 512 |
+
}[csidl_name]
|
| 513 |
+
|
| 514 |
+
key = _winreg.OpenKey(
|
| 515 |
+
_winreg.HKEY_CURRENT_USER,
|
| 516 |
+
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
|
| 517 |
+
)
|
| 518 |
+
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
|
| 519 |
+
return dir
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def _get_win_folder_with_pywin32(csidl_name):
|
| 523 |
+
from win32com.shell import shell, shellcon
|
| 524 |
+
|
| 525 |
+
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
|
| 526 |
+
# Try to make this a unicode path because SHGetFolderPath does
|
| 527 |
+
# not return unicode strings when there is unicode data in the
|
| 528 |
+
# path.
|
| 529 |
+
try:
|
| 530 |
+
dir = unicode(dir)
|
| 531 |
+
|
| 532 |
+
# Downgrade to short path name if have highbit chars. See
|
| 533 |
+
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
| 534 |
+
has_high_char = False
|
| 535 |
+
for c in dir:
|
| 536 |
+
if ord(c) > 255:
|
| 537 |
+
has_high_char = True
|
| 538 |
+
break
|
| 539 |
+
if has_high_char:
|
| 540 |
+
try:
|
| 541 |
+
import win32api
|
| 542 |
+
|
| 543 |
+
dir = win32api.GetShortPathName(dir)
|
| 544 |
+
except ImportError:
|
| 545 |
+
pass
|
| 546 |
+
except UnicodeError:
|
| 547 |
+
pass
|
| 548 |
+
return dir
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def _get_win_folder_with_ctypes(csidl_name):
|
| 552 |
+
import ctypes
|
| 553 |
+
|
| 554 |
+
csidl_const = {
|
| 555 |
+
"CSIDL_APPDATA": 26,
|
| 556 |
+
"CSIDL_COMMON_APPDATA": 35,
|
| 557 |
+
"CSIDL_LOCAL_APPDATA": 28,
|
| 558 |
+
}[csidl_name]
|
| 559 |
+
|
| 560 |
+
buf = ctypes.create_unicode_buffer(1024)
|
| 561 |
+
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
|
| 562 |
+
|
| 563 |
+
# Downgrade to short path name if have highbit chars. See
|
| 564 |
+
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
| 565 |
+
has_high_char = False
|
| 566 |
+
for c in buf:
|
| 567 |
+
if ord(c) > 255:
|
| 568 |
+
has_high_char = True
|
| 569 |
+
break
|
| 570 |
+
if has_high_char:
|
| 571 |
+
buf2 = ctypes.create_unicode_buffer(1024)
|
| 572 |
+
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
|
| 573 |
+
buf = buf2
|
| 574 |
+
|
| 575 |
+
return buf.value
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _get_win_folder_with_jna(csidl_name):
|
| 579 |
+
import array
|
| 580 |
+
|
| 581 |
+
from com.sun import jna
|
| 582 |
+
from com.sun.jna.platform import win32
|
| 583 |
+
|
| 584 |
+
buf_size = win32.WinDef.MAX_PATH * 2
|
| 585 |
+
buf = array.zeros("c", buf_size)
|
| 586 |
+
shell = win32.Shell32.INSTANCE
|
| 587 |
+
shell.SHGetFolderPath(
|
| 588 |
+
None,
|
| 589 |
+
getattr(win32.ShlObj, csidl_name),
|
| 590 |
+
None,
|
| 591 |
+
win32.ShlObj.SHGFP_TYPE_CURRENT,
|
| 592 |
+
buf,
|
| 593 |
+
)
|
| 594 |
+
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
| 595 |
+
|
| 596 |
+
# Downgrade to short path name if have highbit chars. See
|
| 597 |
+
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
| 598 |
+
has_high_char = False
|
| 599 |
+
for c in dir:
|
| 600 |
+
if ord(c) > 255:
|
| 601 |
+
has_high_char = True
|
| 602 |
+
break
|
| 603 |
+
if has_high_char:
|
| 604 |
+
buf = array.zeros("c", buf_size)
|
| 605 |
+
kernel = win32.Kernel32.INSTANCE
|
| 606 |
+
if kernel.GetShortPathName(dir, buf, buf_size):
|
| 607 |
+
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
| 608 |
+
|
| 609 |
+
return dir
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
if system == "win32":
|
| 613 |
+
try:
|
| 614 |
+
import win32com.shell
|
| 615 |
+
|
| 616 |
+
_get_win_folder = _get_win_folder_with_pywin32
|
| 617 |
+
except ImportError:
|
| 618 |
+
try:
|
| 619 |
+
from ctypes import windll
|
| 620 |
+
|
| 621 |
+
_get_win_folder = _get_win_folder_with_ctypes
|
| 622 |
+
except ImportError:
|
| 623 |
+
try:
|
| 624 |
+
import com.sun.jna
|
| 625 |
+
|
| 626 |
+
_get_win_folder = _get_win_folder_with_jna
|
| 627 |
+
except ImportError:
|
| 628 |
+
_get_win_folder = _get_win_folder_from_registry
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
# ---- self test code
|
| 632 |
+
|
| 633 |
+
if __name__ == "__main__":
|
| 634 |
+
appname = "MyApp"
|
| 635 |
+
appauthor = "MyCompany"
|
| 636 |
+
|
| 637 |
+
props = (
|
| 638 |
+
"user_data_dir",
|
| 639 |
+
"user_config_dir",
|
| 640 |
+
"user_cache_dir",
|
| 641 |
+
"user_state_dir",
|
| 642 |
+
"user_log_dir",
|
| 643 |
+
"site_data_dir",
|
| 644 |
+
"site_config_dir",
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
print(f"-- app dirs {__version__} --")
|
| 648 |
+
|
| 649 |
+
print("-- app dirs (with optional 'version')")
|
| 650 |
+
dirs = AppDirs(appname, appauthor, version="1.0")
|
| 651 |
+
for prop in props:
|
| 652 |
+
print(f"{prop}: {getattr(dirs, prop)}")
|
| 653 |
+
|
| 654 |
+
print("\n-- app dirs (without optional 'version')")
|
| 655 |
+
dirs = AppDirs(appname, appauthor)
|
| 656 |
+
for prop in props:
|
| 657 |
+
print(f"{prop}: {getattr(dirs, prop)}")
|
| 658 |
+
|
| 659 |
+
print("\n-- app dirs (without optional 'appauthor')")
|
| 660 |
+
dirs = AppDirs(appname)
|
| 661 |
+
for prop in props:
|
| 662 |
+
print(f"{prop}: {getattr(dirs, prop)}")
|
| 663 |
+
|
| 664 |
+
print("\n-- app dirs (with disabled 'appauthor')")
|
| 665 |
+
dirs = AppDirs(appname, appauthor=False)
|
| 666 |
+
for prop in props:
|
| 667 |
+
print(f"{prop}: {getattr(dirs, prop)}")
|
.venv/lib/python3.11/site-packages/torch/_classes.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import types
|
| 3 |
+
|
| 4 |
+
import torch._C
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class _ClassNamespace(types.ModuleType):
|
| 8 |
+
def __init__(self, name):
|
| 9 |
+
super().__init__("torch.classes" + name)
|
| 10 |
+
self.name = name
|
| 11 |
+
|
| 12 |
+
def __getattr__(self, attr):
|
| 13 |
+
proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
|
| 14 |
+
if proxy is None:
|
| 15 |
+
raise RuntimeError(f"Class {self.name}.{attr} not registered!")
|
| 16 |
+
return proxy
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _Classes(types.ModuleType):
|
| 20 |
+
__file__ = "_classes.py"
|
| 21 |
+
|
| 22 |
+
def __init__(self) -> None:
|
| 23 |
+
super().__init__("torch.classes")
|
| 24 |
+
|
| 25 |
+
def __getattr__(self, name):
|
| 26 |
+
namespace = _ClassNamespace(name)
|
| 27 |
+
setattr(self, name, namespace)
|
| 28 |
+
return namespace
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def loaded_libraries(self):
|
| 32 |
+
return torch.ops.loaded_libraries
|
| 33 |
+
|
| 34 |
+
def load_library(self, path):
|
| 35 |
+
"""
|
| 36 |
+
Loads a shared library from the given path into the current process.
|
| 37 |
+
|
| 38 |
+
The library being loaded may run global initialization code to register
|
| 39 |
+
custom classes with the PyTorch JIT runtime. This allows dynamically
|
| 40 |
+
loading custom classes. For this, you should compile your class
|
| 41 |
+
and the static registration code into a shared library object, and then
|
| 42 |
+
call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
|
| 43 |
+
shared object.
|
| 44 |
+
|
| 45 |
+
After the library is loaded, it is added to the
|
| 46 |
+
``torch.classes.loaded_libraries`` attribute, a set that may be inspected
|
| 47 |
+
for the paths of all libraries loaded using this function.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
path (str): A path to a shared library to load.
|
| 51 |
+
"""
|
| 52 |
+
torch.ops.load_library(path)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# The classes "namespace"
|
| 56 |
+
classes = _Classes()
|
.venv/lib/python3.11/site-packages/torch/_compile.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""
|
| 3 |
+
APIs related to torch.compile which lazily import torch._dynamo to avoid
|
| 4 |
+
circular dependencies.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import functools
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _disable_dynamo(fn=None, recursive=True):
|
| 11 |
+
"""
|
| 12 |
+
This API should be only used inside torch, external users should still use
|
| 13 |
+
torch._dynamo.disable. The main goal of this API is to avoid circular
|
| 14 |
+
imports issues that is common while using _dynamo.disable inside torch
|
| 15 |
+
itself.
|
| 16 |
+
|
| 17 |
+
This API avoids it by lazily importing torch._dynamo from the import time to
|
| 18 |
+
the invocation of the decorated function.
|
| 19 |
+
"""
|
| 20 |
+
if fn is not None:
|
| 21 |
+
|
| 22 |
+
@functools.wraps(fn)
|
| 23 |
+
def inner(*args, **kwargs):
|
| 24 |
+
# cache this on the first invocation to avoid adding too much overhead.
|
| 25 |
+
disable_fn = getattr(fn, "__dynamo_disable", None)
|
| 26 |
+
if disable_fn is None:
|
| 27 |
+
import torch._dynamo
|
| 28 |
+
|
| 29 |
+
disable_fn = torch._dynamo.disable(fn, recursive)
|
| 30 |
+
fn.__dynamo_disable = disable_fn
|
| 31 |
+
|
| 32 |
+
return disable_fn(*args, **kwargs)
|
| 33 |
+
|
| 34 |
+
return inner
|
| 35 |
+
else:
|
| 36 |
+
# decorator usage like @_disable_dynamo(recursive=False). The resulting
|
| 37 |
+
# object expects the original decorated function as the arg.
|
| 38 |
+
return functools.partial(_disable_dynamo, recursive=recursive)
|
.venv/lib/python3.11/site-packages/torch/_custom_ops.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import inspect
|
| 3 |
+
|
| 4 |
+
from torch._custom_op.impl import (
|
| 5 |
+
_custom_op_with_schema,
|
| 6 |
+
_find_custom_op,
|
| 7 |
+
infer_schema,
|
| 8 |
+
parse_qualname,
|
| 9 |
+
validate_namespace,
|
| 10 |
+
)
|
| 11 |
+
from torch.library import get_ctx
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"custom_op",
|
| 16 |
+
"impl",
|
| 17 |
+
"impl_abstract",
|
| 18 |
+
"get_ctx",
|
| 19 |
+
"impl_save_for_backward",
|
| 20 |
+
"impl_backward",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def custom_op(qualname, func_or_schema=None):
|
| 25 |
+
r"""Register a new custom operator
|
| 26 |
+
|
| 27 |
+
In PyTorch, defining an op (short for "operator") is a two step-process:
|
| 28 |
+
- we need to define the op (by providing an operator name and schema)
|
| 29 |
+
- we need to implement behavior for how the operator interacts with
|
| 30 |
+
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
|
| 31 |
+
|
| 32 |
+
This entrypoint defines the custom operator (the first step)
|
| 33 |
+
you must then perform the second step by calling various
|
| 34 |
+
``impl_*`` APIs.
|
| 35 |
+
|
| 36 |
+
This API may be used as a decorator (see examples).
|
| 37 |
+
|
| 38 |
+
For a detailed guide on custom ops, please see
|
| 39 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 40 |
+
|
| 41 |
+
Arguments:
|
| 42 |
+
qualname (str): Should be a string that looks like
|
| 43 |
+
"namespace::operator_name". Operators in PyTorch need a namespace to
|
| 44 |
+
avoid name collisions; a given operator may only be created once.
|
| 45 |
+
If you are writing a Python library, we recommend the namespace to
|
| 46 |
+
be the name of your top-level module.
|
| 47 |
+
func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
|
| 48 |
+
schema that tells PyTorch the types of the inputs/outputs.
|
| 49 |
+
If this is a Callable, we will automatically infer the schema from
|
| 50 |
+
the type annotations on the function (see examples). Otherwise,
|
| 51 |
+
if you don't want to use type annotations, you may provide us the
|
| 52 |
+
schema string.
|
| 53 |
+
|
| 54 |
+
Example::
|
| 55 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
| 56 |
+
>>> import torch
|
| 57 |
+
>>> import numpy as np
|
| 58 |
+
>>> from torch import Tensor
|
| 59 |
+
>>>
|
| 60 |
+
>>> # Step 1: define the custom op.
|
| 61 |
+
>>> # We need to provide the API a "prototype function"
|
| 62 |
+
>>> # (a function that returns NotImplementedError), from which
|
| 63 |
+
>>> # we will infer the types of the inputs and outputs.
|
| 64 |
+
>>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
|
| 65 |
+
>>> def numpy_sin(x: Tensor) -> Tensor:
|
| 66 |
+
>>> raise NotImplementedError
|
| 67 |
+
>>>
|
| 68 |
+
>>> # The custom op is now accessible via the torch.ops module:
|
| 69 |
+
>>> torch.ops.mylibrary.numpy_sin
|
| 70 |
+
>>>
|
| 71 |
+
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
| 72 |
+
>>>
|
| 73 |
+
>>> # Register an implementation for CPU tensors
|
| 74 |
+
>>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
|
| 75 |
+
>>> def numpy_sin_impl_cpu(x):
|
| 76 |
+
>>> return torch.from_numpy(np.sin(x.numpy()))
|
| 77 |
+
>>>
|
| 78 |
+
>>> # Register an implementation for CUDA tensors
|
| 79 |
+
>>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
|
| 80 |
+
>>> def numpy_sin_impl_cuda(x):
|
| 81 |
+
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
|
| 82 |
+
>>>
|
| 83 |
+
>>> x = torch.randn(3)
|
| 84 |
+
>>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
|
| 85 |
+
>>>
|
| 86 |
+
>>> x_cuda = x.cuda()
|
| 87 |
+
>>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
ns, name = parse_qualname(qualname)
|
| 91 |
+
validate_namespace(ns)
|
| 92 |
+
|
| 93 |
+
def inner(func):
|
| 94 |
+
if not inspect.isfunction(func):
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"custom_op(...)(func): Expected `func` to be a Python "
|
| 97 |
+
f"function, got: {type(func)}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if func.__name__ != name:
|
| 101 |
+
raise ValueError(
|
| 102 |
+
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
|
| 103 |
+
f"to have name '{name}' but got '{func.__name__}'. "
|
| 104 |
+
f"Please either change the name of `func` or the qualname that "
|
| 105 |
+
f"is passed to `custom_op`"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
schema = infer_schema(func, mutates_args=())
|
| 109 |
+
_custom_op_with_schema(qualname, schema)
|
| 110 |
+
return func
|
| 111 |
+
|
| 112 |
+
if func_or_schema is None:
|
| 113 |
+
return inner
|
| 114 |
+
if isinstance(func_or_schema, str):
|
| 115 |
+
_custom_op_with_schema(qualname, func_or_schema)
|
| 116 |
+
else:
|
| 117 |
+
return inner(func_or_schema)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
|
| 121 |
+
r"""Register an implementation for a device type for this custom op.
|
| 122 |
+
|
| 123 |
+
If the op is passed multiple Tensor inputs with different device
|
| 124 |
+
types, it will dispatch to the registered implementation for the highest
|
| 125 |
+
priority device type among those present.
|
| 126 |
+
The supported device types, in order of priority, are {'cuda', 'cpu'}.
|
| 127 |
+
|
| 128 |
+
This API may be used as a decorator (see examples).
|
| 129 |
+
|
| 130 |
+
For a detailed guide on custom ops, please see
|
| 131 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 132 |
+
|
| 133 |
+
Arguments:
|
| 134 |
+
device_types (str or Iterable[str]): the device type(s) to register the function for.
|
| 135 |
+
|
| 136 |
+
Example::
|
| 137 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
| 138 |
+
>>> import torch
|
| 139 |
+
>>> import numpy as np
|
| 140 |
+
>>> from torch import Tensor
|
| 141 |
+
>>>
|
| 142 |
+
>>> # Step 1: define the custom op.
|
| 143 |
+
>>> # We need to provide the API a "prototype function"
|
| 144 |
+
>>> # (a function that returns NotImplementedError), from which
|
| 145 |
+
>>> # we will infer the types of the inputs and outputs.
|
| 146 |
+
>>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
|
| 147 |
+
>>> def numpy_cos(x: Tensor) -> Tensor:
|
| 148 |
+
>>> raise NotImplementedError
|
| 149 |
+
>>>
|
| 150 |
+
>>> # The custom op is now accessible via the torch.ops module:
|
| 151 |
+
>>> torch.ops.mylibrary.numpy_cos
|
| 152 |
+
>>>
|
| 153 |
+
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
| 154 |
+
>>>
|
| 155 |
+
>>> # Register an implementation for CPU tensors
|
| 156 |
+
>>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
|
| 157 |
+
>>> def numpy_cos_impl_cpu(x):
|
| 158 |
+
>>> return torch.from_numpy(np.cos(x.numpy()))
|
| 159 |
+
>>>
|
| 160 |
+
>>> # Register an implementation for CUDA tensors
|
| 161 |
+
>>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
|
| 162 |
+
>>> def numpy_cos_impl_cuda(x):
|
| 163 |
+
>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
|
| 164 |
+
>>>
|
| 165 |
+
>>> x = torch.randn(3)
|
| 166 |
+
>>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
|
| 167 |
+
>>>
|
| 168 |
+
>>> x_cuda = x.cuda()
|
| 169 |
+
>>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
|
| 170 |
+
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def inner(func):
|
| 174 |
+
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
| 175 |
+
custom_op.impl(device_types, _stacklevel=3)(func)
|
| 176 |
+
return func
|
| 177 |
+
|
| 178 |
+
if func is None:
|
| 179 |
+
return inner
|
| 180 |
+
return inner(func)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def impl_abstract(qualname, *, func=None):
|
| 184 |
+
r"""Register an abstract implementation for this operator.
|
| 185 |
+
|
| 186 |
+
An "abstract implementation" specifies the behavior of this operator on
|
| 187 |
+
Tensors that carry no data. Given some input Tensors with certain properties
|
| 188 |
+
(sizes/strides/storage_offset/device), it specifies what the properties of
|
| 189 |
+
the output Tensors are.
|
| 190 |
+
|
| 191 |
+
The abstract implementation has the same signature as the operator.
|
| 192 |
+
It is run for both FakeTensors and meta tensors. To write an abstract
|
| 193 |
+
implementation, assume that all Tensor inputs to the operator are
|
| 194 |
+
regular CPU/CUDA/Meta tensors, but they do not have storage, and
|
| 195 |
+
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
|
| 196 |
+
The abstract implementation must consist of only PyTorch operations
|
| 197 |
+
(and may not directly access the storage or data of any input or
|
| 198 |
+
intermediate Tensors).
|
| 199 |
+
|
| 200 |
+
This API may be used as a decorator (see examples).
|
| 201 |
+
|
| 202 |
+
For a detailed guide on custom ops, please see
|
| 203 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 204 |
+
|
| 205 |
+
Examples::
|
| 206 |
+
>>> import numpy as np
|
| 207 |
+
>>> from torch import Tensor
|
| 208 |
+
>>>
|
| 209 |
+
>>> # Example 1: an operator without data-dependent output shape
|
| 210 |
+
>>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
|
| 211 |
+
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
| 212 |
+
>>> raise NotImplementedError
|
| 213 |
+
>>>
|
| 214 |
+
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
|
| 215 |
+
>>> def custom_linear_abstract(x, weight):
|
| 216 |
+
>>> assert x.dim() == 2
|
| 217 |
+
>>> assert weight.dim() == 2
|
| 218 |
+
>>> assert bias.dim() == 1
|
| 219 |
+
>>> assert x.shape[1] == weight.shape[1]
|
| 220 |
+
>>> assert weight.shape[0] == bias.shape[0]
|
| 221 |
+
>>> assert x.device == weight.device
|
| 222 |
+
>>>
|
| 223 |
+
>>> return (x @ weight.t()) + bias
|
| 224 |
+
>>>
|
| 225 |
+
>>> # Example 2: an operator with data-dependent output shape
|
| 226 |
+
>>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
|
| 227 |
+
>>> def custom_nonzero(x: Tensor) -> Tensor:
|
| 228 |
+
>>> ...
|
| 229 |
+
>>>
|
| 230 |
+
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
|
| 231 |
+
>>> def custom_nonzero_abstract(x):
|
| 232 |
+
>>> # Number of nonzero-elements is data-dependent.
|
| 233 |
+
>>> # Since we cannot peek at the data in an abstract impl,
|
| 234 |
+
>>> # we use the ctx object to construct a new symint that
|
| 235 |
+
>>> # represents the data-dependent size.
|
| 236 |
+
>>> ctx = torch._custom_ops.get_ctx()
|
| 237 |
+
>>> nnz = ctx.create_unbacked_symint()
|
| 238 |
+
>>> shape = [x.dim(), nnz]
|
| 239 |
+
>>> result = x.new_empty(shape, dtype=torch.long)
|
| 240 |
+
>>> return result
|
| 241 |
+
>>>
|
| 242 |
+
>>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
|
| 243 |
+
>>> def custom_nonzero_impl(x):
|
| 244 |
+
>>> x_np = to_numpy(x)
|
| 245 |
+
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
| 246 |
+
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
|
| 247 |
+
>>> # constrain the range to at least 2
|
| 248 |
+
>>> if res.shape[0] <= 1:
|
| 249 |
+
>>> raise RuntimeError("not supported")
|
| 250 |
+
>>> return torch.tensor(res, device=x.device)
|
| 251 |
+
|
| 252 |
+
"""
|
| 253 |
+
import torch.library
|
| 254 |
+
|
| 255 |
+
return torch.library.register_fake(qualname, func, _stacklevel=2)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def impl_save_for_backward(qualname, *, func=None):
|
| 259 |
+
r"""Register a function that tells us what to save for backward.
|
| 260 |
+
|
| 261 |
+
Please see :func:`impl_backward` for more details.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def inner(func):
|
| 265 |
+
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
| 266 |
+
custom_op.impl_save_for_backward(_stacklevel=3)(func)
|
| 267 |
+
return func
|
| 268 |
+
|
| 269 |
+
if func is None:
|
| 270 |
+
return inner
|
| 271 |
+
return inner(func)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def impl_backward(qualname, output_differentiability=None, *, func=None):
|
| 275 |
+
r"""Registers a backward formula for an operator.
|
| 276 |
+
|
| 277 |
+
In order for an operator to work with autograd, you need to register
|
| 278 |
+
a backward formula. There are two pieces to this:
|
| 279 |
+
1. You must give us a function to specify what to save for backward.
|
| 280 |
+
Call this the "save for backward" function.
|
| 281 |
+
2. You must give us a function that computes gradients. Call this the
|
| 282 |
+
"backward" function.
|
| 283 |
+
|
| 284 |
+
Use `impl_save_for_backward` to define a "save for backward" function
|
| 285 |
+
that specifies what gets saved for backward. The function should accept
|
| 286 |
+
two arguments ``(inputs, output)`` and return the quantities to be saved
|
| 287 |
+
for backward.
|
| 288 |
+
|
| 289 |
+
During runtime, when you call the operator in a forwards pass, PyTorch
|
| 290 |
+
will invoke the "save for backward" function with the inputs and output
|
| 291 |
+
of the operator.
|
| 292 |
+
|
| 293 |
+
Use `impl_backward` to define the "backward" function. The backward
|
| 294 |
+
function must accept ``(ctx, saved, *grads)``:
|
| 295 |
+
- ``ctx`` is a context object where we may provide information
|
| 296 |
+
- ``saved`` is exactly what gets returned from the "save for backward"
|
| 297 |
+
function
|
| 298 |
+
- ``grads`` is one or more gradients. The number of gradients matches
|
| 299 |
+
the number of outputs of the operator.
|
| 300 |
+
|
| 301 |
+
The backward function must return a dict that maps the name of
|
| 302 |
+
an input to the operator to its corresponding gradient. All inputs that
|
| 303 |
+
were declared to be Tensors in the operator definition must be accounted
|
| 304 |
+
for in the dict. The gradient may be a Tensor or None.
|
| 305 |
+
|
| 306 |
+
For a detailed guide on custom ops, please see
|
| 307 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
| 308 |
+
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def inner(func):
|
| 312 |
+
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
| 313 |
+
custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
|
| 314 |
+
return func
|
| 315 |
+
|
| 316 |
+
if func is None:
|
| 317 |
+
return inner
|
| 318 |
+
return inner(func)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _destroy(qualname):
|
| 322 |
+
"""De-registers a custom op. For testing purposes only"""
|
| 323 |
+
custom_op = _find_custom_op(qualname)
|
| 324 |
+
custom_op._destroy()
|
.venv/lib/python3.11/site-packages/torch/_deploy.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import io
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
|
| 6 |
+
from torch.package._package_pickler import create_pickler
|
| 7 |
+
from torch.package._package_unpickler import PackageUnpickler
|
| 8 |
+
from torch.serialization import _maybe_decode_ascii
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _save_storages(importer, obj):
|
| 12 |
+
serialized_storages = []
|
| 13 |
+
serialized_dtypes = []
|
| 14 |
+
|
| 15 |
+
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
|
| 16 |
+
importers: Importer
|
| 17 |
+
if importer is not None:
|
| 18 |
+
importers = OrderedImporter(importer, sys_importer)
|
| 19 |
+
else:
|
| 20 |
+
importers = sys_importer
|
| 21 |
+
|
| 22 |
+
def persistent_id(obj):
|
| 23 |
+
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
|
| 24 |
+
if isinstance(obj, torch.storage.TypedStorage):
|
| 25 |
+
# TODO: Once we decide to break serialization FC, we can
|
| 26 |
+
# remove this case
|
| 27 |
+
dtype = obj.dtype
|
| 28 |
+
else:
|
| 29 |
+
dtype = torch.uint8
|
| 30 |
+
|
| 31 |
+
serialized_storages.append(obj)
|
| 32 |
+
serialized_dtypes.append(dtype)
|
| 33 |
+
return ("storage", len(serialized_storages) - 1)
|
| 34 |
+
|
| 35 |
+
if hasattr(obj, "__reduce_deploy__"):
|
| 36 |
+
if _serialized_reduces.get(id(obj)) is None:
|
| 37 |
+
_serialized_reduces[id(obj)] = (
|
| 38 |
+
"reduce_deploy",
|
| 39 |
+
id(obj),
|
| 40 |
+
*obj.__reduce_deploy__(importers),
|
| 41 |
+
)
|
| 42 |
+
return _serialized_reduces[id(obj)]
|
| 43 |
+
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
# Write the pickle data for `obj`
|
| 47 |
+
data_buf = io.BytesIO()
|
| 48 |
+
pickler = create_pickler(data_buf, importers)
|
| 49 |
+
pickler.persistent_id = persistent_id
|
| 50 |
+
pickler.dump(obj)
|
| 51 |
+
data_value = data_buf.getvalue()
|
| 52 |
+
return (
|
| 53 |
+
data_value,
|
| 54 |
+
serialized_storages,
|
| 55 |
+
serialized_dtypes,
|
| 56 |
+
importer.zip_reader if importer else None,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
|
| 61 |
+
def persistent_load(saved_id):
|
| 62 |
+
assert isinstance(saved_id, tuple)
|
| 63 |
+
typename = _maybe_decode_ascii(saved_id[0])
|
| 64 |
+
data = saved_id[1:]
|
| 65 |
+
|
| 66 |
+
if typename == "storage":
|
| 67 |
+
# TODO: Once we decide to break serialization FC, we can
|
| 68 |
+
# stop wrapping with TypedStorage
|
| 69 |
+
storage = serialized_storages[data[0]]
|
| 70 |
+
dtype = serialized_dtypes[data[0]]
|
| 71 |
+
return torch.storage.TypedStorage(
|
| 72 |
+
wrap_storage=storage.untyped(), dtype=dtype
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if typename == "reduce_deploy":
|
| 76 |
+
reduce_id, func, args = data
|
| 77 |
+
if reduce_id not in _loaded_reduces:
|
| 78 |
+
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
|
| 79 |
+
return _loaded_reduces[reduce_id]
|
| 80 |
+
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
importer: Importer
|
| 84 |
+
if zip_reader is not None:
|
| 85 |
+
importer = OrderedImporter(_get_package(zip_reader), sys_importer)
|
| 86 |
+
else:
|
| 87 |
+
importer = sys_importer
|
| 88 |
+
|
| 89 |
+
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
|
| 90 |
+
unpickler.persistent_load = persistent_load # type: ignore[method-assign]
|
| 91 |
+
result = _deploy_objects[id] = unpickler.load()
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _get_package(zip_reader):
|
| 96 |
+
if zip_reader not in _raw_packages:
|
| 97 |
+
_raw_packages[zip_reader] = PackageImporter(zip_reader)
|
| 98 |
+
return _raw_packages[zip_reader]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
_raw_packages: dict = {}
|
| 102 |
+
_deploy_objects: dict = {}
|
| 103 |
+
_serialized_reduces: dict = {}
|
| 104 |
+
_loaded_reduces: dict = {}
|
.venv/lib/python3.11/site-packages/torch/_guards.py
ADDED
|
@@ -0,0 +1,925 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import contextlib
|
| 5 |
+
import dataclasses
|
| 6 |
+
import enum
|
| 7 |
+
import functools
|
| 8 |
+
import logging
|
| 9 |
+
import threading
|
| 10 |
+
import traceback
|
| 11 |
+
import unittest.mock
|
| 12 |
+
import weakref
|
| 13 |
+
from abc import abstractmethod
|
| 14 |
+
from contextlib import contextmanager
|
| 15 |
+
from typing import (
|
| 16 |
+
Any,
|
| 17 |
+
Callable,
|
| 18 |
+
Dict,
|
| 19 |
+
Generic,
|
| 20 |
+
List,
|
| 21 |
+
NamedTuple,
|
| 22 |
+
Optional,
|
| 23 |
+
Set,
|
| 24 |
+
Tuple,
|
| 25 |
+
TYPE_CHECKING,
|
| 26 |
+
TypeVar,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from torch._C._dynamo.eval_frame import set_context_frame # noqa: F401
|
| 30 |
+
from torch.utils import _pytree as pytree
|
| 31 |
+
from torch.utils._traceback import CapturedTraceback
|
| 32 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
log = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if TYPE_CHECKING:
|
| 39 |
+
import sympy
|
| 40 |
+
|
| 41 |
+
# Import the following modules during type checking to enable code intelligence features,
|
| 42 |
+
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
| 43 |
+
# imported in user code.
|
| 44 |
+
import torch
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
torch._guards is the definitional source of truth for general purpose guard structures.
|
| 49 |
+
|
| 50 |
+
An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
|
| 51 |
+
and no guard installation notions here.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CompileId(NamedTuple):
|
| 56 |
+
frame_id: int
|
| 57 |
+
# This id is per-frame, and counts how many times we've compiled this
|
| 58 |
+
# frame. This could have been a global id but having this be per-frame
|
| 59 |
+
# gives you a better intuitive sense for how many recompiles have occurred
|
| 60 |
+
# so far.
|
| 61 |
+
frame_compile_id: int
|
| 62 |
+
# TODO: consider also tracking the recompilation count
|
| 63 |
+
|
| 64 |
+
def __str__(self):
|
| 65 |
+
return f"{self.frame_id}/{self.frame_compile_id}"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TraceId(NamedTuple):
|
| 69 |
+
compile_id: CompileId
|
| 70 |
+
# This starts off as 0, and every time we restart analysis it goes
|
| 71 |
+
# up by one
|
| 72 |
+
attempt: int
|
| 73 |
+
|
| 74 |
+
def __str__(self):
|
| 75 |
+
if self.attempt == 0:
|
| 76 |
+
return str(self.compile_id)
|
| 77 |
+
else:
|
| 78 |
+
return f"{self.compile_id}_{self.attempt}"
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class GuardSource(enum.Enum):
|
| 82 |
+
LOCAL = 0
|
| 83 |
+
GLOBAL = 1
|
| 84 |
+
LOCAL_SPECIALIZED_NN_MODULE = 2
|
| 85 |
+
GLOBAL_SPECIALIZED_NN_MODULE = 3
|
| 86 |
+
CONSTANT = 4
|
| 87 |
+
RANDOM_VALUE = 5
|
| 88 |
+
SHAPE_ENV = 6
|
| 89 |
+
LOCAL_FSDP_MODULE = 7
|
| 90 |
+
GLOBAL_FSDP_MODULE = 8
|
| 91 |
+
BACKWARD_STATE = 9
|
| 92 |
+
EPHEMERAL = 10
|
| 93 |
+
SYNTHETIC_LOCAL = 11
|
| 94 |
+
LOCAL_UNSPECIALIZED_NN_MODULE = 12
|
| 95 |
+
GLOBAL_UNSPECIALIZED_NN_MODULE = 13
|
| 96 |
+
LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14
|
| 97 |
+
GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15
|
| 98 |
+
|
| 99 |
+
def is_fsdp_module(self) -> bool:
|
| 100 |
+
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
|
| 101 |
+
|
| 102 |
+
def is_specialized_nn_module(self) -> bool:
|
| 103 |
+
return (
|
| 104 |
+
self
|
| 105 |
+
in (
|
| 106 |
+
GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
|
| 107 |
+
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
| 108 |
+
)
|
| 109 |
+
# TODO (anijain2305) - Investigate why is_fsdp_module required.
|
| 110 |
+
or self.is_fsdp_module()
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def is_unspecialized_nn_module(self) -> bool:
|
| 114 |
+
return self in (
|
| 115 |
+
GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
|
| 116 |
+
GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
|
| 117 |
+
GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
| 118 |
+
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def is_unspecialized_builtin_nn_module(self) -> bool:
|
| 122 |
+
return self in (
|
| 123 |
+
GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
| 124 |
+
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def is_local(self):
|
| 128 |
+
return self in (
|
| 129 |
+
GuardSource.LOCAL,
|
| 130 |
+
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
| 131 |
+
GuardSource.LOCAL_FSDP_MODULE,
|
| 132 |
+
GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
|
| 133 |
+
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
"""
|
| 138 |
+
Base class for a "GuardBuilder" role.
|
| 139 |
+
|
| 140 |
+
The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
|
| 141 |
+
confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
|
| 142 |
+
to torchdynamo's GuardBuilder.
|
| 143 |
+
|
| 144 |
+
Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
|
| 145 |
+
on GuardSource's select function.
|
| 146 |
+
|
| 147 |
+
There is value in keeping this GuardBuilderBase empty to keep layering clean.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class GuardBuilderBase:
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class ShapeGuard(NamedTuple):
|
| 156 |
+
expr: sympy.Expr
|
| 157 |
+
stack: CapturedTraceback
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@dataclasses.dataclass
|
| 161 |
+
class Guard:
|
| 162 |
+
# originating_source is the source that called the make_guard method to
|
| 163 |
+
# construct this guard object. The property name specifies what exactly it
|
| 164 |
+
# is the guard is guarding on. The meaning of the name is dependent on the
|
| 165 |
+
# create_fn; you must look at the use-site inside create_fn to know what
|
| 166 |
+
# name means.
|
| 167 |
+
#
|
| 168 |
+
# That being said, although you might think this is just a "name", name is
|
| 169 |
+
# usually an arbitrary Python expression that will be evaluated with all
|
| 170 |
+
# globals (and locals, if you create a LOCAL guard) to extract the Python
|
| 171 |
+
# object that we want to perform guard tests on. This evaluation
|
| 172 |
+
# typically happens in GuardBuilder.eval. In these cases, name is
|
| 173 |
+
# typically produced by originating_source.name() (not to be confused with
|
| 174 |
+
# GuardSource - the property source).
|
| 175 |
+
#
|
| 176 |
+
# Occasionally, name is not a valid Python expression; sometimes
|
| 177 |
+
# it is meaningless. Example create_fns that are like this include
|
| 178 |
+
# GRAD_MODE and SHAPE_ENV.
|
| 179 |
+
originating_source: Source
|
| 180 |
+
create_fn: Callable[[GuardBuilderBase, Guard], None]
|
| 181 |
+
|
| 182 |
+
# Export only. These values are written to at time of guard check_fn creation.
|
| 183 |
+
guard_types: Optional[List[str]] = None
|
| 184 |
+
code_list: Optional[List[str]] = None
|
| 185 |
+
obj_weakref: Optional[object] = None
|
| 186 |
+
guarded_class_weakref: Optional[type] = None
|
| 187 |
+
|
| 188 |
+
stack: Optional[CapturedTraceback] = None
|
| 189 |
+
user_stack: Optional[traceback.StackSummary] = None
|
| 190 |
+
_hash: Optional[int] = None
|
| 191 |
+
|
| 192 |
+
def __hash__(self):
|
| 193 |
+
if self._hash is None:
|
| 194 |
+
self._hash = hash((self.name, self.source, id(self.create_fn)))
|
| 195 |
+
return self._hash
|
| 196 |
+
|
| 197 |
+
def sort_key(self):
|
| 198 |
+
# Put the duplicate input guards at the end. The duplicate guards have
|
| 199 |
+
# two sources while guard.name only considers one source.
|
| 200 |
+
from torch._dynamo.guards import GuardBuilder
|
| 201 |
+
|
| 202 |
+
is_duplicate_input = (
|
| 203 |
+
isinstance(self.create_fn, functools.partial)
|
| 204 |
+
and self.create_fn.func is GuardBuilder.DUPLICATE_INPUT
|
| 205 |
+
)
|
| 206 |
+
return (
|
| 207 |
+
is_duplicate_input,
|
| 208 |
+
self.source.value if self.source else -1,
|
| 209 |
+
len(self.name),
|
| 210 |
+
self.name,
|
| 211 |
+
self.inner_create_fn().__code__.co_firstlineno,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def __lt__(self, other):
|
| 215 |
+
return self.sort_key() < other.sort_key()
|
| 216 |
+
|
| 217 |
+
def inner_create_fn(self):
|
| 218 |
+
if isinstance(self.create_fn, functools.partial):
|
| 219 |
+
return self.create_fn.func
|
| 220 |
+
else:
|
| 221 |
+
return self.create_fn
|
| 222 |
+
|
| 223 |
+
@property
|
| 224 |
+
def name(self) -> str:
|
| 225 |
+
return self.originating_source.name()
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def source(self) -> GuardSource:
|
| 229 |
+
return self.originating_source.guard_source()
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def weakref_to_str(obj_weakref):
|
| 233 |
+
"""
|
| 234 |
+
This is a workaround of a Python weakref bug.
|
| 235 |
+
|
| 236 |
+
`obj_weakref` is instance returned by `weakref.ref`,
|
| 237 |
+
`str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
|
| 238 |
+
|
| 239 |
+
class MyConfig(dict):
|
| 240 |
+
def __getattr__(self, x):
|
| 241 |
+
return self[x]
|
| 242 |
+
|
| 243 |
+
obj = MyConfig(offset=5)
|
| 244 |
+
obj_weakref = weakref.ref(obj)
|
| 245 |
+
str(obj_weakref) # raise error: KeyError: '__name__'
|
| 246 |
+
"""
|
| 247 |
+
if isinstance(obj_weakref, weakref.ReferenceType):
|
| 248 |
+
obj = obj_weakref()
|
| 249 |
+
if obj is not None:
|
| 250 |
+
return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
|
| 251 |
+
else:
|
| 252 |
+
return f"<weakref at {hex(id(obj_weakref))}; dead>"
|
| 253 |
+
else:
|
| 254 |
+
return str(obj_weakref)
|
| 255 |
+
|
| 256 |
+
def __repr__(self):
|
| 257 |
+
s = f"""
|
| 258 |
+
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
|
| 259 |
+
{{
|
| 260 |
+
'guard_types': {self.guard_types},
|
| 261 |
+
'code': {self.code_list},
|
| 262 |
+
'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
|
| 263 |
+
'guarded_class': {self.guarded_class_weakref}
|
| 264 |
+
}}
|
| 265 |
+
"""
|
| 266 |
+
return s
|
| 267 |
+
|
| 268 |
+
def __str__(self):
|
| 269 |
+
output = f"Name: {repr(self.name)}\n"
|
| 270 |
+
source = self.source.name.lower() if self.source else ""
|
| 271 |
+
output += f" Source: {source}\n"
|
| 272 |
+
output += f" Create Function: {self.inner_create_fn().__name__}\n"
|
| 273 |
+
output += f" Guard Types: {self.guard_types}\n"
|
| 274 |
+
output += f" Code List: {self.code_list}\n"
|
| 275 |
+
output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
|
| 276 |
+
output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
|
| 277 |
+
return output
|
| 278 |
+
|
| 279 |
+
def create(self, builder: GuardBuilderBase):
|
| 280 |
+
try:
|
| 281 |
+
return self.create_fn(builder, self)
|
| 282 |
+
except Exception:
|
| 283 |
+
log.exception("Error while creating guard:\n%s", str(self).rstrip())
|
| 284 |
+
if self.stack:
|
| 285 |
+
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
|
| 286 |
+
raise
|
| 287 |
+
|
| 288 |
+
def is_specialized_nn_module(self):
|
| 289 |
+
return self.source.is_specialized_nn_module()
|
| 290 |
+
|
| 291 |
+
def is_fsdp_module(self):
|
| 292 |
+
return self.source.is_fsdp_module()
|
| 293 |
+
|
| 294 |
+
def is_local(self):
|
| 295 |
+
return self.source.is_local()
|
| 296 |
+
|
| 297 |
+
def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
|
| 298 |
+
if not self.guard_types:
|
| 299 |
+
self.guard_types = []
|
| 300 |
+
|
| 301 |
+
self.guard_types.append(guard_type)
|
| 302 |
+
|
| 303 |
+
assert self.guarded_class_weakref in (
|
| 304 |
+
guarded_class,
|
| 305 |
+
None,
|
| 306 |
+
), "Guarded class id must be identical, or None"
|
| 307 |
+
self.guarded_class_weakref = guarded_class
|
| 308 |
+
|
| 309 |
+
if not self.code_list:
|
| 310 |
+
self.code_list = code_list
|
| 311 |
+
else:
|
| 312 |
+
self.code_list.extend(code_list)
|
| 313 |
+
|
| 314 |
+
# Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
|
| 315 |
+
# multiple guards on the same object, the weakref can die between the
|
| 316 |
+
# invocation of set_export_info calls. So a dead weakref is also
|
| 317 |
+
# acceptable.
|
| 318 |
+
assert (
|
| 319 |
+
self.obj_weakref in (obj_weakref, None)
|
| 320 |
+
or callable(self.obj_weakref)
|
| 321 |
+
and self.obj_weakref() is None
|
| 322 |
+
), "Guarded object must be identical, None or ephemeral (dead weakref)"
|
| 323 |
+
self.obj_weakref = obj_weakref
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
T = TypeVar("T")
|
| 327 |
+
|
| 328 |
+
"""
|
| 329 |
+
Parent structure for guard env expressions.
|
| 330 |
+
A GuardEnvExpr can have any subtype.
|
| 331 |
+
Note: All subtypes must be handled exhaustively in
|
| 332 |
+
torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@dataclasses.dataclass
|
| 337 |
+
class GuardEnvExpr:
|
| 338 |
+
pass
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
"""
|
| 342 |
+
A class representing a pair of duplicate inputs.
|
| 343 |
+
input_pos_a and input_pos_b are input positions we have deduped.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@dataclasses.dataclass
|
| 348 |
+
class DuplicateInputs(GuardEnvExpr):
|
| 349 |
+
input_source_a: Source
|
| 350 |
+
input_source_b: Source
|
| 351 |
+
|
| 352 |
+
def __post_init__(self):
|
| 353 |
+
assert self.input_source_a != self.input_source_b
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
"""
|
| 357 |
+
Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
|
| 358 |
+
|
| 359 |
+
copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
|
| 360 |
+
can also be taken in at restore_graphstate(T) calls.
|
| 361 |
+
|
| 362 |
+
When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
|
| 363 |
+
does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet.
|
| 364 |
+
|
| 365 |
+
In the future, it will have a closer coupling to a generic Checkpoint management system.
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class Checkpointable(Generic[T]):
|
| 370 |
+
@abstractmethod
|
| 371 |
+
def copy_graphstate(self) -> T: ...
|
| 372 |
+
|
| 373 |
+
@abstractmethod
|
| 374 |
+
def restore_graphstate(self, state: T): ...
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class GuardsCheckpointState:
|
| 378 |
+
"""
|
| 379 |
+
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
dynamo_guards: Set[Guard] = set()
|
| 383 |
+
|
| 384 |
+
def __init__(self, dynamo_guards):
|
| 385 |
+
self.dynamo_guards = dynamo_guards
|
| 386 |
+
|
| 387 |
+
def diff(self, other):
|
| 388 |
+
"""
|
| 389 |
+
Produces a delta against another GuardsCheckpointState.
|
| 390 |
+
|
| 391 |
+
Returns None if no delta is found, otherwise, return a set() of mismatched
|
| 392 |
+
Guard type objects.
|
| 393 |
+
"""
|
| 394 |
+
r = self.dynamo_guards.difference(other.dynamo_guards)
|
| 395 |
+
if len(r) == 0:
|
| 396 |
+
return None
|
| 397 |
+
return r
|
| 398 |
+
|
| 399 |
+
def __eq__(self, other):
|
| 400 |
+
return self.diff(other) is None
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class ModuleContextCheckpointState:
|
| 404 |
+
nn_modules: Dict[str, torch.nn.Module] = {}
|
| 405 |
+
|
| 406 |
+
def __init__(self, nn_modules):
|
| 407 |
+
self.nn_modules = nn_modules
|
| 408 |
+
|
| 409 |
+
def diff(self, other):
|
| 410 |
+
"""
|
| 411 |
+
Produces a delta against another ModuleContextCheckpointState.
|
| 412 |
+
|
| 413 |
+
Returns None if no delta is found, otherwise, return a set() of mismatched
|
| 414 |
+
module key names.
|
| 415 |
+
"""
|
| 416 |
+
r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
|
| 417 |
+
if len(r) == 0:
|
| 418 |
+
return None
|
| 419 |
+
return r
|
| 420 |
+
|
| 421 |
+
def __eq__(self, other):
|
| 422 |
+
return self.diff(other) is None
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
| 426 |
+
def __init__(self) -> None:
|
| 427 |
+
self.nn_modules: Dict[str, Any] = {}
|
| 428 |
+
|
| 429 |
+
def copy_graphstate(self):
|
| 430 |
+
return ModuleContextCheckpointState(dict(self.nn_modules))
|
| 431 |
+
|
| 432 |
+
def restore_graphstate(self, state):
|
| 433 |
+
assert isinstance(state, ModuleContextCheckpointState)
|
| 434 |
+
self.nn_modules = state.nn_modules
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class GlobalContextCheckpointState:
|
| 438 |
+
global_state: Dict[str, Tuple[Callable, ...]] = {}
|
| 439 |
+
|
| 440 |
+
def __init__(self, global_states):
|
| 441 |
+
self.global_state = global_states
|
| 442 |
+
|
| 443 |
+
def diff(self, other):
|
| 444 |
+
"""
|
| 445 |
+
Produces a delta against another GlobalContextCheckpointState.
|
| 446 |
+
|
| 447 |
+
Returns None if no delta is found, otherwise, return a set() of mismatched
|
| 448 |
+
global key names.
|
| 449 |
+
"""
|
| 450 |
+
r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
|
| 451 |
+
if len(r) == 0:
|
| 452 |
+
return None
|
| 453 |
+
return r
|
| 454 |
+
|
| 455 |
+
def __eq__(self, other):
|
| 456 |
+
return self.diff(other) is None
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
|
| 460 |
+
"""
|
| 461 |
+
This keeps track of the global torch state during tracing of a function.
|
| 462 |
+
For example, torch.is_grad_enabled.
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
_supported_global_states = {
|
| 466 |
+
"grad_enabled",
|
| 467 |
+
"torch_function_enabled",
|
| 468 |
+
"autocast_enabled",
|
| 469 |
+
"autocast_cpu_enabled",
|
| 470 |
+
"autocast_gpu_dtype",
|
| 471 |
+
"autocast_cpu_dtype",
|
| 472 |
+
"autocast_cache_enabled",
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
def __init__(self) -> None:
|
| 476 |
+
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
|
| 477 |
+
|
| 478 |
+
def copy_graphstate(self):
|
| 479 |
+
return GlobalContextCheckpointState(dict(self.global_state))
|
| 480 |
+
|
| 481 |
+
def restore_graphstate(self, state):
|
| 482 |
+
assert isinstance(state, GlobalContextCheckpointState)
|
| 483 |
+
self.global_state = state.global_state
|
| 484 |
+
assert (
|
| 485 |
+
len(self.global_state) == len(self._supported_global_states)
|
| 486 |
+
and set(self.global_state.keys()) == self._supported_global_states
|
| 487 |
+
), "Global state mismatch"
|
| 488 |
+
for func, args in self.global_state.values():
|
| 489 |
+
func(args)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
"""
|
| 493 |
+
A GuardsContext is a checkpointable representation of all the guards in the current tracing
|
| 494 |
+
context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
|
| 495 |
+
directly outside of it. For passing around internal state representations of this object,
|
| 496 |
+
prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# Like a Set[Guard] but will record the user stack on all guards at the
|
| 501 |
+
# time they were installed at their destination
|
| 502 |
+
class GuardsSet:
|
| 503 |
+
def __init__(self, inner=None):
|
| 504 |
+
if inner is None:
|
| 505 |
+
inner = set()
|
| 506 |
+
self.inner = inner
|
| 507 |
+
|
| 508 |
+
def __iter__(self):
|
| 509 |
+
return iter(self.inner)
|
| 510 |
+
|
| 511 |
+
def __len__(self):
|
| 512 |
+
return len(self.inner)
|
| 513 |
+
|
| 514 |
+
# Subtraction along with bool is typically used to determine the delta of
|
| 515 |
+
# added guards between checkpoints for higher order ops
|
| 516 |
+
def __sub__(self, other):
|
| 517 |
+
return GuardsSet(self.inner - other.inner)
|
| 518 |
+
|
| 519 |
+
def __bool__(self):
|
| 520 |
+
return bool(self.inner)
|
| 521 |
+
|
| 522 |
+
def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
|
| 523 |
+
if guard in self.inner:
|
| 524 |
+
return
|
| 525 |
+
if collect_debug_stack:
|
| 526 |
+
if guard.stack is None:
|
| 527 |
+
guard.stack = CapturedTraceback.extract(skip=1 + skip)
|
| 528 |
+
if guard.user_stack is None:
|
| 529 |
+
guard.user_stack = TracingContext.extract_stack()
|
| 530 |
+
self.inner.add(guard)
|
| 531 |
+
|
| 532 |
+
def update(self, *others: Set[Guard]):
|
| 533 |
+
for o in others:
|
| 534 |
+
for g in o:
|
| 535 |
+
self.add(g, skip=1)
|
| 536 |
+
|
| 537 |
+
def remove_guards_with_source(self, source):
|
| 538 |
+
"""Delete all guards with a given source"""
|
| 539 |
+
self.inner = {g for g in self.inner if g.originating_source != source}
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
| 543 |
+
def __init__(self) -> None:
|
| 544 |
+
self.dynamo_guards: GuardsSet = GuardsSet()
|
| 545 |
+
self.aotautograd_guards: List[GuardEnvExpr] = []
|
| 546 |
+
|
| 547 |
+
def copy_graphstate(self):
|
| 548 |
+
return GuardsCheckpointState(set(self.dynamo_guards.inner))
|
| 549 |
+
|
| 550 |
+
def restore_graphstate(self, state):
|
| 551 |
+
# NB: "steals" the passed in state
|
| 552 |
+
assert isinstance(state, GuardsCheckpointState)
|
| 553 |
+
self.dynamo_guards = GuardsSet(state.dynamo_guards)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
_TLS = threading.local()
|
| 557 |
+
|
| 558 |
+
"""
|
| 559 |
+
TracingContext is the source of truth for all currently accumulated information
|
| 560 |
+
needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
|
| 561 |
+
are open to managing their own TracingContext with that in mind.
|
| 562 |
+
|
| 563 |
+
The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
|
| 564 |
+
having to plumb complex subsystems across multiple verticals.
|
| 565 |
+
|
| 566 |
+
Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
|
| 567 |
+
Accessing the current tracing context via
|
| 568 |
+
TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
|
| 569 |
+
to plumb objects back up to where frame interpretation happened.
|
| 570 |
+
|
| 571 |
+
Note that you can end up with multiple TracingContext for a single compilation
|
| 572 |
+
of a frame, as we reset the TracingContext whenever we restart analysis.
|
| 573 |
+
CompileContext is a more overarching context that encompasses multiple restarts.
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class CompileContext:
|
| 578 |
+
@staticmethod
|
| 579 |
+
def get() -> CompileContext:
|
| 580 |
+
assert _TLS.compile_context is not None
|
| 581 |
+
return _TLS.compile_context
|
| 582 |
+
|
| 583 |
+
@staticmethod
|
| 584 |
+
def try_get() -> Optional[CompileContext]:
|
| 585 |
+
return getattr(_TLS, "compile_context", None)
|
| 586 |
+
|
| 587 |
+
def __init__(self, compile_id):
|
| 588 |
+
assert compile_id is None or isinstance(compile_id, CompileId)
|
| 589 |
+
self.compile_id: Optional[CompileId] = compile_id
|
| 590 |
+
self.attempt = 0
|
| 591 |
+
|
| 592 |
+
@staticmethod
|
| 593 |
+
def current_compile_id():
|
| 594 |
+
self = CompileContext.try_get()
|
| 595 |
+
if self is None:
|
| 596 |
+
return None
|
| 597 |
+
return self.compile_id
|
| 598 |
+
|
| 599 |
+
@staticmethod
|
| 600 |
+
def current_trace_id():
|
| 601 |
+
self = CompileContext.try_get()
|
| 602 |
+
if self is None:
|
| 603 |
+
return None
|
| 604 |
+
if self.compile_id is None:
|
| 605 |
+
return None
|
| 606 |
+
return TraceId(self.compile_id, self.attempt)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class TracingContext:
|
| 610 |
+
"""
|
| 611 |
+
Provides the currently installed TracingContext, or None.
|
| 612 |
+
|
| 613 |
+
Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
|
| 614 |
+
will return None.
|
| 615 |
+
"""
|
| 616 |
+
|
| 617 |
+
@staticmethod
|
| 618 |
+
def try_get() -> Optional[TracingContext]:
|
| 619 |
+
return getattr(_TLS, "tracing_context", None)
|
| 620 |
+
|
| 621 |
+
@staticmethod
|
| 622 |
+
def get() -> TracingContext:
|
| 623 |
+
if ctx := TracingContext.try_get():
|
| 624 |
+
return ctx
|
| 625 |
+
raise RuntimeError(
|
| 626 |
+
"TracingContext.get() must be called within an ongoing trace."
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
def __init__(self, fake_mode):
|
| 630 |
+
self.guards_context = GuardsContext()
|
| 631 |
+
self.module_context = ModuleContext()
|
| 632 |
+
self.global_context = GlobalContext()
|
| 633 |
+
self.fake_mode = fake_mode
|
| 634 |
+
self.frame_summary_stack = []
|
| 635 |
+
# This is morally part of frame_summary_stack, but it is kept separate
|
| 636 |
+
# for clarity. As we process a frame, this variable gets updated
|
| 637 |
+
# to keep track of what line we are in the function. We make a
|
| 638 |
+
# function call, this gets cleared and the frame location is pushed
|
| 639 |
+
# to frame_summary_stack (prepping this variable for the inner frame's
|
| 640 |
+
# progress)
|
| 641 |
+
self.loc_in_frame = None
|
| 642 |
+
# this is only set after aot_autograd
|
| 643 |
+
self.fw_metadata = None
|
| 644 |
+
# this is only set after aot_autograd
|
| 645 |
+
self.aot_graph_name = None
|
| 646 |
+
self.params_flat = None
|
| 647 |
+
# this is for extended return calling convention from backend
|
| 648 |
+
# compiler to aot_autograd
|
| 649 |
+
# Per output, what the compiler specified stride of the output is,
|
| 650 |
+
# or None if no stride is known. This is always the HINT, it
|
| 651 |
+
# is never a SymInt (it would be better if it was a SymInt, but
|
| 652 |
+
# I can't conveniently get this from Inductor atm. Also, be
|
| 653 |
+
# careful not to accidentally induce guards on the SymInt if
|
| 654 |
+
# you ever do change this in aot_autograd.py; you should check
|
| 655 |
+
# on permutations preferentially.)
|
| 656 |
+
self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
|
| 657 |
+
# When this is True, whenever we encounter an int in Dynamo tracing,
|
| 658 |
+
# we will (1) force unspec it and (2) force it as a size-like unbacked
|
| 659 |
+
# integer. This is currently used when processing certain lists of
|
| 660 |
+
# ints that are known to be size-like and may have 0/1 entries that we
|
| 661 |
+
# must not specialize on.
|
| 662 |
+
self.force_unspec_int_unbacked_size_like = False
|
| 663 |
+
# See note [Tensor Fakification and Symbol Caching]
|
| 664 |
+
self.tensor_to_context = WeakTensorKeyDictionary()
|
| 665 |
+
|
| 666 |
+
# If this true, Aot Autograd will return output Fake Tensors with appropiate
|
| 667 |
+
# meta on the first invocation
|
| 668 |
+
# see note: [Returning Fake Tensors on First AOT Autograd Call]
|
| 669 |
+
self.fakify_first_call = False
|
| 670 |
+
|
| 671 |
+
def clear(self):
|
| 672 |
+
# Look at the note in output_graph.py in function `save_global_state`
|
| 673 |
+
# for the context on clearing global context.
|
| 674 |
+
self.global_context.global_state = {}
|
| 675 |
+
|
| 676 |
+
@staticmethod
|
| 677 |
+
@contextmanager
|
| 678 |
+
def patch(**kwargs):
|
| 679 |
+
prior = {}
|
| 680 |
+
ctx = TracingContext.get()
|
| 681 |
+
|
| 682 |
+
for key in kwargs.keys():
|
| 683 |
+
# KeyError on invalid entry
|
| 684 |
+
prior[key] = getattr(ctx, key)
|
| 685 |
+
for key, val in kwargs.items():
|
| 686 |
+
setattr(ctx, key, val)
|
| 687 |
+
try:
|
| 688 |
+
yield
|
| 689 |
+
finally:
|
| 690 |
+
for key, val in prior.items():
|
| 691 |
+
setattr(ctx, key, val)
|
| 692 |
+
|
| 693 |
+
@staticmethod
|
| 694 |
+
def extract_stack():
|
| 695 |
+
self = TracingContext.try_get()
|
| 696 |
+
if self is None:
|
| 697 |
+
return traceback.StackSummary()
|
| 698 |
+
stack = self.frame_summary_stack
|
| 699 |
+
if self.loc_in_frame is not None:
|
| 700 |
+
stack = stack + [self.loc_in_frame]
|
| 701 |
+
return traceback.StackSummary.from_list(stack)
|
| 702 |
+
|
| 703 |
+
# Call this when you want to call into some code that isn't necessarily
|
| 704 |
+
# associated with the current frame state
|
| 705 |
+
@staticmethod
|
| 706 |
+
@contextlib.contextmanager
|
| 707 |
+
def clear_frame():
|
| 708 |
+
tc = TracingContext.get()
|
| 709 |
+
with unittest.mock.patch.object(
|
| 710 |
+
tc, "frame_summary_stack", []
|
| 711 |
+
), unittest.mock.patch.object(tc, "loc_in_frame", None):
|
| 712 |
+
try:
|
| 713 |
+
yield
|
| 714 |
+
except Exception as e:
|
| 715 |
+
# Prevent real_stack from getting attached
|
| 716 |
+
#
|
| 717 |
+
# The invariant is that if an Exception as real_stack, we've
|
| 718 |
+
# appropriately attached a user stack and we no longer need to
|
| 719 |
+
# attach anything. Because we cannot conveniently interpose
|
| 720 |
+
# when an exception is thrown, we instead interpose everywhere
|
| 721 |
+
# we set what the user stack is set (using the context
|
| 722 |
+
# manager). However, our compiler stack does "tail calls"
|
| 723 |
+
# (when it calls into user compiler), at which point the
|
| 724 |
+
# parent exception frames would incorrectly attach an
|
| 725 |
+
# incorrect frame.
|
| 726 |
+
#
|
| 727 |
+
# However, if, somehow, someone raised an exception with this
|
| 728 |
+
# scope that had a stack (for example, because they are
|
| 729 |
+
# restoring the user stack state appropriately as they process
|
| 730 |
+
# node by node), we should respect it. Thus, we cannot
|
| 731 |
+
# unconditionally set None.
|
| 732 |
+
if not hasattr(e, "real_stack"):
|
| 733 |
+
e.real_stack = None # type: ignore[attr-defined]
|
| 734 |
+
raise
|
| 735 |
+
|
| 736 |
+
@staticmethod
|
| 737 |
+
@contextlib.contextmanager
|
| 738 |
+
def current_frame(frame_summary):
|
| 739 |
+
# frame_summary can be None to solely take advantage of real_stack
|
| 740 |
+
# attachment to thrown exceptions
|
| 741 |
+
tc = TracingContext.get()
|
| 742 |
+
if frame_summary is not None:
|
| 743 |
+
tc.frame_summary_stack.append(frame_summary)
|
| 744 |
+
old = tc.loc_in_frame
|
| 745 |
+
tc.loc_in_frame = None
|
| 746 |
+
try:
|
| 747 |
+
yield
|
| 748 |
+
except Exception as e:
|
| 749 |
+
if not hasattr(e, "real_stack"):
|
| 750 |
+
e.real_stack = tc.extract_stack() # type: ignore[attr-defined]
|
| 751 |
+
raise
|
| 752 |
+
finally:
|
| 753 |
+
if frame_summary is not None:
|
| 754 |
+
tc.frame_summary_stack.pop()
|
| 755 |
+
tc.loc_in_frame = old
|
| 756 |
+
|
| 757 |
+
@staticmethod
|
| 758 |
+
@contextlib.contextmanager
|
| 759 |
+
def report_output_strides():
|
| 760 |
+
tc = TracingContext.try_get()
|
| 761 |
+
if tc is None:
|
| 762 |
+
yield None
|
| 763 |
+
return
|
| 764 |
+
old_output_strides = tc.output_strides
|
| 765 |
+
tc.output_strides = []
|
| 766 |
+
try:
|
| 767 |
+
yield tc.output_strides
|
| 768 |
+
finally:
|
| 769 |
+
tc.output_strides = old_output_strides
|
| 770 |
+
|
| 771 |
+
@staticmethod
|
| 772 |
+
def set_current_loc(filename, lineno, frame_name):
|
| 773 |
+
TracingContext.get().loc_in_frame = traceback.FrameSummary(
|
| 774 |
+
filename, lineno, frame_name, lookup_line=False
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
@contextmanager
|
| 779 |
+
def compile_context(context: Optional[CompileContext]):
|
| 780 |
+
old_context = getattr(_TLS, "compile_context", None)
|
| 781 |
+
_TLS.compile_context = context
|
| 782 |
+
try:
|
| 783 |
+
yield context
|
| 784 |
+
finally:
|
| 785 |
+
if context is not None:
|
| 786 |
+
if context.compile_id is not None:
|
| 787 |
+
set_context_frame(
|
| 788 |
+
(
|
| 789 |
+
context.compile_id.frame_id,
|
| 790 |
+
context.compile_id.frame_compile_id,
|
| 791 |
+
context.attempt,
|
| 792 |
+
)
|
| 793 |
+
)
|
| 794 |
+
_TLS.compile_context = old_context
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
@contextmanager
|
| 798 |
+
def tracing(context: Optional[TracingContext]):
|
| 799 |
+
"""
|
| 800 |
+
This function installs the passed in tracing context as a dynamic scoped
|
| 801 |
+
global variable.
|
| 802 |
+
|
| 803 |
+
Calls to TracingContext.get() while not under a `with tracing()` context
|
| 804 |
+
will return None.
|
| 805 |
+
"""
|
| 806 |
+
old_context = getattr(_TLS, "tracing_context", None)
|
| 807 |
+
_TLS.tracing_context = context
|
| 808 |
+
try:
|
| 809 |
+
yield context
|
| 810 |
+
except Exception as e:
|
| 811 |
+
if not hasattr(e, "real_stack") and context is not None:
|
| 812 |
+
e.real_stack = context.extract_stack() # type: ignore[attr-defined]
|
| 813 |
+
raise
|
| 814 |
+
finally:
|
| 815 |
+
if (
|
| 816 |
+
context is not None
|
| 817 |
+
and context.fake_mode is not None
|
| 818 |
+
and context.fake_mode.shape_env is not None
|
| 819 |
+
):
|
| 820 |
+
context.fake_mode.shape_env.cleanup()
|
| 821 |
+
_TLS.tracing_context = old_context
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
# Subclasses can be found in torch/_dynamo/source.py
|
| 825 |
+
# TODO(voz): Consider a toplevel torch/_source.py
|
| 826 |
+
@dataclasses.dataclass(frozen=True)
|
| 827 |
+
class Source:
|
| 828 |
+
def is_dict_key(self):
|
| 829 |
+
return False
|
| 830 |
+
|
| 831 |
+
def is_ephemeral(self):
|
| 832 |
+
return False
|
| 833 |
+
|
| 834 |
+
def reconstruct(self, codegen):
|
| 835 |
+
raise NotImplementedError
|
| 836 |
+
|
| 837 |
+
def guard_source(self) -> GuardSource:
|
| 838 |
+
raise NotImplementedError
|
| 839 |
+
|
| 840 |
+
def name(self) -> str:
|
| 841 |
+
raise NotImplementedError
|
| 842 |
+
|
| 843 |
+
def make_guard(self, fn) -> Guard:
|
| 844 |
+
if self.guard_source() is GuardSource.CONSTANT:
|
| 845 |
+
raise NotImplementedError
|
| 846 |
+
return Guard(self, fn)
|
| 847 |
+
|
| 848 |
+
def is_specialized_nn_module(self) -> bool:
|
| 849 |
+
return self.guard_source().is_specialized_nn_module()
|
| 850 |
+
|
| 851 |
+
def subguards_allowed(self):
|
| 852 |
+
"""True if you can guard on attributes of this"""
|
| 853 |
+
return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
# Subclasses can be found in torch/_dynamo/source.py
|
| 857 |
+
@dataclasses.dataclass(frozen=True)
|
| 858 |
+
class ChainedSource(Source):
|
| 859 |
+
base: Source
|
| 860 |
+
|
| 861 |
+
def is_dict_key(self):
|
| 862 |
+
# Recurse until you either hit a ConstDictKey or a Source
|
| 863 |
+
return self.base.is_dict_key()
|
| 864 |
+
|
| 865 |
+
def is_ephemeral(self):
|
| 866 |
+
return self.base.is_ephemeral()
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def detect_fake_mode(inputs: Any = None):
|
| 870 |
+
"""
|
| 871 |
+
Attempts to "detect" what the current fake mode is. If there is one ambiently
|
| 872 |
+
available from TracingContext, we preferentially use that. Otherwise, we
|
| 873 |
+
heuristically detect the fake mode via the following sources, in order of
|
| 874 |
+
priority:
|
| 875 |
+
|
| 876 |
+
- Currently active fake mode on stack
|
| 877 |
+
- Fake mode associated with passed in tensors (inputs does not
|
| 878 |
+
have to be flattened)
|
| 879 |
+
"""
|
| 880 |
+
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
| 881 |
+
|
| 882 |
+
fake_modes = []
|
| 883 |
+
|
| 884 |
+
if context := TracingContext.try_get():
|
| 885 |
+
fake_mode = context.fake_mode
|
| 886 |
+
if fake_mode is not None:
|
| 887 |
+
fake_modes.append((fake_mode, "tracing context", 0))
|
| 888 |
+
|
| 889 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
|
| 890 |
+
|
| 891 |
+
for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
|
| 892 |
+
if isinstance(m, FakeTensorMode):
|
| 893 |
+
fake_modes.append((m, "active fake mode", i))
|
| 894 |
+
|
| 895 |
+
flat_inputs = pytree.tree_leaves(inputs)
|
| 896 |
+
for i, flat_input in enumerate(flat_inputs):
|
| 897 |
+
if isinstance(flat_input, FakeTensor):
|
| 898 |
+
fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
|
| 899 |
+
|
| 900 |
+
if fake_modes:
|
| 901 |
+
fake_mode, desc1, i1 = fake_modes[0]
|
| 902 |
+
for m, desc2, i2 in fake_modes[1:]:
|
| 903 |
+
assert fake_mode is m, (
|
| 904 |
+
f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
|
| 905 |
+
f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
|
| 906 |
+
f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
|
| 907 |
+
)
|
| 908 |
+
return fake_mode
|
| 909 |
+
else:
|
| 910 |
+
return None
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
def active_fake_mode():
|
| 914 |
+
"""
|
| 915 |
+
Inspects the dispatch mode stack for an active fake mode and returns it.
|
| 916 |
+
Returns None if no fake mode is active.
|
| 917 |
+
"""
|
| 918 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 919 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
|
| 920 |
+
|
| 921 |
+
for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
|
| 922 |
+
if isinstance(m, FakeTensorMode):
|
| 923 |
+
return m
|
| 924 |
+
|
| 925 |
+
return None
|
.venv/lib/python3.11/site-packages/torch/_jit_internal.py
ADDED
|
@@ -0,0 +1,1547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""
|
| 3 |
+
The weak_script annotation needs to be here instead of inside torch/jit/ so it
|
| 4 |
+
can be used in other places in torch/ (namely torch.nn) without running into
|
| 5 |
+
circular dependency problems
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import ast
|
| 9 |
+
import builtins
|
| 10 |
+
import collections
|
| 11 |
+
import contextlib
|
| 12 |
+
import enum
|
| 13 |
+
import inspect
|
| 14 |
+
import io
|
| 15 |
+
import pickle
|
| 16 |
+
import sys
|
| 17 |
+
import textwrap
|
| 18 |
+
import threading
|
| 19 |
+
import types
|
| 20 |
+
import typing
|
| 21 |
+
import warnings
|
| 22 |
+
import weakref
|
| 23 |
+
from typing import (
|
| 24 |
+
Any,
|
| 25 |
+
Callable,
|
| 26 |
+
Dict,
|
| 27 |
+
Final,
|
| 28 |
+
ForwardRef,
|
| 29 |
+
get_args,
|
| 30 |
+
get_origin,
|
| 31 |
+
List,
|
| 32 |
+
Optional,
|
| 33 |
+
Tuple,
|
| 34 |
+
Type,
|
| 35 |
+
Union,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
|
| 41 |
+
# Explicitly ask to import `torch.distributed.__init__` first.
|
| 42 |
+
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
|
| 43 |
+
import torch.distributed.rpc
|
| 44 |
+
import torch.package._mangling as package_mangling
|
| 45 |
+
from torch._awaits import _Await
|
| 46 |
+
from torch._C import _Await as CAwait, Future as CFuture
|
| 47 |
+
from torch._sources import fake_range, get_source_lines_and_file, parse_def
|
| 48 |
+
from torch.futures import Future
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
|
| 52 |
+
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
|
| 53 |
+
|
| 54 |
+
BuiltinUnionType: Union[Type, Tuple[Type, ...]]
|
| 55 |
+
if sys.version_info >= (3, 10):
|
| 56 |
+
# NOTE: IS_PY310_PLUS doesn't work with mypy.
|
| 57 |
+
# cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
|
| 58 |
+
BuiltinUnionType = types.UnionType
|
| 59 |
+
else:
|
| 60 |
+
BuiltinUnionType = () # trick: this makes isinstance short circuit.
|
| 61 |
+
|
| 62 |
+
LockType: Type
|
| 63 |
+
try:
|
| 64 |
+
import _thread
|
| 65 |
+
|
| 66 |
+
LockType = _thread.LockType
|
| 67 |
+
except ImportError:
|
| 68 |
+
import _dummy_thread # type: ignore[import-not-found]
|
| 69 |
+
|
| 70 |
+
LockType = _dummy_thread.LockType
|
| 71 |
+
|
| 72 |
+
# Wrapper functions that can call either of 2 functions depending on a boolean
|
| 73 |
+
# argument
|
| 74 |
+
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
|
| 75 |
+
weakref.WeakKeyDictionary()
|
| 76 |
+
) # noqa: T484
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def is_final(ann) -> bool:
|
| 83 |
+
return (
|
| 84 |
+
hasattr(ann, "__module__")
|
| 85 |
+
and ann.__module__ in {"typing", "typing_extensions"}
|
| 86 |
+
and (get_origin(ann) is Final or isinstance(ann, type(Final)))
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# allows BroadcastingList instance to be subscriptable
|
| 91 |
+
class BroadcastingListCls:
|
| 92 |
+
def __getitem__(self, types):
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# mypy doesn't support parameters on types, so we have to explicitly type each
|
| 97 |
+
# list size
|
| 98 |
+
BroadcastingList1 = BroadcastingListCls()
|
| 99 |
+
for i in range(2, 7):
|
| 100 |
+
globals()[f"BroadcastingList{i}"] = BroadcastingList1
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def is_scripting() -> bool:
|
| 104 |
+
r"""
|
| 105 |
+
Function that returns True when in compilation and False otherwise. This
|
| 106 |
+
is useful especially with the @unused decorator to leave code in your
|
| 107 |
+
model that is not yet TorchScript compatible.
|
| 108 |
+
.. testcode::
|
| 109 |
+
|
| 110 |
+
import torch
|
| 111 |
+
|
| 112 |
+
@torch.jit.unused
|
| 113 |
+
def unsupported_linear_op(x):
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
def linear(x):
|
| 117 |
+
if torch.jit.is_scripting():
|
| 118 |
+
return torch.linear(x)
|
| 119 |
+
else:
|
| 120 |
+
return unsupported_linear_op(x)
|
| 121 |
+
"""
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
|
| 126 |
+
def _qualified_name(obj, mangle_name=True) -> str:
|
| 127 |
+
# This special case allows us to override the qualified name on a type.
|
| 128 |
+
# It's currently used in conjunction with tracing, where we create a
|
| 129 |
+
# fake module to filter only supported attributes. However, since this
|
| 130 |
+
# new type is defined as a local class, we need a mechanism to override
|
| 131 |
+
# its qualname so it appears correctly in the TorchScript system. This,
|
| 132 |
+
# we set '_jit_override_qualname' with the original traced module's
|
| 133 |
+
# qualified name, which is picked up here
|
| 134 |
+
if hasattr(obj, "_jit_override_qualname"):
|
| 135 |
+
return obj._jit_override_qualname
|
| 136 |
+
# short-circuit in cases where the object already has a known qualified name
|
| 137 |
+
if isinstance(obj, torch._C.ScriptFunction):
|
| 138 |
+
return obj.qualified_name
|
| 139 |
+
|
| 140 |
+
if getattr(obj, "__name__", None):
|
| 141 |
+
name = obj.__name__
|
| 142 |
+
# Enum classes do not have `__name__` attr, instead they have `name`.
|
| 143 |
+
elif isinstance(obj, enum.Enum):
|
| 144 |
+
name = obj.name
|
| 145 |
+
else:
|
| 146 |
+
raise RuntimeError("Could not get name of python class object")
|
| 147 |
+
|
| 148 |
+
if name == "<lambda>":
|
| 149 |
+
name = "_lambda" # make name a valid identifier
|
| 150 |
+
|
| 151 |
+
module_name = obj.__module__
|
| 152 |
+
|
| 153 |
+
# If the module is actually a torchbind module, then we should short circuit
|
| 154 |
+
if module_name == "torch._classes":
|
| 155 |
+
return obj.qualified_name
|
| 156 |
+
|
| 157 |
+
# The Python docs are very clear that `__module__` can be None, but I can't
|
| 158 |
+
# figure out when it actually would be.
|
| 159 |
+
if module_name is None:
|
| 160 |
+
raise RuntimeError(
|
| 161 |
+
f"Could not get qualified name for class '{name}': "
|
| 162 |
+
"__module__ can't be None."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# if getattr(sys.modules[module_name], name) is not obj:
|
| 166 |
+
# raise RuntimeError(f"Could not get qualified name for class '{name}': "
|
| 167 |
+
# f"the attr {name} on module {module_name} is not the class")
|
| 168 |
+
|
| 169 |
+
# torch.package and TorchScript have separate mangling schemes to avoid
|
| 170 |
+
# name collisions from multiple packages. To avoid them interfering with
|
| 171 |
+
# each other, normalize the package manging here.
|
| 172 |
+
if package_mangling.is_mangled(module_name):
|
| 173 |
+
module_name = module_name.replace("<", "_")
|
| 174 |
+
module_name = module_name.replace(">", "_")
|
| 175 |
+
|
| 176 |
+
# The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
|
| 177 |
+
# does not need mangle the python class name.
|
| 178 |
+
if mangle_name:
|
| 179 |
+
# __main__ is a builtin module, so rewrite it to "__torch__".
|
| 180 |
+
if module_name == "__main__":
|
| 181 |
+
module_name = "__torch__"
|
| 182 |
+
else:
|
| 183 |
+
# Everything else gets a "__torch__" prefix to avoid name collisions
|
| 184 |
+
# with the names of user values.
|
| 185 |
+
module_name = "__torch__." + module_name
|
| 186 |
+
|
| 187 |
+
if "." in name:
|
| 188 |
+
raise RuntimeError(
|
| 189 |
+
f"Could not get qualified name for class '{name}': "
|
| 190 |
+
f"'{name}' is not a valid identifier"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return module_name + "." + name
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class SourceLoader:
|
| 197 |
+
def __init__(self):
|
| 198 |
+
self.content = {}
|
| 199 |
+
|
| 200 |
+
def cache(self, fn, source):
|
| 201 |
+
self.content[fn] = source
|
| 202 |
+
|
| 203 |
+
def get_source(self, fn):
|
| 204 |
+
return self.content.get(fn)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
loader = SourceLoader()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def createResolutionCallbackFromEnv(lookup_base):
|
| 211 |
+
"""
|
| 212 |
+
Creates a resolution callback that will look up qualified names in an
|
| 213 |
+
environment, starting with `lookup_base` for the base of any qualified
|
| 214 |
+
names, then proceeding down the lookup chain with the resolved object.
|
| 215 |
+
|
| 216 |
+
You should not use this directly, it should only be used from the other
|
| 217 |
+
createResolutionCallbackFrom* functions.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def lookupInModule(qualified_name, module):
|
| 221 |
+
if "." in qualified_name:
|
| 222 |
+
base, remaining_pieces = qualified_name.split(".", maxsplit=1)
|
| 223 |
+
module_value = getattr(module, base)
|
| 224 |
+
return lookupInModule(remaining_pieces, module_value)
|
| 225 |
+
else:
|
| 226 |
+
return getattr(module, qualified_name)
|
| 227 |
+
|
| 228 |
+
def parseNestedExpr(expr, module) -> Tuple[Any, int]:
|
| 229 |
+
i = 0
|
| 230 |
+
while i < len(expr) and expr[i] not in (",", "[", "]"):
|
| 231 |
+
i += 1
|
| 232 |
+
|
| 233 |
+
# Special case logic for the empty Tuple as a subscript (used
|
| 234 |
+
# in the type annotation `Tuple[()]`)
|
| 235 |
+
if expr[:i] == "()":
|
| 236 |
+
return (), i
|
| 237 |
+
|
| 238 |
+
base = lookupInModule(expr[:i].strip(), module)
|
| 239 |
+
assert base is not None, f"Unresolvable type {expr[:i]}"
|
| 240 |
+
if i == len(expr) or expr[i] != "[":
|
| 241 |
+
return base, i
|
| 242 |
+
|
| 243 |
+
assert expr[i] == "["
|
| 244 |
+
parts = []
|
| 245 |
+
while expr[i] != "]":
|
| 246 |
+
part_len = 0
|
| 247 |
+
i += 1
|
| 248 |
+
part, part_len = parseNestedExpr(expr[i:], module)
|
| 249 |
+
parts.append(part)
|
| 250 |
+
i += part_len
|
| 251 |
+
if len(parts) > 1:
|
| 252 |
+
return base[tuple(parts)], i + 1
|
| 253 |
+
else:
|
| 254 |
+
return base[parts[0]], i + 1
|
| 255 |
+
|
| 256 |
+
def parseExpr(expr, module):
|
| 257 |
+
try:
|
| 258 |
+
value, len_parsed = parseNestedExpr(expr, module)
|
| 259 |
+
assert len_parsed == len(
|
| 260 |
+
expr
|
| 261 |
+
), "whole expression was not parsed, falling back to c++ parser"
|
| 262 |
+
return value
|
| 263 |
+
except Exception:
|
| 264 |
+
"""
|
| 265 |
+
The python resolver fails in several cases in known unit tests, and is intended
|
| 266 |
+
to fall back gracefully to the c++ resolver in general. For example, python 2 style
|
| 267 |
+
annotations which are frequent in our unit tests often fail with types e.g. int not
|
| 268 |
+
resolvable from the calling frame.
|
| 269 |
+
"""
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
return lambda expr: parseExpr(expr, lookup_base)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def createResolutionCallbackFromFrame(frames_up: int = 0):
|
| 276 |
+
"""
|
| 277 |
+
Creates a function which, given a string variable name,
|
| 278 |
+
returns the value of the variable in the scope of the caller of
|
| 279 |
+
the function which called createResolutionCallbackFromFrame (by default).
|
| 280 |
+
|
| 281 |
+
This is used to enable access in-scope Python variables inside
|
| 282 |
+
TorchScript fragments.
|
| 283 |
+
|
| 284 |
+
frames_up is number of additional frames to go up on the stack.
|
| 285 |
+
The default value is 0, which correspond to the frame of the caller
|
| 286 |
+
of createResolutionCallbackFromFrame. Also for example, if frames_up is set
|
| 287 |
+
to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
|
| 288 |
+
will be taken.
|
| 289 |
+
|
| 290 |
+
For example, the following program prints 2::
|
| 291 |
+
|
| 292 |
+
def bar():
|
| 293 |
+
cb = createResolutionCallbackFromFrame(1)
|
| 294 |
+
print(cb("foo"))
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def baz():
|
| 298 |
+
foo = 2
|
| 299 |
+
bar()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
baz()
|
| 303 |
+
"""
|
| 304 |
+
frame = inspect.currentframe()
|
| 305 |
+
i = 0
|
| 306 |
+
while i < frames_up + 1:
|
| 307 |
+
assert frame is not None
|
| 308 |
+
frame = frame.f_back
|
| 309 |
+
i += 1
|
| 310 |
+
|
| 311 |
+
assert frame is not None
|
| 312 |
+
f_locals = frame.f_locals
|
| 313 |
+
f_globals = frame.f_globals
|
| 314 |
+
|
| 315 |
+
class env:
|
| 316 |
+
def __getattr__(self, key):
|
| 317 |
+
if key in f_locals:
|
| 318 |
+
return f_locals[key]
|
| 319 |
+
elif key in f_globals:
|
| 320 |
+
return f_globals[key]
|
| 321 |
+
elif key in dir(builtins):
|
| 322 |
+
return getattr(builtins, key)
|
| 323 |
+
|
| 324 |
+
return createResolutionCallbackFromEnv(env())
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def get_closure(fn):
|
| 328 |
+
"""
|
| 329 |
+
Get a dictionary of closed over variables from a function
|
| 330 |
+
"""
|
| 331 |
+
captures = {}
|
| 332 |
+
captures.update(fn.__globals__)
|
| 333 |
+
|
| 334 |
+
for index, captured_name in enumerate(fn.__code__.co_freevars):
|
| 335 |
+
captures[captured_name] = fn.__closure__[index].cell_contents
|
| 336 |
+
|
| 337 |
+
return captures
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# [local resolution in python]
|
| 341 |
+
# Depending on where a variable is defined, and where it is used, we may
|
| 342 |
+
# or may not be able to recover its value when recursively compiling a
|
| 343 |
+
# script function. Remember in the general case, a module or function is
|
| 344 |
+
# first defined and then later scripted. This means we do not have a
|
| 345 |
+
# chance to capture the active frames when the function is defined. Hence any
|
| 346 |
+
# name resolution has to happen later on the created closure. The way
|
| 347 |
+
# python captures type annotations restricts what we can recover. The
|
| 348 |
+
# follow example illustrates the different cases:
|
| 349 |
+
#
|
| 350 |
+
# class MyGlobalClass:
|
| 351 |
+
# ...
|
| 352 |
+
# def my_local_scope():
|
| 353 |
+
# @torch.jit.script
|
| 354 |
+
# class MyClass:
|
| 355 |
+
# ...
|
| 356 |
+
# @torch.jit.script
|
| 357 |
+
# class MyClassUsedAsVar:
|
| 358 |
+
# ...
|
| 359 |
+
# def eg(x: MyClass, y: MyGlobalClass):
|
| 360 |
+
# a_local_capture : Foo
|
| 361 |
+
# return MyClassUsedAsVar(x)
|
| 362 |
+
#
|
| 363 |
+
# MyGlobalClass is defined in the __globals__ dictionary of function
|
| 364 |
+
# 'eg', so it is always recoverable. my_local_scope introduces a new local
|
| 365 |
+
# variable scope in the function. Classes defined here are only visible as
|
| 366 |
+
# local variables. For the case of MyClassUsedAsVar, it is captured
|
| 367 |
+
# because it is used as a variable inside the body of the function, and we
|
| 368 |
+
# can resolve it using the captures returned from `get_closure`. However,
|
| 369 |
+
# the type annotations are not captured by the closure. In Python
|
| 370 |
+
# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as
|
| 371 |
+
# annotations on `eg``, but starting in Python 4.0, they will represented as
|
| 372 |
+
# strings and no longer present. Furthermore, since the body of `eg` does
|
| 373 |
+
# not reference those names, they do not appear in the list of closed over
|
| 374 |
+
# variables. In Python 2.x, type annotations are in comments, leading to a
|
| 375 |
+
# similar situation where their definitions are not available. We anticipate
|
| 376 |
+
# that most users will not run into this issue because their modules and
|
| 377 |
+
# functions will be defined at a global scope like MyGlobalClass. In cases
|
| 378 |
+
# where they are not, it is possible to work around issues by declaring the
|
| 379 |
+
# values global in the function.
|
| 380 |
+
# In Python 3.9 declaring class as global will make it invisible to
|
| 381 |
+
# `inspect.getsource`, see https://bugs.python.org/issue42666 .
|
| 382 |
+
# This could be worked around by manualy adding it to `global()` dictionary.
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def createResolutionCallbackFromClosure(fn):
|
| 386 |
+
"""
|
| 387 |
+
Create a resolutionCallback by introspecting the function instead of
|
| 388 |
+
looking up the stack for the enclosing scope
|
| 389 |
+
"""
|
| 390 |
+
closure = get_closure(fn)
|
| 391 |
+
|
| 392 |
+
class closure_lookup:
|
| 393 |
+
# This is a class since `closure` is a dict and it's easier in
|
| 394 |
+
# `env_helper` if everything just works with `getattr` calls
|
| 395 |
+
def __getattr__(self, key):
|
| 396 |
+
if key in closure:
|
| 397 |
+
return closure[key]
|
| 398 |
+
elif hasattr(typing, key):
|
| 399 |
+
return getattr(typing, key)
|
| 400 |
+
elif hasattr(builtins, key):
|
| 401 |
+
return getattr(builtins, key)
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
return createResolutionCallbackFromEnv(closure_lookup())
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def can_compile_class(cls) -> bool:
|
| 408 |
+
# If any of the functions on a type don't have a code object, this type can't
|
| 409 |
+
# be compiled and is probably a builtin / bound from C
|
| 410 |
+
if is_ignored_fn(cls):
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
+
# Ignore the following list of built-in classes.
|
| 414 |
+
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
|
| 415 |
+
if issubclass(cls, ignored_builtin_classes):
|
| 416 |
+
return False
|
| 417 |
+
|
| 418 |
+
names = cls.__dict__
|
| 419 |
+
fns = [
|
| 420 |
+
getattr(cls, name)
|
| 421 |
+
for name in names
|
| 422 |
+
if inspect.isroutine(getattr(cls, name, None))
|
| 423 |
+
]
|
| 424 |
+
has_code = [hasattr(fn, "__code__") for fn in fns]
|
| 425 |
+
return all(has_code)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def get_callable_argument_names(fn) -> List[str]:
|
| 429 |
+
"""
|
| 430 |
+
Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
|
| 431 |
+
Returns an empty list when other types of arguments are present.
|
| 432 |
+
|
| 433 |
+
This is used by `torch.jit.trace` to assign meaningful argument names to
|
| 434 |
+
traced functions and modules.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
fn: A callable.
|
| 438 |
+
Returns:
|
| 439 |
+
Argument names: List[str]
|
| 440 |
+
"""
|
| 441 |
+
# inspect.signature may fail, give up in that case.
|
| 442 |
+
try:
|
| 443 |
+
callable_signature = inspect.signature(fn)
|
| 444 |
+
except Exception:
|
| 445 |
+
return []
|
| 446 |
+
|
| 447 |
+
argument_names = []
|
| 448 |
+
for name, param in callable_signature.parameters.items():
|
| 449 |
+
# All four other types of arguments do not map to individual values
|
| 450 |
+
# with a keyword as name.
|
| 451 |
+
if not param.kind == param.POSITIONAL_OR_KEYWORD:
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
argument_names.append(name)
|
| 455 |
+
|
| 456 |
+
return argument_names
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def get_annotation_str(annotation):
|
| 460 |
+
"""
|
| 461 |
+
Convert an AST node containing a type annotation to the string present in the source
|
| 462 |
+
that represents the same annotation.
|
| 463 |
+
"""
|
| 464 |
+
if isinstance(annotation, ast.Name):
|
| 465 |
+
return annotation.id
|
| 466 |
+
elif isinstance(annotation, ast.Attribute):
|
| 467 |
+
return ".".join([get_annotation_str(annotation.value), annotation.attr])
|
| 468 |
+
elif isinstance(annotation, ast.Subscript):
|
| 469 |
+
# In Python3.9+ subscript indicies are not wrapped in ast.Index
|
| 470 |
+
subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined]
|
| 471 |
+
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
|
| 472 |
+
elif isinstance(annotation, ast.Tuple):
|
| 473 |
+
return ",".join([get_annotation_str(elt) for elt in annotation.elts])
|
| 474 |
+
elif isinstance(annotation, ast.Constant):
|
| 475 |
+
return f"{annotation.value}"
|
| 476 |
+
|
| 477 |
+
# If an AST node is not handled here, it's probably handled in ScriptTypeParser.
|
| 478 |
+
return None
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def get_type_hint_captures(fn):
|
| 482 |
+
"""
|
| 483 |
+
Get a dictionary containing type resolution mappings necessary to resolve types
|
| 484 |
+
for the literal annotations on 'fn'. These are not considered to be closed-over by fn
|
| 485 |
+
and must be obtained separately (e.g. using this function).
|
| 486 |
+
|
| 487 |
+
Args:
|
| 488 |
+
fn: A callable.
|
| 489 |
+
Returns:
|
| 490 |
+
A Dict[str, Any] containing a mapping from the literal annotations used on
|
| 491 |
+
fn to the Python objects they refer to.
|
| 492 |
+
"""
|
| 493 |
+
# First, try to get the source of the function. We'll need to parse it to find the actual string names
|
| 494 |
+
# that were used to annotate the types, since inspect.signature() will only return the class object that
|
| 495 |
+
# the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
|
| 496 |
+
# This may happen in cases where the function is synthesized dynamically at runtime.
|
| 497 |
+
src = loader.get_source(fn)
|
| 498 |
+
if src is None:
|
| 499 |
+
try:
|
| 500 |
+
src = inspect.getsource(fn)
|
| 501 |
+
except OSError as e:
|
| 502 |
+
raise OSError(
|
| 503 |
+
f"Failed to get source for {fn} using inspect.getsource"
|
| 504 |
+
) from e
|
| 505 |
+
|
| 506 |
+
# Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
|
| 507 |
+
# types are strings. These are only understood by TorchScript in the context of a type annotation
|
| 508 |
+
# that refers to a class in its own definition, but trying to include a mapping for this in the result
|
| 509 |
+
# function would cause infinite recursion because the class is currently being compiled.
|
| 510 |
+
# In addition, there is logic in ScriptTypeParser to handle this.
|
| 511 |
+
signature = inspect.signature(fn)
|
| 512 |
+
name_to_type = {
|
| 513 |
+
name: parameter.annotation
|
| 514 |
+
for name, parameter in signature.parameters.items()
|
| 515 |
+
if parameter.annotation is not inspect.Parameter.empty
|
| 516 |
+
and not isinstance(parameter.annotation, str)
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
# Then, get the literal type annotations from the function declaration
|
| 520 |
+
# by source inspection. This accounts for the case in which aliases are used
|
| 521 |
+
# to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
|
| 522 |
+
# frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
|
| 523 |
+
a = ast.parse(textwrap.dedent(src))
|
| 524 |
+
if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
|
| 525 |
+
raise RuntimeError(f"Expected {fn} to be a function")
|
| 526 |
+
f = a.body[0]
|
| 527 |
+
|
| 528 |
+
# Prepare a dictionary of source annotation -> type, which will be the final result of this function,
|
| 529 |
+
# by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping
|
| 530 |
+
# them to the type object corresponding to the annotation via name_to_type using the parameter name.
|
| 531 |
+
annotation_to_type = {}
|
| 532 |
+
|
| 533 |
+
for arg in f.args.args:
|
| 534 |
+
# Get the source type annotation string for this argument if possible.
|
| 535 |
+
arg_annotation_str = (
|
| 536 |
+
get_annotation_str(arg.annotation) if arg.annotation else None
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# If the argument has no annotation or get_annotation_str cannot convert it to a string,
|
| 540 |
+
# arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle
|
| 541 |
+
# this in the latter case.
|
| 542 |
+
if arg_annotation_str is None:
|
| 543 |
+
continue
|
| 544 |
+
|
| 545 |
+
# Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not
|
| 546 |
+
# be present in name_to_type is that the annotation itself is a string and not a type object
|
| 547 |
+
# (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this.
|
| 548 |
+
arg_name = arg.arg
|
| 549 |
+
if arg_name in name_to_type:
|
| 550 |
+
annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
|
| 551 |
+
|
| 552 |
+
# If there is a valid return annotation, include it in annotation_to_type. As with argument annotations,
|
| 553 |
+
# the literal annotation has to be convertible to a string by get_annotation_str, and the actual type
|
| 554 |
+
# of the annotation cannot be a string.
|
| 555 |
+
literal_return_annotation = get_annotation_str(f.returns)
|
| 556 |
+
valid_literal_annotation = literal_return_annotation is not None
|
| 557 |
+
return_annotation = signature.return_annotation
|
| 558 |
+
valid_return_annotation_type = (
|
| 559 |
+
return_annotation is not inspect.Parameter.empty
|
| 560 |
+
and not isinstance(return_annotation, str)
|
| 561 |
+
)
|
| 562 |
+
if valid_literal_annotation and valid_return_annotation_type:
|
| 563 |
+
annotation_to_type[literal_return_annotation] = return_annotation
|
| 564 |
+
|
| 565 |
+
return annotation_to_type
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def createResolutionCallbackForClassMethods(cls):
|
| 569 |
+
"""
|
| 570 |
+
This looks at all the methods defined in a class and pulls their closed-over
|
| 571 |
+
variables into a dictionary and uses that to resolve variables.
|
| 572 |
+
"""
|
| 573 |
+
# cls is a type here, so `ismethod` is false since the methods on the type
|
| 574 |
+
# aren't bound to anything, so Python treats them as regular functions
|
| 575 |
+
fns = [
|
| 576 |
+
getattr(cls, name)
|
| 577 |
+
for name in cls.__dict__
|
| 578 |
+
if inspect.isroutine(getattr(cls, name))
|
| 579 |
+
]
|
| 580 |
+
# Skip built-ins, as they do not have global scope nor type hints
|
| 581 |
+
# Needed to support `enum.Enum` derived classes in Python-3.11
|
| 582 |
+
# That adds `_new_member_` property which is an alias to `__new__`
|
| 583 |
+
fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
|
| 584 |
+
captures = {}
|
| 585 |
+
|
| 586 |
+
for fn in fns:
|
| 587 |
+
captures.update(get_closure(fn))
|
| 588 |
+
captures.update(get_type_hint_captures(fn))
|
| 589 |
+
|
| 590 |
+
def lookup_in_class(key):
|
| 591 |
+
if key in captures:
|
| 592 |
+
return captures[key]
|
| 593 |
+
else:
|
| 594 |
+
return getattr(builtins, key, None)
|
| 595 |
+
|
| 596 |
+
return lookup_in_class
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def boolean_dispatch(
|
| 600 |
+
arg_name,
|
| 601 |
+
arg_index,
|
| 602 |
+
default,
|
| 603 |
+
if_true,
|
| 604 |
+
if_false,
|
| 605 |
+
module_name,
|
| 606 |
+
func_name,
|
| 607 |
+
):
|
| 608 |
+
"""
|
| 609 |
+
Dispatches to either of 2 script functions based on a boolean argument.
|
| 610 |
+
In TorchScript, the boolean argument must be constant so that the correct
|
| 611 |
+
function to use can be determined at compile time.
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
def fn(*args, **kwargs):
|
| 615 |
+
dispatch_flag = default
|
| 616 |
+
if arg_name in kwargs:
|
| 617 |
+
dispatch_flag = kwargs[arg_name]
|
| 618 |
+
elif arg_index < len(args):
|
| 619 |
+
dispatch_flag = args[arg_index]
|
| 620 |
+
|
| 621 |
+
if dispatch_flag:
|
| 622 |
+
return if_true(*args, **kwargs)
|
| 623 |
+
else:
|
| 624 |
+
return if_false(*args, **kwargs)
|
| 625 |
+
|
| 626 |
+
if if_true.__doc__ is None and if_false.__doc__ is not None:
|
| 627 |
+
doc = if_false.__doc__
|
| 628 |
+
if_true.__doc__ = doc
|
| 629 |
+
elif if_false.__doc__ is None and if_true.__doc__ is not None:
|
| 630 |
+
doc = if_true.__doc__
|
| 631 |
+
if_false.__doc__ = doc
|
| 632 |
+
elif if_false.__doc__ is None and if_true.__doc__ is None:
|
| 633 |
+
# neither function has a docstring
|
| 634 |
+
doc = None
|
| 635 |
+
else:
|
| 636 |
+
raise RuntimeError("only one function can have a docstring")
|
| 637 |
+
fn.__doc__ = doc
|
| 638 |
+
|
| 639 |
+
if module_name is not None:
|
| 640 |
+
fn.__module__ = module_name
|
| 641 |
+
if func_name is not None:
|
| 642 |
+
fn.__name__ = func_name
|
| 643 |
+
|
| 644 |
+
boolean_dispatched[fn] = {
|
| 645 |
+
"if_true": if_true,
|
| 646 |
+
"if_false": if_false,
|
| 647 |
+
"index": arg_index,
|
| 648 |
+
"default": default,
|
| 649 |
+
"arg_name": arg_name,
|
| 650 |
+
}
|
| 651 |
+
return fn
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
class FunctionModifiers:
|
| 655 |
+
"""
|
| 656 |
+
Used to denote the behavior of a function in TorchScript. See export() and
|
| 657 |
+
ignore() for details.
|
| 658 |
+
"""
|
| 659 |
+
|
| 660 |
+
UNUSED = "unused (ignored and replaced with raising of an exception)"
|
| 661 |
+
IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
|
| 662 |
+
EXPORT = "export (compile this function even if nothing calls it)"
|
| 663 |
+
DEFAULT = "default (compile if called from a exported function / forward)"
|
| 664 |
+
COPY_TO_SCRIPT_WRAPPER = (
|
| 665 |
+
"if this method is not scripted, copy the python method onto the scripted model"
|
| 666 |
+
)
|
| 667 |
+
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def export(fn):
|
| 671 |
+
"""
|
| 672 |
+
This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
|
| 673 |
+
:class:`ScriptModule` and should be compiled.
|
| 674 |
+
|
| 675 |
+
``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
|
| 676 |
+
Functions and methods called from ``forward`` are compiled as they are seen
|
| 677 |
+
by the compiler, so they do not need this decorator either.
|
| 678 |
+
|
| 679 |
+
Example (using ``@torch.jit.export`` on a method):
|
| 680 |
+
|
| 681 |
+
.. testcode::
|
| 682 |
+
|
| 683 |
+
import torch
|
| 684 |
+
import torch.nn as nn
|
| 685 |
+
|
| 686 |
+
class MyModule(nn.Module):
|
| 687 |
+
def implicitly_compiled_method(self, x):
|
| 688 |
+
return x + 99
|
| 689 |
+
|
| 690 |
+
# `forward` is implicitly decorated with `@torch.jit.export`,
|
| 691 |
+
# so adding it here would have no effect
|
| 692 |
+
def forward(self, x):
|
| 693 |
+
return x + 10
|
| 694 |
+
|
| 695 |
+
@torch.jit.export
|
| 696 |
+
def another_forward(self, x):
|
| 697 |
+
# When the compiler sees this call, it will compile
|
| 698 |
+
# `implicitly_compiled_method`
|
| 699 |
+
return self.implicitly_compiled_method(x)
|
| 700 |
+
|
| 701 |
+
def unused_method(self, x):
|
| 702 |
+
return x - 20
|
| 703 |
+
|
| 704 |
+
# `m` will contain compiled methods:
|
| 705 |
+
# `forward`
|
| 706 |
+
# `another_forward`
|
| 707 |
+
# `implicitly_compiled_method`
|
| 708 |
+
# `unused_method` will not be compiled since it was not called from
|
| 709 |
+
# any compiled methods and wasn't decorated with `@torch.jit.export`
|
| 710 |
+
m = torch.jit.script(MyModule())
|
| 711 |
+
"""
|
| 712 |
+
fn._torchscript_modifier = FunctionModifiers.EXPORT
|
| 713 |
+
return fn
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def unused(fn):
|
| 717 |
+
"""
|
| 718 |
+
This decorator indicates to the compiler that a function or method should
|
| 719 |
+
be ignored and replaced with the raising of an exception. This allows you
|
| 720 |
+
to leave code in your model that is not yet TorchScript compatible and still
|
| 721 |
+
export your model.
|
| 722 |
+
|
| 723 |
+
Example (using ``@torch.jit.unused`` on a method)::
|
| 724 |
+
|
| 725 |
+
import torch
|
| 726 |
+
import torch.nn as nn
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
class MyModule(nn.Module):
|
| 730 |
+
def __init__(self, use_memory_efficient):
|
| 731 |
+
super().__init__()
|
| 732 |
+
self.use_memory_efficient = use_memory_efficient
|
| 733 |
+
|
| 734 |
+
@torch.jit.unused
|
| 735 |
+
def memory_efficient(self, x):
|
| 736 |
+
import pdb
|
| 737 |
+
|
| 738 |
+
pdb.set_trace()
|
| 739 |
+
return x + 10
|
| 740 |
+
|
| 741 |
+
def forward(self, x):
|
| 742 |
+
# Use not-yet-scriptable memory efficient mode
|
| 743 |
+
if self.use_memory_efficient:
|
| 744 |
+
return self.memory_efficient(x)
|
| 745 |
+
else:
|
| 746 |
+
return x + 10
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
m = torch.jit.script(MyModule(use_memory_efficient=False))
|
| 750 |
+
m.save("m.pt")
|
| 751 |
+
|
| 752 |
+
m = torch.jit.script(MyModule(use_memory_efficient=True))
|
| 753 |
+
# exception raised
|
| 754 |
+
m(torch.rand(100))
|
| 755 |
+
"""
|
| 756 |
+
if isinstance(fn, property):
|
| 757 |
+
prop = fn
|
| 758 |
+
setattr( # noqa: B010
|
| 759 |
+
prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
if prop.fset:
|
| 763 |
+
setattr( # noqa: B010
|
| 764 |
+
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
return prop
|
| 768 |
+
|
| 769 |
+
fn._torchscript_modifier = FunctionModifiers.UNUSED
|
| 770 |
+
return fn
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
# No op context manager from python side
|
| 774 |
+
class _IgnoreContextManager(contextlib.AbstractContextManager):
|
| 775 |
+
def __init__(self, **kwargs):
|
| 776 |
+
pass
|
| 777 |
+
|
| 778 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 779 |
+
pass
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def ignore(drop=False, **kwargs):
|
| 783 |
+
"""
|
| 784 |
+
This decorator indicates to the compiler that a function or method should
|
| 785 |
+
be ignored and left as a Python function. This allows you to leave code in
|
| 786 |
+
your model that is not yet TorchScript compatible. If called from TorchScript,
|
| 787 |
+
ignored functions will dispatch the call to the Python interpreter. Models with ignored
|
| 788 |
+
functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
|
| 789 |
+
|
| 790 |
+
Example (using ``@torch.jit.ignore`` on a method)::
|
| 791 |
+
|
| 792 |
+
import torch
|
| 793 |
+
import torch.nn as nn
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
class MyModule(nn.Module):
|
| 797 |
+
@torch.jit.ignore
|
| 798 |
+
def debugger(self, x):
|
| 799 |
+
import pdb
|
| 800 |
+
|
| 801 |
+
pdb.set_trace()
|
| 802 |
+
|
| 803 |
+
def forward(self, x):
|
| 804 |
+
x += 10
|
| 805 |
+
# The compiler would normally try to compile `debugger`,
|
| 806 |
+
# but since it is `@ignore`d, it will be left as a call
|
| 807 |
+
# to Python
|
| 808 |
+
self.debugger(x)
|
| 809 |
+
return x
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
m = torch.jit.script(MyModule())
|
| 813 |
+
|
| 814 |
+
# Error! The call `debugger` cannot be saved since it calls into Python
|
| 815 |
+
m.save("m.pt")
|
| 816 |
+
|
| 817 |
+
Example (using ``@torch.jit.ignore(drop=True)`` on a method):
|
| 818 |
+
|
| 819 |
+
.. testcode::
|
| 820 |
+
|
| 821 |
+
import torch
|
| 822 |
+
import torch.nn as nn
|
| 823 |
+
|
| 824 |
+
class MyModule(nn.Module):
|
| 825 |
+
@torch.jit.ignore(drop=True)
|
| 826 |
+
def training_method(self, x):
|
| 827 |
+
import pdb
|
| 828 |
+
pdb.set_trace()
|
| 829 |
+
|
| 830 |
+
def forward(self, x):
|
| 831 |
+
if self.training:
|
| 832 |
+
self.training_method(x)
|
| 833 |
+
return x
|
| 834 |
+
|
| 835 |
+
m = torch.jit.script(MyModule())
|
| 836 |
+
|
| 837 |
+
# This is OK since `training_method` is not saved, the call is replaced
|
| 838 |
+
# with a `raise`.
|
| 839 |
+
m.save("m.pt")
|
| 840 |
+
|
| 841 |
+
.. testcleanup::
|
| 842 |
+
|
| 843 |
+
import os
|
| 844 |
+
os.remove('m.pt')
|
| 845 |
+
"""
|
| 846 |
+
|
| 847 |
+
if callable(drop):
|
| 848 |
+
# used without any args, so drop is actually a function
|
| 849 |
+
# @torch.jit.ignore
|
| 850 |
+
# def fn(...):
|
| 851 |
+
fn = drop
|
| 852 |
+
fn._torchscript_modifier = FunctionModifiers.IGNORE
|
| 853 |
+
return fn
|
| 854 |
+
|
| 855 |
+
if not isinstance(drop, bool):
|
| 856 |
+
raise RuntimeError(
|
| 857 |
+
"Argument to @torch.jit.ignore must be a bool or "
|
| 858 |
+
f"a function but got {drop}"
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
# for backwards compat
|
| 862 |
+
drop_on_export = kwargs.pop("drop_on_export", None)
|
| 863 |
+
if drop_on_export:
|
| 864 |
+
warnings.warn(
|
| 865 |
+
"ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
|
| 866 |
+
"call on compilation. Use torch.jit.unused now. {}",
|
| 867 |
+
category=FutureWarning,
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
drop = drop_on_export
|
| 871 |
+
elif drop:
|
| 872 |
+
warnings.warn(
|
| 873 |
+
"ignore(True) has been deprecated. TorchScript will now drop the function "
|
| 874 |
+
"call on compilation. Use torch.jit.unused now. {}",
|
| 875 |
+
category=FutureWarning,
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
def decorator(fn):
|
| 879 |
+
if drop:
|
| 880 |
+
fn._torchscript_modifier = FunctionModifiers.UNUSED
|
| 881 |
+
else:
|
| 882 |
+
fn._torchscript_modifier = FunctionModifiers.IGNORE
|
| 883 |
+
return fn
|
| 884 |
+
|
| 885 |
+
return decorator
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def _drop(fn):
|
| 889 |
+
fn._torchscript_modifier = FunctionModifiers._DROP
|
| 890 |
+
return fn
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def _copy_to_script_wrapper(fn):
|
| 894 |
+
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
|
| 895 |
+
return fn
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
def module_has_exports(mod):
|
| 899 |
+
for name in dir(mod):
|
| 900 |
+
if hasattr(mod, name):
|
| 901 |
+
item = getattr(mod, name)
|
| 902 |
+
if callable(item):
|
| 903 |
+
if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
|
| 904 |
+
return True
|
| 905 |
+
return False
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you
|
| 909 |
+
# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to
|
| 910 |
+
# allow JIT'd code to still be covered.
|
| 911 |
+
def should_drop(fn) -> bool:
|
| 912 |
+
attr = get_torchscript_modifier(fn)
|
| 913 |
+
if attr is None:
|
| 914 |
+
return False
|
| 915 |
+
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def is_ignored_fn(fn) -> bool:
|
| 919 |
+
mod = get_torchscript_modifier(fn)
|
| 920 |
+
return (
|
| 921 |
+
mod is FunctionModifiers.UNUSED
|
| 922 |
+
or mod is FunctionModifiers.IGNORE
|
| 923 |
+
or mod is FunctionModifiers._DROP
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def _is_drop_fn(fn) -> bool:
|
| 928 |
+
mod = get_torchscript_modifier(fn)
|
| 929 |
+
return mod is FunctionModifiers._DROP
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
def is_static_fn(cls, fn) -> bool:
|
| 933 |
+
return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
def get_static_fn(cls, fn):
|
| 937 |
+
return inspect.getattr_static(cls, fn).__func__
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def get_torchscript_modifier(fn):
|
| 941 |
+
if not callable(fn):
|
| 942 |
+
return None
|
| 943 |
+
if hasattr(fn, "__func__"):
|
| 944 |
+
fn = fn.__func__
|
| 945 |
+
return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def copy_torchscript_modifier(orig, new) -> None:
|
| 949 |
+
attr = get_torchscript_modifier(orig)
|
| 950 |
+
if attr is None:
|
| 951 |
+
return
|
| 952 |
+
new._torchscript_modifier = attr
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
# overloading registration
|
| 956 |
+
# overloads get registered in this file, and compiled in torch/jit/__init__.py
|
| 957 |
+
# so that they can be imported in nn/functional.py without an import cycle
|
| 958 |
+
|
| 959 |
+
# qualified_name => list[overload_functions]
|
| 960 |
+
_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484
|
| 961 |
+
|
| 962 |
+
|
| 963 |
+
_OVERLOAD_EXAMPLE = """
|
| 964 |
+
Example usage of overload function:
|
| 965 |
+
@torch.jit._overload
|
| 966 |
+
def my_function(x: type0) -> type0: # decl 1
|
| 967 |
+
pass
|
| 968 |
+
|
| 969 |
+
@torch.jit._overload
|
| 970 |
+
def my_function(x: type1) -> type1: # decl 2
|
| 971 |
+
pass
|
| 972 |
+
|
| 973 |
+
def my_function(x): # implementation
|
| 974 |
+
if isinstance(x, type0):
|
| 975 |
+
return x
|
| 976 |
+
elif isinstance(x, type1):
|
| 977 |
+
return x
|
| 978 |
+
"""
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
def get_overload_no_implementation_error_message(kind, obj):
|
| 982 |
+
sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
|
| 983 |
+
return (
|
| 984 |
+
f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
|
| 985 |
+
f"sure a definition is provided and defined after all overload declarations.\n"
|
| 986 |
+
f'File "{filename}", line {file_lineno}:\n'
|
| 987 |
+
+ "".join(sourcelines)
|
| 988 |
+
+ "\n"
|
| 989 |
+
+ _OVERLOAD_EXAMPLE
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
def _check_overload_body(func):
|
| 994 |
+
try:
|
| 995 |
+
parsed_def = parse_def(func)
|
| 996 |
+
except OSError as e:
|
| 997 |
+
# Parsing the function definition can raise an OSError if source is unavailable.
|
| 998 |
+
# Since this is just an initial check, just raise a warning if this is the case.
|
| 999 |
+
warnings.warn(
|
| 1000 |
+
f"Unable to retrieve source for @torch.jit._overload function: {func}."
|
| 1001 |
+
)
|
| 1002 |
+
return
|
| 1003 |
+
|
| 1004 |
+
body = parsed_def.ast.body[0].body
|
| 1005 |
+
|
| 1006 |
+
def is_pass(x):
|
| 1007 |
+
return isinstance(x, ast.Pass)
|
| 1008 |
+
|
| 1009 |
+
def is_ellipsis(x):
|
| 1010 |
+
return (
|
| 1011 |
+
isinstance(x, ast.Expr)
|
| 1012 |
+
and isinstance(x.value, ast.Constant)
|
| 1013 |
+
and x.value.value is Ellipsis
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
|
| 1017 |
+
msg = (
|
| 1018 |
+
"Only `pass` statement or `...` can be the body of overload declaration:\n"
|
| 1019 |
+
)
|
| 1020 |
+
msg += "\n".join(parsed_def.source.split("\n")[:3])
|
| 1021 |
+
msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
|
| 1022 |
+
raise RuntimeError(msg)
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
def _overload(func):
|
| 1026 |
+
_check_overload_body(func)
|
| 1027 |
+
qual_name = _qualified_name(func)
|
| 1028 |
+
global _overloaded_fns
|
| 1029 |
+
fn_overload_list = _overloaded_fns.get(qual_name)
|
| 1030 |
+
if fn_overload_list is None:
|
| 1031 |
+
fn_overload_list = []
|
| 1032 |
+
_overloaded_fns[qual_name] = fn_overload_list
|
| 1033 |
+
fn_overload_list.append(func)
|
| 1034 |
+
return func
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
def _get_fn_overloads(qual_name):
|
| 1038 |
+
return _overloaded_fns.get(qual_name)
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
def _clear_fn_overloads(qual_name) -> None:
|
| 1042 |
+
del _overloaded_fns[qual_name]
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
def get_class_name_lineno(method) -> Tuple[str, int]:
|
| 1046 |
+
current_frame = inspect.currentframe()
|
| 1047 |
+
|
| 1048 |
+
# one for the get_class_name call, one for _overload_method call
|
| 1049 |
+
for i in range(2):
|
| 1050 |
+
assert (
|
| 1051 |
+
current_frame is not None
|
| 1052 |
+
) # assert current frame is not an Optional[FrameType]
|
| 1053 |
+
current_frame = current_frame.f_back
|
| 1054 |
+
|
| 1055 |
+
assert current_frame is not None # same here
|
| 1056 |
+
class_name = current_frame.f_code.co_name
|
| 1057 |
+
line_no = current_frame.f_code.co_firstlineno
|
| 1058 |
+
return class_name, line_no
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
# At the point the decorator is applied to class methods the method
|
| 1062 |
+
# has no reference to its owning class. _qualified_name would not include
|
| 1063 |
+
# the class it is defined in, so any methods with the same name in the same file
|
| 1064 |
+
# would have the same _qualified_name, even if they were defined in different
|
| 1065 |
+
# classes. This problem only exists in python 2.
|
| 1066 |
+
# We get around this problem by looking at the stack frame and identifying
|
| 1067 |
+
# the class name, and throwing an error whenever overloads are used
|
| 1068 |
+
# when modules of the same name are in the same file
|
| 1069 |
+
|
| 1070 |
+
# qualified_name => class name => list[overload_functions]
|
| 1071 |
+
_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
# (qualified_name, class name) => class_fileno
|
| 1075 |
+
_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {}
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
def _overload_method(func):
|
| 1079 |
+
_check_overload_body(func)
|
| 1080 |
+
qual_name = _qualified_name(func)
|
| 1081 |
+
global _overloaded_methods
|
| 1082 |
+
class_name_map = _overloaded_methods.get(qual_name, None)
|
| 1083 |
+
if class_name_map is None:
|
| 1084 |
+
class_name_map = {}
|
| 1085 |
+
_overloaded_methods[qual_name] = class_name_map
|
| 1086 |
+
|
| 1087 |
+
class_name, line_no = get_class_name_lineno(func)
|
| 1088 |
+
method_overloads = class_name_map.get(class_name, None)
|
| 1089 |
+
if method_overloads is None:
|
| 1090 |
+
method_overloads = []
|
| 1091 |
+
class_name_map[class_name] = method_overloads
|
| 1092 |
+
_overloaded_method_class_fileno[(qual_name, class_name)] = line_no
|
| 1093 |
+
else:
|
| 1094 |
+
existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
|
| 1095 |
+
if existing_lineno != line_no:
|
| 1096 |
+
raise RuntimeError(
|
| 1097 |
+
"Cannot currently overload the same method name in two different"
|
| 1098 |
+
" classes with the same name in the same module"
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
method_overloads.append(func)
|
| 1102 |
+
return func
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
def _get_overloaded_methods(method, mod_class):
|
| 1106 |
+
# TODO: __name__ not set for submodules in recursive script
|
| 1107 |
+
if not hasattr(method, "__name__"):
|
| 1108 |
+
return None
|
| 1109 |
+
qual_name = _qualified_name(method)
|
| 1110 |
+
class_name_map = _overloaded_methods.get(qual_name, None)
|
| 1111 |
+
if class_name_map is None:
|
| 1112 |
+
return None
|
| 1113 |
+
overloads = class_name_map.get(mod_class.__name__, None)
|
| 1114 |
+
if overloads is None:
|
| 1115 |
+
return None
|
| 1116 |
+
|
| 1117 |
+
method_line_no = get_source_lines_and_file(method)[1]
|
| 1118 |
+
mod_class_fileno = get_source_lines_and_file(mod_class)[1]
|
| 1119 |
+
mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
|
| 1120 |
+
if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
|
| 1121 |
+
raise AssertionError(
|
| 1122 |
+
"Overloads are not useable when a module is redeclared within the same file: "
|
| 1123 |
+
+ str(method)
|
| 1124 |
+
)
|
| 1125 |
+
return overloads
|
| 1126 |
+
|
| 1127 |
+
|
| 1128 |
+
def is_tuple(ann) -> bool:
|
| 1129 |
+
if ann is Tuple:
|
| 1130 |
+
raise_error_container_parameter_missing("Tuple")
|
| 1131 |
+
|
| 1132 |
+
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
|
| 1133 |
+
if not hasattr(ann, "__module__"):
|
| 1134 |
+
return False
|
| 1135 |
+
|
| 1136 |
+
ann_origin = get_origin(ann)
|
| 1137 |
+
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
|
| 1138 |
+
return True
|
| 1139 |
+
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
def is_list(ann) -> bool:
|
| 1143 |
+
if ann is List:
|
| 1144 |
+
raise_error_container_parameter_missing("List")
|
| 1145 |
+
|
| 1146 |
+
if not hasattr(ann, "__module__"):
|
| 1147 |
+
return False
|
| 1148 |
+
|
| 1149 |
+
ann_origin = get_origin(ann)
|
| 1150 |
+
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
|
| 1151 |
+
return True
|
| 1152 |
+
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
def is_dict(ann) -> bool:
|
| 1156 |
+
if ann is Dict:
|
| 1157 |
+
raise_error_container_parameter_missing("Dict")
|
| 1158 |
+
|
| 1159 |
+
if not hasattr(ann, "__module__"):
|
| 1160 |
+
return False
|
| 1161 |
+
|
| 1162 |
+
ann_origin = get_origin(ann)
|
| 1163 |
+
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
|
| 1164 |
+
return True
|
| 1165 |
+
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
def is_union(ann):
|
| 1169 |
+
if ann is Union:
|
| 1170 |
+
raise_error_container_parameter_missing("Union")
|
| 1171 |
+
|
| 1172 |
+
return isinstance(ann, BuiltinUnionType) or (
|
| 1173 |
+
hasattr(ann, "__module__")
|
| 1174 |
+
and ann.__module__ == "typing"
|
| 1175 |
+
and (get_origin(ann) is Union)
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
|
| 1179 |
+
def is_optional(ann):
|
| 1180 |
+
if ann is Optional:
|
| 1181 |
+
raise_error_container_parameter_missing("Optional")
|
| 1182 |
+
|
| 1183 |
+
def is_optional_as_optional(ann):
|
| 1184 |
+
return (
|
| 1185 |
+
hasattr(ann, "__module__")
|
| 1186 |
+
and ann.__module__ == "typing"
|
| 1187 |
+
and (get_origin(ann) is Optional)
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
def is_union_as_optional(ann):
|
| 1191 |
+
ann_args = get_args(ann)
|
| 1192 |
+
return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
|
| 1193 |
+
|
| 1194 |
+
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
def is_future(ann) -> bool:
|
| 1198 |
+
if ann is Future:
|
| 1199 |
+
raise RuntimeError(
|
| 1200 |
+
"Attempted to use Future without a "
|
| 1201 |
+
"contained type. Please add a contained type, e.g. "
|
| 1202 |
+
"Future[int]"
|
| 1203 |
+
)
|
| 1204 |
+
return get_origin(ann) is Future
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
def is_await(ann) -> bool:
|
| 1208 |
+
if ann is _Await:
|
| 1209 |
+
return True
|
| 1210 |
+
return get_origin(ann) is _Await
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
+
if torch.distributed.rpc.is_available():
|
| 1214 |
+
from torch._C._distributed_rpc import PyRRef
|
| 1215 |
+
from torch.distributed.rpc import RRef
|
| 1216 |
+
|
| 1217 |
+
def is_rref(ann) -> bool:
|
| 1218 |
+
if ann is RRef:
|
| 1219 |
+
raise RuntimeError(
|
| 1220 |
+
"Attempted to use RRef without a "
|
| 1221 |
+
"contained type. Please add a contained type, e.g. "
|
| 1222 |
+
"RRef[int]"
|
| 1223 |
+
)
|
| 1224 |
+
return get_origin(ann) is RRef
|
| 1225 |
+
|
| 1226 |
+
def is_rref_instance(obj) -> bool:
|
| 1227 |
+
return isinstance(obj, PyRRef)
|
| 1228 |
+
|
| 1229 |
+
else:
|
| 1230 |
+
|
| 1231 |
+
def is_rref_instance(obj) -> bool:
|
| 1232 |
+
# If the RPC module doesn't exist then RRefs don't exist either.
|
| 1233 |
+
return False
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
def _try_get_dispatched_fn(fn):
|
| 1237 |
+
if not callable(fn):
|
| 1238 |
+
return None
|
| 1239 |
+
return boolean_dispatched.get(fn)
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
+
def _get_named_tuple_properties(
|
| 1243 |
+
obj,
|
| 1244 |
+
loc: Optional[torch._C._jit_tree_views.SourceRange] = None,
|
| 1245 |
+
rcb=None,
|
| 1246 |
+
):
|
| 1247 |
+
if loc is None:
|
| 1248 |
+
loc = fake_range()
|
| 1249 |
+
|
| 1250 |
+
assert issubclass(obj, tuple) and hasattr(obj, "_fields")
|
| 1251 |
+
if hasattr(obj, "_field_defaults"):
|
| 1252 |
+
defaults = [
|
| 1253 |
+
obj._field_defaults[field]
|
| 1254 |
+
for field in obj._fields
|
| 1255 |
+
if field in obj._field_defaults
|
| 1256 |
+
]
|
| 1257 |
+
else:
|
| 1258 |
+
defaults = []
|
| 1259 |
+
# In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
|
| 1260 |
+
# Also, annotations from base class are not inherited so they need to be queried explicitly
|
| 1261 |
+
if sys.version_info[:2] < (3, 10):
|
| 1262 |
+
obj_annotations = getattr(obj, "__annotations__", {})
|
| 1263 |
+
else:
|
| 1264 |
+
obj_annotations = inspect.get_annotations(obj)
|
| 1265 |
+
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
|
| 1266 |
+
obj_annotations = inspect.get_annotations(obj.__base__)
|
| 1267 |
+
|
| 1268 |
+
annotations = []
|
| 1269 |
+
for field in obj._fields:
|
| 1270 |
+
if field in obj_annotations:
|
| 1271 |
+
field_type = obj_annotations[field]
|
| 1272 |
+
# [Note: ForwardRef annotations in NamedTuple attributes]
|
| 1273 |
+
# NamedTuple types are slightly different from normal types.
|
| 1274 |
+
#
|
| 1275 |
+
# Normally, annotations are evaluted like this (during jit.script):
|
| 1276 |
+
# 1. Load strings of python code into c++ and parse.
|
| 1277 |
+
# 2. Get annotations as strings
|
| 1278 |
+
# 3. Use the PythonResolver's resolution callback (rcb) to convert
|
| 1279 |
+
# the string into a python object
|
| 1280 |
+
# 4. We call into annotations.py:ann_to_type to convert python obj
|
| 1281 |
+
# from step 3 into a type that torchscript understands.
|
| 1282 |
+
#
|
| 1283 |
+
# NamedTuples are more complicated, because it has sub-types.
|
| 1284 |
+
# Normally, once we have the NamedTuple type object from #3,
|
| 1285 |
+
# we can just look at the annotation literal values and use
|
| 1286 |
+
# ann_to_type directly on them.
|
| 1287 |
+
#
|
| 1288 |
+
# But sometimes, users will annotate with string literals, e.g.
|
| 1289 |
+
# x: 'int'
|
| 1290 |
+
# This also happens with PEP563 (from __forward__ import annotations)
|
| 1291 |
+
#
|
| 1292 |
+
# These annotations appear in the annotation dict as ForwardRef('int').
|
| 1293 |
+
#
|
| 1294 |
+
# Then, we need to convert the string into a python object. This
|
| 1295 |
+
# requires having local context for custom objects or imported types.
|
| 1296 |
+
# rcb() is what gives us this. So, we plumb rcb through the stack so
|
| 1297 |
+
# it can be used in this context for the if block below.
|
| 1298 |
+
#
|
| 1299 |
+
# FAQ:
|
| 1300 |
+
# - Why do we need this special handling for NamedTuple but string
|
| 1301 |
+
# annotations work fine for normal types? Normally, we parse the
|
| 1302 |
+
# string directly and then call rcb() directly from C++.
|
| 1303 |
+
# - Why not use ForwardRef._evaluate? For that, we need globals()
|
| 1304 |
+
# and locals() for the local context where the NamedTuple was defined.
|
| 1305 |
+
# rcb is what lets us look up into these. So, basically rcb does the
|
| 1306 |
+
# hard work for us.
|
| 1307 |
+
if isinstance(field_type, ForwardRef) and rcb is not None:
|
| 1308 |
+
rcb_type = rcb(field_type.__forward_arg__)
|
| 1309 |
+
# rcb returns None if it can't find anything.
|
| 1310 |
+
if rcb_type is None:
|
| 1311 |
+
raise ValueError(
|
| 1312 |
+
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
|
| 1313 |
+
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
|
| 1314 |
+
f" Issue occurred at {loc.highlight()}"
|
| 1315 |
+
)
|
| 1316 |
+
field_type = rcb_type
|
| 1317 |
+
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
|
| 1318 |
+
annotations.append(the_type)
|
| 1319 |
+
else:
|
| 1320 |
+
annotations.append(torch._C.TensorType.getInferred())
|
| 1321 |
+
return type(obj).__name__, obj._fields, annotations, defaults
|
| 1322 |
+
|
| 1323 |
+
|
| 1324 |
+
def _create_named_tuple(
|
| 1325 |
+
t,
|
| 1326 |
+
unqual_name: str,
|
| 1327 |
+
field_names: List[str],
|
| 1328 |
+
defaults: Tuple[Any, ...],
|
| 1329 |
+
):
|
| 1330 |
+
TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc]
|
| 1331 |
+
return TupleType(*t)
|
| 1332 |
+
|
| 1333 |
+
|
| 1334 |
+
@contextlib.contextmanager
|
| 1335 |
+
def _disable_emit_hooks():
|
| 1336 |
+
hooks = torch._C._jit_get_emit_hooks()
|
| 1337 |
+
torch._C._jit_set_emit_hooks(None, None)
|
| 1338 |
+
try:
|
| 1339 |
+
yield
|
| 1340 |
+
finally:
|
| 1341 |
+
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
|
| 1342 |
+
|
| 1343 |
+
|
| 1344 |
+
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811
|
| 1345 |
+
def __enter__(self) -> None:
|
| 1346 |
+
self.hooks = torch._C._jit_get_emit_hooks()
|
| 1347 |
+
torch._C._jit_set_emit_hooks(None, None)
|
| 1348 |
+
|
| 1349 |
+
def __exit__(self, *args) -> None:
|
| 1350 |
+
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
|
| 1351 |
+
|
| 1352 |
+
|
| 1353 |
+
def _is_exception(obj) -> bool:
|
| 1354 |
+
if not inspect.isclass(obj):
|
| 1355 |
+
return False
|
| 1356 |
+
return issubclass(obj, Exception)
|
| 1357 |
+
|
| 1358 |
+
|
| 1359 |
+
def raise_error_container_parameter_missing(target_type) -> None:
|
| 1360 |
+
if target_type == "Dict":
|
| 1361 |
+
raise RuntimeError(
|
| 1362 |
+
"Attempted to use Dict without "
|
| 1363 |
+
"contained types. Please add contained type, e.g. "
|
| 1364 |
+
"Dict[int, int]"
|
| 1365 |
+
)
|
| 1366 |
+
raise RuntimeError(
|
| 1367 |
+
f"Attempted to use {target_type} without a "
|
| 1368 |
+
"contained type. Please add a contained type, e.g. "
|
| 1369 |
+
f"{target_type}[int]"
|
| 1370 |
+
)
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
def check_args_exist(target_type) -> None:
|
| 1374 |
+
if target_type is List or target_type is list:
|
| 1375 |
+
raise_error_container_parameter_missing("List")
|
| 1376 |
+
elif target_type is Tuple or target_type is tuple:
|
| 1377 |
+
raise_error_container_parameter_missing("Tuple")
|
| 1378 |
+
elif target_type is Dict or target_type is dict:
|
| 1379 |
+
raise_error_container_parameter_missing("Dict")
|
| 1380 |
+
elif target_type is None or target_type is Optional:
|
| 1381 |
+
raise_error_container_parameter_missing("Optional")
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
def check_empty_containers(obj) -> None:
|
| 1385 |
+
if obj == [] or obj == {} or obj == ():
|
| 1386 |
+
warnings.warn(
|
| 1387 |
+
"The inner type of a container is lost when "
|
| 1388 |
+
"calling torch.jit.isinstance in eager mode. For "
|
| 1389 |
+
"example, List[int] would become list and "
|
| 1390 |
+
"therefore falsely return True for List[float] or"
|
| 1391 |
+
" List[str]."
|
| 1392 |
+
)
|
| 1393 |
+
|
| 1394 |
+
|
| 1395 |
+
# supports List/Dict/Tuple and Optional types
|
| 1396 |
+
# TODO support future
|
| 1397 |
+
def container_checker(obj, target_type) -> bool:
|
| 1398 |
+
origin_type = get_origin(target_type)
|
| 1399 |
+
check_args_exist(target_type)
|
| 1400 |
+
if origin_type is None:
|
| 1401 |
+
return False
|
| 1402 |
+
elif origin_type is list or origin_type is List:
|
| 1403 |
+
check_empty_containers(obj)
|
| 1404 |
+
if not isinstance(obj, list):
|
| 1405 |
+
return False
|
| 1406 |
+
arg_type = get_args(target_type)[0]
|
| 1407 |
+
arg_origin = get_origin(arg_type)
|
| 1408 |
+
for el in obj:
|
| 1409 |
+
# check if nested container, ex: List[List[str]]
|
| 1410 |
+
if arg_origin: # processes nested container, ex: List[List[str]]
|
| 1411 |
+
if not container_checker(el, arg_type):
|
| 1412 |
+
return False
|
| 1413 |
+
elif not isinstance(el, arg_type):
|
| 1414 |
+
return False
|
| 1415 |
+
return True
|
| 1416 |
+
elif origin_type is Dict or origin_type is dict:
|
| 1417 |
+
check_empty_containers(obj)
|
| 1418 |
+
if not isinstance(obj, dict):
|
| 1419 |
+
return False
|
| 1420 |
+
key_type = get_args(target_type)[0]
|
| 1421 |
+
val_type = get_args(target_type)[1]
|
| 1422 |
+
for key, val in obj.items():
|
| 1423 |
+
# check if keys are of right type
|
| 1424 |
+
if not isinstance(key, key_type):
|
| 1425 |
+
return False
|
| 1426 |
+
val_origin = get_origin(val_type)
|
| 1427 |
+
if val_origin:
|
| 1428 |
+
if not container_checker(val, val_type):
|
| 1429 |
+
return False
|
| 1430 |
+
elif not isinstance(val, val_type):
|
| 1431 |
+
return False
|
| 1432 |
+
return True
|
| 1433 |
+
elif origin_type is Tuple or origin_type is tuple:
|
| 1434 |
+
check_empty_containers(obj)
|
| 1435 |
+
if not isinstance(obj, tuple):
|
| 1436 |
+
return False
|
| 1437 |
+
arg_types = get_args(target_type)
|
| 1438 |
+
if len(obj) != len(arg_types):
|
| 1439 |
+
return False
|
| 1440 |
+
for el, el_type in zip(obj, arg_types):
|
| 1441 |
+
el_origin = get_origin(el_type)
|
| 1442 |
+
if el_origin:
|
| 1443 |
+
if not container_checker(el, el_type):
|
| 1444 |
+
return False
|
| 1445 |
+
elif not isinstance(el, el_type):
|
| 1446 |
+
return False
|
| 1447 |
+
return True
|
| 1448 |
+
elif origin_type is Union or issubclass(
|
| 1449 |
+
origin_type, BuiltinUnionType
|
| 1450 |
+
): # also handles Optional
|
| 1451 |
+
if obj is None: # check before recursion because None is always fine
|
| 1452 |
+
return True
|
| 1453 |
+
inner_types = get_args(target_type)
|
| 1454 |
+
for t in inner_types:
|
| 1455 |
+
t_origin = get_origin(t)
|
| 1456 |
+
if t_origin:
|
| 1457 |
+
return container_checker(obj, t)
|
| 1458 |
+
elif isinstance(obj, t):
|
| 1459 |
+
return True
|
| 1460 |
+
return False
|
| 1461 |
+
|
| 1462 |
+
|
| 1463 |
+
def _isinstance(obj, target_type) -> bool:
|
| 1464 |
+
if isinstance(target_type, collections.abc.Container):
|
| 1465 |
+
if not isinstance(target_type, tuple):
|
| 1466 |
+
raise RuntimeError(
|
| 1467 |
+
"The second argument to "
|
| 1468 |
+
"`torch.jit.isinstance` must be a type "
|
| 1469 |
+
"or a tuple of types"
|
| 1470 |
+
)
|
| 1471 |
+
for t_type in target_type:
|
| 1472 |
+
if _isinstance(obj, t_type):
|
| 1473 |
+
return True
|
| 1474 |
+
return False
|
| 1475 |
+
|
| 1476 |
+
origin_type = get_origin(target_type)
|
| 1477 |
+
if origin_type:
|
| 1478 |
+
return container_checker(obj, target_type)
|
| 1479 |
+
|
| 1480 |
+
# Check to handle non-typed optional origin returns as none instead
|
| 1481 |
+
# of as optional in 3.7-3.8
|
| 1482 |
+
check_args_exist(target_type)
|
| 1483 |
+
|
| 1484 |
+
# handle non-containers
|
| 1485 |
+
return isinstance(obj, target_type)
|
| 1486 |
+
|
| 1487 |
+
|
| 1488 |
+
class _TensorExtractor(pickle.Pickler):
|
| 1489 |
+
def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
|
| 1490 |
+
super().__init__(*args, **kwargs)
|
| 1491 |
+
self.tensors = tensors
|
| 1492 |
+
|
| 1493 |
+
def persistent_id(self, obj):
|
| 1494 |
+
if isinstance(obj, torch.Tensor):
|
| 1495 |
+
self.tensors.append(obj)
|
| 1496 |
+
return ""
|
| 1497 |
+
# Since we just want to extract tensors, we don't mind if an object is
|
| 1498 |
+
# unpicklable if it doesn't contain tensors, as we can just ignore/skip
|
| 1499 |
+
# it. To play it safe, we only do so for common objects that we're sure
|
| 1500 |
+
# don't contain tensors. Feel free to add new types here. Note also that
|
| 1501 |
+
# even if a type isn't listed here this won't block users, since thet
|
| 1502 |
+
# can just add a __getstate__ or __reduce__ method to their class.
|
| 1503 |
+
if isinstance(obj, LockType):
|
| 1504 |
+
return ""
|
| 1505 |
+
# Futures and RRefs don't technically contain a value, they just offer
|
| 1506 |
+
# the means to access a value.
|
| 1507 |
+
if isinstance(obj, CFuture) or is_rref_instance(obj):
|
| 1508 |
+
return ""
|
| 1509 |
+
if isinstance(obj, CAwait):
|
| 1510 |
+
return ""
|
| 1511 |
+
if isinstance(obj, torch.cuda.Event):
|
| 1512 |
+
return ""
|
| 1513 |
+
if isinstance(obj, threading.Thread):
|
| 1514 |
+
return ""
|
| 1515 |
+
return None
|
| 1516 |
+
|
| 1517 |
+
|
| 1518 |
+
def _extract_tensors(obj):
|
| 1519 |
+
r"""
|
| 1520 |
+
This function is exclusively called from C++.
|
| 1521 |
+
See ``torch/csrc/jit/python/python_ivalue.h``.
|
| 1522 |
+
|
| 1523 |
+
It extracts the tensors contained in the given object, through pickling.
|
| 1524 |
+
"""
|
| 1525 |
+
tensors: List[torch.Tensor] = []
|
| 1526 |
+
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
|
| 1527 |
+
extractor.dump(obj)
|
| 1528 |
+
return tensors
|
| 1529 |
+
|
| 1530 |
+
|
| 1531 |
+
def _get_model_id(obj) -> Optional[str]:
|
| 1532 |
+
if isinstance(obj, torch.jit.ScriptModule):
|
| 1533 |
+
return str(obj._c._type())
|
| 1534 |
+
elif isinstance(obj, torch.jit.ScriptFunction):
|
| 1535 |
+
return obj.qualified_name
|
| 1536 |
+
else:
|
| 1537 |
+
return None
|
| 1538 |
+
|
| 1539 |
+
|
| 1540 |
+
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
|
| 1541 |
+
# that were previously dropped. To preserve the behavior, explicitly drop them there
|
| 1542 |
+
|
| 1543 |
+
if sys.version_info > (3, 10):
|
| 1544 |
+
_drop(enum.Enum.__new__)
|
| 1545 |
+
_drop(enum.Enum.__format__)
|
| 1546 |
+
_drop(enum.Enum.__repr__)
|
| 1547 |
+
_drop(enum.Enum.__str__)
|
.venv/lib/python3.11/site-packages/torch/_linalg_utils.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""Various linear algebra utility methods for internal use."""
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def is_sparse(A):
|
| 11 |
+
"""Check if tensor A is a sparse tensor"""
|
| 12 |
+
if isinstance(A, torch.Tensor):
|
| 13 |
+
return A.layout == torch.sparse_coo
|
| 14 |
+
|
| 15 |
+
error_str = "expected Tensor"
|
| 16 |
+
if not torch.jit.is_scripting():
|
| 17 |
+
error_str += f" but got {type(A)}"
|
| 18 |
+
raise TypeError(error_str)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_floating_dtype(A):
|
| 22 |
+
"""Return the floating point dtype of tensor A.
|
| 23 |
+
|
| 24 |
+
Integer types map to float32.
|
| 25 |
+
"""
|
| 26 |
+
dtype = A.dtype
|
| 27 |
+
if dtype in (torch.float16, torch.float32, torch.float64):
|
| 28 |
+
return dtype
|
| 29 |
+
return torch.float32
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def matmul(A: Optional[Tensor], B: Tensor) -> Tensor:
|
| 33 |
+
"""Multiply two matrices.
|
| 34 |
+
|
| 35 |
+
If A is None, return B. A can be sparse or dense. B is always
|
| 36 |
+
dense.
|
| 37 |
+
"""
|
| 38 |
+
if A is None:
|
| 39 |
+
return B
|
| 40 |
+
if is_sparse(A):
|
| 41 |
+
return torch.sparse.mm(A, B)
|
| 42 |
+
return torch.matmul(A, B)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor:
|
| 46 |
+
"""Return bilinear form of matrices: :math:`X^T A Y`."""
|
| 47 |
+
return matmul(X.mT, matmul(A, Y))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def qform(A: Optional[Tensor], S: Tensor):
|
| 51 |
+
"""Return quadratic form :math:`S^T A S`."""
|
| 52 |
+
return bform(S, A, S)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def basis(A):
|
| 56 |
+
"""Return orthogonal basis of A columns."""
|
| 57 |
+
return torch.linalg.qr(A).Q
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]:
|
| 61 |
+
"""Return eigenpairs of A with specified ordering."""
|
| 62 |
+
if largest is None:
|
| 63 |
+
largest = False
|
| 64 |
+
E, Z = torch.linalg.eigh(A, UPLO="U")
|
| 65 |
+
# assuming that E is ordered
|
| 66 |
+
if largest:
|
| 67 |
+
E = torch.flip(E, dims=(-1,))
|
| 68 |
+
Z = torch.flip(Z, dims=(-1,))
|
| 69 |
+
return E, Z
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# These functions were deprecated and removed
|
| 73 |
+
# This nice error message can be removed in version 1.13+
|
| 74 |
+
def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor:
|
| 75 |
+
raise RuntimeError(
|
| 76 |
+
"This function was deprecated since version 1.9 and is now removed.\n"
|
| 77 |
+
"Please use the `torch.linalg.matrix_rank` function instead. "
|
| 78 |
+
"The parameter 'symmetric' was renamed in `torch.linalg.matrix_rank()` to 'hermitian'."
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
|
| 83 |
+
raise RuntimeError(
|
| 84 |
+
"This function was deprecated since version 1.9 and is now removed. "
|
| 85 |
+
"`torch.solve` is deprecated in favor of `torch.linalg.solve`. "
|
| 86 |
+
"`torch.linalg.solve` has its arguments reversed and does not return the LU factorization.\n\n"
|
| 87 |
+
"To get the LU factorization see `torch.lu`, which can be used with `torch.lu_solve` or `torch.lu_unpack`.\n"
|
| 88 |
+
"X = torch.solve(B, A).solution "
|
| 89 |
+
"should be replaced with:\n"
|
| 90 |
+
"X = torch.linalg.solve(A, B)"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def lstsq(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
|
| 95 |
+
raise RuntimeError(
|
| 96 |
+
"This function was deprecated since version 1.9 and is now removed. "
|
| 97 |
+
"`torch.lstsq` is deprecated in favor of `torch.linalg.lstsq`.\n"
|
| 98 |
+
"`torch.linalg.lstsq` has reversed arguments and does not return the QR decomposition in "
|
| 99 |
+
"the returned tuple (although it returns other information about the problem).\n\n"
|
| 100 |
+
"To get the QR decomposition consider using `torch.linalg.qr`.\n\n"
|
| 101 |
+
"The returned solution in `torch.lstsq` stored the residuals of the solution in the "
|
| 102 |
+
"last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, "
|
| 103 |
+
"the residuals are in the field 'residuals' of the returned named tuple.\n\n"
|
| 104 |
+
"The unpacking of the solution, as in\n"
|
| 105 |
+
"X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n"
|
| 106 |
+
"should be replaced with:\n"
|
| 107 |
+
"X = torch.linalg.lstsq(A, B).solution"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _symeig(
|
| 112 |
+
input,
|
| 113 |
+
eigenvectors=False,
|
| 114 |
+
upper=True,
|
| 115 |
+
*,
|
| 116 |
+
out=None,
|
| 117 |
+
) -> Tuple[Tensor, Tensor]:
|
| 118 |
+
raise RuntimeError(
|
| 119 |
+
"This function was deprecated since version 1.9 and is now removed. "
|
| 120 |
+
"The default behavior has changed from using the upper triangular portion of the matrix by default "
|
| 121 |
+
"to using the lower triangular portion.\n\n"
|
| 122 |
+
"L, _ = torch.symeig(A, upper=upper) "
|
| 123 |
+
"should be replaced with:\n"
|
| 124 |
+
"L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n\n"
|
| 125 |
+
"and\n\n"
|
| 126 |
+
"L, V = torch.symeig(A, eigenvectors=True) "
|
| 127 |
+
"should be replaced with:\n"
|
| 128 |
+
"L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def eig(
|
| 133 |
+
self: Tensor,
|
| 134 |
+
eigenvectors: bool = False,
|
| 135 |
+
*,
|
| 136 |
+
e=None,
|
| 137 |
+
v=None,
|
| 138 |
+
) -> Tuple[Tensor, Tensor]:
|
| 139 |
+
raise RuntimeError(
|
| 140 |
+
"This function was deprecated since version 1.9 and is now removed. "
|
| 141 |
+
"`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` rather than real tensors "
|
| 142 |
+
"mimicking complex tensors.\n\n"
|
| 143 |
+
"L, _ = torch.eig(A) "
|
| 144 |
+
"should be replaced with:\n"
|
| 145 |
+
"L_complex = torch.linalg.eigvals(A)\n\n"
|
| 146 |
+
"and\n\n"
|
| 147 |
+
"L, V = torch.eig(A, eigenvectors=True) "
|
| 148 |
+
"should be replaced with:\n"
|
| 149 |
+
"L_complex, V_complex = torch.linalg.eig(A)"
|
| 150 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_lobpcg.py
ADDED
|
@@ -0,0 +1,1157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""Locally Optimal Block Preconditioned Conjugate Gradient methods."""
|
| 3 |
+
# Author: Pearu Peterson
|
| 4 |
+
# Created: February 2020
|
| 5 |
+
|
| 6 |
+
from typing import Dict, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import _linalg_utils as _utils, Tensor
|
| 10 |
+
from torch.overrides import handle_torch_function, has_torch_function
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = ["lobpcg"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
|
| 17 |
+
# compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
|
| 18 |
+
F = D.unsqueeze(-2) - D.unsqueeze(-1)
|
| 19 |
+
F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
|
| 20 |
+
F.pow_(-1)
|
| 21 |
+
|
| 22 |
+
# A.grad = U (D.grad + (U^T U.grad * F)) U^T
|
| 23 |
+
Ut = U.mT.contiguous()
|
| 24 |
+
res = torch.matmul(
|
| 25 |
+
U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
return res
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _polynomial_coefficients_given_roots(roots):
|
| 32 |
+
"""
|
| 33 |
+
Given the `roots` of a polynomial, find the polynomial's coefficients.
|
| 34 |
+
|
| 35 |
+
If roots = (r_1, ..., r_n), then the method returns
|
| 36 |
+
coefficients (a_0, a_1, ..., a_n (== 1)) so that
|
| 37 |
+
p(x) = (x - r_1) * ... * (x - r_n)
|
| 38 |
+
= x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
|
| 39 |
+
|
| 40 |
+
Note: for better performance requires writing a low-level kernel
|
| 41 |
+
"""
|
| 42 |
+
poly_order = roots.shape[-1]
|
| 43 |
+
poly_coeffs_shape = list(roots.shape)
|
| 44 |
+
# we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
|
| 45 |
+
# so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
|
| 46 |
+
# but we insert one extra coefficient to enable better vectorization below
|
| 47 |
+
poly_coeffs_shape[-1] += 2
|
| 48 |
+
poly_coeffs = roots.new_zeros(poly_coeffs_shape)
|
| 49 |
+
poly_coeffs[..., 0] = 1
|
| 50 |
+
poly_coeffs[..., -1] = 1
|
| 51 |
+
|
| 52 |
+
# perform the Horner's rule
|
| 53 |
+
for i in range(1, poly_order + 1):
|
| 54 |
+
# note that it is computationally hard to compute backward for this method,
|
| 55 |
+
# because then given the coefficients it would require finding the roots and/or
|
| 56 |
+
# calculating the sensitivity based on the Vieta's theorem.
|
| 57 |
+
# So the code below tries to circumvent the explicit root finding by series
|
| 58 |
+
# of operations on memory copies imitating the Horner's method.
|
| 59 |
+
# The memory copies are required to construct nodes in the computational graph
|
| 60 |
+
# by exploting the explicit (not in-place, separate node for each step)
|
| 61 |
+
# recursion of the Horner's method.
|
| 62 |
+
# Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
|
| 63 |
+
poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
|
| 64 |
+
out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
|
| 65 |
+
out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
|
| 66 |
+
-1, poly_order - i + 1, i + 1
|
| 67 |
+
)
|
| 68 |
+
poly_coeffs = poly_coeffs_new
|
| 69 |
+
|
| 70 |
+
return poly_coeffs.narrow(-1, 1, poly_order + 1)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _polynomial_value(poly, x, zero_power, transition):
|
| 74 |
+
"""
|
| 75 |
+
A generic method for computing poly(x) using the Horner's rule.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
poly (Tensor): the (possibly batched) 1D Tensor representing
|
| 79 |
+
polynomial coefficients such that
|
| 80 |
+
poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
|
| 81 |
+
poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
|
| 82 |
+
|
| 83 |
+
x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
|
| 84 |
+
|
| 85 |
+
zero_power (Tensor): the representation of `x^0`. It is application-specific.
|
| 86 |
+
|
| 87 |
+
transition (Callable): the function that accepts some intermediate result `int_val`,
|
| 88 |
+
the `x` and a specific polynomial coefficient
|
| 89 |
+
`poly[..., k]` for some iteration `k`.
|
| 90 |
+
It basically performs one iteration of the Horner's rule
|
| 91 |
+
defined as `x * int_val + poly[..., k] * zero_power`.
|
| 92 |
+
Note that `zero_power` is not a parameter,
|
| 93 |
+
because the step `+ poly[..., k] * zero_power` depends on `x`,
|
| 94 |
+
whether it is a vector, a matrix, or something else, so this
|
| 95 |
+
functionality is delegated to the user.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
res = zero_power.clone()
|
| 99 |
+
for k in range(poly.size(-1) - 2, -1, -1):
|
| 100 |
+
res = transition(res, x, poly[..., k])
|
| 101 |
+
return res
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _matrix_polynomial_value(poly, x, zero_power=None):
|
| 105 |
+
"""
|
| 106 |
+
Evaluates `poly(x)` for the (batched) matrix input `x`.
|
| 107 |
+
Check out `_polynomial_value` function for more details.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
# matrix-aware Horner's rule iteration
|
| 111 |
+
def transition(curr_poly_val, x, poly_coeff):
|
| 112 |
+
res = x.matmul(curr_poly_val)
|
| 113 |
+
res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
|
| 114 |
+
return res
|
| 115 |
+
|
| 116 |
+
if zero_power is None:
|
| 117 |
+
zero_power = torch.eye(
|
| 118 |
+
x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
|
| 119 |
+
).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
|
| 120 |
+
|
| 121 |
+
return _polynomial_value(poly, x, zero_power, transition)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _vector_polynomial_value(poly, x, zero_power=None):
|
| 125 |
+
"""
|
| 126 |
+
Evaluates `poly(x)` for the (batched) vector input `x`.
|
| 127 |
+
Check out `_polynomial_value` function for more details.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
# vector-aware Horner's rule iteration
|
| 131 |
+
def transition(curr_poly_val, x, poly_coeff):
|
| 132 |
+
res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
|
| 133 |
+
return res
|
| 134 |
+
|
| 135 |
+
if zero_power is None:
|
| 136 |
+
zero_power = x.new_ones(1).expand(x.shape)
|
| 137 |
+
|
| 138 |
+
return _polynomial_value(poly, x, zero_power, transition)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
|
| 142 |
+
# compute a projection operator onto an orthogonal subspace spanned by the
|
| 143 |
+
# columns of U defined as (I - UU^T)
|
| 144 |
+
Ut = U.mT.contiguous()
|
| 145 |
+
proj_U_ortho = -U.matmul(Ut)
|
| 146 |
+
proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
|
| 147 |
+
|
| 148 |
+
# compute U_ortho, a basis for the orthogonal complement to the span(U),
|
| 149 |
+
# by projecting a random [..., m, m - k] matrix onto the subspace spanned
|
| 150 |
+
# by the columns of U.
|
| 151 |
+
#
|
| 152 |
+
# fix generator for determinism
|
| 153 |
+
gen = torch.Generator(A.device)
|
| 154 |
+
|
| 155 |
+
# orthogonal complement to the span(U)
|
| 156 |
+
U_ortho = proj_U_ortho.matmul(
|
| 157 |
+
torch.randn(
|
| 158 |
+
(*A.shape[:-1], A.size(-1) - D.size(-1)),
|
| 159 |
+
dtype=A.dtype,
|
| 160 |
+
device=A.device,
|
| 161 |
+
generator=gen,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
U_ortho_t = U_ortho.mT.contiguous()
|
| 165 |
+
|
| 166 |
+
# compute the coefficients of the characteristic polynomial of the tensor D.
|
| 167 |
+
# Note that D is diagonal, so the diagonal elements are exactly the roots
|
| 168 |
+
# of the characteristic polynomial.
|
| 169 |
+
chr_poly_D = _polynomial_coefficients_given_roots(D)
|
| 170 |
+
|
| 171 |
+
# the code belows finds the explicit solution to the Sylvester equation
|
| 172 |
+
# U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
|
| 173 |
+
# and incorporates it into the whole gradient stored in the `res` variable.
|
| 174 |
+
#
|
| 175 |
+
# Equivalent to the following naive implementation:
|
| 176 |
+
# res = A.new_zeros(A.shape)
|
| 177 |
+
# p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
|
| 178 |
+
# for k in range(1, chr_poly_D.size(-1)):
|
| 179 |
+
# p_res.zero_()
|
| 180 |
+
# for i in range(0, k):
|
| 181 |
+
# p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
|
| 182 |
+
# res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
|
| 183 |
+
#
|
| 184 |
+
# Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
|
| 185 |
+
# Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
|
| 186 |
+
# and we need to compute g(U_grad, A, U, D)
|
| 187 |
+
#
|
| 188 |
+
# The naive implementation is based on the paper
|
| 189 |
+
# Hu, Qingxi, and Daizhan Cheng.
|
| 190 |
+
# "The polynomial solution to the Sylvester matrix equation."
|
| 191 |
+
# Applied mathematics letters 19.9 (2006): 859-864.
|
| 192 |
+
#
|
| 193 |
+
# We can modify the computation of `p_res` from above in a more efficient way
|
| 194 |
+
# p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
|
| 195 |
+
# + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
|
| 196 |
+
# + ...
|
| 197 |
+
# + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
|
| 198 |
+
# Note that this saves us from redundant matrix products with A (elimination of matrix_power)
|
| 199 |
+
U_grad_projected = U_grad
|
| 200 |
+
series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
|
| 201 |
+
for k in range(1, chr_poly_D.size(-1)):
|
| 202 |
+
poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
|
| 203 |
+
series_acc += U_grad_projected * poly_D.unsqueeze(-2)
|
| 204 |
+
U_grad_projected = A.matmul(U_grad_projected)
|
| 205 |
+
|
| 206 |
+
# compute chr_poly_D(A) which essentially is:
|
| 207 |
+
#
|
| 208 |
+
# chr_poly_D_at_A = A.new_zeros(A.shape)
|
| 209 |
+
# for k in range(chr_poly_D.size(-1)):
|
| 210 |
+
# chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
|
| 211 |
+
#
|
| 212 |
+
# Note, however, for better performance we use the Horner's rule
|
| 213 |
+
chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
|
| 214 |
+
|
| 215 |
+
# compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
|
| 216 |
+
chr_poly_D_at_A_to_U_ortho = torch.matmul(
|
| 217 |
+
U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
|
| 218 |
+
)
|
| 219 |
+
# we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
|
| 220 |
+
# Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
|
| 221 |
+
# Cholesky decomposition requires the input to be positive-definite.
|
| 222 |
+
# Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
|
| 223 |
+
# 1. `largest` == False, or
|
| 224 |
+
# 2. `largest` == True and `k` is even
|
| 225 |
+
# under the assumption that `A` has distinct eigenvalues.
|
| 226 |
+
#
|
| 227 |
+
# check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
|
| 228 |
+
chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
|
| 229 |
+
chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
|
| 230 |
+
chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# compute the gradient part in span(U)
|
| 234 |
+
res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
|
| 235 |
+
|
| 236 |
+
# incorporate the Sylvester equation solution into the full gradient
|
| 237 |
+
# it resides in span(U_ortho)
|
| 238 |
+
res -= U_ortho.matmul(
|
| 239 |
+
chr_poly_D_at_A_to_U_ortho_sign
|
| 240 |
+
* torch.cholesky_solve(
|
| 241 |
+
U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
|
| 242 |
+
)
|
| 243 |
+
).matmul(Ut)
|
| 244 |
+
|
| 245 |
+
return res
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _symeig_backward(D_grad, U_grad, A, D, U, largest):
|
| 249 |
+
# if `U` is square, then the columns of `U` is a complete eigenspace
|
| 250 |
+
if U.size(-1) == U.size(-2):
|
| 251 |
+
return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
|
| 252 |
+
else:
|
| 253 |
+
return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class LOBPCGAutogradFunction(torch.autograd.Function):
|
| 257 |
+
@staticmethod
|
| 258 |
+
def forward( # type: ignore[override]
|
| 259 |
+
ctx,
|
| 260 |
+
A: Tensor,
|
| 261 |
+
k: Optional[int] = None,
|
| 262 |
+
B: Optional[Tensor] = None,
|
| 263 |
+
X: Optional[Tensor] = None,
|
| 264 |
+
n: Optional[int] = None,
|
| 265 |
+
iK: Optional[Tensor] = None,
|
| 266 |
+
niter: Optional[int] = None,
|
| 267 |
+
tol: Optional[float] = None,
|
| 268 |
+
largest: Optional[bool] = None,
|
| 269 |
+
method: Optional[str] = None,
|
| 270 |
+
tracker: None = None,
|
| 271 |
+
ortho_iparams: Optional[Dict[str, int]] = None,
|
| 272 |
+
ortho_fparams: Optional[Dict[str, float]] = None,
|
| 273 |
+
ortho_bparams: Optional[Dict[str, bool]] = None,
|
| 274 |
+
) -> Tuple[Tensor, Tensor]:
|
| 275 |
+
# makes sure that input is contiguous for efficiency.
|
| 276 |
+
# Note: autograd does not support dense gradients for sparse input yet.
|
| 277 |
+
A = A.contiguous() if (not A.is_sparse) else A
|
| 278 |
+
if B is not None:
|
| 279 |
+
B = B.contiguous() if (not B.is_sparse) else B
|
| 280 |
+
|
| 281 |
+
D, U = _lobpcg(
|
| 282 |
+
A,
|
| 283 |
+
k,
|
| 284 |
+
B,
|
| 285 |
+
X,
|
| 286 |
+
n,
|
| 287 |
+
iK,
|
| 288 |
+
niter,
|
| 289 |
+
tol,
|
| 290 |
+
largest,
|
| 291 |
+
method,
|
| 292 |
+
tracker,
|
| 293 |
+
ortho_iparams,
|
| 294 |
+
ortho_fparams,
|
| 295 |
+
ortho_bparams,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
ctx.save_for_backward(A, B, D, U)
|
| 299 |
+
ctx.largest = largest
|
| 300 |
+
|
| 301 |
+
return D, U
|
| 302 |
+
|
| 303 |
+
@staticmethod
|
| 304 |
+
def backward(ctx, D_grad, U_grad):
|
| 305 |
+
A_grad = B_grad = None
|
| 306 |
+
grads = [None] * 14
|
| 307 |
+
|
| 308 |
+
A, B, D, U = ctx.saved_tensors
|
| 309 |
+
largest = ctx.largest
|
| 310 |
+
|
| 311 |
+
# lobpcg.backward has some limitations. Checks for unsupported input
|
| 312 |
+
if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
|
| 313 |
+
raise ValueError(
|
| 314 |
+
"lobpcg.backward does not support sparse input yet."
|
| 315 |
+
"Note that lobpcg.forward does though."
|
| 316 |
+
)
|
| 317 |
+
if (
|
| 318 |
+
A.dtype in (torch.complex64, torch.complex128)
|
| 319 |
+
or B is not None
|
| 320 |
+
and B.dtype in (torch.complex64, torch.complex128)
|
| 321 |
+
):
|
| 322 |
+
raise ValueError(
|
| 323 |
+
"lobpcg.backward does not support complex input yet."
|
| 324 |
+
"Note that lobpcg.forward does though."
|
| 325 |
+
)
|
| 326 |
+
if B is not None:
|
| 327 |
+
raise ValueError(
|
| 328 |
+
"lobpcg.backward does not support backward with B != I yet."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if largest is None:
|
| 332 |
+
largest = True
|
| 333 |
+
|
| 334 |
+
# symeig backward
|
| 335 |
+
if B is None:
|
| 336 |
+
A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
|
| 337 |
+
|
| 338 |
+
# A has index 0
|
| 339 |
+
grads[0] = A_grad
|
| 340 |
+
# B has index 2
|
| 341 |
+
grads[2] = B_grad
|
| 342 |
+
return tuple(grads)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def lobpcg(
|
| 346 |
+
A: Tensor,
|
| 347 |
+
k: Optional[int] = None,
|
| 348 |
+
B: Optional[Tensor] = None,
|
| 349 |
+
X: Optional[Tensor] = None,
|
| 350 |
+
n: Optional[int] = None,
|
| 351 |
+
iK: Optional[Tensor] = None,
|
| 352 |
+
niter: Optional[int] = None,
|
| 353 |
+
tol: Optional[float] = None,
|
| 354 |
+
largest: Optional[bool] = None,
|
| 355 |
+
method: Optional[str] = None,
|
| 356 |
+
tracker: None = None,
|
| 357 |
+
ortho_iparams: Optional[Dict[str, int]] = None,
|
| 358 |
+
ortho_fparams: Optional[Dict[str, float]] = None,
|
| 359 |
+
ortho_bparams: Optional[Dict[str, bool]] = None,
|
| 360 |
+
) -> Tuple[Tensor, Tensor]:
|
| 361 |
+
"""Find the k largest (or smallest) eigenvalues and the corresponding
|
| 362 |
+
eigenvectors of a symmetric positive definite generalized
|
| 363 |
+
eigenvalue problem using matrix-free LOBPCG methods.
|
| 364 |
+
|
| 365 |
+
This function is a front-end to the following LOBPCG algorithms
|
| 366 |
+
selectable via `method` argument:
|
| 367 |
+
|
| 368 |
+
`method="basic"` - the LOBPCG method introduced by Andrew
|
| 369 |
+
Knyazev, see [Knyazev2001]. A less robust method, may fail when
|
| 370 |
+
Cholesky is applied to singular input.
|
| 371 |
+
|
| 372 |
+
`method="ortho"` - the LOBPCG method with orthogonal basis
|
| 373 |
+
selection [StathopoulosEtal2002]. A robust method.
|
| 374 |
+
|
| 375 |
+
Supported inputs are dense, sparse, and batches of dense matrices.
|
| 376 |
+
|
| 377 |
+
.. note:: In general, the basic method spends least time per
|
| 378 |
+
iteration. However, the robust methods converge much faster and
|
| 379 |
+
are more stable. So, the usage of the basic method is generally
|
| 380 |
+
not recommended but there exist cases where the usage of the
|
| 381 |
+
basic method may be preferred.
|
| 382 |
+
|
| 383 |
+
.. warning:: The backward method does not support sparse and complex inputs.
|
| 384 |
+
It works only when `B` is not provided (i.e. `B == None`).
|
| 385 |
+
We are actively working on extensions, and the details of
|
| 386 |
+
the algorithms are going to be published promptly.
|
| 387 |
+
|
| 388 |
+
.. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
|
| 389 |
+
To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
|
| 390 |
+
in first-order optimization routines, prior to running `lobpcg`
|
| 391 |
+
we do the following symmetrization map: `A -> (A + A.t()) / 2`.
|
| 392 |
+
The map is performed only when the `A` requires gradients.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
|
| 396 |
+
A (Tensor): the input tensor of size :math:`(*, m, m)`
|
| 397 |
+
|
| 398 |
+
B (Tensor, optional): the input tensor of size :math:`(*, m,
|
| 399 |
+
m)`. When not specified, `B` is interpreted as
|
| 400 |
+
identity matrix.
|
| 401 |
+
|
| 402 |
+
X (tensor, optional): the input tensor of size :math:`(*, m, n)`
|
| 403 |
+
where `k <= n <= m`. When specified, it is used as
|
| 404 |
+
initial approximation of eigenvectors. X must be a
|
| 405 |
+
dense tensor.
|
| 406 |
+
|
| 407 |
+
iK (tensor, optional): the input tensor of size :math:`(*, m,
|
| 408 |
+
m)`. When specified, it will be used as preconditioner.
|
| 409 |
+
|
| 410 |
+
k (integer, optional): the number of requested
|
| 411 |
+
eigenpairs. Default is the number of :math:`X`
|
| 412 |
+
columns (when specified) or `1`.
|
| 413 |
+
|
| 414 |
+
n (integer, optional): if :math:`X` is not specified then `n`
|
| 415 |
+
specifies the size of the generated random
|
| 416 |
+
approximation of eigenvectors. Default value for `n`
|
| 417 |
+
is `k`. If :math:`X` is specified, the value of `n`
|
| 418 |
+
(when specified) must be the number of :math:`X`
|
| 419 |
+
columns.
|
| 420 |
+
|
| 421 |
+
tol (float, optional): residual tolerance for stopping
|
| 422 |
+
criterion. Default is `feps ** 0.5` where `feps` is
|
| 423 |
+
smallest non-zero floating-point number of the given
|
| 424 |
+
input tensor `A` data type.
|
| 425 |
+
|
| 426 |
+
largest (bool, optional): when True, solve the eigenproblem for
|
| 427 |
+
the largest eigenvalues. Otherwise, solve the
|
| 428 |
+
eigenproblem for smallest eigenvalues. Default is
|
| 429 |
+
`True`.
|
| 430 |
+
|
| 431 |
+
method (str, optional): select LOBPCG method. See the
|
| 432 |
+
description of the function above. Default is
|
| 433 |
+
"ortho".
|
| 434 |
+
|
| 435 |
+
niter (int, optional): maximum number of iterations. When
|
| 436 |
+
reached, the iteration process is hard-stopped and
|
| 437 |
+
the current approximation of eigenpairs is returned.
|
| 438 |
+
For infinite iteration but until convergence criteria
|
| 439 |
+
is met, use `-1`.
|
| 440 |
+
|
| 441 |
+
tracker (callable, optional) : a function for tracing the
|
| 442 |
+
iteration process. When specified, it is called at
|
| 443 |
+
each iteration step with LOBPCG instance as an
|
| 444 |
+
argument. The LOBPCG instance holds the full state of
|
| 445 |
+
the iteration process in the following attributes:
|
| 446 |
+
|
| 447 |
+
`iparams`, `fparams`, `bparams` - dictionaries of
|
| 448 |
+
integer, float, and boolean valued input
|
| 449 |
+
parameters, respectively
|
| 450 |
+
|
| 451 |
+
`ivars`, `fvars`, `bvars`, `tvars` - dictionaries
|
| 452 |
+
of integer, float, boolean, and Tensor valued
|
| 453 |
+
iteration variables, respectively.
|
| 454 |
+
|
| 455 |
+
`A`, `B`, `iK` - input Tensor arguments.
|
| 456 |
+
|
| 457 |
+
`E`, `X`, `S`, `R` - iteration Tensor variables.
|
| 458 |
+
|
| 459 |
+
For instance:
|
| 460 |
+
|
| 461 |
+
`ivars["istep"]` - the current iteration step
|
| 462 |
+
`X` - the current approximation of eigenvectors
|
| 463 |
+
`E` - the current approximation of eigenvalues
|
| 464 |
+
`R` - the current residual
|
| 465 |
+
`ivars["converged_count"]` - the current number of converged eigenpairs
|
| 466 |
+
`tvars["rerr"]` - the current state of convergence criteria
|
| 467 |
+
|
| 468 |
+
Note that when `tracker` stores Tensor objects from
|
| 469 |
+
the LOBPCG instance, it must make copies of these.
|
| 470 |
+
|
| 471 |
+
If `tracker` sets `bvars["force_stop"] = True`, the
|
| 472 |
+
iteration process will be hard-stopped.
|
| 473 |
+
|
| 474 |
+
ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
|
| 475 |
+
various parameters to LOBPCG algorithm when using
|
| 476 |
+
`method="ortho"`.
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
|
| 480 |
+
E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
|
| 481 |
+
|
| 482 |
+
X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
|
| 483 |
+
|
| 484 |
+
References:
|
| 485 |
+
|
| 486 |
+
[Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
|
| 487 |
+
Preconditioned Eigensolver: Locally Optimal Block Preconditioned
|
| 488 |
+
Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
|
| 489 |
+
517-541. (25 pages)
|
| 490 |
+
https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
|
| 491 |
+
|
| 492 |
+
[StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
|
| 493 |
+
Wu. (2002) A Block Orthogonalization Procedure with Constant
|
| 494 |
+
Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
|
| 495 |
+
2165-2182. (18 pages)
|
| 496 |
+
https://epubs.siam.org/doi/10.1137/S1064827500370883
|
| 497 |
+
|
| 498 |
+
[DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
|
| 499 |
+
Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
|
| 500 |
+
SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
|
| 501 |
+
https://epubs.siam.org/doi/abs/10.1137/17M1129830
|
| 502 |
+
|
| 503 |
+
"""
|
| 504 |
+
|
| 505 |
+
if not torch.jit.is_scripting():
|
| 506 |
+
tensor_ops = (A, B, X, iK)
|
| 507 |
+
if not set(map(type, tensor_ops)).issubset(
|
| 508 |
+
(torch.Tensor, type(None))
|
| 509 |
+
) and has_torch_function(tensor_ops):
|
| 510 |
+
return handle_torch_function(
|
| 511 |
+
lobpcg,
|
| 512 |
+
tensor_ops,
|
| 513 |
+
A,
|
| 514 |
+
k=k,
|
| 515 |
+
B=B,
|
| 516 |
+
X=X,
|
| 517 |
+
n=n,
|
| 518 |
+
iK=iK,
|
| 519 |
+
niter=niter,
|
| 520 |
+
tol=tol,
|
| 521 |
+
largest=largest,
|
| 522 |
+
method=method,
|
| 523 |
+
tracker=tracker,
|
| 524 |
+
ortho_iparams=ortho_iparams,
|
| 525 |
+
ortho_fparams=ortho_fparams,
|
| 526 |
+
ortho_bparams=ortho_bparams,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
if not torch._jit_internal.is_scripting():
|
| 530 |
+
if A.requires_grad or (B is not None and B.requires_grad):
|
| 531 |
+
# While it is expected that `A` is symmetric,
|
| 532 |
+
# the `A_grad` might be not. Therefore we perform the trick below,
|
| 533 |
+
# so that `A_grad` becomes symmetric.
|
| 534 |
+
# The symmetrization is important for first-order optimization methods,
|
| 535 |
+
# so that (A - alpha * A_grad) is still a symmetric matrix.
|
| 536 |
+
# Same holds for `B`.
|
| 537 |
+
A_sym = (A + A.mT) / 2
|
| 538 |
+
B_sym = (B + B.mT) / 2 if (B is not None) else None
|
| 539 |
+
|
| 540 |
+
return LOBPCGAutogradFunction.apply(
|
| 541 |
+
A_sym,
|
| 542 |
+
k,
|
| 543 |
+
B_sym,
|
| 544 |
+
X,
|
| 545 |
+
n,
|
| 546 |
+
iK,
|
| 547 |
+
niter,
|
| 548 |
+
tol,
|
| 549 |
+
largest,
|
| 550 |
+
method,
|
| 551 |
+
tracker,
|
| 552 |
+
ortho_iparams,
|
| 553 |
+
ortho_fparams,
|
| 554 |
+
ortho_bparams,
|
| 555 |
+
)
|
| 556 |
+
else:
|
| 557 |
+
if A.requires_grad or (B is not None and B.requires_grad):
|
| 558 |
+
raise RuntimeError(
|
| 559 |
+
"Script and require grads is not supported atm."
|
| 560 |
+
"If you just want to do the forward, use .detach()"
|
| 561 |
+
"on A and B before calling into lobpcg"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
return _lobpcg(
|
| 565 |
+
A,
|
| 566 |
+
k,
|
| 567 |
+
B,
|
| 568 |
+
X,
|
| 569 |
+
n,
|
| 570 |
+
iK,
|
| 571 |
+
niter,
|
| 572 |
+
tol,
|
| 573 |
+
largest,
|
| 574 |
+
method,
|
| 575 |
+
tracker,
|
| 576 |
+
ortho_iparams,
|
| 577 |
+
ortho_fparams,
|
| 578 |
+
ortho_bparams,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def _lobpcg(
|
| 583 |
+
A: Tensor,
|
| 584 |
+
k: Optional[int] = None,
|
| 585 |
+
B: Optional[Tensor] = None,
|
| 586 |
+
X: Optional[Tensor] = None,
|
| 587 |
+
n: Optional[int] = None,
|
| 588 |
+
iK: Optional[Tensor] = None,
|
| 589 |
+
niter: Optional[int] = None,
|
| 590 |
+
tol: Optional[float] = None,
|
| 591 |
+
largest: Optional[bool] = None,
|
| 592 |
+
method: Optional[str] = None,
|
| 593 |
+
tracker: None = None,
|
| 594 |
+
ortho_iparams: Optional[Dict[str, int]] = None,
|
| 595 |
+
ortho_fparams: Optional[Dict[str, float]] = None,
|
| 596 |
+
ortho_bparams: Optional[Dict[str, bool]] = None,
|
| 597 |
+
) -> Tuple[Tensor, Tensor]:
|
| 598 |
+
# A must be square:
|
| 599 |
+
assert A.shape[-2] == A.shape[-1], A.shape
|
| 600 |
+
if B is not None:
|
| 601 |
+
# A and B must have the same shapes:
|
| 602 |
+
assert A.shape == B.shape, (A.shape, B.shape)
|
| 603 |
+
|
| 604 |
+
dtype = _utils.get_floating_dtype(A)
|
| 605 |
+
device = A.device
|
| 606 |
+
if tol is None:
|
| 607 |
+
feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
|
| 608 |
+
tol = feps**0.5
|
| 609 |
+
|
| 610 |
+
m = A.shape[-1]
|
| 611 |
+
k = (1 if X is None else X.shape[-1]) if k is None else k
|
| 612 |
+
n = (k if n is None else n) if X is None else X.shape[-1]
|
| 613 |
+
|
| 614 |
+
if m < 3 * n:
|
| 615 |
+
raise ValueError(
|
| 616 |
+
f"LPBPCG algorithm is not applicable when the number of A rows (={m})"
|
| 617 |
+
f" is smaller than 3 x the number of requested eigenpairs (={n})"
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
method = "ortho" if method is None else method
|
| 621 |
+
|
| 622 |
+
iparams = {
|
| 623 |
+
"m": m,
|
| 624 |
+
"n": n,
|
| 625 |
+
"k": k,
|
| 626 |
+
"niter": 1000 if niter is None else niter,
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
fparams = {
|
| 630 |
+
"tol": tol,
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
bparams = {"largest": True if largest is None else largest}
|
| 634 |
+
|
| 635 |
+
if method == "ortho":
|
| 636 |
+
if ortho_iparams is not None:
|
| 637 |
+
iparams.update(ortho_iparams)
|
| 638 |
+
if ortho_fparams is not None:
|
| 639 |
+
fparams.update(ortho_fparams)
|
| 640 |
+
if ortho_bparams is not None:
|
| 641 |
+
bparams.update(ortho_bparams)
|
| 642 |
+
iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
|
| 643 |
+
iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
|
| 644 |
+
fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
|
| 645 |
+
fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
|
| 646 |
+
fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
|
| 647 |
+
bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
|
| 648 |
+
|
| 649 |
+
if not torch.jit.is_scripting():
|
| 650 |
+
LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[method-assign]
|
| 651 |
+
|
| 652 |
+
if len(A.shape) > 2:
|
| 653 |
+
N = int(torch.prod(torch.tensor(A.shape[:-2])))
|
| 654 |
+
bA = A.reshape((N,) + A.shape[-2:])
|
| 655 |
+
bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
|
| 656 |
+
bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
|
| 657 |
+
bE = torch.empty((N, k), dtype=dtype, device=device)
|
| 658 |
+
bXret = torch.empty((N, m, k), dtype=dtype, device=device)
|
| 659 |
+
|
| 660 |
+
for i in range(N):
|
| 661 |
+
A_ = bA[i]
|
| 662 |
+
B_ = bB[i] if bB is not None else None
|
| 663 |
+
X_ = (
|
| 664 |
+
torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
|
| 665 |
+
)
|
| 666 |
+
assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
|
| 667 |
+
iparams["batch_index"] = i
|
| 668 |
+
worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
|
| 669 |
+
worker.run()
|
| 670 |
+
bE[i] = worker.E[:k]
|
| 671 |
+
bXret[i] = worker.X[:, :k]
|
| 672 |
+
|
| 673 |
+
if not torch.jit.is_scripting():
|
| 674 |
+
LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign]
|
| 675 |
+
|
| 676 |
+
return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
|
| 677 |
+
|
| 678 |
+
X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
|
| 679 |
+
assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
|
| 680 |
+
|
| 681 |
+
worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
|
| 682 |
+
|
| 683 |
+
worker.run()
|
| 684 |
+
|
| 685 |
+
if not torch.jit.is_scripting():
|
| 686 |
+
LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign]
|
| 687 |
+
|
| 688 |
+
return worker.E[:k], worker.X[:, :k]
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class LOBPCG:
|
| 692 |
+
"""Worker class of LOBPCG methods."""
|
| 693 |
+
|
| 694 |
+
def __init__(
|
| 695 |
+
self,
|
| 696 |
+
A: Optional[Tensor],
|
| 697 |
+
B: Optional[Tensor],
|
| 698 |
+
X: Tensor,
|
| 699 |
+
iK: Optional[Tensor],
|
| 700 |
+
iparams: Dict[str, int],
|
| 701 |
+
fparams: Dict[str, float],
|
| 702 |
+
bparams: Dict[str, bool],
|
| 703 |
+
method: str,
|
| 704 |
+
tracker: None,
|
| 705 |
+
) -> None:
|
| 706 |
+
# constant parameters
|
| 707 |
+
self.A = A
|
| 708 |
+
self.B = B
|
| 709 |
+
self.iK = iK
|
| 710 |
+
self.iparams = iparams
|
| 711 |
+
self.fparams = fparams
|
| 712 |
+
self.bparams = bparams
|
| 713 |
+
self.method = method
|
| 714 |
+
self.tracker = tracker
|
| 715 |
+
m = iparams["m"]
|
| 716 |
+
n = iparams["n"]
|
| 717 |
+
|
| 718 |
+
# variable parameters
|
| 719 |
+
self.X = X
|
| 720 |
+
self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
|
| 721 |
+
self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
|
| 722 |
+
self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
|
| 723 |
+
self.tvars: Dict[str, Tensor] = {}
|
| 724 |
+
self.ivars: Dict[str, int] = {"istep": 0}
|
| 725 |
+
self.fvars: Dict[str, float] = {"_": 0.0}
|
| 726 |
+
self.bvars: Dict[str, bool] = {"_": False}
|
| 727 |
+
|
| 728 |
+
def __str__(self):
|
| 729 |
+
lines = ["LOPBCG:"]
|
| 730 |
+
lines += [f" iparams={self.iparams}"]
|
| 731 |
+
lines += [f" fparams={self.fparams}"]
|
| 732 |
+
lines += [f" bparams={self.bparams}"]
|
| 733 |
+
lines += [f" ivars={self.ivars}"]
|
| 734 |
+
lines += [f" fvars={self.fvars}"]
|
| 735 |
+
lines += [f" bvars={self.bvars}"]
|
| 736 |
+
lines += [f" tvars={self.tvars}"]
|
| 737 |
+
lines += [f" A={self.A}"]
|
| 738 |
+
lines += [f" B={self.B}"]
|
| 739 |
+
lines += [f" iK={self.iK}"]
|
| 740 |
+
lines += [f" X={self.X}"]
|
| 741 |
+
lines += [f" E={self.E}"]
|
| 742 |
+
r = ""
|
| 743 |
+
for line in lines:
|
| 744 |
+
r += line + "\n"
|
| 745 |
+
return r
|
| 746 |
+
|
| 747 |
+
def update(self):
|
| 748 |
+
"""Set and update iteration variables."""
|
| 749 |
+
if self.ivars["istep"] == 0:
|
| 750 |
+
X_norm = float(torch.norm(self.X))
|
| 751 |
+
iX_norm = X_norm**-1
|
| 752 |
+
A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
|
| 753 |
+
B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
|
| 754 |
+
self.fvars["X_norm"] = X_norm
|
| 755 |
+
self.fvars["A_norm"] = A_norm
|
| 756 |
+
self.fvars["B_norm"] = B_norm
|
| 757 |
+
self.ivars["iterations_left"] = self.iparams["niter"]
|
| 758 |
+
self.ivars["converged_count"] = 0
|
| 759 |
+
self.ivars["converged_end"] = 0
|
| 760 |
+
|
| 761 |
+
if self.method == "ortho":
|
| 762 |
+
self._update_ortho()
|
| 763 |
+
else:
|
| 764 |
+
self._update_basic()
|
| 765 |
+
|
| 766 |
+
self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
|
| 767 |
+
self.ivars["istep"] = self.ivars["istep"] + 1
|
| 768 |
+
|
| 769 |
+
def update_residual(self):
|
| 770 |
+
"""Update residual R from A, B, X, E."""
|
| 771 |
+
mm = _utils.matmul
|
| 772 |
+
self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
|
| 773 |
+
|
| 774 |
+
def update_converged_count(self):
|
| 775 |
+
"""Determine the number of converged eigenpairs using backward stable
|
| 776 |
+
convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
|
| 777 |
+
|
| 778 |
+
Users may redefine this method for custom convergence criteria.
|
| 779 |
+
"""
|
| 780 |
+
# (...) -> int
|
| 781 |
+
prev_count = self.ivars["converged_count"]
|
| 782 |
+
tol = self.fparams["tol"]
|
| 783 |
+
A_norm = self.fvars["A_norm"]
|
| 784 |
+
B_norm = self.fvars["B_norm"]
|
| 785 |
+
E, X, R = self.E, self.X, self.R
|
| 786 |
+
rerr = (
|
| 787 |
+
torch.norm(R, 2, (0,))
|
| 788 |
+
* (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1
|
| 789 |
+
)
|
| 790 |
+
converged = rerr.real < tol # this is a norm so imag is 0.0
|
| 791 |
+
count = 0
|
| 792 |
+
for b in converged:
|
| 793 |
+
if not b:
|
| 794 |
+
# ignore convergence of following pairs to ensure
|
| 795 |
+
# strict ordering of eigenpairs
|
| 796 |
+
break
|
| 797 |
+
count += 1
|
| 798 |
+
assert (
|
| 799 |
+
count >= prev_count
|
| 800 |
+
), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
|
| 801 |
+
self.ivars["converged_count"] = count
|
| 802 |
+
self.tvars["rerr"] = rerr
|
| 803 |
+
return count
|
| 804 |
+
|
| 805 |
+
def stop_iteration(self):
|
| 806 |
+
"""Return True to stop iterations.
|
| 807 |
+
|
| 808 |
+
Note that tracker (if defined) can force-stop iterations by
|
| 809 |
+
setting ``worker.bvars['force_stop'] = True``.
|
| 810 |
+
"""
|
| 811 |
+
return (
|
| 812 |
+
self.bvars.get("force_stop", False)
|
| 813 |
+
or self.ivars["iterations_left"] == 0
|
| 814 |
+
or self.ivars["converged_count"] >= self.iparams["k"]
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
def run(self):
|
| 818 |
+
"""Run LOBPCG iterations.
|
| 819 |
+
|
| 820 |
+
Use this method as a template for implementing LOBPCG
|
| 821 |
+
iteration scheme with custom tracker that is compatible with
|
| 822 |
+
TorchScript.
|
| 823 |
+
"""
|
| 824 |
+
self.update()
|
| 825 |
+
|
| 826 |
+
if not torch.jit.is_scripting() and self.tracker is not None:
|
| 827 |
+
self.call_tracker()
|
| 828 |
+
|
| 829 |
+
while not self.stop_iteration():
|
| 830 |
+
self.update()
|
| 831 |
+
|
| 832 |
+
if not torch.jit.is_scripting() and self.tracker is not None:
|
| 833 |
+
self.call_tracker()
|
| 834 |
+
|
| 835 |
+
@torch.jit.unused
|
| 836 |
+
def call_tracker(self):
|
| 837 |
+
"""Interface for tracking iteration process in Python mode.
|
| 838 |
+
|
| 839 |
+
Tracking the iteration process is disabled in TorchScript
|
| 840 |
+
mode. In fact, one should specify tracker=None when JIT
|
| 841 |
+
compiling functions using lobpcg.
|
| 842 |
+
"""
|
| 843 |
+
# do nothing when in TorchScript mode
|
| 844 |
+
|
| 845 |
+
# Internal methods
|
| 846 |
+
|
| 847 |
+
def _update_basic(self):
|
| 848 |
+
"""
|
| 849 |
+
Update or initialize iteration variables when `method == "basic"`.
|
| 850 |
+
"""
|
| 851 |
+
mm = torch.matmul
|
| 852 |
+
ns = self.ivars["converged_end"]
|
| 853 |
+
nc = self.ivars["converged_count"]
|
| 854 |
+
n = self.iparams["n"]
|
| 855 |
+
largest = self.bparams["largest"]
|
| 856 |
+
|
| 857 |
+
if self.ivars["istep"] == 0:
|
| 858 |
+
Ri = self._get_rayleigh_ritz_transform(self.X)
|
| 859 |
+
M = _utils.qform(_utils.qform(self.A, self.X), Ri)
|
| 860 |
+
E, Z = _utils.symeig(M, largest)
|
| 861 |
+
self.X[:] = mm(self.X, mm(Ri, Z))
|
| 862 |
+
self.E[:] = E
|
| 863 |
+
np = 0
|
| 864 |
+
self.update_residual()
|
| 865 |
+
nc = self.update_converged_count()
|
| 866 |
+
self.S[..., :n] = self.X
|
| 867 |
+
|
| 868 |
+
W = _utils.matmul(self.iK, self.R)
|
| 869 |
+
self.ivars["converged_end"] = ns = n + np + W.shape[-1]
|
| 870 |
+
self.S[:, n + np : ns] = W
|
| 871 |
+
else:
|
| 872 |
+
S_ = self.S[:, nc:ns]
|
| 873 |
+
Ri = self._get_rayleigh_ritz_transform(S_)
|
| 874 |
+
M = _utils.qform(_utils.qform(self.A, S_), Ri)
|
| 875 |
+
E_, Z = _utils.symeig(M, largest)
|
| 876 |
+
self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
|
| 877 |
+
self.E[nc:] = E_[: n - nc]
|
| 878 |
+
P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
|
| 879 |
+
np = P.shape[-1]
|
| 880 |
+
|
| 881 |
+
self.update_residual()
|
| 882 |
+
nc = self.update_converged_count()
|
| 883 |
+
self.S[..., :n] = self.X
|
| 884 |
+
self.S[:, n : n + np] = P
|
| 885 |
+
W = _utils.matmul(self.iK, self.R[:, nc:])
|
| 886 |
+
|
| 887 |
+
self.ivars["converged_end"] = ns = n + np + W.shape[-1]
|
| 888 |
+
self.S[:, n + np : ns] = W
|
| 889 |
+
|
| 890 |
+
def _update_ortho(self):
|
| 891 |
+
"""
|
| 892 |
+
Update or initialize iteration variables when `method == "ortho"`.
|
| 893 |
+
"""
|
| 894 |
+
mm = torch.matmul
|
| 895 |
+
ns = self.ivars["converged_end"]
|
| 896 |
+
nc = self.ivars["converged_count"]
|
| 897 |
+
n = self.iparams["n"]
|
| 898 |
+
largest = self.bparams["largest"]
|
| 899 |
+
|
| 900 |
+
if self.ivars["istep"] == 0:
|
| 901 |
+
Ri = self._get_rayleigh_ritz_transform(self.X)
|
| 902 |
+
M = _utils.qform(_utils.qform(self.A, self.X), Ri)
|
| 903 |
+
E, Z = _utils.symeig(M, largest)
|
| 904 |
+
self.X = mm(self.X, mm(Ri, Z))
|
| 905 |
+
self.update_residual()
|
| 906 |
+
np = 0
|
| 907 |
+
nc = self.update_converged_count()
|
| 908 |
+
self.S[:, :n] = self.X
|
| 909 |
+
W = self._get_ortho(self.R, self.X)
|
| 910 |
+
ns = self.ivars["converged_end"] = n + np + W.shape[-1]
|
| 911 |
+
self.S[:, n + np : ns] = W
|
| 912 |
+
|
| 913 |
+
else:
|
| 914 |
+
S_ = self.S[:, nc:ns]
|
| 915 |
+
# Rayleigh-Ritz procedure
|
| 916 |
+
E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
|
| 917 |
+
|
| 918 |
+
# Update E, X, P
|
| 919 |
+
self.X[:, nc:] = mm(S_, Z[:, : n - nc])
|
| 920 |
+
self.E[nc:] = E_[: n - nc]
|
| 921 |
+
P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT)))
|
| 922 |
+
np = P.shape[-1]
|
| 923 |
+
|
| 924 |
+
# check convergence
|
| 925 |
+
self.update_residual()
|
| 926 |
+
nc = self.update_converged_count()
|
| 927 |
+
|
| 928 |
+
# update S
|
| 929 |
+
self.S[:, :n] = self.X
|
| 930 |
+
self.S[:, n : n + np] = P
|
| 931 |
+
W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
|
| 932 |
+
ns = self.ivars["converged_end"] = n + np + W.shape[-1]
|
| 933 |
+
self.S[:, n + np : ns] = W
|
| 934 |
+
|
| 935 |
+
def _get_rayleigh_ritz_transform(self, S):
|
| 936 |
+
"""Return a transformation matrix that is used in Rayleigh-Ritz
|
| 937 |
+
procedure for reducing a general eigenvalue problem :math:`(S^TAS)
|
| 938 |
+
C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
|
| 939 |
+
S^TAS Ri) Z = Z E` where `C = Ri Z`.
|
| 940 |
+
|
| 941 |
+
.. note:: In the original Rayleight-Ritz procedure in
|
| 942 |
+
[DuerschEtal2018], the problem is formulated as follows::
|
| 943 |
+
|
| 944 |
+
SAS = S^T A S
|
| 945 |
+
SBS = S^T B S
|
| 946 |
+
D = (<diagonal matrix of SBS>) ** -1/2
|
| 947 |
+
R^T R = Cholesky(D SBS D)
|
| 948 |
+
Ri = D R^-1
|
| 949 |
+
solve symeig problem Ri^T SAS Ri Z = Theta Z
|
| 950 |
+
C = Ri Z
|
| 951 |
+
|
| 952 |
+
To reduce the number of matrix products (denoted by empty
|
| 953 |
+
space between matrices), here we introduce element-wise
|
| 954 |
+
products (denoted by symbol `*`) so that the Rayleight-Ritz
|
| 955 |
+
procedure becomes::
|
| 956 |
+
|
| 957 |
+
SAS = S^T A S
|
| 958 |
+
SBS = S^T B S
|
| 959 |
+
d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector
|
| 960 |
+
dd = d d^T # this is 2-d matrix
|
| 961 |
+
R^T R = Cholesky(dd * SBS)
|
| 962 |
+
Ri = R^-1 * d # broadcasting
|
| 963 |
+
solve symeig problem Ri^T SAS Ri Z = Theta Z
|
| 964 |
+
C = Ri Z
|
| 965 |
+
|
| 966 |
+
where `dd` is 2-d matrix that replaces matrix products `D M
|
| 967 |
+
D` with one element-wise product `M * dd`; and `d` replaces
|
| 968 |
+
matrix product `D M` with element-wise product `M *
|
| 969 |
+
d`. Also, creating the diagonal matrix `D` is avoided.
|
| 970 |
+
|
| 971 |
+
Args:
|
| 972 |
+
S (Tensor): the matrix basis for the search subspace, size is
|
| 973 |
+
:math:`(m, n)`.
|
| 974 |
+
|
| 975 |
+
Returns:
|
| 976 |
+
Ri (tensor): upper-triangular transformation matrix of size
|
| 977 |
+
:math:`(n, n)`.
|
| 978 |
+
|
| 979 |
+
"""
|
| 980 |
+
B = self.B
|
| 981 |
+
mm = torch.matmul
|
| 982 |
+
SBS = _utils.qform(B, S)
|
| 983 |
+
d_row = SBS.diagonal(0, -2, -1) ** -0.5
|
| 984 |
+
d_col = d_row.reshape(d_row.shape[0], 1)
|
| 985 |
+
# TODO use torch.linalg.cholesky_solve once it is implemented
|
| 986 |
+
R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
|
| 987 |
+
return torch.linalg.solve_triangular(
|
| 988 |
+
R, d_row.diag_embed(), upper=True, left=False
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor:
|
| 992 |
+
"""Return B-orthonormal U.
|
| 993 |
+
|
| 994 |
+
.. note:: When `drop` is `False` then `svqb` is based on the
|
| 995 |
+
Algorithm 4 from [DuerschPhD2015] that is a slight
|
| 996 |
+
modification of the corresponding algorithm
|
| 997 |
+
introduced in [StathopolousWu2002].
|
| 998 |
+
|
| 999 |
+
Args:
|
| 1000 |
+
|
| 1001 |
+
U (Tensor) : initial approximation, size is (m, n)
|
| 1002 |
+
drop (bool) : when True, drop columns that
|
| 1003 |
+
contribution to the `span([U])` is small.
|
| 1004 |
+
tau (float) : positive tolerance
|
| 1005 |
+
|
| 1006 |
+
Returns:
|
| 1007 |
+
|
| 1008 |
+
U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
|
| 1009 |
+
is (m, n1), where `n1 = n` if `drop` is `False,
|
| 1010 |
+
otherwise `n1 <= n`.
|
| 1011 |
+
|
| 1012 |
+
"""
|
| 1013 |
+
if torch.numel(U) == 0:
|
| 1014 |
+
return U
|
| 1015 |
+
UBU = _utils.qform(self.B, U)
|
| 1016 |
+
d = UBU.diagonal(0, -2, -1)
|
| 1017 |
+
|
| 1018 |
+
# Detect and drop exact zero columns from U. While the test
|
| 1019 |
+
# `abs(d) == 0` is unlikely to be True for random data, it is
|
| 1020 |
+
# possible to construct input data to lobpcg where it will be
|
| 1021 |
+
# True leading to a failure (notice the `d ** -0.5` operation
|
| 1022 |
+
# in the original algorithm). To prevent the failure, we drop
|
| 1023 |
+
# the exact zero columns here and then continue with the
|
| 1024 |
+
# original algorithm below.
|
| 1025 |
+
nz = torch.where(abs(d) != 0.0)
|
| 1026 |
+
assert len(nz) == 1, nz
|
| 1027 |
+
if len(nz[0]) < len(d):
|
| 1028 |
+
U = U[:, nz[0]]
|
| 1029 |
+
if torch.numel(U) == 0:
|
| 1030 |
+
return U
|
| 1031 |
+
UBU = _utils.qform(self.B, U)
|
| 1032 |
+
d = UBU.diagonal(0, -2, -1)
|
| 1033 |
+
nz = torch.where(abs(d) != 0.0)
|
| 1034 |
+
assert len(nz[0]) == len(d)
|
| 1035 |
+
|
| 1036 |
+
# The original algorithm 4 from [DuerschPhD2015].
|
| 1037 |
+
d_col = (d**-0.5).reshape(d.shape[0], 1)
|
| 1038 |
+
DUBUD = (UBU * d_col) * d_col.mT
|
| 1039 |
+
E, Z = _utils.symeig(DUBUD)
|
| 1040 |
+
t = tau * abs(E).max()
|
| 1041 |
+
if drop:
|
| 1042 |
+
keep = torch.where(E > t)
|
| 1043 |
+
assert len(keep) == 1, keep
|
| 1044 |
+
E = E[keep[0]]
|
| 1045 |
+
Z = Z[:, keep[0]]
|
| 1046 |
+
d_col = d_col[keep[0]]
|
| 1047 |
+
else:
|
| 1048 |
+
E[(torch.where(E < t))[0]] = t
|
| 1049 |
+
|
| 1050 |
+
return torch.matmul(U * d_col.mT, Z * E**-0.5)
|
| 1051 |
+
|
| 1052 |
+
def _get_ortho(self, U, V):
|
| 1053 |
+
"""Return B-orthonormal U with columns are B-orthogonal to V.
|
| 1054 |
+
|
| 1055 |
+
.. note:: When `bparams["ortho_use_drop"] == False` then
|
| 1056 |
+
`_get_ortho` is based on the Algorithm 3 from
|
| 1057 |
+
[DuerschPhD2015] that is a slight modification of
|
| 1058 |
+
the corresponding algorithm introduced in
|
| 1059 |
+
[StathopolousWu2002]. Otherwise, the method
|
| 1060 |
+
implements Algorithm 6 from [DuerschPhD2015]
|
| 1061 |
+
|
| 1062 |
+
.. note:: If all U columns are B-collinear to V then the
|
| 1063 |
+
returned tensor U will be empty.
|
| 1064 |
+
|
| 1065 |
+
Args:
|
| 1066 |
+
|
| 1067 |
+
U (Tensor) : initial approximation, size is (m, n)
|
| 1068 |
+
V (Tensor) : B-orthogonal external basis, size is (m, k)
|
| 1069 |
+
|
| 1070 |
+
Returns:
|
| 1071 |
+
|
| 1072 |
+
U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
|
| 1073 |
+
such that :math:`V^T B U=0`, size is (m, n1),
|
| 1074 |
+
where `n1 = n` if `drop` is `False, otherwise
|
| 1075 |
+
`n1 <= n`.
|
| 1076 |
+
"""
|
| 1077 |
+
mm = torch.matmul
|
| 1078 |
+
mm_B = _utils.matmul
|
| 1079 |
+
m = self.iparams["m"]
|
| 1080 |
+
tau_ortho = self.fparams["ortho_tol"]
|
| 1081 |
+
tau_drop = self.fparams["ortho_tol_drop"]
|
| 1082 |
+
tau_replace = self.fparams["ortho_tol_replace"]
|
| 1083 |
+
i_max = self.iparams["ortho_i_max"]
|
| 1084 |
+
j_max = self.iparams["ortho_j_max"]
|
| 1085 |
+
# when use_drop==True, enable dropping U columns that have
|
| 1086 |
+
# small contribution to the `span([U, V])`.
|
| 1087 |
+
use_drop = self.bparams["ortho_use_drop"]
|
| 1088 |
+
|
| 1089 |
+
# clean up variables from the previous call
|
| 1090 |
+
for vkey in list(self.fvars.keys()):
|
| 1091 |
+
if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
|
| 1092 |
+
self.fvars.pop(vkey)
|
| 1093 |
+
self.ivars.pop("ortho_i", 0)
|
| 1094 |
+
self.ivars.pop("ortho_j", 0)
|
| 1095 |
+
|
| 1096 |
+
BV_norm = torch.norm(mm_B(self.B, V))
|
| 1097 |
+
BU = mm_B(self.B, U)
|
| 1098 |
+
VBU = mm(V.mT, BU)
|
| 1099 |
+
i = j = 0
|
| 1100 |
+
stats = ""
|
| 1101 |
+
for i in range(i_max):
|
| 1102 |
+
U = U - mm(V, VBU)
|
| 1103 |
+
drop = False
|
| 1104 |
+
tau_svqb = tau_drop
|
| 1105 |
+
for j in range(j_max):
|
| 1106 |
+
if use_drop:
|
| 1107 |
+
U = self._get_svqb(U, drop, tau_svqb)
|
| 1108 |
+
drop = True
|
| 1109 |
+
tau_svqb = tau_replace
|
| 1110 |
+
else:
|
| 1111 |
+
U = self._get_svqb(U, False, tau_replace)
|
| 1112 |
+
if torch.numel(U) == 0:
|
| 1113 |
+
# all initial U columns are B-collinear to V
|
| 1114 |
+
self.ivars["ortho_i"] = i
|
| 1115 |
+
self.ivars["ortho_j"] = j
|
| 1116 |
+
return U
|
| 1117 |
+
BU = mm_B(self.B, U)
|
| 1118 |
+
UBU = mm(U.mT, BU)
|
| 1119 |
+
U_norm = torch.norm(U)
|
| 1120 |
+
BU_norm = torch.norm(BU)
|
| 1121 |
+
R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
|
| 1122 |
+
R_norm = torch.norm(R)
|
| 1123 |
+
# https://github.com/pytorch/pytorch/issues/33810 workaround:
|
| 1124 |
+
rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
|
| 1125 |
+
vkey = f"ortho_UBUmI_rerr[{i}, {j}]"
|
| 1126 |
+
self.fvars[vkey] = rerr
|
| 1127 |
+
if rerr < tau_ortho:
|
| 1128 |
+
break
|
| 1129 |
+
VBU = mm(V.mT, BU)
|
| 1130 |
+
VBU_norm = torch.norm(VBU)
|
| 1131 |
+
U_norm = torch.norm(U)
|
| 1132 |
+
rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
|
| 1133 |
+
vkey = f"ortho_VBU_rerr[{i}]"
|
| 1134 |
+
self.fvars[vkey] = rerr
|
| 1135 |
+
if rerr < tau_ortho:
|
| 1136 |
+
break
|
| 1137 |
+
if m < U.shape[-1] + V.shape[-1]:
|
| 1138 |
+
# TorchScript needs the class var to be assigned to a local to
|
| 1139 |
+
# do optional type refinement
|
| 1140 |
+
B = self.B
|
| 1141 |
+
assert B is not None
|
| 1142 |
+
raise ValueError(
|
| 1143 |
+
"Overdetermined shape of U:"
|
| 1144 |
+
f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold"
|
| 1145 |
+
)
|
| 1146 |
+
self.ivars["ortho_i"] = i
|
| 1147 |
+
self.ivars["ortho_j"] = j
|
| 1148 |
+
return U
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
# Calling tracker is separated from LOBPCG definitions because
|
| 1152 |
+
# TorchScript does not support user-defined callback arguments:
|
| 1153 |
+
LOBPCG_call_tracker_orig = LOBPCG.call_tracker
|
| 1154 |
+
|
| 1155 |
+
|
| 1156 |
+
def LOBPCG_call_tracker(self):
|
| 1157 |
+
self.tracker(self)
|
.venv/lib/python3.11/site-packages/torch/_lowrank.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Implement various linear algebra algorithms for low rank matrices."""
|
| 2 |
+
|
| 3 |
+
__all__ = ["svd_lowrank", "pca_lowrank"]
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import _linalg_utils as _utils, Tensor
|
| 9 |
+
from torch.overrides import handle_torch_function, has_torch_function
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_approximate_basis(
|
| 13 |
+
A: Tensor,
|
| 14 |
+
q: int,
|
| 15 |
+
niter: Optional[int] = 2,
|
| 16 |
+
M: Optional[Tensor] = None,
|
| 17 |
+
) -> Tensor:
|
| 18 |
+
"""Return tensor :math:`Q` with :math:`q` orthonormal columns such
|
| 19 |
+
that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
|
| 20 |
+
specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
|
| 21 |
+
approximates :math:`A - M`. without instantiating any tensors
|
| 22 |
+
of the size of :math:`A` or :math:`M`.
|
| 23 |
+
|
| 24 |
+
.. note:: The implementation is based on the Algorithm 4.4 from
|
| 25 |
+
Halko et al., 2009.
|
| 26 |
+
|
| 27 |
+
.. note:: For an adequate approximation of a k-rank matrix
|
| 28 |
+
:math:`A`, where k is not known in advance but could be
|
| 29 |
+
estimated, the number of :math:`Q` columns, q, can be
|
| 30 |
+
choosen according to the following criteria: in general,
|
| 31 |
+
:math:`k <= q <= min(2*k, m, n)`. For large low-rank
|
| 32 |
+
matrices, take :math:`q = k + 5..10`. If k is
|
| 33 |
+
relatively small compared to :math:`min(m, n)`, choosing
|
| 34 |
+
:math:`q = k + 0..2` may be sufficient.
|
| 35 |
+
|
| 36 |
+
.. note:: To obtain repeatable results, reset the seed for the
|
| 37 |
+
pseudorandom number generator
|
| 38 |
+
|
| 39 |
+
Args::
|
| 40 |
+
A (Tensor): the input tensor of size :math:`(*, m, n)`
|
| 41 |
+
|
| 42 |
+
q (int): the dimension of subspace spanned by :math:`Q`
|
| 43 |
+
columns.
|
| 44 |
+
|
| 45 |
+
niter (int, optional): the number of subspace iterations to
|
| 46 |
+
conduct; ``niter`` must be a
|
| 47 |
+
nonnegative integer. In most cases, the
|
| 48 |
+
default value 2 is more than enough.
|
| 49 |
+
|
| 50 |
+
M (Tensor, optional): the input tensor's mean of size
|
| 51 |
+
:math:`(*, m, n)`.
|
| 52 |
+
|
| 53 |
+
References::
|
| 54 |
+
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
|
| 55 |
+
structure with randomness: probabilistic algorithms for
|
| 56 |
+
constructing approximate matrix decompositions,
|
| 57 |
+
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
|
| 58 |
+
`arXiv <http://arxiv.org/abs/0909.4061>`_).
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
niter = 2 if niter is None else niter
|
| 62 |
+
dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype
|
| 63 |
+
matmul = _utils.matmul
|
| 64 |
+
|
| 65 |
+
R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device)
|
| 66 |
+
|
| 67 |
+
# The following code could be made faster using torch.geqrf + torch.ormqr
|
| 68 |
+
# but geqrf is not differentiable
|
| 69 |
+
|
| 70 |
+
X = matmul(A, R)
|
| 71 |
+
if M is not None:
|
| 72 |
+
X = X - matmul(M, R)
|
| 73 |
+
Q = torch.linalg.qr(X).Q
|
| 74 |
+
for i in range(niter):
|
| 75 |
+
X = matmul(A.mH, Q)
|
| 76 |
+
if M is not None:
|
| 77 |
+
X = X - matmul(M.mH, Q)
|
| 78 |
+
Q = torch.linalg.qr(X).Q
|
| 79 |
+
X = matmul(A, Q)
|
| 80 |
+
if M is not None:
|
| 81 |
+
X = X - matmul(M, Q)
|
| 82 |
+
Q = torch.linalg.qr(X).Q
|
| 83 |
+
return Q
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def svd_lowrank(
|
| 87 |
+
A: Tensor,
|
| 88 |
+
q: Optional[int] = 6,
|
| 89 |
+
niter: Optional[int] = 2,
|
| 90 |
+
M: Optional[Tensor] = None,
|
| 91 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 92 |
+
r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
|
| 93 |
+
batches of matrices, or a sparse matrix :math:`A` such that
|
| 94 |
+
:math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
|
| 95 |
+
SVD is computed for the matrix :math:`A - M`.
|
| 96 |
+
|
| 97 |
+
.. note:: The implementation is based on the Algorithm 5.1 from
|
| 98 |
+
Halko et al., 2009.
|
| 99 |
+
|
| 100 |
+
.. note:: For an adequate approximation of a k-rank matrix
|
| 101 |
+
:math:`A`, where k is not known in advance but could be
|
| 102 |
+
estimated, the number of :math:`Q` columns, q, can be
|
| 103 |
+
choosen according to the following criteria: in general,
|
| 104 |
+
:math:`k <= q <= min(2*k, m, n)`. For large low-rank
|
| 105 |
+
matrices, take :math:`q = k + 5..10`. If k is
|
| 106 |
+
relatively small compared to :math:`min(m, n)`, choosing
|
| 107 |
+
:math:`q = k + 0..2` may be sufficient.
|
| 108 |
+
|
| 109 |
+
.. note:: This is a randomized method. To obtain repeatable results,
|
| 110 |
+
set the seed for the pseudorandom number generator
|
| 111 |
+
|
| 112 |
+
.. note:: In general, use the full-rank SVD implementation
|
| 113 |
+
:func:`torch.linalg.svd` for dense matrices due to its 10x
|
| 114 |
+
higher performance characteristics. The low-rank SVD
|
| 115 |
+
will be useful for huge sparse matrices that
|
| 116 |
+
:func:`torch.linalg.svd` cannot handle.
|
| 117 |
+
|
| 118 |
+
Args::
|
| 119 |
+
A (Tensor): the input tensor of size :math:`(*, m, n)`
|
| 120 |
+
|
| 121 |
+
q (int, optional): a slightly overestimated rank of A.
|
| 122 |
+
|
| 123 |
+
niter (int, optional): the number of subspace iterations to
|
| 124 |
+
conduct; niter must be a nonnegative
|
| 125 |
+
integer, and defaults to 2
|
| 126 |
+
|
| 127 |
+
M (Tensor, optional): the input tensor's mean of size
|
| 128 |
+
:math:`(*, m, n)`, which will be broadcasted
|
| 129 |
+
to the size of A in this function.
|
| 130 |
+
|
| 131 |
+
References::
|
| 132 |
+
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
|
| 133 |
+
structure with randomness: probabilistic algorithms for
|
| 134 |
+
constructing approximate matrix decompositions,
|
| 135 |
+
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
|
| 136 |
+
`arXiv <https://arxiv.org/abs/0909.4061>`_).
|
| 137 |
+
|
| 138 |
+
"""
|
| 139 |
+
if not torch.jit.is_scripting():
|
| 140 |
+
tensor_ops = (A, M)
|
| 141 |
+
if not set(map(type, tensor_ops)).issubset(
|
| 142 |
+
(torch.Tensor, type(None))
|
| 143 |
+
) and has_torch_function(tensor_ops):
|
| 144 |
+
return handle_torch_function(
|
| 145 |
+
svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
|
| 146 |
+
)
|
| 147 |
+
return _svd_lowrank(A, q=q, niter=niter, M=M)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _svd_lowrank(
|
| 151 |
+
A: Tensor,
|
| 152 |
+
q: Optional[int] = 6,
|
| 153 |
+
niter: Optional[int] = 2,
|
| 154 |
+
M: Optional[Tensor] = None,
|
| 155 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 156 |
+
# Algorithm 5.1 in Halko et al., 2009
|
| 157 |
+
|
| 158 |
+
q = 6 if q is None else q
|
| 159 |
+
m, n = A.shape[-2:]
|
| 160 |
+
matmul = _utils.matmul
|
| 161 |
+
if M is not None:
|
| 162 |
+
M = M.broadcast_to(A.size())
|
| 163 |
+
|
| 164 |
+
# Assume that A is tall
|
| 165 |
+
if m < n:
|
| 166 |
+
A = A.mH
|
| 167 |
+
if M is not None:
|
| 168 |
+
M = M.mH
|
| 169 |
+
|
| 170 |
+
Q = get_approximate_basis(A, q, niter=niter, M=M)
|
| 171 |
+
B = matmul(Q.mH, A)
|
| 172 |
+
if M is not None:
|
| 173 |
+
B = B - matmul(Q.mH, M)
|
| 174 |
+
U, S, Vh = torch.linalg.svd(B, full_matrices=False)
|
| 175 |
+
V = Vh.mH
|
| 176 |
+
U = Q.matmul(U)
|
| 177 |
+
|
| 178 |
+
if m < n:
|
| 179 |
+
U, V = V, U
|
| 180 |
+
|
| 181 |
+
return U, S, V
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def pca_lowrank(
|
| 185 |
+
A: Tensor,
|
| 186 |
+
q: Optional[int] = None,
|
| 187 |
+
center: bool = True,
|
| 188 |
+
niter: int = 2,
|
| 189 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 190 |
+
r"""Performs linear Principal Component Analysis (PCA) on a low-rank
|
| 191 |
+
matrix, batches of such matrices, or sparse matrix.
|
| 192 |
+
|
| 193 |
+
This function returns a namedtuple ``(U, S, V)`` which is the
|
| 194 |
+
nearly optimal approximation of a singular value decomposition of
|
| 195 |
+
a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`
|
| 196 |
+
|
| 197 |
+
.. note:: The relation of ``(U, S, V)`` to PCA is as follows:
|
| 198 |
+
|
| 199 |
+
- :math:`A` is a data matrix with ``m`` samples and
|
| 200 |
+
``n`` features
|
| 201 |
+
|
| 202 |
+
- the :math:`V` columns represent the principal directions
|
| 203 |
+
|
| 204 |
+
- :math:`S ** 2 / (m - 1)` contains the eigenvalues of
|
| 205 |
+
:math:`A^T A / (m - 1)` which is the covariance of
|
| 206 |
+
``A`` when ``center=True`` is provided.
|
| 207 |
+
|
| 208 |
+
- ``matmul(A, V[:, :k])`` projects data to the first k
|
| 209 |
+
principal components
|
| 210 |
+
|
| 211 |
+
.. note:: Different from the standard SVD, the size of returned
|
| 212 |
+
matrices depend on the specified rank and q
|
| 213 |
+
values as follows:
|
| 214 |
+
|
| 215 |
+
- :math:`U` is m x q matrix
|
| 216 |
+
|
| 217 |
+
- :math:`S` is q-vector
|
| 218 |
+
|
| 219 |
+
- :math:`V` is n x q matrix
|
| 220 |
+
|
| 221 |
+
.. note:: To obtain repeatable results, reset the seed for the
|
| 222 |
+
pseudorandom number generator
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
|
| 226 |
+
A (Tensor): the input tensor of size :math:`(*, m, n)`
|
| 227 |
+
|
| 228 |
+
q (int, optional): a slightly overestimated rank of
|
| 229 |
+
:math:`A`. By default, ``q = min(6, m,
|
| 230 |
+
n)``.
|
| 231 |
+
|
| 232 |
+
center (bool, optional): if True, center the input tensor,
|
| 233 |
+
otherwise, assume that the input is
|
| 234 |
+
centered.
|
| 235 |
+
|
| 236 |
+
niter (int, optional): the number of subspace iterations to
|
| 237 |
+
conduct; niter must be a nonnegative
|
| 238 |
+
integer, and defaults to 2.
|
| 239 |
+
|
| 240 |
+
References::
|
| 241 |
+
|
| 242 |
+
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
|
| 243 |
+
structure with randomness: probabilistic algorithms for
|
| 244 |
+
constructing approximate matrix decompositions,
|
| 245 |
+
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
|
| 246 |
+
`arXiv <http://arxiv.org/abs/0909.4061>`_).
|
| 247 |
+
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
if not torch.jit.is_scripting():
|
| 251 |
+
if type(A) is not torch.Tensor and has_torch_function((A,)):
|
| 252 |
+
return handle_torch_function(
|
| 253 |
+
pca_lowrank, (A,), A, q=q, center=center, niter=niter
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
(m, n) = A.shape[-2:]
|
| 257 |
+
|
| 258 |
+
if q is None:
|
| 259 |
+
q = min(6, m, n)
|
| 260 |
+
elif not (q >= 0 and q <= min(m, n)):
|
| 261 |
+
raise ValueError(
|
| 262 |
+
f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}"
|
| 263 |
+
)
|
| 264 |
+
if not (niter >= 0):
|
| 265 |
+
raise ValueError(f"niter(={niter}) must be non-negative integer")
|
| 266 |
+
|
| 267 |
+
dtype = _utils.get_floating_dtype(A)
|
| 268 |
+
|
| 269 |
+
if not center:
|
| 270 |
+
return _svd_lowrank(A, q, niter=niter, M=None)
|
| 271 |
+
|
| 272 |
+
if _utils.is_sparse(A):
|
| 273 |
+
if len(A.shape) != 2:
|
| 274 |
+
raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
|
| 275 |
+
c = torch.sparse.sum(A, dim=(-2,)) / m
|
| 276 |
+
# reshape c
|
| 277 |
+
column_indices = c.indices()[0]
|
| 278 |
+
indices = torch.zeros(
|
| 279 |
+
2,
|
| 280 |
+
len(column_indices),
|
| 281 |
+
dtype=column_indices.dtype,
|
| 282 |
+
device=column_indices.device,
|
| 283 |
+
)
|
| 284 |
+
indices[0] = column_indices
|
| 285 |
+
C_t = torch.sparse_coo_tensor(
|
| 286 |
+
indices, c.values(), (n, 1), dtype=dtype, device=A.device
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
|
| 290 |
+
M = torch.sparse.mm(C_t, ones_m1_t).mT
|
| 291 |
+
return _svd_lowrank(A, q, niter=niter, M=M)
|
| 292 |
+
else:
|
| 293 |
+
C = A.mean(dim=(-2,), keepdim=True)
|
| 294 |
+
return _svd_lowrank(A - C, q, niter=niter, M=None)
|
.venv/lib/python3.11/site-packages/torch/_meta_registrations.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/_namedtensor_internals.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
This file contains helper functions that implement experimental functionality
|
| 7 |
+
for named tensors in python. All of these are experimental, unstable, and
|
| 8 |
+
subject to change or deletion.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def check_serializing_named_tensor(tensor):
|
| 13 |
+
if tensor.has_names():
|
| 14 |
+
raise RuntimeError(
|
| 15 |
+
"NYI: Named tensors don't support serialization. Please drop "
|
| 16 |
+
"names via `tensor = tensor.rename(None)` before serialization."
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def build_dim_map(tensor):
|
| 21 |
+
"""Returns a map of { dim: dim_name } where dim is a name if the dim is named
|
| 22 |
+
and the dim index otherwise."""
|
| 23 |
+
return OrderedDict(
|
| 24 |
+
[(idx if name is None else name, name) for idx, name in enumerate(tensor.names)]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def unzip_namedshape(namedshape):
|
| 29 |
+
if isinstance(namedshape, OrderedDict):
|
| 30 |
+
namedshape = namedshape.items()
|
| 31 |
+
if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple):
|
| 32 |
+
raise RuntimeError(
|
| 33 |
+
f"Expected namedshape to be OrderedDict or iterable of tuples, got: {type(namedshape)}"
|
| 34 |
+
)
|
| 35 |
+
if len(namedshape) == 0:
|
| 36 |
+
raise RuntimeError("Expected namedshape to non-empty.")
|
| 37 |
+
return zip(*namedshape)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def namer_api_name(inplace):
|
| 41 |
+
if inplace:
|
| 42 |
+
return "rename_"
|
| 43 |
+
else:
|
| 44 |
+
return "rename"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def is_ellipsis(item):
|
| 48 |
+
return item == Ellipsis or item == "..."
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def single_ellipsis_index(names, fn_name):
|
| 52 |
+
ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
|
| 53 |
+
if len(ellipsis_indices) >= 2:
|
| 54 |
+
raise RuntimeError(
|
| 55 |
+
f"{fn_name}: More than one Ellipsis ('...') found in names ("
|
| 56 |
+
f"{names}). This function supports up to one Ellipsis."
|
| 57 |
+
)
|
| 58 |
+
if len(ellipsis_indices) == 1:
|
| 59 |
+
return ellipsis_indices[0]
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names):
|
| 64 |
+
return names[numel_pre_glob : len(names) - numel_post_glob]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names):
|
| 68 |
+
globbed_names = expand_single_ellipsis(
|
| 69 |
+
ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names
|
| 70 |
+
)
|
| 71 |
+
return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def resolve_ellipsis(names, tensor_names, fn_name):
|
| 75 |
+
"""
|
| 76 |
+
Expands ... inside `names` to be equal to a list of names from `tensor_names`.
|
| 77 |
+
"""
|
| 78 |
+
ellipsis_idx = single_ellipsis_index(names, fn_name)
|
| 79 |
+
if ellipsis_idx is None:
|
| 80 |
+
return names
|
| 81 |
+
return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def update_names_with_list(tensor, names, inplace):
|
| 85 |
+
# Special case for tensor.rename(None)
|
| 86 |
+
if len(names) == 1 and names[0] is None:
|
| 87 |
+
return tensor._update_names(None, inplace)
|
| 88 |
+
|
| 89 |
+
return tensor._update_names(
|
| 90 |
+
resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def update_names_with_mapping(tensor, rename_map, inplace):
|
| 95 |
+
dim_map = build_dim_map(tensor)
|
| 96 |
+
for old_dim in rename_map.keys():
|
| 97 |
+
new_dim = rename_map[old_dim]
|
| 98 |
+
if old_dim in dim_map.keys():
|
| 99 |
+
dim_map[old_dim] = new_dim
|
| 100 |
+
else:
|
| 101 |
+
raise RuntimeError(
|
| 102 |
+
f"{namer_api_name(inplace)}: Tried to rename dim '{old_dim}' to dim "
|
| 103 |
+
f"{new_dim} in Tensor[{tensor.names}] but dim '{old_dim}' does not exist"
|
| 104 |
+
)
|
| 105 |
+
return tensor._update_names(tuple(dim_map.values()), inplace)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def update_names(tensor, names, rename_map, inplace):
|
| 109 |
+
"""There are two usages:
|
| 110 |
+
|
| 111 |
+
tensor.rename(*names) returns a view on tensor with named dims `names`.
|
| 112 |
+
`names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`,
|
| 113 |
+
then it is expanded greedily to be equal to the corresponding names from
|
| 114 |
+
`tensor.names`.
|
| 115 |
+
|
| 116 |
+
For example,
|
| 117 |
+
```
|
| 118 |
+
>>> # xdoctest: +SKIP
|
| 119 |
+
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
|
| 120 |
+
>>> x.rename('...', 'height', 'width').names
|
| 121 |
+
('N', 'C', 'height', 'width')
|
| 122 |
+
|
| 123 |
+
>>> # xdoctest: +SKIP
|
| 124 |
+
>>> x.rename('batch', '...', 'width').names
|
| 125 |
+
('batch', 'C', 'H', 'width')
|
| 126 |
+
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
tensor.rename(**rename_map) returns a view on tensor that has rename dims
|
| 130 |
+
as specified in the mapping `rename_map`.
|
| 131 |
+
|
| 132 |
+
For example,
|
| 133 |
+
```
|
| 134 |
+
>>> # xdoctest: +SKIP
|
| 135 |
+
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
|
| 136 |
+
>>> x.rename(W='width', H='height').names
|
| 137 |
+
('N', 'C', 'height', 'width')
|
| 138 |
+
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
Finally, tensor.rename has an in-place version called tensor.rename_.
|
| 142 |
+
"""
|
| 143 |
+
has_names = len(names) > 0
|
| 144 |
+
has_rename_pairs = bool(rename_map)
|
| 145 |
+
if has_names and has_rename_pairs:
|
| 146 |
+
raise RuntimeError(
|
| 147 |
+
f"{namer_api_name(inplace)}: This function takes either positional "
|
| 148 |
+
f"args or keyword args, but not both. Use tensor.{namer_api_name(inplace)}(*names) "
|
| 149 |
+
f"to name dims and tensor.{namer_api_name(inplace)}(**rename_map) to rename "
|
| 150 |
+
"dims."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Special case for tensor.rename(*[]), which is valid for a 0 dim tensor.
|
| 154 |
+
if not has_names and not has_rename_pairs:
|
| 155 |
+
return update_names_with_list(tensor, names, inplace)
|
| 156 |
+
|
| 157 |
+
if has_names:
|
| 158 |
+
return update_names_with_list(tensor, names, inplace)
|
| 159 |
+
return update_names_with_mapping(tensor, rename_map, inplace)
|
.venv/lib/python3.11/site-packages/torch/_ops.py
ADDED
|
@@ -0,0 +1,1355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import abc
|
| 3 |
+
import contextlib
|
| 4 |
+
import ctypes
|
| 5 |
+
import importlib
|
| 6 |
+
import inspect
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from typing import Any, Callable, Dict, List, Set, Type, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.utils._pytree as pytree
|
| 13 |
+
from torch import _utils_internal
|
| 14 |
+
from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
|
| 15 |
+
from torch._functorch.pyfunctorch import dispatch_functorch
|
| 16 |
+
from torch.utils._python_dispatch import TorchDispatchMode
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Query `hasattr` only once.
|
| 20 |
+
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@contextlib.contextmanager
|
| 24 |
+
def dl_open_guard():
|
| 25 |
+
"""
|
| 26 |
+
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
|
| 27 |
+
shared library to load custom operators.
|
| 28 |
+
"""
|
| 29 |
+
if not _SET_GLOBAL_FLAGS:
|
| 30 |
+
yield
|
| 31 |
+
return
|
| 32 |
+
old_flags = sys.getdlopenflags()
|
| 33 |
+
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
|
| 34 |
+
try:
|
| 35 |
+
yield
|
| 36 |
+
finally:
|
| 37 |
+
sys.setdlopenflags(old_flags)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class OperatorBase:
|
| 41 |
+
"""
|
| 42 |
+
Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator
|
| 43 |
+
(which represents Python-only operators that are unrepresentable in TorchScript).
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
# The dispatch cache precomputes a mapping of dispatch key that the
|
| 48 |
+
# dispatcher wants to dispatch to, to an actual implementation of the
|
| 49 |
+
# dispatch key. Confusingly, the actual implementation could *also* be a
|
| 50 |
+
# dispatch key, but in this case, this refers to the C++ kernel that
|
| 51 |
+
# was registered to some dispatch key. Aliases are permitted in the
|
| 52 |
+
# latter but not the former; for example, you might lookup the
|
| 53 |
+
# entry for AutogradCPU, and this maps you to the Autograd key for
|
| 54 |
+
# the generic autograd kernel that works for all devices. Since this
|
| 55 |
+
# is the Python dispatcher, you can also put an arbitrary Python
|
| 56 |
+
# callable to call instead. This handler gets precisely the
|
| 57 |
+
# args/kwargs that the operator was __call__'ed with.
|
| 58 |
+
# NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
|
| 59 |
+
# for use with OpOverload; cache lookup is done entirely from C++
|
| 60 |
+
# for speed.
|
| 61 |
+
# TODO: The cache is NOT currently used by HigherOrderOperator, but it should!
|
| 62 |
+
self._dispatch_cache: Dict[
|
| 63 |
+
DispatchKey, Union[DispatchKey, Callable[..., Any]]
|
| 64 |
+
] = {}
|
| 65 |
+
|
| 66 |
+
# This table allows you to override the behavior of a particular
|
| 67 |
+
# dispatch key to call a custom Python function, rather than the
|
| 68 |
+
# ordinary C++ configured behavior. This is the raison d'etre of
|
| 69 |
+
# Python dispatcher: to let you program the dispatcher from Python
|
| 70 |
+
# in case you need something unusual, and don't want to clobber
|
| 71 |
+
# the existing registrations using the Python operator registration
|
| 72 |
+
# API.
|
| 73 |
+
self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {}
|
| 74 |
+
|
| 75 |
+
# This table allows you to override the behavior of a particular
|
| 76 |
+
# operator for a particular TorchDispatchMode. In practice,
|
| 77 |
+
# we are using this mostly for ProxyTensorMode. Modes can be
|
| 78 |
+
# thought of as an open world extension of dispatch keys, so it
|
| 79 |
+
# makes sense that you should be able to register them, the same
|
| 80 |
+
# way you can register dispatch keys.
|
| 81 |
+
self.python_key_table: Dict[
|
| 82 |
+
Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any]
|
| 83 |
+
] = {}
|
| 84 |
+
|
| 85 |
+
# This table allows you to override the behavior of functorch
|
| 86 |
+
# transformations. NB: this currently only does something for
|
| 87 |
+
# HigherOrderOperator
|
| 88 |
+
self.functorch_table = {}
|
| 89 |
+
|
| 90 |
+
def __call__(self, *args, **kwargs):
|
| 91 |
+
raise NotImplementedError
|
| 92 |
+
|
| 93 |
+
def has_kernel_for_dispatch_key(self, k):
|
| 94 |
+
return k in self.py_kernels
|
| 95 |
+
|
| 96 |
+
def has_kernel_for_any_dispatch_key(self, ks):
|
| 97 |
+
for k in self.py_kernels:
|
| 98 |
+
if not torch._C._dispatch_is_alias_key(k) and ks.has(k):
|
| 99 |
+
return True
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
def py_impl(self, k):
|
| 103 |
+
def inner(fn):
|
| 104 |
+
if inspect.isclass(k) and (
|
| 105 |
+
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
|
| 106 |
+
):
|
| 107 |
+
assert k not in self.python_key_table
|
| 108 |
+
# TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys?
|
| 109 |
+
self.python_key_table[k] = fn
|
| 110 |
+
self._dispatch_cache.clear()
|
| 111 |
+
return fn
|
| 112 |
+
|
| 113 |
+
if isinstance(k, torch._C._functorch.TransformType):
|
| 114 |
+
assert k not in self.functorch_table
|
| 115 |
+
self.functorch_table[k] = fn
|
| 116 |
+
return fn
|
| 117 |
+
|
| 118 |
+
assert isinstance(k, DispatchKey)
|
| 119 |
+
assert (
|
| 120 |
+
k != DispatchKey.Python
|
| 121 |
+
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
|
| 122 |
+
|
| 123 |
+
if k in self.py_kernels:
|
| 124 |
+
raise RuntimeError(
|
| 125 |
+
f"Trying to override a python impl for {k} on operator {self.name()}"
|
| 126 |
+
)
|
| 127 |
+
self.py_kernels[k] = fn
|
| 128 |
+
self._dispatch_cache.clear()
|
| 129 |
+
return fn
|
| 130 |
+
|
| 131 |
+
return inner
|
| 132 |
+
|
| 133 |
+
# Registers an implementation to all **3** variants of functionalization that we have:
|
| 134 |
+
# - DispatchKey.Functionalize
|
| 135 |
+
# - functorch.TransformType.Functionalize
|
| 136 |
+
# - FunctionalTensorMode
|
| 137 |
+
# Example:
|
| 138 |
+
# @py_functionalize_impl
|
| 139 |
+
# def functionalize_rule(ctx, inner_f, *args):
|
| 140 |
+
# args_unwrapped = ctx.unwrap_tensors(args)
|
| 141 |
+
# with ctx.redispatch_to_next():
|
| 142 |
+
# out = ctx.functionalize(inner_f)(*args_unwrapped)
|
| 143 |
+
# return ctx.wrap_tensors(out)
|
| 144 |
+
def py_functionalize_impl(self, fn):
|
| 145 |
+
from torch._subclasses.functional_tensor import (
|
| 146 |
+
CppFunctionalizeAPI as _CppFunctionalizeAPI,
|
| 147 |
+
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
|
| 148 |
+
PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Construct our three flavors of functionalization,
|
| 152 |
+
# each of which have slightly different wrap/unwrap/redispatch policies
|
| 153 |
+
def functionalize_dk_fn(*args, **kwargs):
|
| 154 |
+
return fn(_CppFunctionalizeAPI(), *args, **kwargs)
|
| 155 |
+
|
| 156 |
+
def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
|
| 157 |
+
return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
|
| 158 |
+
|
| 159 |
+
def functionalize_functorch_fn(interpreter, *args, **kwargs):
|
| 160 |
+
return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
|
| 161 |
+
|
| 162 |
+
self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
|
| 163 |
+
self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
|
| 164 |
+
functionalize_dispatch_mode_fn
|
| 165 |
+
)
|
| 166 |
+
self.py_impl(torch._C._functorch.TransformType.Functionalize)(
|
| 167 |
+
functionalize_functorch_fn
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return fn
|
| 171 |
+
|
| 172 |
+
def name(self):
|
| 173 |
+
raise NotImplementedError
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Equivalent to computeDispatchTableEntryWithDebug
|
| 177 |
+
def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
|
| 178 |
+
# 1. (Direct) operator registration
|
| 179 |
+
if op.has_kernel_for_dispatch_key(k):
|
| 180 |
+
return k
|
| 181 |
+
# 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
|
| 182 |
+
cand = DispatchKey.CompositeExplicitAutogradNonFunctional
|
| 183 |
+
if (
|
| 184 |
+
k == DispatchKey.Undefined or is_included_in_alias(k, cand)
|
| 185 |
+
) and op.has_kernel_for_dispatch_key(cand):
|
| 186 |
+
return cand
|
| 187 |
+
# 2.2 Use CompositeExplicitAutograd kernel if available
|
| 188 |
+
cand = DispatchKey.CompositeExplicitAutograd
|
| 189 |
+
if (
|
| 190 |
+
k == DispatchKey.Undefined or is_included_in_alias(k, cand)
|
| 191 |
+
) and op.has_kernel_for_dispatch_key(cand):
|
| 192 |
+
return cand
|
| 193 |
+
has_backend_kernel = op.has_kernel_for_any_dispatch_key(
|
| 194 |
+
torch._C._dispatch_get_backend_keyset_from_autograd(k)
|
| 195 |
+
) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd)
|
| 196 |
+
# 2.3. Use CompositeImplicitAutograd kernel if available
|
| 197 |
+
cand = DispatchKey.CompositeImplicitAutogradNestedTensor
|
| 198 |
+
if (
|
| 199 |
+
(k != DispatchKey.Undefined and is_included_in_alias(k, cand))
|
| 200 |
+
and op.has_kernel_for_dispatch_key(cand)
|
| 201 |
+
and not has_backend_kernel
|
| 202 |
+
):
|
| 203 |
+
return cand
|
| 204 |
+
cand = DispatchKey.CompositeImplicitAutograd
|
| 205 |
+
if (
|
| 206 |
+
k == DispatchKey.Undefined or is_included_in_alias(k, cand)
|
| 207 |
+
) and op.has_kernel_for_dispatch_key(cand):
|
| 208 |
+
if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key(
|
| 209 |
+
torch._C._dispatch_autogradother_backends
|
| 210 |
+
):
|
| 211 |
+
raise RuntimeError("ambiguous autogradother kernel")
|
| 212 |
+
elif not has_backend_kernel:
|
| 213 |
+
return cand
|
| 214 |
+
# 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
|
| 215 |
+
cand = DispatchKey.Autograd
|
| 216 |
+
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
|
| 217 |
+
return cand
|
| 218 |
+
# 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
|
| 219 |
+
cand = DispatchKey.FuncTorchBatchedDecomposition
|
| 220 |
+
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
|
| 221 |
+
return cand
|
| 222 |
+
# Backend fallback
|
| 223 |
+
if torch._C._dispatch_has_backend_fallback(k):
|
| 224 |
+
# The dispatch key itself will implicitly route to backend fallback.
|
| 225 |
+
# This is probably not great for the pure Python implementation.
|
| 226 |
+
return k
|
| 227 |
+
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
_higher_order_ops: Dict[str, "HigherOrderOperator"] = {}
|
| 231 |
+
|
| 232 |
+
_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
|
| 233 |
+
DispatchKey.PythonDispatcher, # type: ignore[attr-defined]
|
| 234 |
+
DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined]
|
| 235 |
+
DispatchKey.ADInplaceOrView,
|
| 236 |
+
DispatchKey.BackendSelect,
|
| 237 |
+
DispatchKey.AutocastCPU, # type: ignore[attr-defined]
|
| 238 |
+
DispatchKey.AutocastCUDA, # type: ignore[attr-defined]
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class HigherOrderOperator(OperatorBase, abc.ABC):
|
| 243 |
+
# The HigherOrderOperator will appear as torch.ops.higher_order.{name}
|
| 244 |
+
#
|
| 245 |
+
# If you're creating a new HigherOrderOperator, please do not change the
|
| 246 |
+
# default. Adding operators to the global torch.ops namespace is a bad
|
| 247 |
+
# practice due to name collisions.
|
| 248 |
+
def __init__(self, name):
|
| 249 |
+
super().__init__()
|
| 250 |
+
if type(self) is HigherOrderOperator:
|
| 251 |
+
raise RuntimeError(
|
| 252 |
+
"Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
|
| 253 |
+
)
|
| 254 |
+
self._name = name
|
| 255 |
+
|
| 256 |
+
# Make _OPNamespace not scream, this whole name based association needs a good hard look
|
| 257 |
+
self.__name__ = name
|
| 258 |
+
_higher_order_ops[name] = self
|
| 259 |
+
self._ns = "higher_order"
|
| 260 |
+
self.__module__ = "torch.ops.higher_order"
|
| 261 |
+
|
| 262 |
+
self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
|
| 263 |
+
|
| 264 |
+
for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS:
|
| 265 |
+
self.fallthrough(dispatch_key)
|
| 266 |
+
|
| 267 |
+
# [NOTE] We have to register pre-dispatch key implementation
|
| 268 |
+
# because sometimes HOP use aot-dispatch tracing to detect certaion
|
| 269 |
+
# mutations. This is problematic when we are functionalizing HOP
|
| 270 |
+
# during pre-dispatch because when the inner tracer starts, it will see
|
| 271 |
+
# that PreDispatch key is still active. In that case, we just redispatch
|
| 272 |
+
# it to next key. This is only safe to do when PreDispatch key stack has no
|
| 273 |
+
# active modes.
|
| 274 |
+
|
| 275 |
+
def py_impl(self, k):
|
| 276 |
+
if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
|
| 277 |
+
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
|
| 278 |
+
return super().py_impl(k)
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def namespace(self):
|
| 282 |
+
return self._ns
|
| 283 |
+
|
| 284 |
+
def fallthrough(self, dispatch_key):
|
| 285 |
+
self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
|
| 286 |
+
|
| 287 |
+
# Use positional-only argument to avoid naming collide with custom ops arguments
|
| 288 |
+
# that are named "self".
|
| 289 |
+
def dispatch(self, /, dispatch_key, *args, **kwargs):
|
| 290 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
| 291 |
+
|
| 292 |
+
if dispatch_key in self._dispatch_cache:
|
| 293 |
+
kernel = self._dispatch_cache[dispatch_key]
|
| 294 |
+
assert not isinstance(kernel, DispatchKey)
|
| 295 |
+
return kernel(*args, **kwargs)
|
| 296 |
+
|
| 297 |
+
if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode:
|
| 298 |
+
return dispatch_functorch(self, args, kwargs)
|
| 299 |
+
|
| 300 |
+
if dispatch_key == DispatchKey.Python:
|
| 301 |
+
# Keep the following 1:1 with handle_torch_function_no_python_arg_parser
|
| 302 |
+
# in torch/csrc/utils/python_arg_parser.cpp
|
| 303 |
+
|
| 304 |
+
overloaded_args_list = []
|
| 305 |
+
|
| 306 |
+
def has_python_key(tensor):
|
| 307 |
+
return torch._C._dispatch_keys(tensor).has("Python")
|
| 308 |
+
|
| 309 |
+
def check_overloaded(arg):
|
| 310 |
+
if isinstance(arg, torch.Tensor) and has_python_key(arg):
|
| 311 |
+
overloaded_args_list.append(arg)
|
| 312 |
+
|
| 313 |
+
for arg in (*args, *kwargs.values()):
|
| 314 |
+
check_overloaded(arg)
|
| 315 |
+
if isinstance(arg, (list, tuple)):
|
| 316 |
+
for a in arg:
|
| 317 |
+
check_overloaded(a)
|
| 318 |
+
|
| 319 |
+
overloaded_args = tuple(overloaded_args_list)
|
| 320 |
+
overloaded_types = tuple(type(arg) for arg in overloaded_args)
|
| 321 |
+
|
| 322 |
+
# Step 1: dispatch on any user TorchDispatchModes
|
| 323 |
+
from torch.utils._python_dispatch import _pop_mode_temporarily
|
| 324 |
+
|
| 325 |
+
curr_mode = _get_current_dispatch_mode()
|
| 326 |
+
if curr_mode is not None:
|
| 327 |
+
if type(curr_mode) in self.python_key_table:
|
| 328 |
+
handler = self.python_key_table[type(curr_mode)]
|
| 329 |
+
with _pop_mode_temporarily() as mode:
|
| 330 |
+
# "natural" calling convention: (mode, *args, **kwargs)
|
| 331 |
+
# TODO(rzou): we should support torch_dispatch calling convention too.
|
| 332 |
+
result = handler(mode, *args, **kwargs)
|
| 333 |
+
else:
|
| 334 |
+
raise NotImplementedError(
|
| 335 |
+
f"There was no rule registered for HOP {self._name} and mode {curr_mode}. "
|
| 336 |
+
f"We recommend filing an issue."
|
| 337 |
+
)
|
| 338 |
+
if result is not NotImplemented:
|
| 339 |
+
return result
|
| 340 |
+
|
| 341 |
+
# Step 2: dispatch on any subclasses
|
| 342 |
+
for arg in overloaded_args:
|
| 343 |
+
subclass_type = type(arg)
|
| 344 |
+
if (
|
| 345 |
+
subclass_type.__torch_dispatch__
|
| 346 |
+
== torch._C._disabled_torch_dispatch_impl
|
| 347 |
+
):
|
| 348 |
+
continue
|
| 349 |
+
if subclass_type in self.python_key_table:
|
| 350 |
+
handler = self.python_key_table[subclass_type]
|
| 351 |
+
# "natural" calling convention: (*args, **kwargs)
|
| 352 |
+
# TODO(rzou): we should support torch_dispatch calling convention too.
|
| 353 |
+
result = handler(*args, **kwargs)
|
| 354 |
+
else:
|
| 355 |
+
raise NotImplementedError(
|
| 356 |
+
f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. "
|
| 357 |
+
f"We recommend filing an issue."
|
| 358 |
+
)
|
| 359 |
+
if result is not NotImplemented:
|
| 360 |
+
return result
|
| 361 |
+
|
| 362 |
+
# All handlers returned NotImplemented
|
| 363 |
+
raise TypeError(
|
| 364 |
+
f"Multiple dispatch failed for {self._name}. There was no registered that "
|
| 365 |
+
f"did not return NotImplemented. Use HOP.py_impl to register some. "
|
| 366 |
+
f"Tried mode: {curr_mode}) and subclasses: "
|
| 367 |
+
f"{[type(a) for a in overloaded_args]}"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined]
|
| 371 |
+
if functionality_key == DispatchKey.PreDispatch:
|
| 372 |
+
from torch.utils._python_dispatch import _pop_mode_temporarily
|
| 373 |
+
|
| 374 |
+
# The check for Python in the exclude set is so we properly respect `with no_dispatch()`
|
| 375 |
+
# calls inside of a mode.
|
| 376 |
+
if (
|
| 377 |
+
_len_torch_dispatch_stack_pre_dispatch() > 0
|
| 378 |
+
) and not torch._C._dispatch_tls_is_dispatch_key_excluded(
|
| 379 |
+
DispatchKey.Python
|
| 380 |
+
):
|
| 381 |
+
curr_mode = _get_current_dispatch_mode_pre_dispatch()
|
| 382 |
+
assert (
|
| 383 |
+
curr_mode is not None
|
| 384 |
+
), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
|
| 385 |
+
assert (
|
| 386 |
+
type(curr_mode) in self.python_key_table
|
| 387 |
+
), f"Current active mode {curr_mode} not registered"
|
| 388 |
+
handler = self.python_key_table[type(curr_mode)]
|
| 389 |
+
with _pop_mode_temporarily(functionality_key) as mode:
|
| 390 |
+
return handler(mode, *args, **kwargs)
|
| 391 |
+
|
| 392 |
+
final_key = resolve_key(self, dispatch_key)
|
| 393 |
+
|
| 394 |
+
# This can current fail due to backend fallbacks. You just have to
|
| 395 |
+
# register them by hand for HigherOrderOperator.
|
| 396 |
+
if final_key not in self.py_kernels:
|
| 397 |
+
raise NotImplementedError(
|
| 398 |
+
f"could not find kernel for HigherOrderOperator {self._name} "
|
| 399 |
+
f"at dispatch key {final_key} (resolved from {dispatch_key})"
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# [NOTE] We shouldn't cache PreDispatch kernel here because depending
|
| 403 |
+
# on what modes are active, predispatch behaviour is different.
|
| 404 |
+
# Also we do same thing for normal ops:
|
| 405 |
+
# See Note [Not Caching Per-Dispatch-Key Mode Handlers]
|
| 406 |
+
if dispatch_key != DispatchKey.PreDispatch:
|
| 407 |
+
self._dispatch_cache[dispatch_key] = self.py_kernels[final_key]
|
| 408 |
+
kernel = self.py_kernels[final_key]
|
| 409 |
+
# It's illegal to register DispatchKey to py_kernels, since there's no
|
| 410 |
+
# C++ kernel to call into
|
| 411 |
+
assert not isinstance(kernel, DispatchKey)
|
| 412 |
+
return kernel(*args, **kwargs)
|
| 413 |
+
|
| 414 |
+
@abc.abstractmethod
|
| 415 |
+
def __call__(self, /, *args, **kwargs):
|
| 416 |
+
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
| 417 |
+
# so no need to trace into it.
|
| 418 |
+
from torch._dynamo import disable
|
| 419 |
+
|
| 420 |
+
@disable
|
| 421 |
+
def wrapper():
|
| 422 |
+
flat_args = _to_flat_tuple(args, kwargs)
|
| 423 |
+
if torch.overrides.has_torch_function(flat_args):
|
| 424 |
+
return torch.overrides.handle_torch_function(
|
| 425 |
+
self, flat_args, *args, **kwargs
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
|
| 429 |
+
return self.dispatch(
|
| 430 |
+
dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return wrapper()
|
| 434 |
+
|
| 435 |
+
def __str__(self):
|
| 436 |
+
return f"{self.name()}"
|
| 437 |
+
|
| 438 |
+
def name(self):
|
| 439 |
+
return self._name
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def _to_flat_tuple(args, kwargs):
|
| 443 |
+
return pytree.arg_tree_leaves(*args, **kwargs)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _compute_keyset(args, kwargs, non_fallthrough_keys):
|
| 447 |
+
tensors = _get_tensors(args, kwargs)
|
| 448 |
+
return key_extractor(tensors, non_fallthrough_keys)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _get_tensors(args, kwargs):
|
| 452 |
+
flat_all = _to_flat_tuple(args, kwargs)
|
| 453 |
+
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
|
| 454 |
+
return tuple(tensor_args)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
|
| 458 |
+
# at ATen/core/dispatch/DispatchKeyExtractor.h
|
| 459 |
+
def key_extractor(tensors, key_mask):
|
| 460 |
+
key_set = torch._C._dispatch_tls_local_include_set()
|
| 461 |
+
for tensor in tensors:
|
| 462 |
+
key_set = key_set | torch._C._dispatch_keys(tensor)
|
| 463 |
+
key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
|
| 464 |
+
key_set = key_set & key_mask
|
| 465 |
+
return key_set
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# Mode stack for PreDispatchKey
|
| 469 |
+
# it should always have three keys with
|
| 470 |
+
# priority given to FunctionalTensorMode and
|
| 471 |
+
# then ProxyTorchDispatchMode. It means that
|
| 472 |
+
# slot 0 belongs to ProxyTorchDispatchMode and
|
| 473 |
+
# slot 1 belongs to FunctionalTensorMode.
|
| 474 |
+
#
|
| 475 |
+
# SchemaCheckMode is separate from the other 2,
|
| 476 |
+
# and is only valid when the stack is empty.
|
| 477 |
+
# SchemaCheckMode is for testing purposes, and
|
| 478 |
+
# is meant to run in eager mode on concrete inputs,
|
| 479 |
+
# checking for incorrect schemas in regards to
|
| 480 |
+
# aliasing or mutating ops.
|
| 481 |
+
class _ModeStackStateForPreDispatch:
|
| 482 |
+
def __init__(self):
|
| 483 |
+
self.__infra_modes = [None, None]
|
| 484 |
+
self._schema_check_mode = None
|
| 485 |
+
|
| 486 |
+
def set(self, index, mode):
|
| 487 |
+
assert index < len(self.__infra_modes)
|
| 488 |
+
self.__infra_modes[index] = mode
|
| 489 |
+
|
| 490 |
+
def get(self, index):
|
| 491 |
+
assert index < len(self.__infra_modes)
|
| 492 |
+
return self.__infra_modes[index]
|
| 493 |
+
|
| 494 |
+
def count(self):
|
| 495 |
+
return len([i for i in self.__infra_modes if i is not None]) + int(
|
| 496 |
+
self._schema_check_mode is not None
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch()
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def unset_mode_pre_dispatch(mode_key, schema_check=False):
|
| 504 |
+
current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch()
|
| 505 |
+
assert mode_key is None or mode_key in (
|
| 506 |
+
torch._C._TorchDispatchModeKey.PROXY,
|
| 507 |
+
torch._C._TorchDispatchModeKey.FUNCTIONAL,
|
| 508 |
+
)
|
| 509 |
+
if schema_check:
|
| 510 |
+
assert mode_key is None
|
| 511 |
+
|
| 512 |
+
def _unset_mode():
|
| 513 |
+
if mode_key == torch._C._TorchDispatchModeKey.PROXY:
|
| 514 |
+
current_mode = current_mode_stack_pre_dispatch.get(0)
|
| 515 |
+
mode_stack_state_for_pre_dispatch().set(0, None)
|
| 516 |
+
return current_mode
|
| 517 |
+
elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL:
|
| 518 |
+
current_mode = current_mode_stack_pre_dispatch.get(1)
|
| 519 |
+
mode_stack_state_for_pre_dispatch().set(1, None)
|
| 520 |
+
return current_mode
|
| 521 |
+
else:
|
| 522 |
+
current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
|
| 523 |
+
mode_stack_state_for_pre_dispatch()._schema_check_mode = None
|
| 524 |
+
return current_mode
|
| 525 |
+
|
| 526 |
+
current_mode = _unset_mode()
|
| 527 |
+
|
| 528 |
+
new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
|
| 529 |
+
# When we are unsetting a mode, we need to check if there is
|
| 530 |
+
# active mode left on the PreDispatch key. If there is nothing
|
| 531 |
+
# active, we need to remove PreDispatch key from local dispatch include
|
| 532 |
+
# set.
|
| 533 |
+
if new_pre_dispatch_len == 0:
|
| 534 |
+
torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False)
|
| 535 |
+
|
| 536 |
+
return current_mode
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def _set_mode_pre_dispatch(mode):
|
| 540 |
+
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
| 541 |
+
from torch._subclasses.schema_check_mode import SchemaCheckMode
|
| 542 |
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
| 543 |
+
|
| 544 |
+
assert isinstance(
|
| 545 |
+
mode,
|
| 546 |
+
(
|
| 547 |
+
FunctionalTensorMode,
|
| 548 |
+
ProxyTorchDispatchMode,
|
| 549 |
+
SchemaCheckMode,
|
| 550 |
+
),
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch()
|
| 554 |
+
if isinstance(mode, SchemaCheckMode):
|
| 555 |
+
current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
|
| 556 |
+
if previous_mode_stack_len > 0:
|
| 557 |
+
raise AssertionError(
|
| 558 |
+
"SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack"
|
| 559 |
+
)
|
| 560 |
+
mode_stack_state_for_pre_dispatch()._schema_check_mode = mode
|
| 561 |
+
elif isinstance(mode, FunctionalTensorMode):
|
| 562 |
+
current_mode = mode_stack_state_for_pre_dispatch().get(1)
|
| 563 |
+
assert current_mode is None
|
| 564 |
+
mode_stack_state_for_pre_dispatch().set(1, mode)
|
| 565 |
+
else:
|
| 566 |
+
current_mode = mode_stack_state_for_pre_dispatch().get(0)
|
| 567 |
+
assert current_mode is None
|
| 568 |
+
mode_stack_state_for_pre_dispatch().set(0, mode)
|
| 569 |
+
|
| 570 |
+
# When we are setting a mode, we need to check if there is
|
| 571 |
+
# active mode left on the PreDispatch key. If there was nothing
|
| 572 |
+
# active before setting this mode, it means that PreDispatch key
|
| 573 |
+
# was turned off. So we need to turn it on again.
|
| 574 |
+
if previous_mode_stack_len == 0:
|
| 575 |
+
torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _pop_mode_from_pre_dispatch():
|
| 579 |
+
mode_stack = mode_stack_state_for_pre_dispatch()
|
| 580 |
+
pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
|
| 581 |
+
|
| 582 |
+
if pre_dispatch_len == 0:
|
| 583 |
+
raise AssertionError("Trying to pop empty mode stack")
|
| 584 |
+
|
| 585 |
+
if mode_stack._schema_check_mode is not None:
|
| 586 |
+
return unset_mode_pre_dispatch(None, schema_check=True)
|
| 587 |
+
if mode_stack.get(1) is not None:
|
| 588 |
+
return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL)
|
| 589 |
+
if mode_stack.get(0) is not None:
|
| 590 |
+
return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def _len_torch_dispatch_stack_pre_dispatch():
|
| 594 |
+
return mode_stack_state_for_pre_dispatch().count()
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def _get_dispatch_mode_pre_dispatch(mode_key):
|
| 598 |
+
assert mode_key in (
|
| 599 |
+
torch._C._TorchDispatchModeKey.PROXY,
|
| 600 |
+
torch._C._TorchDispatchModeKey.FUNCTIONAL,
|
| 601 |
+
)
|
| 602 |
+
if mode_key == torch._C._TorchDispatchModeKey.PROXY:
|
| 603 |
+
return mode_stack_state_for_pre_dispatch().get(0)
|
| 604 |
+
else:
|
| 605 |
+
return mode_stack_state_for_pre_dispatch().get(1)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def _get_current_dispatch_mode_pre_dispatch():
|
| 609 |
+
if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None:
|
| 610 |
+
return mode_stack_state_for_pre_dispatch()._schema_check_mode
|
| 611 |
+
else:
|
| 612 |
+
stack_len = mode_stack_state_for_pre_dispatch().count()
|
| 613 |
+
if stack_len == 2:
|
| 614 |
+
return mode_stack_state_for_pre_dispatch().get(1)
|
| 615 |
+
if stack_len == 1:
|
| 616 |
+
return (
|
| 617 |
+
mode_stack_state_for_pre_dispatch().get(1)
|
| 618 |
+
if mode_stack_state_for_pre_dispatch().get(1) is not None
|
| 619 |
+
else mode_stack_state_for_pre_dispatch().get(0)
|
| 620 |
+
)
|
| 621 |
+
return None
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def mode_stack_state_for_pre_dispatch():
|
| 625 |
+
global _mode_stack_state_for_pre_dispatch
|
| 626 |
+
return _mode_stack_state_for_pre_dispatch
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
cached_ops: Set["OpOverload"] = set()
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def add_cached_op(op_overload):
|
| 633 |
+
global cached_ops
|
| 634 |
+
cached_ops.add(op_overload)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def reset_cached_ops():
|
| 638 |
+
global cached_ops
|
| 639 |
+
cached_ops.clear()
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def get_cached_ops():
|
| 643 |
+
global cached_ops
|
| 644 |
+
return cached_ops
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
|
| 648 |
+
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
|
| 649 |
+
class OpOverload(OperatorBase):
|
| 650 |
+
def __init__(self, overloadpacket, op, op_dk, schema, tags):
|
| 651 |
+
super().__init__()
|
| 652 |
+
self._op = op
|
| 653 |
+
self._op_dk = op_dk
|
| 654 |
+
self._schema = schema
|
| 655 |
+
self._overloadpacket = overloadpacket
|
| 656 |
+
self._tags = tags
|
| 657 |
+
self._overloadname = (
|
| 658 |
+
"default" if schema.overload_name == "" else schema.overload_name
|
| 659 |
+
)
|
| 660 |
+
self._name = self._schema.name
|
| 661 |
+
if schema.overload_name:
|
| 662 |
+
self._name += "." + schema.overload_name
|
| 663 |
+
self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}"
|
| 664 |
+
self.__module__ = overloadpacket.__module__
|
| 665 |
+
op.__module__ = overloadpacket.__module__
|
| 666 |
+
self.__qualname__ = self._name
|
| 667 |
+
self.__annotations__ = {}
|
| 668 |
+
# Only compute the OperatorHandle when we need it. Not all OpOverloads have
|
| 669 |
+
# OperatorHandles (the TorchScript ones don't...)
|
| 670 |
+
self._lazy_handle = None
|
| 671 |
+
|
| 672 |
+
# If the OpOverload was constructed from a Library.def in Python.
|
| 673 |
+
self._defined_in_python = self.__qualname__ in torch.library._defs
|
| 674 |
+
|
| 675 |
+
# Logic replicated from aten/src/ATen/native/MathBitsFallback.h
|
| 676 |
+
is_write = None
|
| 677 |
+
for a in self._schema.arguments:
|
| 678 |
+
if a.alias_info is None:
|
| 679 |
+
continue
|
| 680 |
+
if is_write is None:
|
| 681 |
+
is_write = a.alias_info.is_write
|
| 682 |
+
else:
|
| 683 |
+
# We will conservatively call mixed mutable/non-mutable
|
| 684 |
+
# aliased inputs as NOT a view
|
| 685 |
+
is_write = a.alias_info.is_write or is_write
|
| 686 |
+
self.is_view = is_write is not None and not is_write
|
| 687 |
+
|
| 688 |
+
@property
|
| 689 |
+
def _namespace(self):
|
| 690 |
+
return self._schema.name.split("::")[0]
|
| 691 |
+
|
| 692 |
+
@property
|
| 693 |
+
def _opname(self):
|
| 694 |
+
return self._schema.name.split("::")[1]
|
| 695 |
+
|
| 696 |
+
@property
|
| 697 |
+
def _handle(self):
|
| 698 |
+
if self._lazy_handle is None:
|
| 699 |
+
self._lazy_handle = torch._C._dispatch_find_schema_or_throw(
|
| 700 |
+
self._schema.name, self._schema.overload_name
|
| 701 |
+
)
|
| 702 |
+
return self._lazy_handle
|
| 703 |
+
|
| 704 |
+
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
|
| 705 |
+
def __deepcopy__(self, memo=None):
|
| 706 |
+
return self
|
| 707 |
+
|
| 708 |
+
def __repr__(self):
|
| 709 |
+
return "<OpOverload(op='{}.{}', overload='{}')>".format(
|
| 710 |
+
*self._schema.name.split("::"), self._overloadname
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# Use positional-only argument to avoid naming collision with aten ops arguments
|
| 714 |
+
# that are named "self". This way, all the aten ops can be called by kwargs.
|
| 715 |
+
def __call__(self, /, *args, **kwargs):
|
| 716 |
+
return self._op(*args, **kwargs)
|
| 717 |
+
|
| 718 |
+
# Use positional-only argument to avoid naming collision with aten ops arguments
|
| 719 |
+
# that are named "self". This way, all the aten ops can be called by kwargs.
|
| 720 |
+
def redispatch(self, /, keyset, *args, **kwargs):
|
| 721 |
+
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
|
| 722 |
+
|
| 723 |
+
def __hash__(self):
|
| 724 |
+
return hash(self._op)
|
| 725 |
+
|
| 726 |
+
# `my_namespace.my_op_name.overload_name`
|
| 727 |
+
def __str__(self):
|
| 728 |
+
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
|
| 729 |
+
|
| 730 |
+
def has_kernel_for_dispatch_key(self, k):
|
| 731 |
+
return super().has_kernel_for_dispatch_key(
|
| 732 |
+
k
|
| 733 |
+
) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k)
|
| 734 |
+
|
| 735 |
+
def has_kernel_for_any_dispatch_key(self, ks):
|
| 736 |
+
return torch._C._dispatch_has_kernel_for_any_dispatch_key(
|
| 737 |
+
self.name(), ks
|
| 738 |
+
) or super().has_kernel_for_any_dispatch_key(ks)
|
| 739 |
+
|
| 740 |
+
@property
|
| 741 |
+
def namespace(self):
|
| 742 |
+
return self._schema.name.split("::")[0]
|
| 743 |
+
|
| 744 |
+
def _can_decompose(self):
|
| 745 |
+
dk = DispatchKey.CompositeImplicitAutograd
|
| 746 |
+
return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key(
|
| 747 |
+
self.name(), dk
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
def decompose(self, *args, **kwargs):
|
| 751 |
+
dk = DispatchKey.CompositeImplicitAutograd
|
| 752 |
+
if dk in self.py_kernels:
|
| 753 |
+
# NB: This branch is not too necessary anymore, because we can
|
| 754 |
+
# apply Python CompositeImplicitAutograd *before* tracing
|
| 755 |
+
# using Python dispatcher (also taking advantage of the autograd
|
| 756 |
+
# formula). But it's included for completeness
|
| 757 |
+
return self.py_kernels[dk](*args, **kwargs)
|
| 758 |
+
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
|
| 759 |
+
return self._op_dk(dk, *args, **kwargs)
|
| 760 |
+
else:
|
| 761 |
+
return NotImplemented
|
| 762 |
+
|
| 763 |
+
# Remove a dispatch key from the dispatch cache. This will force it to get
|
| 764 |
+
# recomputed the next time. Does nothing
|
| 765 |
+
# WARNING: if you register a dispatch key to py_kernels of an OpOverload,
|
| 766 |
+
# calling _del_dispatch on that key is NOT sufficient to apply your change,
|
| 767 |
+
# because a single registration may affect MULTIPLE dispatch keys (e.g.,
|
| 768 |
+
# registering Autograd affects AutogradCPU). del_dispatch is to be used
|
| 769 |
+
# only if you are specifically modifying how get_dispatch handles a
|
| 770 |
+
# particular input 'key'.
|
| 771 |
+
def _uncache_dispatch(self, key):
|
| 772 |
+
self._dispatch_cache.pop(key, None)
|
| 773 |
+
|
| 774 |
+
# This implements the pre-computation logic for the Python dispatcher.
|
| 775 |
+
def _get_dispatch(self, key):
|
| 776 |
+
# This is only called upon a cache miss
|
| 777 |
+
assert key not in self._dispatch_cache, f"{self} {key}"
|
| 778 |
+
|
| 779 |
+
if key == DispatchKey.Python:
|
| 780 |
+
if not isinstance(self, TorchBindOpOverload) and not self.python_key_table:
|
| 781 |
+
self._dispatch_cache[key] = key
|
| 782 |
+
add_cached_op(self)
|
| 783 |
+
return key
|
| 784 |
+
|
| 785 |
+
def handler(*args, **kwargs):
|
| 786 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
| 787 |
+
|
| 788 |
+
# TODO: We also need to handle tensor subclasses here
|
| 789 |
+
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
|
| 790 |
+
curr_mode = type(_get_current_dispatch_mode())
|
| 791 |
+
assert (
|
| 792 |
+
curr_mode is not None
|
| 793 |
+
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
|
| 794 |
+
|
| 795 |
+
if curr_mode not in self.python_key_table:
|
| 796 |
+
if isinstance(self, TorchBindOpOverload):
|
| 797 |
+
with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
|
| 798 |
+
return torch._library.utils.handle_dispatch_mode(
|
| 799 |
+
mode, self, *args, **kwargs
|
| 800 |
+
)
|
| 801 |
+
else:
|
| 802 |
+
return self._op_dk(key, *args, **kwargs)
|
| 803 |
+
|
| 804 |
+
with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
|
| 805 |
+
return self.python_key_table[curr_mode](mode, *args, **kwargs)
|
| 806 |
+
|
| 807 |
+
self._dispatch_cache[key] = handler
|
| 808 |
+
add_cached_op(self)
|
| 809 |
+
return handler
|
| 810 |
+
|
| 811 |
+
functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined]
|
| 812 |
+
if functionality_key == DispatchKey.PreDispatch:
|
| 813 |
+
curr_stack_len = _len_torch_dispatch_stack_pre_dispatch()
|
| 814 |
+
# The check for Python in the exclude set is so we properly respect `with no_dispatch()`
|
| 815 |
+
# calls inside of a mode.
|
| 816 |
+
if (
|
| 817 |
+
curr_stack_len > 0
|
| 818 |
+
and not torch._C._dispatch_tls_is_dispatch_key_excluded(
|
| 819 |
+
DispatchKey.Python
|
| 820 |
+
)
|
| 821 |
+
):
|
| 822 |
+
|
| 823 |
+
def handler(*args, **kwargs):
|
| 824 |
+
@contextlib.contextmanager
|
| 825 |
+
def _temporarily_pop_modes_from_pre_dispatch():
|
| 826 |
+
top_mode = _pop_mode_from_pre_dispatch()
|
| 827 |
+
try:
|
| 828 |
+
yield top_mode
|
| 829 |
+
finally:
|
| 830 |
+
_set_mode_pre_dispatch(top_mode)
|
| 831 |
+
|
| 832 |
+
with _temporarily_pop_modes_from_pre_dispatch() as curr_mode:
|
| 833 |
+
return torch._library.utils.handle_dispatch_mode(
|
| 834 |
+
curr_mode, self, *args, **kwargs
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# Note [Not Caching Per-Dispatch-Key Mode Handlers]
|
| 838 |
+
# Note that we're not caching this handler. There isn't really a point, since the slow bit
|
| 839 |
+
# is the handler itself (in python).
|
| 840 |
+
# Also, not caching means that we don't have to reset the cache when any existing
|
| 841 |
+
# modes go out of scope (which in of itself takes time to loop through all operators).
|
| 842 |
+
return handler
|
| 843 |
+
|
| 844 |
+
final_key = resolve_key(self, key)
|
| 845 |
+
|
| 846 |
+
# See Note [Not Caching Per-Dispatch-Key Mode Handlers]
|
| 847 |
+
cache_result = key != DispatchKey.PreDispatch
|
| 848 |
+
|
| 849 |
+
# TODO: We could potentially have lots of debugging wrappers against
|
| 850 |
+
# dispatch keys; design some general registration mechanism instead of
|
| 851 |
+
# having if statement for each of them
|
| 852 |
+
if key == DispatchKey.Functionalize:
|
| 853 |
+
import torch._dispatch.python as pydispatch
|
| 854 |
+
|
| 855 |
+
if pydispatch.CROSSREF_FUNCTIONALIZE:
|
| 856 |
+
handler = pydispatch.make_crossref_functionalize(self, final_key)
|
| 857 |
+
if cache_result:
|
| 858 |
+
self._dispatch_cache[key] = handler
|
| 859 |
+
add_cached_op(self)
|
| 860 |
+
return handler
|
| 861 |
+
|
| 862 |
+
r = self.py_kernels.get(final_key, final_key)
|
| 863 |
+
if cache_result:
|
| 864 |
+
self._dispatch_cache[key] = r
|
| 865 |
+
add_cached_op(self)
|
| 866 |
+
return r
|
| 867 |
+
|
| 868 |
+
def name(self):
|
| 869 |
+
return self._name
|
| 870 |
+
|
| 871 |
+
@property
|
| 872 |
+
def overloadpacket(self):
|
| 873 |
+
return self._overloadpacket
|
| 874 |
+
|
| 875 |
+
@property
|
| 876 |
+
def op(self):
|
| 877 |
+
return self._op
|
| 878 |
+
|
| 879 |
+
@property
|
| 880 |
+
def tags(self):
|
| 881 |
+
return self._tags
|
| 882 |
+
|
| 883 |
+
# TODO: add more methods to expose information about input and output arguments
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
# TorchBindOpOverload are those custom ops which have at least one overload's
|
| 887 |
+
# schema consists of torch.ScriptObject (i.e. custom class) input.
|
| 888 |
+
# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
|
| 889 |
+
# when its inputs contain FakeScriptObject in a similar way as higher order ops.
|
| 890 |
+
class TorchBindOpOverload(OpOverload):
|
| 891 |
+
def _fallthrough_keys(self) -> List[DispatchKey]:
|
| 892 |
+
# TODO: we should be calling the fallback for these, but a fallthrough is almost close
|
| 893 |
+
# enough to the fallback in most cases that we care about.
|
| 894 |
+
_DEFAULT_FALLTHROUGH_KEYS = [
|
| 895 |
+
DispatchKey.Autograd,
|
| 896 |
+
DispatchKey.AutogradCPU,
|
| 897 |
+
DispatchKey.AutogradCUDA,
|
| 898 |
+
DispatchKey.ADInplaceOrView,
|
| 899 |
+
DispatchKey.BackendSelect,
|
| 900 |
+
DispatchKey.PythonTLSSnapshot,
|
| 901 |
+
DispatchKey.PythonDispatcher,
|
| 902 |
+
]
|
| 903 |
+
|
| 904 |
+
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
|
| 905 |
+
if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
|
| 906 |
+
return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
|
| 907 |
+
self.name(), key
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
return (
|
| 911 |
+
key not in self.py_kernels
|
| 912 |
+
or self.py_kernels[key] is torch.library.fallthrough_kernel
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
return [
|
| 916 |
+
key
|
| 917 |
+
for key in _DEFAULT_FALLTHROUGH_KEYS
|
| 918 |
+
if _may_use_fallthrough_instead_of_fallback(key)
|
| 919 |
+
]
|
| 920 |
+
|
| 921 |
+
@contextlib.contextmanager
|
| 922 |
+
def _register_as_effectful_op_temporarily(self):
|
| 923 |
+
from torch._higher_order_ops.effects import (
|
| 924 |
+
_EffectType,
|
| 925 |
+
_register_effectful_op,
|
| 926 |
+
SIDE_EFFECTS,
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
try:
|
| 930 |
+
if self not in SIDE_EFFECTS:
|
| 931 |
+
_register_effectful_op(self, _EffectType.ORDERED)
|
| 932 |
+
yield
|
| 933 |
+
finally:
|
| 934 |
+
if self in SIDE_EFFECTS:
|
| 935 |
+
del SIDE_EFFECTS[self]
|
| 936 |
+
|
| 937 |
+
# Use positional-only argument to avoid naming collision with aten ops arguments
|
| 938 |
+
# that are named "self". This way, all the aten ops can be called by kwargs.
|
| 939 |
+
def __call__(self, /, *args, **kwargs):
|
| 940 |
+
if _must_dispatch_in_python(args, kwargs):
|
| 941 |
+
# When any inputs are FakeScriptObject, we need to
|
| 942 |
+
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
|
| 943 |
+
# because C++ dispatcher will check the schema and cannot recognize FakeScriptObject.
|
| 944 |
+
#
|
| 945 |
+
# Note:
|
| 946 |
+
# 1. We only register the torchbind op temporarily as effectful op because we only want
|
| 947 |
+
# the effect token functionalization logic to be applied during tracing. Otherwise, the behavior
|
| 948 |
+
# of the eagerly executing the op might change after tracing.
|
| 949 |
+
# 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
|
| 950 |
+
# cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
|
| 951 |
+
with self._register_as_effectful_op_temporarily():
|
| 952 |
+
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
|
| 953 |
+
return self._op(*args, **kwargs)
|
| 954 |
+
|
| 955 |
+
def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
|
| 956 |
+
non_fallthrough_keys = torch._C._dispatch_keyset_full()
|
| 957 |
+
for key in fallthrough_keys:
|
| 958 |
+
non_fallthrough_keys = non_fallthrough_keys.remove(key)
|
| 959 |
+
|
| 960 |
+
dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
|
| 961 |
+
dispatch_key = dispatch_key_set.highestPriorityTypeId()
|
| 962 |
+
|
| 963 |
+
handler = (
|
| 964 |
+
self._get_dispatch(dispatch_key)
|
| 965 |
+
if dispatch_key not in self._dispatch_cache
|
| 966 |
+
else self._dispatch_cache[dispatch_key]
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
if isinstance(handler, DispatchKey):
|
| 970 |
+
# fallthrough keys can be registered at runtime via torch.library.impl
|
| 971 |
+
# so need to add it to fallthrough_keys and re-dispatch.
|
| 972 |
+
if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
|
| 973 |
+
self.name(), dispatch_key
|
| 974 |
+
):
|
| 975 |
+
return self._dispatch_in_python(
|
| 976 |
+
args, kwargs, fallthrough_keys + [dispatch_key]
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
raise RuntimeError(
|
| 980 |
+
f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}."
|
| 981 |
+
f" but no python implementation is found."
|
| 982 |
+
f" Please file an issue on this when you encounter this error."
|
| 983 |
+
f" This error can happen when you export or compile the model."
|
| 984 |
+
f" It can still happpen even if a C++ implementation for {dispatch_key}. "
|
| 985 |
+
f" has been registered. That's because FakeScriptObject purely lives in python and cannot work "
|
| 986 |
+
f" with a C++ implementation."
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
assert isinstance(handler, Callable) # type: ignore[arg-type]
|
| 990 |
+
return handler(*args, **kwargs)
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
def _must_dispatch_in_python(args, kwargs):
|
| 994 |
+
return pytree.tree_any(
|
| 995 |
+
lambda obj: isinstance(
|
| 996 |
+
obj, torch._library.fake_class_registry.FakeScriptObject
|
| 997 |
+
),
|
| 998 |
+
(args, kwargs),
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
|
| 1003 |
+
return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
|
| 1007 |
+
# You can obtain an OpOverload object through attribute query.
|
| 1008 |
+
class OpOverloadPacket:
|
| 1009 |
+
def __init__(self, qualified_op_name, op_name, op, overload_names):
|
| 1010 |
+
# These attributes are accessible on the object through the properties
|
| 1011 |
+
# defined below but are immutable
|
| 1012 |
+
self._qualified_op_name = qualified_op_name
|
| 1013 |
+
self.__name__ = op_name
|
| 1014 |
+
self._op = op
|
| 1015 |
+
self._overload_names = overload_names
|
| 1016 |
+
self._dir = []
|
| 1017 |
+
self._has_torchbind_op_overload = any(
|
| 1018 |
+
_has_script_object_arg(schema) for schema in self._schemas.values()
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
|
| 1022 |
+
def __deepcopy__(self, memo=None):
|
| 1023 |
+
return self
|
| 1024 |
+
|
| 1025 |
+
def __repr__(self):
|
| 1026 |
+
return "<OpOverloadPacket(op='{}.{}')>".format(
|
| 1027 |
+
*self._qualified_op_name.split("::")
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
def __hash__(self):
|
| 1031 |
+
return hash(self._op)
|
| 1032 |
+
|
| 1033 |
+
def __str__(self):
|
| 1034 |
+
return "{}.{}".format(*self._qualified_op_name.split("::"))
|
| 1035 |
+
|
| 1036 |
+
@property
|
| 1037 |
+
def op(self):
|
| 1038 |
+
return self._op
|
| 1039 |
+
|
| 1040 |
+
@property
|
| 1041 |
+
def _schemas(self):
|
| 1042 |
+
return {
|
| 1043 |
+
overload_name: torch._C._get_schema(self._qualified_op_name, overload_name)
|
| 1044 |
+
for overload_name in self._overload_names
|
| 1045 |
+
}
|
| 1046 |
+
|
| 1047 |
+
def __getattr__(self, key):
|
| 1048 |
+
# It is not a valid op_name when __file__ is passed in
|
| 1049 |
+
if key == "__file__":
|
| 1050 |
+
return "torch.ops"
|
| 1051 |
+
|
| 1052 |
+
# ensure that query for dunder attributes that does not exist on
|
| 1053 |
+
# opoverloadpacket but instead exists on the self._op object does not unnecessarily call
|
| 1054 |
+
# `_get_operation_overload` (which is an expensive operation).
|
| 1055 |
+
# This is done to prevent any potential slowdown. This list can be extended
|
| 1056 |
+
# if there exists other attributes like `__name__` that only exist on self._op and not on the
|
| 1057 |
+
# opoverloadpacket.
|
| 1058 |
+
# This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
|
| 1059 |
+
try:
|
| 1060 |
+
if key.startswith("__"):
|
| 1061 |
+
return getattr(self._op, key)
|
| 1062 |
+
except AttributeError:
|
| 1063 |
+
# for consistency because it seems weird to
|
| 1064 |
+
# throw an attribute error with a message containing
|
| 1065 |
+
# an object name different from the one the attribute
|
| 1066 |
+
# query was performed on.
|
| 1067 |
+
raise AttributeError(
|
| 1068 |
+
f"'{str(self)}' can't have an overload name beginning with '__' and the "
|
| 1069 |
+
f"underlying op {str(self._op)} has no attribute {key} either."
|
| 1070 |
+
) from None
|
| 1071 |
+
|
| 1072 |
+
try:
|
| 1073 |
+
# This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
|
| 1074 |
+
use_key = "" if key == "default" else key
|
| 1075 |
+
# TODO: disallow access to overloads registered by JIT
|
| 1076 |
+
op_dk_tags = torch._C._get_operation_overload(
|
| 1077 |
+
self._qualified_op_name, use_key
|
| 1078 |
+
)
|
| 1079 |
+
if op_dk_tags is None:
|
| 1080 |
+
raise AttributeError(
|
| 1081 |
+
f"The underlying op of '{str(self)}' has no overload name '{key}'"
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
op_, op_dk_, tags = op_dk_tags
|
| 1085 |
+
schema = torch._C._get_schema(self._qualified_op_name, use_key)
|
| 1086 |
+
overload = (
|
| 1087 |
+
OpOverload(self, op_, op_dk_, schema, tags)
|
| 1088 |
+
if not _has_script_object_arg(schema)
|
| 1089 |
+
else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
|
| 1090 |
+
)
|
| 1091 |
+
# cache the overload object
|
| 1092 |
+
setattr(self, key, overload)
|
| 1093 |
+
self._dir.append(key)
|
| 1094 |
+
return overload
|
| 1095 |
+
except RuntimeError:
|
| 1096 |
+
raise AttributeError(
|
| 1097 |
+
f"The underlying op of '{str(self)}' has no overload name '{key}'"
|
| 1098 |
+
) from None
|
| 1099 |
+
|
| 1100 |
+
def __iter__(self):
|
| 1101 |
+
return iter(self._dir)
|
| 1102 |
+
|
| 1103 |
+
# Use positional-only argument to avoid naming collision with aten ops arguments
|
| 1104 |
+
# that are named "self". This way, all the aten ops can be called by kwargs.
|
| 1105 |
+
def __call__(self, /, *args, **kwargs):
|
| 1106 |
+
# overloading __call__ to ensure torch.ops.foo.bar()
|
| 1107 |
+
# is still callable from JIT
|
| 1108 |
+
# We save the function ptr as the `op` attribute on
|
| 1109 |
+
# OpOverloadPacket to access it here.
|
| 1110 |
+
|
| 1111 |
+
# Directly calling OverloadPacket goes into C++, which will check
|
| 1112 |
+
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
|
| 1113 |
+
# intercept it here and call TorchBindOpverload instead.
|
| 1114 |
+
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
|
| 1115 |
+
return _call_overload_packet_from_python(self, args, kwargs)
|
| 1116 |
+
return self._op(*args, **(kwargs or {}))
|
| 1117 |
+
|
| 1118 |
+
# TODO: use this to make a __dir__
|
| 1119 |
+
def overloads(self):
|
| 1120 |
+
return [n if n else "default" for n in self._overload_names]
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
+
# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
|
| 1124 |
+
# _jit_get_operations, which calls _get_operation_for_overload_or_packet.
|
| 1125 |
+
def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
|
| 1126 |
+
# Re-use the torch function handling logic in cpp
|
| 1127 |
+
torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
|
| 1128 |
+
op, *args, **kwargs
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
if torch_function_called:
|
| 1132 |
+
return ret
|
| 1133 |
+
|
| 1134 |
+
# The following mirrors getOpWithStack.
|
| 1135 |
+
# In cpp, we do a schema matching for the arguments, and call ToIValue to
|
| 1136 |
+
# to check whether the arguments are valid. But need to do similar things here
|
| 1137 |
+
# and check the schema whether the FakeScriptObject is the corresponding fake class
|
| 1138 |
+
# of the actual class used in schema.
|
| 1139 |
+
exceptions = {}
|
| 1140 |
+
found_op = None
|
| 1141 |
+
for overload_name in op.overloads():
|
| 1142 |
+
op_overload = getattr(op, overload_name)
|
| 1143 |
+
try:
|
| 1144 |
+
_ = torch._C._check_schema_allow_fake_script_object(
|
| 1145 |
+
op_overload._schema, *args, **kwargs
|
| 1146 |
+
)
|
| 1147 |
+
found_op = op_overload
|
| 1148 |
+
break
|
| 1149 |
+
except RuntimeError as e:
|
| 1150 |
+
exceptions[overload_name] = e
|
| 1151 |
+
|
| 1152 |
+
if found_op:
|
| 1153 |
+
return found_op(*args, **kwargs)
|
| 1154 |
+
|
| 1155 |
+
err_msg = (
|
| 1156 |
+
f"Fail to match any TorchBindOverload of {op} with following exceptions:\n"
|
| 1157 |
+
)
|
| 1158 |
+
for i, (key, msg) in enumerate(exceptions.items()):
|
| 1159 |
+
err_msg += f"Overload name {key}:\n {msg}\n"
|
| 1160 |
+
raise RuntimeError(err_msg)
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
# Resolution of torch.fn is different from torch.ops.aten.fn
|
| 1164 |
+
# torch.fn uses the Python argparser, matches with the
|
| 1165 |
+
# appropriate schema, and calls into the unboxed version of the method
|
| 1166 |
+
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
|
| 1167 |
+
# JIT creates a stack of all the overloads and then tries to match the
|
| 1168 |
+
# correct one at runtime and always calls into the boxed version of the method
|
| 1169 |
+
# Autograd codegen creates VariableType, TracerType,
|
| 1170 |
+
# inplace or view type and python bindings.
|
| 1171 |
+
# Aten codegen generates tensor methods for the tensor class.
|
| 1172 |
+
|
| 1173 |
+
# _OpNamespace is a subclass of ModuleType because the torch script
|
| 1174 |
+
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
|
| 1175 |
+
# to work from script, we need to ensure ops and foo are modules
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
class _OpNamespace(types.ModuleType):
|
| 1179 |
+
"""
|
| 1180 |
+
An op namespace to dynamically bind Operators into Python.
|
| 1181 |
+
|
| 1182 |
+
Say a user has created a custom Operator called "my_namespace::my_op". To
|
| 1183 |
+
call this op, the user will write torch.ops.my_namespace.my_op(...).
|
| 1184 |
+
At startup, this operation will not yet be bound into Python. Instead, the
|
| 1185 |
+
following sequence of magic tricks will occur:
|
| 1186 |
+
1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
|
| 1187 |
+
on the `torch.ops` object, which will create a new `_OpNamespace`
|
| 1188 |
+
object called `my_namespace` and set it as an attribute on the `ops`
|
| 1189 |
+
object.
|
| 1190 |
+
2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
|
| 1191 |
+
the `my_namespace` object, which will retrieve the operation via
|
| 1192 |
+
`torch.get_operation`, a function bound from C++, and then in a similar
|
| 1193 |
+
fashion bind this new object onto the `my_namespace` object.
|
| 1194 |
+
3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
|
| 1195 |
+
and subsequent accesses will incur no further lookup (the namespace and
|
| 1196 |
+
operation will already exist).
|
| 1197 |
+
"""
|
| 1198 |
+
|
| 1199 |
+
def __init__(self, name):
|
| 1200 |
+
super().__init__("torch.ops." + name)
|
| 1201 |
+
self.name = name
|
| 1202 |
+
self._dir = []
|
| 1203 |
+
|
| 1204 |
+
def __iter__(self):
|
| 1205 |
+
return iter(self._dir)
|
| 1206 |
+
|
| 1207 |
+
def __getattr__(self, op_name):
|
| 1208 |
+
# It is not a valid op_name when __file__ is passed in
|
| 1209 |
+
if op_name == "__file__":
|
| 1210 |
+
return "torch.ops"
|
| 1211 |
+
elif op_name in ["__origin__", "__self__"]:
|
| 1212 |
+
raise AttributeError(
|
| 1213 |
+
f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'"
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
# Get the op `my_namespace::my_op` if available. This will also check
|
| 1217 |
+
# for overloads and raise an exception if there are more than one.
|
| 1218 |
+
namespace_name = self.name
|
| 1219 |
+
qualified_op_name = f"{namespace_name}::{op_name}"
|
| 1220 |
+
module_name = self.__module__ + "." + namespace_name
|
| 1221 |
+
|
| 1222 |
+
try:
|
| 1223 |
+
op, overload_names = _get_packet(qualified_op_name, module_name)
|
| 1224 |
+
if op is None:
|
| 1225 |
+
raise AttributeError(
|
| 1226 |
+
f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
|
| 1227 |
+
)
|
| 1228 |
+
except RuntimeError as e:
|
| 1229 |
+
# Turn this into AttributeError so getattr(obj, key, default)
|
| 1230 |
+
# works (this is called by TorchScript with __origin__)
|
| 1231 |
+
raise AttributeError(
|
| 1232 |
+
f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
|
| 1233 |
+
) from e
|
| 1234 |
+
|
| 1235 |
+
op.__module__ = module_name
|
| 1236 |
+
opoverloadpacket = OpOverloadPacket(
|
| 1237 |
+
qualified_op_name, op_name, op, overload_names
|
| 1238 |
+
)
|
| 1239 |
+
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
|
| 1240 |
+
# cache the opoverloadpacket to ensure that each op corresponds to
|
| 1241 |
+
# a unique OpOverloadPacket object
|
| 1242 |
+
setattr(self, op_name, opoverloadpacket)
|
| 1243 |
+
self._dir.append(op_name)
|
| 1244 |
+
return opoverloadpacket
|
| 1245 |
+
|
| 1246 |
+
|
| 1247 |
+
def _get_packet(qualname, op_module):
|
| 1248 |
+
op, overload_names = torch._C._jit_get_operation(qualname)
|
| 1249 |
+
if op is not None:
|
| 1250 |
+
# let the script frontend know that op is identical to the builtin op
|
| 1251 |
+
# with qualified_op_name
|
| 1252 |
+
torch.jit._builtins._register_builtin(op, qualname)
|
| 1253 |
+
op.__module__ = op_module
|
| 1254 |
+
return op, overload_names
|
| 1255 |
+
|
| 1256 |
+
|
| 1257 |
+
def _refresh_packet(packet):
|
| 1258 |
+
op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__)
|
| 1259 |
+
assert op is not None
|
| 1260 |
+
packet._op = op
|
| 1261 |
+
packet._overload_names = overload_names
|
| 1262 |
+
|
| 1263 |
+
|
| 1264 |
+
class _PyOpNamespace(_OpNamespace):
|
| 1265 |
+
def __init__(self, name, ops):
|
| 1266 |
+
super().__init__(name)
|
| 1267 |
+
self._ops = ops
|
| 1268 |
+
|
| 1269 |
+
def __getattr__(self, name):
|
| 1270 |
+
# Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
|
| 1271 |
+
op = self._ops.get(name, None)
|
| 1272 |
+
if op is None:
|
| 1273 |
+
raise AttributeError(
|
| 1274 |
+
f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
|
| 1275 |
+
)
|
| 1276 |
+
setattr(self, name, op)
|
| 1277 |
+
return op
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
class _Ops(types.ModuleType):
|
| 1281 |
+
__file__ = "_ops.py"
|
| 1282 |
+
|
| 1283 |
+
def __init__(self):
|
| 1284 |
+
super().__init__("torch.ops")
|
| 1285 |
+
self.loaded_libraries = set()
|
| 1286 |
+
self._higher_order_op_namespace = _PyOpNamespace(
|
| 1287 |
+
"torch.ops.higher_order", _higher_order_ops
|
| 1288 |
+
)
|
| 1289 |
+
self._dir = []
|
| 1290 |
+
|
| 1291 |
+
def __getattr__(self, name):
|
| 1292 |
+
# Check if the name is a HigherOrderOperator
|
| 1293 |
+
if name == "higher_order":
|
| 1294 |
+
return self._higher_order_op_namespace
|
| 1295 |
+
|
| 1296 |
+
# Here we are creating `torch.ops.my_namespace`
|
| 1297 |
+
namespace = _OpNamespace(name)
|
| 1298 |
+
setattr(self, name, namespace)
|
| 1299 |
+
self._dir.append(name)
|
| 1300 |
+
return namespace
|
| 1301 |
+
|
| 1302 |
+
def __iter__(self):
|
| 1303 |
+
return iter(self._dir)
|
| 1304 |
+
|
| 1305 |
+
def import_module(self, module):
|
| 1306 |
+
"""
|
| 1307 |
+
Imports a Python module that has torch.library registrations.
|
| 1308 |
+
|
| 1309 |
+
Generally, to extend PyTorch with custom operators, a user will
|
| 1310 |
+
create a Python module whose import triggers registration of
|
| 1311 |
+
the custom operators via a torch.ops.load_library call or a call
|
| 1312 |
+
to one or more torch.library.* APIs.
|
| 1313 |
+
|
| 1314 |
+
It is unexpected for Python modules to have side effects, so some
|
| 1315 |
+
linters and formatters will complain. Use this API to import Python
|
| 1316 |
+
modules that contain these torch.library side effects.
|
| 1317 |
+
|
| 1318 |
+
Args:
|
| 1319 |
+
module (str): The name of the Python module to import
|
| 1320 |
+
|
| 1321 |
+
"""
|
| 1322 |
+
importlib.import_module(module)
|
| 1323 |
+
|
| 1324 |
+
def load_library(self, path):
|
| 1325 |
+
"""
|
| 1326 |
+
Loads a shared library from the given path into the current process.
|
| 1327 |
+
|
| 1328 |
+
The library being loaded may run global initialization code to register
|
| 1329 |
+
custom operators with the PyTorch JIT runtime. This allows dynamically
|
| 1330 |
+
loading custom operators. For this, you should compile your operator
|
| 1331 |
+
and the static registration code into a shared library object, and then
|
| 1332 |
+
call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
|
| 1333 |
+
shared object.
|
| 1334 |
+
|
| 1335 |
+
After the library is loaded, it is added to the
|
| 1336 |
+
``torch.ops.loaded_libraries`` attribute, a set that may be inspected
|
| 1337 |
+
for the paths of all libraries loaded using this function.
|
| 1338 |
+
|
| 1339 |
+
Args:
|
| 1340 |
+
path (str): A path to a shared library to load.
|
| 1341 |
+
"""
|
| 1342 |
+
if torch._running_with_deploy():
|
| 1343 |
+
return
|
| 1344 |
+
|
| 1345 |
+
path = _utils_internal.resolve_library_path(path)
|
| 1346 |
+
with dl_open_guard():
|
| 1347 |
+
# Import the shared library into the process, thus running its
|
| 1348 |
+
# static (global) initialization code in order to register custom
|
| 1349 |
+
# operators with the JIT.
|
| 1350 |
+
ctypes.CDLL(path)
|
| 1351 |
+
self.loaded_libraries.add(path)
|
| 1352 |
+
|
| 1353 |
+
|
| 1354 |
+
# The ops "namespace"
|
| 1355 |
+
ops: _Ops = _Ops()
|
.venv/lib/python3.11/site-packages/torch/_python_dispatcher.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import torch._C as C
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
PythonDispatcher class is a thin python-binding to C++ dispatcher and it
|
| 9 |
+
is designed to show how dispatcher precompute works. In particular,
|
| 10 |
+
it shows for a certain op `foo`, what the computed dispatch table looks
|
| 11 |
+
like after user register their kernels to certains dispatch keys.
|
| 12 |
+
|
| 13 |
+
In the real C++ dispatcher we support many dispatch keys for different
|
| 14 |
+
functionalities. For simplicity PythonDispatcher only supports dispatch
|
| 15 |
+
keys for a single example of each use case. These use cases are listed below:
|
| 16 |
+
|
| 17 |
+
- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
|
| 18 |
+
autograd kernel in pytorch core library.
|
| 19 |
+
E.g. CPU, CUDA
|
| 20 |
+
- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
|
| 21 |
+
inference kernels, but they share the same autograd kernel specified in AutogradOther.
|
| 22 |
+
E.g. FPGA, SparseCsrCPU
|
| 23 |
+
- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
|
| 24 |
+
kernel defined in pytorch core library. Backend owner is responsible for registering both
|
| 25 |
+
inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
|
| 26 |
+
E.g. XLA, XPU, MPS
|
| 27 |
+
- CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc.
|
| 28 |
+
Kernels registered to this key MUST work for inference for all backends.
|
| 29 |
+
- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther.
|
| 30 |
+
Kernels registered to this key MUST work for autograd for all backends.
|
| 31 |
+
- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd
|
| 32 |
+
Kernels registered to this key MUST work for both inference + autograd for all backends.
|
| 33 |
+
|
| 34 |
+
Note we only allow registrations to alias keys inside pytorch core library. E.g
|
| 35 |
+
you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd
|
| 36 |
+
kernel from torch-xla extension, instead you should upstream the kernel into
|
| 37 |
+
pytorch/pytorch repo so that it's available for all backends and continuously
|
| 38 |
+
tested even without the extension.
|
| 39 |
+
|
| 40 |
+
Usage:
|
| 41 |
+
dispatcher = PythonDispatcher()
|
| 42 |
+
dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"])
|
| 43 |
+
print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend.
|
| 44 |
+
# For more debugging information
|
| 45 |
+
# print(dispatcher.keys())
|
| 46 |
+
# print(dispatcher.registrations())
|
| 47 |
+
# print(dispatcher.rawRegistrations())
|
| 48 |
+
# print(dispatcher.rawDispatchTable())
|
| 49 |
+
PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table.
|
| 50 |
+
This file only provides the simplified API for developers, relevant test code is located in
|
| 51 |
+
test/test_dispatch.py
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class PythonDispatcher:
|
| 56 |
+
namespace = "__test__"
|
| 57 |
+
name = "foo"
|
| 58 |
+
# fmt: off
|
| 59 |
+
runtime_keys = [
|
| 60 |
+
"CPU", "AutogradCPU",
|
| 61 |
+
"FPGA", "AutogradOther",
|
| 62 |
+
"XLA", "AutogradXLA",
|
| 63 |
+
"Lazy", "AutogradLazy",
|
| 64 |
+
]
|
| 65 |
+
# fmt: on
|
| 66 |
+
alias_keys = [
|
| 67 |
+
"CompositeExplicitAutograd",
|
| 68 |
+
"Autograd",
|
| 69 |
+
"CompositeImplicitAutograd",
|
| 70 |
+
]
|
| 71 |
+
supported_keys = runtime_keys + alias_keys
|
| 72 |
+
|
| 73 |
+
def __init__(self) -> None:
|
| 74 |
+
C._dispatch_check_invariants(self.name) # type: ignore[attr-defined]
|
| 75 |
+
self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
|
| 76 |
+
self.ref.def_("foo(Tensor x) -> Tensor")
|
| 77 |
+
|
| 78 |
+
"""
|
| 79 |
+
Returns a list of dispatch keys supported by PythonDispatcher.
|
| 80 |
+
You can register kernels to these keys.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def keys(self):
|
| 84 |
+
return self.supported_keys
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
Register kernels to the target dispatchKeys.
|
| 88 |
+
dispatchKeys(list[str]): a list of dispatch keys that you want to register
|
| 89 |
+
your own kernel. Note that you don't need to write the kernel yourself in
|
| 90 |
+
this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is
|
| 91 |
+
automatically generated and registered.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def register(self, dispatchKeys):
|
| 95 |
+
# Overriden is not supported and triggers a warning in C++ dispatcher.
|
| 96 |
+
if len(set(dispatchKeys)) != len(dispatchKeys):
|
| 97 |
+
raise RuntimeError(
|
| 98 |
+
f"Overriden is not allowed but found duplicates in {dispatchKeys}."
|
| 99 |
+
)
|
| 100 |
+
# We currently forbid this in codegen instead of C++ dispatcher.
|
| 101 |
+
if (
|
| 102 |
+
"CompositeImplicitAutograd" in dispatchKeys
|
| 103 |
+
and "CompositeExplicitAutograd" in dispatchKeys
|
| 104 |
+
):
|
| 105 |
+
raise RuntimeError(
|
| 106 |
+
"Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed."
|
| 107 |
+
)
|
| 108 |
+
for key in dispatchKeys:
|
| 109 |
+
if key not in self.supported_keys:
|
| 110 |
+
raise RuntimeError(
|
| 111 |
+
f"{key} is not supported, please select a dispatch key in {self.supported_keys}."
|
| 112 |
+
)
|
| 113 |
+
self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key)
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
Helper function to format (key, kernel).
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def _format_line(self, key, kernel):
|
| 120 |
+
return f"{key:<15} {kernel}\n"
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
Helper function to print a table header.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def _format_header(self, header):
|
| 127 |
+
s = f"""
|
| 128 |
+
{header}
|
| 129 |
+
"""
|
| 130 |
+
s += self._format_line("key", "kernel")
|
| 131 |
+
s += "---------------------------\n"
|
| 132 |
+
return s
|
| 133 |
+
|
| 134 |
+
"""
|
| 135 |
+
Returns raw output of all registration info for debugging only.
|
| 136 |
+
Use registrations() for a simplified version.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def rawRegistrations(self):
|
| 140 |
+
return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined]
|
| 141 |
+
|
| 142 |
+
"""
|
| 143 |
+
Returns raw output of computed dispatch table for debugging only.
|
| 144 |
+
Use dispatchTable() for a simplified version.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def rawDispatchTable(self):
|
| 148 |
+
return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined]
|
| 149 |
+
|
| 150 |
+
"""
|
| 151 |
+
Returns a table(str) including all the registrations from users.
|
| 152 |
+
Note this includes registrations to both runtime keys and alias keys.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def registrations(self):
|
| 156 |
+
output = self._format_header("Registered Kernels")
|
| 157 |
+
state = self.rawRegistrations()
|
| 158 |
+
state_entries = state.split("\n")
|
| 159 |
+
for line in state_entries:
|
| 160 |
+
first = line.split(":")[0]
|
| 161 |
+
if any(first.startswith(k) for k in self.supported_keys):
|
| 162 |
+
kernel = line.split("::")[0].split(" ")[1]
|
| 163 |
+
output += self._format_line(first, kernel)
|
| 164 |
+
return output
|
| 165 |
+
|
| 166 |
+
"""
|
| 167 |
+
Returns the computed dispatch table(str). Note this only include
|
| 168 |
+
runtime keys, registrations to alias keys have been decoded to their
|
| 169 |
+
mapped runtime keys.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def dispatchTable(self):
|
| 173 |
+
output = self._format_header("Computed Dispatch Table")
|
| 174 |
+
table = self.rawDispatchTable()
|
| 175 |
+
table_entries = table.split("\n")
|
| 176 |
+
regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
|
| 177 |
+
for line in table_entries:
|
| 178 |
+
k = line.split(":")[0]
|
| 179 |
+
if k in self.runtime_keys:
|
| 180 |
+
entry = regex.sub("[", line)
|
| 181 |
+
output += self._format_line(k, entry.split(": ")[1])
|
| 182 |
+
return output
|