Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .venv/lib/python3.11/site-packages/click/__pycache__/_winconsole.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/click/__pycache__/formatting.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/click/__pycache__/parser.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/click/__pycache__/termui.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/click/_termui_impl.py +788 -0
- .venv/lib/python3.11/site-packages/click/decorators.py +562 -0
- .venv/lib/python3.11/site-packages/click/py.typed +0 -0
- .venv/lib/python3.11/site-packages/click/shell_completion.py +603 -0
- .venv/lib/python3.11/site-packages/click/termui.py +784 -0
- .venv/lib/python3.11/site-packages/click/utils.py +624 -0
- .venv/lib/python3.11/site-packages/httplib2/__init__.py +1799 -0
- .venv/lib/python3.11/site-packages/httplib2/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/httplib2/__pycache__/auth.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/httplib2/__pycache__/certs.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/httplib2/__pycache__/error.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/httplib2/__pycache__/iri2uri.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/httplib2/__pycache__/socks.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/httplib2/auth.py +69 -0
- .venv/lib/python3.11/site-packages/httplib2/cacerts.txt +0 -0
- .venv/lib/python3.11/site-packages/httplib2/certs.py +42 -0
- .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/INSTALLER +1 -0
- .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/LICENSE +21 -0
- .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/METADATA +252 -0
- .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/RECORD +45 -0
- .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/WHEEL +4 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv.h +671 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend.h +60 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn.h +693 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_v9.h +693 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_graph_v9.h +909 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops_v9.h +1316 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_v9.h +68 -0
- .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version.h +70 -0
- .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libavutil-734d06dd.so.57.28.100 +3 -0
- .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libssl-28bef1ac.so.1.1 +3 -0
- .venv/lib/python3.11/site-packages/pyasn1/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/pyasn1/__pycache__/debug.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/__init__.py +1 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__init__.py +1 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/decoder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/encoder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/eoo.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/ber/decoder.py +2189 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/ber/encoder.py +954 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/ber/eoo.py +28 -0
- .venv/lib/python3.11/site-packages/pyasn1/codec/native/__init__.py +1 -0
.gitattributes
CHANGED
|
@@ -407,3 +407,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
|
|
| 407 |
.venv/lib/python3.11/site-packages/aiohttp/_websocket/mask.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 408 |
.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 409 |
.venv/lib/python3.11/site-packages/aiohttp/_websocket/reader_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 407 |
.venv/lib/python3.11/site-packages/aiohttp/_websocket/mask.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 408 |
.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 409 |
.venv/lib/python3.11/site-packages/aiohttp/_websocket/reader_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 410 |
+
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libssl-28bef1ac.so.1.1 filter=lfs diff=lfs merge=lfs -text
|
| 411 |
+
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libavutil-734d06dd.so.57.28.100 filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/click/__pycache__/_winconsole.cpython-311.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/click/__pycache__/formatting.cpython-311.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/click/__pycache__/parser.cpython-311.pyc
ADDED
|
Binary file (23.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/click/__pycache__/termui.cpython-311.pyc
ADDED
|
Binary file (34.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/click/_termui_impl.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains implementations for the termui module. To keep the
|
| 3 |
+
import time of Click down, some infrequently used functionality is
|
| 4 |
+
placed in this module and only imported as needed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import contextlib
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
import typing as t
|
| 13 |
+
from gettext import gettext as _
|
| 14 |
+
from io import StringIO
|
| 15 |
+
from shutil import which
|
| 16 |
+
from types import TracebackType
|
| 17 |
+
|
| 18 |
+
from ._compat import _default_text_stdout
|
| 19 |
+
from ._compat import CYGWIN
|
| 20 |
+
from ._compat import get_best_encoding
|
| 21 |
+
from ._compat import isatty
|
| 22 |
+
from ._compat import open_stream
|
| 23 |
+
from ._compat import strip_ansi
|
| 24 |
+
from ._compat import term_len
|
| 25 |
+
from ._compat import WIN
|
| 26 |
+
from .exceptions import ClickException
|
| 27 |
+
from .utils import echo
|
| 28 |
+
|
| 29 |
+
V = t.TypeVar("V")
|
| 30 |
+
|
| 31 |
+
if os.name == "nt":
|
| 32 |
+
BEFORE_BAR = "\r"
|
| 33 |
+
AFTER_BAR = "\n"
|
| 34 |
+
else:
|
| 35 |
+
BEFORE_BAR = "\r\033[?25l"
|
| 36 |
+
AFTER_BAR = "\033[?25h\n"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ProgressBar(t.Generic[V]):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
iterable: t.Optional[t.Iterable[V]],
|
| 43 |
+
length: t.Optional[int] = None,
|
| 44 |
+
fill_char: str = "#",
|
| 45 |
+
empty_char: str = " ",
|
| 46 |
+
bar_template: str = "%(bar)s",
|
| 47 |
+
info_sep: str = " ",
|
| 48 |
+
show_eta: bool = True,
|
| 49 |
+
show_percent: t.Optional[bool] = None,
|
| 50 |
+
show_pos: bool = False,
|
| 51 |
+
item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
|
| 52 |
+
label: t.Optional[str] = None,
|
| 53 |
+
file: t.Optional[t.TextIO] = None,
|
| 54 |
+
color: t.Optional[bool] = None,
|
| 55 |
+
update_min_steps: int = 1,
|
| 56 |
+
width: int = 30,
|
| 57 |
+
) -> None:
|
| 58 |
+
self.fill_char = fill_char
|
| 59 |
+
self.empty_char = empty_char
|
| 60 |
+
self.bar_template = bar_template
|
| 61 |
+
self.info_sep = info_sep
|
| 62 |
+
self.show_eta = show_eta
|
| 63 |
+
self.show_percent = show_percent
|
| 64 |
+
self.show_pos = show_pos
|
| 65 |
+
self.item_show_func = item_show_func
|
| 66 |
+
self.label: str = label or ""
|
| 67 |
+
|
| 68 |
+
if file is None:
|
| 69 |
+
file = _default_text_stdout()
|
| 70 |
+
|
| 71 |
+
# There are no standard streams attached to write to. For example,
|
| 72 |
+
# pythonw on Windows.
|
| 73 |
+
if file is None:
|
| 74 |
+
file = StringIO()
|
| 75 |
+
|
| 76 |
+
self.file = file
|
| 77 |
+
self.color = color
|
| 78 |
+
self.update_min_steps = update_min_steps
|
| 79 |
+
self._completed_intervals = 0
|
| 80 |
+
self.width: int = width
|
| 81 |
+
self.autowidth: bool = width == 0
|
| 82 |
+
|
| 83 |
+
if length is None:
|
| 84 |
+
from operator import length_hint
|
| 85 |
+
|
| 86 |
+
length = length_hint(iterable, -1)
|
| 87 |
+
|
| 88 |
+
if length == -1:
|
| 89 |
+
length = None
|
| 90 |
+
if iterable is None:
|
| 91 |
+
if length is None:
|
| 92 |
+
raise TypeError("iterable or length is required")
|
| 93 |
+
iterable = t.cast(t.Iterable[V], range(length))
|
| 94 |
+
self.iter: t.Iterable[V] = iter(iterable)
|
| 95 |
+
self.length = length
|
| 96 |
+
self.pos = 0
|
| 97 |
+
self.avg: t.List[float] = []
|
| 98 |
+
self.last_eta: float
|
| 99 |
+
self.start: float
|
| 100 |
+
self.start = self.last_eta = time.time()
|
| 101 |
+
self.eta_known: bool = False
|
| 102 |
+
self.finished: bool = False
|
| 103 |
+
self.max_width: t.Optional[int] = None
|
| 104 |
+
self.entered: bool = False
|
| 105 |
+
self.current_item: t.Optional[V] = None
|
| 106 |
+
self.is_hidden: bool = not isatty(self.file)
|
| 107 |
+
self._last_line: t.Optional[str] = None
|
| 108 |
+
|
| 109 |
+
def __enter__(self) -> "ProgressBar[V]":
|
| 110 |
+
self.entered = True
|
| 111 |
+
self.render_progress()
|
| 112 |
+
return self
|
| 113 |
+
|
| 114 |
+
def __exit__(
|
| 115 |
+
self,
|
| 116 |
+
exc_type: t.Optional[t.Type[BaseException]],
|
| 117 |
+
exc_value: t.Optional[BaseException],
|
| 118 |
+
tb: t.Optional[TracebackType],
|
| 119 |
+
) -> None:
|
| 120 |
+
self.render_finish()
|
| 121 |
+
|
| 122 |
+
def __iter__(self) -> t.Iterator[V]:
|
| 123 |
+
if not self.entered:
|
| 124 |
+
raise RuntimeError("You need to use progress bars in a with block.")
|
| 125 |
+
self.render_progress()
|
| 126 |
+
return self.generator()
|
| 127 |
+
|
| 128 |
+
def __next__(self) -> V:
|
| 129 |
+
# Iteration is defined in terms of a generator function,
|
| 130 |
+
# returned by iter(self); use that to define next(). This works
|
| 131 |
+
# because `self.iter` is an iterable consumed by that generator,
|
| 132 |
+
# so it is re-entry safe. Calling `next(self.generator())`
|
| 133 |
+
# twice works and does "what you want".
|
| 134 |
+
return next(iter(self))
|
| 135 |
+
|
| 136 |
+
def render_finish(self) -> None:
|
| 137 |
+
if self.is_hidden:
|
| 138 |
+
return
|
| 139 |
+
self.file.write(AFTER_BAR)
|
| 140 |
+
self.file.flush()
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def pct(self) -> float:
|
| 144 |
+
if self.finished:
|
| 145 |
+
return 1.0
|
| 146 |
+
return min(self.pos / (float(self.length or 1) or 1), 1.0)
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def time_per_iteration(self) -> float:
|
| 150 |
+
if not self.avg:
|
| 151 |
+
return 0.0
|
| 152 |
+
return sum(self.avg) / float(len(self.avg))
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def eta(self) -> float:
|
| 156 |
+
if self.length is not None and not self.finished:
|
| 157 |
+
return self.time_per_iteration * (self.length - self.pos)
|
| 158 |
+
return 0.0
|
| 159 |
+
|
| 160 |
+
def format_eta(self) -> str:
|
| 161 |
+
if self.eta_known:
|
| 162 |
+
t = int(self.eta)
|
| 163 |
+
seconds = t % 60
|
| 164 |
+
t //= 60
|
| 165 |
+
minutes = t % 60
|
| 166 |
+
t //= 60
|
| 167 |
+
hours = t % 24
|
| 168 |
+
t //= 24
|
| 169 |
+
if t > 0:
|
| 170 |
+
return f"{t}d {hours:02}:{minutes:02}:{seconds:02}"
|
| 171 |
+
else:
|
| 172 |
+
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
| 173 |
+
return ""
|
| 174 |
+
|
| 175 |
+
def format_pos(self) -> str:
|
| 176 |
+
pos = str(self.pos)
|
| 177 |
+
if self.length is not None:
|
| 178 |
+
pos += f"/{self.length}"
|
| 179 |
+
return pos
|
| 180 |
+
|
| 181 |
+
def format_pct(self) -> str:
|
| 182 |
+
return f"{int(self.pct * 100): 4}%"[1:]
|
| 183 |
+
|
| 184 |
+
def format_bar(self) -> str:
|
| 185 |
+
if self.length is not None:
|
| 186 |
+
bar_length = int(self.pct * self.width)
|
| 187 |
+
bar = self.fill_char * bar_length
|
| 188 |
+
bar += self.empty_char * (self.width - bar_length)
|
| 189 |
+
elif self.finished:
|
| 190 |
+
bar = self.fill_char * self.width
|
| 191 |
+
else:
|
| 192 |
+
chars = list(self.empty_char * (self.width or 1))
|
| 193 |
+
if self.time_per_iteration != 0:
|
| 194 |
+
chars[
|
| 195 |
+
int(
|
| 196 |
+
(math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5)
|
| 197 |
+
* self.width
|
| 198 |
+
)
|
| 199 |
+
] = self.fill_char
|
| 200 |
+
bar = "".join(chars)
|
| 201 |
+
return bar
|
| 202 |
+
|
| 203 |
+
def format_progress_line(self) -> str:
|
| 204 |
+
show_percent = self.show_percent
|
| 205 |
+
|
| 206 |
+
info_bits = []
|
| 207 |
+
if self.length is not None and show_percent is None:
|
| 208 |
+
show_percent = not self.show_pos
|
| 209 |
+
|
| 210 |
+
if self.show_pos:
|
| 211 |
+
info_bits.append(self.format_pos())
|
| 212 |
+
if show_percent:
|
| 213 |
+
info_bits.append(self.format_pct())
|
| 214 |
+
if self.show_eta and self.eta_known and not self.finished:
|
| 215 |
+
info_bits.append(self.format_eta())
|
| 216 |
+
if self.item_show_func is not None:
|
| 217 |
+
item_info = self.item_show_func(self.current_item)
|
| 218 |
+
if item_info is not None:
|
| 219 |
+
info_bits.append(item_info)
|
| 220 |
+
|
| 221 |
+
return (
|
| 222 |
+
self.bar_template
|
| 223 |
+
% {
|
| 224 |
+
"label": self.label,
|
| 225 |
+
"bar": self.format_bar(),
|
| 226 |
+
"info": self.info_sep.join(info_bits),
|
| 227 |
+
}
|
| 228 |
+
).rstrip()
|
| 229 |
+
|
| 230 |
+
def render_progress(self) -> None:
|
| 231 |
+
import shutil
|
| 232 |
+
|
| 233 |
+
if self.is_hidden:
|
| 234 |
+
# Only output the label as it changes if the output is not a
|
| 235 |
+
# TTY. Use file=stderr if you expect to be piping stdout.
|
| 236 |
+
if self._last_line != self.label:
|
| 237 |
+
self._last_line = self.label
|
| 238 |
+
echo(self.label, file=self.file, color=self.color)
|
| 239 |
+
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
buf = []
|
| 243 |
+
# Update width in case the terminal has been resized
|
| 244 |
+
if self.autowidth:
|
| 245 |
+
old_width = self.width
|
| 246 |
+
self.width = 0
|
| 247 |
+
clutter_length = term_len(self.format_progress_line())
|
| 248 |
+
new_width = max(0, shutil.get_terminal_size().columns - clutter_length)
|
| 249 |
+
if new_width < old_width:
|
| 250 |
+
buf.append(BEFORE_BAR)
|
| 251 |
+
buf.append(" " * self.max_width) # type: ignore
|
| 252 |
+
self.max_width = new_width
|
| 253 |
+
self.width = new_width
|
| 254 |
+
|
| 255 |
+
clear_width = self.width
|
| 256 |
+
if self.max_width is not None:
|
| 257 |
+
clear_width = self.max_width
|
| 258 |
+
|
| 259 |
+
buf.append(BEFORE_BAR)
|
| 260 |
+
line = self.format_progress_line()
|
| 261 |
+
line_len = term_len(line)
|
| 262 |
+
if self.max_width is None or self.max_width < line_len:
|
| 263 |
+
self.max_width = line_len
|
| 264 |
+
|
| 265 |
+
buf.append(line)
|
| 266 |
+
buf.append(" " * (clear_width - line_len))
|
| 267 |
+
line = "".join(buf)
|
| 268 |
+
# Render the line only if it changed.
|
| 269 |
+
|
| 270 |
+
if line != self._last_line:
|
| 271 |
+
self._last_line = line
|
| 272 |
+
echo(line, file=self.file, color=self.color, nl=False)
|
| 273 |
+
self.file.flush()
|
| 274 |
+
|
| 275 |
+
def make_step(self, n_steps: int) -> None:
|
| 276 |
+
self.pos += n_steps
|
| 277 |
+
if self.length is not None and self.pos >= self.length:
|
| 278 |
+
self.finished = True
|
| 279 |
+
|
| 280 |
+
if (time.time() - self.last_eta) < 1.0:
|
| 281 |
+
return
|
| 282 |
+
|
| 283 |
+
self.last_eta = time.time()
|
| 284 |
+
|
| 285 |
+
# self.avg is a rolling list of length <= 7 of steps where steps are
|
| 286 |
+
# defined as time elapsed divided by the total progress through
|
| 287 |
+
# self.length.
|
| 288 |
+
if self.pos:
|
| 289 |
+
step = (time.time() - self.start) / self.pos
|
| 290 |
+
else:
|
| 291 |
+
step = time.time() - self.start
|
| 292 |
+
|
| 293 |
+
self.avg = self.avg[-6:] + [step]
|
| 294 |
+
|
| 295 |
+
self.eta_known = self.length is not None
|
| 296 |
+
|
| 297 |
+
def update(self, n_steps: int, current_item: t.Optional[V] = None) -> None:
|
| 298 |
+
"""Update the progress bar by advancing a specified number of
|
| 299 |
+
steps, and optionally set the ``current_item`` for this new
|
| 300 |
+
position.
|
| 301 |
+
|
| 302 |
+
:param n_steps: Number of steps to advance.
|
| 303 |
+
:param current_item: Optional item to set as ``current_item``
|
| 304 |
+
for the updated position.
|
| 305 |
+
|
| 306 |
+
.. versionchanged:: 8.0
|
| 307 |
+
Added the ``current_item`` optional parameter.
|
| 308 |
+
|
| 309 |
+
.. versionchanged:: 8.0
|
| 310 |
+
Only render when the number of steps meets the
|
| 311 |
+
``update_min_steps`` threshold.
|
| 312 |
+
"""
|
| 313 |
+
if current_item is not None:
|
| 314 |
+
self.current_item = current_item
|
| 315 |
+
|
| 316 |
+
self._completed_intervals += n_steps
|
| 317 |
+
|
| 318 |
+
if self._completed_intervals >= self.update_min_steps:
|
| 319 |
+
self.make_step(self._completed_intervals)
|
| 320 |
+
self.render_progress()
|
| 321 |
+
self._completed_intervals = 0
|
| 322 |
+
|
| 323 |
+
def finish(self) -> None:
|
| 324 |
+
self.eta_known = False
|
| 325 |
+
self.current_item = None
|
| 326 |
+
self.finished = True
|
| 327 |
+
|
| 328 |
+
def generator(self) -> t.Iterator[V]:
|
| 329 |
+
"""Return a generator which yields the items added to the bar
|
| 330 |
+
during construction, and updates the progress bar *after* the
|
| 331 |
+
yielded block returns.
|
| 332 |
+
"""
|
| 333 |
+
# WARNING: the iterator interface for `ProgressBar` relies on
|
| 334 |
+
# this and only works because this is a simple generator which
|
| 335 |
+
# doesn't create or manage additional state. If this function
|
| 336 |
+
# changes, the impact should be evaluated both against
|
| 337 |
+
# `iter(bar)` and `next(bar)`. `next()` in particular may call
|
| 338 |
+
# `self.generator()` repeatedly, and this must remain safe in
|
| 339 |
+
# order for that interface to work.
|
| 340 |
+
if not self.entered:
|
| 341 |
+
raise RuntimeError("You need to use progress bars in a with block.")
|
| 342 |
+
|
| 343 |
+
if self.is_hidden:
|
| 344 |
+
yield from self.iter
|
| 345 |
+
else:
|
| 346 |
+
for rv in self.iter:
|
| 347 |
+
self.current_item = rv
|
| 348 |
+
|
| 349 |
+
# This allows show_item_func to be updated before the
|
| 350 |
+
# item is processed. Only trigger at the beginning of
|
| 351 |
+
# the update interval.
|
| 352 |
+
if self._completed_intervals == 0:
|
| 353 |
+
self.render_progress()
|
| 354 |
+
|
| 355 |
+
yield rv
|
| 356 |
+
self.update(1)
|
| 357 |
+
|
| 358 |
+
self.finish()
|
| 359 |
+
self.render_progress()
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def pager(generator: t.Iterable[str], color: t.Optional[bool] = None) -> None:
|
| 363 |
+
"""Decide what method to use for paging through text."""
|
| 364 |
+
stdout = _default_text_stdout()
|
| 365 |
+
|
| 366 |
+
# There are no standard streams attached to write to. For example,
|
| 367 |
+
# pythonw on Windows.
|
| 368 |
+
if stdout is None:
|
| 369 |
+
stdout = StringIO()
|
| 370 |
+
|
| 371 |
+
if not isatty(sys.stdin) or not isatty(stdout):
|
| 372 |
+
return _nullpager(stdout, generator, color)
|
| 373 |
+
pager_cmd = (os.environ.get("PAGER", None) or "").strip()
|
| 374 |
+
if pager_cmd:
|
| 375 |
+
if WIN:
|
| 376 |
+
if _tempfilepager(generator, pager_cmd, color):
|
| 377 |
+
return
|
| 378 |
+
elif _pipepager(generator, pager_cmd, color):
|
| 379 |
+
return
|
| 380 |
+
if os.environ.get("TERM") in ("dumb", "emacs"):
|
| 381 |
+
return _nullpager(stdout, generator, color)
|
| 382 |
+
if (WIN or sys.platform.startswith("os2")) and _tempfilepager(
|
| 383 |
+
generator, "more", color
|
| 384 |
+
):
|
| 385 |
+
return
|
| 386 |
+
if _pipepager(generator, "less", color):
|
| 387 |
+
return
|
| 388 |
+
|
| 389 |
+
import tempfile
|
| 390 |
+
|
| 391 |
+
fd, filename = tempfile.mkstemp()
|
| 392 |
+
os.close(fd)
|
| 393 |
+
try:
|
| 394 |
+
if _pipepager(generator, "more", color):
|
| 395 |
+
return
|
| 396 |
+
return _nullpager(stdout, generator, color)
|
| 397 |
+
finally:
|
| 398 |
+
os.unlink(filename)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def _pipepager(generator: t.Iterable[str], cmd: str, color: t.Optional[bool]) -> bool:
|
| 402 |
+
"""Page through text by feeding it to another program. Invoking a
|
| 403 |
+
pager through this might support colors.
|
| 404 |
+
|
| 405 |
+
Returns True if the command was found, False otherwise and thus another
|
| 406 |
+
pager should be attempted.
|
| 407 |
+
"""
|
| 408 |
+
cmd_absolute = which(cmd)
|
| 409 |
+
if cmd_absolute is None:
|
| 410 |
+
return False
|
| 411 |
+
|
| 412 |
+
import subprocess
|
| 413 |
+
|
| 414 |
+
env = dict(os.environ)
|
| 415 |
+
|
| 416 |
+
# If we're piping to less we might support colors under the
|
| 417 |
+
# condition that
|
| 418 |
+
cmd_detail = cmd.rsplit("/", 1)[-1].split()
|
| 419 |
+
if color is None and cmd_detail[0] == "less":
|
| 420 |
+
less_flags = f"{os.environ.get('LESS', '')}{' '.join(cmd_detail[1:])}"
|
| 421 |
+
if not less_flags:
|
| 422 |
+
env["LESS"] = "-R"
|
| 423 |
+
color = True
|
| 424 |
+
elif "r" in less_flags or "R" in less_flags:
|
| 425 |
+
color = True
|
| 426 |
+
|
| 427 |
+
c = subprocess.Popen(
|
| 428 |
+
[cmd_absolute],
|
| 429 |
+
shell=True,
|
| 430 |
+
stdin=subprocess.PIPE,
|
| 431 |
+
env=env,
|
| 432 |
+
errors="replace",
|
| 433 |
+
text=True,
|
| 434 |
+
)
|
| 435 |
+
assert c.stdin is not None
|
| 436 |
+
try:
|
| 437 |
+
for text in generator:
|
| 438 |
+
if not color:
|
| 439 |
+
text = strip_ansi(text)
|
| 440 |
+
|
| 441 |
+
c.stdin.write(text)
|
| 442 |
+
except (OSError, KeyboardInterrupt):
|
| 443 |
+
pass
|
| 444 |
+
else:
|
| 445 |
+
c.stdin.close()
|
| 446 |
+
|
| 447 |
+
# Less doesn't respect ^C, but catches it for its own UI purposes (aborting
|
| 448 |
+
# search or other commands inside less).
|
| 449 |
+
#
|
| 450 |
+
# That means when the user hits ^C, the parent process (click) terminates,
|
| 451 |
+
# but less is still alive, paging the output and messing up the terminal.
|
| 452 |
+
#
|
| 453 |
+
# If the user wants to make the pager exit on ^C, they should set
|
| 454 |
+
# `LESS='-K'`. It's not our decision to make.
|
| 455 |
+
while True:
|
| 456 |
+
try:
|
| 457 |
+
c.wait()
|
| 458 |
+
except KeyboardInterrupt:
|
| 459 |
+
pass
|
| 460 |
+
else:
|
| 461 |
+
break
|
| 462 |
+
|
| 463 |
+
return True
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _tempfilepager(
|
| 467 |
+
generator: t.Iterable[str],
|
| 468 |
+
cmd: str,
|
| 469 |
+
color: t.Optional[bool],
|
| 470 |
+
) -> bool:
|
| 471 |
+
"""Page through text by invoking a program on a temporary file.
|
| 472 |
+
|
| 473 |
+
Returns True if the command was found, False otherwise and thus another
|
| 474 |
+
pager should be attempted.
|
| 475 |
+
"""
|
| 476 |
+
# Which is necessary for Windows, it is also recommended in the Popen docs.
|
| 477 |
+
cmd_absolute = which(cmd)
|
| 478 |
+
if cmd_absolute is None:
|
| 479 |
+
return False
|
| 480 |
+
|
| 481 |
+
import subprocess
|
| 482 |
+
import tempfile
|
| 483 |
+
|
| 484 |
+
fd, filename = tempfile.mkstemp()
|
| 485 |
+
# TODO: This never terminates if the passed generator never terminates.
|
| 486 |
+
text = "".join(generator)
|
| 487 |
+
if not color:
|
| 488 |
+
text = strip_ansi(text)
|
| 489 |
+
encoding = get_best_encoding(sys.stdout)
|
| 490 |
+
with open_stream(filename, "wb")[0] as f:
|
| 491 |
+
f.write(text.encode(encoding))
|
| 492 |
+
try:
|
| 493 |
+
subprocess.call([cmd_absolute, filename])
|
| 494 |
+
except OSError:
|
| 495 |
+
# Command not found
|
| 496 |
+
pass
|
| 497 |
+
finally:
|
| 498 |
+
os.close(fd)
|
| 499 |
+
os.unlink(filename)
|
| 500 |
+
|
| 501 |
+
return True
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def _nullpager(
|
| 505 |
+
stream: t.TextIO, generator: t.Iterable[str], color: t.Optional[bool]
|
| 506 |
+
) -> None:
|
| 507 |
+
"""Simply print unformatted text. This is the ultimate fallback."""
|
| 508 |
+
for text in generator:
|
| 509 |
+
if not color:
|
| 510 |
+
text = strip_ansi(text)
|
| 511 |
+
stream.write(text)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class Editor:
|
| 515 |
+
def __init__(
|
| 516 |
+
self,
|
| 517 |
+
editor: t.Optional[str] = None,
|
| 518 |
+
env: t.Optional[t.Mapping[str, str]] = None,
|
| 519 |
+
require_save: bool = True,
|
| 520 |
+
extension: str = ".txt",
|
| 521 |
+
) -> None:
|
| 522 |
+
self.editor = editor
|
| 523 |
+
self.env = env
|
| 524 |
+
self.require_save = require_save
|
| 525 |
+
self.extension = extension
|
| 526 |
+
|
| 527 |
+
def get_editor(self) -> str:
|
| 528 |
+
if self.editor is not None:
|
| 529 |
+
return self.editor
|
| 530 |
+
for key in "VISUAL", "EDITOR":
|
| 531 |
+
rv = os.environ.get(key)
|
| 532 |
+
if rv:
|
| 533 |
+
return rv
|
| 534 |
+
if WIN:
|
| 535 |
+
return "notepad"
|
| 536 |
+
for editor in "sensible-editor", "vim", "nano":
|
| 537 |
+
if which(editor) is not None:
|
| 538 |
+
return editor
|
| 539 |
+
return "vi"
|
| 540 |
+
|
| 541 |
+
def edit_file(self, filename: str) -> None:
|
| 542 |
+
import subprocess
|
| 543 |
+
|
| 544 |
+
editor = self.get_editor()
|
| 545 |
+
environ: t.Optional[t.Dict[str, str]] = None
|
| 546 |
+
|
| 547 |
+
if self.env:
|
| 548 |
+
environ = os.environ.copy()
|
| 549 |
+
environ.update(self.env)
|
| 550 |
+
|
| 551 |
+
try:
|
| 552 |
+
c = subprocess.Popen(f'{editor} "{filename}"', env=environ, shell=True)
|
| 553 |
+
exit_code = c.wait()
|
| 554 |
+
if exit_code != 0:
|
| 555 |
+
raise ClickException(
|
| 556 |
+
_("{editor}: Editing failed").format(editor=editor)
|
| 557 |
+
)
|
| 558 |
+
except OSError as e:
|
| 559 |
+
raise ClickException(
|
| 560 |
+
_("{editor}: Editing failed: {e}").format(editor=editor, e=e)
|
| 561 |
+
) from e
|
| 562 |
+
|
| 563 |
+
def edit(self, text: t.Optional[t.AnyStr]) -> t.Optional[t.AnyStr]:
|
| 564 |
+
import tempfile
|
| 565 |
+
|
| 566 |
+
if not text:
|
| 567 |
+
data = b""
|
| 568 |
+
elif isinstance(text, (bytes, bytearray)):
|
| 569 |
+
data = text
|
| 570 |
+
else:
|
| 571 |
+
if text and not text.endswith("\n"):
|
| 572 |
+
text += "\n"
|
| 573 |
+
|
| 574 |
+
if WIN:
|
| 575 |
+
data = text.replace("\n", "\r\n").encode("utf-8-sig")
|
| 576 |
+
else:
|
| 577 |
+
data = text.encode("utf-8")
|
| 578 |
+
|
| 579 |
+
fd, name = tempfile.mkstemp(prefix="editor-", suffix=self.extension)
|
| 580 |
+
f: t.BinaryIO
|
| 581 |
+
|
| 582 |
+
try:
|
| 583 |
+
with os.fdopen(fd, "wb") as f:
|
| 584 |
+
f.write(data)
|
| 585 |
+
|
| 586 |
+
# If the filesystem resolution is 1 second, like Mac OS
|
| 587 |
+
# 10.12 Extended, or 2 seconds, like FAT32, and the editor
|
| 588 |
+
# closes very fast, require_save can fail. Set the modified
|
| 589 |
+
# time to be 2 seconds in the past to work around this.
|
| 590 |
+
os.utime(name, (os.path.getatime(name), os.path.getmtime(name) - 2))
|
| 591 |
+
# Depending on the resolution, the exact value might not be
|
| 592 |
+
# recorded, so get the new recorded value.
|
| 593 |
+
timestamp = os.path.getmtime(name)
|
| 594 |
+
|
| 595 |
+
self.edit_file(name)
|
| 596 |
+
|
| 597 |
+
if self.require_save and os.path.getmtime(name) == timestamp:
|
| 598 |
+
return None
|
| 599 |
+
|
| 600 |
+
with open(name, "rb") as f:
|
| 601 |
+
rv = f.read()
|
| 602 |
+
|
| 603 |
+
if isinstance(text, (bytes, bytearray)):
|
| 604 |
+
return rv
|
| 605 |
+
|
| 606 |
+
return rv.decode("utf-8-sig").replace("\r\n", "\n") # type: ignore
|
| 607 |
+
finally:
|
| 608 |
+
os.unlink(name)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def open_url(url: str, wait: bool = False, locate: bool = False) -> int:
|
| 612 |
+
import subprocess
|
| 613 |
+
|
| 614 |
+
def _unquote_file(url: str) -> str:
|
| 615 |
+
from urllib.parse import unquote
|
| 616 |
+
|
| 617 |
+
if url.startswith("file://"):
|
| 618 |
+
url = unquote(url[7:])
|
| 619 |
+
|
| 620 |
+
return url
|
| 621 |
+
|
| 622 |
+
if sys.platform == "darwin":
|
| 623 |
+
args = ["open"]
|
| 624 |
+
if wait:
|
| 625 |
+
args.append("-W")
|
| 626 |
+
if locate:
|
| 627 |
+
args.append("-R")
|
| 628 |
+
args.append(_unquote_file(url))
|
| 629 |
+
null = open("/dev/null", "w")
|
| 630 |
+
try:
|
| 631 |
+
return subprocess.Popen(args, stderr=null).wait()
|
| 632 |
+
finally:
|
| 633 |
+
null.close()
|
| 634 |
+
elif WIN:
|
| 635 |
+
if locate:
|
| 636 |
+
url = _unquote_file(url)
|
| 637 |
+
args = ["explorer", f"/select,{url}"]
|
| 638 |
+
else:
|
| 639 |
+
args = ["start"]
|
| 640 |
+
if wait:
|
| 641 |
+
args.append("/WAIT")
|
| 642 |
+
args.append("")
|
| 643 |
+
args.append(url)
|
| 644 |
+
try:
|
| 645 |
+
return subprocess.call(args)
|
| 646 |
+
except OSError:
|
| 647 |
+
# Command not found
|
| 648 |
+
return 127
|
| 649 |
+
elif CYGWIN:
|
| 650 |
+
if locate:
|
| 651 |
+
url = _unquote_file(url)
|
| 652 |
+
args = ["cygstart", os.path.dirname(url)]
|
| 653 |
+
else:
|
| 654 |
+
args = ["cygstart"]
|
| 655 |
+
if wait:
|
| 656 |
+
args.append("-w")
|
| 657 |
+
args.append(url)
|
| 658 |
+
try:
|
| 659 |
+
return subprocess.call(args)
|
| 660 |
+
except OSError:
|
| 661 |
+
# Command not found
|
| 662 |
+
return 127
|
| 663 |
+
|
| 664 |
+
try:
|
| 665 |
+
if locate:
|
| 666 |
+
url = os.path.dirname(_unquote_file(url)) or "."
|
| 667 |
+
else:
|
| 668 |
+
url = _unquote_file(url)
|
| 669 |
+
c = subprocess.Popen(["xdg-open", url])
|
| 670 |
+
if wait:
|
| 671 |
+
return c.wait()
|
| 672 |
+
return 0
|
| 673 |
+
except OSError:
|
| 674 |
+
if url.startswith(("http://", "https://")) and not locate and not wait:
|
| 675 |
+
import webbrowser
|
| 676 |
+
|
| 677 |
+
webbrowser.open(url)
|
| 678 |
+
return 0
|
| 679 |
+
return 1
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def _translate_ch_to_exc(ch: str) -> t.Optional[BaseException]:
|
| 683 |
+
if ch == "\x03":
|
| 684 |
+
raise KeyboardInterrupt()
|
| 685 |
+
|
| 686 |
+
if ch == "\x04" and not WIN: # Unix-like, Ctrl+D
|
| 687 |
+
raise EOFError()
|
| 688 |
+
|
| 689 |
+
if ch == "\x1a" and WIN: # Windows, Ctrl+Z
|
| 690 |
+
raise EOFError()
|
| 691 |
+
|
| 692 |
+
return None
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
if WIN:
|
| 696 |
+
import msvcrt
|
| 697 |
+
|
| 698 |
+
@contextlib.contextmanager
|
| 699 |
+
def raw_terminal() -> t.Iterator[int]:
|
| 700 |
+
yield -1
|
| 701 |
+
|
| 702 |
+
def getchar(echo: bool) -> str:
|
| 703 |
+
# The function `getch` will return a bytes object corresponding to
|
| 704 |
+
# the pressed character. Since Windows 10 build 1803, it will also
|
| 705 |
+
# return \x00 when called a second time after pressing a regular key.
|
| 706 |
+
#
|
| 707 |
+
# `getwch` does not share this probably-bugged behavior. Moreover, it
|
| 708 |
+
# returns a Unicode object by default, which is what we want.
|
| 709 |
+
#
|
| 710 |
+
# Either of these functions will return \x00 or \xe0 to indicate
|
| 711 |
+
# a special key, and you need to call the same function again to get
|
| 712 |
+
# the "rest" of the code. The fun part is that \u00e0 is
|
| 713 |
+
# "latin small letter a with grave", so if you type that on a French
|
| 714 |
+
# keyboard, you _also_ get a \xe0.
|
| 715 |
+
# E.g., consider the Up arrow. This returns \xe0 and then \x48. The
|
| 716 |
+
# resulting Unicode string reads as "a with grave" + "capital H".
|
| 717 |
+
# This is indistinguishable from when the user actually types
|
| 718 |
+
# "a with grave" and then "capital H".
|
| 719 |
+
#
|
| 720 |
+
# When \xe0 is returned, we assume it's part of a special-key sequence
|
| 721 |
+
# and call `getwch` again, but that means that when the user types
|
| 722 |
+
# the \u00e0 character, `getchar` doesn't return until a second
|
| 723 |
+
# character is typed.
|
| 724 |
+
# The alternative is returning immediately, but that would mess up
|
| 725 |
+
# cross-platform handling of arrow keys and others that start with
|
| 726 |
+
# \xe0. Another option is using `getch`, but then we can't reliably
|
| 727 |
+
# read non-ASCII characters, because return values of `getch` are
|
| 728 |
+
# limited to the current 8-bit codepage.
|
| 729 |
+
#
|
| 730 |
+
# Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
|
| 731 |
+
# is doing the right thing in more situations than with `getch`.
|
| 732 |
+
func: t.Callable[[], str]
|
| 733 |
+
|
| 734 |
+
if echo:
|
| 735 |
+
func = msvcrt.getwche # type: ignore
|
| 736 |
+
else:
|
| 737 |
+
func = msvcrt.getwch # type: ignore
|
| 738 |
+
|
| 739 |
+
rv = func()
|
| 740 |
+
|
| 741 |
+
if rv in ("\x00", "\xe0"):
|
| 742 |
+
# \x00 and \xe0 are control characters that indicate special key,
|
| 743 |
+
# see above.
|
| 744 |
+
rv += func()
|
| 745 |
+
|
| 746 |
+
_translate_ch_to_exc(rv)
|
| 747 |
+
return rv
|
| 748 |
+
|
| 749 |
+
else:
|
| 750 |
+
import termios
|
| 751 |
+
import tty
|
| 752 |
+
|
| 753 |
+
@contextlib.contextmanager
|
| 754 |
+
def raw_terminal() -> t.Iterator[int]:
|
| 755 |
+
f: t.Optional[t.TextIO]
|
| 756 |
+
fd: int
|
| 757 |
+
|
| 758 |
+
if not isatty(sys.stdin):
|
| 759 |
+
f = open("/dev/tty")
|
| 760 |
+
fd = f.fileno()
|
| 761 |
+
else:
|
| 762 |
+
fd = sys.stdin.fileno()
|
| 763 |
+
f = None
|
| 764 |
+
|
| 765 |
+
try:
|
| 766 |
+
old_settings = termios.tcgetattr(fd)
|
| 767 |
+
|
| 768 |
+
try:
|
| 769 |
+
tty.setraw(fd)
|
| 770 |
+
yield fd
|
| 771 |
+
finally:
|
| 772 |
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
| 773 |
+
sys.stdout.flush()
|
| 774 |
+
|
| 775 |
+
if f is not None:
|
| 776 |
+
f.close()
|
| 777 |
+
except termios.error:
|
| 778 |
+
pass
|
| 779 |
+
|
| 780 |
+
def getchar(echo: bool) -> str:
|
| 781 |
+
with raw_terminal() as fd:
|
| 782 |
+
ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace")
|
| 783 |
+
|
| 784 |
+
if echo and isatty(sys.stdout):
|
| 785 |
+
sys.stdout.write(ch)
|
| 786 |
+
|
| 787 |
+
_translate_ch_to_exc(ch)
|
| 788 |
+
return ch
|
.venv/lib/python3.11/site-packages/click/decorators.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import types
|
| 3 |
+
import typing as t
|
| 4 |
+
from functools import update_wrapper
|
| 5 |
+
from gettext import gettext as _
|
| 6 |
+
|
| 7 |
+
from .core import Argument
|
| 8 |
+
from .core import Command
|
| 9 |
+
from .core import Context
|
| 10 |
+
from .core import Group
|
| 11 |
+
from .core import Option
|
| 12 |
+
from .core import Parameter
|
| 13 |
+
from .globals import get_current_context
|
| 14 |
+
from .utils import echo
|
| 15 |
+
|
| 16 |
+
if t.TYPE_CHECKING:
|
| 17 |
+
import typing_extensions as te
|
| 18 |
+
|
| 19 |
+
P = te.ParamSpec("P")
|
| 20 |
+
|
| 21 |
+
R = t.TypeVar("R")
|
| 22 |
+
T = t.TypeVar("T")
|
| 23 |
+
_AnyCallable = t.Callable[..., t.Any]
|
| 24 |
+
FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, Command])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def pass_context(f: "t.Callable[te.Concatenate[Context, P], R]") -> "t.Callable[P, R]":
|
| 28 |
+
"""Marks a callback as wanting to receive the current context
|
| 29 |
+
object as first argument.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R":
|
| 33 |
+
return f(get_current_context(), *args, **kwargs)
|
| 34 |
+
|
| 35 |
+
return update_wrapper(new_func, f)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def pass_obj(f: "t.Callable[te.Concatenate[t.Any, P], R]") -> "t.Callable[P, R]":
|
| 39 |
+
"""Similar to :func:`pass_context`, but only pass the object on the
|
| 40 |
+
context onwards (:attr:`Context.obj`). This is useful if that object
|
| 41 |
+
represents the state of a nested system.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R":
|
| 45 |
+
return f(get_current_context().obj, *args, **kwargs)
|
| 46 |
+
|
| 47 |
+
return update_wrapper(new_func, f)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def make_pass_decorator(
|
| 51 |
+
object_type: t.Type[T], ensure: bool = False
|
| 52 |
+
) -> t.Callable[["t.Callable[te.Concatenate[T, P], R]"], "t.Callable[P, R]"]:
|
| 53 |
+
"""Given an object type this creates a decorator that will work
|
| 54 |
+
similar to :func:`pass_obj` but instead of passing the object of the
|
| 55 |
+
current context, it will find the innermost context of type
|
| 56 |
+
:func:`object_type`.
|
| 57 |
+
|
| 58 |
+
This generates a decorator that works roughly like this::
|
| 59 |
+
|
| 60 |
+
from functools import update_wrapper
|
| 61 |
+
|
| 62 |
+
def decorator(f):
|
| 63 |
+
@pass_context
|
| 64 |
+
def new_func(ctx, *args, **kwargs):
|
| 65 |
+
obj = ctx.find_object(object_type)
|
| 66 |
+
return ctx.invoke(f, obj, *args, **kwargs)
|
| 67 |
+
return update_wrapper(new_func, f)
|
| 68 |
+
return decorator
|
| 69 |
+
|
| 70 |
+
:param object_type: the type of the object to pass.
|
| 71 |
+
:param ensure: if set to `True`, a new object will be created and
|
| 72 |
+
remembered on the context if it's not there yet.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def decorator(f: "t.Callable[te.Concatenate[T, P], R]") -> "t.Callable[P, R]":
|
| 76 |
+
def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R":
|
| 77 |
+
ctx = get_current_context()
|
| 78 |
+
|
| 79 |
+
obj: t.Optional[T]
|
| 80 |
+
if ensure:
|
| 81 |
+
obj = ctx.ensure_object(object_type)
|
| 82 |
+
else:
|
| 83 |
+
obj = ctx.find_object(object_type)
|
| 84 |
+
|
| 85 |
+
if obj is None:
|
| 86 |
+
raise RuntimeError(
|
| 87 |
+
"Managed to invoke callback without a context"
|
| 88 |
+
f" object of type {object_type.__name__!r}"
|
| 89 |
+
" existing."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return ctx.invoke(f, obj, *args, **kwargs)
|
| 93 |
+
|
| 94 |
+
return update_wrapper(new_func, f)
|
| 95 |
+
|
| 96 |
+
return decorator
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def pass_meta_key(
|
| 100 |
+
key: str, *, doc_description: t.Optional[str] = None
|
| 101 |
+
) -> "t.Callable[[t.Callable[te.Concatenate[t.Any, P], R]], t.Callable[P, R]]":
|
| 102 |
+
"""Create a decorator that passes a key from
|
| 103 |
+
:attr:`click.Context.meta` as the first argument to the decorated
|
| 104 |
+
function.
|
| 105 |
+
|
| 106 |
+
:param key: Key in ``Context.meta`` to pass.
|
| 107 |
+
:param doc_description: Description of the object being passed,
|
| 108 |
+
inserted into the decorator's docstring. Defaults to "the 'key'
|
| 109 |
+
key from Context.meta".
|
| 110 |
+
|
| 111 |
+
.. versionadded:: 8.0
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def decorator(f: "t.Callable[te.Concatenate[t.Any, P], R]") -> "t.Callable[P, R]":
|
| 115 |
+
def new_func(*args: "P.args", **kwargs: "P.kwargs") -> R:
|
| 116 |
+
ctx = get_current_context()
|
| 117 |
+
obj = ctx.meta[key]
|
| 118 |
+
return ctx.invoke(f, obj, *args, **kwargs)
|
| 119 |
+
|
| 120 |
+
return update_wrapper(new_func, f)
|
| 121 |
+
|
| 122 |
+
if doc_description is None:
|
| 123 |
+
doc_description = f"the {key!r} key from :attr:`click.Context.meta`"
|
| 124 |
+
|
| 125 |
+
decorator.__doc__ = (
|
| 126 |
+
f"Decorator that passes {doc_description} as the first argument"
|
| 127 |
+
" to the decorated function."
|
| 128 |
+
)
|
| 129 |
+
return decorator
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
CmdType = t.TypeVar("CmdType", bound=Command)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# variant: no call, directly as decorator for a function.
|
| 136 |
+
@t.overload
|
| 137 |
+
def command(name: _AnyCallable) -> Command: ...
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# variant: with positional name and with positional or keyword cls argument:
|
| 141 |
+
# @command(namearg, CommandCls, ...) or @command(namearg, cls=CommandCls, ...)
|
| 142 |
+
@t.overload
|
| 143 |
+
def command(
|
| 144 |
+
name: t.Optional[str],
|
| 145 |
+
cls: t.Type[CmdType],
|
| 146 |
+
**attrs: t.Any,
|
| 147 |
+
) -> t.Callable[[_AnyCallable], CmdType]: ...
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# variant: name omitted, cls _must_ be a keyword argument, @command(cls=CommandCls, ...)
|
| 151 |
+
@t.overload
|
| 152 |
+
def command(
|
| 153 |
+
name: None = None,
|
| 154 |
+
*,
|
| 155 |
+
cls: t.Type[CmdType],
|
| 156 |
+
**attrs: t.Any,
|
| 157 |
+
) -> t.Callable[[_AnyCallable], CmdType]: ...
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# variant: with optional string name, no cls argument provided.
|
| 161 |
+
@t.overload
|
| 162 |
+
def command(
|
| 163 |
+
name: t.Optional[str] = ..., cls: None = None, **attrs: t.Any
|
| 164 |
+
) -> t.Callable[[_AnyCallable], Command]: ...
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def command(
|
| 168 |
+
name: t.Union[t.Optional[str], _AnyCallable] = None,
|
| 169 |
+
cls: t.Optional[t.Type[CmdType]] = None,
|
| 170 |
+
**attrs: t.Any,
|
| 171 |
+
) -> t.Union[Command, t.Callable[[_AnyCallable], t.Union[Command, CmdType]]]:
|
| 172 |
+
r"""Creates a new :class:`Command` and uses the decorated function as
|
| 173 |
+
callback. This will also automatically attach all decorated
|
| 174 |
+
:func:`option`\s and :func:`argument`\s as parameters to the command.
|
| 175 |
+
|
| 176 |
+
The name of the command defaults to the name of the function with
|
| 177 |
+
underscores replaced by dashes. If you want to change that, you can
|
| 178 |
+
pass the intended name as the first argument.
|
| 179 |
+
|
| 180 |
+
All keyword arguments are forwarded to the underlying command class.
|
| 181 |
+
For the ``params`` argument, any decorated params are appended to
|
| 182 |
+
the end of the list.
|
| 183 |
+
|
| 184 |
+
Once decorated the function turns into a :class:`Command` instance
|
| 185 |
+
that can be invoked as a command line utility or be attached to a
|
| 186 |
+
command :class:`Group`.
|
| 187 |
+
|
| 188 |
+
:param name: the name of the command. This defaults to the function
|
| 189 |
+
name with underscores replaced by dashes.
|
| 190 |
+
:param cls: the command class to instantiate. This defaults to
|
| 191 |
+
:class:`Command`.
|
| 192 |
+
|
| 193 |
+
.. versionchanged:: 8.1
|
| 194 |
+
This decorator can be applied without parentheses.
|
| 195 |
+
|
| 196 |
+
.. versionchanged:: 8.1
|
| 197 |
+
The ``params`` argument can be used. Decorated params are
|
| 198 |
+
appended to the end of the list.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
func: t.Optional[t.Callable[[_AnyCallable], t.Any]] = None
|
| 202 |
+
|
| 203 |
+
if callable(name):
|
| 204 |
+
func = name
|
| 205 |
+
name = None
|
| 206 |
+
assert cls is None, "Use 'command(cls=cls)(callable)' to specify a class."
|
| 207 |
+
assert not attrs, "Use 'command(**kwargs)(callable)' to provide arguments."
|
| 208 |
+
|
| 209 |
+
if cls is None:
|
| 210 |
+
cls = t.cast(t.Type[CmdType], Command)
|
| 211 |
+
|
| 212 |
+
def decorator(f: _AnyCallable) -> CmdType:
|
| 213 |
+
if isinstance(f, Command):
|
| 214 |
+
raise TypeError("Attempted to convert a callback into a command twice.")
|
| 215 |
+
|
| 216 |
+
attr_params = attrs.pop("params", None)
|
| 217 |
+
params = attr_params if attr_params is not None else []
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
decorator_params = f.__click_params__ # type: ignore
|
| 221 |
+
except AttributeError:
|
| 222 |
+
pass
|
| 223 |
+
else:
|
| 224 |
+
del f.__click_params__ # type: ignore
|
| 225 |
+
params.extend(reversed(decorator_params))
|
| 226 |
+
|
| 227 |
+
if attrs.get("help") is None:
|
| 228 |
+
attrs["help"] = f.__doc__
|
| 229 |
+
|
| 230 |
+
if t.TYPE_CHECKING:
|
| 231 |
+
assert cls is not None
|
| 232 |
+
assert not callable(name)
|
| 233 |
+
|
| 234 |
+
cmd = cls(
|
| 235 |
+
name=name or f.__name__.lower().replace("_", "-"),
|
| 236 |
+
callback=f,
|
| 237 |
+
params=params,
|
| 238 |
+
**attrs,
|
| 239 |
+
)
|
| 240 |
+
cmd.__doc__ = f.__doc__
|
| 241 |
+
return cmd
|
| 242 |
+
|
| 243 |
+
if func is not None:
|
| 244 |
+
return decorator(func)
|
| 245 |
+
|
| 246 |
+
return decorator
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
GrpType = t.TypeVar("GrpType", bound=Group)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# variant: no call, directly as decorator for a function.
|
| 253 |
+
@t.overload
|
| 254 |
+
def group(name: _AnyCallable) -> Group: ...
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# variant: with positional name and with positional or keyword cls argument:
|
| 258 |
+
# @group(namearg, GroupCls, ...) or @group(namearg, cls=GroupCls, ...)
|
| 259 |
+
@t.overload
|
| 260 |
+
def group(
|
| 261 |
+
name: t.Optional[str],
|
| 262 |
+
cls: t.Type[GrpType],
|
| 263 |
+
**attrs: t.Any,
|
| 264 |
+
) -> t.Callable[[_AnyCallable], GrpType]: ...
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# variant: name omitted, cls _must_ be a keyword argument, @group(cmd=GroupCls, ...)
|
| 268 |
+
@t.overload
|
| 269 |
+
def group(
|
| 270 |
+
name: None = None,
|
| 271 |
+
*,
|
| 272 |
+
cls: t.Type[GrpType],
|
| 273 |
+
**attrs: t.Any,
|
| 274 |
+
) -> t.Callable[[_AnyCallable], GrpType]: ...
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# variant: with optional string name, no cls argument provided.
|
| 278 |
+
@t.overload
|
| 279 |
+
def group(
|
| 280 |
+
name: t.Optional[str] = ..., cls: None = None, **attrs: t.Any
|
| 281 |
+
) -> t.Callable[[_AnyCallable], Group]: ...
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def group(
|
| 285 |
+
name: t.Union[str, _AnyCallable, None] = None,
|
| 286 |
+
cls: t.Optional[t.Type[GrpType]] = None,
|
| 287 |
+
**attrs: t.Any,
|
| 288 |
+
) -> t.Union[Group, t.Callable[[_AnyCallable], t.Union[Group, GrpType]]]:
|
| 289 |
+
"""Creates a new :class:`Group` with a function as callback. This
|
| 290 |
+
works otherwise the same as :func:`command` just that the `cls`
|
| 291 |
+
parameter is set to :class:`Group`.
|
| 292 |
+
|
| 293 |
+
.. versionchanged:: 8.1
|
| 294 |
+
This decorator can be applied without parentheses.
|
| 295 |
+
"""
|
| 296 |
+
if cls is None:
|
| 297 |
+
cls = t.cast(t.Type[GrpType], Group)
|
| 298 |
+
|
| 299 |
+
if callable(name):
|
| 300 |
+
return command(cls=cls, **attrs)(name)
|
| 301 |
+
|
| 302 |
+
return command(name, cls, **attrs)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _param_memo(f: t.Callable[..., t.Any], param: Parameter) -> None:
|
| 306 |
+
if isinstance(f, Command):
|
| 307 |
+
f.params.append(param)
|
| 308 |
+
else:
|
| 309 |
+
if not hasattr(f, "__click_params__"):
|
| 310 |
+
f.__click_params__ = [] # type: ignore
|
| 311 |
+
|
| 312 |
+
f.__click_params__.append(param) # type: ignore
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def argument(
|
| 316 |
+
*param_decls: str, cls: t.Optional[t.Type[Argument]] = None, **attrs: t.Any
|
| 317 |
+
) -> t.Callable[[FC], FC]:
|
| 318 |
+
"""Attaches an argument to the command. All positional arguments are
|
| 319 |
+
passed as parameter declarations to :class:`Argument`; all keyword
|
| 320 |
+
arguments are forwarded unchanged (except ``cls``).
|
| 321 |
+
This is equivalent to creating an :class:`Argument` instance manually
|
| 322 |
+
and attaching it to the :attr:`Command.params` list.
|
| 323 |
+
|
| 324 |
+
For the default argument class, refer to :class:`Argument` and
|
| 325 |
+
:class:`Parameter` for descriptions of parameters.
|
| 326 |
+
|
| 327 |
+
:param cls: the argument class to instantiate. This defaults to
|
| 328 |
+
:class:`Argument`.
|
| 329 |
+
:param param_decls: Passed as positional arguments to the constructor of
|
| 330 |
+
``cls``.
|
| 331 |
+
:param attrs: Passed as keyword arguments to the constructor of ``cls``.
|
| 332 |
+
"""
|
| 333 |
+
if cls is None:
|
| 334 |
+
cls = Argument
|
| 335 |
+
|
| 336 |
+
def decorator(f: FC) -> FC:
|
| 337 |
+
_param_memo(f, cls(param_decls, **attrs))
|
| 338 |
+
return f
|
| 339 |
+
|
| 340 |
+
return decorator
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def option(
|
| 344 |
+
*param_decls: str, cls: t.Optional[t.Type[Option]] = None, **attrs: t.Any
|
| 345 |
+
) -> t.Callable[[FC], FC]:
|
| 346 |
+
"""Attaches an option to the command. All positional arguments are
|
| 347 |
+
passed as parameter declarations to :class:`Option`; all keyword
|
| 348 |
+
arguments are forwarded unchanged (except ``cls``).
|
| 349 |
+
This is equivalent to creating an :class:`Option` instance manually
|
| 350 |
+
and attaching it to the :attr:`Command.params` list.
|
| 351 |
+
|
| 352 |
+
For the default option class, refer to :class:`Option` and
|
| 353 |
+
:class:`Parameter` for descriptions of parameters.
|
| 354 |
+
|
| 355 |
+
:param cls: the option class to instantiate. This defaults to
|
| 356 |
+
:class:`Option`.
|
| 357 |
+
:param param_decls: Passed as positional arguments to the constructor of
|
| 358 |
+
``cls``.
|
| 359 |
+
:param attrs: Passed as keyword arguments to the constructor of ``cls``.
|
| 360 |
+
"""
|
| 361 |
+
if cls is None:
|
| 362 |
+
cls = Option
|
| 363 |
+
|
| 364 |
+
def decorator(f: FC) -> FC:
|
| 365 |
+
_param_memo(f, cls(param_decls, **attrs))
|
| 366 |
+
return f
|
| 367 |
+
|
| 368 |
+
return decorator
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def confirmation_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
|
| 372 |
+
"""Add a ``--yes`` option which shows a prompt before continuing if
|
| 373 |
+
not passed. If the prompt is declined, the program will exit.
|
| 374 |
+
|
| 375 |
+
:param param_decls: One or more option names. Defaults to the single
|
| 376 |
+
value ``"--yes"``.
|
| 377 |
+
:param kwargs: Extra arguments are passed to :func:`option`.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def callback(ctx: Context, param: Parameter, value: bool) -> None:
|
| 381 |
+
if not value:
|
| 382 |
+
ctx.abort()
|
| 383 |
+
|
| 384 |
+
if not param_decls:
|
| 385 |
+
param_decls = ("--yes",)
|
| 386 |
+
|
| 387 |
+
kwargs.setdefault("is_flag", True)
|
| 388 |
+
kwargs.setdefault("callback", callback)
|
| 389 |
+
kwargs.setdefault("expose_value", False)
|
| 390 |
+
kwargs.setdefault("prompt", "Do you want to continue?")
|
| 391 |
+
kwargs.setdefault("help", "Confirm the action without prompting.")
|
| 392 |
+
return option(*param_decls, **kwargs)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def password_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
|
| 396 |
+
"""Add a ``--password`` option which prompts for a password, hiding
|
| 397 |
+
input and asking to enter the value again for confirmation.
|
| 398 |
+
|
| 399 |
+
:param param_decls: One or more option names. Defaults to the single
|
| 400 |
+
value ``"--password"``.
|
| 401 |
+
:param kwargs: Extra arguments are passed to :func:`option`.
|
| 402 |
+
"""
|
| 403 |
+
if not param_decls:
|
| 404 |
+
param_decls = ("--password",)
|
| 405 |
+
|
| 406 |
+
kwargs.setdefault("prompt", True)
|
| 407 |
+
kwargs.setdefault("confirmation_prompt", True)
|
| 408 |
+
kwargs.setdefault("hide_input", True)
|
| 409 |
+
return option(*param_decls, **kwargs)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def version_option(
|
| 413 |
+
version: t.Optional[str] = None,
|
| 414 |
+
*param_decls: str,
|
| 415 |
+
package_name: t.Optional[str] = None,
|
| 416 |
+
prog_name: t.Optional[str] = None,
|
| 417 |
+
message: t.Optional[str] = None,
|
| 418 |
+
**kwargs: t.Any,
|
| 419 |
+
) -> t.Callable[[FC], FC]:
|
| 420 |
+
"""Add a ``--version`` option which immediately prints the version
|
| 421 |
+
number and exits the program.
|
| 422 |
+
|
| 423 |
+
If ``version`` is not provided, Click will try to detect it using
|
| 424 |
+
:func:`importlib.metadata.version` to get the version for the
|
| 425 |
+
``package_name``. On Python < 3.8, the ``importlib_metadata``
|
| 426 |
+
backport must be installed.
|
| 427 |
+
|
| 428 |
+
If ``package_name`` is not provided, Click will try to detect it by
|
| 429 |
+
inspecting the stack frames. This will be used to detect the
|
| 430 |
+
version, so it must match the name of the installed package.
|
| 431 |
+
|
| 432 |
+
:param version: The version number to show. If not provided, Click
|
| 433 |
+
will try to detect it.
|
| 434 |
+
:param param_decls: One or more option names. Defaults to the single
|
| 435 |
+
value ``"--version"``.
|
| 436 |
+
:param package_name: The package name to detect the version from. If
|
| 437 |
+
not provided, Click will try to detect it.
|
| 438 |
+
:param prog_name: The name of the CLI to show in the message. If not
|
| 439 |
+
provided, it will be detected from the command.
|
| 440 |
+
:param message: The message to show. The values ``%(prog)s``,
|
| 441 |
+
``%(package)s``, and ``%(version)s`` are available. Defaults to
|
| 442 |
+
``"%(prog)s, version %(version)s"``.
|
| 443 |
+
:param kwargs: Extra arguments are passed to :func:`option`.
|
| 444 |
+
:raise RuntimeError: ``version`` could not be detected.
|
| 445 |
+
|
| 446 |
+
.. versionchanged:: 8.0
|
| 447 |
+
Add the ``package_name`` parameter, and the ``%(package)s``
|
| 448 |
+
value for messages.
|
| 449 |
+
|
| 450 |
+
.. versionchanged:: 8.0
|
| 451 |
+
Use :mod:`importlib.metadata` instead of ``pkg_resources``. The
|
| 452 |
+
version is detected based on the package name, not the entry
|
| 453 |
+
point name. The Python package name must match the installed
|
| 454 |
+
package name, or be passed with ``package_name=``.
|
| 455 |
+
"""
|
| 456 |
+
if message is None:
|
| 457 |
+
message = _("%(prog)s, version %(version)s")
|
| 458 |
+
|
| 459 |
+
if version is None and package_name is None:
|
| 460 |
+
frame = inspect.currentframe()
|
| 461 |
+
f_back = frame.f_back if frame is not None else None
|
| 462 |
+
f_globals = f_back.f_globals if f_back is not None else None
|
| 463 |
+
# break reference cycle
|
| 464 |
+
# https://docs.python.org/3/library/inspect.html#the-interpreter-stack
|
| 465 |
+
del frame
|
| 466 |
+
|
| 467 |
+
if f_globals is not None:
|
| 468 |
+
package_name = f_globals.get("__name__")
|
| 469 |
+
|
| 470 |
+
if package_name == "__main__":
|
| 471 |
+
package_name = f_globals.get("__package__")
|
| 472 |
+
|
| 473 |
+
if package_name:
|
| 474 |
+
package_name = package_name.partition(".")[0]
|
| 475 |
+
|
| 476 |
+
def callback(ctx: Context, param: Parameter, value: bool) -> None:
|
| 477 |
+
if not value or ctx.resilient_parsing:
|
| 478 |
+
return
|
| 479 |
+
|
| 480 |
+
nonlocal prog_name
|
| 481 |
+
nonlocal version
|
| 482 |
+
|
| 483 |
+
if prog_name is None:
|
| 484 |
+
prog_name = ctx.find_root().info_name
|
| 485 |
+
|
| 486 |
+
if version is None and package_name is not None:
|
| 487 |
+
metadata: t.Optional[types.ModuleType]
|
| 488 |
+
|
| 489 |
+
try:
|
| 490 |
+
from importlib import metadata
|
| 491 |
+
except ImportError:
|
| 492 |
+
# Python < 3.8
|
| 493 |
+
import importlib_metadata as metadata # type: ignore
|
| 494 |
+
|
| 495 |
+
try:
|
| 496 |
+
version = metadata.version(package_name) # type: ignore
|
| 497 |
+
except metadata.PackageNotFoundError: # type: ignore
|
| 498 |
+
raise RuntimeError(
|
| 499 |
+
f"{package_name!r} is not installed. Try passing"
|
| 500 |
+
" 'package_name' instead."
|
| 501 |
+
) from None
|
| 502 |
+
|
| 503 |
+
if version is None:
|
| 504 |
+
raise RuntimeError(
|
| 505 |
+
f"Could not determine the version for {package_name!r} automatically."
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
echo(
|
| 509 |
+
message % {"prog": prog_name, "package": package_name, "version": version},
|
| 510 |
+
color=ctx.color,
|
| 511 |
+
)
|
| 512 |
+
ctx.exit()
|
| 513 |
+
|
| 514 |
+
if not param_decls:
|
| 515 |
+
param_decls = ("--version",)
|
| 516 |
+
|
| 517 |
+
kwargs.setdefault("is_flag", True)
|
| 518 |
+
kwargs.setdefault("expose_value", False)
|
| 519 |
+
kwargs.setdefault("is_eager", True)
|
| 520 |
+
kwargs.setdefault("help", _("Show the version and exit."))
|
| 521 |
+
kwargs["callback"] = callback
|
| 522 |
+
return option(*param_decls, **kwargs)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
class HelpOption(Option):
|
| 526 |
+
"""Pre-configured ``--help`` option which immediately prints the help page
|
| 527 |
+
and exits the program.
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
def __init__(
|
| 531 |
+
self,
|
| 532 |
+
param_decls: t.Optional[t.Sequence[str]] = None,
|
| 533 |
+
**kwargs: t.Any,
|
| 534 |
+
) -> None:
|
| 535 |
+
if not param_decls:
|
| 536 |
+
param_decls = ("--help",)
|
| 537 |
+
|
| 538 |
+
kwargs.setdefault("is_flag", True)
|
| 539 |
+
kwargs.setdefault("expose_value", False)
|
| 540 |
+
kwargs.setdefault("is_eager", True)
|
| 541 |
+
kwargs.setdefault("help", _("Show this message and exit."))
|
| 542 |
+
kwargs.setdefault("callback", self.show_help)
|
| 543 |
+
|
| 544 |
+
super().__init__(param_decls, **kwargs)
|
| 545 |
+
|
| 546 |
+
@staticmethod
|
| 547 |
+
def show_help(ctx: Context, param: Parameter, value: bool) -> None:
|
| 548 |
+
"""Callback that print the help page on ``<stdout>`` and exits."""
|
| 549 |
+
if value and not ctx.resilient_parsing:
|
| 550 |
+
echo(ctx.get_help(), color=ctx.color)
|
| 551 |
+
ctx.exit()
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def help_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
|
| 555 |
+
"""Decorator for the pre-configured ``--help`` option defined above.
|
| 556 |
+
|
| 557 |
+
:param param_decls: One or more option names. Defaults to the single
|
| 558 |
+
value ``"--help"``.
|
| 559 |
+
:param kwargs: Extra arguments are passed to :func:`option`.
|
| 560 |
+
"""
|
| 561 |
+
kwargs.setdefault("cls", HelpOption)
|
| 562 |
+
return option(*param_decls, **kwargs)
|
.venv/lib/python3.11/site-packages/click/py.typed
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/click/shell_completion.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import typing as t
|
| 4 |
+
from gettext import gettext as _
|
| 5 |
+
|
| 6 |
+
from .core import Argument
|
| 7 |
+
from .core import BaseCommand
|
| 8 |
+
from .core import Context
|
| 9 |
+
from .core import MultiCommand
|
| 10 |
+
from .core import Option
|
| 11 |
+
from .core import Parameter
|
| 12 |
+
from .core import ParameterSource
|
| 13 |
+
from .parser import split_arg_string
|
| 14 |
+
from .utils import echo
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def shell_complete(
|
| 18 |
+
cli: BaseCommand,
|
| 19 |
+
ctx_args: t.MutableMapping[str, t.Any],
|
| 20 |
+
prog_name: str,
|
| 21 |
+
complete_var: str,
|
| 22 |
+
instruction: str,
|
| 23 |
+
) -> int:
|
| 24 |
+
"""Perform shell completion for the given CLI program.
|
| 25 |
+
|
| 26 |
+
:param cli: Command being called.
|
| 27 |
+
:param ctx_args: Extra arguments to pass to
|
| 28 |
+
``cli.make_context``.
|
| 29 |
+
:param prog_name: Name of the executable in the shell.
|
| 30 |
+
:param complete_var: Name of the environment variable that holds
|
| 31 |
+
the completion instruction.
|
| 32 |
+
:param instruction: Value of ``complete_var`` with the completion
|
| 33 |
+
instruction and shell, in the form ``instruction_shell``.
|
| 34 |
+
:return: Status code to exit with.
|
| 35 |
+
"""
|
| 36 |
+
shell, _, instruction = instruction.partition("_")
|
| 37 |
+
comp_cls = get_completion_class(shell)
|
| 38 |
+
|
| 39 |
+
if comp_cls is None:
|
| 40 |
+
return 1
|
| 41 |
+
|
| 42 |
+
comp = comp_cls(cli, ctx_args, prog_name, complete_var)
|
| 43 |
+
|
| 44 |
+
if instruction == "source":
|
| 45 |
+
echo(comp.source())
|
| 46 |
+
return 0
|
| 47 |
+
|
| 48 |
+
if instruction == "complete":
|
| 49 |
+
echo(comp.complete())
|
| 50 |
+
return 0
|
| 51 |
+
|
| 52 |
+
return 1
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CompletionItem:
|
| 56 |
+
"""Represents a completion value and metadata about the value. The
|
| 57 |
+
default metadata is ``type`` to indicate special shell handling,
|
| 58 |
+
and ``help`` if a shell supports showing a help string next to the
|
| 59 |
+
value.
|
| 60 |
+
|
| 61 |
+
Arbitrary parameters can be passed when creating the object, and
|
| 62 |
+
accessed using ``item.attr``. If an attribute wasn't passed,
|
| 63 |
+
accessing it returns ``None``.
|
| 64 |
+
|
| 65 |
+
:param value: The completion suggestion.
|
| 66 |
+
:param type: Tells the shell script to provide special completion
|
| 67 |
+
support for the type. Click uses ``"dir"`` and ``"file"``.
|
| 68 |
+
:param help: String shown next to the value if supported.
|
| 69 |
+
:param kwargs: Arbitrary metadata. The built-in implementations
|
| 70 |
+
don't use this, but custom type completions paired with custom
|
| 71 |
+
shell support could use it.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
__slots__ = ("value", "type", "help", "_info")
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
value: t.Any,
|
| 79 |
+
type: str = "plain",
|
| 80 |
+
help: t.Optional[str] = None,
|
| 81 |
+
**kwargs: t.Any,
|
| 82 |
+
) -> None:
|
| 83 |
+
self.value: t.Any = value
|
| 84 |
+
self.type: str = type
|
| 85 |
+
self.help: t.Optional[str] = help
|
| 86 |
+
self._info = kwargs
|
| 87 |
+
|
| 88 |
+
def __getattr__(self, name: str) -> t.Any:
|
| 89 |
+
return self._info.get(name)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Only Bash >= 4.4 has the nosort option.
|
| 93 |
+
_SOURCE_BASH = """\
|
| 94 |
+
%(complete_func)s() {
|
| 95 |
+
local IFS=$'\\n'
|
| 96 |
+
local response
|
| 97 |
+
|
| 98 |
+
response=$(env COMP_WORDS="${COMP_WORDS[*]}" COMP_CWORD=$COMP_CWORD \
|
| 99 |
+
%(complete_var)s=bash_complete $1)
|
| 100 |
+
|
| 101 |
+
for completion in $response; do
|
| 102 |
+
IFS=',' read type value <<< "$completion"
|
| 103 |
+
|
| 104 |
+
if [[ $type == 'dir' ]]; then
|
| 105 |
+
COMPREPLY=()
|
| 106 |
+
compopt -o dirnames
|
| 107 |
+
elif [[ $type == 'file' ]]; then
|
| 108 |
+
COMPREPLY=()
|
| 109 |
+
compopt -o default
|
| 110 |
+
elif [[ $type == 'plain' ]]; then
|
| 111 |
+
COMPREPLY+=($value)
|
| 112 |
+
fi
|
| 113 |
+
done
|
| 114 |
+
|
| 115 |
+
return 0
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
%(complete_func)s_setup() {
|
| 119 |
+
complete -o nosort -F %(complete_func)s %(prog_name)s
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
%(complete_func)s_setup;
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
_SOURCE_ZSH = """\
|
| 126 |
+
#compdef %(prog_name)s
|
| 127 |
+
|
| 128 |
+
%(complete_func)s() {
|
| 129 |
+
local -a completions
|
| 130 |
+
local -a completions_with_descriptions
|
| 131 |
+
local -a response
|
| 132 |
+
(( ! $+commands[%(prog_name)s] )) && return 1
|
| 133 |
+
|
| 134 |
+
response=("${(@f)$(env COMP_WORDS="${words[*]}" COMP_CWORD=$((CURRENT-1)) \
|
| 135 |
+
%(complete_var)s=zsh_complete %(prog_name)s)}")
|
| 136 |
+
|
| 137 |
+
for type key descr in ${response}; do
|
| 138 |
+
if [[ "$type" == "plain" ]]; then
|
| 139 |
+
if [[ "$descr" == "_" ]]; then
|
| 140 |
+
completions+=("$key")
|
| 141 |
+
else
|
| 142 |
+
completions_with_descriptions+=("$key":"$descr")
|
| 143 |
+
fi
|
| 144 |
+
elif [[ "$type" == "dir" ]]; then
|
| 145 |
+
_path_files -/
|
| 146 |
+
elif [[ "$type" == "file" ]]; then
|
| 147 |
+
_path_files -f
|
| 148 |
+
fi
|
| 149 |
+
done
|
| 150 |
+
|
| 151 |
+
if [ -n "$completions_with_descriptions" ]; then
|
| 152 |
+
_describe -V unsorted completions_with_descriptions -U
|
| 153 |
+
fi
|
| 154 |
+
|
| 155 |
+
if [ -n "$completions" ]; then
|
| 156 |
+
compadd -U -V unsorted -a completions
|
| 157 |
+
fi
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
if [[ $zsh_eval_context[-1] == loadautofunc ]]; then
|
| 161 |
+
# autoload from fpath, call function directly
|
| 162 |
+
%(complete_func)s "$@"
|
| 163 |
+
else
|
| 164 |
+
# eval/source/. command, register function for later
|
| 165 |
+
compdef %(complete_func)s %(prog_name)s
|
| 166 |
+
fi
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
_SOURCE_FISH = """\
|
| 170 |
+
function %(complete_func)s;
|
| 171 |
+
set -l response (env %(complete_var)s=fish_complete COMP_WORDS=(commandline -cp) \
|
| 172 |
+
COMP_CWORD=(commandline -t) %(prog_name)s);
|
| 173 |
+
|
| 174 |
+
for completion in $response;
|
| 175 |
+
set -l metadata (string split "," $completion);
|
| 176 |
+
|
| 177 |
+
if test $metadata[1] = "dir";
|
| 178 |
+
__fish_complete_directories $metadata[2];
|
| 179 |
+
else if test $metadata[1] = "file";
|
| 180 |
+
__fish_complete_path $metadata[2];
|
| 181 |
+
else if test $metadata[1] = "plain";
|
| 182 |
+
echo $metadata[2];
|
| 183 |
+
end;
|
| 184 |
+
end;
|
| 185 |
+
end;
|
| 186 |
+
|
| 187 |
+
complete --no-files --command %(prog_name)s --arguments \
|
| 188 |
+
"(%(complete_func)s)";
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class ShellComplete:
|
| 193 |
+
"""Base class for providing shell completion support. A subclass for
|
| 194 |
+
a given shell will override attributes and methods to implement the
|
| 195 |
+
completion instructions (``source`` and ``complete``).
|
| 196 |
+
|
| 197 |
+
:param cli: Command being called.
|
| 198 |
+
:param prog_name: Name of the executable in the shell.
|
| 199 |
+
:param complete_var: Name of the environment variable that holds
|
| 200 |
+
the completion instruction.
|
| 201 |
+
|
| 202 |
+
.. versionadded:: 8.0
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
name: t.ClassVar[str]
|
| 206 |
+
"""Name to register the shell as with :func:`add_completion_class`.
|
| 207 |
+
This is used in completion instructions (``{name}_source`` and
|
| 208 |
+
``{name}_complete``).
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
source_template: t.ClassVar[str]
|
| 212 |
+
"""Completion script template formatted by :meth:`source`. This must
|
| 213 |
+
be provided by subclasses.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
cli: BaseCommand,
|
| 219 |
+
ctx_args: t.MutableMapping[str, t.Any],
|
| 220 |
+
prog_name: str,
|
| 221 |
+
complete_var: str,
|
| 222 |
+
) -> None:
|
| 223 |
+
self.cli = cli
|
| 224 |
+
self.ctx_args = ctx_args
|
| 225 |
+
self.prog_name = prog_name
|
| 226 |
+
self.complete_var = complete_var
|
| 227 |
+
|
| 228 |
+
@property
|
| 229 |
+
def func_name(self) -> str:
|
| 230 |
+
"""The name of the shell function defined by the completion
|
| 231 |
+
script.
|
| 232 |
+
"""
|
| 233 |
+
safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), flags=re.ASCII)
|
| 234 |
+
return f"_{safe_name}_completion"
|
| 235 |
+
|
| 236 |
+
def source_vars(self) -> t.Dict[str, t.Any]:
|
| 237 |
+
"""Vars for formatting :attr:`source_template`.
|
| 238 |
+
|
| 239 |
+
By default this provides ``complete_func``, ``complete_var``,
|
| 240 |
+
and ``prog_name``.
|
| 241 |
+
"""
|
| 242 |
+
return {
|
| 243 |
+
"complete_func": self.func_name,
|
| 244 |
+
"complete_var": self.complete_var,
|
| 245 |
+
"prog_name": self.prog_name,
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
def source(self) -> str:
|
| 249 |
+
"""Produce the shell script that defines the completion
|
| 250 |
+
function. By default this ``%``-style formats
|
| 251 |
+
:attr:`source_template` with the dict returned by
|
| 252 |
+
:meth:`source_vars`.
|
| 253 |
+
"""
|
| 254 |
+
return self.source_template % self.source_vars()
|
| 255 |
+
|
| 256 |
+
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
|
| 257 |
+
"""Use the env vars defined by the shell script to return a
|
| 258 |
+
tuple of ``args, incomplete``. This must be implemented by
|
| 259 |
+
subclasses.
|
| 260 |
+
"""
|
| 261 |
+
raise NotImplementedError
|
| 262 |
+
|
| 263 |
+
def get_completions(
|
| 264 |
+
self, args: t.List[str], incomplete: str
|
| 265 |
+
) -> t.List[CompletionItem]:
|
| 266 |
+
"""Determine the context and last complete command or parameter
|
| 267 |
+
from the complete args. Call that object's ``shell_complete``
|
| 268 |
+
method to get the completions for the incomplete value.
|
| 269 |
+
|
| 270 |
+
:param args: List of complete args before the incomplete value.
|
| 271 |
+
:param incomplete: Value being completed. May be empty.
|
| 272 |
+
"""
|
| 273 |
+
ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args)
|
| 274 |
+
obj, incomplete = _resolve_incomplete(ctx, args, incomplete)
|
| 275 |
+
return obj.shell_complete(ctx, incomplete)
|
| 276 |
+
|
| 277 |
+
def format_completion(self, item: CompletionItem) -> str:
|
| 278 |
+
"""Format a completion item into the form recognized by the
|
| 279 |
+
shell script. This must be implemented by subclasses.
|
| 280 |
+
|
| 281 |
+
:param item: Completion item to format.
|
| 282 |
+
"""
|
| 283 |
+
raise NotImplementedError
|
| 284 |
+
|
| 285 |
+
def complete(self) -> str:
|
| 286 |
+
"""Produce the completion data to send back to the shell.
|
| 287 |
+
|
| 288 |
+
By default this calls :meth:`get_completion_args`, gets the
|
| 289 |
+
completions, then calls :meth:`format_completion` for each
|
| 290 |
+
completion.
|
| 291 |
+
"""
|
| 292 |
+
args, incomplete = self.get_completion_args()
|
| 293 |
+
completions = self.get_completions(args, incomplete)
|
| 294 |
+
out = [self.format_completion(item) for item in completions]
|
| 295 |
+
return "\n".join(out)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class BashComplete(ShellComplete):
|
| 299 |
+
"""Shell completion for Bash."""
|
| 300 |
+
|
| 301 |
+
name = "bash"
|
| 302 |
+
source_template = _SOURCE_BASH
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
def _check_version() -> None:
|
| 306 |
+
import shutil
|
| 307 |
+
import subprocess
|
| 308 |
+
|
| 309 |
+
bash_exe = shutil.which("bash")
|
| 310 |
+
|
| 311 |
+
if bash_exe is None:
|
| 312 |
+
match = None
|
| 313 |
+
else:
|
| 314 |
+
output = subprocess.run(
|
| 315 |
+
[bash_exe, "--norc", "-c", 'echo "${BASH_VERSION}"'],
|
| 316 |
+
stdout=subprocess.PIPE,
|
| 317 |
+
)
|
| 318 |
+
match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode())
|
| 319 |
+
|
| 320 |
+
if match is not None:
|
| 321 |
+
major, minor = match.groups()
|
| 322 |
+
|
| 323 |
+
if major < "4" or major == "4" and minor < "4":
|
| 324 |
+
echo(
|
| 325 |
+
_(
|
| 326 |
+
"Shell completion is not supported for Bash"
|
| 327 |
+
" versions older than 4.4."
|
| 328 |
+
),
|
| 329 |
+
err=True,
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
echo(
|
| 333 |
+
_("Couldn't detect Bash version, shell completion is not supported."),
|
| 334 |
+
err=True,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def source(self) -> str:
|
| 338 |
+
self._check_version()
|
| 339 |
+
return super().source()
|
| 340 |
+
|
| 341 |
+
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
|
| 342 |
+
cwords = split_arg_string(os.environ["COMP_WORDS"])
|
| 343 |
+
cword = int(os.environ["COMP_CWORD"])
|
| 344 |
+
args = cwords[1:cword]
|
| 345 |
+
|
| 346 |
+
try:
|
| 347 |
+
incomplete = cwords[cword]
|
| 348 |
+
except IndexError:
|
| 349 |
+
incomplete = ""
|
| 350 |
+
|
| 351 |
+
return args, incomplete
|
| 352 |
+
|
| 353 |
+
def format_completion(self, item: CompletionItem) -> str:
|
| 354 |
+
return f"{item.type},{item.value}"
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class ZshComplete(ShellComplete):
|
| 358 |
+
"""Shell completion for Zsh."""
|
| 359 |
+
|
| 360 |
+
name = "zsh"
|
| 361 |
+
source_template = _SOURCE_ZSH
|
| 362 |
+
|
| 363 |
+
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
|
| 364 |
+
cwords = split_arg_string(os.environ["COMP_WORDS"])
|
| 365 |
+
cword = int(os.environ["COMP_CWORD"])
|
| 366 |
+
args = cwords[1:cword]
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
incomplete = cwords[cword]
|
| 370 |
+
except IndexError:
|
| 371 |
+
incomplete = ""
|
| 372 |
+
|
| 373 |
+
return args, incomplete
|
| 374 |
+
|
| 375 |
+
def format_completion(self, item: CompletionItem) -> str:
|
| 376 |
+
return f"{item.type}\n{item.value}\n{item.help if item.help else '_'}"
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class FishComplete(ShellComplete):
|
| 380 |
+
"""Shell completion for Fish."""
|
| 381 |
+
|
| 382 |
+
name = "fish"
|
| 383 |
+
source_template = _SOURCE_FISH
|
| 384 |
+
|
| 385 |
+
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
|
| 386 |
+
cwords = split_arg_string(os.environ["COMP_WORDS"])
|
| 387 |
+
incomplete = os.environ["COMP_CWORD"]
|
| 388 |
+
args = cwords[1:]
|
| 389 |
+
|
| 390 |
+
# Fish stores the partial word in both COMP_WORDS and
|
| 391 |
+
# COMP_CWORD, remove it from complete args.
|
| 392 |
+
if incomplete and args and args[-1] == incomplete:
|
| 393 |
+
args.pop()
|
| 394 |
+
|
| 395 |
+
return args, incomplete
|
| 396 |
+
|
| 397 |
+
def format_completion(self, item: CompletionItem) -> str:
|
| 398 |
+
if item.help:
|
| 399 |
+
return f"{item.type},{item.value}\t{item.help}"
|
| 400 |
+
|
| 401 |
+
return f"{item.type},{item.value}"
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
ShellCompleteType = t.TypeVar("ShellCompleteType", bound=t.Type[ShellComplete])
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
_available_shells: t.Dict[str, t.Type[ShellComplete]] = {
|
| 408 |
+
"bash": BashComplete,
|
| 409 |
+
"fish": FishComplete,
|
| 410 |
+
"zsh": ZshComplete,
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def add_completion_class(
|
| 415 |
+
cls: ShellCompleteType, name: t.Optional[str] = None
|
| 416 |
+
) -> ShellCompleteType:
|
| 417 |
+
"""Register a :class:`ShellComplete` subclass under the given name.
|
| 418 |
+
The name will be provided by the completion instruction environment
|
| 419 |
+
variable during completion.
|
| 420 |
+
|
| 421 |
+
:param cls: The completion class that will handle completion for the
|
| 422 |
+
shell.
|
| 423 |
+
:param name: Name to register the class under. Defaults to the
|
| 424 |
+
class's ``name`` attribute.
|
| 425 |
+
"""
|
| 426 |
+
if name is None:
|
| 427 |
+
name = cls.name
|
| 428 |
+
|
| 429 |
+
_available_shells[name] = cls
|
| 430 |
+
|
| 431 |
+
return cls
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def get_completion_class(shell: str) -> t.Optional[t.Type[ShellComplete]]:
|
| 435 |
+
"""Look up a registered :class:`ShellComplete` subclass by the name
|
| 436 |
+
provided by the completion instruction environment variable. If the
|
| 437 |
+
name isn't registered, returns ``None``.
|
| 438 |
+
|
| 439 |
+
:param shell: Name the class is registered under.
|
| 440 |
+
"""
|
| 441 |
+
return _available_shells.get(shell)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool:
|
| 445 |
+
"""Determine if the given parameter is an argument that can still
|
| 446 |
+
accept values.
|
| 447 |
+
|
| 448 |
+
:param ctx: Invocation context for the command represented by the
|
| 449 |
+
parsed complete args.
|
| 450 |
+
:param param: Argument object being checked.
|
| 451 |
+
"""
|
| 452 |
+
if not isinstance(param, Argument):
|
| 453 |
+
return False
|
| 454 |
+
|
| 455 |
+
assert param.name is not None
|
| 456 |
+
# Will be None if expose_value is False.
|
| 457 |
+
value = ctx.params.get(param.name)
|
| 458 |
+
return (
|
| 459 |
+
param.nargs == -1
|
| 460 |
+
or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE
|
| 461 |
+
or (
|
| 462 |
+
param.nargs > 1
|
| 463 |
+
and isinstance(value, (tuple, list))
|
| 464 |
+
and len(value) < param.nargs
|
| 465 |
+
)
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _start_of_option(ctx: Context, value: str) -> bool:
|
| 470 |
+
"""Check if the value looks like the start of an option."""
|
| 471 |
+
if not value:
|
| 472 |
+
return False
|
| 473 |
+
|
| 474 |
+
c = value[0]
|
| 475 |
+
return c in ctx._opt_prefixes
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _is_incomplete_option(ctx: Context, args: t.List[str], param: Parameter) -> bool:
|
| 479 |
+
"""Determine if the given parameter is an option that needs a value.
|
| 480 |
+
|
| 481 |
+
:param args: List of complete args before the incomplete value.
|
| 482 |
+
:param param: Option object being checked.
|
| 483 |
+
"""
|
| 484 |
+
if not isinstance(param, Option):
|
| 485 |
+
return False
|
| 486 |
+
|
| 487 |
+
if param.is_flag or param.count:
|
| 488 |
+
return False
|
| 489 |
+
|
| 490 |
+
last_option = None
|
| 491 |
+
|
| 492 |
+
for index, arg in enumerate(reversed(args)):
|
| 493 |
+
if index + 1 > param.nargs:
|
| 494 |
+
break
|
| 495 |
+
|
| 496 |
+
if _start_of_option(ctx, arg):
|
| 497 |
+
last_option = arg
|
| 498 |
+
|
| 499 |
+
return last_option is not None and last_option in param.opts
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def _resolve_context(
|
| 503 |
+
cli: BaseCommand,
|
| 504 |
+
ctx_args: t.MutableMapping[str, t.Any],
|
| 505 |
+
prog_name: str,
|
| 506 |
+
args: t.List[str],
|
| 507 |
+
) -> Context:
|
| 508 |
+
"""Produce the context hierarchy starting with the command and
|
| 509 |
+
traversing the complete arguments. This only follows the commands,
|
| 510 |
+
it doesn't trigger input prompts or callbacks.
|
| 511 |
+
|
| 512 |
+
:param cli: Command being called.
|
| 513 |
+
:param prog_name: Name of the executable in the shell.
|
| 514 |
+
:param args: List of complete args before the incomplete value.
|
| 515 |
+
"""
|
| 516 |
+
ctx_args["resilient_parsing"] = True
|
| 517 |
+
ctx = cli.make_context(prog_name, args.copy(), **ctx_args)
|
| 518 |
+
args = ctx.protected_args + ctx.args
|
| 519 |
+
|
| 520 |
+
while args:
|
| 521 |
+
command = ctx.command
|
| 522 |
+
|
| 523 |
+
if isinstance(command, MultiCommand):
|
| 524 |
+
if not command.chain:
|
| 525 |
+
name, cmd, args = command.resolve_command(ctx, args)
|
| 526 |
+
|
| 527 |
+
if cmd is None:
|
| 528 |
+
return ctx
|
| 529 |
+
|
| 530 |
+
ctx = cmd.make_context(name, args, parent=ctx, resilient_parsing=True)
|
| 531 |
+
args = ctx.protected_args + ctx.args
|
| 532 |
+
else:
|
| 533 |
+
sub_ctx = ctx
|
| 534 |
+
|
| 535 |
+
while args:
|
| 536 |
+
name, cmd, args = command.resolve_command(ctx, args)
|
| 537 |
+
|
| 538 |
+
if cmd is None:
|
| 539 |
+
return ctx
|
| 540 |
+
|
| 541 |
+
sub_ctx = cmd.make_context(
|
| 542 |
+
name,
|
| 543 |
+
args,
|
| 544 |
+
parent=ctx,
|
| 545 |
+
allow_extra_args=True,
|
| 546 |
+
allow_interspersed_args=False,
|
| 547 |
+
resilient_parsing=True,
|
| 548 |
+
)
|
| 549 |
+
args = sub_ctx.args
|
| 550 |
+
|
| 551 |
+
ctx = sub_ctx
|
| 552 |
+
args = [*sub_ctx.protected_args, *sub_ctx.args]
|
| 553 |
+
else:
|
| 554 |
+
break
|
| 555 |
+
|
| 556 |
+
return ctx
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def _resolve_incomplete(
|
| 560 |
+
ctx: Context, args: t.List[str], incomplete: str
|
| 561 |
+
) -> t.Tuple[t.Union[BaseCommand, Parameter], str]:
|
| 562 |
+
"""Find the Click object that will handle the completion of the
|
| 563 |
+
incomplete value. Return the object and the incomplete value.
|
| 564 |
+
|
| 565 |
+
:param ctx: Invocation context for the command represented by
|
| 566 |
+
the parsed complete args.
|
| 567 |
+
:param args: List of complete args before the incomplete value.
|
| 568 |
+
:param incomplete: Value being completed. May be empty.
|
| 569 |
+
"""
|
| 570 |
+
# Different shells treat an "=" between a long option name and
|
| 571 |
+
# value differently. Might keep the value joined, return the "="
|
| 572 |
+
# as a separate item, or return the split name and value. Always
|
| 573 |
+
# split and discard the "=" to make completion easier.
|
| 574 |
+
if incomplete == "=":
|
| 575 |
+
incomplete = ""
|
| 576 |
+
elif "=" in incomplete and _start_of_option(ctx, incomplete):
|
| 577 |
+
name, _, incomplete = incomplete.partition("=")
|
| 578 |
+
args.append(name)
|
| 579 |
+
|
| 580 |
+
# The "--" marker tells Click to stop treating values as options
|
| 581 |
+
# even if they start with the option character. If it hasn't been
|
| 582 |
+
# given and the incomplete arg looks like an option, the current
|
| 583 |
+
# command will provide option name completions.
|
| 584 |
+
if "--" not in args and _start_of_option(ctx, incomplete):
|
| 585 |
+
return ctx.command, incomplete
|
| 586 |
+
|
| 587 |
+
params = ctx.command.get_params(ctx)
|
| 588 |
+
|
| 589 |
+
# If the last complete arg is an option name with an incomplete
|
| 590 |
+
# value, the option will provide value completions.
|
| 591 |
+
for param in params:
|
| 592 |
+
if _is_incomplete_option(ctx, args, param):
|
| 593 |
+
return param, incomplete
|
| 594 |
+
|
| 595 |
+
# It's not an option name or value. The first argument without a
|
| 596 |
+
# parsed value will provide value completions.
|
| 597 |
+
for param in params:
|
| 598 |
+
if _is_incomplete_argument(ctx, param):
|
| 599 |
+
return param, incomplete
|
| 600 |
+
|
| 601 |
+
# There were no unparsed arguments, the command may be a group that
|
| 602 |
+
# will provide command name completions.
|
| 603 |
+
return ctx.command, incomplete
|
.venv/lib/python3.11/site-packages/click/termui.py
ADDED
|
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import io
|
| 3 |
+
import itertools
|
| 4 |
+
import sys
|
| 5 |
+
import typing as t
|
| 6 |
+
from gettext import gettext as _
|
| 7 |
+
|
| 8 |
+
from ._compat import isatty
|
| 9 |
+
from ._compat import strip_ansi
|
| 10 |
+
from .exceptions import Abort
|
| 11 |
+
from .exceptions import UsageError
|
| 12 |
+
from .globals import resolve_color_default
|
| 13 |
+
from .types import Choice
|
| 14 |
+
from .types import convert_type
|
| 15 |
+
from .types import ParamType
|
| 16 |
+
from .utils import echo
|
| 17 |
+
from .utils import LazyFile
|
| 18 |
+
|
| 19 |
+
if t.TYPE_CHECKING:
|
| 20 |
+
from ._termui_impl import ProgressBar
|
| 21 |
+
|
| 22 |
+
V = t.TypeVar("V")
|
| 23 |
+
|
| 24 |
+
# The prompt functions to use. The doc tools currently override these
|
| 25 |
+
# functions to customize how they work.
|
| 26 |
+
visible_prompt_func: t.Callable[[str], str] = input
|
| 27 |
+
|
| 28 |
+
_ansi_colors = {
|
| 29 |
+
"black": 30,
|
| 30 |
+
"red": 31,
|
| 31 |
+
"green": 32,
|
| 32 |
+
"yellow": 33,
|
| 33 |
+
"blue": 34,
|
| 34 |
+
"magenta": 35,
|
| 35 |
+
"cyan": 36,
|
| 36 |
+
"white": 37,
|
| 37 |
+
"reset": 39,
|
| 38 |
+
"bright_black": 90,
|
| 39 |
+
"bright_red": 91,
|
| 40 |
+
"bright_green": 92,
|
| 41 |
+
"bright_yellow": 93,
|
| 42 |
+
"bright_blue": 94,
|
| 43 |
+
"bright_magenta": 95,
|
| 44 |
+
"bright_cyan": 96,
|
| 45 |
+
"bright_white": 97,
|
| 46 |
+
}
|
| 47 |
+
_ansi_reset_all = "\033[0m"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def hidden_prompt_func(prompt: str) -> str:
|
| 51 |
+
import getpass
|
| 52 |
+
|
| 53 |
+
return getpass.getpass(prompt)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _build_prompt(
|
| 57 |
+
text: str,
|
| 58 |
+
suffix: str,
|
| 59 |
+
show_default: bool = False,
|
| 60 |
+
default: t.Optional[t.Any] = None,
|
| 61 |
+
show_choices: bool = True,
|
| 62 |
+
type: t.Optional[ParamType] = None,
|
| 63 |
+
) -> str:
|
| 64 |
+
prompt = text
|
| 65 |
+
if type is not None and show_choices and isinstance(type, Choice):
|
| 66 |
+
prompt += f" ({', '.join(map(str, type.choices))})"
|
| 67 |
+
if default is not None and show_default:
|
| 68 |
+
prompt = f"{prompt} [{_format_default(default)}]"
|
| 69 |
+
return f"{prompt}{suffix}"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _format_default(default: t.Any) -> t.Any:
|
| 73 |
+
if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"):
|
| 74 |
+
return default.name
|
| 75 |
+
|
| 76 |
+
return default
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def prompt(
|
| 80 |
+
text: str,
|
| 81 |
+
default: t.Optional[t.Any] = None,
|
| 82 |
+
hide_input: bool = False,
|
| 83 |
+
confirmation_prompt: t.Union[bool, str] = False,
|
| 84 |
+
type: t.Optional[t.Union[ParamType, t.Any]] = None,
|
| 85 |
+
value_proc: t.Optional[t.Callable[[str], t.Any]] = None,
|
| 86 |
+
prompt_suffix: str = ": ",
|
| 87 |
+
show_default: bool = True,
|
| 88 |
+
err: bool = False,
|
| 89 |
+
show_choices: bool = True,
|
| 90 |
+
) -> t.Any:
|
| 91 |
+
"""Prompts a user for input. This is a convenience function that can
|
| 92 |
+
be used to prompt a user for input later.
|
| 93 |
+
|
| 94 |
+
If the user aborts the input by sending an interrupt signal, this
|
| 95 |
+
function will catch it and raise a :exc:`Abort` exception.
|
| 96 |
+
|
| 97 |
+
:param text: the text to show for the prompt.
|
| 98 |
+
:param default: the default value to use if no input happens. If this
|
| 99 |
+
is not given it will prompt until it's aborted.
|
| 100 |
+
:param hide_input: if this is set to true then the input value will
|
| 101 |
+
be hidden.
|
| 102 |
+
:param confirmation_prompt: Prompt a second time to confirm the
|
| 103 |
+
value. Can be set to a string instead of ``True`` to customize
|
| 104 |
+
the message.
|
| 105 |
+
:param type: the type to use to check the value against.
|
| 106 |
+
:param value_proc: if this parameter is provided it's a function that
|
| 107 |
+
is invoked instead of the type conversion to
|
| 108 |
+
convert a value.
|
| 109 |
+
:param prompt_suffix: a suffix that should be added to the prompt.
|
| 110 |
+
:param show_default: shows or hides the default value in the prompt.
|
| 111 |
+
:param err: if set to true the file defaults to ``stderr`` instead of
|
| 112 |
+
``stdout``, the same as with echo.
|
| 113 |
+
:param show_choices: Show or hide choices if the passed type is a Choice.
|
| 114 |
+
For example if type is a Choice of either day or week,
|
| 115 |
+
show_choices is true and text is "Group by" then the
|
| 116 |
+
prompt will be "Group by (day, week): ".
|
| 117 |
+
|
| 118 |
+
.. versionadded:: 8.0
|
| 119 |
+
``confirmation_prompt`` can be a custom string.
|
| 120 |
+
|
| 121 |
+
.. versionadded:: 7.0
|
| 122 |
+
Added the ``show_choices`` parameter.
|
| 123 |
+
|
| 124 |
+
.. versionadded:: 6.0
|
| 125 |
+
Added unicode support for cmd.exe on Windows.
|
| 126 |
+
|
| 127 |
+
.. versionadded:: 4.0
|
| 128 |
+
Added the `err` parameter.
|
| 129 |
+
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def prompt_func(text: str) -> str:
|
| 133 |
+
f = hidden_prompt_func if hide_input else visible_prompt_func
|
| 134 |
+
try:
|
| 135 |
+
# Write the prompt separately so that we get nice
|
| 136 |
+
# coloring through colorama on Windows
|
| 137 |
+
echo(text.rstrip(" "), nl=False, err=err)
|
| 138 |
+
# Echo a space to stdout to work around an issue where
|
| 139 |
+
# readline causes backspace to clear the whole line.
|
| 140 |
+
return f(" ")
|
| 141 |
+
except (KeyboardInterrupt, EOFError):
|
| 142 |
+
# getpass doesn't print a newline if the user aborts input with ^C.
|
| 143 |
+
# Allegedly this behavior is inherited from getpass(3).
|
| 144 |
+
# A doc bug has been filed at https://bugs.python.org/issue24711
|
| 145 |
+
if hide_input:
|
| 146 |
+
echo(None, err=err)
|
| 147 |
+
raise Abort() from None
|
| 148 |
+
|
| 149 |
+
if value_proc is None:
|
| 150 |
+
value_proc = convert_type(type, default)
|
| 151 |
+
|
| 152 |
+
prompt = _build_prompt(
|
| 153 |
+
text, prompt_suffix, show_default, default, show_choices, type
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if confirmation_prompt:
|
| 157 |
+
if confirmation_prompt is True:
|
| 158 |
+
confirmation_prompt = _("Repeat for confirmation")
|
| 159 |
+
|
| 160 |
+
confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix)
|
| 161 |
+
|
| 162 |
+
while True:
|
| 163 |
+
while True:
|
| 164 |
+
value = prompt_func(prompt)
|
| 165 |
+
if value:
|
| 166 |
+
break
|
| 167 |
+
elif default is not None:
|
| 168 |
+
value = default
|
| 169 |
+
break
|
| 170 |
+
try:
|
| 171 |
+
result = value_proc(value)
|
| 172 |
+
except UsageError as e:
|
| 173 |
+
if hide_input:
|
| 174 |
+
echo(_("Error: The value you entered was invalid."), err=err)
|
| 175 |
+
else:
|
| 176 |
+
echo(_("Error: {e.message}").format(e=e), err=err)
|
| 177 |
+
continue
|
| 178 |
+
if not confirmation_prompt:
|
| 179 |
+
return result
|
| 180 |
+
while True:
|
| 181 |
+
value2 = prompt_func(confirmation_prompt)
|
| 182 |
+
is_empty = not value and not value2
|
| 183 |
+
if value2 or is_empty:
|
| 184 |
+
break
|
| 185 |
+
if value == value2:
|
| 186 |
+
return result
|
| 187 |
+
echo(_("Error: The two entered values do not match."), err=err)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def confirm(
|
| 191 |
+
text: str,
|
| 192 |
+
default: t.Optional[bool] = False,
|
| 193 |
+
abort: bool = False,
|
| 194 |
+
prompt_suffix: str = ": ",
|
| 195 |
+
show_default: bool = True,
|
| 196 |
+
err: bool = False,
|
| 197 |
+
) -> bool:
|
| 198 |
+
"""Prompts for confirmation (yes/no question).
|
| 199 |
+
|
| 200 |
+
If the user aborts the input by sending a interrupt signal this
|
| 201 |
+
function will catch it and raise a :exc:`Abort` exception.
|
| 202 |
+
|
| 203 |
+
:param text: the question to ask.
|
| 204 |
+
:param default: The default value to use when no input is given. If
|
| 205 |
+
``None``, repeat until input is given.
|
| 206 |
+
:param abort: if this is set to `True` a negative answer aborts the
|
| 207 |
+
exception by raising :exc:`Abort`.
|
| 208 |
+
:param prompt_suffix: a suffix that should be added to the prompt.
|
| 209 |
+
:param show_default: shows or hides the default value in the prompt.
|
| 210 |
+
:param err: if set to true the file defaults to ``stderr`` instead of
|
| 211 |
+
``stdout``, the same as with echo.
|
| 212 |
+
|
| 213 |
+
.. versionchanged:: 8.0
|
| 214 |
+
Repeat until input is given if ``default`` is ``None``.
|
| 215 |
+
|
| 216 |
+
.. versionadded:: 4.0
|
| 217 |
+
Added the ``err`` parameter.
|
| 218 |
+
"""
|
| 219 |
+
prompt = _build_prompt(
|
| 220 |
+
text,
|
| 221 |
+
prompt_suffix,
|
| 222 |
+
show_default,
|
| 223 |
+
"y/n" if default is None else ("Y/n" if default else "y/N"),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
while True:
|
| 227 |
+
try:
|
| 228 |
+
# Write the prompt separately so that we get nice
|
| 229 |
+
# coloring through colorama on Windows
|
| 230 |
+
echo(prompt.rstrip(" "), nl=False, err=err)
|
| 231 |
+
# Echo a space to stdout to work around an issue where
|
| 232 |
+
# readline causes backspace to clear the whole line.
|
| 233 |
+
value = visible_prompt_func(" ").lower().strip()
|
| 234 |
+
except (KeyboardInterrupt, EOFError):
|
| 235 |
+
raise Abort() from None
|
| 236 |
+
if value in ("y", "yes"):
|
| 237 |
+
rv = True
|
| 238 |
+
elif value in ("n", "no"):
|
| 239 |
+
rv = False
|
| 240 |
+
elif default is not None and value == "":
|
| 241 |
+
rv = default
|
| 242 |
+
else:
|
| 243 |
+
echo(_("Error: invalid input"), err=err)
|
| 244 |
+
continue
|
| 245 |
+
break
|
| 246 |
+
if abort and not rv:
|
| 247 |
+
raise Abort()
|
| 248 |
+
return rv
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def echo_via_pager(
|
| 252 |
+
text_or_generator: t.Union[t.Iterable[str], t.Callable[[], t.Iterable[str]], str],
|
| 253 |
+
color: t.Optional[bool] = None,
|
| 254 |
+
) -> None:
|
| 255 |
+
"""This function takes a text and shows it via an environment specific
|
| 256 |
+
pager on stdout.
|
| 257 |
+
|
| 258 |
+
.. versionchanged:: 3.0
|
| 259 |
+
Added the `color` flag.
|
| 260 |
+
|
| 261 |
+
:param text_or_generator: the text to page, or alternatively, a
|
| 262 |
+
generator emitting the text to page.
|
| 263 |
+
:param color: controls if the pager supports ANSI colors or not. The
|
| 264 |
+
default is autodetection.
|
| 265 |
+
"""
|
| 266 |
+
color = resolve_color_default(color)
|
| 267 |
+
|
| 268 |
+
if inspect.isgeneratorfunction(text_or_generator):
|
| 269 |
+
i = t.cast(t.Callable[[], t.Iterable[str]], text_or_generator)()
|
| 270 |
+
elif isinstance(text_or_generator, str):
|
| 271 |
+
i = [text_or_generator]
|
| 272 |
+
else:
|
| 273 |
+
i = iter(t.cast(t.Iterable[str], text_or_generator))
|
| 274 |
+
|
| 275 |
+
# convert every element of i to a text type if necessary
|
| 276 |
+
text_generator = (el if isinstance(el, str) else str(el) for el in i)
|
| 277 |
+
|
| 278 |
+
from ._termui_impl import pager
|
| 279 |
+
|
| 280 |
+
return pager(itertools.chain(text_generator, "\n"), color)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def progressbar(
|
| 284 |
+
iterable: t.Optional[t.Iterable[V]] = None,
|
| 285 |
+
length: t.Optional[int] = None,
|
| 286 |
+
label: t.Optional[str] = None,
|
| 287 |
+
show_eta: bool = True,
|
| 288 |
+
show_percent: t.Optional[bool] = None,
|
| 289 |
+
show_pos: bool = False,
|
| 290 |
+
item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
|
| 291 |
+
fill_char: str = "#",
|
| 292 |
+
empty_char: str = "-",
|
| 293 |
+
bar_template: str = "%(label)s [%(bar)s] %(info)s",
|
| 294 |
+
info_sep: str = " ",
|
| 295 |
+
width: int = 36,
|
| 296 |
+
file: t.Optional[t.TextIO] = None,
|
| 297 |
+
color: t.Optional[bool] = None,
|
| 298 |
+
update_min_steps: int = 1,
|
| 299 |
+
) -> "ProgressBar[V]":
|
| 300 |
+
"""This function creates an iterable context manager that can be used
|
| 301 |
+
to iterate over something while showing a progress bar. It will
|
| 302 |
+
either iterate over the `iterable` or `length` items (that are counted
|
| 303 |
+
up). While iteration happens, this function will print a rendered
|
| 304 |
+
progress bar to the given `file` (defaults to stdout) and will attempt
|
| 305 |
+
to calculate remaining time and more. By default, this progress bar
|
| 306 |
+
will not be rendered if the file is not a terminal.
|
| 307 |
+
|
| 308 |
+
The context manager creates the progress bar. When the context
|
| 309 |
+
manager is entered the progress bar is already created. With every
|
| 310 |
+
iteration over the progress bar, the iterable passed to the bar is
|
| 311 |
+
advanced and the bar is updated. When the context manager exits,
|
| 312 |
+
a newline is printed and the progress bar is finalized on screen.
|
| 313 |
+
|
| 314 |
+
Note: The progress bar is currently designed for use cases where the
|
| 315 |
+
total progress can be expected to take at least several seconds.
|
| 316 |
+
Because of this, the ProgressBar class object won't display
|
| 317 |
+
progress that is considered too fast, and progress where the time
|
| 318 |
+
between steps is less than a second.
|
| 319 |
+
|
| 320 |
+
No printing must happen or the progress bar will be unintentionally
|
| 321 |
+
destroyed.
|
| 322 |
+
|
| 323 |
+
Example usage::
|
| 324 |
+
|
| 325 |
+
with progressbar(items) as bar:
|
| 326 |
+
for item in bar:
|
| 327 |
+
do_something_with(item)
|
| 328 |
+
|
| 329 |
+
Alternatively, if no iterable is specified, one can manually update the
|
| 330 |
+
progress bar through the `update()` method instead of directly
|
| 331 |
+
iterating over the progress bar. The update method accepts the number
|
| 332 |
+
of steps to increment the bar with::
|
| 333 |
+
|
| 334 |
+
with progressbar(length=chunks.total_bytes) as bar:
|
| 335 |
+
for chunk in chunks:
|
| 336 |
+
process_chunk(chunk)
|
| 337 |
+
bar.update(chunks.bytes)
|
| 338 |
+
|
| 339 |
+
The ``update()`` method also takes an optional value specifying the
|
| 340 |
+
``current_item`` at the new position. This is useful when used
|
| 341 |
+
together with ``item_show_func`` to customize the output for each
|
| 342 |
+
manual step::
|
| 343 |
+
|
| 344 |
+
with click.progressbar(
|
| 345 |
+
length=total_size,
|
| 346 |
+
label='Unzipping archive',
|
| 347 |
+
item_show_func=lambda a: a.filename
|
| 348 |
+
) as bar:
|
| 349 |
+
for archive in zip_file:
|
| 350 |
+
archive.extract()
|
| 351 |
+
bar.update(archive.size, archive)
|
| 352 |
+
|
| 353 |
+
:param iterable: an iterable to iterate over. If not provided the length
|
| 354 |
+
is required.
|
| 355 |
+
:param length: the number of items to iterate over. By default the
|
| 356 |
+
progressbar will attempt to ask the iterator about its
|
| 357 |
+
length, which might or might not work. If an iterable is
|
| 358 |
+
also provided this parameter can be used to override the
|
| 359 |
+
length. If an iterable is not provided the progress bar
|
| 360 |
+
will iterate over a range of that length.
|
| 361 |
+
:param label: the label to show next to the progress bar.
|
| 362 |
+
:param show_eta: enables or disables the estimated time display. This is
|
| 363 |
+
automatically disabled if the length cannot be
|
| 364 |
+
determined.
|
| 365 |
+
:param show_percent: enables or disables the percentage display. The
|
| 366 |
+
default is `True` if the iterable has a length or
|
| 367 |
+
`False` if not.
|
| 368 |
+
:param show_pos: enables or disables the absolute position display. The
|
| 369 |
+
default is `False`.
|
| 370 |
+
:param item_show_func: A function called with the current item which
|
| 371 |
+
can return a string to show next to the progress bar. If the
|
| 372 |
+
function returns ``None`` nothing is shown. The current item can
|
| 373 |
+
be ``None``, such as when entering and exiting the bar.
|
| 374 |
+
:param fill_char: the character to use to show the filled part of the
|
| 375 |
+
progress bar.
|
| 376 |
+
:param empty_char: the character to use to show the non-filled part of
|
| 377 |
+
the progress bar.
|
| 378 |
+
:param bar_template: the format string to use as template for the bar.
|
| 379 |
+
The parameters in it are ``label`` for the label,
|
| 380 |
+
``bar`` for the progress bar and ``info`` for the
|
| 381 |
+
info section.
|
| 382 |
+
:param info_sep: the separator between multiple info items (eta etc.)
|
| 383 |
+
:param width: the width of the progress bar in characters, 0 means full
|
| 384 |
+
terminal width
|
| 385 |
+
:param file: The file to write to. If this is not a terminal then
|
| 386 |
+
only the label is printed.
|
| 387 |
+
:param color: controls if the terminal supports ANSI colors or not. The
|
| 388 |
+
default is autodetection. This is only needed if ANSI
|
| 389 |
+
codes are included anywhere in the progress bar output
|
| 390 |
+
which is not the case by default.
|
| 391 |
+
:param update_min_steps: Render only when this many updates have
|
| 392 |
+
completed. This allows tuning for very fast iterators.
|
| 393 |
+
|
| 394 |
+
.. versionchanged:: 8.0
|
| 395 |
+
Output is shown even if execution time is less than 0.5 seconds.
|
| 396 |
+
|
| 397 |
+
.. versionchanged:: 8.0
|
| 398 |
+
``item_show_func`` shows the current item, not the previous one.
|
| 399 |
+
|
| 400 |
+
.. versionchanged:: 8.0
|
| 401 |
+
Labels are echoed if the output is not a TTY. Reverts a change
|
| 402 |
+
in 7.0 that removed all output.
|
| 403 |
+
|
| 404 |
+
.. versionadded:: 8.0
|
| 405 |
+
Added the ``update_min_steps`` parameter.
|
| 406 |
+
|
| 407 |
+
.. versionchanged:: 4.0
|
| 408 |
+
Added the ``color`` parameter. Added the ``update`` method to
|
| 409 |
+
the object.
|
| 410 |
+
|
| 411 |
+
.. versionadded:: 2.0
|
| 412 |
+
"""
|
| 413 |
+
from ._termui_impl import ProgressBar
|
| 414 |
+
|
| 415 |
+
color = resolve_color_default(color)
|
| 416 |
+
return ProgressBar(
|
| 417 |
+
iterable=iterable,
|
| 418 |
+
length=length,
|
| 419 |
+
show_eta=show_eta,
|
| 420 |
+
show_percent=show_percent,
|
| 421 |
+
show_pos=show_pos,
|
| 422 |
+
item_show_func=item_show_func,
|
| 423 |
+
fill_char=fill_char,
|
| 424 |
+
empty_char=empty_char,
|
| 425 |
+
bar_template=bar_template,
|
| 426 |
+
info_sep=info_sep,
|
| 427 |
+
file=file,
|
| 428 |
+
label=label,
|
| 429 |
+
width=width,
|
| 430 |
+
color=color,
|
| 431 |
+
update_min_steps=update_min_steps,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def clear() -> None:
|
| 436 |
+
"""Clears the terminal screen. This will have the effect of clearing
|
| 437 |
+
the whole visible space of the terminal and moving the cursor to the
|
| 438 |
+
top left. This does not do anything if not connected to a terminal.
|
| 439 |
+
|
| 440 |
+
.. versionadded:: 2.0
|
| 441 |
+
"""
|
| 442 |
+
if not isatty(sys.stdout):
|
| 443 |
+
return
|
| 444 |
+
|
| 445 |
+
# ANSI escape \033[2J clears the screen, \033[1;1H moves the cursor
|
| 446 |
+
echo("\033[2J\033[1;1H", nl=False)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def _interpret_color(
|
| 450 |
+
color: t.Union[int, t.Tuple[int, int, int], str], offset: int = 0
|
| 451 |
+
) -> str:
|
| 452 |
+
if isinstance(color, int):
|
| 453 |
+
return f"{38 + offset};5;{color:d}"
|
| 454 |
+
|
| 455 |
+
if isinstance(color, (tuple, list)):
|
| 456 |
+
r, g, b = color
|
| 457 |
+
return f"{38 + offset};2;{r:d};{g:d};{b:d}"
|
| 458 |
+
|
| 459 |
+
return str(_ansi_colors[color] + offset)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def style(
|
| 463 |
+
text: t.Any,
|
| 464 |
+
fg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
|
| 465 |
+
bg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
|
| 466 |
+
bold: t.Optional[bool] = None,
|
| 467 |
+
dim: t.Optional[bool] = None,
|
| 468 |
+
underline: t.Optional[bool] = None,
|
| 469 |
+
overline: t.Optional[bool] = None,
|
| 470 |
+
italic: t.Optional[bool] = None,
|
| 471 |
+
blink: t.Optional[bool] = None,
|
| 472 |
+
reverse: t.Optional[bool] = None,
|
| 473 |
+
strikethrough: t.Optional[bool] = None,
|
| 474 |
+
reset: bool = True,
|
| 475 |
+
) -> str:
|
| 476 |
+
"""Styles a text with ANSI styles and returns the new string. By
|
| 477 |
+
default the styling is self contained which means that at the end
|
| 478 |
+
of the string a reset code is issued. This can be prevented by
|
| 479 |
+
passing ``reset=False``.
|
| 480 |
+
|
| 481 |
+
Examples::
|
| 482 |
+
|
| 483 |
+
click.echo(click.style('Hello World!', fg='green'))
|
| 484 |
+
click.echo(click.style('ATTENTION!', blink=True))
|
| 485 |
+
click.echo(click.style('Some things', reverse=True, fg='cyan'))
|
| 486 |
+
click.echo(click.style('More colors', fg=(255, 12, 128), bg=117))
|
| 487 |
+
|
| 488 |
+
Supported color names:
|
| 489 |
+
|
| 490 |
+
* ``black`` (might be a gray)
|
| 491 |
+
* ``red``
|
| 492 |
+
* ``green``
|
| 493 |
+
* ``yellow`` (might be an orange)
|
| 494 |
+
* ``blue``
|
| 495 |
+
* ``magenta``
|
| 496 |
+
* ``cyan``
|
| 497 |
+
* ``white`` (might be light gray)
|
| 498 |
+
* ``bright_black``
|
| 499 |
+
* ``bright_red``
|
| 500 |
+
* ``bright_green``
|
| 501 |
+
* ``bright_yellow``
|
| 502 |
+
* ``bright_blue``
|
| 503 |
+
* ``bright_magenta``
|
| 504 |
+
* ``bright_cyan``
|
| 505 |
+
* ``bright_white``
|
| 506 |
+
* ``reset`` (reset the color code only)
|
| 507 |
+
|
| 508 |
+
If the terminal supports it, color may also be specified as:
|
| 509 |
+
|
| 510 |
+
- An integer in the interval [0, 255]. The terminal must support
|
| 511 |
+
8-bit/256-color mode.
|
| 512 |
+
- An RGB tuple of three integers in [0, 255]. The terminal must
|
| 513 |
+
support 24-bit/true-color mode.
|
| 514 |
+
|
| 515 |
+
See https://en.wikipedia.org/wiki/ANSI_color and
|
| 516 |
+
https://gist.github.com/XVilka/8346728 for more information.
|
| 517 |
+
|
| 518 |
+
:param text: the string to style with ansi codes.
|
| 519 |
+
:param fg: if provided this will become the foreground color.
|
| 520 |
+
:param bg: if provided this will become the background color.
|
| 521 |
+
:param bold: if provided this will enable or disable bold mode.
|
| 522 |
+
:param dim: if provided this will enable or disable dim mode. This is
|
| 523 |
+
badly supported.
|
| 524 |
+
:param underline: if provided this will enable or disable underline.
|
| 525 |
+
:param overline: if provided this will enable or disable overline.
|
| 526 |
+
:param italic: if provided this will enable or disable italic.
|
| 527 |
+
:param blink: if provided this will enable or disable blinking.
|
| 528 |
+
:param reverse: if provided this will enable or disable inverse
|
| 529 |
+
rendering (foreground becomes background and the
|
| 530 |
+
other way round).
|
| 531 |
+
:param strikethrough: if provided this will enable or disable
|
| 532 |
+
striking through text.
|
| 533 |
+
:param reset: by default a reset-all code is added at the end of the
|
| 534 |
+
string which means that styles do not carry over. This
|
| 535 |
+
can be disabled to compose styles.
|
| 536 |
+
|
| 537 |
+
.. versionchanged:: 8.0
|
| 538 |
+
A non-string ``message`` is converted to a string.
|
| 539 |
+
|
| 540 |
+
.. versionchanged:: 8.0
|
| 541 |
+
Added support for 256 and RGB color codes.
|
| 542 |
+
|
| 543 |
+
.. versionchanged:: 8.0
|
| 544 |
+
Added the ``strikethrough``, ``italic``, and ``overline``
|
| 545 |
+
parameters.
|
| 546 |
+
|
| 547 |
+
.. versionchanged:: 7.0
|
| 548 |
+
Added support for bright colors.
|
| 549 |
+
|
| 550 |
+
.. versionadded:: 2.0
|
| 551 |
+
"""
|
| 552 |
+
if not isinstance(text, str):
|
| 553 |
+
text = str(text)
|
| 554 |
+
|
| 555 |
+
bits = []
|
| 556 |
+
|
| 557 |
+
if fg:
|
| 558 |
+
try:
|
| 559 |
+
bits.append(f"\033[{_interpret_color(fg)}m")
|
| 560 |
+
except KeyError:
|
| 561 |
+
raise TypeError(f"Unknown color {fg!r}") from None
|
| 562 |
+
|
| 563 |
+
if bg:
|
| 564 |
+
try:
|
| 565 |
+
bits.append(f"\033[{_interpret_color(bg, 10)}m")
|
| 566 |
+
except KeyError:
|
| 567 |
+
raise TypeError(f"Unknown color {bg!r}") from None
|
| 568 |
+
|
| 569 |
+
if bold is not None:
|
| 570 |
+
bits.append(f"\033[{1 if bold else 22}m")
|
| 571 |
+
if dim is not None:
|
| 572 |
+
bits.append(f"\033[{2 if dim else 22}m")
|
| 573 |
+
if underline is not None:
|
| 574 |
+
bits.append(f"\033[{4 if underline else 24}m")
|
| 575 |
+
if overline is not None:
|
| 576 |
+
bits.append(f"\033[{53 if overline else 55}m")
|
| 577 |
+
if italic is not None:
|
| 578 |
+
bits.append(f"\033[{3 if italic else 23}m")
|
| 579 |
+
if blink is not None:
|
| 580 |
+
bits.append(f"\033[{5 if blink else 25}m")
|
| 581 |
+
if reverse is not None:
|
| 582 |
+
bits.append(f"\033[{7 if reverse else 27}m")
|
| 583 |
+
if strikethrough is not None:
|
| 584 |
+
bits.append(f"\033[{9 if strikethrough else 29}m")
|
| 585 |
+
bits.append(text)
|
| 586 |
+
if reset:
|
| 587 |
+
bits.append(_ansi_reset_all)
|
| 588 |
+
return "".join(bits)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def unstyle(text: str) -> str:
|
| 592 |
+
"""Removes ANSI styling information from a string. Usually it's not
|
| 593 |
+
necessary to use this function as Click's echo function will
|
| 594 |
+
automatically remove styling if necessary.
|
| 595 |
+
|
| 596 |
+
.. versionadded:: 2.0
|
| 597 |
+
|
| 598 |
+
:param text: the text to remove style information from.
|
| 599 |
+
"""
|
| 600 |
+
return strip_ansi(text)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def secho(
|
| 604 |
+
message: t.Optional[t.Any] = None,
|
| 605 |
+
file: t.Optional[t.IO[t.AnyStr]] = None,
|
| 606 |
+
nl: bool = True,
|
| 607 |
+
err: bool = False,
|
| 608 |
+
color: t.Optional[bool] = None,
|
| 609 |
+
**styles: t.Any,
|
| 610 |
+
) -> None:
|
| 611 |
+
"""This function combines :func:`echo` and :func:`style` into one
|
| 612 |
+
call. As such the following two calls are the same::
|
| 613 |
+
|
| 614 |
+
click.secho('Hello World!', fg='green')
|
| 615 |
+
click.echo(click.style('Hello World!', fg='green'))
|
| 616 |
+
|
| 617 |
+
All keyword arguments are forwarded to the underlying functions
|
| 618 |
+
depending on which one they go with.
|
| 619 |
+
|
| 620 |
+
Non-string types will be converted to :class:`str`. However,
|
| 621 |
+
:class:`bytes` are passed directly to :meth:`echo` without applying
|
| 622 |
+
style. If you want to style bytes that represent text, call
|
| 623 |
+
:meth:`bytes.decode` first.
|
| 624 |
+
|
| 625 |
+
.. versionchanged:: 8.0
|
| 626 |
+
A non-string ``message`` is converted to a string. Bytes are
|
| 627 |
+
passed through without style applied.
|
| 628 |
+
|
| 629 |
+
.. versionadded:: 2.0
|
| 630 |
+
"""
|
| 631 |
+
if message is not None and not isinstance(message, (bytes, bytearray)):
|
| 632 |
+
message = style(message, **styles)
|
| 633 |
+
|
| 634 |
+
return echo(message, file=file, nl=nl, err=err, color=color)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def edit(
|
| 638 |
+
text: t.Optional[t.AnyStr] = None,
|
| 639 |
+
editor: t.Optional[str] = None,
|
| 640 |
+
env: t.Optional[t.Mapping[str, str]] = None,
|
| 641 |
+
require_save: bool = True,
|
| 642 |
+
extension: str = ".txt",
|
| 643 |
+
filename: t.Optional[str] = None,
|
| 644 |
+
) -> t.Optional[t.AnyStr]:
|
| 645 |
+
r"""Edits the given text in the defined editor. If an editor is given
|
| 646 |
+
(should be the full path to the executable but the regular operating
|
| 647 |
+
system search path is used for finding the executable) it overrides
|
| 648 |
+
the detected editor. Optionally, some environment variables can be
|
| 649 |
+
used. If the editor is closed without changes, `None` is returned. In
|
| 650 |
+
case a file is edited directly the return value is always `None` and
|
| 651 |
+
`require_save` and `extension` are ignored.
|
| 652 |
+
|
| 653 |
+
If the editor cannot be opened a :exc:`UsageError` is raised.
|
| 654 |
+
|
| 655 |
+
Note for Windows: to simplify cross-platform usage, the newlines are
|
| 656 |
+
automatically converted from POSIX to Windows and vice versa. As such,
|
| 657 |
+
the message here will have ``\n`` as newline markers.
|
| 658 |
+
|
| 659 |
+
:param text: the text to edit.
|
| 660 |
+
:param editor: optionally the editor to use. Defaults to automatic
|
| 661 |
+
detection.
|
| 662 |
+
:param env: environment variables to forward to the editor.
|
| 663 |
+
:param require_save: if this is true, then not saving in the editor
|
| 664 |
+
will make the return value become `None`.
|
| 665 |
+
:param extension: the extension to tell the editor about. This defaults
|
| 666 |
+
to `.txt` but changing this might change syntax
|
| 667 |
+
highlighting.
|
| 668 |
+
:param filename: if provided it will edit this file instead of the
|
| 669 |
+
provided text contents. It will not use a temporary
|
| 670 |
+
file as an indirection in that case.
|
| 671 |
+
"""
|
| 672 |
+
from ._termui_impl import Editor
|
| 673 |
+
|
| 674 |
+
ed = Editor(editor=editor, env=env, require_save=require_save, extension=extension)
|
| 675 |
+
|
| 676 |
+
if filename is None:
|
| 677 |
+
return ed.edit(text)
|
| 678 |
+
|
| 679 |
+
ed.edit_file(filename)
|
| 680 |
+
return None
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def launch(url: str, wait: bool = False, locate: bool = False) -> int:
|
| 684 |
+
"""This function launches the given URL (or filename) in the default
|
| 685 |
+
viewer application for this file type. If this is an executable, it
|
| 686 |
+
might launch the executable in a new session. The return value is
|
| 687 |
+
the exit code of the launched application. Usually, ``0`` indicates
|
| 688 |
+
success.
|
| 689 |
+
|
| 690 |
+
Examples::
|
| 691 |
+
|
| 692 |
+
click.launch('https://click.palletsprojects.com/')
|
| 693 |
+
click.launch('/my/downloaded/file', locate=True)
|
| 694 |
+
|
| 695 |
+
.. versionadded:: 2.0
|
| 696 |
+
|
| 697 |
+
:param url: URL or filename of the thing to launch.
|
| 698 |
+
:param wait: Wait for the program to exit before returning. This
|
| 699 |
+
only works if the launched program blocks. In particular,
|
| 700 |
+
``xdg-open`` on Linux does not block.
|
| 701 |
+
:param locate: if this is set to `True` then instead of launching the
|
| 702 |
+
application associated with the URL it will attempt to
|
| 703 |
+
launch a file manager with the file located. This
|
| 704 |
+
might have weird effects if the URL does not point to
|
| 705 |
+
the filesystem.
|
| 706 |
+
"""
|
| 707 |
+
from ._termui_impl import open_url
|
| 708 |
+
|
| 709 |
+
return open_url(url, wait=wait, locate=locate)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
# If this is provided, getchar() calls into this instead. This is used
|
| 713 |
+
# for unittesting purposes.
|
| 714 |
+
_getchar: t.Optional[t.Callable[[bool], str]] = None
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def getchar(echo: bool = False) -> str:
|
| 718 |
+
"""Fetches a single character from the terminal and returns it. This
|
| 719 |
+
will always return a unicode character and under certain rare
|
| 720 |
+
circumstances this might return more than one character. The
|
| 721 |
+
situations which more than one character is returned is when for
|
| 722 |
+
whatever reason multiple characters end up in the terminal buffer or
|
| 723 |
+
standard input was not actually a terminal.
|
| 724 |
+
|
| 725 |
+
Note that this will always read from the terminal, even if something
|
| 726 |
+
is piped into the standard input.
|
| 727 |
+
|
| 728 |
+
Note for Windows: in rare cases when typing non-ASCII characters, this
|
| 729 |
+
function might wait for a second character and then return both at once.
|
| 730 |
+
This is because certain Unicode characters look like special-key markers.
|
| 731 |
+
|
| 732 |
+
.. versionadded:: 2.0
|
| 733 |
+
|
| 734 |
+
:param echo: if set to `True`, the character read will also show up on
|
| 735 |
+
the terminal. The default is to not show it.
|
| 736 |
+
"""
|
| 737 |
+
global _getchar
|
| 738 |
+
|
| 739 |
+
if _getchar is None:
|
| 740 |
+
from ._termui_impl import getchar as f
|
| 741 |
+
|
| 742 |
+
_getchar = f
|
| 743 |
+
|
| 744 |
+
return _getchar(echo)
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def raw_terminal() -> t.ContextManager[int]:
|
| 748 |
+
from ._termui_impl import raw_terminal as f
|
| 749 |
+
|
| 750 |
+
return f()
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def pause(info: t.Optional[str] = None, err: bool = False) -> None:
|
| 754 |
+
"""This command stops execution and waits for the user to press any
|
| 755 |
+
key to continue. This is similar to the Windows batch "pause"
|
| 756 |
+
command. If the program is not run through a terminal, this command
|
| 757 |
+
will instead do nothing.
|
| 758 |
+
|
| 759 |
+
.. versionadded:: 2.0
|
| 760 |
+
|
| 761 |
+
.. versionadded:: 4.0
|
| 762 |
+
Added the `err` parameter.
|
| 763 |
+
|
| 764 |
+
:param info: The message to print before pausing. Defaults to
|
| 765 |
+
``"Press any key to continue..."``.
|
| 766 |
+
:param err: if set to message goes to ``stderr`` instead of
|
| 767 |
+
``stdout``, the same as with echo.
|
| 768 |
+
"""
|
| 769 |
+
if not isatty(sys.stdin) or not isatty(sys.stdout):
|
| 770 |
+
return
|
| 771 |
+
|
| 772 |
+
if info is None:
|
| 773 |
+
info = _("Press any key to continue...")
|
| 774 |
+
|
| 775 |
+
try:
|
| 776 |
+
if info:
|
| 777 |
+
echo(info, nl=False, err=err)
|
| 778 |
+
try:
|
| 779 |
+
getchar()
|
| 780 |
+
except (KeyboardInterrupt, EOFError):
|
| 781 |
+
pass
|
| 782 |
+
finally:
|
| 783 |
+
if info:
|
| 784 |
+
echo(err=err)
|
.venv/lib/python3.11/site-packages/click/utils.py
ADDED
|
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import sys
|
| 4 |
+
import typing as t
|
| 5 |
+
from functools import update_wrapper
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
from types import TracebackType
|
| 8 |
+
|
| 9 |
+
from ._compat import _default_text_stderr
|
| 10 |
+
from ._compat import _default_text_stdout
|
| 11 |
+
from ._compat import _find_binary_writer
|
| 12 |
+
from ._compat import auto_wrap_for_ansi
|
| 13 |
+
from ._compat import binary_streams
|
| 14 |
+
from ._compat import open_stream
|
| 15 |
+
from ._compat import should_strip_ansi
|
| 16 |
+
from ._compat import strip_ansi
|
| 17 |
+
from ._compat import text_streams
|
| 18 |
+
from ._compat import WIN
|
| 19 |
+
from .globals import resolve_color_default
|
| 20 |
+
|
| 21 |
+
if t.TYPE_CHECKING:
|
| 22 |
+
import typing_extensions as te
|
| 23 |
+
|
| 24 |
+
P = te.ParamSpec("P")
|
| 25 |
+
|
| 26 |
+
R = t.TypeVar("R")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _posixify(name: str) -> str:
|
| 30 |
+
return "-".join(name.split()).lower()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def safecall(func: "t.Callable[P, R]") -> "t.Callable[P, t.Optional[R]]":
|
| 34 |
+
"""Wraps a function so that it swallows exceptions."""
|
| 35 |
+
|
| 36 |
+
def wrapper(*args: "P.args", **kwargs: "P.kwargs") -> t.Optional[R]:
|
| 37 |
+
try:
|
| 38 |
+
return func(*args, **kwargs)
|
| 39 |
+
except Exception:
|
| 40 |
+
pass
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
return update_wrapper(wrapper, func)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def make_str(value: t.Any) -> str:
|
| 47 |
+
"""Converts a value into a valid string."""
|
| 48 |
+
if isinstance(value, bytes):
|
| 49 |
+
try:
|
| 50 |
+
return value.decode(sys.getfilesystemencoding())
|
| 51 |
+
except UnicodeError:
|
| 52 |
+
return value.decode("utf-8", "replace")
|
| 53 |
+
return str(value)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def make_default_short_help(help: str, max_length: int = 45) -> str:
|
| 57 |
+
"""Returns a condensed version of help string."""
|
| 58 |
+
# Consider only the first paragraph.
|
| 59 |
+
paragraph_end = help.find("\n\n")
|
| 60 |
+
|
| 61 |
+
if paragraph_end != -1:
|
| 62 |
+
help = help[:paragraph_end]
|
| 63 |
+
|
| 64 |
+
# Collapse newlines, tabs, and spaces.
|
| 65 |
+
words = help.split()
|
| 66 |
+
|
| 67 |
+
if not words:
|
| 68 |
+
return ""
|
| 69 |
+
|
| 70 |
+
# The first paragraph started with a "no rewrap" marker, ignore it.
|
| 71 |
+
if words[0] == "\b":
|
| 72 |
+
words = words[1:]
|
| 73 |
+
|
| 74 |
+
total_length = 0
|
| 75 |
+
last_index = len(words) - 1
|
| 76 |
+
|
| 77 |
+
for i, word in enumerate(words):
|
| 78 |
+
total_length += len(word) + (i > 0)
|
| 79 |
+
|
| 80 |
+
if total_length > max_length: # too long, truncate
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
if word[-1] == ".": # sentence end, truncate without "..."
|
| 84 |
+
return " ".join(words[: i + 1])
|
| 85 |
+
|
| 86 |
+
if total_length == max_length and i != last_index:
|
| 87 |
+
break # not at sentence end, truncate with "..."
|
| 88 |
+
else:
|
| 89 |
+
return " ".join(words) # no truncation needed
|
| 90 |
+
|
| 91 |
+
# Account for the length of the suffix.
|
| 92 |
+
total_length += len("...")
|
| 93 |
+
|
| 94 |
+
# remove words until the length is short enough
|
| 95 |
+
while i > 0:
|
| 96 |
+
total_length -= len(words[i]) + (i > 0)
|
| 97 |
+
|
| 98 |
+
if total_length <= max_length:
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
i -= 1
|
| 102 |
+
|
| 103 |
+
return " ".join(words[:i]) + "..."
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class LazyFile:
|
| 107 |
+
"""A lazy file works like a regular file but it does not fully open
|
| 108 |
+
the file but it does perform some basic checks early to see if the
|
| 109 |
+
filename parameter does make sense. This is useful for safely opening
|
| 110 |
+
files for writing.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
filename: t.Union[str, "os.PathLike[str]"],
|
| 116 |
+
mode: str = "r",
|
| 117 |
+
encoding: t.Optional[str] = None,
|
| 118 |
+
errors: t.Optional[str] = "strict",
|
| 119 |
+
atomic: bool = False,
|
| 120 |
+
):
|
| 121 |
+
self.name: str = os.fspath(filename)
|
| 122 |
+
self.mode = mode
|
| 123 |
+
self.encoding = encoding
|
| 124 |
+
self.errors = errors
|
| 125 |
+
self.atomic = atomic
|
| 126 |
+
self._f: t.Optional[t.IO[t.Any]]
|
| 127 |
+
self.should_close: bool
|
| 128 |
+
|
| 129 |
+
if self.name == "-":
|
| 130 |
+
self._f, self.should_close = open_stream(filename, mode, encoding, errors)
|
| 131 |
+
else:
|
| 132 |
+
if "r" in mode:
|
| 133 |
+
# Open and close the file in case we're opening it for
|
| 134 |
+
# reading so that we can catch at least some errors in
|
| 135 |
+
# some cases early.
|
| 136 |
+
open(filename, mode).close()
|
| 137 |
+
self._f = None
|
| 138 |
+
self.should_close = True
|
| 139 |
+
|
| 140 |
+
def __getattr__(self, name: str) -> t.Any:
|
| 141 |
+
return getattr(self.open(), name)
|
| 142 |
+
|
| 143 |
+
def __repr__(self) -> str:
|
| 144 |
+
if self._f is not None:
|
| 145 |
+
return repr(self._f)
|
| 146 |
+
return f"<unopened file '{format_filename(self.name)}' {self.mode}>"
|
| 147 |
+
|
| 148 |
+
def open(self) -> t.IO[t.Any]:
|
| 149 |
+
"""Opens the file if it's not yet open. This call might fail with
|
| 150 |
+
a :exc:`FileError`. Not handling this error will produce an error
|
| 151 |
+
that Click shows.
|
| 152 |
+
"""
|
| 153 |
+
if self._f is not None:
|
| 154 |
+
return self._f
|
| 155 |
+
try:
|
| 156 |
+
rv, self.should_close = open_stream(
|
| 157 |
+
self.name, self.mode, self.encoding, self.errors, atomic=self.atomic
|
| 158 |
+
)
|
| 159 |
+
except OSError as e:
|
| 160 |
+
from .exceptions import FileError
|
| 161 |
+
|
| 162 |
+
raise FileError(self.name, hint=e.strerror) from e
|
| 163 |
+
self._f = rv
|
| 164 |
+
return rv
|
| 165 |
+
|
| 166 |
+
def close(self) -> None:
|
| 167 |
+
"""Closes the underlying file, no matter what."""
|
| 168 |
+
if self._f is not None:
|
| 169 |
+
self._f.close()
|
| 170 |
+
|
| 171 |
+
def close_intelligently(self) -> None:
|
| 172 |
+
"""This function only closes the file if it was opened by the lazy
|
| 173 |
+
file wrapper. For instance this will never close stdin.
|
| 174 |
+
"""
|
| 175 |
+
if self.should_close:
|
| 176 |
+
self.close()
|
| 177 |
+
|
| 178 |
+
def __enter__(self) -> "LazyFile":
|
| 179 |
+
return self
|
| 180 |
+
|
| 181 |
+
def __exit__(
|
| 182 |
+
self,
|
| 183 |
+
exc_type: t.Optional[t.Type[BaseException]],
|
| 184 |
+
exc_value: t.Optional[BaseException],
|
| 185 |
+
tb: t.Optional[TracebackType],
|
| 186 |
+
) -> None:
|
| 187 |
+
self.close_intelligently()
|
| 188 |
+
|
| 189 |
+
def __iter__(self) -> t.Iterator[t.AnyStr]:
|
| 190 |
+
self.open()
|
| 191 |
+
return iter(self._f) # type: ignore
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class KeepOpenFile:
|
| 195 |
+
def __init__(self, file: t.IO[t.Any]) -> None:
|
| 196 |
+
self._file: t.IO[t.Any] = file
|
| 197 |
+
|
| 198 |
+
def __getattr__(self, name: str) -> t.Any:
|
| 199 |
+
return getattr(self._file, name)
|
| 200 |
+
|
| 201 |
+
def __enter__(self) -> "KeepOpenFile":
|
| 202 |
+
return self
|
| 203 |
+
|
| 204 |
+
def __exit__(
|
| 205 |
+
self,
|
| 206 |
+
exc_type: t.Optional[t.Type[BaseException]],
|
| 207 |
+
exc_value: t.Optional[BaseException],
|
| 208 |
+
tb: t.Optional[TracebackType],
|
| 209 |
+
) -> None:
|
| 210 |
+
pass
|
| 211 |
+
|
| 212 |
+
def __repr__(self) -> str:
|
| 213 |
+
return repr(self._file)
|
| 214 |
+
|
| 215 |
+
def __iter__(self) -> t.Iterator[t.AnyStr]:
|
| 216 |
+
return iter(self._file)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def echo(
|
| 220 |
+
message: t.Optional[t.Any] = None,
|
| 221 |
+
file: t.Optional[t.IO[t.Any]] = None,
|
| 222 |
+
nl: bool = True,
|
| 223 |
+
err: bool = False,
|
| 224 |
+
color: t.Optional[bool] = None,
|
| 225 |
+
) -> None:
|
| 226 |
+
"""Print a message and newline to stdout or a file. This should be
|
| 227 |
+
used instead of :func:`print` because it provides better support
|
| 228 |
+
for different data, files, and environments.
|
| 229 |
+
|
| 230 |
+
Compared to :func:`print`, this does the following:
|
| 231 |
+
|
| 232 |
+
- Ensures that the output encoding is not misconfigured on Linux.
|
| 233 |
+
- Supports Unicode in the Windows console.
|
| 234 |
+
- Supports writing to binary outputs, and supports writing bytes
|
| 235 |
+
to text outputs.
|
| 236 |
+
- Supports colors and styles on Windows.
|
| 237 |
+
- Removes ANSI color and style codes if the output does not look
|
| 238 |
+
like an interactive terminal.
|
| 239 |
+
- Always flushes the output.
|
| 240 |
+
|
| 241 |
+
:param message: The string or bytes to output. Other objects are
|
| 242 |
+
converted to strings.
|
| 243 |
+
:param file: The file to write to. Defaults to ``stdout``.
|
| 244 |
+
:param err: Write to ``stderr`` instead of ``stdout``.
|
| 245 |
+
:param nl: Print a newline after the message. Enabled by default.
|
| 246 |
+
:param color: Force showing or hiding colors and other styles. By
|
| 247 |
+
default Click will remove color if the output does not look like
|
| 248 |
+
an interactive terminal.
|
| 249 |
+
|
| 250 |
+
.. versionchanged:: 6.0
|
| 251 |
+
Support Unicode output on the Windows console. Click does not
|
| 252 |
+
modify ``sys.stdout``, so ``sys.stdout.write()`` and ``print()``
|
| 253 |
+
will still not support Unicode.
|
| 254 |
+
|
| 255 |
+
.. versionchanged:: 4.0
|
| 256 |
+
Added the ``color`` parameter.
|
| 257 |
+
|
| 258 |
+
.. versionadded:: 3.0
|
| 259 |
+
Added the ``err`` parameter.
|
| 260 |
+
|
| 261 |
+
.. versionchanged:: 2.0
|
| 262 |
+
Support colors on Windows if colorama is installed.
|
| 263 |
+
"""
|
| 264 |
+
if file is None:
|
| 265 |
+
if err:
|
| 266 |
+
file = _default_text_stderr()
|
| 267 |
+
else:
|
| 268 |
+
file = _default_text_stdout()
|
| 269 |
+
|
| 270 |
+
# There are no standard streams attached to write to. For example,
|
| 271 |
+
# pythonw on Windows.
|
| 272 |
+
if file is None:
|
| 273 |
+
return
|
| 274 |
+
|
| 275 |
+
# Convert non bytes/text into the native string type.
|
| 276 |
+
if message is not None and not isinstance(message, (str, bytes, bytearray)):
|
| 277 |
+
out: t.Optional[t.Union[str, bytes]] = str(message)
|
| 278 |
+
else:
|
| 279 |
+
out = message
|
| 280 |
+
|
| 281 |
+
if nl:
|
| 282 |
+
out = out or ""
|
| 283 |
+
if isinstance(out, str):
|
| 284 |
+
out += "\n"
|
| 285 |
+
else:
|
| 286 |
+
out += b"\n"
|
| 287 |
+
|
| 288 |
+
if not out:
|
| 289 |
+
file.flush()
|
| 290 |
+
return
|
| 291 |
+
|
| 292 |
+
# If there is a message and the value looks like bytes, we manually
|
| 293 |
+
# need to find the binary stream and write the message in there.
|
| 294 |
+
# This is done separately so that most stream types will work as you
|
| 295 |
+
# would expect. Eg: you can write to StringIO for other cases.
|
| 296 |
+
if isinstance(out, (bytes, bytearray)):
|
| 297 |
+
binary_file = _find_binary_writer(file)
|
| 298 |
+
|
| 299 |
+
if binary_file is not None:
|
| 300 |
+
file.flush()
|
| 301 |
+
binary_file.write(out)
|
| 302 |
+
binary_file.flush()
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
# ANSI style code support. For no message or bytes, nothing happens.
|
| 306 |
+
# When outputting to a file instead of a terminal, strip codes.
|
| 307 |
+
else:
|
| 308 |
+
color = resolve_color_default(color)
|
| 309 |
+
|
| 310 |
+
if should_strip_ansi(file, color):
|
| 311 |
+
out = strip_ansi(out)
|
| 312 |
+
elif WIN:
|
| 313 |
+
if auto_wrap_for_ansi is not None:
|
| 314 |
+
file = auto_wrap_for_ansi(file, color) # type: ignore
|
| 315 |
+
elif not color:
|
| 316 |
+
out = strip_ansi(out)
|
| 317 |
+
|
| 318 |
+
file.write(out) # type: ignore
|
| 319 |
+
file.flush()
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def get_binary_stream(name: "te.Literal['stdin', 'stdout', 'stderr']") -> t.BinaryIO:
|
| 323 |
+
"""Returns a system stream for byte processing.
|
| 324 |
+
|
| 325 |
+
:param name: the name of the stream to open. Valid names are ``'stdin'``,
|
| 326 |
+
``'stdout'`` and ``'stderr'``
|
| 327 |
+
"""
|
| 328 |
+
opener = binary_streams.get(name)
|
| 329 |
+
if opener is None:
|
| 330 |
+
raise TypeError(f"Unknown standard stream '{name}'")
|
| 331 |
+
return opener()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def get_text_stream(
|
| 335 |
+
name: "te.Literal['stdin', 'stdout', 'stderr']",
|
| 336 |
+
encoding: t.Optional[str] = None,
|
| 337 |
+
errors: t.Optional[str] = "strict",
|
| 338 |
+
) -> t.TextIO:
|
| 339 |
+
"""Returns a system stream for text processing. This usually returns
|
| 340 |
+
a wrapped stream around a binary stream returned from
|
| 341 |
+
:func:`get_binary_stream` but it also can take shortcuts for already
|
| 342 |
+
correctly configured streams.
|
| 343 |
+
|
| 344 |
+
:param name: the name of the stream to open. Valid names are ``'stdin'``,
|
| 345 |
+
``'stdout'`` and ``'stderr'``
|
| 346 |
+
:param encoding: overrides the detected default encoding.
|
| 347 |
+
:param errors: overrides the default error mode.
|
| 348 |
+
"""
|
| 349 |
+
opener = text_streams.get(name)
|
| 350 |
+
if opener is None:
|
| 351 |
+
raise TypeError(f"Unknown standard stream '{name}'")
|
| 352 |
+
return opener(encoding, errors)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def open_file(
|
| 356 |
+
filename: t.Union[str, "os.PathLike[str]"],
|
| 357 |
+
mode: str = "r",
|
| 358 |
+
encoding: t.Optional[str] = None,
|
| 359 |
+
errors: t.Optional[str] = "strict",
|
| 360 |
+
lazy: bool = False,
|
| 361 |
+
atomic: bool = False,
|
| 362 |
+
) -> t.IO[t.Any]:
|
| 363 |
+
"""Open a file, with extra behavior to handle ``'-'`` to indicate
|
| 364 |
+
a standard stream, lazy open on write, and atomic write. Similar to
|
| 365 |
+
the behavior of the :class:`~click.File` param type.
|
| 366 |
+
|
| 367 |
+
If ``'-'`` is given to open ``stdout`` or ``stdin``, the stream is
|
| 368 |
+
wrapped so that using it in a context manager will not close it.
|
| 369 |
+
This makes it possible to use the function without accidentally
|
| 370 |
+
closing a standard stream:
|
| 371 |
+
|
| 372 |
+
.. code-block:: python
|
| 373 |
+
|
| 374 |
+
with open_file(filename) as f:
|
| 375 |
+
...
|
| 376 |
+
|
| 377 |
+
:param filename: The name or Path of the file to open, or ``'-'`` for
|
| 378 |
+
``stdin``/``stdout``.
|
| 379 |
+
:param mode: The mode in which to open the file.
|
| 380 |
+
:param encoding: The encoding to decode or encode a file opened in
|
| 381 |
+
text mode.
|
| 382 |
+
:param errors: The error handling mode.
|
| 383 |
+
:param lazy: Wait to open the file until it is accessed. For read
|
| 384 |
+
mode, the file is temporarily opened to raise access errors
|
| 385 |
+
early, then closed until it is read again.
|
| 386 |
+
:param atomic: Write to a temporary file and replace the given file
|
| 387 |
+
on close.
|
| 388 |
+
|
| 389 |
+
.. versionadded:: 3.0
|
| 390 |
+
"""
|
| 391 |
+
if lazy:
|
| 392 |
+
return t.cast(
|
| 393 |
+
t.IO[t.Any], LazyFile(filename, mode, encoding, errors, atomic=atomic)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
f, should_close = open_stream(filename, mode, encoding, errors, atomic=atomic)
|
| 397 |
+
|
| 398 |
+
if not should_close:
|
| 399 |
+
f = t.cast(t.IO[t.Any], KeepOpenFile(f))
|
| 400 |
+
|
| 401 |
+
return f
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def format_filename(
|
| 405 |
+
filename: "t.Union[str, bytes, os.PathLike[str], os.PathLike[bytes]]",
|
| 406 |
+
shorten: bool = False,
|
| 407 |
+
) -> str:
|
| 408 |
+
"""Format a filename as a string for display. Ensures the filename can be
|
| 409 |
+
displayed by replacing any invalid bytes or surrogate escapes in the name
|
| 410 |
+
with the replacement character ``�``.
|
| 411 |
+
|
| 412 |
+
Invalid bytes or surrogate escapes will raise an error when written to a
|
| 413 |
+
stream with ``errors="strict"``. This will typically happen with ``stdout``
|
| 414 |
+
when the locale is something like ``en_GB.UTF-8``.
|
| 415 |
+
|
| 416 |
+
Many scenarios *are* safe to write surrogates though, due to PEP 538 and
|
| 417 |
+
PEP 540, including:
|
| 418 |
+
|
| 419 |
+
- Writing to ``stderr``, which uses ``errors="backslashreplace"``.
|
| 420 |
+
- The system has ``LANG=C.UTF-8``, ``C``, or ``POSIX``. Python opens
|
| 421 |
+
stdout and stderr with ``errors="surrogateescape"``.
|
| 422 |
+
- None of ``LANG/LC_*`` are set. Python assumes ``LANG=C.UTF-8``.
|
| 423 |
+
- Python is started in UTF-8 mode with ``PYTHONUTF8=1`` or ``-X utf8``.
|
| 424 |
+
Python opens stdout and stderr with ``errors="surrogateescape"``.
|
| 425 |
+
|
| 426 |
+
:param filename: formats a filename for UI display. This will also convert
|
| 427 |
+
the filename into unicode without failing.
|
| 428 |
+
:param shorten: this optionally shortens the filename to strip of the
|
| 429 |
+
path that leads up to it.
|
| 430 |
+
"""
|
| 431 |
+
if shorten:
|
| 432 |
+
filename = os.path.basename(filename)
|
| 433 |
+
else:
|
| 434 |
+
filename = os.fspath(filename)
|
| 435 |
+
|
| 436 |
+
if isinstance(filename, bytes):
|
| 437 |
+
filename = filename.decode(sys.getfilesystemencoding(), "replace")
|
| 438 |
+
else:
|
| 439 |
+
filename = filename.encode("utf-8", "surrogateescape").decode(
|
| 440 |
+
"utf-8", "replace"
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
return filename
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str:
|
| 447 |
+
r"""Returns the config folder for the application. The default behavior
|
| 448 |
+
is to return whatever is most appropriate for the operating system.
|
| 449 |
+
|
| 450 |
+
To give you an idea, for an app called ``"Foo Bar"``, something like
|
| 451 |
+
the following folders could be returned:
|
| 452 |
+
|
| 453 |
+
Mac OS X:
|
| 454 |
+
``~/Library/Application Support/Foo Bar``
|
| 455 |
+
Mac OS X (POSIX):
|
| 456 |
+
``~/.foo-bar``
|
| 457 |
+
Unix:
|
| 458 |
+
``~/.config/foo-bar``
|
| 459 |
+
Unix (POSIX):
|
| 460 |
+
``~/.foo-bar``
|
| 461 |
+
Windows (roaming):
|
| 462 |
+
``C:\Users\<user>\AppData\Roaming\Foo Bar``
|
| 463 |
+
Windows (not roaming):
|
| 464 |
+
``C:\Users\<user>\AppData\Local\Foo Bar``
|
| 465 |
+
|
| 466 |
+
.. versionadded:: 2.0
|
| 467 |
+
|
| 468 |
+
:param app_name: the application name. This should be properly capitalized
|
| 469 |
+
and can contain whitespace.
|
| 470 |
+
:param roaming: controls if the folder should be roaming or not on Windows.
|
| 471 |
+
Has no effect otherwise.
|
| 472 |
+
:param force_posix: if this is set to `True` then on any POSIX system the
|
| 473 |
+
folder will be stored in the home folder with a leading
|
| 474 |
+
dot instead of the XDG config home or darwin's
|
| 475 |
+
application support folder.
|
| 476 |
+
"""
|
| 477 |
+
if WIN:
|
| 478 |
+
key = "APPDATA" if roaming else "LOCALAPPDATA"
|
| 479 |
+
folder = os.environ.get(key)
|
| 480 |
+
if folder is None:
|
| 481 |
+
folder = os.path.expanduser("~")
|
| 482 |
+
return os.path.join(folder, app_name)
|
| 483 |
+
if force_posix:
|
| 484 |
+
return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}"))
|
| 485 |
+
if sys.platform == "darwin":
|
| 486 |
+
return os.path.join(
|
| 487 |
+
os.path.expanduser("~/Library/Application Support"), app_name
|
| 488 |
+
)
|
| 489 |
+
return os.path.join(
|
| 490 |
+
os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
|
| 491 |
+
_posixify(app_name),
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
class PacifyFlushWrapper:
|
| 496 |
+
"""This wrapper is used to catch and suppress BrokenPipeErrors resulting
|
| 497 |
+
from ``.flush()`` being called on broken pipe during the shutdown/final-GC
|
| 498 |
+
of the Python interpreter. Notably ``.flush()`` is always called on
|
| 499 |
+
``sys.stdout`` and ``sys.stderr``. So as to have minimal impact on any
|
| 500 |
+
other cleanup code, and the case where the underlying file is not a broken
|
| 501 |
+
pipe, all calls and attributes are proxied.
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
def __init__(self, wrapped: t.IO[t.Any]) -> None:
|
| 505 |
+
self.wrapped = wrapped
|
| 506 |
+
|
| 507 |
+
def flush(self) -> None:
|
| 508 |
+
try:
|
| 509 |
+
self.wrapped.flush()
|
| 510 |
+
except OSError as e:
|
| 511 |
+
import errno
|
| 512 |
+
|
| 513 |
+
if e.errno != errno.EPIPE:
|
| 514 |
+
raise
|
| 515 |
+
|
| 516 |
+
def __getattr__(self, attr: str) -> t.Any:
|
| 517 |
+
return getattr(self.wrapped, attr)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def _detect_program_name(
|
| 521 |
+
path: t.Optional[str] = None, _main: t.Optional[ModuleType] = None
|
| 522 |
+
) -> str:
|
| 523 |
+
"""Determine the command used to run the program, for use in help
|
| 524 |
+
text. If a file or entry point was executed, the file name is
|
| 525 |
+
returned. If ``python -m`` was used to execute a module or package,
|
| 526 |
+
``python -m name`` is returned.
|
| 527 |
+
|
| 528 |
+
This doesn't try to be too precise, the goal is to give a concise
|
| 529 |
+
name for help text. Files are only shown as their name without the
|
| 530 |
+
path. ``python`` is only shown for modules, and the full path to
|
| 531 |
+
``sys.executable`` is not shown.
|
| 532 |
+
|
| 533 |
+
:param path: The Python file being executed. Python puts this in
|
| 534 |
+
``sys.argv[0]``, which is used by default.
|
| 535 |
+
:param _main: The ``__main__`` module. This should only be passed
|
| 536 |
+
during internal testing.
|
| 537 |
+
|
| 538 |
+
.. versionadded:: 8.0
|
| 539 |
+
Based on command args detection in the Werkzeug reloader.
|
| 540 |
+
|
| 541 |
+
:meta private:
|
| 542 |
+
"""
|
| 543 |
+
if _main is None:
|
| 544 |
+
_main = sys.modules["__main__"]
|
| 545 |
+
|
| 546 |
+
if not path:
|
| 547 |
+
path = sys.argv[0]
|
| 548 |
+
|
| 549 |
+
# The value of __package__ indicates how Python was called. It may
|
| 550 |
+
# not exist if a setuptools script is installed as an egg. It may be
|
| 551 |
+
# set incorrectly for entry points created with pip on Windows.
|
| 552 |
+
# It is set to "" inside a Shiv or PEX zipapp.
|
| 553 |
+
if getattr(_main, "__package__", None) in {None, ""} or (
|
| 554 |
+
os.name == "nt"
|
| 555 |
+
and _main.__package__ == ""
|
| 556 |
+
and not os.path.exists(path)
|
| 557 |
+
and os.path.exists(f"{path}.exe")
|
| 558 |
+
):
|
| 559 |
+
# Executed a file, like "python app.py".
|
| 560 |
+
return os.path.basename(path)
|
| 561 |
+
|
| 562 |
+
# Executed a module, like "python -m example".
|
| 563 |
+
# Rewritten by Python from "-m script" to "/path/to/script.py".
|
| 564 |
+
# Need to look at main module to determine how it was executed.
|
| 565 |
+
py_module = t.cast(str, _main.__package__)
|
| 566 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
| 567 |
+
|
| 568 |
+
# A submodule like "example.cli".
|
| 569 |
+
if name != "__main__":
|
| 570 |
+
py_module = f"{py_module}.{name}"
|
| 571 |
+
|
| 572 |
+
return f"python -m {py_module.lstrip('.')}"
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def _expand_args(
|
| 576 |
+
args: t.Iterable[str],
|
| 577 |
+
*,
|
| 578 |
+
user: bool = True,
|
| 579 |
+
env: bool = True,
|
| 580 |
+
glob_recursive: bool = True,
|
| 581 |
+
) -> t.List[str]:
|
| 582 |
+
"""Simulate Unix shell expansion with Python functions.
|
| 583 |
+
|
| 584 |
+
See :func:`glob.glob`, :func:`os.path.expanduser`, and
|
| 585 |
+
:func:`os.path.expandvars`.
|
| 586 |
+
|
| 587 |
+
This is intended for use on Windows, where the shell does not do any
|
| 588 |
+
expansion. It may not exactly match what a Unix shell would do.
|
| 589 |
+
|
| 590 |
+
:param args: List of command line arguments to expand.
|
| 591 |
+
:param user: Expand user home directory.
|
| 592 |
+
:param env: Expand environment variables.
|
| 593 |
+
:param glob_recursive: ``**`` matches directories recursively.
|
| 594 |
+
|
| 595 |
+
.. versionchanged:: 8.1
|
| 596 |
+
Invalid glob patterns are treated as empty expansions rather
|
| 597 |
+
than raising an error.
|
| 598 |
+
|
| 599 |
+
.. versionadded:: 8.0
|
| 600 |
+
|
| 601 |
+
:meta private:
|
| 602 |
+
"""
|
| 603 |
+
from glob import glob
|
| 604 |
+
|
| 605 |
+
out = []
|
| 606 |
+
|
| 607 |
+
for arg in args:
|
| 608 |
+
if user:
|
| 609 |
+
arg = os.path.expanduser(arg)
|
| 610 |
+
|
| 611 |
+
if env:
|
| 612 |
+
arg = os.path.expandvars(arg)
|
| 613 |
+
|
| 614 |
+
try:
|
| 615 |
+
matches = glob(arg, recursive=glob_recursive)
|
| 616 |
+
except re.error:
|
| 617 |
+
matches = []
|
| 618 |
+
|
| 619 |
+
if not matches:
|
| 620 |
+
out.append(arg)
|
| 621 |
+
else:
|
| 622 |
+
out.extend(matches)
|
| 623 |
+
|
| 624 |
+
return out
|
.venv/lib/python3.11/site-packages/httplib2/__init__.py
ADDED
|
@@ -0,0 +1,1799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""Small, fast HTTP client library for Python."""
|
| 3 |
+
|
| 4 |
+
__author__ = "Joe Gregorio (joe@bitworking.org)"
|
| 5 |
+
__copyright__ = "Copyright 2006, Joe Gregorio"
|
| 6 |
+
__contributors__ = [
|
| 7 |
+
"Thomas Broyer (t.broyer@ltgt.net)",
|
| 8 |
+
"James Antill",
|
| 9 |
+
"Xavier Verges Farrero",
|
| 10 |
+
"Jonathan Feinberg",
|
| 11 |
+
"Blair Zajac",
|
| 12 |
+
"Sam Ruby",
|
| 13 |
+
"Louis Nyffenegger",
|
| 14 |
+
"Mark Pilgrim",
|
| 15 |
+
"Alex Yu",
|
| 16 |
+
"Lai Han",
|
| 17 |
+
]
|
| 18 |
+
__license__ = "MIT"
|
| 19 |
+
__version__ = "0.22.0"
|
| 20 |
+
|
| 21 |
+
import base64
|
| 22 |
+
import calendar
|
| 23 |
+
import copy
|
| 24 |
+
import email
|
| 25 |
+
import email.feedparser
|
| 26 |
+
from email import header
|
| 27 |
+
import email.message
|
| 28 |
+
import email.utils
|
| 29 |
+
import errno
|
| 30 |
+
from gettext import gettext as _
|
| 31 |
+
import gzip
|
| 32 |
+
from hashlib import md5 as _md5
|
| 33 |
+
from hashlib import sha1 as _sha
|
| 34 |
+
import hmac
|
| 35 |
+
import http.client
|
| 36 |
+
import io
|
| 37 |
+
import os
|
| 38 |
+
import random
|
| 39 |
+
import re
|
| 40 |
+
import socket
|
| 41 |
+
import ssl
|
| 42 |
+
import sys
|
| 43 |
+
import time
|
| 44 |
+
import urllib.parse
|
| 45 |
+
import zlib
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
import socks
|
| 49 |
+
except ImportError:
|
| 50 |
+
# TODO: remove this fallback and copypasted socksipy module upon py2/3 merge,
|
| 51 |
+
# idea is to have soft-dependency on any compatible module called socks
|
| 52 |
+
from . import socks
|
| 53 |
+
from . import auth
|
| 54 |
+
from .error import *
|
| 55 |
+
from .iri2uri import iri2uri
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def has_timeout(timeout):
|
| 59 |
+
if hasattr(socket, "_GLOBAL_DEFAULT_TIMEOUT"):
|
| 60 |
+
return timeout is not None and timeout is not socket._GLOBAL_DEFAULT_TIMEOUT
|
| 61 |
+
return timeout is not None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
__all__ = [
|
| 65 |
+
"debuglevel",
|
| 66 |
+
"FailedToDecompressContent",
|
| 67 |
+
"Http",
|
| 68 |
+
"HttpLib2Error",
|
| 69 |
+
"ProxyInfo",
|
| 70 |
+
"RedirectLimit",
|
| 71 |
+
"RedirectMissingLocation",
|
| 72 |
+
"Response",
|
| 73 |
+
"RETRIES",
|
| 74 |
+
"UnimplementedDigestAuthOptionError",
|
| 75 |
+
"UnimplementedHmacDigestAuthOptionError",
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
# The httplib debug level, set to a non-zero value to get debug output
|
| 79 |
+
debuglevel = 0
|
| 80 |
+
|
| 81 |
+
# A request will be tried 'RETRIES' times if it fails at the socket/connection level.
|
| 82 |
+
RETRIES = 2
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Open Items:
|
| 86 |
+
# -----------
|
| 87 |
+
|
| 88 |
+
# Are we removing the cached content too soon on PUT (only delete on 200 Maybe?)
|
| 89 |
+
|
| 90 |
+
# Pluggable cache storage (supports storing the cache in
|
| 91 |
+
# flat files by default. We need a plug-in architecture
|
| 92 |
+
# that can support Berkeley DB and Squid)
|
| 93 |
+
|
| 94 |
+
# == Known Issues ==
|
| 95 |
+
# Does not handle a resource that uses conneg and Last-Modified but no ETag as a cache validator.
|
| 96 |
+
# Does not handle Cache-Control: max-stale
|
| 97 |
+
# Does not use Age: headers when calculating cache freshness.
|
| 98 |
+
|
| 99 |
+
# The number of redirections to follow before giving up.
|
| 100 |
+
# Note that only GET redirects are automatically followed.
|
| 101 |
+
# Will also honor 301 requests by saving that info and never
|
| 102 |
+
# requesting that URI again.
|
| 103 |
+
DEFAULT_MAX_REDIRECTS = 5
|
| 104 |
+
|
| 105 |
+
# Which headers are hop-by-hop headers by default
|
| 106 |
+
HOP_BY_HOP = [
|
| 107 |
+
"connection",
|
| 108 |
+
"keep-alive",
|
| 109 |
+
"proxy-authenticate",
|
| 110 |
+
"proxy-authorization",
|
| 111 |
+
"te",
|
| 112 |
+
"trailers",
|
| 113 |
+
"transfer-encoding",
|
| 114 |
+
"upgrade",
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
# https://tools.ietf.org/html/rfc7231#section-8.1.3
|
| 118 |
+
SAFE_METHODS = ("GET", "HEAD", "OPTIONS", "TRACE")
|
| 119 |
+
|
| 120 |
+
# To change, assign to `Http().redirect_codes`
|
| 121 |
+
REDIRECT_CODES = frozenset((300, 301, 302, 303, 307, 308))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
from httplib2 import certs
|
| 125 |
+
|
| 126 |
+
CA_CERTS = certs.where()
|
| 127 |
+
|
| 128 |
+
# PROTOCOL_TLS is python 3.5.3+. PROTOCOL_SSLv23 is deprecated.
|
| 129 |
+
# Both PROTOCOL_TLS and PROTOCOL_SSLv23 are equivalent and means:
|
| 130 |
+
# > Selects the highest protocol version that both the client and server support.
|
| 131 |
+
# > Despite the name, this option can select “TLS” protocols as well as “SSL”.
|
| 132 |
+
# source: https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_SSLv23
|
| 133 |
+
|
| 134 |
+
# PROTOCOL_TLS_CLIENT is python 3.10.0+. PROTOCOL_TLS is deprecated.
|
| 135 |
+
# > Auto-negotiate the highest protocol version that both the client and server support, and configure the context client-side connections.
|
| 136 |
+
# > The protocol enables CERT_REQUIRED and check_hostname by default.
|
| 137 |
+
# source: https://docs.python.org/3.10/library/ssl.html#ssl.PROTOCOL_TLS
|
| 138 |
+
|
| 139 |
+
DEFAULT_TLS_VERSION = getattr(ssl, "PROTOCOL_TLS_CLIENT", None) or getattr(ssl, "PROTOCOL_TLS", None) or getattr(ssl, "PROTOCOL_SSLv23")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _build_ssl_context(
|
| 143 |
+
disable_ssl_certificate_validation,
|
| 144 |
+
ca_certs,
|
| 145 |
+
cert_file=None,
|
| 146 |
+
key_file=None,
|
| 147 |
+
maximum_version=None,
|
| 148 |
+
minimum_version=None,
|
| 149 |
+
key_password=None,
|
| 150 |
+
):
|
| 151 |
+
if not hasattr(ssl, "SSLContext"):
|
| 152 |
+
raise RuntimeError("httplib2 requires Python 3.2+ for ssl.SSLContext")
|
| 153 |
+
|
| 154 |
+
context = ssl.SSLContext(DEFAULT_TLS_VERSION)
|
| 155 |
+
# check_hostname and verify_mode should be set in opposite order during disable
|
| 156 |
+
# https://bugs.python.org/issue31431
|
| 157 |
+
if disable_ssl_certificate_validation and hasattr(context, "check_hostname"):
|
| 158 |
+
context.check_hostname = not disable_ssl_certificate_validation
|
| 159 |
+
context.verify_mode = ssl.CERT_NONE if disable_ssl_certificate_validation else ssl.CERT_REQUIRED
|
| 160 |
+
|
| 161 |
+
# SSLContext.maximum_version and SSLContext.minimum_version are python 3.7+.
|
| 162 |
+
# source: https://docs.python.org/3/library/ssl.html#ssl.SSLContext.maximum_version
|
| 163 |
+
if maximum_version is not None:
|
| 164 |
+
if hasattr(context, "maximum_version"):
|
| 165 |
+
if isinstance(maximum_version, str):
|
| 166 |
+
maximum_version = getattr(ssl.TLSVersion, maximum_version)
|
| 167 |
+
context.maximum_version = maximum_version
|
| 168 |
+
else:
|
| 169 |
+
raise RuntimeError("setting tls_maximum_version requires Python 3.7 and OpenSSL 1.1 or newer")
|
| 170 |
+
if minimum_version is not None:
|
| 171 |
+
if hasattr(context, "minimum_version"):
|
| 172 |
+
if isinstance(minimum_version, str):
|
| 173 |
+
minimum_version = getattr(ssl.TLSVersion, minimum_version)
|
| 174 |
+
context.minimum_version = minimum_version
|
| 175 |
+
else:
|
| 176 |
+
raise RuntimeError("setting tls_minimum_version requires Python 3.7 and OpenSSL 1.1 or newer")
|
| 177 |
+
# check_hostname requires python 3.4+
|
| 178 |
+
# we will perform the equivalent in HTTPSConnectionWithTimeout.connect() by calling ssl.match_hostname
|
| 179 |
+
# if check_hostname is not supported.
|
| 180 |
+
if hasattr(context, "check_hostname"):
|
| 181 |
+
context.check_hostname = not disable_ssl_certificate_validation
|
| 182 |
+
|
| 183 |
+
context.load_verify_locations(ca_certs)
|
| 184 |
+
|
| 185 |
+
if cert_file:
|
| 186 |
+
context.load_cert_chain(cert_file, key_file, key_password)
|
| 187 |
+
|
| 188 |
+
return context
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _get_end2end_headers(response):
|
| 192 |
+
hopbyhop = list(HOP_BY_HOP)
|
| 193 |
+
hopbyhop.extend([x.strip() for x in response.get("connection", "").split(",")])
|
| 194 |
+
return [header for header in list(response.keys()) if header not in hopbyhop]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
_missing = object()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _errno_from_exception(e):
|
| 201 |
+
# TODO python 3.11+ cheap try: return e.errno except AttributeError: pass
|
| 202 |
+
errno = getattr(e, "errno", _missing)
|
| 203 |
+
if errno is not _missing:
|
| 204 |
+
return errno
|
| 205 |
+
|
| 206 |
+
# socket.error and common wrap in .args
|
| 207 |
+
args = getattr(e, "args", None)
|
| 208 |
+
if args:
|
| 209 |
+
return _errno_from_exception(args[0])
|
| 210 |
+
|
| 211 |
+
# pysocks.ProxyError wraps in .socket_err
|
| 212 |
+
# https://github.com/httplib2/httplib2/pull/202
|
| 213 |
+
socket_err = getattr(e, "socket_err", None)
|
| 214 |
+
if socket_err:
|
| 215 |
+
return _errno_from_exception(socket_err)
|
| 216 |
+
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
URI = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def parse_uri(uri):
|
| 224 |
+
"""Parses a URI using the regex given in Appendix B of RFC 3986.
|
| 225 |
+
|
| 226 |
+
(scheme, authority, path, query, fragment) = parse_uri(uri)
|
| 227 |
+
"""
|
| 228 |
+
groups = URI.match(uri).groups()
|
| 229 |
+
return (groups[1], groups[3], groups[4], groups[6], groups[8])
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def urlnorm(uri):
|
| 233 |
+
(scheme, authority, path, query, fragment) = parse_uri(uri)
|
| 234 |
+
if not scheme or not authority:
|
| 235 |
+
raise RelativeURIError("Only absolute URIs are allowed. uri = %s" % uri)
|
| 236 |
+
authority = authority.lower()
|
| 237 |
+
scheme = scheme.lower()
|
| 238 |
+
if not path:
|
| 239 |
+
path = "/"
|
| 240 |
+
# Could do syntax based normalization of the URI before
|
| 241 |
+
# computing the digest. See Section 6.2.2 of Std 66.
|
| 242 |
+
request_uri = query and "?".join([path, query]) or path
|
| 243 |
+
scheme = scheme.lower()
|
| 244 |
+
defrag_uri = scheme + "://" + authority + request_uri
|
| 245 |
+
return scheme, authority, request_uri, defrag_uri
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# Cache filename construction (original borrowed from Venus http://intertwingly.net/code/venus/)
|
| 249 |
+
re_url_scheme = re.compile(r"^\w+://")
|
| 250 |
+
re_unsafe = re.compile(r"[^\w\-_.()=!]+", re.ASCII)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def safename(filename):
|
| 254 |
+
"""Return a filename suitable for the cache.
|
| 255 |
+
Strips dangerous and common characters to create a filename we
|
| 256 |
+
can use to store the cache in.
|
| 257 |
+
"""
|
| 258 |
+
if isinstance(filename, bytes):
|
| 259 |
+
filename_bytes = filename
|
| 260 |
+
filename = filename.decode("utf-8")
|
| 261 |
+
else:
|
| 262 |
+
filename_bytes = filename.encode("utf-8")
|
| 263 |
+
filemd5 = _md5(filename_bytes).hexdigest()
|
| 264 |
+
filename = re_url_scheme.sub("", filename)
|
| 265 |
+
filename = re_unsafe.sub("", filename)
|
| 266 |
+
|
| 267 |
+
# limit length of filename (vital for Windows)
|
| 268 |
+
# https://github.com/httplib2/httplib2/pull/74
|
| 269 |
+
# C:\Users\ <username> \AppData\Local\Temp\ <safe_filename> , <md5>
|
| 270 |
+
# 9 chars + max 104 chars + 20 chars + x + 1 + 32 = max 259 chars
|
| 271 |
+
# Thus max safe filename x = 93 chars. Let it be 90 to make a round sum:
|
| 272 |
+
filename = filename[:90]
|
| 273 |
+
|
| 274 |
+
return ",".join((filename, filemd5))
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
NORMALIZE_SPACE = re.compile(r"(?:\r\n)?[ \t]+")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _normalize_headers(headers):
|
| 281 |
+
return dict(
|
| 282 |
+
[
|
| 283 |
+
(_convert_byte_str(key).lower(), NORMALIZE_SPACE.sub(_convert_byte_str(value), " ").strip(),)
|
| 284 |
+
for (key, value) in headers.items()
|
| 285 |
+
]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _convert_byte_str(s):
|
| 290 |
+
if not isinstance(s, str):
|
| 291 |
+
return str(s, "utf-8")
|
| 292 |
+
return s
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _parse_cache_control(headers):
|
| 296 |
+
retval = {}
|
| 297 |
+
if "cache-control" in headers:
|
| 298 |
+
parts = headers["cache-control"].split(",")
|
| 299 |
+
parts_with_args = [
|
| 300 |
+
tuple([x.strip().lower() for x in part.split("=", 1)]) for part in parts if -1 != part.find("=")
|
| 301 |
+
]
|
| 302 |
+
parts_wo_args = [(name.strip().lower(), 1) for name in parts if -1 == name.find("=")]
|
| 303 |
+
retval = dict(parts_with_args + parts_wo_args)
|
| 304 |
+
return retval
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# Whether to use a strict mode to parse WWW-Authenticate headers
|
| 308 |
+
# Might lead to bad results in case of ill-formed header value,
|
| 309 |
+
# so disabled by default, falling back to relaxed parsing.
|
| 310 |
+
# Set to true to turn on, useful for testing servers.
|
| 311 |
+
USE_WWW_AUTH_STRICT_PARSING = 0
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _entry_disposition(response_headers, request_headers):
|
| 315 |
+
"""Determine freshness from the Date, Expires and Cache-Control headers.
|
| 316 |
+
|
| 317 |
+
We don't handle the following:
|
| 318 |
+
|
| 319 |
+
1. Cache-Control: max-stale
|
| 320 |
+
2. Age: headers are not used in the calculations.
|
| 321 |
+
|
| 322 |
+
Not that this algorithm is simpler than you might think
|
| 323 |
+
because we are operating as a private (non-shared) cache.
|
| 324 |
+
This lets us ignore 's-maxage'. We can also ignore
|
| 325 |
+
'proxy-invalidate' since we aren't a proxy.
|
| 326 |
+
We will never return a stale document as
|
| 327 |
+
fresh as a design decision, and thus the non-implementation
|
| 328 |
+
of 'max-stale'. This also lets us safely ignore 'must-revalidate'
|
| 329 |
+
since we operate as if every server has sent 'must-revalidate'.
|
| 330 |
+
Since we are private we get to ignore both 'public' and
|
| 331 |
+
'private' parameters. We also ignore 'no-transform' since
|
| 332 |
+
we don't do any transformations.
|
| 333 |
+
The 'no-store' parameter is handled at a higher level.
|
| 334 |
+
So the only Cache-Control parameters we look at are:
|
| 335 |
+
|
| 336 |
+
no-cache
|
| 337 |
+
only-if-cached
|
| 338 |
+
max-age
|
| 339 |
+
min-fresh
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
retval = "STALE"
|
| 343 |
+
cc = _parse_cache_control(request_headers)
|
| 344 |
+
cc_response = _parse_cache_control(response_headers)
|
| 345 |
+
|
| 346 |
+
if "pragma" in request_headers and request_headers["pragma"].lower().find("no-cache") != -1:
|
| 347 |
+
retval = "TRANSPARENT"
|
| 348 |
+
if "cache-control" not in request_headers:
|
| 349 |
+
request_headers["cache-control"] = "no-cache"
|
| 350 |
+
elif "no-cache" in cc:
|
| 351 |
+
retval = "TRANSPARENT"
|
| 352 |
+
elif "no-cache" in cc_response:
|
| 353 |
+
retval = "STALE"
|
| 354 |
+
elif "only-if-cached" in cc:
|
| 355 |
+
retval = "FRESH"
|
| 356 |
+
elif "date" in response_headers:
|
| 357 |
+
date = calendar.timegm(email.utils.parsedate_tz(response_headers["date"]))
|
| 358 |
+
now = time.time()
|
| 359 |
+
current_age = max(0, now - date)
|
| 360 |
+
if "max-age" in cc_response:
|
| 361 |
+
try:
|
| 362 |
+
freshness_lifetime = int(cc_response["max-age"])
|
| 363 |
+
except ValueError:
|
| 364 |
+
freshness_lifetime = 0
|
| 365 |
+
elif "expires" in response_headers:
|
| 366 |
+
expires = email.utils.parsedate_tz(response_headers["expires"])
|
| 367 |
+
if None == expires:
|
| 368 |
+
freshness_lifetime = 0
|
| 369 |
+
else:
|
| 370 |
+
freshness_lifetime = max(0, calendar.timegm(expires) - date)
|
| 371 |
+
else:
|
| 372 |
+
freshness_lifetime = 0
|
| 373 |
+
if "max-age" in cc:
|
| 374 |
+
try:
|
| 375 |
+
freshness_lifetime = int(cc["max-age"])
|
| 376 |
+
except ValueError:
|
| 377 |
+
freshness_lifetime = 0
|
| 378 |
+
if "min-fresh" in cc:
|
| 379 |
+
try:
|
| 380 |
+
min_fresh = int(cc["min-fresh"])
|
| 381 |
+
except ValueError:
|
| 382 |
+
min_fresh = 0
|
| 383 |
+
current_age += min_fresh
|
| 384 |
+
if freshness_lifetime > current_age:
|
| 385 |
+
retval = "FRESH"
|
| 386 |
+
return retval
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def _decompressContent(response, new_content):
|
| 390 |
+
content = new_content
|
| 391 |
+
try:
|
| 392 |
+
encoding = response.get("content-encoding", None)
|
| 393 |
+
if encoding in ["gzip", "deflate"]:
|
| 394 |
+
if encoding == "gzip":
|
| 395 |
+
content = gzip.GzipFile(fileobj=io.BytesIO(new_content)).read()
|
| 396 |
+
if encoding == "deflate":
|
| 397 |
+
try:
|
| 398 |
+
content = zlib.decompress(content, zlib.MAX_WBITS)
|
| 399 |
+
except (IOError, zlib.error):
|
| 400 |
+
content = zlib.decompress(content, -zlib.MAX_WBITS)
|
| 401 |
+
response["content-length"] = str(len(content))
|
| 402 |
+
# Record the historical presence of the encoding in a way the won't interfere.
|
| 403 |
+
response["-content-encoding"] = response["content-encoding"]
|
| 404 |
+
del response["content-encoding"]
|
| 405 |
+
except (IOError, zlib.error):
|
| 406 |
+
content = ""
|
| 407 |
+
raise FailedToDecompressContent(
|
| 408 |
+
_("Content purported to be compressed with %s but failed to decompress.") % response.get("content-encoding"),
|
| 409 |
+
response,
|
| 410 |
+
content,
|
| 411 |
+
)
|
| 412 |
+
return content
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def _bind_write_headers(msg):
|
| 416 |
+
def _write_headers(self):
|
| 417 |
+
# Self refers to the Generator object.
|
| 418 |
+
for h, v in msg.items():
|
| 419 |
+
print("%s:" % h, end=" ", file=self._fp)
|
| 420 |
+
if isinstance(v, header.Header):
|
| 421 |
+
print(v.encode(maxlinelen=self._maxheaderlen), file=self._fp)
|
| 422 |
+
else:
|
| 423 |
+
# email.Header got lots of smarts, so use it.
|
| 424 |
+
headers = header.Header(v, maxlinelen=self._maxheaderlen, charset="utf-8", header_name=h)
|
| 425 |
+
print(headers.encode(), file=self._fp)
|
| 426 |
+
# A blank line always separates headers from body.
|
| 427 |
+
print(file=self._fp)
|
| 428 |
+
|
| 429 |
+
return _write_headers
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _updateCache(request_headers, response_headers, content, cache, cachekey):
|
| 433 |
+
if cachekey:
|
| 434 |
+
cc = _parse_cache_control(request_headers)
|
| 435 |
+
cc_response = _parse_cache_control(response_headers)
|
| 436 |
+
if "no-store" in cc or "no-store" in cc_response:
|
| 437 |
+
cache.delete(cachekey)
|
| 438 |
+
else:
|
| 439 |
+
info = email.message.Message()
|
| 440 |
+
for key, value in response_headers.items():
|
| 441 |
+
if key not in ["status", "content-encoding", "transfer-encoding"]:
|
| 442 |
+
info[key] = value
|
| 443 |
+
|
| 444 |
+
# Add annotations to the cache to indicate what headers
|
| 445 |
+
# are variant for this request.
|
| 446 |
+
vary = response_headers.get("vary", None)
|
| 447 |
+
if vary:
|
| 448 |
+
vary_headers = vary.lower().replace(" ", "").split(",")
|
| 449 |
+
for header in vary_headers:
|
| 450 |
+
key = "-varied-%s" % header
|
| 451 |
+
try:
|
| 452 |
+
info[key] = request_headers[header]
|
| 453 |
+
except KeyError:
|
| 454 |
+
pass
|
| 455 |
+
|
| 456 |
+
status = response_headers.status
|
| 457 |
+
if status == 304:
|
| 458 |
+
status = 200
|
| 459 |
+
|
| 460 |
+
status_header = "status: %d\r\n" % status
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
header_str = info.as_string()
|
| 464 |
+
except UnicodeEncodeError:
|
| 465 |
+
setattr(info, "_write_headers", _bind_write_headers(info))
|
| 466 |
+
header_str = info.as_string()
|
| 467 |
+
|
| 468 |
+
header_str = re.sub("\r(?!\n)|(?<!\r)\n", "\r\n", header_str)
|
| 469 |
+
text = b"".join([status_header.encode("utf-8"), header_str.encode("utf-8"), content])
|
| 470 |
+
|
| 471 |
+
cache.set(cachekey, text)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def _cnonce():
|
| 475 |
+
dig = _md5(
|
| 476 |
+
("%s:%s" % (time.ctime(), ["0123456789"[random.randrange(0, 9)] for i in range(20)])).encode("utf-8")
|
| 477 |
+
).hexdigest()
|
| 478 |
+
return dig[:16]
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def _wsse_username_token(cnonce, iso_now, password):
|
| 482 |
+
return (
|
| 483 |
+
base64.b64encode(_sha(("%s%s%s" % (cnonce, iso_now, password)).encode("utf-8")).digest()).strip().decode("utf-8")
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# For credentials we need two things, first
|
| 488 |
+
# a pool of credential to try (not necesarily tied to BAsic, Digest, etc.)
|
| 489 |
+
# Then we also need a list of URIs that have already demanded authentication
|
| 490 |
+
# That list is tricky since sub-URIs can take the same auth, or the
|
| 491 |
+
# auth scheme may change as you descend the tree.
|
| 492 |
+
# So we also need each Auth instance to be able to tell us
|
| 493 |
+
# how close to the 'top' it is.
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
class Authentication(object):
|
| 497 |
+
def __init__(self, credentials, host, request_uri, headers, response, content, http):
|
| 498 |
+
(scheme, authority, path, query, fragment) = parse_uri(request_uri)
|
| 499 |
+
self.path = path
|
| 500 |
+
self.host = host
|
| 501 |
+
self.credentials = credentials
|
| 502 |
+
self.http = http
|
| 503 |
+
|
| 504 |
+
def depth(self, request_uri):
|
| 505 |
+
(scheme, authority, path, query, fragment) = parse_uri(request_uri)
|
| 506 |
+
return request_uri[len(self.path) :].count("/")
|
| 507 |
+
|
| 508 |
+
def inscope(self, host, request_uri):
|
| 509 |
+
# XXX Should we normalize the request_uri?
|
| 510 |
+
(scheme, authority, path, query, fragment) = parse_uri(request_uri)
|
| 511 |
+
return (host == self.host) and path.startswith(self.path)
|
| 512 |
+
|
| 513 |
+
def request(self, method, request_uri, headers, content):
|
| 514 |
+
"""Modify the request headers to add the appropriate
|
| 515 |
+
Authorization header. Over-rise this in sub-classes."""
|
| 516 |
+
pass
|
| 517 |
+
|
| 518 |
+
def response(self, response, content):
|
| 519 |
+
"""Gives us a chance to update with new nonces
|
| 520 |
+
or such returned from the last authorized response.
|
| 521 |
+
Over-rise this in sub-classes if necessary.
|
| 522 |
+
|
| 523 |
+
Return TRUE is the request is to be retried, for
|
| 524 |
+
example Digest may return stale=true.
|
| 525 |
+
"""
|
| 526 |
+
return False
|
| 527 |
+
|
| 528 |
+
def __eq__(self, auth):
|
| 529 |
+
return False
|
| 530 |
+
|
| 531 |
+
def __ne__(self, auth):
|
| 532 |
+
return True
|
| 533 |
+
|
| 534 |
+
def __lt__(self, auth):
|
| 535 |
+
return True
|
| 536 |
+
|
| 537 |
+
def __gt__(self, auth):
|
| 538 |
+
return False
|
| 539 |
+
|
| 540 |
+
def __le__(self, auth):
|
| 541 |
+
return True
|
| 542 |
+
|
| 543 |
+
def __ge__(self, auth):
|
| 544 |
+
return False
|
| 545 |
+
|
| 546 |
+
def __bool__(self):
|
| 547 |
+
return True
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
class BasicAuthentication(Authentication):
|
| 551 |
+
def __init__(self, credentials, host, request_uri, headers, response, content, http):
|
| 552 |
+
Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
|
| 553 |
+
|
| 554 |
+
def request(self, method, request_uri, headers, content):
|
| 555 |
+
"""Modify the request headers to add the appropriate
|
| 556 |
+
Authorization header."""
|
| 557 |
+
headers["authorization"] = "Basic " + base64.b64encode(
|
| 558 |
+
("%s:%s" % self.credentials).encode("utf-8")
|
| 559 |
+
).strip().decode("utf-8")
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class DigestAuthentication(Authentication):
|
| 563 |
+
"""Only do qop='auth' and MD5, since that
|
| 564 |
+
is all Apache currently implements"""
|
| 565 |
+
|
| 566 |
+
def __init__(self, credentials, host, request_uri, headers, response, content, http):
|
| 567 |
+
Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
|
| 568 |
+
self.challenge = auth._parse_www_authenticate(response, "www-authenticate")["digest"]
|
| 569 |
+
qop = self.challenge.get("qop", "auth")
|
| 570 |
+
self.challenge["qop"] = ("auth" in [x.strip() for x in qop.split()]) and "auth" or None
|
| 571 |
+
if self.challenge["qop"] is None:
|
| 572 |
+
raise UnimplementedDigestAuthOptionError(_("Unsupported value for qop: %s." % qop))
|
| 573 |
+
self.challenge["algorithm"] = self.challenge.get("algorithm", "MD5").upper()
|
| 574 |
+
if self.challenge["algorithm"] != "MD5":
|
| 575 |
+
raise UnimplementedDigestAuthOptionError(
|
| 576 |
+
_("Unsupported value for algorithm: %s." % self.challenge["algorithm"])
|
| 577 |
+
)
|
| 578 |
+
self.A1 = "".join([self.credentials[0], ":", self.challenge["realm"], ":", self.credentials[1],])
|
| 579 |
+
self.challenge["nc"] = 1
|
| 580 |
+
|
| 581 |
+
def request(self, method, request_uri, headers, content, cnonce=None):
|
| 582 |
+
"""Modify the request headers"""
|
| 583 |
+
H = lambda x: _md5(x.encode("utf-8")).hexdigest()
|
| 584 |
+
KD = lambda s, d: H("%s:%s" % (s, d))
|
| 585 |
+
A2 = "".join([method, ":", request_uri])
|
| 586 |
+
self.challenge["cnonce"] = cnonce or _cnonce()
|
| 587 |
+
request_digest = '"%s"' % KD(
|
| 588 |
+
H(self.A1),
|
| 589 |
+
"%s:%s:%s:%s:%s"
|
| 590 |
+
% (
|
| 591 |
+
self.challenge["nonce"],
|
| 592 |
+
"%08x" % self.challenge["nc"],
|
| 593 |
+
self.challenge["cnonce"],
|
| 594 |
+
self.challenge["qop"],
|
| 595 |
+
H(A2),
|
| 596 |
+
),
|
| 597 |
+
)
|
| 598 |
+
headers["authorization"] = (
|
| 599 |
+
'Digest username="%s", realm="%s", nonce="%s", '
|
| 600 |
+
'uri="%s", algorithm=%s, response=%s, qop=%s, '
|
| 601 |
+
'nc=%08x, cnonce="%s"'
|
| 602 |
+
) % (
|
| 603 |
+
self.credentials[0],
|
| 604 |
+
self.challenge["realm"],
|
| 605 |
+
self.challenge["nonce"],
|
| 606 |
+
request_uri,
|
| 607 |
+
self.challenge["algorithm"],
|
| 608 |
+
request_digest,
|
| 609 |
+
self.challenge["qop"],
|
| 610 |
+
self.challenge["nc"],
|
| 611 |
+
self.challenge["cnonce"],
|
| 612 |
+
)
|
| 613 |
+
if self.challenge.get("opaque"):
|
| 614 |
+
headers["authorization"] += ', opaque="%s"' % self.challenge["opaque"]
|
| 615 |
+
self.challenge["nc"] += 1
|
| 616 |
+
|
| 617 |
+
def response(self, response, content):
|
| 618 |
+
if "authentication-info" not in response:
|
| 619 |
+
challenge = auth._parse_www_authenticate(response, "www-authenticate").get("digest", {})
|
| 620 |
+
if "true" == challenge.get("stale"):
|
| 621 |
+
self.challenge["nonce"] = challenge["nonce"]
|
| 622 |
+
self.challenge["nc"] = 1
|
| 623 |
+
return True
|
| 624 |
+
else:
|
| 625 |
+
updated_challenge = auth._parse_authentication_info(response, "authentication-info")
|
| 626 |
+
|
| 627 |
+
if "nextnonce" in updated_challenge:
|
| 628 |
+
self.challenge["nonce"] = updated_challenge["nextnonce"]
|
| 629 |
+
self.challenge["nc"] = 1
|
| 630 |
+
return False
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class HmacDigestAuthentication(Authentication):
|
| 634 |
+
"""Adapted from Robert Sayre's code and DigestAuthentication above."""
|
| 635 |
+
|
| 636 |
+
__author__ = "Thomas Broyer (t.broyer@ltgt.net)"
|
| 637 |
+
|
| 638 |
+
def __init__(self, credentials, host, request_uri, headers, response, content, http):
|
| 639 |
+
Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
|
| 640 |
+
challenge = auth._parse_www_authenticate(response, "www-authenticate")
|
| 641 |
+
self.challenge = challenge["hmacdigest"]
|
| 642 |
+
# TODO: self.challenge['domain']
|
| 643 |
+
self.challenge["reason"] = self.challenge.get("reason", "unauthorized")
|
| 644 |
+
if self.challenge["reason"] not in ["unauthorized", "integrity"]:
|
| 645 |
+
self.challenge["reason"] = "unauthorized"
|
| 646 |
+
self.challenge["salt"] = self.challenge.get("salt", "")
|
| 647 |
+
if not self.challenge.get("snonce"):
|
| 648 |
+
raise UnimplementedHmacDigestAuthOptionError(
|
| 649 |
+
_("The challenge doesn't contain a server nonce, or this one is empty.")
|
| 650 |
+
)
|
| 651 |
+
self.challenge["algorithm"] = self.challenge.get("algorithm", "HMAC-SHA-1")
|
| 652 |
+
if self.challenge["algorithm"] not in ["HMAC-SHA-1", "HMAC-MD5"]:
|
| 653 |
+
raise UnimplementedHmacDigestAuthOptionError(
|
| 654 |
+
_("Unsupported value for algorithm: %s." % self.challenge["algorithm"])
|
| 655 |
+
)
|
| 656 |
+
self.challenge["pw-algorithm"] = self.challenge.get("pw-algorithm", "SHA-1")
|
| 657 |
+
if self.challenge["pw-algorithm"] not in ["SHA-1", "MD5"]:
|
| 658 |
+
raise UnimplementedHmacDigestAuthOptionError(
|
| 659 |
+
_("Unsupported value for pw-algorithm: %s." % self.challenge["pw-algorithm"])
|
| 660 |
+
)
|
| 661 |
+
if self.challenge["algorithm"] == "HMAC-MD5":
|
| 662 |
+
self.hashmod = _md5
|
| 663 |
+
else:
|
| 664 |
+
self.hashmod = _sha
|
| 665 |
+
if self.challenge["pw-algorithm"] == "MD5":
|
| 666 |
+
self.pwhashmod = _md5
|
| 667 |
+
else:
|
| 668 |
+
self.pwhashmod = _sha
|
| 669 |
+
self.key = "".join(
|
| 670 |
+
[
|
| 671 |
+
self.credentials[0],
|
| 672 |
+
":",
|
| 673 |
+
self.pwhashmod.new("".join([self.credentials[1], self.challenge["salt"]])).hexdigest().lower(),
|
| 674 |
+
":",
|
| 675 |
+
self.challenge["realm"],
|
| 676 |
+
]
|
| 677 |
+
)
|
| 678 |
+
self.key = self.pwhashmod.new(self.key).hexdigest().lower()
|
| 679 |
+
|
| 680 |
+
def request(self, method, request_uri, headers, content):
|
| 681 |
+
"""Modify the request headers"""
|
| 682 |
+
keys = _get_end2end_headers(headers)
|
| 683 |
+
keylist = "".join(["%s " % k for k in keys])
|
| 684 |
+
headers_val = "".join([headers[k] for k in keys])
|
| 685 |
+
created = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
| 686 |
+
cnonce = _cnonce()
|
| 687 |
+
request_digest = "%s:%s:%s:%s:%s" % (method, request_uri, cnonce, self.challenge["snonce"], headers_val,)
|
| 688 |
+
request_digest = hmac.new(self.key, request_digest, self.hashmod).hexdigest().lower()
|
| 689 |
+
headers["authorization"] = (
|
| 690 |
+
'HMACDigest username="%s", realm="%s", snonce="%s",'
|
| 691 |
+
' cnonce="%s", uri="%s", created="%s", '
|
| 692 |
+
'response="%s", headers="%s"'
|
| 693 |
+
) % (
|
| 694 |
+
self.credentials[0],
|
| 695 |
+
self.challenge["realm"],
|
| 696 |
+
self.challenge["snonce"],
|
| 697 |
+
cnonce,
|
| 698 |
+
request_uri,
|
| 699 |
+
created,
|
| 700 |
+
request_digest,
|
| 701 |
+
keylist,
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
def response(self, response, content):
|
| 705 |
+
challenge = auth._parse_www_authenticate(response, "www-authenticate").get("hmacdigest", {})
|
| 706 |
+
if challenge.get("reason") in ["integrity", "stale"]:
|
| 707 |
+
return True
|
| 708 |
+
return False
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
class WsseAuthentication(Authentication):
|
| 712 |
+
"""This is thinly tested and should not be relied upon.
|
| 713 |
+
At this time there isn't any third party server to test against.
|
| 714 |
+
Blogger and TypePad implemented this algorithm at one point
|
| 715 |
+
but Blogger has since switched to Basic over HTTPS and
|
| 716 |
+
TypePad has implemented it wrong, by never issuing a 401
|
| 717 |
+
challenge but instead requiring your client to telepathically know that
|
| 718 |
+
their endpoint is expecting WSSE profile="UsernameToken"."""
|
| 719 |
+
|
| 720 |
+
def __init__(self, credentials, host, request_uri, headers, response, content, http):
|
| 721 |
+
Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
|
| 722 |
+
|
| 723 |
+
def request(self, method, request_uri, headers, content):
|
| 724 |
+
"""Modify the request headers to add the appropriate
|
| 725 |
+
Authorization header."""
|
| 726 |
+
headers["authorization"] = 'WSSE profile="UsernameToken"'
|
| 727 |
+
iso_now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
| 728 |
+
cnonce = _cnonce()
|
| 729 |
+
password_digest = _wsse_username_token(cnonce, iso_now, self.credentials[1])
|
| 730 |
+
headers["X-WSSE"] = ('UsernameToken Username="%s", PasswordDigest="%s", ' 'Nonce="%s", Created="%s"') % (
|
| 731 |
+
self.credentials[0],
|
| 732 |
+
password_digest,
|
| 733 |
+
cnonce,
|
| 734 |
+
iso_now,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
class GoogleLoginAuthentication(Authentication):
|
| 739 |
+
def __init__(self, credentials, host, request_uri, headers, response, content, http):
|
| 740 |
+
from urllib.parse import urlencode
|
| 741 |
+
|
| 742 |
+
Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
|
| 743 |
+
challenge = auth._parse_www_authenticate(response, "www-authenticate")
|
| 744 |
+
service = challenge["googlelogin"].get("service", "xapi")
|
| 745 |
+
# Bloggger actually returns the service in the challenge
|
| 746 |
+
# For the rest we guess based on the URI
|
| 747 |
+
if service == "xapi" and request_uri.find("calendar") > 0:
|
| 748 |
+
service = "cl"
|
| 749 |
+
# No point in guessing Base or Spreadsheet
|
| 750 |
+
# elif request_uri.find("spreadsheets") > 0:
|
| 751 |
+
# service = "wise"
|
| 752 |
+
|
| 753 |
+
auth = dict(Email=credentials[0], Passwd=credentials[1], service=service, source=headers["user-agent"],)
|
| 754 |
+
resp, content = self.http.request(
|
| 755 |
+
"https://www.google.com/accounts/ClientLogin",
|
| 756 |
+
method="POST",
|
| 757 |
+
body=urlencode(auth),
|
| 758 |
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
| 759 |
+
)
|
| 760 |
+
lines = content.split("\n")
|
| 761 |
+
d = dict([tuple(line.split("=", 1)) for line in lines if line])
|
| 762 |
+
if resp.status == 403:
|
| 763 |
+
self.Auth = ""
|
| 764 |
+
else:
|
| 765 |
+
self.Auth = d["Auth"]
|
| 766 |
+
|
| 767 |
+
def request(self, method, request_uri, headers, content):
|
| 768 |
+
"""Modify the request headers to add the appropriate
|
| 769 |
+
Authorization header."""
|
| 770 |
+
headers["authorization"] = "GoogleLogin Auth=" + self.Auth
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
AUTH_SCHEME_CLASSES = {
|
| 774 |
+
"basic": BasicAuthentication,
|
| 775 |
+
"wsse": WsseAuthentication,
|
| 776 |
+
"digest": DigestAuthentication,
|
| 777 |
+
"hmacdigest": HmacDigestAuthentication,
|
| 778 |
+
"googlelogin": GoogleLoginAuthentication,
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
AUTH_SCHEME_ORDER = ["hmacdigest", "googlelogin", "digest", "wsse", "basic"]
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
class FileCache(object):
|
| 785 |
+
"""Uses a local directory as a store for cached files.
|
| 786 |
+
Not really safe to use if multiple threads or processes are going to
|
| 787 |
+
be running on the same cache.
|
| 788 |
+
"""
|
| 789 |
+
|
| 790 |
+
def __init__(self, cache, safe=safename): # use safe=lambda x: md5.new(x).hexdigest() for the old behavior
|
| 791 |
+
self.cache = cache
|
| 792 |
+
self.safe = safe
|
| 793 |
+
if not os.path.exists(cache):
|
| 794 |
+
os.makedirs(self.cache)
|
| 795 |
+
|
| 796 |
+
def get(self, key):
|
| 797 |
+
retval = None
|
| 798 |
+
cacheFullPath = os.path.join(self.cache, self.safe(key))
|
| 799 |
+
try:
|
| 800 |
+
f = open(cacheFullPath, "rb")
|
| 801 |
+
retval = f.read()
|
| 802 |
+
f.close()
|
| 803 |
+
except IOError:
|
| 804 |
+
pass
|
| 805 |
+
return retval
|
| 806 |
+
|
| 807 |
+
def set(self, key, value):
|
| 808 |
+
cacheFullPath = os.path.join(self.cache, self.safe(key))
|
| 809 |
+
f = open(cacheFullPath, "wb")
|
| 810 |
+
f.write(value)
|
| 811 |
+
f.close()
|
| 812 |
+
|
| 813 |
+
def delete(self, key):
|
| 814 |
+
cacheFullPath = os.path.join(self.cache, self.safe(key))
|
| 815 |
+
if os.path.exists(cacheFullPath):
|
| 816 |
+
os.remove(cacheFullPath)
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
class Credentials(object):
|
| 820 |
+
def __init__(self):
|
| 821 |
+
self.credentials = []
|
| 822 |
+
|
| 823 |
+
def add(self, name, password, domain=""):
|
| 824 |
+
self.credentials.append((domain.lower(), name, password))
|
| 825 |
+
|
| 826 |
+
def clear(self):
|
| 827 |
+
self.credentials = []
|
| 828 |
+
|
| 829 |
+
def iter(self, domain):
|
| 830 |
+
for (cdomain, name, password) in self.credentials:
|
| 831 |
+
if cdomain == "" or domain == cdomain:
|
| 832 |
+
yield (name, password)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
class KeyCerts(Credentials):
|
| 836 |
+
"""Identical to Credentials except that
|
| 837 |
+
name/password are mapped to key/cert."""
|
| 838 |
+
|
| 839 |
+
def add(self, key, cert, domain, password):
|
| 840 |
+
self.credentials.append((domain.lower(), key, cert, password))
|
| 841 |
+
|
| 842 |
+
def iter(self, domain):
|
| 843 |
+
for (cdomain, key, cert, password) in self.credentials:
|
| 844 |
+
if cdomain == "" or domain == cdomain:
|
| 845 |
+
yield (key, cert, password)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
class AllHosts(object):
|
| 849 |
+
pass
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
class ProxyInfo(object):
|
| 853 |
+
"""Collect information required to use a proxy."""
|
| 854 |
+
|
| 855 |
+
bypass_hosts = ()
|
| 856 |
+
|
| 857 |
+
def __init__(
|
| 858 |
+
self, proxy_type, proxy_host, proxy_port, proxy_rdns=True, proxy_user=None, proxy_pass=None, proxy_headers=None,
|
| 859 |
+
):
|
| 860 |
+
"""Args:
|
| 861 |
+
|
| 862 |
+
proxy_type: The type of proxy server. This must be set to one of
|
| 863 |
+
socks.PROXY_TYPE_XXX constants. For example: p =
|
| 864 |
+
ProxyInfo(proxy_type=socks.PROXY_TYPE_HTTP, proxy_host='localhost',
|
| 865 |
+
proxy_port=8000)
|
| 866 |
+
proxy_host: The hostname or IP address of the proxy server.
|
| 867 |
+
proxy_port: The port that the proxy server is running on.
|
| 868 |
+
proxy_rdns: If True (default), DNS queries will not be performed
|
| 869 |
+
locally, and instead, handed to the proxy to resolve. This is useful
|
| 870 |
+
if the network does not allow resolution of non-local names. In
|
| 871 |
+
httplib2 0.9 and earlier, this defaulted to False.
|
| 872 |
+
proxy_user: The username used to authenticate with the proxy server.
|
| 873 |
+
proxy_pass: The password used to authenticate with the proxy server.
|
| 874 |
+
proxy_headers: Additional or modified headers for the proxy connect
|
| 875 |
+
request.
|
| 876 |
+
"""
|
| 877 |
+
if isinstance(proxy_user, bytes):
|
| 878 |
+
proxy_user = proxy_user.decode()
|
| 879 |
+
if isinstance(proxy_pass, bytes):
|
| 880 |
+
proxy_pass = proxy_pass.decode()
|
| 881 |
+
(
|
| 882 |
+
self.proxy_type,
|
| 883 |
+
self.proxy_host,
|
| 884 |
+
self.proxy_port,
|
| 885 |
+
self.proxy_rdns,
|
| 886 |
+
self.proxy_user,
|
| 887 |
+
self.proxy_pass,
|
| 888 |
+
self.proxy_headers,
|
| 889 |
+
) = (
|
| 890 |
+
proxy_type,
|
| 891 |
+
proxy_host,
|
| 892 |
+
proxy_port,
|
| 893 |
+
proxy_rdns,
|
| 894 |
+
proxy_user,
|
| 895 |
+
proxy_pass,
|
| 896 |
+
proxy_headers,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
def astuple(self):
|
| 900 |
+
return (
|
| 901 |
+
self.proxy_type,
|
| 902 |
+
self.proxy_host,
|
| 903 |
+
self.proxy_port,
|
| 904 |
+
self.proxy_rdns,
|
| 905 |
+
self.proxy_user,
|
| 906 |
+
self.proxy_pass,
|
| 907 |
+
self.proxy_headers,
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
def isgood(self):
|
| 911 |
+
return socks and (self.proxy_host != None) and (self.proxy_port != None)
|
| 912 |
+
|
| 913 |
+
def applies_to(self, hostname):
|
| 914 |
+
return not self.bypass_host(hostname)
|
| 915 |
+
|
| 916 |
+
def bypass_host(self, hostname):
|
| 917 |
+
"""Has this host been excluded from the proxy config"""
|
| 918 |
+
if self.bypass_hosts is AllHosts:
|
| 919 |
+
return True
|
| 920 |
+
|
| 921 |
+
hostname = "." + hostname.lstrip(".")
|
| 922 |
+
for skip_name in self.bypass_hosts:
|
| 923 |
+
# *.suffix
|
| 924 |
+
if skip_name.startswith(".") and hostname.endswith(skip_name):
|
| 925 |
+
return True
|
| 926 |
+
# exact match
|
| 927 |
+
if hostname == "." + skip_name:
|
| 928 |
+
return True
|
| 929 |
+
return False
|
| 930 |
+
|
| 931 |
+
def __repr__(self):
|
| 932 |
+
return (
|
| 933 |
+
"<ProxyInfo type={p.proxy_type} "
|
| 934 |
+
"host:port={p.proxy_host}:{p.proxy_port} rdns={p.proxy_rdns}"
|
| 935 |
+
+ " user={p.proxy_user} headers={p.proxy_headers}>"
|
| 936 |
+
).format(p=self)
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
def proxy_info_from_environment(method="http"):
|
| 940 |
+
"""Read proxy info from the environment variables.
|
| 941 |
+
"""
|
| 942 |
+
if method not in ("http", "https"):
|
| 943 |
+
return
|
| 944 |
+
|
| 945 |
+
env_var = method + "_proxy"
|
| 946 |
+
url = os.environ.get(env_var, os.environ.get(env_var.upper()))
|
| 947 |
+
if not url:
|
| 948 |
+
return
|
| 949 |
+
return proxy_info_from_url(url, method, noproxy=None)
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
def proxy_info_from_url(url, method="http", noproxy=None):
|
| 953 |
+
"""Construct a ProxyInfo from a URL (such as http_proxy env var)
|
| 954 |
+
"""
|
| 955 |
+
url = urllib.parse.urlparse(url)
|
| 956 |
+
|
| 957 |
+
proxy_type = 3 # socks.PROXY_TYPE_HTTP
|
| 958 |
+
pi = ProxyInfo(
|
| 959 |
+
proxy_type=proxy_type,
|
| 960 |
+
proxy_host=url.hostname,
|
| 961 |
+
proxy_port=url.port or dict(https=443, http=80)[method],
|
| 962 |
+
proxy_user=url.username or None,
|
| 963 |
+
proxy_pass=url.password or None,
|
| 964 |
+
proxy_headers=None,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
bypass_hosts = []
|
| 968 |
+
# If not given an explicit noproxy value, respect values in env vars.
|
| 969 |
+
if noproxy is None:
|
| 970 |
+
noproxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY", ""))
|
| 971 |
+
# Special case: A single '*' character means all hosts should be bypassed.
|
| 972 |
+
if noproxy == "*":
|
| 973 |
+
bypass_hosts = AllHosts
|
| 974 |
+
elif noproxy.strip():
|
| 975 |
+
bypass_hosts = noproxy.split(",")
|
| 976 |
+
bypass_hosts = tuple(filter(bool, bypass_hosts)) # To exclude empty string.
|
| 977 |
+
|
| 978 |
+
pi.bypass_hosts = bypass_hosts
|
| 979 |
+
return pi
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
class HTTPConnectionWithTimeout(http.client.HTTPConnection):
|
| 983 |
+
"""HTTPConnection subclass that supports timeouts
|
| 984 |
+
|
| 985 |
+
HTTPConnection subclass that supports timeouts
|
| 986 |
+
|
| 987 |
+
All timeouts are in seconds. If None is passed for timeout then
|
| 988 |
+
Python's default timeout for sockets will be used. See for example
|
| 989 |
+
the docs of socket.setdefaulttimeout():
|
| 990 |
+
http://docs.python.org/library/socket.html#socket.setdefaulttimeout
|
| 991 |
+
"""
|
| 992 |
+
|
| 993 |
+
def __init__(self, host, port=None, timeout=None, proxy_info=None):
|
| 994 |
+
http.client.HTTPConnection.__init__(self, host, port=port, timeout=timeout)
|
| 995 |
+
|
| 996 |
+
self.proxy_info = proxy_info
|
| 997 |
+
if proxy_info and not isinstance(proxy_info, ProxyInfo):
|
| 998 |
+
self.proxy_info = proxy_info("http")
|
| 999 |
+
|
| 1000 |
+
def connect(self):
|
| 1001 |
+
"""Connect to the host and port specified in __init__."""
|
| 1002 |
+
if self.proxy_info and socks is None:
|
| 1003 |
+
raise ProxiesUnavailableError("Proxy support missing but proxy use was requested!")
|
| 1004 |
+
if self.proxy_info and self.proxy_info.isgood() and self.proxy_info.applies_to(self.host):
|
| 1005 |
+
use_proxy = True
|
| 1006 |
+
(
|
| 1007 |
+
proxy_type,
|
| 1008 |
+
proxy_host,
|
| 1009 |
+
proxy_port,
|
| 1010 |
+
proxy_rdns,
|
| 1011 |
+
proxy_user,
|
| 1012 |
+
proxy_pass,
|
| 1013 |
+
proxy_headers,
|
| 1014 |
+
) = self.proxy_info.astuple()
|
| 1015 |
+
|
| 1016 |
+
host = proxy_host
|
| 1017 |
+
port = proxy_port
|
| 1018 |
+
else:
|
| 1019 |
+
use_proxy = False
|
| 1020 |
+
|
| 1021 |
+
host = self.host
|
| 1022 |
+
port = self.port
|
| 1023 |
+
proxy_type = None
|
| 1024 |
+
|
| 1025 |
+
socket_err = None
|
| 1026 |
+
|
| 1027 |
+
for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
|
| 1028 |
+
af, socktype, proto, canonname, sa = res
|
| 1029 |
+
try:
|
| 1030 |
+
if use_proxy:
|
| 1031 |
+
self.sock = socks.socksocket(af, socktype, proto)
|
| 1032 |
+
self.sock.setproxy(
|
| 1033 |
+
proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass,
|
| 1034 |
+
)
|
| 1035 |
+
else:
|
| 1036 |
+
self.sock = socket.socket(af, socktype, proto)
|
| 1037 |
+
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
| 1038 |
+
if has_timeout(self.timeout):
|
| 1039 |
+
self.sock.settimeout(self.timeout)
|
| 1040 |
+
if self.debuglevel > 0:
|
| 1041 |
+
print("connect: ({0}, {1}) ************".format(self.host, self.port))
|
| 1042 |
+
if use_proxy:
|
| 1043 |
+
print(
|
| 1044 |
+
"proxy: {0} ************".format(
|
| 1045 |
+
str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers,))
|
| 1046 |
+
)
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
self.sock.connect((self.host, self.port) + sa[2:])
|
| 1050 |
+
except socket.error as e:
|
| 1051 |
+
socket_err = e
|
| 1052 |
+
if self.debuglevel > 0:
|
| 1053 |
+
print("connect fail: ({0}, {1})".format(self.host, self.port))
|
| 1054 |
+
if use_proxy:
|
| 1055 |
+
print(
|
| 1056 |
+
"proxy: {0}".format(
|
| 1057 |
+
str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers,))
|
| 1058 |
+
)
|
| 1059 |
+
)
|
| 1060 |
+
if self.sock:
|
| 1061 |
+
self.sock.close()
|
| 1062 |
+
self.sock = None
|
| 1063 |
+
continue
|
| 1064 |
+
break
|
| 1065 |
+
if not self.sock:
|
| 1066 |
+
raise socket_err
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
class HTTPSConnectionWithTimeout(http.client.HTTPSConnection):
|
| 1070 |
+
"""This class allows communication via SSL.
|
| 1071 |
+
|
| 1072 |
+
All timeouts are in seconds. If None is passed for timeout then
|
| 1073 |
+
Python's default timeout for sockets will be used. See for example
|
| 1074 |
+
the docs of socket.setdefaulttimeout():
|
| 1075 |
+
http://docs.python.org/library/socket.html#socket.setdefaulttimeout
|
| 1076 |
+
"""
|
| 1077 |
+
|
| 1078 |
+
def __init__(
|
| 1079 |
+
self,
|
| 1080 |
+
host,
|
| 1081 |
+
port=None,
|
| 1082 |
+
key_file=None,
|
| 1083 |
+
cert_file=None,
|
| 1084 |
+
timeout=None,
|
| 1085 |
+
proxy_info=None,
|
| 1086 |
+
ca_certs=None,
|
| 1087 |
+
disable_ssl_certificate_validation=False,
|
| 1088 |
+
tls_maximum_version=None,
|
| 1089 |
+
tls_minimum_version=None,
|
| 1090 |
+
key_password=None,
|
| 1091 |
+
):
|
| 1092 |
+
|
| 1093 |
+
self.disable_ssl_certificate_validation = disable_ssl_certificate_validation
|
| 1094 |
+
self.ca_certs = ca_certs if ca_certs else CA_CERTS
|
| 1095 |
+
|
| 1096 |
+
self.proxy_info = proxy_info
|
| 1097 |
+
if proxy_info and not isinstance(proxy_info, ProxyInfo):
|
| 1098 |
+
self.proxy_info = proxy_info("https")
|
| 1099 |
+
|
| 1100 |
+
context = _build_ssl_context(
|
| 1101 |
+
self.disable_ssl_certificate_validation,
|
| 1102 |
+
self.ca_certs,
|
| 1103 |
+
cert_file,
|
| 1104 |
+
key_file,
|
| 1105 |
+
maximum_version=tls_maximum_version,
|
| 1106 |
+
minimum_version=tls_minimum_version,
|
| 1107 |
+
key_password=key_password,
|
| 1108 |
+
)
|
| 1109 |
+
super(HTTPSConnectionWithTimeout, self).__init__(
|
| 1110 |
+
host, port=port, timeout=timeout, context=context,
|
| 1111 |
+
)
|
| 1112 |
+
self.key_file = key_file
|
| 1113 |
+
self.cert_file = cert_file
|
| 1114 |
+
self.key_password = key_password
|
| 1115 |
+
|
| 1116 |
+
def connect(self):
|
| 1117 |
+
"""Connect to a host on a given (SSL) port."""
|
| 1118 |
+
if self.proxy_info and self.proxy_info.isgood() and self.proxy_info.applies_to(self.host):
|
| 1119 |
+
use_proxy = True
|
| 1120 |
+
(
|
| 1121 |
+
proxy_type,
|
| 1122 |
+
proxy_host,
|
| 1123 |
+
proxy_port,
|
| 1124 |
+
proxy_rdns,
|
| 1125 |
+
proxy_user,
|
| 1126 |
+
proxy_pass,
|
| 1127 |
+
proxy_headers,
|
| 1128 |
+
) = self.proxy_info.astuple()
|
| 1129 |
+
|
| 1130 |
+
host = proxy_host
|
| 1131 |
+
port = proxy_port
|
| 1132 |
+
else:
|
| 1133 |
+
use_proxy = False
|
| 1134 |
+
|
| 1135 |
+
host = self.host
|
| 1136 |
+
port = self.port
|
| 1137 |
+
proxy_type = None
|
| 1138 |
+
proxy_headers = None
|
| 1139 |
+
|
| 1140 |
+
socket_err = None
|
| 1141 |
+
|
| 1142 |
+
address_info = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
|
| 1143 |
+
for family, socktype, proto, canonname, sockaddr in address_info:
|
| 1144 |
+
try:
|
| 1145 |
+
if use_proxy:
|
| 1146 |
+
sock = socks.socksocket(family, socktype, proto)
|
| 1147 |
+
|
| 1148 |
+
sock.setproxy(
|
| 1149 |
+
proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass,
|
| 1150 |
+
)
|
| 1151 |
+
else:
|
| 1152 |
+
sock = socket.socket(family, socktype, proto)
|
| 1153 |
+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
| 1154 |
+
if has_timeout(self.timeout):
|
| 1155 |
+
sock.settimeout(self.timeout)
|
| 1156 |
+
sock.connect((self.host, self.port))
|
| 1157 |
+
|
| 1158 |
+
self.sock = self._context.wrap_socket(sock, server_hostname=self.host)
|
| 1159 |
+
|
| 1160 |
+
# Python 3.3 compatibility: emulate the check_hostname behavior
|
| 1161 |
+
if not hasattr(self._context, "check_hostname") and not self.disable_ssl_certificate_validation:
|
| 1162 |
+
try:
|
| 1163 |
+
ssl.match_hostname(self.sock.getpeercert(), self.host)
|
| 1164 |
+
except Exception:
|
| 1165 |
+
self.sock.shutdown(socket.SHUT_RDWR)
|
| 1166 |
+
self.sock.close()
|
| 1167 |
+
raise
|
| 1168 |
+
|
| 1169 |
+
if self.debuglevel > 0:
|
| 1170 |
+
print("connect: ({0}, {1})".format(self.host, self.port))
|
| 1171 |
+
if use_proxy:
|
| 1172 |
+
print(
|
| 1173 |
+
"proxy: {0}".format(
|
| 1174 |
+
str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers,))
|
| 1175 |
+
)
|
| 1176 |
+
)
|
| 1177 |
+
except (ssl.SSLError, ssl.CertificateError) as e:
|
| 1178 |
+
if sock:
|
| 1179 |
+
sock.close()
|
| 1180 |
+
if self.sock:
|
| 1181 |
+
self.sock.close()
|
| 1182 |
+
self.sock = None
|
| 1183 |
+
raise
|
| 1184 |
+
except (socket.timeout, socket.gaierror):
|
| 1185 |
+
raise
|
| 1186 |
+
except socket.error as e:
|
| 1187 |
+
socket_err = e
|
| 1188 |
+
if self.debuglevel > 0:
|
| 1189 |
+
print("connect fail: ({0}, {1})".format(self.host, self.port))
|
| 1190 |
+
if use_proxy:
|
| 1191 |
+
print(
|
| 1192 |
+
"proxy: {0}".format(
|
| 1193 |
+
str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers,))
|
| 1194 |
+
)
|
| 1195 |
+
)
|
| 1196 |
+
if self.sock:
|
| 1197 |
+
self.sock.close()
|
| 1198 |
+
self.sock = None
|
| 1199 |
+
continue
|
| 1200 |
+
break
|
| 1201 |
+
if not self.sock:
|
| 1202 |
+
raise socket_err
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
SCHEME_TO_CONNECTION = {
|
| 1206 |
+
"http": HTTPConnectionWithTimeout,
|
| 1207 |
+
"https": HTTPSConnectionWithTimeout,
|
| 1208 |
+
}
|
| 1209 |
+
|
| 1210 |
+
|
| 1211 |
+
class Http(object):
|
| 1212 |
+
"""An HTTP client that handles:
|
| 1213 |
+
|
| 1214 |
+
- all methods
|
| 1215 |
+
- caching
|
| 1216 |
+
- ETags
|
| 1217 |
+
- compression,
|
| 1218 |
+
- HTTPS
|
| 1219 |
+
- Basic
|
| 1220 |
+
- Digest
|
| 1221 |
+
- WSSE
|
| 1222 |
+
|
| 1223 |
+
and more.
|
| 1224 |
+
"""
|
| 1225 |
+
|
| 1226 |
+
def __init__(
|
| 1227 |
+
self,
|
| 1228 |
+
cache=None,
|
| 1229 |
+
timeout=None,
|
| 1230 |
+
proxy_info=proxy_info_from_environment,
|
| 1231 |
+
ca_certs=None,
|
| 1232 |
+
disable_ssl_certificate_validation=False,
|
| 1233 |
+
tls_maximum_version=None,
|
| 1234 |
+
tls_minimum_version=None,
|
| 1235 |
+
):
|
| 1236 |
+
"""If 'cache' is a string then it is used as a directory name for
|
| 1237 |
+
a disk cache. Otherwise it must be an object that supports the
|
| 1238 |
+
same interface as FileCache.
|
| 1239 |
+
|
| 1240 |
+
All timeouts are in seconds. If None is passed for timeout
|
| 1241 |
+
then Python's default timeout for sockets will be used. See
|
| 1242 |
+
for example the docs of socket.setdefaulttimeout():
|
| 1243 |
+
http://docs.python.org/library/socket.html#socket.setdefaulttimeout
|
| 1244 |
+
|
| 1245 |
+
`proxy_info` may be:
|
| 1246 |
+
- a callable that takes the http scheme ('http' or 'https') and
|
| 1247 |
+
returns a ProxyInfo instance per request. By default, uses
|
| 1248 |
+
proxy_info_from_environment.
|
| 1249 |
+
- a ProxyInfo instance (static proxy config).
|
| 1250 |
+
- None (proxy disabled).
|
| 1251 |
+
|
| 1252 |
+
ca_certs is the path of a file containing root CA certificates for SSL
|
| 1253 |
+
server certificate validation. By default, a CA cert file bundled with
|
| 1254 |
+
httplib2 is used.
|
| 1255 |
+
|
| 1256 |
+
If disable_ssl_certificate_validation is true, SSL cert validation will
|
| 1257 |
+
not be performed.
|
| 1258 |
+
|
| 1259 |
+
tls_maximum_version / tls_minimum_version require Python 3.7+ /
|
| 1260 |
+
OpenSSL 1.1.0g+. A value of "TLSv1_3" requires OpenSSL 1.1.1+.
|
| 1261 |
+
"""
|
| 1262 |
+
self.proxy_info = proxy_info
|
| 1263 |
+
self.ca_certs = ca_certs
|
| 1264 |
+
self.disable_ssl_certificate_validation = disable_ssl_certificate_validation
|
| 1265 |
+
self.tls_maximum_version = tls_maximum_version
|
| 1266 |
+
self.tls_minimum_version = tls_minimum_version
|
| 1267 |
+
# Map domain name to an httplib connection
|
| 1268 |
+
self.connections = {}
|
| 1269 |
+
# The location of the cache, for now a directory
|
| 1270 |
+
# where cached responses are held.
|
| 1271 |
+
if cache and isinstance(cache, str):
|
| 1272 |
+
self.cache = FileCache(cache)
|
| 1273 |
+
else:
|
| 1274 |
+
self.cache = cache
|
| 1275 |
+
|
| 1276 |
+
# Name/password
|
| 1277 |
+
self.credentials = Credentials()
|
| 1278 |
+
|
| 1279 |
+
# Key/cert
|
| 1280 |
+
self.certificates = KeyCerts()
|
| 1281 |
+
|
| 1282 |
+
# authorization objects
|
| 1283 |
+
self.authorizations = []
|
| 1284 |
+
|
| 1285 |
+
# If set to False then no redirects are followed, even safe ones.
|
| 1286 |
+
self.follow_redirects = True
|
| 1287 |
+
|
| 1288 |
+
self.redirect_codes = REDIRECT_CODES
|
| 1289 |
+
|
| 1290 |
+
# Which HTTP methods do we apply optimistic concurrency to, i.e.
|
| 1291 |
+
# which methods get an "if-match:" etag header added to them.
|
| 1292 |
+
self.optimistic_concurrency_methods = ["PUT", "PATCH"]
|
| 1293 |
+
|
| 1294 |
+
self.safe_methods = list(SAFE_METHODS)
|
| 1295 |
+
|
| 1296 |
+
# If 'follow_redirects' is True, and this is set to True then
|
| 1297 |
+
# all redirecs are followed, including unsafe ones.
|
| 1298 |
+
self.follow_all_redirects = False
|
| 1299 |
+
|
| 1300 |
+
self.ignore_etag = False
|
| 1301 |
+
|
| 1302 |
+
self.force_exception_to_status_code = False
|
| 1303 |
+
|
| 1304 |
+
self.timeout = timeout
|
| 1305 |
+
|
| 1306 |
+
# Keep Authorization: headers on a redirect.
|
| 1307 |
+
self.forward_authorization_headers = False
|
| 1308 |
+
|
| 1309 |
+
def close(self):
|
| 1310 |
+
"""Close persistent connections, clear sensitive data.
|
| 1311 |
+
Not thread-safe, requires external synchronization against concurrent requests.
|
| 1312 |
+
"""
|
| 1313 |
+
existing, self.connections = self.connections, {}
|
| 1314 |
+
for _, c in existing.items():
|
| 1315 |
+
c.close()
|
| 1316 |
+
self.certificates.clear()
|
| 1317 |
+
self.clear_credentials()
|
| 1318 |
+
|
| 1319 |
+
def __getstate__(self):
|
| 1320 |
+
state_dict = copy.copy(self.__dict__)
|
| 1321 |
+
# In case request is augmented by some foreign object such as
|
| 1322 |
+
# credentials which handle auth
|
| 1323 |
+
if "request" in state_dict:
|
| 1324 |
+
del state_dict["request"]
|
| 1325 |
+
if "connections" in state_dict:
|
| 1326 |
+
del state_dict["connections"]
|
| 1327 |
+
return state_dict
|
| 1328 |
+
|
| 1329 |
+
def __setstate__(self, state):
|
| 1330 |
+
self.__dict__.update(state)
|
| 1331 |
+
self.connections = {}
|
| 1332 |
+
|
| 1333 |
+
def _auth_from_challenge(self, host, request_uri, headers, response, content):
|
| 1334 |
+
"""A generator that creates Authorization objects
|
| 1335 |
+
that can be applied to requests.
|
| 1336 |
+
"""
|
| 1337 |
+
challenges = auth._parse_www_authenticate(response, "www-authenticate")
|
| 1338 |
+
for cred in self.credentials.iter(host):
|
| 1339 |
+
for scheme in AUTH_SCHEME_ORDER:
|
| 1340 |
+
if scheme in challenges:
|
| 1341 |
+
yield AUTH_SCHEME_CLASSES[scheme](cred, host, request_uri, headers, response, content, self)
|
| 1342 |
+
|
| 1343 |
+
def add_credentials(self, name, password, domain=""):
|
| 1344 |
+
"""Add a name and password that will be used
|
| 1345 |
+
any time a request requires authentication."""
|
| 1346 |
+
self.credentials.add(name, password, domain)
|
| 1347 |
+
|
| 1348 |
+
def add_certificate(self, key, cert, domain, password=None):
|
| 1349 |
+
"""Add a key and cert that will be used
|
| 1350 |
+
any time a request requires authentication."""
|
| 1351 |
+
self.certificates.add(key, cert, domain, password)
|
| 1352 |
+
|
| 1353 |
+
def clear_credentials(self):
|
| 1354 |
+
"""Remove all the names and passwords
|
| 1355 |
+
that are used for authentication"""
|
| 1356 |
+
self.credentials.clear()
|
| 1357 |
+
self.authorizations = []
|
| 1358 |
+
|
| 1359 |
+
def _conn_request(self, conn, request_uri, method, body, headers):
|
| 1360 |
+
i = 0
|
| 1361 |
+
seen_bad_status_line = False
|
| 1362 |
+
while i < RETRIES:
|
| 1363 |
+
i += 1
|
| 1364 |
+
try:
|
| 1365 |
+
if conn.sock is None:
|
| 1366 |
+
conn.connect()
|
| 1367 |
+
conn.request(method, request_uri, body, headers)
|
| 1368 |
+
except socket.timeout:
|
| 1369 |
+
conn.close()
|
| 1370 |
+
raise
|
| 1371 |
+
except socket.gaierror:
|
| 1372 |
+
conn.close()
|
| 1373 |
+
raise ServerNotFoundError("Unable to find the server at %s" % conn.host)
|
| 1374 |
+
except socket.error as e:
|
| 1375 |
+
errno_ = _errno_from_exception(e)
|
| 1376 |
+
if errno_ in (errno.ENETUNREACH, errno.EADDRNOTAVAIL) and i < RETRIES:
|
| 1377 |
+
continue # retry on potentially transient errors
|
| 1378 |
+
raise
|
| 1379 |
+
except http.client.HTTPException:
|
| 1380 |
+
if conn.sock is None:
|
| 1381 |
+
if i < RETRIES - 1:
|
| 1382 |
+
conn.close()
|
| 1383 |
+
conn.connect()
|
| 1384 |
+
continue
|
| 1385 |
+
else:
|
| 1386 |
+
conn.close()
|
| 1387 |
+
raise
|
| 1388 |
+
if i < RETRIES - 1:
|
| 1389 |
+
conn.close()
|
| 1390 |
+
conn.connect()
|
| 1391 |
+
continue
|
| 1392 |
+
# Just because the server closed the connection doesn't apparently mean
|
| 1393 |
+
# that the server didn't send a response.
|
| 1394 |
+
pass
|
| 1395 |
+
try:
|
| 1396 |
+
response = conn.getresponse()
|
| 1397 |
+
except (http.client.BadStatusLine, http.client.ResponseNotReady):
|
| 1398 |
+
# If we get a BadStatusLine on the first try then that means
|
| 1399 |
+
# the connection just went stale, so retry regardless of the
|
| 1400 |
+
# number of RETRIES set.
|
| 1401 |
+
if not seen_bad_status_line and i == 1:
|
| 1402 |
+
i = 0
|
| 1403 |
+
seen_bad_status_line = True
|
| 1404 |
+
conn.close()
|
| 1405 |
+
conn.connect()
|
| 1406 |
+
continue
|
| 1407 |
+
else:
|
| 1408 |
+
conn.close()
|
| 1409 |
+
raise
|
| 1410 |
+
except socket.timeout:
|
| 1411 |
+
raise
|
| 1412 |
+
except (socket.error, http.client.HTTPException):
|
| 1413 |
+
conn.close()
|
| 1414 |
+
if i == 0:
|
| 1415 |
+
conn.close()
|
| 1416 |
+
conn.connect()
|
| 1417 |
+
continue
|
| 1418 |
+
else:
|
| 1419 |
+
raise
|
| 1420 |
+
else:
|
| 1421 |
+
content = b""
|
| 1422 |
+
if method == "HEAD":
|
| 1423 |
+
conn.close()
|
| 1424 |
+
else:
|
| 1425 |
+
content = response.read()
|
| 1426 |
+
response = Response(response)
|
| 1427 |
+
if method != "HEAD":
|
| 1428 |
+
content = _decompressContent(response, content)
|
| 1429 |
+
|
| 1430 |
+
break
|
| 1431 |
+
return (response, content)
|
| 1432 |
+
|
| 1433 |
+
def _request(
|
| 1434 |
+
self, conn, host, absolute_uri, request_uri, method, body, headers, redirections, cachekey,
|
| 1435 |
+
):
|
| 1436 |
+
"""Do the actual request using the connection object
|
| 1437 |
+
and also follow one level of redirects if necessary"""
|
| 1438 |
+
|
| 1439 |
+
auths = [(auth.depth(request_uri), auth) for auth in self.authorizations if auth.inscope(host, request_uri)]
|
| 1440 |
+
auth = auths and sorted(auths)[0][1] or None
|
| 1441 |
+
if auth:
|
| 1442 |
+
auth.request(method, request_uri, headers, body)
|
| 1443 |
+
|
| 1444 |
+
(response, content) = self._conn_request(conn, request_uri, method, body, headers)
|
| 1445 |
+
|
| 1446 |
+
if auth:
|
| 1447 |
+
if auth.response(response, body):
|
| 1448 |
+
auth.request(method, request_uri, headers, body)
|
| 1449 |
+
(response, content) = self._conn_request(conn, request_uri, method, body, headers)
|
| 1450 |
+
response._stale_digest = 1
|
| 1451 |
+
|
| 1452 |
+
if response.status == 401:
|
| 1453 |
+
for authorization in self._auth_from_challenge(host, request_uri, headers, response, content):
|
| 1454 |
+
authorization.request(method, request_uri, headers, body)
|
| 1455 |
+
(response, content) = self._conn_request(conn, request_uri, method, body, headers)
|
| 1456 |
+
if response.status != 401:
|
| 1457 |
+
self.authorizations.append(authorization)
|
| 1458 |
+
authorization.response(response, body)
|
| 1459 |
+
break
|
| 1460 |
+
|
| 1461 |
+
if self.follow_all_redirects or method in self.safe_methods or response.status in (303, 308):
|
| 1462 |
+
if self.follow_redirects and response.status in self.redirect_codes:
|
| 1463 |
+
# Pick out the location header and basically start from the beginning
|
| 1464 |
+
# remembering first to strip the ETag header and decrement our 'depth'
|
| 1465 |
+
if redirections:
|
| 1466 |
+
if "location" not in response and response.status != 300:
|
| 1467 |
+
raise RedirectMissingLocation(
|
| 1468 |
+
_("Redirected but the response is missing a Location: header."), response, content,
|
| 1469 |
+
)
|
| 1470 |
+
# Fix-up relative redirects (which violate an RFC 2616 MUST)
|
| 1471 |
+
if "location" in response:
|
| 1472 |
+
location = response["location"]
|
| 1473 |
+
(scheme, authority, path, query, fragment) = parse_uri(location)
|
| 1474 |
+
if authority == None:
|
| 1475 |
+
response["location"] = urllib.parse.urljoin(absolute_uri, location)
|
| 1476 |
+
if response.status == 308 or (response.status == 301 and (method in self.safe_methods)):
|
| 1477 |
+
response["-x-permanent-redirect-url"] = response["location"]
|
| 1478 |
+
if "content-location" not in response:
|
| 1479 |
+
response["content-location"] = absolute_uri
|
| 1480 |
+
_updateCache(headers, response, content, self.cache, cachekey)
|
| 1481 |
+
if "if-none-match" in headers:
|
| 1482 |
+
del headers["if-none-match"]
|
| 1483 |
+
if "if-modified-since" in headers:
|
| 1484 |
+
del headers["if-modified-since"]
|
| 1485 |
+
if "authorization" in headers and not self.forward_authorization_headers:
|
| 1486 |
+
del headers["authorization"]
|
| 1487 |
+
if "location" in response:
|
| 1488 |
+
location = response["location"]
|
| 1489 |
+
old_response = copy.deepcopy(response)
|
| 1490 |
+
if "content-location" not in old_response:
|
| 1491 |
+
old_response["content-location"] = absolute_uri
|
| 1492 |
+
redirect_method = method
|
| 1493 |
+
if response.status in [302, 303]:
|
| 1494 |
+
redirect_method = "GET"
|
| 1495 |
+
body = None
|
| 1496 |
+
(response, content) = self.request(
|
| 1497 |
+
location, method=redirect_method, body=body, headers=headers, redirections=redirections - 1,
|
| 1498 |
+
)
|
| 1499 |
+
response.previous = old_response
|
| 1500 |
+
else:
|
| 1501 |
+
raise RedirectLimit(
|
| 1502 |
+
"Redirected more times than redirection_limit allows.", response, content,
|
| 1503 |
+
)
|
| 1504 |
+
elif response.status in [200, 203] and method in self.safe_methods:
|
| 1505 |
+
# Don't cache 206's since we aren't going to handle byte range requests
|
| 1506 |
+
if "content-location" not in response:
|
| 1507 |
+
response["content-location"] = absolute_uri
|
| 1508 |
+
_updateCache(headers, response, content, self.cache, cachekey)
|
| 1509 |
+
|
| 1510 |
+
return (response, content)
|
| 1511 |
+
|
| 1512 |
+
def _normalize_headers(self, headers):
|
| 1513 |
+
return _normalize_headers(headers)
|
| 1514 |
+
|
| 1515 |
+
# Need to catch and rebrand some exceptions
|
| 1516 |
+
# Then need to optionally turn all exceptions into status codes
|
| 1517 |
+
# including all socket.* and httplib.* exceptions.
|
| 1518 |
+
|
| 1519 |
+
def request(
|
| 1520 |
+
self, uri, method="GET", body=None, headers=None, redirections=DEFAULT_MAX_REDIRECTS, connection_type=None,
|
| 1521 |
+
):
|
| 1522 |
+
""" Performs a single HTTP request.
|
| 1523 |
+
The 'uri' is the URI of the HTTP resource and can begin
|
| 1524 |
+
with either 'http' or 'https'. The value of 'uri' must be an absolute URI.
|
| 1525 |
+
|
| 1526 |
+
The 'method' is the HTTP method to perform, such as GET, POST, DELETE, etc.
|
| 1527 |
+
There is no restriction on the methods allowed.
|
| 1528 |
+
|
| 1529 |
+
The 'body' is the entity body to be sent with the request. It is a string
|
| 1530 |
+
object.
|
| 1531 |
+
|
| 1532 |
+
Any extra headers that are to be sent with the request should be provided in the
|
| 1533 |
+
'headers' dictionary.
|
| 1534 |
+
|
| 1535 |
+
The maximum number of redirect to follow before raising an
|
| 1536 |
+
exception is 'redirections. The default is 5.
|
| 1537 |
+
|
| 1538 |
+
The return value is a tuple of (response, content), the first
|
| 1539 |
+
being and instance of the 'Response' class, the second being
|
| 1540 |
+
a string that contains the response entity body.
|
| 1541 |
+
"""
|
| 1542 |
+
conn_key = ""
|
| 1543 |
+
|
| 1544 |
+
try:
|
| 1545 |
+
if headers is None:
|
| 1546 |
+
headers = {}
|
| 1547 |
+
else:
|
| 1548 |
+
headers = self._normalize_headers(headers)
|
| 1549 |
+
|
| 1550 |
+
if "user-agent" not in headers:
|
| 1551 |
+
headers["user-agent"] = "Python-httplib2/%s (gzip)" % __version__
|
| 1552 |
+
|
| 1553 |
+
uri = iri2uri(uri)
|
| 1554 |
+
# Prevent CWE-75 space injection to manipulate request via part of uri.
|
| 1555 |
+
# Prevent CWE-93 CRLF injection to modify headers via part of uri.
|
| 1556 |
+
uri = uri.replace(" ", "%20").replace("\r", "%0D").replace("\n", "%0A")
|
| 1557 |
+
|
| 1558 |
+
(scheme, authority, request_uri, defrag_uri) = urlnorm(uri)
|
| 1559 |
+
|
| 1560 |
+
conn_key = scheme + ":" + authority
|
| 1561 |
+
conn = self.connections.get(conn_key)
|
| 1562 |
+
if conn is None:
|
| 1563 |
+
if not connection_type:
|
| 1564 |
+
connection_type = SCHEME_TO_CONNECTION[scheme]
|
| 1565 |
+
certs = list(self.certificates.iter(authority))
|
| 1566 |
+
if issubclass(connection_type, HTTPSConnectionWithTimeout):
|
| 1567 |
+
if certs:
|
| 1568 |
+
conn = self.connections[conn_key] = connection_type(
|
| 1569 |
+
authority,
|
| 1570 |
+
key_file=certs[0][0],
|
| 1571 |
+
cert_file=certs[0][1],
|
| 1572 |
+
timeout=self.timeout,
|
| 1573 |
+
proxy_info=self.proxy_info,
|
| 1574 |
+
ca_certs=self.ca_certs,
|
| 1575 |
+
disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
|
| 1576 |
+
tls_maximum_version=self.tls_maximum_version,
|
| 1577 |
+
tls_minimum_version=self.tls_minimum_version,
|
| 1578 |
+
key_password=certs[0][2],
|
| 1579 |
+
)
|
| 1580 |
+
else:
|
| 1581 |
+
conn = self.connections[conn_key] = connection_type(
|
| 1582 |
+
authority,
|
| 1583 |
+
timeout=self.timeout,
|
| 1584 |
+
proxy_info=self.proxy_info,
|
| 1585 |
+
ca_certs=self.ca_certs,
|
| 1586 |
+
disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
|
| 1587 |
+
tls_maximum_version=self.tls_maximum_version,
|
| 1588 |
+
tls_minimum_version=self.tls_minimum_version,
|
| 1589 |
+
)
|
| 1590 |
+
else:
|
| 1591 |
+
conn = self.connections[conn_key] = connection_type(
|
| 1592 |
+
authority, timeout=self.timeout, proxy_info=self.proxy_info
|
| 1593 |
+
)
|
| 1594 |
+
conn.set_debuglevel(debuglevel)
|
| 1595 |
+
|
| 1596 |
+
if "range" not in headers and "accept-encoding" not in headers:
|
| 1597 |
+
headers["accept-encoding"] = "gzip, deflate"
|
| 1598 |
+
|
| 1599 |
+
info = email.message.Message()
|
| 1600 |
+
cachekey = None
|
| 1601 |
+
cached_value = None
|
| 1602 |
+
if self.cache:
|
| 1603 |
+
cachekey = defrag_uri
|
| 1604 |
+
cached_value = self.cache.get(cachekey)
|
| 1605 |
+
if cached_value:
|
| 1606 |
+
try:
|
| 1607 |
+
info, content = cached_value.split(b"\r\n\r\n", 1)
|
| 1608 |
+
info = email.message_from_bytes(info)
|
| 1609 |
+
for k, v in info.items():
|
| 1610 |
+
if v.startswith("=?") and v.endswith("?="):
|
| 1611 |
+
info.replace_header(k, str(*email.header.decode_header(v)[0]))
|
| 1612 |
+
except (IndexError, ValueError):
|
| 1613 |
+
self.cache.delete(cachekey)
|
| 1614 |
+
cachekey = None
|
| 1615 |
+
cached_value = None
|
| 1616 |
+
|
| 1617 |
+
if (
|
| 1618 |
+
method in self.optimistic_concurrency_methods
|
| 1619 |
+
and self.cache
|
| 1620 |
+
and "etag" in info
|
| 1621 |
+
and not self.ignore_etag
|
| 1622 |
+
and "if-match" not in headers
|
| 1623 |
+
):
|
| 1624 |
+
# http://www.w3.org/1999/04/Editing/
|
| 1625 |
+
headers["if-match"] = info["etag"]
|
| 1626 |
+
|
| 1627 |
+
# https://tools.ietf.org/html/rfc7234
|
| 1628 |
+
# A cache MUST invalidate the effective Request URI as well as [...] Location and Content-Location
|
| 1629 |
+
# when a non-error status code is received in response to an unsafe request method.
|
| 1630 |
+
if self.cache and cachekey and method not in self.safe_methods:
|
| 1631 |
+
self.cache.delete(cachekey)
|
| 1632 |
+
|
| 1633 |
+
# Check the vary header in the cache to see if this request
|
| 1634 |
+
# matches what varies in the cache.
|
| 1635 |
+
if method in self.safe_methods and "vary" in info:
|
| 1636 |
+
vary = info["vary"]
|
| 1637 |
+
vary_headers = vary.lower().replace(" ", "").split(",")
|
| 1638 |
+
for header in vary_headers:
|
| 1639 |
+
key = "-varied-%s" % header
|
| 1640 |
+
value = info[key]
|
| 1641 |
+
if headers.get(header, None) != value:
|
| 1642 |
+
cached_value = None
|
| 1643 |
+
break
|
| 1644 |
+
|
| 1645 |
+
if (
|
| 1646 |
+
self.cache
|
| 1647 |
+
and cached_value
|
| 1648 |
+
and (method in self.safe_methods or info["status"] == "308")
|
| 1649 |
+
and "range" not in headers
|
| 1650 |
+
):
|
| 1651 |
+
redirect_method = method
|
| 1652 |
+
if info["status"] not in ("307", "308"):
|
| 1653 |
+
redirect_method = "GET"
|
| 1654 |
+
if "-x-permanent-redirect-url" in info:
|
| 1655 |
+
# Should cached permanent redirects be counted in our redirection count? For now, yes.
|
| 1656 |
+
if redirections <= 0:
|
| 1657 |
+
raise RedirectLimit(
|
| 1658 |
+
"Redirected more times than redirection_limit allows.", {}, "",
|
| 1659 |
+
)
|
| 1660 |
+
(response, new_content) = self.request(
|
| 1661 |
+
info["-x-permanent-redirect-url"],
|
| 1662 |
+
method=redirect_method,
|
| 1663 |
+
headers=headers,
|
| 1664 |
+
redirections=redirections - 1,
|
| 1665 |
+
)
|
| 1666 |
+
response.previous = Response(info)
|
| 1667 |
+
response.previous.fromcache = True
|
| 1668 |
+
else:
|
| 1669 |
+
# Determine our course of action:
|
| 1670 |
+
# Is the cached entry fresh or stale?
|
| 1671 |
+
# Has the client requested a non-cached response?
|
| 1672 |
+
#
|
| 1673 |
+
# There seems to be three possible answers:
|
| 1674 |
+
# 1. [FRESH] Return the cache entry w/o doing a GET
|
| 1675 |
+
# 2. [STALE] Do the GET (but add in cache validators if available)
|
| 1676 |
+
# 3. [TRANSPARENT] Do a GET w/o any cache validators (Cache-Control: no-cache) on the request
|
| 1677 |
+
entry_disposition = _entry_disposition(info, headers)
|
| 1678 |
+
|
| 1679 |
+
if entry_disposition == "FRESH":
|
| 1680 |
+
response = Response(info)
|
| 1681 |
+
response.fromcache = True
|
| 1682 |
+
return (response, content)
|
| 1683 |
+
|
| 1684 |
+
if entry_disposition == "STALE":
|
| 1685 |
+
if "etag" in info and not self.ignore_etag and not "if-none-match" in headers:
|
| 1686 |
+
headers["if-none-match"] = info["etag"]
|
| 1687 |
+
if "last-modified" in info and not "last-modified" in headers:
|
| 1688 |
+
headers["if-modified-since"] = info["last-modified"]
|
| 1689 |
+
elif entry_disposition == "TRANSPARENT":
|
| 1690 |
+
pass
|
| 1691 |
+
|
| 1692 |
+
(response, new_content) = self._request(
|
| 1693 |
+
conn, authority, uri, request_uri, method, body, headers, redirections, cachekey,
|
| 1694 |
+
)
|
| 1695 |
+
|
| 1696 |
+
if response.status == 304 and method == "GET":
|
| 1697 |
+
# Rewrite the cache entry with the new end-to-end headers
|
| 1698 |
+
# Take all headers that are in response
|
| 1699 |
+
# and overwrite their values in info.
|
| 1700 |
+
# unless they are hop-by-hop, or are listed in the connection header.
|
| 1701 |
+
|
| 1702 |
+
for key in _get_end2end_headers(response):
|
| 1703 |
+
info[key] = response[key]
|
| 1704 |
+
merged_response = Response(info)
|
| 1705 |
+
if hasattr(response, "_stale_digest"):
|
| 1706 |
+
merged_response._stale_digest = response._stale_digest
|
| 1707 |
+
_updateCache(headers, merged_response, content, self.cache, cachekey)
|
| 1708 |
+
response = merged_response
|
| 1709 |
+
response.status = 200
|
| 1710 |
+
response.fromcache = True
|
| 1711 |
+
|
| 1712 |
+
elif response.status == 200:
|
| 1713 |
+
content = new_content
|
| 1714 |
+
else:
|
| 1715 |
+
self.cache.delete(cachekey)
|
| 1716 |
+
content = new_content
|
| 1717 |
+
else:
|
| 1718 |
+
cc = _parse_cache_control(headers)
|
| 1719 |
+
if "only-if-cached" in cc:
|
| 1720 |
+
info["status"] = "504"
|
| 1721 |
+
response = Response(info)
|
| 1722 |
+
content = b""
|
| 1723 |
+
else:
|
| 1724 |
+
(response, content) = self._request(
|
| 1725 |
+
conn, authority, uri, request_uri, method, body, headers, redirections, cachekey,
|
| 1726 |
+
)
|
| 1727 |
+
except Exception as e:
|
| 1728 |
+
is_timeout = isinstance(e, socket.timeout)
|
| 1729 |
+
if is_timeout:
|
| 1730 |
+
conn = self.connections.pop(conn_key, None)
|
| 1731 |
+
if conn:
|
| 1732 |
+
conn.close()
|
| 1733 |
+
|
| 1734 |
+
if self.force_exception_to_status_code:
|
| 1735 |
+
if isinstance(e, HttpLib2ErrorWithResponse):
|
| 1736 |
+
response = e.response
|
| 1737 |
+
content = e.content
|
| 1738 |
+
response.status = 500
|
| 1739 |
+
response.reason = str(e)
|
| 1740 |
+
elif isinstance(e, socket.timeout):
|
| 1741 |
+
content = b"Request Timeout"
|
| 1742 |
+
response = Response({"content-type": "text/plain", "status": "408", "content-length": len(content),})
|
| 1743 |
+
response.reason = "Request Timeout"
|
| 1744 |
+
else:
|
| 1745 |
+
content = str(e).encode("utf-8")
|
| 1746 |
+
response = Response({"content-type": "text/plain", "status": "400", "content-length": len(content),})
|
| 1747 |
+
response.reason = "Bad Request"
|
| 1748 |
+
else:
|
| 1749 |
+
raise
|
| 1750 |
+
|
| 1751 |
+
return (response, content)
|
| 1752 |
+
|
| 1753 |
+
|
| 1754 |
+
class Response(dict):
|
| 1755 |
+
"""An object more like email.message than httplib.HTTPResponse."""
|
| 1756 |
+
|
| 1757 |
+
"""Is this response from our local cache"""
|
| 1758 |
+
fromcache = False
|
| 1759 |
+
"""HTTP protocol version used by server.
|
| 1760 |
+
|
| 1761 |
+
10 for HTTP/1.0, 11 for HTTP/1.1.
|
| 1762 |
+
"""
|
| 1763 |
+
version = 11
|
| 1764 |
+
|
| 1765 |
+
"Status code returned by server. "
|
| 1766 |
+
status = 200
|
| 1767 |
+
"""Reason phrase returned by server."""
|
| 1768 |
+
reason = "Ok"
|
| 1769 |
+
|
| 1770 |
+
previous = None
|
| 1771 |
+
|
| 1772 |
+
def __init__(self, info):
|
| 1773 |
+
# info is either an email.message or
|
| 1774 |
+
# an httplib.HTTPResponse object.
|
| 1775 |
+
if isinstance(info, http.client.HTTPResponse):
|
| 1776 |
+
for key, value in info.getheaders():
|
| 1777 |
+
key = key.lower()
|
| 1778 |
+
prev = self.get(key)
|
| 1779 |
+
if prev is not None:
|
| 1780 |
+
value = ", ".join((prev, value))
|
| 1781 |
+
self[key] = value
|
| 1782 |
+
self.status = info.status
|
| 1783 |
+
self["status"] = str(self.status)
|
| 1784 |
+
self.reason = info.reason
|
| 1785 |
+
self.version = info.version
|
| 1786 |
+
elif isinstance(info, email.message.Message):
|
| 1787 |
+
for key, value in list(info.items()):
|
| 1788 |
+
self[key.lower()] = value
|
| 1789 |
+
self.status = int(self["status"])
|
| 1790 |
+
else:
|
| 1791 |
+
for key, value in info.items():
|
| 1792 |
+
self[key.lower()] = value
|
| 1793 |
+
self.status = int(self.get("status", self.status))
|
| 1794 |
+
|
| 1795 |
+
def __getattr__(self, name):
|
| 1796 |
+
if name == "dict":
|
| 1797 |
+
return self
|
| 1798 |
+
else:
|
| 1799 |
+
raise AttributeError(name)
|
.venv/lib/python3.11/site-packages/httplib2/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (80.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/httplib2/__pycache__/auth.cpython-311.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/httplib2/__pycache__/certs.cpython-311.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/httplib2/__pycache__/error.cpython-311.pyc
ADDED
|
Binary file (2.69 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/httplib2/__pycache__/iri2uri.cpython-311.pyc
ADDED
|
Binary file (4.88 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/httplib2/__pycache__/socks.cpython-311.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/httplib2/auth.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import pyparsing as pp
|
| 5 |
+
|
| 6 |
+
from .error import *
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
try: # pyparsing>=3.0.0
|
| 10 |
+
downcaseTokens = pp.common.downcaseTokens
|
| 11 |
+
except AttributeError:
|
| 12 |
+
downcaseTokens = pp.downcaseTokens
|
| 13 |
+
|
| 14 |
+
UNQUOTE_PAIRS = re.compile(r"\\(.)")
|
| 15 |
+
unquote = lambda s, l, t: UNQUOTE_PAIRS.sub(r"\1", t[0][1:-1])
|
| 16 |
+
|
| 17 |
+
# https://tools.ietf.org/html/rfc7235#section-1.2
|
| 18 |
+
# https://tools.ietf.org/html/rfc7235#appendix-B
|
| 19 |
+
tchar = "!#$%&'*+-.^_`|~" + pp.nums + pp.alphas
|
| 20 |
+
token = pp.Word(tchar).setName("token")
|
| 21 |
+
token68 = pp.Combine(pp.Word("-._~+/" + pp.nums + pp.alphas) + pp.Optional(pp.Word("=").leaveWhitespace())).setName(
|
| 22 |
+
"token68"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
quoted_string = pp.dblQuotedString.copy().setName("quoted-string").setParseAction(unquote)
|
| 26 |
+
auth_param_name = token.copy().setName("auth-param-name").addParseAction(downcaseTokens)
|
| 27 |
+
auth_param = auth_param_name + pp.Suppress("=") + (quoted_string | token)
|
| 28 |
+
params = pp.Dict(pp.delimitedList(pp.Group(auth_param)))
|
| 29 |
+
|
| 30 |
+
scheme = token("scheme")
|
| 31 |
+
challenge = scheme + (params("params") | token68("token"))
|
| 32 |
+
|
| 33 |
+
authentication_info = params.copy()
|
| 34 |
+
www_authenticate = pp.delimitedList(pp.Group(challenge))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _parse_authentication_info(headers, headername="authentication-info"):
|
| 38 |
+
"""https://tools.ietf.org/html/rfc7615
|
| 39 |
+
"""
|
| 40 |
+
header = headers.get(headername, "").strip()
|
| 41 |
+
if not header:
|
| 42 |
+
return {}
|
| 43 |
+
try:
|
| 44 |
+
parsed = authentication_info.parseString(header)
|
| 45 |
+
except pp.ParseException as ex:
|
| 46 |
+
# print(ex.explain(ex))
|
| 47 |
+
raise MalformedHeader(headername)
|
| 48 |
+
|
| 49 |
+
return parsed.asDict()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _parse_www_authenticate(headers, headername="www-authenticate"):
|
| 53 |
+
"""Returns a dictionary of dictionaries, one dict per auth_scheme."""
|
| 54 |
+
header = headers.get(headername, "").strip()
|
| 55 |
+
if not header:
|
| 56 |
+
return {}
|
| 57 |
+
try:
|
| 58 |
+
parsed = www_authenticate.parseString(header)
|
| 59 |
+
except pp.ParseException as ex:
|
| 60 |
+
# print(ex.explain(ex))
|
| 61 |
+
raise MalformedHeader(headername)
|
| 62 |
+
|
| 63 |
+
retval = {
|
| 64 |
+
challenge["scheme"].lower(): challenge["params"].asDict()
|
| 65 |
+
if "params" in challenge
|
| 66 |
+
else {"token": challenge.get("token")}
|
| 67 |
+
for challenge in parsed
|
| 68 |
+
}
|
| 69 |
+
return retval
|
.venv/lib/python3.11/site-packages/httplib2/cacerts.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/httplib2/certs.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for certificate management."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
certifi_available = False
|
| 6 |
+
certifi_where = None
|
| 7 |
+
try:
|
| 8 |
+
from certifi import where as certifi_where
|
| 9 |
+
certifi_available = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
custom_ca_locater_available = False
|
| 14 |
+
custom_ca_locater_where = None
|
| 15 |
+
try:
|
| 16 |
+
from ca_certs_locater import get as custom_ca_locater_where
|
| 17 |
+
custom_ca_locater_available = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
BUILTIN_CA_CERTS = os.path.join(
|
| 23 |
+
os.path.dirname(os.path.abspath(__file__)), "cacerts.txt"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def where():
|
| 28 |
+
env = os.environ.get("HTTPLIB2_CA_CERTS")
|
| 29 |
+
if env is not None:
|
| 30 |
+
if os.path.isfile(env):
|
| 31 |
+
return env
|
| 32 |
+
else:
|
| 33 |
+
raise RuntimeError("Environment variable HTTPLIB2_CA_CERTS not a valid file")
|
| 34 |
+
if custom_ca_locater_available:
|
| 35 |
+
return custom_ca_locater_where()
|
| 36 |
+
if certifi_available:
|
| 37 |
+
return certifi_where()
|
| 38 |
+
return BUILTIN_CA_CERTS
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
print(where())
|
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Noam Gat
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/METADATA
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: lm-format-enforcer
|
| 3 |
+
Version: 0.10.9
|
| 4 |
+
Summary: Enforce the output format (JSON Schema, Regex etc) of a language model
|
| 5 |
+
Home-page: https://github.com/noamgat/lm-format-enforcer
|
| 6 |
+
License: MIT
|
| 7 |
+
Author: Noam Gat
|
| 8 |
+
Author-email: noamgat@gmail.com
|
| 9 |
+
Requires-Python: >=3.8,<4.0
|
| 10 |
+
Classifier: Development Status :: 3 - Alpha
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 13 |
+
Classifier: Operating System :: OS Independent
|
| 14 |
+
Classifier: Programming Language :: Python
|
| 15 |
+
Classifier: Programming Language :: Python :: 3
|
| 16 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.13
|
| 22 |
+
Classifier: Programming Language :: Python :: 3 :: Only
|
| 23 |
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
| 24 |
+
Requires-Dist: interegular (>=0.3.2)
|
| 25 |
+
Requires-Dist: packaging
|
| 26 |
+
Requires-Dist: pydantic (>=1.10.8)
|
| 27 |
+
Requires-Dist: pyyaml
|
| 28 |
+
Project-URL: Bug Tracker, https://github.com/noamgat/lm-format-enforcer/issues
|
| 29 |
+
Project-URL: Documentation, https://github.com/noamgat/lm-format-enforcer
|
| 30 |
+
Project-URL: Repository, https://github.com/noamgat/lm-format-enforcer
|
| 31 |
+
Description-Content-Type: text/markdown
|
| 32 |
+
|
| 33 |
+
# lm-format-enforcer
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
**Enforce the output format (JSON Schema, Regex etc) of a language model**
|
| 38 |
+
|
| 39 |
+
<a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_llama2_enforcer.ipynb">
|
| 40 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
| 41 |
+
</a>
|
| 42 |
+
|
| 43 |
+
[](https://codecov.io/gh/noamgat/lm-format-enforcer)
|
| 44 |
+

|
| 45 |
+
|
| 46 |
+
|
| 47 |
+

|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
Language models are able to generate text, but when requiring a precise output format, they do not always perform as instructed.
|
| 51 |
+
Various prompt engineering techniques have been introduced to improve the robustness of the generated text, but they are not always sufficient.
|
| 52 |
+
This project solves the issues by filtering the tokens that the language model is allowed to generate at every timestep, thus ensuring that the output format is respected, while minimizing the limitations on the language model.
|
| 53 |
+
|
| 54 |
+
## Installation
|
| 55 |
+
```pip install lm-format-enforcer```
|
| 56 |
+
|
| 57 |
+
## Basic Tutorial
|
| 58 |
+
```python
|
| 59 |
+
# Requirements if running from Google Colab with a T4 GPU.
|
| 60 |
+
!pip install transformers torch lm-format-enforcer huggingface_hub optimum
|
| 61 |
+
!pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 62 |
+
|
| 63 |
+
from pydantic import BaseModel
|
| 64 |
+
from lmformatenforcer import JsonSchemaParser
|
| 65 |
+
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
|
| 66 |
+
from transformers import pipeline
|
| 67 |
+
|
| 68 |
+
class AnswerFormat(BaseModel):
|
| 69 |
+
first_name: str
|
| 70 |
+
last_name: str
|
| 71 |
+
year_of_birth: int
|
| 72 |
+
num_seasons_in_nba: int
|
| 73 |
+
|
| 74 |
+
# Create a transformers pipeline
|
| 75 |
+
hf_pipeline = pipeline('text-generation', model='TheBloke/Llama-2-7b-Chat-GPTQ', device_map='auto')
|
| 76 |
+
prompt = f'Here is information about Michael Jordan in the following json schema: {AnswerFormat.schema_json()} :\n'
|
| 77 |
+
|
| 78 |
+
# Create a character level parser and build a transformers prefix function from it
|
| 79 |
+
parser = JsonSchemaParser(AnswerFormat.schema())
|
| 80 |
+
prefix_function = build_transformers_prefix_allowed_tokens_fn(hf_pipeline.tokenizer, parser)
|
| 81 |
+
|
| 82 |
+
# Call the pipeline with the prefix function
|
| 83 |
+
output_dict = hf_pipeline(prompt, prefix_allowed_tokens_fn=prefix_function)
|
| 84 |
+
|
| 85 |
+
# Extract the results
|
| 86 |
+
result = output_dict[0]['generated_text'][len(prompt):]
|
| 87 |
+
print(result)
|
| 88 |
+
# {'first_name': 'Michael', 'last_name': 'Jordan', 'year_of_birth': 1963, 'num_seasons_in_nba': 15}
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## Capabilities / Advantages
|
| 92 |
+
|
| 93 |
+
- Works with any Python language model and tokenizer. Already supports [transformers](https://github.com/huggingface/transformers), [LangChain](https://python.langchain.com/docs/integrations/llms/lmformatenforcer_experimental), [LlamaIndex](https://docs.llamaindex.ai/en/latest/community/integrations/lmformatenforcer.html), [llama.cpp](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_llamacpppython_integration.ipynb), [vLLM](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_vllm_integration.ipynb), [Haystack](https://haystack.deepset.ai/integrations/lmformatenforcer), [NVIDIA TensorRT-LLM](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_trtllm_integration.ipynb) and [ExLlamaV2](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_exllamav2_integration.ipynb).
|
| 94 |
+
- Supports batched generation and beam searches - each input / beam can have different tokens filtered at every timestep
|
| 95 |
+
- Supports JSON Schema, JSON Mode (schemaless) and Regular Expression formats
|
| 96 |
+
- Supports both required and optional fields in JSON schemas
|
| 97 |
+
- Supports nested fields, arrays and dictionaries in JSON schemas
|
| 98 |
+
- Gives the language model freedom to control whitespacing and field ordering in JSON schemas, reducing hallucinations.
|
| 99 |
+
- Does not modify the high level loop of transformers API, so can be used in any scenario.
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
## Comparison to other libraries
|
| 103 |
+
|
| 104 |
+
Capability | LM Format Enforcer | [Guidance](https://github.com/guidance-ai/guidance) | [Jsonformer](https://github.com/1rgs/jsonformer) | [Outlines](https://github.com/outlines-dev/outlines)
|
| 105 |
+
:------------ | :-------------| :-------------| :------------- | :----
|
| 106 |
+
Regular Expressions | ✅ | ✅ | ❌ | ✅
|
| 107 |
+
JSON Schema | ✅ | 🟡 ([Partial conversion is possible](https://github.com/guidance-ai/guidance/blob/main/notebooks/applications/jsonformer.ipynb)) | ✅ | ✅
|
| 108 |
+
Batched Generation | ✅ | ❌ | ❌ | ✅
|
| 109 |
+
Beam Search | ✅ | ❌ | ❌ | ✅
|
| 110 |
+
Integrates into existing pipelines | ✅ | ❌ | ❌ | ✅
|
| 111 |
+
Optional JSON Fields | ✅ | ❌ | ❌ | ❌
|
| 112 |
+
LLM Controls JSON field ordering and whitespace | ✅ | ❌ | ❌ | ❌
|
| 113 |
+
JSON Schema with recursive classes | ✅ | ❌ | ✅ | ❌
|
| 114 |
+
Visual model support | [✅](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_llama32_vision_enforcer.ipynb) | ✅ | ❌ | ❌
|
| 115 |
+
|
| 116 |
+
Spotted a mistake? Library updated with new capabilities? [Open an issue!](https://github.com/noamgat/lm-format-enforcer/issues)
|
| 117 |
+
|
| 118 |
+
## Detailed example
|
| 119 |
+
|
| 120 |
+
We created a Google Colab Notebook which contains a full example of how to use this library to enforce the output format of llama2, including interpreting the intermediate results. The notebook can run on a free GPU-backed runtime in Colab.
|
| 121 |
+
|
| 122 |
+
<a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_llama2_enforcer.ipynb">
|
| 123 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
| 124 |
+
</a>
|
| 125 |
+
|
| 126 |
+
You can also [view the notebook in GitHub](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_llama2_enforcer.ipynb).
|
| 127 |
+
|
| 128 |
+
For the different ways to integrate with huggingface transformers, see the [unit tests](https://github.com/noamgat/lm-format-enforcer/blob/main/tests/test_transformerenforcer.py).
|
| 129 |
+
|
| 130 |
+
## vLLM Server Integration
|
| 131 |
+
|
| 132 |
+
LM Format Enforcer is integrated into the [vLLM](https://github.com/vllm-project/vllm) inference server. vLLM includes an [OpenAI compatible server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html) with added capabilities that allow using LM Format Enforcer without writing custom inference code.
|
| 133 |
+
|
| 134 |
+
Use LM Format Enforcer with the vLLM OpenAI Server either by adding the [vLLM command line parameter](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server):
|
| 135 |
+
|
| 136 |
+
```
|
| 137 |
+
python -m vllm.entrypoints.openai.api_server \
|
| 138 |
+
--model mistralai/Mistral-7B-Instruct-v0.2 \
|
| 139 |
+
--guided-decoding-backend lm-format-enforcer
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
Or on a per-request basis, by adding the `guided_decoding_backend` parameter to the request together with the guided decoding parameters:
|
| 143 |
+
|
| 144 |
+
```
|
| 145 |
+
completion = client.chat.completions.create(
|
| 146 |
+
model="mistralai/Mistral-7B-Instruct-v0.2",
|
| 147 |
+
messages=[
|
| 148 |
+
{"role": "user", "content": "Classify this sentiment: LMFE is wonderful!"}
|
| 149 |
+
],
|
| 150 |
+
extra_body={
|
| 151 |
+
"guided_regex": "[Pp]ositive|[Nn]egative",
|
| 152 |
+
"guided_decoding_backend": "lm-format-enforcer"
|
| 153 |
+
}
|
| 154 |
+
)
|
| 155 |
+
```
|
| 156 |
+
Json schema and choice decoding also supported via `guided_json` and `guided_choice` [extra parameters](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api).
|
| 157 |
+
|
| 158 |
+
## How does it work?
|
| 159 |
+
|
| 160 |
+
The library works by combining a character level parser and a tokenizer prefix tree into a smart token filtering mechanism.
|
| 161 |
+
|
| 162 |
+

|
| 163 |
+
|
| 164 |
+
### Character Level Parser
|
| 165 |
+
|
| 166 |
+
Parsing a string into any kind of formatter can be looked at as an implicit tree structure - at any moment in the parsing process, there is a set of allowed next characters, and if any of them are selected, there is a new set of allowed next characters, and so on.
|
| 167 |
+
|
| 168 |
+
```CharacterLevelParser``` is an interface for parsing according to this implicit structure. ```add_character()``` and ```get_allowed_characters()``` can be seen as tree traversal methods.
|
| 169 |
+
|
| 170 |
+
There are several implementations of this interface:
|
| 171 |
+
- ```JsonSchemaParser``` - parses according to a json schema (or pure json output - `JsonSchemaParser(None) will result in any json object allowed`).
|
| 172 |
+
- ```StringParser``` - forces an exact string (used mainly for diagnostics)
|
| 173 |
+
- ```RegexParser``` - parses according to a regular expression. Note that this cannot use the built in python regex and uses a manually implemented one (via the [interegular](https://pypi.org/project/interegular/) library), so it doesn't cover 100% of the regex standard.
|
| 174 |
+
### Tokenizer Prefix Tree
|
| 175 |
+
|
| 176 |
+
Given a tokenizer used by a certain language model, we can build a prefix tree of all the tokens that the language model can generate. This is done by generating all possible sequences of tokens, and adding them to the tree.
|
| 177 |
+
See ```TokenizerPrefixTree```
|
| 178 |
+
|
| 179 |
+
### Combining the two
|
| 180 |
+
|
| 181 |
+
Given a character level parser and a tokenizer prefix tree, we can elegantly and efficiently filter the tokens that the language model is allowed to generate at the next timestep:
|
| 182 |
+
We only traverse the characters that are in BOTH the character level parsing node and the tokenizer prefix tree node. This allows us to find all of the tokens (including complex subword tokens such as ```","``` which are critical in JSON parsing).
|
| 183 |
+
We do this recursively on both trees and return all of the allowed tokens. When the language model generates a token, we advance the character level parser according to the new characters, ready to filter the next timestep.
|
| 184 |
+
|
| 185 |
+
### How is this approach different? Why is it good?
|
| 186 |
+
|
| 187 |
+
This is not the first library to enforce the output format of a language model. However, other similar libraries (such as Guidance, JsonFormer and Outlines) enforce an exact output format. This means that the language model is not allowed to control whitespacing, field optionality and field ordering (in the JSON usecase). While this seems inconsequencial to humans, it means that the language model may not be generating the JSON formats that it "wants to" generate, and could put its internal states in a suboptimal value, reducing the quality of the output in later timesteps.
|
| 188 |
+
|
| 189 |
+
This forces language model users to know the details of the language model they are using (for example - were JSONs minified before pretraining?) and modify the libraries to generate the precise format.
|
| 190 |
+
|
| 191 |
+
We avoid this problem by scanning potential next tokens and allowing any token sequence that will be parsed into the output format. This means that the language model can control all of these aspects, and output the token sequence that matches its' style in the most natural way, without requiring the developer to know the details.
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
## Diagnostics - Will I always get good results?
|
| 195 |
+
|
| 196 |
+
Using this library guarantees that the output will match the format, but it does not guarantee that the output will be semantically correct. Forcing the language model to conform to a certain output may lead to increased hallucinations. Guiding the model via prompt engineering is still likely to improve results.
|
| 197 |
+
|
| 198 |
+
In order to help you understand the aggressiveness caused by the format enforcement, if you pass ```output_scores=True``` and ```return_dict_in_generate=True``` in the ```kwargs``` to ```generate_enforced()``` (these are existing optional parameters in the ```transformers``` library), you will also get a token-by-token dataframe showing which token was selected, its score, and what was the token that would have been chosen if the format enforcement was not applied. If you see that the format enforcer forced the language model to select tokens with very low weights, it is a likely contributor to the poor results. Try modifying the prompt to guide the language model to not force the format enforcer to be so aggressive.
|
| 199 |
+
|
| 200 |
+
Example using the regular expression format ``` Michael Jordan was Born in (\d)+.```
|
| 201 |
+
|
| 202 |
+
idx | generated_token | generated_token_idx | generated_score | leading_token | leading_token_idx | leading_score
|
| 203 |
+
:------------ | :-------------| :-------------| :------------- | :------------ | :-------------| :-------------
|
| 204 |
+
0 | ▁ | 29871 | 1.000000 | ▁ | 29871 | 1.000000
|
| 205 |
+
1 | Michael | 24083 | 0.000027 | ▁Sure | 18585 | 0.959473
|
| 206 |
+
2 | ▁Jordan | 18284 | 1.000000 | ▁Jordan | 18284 | 1.000000
|
| 207 |
+
3 | ▁was | 471 | 1.000000 | ▁was | 471 | 1.000000
|
| 208 |
+
4 | ▁Born | 19298 | 0.000008 | ▁born | 6345 | 1.000000
|
| 209 |
+
5 | ▁in | 297 | 0.994629 | ▁in | 297 | 0.994629
|
| 210 |
+
6 | ▁ | 29871 | 0.982422 | ▁ | 29871 | 0.982422
|
| 211 |
+
7 | 1 | 29896 | 1.000000 | 1 | 29896 | 1.000000
|
| 212 |
+
8 | 9 | 29929 | 1.000000 | 9 | 29929 | 1.000000
|
| 213 |
+
9 | 6 | 29953 | 1.000000 | 6 | 29953 | 1.000000
|
| 214 |
+
10 | 3 | 29941 | 1.000000 | 3 | 29941 | 1.000000
|
| 215 |
+
11 | . | 29889 | 0.999512 | . | 29889 | 0.999512
|
| 216 |
+
12 | ```</s>``` | 2 | 0.981445 | ```</s>``` | 2 | 0.981445
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
You can see that the model "wanted" to start the answer using ```Sure```, but the format enforcer forced it to use ```Michael``` - there was a big gap in token 1. Afterwards, almost all of the leading scores are all within the allowed token set, meaning the model likely did not hallucinate due to the token forcing. The only exception was timestep 4 - " Born" was forced while the LLM wanted to choose "born". This is a hint for the prompt engineer, to change the prompt to use a lowercase b instead.
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
## Configuration options
|
| 223 |
+
|
| 224 |
+
LM Format Enforcer makes use of several heuristics to avoid edge cases that may happen with LLM's generating structure outputs.
|
| 225 |
+
There are two ways to control these heuristics:
|
| 226 |
+
|
| 227 |
+
### Option 1: via Environment Variables
|
| 228 |
+
|
| 229 |
+
There are several environment variables that can be set, that affect the operation of the library. This method is useful when you don't want to modify the code, for example when using the library through the vLLM OpenAI server.
|
| 230 |
+
|
| 231 |
+
- `LMFE_MAX_CONSECUTIVE_WHITESPACES` - How many consecutive whitespaces are allowed when parsing JsonSchemaObjects. Default: 12.
|
| 232 |
+
- `LMFE_STRICT_JSON_FIELD_ORDER` - Should the JsonSchemaParser force the properties to appear in the same order as they appear in the 'required' list of the JsonSchema? (Note: this is consistent with the order of declaration in Pydantic models). Default: False.
|
| 233 |
+
- `LMFE_MAX_JSON_ARRAY_LENGTH` - What is the maximal JSON array length, if not specified by the schema. Helps LLM Avoid infinite loops. Default: 20.
|
| 234 |
+
|
| 235 |
+
### Option 2: via the CharacterLevelParserConfig class
|
| 236 |
+
When using the library through code, any `CharacterLevelParser` (`JsonSchemaParser`, `RegexParser` etc) constructor receives an optional `CharacterLevelParserConfig` object.
|
| 237 |
+
|
| 238 |
+
Therefore, to configure the heuristics of a single parser, instantiate a `CharacterLevelParserConfig` object, modify its values and pass it to the `CharacterLevelParser`'s constructor.
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
## Known issues and limitations
|
| 243 |
+
|
| 244 |
+
- LM Format Enforcer requires a python API to process the output logits of the language model. This means that until the APIs are extended, it can not be used with OpenAI ChatGPT and similar API based solutions.
|
| 245 |
+
- Regular expression syntax is not 100% supported. See [interegular](https://pypi.org/project/interegular/) for more details.
|
| 246 |
+
- LM Format Enforcer Regex Parser can only generate characters that exist in the tokenizer vocabulary. This may be solved in a later version, see [the issue on GitHub](https://github.com/noamgat/lm-format-enforcer/issues/13).
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
## Contributers and contributing
|
| 250 |
+
|
| 251 |
+
See [CONTRIBUTORS.md](https://github.com/noamgat/lm-format-enforcer/blob/main/CONTRIBUTORS.md) for a list of contributers.
|
| 252 |
+
|
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/RECORD
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
lm_format_enforcer-0.10.9.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 2 |
+
lm_format_enforcer-0.10.9.dist-info/LICENSE,sha256=0cAjc_naVKu0D7n6XKbBkTbRDadqEFT_HuJSgiGuG_4,1065
|
| 3 |
+
lm_format_enforcer-0.10.9.dist-info/METADATA,sha256=m23Ia1CRbDDrL7N-gnortFDjA-w9YAy_2HP6EmbH5Mg,17143
|
| 4 |
+
lm_format_enforcer-0.10.9.dist-info/RECORD,,
|
| 5 |
+
lm_format_enforcer-0.10.9.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
| 6 |
+
lmformatenforcer/__init__.py,sha256=VQlyXyCK1HC4an39hFEdX9hae5vod1wY9DjHTVnJUCM,850
|
| 7 |
+
lmformatenforcer/__pycache__/__init__.cpython-311.pyc,,
|
| 8 |
+
lmformatenforcer/__pycache__/analyzer.cpython-311.pyc,,
|
| 9 |
+
lmformatenforcer/__pycache__/characterlevelparser.cpython-311.pyc,,
|
| 10 |
+
lmformatenforcer/__pycache__/consts.cpython-311.pyc,,
|
| 11 |
+
lmformatenforcer/__pycache__/exceptions.cpython-311.pyc,,
|
| 12 |
+
lmformatenforcer/__pycache__/jsonschemaparser.cpython-311.pyc,,
|
| 13 |
+
lmformatenforcer/__pycache__/regexparser.cpython-311.pyc,,
|
| 14 |
+
lmformatenforcer/__pycache__/tokenenforcer.cpython-311.pyc,,
|
| 15 |
+
lmformatenforcer/__pycache__/tokenizerprefixtree.cpython-311.pyc,,
|
| 16 |
+
lmformatenforcer/analyzer.py,sha256=imn5kKVaY833GY1D7qfY-A-7hJ0oQz9E3tcXIww5uTY,3893
|
| 17 |
+
lmformatenforcer/characterlevelparser.py,sha256=f3MrMEC_Qfoo1uZjqtDiAiBsRWipMRc1MFDlRB-nd-Y,8557
|
| 18 |
+
lmformatenforcer/consts.py,sha256=g30hYNKwlgc564aX6ZTYbGPQsv5iHiXEPx0eSNyZbSw,1094
|
| 19 |
+
lmformatenforcer/exceptions.py,sha256=oJuEhaGwaawtgGULhTSL-mnVVNiXuaBC_UNlh6MEfX8,104
|
| 20 |
+
lmformatenforcer/external/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 21 |
+
lmformatenforcer/external/__pycache__/__init__.cpython-311.pyc,,
|
| 22 |
+
lmformatenforcer/external/__pycache__/jsonschemaobject.cpython-311.pyc,,
|
| 23 |
+
lmformatenforcer/external/__pycache__/jsonschemaobjectutil.cpython-311.pyc,,
|
| 24 |
+
lmformatenforcer/external/jsonschemaobject.py,sha256=XT1C-T8MtGzudm-hG5U1isRofLkPPoA5cA1ZjqYr7cs,10716
|
| 25 |
+
lmformatenforcer/external/jsonschemaobjectutil.py,sha256=3W0IVrUDS6k0ydp-Nhd1ymQrDi4Fe3eO8naX-o20zl0,7044
|
| 26 |
+
lmformatenforcer/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
| 27 |
+
lmformatenforcer/integrations/__pycache__/__init__.cpython-311.pyc,,
|
| 28 |
+
lmformatenforcer/integrations/__pycache__/exllamav2.cpython-311.pyc,,
|
| 29 |
+
lmformatenforcer/integrations/__pycache__/haystackv1.cpython-311.pyc,,
|
| 30 |
+
lmformatenforcer/integrations/__pycache__/haystackv2.cpython-311.pyc,,
|
| 31 |
+
lmformatenforcer/integrations/__pycache__/llamacpp.cpython-311.pyc,,
|
| 32 |
+
lmformatenforcer/integrations/__pycache__/transformers.cpython-311.pyc,,
|
| 33 |
+
lmformatenforcer/integrations/__pycache__/trtllm.cpython-311.pyc,,
|
| 34 |
+
lmformatenforcer/integrations/__pycache__/vllm.cpython-311.pyc,,
|
| 35 |
+
lmformatenforcer/integrations/exllamav2.py,sha256=usZWTriLVp4aCA34Y8niehZTrOuY3-Lpig_q6De4Kk8,2595
|
| 36 |
+
lmformatenforcer/integrations/haystackv1.py,sha256=WZ43iebe8Hag3J3ndAPfRgfxd-JWxhFVoUNvjA485ag,2820
|
| 37 |
+
lmformatenforcer/integrations/haystackv2.py,sha256=YLkgadglK4d0Mf2-gJh4UlJiChZouIPjKxk2hYqAq7o,3510
|
| 38 |
+
lmformatenforcer/integrations/llamacpp.py,sha256=W8MckPo2Jtf0_As3Aa4kCaKVQRsgPIVrKyDLX7EWvNQ,3572
|
| 39 |
+
lmformatenforcer/integrations/transformers.py,sha256=PjglkWNJYq_3lMrLOxqXSD-6vb2WTEWCMx9Czl7DYag,7034
|
| 40 |
+
lmformatenforcer/integrations/trtllm.py,sha256=zbk14D7qjrEHUTXYJRWO1Ql4-_VbdtUbwFfqj1-1BLU,3869
|
| 41 |
+
lmformatenforcer/integrations/vllm.py,sha256=mBnjXlIuVLanwiqikLi6N2ZWBgqbEH5m8sjL5arPDBM,2987
|
| 42 |
+
lmformatenforcer/jsonschemaparser.py,sha256=qLt8-Pu3_ZyTsvhYMq5oPx_kxnW19GPaTyWhOdJ_ctI,34589
|
| 43 |
+
lmformatenforcer/regexparser.py,sha256=YoPSowclmDe6M4YbBabf3im2wAyUqzpKkQXAEMYMnEQ,3888
|
| 44 |
+
lmformatenforcer/tokenenforcer.py,sha256=AnvFQWdMHZ8gC5wgBNH5AURqQtJ44cTrHGjGIKvWpAI,10240
|
| 45 |
+
lmformatenforcer/tokenizerprefixtree.py,sha256=EYtWFfzXVY49OMAKm01MvA3SDSGfZl-mwc3WJwMJOUw,6813
|
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: poetry-core 1.9.1
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py3-none-any
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv.h
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/* cudnn_adv : cuDNN's advanced and experimental features.
|
| 51 |
+
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_ADV_H_)
|
| 55 |
+
#define CUDNN_ADV_H_
|
| 56 |
+
|
| 57 |
+
#include <stdint.h>
|
| 58 |
+
|
| 59 |
+
#include "cudnn_version.h"
|
| 60 |
+
#include "cudnn_ops.h"
|
| 61 |
+
|
| 62 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 63 |
+
#define CUDNN_ADV_MAJOR 9
|
| 64 |
+
#define CUDNN_ADV_MINOR 1
|
| 65 |
+
#define CUDNN_ADV_PATCH 0
|
| 66 |
+
|
| 67 |
+
#if (CUDNN_ADV_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_MINOR != CUDNN_MINOR) || (CUDNN_ADV_PATCH != CUDNN_PATCHLEVEL)
|
| 68 |
+
#error Version mismatch in cuDNN ADV INFER!!!
|
| 69 |
+
#endif
|
| 70 |
+
|
| 71 |
+
#if defined(__cplusplus)
|
| 72 |
+
extern "C" {
|
| 73 |
+
#endif
|
| 74 |
+
|
| 75 |
+
/* BASIC RNN API */
|
| 76 |
+
|
| 77 |
+
typedef enum {
|
| 78 |
+
CUDNN_RNN_ALGO_STANDARD = 0,
|
| 79 |
+
CUDNN_RNN_ALGO_PERSIST_STATIC = 1,
|
| 80 |
+
CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2,
|
| 81 |
+
CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H = 3,
|
| 82 |
+
CUDNN_RNN_ALGO_COUNT = 4,
|
| 83 |
+
} cudnnRNNAlgo_t;
|
| 84 |
+
|
| 85 |
+
typedef enum {
|
| 86 |
+
CUDNN_FWD_MODE_INFERENCE = 0,
|
| 87 |
+
CUDNN_FWD_MODE_TRAINING = 1,
|
| 88 |
+
} cudnnForwardMode_t;
|
| 89 |
+
|
| 90 |
+
typedef enum {
|
| 91 |
+
CUDNN_RNN_RELU = 0, /* basic RNN cell type with ReLu activation */
|
| 92 |
+
CUDNN_RNN_TANH = 1, /* basic RNN cell type with tanh activation */
|
| 93 |
+
CUDNN_LSTM = 2, /* LSTM with optional recurrent projection and clipping */
|
| 94 |
+
CUDNN_GRU = 3, /* Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1); */
|
| 95 |
+
} cudnnRNNMode_t;
|
| 96 |
+
|
| 97 |
+
typedef enum {
|
| 98 |
+
CUDNN_RNN_NO_BIAS = 0, /* rnn cell formulas do not use biases */
|
| 99 |
+
CUDNN_RNN_SINGLE_INP_BIAS = 1, /* rnn cell formulas use one input bias in input GEMM */
|
| 100 |
+
CUDNN_RNN_DOUBLE_BIAS = 2, /* default, rnn cell formulas use two bias vectors */
|
| 101 |
+
CUDNN_RNN_SINGLE_REC_BIAS = 3 /* rnn cell formulas use one recurrent bias in recurrent GEMM */
|
| 102 |
+
} cudnnRNNBiasMode_t;
|
| 103 |
+
|
| 104 |
+
typedef enum {
|
| 105 |
+
CUDNN_UNIDIRECTIONAL = 0, /* single direction network */
|
| 106 |
+
CUDNN_BIDIRECTIONAL = 1, /* output concatination at each layer */
|
| 107 |
+
} cudnnDirectionMode_t;
|
| 108 |
+
|
| 109 |
+
typedef enum {
|
| 110 |
+
CUDNN_LINEAR_INPUT = 0, /* adjustable weight matrix in first layer input GEMM */
|
| 111 |
+
CUDNN_SKIP_INPUT = 1, /* fixed identity matrix in the first layer input GEMM */
|
| 112 |
+
} cudnnRNNInputMode_t;
|
| 113 |
+
|
| 114 |
+
typedef enum {
|
| 115 |
+
CUDNN_RNN_CLIP_NONE = 0, /* disables LSTM cell clipping */
|
| 116 |
+
CUDNN_RNN_CLIP_MINMAX = 1, /* enables LSTM cell clipping */
|
| 117 |
+
} cudnnRNNClipMode_t;
|
| 118 |
+
|
| 119 |
+
typedef enum {
|
| 120 |
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED = 0, /* padded, outer stride from one time-step to the next */
|
| 121 |
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED = 1, /* sequence length sorted and packed as in basic RNN api */
|
| 122 |
+
CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED = 2, /* padded, outer stride from one batch to the next */
|
| 123 |
+
} cudnnRNNDataLayout_t;
|
| 124 |
+
|
| 125 |
+
/* For auxFlags in cudnnSetRNNDescriptor_v8() */
|
| 126 |
+
#define CUDNN_RNN_PADDED_IO_DISABLED 0
|
| 127 |
+
#define CUDNN_RNN_PADDED_IO_ENABLED (1U << 0)
|
| 128 |
+
|
| 129 |
+
struct cudnnRNNStruct;
|
| 130 |
+
typedef struct cudnnRNNStruct *cudnnRNNDescriptor_t;
|
| 131 |
+
|
| 132 |
+
struct cudnnRNNDataStruct;
|
| 133 |
+
typedef struct cudnnRNNDataStruct *cudnnRNNDataDescriptor_t;
|
| 134 |
+
|
| 135 |
+
cudnnStatus_t CUDNNWINAPI
|
| 136 |
+
cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc);
|
| 137 |
+
|
| 138 |
+
cudnnStatus_t CUDNNWINAPI
|
| 139 |
+
cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
|
| 140 |
+
|
| 141 |
+
/*
|
| 142 |
+
* mathPrec in cudnnSetRNNDescriptor_v8() specifies compute precision.
|
| 143 |
+
* Compute precision is further modified by mathType that sets the
|
| 144 |
+
* preferred option for using NVIDIA Tensor Cores. dataType specify
|
| 145 |
+
* input/output data type and weight/bias type.
|
| 146 |
+
*/
|
| 147 |
+
|
| 148 |
+
cudnnStatus_t CUDNNWINAPI
|
| 149 |
+
cudnnSetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
|
| 150 |
+
cudnnRNNAlgo_t algo,
|
| 151 |
+
cudnnRNNMode_t cellMode,
|
| 152 |
+
cudnnRNNBiasMode_t biasMode,
|
| 153 |
+
cudnnDirectionMode_t dirMode,
|
| 154 |
+
cudnnRNNInputMode_t inputMode,
|
| 155 |
+
cudnnDataType_t dataType,
|
| 156 |
+
cudnnDataType_t mathPrec,
|
| 157 |
+
cudnnMathType_t mathType,
|
| 158 |
+
int32_t inputSize,
|
| 159 |
+
int32_t hiddenSize,
|
| 160 |
+
int32_t projSize,
|
| 161 |
+
int32_t numLayers,
|
| 162 |
+
cudnnDropoutDescriptor_t dropoutDesc,
|
| 163 |
+
uint32_t auxFlags);
|
| 164 |
+
|
| 165 |
+
cudnnStatus_t CUDNNWINAPI
|
| 166 |
+
cudnnGetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
|
| 167 |
+
cudnnRNNAlgo_t *algo,
|
| 168 |
+
cudnnRNNMode_t *cellMode,
|
| 169 |
+
cudnnRNNBiasMode_t *biasMode,
|
| 170 |
+
cudnnDirectionMode_t *dirMode,
|
| 171 |
+
cudnnRNNInputMode_t *inputMode,
|
| 172 |
+
cudnnDataType_t *dataType,
|
| 173 |
+
cudnnDataType_t *mathPrec,
|
| 174 |
+
cudnnMathType_t *mathType,
|
| 175 |
+
int32_t *inputSize,
|
| 176 |
+
int32_t *hiddenSize,
|
| 177 |
+
int32_t *projSize,
|
| 178 |
+
int32_t *numLayers,
|
| 179 |
+
cudnnDropoutDescriptor_t *dropoutDesc,
|
| 180 |
+
uint32_t *auxFlags);
|
| 181 |
+
|
| 182 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 183 |
+
cudnnRNNSetClip_v8(cudnnRNNDescriptor_t rnnDesc,
|
| 184 |
+
cudnnRNNClipMode_t clipMode,
|
| 185 |
+
cudnnNanPropagation_t clipNanOpt,
|
| 186 |
+
double lclip,
|
| 187 |
+
double rclip);
|
| 188 |
+
|
| 189 |
+
cudnnStatus_t CUDNNWINAPI
|
| 190 |
+
cudnnRNNSetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t clipMode, double lclip, double rclip);
|
| 191 |
+
|
| 192 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 193 |
+
cudnnRNNGetClip_v8(cudnnRNNDescriptor_t rnnDesc,
|
| 194 |
+
cudnnRNNClipMode_t *clipMode,
|
| 195 |
+
cudnnNanPropagation_t *clipNanOpt,
|
| 196 |
+
double *lclip,
|
| 197 |
+
double *rclip);
|
| 198 |
+
|
| 199 |
+
cudnnStatus_t CUDNNWINAPI
|
| 200 |
+
cudnnRNNGetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t *clipMode, double *lclip, double *rclip);
|
| 201 |
+
|
| 202 |
+
cudnnStatus_t CUDNNWINAPI
|
| 203 |
+
cudnnBuildRNNDynamic(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int miniBatch);
|
| 204 |
+
|
| 205 |
+
cudnnStatus_t CUDNNWINAPI
|
| 206 |
+
cudnnGetRNNTempSpaceSizes(cudnnHandle_t handle,
|
| 207 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 208 |
+
cudnnForwardMode_t fwdMode,
|
| 209 |
+
cudnnRNNDataDescriptor_t xDesc,
|
| 210 |
+
size_t *workSpaceSize,
|
| 211 |
+
size_t *reserveSpaceSize);
|
| 212 |
+
|
| 213 |
+
cudnnStatus_t CUDNNWINAPI
|
| 214 |
+
cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, size_t *weightSpaceSize);
|
| 215 |
+
|
| 216 |
+
cudnnStatus_t CUDNNWINAPI
|
| 217 |
+
cudnnGetRNNWeightParams(cudnnHandle_t handle,
|
| 218 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 219 |
+
int32_t pseudoLayer,
|
| 220 |
+
size_t weightSpaceSize,
|
| 221 |
+
const void *weightSpace,
|
| 222 |
+
int32_t linLayerID,
|
| 223 |
+
cudnnTensorDescriptor_t mDesc,
|
| 224 |
+
void **mAddr,
|
| 225 |
+
cudnnTensorDescriptor_t bDesc,
|
| 226 |
+
void **bAddr);
|
| 227 |
+
|
| 228 |
+
cudnnStatus_t CUDNNWINAPI
|
| 229 |
+
cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc);
|
| 230 |
+
|
| 231 |
+
cudnnStatus_t CUDNNWINAPI
|
| 232 |
+
cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc);
|
| 233 |
+
|
| 234 |
+
cudnnStatus_t CUDNNWINAPI
|
| 235 |
+
cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
|
| 236 |
+
cudnnDataType_t dataType,
|
| 237 |
+
cudnnRNNDataLayout_t layout,
|
| 238 |
+
int maxSeqLength,
|
| 239 |
+
int batchSize,
|
| 240 |
+
int vectorSize,
|
| 241 |
+
const int seqLengthArray[], /* length of each sequence in the batch */
|
| 242 |
+
void *paddingFill); /* symbol for filling padding position in output */
|
| 243 |
+
|
| 244 |
+
cudnnStatus_t CUDNNWINAPI
|
| 245 |
+
cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
|
| 246 |
+
cudnnDataType_t *dataType,
|
| 247 |
+
cudnnRNNDataLayout_t *layout,
|
| 248 |
+
int *maxSeqLength,
|
| 249 |
+
int *batchSize,
|
| 250 |
+
int *vectorSize,
|
| 251 |
+
int arrayLengthRequested,
|
| 252 |
+
int seqLengthArray[],
|
| 253 |
+
void *paddingFill);
|
| 254 |
+
|
| 255 |
+
cudnnStatus_t CUDNNWINAPI
|
| 256 |
+
cudnnRNNForward(cudnnHandle_t handle,
|
| 257 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 258 |
+
cudnnForwardMode_t fwdMode,
|
| 259 |
+
const int32_t devSeqLengths[],
|
| 260 |
+
cudnnRNNDataDescriptor_t xDesc,
|
| 261 |
+
const void *x,
|
| 262 |
+
cudnnRNNDataDescriptor_t yDesc,
|
| 263 |
+
void *y,
|
| 264 |
+
cudnnTensorDescriptor_t hDesc,
|
| 265 |
+
const void *hx,
|
| 266 |
+
void *hy,
|
| 267 |
+
cudnnTensorDescriptor_t cDesc,
|
| 268 |
+
const void *cx,
|
| 269 |
+
void *cy,
|
| 270 |
+
size_t weightSpaceSize,
|
| 271 |
+
const void *weightSpace,
|
| 272 |
+
size_t workSpaceSize,
|
| 273 |
+
void *workSpace,
|
| 274 |
+
size_t reserveSpaceSize,
|
| 275 |
+
void *reserveSpace);
|
| 276 |
+
|
| 277 |
+
/* Sequence data descriptor */
|
| 278 |
+
|
| 279 |
+
typedef enum {
|
| 280 |
+
CUDNN_SEQDATA_TIME_DIM = 0, /* index in time */
|
| 281 |
+
CUDNN_SEQDATA_BATCH_DIM = 1, /* index in batch */
|
| 282 |
+
CUDNN_SEQDATA_BEAM_DIM = 2, /* index in beam */
|
| 283 |
+
CUDNN_SEQDATA_VECT_DIM = 3 /* index in vector */
|
| 284 |
+
} cudnnSeqDataAxis_t;
|
| 285 |
+
|
| 286 |
+
struct cudnnSeqDataStruct;
|
| 287 |
+
typedef struct cudnnSeqDataStruct *cudnnSeqDataDescriptor_t CUDNN_DEPRECATED;
|
| 288 |
+
|
| 289 |
+
#define CUDNN_SEQDATA_DIM_COUNT 4 /* dimension count */
|
| 290 |
+
|
| 291 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 292 |
+
cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc);
|
| 293 |
+
|
| 294 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 295 |
+
cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc);
|
| 296 |
+
|
| 297 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 298 |
+
cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc,
|
| 299 |
+
cudnnDataType_t dataType,
|
| 300 |
+
int nbDims,
|
| 301 |
+
const int dimA[],
|
| 302 |
+
const cudnnSeqDataAxis_t axes[],
|
| 303 |
+
size_t seqLengthArraySize,
|
| 304 |
+
const int seqLengthArray[],
|
| 305 |
+
void *paddingFill);
|
| 306 |
+
|
| 307 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 308 |
+
cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc,
|
| 309 |
+
cudnnDataType_t *dataType,
|
| 310 |
+
int *nbDims,
|
| 311 |
+
int nbDimsRequested,
|
| 312 |
+
int dimA[],
|
| 313 |
+
cudnnSeqDataAxis_t axes[],
|
| 314 |
+
size_t *seqLengthArraySize,
|
| 315 |
+
size_t seqLengthSizeRequested,
|
| 316 |
+
int seqLengthArray[],
|
| 317 |
+
void *paddingFill);
|
| 318 |
+
|
| 319 |
+
/* Multihead Attention */
|
| 320 |
+
|
| 321 |
+
/*
|
| 322 |
+
* Multi-head attention options passed via 'attnMode' in cudnnSetAttnDescriptor().
|
| 323 |
+
* Use the bitwise OR operator to combine several settings listed below. Additional
|
| 324 |
+
* minor options can be added here w/o changing or introducing new API functions.
|
| 325 |
+
*/
|
| 326 |
+
#define CUDNN_ATTN_QUERYMAP_ALL_TO_ONE 0 /* multiple Q-s map to a single (K,V) set when beam size > 1 */
|
| 327 |
+
#define CUDNN_ATTN_QUERYMAP_ONE_TO_ONE (1U << 0) /* multiple Q-s map to multiple (K,V) sets when beam size > 1 */
|
| 328 |
+
#define CUDNN_ATTN_DISABLE_PROJ_BIASES 0 /* no biases in attention input and output projections */
|
| 329 |
+
#define CUDNN_ATTN_ENABLE_PROJ_BIASES (1U << 1) /* use biases in attention input and output projections */
|
| 330 |
+
|
| 331 |
+
struct cudnnAttnStruct;
|
| 332 |
+
typedef struct cudnnAttnStruct *cudnnAttnDescriptor_t CUDNN_DEPRECATED;
|
| 333 |
+
|
| 334 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 335 |
+
cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc);
|
| 336 |
+
|
| 337 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 338 |
+
cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc);
|
| 339 |
+
|
| 340 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 341 |
+
cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
|
| 342 |
+
unsigned attnMode,
|
| 343 |
+
int nHeads,
|
| 344 |
+
double smScaler,
|
| 345 |
+
cudnnDataType_t dataType,
|
| 346 |
+
cudnnDataType_t computePrec,
|
| 347 |
+
cudnnMathType_t mathType,
|
| 348 |
+
cudnnDropoutDescriptor_t attnDropoutDesc,
|
| 349 |
+
cudnnDropoutDescriptor_t postDropoutDesc,
|
| 350 |
+
int qSize,
|
| 351 |
+
int kSize,
|
| 352 |
+
int vSize,
|
| 353 |
+
int qProjSize,
|
| 354 |
+
int kProjSize,
|
| 355 |
+
int vProjSize,
|
| 356 |
+
int oProjSize,
|
| 357 |
+
int qoMaxSeqLength,
|
| 358 |
+
int kvMaxSeqLength,
|
| 359 |
+
int maxBatchSize,
|
| 360 |
+
int maxBeamSize);
|
| 361 |
+
|
| 362 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 363 |
+
cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
|
| 364 |
+
unsigned *attnMode,
|
| 365 |
+
int *nHeads,
|
| 366 |
+
double *smScaler,
|
| 367 |
+
cudnnDataType_t *dataType,
|
| 368 |
+
cudnnDataType_t *computePrec,
|
| 369 |
+
cudnnMathType_t *mathType,
|
| 370 |
+
cudnnDropoutDescriptor_t *attnDropoutDesc,
|
| 371 |
+
cudnnDropoutDescriptor_t *postDropoutDesc,
|
| 372 |
+
int *qSize,
|
| 373 |
+
int *kSize,
|
| 374 |
+
int *vSize,
|
| 375 |
+
int *qProjSize,
|
| 376 |
+
int *kProjSize,
|
| 377 |
+
int *vProjSize,
|
| 378 |
+
int *oProjSize,
|
| 379 |
+
int *qoMaxSeqLength,
|
| 380 |
+
int *kvMaxSeqLength,
|
| 381 |
+
int *maxBatchSize,
|
| 382 |
+
int *maxBeamSize);
|
| 383 |
+
|
| 384 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 385 |
+
cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle,
|
| 386 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 387 |
+
size_t *weightSizeInBytes,
|
| 388 |
+
size_t *workSpaceSizeInBytes,
|
| 389 |
+
size_t *reserveSpaceSizeInBytes);
|
| 390 |
+
|
| 391 |
+
typedef enum {
|
| 392 |
+
CUDNN_MH_ATTN_Q_WEIGHTS = 0, /* input projection weights for 'queries' */
|
| 393 |
+
CUDNN_MH_ATTN_K_WEIGHTS = 1, /* input projection weights for 'keys' */
|
| 394 |
+
CUDNN_MH_ATTN_V_WEIGHTS = 2, /* input projection weights for 'values' */
|
| 395 |
+
CUDNN_MH_ATTN_O_WEIGHTS = 3, /* output projection weights */
|
| 396 |
+
CUDNN_MH_ATTN_Q_BIASES = 4, /* input projection bias tensor for 'queries' */
|
| 397 |
+
CUDNN_MH_ATTN_K_BIASES = 5, /* input projection bias for 'keys' */
|
| 398 |
+
CUDNN_MH_ATTN_V_BIASES = 6, /* input projection bias for 'values' */
|
| 399 |
+
CUDNN_MH_ATTN_O_BIASES = 7, /* output projection biases */
|
| 400 |
+
} cudnnMultiHeadAttnWeightKind_t;
|
| 401 |
+
|
| 402 |
+
#define CUDNN_ATTN_WKIND_COUNT 8 /* Number of attention weight/bias tensors */
|
| 403 |
+
|
| 404 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 405 |
+
cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle,
|
| 406 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 407 |
+
cudnnMultiHeadAttnWeightKind_t wKind,
|
| 408 |
+
size_t weightSizeInBytes,
|
| 409 |
+
const void *weights,
|
| 410 |
+
cudnnTensorDescriptor_t wDesc,
|
| 411 |
+
void **wAddr);
|
| 412 |
+
|
| 413 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 414 |
+
cudnnMultiHeadAttnForward(cudnnHandle_t handle,
|
| 415 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 416 |
+
int currIdx,
|
| 417 |
+
const int loWinIdx[],
|
| 418 |
+
const int hiWinIdx[],
|
| 419 |
+
const int devSeqLengthsQO[],
|
| 420 |
+
const int devSeqLengthsKV[],
|
| 421 |
+
const cudnnSeqDataDescriptor_t qDesc,
|
| 422 |
+
const void *queries,
|
| 423 |
+
const void *residuals,
|
| 424 |
+
const cudnnSeqDataDescriptor_t kDesc,
|
| 425 |
+
const void *keys,
|
| 426 |
+
const cudnnSeqDataDescriptor_t vDesc,
|
| 427 |
+
const void *values,
|
| 428 |
+
const cudnnSeqDataDescriptor_t oDesc,
|
| 429 |
+
void *out,
|
| 430 |
+
size_t weightSizeInBytes,
|
| 431 |
+
const void *weights,
|
| 432 |
+
size_t workSpaceSizeInBytes,
|
| 433 |
+
void *workSpace,
|
| 434 |
+
size_t reserveSpaceSizeInBytes,
|
| 435 |
+
void *reserveSpace);
|
| 436 |
+
|
| 437 |
+
/*
|
| 438 |
+
* \brief Cross-library version checker.
|
| 439 |
+
* This function is implemented differently in each sub-library. Each sublib
|
| 440 |
+
* checks whether its own version matches that of its dependencies.
|
| 441 |
+
* \returns CUDNN_STATUS_SUCCESS if the version check passes,
|
| 442 |
+
* CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
|
| 443 |
+
*/
|
| 444 |
+
cudnnStatus_t CUDNNWINAPI
|
| 445 |
+
cudnnAdvVersionCheck(void);
|
| 446 |
+
|
| 447 |
+
typedef enum {
|
| 448 |
+
CUDNN_WGRAD_MODE_ADD = 0, /* add partial gradients to wgrad output buffers */
|
| 449 |
+
CUDNN_WGRAD_MODE_SET = 1, /* write partial gradients to wgrad output buffers */
|
| 450 |
+
} cudnnWgradMode_t;
|
| 451 |
+
|
| 452 |
+
cudnnStatus_t CUDNNWINAPI
|
| 453 |
+
cudnnRNNBackwardData_v8(cudnnHandle_t handle,
|
| 454 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 455 |
+
const int32_t devSeqLengths[],
|
| 456 |
+
cudnnRNNDataDescriptor_t yDesc,
|
| 457 |
+
const void *y,
|
| 458 |
+
const void *dy,
|
| 459 |
+
cudnnRNNDataDescriptor_t xDesc,
|
| 460 |
+
void *dx,
|
| 461 |
+
cudnnTensorDescriptor_t hDesc,
|
| 462 |
+
const void *hx,
|
| 463 |
+
const void *dhy,
|
| 464 |
+
void *dhx,
|
| 465 |
+
cudnnTensorDescriptor_t cDesc,
|
| 466 |
+
const void *cx,
|
| 467 |
+
const void *dcy,
|
| 468 |
+
void *dcx,
|
| 469 |
+
size_t weightSpaceSize,
|
| 470 |
+
const void *weightSpace,
|
| 471 |
+
size_t workSpaceSize,
|
| 472 |
+
void *workSpace,
|
| 473 |
+
size_t reserveSpaceSize,
|
| 474 |
+
void *reserveSpace);
|
| 475 |
+
|
| 476 |
+
cudnnStatus_t CUDNNWINAPI
|
| 477 |
+
cudnnRNNBackwardWeights_v8(cudnnHandle_t handle,
|
| 478 |
+
cudnnRNNDescriptor_t rnnDesc,
|
| 479 |
+
cudnnWgradMode_t addGrad,
|
| 480 |
+
const int32_t devSeqLengths[],
|
| 481 |
+
cudnnRNNDataDescriptor_t xDesc,
|
| 482 |
+
const void *x,
|
| 483 |
+
cudnnTensorDescriptor_t hDesc,
|
| 484 |
+
const void *hx,
|
| 485 |
+
cudnnRNNDataDescriptor_t yDesc,
|
| 486 |
+
const void *y,
|
| 487 |
+
size_t weightSpaceSize,
|
| 488 |
+
void *dweightSpace,
|
| 489 |
+
size_t workSpaceSize,
|
| 490 |
+
void *workSpace,
|
| 491 |
+
size_t reserveSpaceSize,
|
| 492 |
+
void *reserveSpace);
|
| 493 |
+
|
| 494 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 495 |
+
cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle,
|
| 496 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 497 |
+
const int loWinIdx[],
|
| 498 |
+
const int hiWinIdx[],
|
| 499 |
+
const int devSeqLengthsDQDO[],
|
| 500 |
+
const int devSeqLengthsDKDV[],
|
| 501 |
+
const cudnnSeqDataDescriptor_t doDesc,
|
| 502 |
+
const void *dout,
|
| 503 |
+
const cudnnSeqDataDescriptor_t dqDesc,
|
| 504 |
+
void *dqueries,
|
| 505 |
+
const void *queries,
|
| 506 |
+
const cudnnSeqDataDescriptor_t dkDesc,
|
| 507 |
+
void *dkeys,
|
| 508 |
+
const void *keys,
|
| 509 |
+
const cudnnSeqDataDescriptor_t dvDesc,
|
| 510 |
+
void *dvalues,
|
| 511 |
+
const void *values,
|
| 512 |
+
size_t weightSizeInBytes,
|
| 513 |
+
const void *weights,
|
| 514 |
+
size_t workSpaceSizeInBytes,
|
| 515 |
+
void *workSpace,
|
| 516 |
+
size_t reserveSpaceSizeInBytes,
|
| 517 |
+
void *reserveSpace);
|
| 518 |
+
|
| 519 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 520 |
+
cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle,
|
| 521 |
+
const cudnnAttnDescriptor_t attnDesc,
|
| 522 |
+
cudnnWgradMode_t addGrad,
|
| 523 |
+
const cudnnSeqDataDescriptor_t qDesc,
|
| 524 |
+
const void *queries,
|
| 525 |
+
const cudnnSeqDataDescriptor_t kDesc,
|
| 526 |
+
const void *keys,
|
| 527 |
+
const cudnnSeqDataDescriptor_t vDesc,
|
| 528 |
+
const void *values,
|
| 529 |
+
const cudnnSeqDataDescriptor_t doDesc,
|
| 530 |
+
const void *dout,
|
| 531 |
+
size_t weightSizeInBytes,
|
| 532 |
+
const void *weights,
|
| 533 |
+
void *dweights,
|
| 534 |
+
size_t workSpaceSizeInBytes,
|
| 535 |
+
void *workSpace,
|
| 536 |
+
size_t reserveSpaceSizeInBytes,
|
| 537 |
+
void *reserveSpace);
|
| 538 |
+
|
| 539 |
+
/*
|
| 540 |
+
* CTC (Connectionist Temporal Classification) loss descriptor create/destory/set/get functions
|
| 541 |
+
*/
|
| 542 |
+
/* Input normalization mode for loss function */
|
| 543 |
+
typedef enum {
|
| 544 |
+
CUDNN_LOSS_NORMALIZATION_NONE = 0,
|
| 545 |
+
CUDNN_LOSS_NORMALIZATION_SOFTMAX = 1,
|
| 546 |
+
} cudnnLossNormalizationMode_t;
|
| 547 |
+
|
| 548 |
+
cudnnStatus_t CUDNNWINAPI
|
| 549 |
+
cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc);
|
| 550 |
+
|
| 551 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 552 |
+
cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType);
|
| 553 |
+
|
| 554 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 555 |
+
cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 556 |
+
cudnnDataType_t compType,
|
| 557 |
+
cudnnLossNormalizationMode_t normMode,
|
| 558 |
+
cudnnNanPropagation_t gradMode);
|
| 559 |
+
|
| 560 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 561 |
+
cudnnSetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 562 |
+
cudnnDataType_t compType,
|
| 563 |
+
cudnnLossNormalizationMode_t normMode,
|
| 564 |
+
cudnnNanPropagation_t gradMode,
|
| 565 |
+
int maxLabelLength);
|
| 566 |
+
|
| 567 |
+
cudnnStatus_t CUDNNWINAPI
|
| 568 |
+
cudnnSetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 569 |
+
cudnnDataType_t compType,
|
| 570 |
+
cudnnLossNormalizationMode_t normMode,
|
| 571 |
+
cudnnCTCGradMode_t ctcGradMode,
|
| 572 |
+
int maxLabelLength);
|
| 573 |
+
|
| 574 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 575 |
+
cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType);
|
| 576 |
+
|
| 577 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 578 |
+
cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 579 |
+
cudnnDataType_t *compType,
|
| 580 |
+
cudnnLossNormalizationMode_t *normMode,
|
| 581 |
+
cudnnNanPropagation_t *gradMode);
|
| 582 |
+
|
| 583 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 584 |
+
cudnnGetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 585 |
+
cudnnDataType_t *compType,
|
| 586 |
+
cudnnLossNormalizationMode_t *normMode,
|
| 587 |
+
cudnnNanPropagation_t *gradMode,
|
| 588 |
+
int *maxLabelLength);
|
| 589 |
+
|
| 590 |
+
cudnnStatus_t CUDNNWINAPI
|
| 591 |
+
cudnnGetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 592 |
+
cudnnDataType_t *compType,
|
| 593 |
+
cudnnLossNormalizationMode_t *normMode,
|
| 594 |
+
cudnnCTCGradMode_t *ctcGradMode,
|
| 595 |
+
int *maxLabelLength);
|
| 596 |
+
|
| 597 |
+
cudnnStatus_t CUDNNWINAPI
|
| 598 |
+
cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc);
|
| 599 |
+
|
| 600 |
+
/* return the ctc costs and gradients, given the probabilities and labels */
|
| 601 |
+
cudnnStatus_t CUDNNWINAPI
|
| 602 |
+
cudnnCTCLoss(
|
| 603 |
+
cudnnHandle_t handle,
|
| 604 |
+
const cudnnTensorDescriptor_t
|
| 605 |
+
probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
|
| 606 |
+
mini batch size, A is the alphabet size) */
|
| 607 |
+
const void *probs, /* probabilities after softmax, in GPU memory */
|
| 608 |
+
const int hostLabels[], /* labels, in CPU memory */
|
| 609 |
+
const int hostLabelLengths[], /* the length of each label, in CPU memory */
|
| 610 |
+
const int hostInputLengths[], /* the lengths of timing steps in each batch, in CPU memory */
|
| 611 |
+
void *costs, /* the returned costs of CTC, in GPU memory */
|
| 612 |
+
const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
|
| 613 |
+
void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
|
| 614 |
+
cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
|
| 615 |
+
cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 616 |
+
void *workspace, /* pointer to the workspace, in GPU memory */
|
| 617 |
+
size_t workSpaceSizeInBytes); /* size of the workspace */
|
| 618 |
+
|
| 619 |
+
/* return the ctc costs and gradients, given the probabilities and labels */
|
| 620 |
+
cudnnStatus_t CUDNNWINAPI
|
| 621 |
+
cudnnCTCLoss_v8(
|
| 622 |
+
cudnnHandle_t handle,
|
| 623 |
+
cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
|
| 624 |
+
cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 625 |
+
const cudnnTensorDescriptor_t
|
| 626 |
+
probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
|
| 627 |
+
mini batch size, A is the alphabet size) */
|
| 628 |
+
const void *probs, /* probabilities after softmax, in GPU memory */
|
| 629 |
+
const int labels[], /* labels, in GPU memory */
|
| 630 |
+
const int labelLengths[], /* the length of each label, in GPU memory */
|
| 631 |
+
const int inputLengths[], /* the lengths of timing steps in each batch, in GPU memory */
|
| 632 |
+
void *costs, /* the returned costs of CTC, in GPU memory */
|
| 633 |
+
const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
|
| 634 |
+
void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
|
| 635 |
+
size_t workSpaceSizeInBytes, /* size of the workspace */
|
| 636 |
+
void *workspace); /* pointer to the workspace, in GPU memory */
|
| 637 |
+
|
| 638 |
+
/* return the workspace size needed for ctc */
|
| 639 |
+
cudnnStatus_t CUDNNWINAPI
|
| 640 |
+
cudnnGetCTCLossWorkspaceSize(
|
| 641 |
+
cudnnHandle_t handle,
|
| 642 |
+
const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
|
| 643 |
+
timing steps, N is the mini batch size, A is the alphabet size) */
|
| 644 |
+
const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
|
| 645 |
+
dimensions are T,N,A. To compute costs
|
| 646 |
+
only, set it to NULL */
|
| 647 |
+
const int *labels, /* labels, in CPU memory */
|
| 648 |
+
const int *labelLengths, /* the length of each label, in CPU memory */
|
| 649 |
+
const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */
|
| 650 |
+
cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
|
| 651 |
+
cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 652 |
+
size_t *sizeInBytes); /* pointer to the returned workspace size */
|
| 653 |
+
|
| 654 |
+
/* return the workspace size needed for ctc */
|
| 655 |
+
cudnnStatus_t CUDNNWINAPI
|
| 656 |
+
cudnnGetCTCLossWorkspaceSize_v8(
|
| 657 |
+
cudnnHandle_t handle,
|
| 658 |
+
cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
|
| 659 |
+
cudnnCTCLossDescriptor_t ctcLossDesc,
|
| 660 |
+
const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
|
| 661 |
+
timing steps, N is the mini batch size, A is the alphabet size) */
|
| 662 |
+
const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
|
| 663 |
+
dimensions are T,N,A. To compute costs
|
| 664 |
+
only, set it to NULL */
|
| 665 |
+
size_t *sizeInBytes); /* pointer to the returned workspace size */
|
| 666 |
+
|
| 667 |
+
#if defined(__cplusplus)
|
| 668 |
+
}
|
| 669 |
+
#endif
|
| 670 |
+
|
| 671 |
+
#endif /* CUDNN_ADV_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend.h
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#ifndef _CUDNN_BACKEND_H_
|
| 51 |
+
#define _CUDNN_BACKEND_H_
|
| 52 |
+
|
| 53 |
+
/*
|
| 54 |
+
* The content of this header has been moved into cudnn_graph.h.
|
| 55 |
+
* This header is kept for the backward compatibility purpose.
|
| 56 |
+
*/
|
| 57 |
+
|
| 58 |
+
#include "cudnn_graph.h"
|
| 59 |
+
|
| 60 |
+
#endif /* _CUDNN_BACKEND_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn.h
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* cudnn_cnn : cuDNN's basic definitions and CNN functions.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_CNN_H_)
|
| 55 |
+
#define CUDNN_CNN_H_
|
| 56 |
+
|
| 57 |
+
#pragma once
|
| 58 |
+
#include <stdint.h>
|
| 59 |
+
|
| 60 |
+
#include "cudnn_version.h"
|
| 61 |
+
#include "cudnn_ops.h"
|
| 62 |
+
|
| 63 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 64 |
+
#define CUDNN_CNN_MAJOR 9
|
| 65 |
+
#define CUDNN_CNN_MINOR 1
|
| 66 |
+
#define CUDNN_CNN_PATCH 0
|
| 67 |
+
|
| 68 |
+
#if (CUDNN_CNN_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_MINOR != CUDNN_MINOR) || (CUDNN_CNN_PATCH != CUDNN_PATCHLEVEL)
|
| 69 |
+
#error Version mismatch in cuDNN CNN INFER!!!
|
| 70 |
+
#endif
|
| 71 |
+
|
| 72 |
+
#if defined(__cplusplus)
|
| 73 |
+
extern "C" {
|
| 74 |
+
#endif
|
| 75 |
+
|
| 76 |
+
typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t CUDNN_DEPRECATED;
|
| 77 |
+
|
| 78 |
+
typedef struct cudnnConvolutionFwdAlgoPerfStruct {
|
| 79 |
+
cudnnConvolutionFwdAlgo_t algo;
|
| 80 |
+
cudnnStatus_t status;
|
| 81 |
+
float time;
|
| 82 |
+
size_t memory;
|
| 83 |
+
cudnnDeterminism_t determinism;
|
| 84 |
+
cudnnMathType_t mathType;
|
| 85 |
+
int reserved[3];
|
| 86 |
+
} cudnnConvolutionFwdAlgoPerf_t CUDNN_DEPRECATED;
|
| 87 |
+
|
| 88 |
+
/* Create an instance of convolution descriptor */
|
| 89 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 90 |
+
cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc);
|
| 91 |
+
|
| 92 |
+
/* Destroy an instance of convolution descriptor */
|
| 93 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 94 |
+
cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc);
|
| 95 |
+
|
| 96 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 97 |
+
cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType);
|
| 98 |
+
|
| 99 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 100 |
+
cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType);
|
| 101 |
+
|
| 102 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 103 |
+
cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount);
|
| 104 |
+
|
| 105 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 106 |
+
cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount);
|
| 107 |
+
|
| 108 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 109 |
+
cudnnSetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType);
|
| 110 |
+
|
| 111 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 112 |
+
cudnnGetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType);
|
| 113 |
+
|
| 114 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 115 |
+
cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc,
|
| 116 |
+
int pad_h, /* zero-padding height */
|
| 117 |
+
int pad_w, /* zero-padding width */
|
| 118 |
+
int u, /* vertical filter stride */
|
| 119 |
+
int v, /* horizontal filter stride */
|
| 120 |
+
int dilation_h, /* filter dilation in the vertical dimension */
|
| 121 |
+
int dilation_w, /* filter dilation in the horizontal dimension */
|
| 122 |
+
cudnnConvolutionMode_t mode,
|
| 123 |
+
cudnnDataType_t computeType);
|
| 124 |
+
|
| 125 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 126 |
+
cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc,
|
| 127 |
+
int *pad_h, /* zero-padding height */
|
| 128 |
+
int *pad_w, /* zero-padding width */
|
| 129 |
+
int *u, /* vertical filter stride */
|
| 130 |
+
int *v, /* horizontal filter stride */
|
| 131 |
+
int *dilation_h, /* filter dilation in the vertical dimension */
|
| 132 |
+
int *dilation_w, /* filter dilation in the horizontal dimension */
|
| 133 |
+
cudnnConvolutionMode_t *mode,
|
| 134 |
+
cudnnDataType_t *computeType);
|
| 135 |
+
|
| 136 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 137 |
+
cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc,
|
| 138 |
+
int arrayLength, /* nbDims-2 size */
|
| 139 |
+
const int padA[],
|
| 140 |
+
const int filterStrideA[],
|
| 141 |
+
const int dilationA[],
|
| 142 |
+
cudnnConvolutionMode_t mode,
|
| 143 |
+
cudnnDataType_t computeType); /* convolution data type */
|
| 144 |
+
|
| 145 |
+
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
|
| 146 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 147 |
+
cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc,
|
| 148 |
+
int arrayLengthRequested,
|
| 149 |
+
int *arrayLength,
|
| 150 |
+
int padA[],
|
| 151 |
+
int strideA[],
|
| 152 |
+
int dilationA[],
|
| 153 |
+
cudnnConvolutionMode_t *mode,
|
| 154 |
+
cudnnDataType_t *computeType); /* convolution data type */
|
| 155 |
+
|
| 156 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 157 |
+
cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
|
| 158 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 159 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 160 |
+
int *n,
|
| 161 |
+
int *c,
|
| 162 |
+
int *h,
|
| 163 |
+
int *w);
|
| 164 |
+
|
| 165 |
+
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
|
| 166 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 167 |
+
cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
|
| 168 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 169 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 170 |
+
int nbDims,
|
| 171 |
+
int tensorOuputDimA[]);
|
| 172 |
+
|
| 173 |
+
/* helper function to provide the convolution forward algo that fit best the requirement */
|
| 174 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 175 |
+
cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count);
|
| 176 |
+
|
| 177 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 178 |
+
cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle,
|
| 179 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 180 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 181 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 182 |
+
const cudnnTensorDescriptor_t destDesc,
|
| 183 |
+
const int requestedAlgoCount,
|
| 184 |
+
int *returnedAlgoCount,
|
| 185 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults);
|
| 186 |
+
|
| 187 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 188 |
+
cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle,
|
| 189 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 190 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 191 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 192 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 193 |
+
const int requestedAlgoCount,
|
| 194 |
+
int *returnedAlgoCount,
|
| 195 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults);
|
| 196 |
+
|
| 197 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 198 |
+
cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle,
|
| 199 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 200 |
+
const void *x,
|
| 201 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 202 |
+
const void *w,
|
| 203 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 204 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 205 |
+
void *y,
|
| 206 |
+
const int requestedAlgoCount,
|
| 207 |
+
int *returnedAlgoCount,
|
| 208 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults,
|
| 209 |
+
void *workSpace,
|
| 210 |
+
size_t workSpaceSizeInBytes);
|
| 211 |
+
|
| 212 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 213 |
+
cudnnIm2Col(cudnnHandle_t handle,
|
| 214 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 215 |
+
const void *x,
|
| 216 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 217 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 218 |
+
void *colBuffer);
|
| 219 |
+
|
| 220 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 221 |
+
cudnnReorderFilterAndBias(cudnnHandle_t handle,
|
| 222 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 223 |
+
cudnnReorderType_t reorderType,
|
| 224 |
+
const void *filterData,
|
| 225 |
+
void *reorderedFilterData,
|
| 226 |
+
int reorderBias,
|
| 227 |
+
const void *biasData,
|
| 228 |
+
void *reorderedBiasData);
|
| 229 |
+
|
| 230 |
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
|
| 231 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 232 |
+
cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle,
|
| 233 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 234 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 235 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 236 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 237 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 238 |
+
size_t *sizeInBytes);
|
| 239 |
+
|
| 240 |
+
/* Convolution functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 241 |
+
|
| 242 |
+
/* Function to perform the forward pass for batch convolution */
|
| 243 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 244 |
+
cudnnConvolutionForward(cudnnHandle_t handle,
|
| 245 |
+
const void *alpha,
|
| 246 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 247 |
+
const void *x,
|
| 248 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 249 |
+
const void *w,
|
| 250 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 251 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 252 |
+
void *workSpace,
|
| 253 |
+
size_t workSpaceSizeInBytes,
|
| 254 |
+
const void *beta,
|
| 255 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 256 |
+
void *y);
|
| 257 |
+
|
| 258 |
+
/* Fused conv/bias/activation operation : y = Act( alpha1 * conv(x) + alpha2 * z + bias ) */
|
| 259 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 260 |
+
cudnnConvolutionBiasActivationForward(cudnnHandle_t handle,
|
| 261 |
+
const void *alpha1,
|
| 262 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 263 |
+
const void *x,
|
| 264 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 265 |
+
const void *w,
|
| 266 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 267 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 268 |
+
void *workSpace,
|
| 269 |
+
size_t workSpaceSizeInBytes,
|
| 270 |
+
const void *alpha2,
|
| 271 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 272 |
+
const void *z,
|
| 273 |
+
const cudnnTensorDescriptor_t biasDesc,
|
| 274 |
+
const void *bias,
|
| 275 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 276 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 277 |
+
void *y);
|
| 278 |
+
|
| 279 |
+
/* helper function to provide the convolution backward data algo that fit best the requirement */
|
| 280 |
+
|
| 281 |
+
typedef struct cudnnConvolutionBwdDataAlgoPerfStruct {
|
| 282 |
+
cudnnConvolutionBwdDataAlgo_t algo;
|
| 283 |
+
cudnnStatus_t status;
|
| 284 |
+
float time;
|
| 285 |
+
size_t memory;
|
| 286 |
+
cudnnDeterminism_t determinism;
|
| 287 |
+
cudnnMathType_t mathType;
|
| 288 |
+
int reserved[3];
|
| 289 |
+
} cudnnConvolutionBwdDataAlgoPerf_t CUDNN_DEPRECATED;
|
| 290 |
+
|
| 291 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 292 |
+
cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count);
|
| 293 |
+
|
| 294 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 295 |
+
cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
|
| 296 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 297 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 298 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 299 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 300 |
+
const int requestedAlgoCount,
|
| 301 |
+
int *returnedAlgoCount,
|
| 302 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
|
| 303 |
+
|
| 304 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 305 |
+
cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle,
|
| 306 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 307 |
+
const void *w,
|
| 308 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 309 |
+
const void *dy,
|
| 310 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 311 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 312 |
+
void *dx,
|
| 313 |
+
const int requestedAlgoCount,
|
| 314 |
+
int *returnedAlgoCount,
|
| 315 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults,
|
| 316 |
+
void *workSpace,
|
| 317 |
+
size_t workSpaceSizeInBytes);
|
| 318 |
+
|
| 319 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 320 |
+
cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle,
|
| 321 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 322 |
+
const cudnnTensorDescriptor_t diffDesc,
|
| 323 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 324 |
+
const cudnnTensorDescriptor_t gradDesc,
|
| 325 |
+
const int requestedAlgoCount,
|
| 326 |
+
int *returnedAlgoCount,
|
| 327 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
|
| 328 |
+
|
| 329 |
+
/*
|
| 330 |
+
* convolution algorithm (which requires potentially some workspace)
|
| 331 |
+
*/
|
| 332 |
+
|
| 333 |
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
|
| 334 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 335 |
+
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle,
|
| 336 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 337 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 338 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 339 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 340 |
+
cudnnConvolutionBwdDataAlgo_t algo,
|
| 341 |
+
size_t *sizeInBytes);
|
| 342 |
+
|
| 343 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 344 |
+
cudnnConvolutionBackwardData(cudnnHandle_t handle,
|
| 345 |
+
const void *alpha,
|
| 346 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 347 |
+
const void *w,
|
| 348 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 349 |
+
const void *dy,
|
| 350 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 351 |
+
cudnnConvolutionBwdDataAlgo_t algo,
|
| 352 |
+
void *workSpace,
|
| 353 |
+
size_t workSpaceSizeInBytes,
|
| 354 |
+
const void *beta,
|
| 355 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 356 |
+
void *dx);
|
| 357 |
+
|
| 358 |
+
/* Helper function to calculate folding descriptors for dgrad */
|
| 359 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 360 |
+
cudnnGetFoldedConvBackwardDataDescriptors(const cudnnHandle_t handle,
|
| 361 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 362 |
+
const cudnnTensorDescriptor_t diffDesc,
|
| 363 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 364 |
+
const cudnnTensorDescriptor_t gradDesc,
|
| 365 |
+
const cudnnTensorFormat_t transformFormat,
|
| 366 |
+
cudnnFilterDescriptor_t foldedFilterDesc,
|
| 367 |
+
cudnnTensorDescriptor_t paddedDiffDesc,
|
| 368 |
+
cudnnConvolutionDescriptor_t foldedConvDesc,
|
| 369 |
+
cudnnTensorDescriptor_t foldedGradDesc,
|
| 370 |
+
cudnnTensorTransformDescriptor_t filterFoldTransDesc,
|
| 371 |
+
cudnnTensorTransformDescriptor_t diffPadTransDesc,
|
| 372 |
+
cudnnTensorTransformDescriptor_t gradFoldTransDesc,
|
| 373 |
+
cudnnTensorTransformDescriptor_t gradUnfoldTransDesc);
|
| 374 |
+
|
| 375 |
+
/* cudnnFusedOps... */
|
| 376 |
+
struct cudnnFusedOpsConstParamStruct;
|
| 377 |
+
typedef struct cudnnFusedOpsConstParamStruct *cudnnFusedOpsConstParamPack_t CUDNN_DEPRECATED;
|
| 378 |
+
|
| 379 |
+
struct cudnnFusedOpsVariantParamStruct;
|
| 380 |
+
typedef struct cudnnFusedOpsVariantParamStruct *cudnnFusedOpsVariantParamPack_t CUDNN_DEPRECATED;
|
| 381 |
+
|
| 382 |
+
struct cudnnFusedOpsPlanStruct;
|
| 383 |
+
typedef struct cudnnFusedOpsPlanStruct *cudnnFusedOpsPlan_t CUDNN_DEPRECATED;
|
| 384 |
+
|
| 385 |
+
typedef enum {
|
| 386 |
+
/* each op in [ ] can be disabled by passing NULL ptr */
|
| 387 |
+
/* [per channel scale], [per channel bias], [activation], convolution, [generate BN stats] */
|
| 388 |
+
CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS = 0,
|
| 389 |
+
/* [per channel scale], [per channel bias], [activation], convolutionBackwardWeights */
|
| 390 |
+
CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD = 1,
|
| 391 |
+
/* utility for BN training in BN-conv fusion */
|
| 392 |
+
/* computes the equivalent scale and bias from ySum ySqSum and learned scale, bias */
|
| 393 |
+
/* optionally update running stats and generate saved stats */
|
| 394 |
+
CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING = 2,
|
| 395 |
+
/* utility for BN inference in BN-conv fusion */
|
| 396 |
+
/* computes the equivalent scale and bias from learned running stats and learned scale, bias */
|
| 397 |
+
CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE = 3,
|
| 398 |
+
/* reserved for future use: convolution, [per channel scale], [per channel bias], [residual add], [activation] */
|
| 399 |
+
CUDNN_FUSED_CONV_SCALE_BIAS_ADD_ACTIVATION = 4,
|
| 400 |
+
/* reserved for future use: [per channel scale], [per channel bias], [residual add], activation, bitmask */
|
| 401 |
+
CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK = 5,
|
| 402 |
+
/* reserved for future use */
|
| 403 |
+
CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM = 6,
|
| 404 |
+
} cudnnFusedOps_t CUDNN_DEPRECATED;
|
| 405 |
+
|
| 406 |
+
typedef enum {
|
| 407 |
+
/* set XDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 408 |
+
/* get XDESC: pass previously created cudnnTensorDescriptor_t */
|
| 409 |
+
CUDNN_PARAM_XDESC = 0,
|
| 410 |
+
/* set/get XDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 411 |
+
CUDNN_PARAM_XDATA_PLACEHOLDER = 1,
|
| 412 |
+
/* set/get BN_MODE: pass cudnnBatchNormMode_t* */
|
| 413 |
+
CUDNN_PARAM_BN_MODE = 2,
|
| 414 |
+
/* set CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 415 |
+
/* get CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 416 |
+
CUDNN_PARAM_BN_EQSCALEBIAS_DESC = 3,
|
| 417 |
+
/* set/get BN_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 418 |
+
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER = 4,
|
| 419 |
+
/* set/get BN_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 420 |
+
CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER = 5,
|
| 421 |
+
/* set ACTIVATION_DESC: pass previously initialized cudnnActivationDescriptor_t */
|
| 422 |
+
/* get ACTIVATION_DESC: pass previously created cudnnActivationDescriptor_t */
|
| 423 |
+
CUDNN_PARAM_ACTIVATION_DESC = 6,
|
| 424 |
+
/* set CONV_DESC: pass previously initialized cudnnConvolutionDescriptor_t */
|
| 425 |
+
/* get CONV_DESC: pass previously created cudnnConvolutionDescriptor_t */
|
| 426 |
+
CUDNN_PARAM_CONV_DESC = 7,
|
| 427 |
+
/* set WDESC: pass previously initialized cudnnFilterDescriptor_t */
|
| 428 |
+
/* get WDESC: pass previously created cudnnFilterDescriptor_t */
|
| 429 |
+
CUDNN_PARAM_WDESC = 8,
|
| 430 |
+
/* set/get WDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 431 |
+
CUDNN_PARAM_WDATA_PLACEHOLDER = 9,
|
| 432 |
+
/* set DWDESC: pass previously initialized cudnnFilterDescriptor_t */
|
| 433 |
+
/* get DWDESC: pass previously created cudnnFilterDescriptor_t */
|
| 434 |
+
CUDNN_PARAM_DWDESC = 10,
|
| 435 |
+
/* set/get DWDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 436 |
+
CUDNN_PARAM_DWDATA_PLACEHOLDER = 11,
|
| 437 |
+
/* set YDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 438 |
+
/* get YDESC: pass previously created cudnnTensorDescriptor_t */
|
| 439 |
+
CUDNN_PARAM_YDESC = 12,
|
| 440 |
+
/* set/get YDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 441 |
+
CUDNN_PARAM_YDATA_PLACEHOLDER = 13,
|
| 442 |
+
/* set DYDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 443 |
+
/* get DYDESC: pass previously created cudnnTensorDescriptor_t */
|
| 444 |
+
CUDNN_PARAM_DYDESC = 14,
|
| 445 |
+
/* set/get DYDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 446 |
+
CUDNN_PARAM_DYDATA_PLACEHOLDER = 15,
|
| 447 |
+
/* set YSTATS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 448 |
+
/* get YSTATS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 449 |
+
CUDNN_PARAM_YSTATS_DESC = 16,
|
| 450 |
+
/* set/get YSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 451 |
+
CUDNN_PARAM_YSUM_PLACEHOLDER = 17,
|
| 452 |
+
/* set/get YSQSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 453 |
+
CUDNN_PARAM_YSQSUM_PLACEHOLDER = 18,
|
| 454 |
+
/* set CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 455 |
+
/* get CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 456 |
+
CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC = 19,
|
| 457 |
+
/* set/get CUDNN_PARAM_BN_SCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 458 |
+
CUDNN_PARAM_BN_SCALE_PLACEHOLDER = 20,
|
| 459 |
+
/* set/get CUDNN_PARAM_BN_BIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 460 |
+
CUDNN_PARAM_BN_BIAS_PLACEHOLDER = 21,
|
| 461 |
+
/* set/get CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 462 |
+
CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER = 22,
|
| 463 |
+
/* set/get CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 464 |
+
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER = 23,
|
| 465 |
+
/* set/get CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 466 |
+
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER = 24,
|
| 467 |
+
/* set/get CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 468 |
+
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER = 25,
|
| 469 |
+
|
| 470 |
+
/* set ZDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 471 |
+
/* get ZDESC: pass previously created cudnnTensorDescriptor_t */
|
| 472 |
+
CUDNN_PARAM_ZDESC = 26,
|
| 473 |
+
/* set/get ZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 474 |
+
CUDNN_PARAM_ZDATA_PLACEHOLDER = 27,
|
| 475 |
+
/* set BN_Z_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 476 |
+
/* get BN_Z_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 477 |
+
CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC = 28,
|
| 478 |
+
/* set/get BN_Z_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 479 |
+
CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER = 29,
|
| 480 |
+
/* set/get BN_Z_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 481 |
+
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER = 30,
|
| 482 |
+
|
| 483 |
+
/* set ACTIVATION_BITMASK_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 484 |
+
/* get ACTIVATION_BITMASK_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 485 |
+
CUDNN_PARAM_ACTIVATION_BITMASK_DESC = 31,
|
| 486 |
+
/* set/get ACTIVATION_BITMASK_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 487 |
+
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER = 32,
|
| 488 |
+
|
| 489 |
+
/* set DXDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 490 |
+
/* get DXDESC: pass previously created cudnnTensorDescriptor_t */
|
| 491 |
+
CUDNN_PARAM_DXDESC = 33,
|
| 492 |
+
/* set/get DXDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 493 |
+
CUDNN_PARAM_DXDATA_PLACEHOLDER = 34,
|
| 494 |
+
/* set DZDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 495 |
+
/* get DZDESC: pass previously created cudnnTensorDescriptor_t */
|
| 496 |
+
CUDNN_PARAM_DZDESC = 35,
|
| 497 |
+
/* set/get DZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 498 |
+
CUDNN_PARAM_DZDATA_PLACEHOLDER = 36,
|
| 499 |
+
/* set/get CUDNN_PARAM_BN_DSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 500 |
+
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER = 37,
|
| 501 |
+
/* set/get CUDNN_PARAM_BN_DBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 502 |
+
CUDNN_PARAM_BN_DBIAS_PLACEHOLDER = 38,
|
| 503 |
+
} cudnnFusedOpsConstParamLabel_t CUDNN_DEPRECATED;
|
| 504 |
+
|
| 505 |
+
typedef enum {
|
| 506 |
+
CUDNN_PTR_NULL = 0,
|
| 507 |
+
CUDNN_PTR_ELEM_ALIGNED = 1,
|
| 508 |
+
CUDNN_PTR_16B_ALIGNED = 2,
|
| 509 |
+
} cudnnFusedOpsPointerPlaceHolder_t CUDNN_DEPRECATED;
|
| 510 |
+
|
| 511 |
+
typedef enum {
|
| 512 |
+
/* set: pass void* pointing to dev memory */
|
| 513 |
+
/* get: pass void** pointing to host memory */
|
| 514 |
+
CUDNN_PTR_XDATA = 0,
|
| 515 |
+
CUDNN_PTR_BN_EQSCALE = 1,
|
| 516 |
+
CUDNN_PTR_BN_EQBIAS = 2,
|
| 517 |
+
CUDNN_PTR_WDATA = 3,
|
| 518 |
+
CUDNN_PTR_DWDATA = 4,
|
| 519 |
+
CUDNN_PTR_YDATA = 5,
|
| 520 |
+
CUDNN_PTR_DYDATA = 6,
|
| 521 |
+
CUDNN_PTR_YSUM = 7,
|
| 522 |
+
CUDNN_PTR_YSQSUM = 8,
|
| 523 |
+
CUDNN_PTR_WORKSPACE = 9,
|
| 524 |
+
CUDNN_PTR_BN_SCALE = 10,
|
| 525 |
+
CUDNN_PTR_BN_BIAS = 11,
|
| 526 |
+
CUDNN_PTR_BN_SAVED_MEAN = 12,
|
| 527 |
+
CUDNN_PTR_BN_SAVED_INVSTD = 13,
|
| 528 |
+
CUDNN_PTR_BN_RUNNING_MEAN = 14,
|
| 529 |
+
CUDNN_PTR_BN_RUNNING_VAR = 15,
|
| 530 |
+
CUDNN_PTR_ZDATA = 16,
|
| 531 |
+
CUDNN_PTR_BN_Z_EQSCALE = 17,
|
| 532 |
+
CUDNN_PTR_BN_Z_EQBIAS = 18,
|
| 533 |
+
CUDNN_PTR_ACTIVATION_BITMASK = 19,
|
| 534 |
+
CUDNN_PTR_DXDATA = 20,
|
| 535 |
+
CUDNN_PTR_DZDATA = 21,
|
| 536 |
+
CUDNN_PTR_BN_DSCALE = 22,
|
| 537 |
+
CUDNN_PTR_BN_DBIAS = 23,
|
| 538 |
+
|
| 539 |
+
/* set/get: pass size_t* pointing to host memory */
|
| 540 |
+
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES = 100,
|
| 541 |
+
/* set/get: pass int64_t* pointing to host memory */
|
| 542 |
+
CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT = 101,
|
| 543 |
+
/* set/get: pass double* pointing to host memory */
|
| 544 |
+
CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR = 102,
|
| 545 |
+
/* set/get: pass double* pointing to host memory */
|
| 546 |
+
CUDNN_SCALAR_DOUBLE_BN_EPSILON = 103,
|
| 547 |
+
} cudnnFusedOpsVariantParamLabel_t CUDNN_DEPRECATED;
|
| 548 |
+
|
| 549 |
+
cudnnStatus_t CUDNNWINAPI
|
| 550 |
+
cudnnCnnVersionCheck(void);
|
| 551 |
+
|
| 552 |
+
/* helper function to provide the convolution backward filter algo that fit best the requirement */
|
| 553 |
+
|
| 554 |
+
typedef struct cudnnConvolutionBwdFilterAlgoPerfStruct {
|
| 555 |
+
cudnnConvolutionBwdFilterAlgo_t algo;
|
| 556 |
+
cudnnStatus_t status;
|
| 557 |
+
float time;
|
| 558 |
+
size_t memory;
|
| 559 |
+
cudnnDeterminism_t determinism;
|
| 560 |
+
cudnnMathType_t mathType;
|
| 561 |
+
int reserved[3];
|
| 562 |
+
} cudnnConvolutionBwdFilterAlgoPerf_t CUDNN_DEPRECATED;
|
| 563 |
+
|
| 564 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 565 |
+
cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count);
|
| 566 |
+
|
| 567 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 568 |
+
cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
|
| 569 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 570 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 571 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 572 |
+
const cudnnFilterDescriptor_t dwDesc,
|
| 573 |
+
const int requestedAlgoCount,
|
| 574 |
+
int *returnedAlgoCount,
|
| 575 |
+
cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
|
| 576 |
+
|
| 577 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 578 |
+
cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
|
| 579 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 580 |
+
const void *x,
|
| 581 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 582 |
+
const void *y,
|
| 583 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 584 |
+
const cudnnFilterDescriptor_t dwDesc,
|
| 585 |
+
void *dw,
|
| 586 |
+
const int requestedAlgoCount,
|
| 587 |
+
int *returnedAlgoCount,
|
| 588 |
+
cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
|
| 589 |
+
void *workSpace,
|
| 590 |
+
size_t workSpaceSizeInBytes);
|
| 591 |
+
|
| 592 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 593 |
+
cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
|
| 594 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 595 |
+
const cudnnTensorDescriptor_t diffDesc,
|
| 596 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 597 |
+
const cudnnFilterDescriptor_t gradDesc,
|
| 598 |
+
const int requestedAlgoCount,
|
| 599 |
+
int *returnedAlgoCount,
|
| 600 |
+
cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
|
| 601 |
+
|
| 602 |
+
/*
|
| 603 |
+
* convolution algorithm (which requires potentially some workspace)
|
| 604 |
+
*/
|
| 605 |
+
|
| 606 |
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
|
| 607 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 608 |
+
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle,
|
| 609 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 610 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 611 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 612 |
+
const cudnnFilterDescriptor_t gradDesc,
|
| 613 |
+
cudnnConvolutionBwdFilterAlgo_t algo,
|
| 614 |
+
size_t *sizeInBytes);
|
| 615 |
+
|
| 616 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 617 |
+
cudnnConvolutionBackwardFilter(cudnnHandle_t handle,
|
| 618 |
+
const void *alpha,
|
| 619 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 620 |
+
const void *x,
|
| 621 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 622 |
+
const void *dy,
|
| 623 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 624 |
+
cudnnConvolutionBwdFilterAlgo_t algo,
|
| 625 |
+
void *workSpace,
|
| 626 |
+
size_t workSpaceSizeInBytes,
|
| 627 |
+
const void *beta,
|
| 628 |
+
const cudnnFilterDescriptor_t dwDesc,
|
| 629 |
+
void *dw);
|
| 630 |
+
|
| 631 |
+
/* Function to compute the bias gradient for batch convolution */
|
| 632 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 633 |
+
cudnnConvolutionBackwardBias(cudnnHandle_t handle,
|
| 634 |
+
const void *alpha,
|
| 635 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 636 |
+
const void *dy,
|
| 637 |
+
const void *beta,
|
| 638 |
+
const cudnnTensorDescriptor_t dbDesc,
|
| 639 |
+
void *db);
|
| 640 |
+
|
| 641 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 642 |
+
cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops);
|
| 643 |
+
|
| 644 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 645 |
+
cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack);
|
| 646 |
+
|
| 647 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 648 |
+
cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack,
|
| 649 |
+
cudnnFusedOpsConstParamLabel_t paramLabel,
|
| 650 |
+
const void *param);
|
| 651 |
+
|
| 652 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 653 |
+
cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack,
|
| 654 |
+
cudnnFusedOpsConstParamLabel_t paramLabel,
|
| 655 |
+
void *param,
|
| 656 |
+
int *isNULL);
|
| 657 |
+
|
| 658 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 659 |
+
cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops);
|
| 660 |
+
|
| 661 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 662 |
+
cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack);
|
| 663 |
+
|
| 664 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 665 |
+
cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack,
|
| 666 |
+
cudnnFusedOpsVariantParamLabel_t paramLabel,
|
| 667 |
+
void *ptr);
|
| 668 |
+
|
| 669 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 670 |
+
cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack,
|
| 671 |
+
cudnnFusedOpsVariantParamLabel_t paramLabel,
|
| 672 |
+
void *ptr);
|
| 673 |
+
|
| 674 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 675 |
+
cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops);
|
| 676 |
+
|
| 677 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 678 |
+
cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan);
|
| 679 |
+
|
| 680 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 681 |
+
cudnnMakeFusedOpsPlan(cudnnHandle_t handle,
|
| 682 |
+
cudnnFusedOpsPlan_t plan,
|
| 683 |
+
const cudnnFusedOpsConstParamPack_t constPack,
|
| 684 |
+
size_t *workspaceSizeInBytes);
|
| 685 |
+
|
| 686 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 687 |
+
cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack);
|
| 688 |
+
|
| 689 |
+
#if defined(__cplusplus)
|
| 690 |
+
}
|
| 691 |
+
#endif
|
| 692 |
+
|
| 693 |
+
#endif /* CUDNN_CNN_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_v9.h
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* cudnn_cnn : cuDNN's basic definitions and CNN functions.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_CNN_H_)
|
| 55 |
+
#define CUDNN_CNN_H_
|
| 56 |
+
|
| 57 |
+
#pragma once
|
| 58 |
+
#include <stdint.h>
|
| 59 |
+
|
| 60 |
+
#include "cudnn_version.h"
|
| 61 |
+
#include "cudnn_ops.h"
|
| 62 |
+
|
| 63 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 64 |
+
#define CUDNN_CNN_MAJOR 9
|
| 65 |
+
#define CUDNN_CNN_MINOR 1
|
| 66 |
+
#define CUDNN_CNN_PATCH 0
|
| 67 |
+
|
| 68 |
+
#if (CUDNN_CNN_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_MINOR != CUDNN_MINOR) || (CUDNN_CNN_PATCH != CUDNN_PATCHLEVEL)
|
| 69 |
+
#error Version mismatch in cuDNN CNN INFER!!!
|
| 70 |
+
#endif
|
| 71 |
+
|
| 72 |
+
#if defined(__cplusplus)
|
| 73 |
+
extern "C" {
|
| 74 |
+
#endif
|
| 75 |
+
|
| 76 |
+
typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t CUDNN_DEPRECATED;
|
| 77 |
+
|
| 78 |
+
typedef struct cudnnConvolutionFwdAlgoPerfStruct {
|
| 79 |
+
cudnnConvolutionFwdAlgo_t algo;
|
| 80 |
+
cudnnStatus_t status;
|
| 81 |
+
float time;
|
| 82 |
+
size_t memory;
|
| 83 |
+
cudnnDeterminism_t determinism;
|
| 84 |
+
cudnnMathType_t mathType;
|
| 85 |
+
int reserved[3];
|
| 86 |
+
} cudnnConvolutionFwdAlgoPerf_t CUDNN_DEPRECATED;
|
| 87 |
+
|
| 88 |
+
/* Create an instance of convolution descriptor */
|
| 89 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 90 |
+
cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc);
|
| 91 |
+
|
| 92 |
+
/* Destroy an instance of convolution descriptor */
|
| 93 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 94 |
+
cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc);
|
| 95 |
+
|
| 96 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 97 |
+
cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType);
|
| 98 |
+
|
| 99 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 100 |
+
cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType);
|
| 101 |
+
|
| 102 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 103 |
+
cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount);
|
| 104 |
+
|
| 105 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 106 |
+
cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount);
|
| 107 |
+
|
| 108 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 109 |
+
cudnnSetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType);
|
| 110 |
+
|
| 111 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 112 |
+
cudnnGetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType);
|
| 113 |
+
|
| 114 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 115 |
+
cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc,
|
| 116 |
+
int pad_h, /* zero-padding height */
|
| 117 |
+
int pad_w, /* zero-padding width */
|
| 118 |
+
int u, /* vertical filter stride */
|
| 119 |
+
int v, /* horizontal filter stride */
|
| 120 |
+
int dilation_h, /* filter dilation in the vertical dimension */
|
| 121 |
+
int dilation_w, /* filter dilation in the horizontal dimension */
|
| 122 |
+
cudnnConvolutionMode_t mode,
|
| 123 |
+
cudnnDataType_t computeType);
|
| 124 |
+
|
| 125 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 126 |
+
cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc,
|
| 127 |
+
int *pad_h, /* zero-padding height */
|
| 128 |
+
int *pad_w, /* zero-padding width */
|
| 129 |
+
int *u, /* vertical filter stride */
|
| 130 |
+
int *v, /* horizontal filter stride */
|
| 131 |
+
int *dilation_h, /* filter dilation in the vertical dimension */
|
| 132 |
+
int *dilation_w, /* filter dilation in the horizontal dimension */
|
| 133 |
+
cudnnConvolutionMode_t *mode,
|
| 134 |
+
cudnnDataType_t *computeType);
|
| 135 |
+
|
| 136 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 137 |
+
cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc,
|
| 138 |
+
int arrayLength, /* nbDims-2 size */
|
| 139 |
+
const int padA[],
|
| 140 |
+
const int filterStrideA[],
|
| 141 |
+
const int dilationA[],
|
| 142 |
+
cudnnConvolutionMode_t mode,
|
| 143 |
+
cudnnDataType_t computeType); /* convolution data type */
|
| 144 |
+
|
| 145 |
+
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
|
| 146 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 147 |
+
cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc,
|
| 148 |
+
int arrayLengthRequested,
|
| 149 |
+
int *arrayLength,
|
| 150 |
+
int padA[],
|
| 151 |
+
int strideA[],
|
| 152 |
+
int dilationA[],
|
| 153 |
+
cudnnConvolutionMode_t *mode,
|
| 154 |
+
cudnnDataType_t *computeType); /* convolution data type */
|
| 155 |
+
|
| 156 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 157 |
+
cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
|
| 158 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 159 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 160 |
+
int *n,
|
| 161 |
+
int *c,
|
| 162 |
+
int *h,
|
| 163 |
+
int *w);
|
| 164 |
+
|
| 165 |
+
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
|
| 166 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 167 |
+
cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
|
| 168 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 169 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 170 |
+
int nbDims,
|
| 171 |
+
int tensorOuputDimA[]);
|
| 172 |
+
|
| 173 |
+
/* helper function to provide the convolution forward algo that fit best the requirement */
|
| 174 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 175 |
+
cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count);
|
| 176 |
+
|
| 177 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 178 |
+
cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle,
|
| 179 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 180 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 181 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 182 |
+
const cudnnTensorDescriptor_t destDesc,
|
| 183 |
+
const int requestedAlgoCount,
|
| 184 |
+
int *returnedAlgoCount,
|
| 185 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults);
|
| 186 |
+
|
| 187 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 188 |
+
cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle,
|
| 189 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 190 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 191 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 192 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 193 |
+
const int requestedAlgoCount,
|
| 194 |
+
int *returnedAlgoCount,
|
| 195 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults);
|
| 196 |
+
|
| 197 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 198 |
+
cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle,
|
| 199 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 200 |
+
const void *x,
|
| 201 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 202 |
+
const void *w,
|
| 203 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 204 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 205 |
+
void *y,
|
| 206 |
+
const int requestedAlgoCount,
|
| 207 |
+
int *returnedAlgoCount,
|
| 208 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults,
|
| 209 |
+
void *workSpace,
|
| 210 |
+
size_t workSpaceSizeInBytes);
|
| 211 |
+
|
| 212 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 213 |
+
cudnnIm2Col(cudnnHandle_t handle,
|
| 214 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 215 |
+
const void *x,
|
| 216 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 217 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 218 |
+
void *colBuffer);
|
| 219 |
+
|
| 220 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 221 |
+
cudnnReorderFilterAndBias(cudnnHandle_t handle,
|
| 222 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 223 |
+
cudnnReorderType_t reorderType,
|
| 224 |
+
const void *filterData,
|
| 225 |
+
void *reorderedFilterData,
|
| 226 |
+
int reorderBias,
|
| 227 |
+
const void *biasData,
|
| 228 |
+
void *reorderedBiasData);
|
| 229 |
+
|
| 230 |
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
|
| 231 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 232 |
+
cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle,
|
| 233 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 234 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 235 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 236 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 237 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 238 |
+
size_t *sizeInBytes);
|
| 239 |
+
|
| 240 |
+
/* Convolution functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 241 |
+
|
| 242 |
+
/* Function to perform the forward pass for batch convolution */
|
| 243 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 244 |
+
cudnnConvolutionForward(cudnnHandle_t handle,
|
| 245 |
+
const void *alpha,
|
| 246 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 247 |
+
const void *x,
|
| 248 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 249 |
+
const void *w,
|
| 250 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 251 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 252 |
+
void *workSpace,
|
| 253 |
+
size_t workSpaceSizeInBytes,
|
| 254 |
+
const void *beta,
|
| 255 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 256 |
+
void *y);
|
| 257 |
+
|
| 258 |
+
/* Fused conv/bias/activation operation : y = Act( alpha1 * conv(x) + alpha2 * z + bias ) */
|
| 259 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 260 |
+
cudnnConvolutionBiasActivationForward(cudnnHandle_t handle,
|
| 261 |
+
const void *alpha1,
|
| 262 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 263 |
+
const void *x,
|
| 264 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 265 |
+
const void *w,
|
| 266 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 267 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 268 |
+
void *workSpace,
|
| 269 |
+
size_t workSpaceSizeInBytes,
|
| 270 |
+
const void *alpha2,
|
| 271 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 272 |
+
const void *z,
|
| 273 |
+
const cudnnTensorDescriptor_t biasDesc,
|
| 274 |
+
const void *bias,
|
| 275 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 276 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 277 |
+
void *y);
|
| 278 |
+
|
| 279 |
+
/* helper function to provide the convolution backward data algo that fit best the requirement */
|
| 280 |
+
|
| 281 |
+
typedef struct cudnnConvolutionBwdDataAlgoPerfStruct {
|
| 282 |
+
cudnnConvolutionBwdDataAlgo_t algo;
|
| 283 |
+
cudnnStatus_t status;
|
| 284 |
+
float time;
|
| 285 |
+
size_t memory;
|
| 286 |
+
cudnnDeterminism_t determinism;
|
| 287 |
+
cudnnMathType_t mathType;
|
| 288 |
+
int reserved[3];
|
| 289 |
+
} cudnnConvolutionBwdDataAlgoPerf_t CUDNN_DEPRECATED;
|
| 290 |
+
|
| 291 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 292 |
+
cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count);
|
| 293 |
+
|
| 294 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 295 |
+
cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
|
| 296 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 297 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 298 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 299 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 300 |
+
const int requestedAlgoCount,
|
| 301 |
+
int *returnedAlgoCount,
|
| 302 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
|
| 303 |
+
|
| 304 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 305 |
+
cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle,
|
| 306 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 307 |
+
const void *w,
|
| 308 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 309 |
+
const void *dy,
|
| 310 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 311 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 312 |
+
void *dx,
|
| 313 |
+
const int requestedAlgoCount,
|
| 314 |
+
int *returnedAlgoCount,
|
| 315 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults,
|
| 316 |
+
void *workSpace,
|
| 317 |
+
size_t workSpaceSizeInBytes);
|
| 318 |
+
|
| 319 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 320 |
+
cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle,
|
| 321 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 322 |
+
const cudnnTensorDescriptor_t diffDesc,
|
| 323 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 324 |
+
const cudnnTensorDescriptor_t gradDesc,
|
| 325 |
+
const int requestedAlgoCount,
|
| 326 |
+
int *returnedAlgoCount,
|
| 327 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
|
| 328 |
+
|
| 329 |
+
/*
|
| 330 |
+
* convolution algorithm (which requires potentially some workspace)
|
| 331 |
+
*/
|
| 332 |
+
|
| 333 |
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
|
| 334 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 335 |
+
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle,
|
| 336 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 337 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 338 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 339 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 340 |
+
cudnnConvolutionBwdDataAlgo_t algo,
|
| 341 |
+
size_t *sizeInBytes);
|
| 342 |
+
|
| 343 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 344 |
+
cudnnConvolutionBackwardData(cudnnHandle_t handle,
|
| 345 |
+
const void *alpha,
|
| 346 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 347 |
+
const void *w,
|
| 348 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 349 |
+
const void *dy,
|
| 350 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 351 |
+
cudnnConvolutionBwdDataAlgo_t algo,
|
| 352 |
+
void *workSpace,
|
| 353 |
+
size_t workSpaceSizeInBytes,
|
| 354 |
+
const void *beta,
|
| 355 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 356 |
+
void *dx);
|
| 357 |
+
|
| 358 |
+
/* Helper function to calculate folding descriptors for dgrad */
|
| 359 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 360 |
+
cudnnGetFoldedConvBackwardDataDescriptors(const cudnnHandle_t handle,
|
| 361 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 362 |
+
const cudnnTensorDescriptor_t diffDesc,
|
| 363 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 364 |
+
const cudnnTensorDescriptor_t gradDesc,
|
| 365 |
+
const cudnnTensorFormat_t transformFormat,
|
| 366 |
+
cudnnFilterDescriptor_t foldedFilterDesc,
|
| 367 |
+
cudnnTensorDescriptor_t paddedDiffDesc,
|
| 368 |
+
cudnnConvolutionDescriptor_t foldedConvDesc,
|
| 369 |
+
cudnnTensorDescriptor_t foldedGradDesc,
|
| 370 |
+
cudnnTensorTransformDescriptor_t filterFoldTransDesc,
|
| 371 |
+
cudnnTensorTransformDescriptor_t diffPadTransDesc,
|
| 372 |
+
cudnnTensorTransformDescriptor_t gradFoldTransDesc,
|
| 373 |
+
cudnnTensorTransformDescriptor_t gradUnfoldTransDesc);
|
| 374 |
+
|
| 375 |
+
/* cudnnFusedOps... */
|
| 376 |
+
struct cudnnFusedOpsConstParamStruct;
|
| 377 |
+
typedef struct cudnnFusedOpsConstParamStruct *cudnnFusedOpsConstParamPack_t CUDNN_DEPRECATED;
|
| 378 |
+
|
| 379 |
+
struct cudnnFusedOpsVariantParamStruct;
|
| 380 |
+
typedef struct cudnnFusedOpsVariantParamStruct *cudnnFusedOpsVariantParamPack_t CUDNN_DEPRECATED;
|
| 381 |
+
|
| 382 |
+
struct cudnnFusedOpsPlanStruct;
|
| 383 |
+
typedef struct cudnnFusedOpsPlanStruct *cudnnFusedOpsPlan_t CUDNN_DEPRECATED;
|
| 384 |
+
|
| 385 |
+
typedef enum {
|
| 386 |
+
/* each op in [ ] can be disabled by passing NULL ptr */
|
| 387 |
+
/* [per channel scale], [per channel bias], [activation], convolution, [generate BN stats] */
|
| 388 |
+
CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS = 0,
|
| 389 |
+
/* [per channel scale], [per channel bias], [activation], convolutionBackwardWeights */
|
| 390 |
+
CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD = 1,
|
| 391 |
+
/* utility for BN training in BN-conv fusion */
|
| 392 |
+
/* computes the equivalent scale and bias from ySum ySqSum and learned scale, bias */
|
| 393 |
+
/* optionally update running stats and generate saved stats */
|
| 394 |
+
CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING = 2,
|
| 395 |
+
/* utility for BN inference in BN-conv fusion */
|
| 396 |
+
/* computes the equivalent scale and bias from learned running stats and learned scale, bias */
|
| 397 |
+
CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE = 3,
|
| 398 |
+
/* reserved for future use: convolution, [per channel scale], [per channel bias], [residual add], [activation] */
|
| 399 |
+
CUDNN_FUSED_CONV_SCALE_BIAS_ADD_ACTIVATION = 4,
|
| 400 |
+
/* reserved for future use: [per channel scale], [per channel bias], [residual add], activation, bitmask */
|
| 401 |
+
CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK = 5,
|
| 402 |
+
/* reserved for future use */
|
| 403 |
+
CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM = 6,
|
| 404 |
+
} cudnnFusedOps_t CUDNN_DEPRECATED;
|
| 405 |
+
|
| 406 |
+
typedef enum {
|
| 407 |
+
/* set XDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 408 |
+
/* get XDESC: pass previously created cudnnTensorDescriptor_t */
|
| 409 |
+
CUDNN_PARAM_XDESC = 0,
|
| 410 |
+
/* set/get XDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 411 |
+
CUDNN_PARAM_XDATA_PLACEHOLDER = 1,
|
| 412 |
+
/* set/get BN_MODE: pass cudnnBatchNormMode_t* */
|
| 413 |
+
CUDNN_PARAM_BN_MODE = 2,
|
| 414 |
+
/* set CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 415 |
+
/* get CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 416 |
+
CUDNN_PARAM_BN_EQSCALEBIAS_DESC = 3,
|
| 417 |
+
/* set/get BN_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 418 |
+
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER = 4,
|
| 419 |
+
/* set/get BN_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 420 |
+
CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER = 5,
|
| 421 |
+
/* set ACTIVATION_DESC: pass previously initialized cudnnActivationDescriptor_t */
|
| 422 |
+
/* get ACTIVATION_DESC: pass previously created cudnnActivationDescriptor_t */
|
| 423 |
+
CUDNN_PARAM_ACTIVATION_DESC = 6,
|
| 424 |
+
/* set CONV_DESC: pass previously initialized cudnnConvolutionDescriptor_t */
|
| 425 |
+
/* get CONV_DESC: pass previously created cudnnConvolutionDescriptor_t */
|
| 426 |
+
CUDNN_PARAM_CONV_DESC = 7,
|
| 427 |
+
/* set WDESC: pass previously initialized cudnnFilterDescriptor_t */
|
| 428 |
+
/* get WDESC: pass previously created cudnnFilterDescriptor_t */
|
| 429 |
+
CUDNN_PARAM_WDESC = 8,
|
| 430 |
+
/* set/get WDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 431 |
+
CUDNN_PARAM_WDATA_PLACEHOLDER = 9,
|
| 432 |
+
/* set DWDESC: pass previously initialized cudnnFilterDescriptor_t */
|
| 433 |
+
/* get DWDESC: pass previously created cudnnFilterDescriptor_t */
|
| 434 |
+
CUDNN_PARAM_DWDESC = 10,
|
| 435 |
+
/* set/get DWDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 436 |
+
CUDNN_PARAM_DWDATA_PLACEHOLDER = 11,
|
| 437 |
+
/* set YDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 438 |
+
/* get YDESC: pass previously created cudnnTensorDescriptor_t */
|
| 439 |
+
CUDNN_PARAM_YDESC = 12,
|
| 440 |
+
/* set/get YDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 441 |
+
CUDNN_PARAM_YDATA_PLACEHOLDER = 13,
|
| 442 |
+
/* set DYDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 443 |
+
/* get DYDESC: pass previously created cudnnTensorDescriptor_t */
|
| 444 |
+
CUDNN_PARAM_DYDESC = 14,
|
| 445 |
+
/* set/get DYDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 446 |
+
CUDNN_PARAM_DYDATA_PLACEHOLDER = 15,
|
| 447 |
+
/* set YSTATS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 448 |
+
/* get YSTATS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 449 |
+
CUDNN_PARAM_YSTATS_DESC = 16,
|
| 450 |
+
/* set/get YSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 451 |
+
CUDNN_PARAM_YSUM_PLACEHOLDER = 17,
|
| 452 |
+
/* set/get YSQSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 453 |
+
CUDNN_PARAM_YSQSUM_PLACEHOLDER = 18,
|
| 454 |
+
/* set CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 455 |
+
/* get CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 456 |
+
CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC = 19,
|
| 457 |
+
/* set/get CUDNN_PARAM_BN_SCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 458 |
+
CUDNN_PARAM_BN_SCALE_PLACEHOLDER = 20,
|
| 459 |
+
/* set/get CUDNN_PARAM_BN_BIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 460 |
+
CUDNN_PARAM_BN_BIAS_PLACEHOLDER = 21,
|
| 461 |
+
/* set/get CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 462 |
+
CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER = 22,
|
| 463 |
+
/* set/get CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 464 |
+
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER = 23,
|
| 465 |
+
/* set/get CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 466 |
+
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER = 24,
|
| 467 |
+
/* set/get CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 468 |
+
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER = 25,
|
| 469 |
+
|
| 470 |
+
/* set ZDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 471 |
+
/* get ZDESC: pass previously created cudnnTensorDescriptor_t */
|
| 472 |
+
CUDNN_PARAM_ZDESC = 26,
|
| 473 |
+
/* set/get ZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 474 |
+
CUDNN_PARAM_ZDATA_PLACEHOLDER = 27,
|
| 475 |
+
/* set BN_Z_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 476 |
+
/* get BN_Z_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 477 |
+
CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC = 28,
|
| 478 |
+
/* set/get BN_Z_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 479 |
+
CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER = 29,
|
| 480 |
+
/* set/get BN_Z_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 481 |
+
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER = 30,
|
| 482 |
+
|
| 483 |
+
/* set ACTIVATION_BITMASK_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 484 |
+
/* get ACTIVATION_BITMASK_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 485 |
+
CUDNN_PARAM_ACTIVATION_BITMASK_DESC = 31,
|
| 486 |
+
/* set/get ACTIVATION_BITMASK_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 487 |
+
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER = 32,
|
| 488 |
+
|
| 489 |
+
/* set DXDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 490 |
+
/* get DXDESC: pass previously created cudnnTensorDescriptor_t */
|
| 491 |
+
CUDNN_PARAM_DXDESC = 33,
|
| 492 |
+
/* set/get DXDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 493 |
+
CUDNN_PARAM_DXDATA_PLACEHOLDER = 34,
|
| 494 |
+
/* set DZDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 495 |
+
/* get DZDESC: pass previously created cudnnTensorDescriptor_t */
|
| 496 |
+
CUDNN_PARAM_DZDESC = 35,
|
| 497 |
+
/* set/get DZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 498 |
+
CUDNN_PARAM_DZDATA_PLACEHOLDER = 36,
|
| 499 |
+
/* set/get CUDNN_PARAM_BN_DSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 500 |
+
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER = 37,
|
| 501 |
+
/* set/get CUDNN_PARAM_BN_DBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 502 |
+
CUDNN_PARAM_BN_DBIAS_PLACEHOLDER = 38,
|
| 503 |
+
} cudnnFusedOpsConstParamLabel_t CUDNN_DEPRECATED;
|
| 504 |
+
|
| 505 |
+
typedef enum {
|
| 506 |
+
CUDNN_PTR_NULL = 0,
|
| 507 |
+
CUDNN_PTR_ELEM_ALIGNED = 1,
|
| 508 |
+
CUDNN_PTR_16B_ALIGNED = 2,
|
| 509 |
+
} cudnnFusedOpsPointerPlaceHolder_t CUDNN_DEPRECATED;
|
| 510 |
+
|
| 511 |
+
typedef enum {
|
| 512 |
+
/* set: pass void* pointing to dev memory */
|
| 513 |
+
/* get: pass void** pointing to host memory */
|
| 514 |
+
CUDNN_PTR_XDATA = 0,
|
| 515 |
+
CUDNN_PTR_BN_EQSCALE = 1,
|
| 516 |
+
CUDNN_PTR_BN_EQBIAS = 2,
|
| 517 |
+
CUDNN_PTR_WDATA = 3,
|
| 518 |
+
CUDNN_PTR_DWDATA = 4,
|
| 519 |
+
CUDNN_PTR_YDATA = 5,
|
| 520 |
+
CUDNN_PTR_DYDATA = 6,
|
| 521 |
+
CUDNN_PTR_YSUM = 7,
|
| 522 |
+
CUDNN_PTR_YSQSUM = 8,
|
| 523 |
+
CUDNN_PTR_WORKSPACE = 9,
|
| 524 |
+
CUDNN_PTR_BN_SCALE = 10,
|
| 525 |
+
CUDNN_PTR_BN_BIAS = 11,
|
| 526 |
+
CUDNN_PTR_BN_SAVED_MEAN = 12,
|
| 527 |
+
CUDNN_PTR_BN_SAVED_INVSTD = 13,
|
| 528 |
+
CUDNN_PTR_BN_RUNNING_MEAN = 14,
|
| 529 |
+
CUDNN_PTR_BN_RUNNING_VAR = 15,
|
| 530 |
+
CUDNN_PTR_ZDATA = 16,
|
| 531 |
+
CUDNN_PTR_BN_Z_EQSCALE = 17,
|
| 532 |
+
CUDNN_PTR_BN_Z_EQBIAS = 18,
|
| 533 |
+
CUDNN_PTR_ACTIVATION_BITMASK = 19,
|
| 534 |
+
CUDNN_PTR_DXDATA = 20,
|
| 535 |
+
CUDNN_PTR_DZDATA = 21,
|
| 536 |
+
CUDNN_PTR_BN_DSCALE = 22,
|
| 537 |
+
CUDNN_PTR_BN_DBIAS = 23,
|
| 538 |
+
|
| 539 |
+
/* set/get: pass size_t* pointing to host memory */
|
| 540 |
+
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES = 100,
|
| 541 |
+
/* set/get: pass int64_t* pointing to host memory */
|
| 542 |
+
CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT = 101,
|
| 543 |
+
/* set/get: pass double* pointing to host memory */
|
| 544 |
+
CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR = 102,
|
| 545 |
+
/* set/get: pass double* pointing to host memory */
|
| 546 |
+
CUDNN_SCALAR_DOUBLE_BN_EPSILON = 103,
|
| 547 |
+
} cudnnFusedOpsVariantParamLabel_t CUDNN_DEPRECATED;
|
| 548 |
+
|
| 549 |
+
cudnnStatus_t CUDNNWINAPI
|
| 550 |
+
cudnnCnnVersionCheck(void);
|
| 551 |
+
|
| 552 |
+
/* helper function to provide the convolution backward filter algo that fit best the requirement */
|
| 553 |
+
|
| 554 |
+
typedef struct cudnnConvolutionBwdFilterAlgoPerfStruct {
|
| 555 |
+
cudnnConvolutionBwdFilterAlgo_t algo;
|
| 556 |
+
cudnnStatus_t status;
|
| 557 |
+
float time;
|
| 558 |
+
size_t memory;
|
| 559 |
+
cudnnDeterminism_t determinism;
|
| 560 |
+
cudnnMathType_t mathType;
|
| 561 |
+
int reserved[3];
|
| 562 |
+
} cudnnConvolutionBwdFilterAlgoPerf_t CUDNN_DEPRECATED;
|
| 563 |
+
|
| 564 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 565 |
+
cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count);
|
| 566 |
+
|
| 567 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 568 |
+
cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
|
| 569 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 570 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 571 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 572 |
+
const cudnnFilterDescriptor_t dwDesc,
|
| 573 |
+
const int requestedAlgoCount,
|
| 574 |
+
int *returnedAlgoCount,
|
| 575 |
+
cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
|
| 576 |
+
|
| 577 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 578 |
+
cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
|
| 579 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 580 |
+
const void *x,
|
| 581 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 582 |
+
const void *y,
|
| 583 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 584 |
+
const cudnnFilterDescriptor_t dwDesc,
|
| 585 |
+
void *dw,
|
| 586 |
+
const int requestedAlgoCount,
|
| 587 |
+
int *returnedAlgoCount,
|
| 588 |
+
cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
|
| 589 |
+
void *workSpace,
|
| 590 |
+
size_t workSpaceSizeInBytes);
|
| 591 |
+
|
| 592 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 593 |
+
cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
|
| 594 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 595 |
+
const cudnnTensorDescriptor_t diffDesc,
|
| 596 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 597 |
+
const cudnnFilterDescriptor_t gradDesc,
|
| 598 |
+
const int requestedAlgoCount,
|
| 599 |
+
int *returnedAlgoCount,
|
| 600 |
+
cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
|
| 601 |
+
|
| 602 |
+
/*
|
| 603 |
+
* convolution algorithm (which requires potentially some workspace)
|
| 604 |
+
*/
|
| 605 |
+
|
| 606 |
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
|
| 607 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 608 |
+
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle,
|
| 609 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 610 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 611 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 612 |
+
const cudnnFilterDescriptor_t gradDesc,
|
| 613 |
+
cudnnConvolutionBwdFilterAlgo_t algo,
|
| 614 |
+
size_t *sizeInBytes);
|
| 615 |
+
|
| 616 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 617 |
+
cudnnConvolutionBackwardFilter(cudnnHandle_t handle,
|
| 618 |
+
const void *alpha,
|
| 619 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 620 |
+
const void *x,
|
| 621 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 622 |
+
const void *dy,
|
| 623 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 624 |
+
cudnnConvolutionBwdFilterAlgo_t algo,
|
| 625 |
+
void *workSpace,
|
| 626 |
+
size_t workSpaceSizeInBytes,
|
| 627 |
+
const void *beta,
|
| 628 |
+
const cudnnFilterDescriptor_t dwDesc,
|
| 629 |
+
void *dw);
|
| 630 |
+
|
| 631 |
+
/* Function to compute the bias gradient for batch convolution */
|
| 632 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 633 |
+
cudnnConvolutionBackwardBias(cudnnHandle_t handle,
|
| 634 |
+
const void *alpha,
|
| 635 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 636 |
+
const void *dy,
|
| 637 |
+
const void *beta,
|
| 638 |
+
const cudnnTensorDescriptor_t dbDesc,
|
| 639 |
+
void *db);
|
| 640 |
+
|
| 641 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 642 |
+
cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops);
|
| 643 |
+
|
| 644 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 645 |
+
cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack);
|
| 646 |
+
|
| 647 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 648 |
+
cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack,
|
| 649 |
+
cudnnFusedOpsConstParamLabel_t paramLabel,
|
| 650 |
+
const void *param);
|
| 651 |
+
|
| 652 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 653 |
+
cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack,
|
| 654 |
+
cudnnFusedOpsConstParamLabel_t paramLabel,
|
| 655 |
+
void *param,
|
| 656 |
+
int *isNULL);
|
| 657 |
+
|
| 658 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 659 |
+
cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops);
|
| 660 |
+
|
| 661 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 662 |
+
cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack);
|
| 663 |
+
|
| 664 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 665 |
+
cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack,
|
| 666 |
+
cudnnFusedOpsVariantParamLabel_t paramLabel,
|
| 667 |
+
void *ptr);
|
| 668 |
+
|
| 669 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 670 |
+
cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack,
|
| 671 |
+
cudnnFusedOpsVariantParamLabel_t paramLabel,
|
| 672 |
+
void *ptr);
|
| 673 |
+
|
| 674 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 675 |
+
cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops);
|
| 676 |
+
|
| 677 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 678 |
+
cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan);
|
| 679 |
+
|
| 680 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 681 |
+
cudnnMakeFusedOpsPlan(cudnnHandle_t handle,
|
| 682 |
+
cudnnFusedOpsPlan_t plan,
|
| 683 |
+
const cudnnFusedOpsConstParamPack_t constPack,
|
| 684 |
+
size_t *workspaceSizeInBytes);
|
| 685 |
+
|
| 686 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 687 |
+
cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack);
|
| 688 |
+
|
| 689 |
+
#if defined(__cplusplus)
|
| 690 |
+
}
|
| 691 |
+
#endif
|
| 692 |
+
|
| 693 |
+
#endif /* CUDNN_CNN_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_graph_v9.h
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* cudnn_graph : cuDNN's basic definitions operations.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_GRAPH_H_)
|
| 55 |
+
#define CUDNN_GRAPH_H_
|
| 56 |
+
|
| 57 |
+
#include <cuda_runtime_api.h>
|
| 58 |
+
#include <library_types.h>
|
| 59 |
+
|
| 60 |
+
#include <stdint.h>
|
| 61 |
+
|
| 62 |
+
#include "cudnn_version.h"
|
| 63 |
+
|
| 64 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 65 |
+
#define CUDNN_GRAPH_MAJOR 9
|
| 66 |
+
#define CUDNN_GRAPH_MINOR 1
|
| 67 |
+
#define CUDNN_GRAPH_PATCH 0
|
| 68 |
+
|
| 69 |
+
#if (CUDNN_GRAPH_MAJOR != CUDNN_MAJOR) || (CUDNN_GRAPH_MINOR != CUDNN_MINOR) || (CUDNN_GRAPH_PATCH != CUDNN_PATCHLEVEL)
|
| 70 |
+
#error Version mismatch in cuDNN GRAPH!!!
|
| 71 |
+
#endif
|
| 72 |
+
|
| 73 |
+
#ifndef CUDNNWINAPI
|
| 74 |
+
#ifdef _WIN32
|
| 75 |
+
#define CUDNNWINAPI __stdcall
|
| 76 |
+
#else
|
| 77 |
+
#define CUDNNWINAPI
|
| 78 |
+
#endif
|
| 79 |
+
#endif
|
| 80 |
+
|
| 81 |
+
/* Warnings for deprecated API-s are enabled using the CUDNN_WARN_DEPRECATED macro */
|
| 82 |
+
#if defined(CUDNN_WARN_DEPRECATED) && (defined(__GNUC__) || defined(__clang__))
|
| 83 |
+
/* GCC, Intel C/C++, Cray C/C++, CLANG, IBM XL C/C++ little endian */
|
| 84 |
+
#define CUDNN_DEPRECATED __attribute__((deprecated))
|
| 85 |
+
#define CUDNN_DEPRECATED_ENUM __attribute__((deprecated))
|
| 86 |
+
#elif defined(CUDNN_WARN_DEPRECATED) && defined(_MSC_VER)
|
| 87 |
+
/* Microsoft Visual C++ */
|
| 88 |
+
#define CUDNN_DEPRECATED __declspec(deprecated)
|
| 89 |
+
#define CUDNN_DEPRECATED_ENUM __declspec(deprecated)
|
| 90 |
+
#elif defined(CUDNN_WARN_DEPRECATED) && (__cplusplus >= 201402L)
|
| 91 |
+
/* C++14 compilers */
|
| 92 |
+
#define CUDNN_DEPRECATED [[deprecated]]
|
| 93 |
+
#define CUDNN_DEPRECATED_ENUM [[deprecated]]
|
| 94 |
+
#else
|
| 95 |
+
/* No support for the deprecated attribute */
|
| 96 |
+
#define CUDNN_DEPRECATED
|
| 97 |
+
#define CUDNN_DEPRECATED_ENUM
|
| 98 |
+
#endif
|
| 99 |
+
|
| 100 |
+
#if defined(__cplusplus)
|
| 101 |
+
extern "C" {
|
| 102 |
+
#endif
|
| 103 |
+
|
| 104 |
+
struct cudnnContext;
|
| 105 |
+
typedef struct cudnnContext *cudnnHandle_t;
|
| 106 |
+
|
| 107 |
+
size_t CUDNNWINAPI
|
| 108 |
+
cudnnGetVersion(void);
|
| 109 |
+
|
| 110 |
+
size_t CUDNNWINAPI
|
| 111 |
+
cudnnGetMaxDeviceVersion(void);
|
| 112 |
+
|
| 113 |
+
/* Returns CUDA Runtime version statically linked against cudnn */
|
| 114 |
+
size_t CUDNNWINAPI
|
| 115 |
+
cudnnGetCudartVersion(void);
|
| 116 |
+
|
| 117 |
+
/*
|
| 118 |
+
* CUDNN return codes
|
| 119 |
+
*/
|
| 120 |
+
typedef enum {
|
| 121 |
+
CUDNN_STATUS_SUCCESS = 0,
|
| 122 |
+
|
| 123 |
+
/* Uncategorized errors */
|
| 124 |
+
CUDNN_STATUS_NOT_INITIALIZED = 1001,
|
| 125 |
+
CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH = 1002,
|
| 126 |
+
CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH = 1003,
|
| 127 |
+
CUDNN_STATUS_DEPRECATED = 1004,
|
| 128 |
+
CUDNN_STATUS_LICENSE_ERROR = 1005,
|
| 129 |
+
CUDNN_STATUS_RUNTIME_IN_PROGRESS = 1006,
|
| 130 |
+
CUDNN_STATUS_RUNTIME_FP_OVERFLOW = 1007,
|
| 131 |
+
|
| 132 |
+
CUDNN_STATUS_BAD_PARAM = 2000,
|
| 133 |
+
CUDNN_STATUS_BAD_PARAM_NULL_POINTER = 2002,
|
| 134 |
+
CUDNN_STATUS_BAD_PARAM_MISALIGNED_POINTER = 2003,
|
| 135 |
+
CUDNN_STATUS_BAD_PARAM_NOT_FINALIZED = 2004,
|
| 136 |
+
CUDNN_STATUS_BAD_PARAM_OUT_OF_BOUND = 2005,
|
| 137 |
+
CUDNN_STATUS_BAD_PARAM_SIZE_INSUFFICIENT = 2006,
|
| 138 |
+
CUDNN_STATUS_BAD_PARAM_STREAM_MISMATCH = 2007,
|
| 139 |
+
CUDNN_STATUS_BAD_PARAM_SHAPE_MISMATCH = 2008,
|
| 140 |
+
CUDNN_STATUS_BAD_PARAM_DUPLICATED_ENTRIES = 2009,
|
| 141 |
+
CUDNN_STATUS_BAD_PARAM_ATTRIBUTE_TYPE = 2010,
|
| 142 |
+
|
| 143 |
+
CUDNN_STATUS_NOT_SUPPORTED = 3000,
|
| 144 |
+
CUDNN_STATUS_NOT_SUPPORTED_GRAPH_PATTERN = 3001,
|
| 145 |
+
CUDNN_STATUS_NOT_SUPPORTED_SHAPE = 3002,
|
| 146 |
+
CUDNN_STATUS_NOT_SUPPORTED_DATA_TYPE = 3003,
|
| 147 |
+
CUDNN_STATUS_NOT_SUPPORTED_LAYOUT = 3004,
|
| 148 |
+
CUDNN_STATUS_NOT_SUPPORTED_INCOMPATIBLE_CUDA_DRIVER = 3005,
|
| 149 |
+
CUDNN_STATUS_NOT_SUPPORTED_INCOMPATIBLE_CUDART = 3006,
|
| 150 |
+
CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH = 3007,
|
| 151 |
+
CUDNN_STATUS_NOT_SUPPORTED_RUNTIME_PREREQUISITE_MISSING = 3008,
|
| 152 |
+
CUDNN_STATUS_NOT_SUPPORTED_SUBLIBRARY_UNAVAILABLE = 3009,
|
| 153 |
+
CUDNN_STATUS_NOT_SUPPORTED_SHARED_MEMORY_INSUFFICIENT = 3010,
|
| 154 |
+
CUDNN_STATUS_NOT_SUPPORTED_PADDING = 3011,
|
| 155 |
+
CUDNN_STATUS_NOT_SUPPORTED_BAD_LAUNCH_PARAM = 3012,
|
| 156 |
+
|
| 157 |
+
CUDNN_STATUS_INTERNAL_ERROR = 4000,
|
| 158 |
+
CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED = 4001,
|
| 159 |
+
CUDNN_STATUS_INTERNAL_ERROR_UNEXPECTED_VALUE = 4002,
|
| 160 |
+
CUDNN_STATUS_INTERNAL_ERROR_HOST_ALLOCATION_FAILED = 4003,
|
| 161 |
+
CUDNN_STATUS_INTERNAL_ERROR_DEVICE_ALLOCATION_FAILED = 4004,
|
| 162 |
+
CUDNN_STATUS_INTERNAL_ERROR_BAD_LAUNCH_PARAM = 4005,
|
| 163 |
+
CUDNN_STATUS_INTERNAL_ERROR_TEXTURE_CREATION_FAILED = 4006,
|
| 164 |
+
|
| 165 |
+
CUDNN_STATUS_EXECUTION_FAILED = 5000,
|
| 166 |
+
CUDNN_STATUS_EXECUTION_FAILED_CUDA_DRIVER = 5001,
|
| 167 |
+
CUDNN_STATUS_EXECUTION_FAILED_CUBLAS = 5002,
|
| 168 |
+
CUDNN_STATUS_EXECUTION_FAILED_CUDART = 5003,
|
| 169 |
+
CUDNN_STATUS_EXECUTION_FAILED_CURAND = 5004,
|
| 170 |
+
|
| 171 |
+
CUDNN_STATUS_ALLOC_FAILED CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_INTERNAL_ERROR_HOST_ALLOCATION_FAILED,
|
| 172 |
+
CUDNN_STATUS_INVALID_VALUE CUDNN_DEPRECATED_ENUM = 2001 /* please transition to CUDNN_STATUS_BAD_PARAM instead */,
|
| 173 |
+
CUDNN_STATUS_ARCH_MISMATCH CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH,
|
| 174 |
+
CUDNN_STATUS_MAPPING_ERROR CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_INTERNAL_ERROR_TEXTURE_CREATION_FAILED,
|
| 175 |
+
CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING CUDNN_DEPRECATED_ENUM =
|
| 176 |
+
CUDNN_STATUS_NOT_SUPPORTED_RUNTIME_PREREQUISITE_MISSING,
|
| 177 |
+
CUDNN_STATUS_VERSION_MISMATCH CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH,
|
| 178 |
+
} cudnnStatus_t;
|
| 179 |
+
|
| 180 |
+
#define CUDNN_STATUS_FULL_ERROR_CODE(category, specific_err) ((cudnnStatus_t)(0 + (category) + (specific_err)))
|
| 181 |
+
#define CUDNN_STATUS_CATEGORY(full_error_code) ((full_error_code) / 1000 * 1000)
|
| 182 |
+
#define CUDNN_STATUS_SPECIFIC_ERROR(full_error_code) ((full_error_code) % 1000)
|
| 183 |
+
|
| 184 |
+
/* human-readable error messages */
|
| 185 |
+
const char *CUDNNWINAPI
|
| 186 |
+
cudnnGetErrorString(cudnnStatus_t status);
|
| 187 |
+
|
| 188 |
+
void CUDNNWINAPI
|
| 189 |
+
cudnnGetLastErrorString(char *message, size_t max_size);
|
| 190 |
+
|
| 191 |
+
/* Forward definition in this version only */
|
| 192 |
+
typedef struct cudnnRuntimeTag_t cudnnRuntimeTag_t CUDNN_DEPRECATED;
|
| 193 |
+
|
| 194 |
+
typedef enum {
|
| 195 |
+
CUDNN_ERRQUERY_RAWCODE = 0,
|
| 196 |
+
CUDNN_ERRQUERY_NONBLOCKING = 1,
|
| 197 |
+
CUDNN_ERRQUERY_BLOCKING = 2,
|
| 198 |
+
} cudnnErrQueryMode_t;
|
| 199 |
+
|
| 200 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 201 |
+
cudnnQueryRuntimeError(cudnnHandle_t handle, cudnnStatus_t *rstatus, cudnnErrQueryMode_t mode, cudnnRuntimeTag_t *tag);
|
| 202 |
+
|
| 203 |
+
cudnnStatus_t CUDNNWINAPI
|
| 204 |
+
cudnnGetProperty(libraryPropertyType type, int *value);
|
| 205 |
+
|
| 206 |
+
cudnnStatus_t CUDNNWINAPI
|
| 207 |
+
cudnnCreate(cudnnHandle_t *handle);
|
| 208 |
+
cudnnStatus_t CUDNNWINAPI
|
| 209 |
+
cudnnDestroy(cudnnHandle_t handle);
|
| 210 |
+
cudnnStatus_t CUDNNWINAPI
|
| 211 |
+
cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
|
| 212 |
+
cudnnStatus_t CUDNNWINAPI
|
| 213 |
+
cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId);
|
| 214 |
+
/*
|
| 215 |
+
* CUDNN data type
|
| 216 |
+
*/
|
| 217 |
+
typedef enum {
|
| 218 |
+
CUDNN_DATA_FLOAT = 0,
|
| 219 |
+
CUDNN_DATA_DOUBLE = 1,
|
| 220 |
+
CUDNN_DATA_HALF = 2,
|
| 221 |
+
CUDNN_DATA_INT8 = 3,
|
| 222 |
+
CUDNN_DATA_INT32 = 4,
|
| 223 |
+
CUDNN_DATA_INT8x4 CUDNN_DEPRECATED_ENUM = 5,
|
| 224 |
+
CUDNN_DATA_UINT8 = 6,
|
| 225 |
+
CUDNN_DATA_UINT8x4 CUDNN_DEPRECATED_ENUM = 7,
|
| 226 |
+
CUDNN_DATA_INT8x32 CUDNN_DEPRECATED_ENUM = 8,
|
| 227 |
+
CUDNN_DATA_BFLOAT16 = 9,
|
| 228 |
+
CUDNN_DATA_INT64 = 10,
|
| 229 |
+
CUDNN_DATA_BOOLEAN = 11,
|
| 230 |
+
CUDNN_DATA_FP8_E4M3 = 12,
|
| 231 |
+
CUDNN_DATA_FP8_E5M2 = 13,
|
| 232 |
+
CUDNN_DATA_FAST_FLOAT_FOR_FP8 = 14,
|
| 233 |
+
} cudnnDataType_t;
|
| 234 |
+
|
| 235 |
+
/*
|
| 236 |
+
* CUDNN math type
|
| 237 |
+
*/
|
| 238 |
+
typedef enum {
|
| 239 |
+
CUDNN_DEFAULT_MATH = 0,
|
| 240 |
+
CUDNN_TENSOR_OP_MATH = 1,
|
| 241 |
+
CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION = 2,
|
| 242 |
+
CUDNN_FMA_MATH = 3,
|
| 243 |
+
} cudnnMathType_t;
|
| 244 |
+
|
| 245 |
+
/*
|
| 246 |
+
* CUDNN propagate Nan
|
| 247 |
+
*/
|
| 248 |
+
typedef enum {
|
| 249 |
+
CUDNN_NOT_PROPAGATE_NAN CUDNN_DEPRECATED_ENUM = 0,
|
| 250 |
+
CUDNN_PROPAGATE_NAN CUDNN_DEPRECATED_ENUM = 1,
|
| 251 |
+
} cudnnNanPropagation_t;
|
| 252 |
+
|
| 253 |
+
/*
|
| 254 |
+
* Behavior for OOB samples. OOB samples are samples where L+R > T is encountered during the gradient calculation. If
|
| 255 |
+
* gradMode is set to CUDNN_CTC_SKIP_OOB_GRADIENTS, then the CTC loss function does not write to the gradient buffer for
|
| 256 |
+
* that sample. Instead, the current values, even not finite, are retained. If gradMode is set to
|
| 257 |
+
* CUDNN_CTC_ZERO_OOB_GRADIENTS, then the gradient for that sample is set to zero. This guarantees a finite gradient.
|
| 258 |
+
*/
|
| 259 |
+
typedef enum {
|
| 260 |
+
CUDNN_CTC_ZERO_OOB_GRADIENTS = 0,
|
| 261 |
+
CUDNN_CTC_SKIP_OOB_GRADIENTS = 1,
|
| 262 |
+
} cudnnCTCGradMode_t;
|
| 263 |
+
|
| 264 |
+
typedef enum {
|
| 265 |
+
CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */
|
| 266 |
+
CUDNN_TENSOR_NHWC = 1, /* feature maps interleaved ( cStride = 1 )*/
|
| 267 |
+
CUDNN_TENSOR_NCHW_VECT_C = 2, /* each image point is vector of element of C, vector length in data type */
|
| 268 |
+
} cudnnTensorFormat_t;
|
| 269 |
+
|
| 270 |
+
/*
|
| 271 |
+
* CUDNN ReduceTensor op type
|
| 272 |
+
*/
|
| 273 |
+
typedef enum {
|
| 274 |
+
CUDNN_REDUCE_TENSOR_ADD = 0,
|
| 275 |
+
CUDNN_REDUCE_TENSOR_MUL = 1,
|
| 276 |
+
CUDNN_REDUCE_TENSOR_MIN = 2,
|
| 277 |
+
CUDNN_REDUCE_TENSOR_MAX = 3,
|
| 278 |
+
CUDNN_REDUCE_TENSOR_AMAX = 4,
|
| 279 |
+
CUDNN_REDUCE_TENSOR_AVG = 5,
|
| 280 |
+
CUDNN_REDUCE_TENSOR_NORM1 = 6,
|
| 281 |
+
CUDNN_REDUCE_TENSOR_NORM2 = 7,
|
| 282 |
+
CUDNN_REDUCE_TENSOR_MUL_NO_ZEROS = 8,
|
| 283 |
+
} cudnnReduceTensorOp_t;
|
| 284 |
+
|
| 285 |
+
/*
|
| 286 |
+
* activation mode
|
| 287 |
+
*/
|
| 288 |
+
typedef enum {
|
| 289 |
+
CUDNN_ACTIVATION_SIGMOID = 0,
|
| 290 |
+
CUDNN_ACTIVATION_RELU = 1,
|
| 291 |
+
CUDNN_ACTIVATION_TANH = 2,
|
| 292 |
+
CUDNN_ACTIVATION_CLIPPED_RELU = 3,
|
| 293 |
+
CUDNN_ACTIVATION_ELU = 4,
|
| 294 |
+
CUDNN_ACTIVATION_IDENTITY = 5,
|
| 295 |
+
CUDNN_ACTIVATION_SWISH = 6
|
| 296 |
+
} cudnnActivationMode_t CUDNN_DEPRECATED;
|
| 297 |
+
|
| 298 |
+
typedef enum {
|
| 299 |
+
CUDNN_SEV_FATAL = 0,
|
| 300 |
+
CUDNN_SEV_ERROR = 1,
|
| 301 |
+
CUDNN_SEV_WARNING = 2,
|
| 302 |
+
CUDNN_SEV_INFO = 3,
|
| 303 |
+
} cudnnSeverity_t;
|
| 304 |
+
|
| 305 |
+
/* Message masks to be used with cudnnSetCallback() */
|
| 306 |
+
#define CUDNN_SEV_ERROR_EN (1U << CUDNN_SEV_ERROR)
|
| 307 |
+
#define CUDNN_SEV_WARNING_EN (1U << CUDNN_SEV_WARNING)
|
| 308 |
+
#define CUDNN_SEV_INFO_EN (1U << CUDNN_SEV_INFO)
|
| 309 |
+
|
| 310 |
+
/* struct containing useful informaiton for each API call */
|
| 311 |
+
typedef struct cudnnDebugStruct {
|
| 312 |
+
unsigned cudnn_version;
|
| 313 |
+
cudnnStatus_t cudnnStatus;
|
| 314 |
+
unsigned time_sec; /* epoch time in seconds */
|
| 315 |
+
unsigned time_usec; /* microseconds part of epoch time */
|
| 316 |
+
unsigned time_delta; /* time since start in seconds */
|
| 317 |
+
cudnnHandle_t handle; /* cudnn handle */
|
| 318 |
+
cudaStream_t stream; /* cuda stream ID */
|
| 319 |
+
unsigned long long pid; /* process ID */
|
| 320 |
+
unsigned long long tid; /* thread ID */
|
| 321 |
+
int cudaDeviceId; /* CUDA device ID */
|
| 322 |
+
int reserved[15]; /* reserved for future use */
|
| 323 |
+
} cudnnDebug_t;
|
| 324 |
+
|
| 325 |
+
typedef void (*cudnnCallback_t)(cudnnSeverity_t sev, void *udata, const cudnnDebug_t *dbg, const char *msg);
|
| 326 |
+
|
| 327 |
+
cudnnStatus_t CUDNNWINAPI
|
| 328 |
+
cudnnSetCallback(unsigned mask, void *udata, cudnnCallback_t fptr);
|
| 329 |
+
|
| 330 |
+
cudnnStatus_t CUDNNWINAPI
|
| 331 |
+
cudnnGetCallback(unsigned *mask, void **udata, cudnnCallback_t *fptr);
|
| 332 |
+
|
| 333 |
+
/*
|
| 334 |
+
* \brief Cross-library version checker.
|
| 335 |
+
* This function is implemented differently in each sub-library. Each sublib
|
| 336 |
+
* checks whether its own version matches that of its dependencies.
|
| 337 |
+
* \returns CUDNN_STATUS_SUCCESS if the version check passes,
|
| 338 |
+
* CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
|
| 339 |
+
*/
|
| 340 |
+
cudnnStatus_t CUDNNWINAPI
|
| 341 |
+
cudnnGraphVersionCheck(void);
|
| 342 |
+
|
| 343 |
+
/* Maximum supported number of tensor dimensions */
|
| 344 |
+
#define CUDNN_DIM_MAX 8
|
| 345 |
+
|
| 346 |
+
/*
|
| 347 |
+
* convolution mode
|
| 348 |
+
*/
|
| 349 |
+
typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t;
|
| 350 |
+
|
| 351 |
+
/*
|
| 352 |
+
* CUDNN Reorder
|
| 353 |
+
*/
|
| 354 |
+
typedef enum {
|
| 355 |
+
CUDNN_DEFAULT_REORDER = 0,
|
| 356 |
+
CUDNN_NO_REORDER = 1,
|
| 357 |
+
} cudnnReorderType_t CUDNN_DEPRECATED;
|
| 358 |
+
|
| 359 |
+
typedef void *cudnnBackendDescriptor_t;
|
| 360 |
+
|
| 361 |
+
typedef struct cudnnFractionStruct {
|
| 362 |
+
int64_t numerator;
|
| 363 |
+
int64_t denominator;
|
| 364 |
+
} cudnnFraction_t;
|
| 365 |
+
|
| 366 |
+
typedef enum {
|
| 367 |
+
CUDNN_POINTWISE_ADD = 0,
|
| 368 |
+
CUDNN_POINTWISE_ADD_SQUARE = 5,
|
| 369 |
+
CUDNN_POINTWISE_DIV = 6,
|
| 370 |
+
CUDNN_POINTWISE_MAX = 3,
|
| 371 |
+
CUDNN_POINTWISE_MIN = 2,
|
| 372 |
+
CUDNN_POINTWISE_MOD = 7,
|
| 373 |
+
CUDNN_POINTWISE_MUL = 1,
|
| 374 |
+
CUDNN_POINTWISE_POW = 8,
|
| 375 |
+
CUDNN_POINTWISE_SUB = 9,
|
| 376 |
+
|
| 377 |
+
CUDNN_POINTWISE_ABS = 10,
|
| 378 |
+
CUDNN_POINTWISE_CEIL = 11,
|
| 379 |
+
CUDNN_POINTWISE_COS = 12,
|
| 380 |
+
CUDNN_POINTWISE_EXP = 13,
|
| 381 |
+
CUDNN_POINTWISE_FLOOR = 14,
|
| 382 |
+
CUDNN_POINTWISE_LOG = 15,
|
| 383 |
+
CUDNN_POINTWISE_NEG = 16,
|
| 384 |
+
CUDNN_POINTWISE_RSQRT = 17,
|
| 385 |
+
CUDNN_POINTWISE_SIN = 18,
|
| 386 |
+
CUDNN_POINTWISE_SQRT = 4,
|
| 387 |
+
CUDNN_POINTWISE_TAN = 19,
|
| 388 |
+
CUDNN_POINTWISE_ERF = 20,
|
| 389 |
+
CUDNN_POINTWISE_IDENTITY = 21,
|
| 390 |
+
CUDNN_POINTWISE_RECIPROCAL = 22,
|
| 391 |
+
CUDNN_POINTWISE_ATAN2 = 23,
|
| 392 |
+
|
| 393 |
+
CUDNN_POINTWISE_RELU_FWD = 100,
|
| 394 |
+
CUDNN_POINTWISE_TANH_FWD = 101,
|
| 395 |
+
CUDNN_POINTWISE_SIGMOID_FWD = 102,
|
| 396 |
+
CUDNN_POINTWISE_ELU_FWD = 103,
|
| 397 |
+
CUDNN_POINTWISE_GELU_FWD = 104,
|
| 398 |
+
CUDNN_POINTWISE_SOFTPLUS_FWD = 105,
|
| 399 |
+
CUDNN_POINTWISE_SWISH_FWD = 106,
|
| 400 |
+
CUDNN_POINTWISE_GELU_APPROX_TANH_FWD = 107,
|
| 401 |
+
|
| 402 |
+
CUDNN_POINTWISE_RELU_BWD = 200,
|
| 403 |
+
CUDNN_POINTWISE_TANH_BWD = 201,
|
| 404 |
+
CUDNN_POINTWISE_SIGMOID_BWD = 202,
|
| 405 |
+
CUDNN_POINTWISE_ELU_BWD = 203,
|
| 406 |
+
CUDNN_POINTWISE_GELU_BWD = 204,
|
| 407 |
+
CUDNN_POINTWISE_SOFTPLUS_BWD = 205,
|
| 408 |
+
CUDNN_POINTWISE_SWISH_BWD = 206,
|
| 409 |
+
CUDNN_POINTWISE_GELU_APPROX_TANH_BWD = 207,
|
| 410 |
+
|
| 411 |
+
CUDNN_POINTWISE_CMP_EQ = 300,
|
| 412 |
+
CUDNN_POINTWISE_CMP_NEQ = 301,
|
| 413 |
+
CUDNN_POINTWISE_CMP_GT = 302,
|
| 414 |
+
CUDNN_POINTWISE_CMP_GE = 303,
|
| 415 |
+
CUDNN_POINTWISE_CMP_LT = 304,
|
| 416 |
+
CUDNN_POINTWISE_CMP_LE = 305,
|
| 417 |
+
|
| 418 |
+
CUDNN_POINTWISE_LOGICAL_AND = 400,
|
| 419 |
+
CUDNN_POINTWISE_LOGICAL_OR = 401,
|
| 420 |
+
CUDNN_POINTWISE_LOGICAL_NOT = 402,
|
| 421 |
+
|
| 422 |
+
CUDNN_POINTWISE_GEN_INDEX = 501,
|
| 423 |
+
|
| 424 |
+
CUDNN_POINTWISE_BINARY_SELECT = 601,
|
| 425 |
+
} cudnnPointwiseMode_t;
|
| 426 |
+
|
| 427 |
+
typedef enum {
|
| 428 |
+
CUDNN_RESAMPLE_NEAREST = 0,
|
| 429 |
+
CUDNN_RESAMPLE_BILINEAR = 1,
|
| 430 |
+
CUDNN_RESAMPLE_AVGPOOL = 2,
|
| 431 |
+
CUDNN_RESAMPLE_AVGPOOL_INCLUDE_PADDING = 2,
|
| 432 |
+
CUDNN_RESAMPLE_AVGPOOL_EXCLUDE_PADDING = 4,
|
| 433 |
+
CUDNN_RESAMPLE_MAXPOOL = 3,
|
| 434 |
+
} cudnnResampleMode_t;
|
| 435 |
+
|
| 436 |
+
typedef enum {
|
| 437 |
+
CUDNN_SIGNAL_SET = 0,
|
| 438 |
+
CUDNN_SIGNAL_WAIT = 1,
|
| 439 |
+
} cudnnSignalMode_t;
|
| 440 |
+
|
| 441 |
+
typedef enum {
|
| 442 |
+
CUDNN_GENSTATS_SUM_SQSUM = 0,
|
| 443 |
+
} cudnnGenStatsMode_t;
|
| 444 |
+
|
| 445 |
+
typedef enum {
|
| 446 |
+
CUDNN_BN_FINALIZE_STATISTICS_TRAINING = 0,
|
| 447 |
+
CUDNN_BN_FINALIZE_STATISTICS_INFERENCE = 1,
|
| 448 |
+
} cudnnBnFinalizeStatsMode_t;
|
| 449 |
+
|
| 450 |
+
typedef enum {
|
| 451 |
+
CUDNN_RNG_DISTRIBUTION_BERNOULLI,
|
| 452 |
+
CUDNN_RNG_DISTRIBUTION_UNIFORM,
|
| 453 |
+
CUDNN_RNG_DISTRIBUTION_NORMAL,
|
| 454 |
+
} cudnnRngDistribution_t;
|
| 455 |
+
|
| 456 |
+
typedef enum {
|
| 457 |
+
CUDNN_ATTR_POINTWISE_MODE = 0,
|
| 458 |
+
CUDNN_ATTR_POINTWISE_MATH_PREC = 1,
|
| 459 |
+
CUDNN_ATTR_POINTWISE_NAN_PROPAGATION CUDNN_DEPRECATED_ENUM = 2,
|
| 460 |
+
CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3,
|
| 461 |
+
CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4,
|
| 462 |
+
CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE = 5,
|
| 463 |
+
CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6,
|
| 464 |
+
CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7,
|
| 465 |
+
CUDNN_ATTR_POINTWISE_SWISH_BETA = 8,
|
| 466 |
+
CUDNN_ATTR_POINTWISE_AXIS = 9,
|
| 467 |
+
|
| 468 |
+
CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100,
|
| 469 |
+
CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101,
|
| 470 |
+
CUDNN_ATTR_CONVOLUTION_DILATIONS = 102,
|
| 471 |
+
CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES = 103,
|
| 472 |
+
CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104,
|
| 473 |
+
CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105,
|
| 474 |
+
CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106,
|
| 475 |
+
|
| 476 |
+
CUDNN_ATTR_ENGINEHEUR_MODE = 200,
|
| 477 |
+
CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH = 201,
|
| 478 |
+
CUDNN_ATTR_ENGINEHEUR_RESULTS = 202,
|
| 479 |
+
CUDNN_ATTR_ENGINEHEUR_SM_COUNT_TARGET = 203,
|
| 480 |
+
|
| 481 |
+
CUDNN_ATTR_ENGINECFG_ENGINE = 300,
|
| 482 |
+
CUDNN_ATTR_ENGINECFG_INTERMEDIATE_INFO = 301,
|
| 483 |
+
CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302,
|
| 484 |
+
|
| 485 |
+
CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400,
|
| 486 |
+
CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401,
|
| 487 |
+
CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402,
|
| 488 |
+
CUDNN_ATTR_EXECUTION_PLAN_COMPUTED_INTERMEDIATE_UIDS = 403,
|
| 489 |
+
CUDNN_ATTR_EXECUTION_PLAN_RUN_ONLY_INTERMEDIATE_UIDS = 404,
|
| 490 |
+
CUDNN_ATTR_EXECUTION_PLAN_JSON_REPRESENTATION = 405,
|
| 491 |
+
|
| 492 |
+
CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500,
|
| 493 |
+
CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501,
|
| 494 |
+
CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502,
|
| 495 |
+
CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_ATTRIBUTES = 503,
|
| 496 |
+
|
| 497 |
+
CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600,
|
| 498 |
+
CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE = 601,
|
| 499 |
+
|
| 500 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700,
|
| 501 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701,
|
| 502 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702,
|
| 503 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703,
|
| 504 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704,
|
| 505 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705,
|
| 506 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706,
|
| 507 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707,
|
| 508 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708,
|
| 509 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709,
|
| 510 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710,
|
| 511 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711,
|
| 512 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712,
|
| 513 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713,
|
| 514 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC = 714,
|
| 515 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715,
|
| 516 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716,
|
| 517 |
+
CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717,
|
| 518 |
+
|
| 519 |
+
CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR = 750,
|
| 520 |
+
CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751,
|
| 521 |
+
CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752,
|
| 522 |
+
CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753,
|
| 523 |
+
CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754,
|
| 524 |
+
CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755,
|
| 525 |
+
CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756,
|
| 526 |
+
CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757,
|
| 527 |
+
CUDNN_ATTR_OPERATION_POINTWISE_TDESC = 758,
|
| 528 |
+
|
| 529 |
+
CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770,
|
| 530 |
+
CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC = 771,
|
| 531 |
+
CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772,
|
| 532 |
+
CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773,
|
| 533 |
+
CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC = 774,
|
| 534 |
+
|
| 535 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780,
|
| 536 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781,
|
| 537 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782,
|
| 538 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783,
|
| 539 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784,
|
| 540 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785,
|
| 541 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786,
|
| 542 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787,
|
| 543 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC = 788,
|
| 544 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789,
|
| 545 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790,
|
| 546 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791,
|
| 547 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792,
|
| 548 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793,
|
| 549 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794,
|
| 550 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795,
|
| 551 |
+
CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796,
|
| 552 |
+
|
| 553 |
+
CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800,
|
| 554 |
+
CUDNN_ATTR_OPERATIONGRAPH_OPS = 801,
|
| 555 |
+
CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT = 802,
|
| 556 |
+
|
| 557 |
+
CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900,
|
| 558 |
+
CUDNN_ATTR_TENSOR_DATA_TYPE = 901,
|
| 559 |
+
CUDNN_ATTR_TENSOR_DIMENSIONS = 902,
|
| 560 |
+
CUDNN_ATTR_TENSOR_STRIDES = 903,
|
| 561 |
+
CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904,
|
| 562 |
+
CUDNN_ATTR_TENSOR_VECTORIZED_DIMENSION = 905,
|
| 563 |
+
CUDNN_ATTR_TENSOR_UNIQUE_ID = 906,
|
| 564 |
+
CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907,
|
| 565 |
+
CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908,
|
| 566 |
+
CUDNN_ATTR_TENSOR_REORDERING_MODE = 909,
|
| 567 |
+
CUDNN_ATTR_TENSOR_RAGGED_OFFSET_DESC = 913,
|
| 568 |
+
|
| 569 |
+
CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000,
|
| 570 |
+
CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS = 1001,
|
| 571 |
+
CUDNN_ATTR_VARIANT_PACK_INTERMEDIATES = 1002,
|
| 572 |
+
CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003,
|
| 573 |
+
|
| 574 |
+
CUDNN_ATTR_LAYOUT_INFO_TENSOR_UID = 1100,
|
| 575 |
+
CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101,
|
| 576 |
+
|
| 577 |
+
CUDNN_ATTR_KNOB_INFO_TYPE = 1200,
|
| 578 |
+
CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE = 1201,
|
| 579 |
+
CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE = 1202,
|
| 580 |
+
CUDNN_ATTR_KNOB_INFO_STRIDE = 1203,
|
| 581 |
+
|
| 582 |
+
CUDNN_ATTR_ENGINE_OPERATION_GRAPH = 1300,
|
| 583 |
+
CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301,
|
| 584 |
+
CUDNN_ATTR_ENGINE_KNOB_INFO = 1302,
|
| 585 |
+
CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303,
|
| 586 |
+
CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304,
|
| 587 |
+
CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305,
|
| 588 |
+
CUDNN_ATTR_ENGINE_SM_COUNT_TARGET = 1306,
|
| 589 |
+
|
| 590 |
+
CUDNN_ATTR_MATMUL_COMP_TYPE = 1500,
|
| 591 |
+
CUDNN_ATTR_MATMUL_PADDING_VALUE = 1503,
|
| 592 |
+
|
| 593 |
+
CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520,
|
| 594 |
+
CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521,
|
| 595 |
+
CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522,
|
| 596 |
+
CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523,
|
| 597 |
+
CUDNN_ATTR_OPERATION_MATMUL_IRREGULARLY_STRIDED_BATCH_COUNT CUDNN_DEPRECATED_ENUM = 1524,
|
| 598 |
+
CUDNN_ATTR_OPERATION_MATMUL_GEMM_M_OVERRIDE_DESC = 1525,
|
| 599 |
+
CUDNN_ATTR_OPERATION_MATMUL_GEMM_N_OVERRIDE_DESC = 1526,
|
| 600 |
+
CUDNN_ATTR_OPERATION_MATMUL_GEMM_K_OVERRIDE_DESC = 1527,
|
| 601 |
+
|
| 602 |
+
CUDNN_ATTR_REDUCTION_OPERATOR = 1600,
|
| 603 |
+
CUDNN_ATTR_REDUCTION_COMP_TYPE = 1601,
|
| 604 |
+
|
| 605 |
+
CUDNN_ATTR_OPERATION_REDUCTION_XDESC = 1610,
|
| 606 |
+
CUDNN_ATTR_OPERATION_REDUCTION_YDESC = 1611,
|
| 607 |
+
CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612,
|
| 608 |
+
|
| 609 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620,
|
| 610 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621,
|
| 611 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622,
|
| 612 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623,
|
| 613 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624,
|
| 614 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625,
|
| 615 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626,
|
| 616 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627,
|
| 617 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC = 1628,
|
| 618 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629,
|
| 619 |
+
CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630,
|
| 620 |
+
|
| 621 |
+
CUDNN_ATTR_RESAMPLE_MODE = 1700,
|
| 622 |
+
CUDNN_ATTR_RESAMPLE_COMP_TYPE = 1701,
|
| 623 |
+
CUDNN_ATTR_RESAMPLE_SPATIAL_DIMS = 1702,
|
| 624 |
+
CUDNN_ATTR_RESAMPLE_POST_PADDINGS = 1703,
|
| 625 |
+
CUDNN_ATTR_RESAMPLE_PRE_PADDINGS = 1704,
|
| 626 |
+
CUDNN_ATTR_RESAMPLE_STRIDES = 1705,
|
| 627 |
+
CUDNN_ATTR_RESAMPLE_WINDOW_DIMS = 1706,
|
| 628 |
+
CUDNN_ATTR_RESAMPLE_NAN_PROPAGATION = 1707,
|
| 629 |
+
CUDNN_ATTR_RESAMPLE_PADDING_MODE = 1708,
|
| 630 |
+
|
| 631 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_XDESC = 1710,
|
| 632 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_YDESC = 1711,
|
| 633 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_IDXDESC = 1712,
|
| 634 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_ALPHA CUDNN_DEPRECATED_ENUM = 1713,
|
| 635 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_BETA CUDNN_DEPRECATED_ENUM = 1714,
|
| 636 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_FWD_DESC = 1716,
|
| 637 |
+
|
| 638 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DXDESC = 1720,
|
| 639 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DYDESC = 1721,
|
| 640 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_IDXDESC = 1722,
|
| 641 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_ALPHA CUDNN_DEPRECATED_ENUM = 1723,
|
| 642 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_BETA CUDNN_DEPRECATED_ENUM = 1724,
|
| 643 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DESC = 1725,
|
| 644 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_XDESC = 1726,
|
| 645 |
+
CUDNN_ATTR_OPERATION_RESAMPLE_BWD_YDESC = 1727,
|
| 646 |
+
|
| 647 |
+
CUDNN_ATTR_OPERATION_CONCAT_AXIS = 1800,
|
| 648 |
+
CUDNN_ATTR_OPERATION_CONCAT_INPUT_DESCS = 1801,
|
| 649 |
+
CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX = 1802,
|
| 650 |
+
CUDNN_ATTR_OPERATION_CONCAT_OUTPUT_DESC = 1803,
|
| 651 |
+
|
| 652 |
+
CUDNN_ATTR_OPERATION_SIGNAL_MODE = 1900,
|
| 653 |
+
CUDNN_ATTR_OPERATION_SIGNAL_FLAGDESC = 1901,
|
| 654 |
+
CUDNN_ATTR_OPERATION_SIGNAL_VALUE = 1902,
|
| 655 |
+
CUDNN_ATTR_OPERATION_SIGNAL_XDESC = 1903,
|
| 656 |
+
CUDNN_ATTR_OPERATION_SIGNAL_YDESC = 1904,
|
| 657 |
+
|
| 658 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_MODE = 2000,
|
| 659 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_PHASE = 2001,
|
| 660 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_XDESC = 2002,
|
| 661 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC = 2003,
|
| 662 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC = 2004,
|
| 663 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC = 2005,
|
| 664 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC = 2006,
|
| 665 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC = 2007,
|
| 666 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_EXP_AVG_FACTOR_DESC = 2008,
|
| 667 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_MEAN_DESC = 2009,
|
| 668 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_VAR_DESC = 2010,
|
| 669 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_MEAN_DESC = 2011,
|
| 670 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_VAR_DESC = 2012,
|
| 671 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_YDESC = 2013,
|
| 672 |
+
CUDNN_ATTR_OPERATION_NORM_FWD_PEER_STAT_DESCS = 2014,
|
| 673 |
+
|
| 674 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_MODE = 2100,
|
| 675 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_XDESC = 2101,
|
| 676 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC = 2102,
|
| 677 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC = 2103,
|
| 678 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC = 2104,
|
| 679 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC = 2105,
|
| 680 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_EPSILON_DESC = 2106,
|
| 681 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC = 2107,
|
| 682 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC = 2108,
|
| 683 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC = 2109,
|
| 684 |
+
CUDNN_ATTR_OPERATION_NORM_BWD_PEER_STAT_DESCS = 2110,
|
| 685 |
+
|
| 686 |
+
CUDNN_ATTR_OPERATION_RESHAPE_XDESC = 2200,
|
| 687 |
+
CUDNN_ATTR_OPERATION_RESHAPE_YDESC = 2201,
|
| 688 |
+
|
| 689 |
+
CUDNN_ATTR_RNG_DISTRIBUTION = 2300,
|
| 690 |
+
CUDNN_ATTR_RNG_NORMAL_DIST_MEAN = 2301,
|
| 691 |
+
CUDNN_ATTR_RNG_NORMAL_DIST_STANDARD_DEVIATION = 2302,
|
| 692 |
+
CUDNN_ATTR_RNG_UNIFORM_DIST_MAXIMUM = 2303,
|
| 693 |
+
CUDNN_ATTR_RNG_UNIFORM_DIST_MINIMUM = 2304,
|
| 694 |
+
CUDNN_ATTR_RNG_BERNOULLI_DIST_PROBABILITY = 2305,
|
| 695 |
+
|
| 696 |
+
CUDNN_ATTR_OPERATION_RNG_YDESC = 2310,
|
| 697 |
+
CUDNN_ATTR_OPERATION_RNG_SEED = 2311,
|
| 698 |
+
CUDNN_ATTR_OPERATION_RNG_DESC = 2312,
|
| 699 |
+
CUDNN_ATTR_OPERATION_RNG_OFFSET_DESC = 2313,
|
| 700 |
+
} cudnnBackendAttributeName_t;
|
| 701 |
+
|
| 702 |
+
typedef enum {
|
| 703 |
+
CUDNN_TYPE_HANDLE = 0,
|
| 704 |
+
CUDNN_TYPE_DATA_TYPE,
|
| 705 |
+
CUDNN_TYPE_BOOLEAN,
|
| 706 |
+
CUDNN_TYPE_INT64,
|
| 707 |
+
CUDNN_TYPE_FLOAT,
|
| 708 |
+
CUDNN_TYPE_DOUBLE,
|
| 709 |
+
CUDNN_TYPE_VOID_PTR,
|
| 710 |
+
CUDNN_TYPE_CONVOLUTION_MODE,
|
| 711 |
+
CUDNN_TYPE_HEUR_MODE,
|
| 712 |
+
CUDNN_TYPE_KNOB_TYPE,
|
| 713 |
+
CUDNN_TYPE_NAN_PROPOGATION CUDNN_DEPRECATED_ENUM,
|
| 714 |
+
CUDNN_TYPE_NUMERICAL_NOTE,
|
| 715 |
+
CUDNN_TYPE_LAYOUT_TYPE,
|
| 716 |
+
CUDNN_TYPE_ATTRIB_NAME,
|
| 717 |
+
CUDNN_TYPE_POINTWISE_MODE,
|
| 718 |
+
CUDNN_TYPE_BACKEND_DESCRIPTOR,
|
| 719 |
+
CUDNN_TYPE_GENSTATS_MODE,
|
| 720 |
+
CUDNN_TYPE_BN_FINALIZE_STATS_MODE,
|
| 721 |
+
CUDNN_TYPE_REDUCTION_OPERATOR_TYPE,
|
| 722 |
+
CUDNN_TYPE_BEHAVIOR_NOTE,
|
| 723 |
+
CUDNN_TYPE_TENSOR_REORDERING_MODE,
|
| 724 |
+
CUDNN_TYPE_RESAMPLE_MODE,
|
| 725 |
+
CUDNN_TYPE_PADDING_MODE,
|
| 726 |
+
CUDNN_TYPE_INT32,
|
| 727 |
+
CUDNN_TYPE_CHAR,
|
| 728 |
+
CUDNN_TYPE_SIGNAL_MODE,
|
| 729 |
+
CUDNN_TYPE_FRACTION,
|
| 730 |
+
CUDNN_TYPE_NORM_MODE,
|
| 731 |
+
CUDNN_TYPE_NORM_FWD_PHASE,
|
| 732 |
+
CUDNN_TYPE_RNG_DISTRIBUTION
|
| 733 |
+
} cudnnBackendAttributeType_t;
|
| 734 |
+
|
| 735 |
+
typedef enum {
|
| 736 |
+
CUDNN_BACKEND_POINTWISE_DESCRIPTOR = 0,
|
| 737 |
+
CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR,
|
| 738 |
+
CUDNN_BACKEND_ENGINE_DESCRIPTOR,
|
| 739 |
+
CUDNN_BACKEND_ENGINECFG_DESCRIPTOR,
|
| 740 |
+
CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR,
|
| 741 |
+
CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR,
|
| 742 |
+
CUDNN_BACKEND_INTERMEDIATE_INFO_DESCRIPTOR,
|
| 743 |
+
CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR,
|
| 744 |
+
CUDNN_BACKEND_KNOB_INFO_DESCRIPTOR,
|
| 745 |
+
CUDNN_BACKEND_LAYOUT_INFO_DESCRIPTOR,
|
| 746 |
+
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
|
| 747 |
+
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
|
| 748 |
+
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
|
| 749 |
+
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR,
|
| 750 |
+
CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR,
|
| 751 |
+
CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR,
|
| 752 |
+
CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR,
|
| 753 |
+
CUDNN_BACKEND_TENSOR_DESCRIPTOR,
|
| 754 |
+
CUDNN_BACKEND_MATMUL_DESCRIPTOR,
|
| 755 |
+
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR,
|
| 756 |
+
CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR,
|
| 757 |
+
CUDNN_BACKEND_REDUCTION_DESCRIPTOR,
|
| 758 |
+
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR,
|
| 759 |
+
CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR,
|
| 760 |
+
CUDNN_BACKEND_RESAMPLE_DESCRIPTOR,
|
| 761 |
+
CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR,
|
| 762 |
+
CUDNN_BACKEND_OPERATION_RESAMPLE_BWD_DESCRIPTOR,
|
| 763 |
+
CUDNN_BACKEND_OPERATION_CONCAT_DESCRIPTOR,
|
| 764 |
+
CUDNN_BACKEND_OPERATION_SIGNAL_DESCRIPTOR,
|
| 765 |
+
CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR,
|
| 766 |
+
CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR,
|
| 767 |
+
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR,
|
| 768 |
+
CUDNN_BACKEND_RNG_DESCRIPTOR,
|
| 769 |
+
CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR,
|
| 770 |
+
} cudnnBackendDescriptorType_t;
|
| 771 |
+
|
| 772 |
+
typedef enum {
|
| 773 |
+
CUDNN_NUMERICAL_NOTE_TENSOR_CORE = 0,
|
| 774 |
+
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS,
|
| 775 |
+
CUDNN_NUMERICAL_NOTE_REDUCED_PRECISION_REDUCTION,
|
| 776 |
+
CUDNN_NUMERICAL_NOTE_FFT,
|
| 777 |
+
CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC,
|
| 778 |
+
CUDNN_NUMERICAL_NOTE_WINOGRAD,
|
| 779 |
+
CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_4x4,
|
| 780 |
+
CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_6x6,
|
| 781 |
+
CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_13x13,
|
| 782 |
+
CUDNN_NUMERICAL_NOTE_STRICT_NAN_PROP,
|
| 783 |
+
CUDNN_NUMERICAL_NOTE_TYPE_COUNT,
|
| 784 |
+
} cudnnBackendNumericalNote_t;
|
| 785 |
+
|
| 786 |
+
typedef enum {
|
| 787 |
+
CUDNN_BEHAVIOR_NOTE_RUNTIME_COMPILATION = 0,
|
| 788 |
+
CUDNN_BEHAVIOR_NOTE_REQUIRES_FILTER_INT8x32_REORDER = 1,
|
| 789 |
+
CUDNN_BEHAVIOR_NOTE_REQUIRES_BIAS_INT8x32_REORDER = 2,
|
| 790 |
+
CUDNN_BEHAVIOR_NOTE_TYPE_COUNT,
|
| 791 |
+
} cudnnBackendBehaviorNote_t;
|
| 792 |
+
|
| 793 |
+
typedef enum {
|
| 794 |
+
CUDNN_KNOB_TYPE_SPLIT_K CUDNN_DEPRECATED_ENUM = 0,
|
| 795 |
+
CUDNN_KNOB_TYPE_SWIZZLE = 1,
|
| 796 |
+
CUDNN_KNOB_TYPE_TILE_SIZE = 2,
|
| 797 |
+
CUDNN_KNOB_TYPE_USE_TEX CUDNN_DEPRECATED_ENUM = 3,
|
| 798 |
+
CUDNN_KNOB_TYPE_EDGE = 4,
|
| 799 |
+
CUDNN_KNOB_TYPE_KBLOCK CUDNN_DEPRECATED_ENUM = 5,
|
| 800 |
+
CUDNN_KNOB_TYPE_LDGA CUDNN_DEPRECATED_ENUM = 6,
|
| 801 |
+
CUDNN_KNOB_TYPE_LDGB CUDNN_DEPRECATED_ENUM = 7,
|
| 802 |
+
CUDNN_KNOB_TYPE_CHUNK_K CUDNN_DEPRECATED_ENUM = 8,
|
| 803 |
+
CUDNN_KNOB_TYPE_SPLIT_H CUDNN_DEPRECATED_ENUM = 9,
|
| 804 |
+
CUDNN_KNOB_TYPE_WINO_TILE CUDNN_DEPRECATED_ENUM = 10,
|
| 805 |
+
CUDNN_KNOB_TYPE_MULTIPLY = 11,
|
| 806 |
+
CUDNN_KNOB_TYPE_SPLIT_K_BUF = 12,
|
| 807 |
+
CUDNN_KNOB_TYPE_TILEK = 13,
|
| 808 |
+
CUDNN_KNOB_TYPE_STAGES = 14,
|
| 809 |
+
CUDNN_KNOB_TYPE_REDUCTION_MODE = 15,
|
| 810 |
+
CUDNN_KNOB_TYPE_CTA_SPLIT_K_MODE CUDNN_DEPRECATED_ENUM = 16,
|
| 811 |
+
CUDNN_KNOB_TYPE_SPLIT_K_SLC = 17,
|
| 812 |
+
CUDNN_KNOB_TYPE_IDX_MODE CUDNN_DEPRECATED_ENUM = 18,
|
| 813 |
+
CUDNN_KNOB_TYPE_SLICED CUDNN_DEPRECATED_ENUM = 19,
|
| 814 |
+
CUDNN_KNOB_TYPE_SPLIT_RS CUDNN_DEPRECATED_ENUM = 20,
|
| 815 |
+
CUDNN_KNOB_TYPE_SINGLEBUFFER CUDNN_DEPRECATED_ENUM = 21,
|
| 816 |
+
CUDNN_KNOB_TYPE_LDGC CUDNN_DEPRECATED_ENUM = 22,
|
| 817 |
+
CUDNN_KNOB_TYPE_SPECFILT = 23,
|
| 818 |
+
CUDNN_KNOB_TYPE_KERNEL_CFG = 24,
|
| 819 |
+
CUDNN_KNOB_TYPE_WORKSPACE = 25,
|
| 820 |
+
CUDNN_KNOB_TYPE_TILE_CGA CUDNN_DEPRECATED_ENUM = 26,
|
| 821 |
+
CUDNN_KNOB_TYPE_TILE_CGA_M = 27,
|
| 822 |
+
CUDNN_KNOB_TYPE_TILE_CGA_N = 28,
|
| 823 |
+
CUDNN_KNOB_TYPE_BLOCK_SIZE = 29,
|
| 824 |
+
CUDNN_KNOB_TYPE_OCCUPANCY = 30,
|
| 825 |
+
CUDNN_KNOB_TYPE_ARRAY_SIZE_PER_THREAD = 31,
|
| 826 |
+
CUDNN_KNOB_TYPE_NUM_C_PER_BLOCK CUDNN_DEPRECATED_ENUM = 32,
|
| 827 |
+
CUDNN_KNOB_TYPE_SPLIT_COLS = 33,
|
| 828 |
+
CUDNN_KNOB_TYPE_TILE_ROWS = 34,
|
| 829 |
+
CUDNN_KNOB_TYPE_TILE_COLS = 35,
|
| 830 |
+
CUDNN_KNOB_TYPE_LOAD_SIZE = 36,
|
| 831 |
+
CUDNN_KNOB_TYPE_COUNTS,
|
| 832 |
+
} cudnnBackendKnobType_t;
|
| 833 |
+
|
| 834 |
+
typedef enum {
|
| 835 |
+
CUDNN_LAYOUT_TYPE_PREFERRED_NCHW = 0,
|
| 836 |
+
CUDNN_LAYOUT_TYPE_PREFERRED_NHWC = 1,
|
| 837 |
+
CUDNN_LAYOUT_TYPE_PREFERRED_PAD4CK = 2,
|
| 838 |
+
CUDNN_LAYOUT_TYPE_PREFERRED_PAD8CK = 3,
|
| 839 |
+
CUDNN_LAYOUT_TYPE_COUNT = 4,
|
| 840 |
+
} cudnnBackendLayoutType_t;
|
| 841 |
+
|
| 842 |
+
typedef enum {
|
| 843 |
+
CUDNN_HEUR_MODE_INSTANT = 0,
|
| 844 |
+
CUDNN_HEUR_MODE_B = 1,
|
| 845 |
+
CUDNN_HEUR_MODE_FALLBACK = 2,
|
| 846 |
+
CUDNN_HEUR_MODE_A = 3,
|
| 847 |
+
CUDNN_HEUR_MODES_COUNT = 4,
|
| 848 |
+
} cudnnBackendHeurMode_t;
|
| 849 |
+
|
| 850 |
+
typedef enum {
|
| 851 |
+
CUDNN_TENSOR_REORDERING_NONE = 0,
|
| 852 |
+
CUDNN_TENSOR_REORDERING_INT8x32 = 1,
|
| 853 |
+
CUDNN_TENSOR_REORDERING_F16x16 = 2,
|
| 854 |
+
} cudnnBackendTensorReordering_t;
|
| 855 |
+
|
| 856 |
+
typedef enum {
|
| 857 |
+
CUDNN_ZERO_PAD = 0,
|
| 858 |
+
CUDNN_NEG_INF_PAD = 1,
|
| 859 |
+
CUDNN_EDGE_VAL_PAD = 2,
|
| 860 |
+
} cudnnPaddingMode_t;
|
| 861 |
+
|
| 862 |
+
typedef enum {
|
| 863 |
+
CUDNN_LAYER_NORM = 0,
|
| 864 |
+
CUDNN_INSTANCE_NORM = 1,
|
| 865 |
+
CUDNN_BATCH_NORM = 2,
|
| 866 |
+
CUDNN_GROUP_NORM = 3,
|
| 867 |
+
CUDNN_RMS_NORM = 4,
|
| 868 |
+
} cudnnBackendNormMode_t;
|
| 869 |
+
|
| 870 |
+
typedef enum {
|
| 871 |
+
CUDNN_NORM_FWD_INFERENCE = 0,
|
| 872 |
+
CUDNN_NORM_FWD_TRAINING = 1,
|
| 873 |
+
} cudnnBackendNormFwdPhase_t;
|
| 874 |
+
|
| 875 |
+
cudnnStatus_t CUDNNWINAPI
|
| 876 |
+
cudnnBackendCreateDescriptor(cudnnBackendDescriptorType_t descriptorType, cudnnBackendDescriptor_t *descriptor);
|
| 877 |
+
|
| 878 |
+
cudnnStatus_t CUDNNWINAPI
|
| 879 |
+
cudnnBackendDestroyDescriptor(cudnnBackendDescriptor_t descriptor);
|
| 880 |
+
|
| 881 |
+
cudnnStatus_t CUDNNWINAPI
|
| 882 |
+
cudnnBackendInitialize(cudnnBackendDescriptor_t descriptor);
|
| 883 |
+
|
| 884 |
+
cudnnStatus_t CUDNNWINAPI
|
| 885 |
+
cudnnBackendFinalize(cudnnBackendDescriptor_t descriptor);
|
| 886 |
+
|
| 887 |
+
cudnnStatus_t CUDNNWINAPI
|
| 888 |
+
cudnnBackendSetAttribute(cudnnBackendDescriptor_t descriptor,
|
| 889 |
+
cudnnBackendAttributeName_t attributeName,
|
| 890 |
+
cudnnBackendAttributeType_t attributeType,
|
| 891 |
+
int64_t elementCount,
|
| 892 |
+
const void *arrayOfElements);
|
| 893 |
+
|
| 894 |
+
cudnnStatus_t CUDNNWINAPI
|
| 895 |
+
cudnnBackendGetAttribute(cudnnBackendDescriptor_t const descriptor,
|
| 896 |
+
cudnnBackendAttributeName_t attributeName,
|
| 897 |
+
cudnnBackendAttributeType_t attributeType,
|
| 898 |
+
int64_t requestedElementCount,
|
| 899 |
+
int64_t *elementCount,
|
| 900 |
+
void *arrayOfElements);
|
| 901 |
+
|
| 902 |
+
cudnnStatus_t CUDNNWINAPI
|
| 903 |
+
cudnnBackendExecute(cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBackendDescriptor_t variantPack);
|
| 904 |
+
|
| 905 |
+
#if defined(__cplusplus)
|
| 906 |
+
}
|
| 907 |
+
#endif
|
| 908 |
+
|
| 909 |
+
#endif /* CUDNN_GRAPH_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops_v9.h
ADDED
|
@@ -0,0 +1,1316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* cudnn_ops : cuDNN's basic definitions and basic operations.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_OPS_H_)
|
| 55 |
+
#define CUDNN_OPS_H_
|
| 56 |
+
|
| 57 |
+
#include <stdint.h>
|
| 58 |
+
|
| 59 |
+
#include "cudnn_version.h"
|
| 60 |
+
#include "cudnn_graph.h"
|
| 61 |
+
|
| 62 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 63 |
+
#define CUDNN_OPS_MAJOR 9
|
| 64 |
+
#define CUDNN_OPS_MINOR 1
|
| 65 |
+
#define CUDNN_OPS_PATCH 0
|
| 66 |
+
|
| 67 |
+
#if (CUDNN_OPS_MAJOR != CUDNN_MAJOR) || (CUDNN_OPS_MINOR != CUDNN_MINOR) || (CUDNN_OPS_PATCH != CUDNN_PATCHLEVEL)
|
| 68 |
+
#error Version mismatch in cuDNN OPS INFER!!!
|
| 69 |
+
#endif
|
| 70 |
+
|
| 71 |
+
#if defined(__cplusplus)
|
| 72 |
+
extern "C" {
|
| 73 |
+
#endif
|
| 74 |
+
|
| 75 |
+
/* Data structures to represent Image/Filter and the Neural Network Layer */
|
| 76 |
+
typedef struct cudnnTensorStruct *cudnnTensorDescriptor_t;
|
| 77 |
+
typedef struct cudnnPoolingStruct *cudnnPoolingDescriptor_t CUDNN_DEPRECATED;
|
| 78 |
+
typedef struct cudnnFilterStruct *cudnnFilterDescriptor_t CUDNN_DEPRECATED;
|
| 79 |
+
typedef struct cudnnLRNStruct *cudnnLRNDescriptor_t;
|
| 80 |
+
typedef struct cudnnActivationStruct *cudnnActivationDescriptor_t CUDNN_DEPRECATED;
|
| 81 |
+
typedef struct cudnnSpatialTransformerStruct *cudnnSpatialTransformerDescriptor_t;
|
| 82 |
+
typedef struct cudnnOpTensorStruct *cudnnOpTensorDescriptor_t CUDNN_DEPRECATED;
|
| 83 |
+
typedef struct cudnnReduceTensorStruct *cudnnReduceTensorDescriptor_t CUDNN_DEPRECATED;
|
| 84 |
+
typedef struct cudnnCTCLossStruct *cudnnCTCLossDescriptor_t;
|
| 85 |
+
typedef struct cudnnTensorTransformStruct *cudnnTensorTransformDescriptor_t CUDNN_DEPRECATED;
|
| 86 |
+
/*
|
| 87 |
+
* CUDNN Determinism
|
| 88 |
+
*/
|
| 89 |
+
typedef enum {
|
| 90 |
+
CUDNN_NON_DETERMINISTIC = 0,
|
| 91 |
+
CUDNN_DETERMINISTIC = 1,
|
| 92 |
+
} cudnnDeterminism_t;
|
| 93 |
+
|
| 94 |
+
/* Create an instance of a generic Tensor descriptor */
|
| 95 |
+
cudnnStatus_t CUDNNWINAPI
|
| 96 |
+
cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc);
|
| 97 |
+
|
| 98 |
+
cudnnStatus_t CUDNNWINAPI
|
| 99 |
+
cudnnSetTensor4dDescriptor(cudnnTensorDescriptor_t tensorDesc,
|
| 100 |
+
cudnnTensorFormat_t format,
|
| 101 |
+
cudnnDataType_t dataType, /* image data type */
|
| 102 |
+
int n, /* number of inputs (batch size) */
|
| 103 |
+
int c, /* number of input feature maps */
|
| 104 |
+
int h, /* height of input section */
|
| 105 |
+
int w); /* width of input section */
|
| 106 |
+
|
| 107 |
+
cudnnStatus_t CUDNNWINAPI
|
| 108 |
+
cudnnSetTensor4dDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
|
| 109 |
+
cudnnDataType_t dataType, /* image data type */
|
| 110 |
+
int n, /* number of inputs (batch size) */
|
| 111 |
+
int c, /* number of input feature maps */
|
| 112 |
+
int h, /* height of input section */
|
| 113 |
+
int w, /* width of input section */
|
| 114 |
+
int nStride,
|
| 115 |
+
int cStride,
|
| 116 |
+
int hStride,
|
| 117 |
+
int wStride);
|
| 118 |
+
|
| 119 |
+
cudnnStatus_t CUDNNWINAPI
|
| 120 |
+
cudnnGetTensor4dDescriptor(const cudnnTensorDescriptor_t tensorDesc,
|
| 121 |
+
cudnnDataType_t *dataType, /* image data type */
|
| 122 |
+
int *n, /* number of inputs (batch size) */
|
| 123 |
+
int *c, /* number of input feature maps */
|
| 124 |
+
int *h, /* height of input section */
|
| 125 |
+
int *w, /* width of input section */
|
| 126 |
+
int *nStride,
|
| 127 |
+
int *cStride,
|
| 128 |
+
int *hStride,
|
| 129 |
+
int *wStride);
|
| 130 |
+
|
| 131 |
+
cudnnStatus_t CUDNNWINAPI
|
| 132 |
+
cudnnSetTensorNdDescriptor(cudnnTensorDescriptor_t tensorDesc,
|
| 133 |
+
cudnnDataType_t dataType,
|
| 134 |
+
int nbDims,
|
| 135 |
+
const int dimA[],
|
| 136 |
+
const int strideA[]);
|
| 137 |
+
|
| 138 |
+
cudnnStatus_t CUDNNWINAPI
|
| 139 |
+
cudnnSetTensorNdDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
|
| 140 |
+
cudnnTensorFormat_t format,
|
| 141 |
+
cudnnDataType_t dataType,
|
| 142 |
+
int nbDims,
|
| 143 |
+
const int dimA[]);
|
| 144 |
+
|
| 145 |
+
cudnnStatus_t CUDNNWINAPI
|
| 146 |
+
cudnnGetTensorNdDescriptor(const cudnnTensorDescriptor_t tensorDesc,
|
| 147 |
+
int nbDimsRequested,
|
| 148 |
+
cudnnDataType_t *dataType,
|
| 149 |
+
int *nbDims,
|
| 150 |
+
int dimA[],
|
| 151 |
+
int strideA[]);
|
| 152 |
+
|
| 153 |
+
cudnnStatus_t CUDNNWINAPI
|
| 154 |
+
cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size);
|
| 155 |
+
|
| 156 |
+
/* PixelOffset( n, c, h, w ) = n *input_stride + c * feature_stride + h * h_stride + w * w_stride
|
| 157 |
+
|
| 158 |
+
1)Example of all images in row major order one batch of features after the other (with an optional padding on row)
|
| 159 |
+
input_stride : c x h x h_stride
|
| 160 |
+
feature_stride : h x h_stride
|
| 161 |
+
h_stride : >= w ( h_stride = w if no padding)
|
| 162 |
+
w_stride : 1
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
2)Example of all images in row major with features maps interleaved
|
| 166 |
+
input_stride : c x h x h_stride
|
| 167 |
+
feature_stride : 1
|
| 168 |
+
h_stride : w x c
|
| 169 |
+
w_stride : c
|
| 170 |
+
|
| 171 |
+
3)Example of all images in column major order one batch of features after the other (with optional padding on column)
|
| 172 |
+
input_stride : c x w x w_stride
|
| 173 |
+
feature_stride : w x w_stride
|
| 174 |
+
h_stride : 1
|
| 175 |
+
w_stride : >= h
|
| 176 |
+
|
| 177 |
+
*/
|
| 178 |
+
|
| 179 |
+
/* Destroy an instance of Tensor4d descriptor */
|
| 180 |
+
cudnnStatus_t CUDNNWINAPI
|
| 181 |
+
cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc);
|
| 182 |
+
|
| 183 |
+
/* Fold/unfold transforms */
|
| 184 |
+
typedef enum {
|
| 185 |
+
CUDNN_TRANSFORM_FOLD = 0U,
|
| 186 |
+
CUDNN_TRANSFORM_UNFOLD = 1U,
|
| 187 |
+
} cudnnFoldingDirection_t;
|
| 188 |
+
|
| 189 |
+
/** Create a destination descriptor for cudnnTransformTensor */
|
| 190 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 191 |
+
cudnnInitTransformDest(const cudnnTensorTransformDescriptor_t transformDesc,
|
| 192 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 193 |
+
cudnnTensorDescriptor_t destDesc,
|
| 194 |
+
size_t *destSizeInBytes);
|
| 195 |
+
|
| 196 |
+
/** Create an empty tensor transform descriptor */
|
| 197 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 198 |
+
cudnnCreateTensorTransformDescriptor(cudnnTensorTransformDescriptor_t *transformDesc);
|
| 199 |
+
|
| 200 |
+
/** Initialize a previously created tensor transform descriptor. */
|
| 201 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 202 |
+
cudnnSetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
|
| 203 |
+
const uint32_t nbDims,
|
| 204 |
+
const cudnnTensorFormat_t destFormat,
|
| 205 |
+
const int32_t padBeforeA[],
|
| 206 |
+
const int32_t padAfterA[],
|
| 207 |
+
const uint32_t foldA[],
|
| 208 |
+
const cudnnFoldingDirection_t direction);
|
| 209 |
+
|
| 210 |
+
/**
|
| 211 |
+
* Retrieves the values stored in a previously initialized tensor transform
|
| 212 |
+
* descriptor.
|
| 213 |
+
*/
|
| 214 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 215 |
+
cudnnGetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
|
| 216 |
+
uint32_t nbDimsRequested,
|
| 217 |
+
cudnnTensorFormat_t *destFormat,
|
| 218 |
+
int32_t padBeforeA[],
|
| 219 |
+
int32_t padAfterA[],
|
| 220 |
+
uint32_t foldA[],
|
| 221 |
+
cudnnFoldingDirection_t *direction);
|
| 222 |
+
|
| 223 |
+
/**
|
| 224 |
+
* Destroys a previously created tensor transform descriptor.
|
| 225 |
+
*/
|
| 226 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 227 |
+
cudnnDestroyTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc);
|
| 228 |
+
|
| 229 |
+
/* Tensor layout conversion helper (y = alpha * x + beta * y) */
|
| 230 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 231 |
+
cudnnTransformTensor(cudnnHandle_t handle,
|
| 232 |
+
const void *alpha,
|
| 233 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 234 |
+
const void *x,
|
| 235 |
+
const void *beta,
|
| 236 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 237 |
+
void *y);
|
| 238 |
+
|
| 239 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 240 |
+
cudnnTransformTensorEx(cudnnHandle_t handle,
|
| 241 |
+
const cudnnTensorTransformDescriptor_t transDesc,
|
| 242 |
+
const void *alpha,
|
| 243 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 244 |
+
const void *srcData,
|
| 245 |
+
const void *beta,
|
| 246 |
+
const cudnnTensorDescriptor_t destDesc,
|
| 247 |
+
void *destData);
|
| 248 |
+
|
| 249 |
+
/* Tensor Bias addition : C = alpha * A + beta * C */
|
| 250 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 251 |
+
cudnnAddTensor(cudnnHandle_t handle,
|
| 252 |
+
const void *alpha,
|
| 253 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 254 |
+
const void *A,
|
| 255 |
+
const void *beta,
|
| 256 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 257 |
+
void *C);
|
| 258 |
+
|
| 259 |
+
/*
|
| 260 |
+
* CUDNN OpTensor op type
|
| 261 |
+
*/
|
| 262 |
+
typedef enum {
|
| 263 |
+
CUDNN_OP_TENSOR_ADD = 0,
|
| 264 |
+
CUDNN_OP_TENSOR_MUL = 1,
|
| 265 |
+
CUDNN_OP_TENSOR_MIN = 2,
|
| 266 |
+
CUDNN_OP_TENSOR_MAX = 3,
|
| 267 |
+
CUDNN_OP_TENSOR_SQRT = 4,
|
| 268 |
+
CUDNN_OP_TENSOR_NOT = 5,
|
| 269 |
+
} cudnnOpTensorOp_t;
|
| 270 |
+
|
| 271 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 272 |
+
cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc);
|
| 273 |
+
|
| 274 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 275 |
+
cudnnSetOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc,
|
| 276 |
+
cudnnOpTensorOp_t opTensorOp,
|
| 277 |
+
cudnnDataType_t opTensorCompType,
|
| 278 |
+
cudnnNanPropagation_t opTensorNanOpt);
|
| 279 |
+
|
| 280 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 281 |
+
cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc,
|
| 282 |
+
cudnnOpTensorOp_t *opTensorOp,
|
| 283 |
+
cudnnDataType_t *opTensorCompType,
|
| 284 |
+
cudnnNanPropagation_t *opTensorNanOpt);
|
| 285 |
+
|
| 286 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 287 |
+
cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc);
|
| 288 |
+
|
| 289 |
+
/* Tensor operation : C = op( alpha1 * A, alpha2 * B ) + beta * C */
|
| 290 |
+
/* B tensor is ignored for CUDNN_OP_TENSOR_SQRT, CUDNN_OP_TENSOR_NOT. */
|
| 291 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 292 |
+
cudnnOpTensor(cudnnHandle_t handle,
|
| 293 |
+
const cudnnOpTensorDescriptor_t opTensorDesc,
|
| 294 |
+
const void *alpha1,
|
| 295 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 296 |
+
const void *A,
|
| 297 |
+
const void *alpha2,
|
| 298 |
+
const cudnnTensorDescriptor_t bDesc,
|
| 299 |
+
const void *B,
|
| 300 |
+
const void *beta,
|
| 301 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 302 |
+
void *C);
|
| 303 |
+
|
| 304 |
+
/*
|
| 305 |
+
* CUDNN ReduceTensor indices type
|
| 306 |
+
*/
|
| 307 |
+
typedef enum {
|
| 308 |
+
CUDNN_REDUCE_TENSOR_NO_INDICES = 0,
|
| 309 |
+
CUDNN_REDUCE_TENSOR_FLATTENED_INDICES = 1,
|
| 310 |
+
} cudnnReduceTensorIndices_t CUDNN_DEPRECATED;
|
| 311 |
+
|
| 312 |
+
/*
|
| 313 |
+
* CUDNN tensor indices type size (all unsigned)
|
| 314 |
+
* Currently not supported, default is 32 bit unsigned.
|
| 315 |
+
*/
|
| 316 |
+
typedef enum {
|
| 317 |
+
CUDNN_32BIT_INDICES = 0,
|
| 318 |
+
CUDNN_64BIT_INDICES = 1,
|
| 319 |
+
CUDNN_16BIT_INDICES = 2,
|
| 320 |
+
CUDNN_8BIT_INDICES = 3,
|
| 321 |
+
} cudnnIndicesType_t CUDNN_DEPRECATED;
|
| 322 |
+
|
| 323 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 324 |
+
cudnnCreateReduceTensorDescriptor(cudnnReduceTensorDescriptor_t *reduceTensorDesc);
|
| 325 |
+
|
| 326 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 327 |
+
cudnnSetReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 328 |
+
cudnnReduceTensorOp_t reduceTensorOp,
|
| 329 |
+
cudnnDataType_t reduceTensorCompType,
|
| 330 |
+
cudnnNanPropagation_t reduceTensorNanOpt,
|
| 331 |
+
cudnnReduceTensorIndices_t reduceTensorIndices,
|
| 332 |
+
cudnnIndicesType_t reduceTensorIndicesType);
|
| 333 |
+
|
| 334 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 335 |
+
cudnnGetReduceTensorDescriptor(const cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 336 |
+
cudnnReduceTensorOp_t *reduceTensorOp,
|
| 337 |
+
cudnnDataType_t *reduceTensorCompType,
|
| 338 |
+
cudnnNanPropagation_t *reduceTensorNanOpt,
|
| 339 |
+
cudnnReduceTensorIndices_t *reduceTensorIndices,
|
| 340 |
+
cudnnIndicesType_t *reduceTensorIndicesType);
|
| 341 |
+
|
| 342 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 343 |
+
cudnnDestroyReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc);
|
| 344 |
+
|
| 345 |
+
/* Helper function to return the minimum size of the index space to be passed to the reduction given the input and
|
| 346 |
+
* output tensors */
|
| 347 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 348 |
+
cudnnGetReductionIndicesSize(cudnnHandle_t handle,
|
| 349 |
+
const cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 350 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 351 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 352 |
+
size_t *sizeInBytes);
|
| 353 |
+
|
| 354 |
+
/* Helper function to return the minimum size of the workspace to be passed to the reduction given the input and output
|
| 355 |
+
* tensors */
|
| 356 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 357 |
+
cudnnGetReductionWorkspaceSize(cudnnHandle_t handle,
|
| 358 |
+
const cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 359 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 360 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 361 |
+
size_t *sizeInBytes);
|
| 362 |
+
|
| 363 |
+
/* Tensor operation : C = reduce op( alpha * A ) + beta * C */
|
| 364 |
+
/* The NaN propagation enum applies to only the min and max reduce ops; the other reduce ops propagate NaN as usual. */
|
| 365 |
+
/* The indices space is ignored for reduce ops other than min or max. */
|
| 366 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 367 |
+
cudnnReduceTensor(cudnnHandle_t handle,
|
| 368 |
+
const cudnnReduceTensorDescriptor_t reduceTensorDesc,
|
| 369 |
+
void *indices,
|
| 370 |
+
size_t indicesSizeInBytes,
|
| 371 |
+
void *workspace,
|
| 372 |
+
size_t workspaceSizeInBytes,
|
| 373 |
+
const void *alpha,
|
| 374 |
+
const cudnnTensorDescriptor_t aDesc,
|
| 375 |
+
const void *A,
|
| 376 |
+
const void *beta,
|
| 377 |
+
const cudnnTensorDescriptor_t cDesc,
|
| 378 |
+
void *C);
|
| 379 |
+
|
| 380 |
+
/* Set all values of a tensor to a given value : y[i] = value[0] */
|
| 381 |
+
cudnnStatus_t CUDNNWINAPI
|
| 382 |
+
cudnnSetTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *valuePtr);
|
| 383 |
+
|
| 384 |
+
/* Scale all values of a tensor by a given factor : y[i] = alpha * y[i] */
|
| 385 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 386 |
+
cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *alpha);
|
| 387 |
+
|
| 388 |
+
/* Create an instance of FilterStruct */
|
| 389 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 390 |
+
cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc);
|
| 391 |
+
|
| 392 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 393 |
+
cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc,
|
| 394 |
+
cudnnDataType_t dataType, /* image data type */
|
| 395 |
+
cudnnTensorFormat_t format,
|
| 396 |
+
int k, /* number of output feature maps */
|
| 397 |
+
int c, /* number of input feature maps */
|
| 398 |
+
int h, /* height of each input filter */
|
| 399 |
+
int w); /* width of each input filter */
|
| 400 |
+
|
| 401 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 402 |
+
cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc,
|
| 403 |
+
cudnnDataType_t *dataType, /* image data type */
|
| 404 |
+
cudnnTensorFormat_t *format,
|
| 405 |
+
int *k, /* number of output feature maps */
|
| 406 |
+
int *c, /* number of input feature maps */
|
| 407 |
+
int *h, /* height of each input filter */
|
| 408 |
+
int *w); /* width of each input filter */
|
| 409 |
+
|
| 410 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 411 |
+
cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc,
|
| 412 |
+
cudnnDataType_t dataType, /* image data type */
|
| 413 |
+
cudnnTensorFormat_t format,
|
| 414 |
+
int nbDims,
|
| 415 |
+
const int filterDimA[]);
|
| 416 |
+
|
| 417 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 418 |
+
cudnnGetFilterNdDescriptor(const cudnnFilterDescriptor_t filterDesc,
|
| 419 |
+
int nbDimsRequested,
|
| 420 |
+
cudnnDataType_t *dataType, /* image data type */
|
| 421 |
+
cudnnTensorFormat_t *format,
|
| 422 |
+
int *nbDims,
|
| 423 |
+
int filterDimA[]);
|
| 424 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 425 |
+
cudnnGetFilterSizeInBytes(const cudnnFilterDescriptor_t filterDesc, size_t *size);
|
| 426 |
+
|
| 427 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 428 |
+
cudnnTransformFilter(cudnnHandle_t handle,
|
| 429 |
+
const cudnnTensorTransformDescriptor_t transDesc,
|
| 430 |
+
const void *alpha,
|
| 431 |
+
const cudnnFilterDescriptor_t srcDesc,
|
| 432 |
+
const void *srcData,
|
| 433 |
+
const void *beta,
|
| 434 |
+
const cudnnFilterDescriptor_t destDesc,
|
| 435 |
+
void *destData);
|
| 436 |
+
|
| 437 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 438 |
+
cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc);
|
| 439 |
+
|
| 440 |
+
/*
|
| 441 |
+
* softmax algorithm
|
| 442 |
+
*/
|
| 443 |
+
typedef enum {
|
| 444 |
+
CUDNN_SOFTMAX_FAST = 0, /* straightforward implementation */
|
| 445 |
+
CUDNN_SOFTMAX_ACCURATE = 1, /* subtract max from every point to avoid overflow */
|
| 446 |
+
CUDNN_SOFTMAX_LOG = 2
|
| 447 |
+
} cudnnSoftmaxAlgorithm_t;
|
| 448 |
+
|
| 449 |
+
typedef enum {
|
| 450 |
+
CUDNN_SOFTMAX_MODE_INSTANCE = 0, /* compute the softmax over all C, H, W for each N */
|
| 451 |
+
CUDNN_SOFTMAX_MODE_CHANNEL = 1 /* compute the softmax over all C for each H, W, N */
|
| 452 |
+
} cudnnSoftmaxMode_t;
|
| 453 |
+
|
| 454 |
+
/* Softmax functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 455 |
+
|
| 456 |
+
/* Function to perform forward softmax */
|
| 457 |
+
cudnnStatus_t CUDNNWINAPI
|
| 458 |
+
cudnnSoftmaxForward(cudnnHandle_t handle,
|
| 459 |
+
cudnnSoftmaxAlgorithm_t algo,
|
| 460 |
+
cudnnSoftmaxMode_t mode,
|
| 461 |
+
const void *alpha,
|
| 462 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 463 |
+
const void *x,
|
| 464 |
+
const void *beta,
|
| 465 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 466 |
+
void *y);
|
| 467 |
+
|
| 468 |
+
/*
|
| 469 |
+
* pooling mode
|
| 470 |
+
*/
|
| 471 |
+
typedef enum {
|
| 472 |
+
CUDNN_POOLING_MAX = 0,
|
| 473 |
+
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values */
|
| 474 |
+
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values */
|
| 475 |
+
CUDNN_POOLING_MAX_DETERMINISTIC = 3
|
| 476 |
+
} cudnnPoolingMode_t CUDNN_DEPRECATED;
|
| 477 |
+
|
| 478 |
+
/* Create an instance of pooling descriptor */
|
| 479 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 480 |
+
cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
|
| 481 |
+
|
| 482 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 483 |
+
cudnnSetPooling2dDescriptor(cudnnPoolingDescriptor_t poolingDesc,
|
| 484 |
+
cudnnPoolingMode_t mode,
|
| 485 |
+
cudnnNanPropagation_t maxpoolingNanOpt,
|
| 486 |
+
int windowHeight,
|
| 487 |
+
int windowWidth,
|
| 488 |
+
int verticalPadding,
|
| 489 |
+
int horizontalPadding,
|
| 490 |
+
int verticalStride,
|
| 491 |
+
int horizontalStride);
|
| 492 |
+
|
| 493 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 494 |
+
cudnnGetPooling2dDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
|
| 495 |
+
cudnnPoolingMode_t *mode,
|
| 496 |
+
cudnnNanPropagation_t *maxpoolingNanOpt,
|
| 497 |
+
int *windowHeight,
|
| 498 |
+
int *windowWidth,
|
| 499 |
+
int *verticalPadding,
|
| 500 |
+
int *horizontalPadding,
|
| 501 |
+
int *verticalStride,
|
| 502 |
+
int *horizontalStride);
|
| 503 |
+
|
| 504 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 505 |
+
cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc,
|
| 506 |
+
const cudnnPoolingMode_t mode,
|
| 507 |
+
const cudnnNanPropagation_t maxpoolingNanOpt,
|
| 508 |
+
int nbDims,
|
| 509 |
+
const int windowDimA[],
|
| 510 |
+
const int paddingA[],
|
| 511 |
+
const int strideA[]);
|
| 512 |
+
|
| 513 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 514 |
+
cudnnGetPoolingNdDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
|
| 515 |
+
int nbDimsRequested,
|
| 516 |
+
cudnnPoolingMode_t *mode,
|
| 517 |
+
cudnnNanPropagation_t *maxpoolingNanOpt,
|
| 518 |
+
int *nbDims,
|
| 519 |
+
int windowDimA[],
|
| 520 |
+
int paddingA[],
|
| 521 |
+
int strideA[]);
|
| 522 |
+
|
| 523 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 524 |
+
cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
|
| 525 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 526 |
+
int nbDims,
|
| 527 |
+
int outputTensorDimA[]);
|
| 528 |
+
|
| 529 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 530 |
+
cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
|
| 531 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 532 |
+
int *n,
|
| 533 |
+
int *c,
|
| 534 |
+
int *h,
|
| 535 |
+
int *w);
|
| 536 |
+
|
| 537 |
+
/* Destroy an instance of pooling descriptor */
|
| 538 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 539 |
+
cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc);
|
| 540 |
+
|
| 541 |
+
/* Pooling functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 542 |
+
|
| 543 |
+
/* Function to perform forward pooling */
|
| 544 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 545 |
+
cudnnPoolingForward(cudnnHandle_t handle,
|
| 546 |
+
const cudnnPoolingDescriptor_t poolingDesc,
|
| 547 |
+
const void *alpha,
|
| 548 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 549 |
+
const void *x,
|
| 550 |
+
const void *beta,
|
| 551 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 552 |
+
void *y);
|
| 553 |
+
|
| 554 |
+
/* Activation functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 555 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 556 |
+
cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc);
|
| 557 |
+
|
| 558 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 559 |
+
cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc,
|
| 560 |
+
cudnnActivationMode_t mode,
|
| 561 |
+
cudnnNanPropagation_t reluNanOpt,
|
| 562 |
+
double coef); /* ceiling for clipped RELU, alpha for ELU */
|
| 563 |
+
|
| 564 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 565 |
+
cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc,
|
| 566 |
+
cudnnActivationMode_t *mode,
|
| 567 |
+
cudnnNanPropagation_t *reluNanOpt,
|
| 568 |
+
double *coef); /* ceiling for clipped RELU, alpha for ELU */
|
| 569 |
+
|
| 570 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 571 |
+
cudnnSetActivationDescriptorSwishBeta(cudnnActivationDescriptor_t activationDesc, double swish_beta);
|
| 572 |
+
|
| 573 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 574 |
+
cudnnGetActivationDescriptorSwishBeta(cudnnActivationDescriptor_t activationDesc, double *swish_beta);
|
| 575 |
+
|
| 576 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 577 |
+
cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc);
|
| 578 |
+
|
| 579 |
+
/* Function to perform forward activation */
|
| 580 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 581 |
+
cudnnActivationForward(cudnnHandle_t handle,
|
| 582 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 583 |
+
const void *alpha,
|
| 584 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 585 |
+
const void *x,
|
| 586 |
+
const void *beta,
|
| 587 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 588 |
+
void *y);
|
| 589 |
+
|
| 590 |
+
/*
|
| 591 |
+
* Create an instance of LRN (Local Response Normalization) descriptor
|
| 592 |
+
* Uses lrnN=5, lrnAlpha=1e-4, lrnBeta=0.75, lrnK=2.0 as defaults from Krizhevsky'12 ImageNet paper
|
| 593 |
+
*/
|
| 594 |
+
cudnnStatus_t CUDNNWINAPI
|
| 595 |
+
cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc);
|
| 596 |
+
|
| 597 |
+
#define CUDNN_LRN_MIN_N 1 /* minimum allowed lrnN */
|
| 598 |
+
#define CUDNN_LRN_MAX_N 16 /* maximum allowed lrnN */
|
| 599 |
+
#define CUDNN_LRN_MIN_K 1e-5 /* minimum allowed lrnK */
|
| 600 |
+
#define CUDNN_LRN_MIN_BETA 0.01 /* minimum allowed lrnBeta */
|
| 601 |
+
|
| 602 |
+
/* LRN layer mode */
|
| 603 |
+
typedef enum {
|
| 604 |
+
CUDNN_LRN_CROSS_CHANNEL_DIM1 = 0, /* Normalize across tensor's dimA[1] dimension */
|
| 605 |
+
} cudnnLRNMode_t;
|
| 606 |
+
|
| 607 |
+
/*
|
| 608 |
+
* Uses a window [center-lookBehind, center+lookAhead], where
|
| 609 |
+
* lookBehind = floor( (lrnN-1)/2 ), lookAhead = lrnN-lookBehind-1.
|
| 610 |
+
* Values of double parameters cast to tensor data type.
|
| 611 |
+
*/
|
| 612 |
+
cudnnStatus_t CUDNNWINAPI
|
| 613 |
+
cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned lrnN, double lrnAlpha, double lrnBeta, double lrnK);
|
| 614 |
+
/*
|
| 615 |
+
* Retrieve the settings currently stored in an LRN layer descriptor
|
| 616 |
+
* Any of the provided pointers can be NULL (no corresponding value will be returned)
|
| 617 |
+
*/
|
| 618 |
+
cudnnStatus_t CUDNNWINAPI
|
| 619 |
+
cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrnAlpha, double *lrnBeta, double *lrnK);
|
| 620 |
+
|
| 621 |
+
/* Destroy an instance of LRN descriptor */
|
| 622 |
+
cudnnStatus_t CUDNNWINAPI
|
| 623 |
+
cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc);
|
| 624 |
+
|
| 625 |
+
/* LRN functions: output = alpha * normalize(x) + beta * old_y */
|
| 626 |
+
|
| 627 |
+
/* LRN cross-channel forward computation. Double parameters cast to tensor data type */
|
| 628 |
+
cudnnStatus_t CUDNNWINAPI
|
| 629 |
+
cudnnLRNCrossChannelForward(cudnnHandle_t handle,
|
| 630 |
+
cudnnLRNDescriptor_t normDesc,
|
| 631 |
+
cudnnLRNMode_t lrnMode,
|
| 632 |
+
const void *alpha,
|
| 633 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 634 |
+
const void *x,
|
| 635 |
+
const void *beta,
|
| 636 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 637 |
+
void *y);
|
| 638 |
+
|
| 639 |
+
typedef enum {
|
| 640 |
+
CUDNN_DIVNORM_PRECOMPUTED_MEANS = 0,
|
| 641 |
+
} cudnnDivNormMode_t;
|
| 642 |
+
|
| 643 |
+
/* LCN/divisive normalization functions: y = alpha * normalize(x) + beta * y */
|
| 644 |
+
cudnnStatus_t CUDNNWINAPI
|
| 645 |
+
cudnnDivisiveNormalizationForward(cudnnHandle_t handle,
|
| 646 |
+
cudnnLRNDescriptor_t normDesc,
|
| 647 |
+
cudnnDivNormMode_t mode,
|
| 648 |
+
const void *alpha,
|
| 649 |
+
const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */
|
| 650 |
+
const void *x,
|
| 651 |
+
const void *means, /* if NULL, means are assumed to be zero */
|
| 652 |
+
void *temp,
|
| 653 |
+
void *temp2,
|
| 654 |
+
const void *beta,
|
| 655 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 656 |
+
void *y);
|
| 657 |
+
|
| 658 |
+
typedef enum {
|
| 659 |
+
/* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
|
| 660 |
+
CUDNN_BATCHNORM_PER_ACTIVATION = 0,
|
| 661 |
+
|
| 662 |
+
/* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
|
| 663 |
+
CUDNN_BATCHNORM_SPATIAL = 1,
|
| 664 |
+
|
| 665 |
+
/*
|
| 666 |
+
* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors).
|
| 667 |
+
* May be faster than CUDNN_BATCHNORM_SPATIAL but imposes some limits on the range of values
|
| 668 |
+
*/
|
| 669 |
+
CUDNN_BATCHNORM_SPATIAL_PERSISTENT = 2,
|
| 670 |
+
} cudnnBatchNormMode_t CUDNN_DEPRECATED;
|
| 671 |
+
|
| 672 |
+
#define CUDNN_BN_MIN_EPSILON 0.0 /* Minimum epsilon allowed to be used in the Batch Normalization formula */
|
| 673 |
+
|
| 674 |
+
/*
|
| 675 |
+
* Derives a tensor descriptor from layer data descriptor for BatchNormalization
|
| 676 |
+
* scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
|
| 677 |
+
* bnScaleBiasMeanVarDesc and bnScaleBiasDiffDesc in Batch Normalization forward and backward functions.
|
| 678 |
+
*/
|
| 679 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 680 |
+
cudnnDeriveBNTensorDescriptor(cudnnTensorDescriptor_t derivedBnDesc,
|
| 681 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 682 |
+
cudnnBatchNormMode_t mode);
|
| 683 |
+
|
| 684 |
+
typedef enum {
|
| 685 |
+
CUDNN_BATCHNORM_OPS_BN = 0, /* do batch normalization only */
|
| 686 |
+
CUDNN_BATCHNORM_OPS_BN_ACTIVATION = 1, /* do batchNorm, then activation */
|
| 687 |
+
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION = 2, /* do batchNorm, then elemWiseAdd, then activation */
|
| 688 |
+
} cudnnBatchNormOps_t CUDNN_DEPRECATED;
|
| 689 |
+
|
| 690 |
+
/*
|
| 691 |
+
* Performs Batch Normalization during Inference:
|
| 692 |
+
* y[i] = bnScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + bnBias[k]
|
| 693 |
+
* with bnScale, bnBias, runningMean, runningInvVariance tensors indexed
|
| 694 |
+
* according to spatial or per-activation mode. Refer to cudnnBatchNormalizationForwardTraining
|
| 695 |
+
* above for notes on function arguments.
|
| 696 |
+
*/
|
| 697 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 698 |
+
cudnnBatchNormalizationForwardInference(cudnnHandle_t handle,
|
| 699 |
+
cudnnBatchNormMode_t mode,
|
| 700 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 701 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 702 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 703 |
+
const void *x, /* NxCxHxW */
|
| 704 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 705 |
+
void *y, /* NxCxHxW */
|
| 706 |
+
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
|
| 707 |
+
const void *bnScale,
|
| 708 |
+
const void *bnBias,
|
| 709 |
+
const void *estimatedMean,
|
| 710 |
+
const void *estimatedVariance,
|
| 711 |
+
double epsilon);
|
| 712 |
+
|
| 713 |
+
typedef enum {
|
| 714 |
+
/* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
|
| 715 |
+
CUDNN_NORM_PER_ACTIVATION = 0,
|
| 716 |
+
|
| 717 |
+
/* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
|
| 718 |
+
CUDNN_NORM_PER_CHANNEL = 1,
|
| 719 |
+
} cudnnNormMode_t CUDNN_DEPRECATED;
|
| 720 |
+
|
| 721 |
+
typedef enum { CUDNN_NORM_ALGO_STANDARD = 0, CUDNN_NORM_ALGO_PERSIST = 1 } cudnnNormAlgo_t CUDNN_DEPRECATED;
|
| 722 |
+
|
| 723 |
+
/*
|
| 724 |
+
* Derives a tensor descriptor from layer data descriptor for Normalization
|
| 725 |
+
* scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
|
| 726 |
+
* normScaleBiasMeanVarDesc and normScaleBiasDiffDesc in Normalization forward and backward functions.
|
| 727 |
+
*/
|
| 728 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 729 |
+
cudnnDeriveNormTensorDescriptor(cudnnTensorDescriptor_t derivedNormScaleBiasDesc,
|
| 730 |
+
cudnnTensorDescriptor_t derivedNormMeanVarDesc,
|
| 731 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 732 |
+
cudnnNormMode_t mode,
|
| 733 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 734 |
+
|
| 735 |
+
typedef enum {
|
| 736 |
+
CUDNN_NORM_OPS_NORM = 0, /* do normalization only */
|
| 737 |
+
CUDNN_NORM_OPS_NORM_ACTIVATION = 1, /* do Norm, then activation */
|
| 738 |
+
CUDNN_NORM_OPS_NORM_ADD_ACTIVATION = 2, /* do Norm, then elemWiseAdd, then activation */
|
| 739 |
+
} cudnnNormOps_t CUDNN_DEPRECATED;
|
| 740 |
+
|
| 741 |
+
/*
|
| 742 |
+
* Performs Normalization during Inference:
|
| 743 |
+
* y[i] = normScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + normBias[k]
|
| 744 |
+
* with normScale, normBias, runningMean, runningInvVariance tensors indexed
|
| 745 |
+
* according to per-channel or per-activation mode. Refer to cudnnNormalizationForwardTraining
|
| 746 |
+
* above for notes on function arguments.
|
| 747 |
+
*/
|
| 748 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 749 |
+
cudnnNormalizationForwardInference(cudnnHandle_t handle,
|
| 750 |
+
cudnnNormMode_t mode,
|
| 751 |
+
cudnnNormOps_t normOps,
|
| 752 |
+
cudnnNormAlgo_t algo,
|
| 753 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 754 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 755 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 756 |
+
const void *x, /* NxCxHxW */
|
| 757 |
+
const cudnnTensorDescriptor_t normScaleBiasDesc,
|
| 758 |
+
const void *normScale,
|
| 759 |
+
const void *normBias,
|
| 760 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 761 |
+
const void *estimatedMean,
|
| 762 |
+
const void *estimatedVariance,
|
| 763 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 764 |
+
const void *z,
|
| 765 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 766 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 767 |
+
void *y, /* NxCxHxW */
|
| 768 |
+
double epsilon,
|
| 769 |
+
int groupCnt); /* Place hold for future work*/
|
| 770 |
+
|
| 771 |
+
/* APIs for spatial transformer network*/
|
| 772 |
+
typedef enum {
|
| 773 |
+
CUDNN_SAMPLER_BILINEAR = 0,
|
| 774 |
+
} cudnnSamplerType_t;
|
| 775 |
+
|
| 776 |
+
cudnnStatus_t CUDNNWINAPI
|
| 777 |
+
cudnnCreateSpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t *stDesc);
|
| 778 |
+
|
| 779 |
+
cudnnStatus_t CUDNNWINAPI
|
| 780 |
+
cudnnSetSpatialTransformerNdDescriptor(cudnnSpatialTransformerDescriptor_t stDesc,
|
| 781 |
+
cudnnSamplerType_t samplerType,
|
| 782 |
+
cudnnDataType_t dataType,
|
| 783 |
+
const int nbDims,
|
| 784 |
+
const int dimA[]);
|
| 785 |
+
|
| 786 |
+
cudnnStatus_t CUDNNWINAPI
|
| 787 |
+
cudnnDestroySpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t stDesc);
|
| 788 |
+
|
| 789 |
+
cudnnStatus_t CUDNNWINAPI
|
| 790 |
+
cudnnSpatialTfGridGeneratorForward(cudnnHandle_t handle,
|
| 791 |
+
const cudnnSpatialTransformerDescriptor_t stDesc,
|
| 792 |
+
const void *theta,
|
| 793 |
+
void *grid);
|
| 794 |
+
|
| 795 |
+
cudnnStatus_t CUDNNWINAPI
|
| 796 |
+
cudnnSpatialTfSamplerForward(cudnnHandle_t handle,
|
| 797 |
+
cudnnSpatialTransformerDescriptor_t stDesc,
|
| 798 |
+
const void *alpha,
|
| 799 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 800 |
+
const void *x,
|
| 801 |
+
const void *grid,
|
| 802 |
+
const void *beta,
|
| 803 |
+
cudnnTensorDescriptor_t yDesc,
|
| 804 |
+
void *y);
|
| 805 |
+
|
| 806 |
+
typedef struct cudnnDropoutStruct *cudnnDropoutDescriptor_t;
|
| 807 |
+
|
| 808 |
+
cudnnStatus_t CUDNNWINAPI
|
| 809 |
+
cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc);
|
| 810 |
+
|
| 811 |
+
cudnnStatus_t CUDNNWINAPI
|
| 812 |
+
cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc);
|
| 813 |
+
|
| 814 |
+
/*helper function to determine size of the states to be passed to cudnnSetDropoutDescriptor */
|
| 815 |
+
cudnnStatus_t CUDNNWINAPI
|
| 816 |
+
cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t *sizeInBytes);
|
| 817 |
+
|
| 818 |
+
/*helper function to determine size of the reserve space to be passed to dropout forward/backward calls */
|
| 819 |
+
cudnnStatus_t CUDNNWINAPI
|
| 820 |
+
cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes);
|
| 821 |
+
|
| 822 |
+
cudnnStatus_t CUDNNWINAPI
|
| 823 |
+
cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
|
| 824 |
+
cudnnHandle_t handle,
|
| 825 |
+
float dropout,
|
| 826 |
+
void *states,
|
| 827 |
+
size_t stateSizeInBytes,
|
| 828 |
+
unsigned long long seed);
|
| 829 |
+
|
| 830 |
+
/* Restores the dropout descriptor to a previously saved-off state */
|
| 831 |
+
cudnnStatus_t CUDNNWINAPI
|
| 832 |
+
cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
|
| 833 |
+
cudnnHandle_t handle,
|
| 834 |
+
float dropout,
|
| 835 |
+
void *states,
|
| 836 |
+
size_t stateSizeInBytes,
|
| 837 |
+
unsigned long long seed);
|
| 838 |
+
|
| 839 |
+
cudnnStatus_t CUDNNWINAPI
|
| 840 |
+
cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
|
| 841 |
+
cudnnHandle_t handle,
|
| 842 |
+
float *dropout,
|
| 843 |
+
void **states,
|
| 844 |
+
unsigned long long *seed);
|
| 845 |
+
|
| 846 |
+
cudnnStatus_t CUDNNWINAPI
|
| 847 |
+
cudnnDropoutForward(cudnnHandle_t handle,
|
| 848 |
+
const cudnnDropoutDescriptor_t dropoutDesc,
|
| 849 |
+
const cudnnTensorDescriptor_t xdesc,
|
| 850 |
+
const void *x,
|
| 851 |
+
const cudnnTensorDescriptor_t ydesc,
|
| 852 |
+
void *y,
|
| 853 |
+
void *reserveSpace,
|
| 854 |
+
size_t reserveSpaceSizeInBytes);
|
| 855 |
+
|
| 856 |
+
/* TODO: move these enums out to the appropriate submodule */
|
| 857 |
+
typedef enum {
|
| 858 |
+
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0,
|
| 859 |
+
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1,
|
| 860 |
+
CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2,
|
| 861 |
+
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3,
|
| 862 |
+
CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4,
|
| 863 |
+
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5,
|
| 864 |
+
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6,
|
| 865 |
+
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7,
|
| 866 |
+
CUDNN_CONVOLUTION_FWD_ALGO_COUNT = 8
|
| 867 |
+
} cudnnConvolutionFwdAlgo_t;
|
| 868 |
+
|
| 869 |
+
typedef enum {
|
| 870 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, /* non-deterministic */
|
| 871 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
|
| 872 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
|
| 873 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3, /* non-deterministic */
|
| 874 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4, /* not implemented */
|
| 875 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5,
|
| 876 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING = 6,
|
| 877 |
+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT = 7
|
| 878 |
+
} cudnnConvolutionBwdFilterAlgo_t;
|
| 879 |
+
|
| 880 |
+
typedef enum {
|
| 881 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, /* non-deterministic */
|
| 882 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
|
| 883 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
|
| 884 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
|
| 885 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4,
|
| 886 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5,
|
| 887 |
+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT = 6
|
| 888 |
+
} cudnnConvolutionBwdDataAlgo_t;
|
| 889 |
+
|
| 890 |
+
typedef enum { CUDNN_CTC_LOSS_ALGO_DETERMINISTIC = 0, CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC = 1 } cudnnCTCLossAlgo_t;
|
| 891 |
+
|
| 892 |
+
/*
|
| 893 |
+
* \brief Cross-library version checker.
|
| 894 |
+
* This function is implemented differently in each sub-library. Each sublib
|
| 895 |
+
* checks whether its own version matches that of its dependencies.
|
| 896 |
+
* \returns CUDNN_STATUS_SUCCESS if the version check passes,
|
| 897 |
+
* CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
|
| 898 |
+
*/
|
| 899 |
+
cudnnStatus_t CUDNNWINAPI
|
| 900 |
+
cudnnOpsVersionCheck(void);
|
| 901 |
+
|
| 902 |
+
/* Function to perform backward softmax */
|
| 903 |
+
cudnnStatus_t CUDNNWINAPI
|
| 904 |
+
cudnnSoftmaxBackward(cudnnHandle_t handle,
|
| 905 |
+
cudnnSoftmaxAlgorithm_t algo,
|
| 906 |
+
cudnnSoftmaxMode_t mode,
|
| 907 |
+
const void *alpha,
|
| 908 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 909 |
+
const void *y,
|
| 910 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 911 |
+
const void *dy,
|
| 912 |
+
const void *beta,
|
| 913 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 914 |
+
void *dx);
|
| 915 |
+
|
| 916 |
+
/* Function to perform backward pooling */
|
| 917 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 918 |
+
cudnnPoolingBackward(cudnnHandle_t handle,
|
| 919 |
+
const cudnnPoolingDescriptor_t poolingDesc,
|
| 920 |
+
const void *alpha,
|
| 921 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 922 |
+
const void *y,
|
| 923 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 924 |
+
const void *dy,
|
| 925 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 926 |
+
const void *x,
|
| 927 |
+
const void *beta,
|
| 928 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 929 |
+
void *dx);
|
| 930 |
+
|
| 931 |
+
/* Function to perform backward activation */
|
| 932 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 933 |
+
cudnnActivationBackward(cudnnHandle_t handle,
|
| 934 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 935 |
+
const void *alpha,
|
| 936 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 937 |
+
const void *y,
|
| 938 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 939 |
+
const void *dy,
|
| 940 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 941 |
+
const void *x,
|
| 942 |
+
const void *beta,
|
| 943 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 944 |
+
void *dx);
|
| 945 |
+
|
| 946 |
+
/* LRN cross-channel backward computation. Double parameters cast to tensor data type */
|
| 947 |
+
cudnnStatus_t CUDNNWINAPI
|
| 948 |
+
cudnnLRNCrossChannelBackward(cudnnHandle_t handle,
|
| 949 |
+
cudnnLRNDescriptor_t normDesc,
|
| 950 |
+
cudnnLRNMode_t lrnMode,
|
| 951 |
+
const void *alpha,
|
| 952 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 953 |
+
const void *y,
|
| 954 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 955 |
+
const void *dy,
|
| 956 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 957 |
+
const void *x,
|
| 958 |
+
const void *beta,
|
| 959 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 960 |
+
void *dx);
|
| 961 |
+
|
| 962 |
+
cudnnStatus_t CUDNNWINAPI
|
| 963 |
+
cudnnDivisiveNormalizationBackward(cudnnHandle_t handle,
|
| 964 |
+
cudnnLRNDescriptor_t normDesc,
|
| 965 |
+
cudnnDivNormMode_t mode,
|
| 966 |
+
const void *alpha,
|
| 967 |
+
const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */
|
| 968 |
+
const void *x,
|
| 969 |
+
const void *means, /* if NULL, means are assumed to be zero */
|
| 970 |
+
const void *dy,
|
| 971 |
+
void *temp,
|
| 972 |
+
void *temp2,
|
| 973 |
+
const void *beta,
|
| 974 |
+
const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */
|
| 975 |
+
void *dx, /* output x differential */
|
| 976 |
+
void *dMeans); /* output means differential, can be NULL */
|
| 977 |
+
|
| 978 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 979 |
+
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(cudnnHandle_t handle,
|
| 980 |
+
cudnnBatchNormMode_t mode,
|
| 981 |
+
cudnnBatchNormOps_t bnOps,
|
| 982 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 983 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 984 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 985 |
+
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
|
| 986 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 987 |
+
size_t *sizeInBytes);
|
| 988 |
+
|
| 989 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 990 |
+
cudnnGetBatchNormalizationBackwardExWorkspaceSize(cudnnHandle_t handle,
|
| 991 |
+
cudnnBatchNormMode_t mode,
|
| 992 |
+
cudnnBatchNormOps_t bnOps,
|
| 993 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 994 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 995 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 996 |
+
const cudnnTensorDescriptor_t dzDesc,
|
| 997 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 998 |
+
const cudnnTensorDescriptor_t dBnScaleBiasDesc,
|
| 999 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1000 |
+
size_t *sizeInBytes);
|
| 1001 |
+
|
| 1002 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1003 |
+
cudnnGetBatchNormalizationTrainingExReserveSpaceSize(cudnnHandle_t handle,
|
| 1004 |
+
cudnnBatchNormMode_t mode,
|
| 1005 |
+
cudnnBatchNormOps_t bnOps,
|
| 1006 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1007 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1008 |
+
size_t *sizeInBytes);
|
| 1009 |
+
|
| 1010 |
+
/* Computes y = BN(x). Also accumulates moving averages of mean and inverse variances */
|
| 1011 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1012 |
+
cudnnBatchNormalizationForwardTraining(
|
| 1013 |
+
cudnnHandle_t handle,
|
| 1014 |
+
cudnnBatchNormMode_t mode,
|
| 1015 |
+
|
| 1016 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 1017 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 1018 |
+
|
| 1019 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1020 |
+
const void *x, /* NxCxHxW */
|
| 1021 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1022 |
+
void *y, /* NxCxHxW */
|
| 1023 |
+
|
| 1024 |
+
/* Shared desc for the next 6 tensors in the argument list.
|
| 1025 |
+
Data type to be set as follows:
|
| 1026 |
+
type = (typeOf(x) == double) ? double : float
|
| 1027 |
+
Dimensions for this descriptor depend on normalization mode
|
| 1028 |
+
- Spatial Normalization : tensors are expected to have dims 1xCx1x1
|
| 1029 |
+
(normalization is performed across NxHxW)
|
| 1030 |
+
- Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW
|
| 1031 |
+
(normalization is performed across N) */
|
| 1032 |
+
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
|
| 1033 |
+
|
| 1034 |
+
/* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */
|
| 1035 |
+
const void *bnScale,
|
| 1036 |
+
const void *bnBias,
|
| 1037 |
+
|
| 1038 |
+
/* MUST use factor=1 in the very first call of a complete training cycle.
|
| 1039 |
+
Use a factor=1/(1+n) at N-th call to the function to get
|
| 1040 |
+
Cumulative Moving Average (CMA) behavior
|
| 1041 |
+
CMA[n] = (x[1]+...+x[n])/n
|
| 1042 |
+
Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) =
|
| 1043 |
+
((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) =
|
| 1044 |
+
CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */
|
| 1045 |
+
double exponentialAverageFactor,
|
| 1046 |
+
|
| 1047 |
+
/* Used in Training phase only.
|
| 1048 |
+
runningMean = newMean*factor + runningMean*(1-factor) */
|
| 1049 |
+
void *resultRunningMean,
|
| 1050 |
+
/* Output in training mode, input in inference. Is the moving average
|
| 1051 |
+
of variance[x] (factor is applied in the same way as for runningMean) */
|
| 1052 |
+
void *resultRunningVariance,
|
| 1053 |
+
|
| 1054 |
+
/* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
|
| 1055 |
+
double epsilon,
|
| 1056 |
+
|
| 1057 |
+
/* Optionally save intermediate results from the forward pass here
|
| 1058 |
+
- can be reused to speed up backward pass. NULL if unused */
|
| 1059 |
+
void *resultSaveMean,
|
| 1060 |
+
void *resultSaveInvVariance);
|
| 1061 |
+
|
| 1062 |
+
/* Computes y = relu(BN(x) + z). Also accumulates moving averages of mean and inverse variances */
|
| 1063 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1064 |
+
cudnnBatchNormalizationForwardTrainingEx(
|
| 1065 |
+
cudnnHandle_t handle,
|
| 1066 |
+
cudnnBatchNormMode_t mode,
|
| 1067 |
+
cudnnBatchNormOps_t bnOps,
|
| 1068 |
+
|
| 1069 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 1070 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 1071 |
+
|
| 1072 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1073 |
+
const void *xData,
|
| 1074 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 1075 |
+
const void *zData,
|
| 1076 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1077 |
+
void *yData,
|
| 1078 |
+
|
| 1079 |
+
const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
|
| 1080 |
+
const void *bnScale,
|
| 1081 |
+
const void *bnBias,
|
| 1082 |
+
|
| 1083 |
+
double exponentialAverageFactor,
|
| 1084 |
+
void *resultRunningMean,
|
| 1085 |
+
void *resultRunningVariance,
|
| 1086 |
+
|
| 1087 |
+
/* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
|
| 1088 |
+
double epsilon,
|
| 1089 |
+
|
| 1090 |
+
/* Optionally save intermediate results from the forward pass here
|
| 1091 |
+
- can be reused to speed up backward pass. NULL if unused */
|
| 1092 |
+
void *resultSaveMean,
|
| 1093 |
+
void *resultSaveInvVariance,
|
| 1094 |
+
|
| 1095 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 1096 |
+
void *workspace,
|
| 1097 |
+
size_t workSpaceSizeInBytes,
|
| 1098 |
+
void *reserveSpace,
|
| 1099 |
+
size_t reserveSpaceSizeInBytes);
|
| 1100 |
+
|
| 1101 |
+
/* Performs backward pass of Batch Normalization layer. Returns x gradient,
|
| 1102 |
+
* bnScale gradient and bnBias gradient */
|
| 1103 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1104 |
+
cudnnBatchNormalizationBackward(cudnnHandle_t handle,
|
| 1105 |
+
cudnnBatchNormMode_t mode,
|
| 1106 |
+
const void *alphaDataDiff,
|
| 1107 |
+
const void *betaDataDiff,
|
| 1108 |
+
const void *alphaParamDiff,
|
| 1109 |
+
const void *betaParamDiff,
|
| 1110 |
+
const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */
|
| 1111 |
+
const void *x,
|
| 1112 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1113 |
+
const void *dy,
|
| 1114 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1115 |
+
void *dx,
|
| 1116 |
+
/* Shared tensor desc for the 4 tensors below */
|
| 1117 |
+
const cudnnTensorDescriptor_t dBnScaleBiasDesc,
|
| 1118 |
+
const void *bnScale, /* bnBias doesn't affect backpropagation */
|
| 1119 |
+
/* scale and bias diff are not backpropagated below this layer */
|
| 1120 |
+
void *dBnScaleResult,
|
| 1121 |
+
void *dBnBiasResult,
|
| 1122 |
+
/* Same epsilon as forward pass */
|
| 1123 |
+
double epsilon,
|
| 1124 |
+
|
| 1125 |
+
/* Optionally cached intermediate results from
|
| 1126 |
+
forward pass */
|
| 1127 |
+
const void *savedMean,
|
| 1128 |
+
const void *savedInvVariance);
|
| 1129 |
+
|
| 1130 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1131 |
+
cudnnBatchNormalizationBackwardEx(cudnnHandle_t handle,
|
| 1132 |
+
cudnnBatchNormMode_t mode,
|
| 1133 |
+
cudnnBatchNormOps_t bnOps,
|
| 1134 |
+
|
| 1135 |
+
const void *alphaDataDiff,
|
| 1136 |
+
const void *betaDataDiff,
|
| 1137 |
+
const void *alphaParamDiff,
|
| 1138 |
+
const void *betaParamDiff,
|
| 1139 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1140 |
+
const void *xData,
|
| 1141 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1142 |
+
const void *yData,
|
| 1143 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1144 |
+
const void *dyData,
|
| 1145 |
+
const cudnnTensorDescriptor_t dzDesc,
|
| 1146 |
+
void *dzData,
|
| 1147 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1148 |
+
void *dxData,
|
| 1149 |
+
|
| 1150 |
+
/* Shared tensor desc for the 4 tensors below */
|
| 1151 |
+
const cudnnTensorDescriptor_t dBnScaleBiasDesc,
|
| 1152 |
+
const void *bnScaleData,
|
| 1153 |
+
const void *bnBiasData, /* needed if there is activation */
|
| 1154 |
+
void *dBnScaleData,
|
| 1155 |
+
void *dBnBiasData,
|
| 1156 |
+
double epsilon, /* Same epsilon as forward pass */
|
| 1157 |
+
|
| 1158 |
+
/* Optionally cached intermediate results from
|
| 1159 |
+
forward pass */
|
| 1160 |
+
const void *savedMean,
|
| 1161 |
+
const void *savedInvVariance,
|
| 1162 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 1163 |
+
void *workSpace,
|
| 1164 |
+
size_t workSpaceSizeInBytes,
|
| 1165 |
+
void *reserveSpace,
|
| 1166 |
+
size_t reserveSpaceSizeInBytes);
|
| 1167 |
+
|
| 1168 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1169 |
+
cudnnGetNormalizationForwardTrainingWorkspaceSize(cudnnHandle_t handle,
|
| 1170 |
+
cudnnNormMode_t mode,
|
| 1171 |
+
cudnnNormOps_t normOps,
|
| 1172 |
+
cudnnNormAlgo_t algo,
|
| 1173 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1174 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 1175 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1176 |
+
const cudnnTensorDescriptor_t normScaleBiasDesc,
|
| 1177 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1178 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 1179 |
+
size_t *sizeInBytes,
|
| 1180 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1181 |
+
|
| 1182 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1183 |
+
cudnnGetNormalizationBackwardWorkspaceSize(cudnnHandle_t handle,
|
| 1184 |
+
cudnnNormMode_t mode,
|
| 1185 |
+
cudnnNormOps_t normOps,
|
| 1186 |
+
cudnnNormAlgo_t algo,
|
| 1187 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1188 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1189 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1190 |
+
const cudnnTensorDescriptor_t dzDesc,
|
| 1191 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1192 |
+
const cudnnTensorDescriptor_t dNormScaleBiasDesc,
|
| 1193 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1194 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 1195 |
+
size_t *sizeInBytes,
|
| 1196 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1197 |
+
|
| 1198 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1199 |
+
cudnnGetNormalizationTrainingReserveSpaceSize(cudnnHandle_t handle,
|
| 1200 |
+
cudnnNormMode_t mode,
|
| 1201 |
+
cudnnNormOps_t normOps,
|
| 1202 |
+
cudnnNormAlgo_t algo,
|
| 1203 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 1204 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1205 |
+
size_t *sizeInBytes,
|
| 1206 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1207 |
+
|
| 1208 |
+
/* Computes y = relu(Norm(x) + z). Also accumulates moving averages of mean and inverse variances */
|
| 1209 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1210 |
+
cudnnNormalizationForwardTraining(cudnnHandle_t handle,
|
| 1211 |
+
cudnnNormMode_t mode,
|
| 1212 |
+
cudnnNormOps_t normOps,
|
| 1213 |
+
cudnnNormAlgo_t algo,
|
| 1214 |
+
const void *alpha, /* alpha[0] = result blend factor */
|
| 1215 |
+
const void *beta, /* beta[0] = dest layer blend factor */
|
| 1216 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1217 |
+
const void *xData,
|
| 1218 |
+
const cudnnTensorDescriptor_t normScaleBiasDesc,
|
| 1219 |
+
const void *normScale,
|
| 1220 |
+
const void *normBias,
|
| 1221 |
+
double exponentialAverageFactor,
|
| 1222 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 1223 |
+
void *resultRunningMean,
|
| 1224 |
+
void *resultRunningVariance,
|
| 1225 |
+
/* Has to be >= 0. Should be the same in forward and backward functions. */
|
| 1226 |
+
double epsilon,
|
| 1227 |
+
/* Optionally save intermediate results from the forward pass here
|
| 1228 |
+
- can be reused to speed up backward pass. NULL if unused */
|
| 1229 |
+
void *resultSaveMean,
|
| 1230 |
+
void *resultSaveInvVariance,
|
| 1231 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 1232 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 1233 |
+
const void *zData,
|
| 1234 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1235 |
+
void *yData,
|
| 1236 |
+
void *workspace,
|
| 1237 |
+
size_t workSpaceSizeInBytes,
|
| 1238 |
+
void *reserveSpace,
|
| 1239 |
+
size_t reserveSpaceSizeInBytes,
|
| 1240 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1241 |
+
|
| 1242 |
+
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
|
| 1243 |
+
cudnnNormalizationBackward(cudnnHandle_t handle,
|
| 1244 |
+
cudnnNormMode_t mode,
|
| 1245 |
+
cudnnNormOps_t normOps,
|
| 1246 |
+
cudnnNormAlgo_t algo,
|
| 1247 |
+
const void *alphaDataDiff,
|
| 1248 |
+
const void *betaDataDiff,
|
| 1249 |
+
const void *alphaParamDiff,
|
| 1250 |
+
const void *betaParamDiff,
|
| 1251 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1252 |
+
const void *xData,
|
| 1253 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 1254 |
+
const void *yData,
|
| 1255 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1256 |
+
const void *dyData,
|
| 1257 |
+
const cudnnTensorDescriptor_t dzDesc,
|
| 1258 |
+
void *dzData,
|
| 1259 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1260 |
+
void *dxData,
|
| 1261 |
+
/* Shared tensor desc for the 4 tensors below */
|
| 1262 |
+
const cudnnTensorDescriptor_t dNormScaleBiasDesc,
|
| 1263 |
+
const void *normScaleData,
|
| 1264 |
+
const void *normBiasData, /* needed if there is activation */
|
| 1265 |
+
void *dNormScaleData,
|
| 1266 |
+
void *dNormBiasData,
|
| 1267 |
+
double epsilon, /* Same epsilon as forward pass */
|
| 1268 |
+
const cudnnTensorDescriptor_t normMeanVarDesc,
|
| 1269 |
+
/* Optionally cached intermediate results from
|
| 1270 |
+
forward pass */
|
| 1271 |
+
const void *savedMean,
|
| 1272 |
+
const void *savedInvVariance,
|
| 1273 |
+
cudnnActivationDescriptor_t activationDesc,
|
| 1274 |
+
void *workSpace,
|
| 1275 |
+
size_t workSpaceSizeInBytes,
|
| 1276 |
+
void *reserveSpace,
|
| 1277 |
+
size_t reserveSpaceSizeInBytes,
|
| 1278 |
+
int groupCnt); /* Place hold for future work, should be set to 1 now*/
|
| 1279 |
+
|
| 1280 |
+
cudnnStatus_t CUDNNWINAPI
|
| 1281 |
+
cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle,
|
| 1282 |
+
const cudnnSpatialTransformerDescriptor_t stDesc,
|
| 1283 |
+
const void *dgrid,
|
| 1284 |
+
void *dtheta);
|
| 1285 |
+
|
| 1286 |
+
cudnnStatus_t CUDNNWINAPI
|
| 1287 |
+
cudnnSpatialTfSamplerBackward(cudnnHandle_t handle,
|
| 1288 |
+
cudnnSpatialTransformerDescriptor_t stDesc,
|
| 1289 |
+
const void *alpha,
|
| 1290 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 1291 |
+
const void *x,
|
| 1292 |
+
const void *beta,
|
| 1293 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 1294 |
+
void *dx,
|
| 1295 |
+
const void *alphaDgrid,
|
| 1296 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 1297 |
+
const void *dy,
|
| 1298 |
+
const void *grid,
|
| 1299 |
+
const void *betaDgrid,
|
| 1300 |
+
void *dgrid);
|
| 1301 |
+
|
| 1302 |
+
cudnnStatus_t CUDNNWINAPI
|
| 1303 |
+
cudnnDropoutBackward(cudnnHandle_t handle,
|
| 1304 |
+
const cudnnDropoutDescriptor_t dropoutDesc,
|
| 1305 |
+
const cudnnTensorDescriptor_t dydesc,
|
| 1306 |
+
const void *dy,
|
| 1307 |
+
const cudnnTensorDescriptor_t dxdesc,
|
| 1308 |
+
void *dx,
|
| 1309 |
+
void *reserveSpace,
|
| 1310 |
+
size_t reserveSpaceSizeInBytes);
|
| 1311 |
+
|
| 1312 |
+
#if defined(__cplusplus)
|
| 1313 |
+
}
|
| 1314 |
+
#endif
|
| 1315 |
+
|
| 1316 |
+
#endif /* CUDNN_OPS_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_v9.h
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/* cudnn : Neural Networks Library */
|
| 51 |
+
|
| 52 |
+
#if !defined(CUDNN_H_)
|
| 53 |
+
#define CUDNN_H_
|
| 54 |
+
#if defined(__cplusplus)
|
| 55 |
+
extern "C" {
|
| 56 |
+
#endif
|
| 57 |
+
|
| 58 |
+
#include <cuda_runtime_api.h>
|
| 59 |
+
#include "cudnn_version.h"
|
| 60 |
+
#include "cudnn_graph.h"
|
| 61 |
+
#include "cudnn_ops.h"
|
| 62 |
+
#include "cudnn_adv.h"
|
| 63 |
+
#include "cudnn_cnn.h"
|
| 64 |
+
|
| 65 |
+
#if defined(__cplusplus)
|
| 66 |
+
}
|
| 67 |
+
#endif
|
| 68 |
+
#endif /* CUDNN_H_ */
|
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/**
|
| 51 |
+
* \file: The master cuDNN version file.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#ifndef CUDNN_VERSION_H_
|
| 55 |
+
#define CUDNN_VERSION_H_
|
| 56 |
+
|
| 57 |
+
#define CUDNN_MAJOR 9
|
| 58 |
+
#define CUDNN_MINOR 1
|
| 59 |
+
#define CUDNN_PATCHLEVEL 0
|
| 60 |
+
|
| 61 |
+
#define CUDNN_VERSION (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
|
| 62 |
+
|
| 63 |
+
/* cannot use constexpr here since this is a C-only file */
|
| 64 |
+
/* Below is the max SM version this cuDNN library is aware of and supports natively */
|
| 65 |
+
|
| 66 |
+
#define CUDNN_MAX_SM_MAJOR_NUMBER 9
|
| 67 |
+
#define CUDNN_MAX_SM_MINOR_NUMBER 0
|
| 68 |
+
#define CUDNN_MAX_DEVICE_VERSION (CUDNN_MAX_SM_MAJOR_NUMBER * 100 + CUDNN_MAX_SM_MINOR_NUMBER * 10)
|
| 69 |
+
|
| 70 |
+
#endif /* CUDNN_VERSION_H */
|
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libavutil-734d06dd.so.57.28.100
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b207c6d08f72e7bc047e3d17fac1a9f23c61059ffd30cba8040489e8ad79ae33
|
| 3 |
+
size 844673
|
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libssl-28bef1ac.so.1.1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf90306b659880ae21114a928d1a3421c66715f16e49a824fb9ee55113d2b767
|
| 3 |
+
size 736177
|
.venv/lib/python3.11/site-packages/pyasn1/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/pyasn1/__pycache__/debug.cpython-311.pyc
ADDED
|
Binary file (6.88 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/pyasn1/codec/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# This file is necessary to make this directory a package.
|
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# This file is necessary to make this directory a package.
|
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/decoder.cpython-311.pyc
ADDED
|
Binary file (79.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/encoder.cpython-311.pyc
ADDED
|
Binary file (34.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/eoo.cpython-311.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/decoder.py
ADDED
|
@@ -0,0 +1,2189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This file is part of pyasn1 software.
|
| 3 |
+
#
|
| 4 |
+
# Copyright (c) 2005-2020, Ilya Etingof <etingof@gmail.com>
|
| 5 |
+
# License: https://pyasn1.readthedocs.io/en/latest/license.html
|
| 6 |
+
#
|
| 7 |
+
import io
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
from pyasn1 import debug
|
| 13 |
+
from pyasn1 import error
|
| 14 |
+
from pyasn1.codec.ber import eoo
|
| 15 |
+
from pyasn1.codec.streaming import asSeekableStream
|
| 16 |
+
from pyasn1.codec.streaming import isEndOfStream
|
| 17 |
+
from pyasn1.codec.streaming import peekIntoStream
|
| 18 |
+
from pyasn1.codec.streaming import readFromStream
|
| 19 |
+
from pyasn1.compat import _MISSING
|
| 20 |
+
from pyasn1.error import PyAsn1Error
|
| 21 |
+
from pyasn1.type import base
|
| 22 |
+
from pyasn1.type import char
|
| 23 |
+
from pyasn1.type import tag
|
| 24 |
+
from pyasn1.type import tagmap
|
| 25 |
+
from pyasn1.type import univ
|
| 26 |
+
from pyasn1.type import useful
|
| 27 |
+
|
| 28 |
+
__all__ = ['StreamingDecoder', 'Decoder', 'decode']
|
| 29 |
+
|
| 30 |
+
LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_DECODER)
|
| 31 |
+
|
| 32 |
+
noValue = base.noValue
|
| 33 |
+
|
| 34 |
+
SubstrateUnderrunError = error.SubstrateUnderrunError
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AbstractPayloadDecoder(object):
|
| 38 |
+
protoComponent = None
|
| 39 |
+
|
| 40 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 41 |
+
tagSet=None, length=None, state=None,
|
| 42 |
+
decodeFun=None, substrateFun=None,
|
| 43 |
+
**options):
|
| 44 |
+
"""Decode value with fixed byte length.
|
| 45 |
+
|
| 46 |
+
The decoder is allowed to consume as many bytes as necessary.
|
| 47 |
+
"""
|
| 48 |
+
raise error.PyAsn1Error('SingleItemDecoder not implemented for %s' % (tagSet,)) # TODO: Seems more like an NotImplementedError?
|
| 49 |
+
|
| 50 |
+
def indefLenValueDecoder(self, substrate, asn1Spec,
|
| 51 |
+
tagSet=None, length=None, state=None,
|
| 52 |
+
decodeFun=None, substrateFun=None,
|
| 53 |
+
**options):
|
| 54 |
+
"""Decode value with undefined length.
|
| 55 |
+
|
| 56 |
+
The decoder is allowed to consume as many bytes as necessary.
|
| 57 |
+
"""
|
| 58 |
+
raise error.PyAsn1Error('Indefinite length mode decoder not implemented for %s' % (tagSet,)) # TODO: Seems more like an NotImplementedError?
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def _passAsn1Object(asn1Object, options):
|
| 62 |
+
if 'asn1Object' not in options:
|
| 63 |
+
options['asn1Object'] = asn1Object
|
| 64 |
+
|
| 65 |
+
return options
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AbstractSimplePayloadDecoder(AbstractPayloadDecoder):
|
| 69 |
+
@staticmethod
|
| 70 |
+
def substrateCollector(asn1Object, substrate, length, options):
|
| 71 |
+
for chunk in readFromStream(substrate, length, options):
|
| 72 |
+
yield chunk
|
| 73 |
+
|
| 74 |
+
def _createComponent(self, asn1Spec, tagSet, value, **options):
|
| 75 |
+
if options.get('native'):
|
| 76 |
+
return value
|
| 77 |
+
elif asn1Spec is None:
|
| 78 |
+
return self.protoComponent.clone(value, tagSet=tagSet)
|
| 79 |
+
elif value is noValue:
|
| 80 |
+
return asn1Spec
|
| 81 |
+
else:
|
| 82 |
+
return asn1Spec.clone(value)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class RawPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 86 |
+
protoComponent = univ.Any('')
|
| 87 |
+
|
| 88 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 89 |
+
tagSet=None, length=None, state=None,
|
| 90 |
+
decodeFun=None, substrateFun=None,
|
| 91 |
+
**options):
|
| 92 |
+
if substrateFun:
|
| 93 |
+
asn1Object = self._createComponent(asn1Spec, tagSet, '', **options)
|
| 94 |
+
|
| 95 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 96 |
+
yield chunk
|
| 97 |
+
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
for value in decodeFun(substrate, asn1Spec, tagSet, length, **options):
|
| 101 |
+
yield value
|
| 102 |
+
|
| 103 |
+
def indefLenValueDecoder(self, substrate, asn1Spec,
|
| 104 |
+
tagSet=None, length=None, state=None,
|
| 105 |
+
decodeFun=None, substrateFun=None,
|
| 106 |
+
**options):
|
| 107 |
+
if substrateFun:
|
| 108 |
+
asn1Object = self._createComponent(asn1Spec, tagSet, '', **options)
|
| 109 |
+
|
| 110 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 111 |
+
yield chunk
|
| 112 |
+
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
while True:
|
| 116 |
+
for value in decodeFun(
|
| 117 |
+
substrate, asn1Spec, tagSet, length,
|
| 118 |
+
allowEoo=True, **options):
|
| 119 |
+
|
| 120 |
+
if value is eoo.endOfOctets:
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
yield value
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
rawPayloadDecoder = RawPayloadDecoder()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class IntegerPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 130 |
+
protoComponent = univ.Integer(0)
|
| 131 |
+
|
| 132 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 133 |
+
tagSet=None, length=None, state=None,
|
| 134 |
+
decodeFun=None, substrateFun=None,
|
| 135 |
+
**options):
|
| 136 |
+
|
| 137 |
+
if tagSet[0].tagFormat != tag.tagFormatSimple:
|
| 138 |
+
raise error.PyAsn1Error('Simple tag format expected')
|
| 139 |
+
|
| 140 |
+
for chunk in readFromStream(substrate, length, options):
|
| 141 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 142 |
+
yield chunk
|
| 143 |
+
|
| 144 |
+
if chunk:
|
| 145 |
+
value = int.from_bytes(bytes(chunk), 'big', signed=True)
|
| 146 |
+
|
| 147 |
+
else:
|
| 148 |
+
value = 0
|
| 149 |
+
|
| 150 |
+
yield self._createComponent(asn1Spec, tagSet, value, **options)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class BooleanPayloadDecoder(IntegerPayloadDecoder):
|
| 154 |
+
protoComponent = univ.Boolean(0)
|
| 155 |
+
|
| 156 |
+
def _createComponent(self, asn1Spec, tagSet, value, **options):
|
| 157 |
+
return IntegerPayloadDecoder._createComponent(
|
| 158 |
+
self, asn1Spec, tagSet, value and 1 or 0, **options)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class BitStringPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 162 |
+
protoComponent = univ.BitString(())
|
| 163 |
+
supportConstructedForm = True
|
| 164 |
+
|
| 165 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 166 |
+
tagSet=None, length=None, state=None,
|
| 167 |
+
decodeFun=None, substrateFun=None,
|
| 168 |
+
**options):
|
| 169 |
+
|
| 170 |
+
if substrateFun:
|
| 171 |
+
asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options)
|
| 172 |
+
|
| 173 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 174 |
+
yield chunk
|
| 175 |
+
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
if not length:
|
| 179 |
+
raise error.PyAsn1Error('Empty BIT STRING substrate')
|
| 180 |
+
|
| 181 |
+
for chunk in isEndOfStream(substrate):
|
| 182 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 183 |
+
yield chunk
|
| 184 |
+
|
| 185 |
+
if chunk:
|
| 186 |
+
raise error.PyAsn1Error('Empty BIT STRING substrate')
|
| 187 |
+
|
| 188 |
+
if tagSet[0].tagFormat == tag.tagFormatSimple: # XXX what tag to check?
|
| 189 |
+
|
| 190 |
+
for trailingBits in readFromStream(substrate, 1, options):
|
| 191 |
+
if isinstance(trailingBits, SubstrateUnderrunError):
|
| 192 |
+
yield trailingBits
|
| 193 |
+
|
| 194 |
+
trailingBits = ord(trailingBits)
|
| 195 |
+
if trailingBits > 7:
|
| 196 |
+
raise error.PyAsn1Error(
|
| 197 |
+
'Trailing bits overflow %s' % trailingBits
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
for chunk in readFromStream(substrate, length - 1, options):
|
| 201 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 202 |
+
yield chunk
|
| 203 |
+
|
| 204 |
+
value = self.protoComponent.fromOctetString(
|
| 205 |
+
chunk, internalFormat=True, padding=trailingBits)
|
| 206 |
+
|
| 207 |
+
yield self._createComponent(asn1Spec, tagSet, value, **options)
|
| 208 |
+
|
| 209 |
+
return
|
| 210 |
+
|
| 211 |
+
if not self.supportConstructedForm:
|
| 212 |
+
raise error.PyAsn1Error('Constructed encoding form prohibited '
|
| 213 |
+
'at %s' % self.__class__.__name__)
|
| 214 |
+
|
| 215 |
+
if LOG:
|
| 216 |
+
LOG('assembling constructed serialization')
|
| 217 |
+
|
| 218 |
+
# All inner fragments are of the same type, treat them as octet string
|
| 219 |
+
substrateFun = self.substrateCollector
|
| 220 |
+
|
| 221 |
+
bitString = self.protoComponent.fromOctetString(b'', internalFormat=True)
|
| 222 |
+
|
| 223 |
+
current_position = substrate.tell()
|
| 224 |
+
|
| 225 |
+
while substrate.tell() - current_position < length:
|
| 226 |
+
for component in decodeFun(
|
| 227 |
+
substrate, self.protoComponent, substrateFun=substrateFun,
|
| 228 |
+
**options):
|
| 229 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 230 |
+
yield component
|
| 231 |
+
|
| 232 |
+
trailingBits = component[0]
|
| 233 |
+
if trailingBits > 7:
|
| 234 |
+
raise error.PyAsn1Error(
|
| 235 |
+
'Trailing bits overflow %s' % trailingBits
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
bitString = self.protoComponent.fromOctetString(
|
| 239 |
+
component[1:], internalFormat=True,
|
| 240 |
+
prepend=bitString, padding=trailingBits
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
yield self._createComponent(asn1Spec, tagSet, bitString, **options)
|
| 244 |
+
|
| 245 |
+
def indefLenValueDecoder(self, substrate, asn1Spec,
|
| 246 |
+
tagSet=None, length=None, state=None,
|
| 247 |
+
decodeFun=None, substrateFun=None,
|
| 248 |
+
**options):
|
| 249 |
+
|
| 250 |
+
if substrateFun:
|
| 251 |
+
asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options)
|
| 252 |
+
|
| 253 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 254 |
+
yield chunk
|
| 255 |
+
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
# All inner fragments are of the same type, treat them as octet string
|
| 259 |
+
substrateFun = self.substrateCollector
|
| 260 |
+
|
| 261 |
+
bitString = self.protoComponent.fromOctetString(b'', internalFormat=True)
|
| 262 |
+
|
| 263 |
+
while True: # loop over fragments
|
| 264 |
+
|
| 265 |
+
for component in decodeFun(
|
| 266 |
+
substrate, self.protoComponent, substrateFun=substrateFun,
|
| 267 |
+
allowEoo=True, **options):
|
| 268 |
+
|
| 269 |
+
if component is eoo.endOfOctets:
|
| 270 |
+
break
|
| 271 |
+
|
| 272 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 273 |
+
yield component
|
| 274 |
+
|
| 275 |
+
if component is eoo.endOfOctets:
|
| 276 |
+
break
|
| 277 |
+
|
| 278 |
+
trailingBits = component[0]
|
| 279 |
+
if trailingBits > 7:
|
| 280 |
+
raise error.PyAsn1Error(
|
| 281 |
+
'Trailing bits overflow %s' % trailingBits
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
bitString = self.protoComponent.fromOctetString(
|
| 285 |
+
component[1:], internalFormat=True,
|
| 286 |
+
prepend=bitString, padding=trailingBits
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
yield self._createComponent(asn1Spec, tagSet, bitString, **options)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class OctetStringPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 293 |
+
protoComponent = univ.OctetString('')
|
| 294 |
+
supportConstructedForm = True
|
| 295 |
+
|
| 296 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 297 |
+
tagSet=None, length=None, state=None,
|
| 298 |
+
decodeFun=None, substrateFun=None,
|
| 299 |
+
**options):
|
| 300 |
+
if substrateFun:
|
| 301 |
+
asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options)
|
| 302 |
+
|
| 303 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 304 |
+
yield chunk
|
| 305 |
+
|
| 306 |
+
return
|
| 307 |
+
|
| 308 |
+
if tagSet[0].tagFormat == tag.tagFormatSimple: # XXX what tag to check?
|
| 309 |
+
for chunk in readFromStream(substrate, length, options):
|
| 310 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 311 |
+
yield chunk
|
| 312 |
+
|
| 313 |
+
yield self._createComponent(asn1Spec, tagSet, chunk, **options)
|
| 314 |
+
|
| 315 |
+
return
|
| 316 |
+
|
| 317 |
+
if not self.supportConstructedForm:
|
| 318 |
+
raise error.PyAsn1Error('Constructed encoding form prohibited at %s' % self.__class__.__name__)
|
| 319 |
+
|
| 320 |
+
if LOG:
|
| 321 |
+
LOG('assembling constructed serialization')
|
| 322 |
+
|
| 323 |
+
# All inner fragments are of the same type, treat them as octet string
|
| 324 |
+
substrateFun = self.substrateCollector
|
| 325 |
+
|
| 326 |
+
header = b''
|
| 327 |
+
|
| 328 |
+
original_position = substrate.tell()
|
| 329 |
+
# head = popSubstream(substrate, length)
|
| 330 |
+
while substrate.tell() - original_position < length:
|
| 331 |
+
for component in decodeFun(
|
| 332 |
+
substrate, self.protoComponent, substrateFun=substrateFun,
|
| 333 |
+
**options):
|
| 334 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 335 |
+
yield component
|
| 336 |
+
|
| 337 |
+
header += component
|
| 338 |
+
|
| 339 |
+
yield self._createComponent(asn1Spec, tagSet, header, **options)
|
| 340 |
+
|
| 341 |
+
def indefLenValueDecoder(self, substrate, asn1Spec,
|
| 342 |
+
tagSet=None, length=None, state=None,
|
| 343 |
+
decodeFun=None, substrateFun=None,
|
| 344 |
+
**options):
|
| 345 |
+
if substrateFun and substrateFun is not self.substrateCollector:
|
| 346 |
+
asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options)
|
| 347 |
+
|
| 348 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 349 |
+
yield chunk
|
| 350 |
+
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
# All inner fragments are of the same type, treat them as octet string
|
| 354 |
+
substrateFun = self.substrateCollector
|
| 355 |
+
|
| 356 |
+
header = b''
|
| 357 |
+
|
| 358 |
+
while True: # loop over fragments
|
| 359 |
+
|
| 360 |
+
for component in decodeFun(
|
| 361 |
+
substrate, self.protoComponent, substrateFun=substrateFun,
|
| 362 |
+
allowEoo=True, **options):
|
| 363 |
+
|
| 364 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 365 |
+
yield component
|
| 366 |
+
|
| 367 |
+
if component is eoo.endOfOctets:
|
| 368 |
+
break
|
| 369 |
+
|
| 370 |
+
if component is eoo.endOfOctets:
|
| 371 |
+
break
|
| 372 |
+
|
| 373 |
+
header += component
|
| 374 |
+
|
| 375 |
+
yield self._createComponent(asn1Spec, tagSet, header, **options)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class NullPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 379 |
+
protoComponent = univ.Null('')
|
| 380 |
+
|
| 381 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 382 |
+
tagSet=None, length=None, state=None,
|
| 383 |
+
decodeFun=None, substrateFun=None,
|
| 384 |
+
**options):
|
| 385 |
+
|
| 386 |
+
if tagSet[0].tagFormat != tag.tagFormatSimple:
|
| 387 |
+
raise error.PyAsn1Error('Simple tag format expected')
|
| 388 |
+
|
| 389 |
+
for chunk in readFromStream(substrate, length, options):
|
| 390 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 391 |
+
yield chunk
|
| 392 |
+
|
| 393 |
+
component = self._createComponent(asn1Spec, tagSet, '', **options)
|
| 394 |
+
|
| 395 |
+
if chunk:
|
| 396 |
+
raise error.PyAsn1Error('Unexpected %d-octet substrate for Null' % length)
|
| 397 |
+
|
| 398 |
+
yield component
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class ObjectIdentifierPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 402 |
+
protoComponent = univ.ObjectIdentifier(())
|
| 403 |
+
|
| 404 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 405 |
+
tagSet=None, length=None, state=None,
|
| 406 |
+
decodeFun=None, substrateFun=None,
|
| 407 |
+
**options):
|
| 408 |
+
if tagSet[0].tagFormat != tag.tagFormatSimple:
|
| 409 |
+
raise error.PyAsn1Error('Simple tag format expected')
|
| 410 |
+
|
| 411 |
+
for chunk in readFromStream(substrate, length, options):
|
| 412 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 413 |
+
yield chunk
|
| 414 |
+
|
| 415 |
+
if not chunk:
|
| 416 |
+
raise error.PyAsn1Error('Empty substrate')
|
| 417 |
+
|
| 418 |
+
oid = ()
|
| 419 |
+
index = 0
|
| 420 |
+
substrateLen = len(chunk)
|
| 421 |
+
while index < substrateLen:
|
| 422 |
+
subId = chunk[index]
|
| 423 |
+
index += 1
|
| 424 |
+
if subId < 128:
|
| 425 |
+
oid += (subId,)
|
| 426 |
+
elif subId > 128:
|
| 427 |
+
# Construct subid from a number of octets
|
| 428 |
+
nextSubId = subId
|
| 429 |
+
subId = 0
|
| 430 |
+
while nextSubId >= 128:
|
| 431 |
+
subId = (subId << 7) + (nextSubId & 0x7F)
|
| 432 |
+
if index >= substrateLen:
|
| 433 |
+
raise error.SubstrateUnderrunError(
|
| 434 |
+
'Short substrate for sub-OID past %s' % (oid,)
|
| 435 |
+
)
|
| 436 |
+
nextSubId = chunk[index]
|
| 437 |
+
index += 1
|
| 438 |
+
oid += ((subId << 7) + nextSubId,)
|
| 439 |
+
elif subId == 128:
|
| 440 |
+
# ASN.1 spec forbids leading zeros (0x80) in OID
|
| 441 |
+
# encoding, tolerating it opens a vulnerability. See
|
| 442 |
+
# https://www.esat.kuleuven.be/cosic/publications/article-1432.pdf
|
| 443 |
+
# page 7
|
| 444 |
+
raise error.PyAsn1Error('Invalid octet 0x80 in OID encoding')
|
| 445 |
+
|
| 446 |
+
# Decode two leading arcs
|
| 447 |
+
if 0 <= oid[0] <= 39:
|
| 448 |
+
oid = (0,) + oid
|
| 449 |
+
elif 40 <= oid[0] <= 79:
|
| 450 |
+
oid = (1, oid[0] - 40) + oid[1:]
|
| 451 |
+
elif oid[0] >= 80:
|
| 452 |
+
oid = (2, oid[0] - 80) + oid[1:]
|
| 453 |
+
else:
|
| 454 |
+
raise error.PyAsn1Error('Malformed first OID octet: %s' % chunk[0])
|
| 455 |
+
|
| 456 |
+
yield self._createComponent(asn1Spec, tagSet, oid, **options)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class RelativeOIDPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 460 |
+
protoComponent = univ.RelativeOID(())
|
| 461 |
+
|
| 462 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 463 |
+
tagSet=None, length=None, state=None,
|
| 464 |
+
decodeFun=None, substrateFun=None,
|
| 465 |
+
**options):
|
| 466 |
+
if tagSet[0].tagFormat != tag.tagFormatSimple:
|
| 467 |
+
raise error.PyAsn1Error('Simple tag format expected')
|
| 468 |
+
|
| 469 |
+
for chunk in readFromStream(substrate, length, options):
|
| 470 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 471 |
+
yield chunk
|
| 472 |
+
|
| 473 |
+
if not chunk:
|
| 474 |
+
raise error.PyAsn1Error('Empty substrate')
|
| 475 |
+
|
| 476 |
+
reloid = ()
|
| 477 |
+
index = 0
|
| 478 |
+
substrateLen = len(chunk)
|
| 479 |
+
while index < substrateLen:
|
| 480 |
+
subId = chunk[index]
|
| 481 |
+
index += 1
|
| 482 |
+
if subId < 128:
|
| 483 |
+
reloid += (subId,)
|
| 484 |
+
elif subId > 128:
|
| 485 |
+
# Construct subid from a number of octets
|
| 486 |
+
nextSubId = subId
|
| 487 |
+
subId = 0
|
| 488 |
+
while nextSubId >= 128:
|
| 489 |
+
subId = (subId << 7) + (nextSubId & 0x7F)
|
| 490 |
+
if index >= substrateLen:
|
| 491 |
+
raise error.SubstrateUnderrunError(
|
| 492 |
+
'Short substrate for sub-OID past %s' % (reloid,)
|
| 493 |
+
)
|
| 494 |
+
nextSubId = chunk[index]
|
| 495 |
+
index += 1
|
| 496 |
+
reloid += ((subId << 7) + nextSubId,)
|
| 497 |
+
elif subId == 128:
|
| 498 |
+
# ASN.1 spec forbids leading zeros (0x80) in OID
|
| 499 |
+
# encoding, tolerating it opens a vulnerability. See
|
| 500 |
+
# https://www.esat.kuleuven.be/cosic/publications/article-1432.pdf
|
| 501 |
+
# page 7
|
| 502 |
+
raise error.PyAsn1Error('Invalid octet 0x80 in RELATIVE-OID encoding')
|
| 503 |
+
|
| 504 |
+
yield self._createComponent(asn1Spec, tagSet, reloid, **options)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class RealPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 508 |
+
protoComponent = univ.Real()
|
| 509 |
+
|
| 510 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 511 |
+
tagSet=None, length=None, state=None,
|
| 512 |
+
decodeFun=None, substrateFun=None,
|
| 513 |
+
**options):
|
| 514 |
+
if tagSet[0].tagFormat != tag.tagFormatSimple:
|
| 515 |
+
raise error.PyAsn1Error('Simple tag format expected')
|
| 516 |
+
|
| 517 |
+
for chunk in readFromStream(substrate, length, options):
|
| 518 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 519 |
+
yield chunk
|
| 520 |
+
|
| 521 |
+
if not chunk:
|
| 522 |
+
yield self._createComponent(asn1Spec, tagSet, 0.0, **options)
|
| 523 |
+
return
|
| 524 |
+
|
| 525 |
+
fo = chunk[0]
|
| 526 |
+
chunk = chunk[1:]
|
| 527 |
+
if fo & 0x80: # binary encoding
|
| 528 |
+
if not chunk:
|
| 529 |
+
raise error.PyAsn1Error("Incomplete floating-point value")
|
| 530 |
+
|
| 531 |
+
if LOG:
|
| 532 |
+
LOG('decoding binary encoded REAL')
|
| 533 |
+
|
| 534 |
+
n = (fo & 0x03) + 1
|
| 535 |
+
|
| 536 |
+
if n == 4:
|
| 537 |
+
n = chunk[0]
|
| 538 |
+
chunk = chunk[1:]
|
| 539 |
+
|
| 540 |
+
eo, chunk = chunk[:n], chunk[n:]
|
| 541 |
+
|
| 542 |
+
if not eo or not chunk:
|
| 543 |
+
raise error.PyAsn1Error('Real exponent screwed')
|
| 544 |
+
|
| 545 |
+
e = eo[0] & 0x80 and -1 or 0
|
| 546 |
+
|
| 547 |
+
while eo: # exponent
|
| 548 |
+
e <<= 8
|
| 549 |
+
e |= eo[0]
|
| 550 |
+
eo = eo[1:]
|
| 551 |
+
|
| 552 |
+
b = fo >> 4 & 0x03 # base bits
|
| 553 |
+
|
| 554 |
+
if b > 2:
|
| 555 |
+
raise error.PyAsn1Error('Illegal Real base')
|
| 556 |
+
|
| 557 |
+
if b == 1: # encbase = 8
|
| 558 |
+
e *= 3
|
| 559 |
+
|
| 560 |
+
elif b == 2: # encbase = 16
|
| 561 |
+
e *= 4
|
| 562 |
+
p = 0
|
| 563 |
+
|
| 564 |
+
while chunk: # value
|
| 565 |
+
p <<= 8
|
| 566 |
+
p |= chunk[0]
|
| 567 |
+
chunk = chunk[1:]
|
| 568 |
+
|
| 569 |
+
if fo & 0x40: # sign bit
|
| 570 |
+
p = -p
|
| 571 |
+
|
| 572 |
+
sf = fo >> 2 & 0x03 # scale bits
|
| 573 |
+
p *= 2 ** sf
|
| 574 |
+
value = (p, 2, e)
|
| 575 |
+
|
| 576 |
+
elif fo & 0x40: # infinite value
|
| 577 |
+
if LOG:
|
| 578 |
+
LOG('decoding infinite REAL')
|
| 579 |
+
|
| 580 |
+
value = fo & 0x01 and '-inf' or 'inf'
|
| 581 |
+
|
| 582 |
+
elif fo & 0xc0 == 0: # character encoding
|
| 583 |
+
if not chunk:
|
| 584 |
+
raise error.PyAsn1Error("Incomplete floating-point value")
|
| 585 |
+
|
| 586 |
+
if LOG:
|
| 587 |
+
LOG('decoding character encoded REAL')
|
| 588 |
+
|
| 589 |
+
try:
|
| 590 |
+
if fo & 0x3 == 0x1: # NR1
|
| 591 |
+
value = (int(chunk), 10, 0)
|
| 592 |
+
|
| 593 |
+
elif fo & 0x3 == 0x2: # NR2
|
| 594 |
+
value = float(chunk)
|
| 595 |
+
|
| 596 |
+
elif fo & 0x3 == 0x3: # NR3
|
| 597 |
+
value = float(chunk)
|
| 598 |
+
|
| 599 |
+
else:
|
| 600 |
+
raise error.SubstrateUnderrunError(
|
| 601 |
+
'Unknown NR (tag %s)' % fo
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
except ValueError:
|
| 605 |
+
raise error.SubstrateUnderrunError(
|
| 606 |
+
'Bad character Real syntax'
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
else:
|
| 610 |
+
raise error.SubstrateUnderrunError(
|
| 611 |
+
'Unknown encoding (tag %s)' % fo
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
yield self._createComponent(asn1Spec, tagSet, value, **options)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class AbstractConstructedPayloadDecoder(AbstractPayloadDecoder):
|
| 618 |
+
protoComponent = None
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class ConstructedPayloadDecoderBase(AbstractConstructedPayloadDecoder):
|
| 622 |
+
protoRecordComponent = None
|
| 623 |
+
protoSequenceComponent = None
|
| 624 |
+
|
| 625 |
+
def _getComponentTagMap(self, asn1Object, idx):
|
| 626 |
+
raise NotImplementedError
|
| 627 |
+
|
| 628 |
+
def _getComponentPositionByType(self, asn1Object, tagSet, idx):
|
| 629 |
+
raise NotImplementedError
|
| 630 |
+
|
| 631 |
+
def _decodeComponentsSchemaless(
|
| 632 |
+
self, substrate, tagSet=None, decodeFun=None,
|
| 633 |
+
length=None, **options):
|
| 634 |
+
|
| 635 |
+
asn1Object = None
|
| 636 |
+
|
| 637 |
+
components = []
|
| 638 |
+
componentTypes = set()
|
| 639 |
+
|
| 640 |
+
original_position = substrate.tell()
|
| 641 |
+
|
| 642 |
+
while length == -1 or substrate.tell() < original_position + length:
|
| 643 |
+
for component in decodeFun(substrate, **options):
|
| 644 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 645 |
+
yield component
|
| 646 |
+
|
| 647 |
+
if length == -1 and component is eoo.endOfOctets:
|
| 648 |
+
break
|
| 649 |
+
|
| 650 |
+
components.append(component)
|
| 651 |
+
componentTypes.add(component.tagSet)
|
| 652 |
+
|
| 653 |
+
# Now we have to guess is it SEQUENCE/SET or SEQUENCE OF/SET OF
|
| 654 |
+
# The heuristics is:
|
| 655 |
+
# * 1+ components of different types -> likely SEQUENCE/SET
|
| 656 |
+
# * otherwise -> likely SEQUENCE OF/SET OF
|
| 657 |
+
if len(componentTypes) > 1:
|
| 658 |
+
protoComponent = self.protoRecordComponent
|
| 659 |
+
|
| 660 |
+
else:
|
| 661 |
+
protoComponent = self.protoSequenceComponent
|
| 662 |
+
|
| 663 |
+
asn1Object = protoComponent.clone(
|
| 664 |
+
# construct tagSet from base tag from prototype ASN.1 object
|
| 665 |
+
# and additional tags recovered from the substrate
|
| 666 |
+
tagSet=tag.TagSet(protoComponent.tagSet.baseTag, *tagSet.superTags)
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
if LOG:
|
| 670 |
+
LOG('guessed %r container type (pass `asn1Spec` to guide the '
|
| 671 |
+
'decoder)' % asn1Object)
|
| 672 |
+
|
| 673 |
+
for idx, component in enumerate(components):
|
| 674 |
+
asn1Object.setComponentByPosition(
|
| 675 |
+
idx, component,
|
| 676 |
+
verifyConstraints=False,
|
| 677 |
+
matchTags=False, matchConstraints=False
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
yield asn1Object
|
| 681 |
+
|
| 682 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 683 |
+
tagSet=None, length=None, state=None,
|
| 684 |
+
decodeFun=None, substrateFun=None,
|
| 685 |
+
**options):
|
| 686 |
+
if tagSet[0].tagFormat != tag.tagFormatConstructed:
|
| 687 |
+
raise error.PyAsn1Error('Constructed tag format expected')
|
| 688 |
+
|
| 689 |
+
original_position = substrate.tell()
|
| 690 |
+
|
| 691 |
+
if substrateFun:
|
| 692 |
+
if asn1Spec is not None:
|
| 693 |
+
asn1Object = asn1Spec.clone()
|
| 694 |
+
|
| 695 |
+
elif self.protoComponent is not None:
|
| 696 |
+
asn1Object = self.protoComponent.clone(tagSet=tagSet)
|
| 697 |
+
|
| 698 |
+
else:
|
| 699 |
+
asn1Object = self.protoRecordComponent, self.protoSequenceComponent
|
| 700 |
+
|
| 701 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 702 |
+
yield chunk
|
| 703 |
+
|
| 704 |
+
return
|
| 705 |
+
|
| 706 |
+
if asn1Spec is None:
|
| 707 |
+
for asn1Object in self._decodeComponentsSchemaless(
|
| 708 |
+
substrate, tagSet=tagSet, decodeFun=decodeFun,
|
| 709 |
+
length=length, **options):
|
| 710 |
+
if isinstance(asn1Object, SubstrateUnderrunError):
|
| 711 |
+
yield asn1Object
|
| 712 |
+
|
| 713 |
+
if substrate.tell() < original_position + length:
|
| 714 |
+
if LOG:
|
| 715 |
+
for trailing in readFromStream(substrate, context=options):
|
| 716 |
+
if isinstance(trailing, SubstrateUnderrunError):
|
| 717 |
+
yield trailing
|
| 718 |
+
|
| 719 |
+
LOG('Unused trailing %d octets encountered: %s' % (
|
| 720 |
+
len(trailing), debug.hexdump(trailing)))
|
| 721 |
+
|
| 722 |
+
yield asn1Object
|
| 723 |
+
|
| 724 |
+
return
|
| 725 |
+
|
| 726 |
+
asn1Object = asn1Spec.clone()
|
| 727 |
+
asn1Object.clear()
|
| 728 |
+
|
| 729 |
+
options = self._passAsn1Object(asn1Object, options)
|
| 730 |
+
|
| 731 |
+
if asn1Spec.typeId in (univ.Sequence.typeId, univ.Set.typeId):
|
| 732 |
+
|
| 733 |
+
namedTypes = asn1Spec.componentType
|
| 734 |
+
|
| 735 |
+
isSetType = asn1Spec.typeId == univ.Set.typeId
|
| 736 |
+
isDeterministic = not isSetType and not namedTypes.hasOptionalOrDefault
|
| 737 |
+
|
| 738 |
+
if LOG:
|
| 739 |
+
LOG('decoding %sdeterministic %s type %r chosen by type ID' % (
|
| 740 |
+
not isDeterministic and 'non-' or '', isSetType and 'SET' or '',
|
| 741 |
+
asn1Spec))
|
| 742 |
+
|
| 743 |
+
seenIndices = set()
|
| 744 |
+
idx = 0
|
| 745 |
+
while substrate.tell() - original_position < length:
|
| 746 |
+
if not namedTypes:
|
| 747 |
+
componentType = None
|
| 748 |
+
|
| 749 |
+
elif isSetType:
|
| 750 |
+
componentType = namedTypes.tagMapUnique
|
| 751 |
+
|
| 752 |
+
else:
|
| 753 |
+
try:
|
| 754 |
+
if isDeterministic:
|
| 755 |
+
componentType = namedTypes[idx].asn1Object
|
| 756 |
+
|
| 757 |
+
elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted:
|
| 758 |
+
componentType = namedTypes.getTagMapNearPosition(idx)
|
| 759 |
+
|
| 760 |
+
else:
|
| 761 |
+
componentType = namedTypes[idx].asn1Object
|
| 762 |
+
|
| 763 |
+
except IndexError:
|
| 764 |
+
raise error.PyAsn1Error(
|
| 765 |
+
'Excessive components decoded at %r' % (asn1Spec,)
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
for component in decodeFun(substrate, componentType, **options):
|
| 769 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 770 |
+
yield component
|
| 771 |
+
|
| 772 |
+
if not isDeterministic and namedTypes:
|
| 773 |
+
if isSetType:
|
| 774 |
+
idx = namedTypes.getPositionByType(component.effectiveTagSet)
|
| 775 |
+
|
| 776 |
+
elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted:
|
| 777 |
+
idx = namedTypes.getPositionNearType(component.effectiveTagSet, idx)
|
| 778 |
+
|
| 779 |
+
asn1Object.setComponentByPosition(
|
| 780 |
+
idx, component,
|
| 781 |
+
verifyConstraints=False,
|
| 782 |
+
matchTags=False, matchConstraints=False
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
seenIndices.add(idx)
|
| 786 |
+
idx += 1
|
| 787 |
+
|
| 788 |
+
if LOG:
|
| 789 |
+
LOG('seen component indices %s' % seenIndices)
|
| 790 |
+
|
| 791 |
+
if namedTypes:
|
| 792 |
+
if not namedTypes.requiredComponents.issubset(seenIndices):
|
| 793 |
+
raise error.PyAsn1Error(
|
| 794 |
+
'ASN.1 object %s has uninitialized '
|
| 795 |
+
'components' % asn1Object.__class__.__name__)
|
| 796 |
+
|
| 797 |
+
if namedTypes.hasOpenTypes:
|
| 798 |
+
|
| 799 |
+
openTypes = options.get('openTypes', {})
|
| 800 |
+
|
| 801 |
+
if LOG:
|
| 802 |
+
LOG('user-specified open types map:')
|
| 803 |
+
|
| 804 |
+
for k, v in openTypes.items():
|
| 805 |
+
LOG('%s -> %r' % (k, v))
|
| 806 |
+
|
| 807 |
+
if openTypes or options.get('decodeOpenTypes', False):
|
| 808 |
+
|
| 809 |
+
for idx, namedType in enumerate(namedTypes.namedTypes):
|
| 810 |
+
if not namedType.openType:
|
| 811 |
+
continue
|
| 812 |
+
|
| 813 |
+
if namedType.isOptional and not asn1Object.getComponentByPosition(idx).isValue:
|
| 814 |
+
continue
|
| 815 |
+
|
| 816 |
+
governingValue = asn1Object.getComponentByName(
|
| 817 |
+
namedType.openType.name
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
try:
|
| 821 |
+
openType = openTypes[governingValue]
|
| 822 |
+
|
| 823 |
+
except KeyError:
|
| 824 |
+
|
| 825 |
+
if LOG:
|
| 826 |
+
LOG('default open types map of component '
|
| 827 |
+
'"%s.%s" governed by component "%s.%s"'
|
| 828 |
+
':' % (asn1Object.__class__.__name__,
|
| 829 |
+
namedType.name,
|
| 830 |
+
asn1Object.__class__.__name__,
|
| 831 |
+
namedType.openType.name))
|
| 832 |
+
|
| 833 |
+
for k, v in namedType.openType.items():
|
| 834 |
+
LOG('%s -> %r' % (k, v))
|
| 835 |
+
|
| 836 |
+
try:
|
| 837 |
+
openType = namedType.openType[governingValue]
|
| 838 |
+
|
| 839 |
+
except KeyError:
|
| 840 |
+
if LOG:
|
| 841 |
+
LOG('failed to resolve open type by governing '
|
| 842 |
+
'value %r' % (governingValue,))
|
| 843 |
+
continue
|
| 844 |
+
|
| 845 |
+
if LOG:
|
| 846 |
+
LOG('resolved open type %r by governing '
|
| 847 |
+
'value %r' % (openType, governingValue))
|
| 848 |
+
|
| 849 |
+
containerValue = asn1Object.getComponentByPosition(idx)
|
| 850 |
+
|
| 851 |
+
if containerValue.typeId in (
|
| 852 |
+
univ.SetOf.typeId, univ.SequenceOf.typeId):
|
| 853 |
+
|
| 854 |
+
for pos, containerElement in enumerate(
|
| 855 |
+
containerValue):
|
| 856 |
+
|
| 857 |
+
stream = asSeekableStream(containerValue[pos].asOctets())
|
| 858 |
+
|
| 859 |
+
for component in decodeFun(stream, asn1Spec=openType, **options):
|
| 860 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 861 |
+
yield component
|
| 862 |
+
|
| 863 |
+
containerValue[pos] = component
|
| 864 |
+
|
| 865 |
+
else:
|
| 866 |
+
stream = asSeekableStream(asn1Object.getComponentByPosition(idx).asOctets())
|
| 867 |
+
|
| 868 |
+
for component in decodeFun(stream, asn1Spec=openType, **options):
|
| 869 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 870 |
+
yield component
|
| 871 |
+
|
| 872 |
+
asn1Object.setComponentByPosition(idx, component)
|
| 873 |
+
|
| 874 |
+
else:
|
| 875 |
+
inconsistency = asn1Object.isInconsistent
|
| 876 |
+
if inconsistency:
|
| 877 |
+
raise error.PyAsn1Error(
|
| 878 |
+
f"ASN.1 object {asn1Object.__class__.__name__} is inconsistent")
|
| 879 |
+
|
| 880 |
+
else:
|
| 881 |
+
componentType = asn1Spec.componentType
|
| 882 |
+
|
| 883 |
+
if LOG:
|
| 884 |
+
LOG('decoding type %r chosen by given `asn1Spec`' % componentType)
|
| 885 |
+
|
| 886 |
+
idx = 0
|
| 887 |
+
|
| 888 |
+
while substrate.tell() - original_position < length:
|
| 889 |
+
for component in decodeFun(substrate, componentType, **options):
|
| 890 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 891 |
+
yield component
|
| 892 |
+
|
| 893 |
+
asn1Object.setComponentByPosition(
|
| 894 |
+
idx, component,
|
| 895 |
+
verifyConstraints=False,
|
| 896 |
+
matchTags=False, matchConstraints=False
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
idx += 1
|
| 900 |
+
|
| 901 |
+
yield asn1Object
|
| 902 |
+
|
| 903 |
+
def indefLenValueDecoder(self, substrate, asn1Spec,
|
| 904 |
+
tagSet=None, length=None, state=None,
|
| 905 |
+
decodeFun=None, substrateFun=None,
|
| 906 |
+
**options):
|
| 907 |
+
if tagSet[0].tagFormat != tag.tagFormatConstructed:
|
| 908 |
+
raise error.PyAsn1Error('Constructed tag format expected')
|
| 909 |
+
|
| 910 |
+
if substrateFun is not None:
|
| 911 |
+
if asn1Spec is not None:
|
| 912 |
+
asn1Object = asn1Spec.clone()
|
| 913 |
+
|
| 914 |
+
elif self.protoComponent is not None:
|
| 915 |
+
asn1Object = self.protoComponent.clone(tagSet=tagSet)
|
| 916 |
+
|
| 917 |
+
else:
|
| 918 |
+
asn1Object = self.protoRecordComponent, self.protoSequenceComponent
|
| 919 |
+
|
| 920 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 921 |
+
yield chunk
|
| 922 |
+
|
| 923 |
+
return
|
| 924 |
+
|
| 925 |
+
if asn1Spec is None:
|
| 926 |
+
for asn1Object in self._decodeComponentsSchemaless(
|
| 927 |
+
substrate, tagSet=tagSet, decodeFun=decodeFun,
|
| 928 |
+
length=length, **dict(options, allowEoo=True)):
|
| 929 |
+
if isinstance(asn1Object, SubstrateUnderrunError):
|
| 930 |
+
yield asn1Object
|
| 931 |
+
|
| 932 |
+
yield asn1Object
|
| 933 |
+
|
| 934 |
+
return
|
| 935 |
+
|
| 936 |
+
asn1Object = asn1Spec.clone()
|
| 937 |
+
asn1Object.clear()
|
| 938 |
+
|
| 939 |
+
options = self._passAsn1Object(asn1Object, options)
|
| 940 |
+
|
| 941 |
+
if asn1Spec.typeId in (univ.Sequence.typeId, univ.Set.typeId):
|
| 942 |
+
|
| 943 |
+
namedTypes = asn1Object.componentType
|
| 944 |
+
|
| 945 |
+
isSetType = asn1Object.typeId == univ.Set.typeId
|
| 946 |
+
isDeterministic = not isSetType and not namedTypes.hasOptionalOrDefault
|
| 947 |
+
|
| 948 |
+
if LOG:
|
| 949 |
+
LOG('decoding %sdeterministic %s type %r chosen by type ID' % (
|
| 950 |
+
not isDeterministic and 'non-' or '', isSetType and 'SET' or '',
|
| 951 |
+
asn1Spec))
|
| 952 |
+
|
| 953 |
+
seenIndices = set()
|
| 954 |
+
|
| 955 |
+
idx = 0
|
| 956 |
+
|
| 957 |
+
while True: # loop over components
|
| 958 |
+
if len(namedTypes) <= idx:
|
| 959 |
+
asn1Spec = None
|
| 960 |
+
|
| 961 |
+
elif isSetType:
|
| 962 |
+
asn1Spec = namedTypes.tagMapUnique
|
| 963 |
+
|
| 964 |
+
else:
|
| 965 |
+
try:
|
| 966 |
+
if isDeterministic:
|
| 967 |
+
asn1Spec = namedTypes[idx].asn1Object
|
| 968 |
+
|
| 969 |
+
elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted:
|
| 970 |
+
asn1Spec = namedTypes.getTagMapNearPosition(idx)
|
| 971 |
+
|
| 972 |
+
else:
|
| 973 |
+
asn1Spec = namedTypes[idx].asn1Object
|
| 974 |
+
|
| 975 |
+
except IndexError:
|
| 976 |
+
raise error.PyAsn1Error(
|
| 977 |
+
'Excessive components decoded at %r' % (asn1Object,)
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
for component in decodeFun(substrate, asn1Spec, allowEoo=True, **options):
|
| 981 |
+
|
| 982 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 983 |
+
yield component
|
| 984 |
+
|
| 985 |
+
if component is eoo.endOfOctets:
|
| 986 |
+
break
|
| 987 |
+
|
| 988 |
+
if component is eoo.endOfOctets:
|
| 989 |
+
break
|
| 990 |
+
|
| 991 |
+
if not isDeterministic and namedTypes:
|
| 992 |
+
if isSetType:
|
| 993 |
+
idx = namedTypes.getPositionByType(component.effectiveTagSet)
|
| 994 |
+
|
| 995 |
+
elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted:
|
| 996 |
+
idx = namedTypes.getPositionNearType(component.effectiveTagSet, idx)
|
| 997 |
+
|
| 998 |
+
asn1Object.setComponentByPosition(
|
| 999 |
+
idx, component,
|
| 1000 |
+
verifyConstraints=False,
|
| 1001 |
+
matchTags=False, matchConstraints=False
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
seenIndices.add(idx)
|
| 1005 |
+
idx += 1
|
| 1006 |
+
|
| 1007 |
+
if LOG:
|
| 1008 |
+
LOG('seen component indices %s' % seenIndices)
|
| 1009 |
+
|
| 1010 |
+
if namedTypes:
|
| 1011 |
+
if not namedTypes.requiredComponents.issubset(seenIndices):
|
| 1012 |
+
raise error.PyAsn1Error(
|
| 1013 |
+
'ASN.1 object %s has uninitialized '
|
| 1014 |
+
'components' % asn1Object.__class__.__name__)
|
| 1015 |
+
|
| 1016 |
+
if namedTypes.hasOpenTypes:
|
| 1017 |
+
|
| 1018 |
+
openTypes = options.get('openTypes', {})
|
| 1019 |
+
|
| 1020 |
+
if LOG:
|
| 1021 |
+
LOG('user-specified open types map:')
|
| 1022 |
+
|
| 1023 |
+
for k, v in openTypes.items():
|
| 1024 |
+
LOG('%s -> %r' % (k, v))
|
| 1025 |
+
|
| 1026 |
+
if openTypes or options.get('decodeOpenTypes', False):
|
| 1027 |
+
|
| 1028 |
+
for idx, namedType in enumerate(namedTypes.namedTypes):
|
| 1029 |
+
if not namedType.openType:
|
| 1030 |
+
continue
|
| 1031 |
+
|
| 1032 |
+
if namedType.isOptional and not asn1Object.getComponentByPosition(idx).isValue:
|
| 1033 |
+
continue
|
| 1034 |
+
|
| 1035 |
+
governingValue = asn1Object.getComponentByName(
|
| 1036 |
+
namedType.openType.name
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
try:
|
| 1040 |
+
openType = openTypes[governingValue]
|
| 1041 |
+
|
| 1042 |
+
except KeyError:
|
| 1043 |
+
|
| 1044 |
+
if LOG:
|
| 1045 |
+
LOG('default open types map of component '
|
| 1046 |
+
'"%s.%s" governed by component "%s.%s"'
|
| 1047 |
+
':' % (asn1Object.__class__.__name__,
|
| 1048 |
+
namedType.name,
|
| 1049 |
+
asn1Object.__class__.__name__,
|
| 1050 |
+
namedType.openType.name))
|
| 1051 |
+
|
| 1052 |
+
for k, v in namedType.openType.items():
|
| 1053 |
+
LOG('%s -> %r' % (k, v))
|
| 1054 |
+
|
| 1055 |
+
try:
|
| 1056 |
+
openType = namedType.openType[governingValue]
|
| 1057 |
+
|
| 1058 |
+
except KeyError:
|
| 1059 |
+
if LOG:
|
| 1060 |
+
LOG('failed to resolve open type by governing '
|
| 1061 |
+
'value %r' % (governingValue,))
|
| 1062 |
+
continue
|
| 1063 |
+
|
| 1064 |
+
if LOG:
|
| 1065 |
+
LOG('resolved open type %r by governing '
|
| 1066 |
+
'value %r' % (openType, governingValue))
|
| 1067 |
+
|
| 1068 |
+
containerValue = asn1Object.getComponentByPosition(idx)
|
| 1069 |
+
|
| 1070 |
+
if containerValue.typeId in (
|
| 1071 |
+
univ.SetOf.typeId, univ.SequenceOf.typeId):
|
| 1072 |
+
|
| 1073 |
+
for pos, containerElement in enumerate(
|
| 1074 |
+
containerValue):
|
| 1075 |
+
|
| 1076 |
+
stream = asSeekableStream(containerValue[pos].asOctets())
|
| 1077 |
+
|
| 1078 |
+
for component in decodeFun(stream, asn1Spec=openType,
|
| 1079 |
+
**dict(options, allowEoo=True)):
|
| 1080 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 1081 |
+
yield component
|
| 1082 |
+
|
| 1083 |
+
if component is eoo.endOfOctets:
|
| 1084 |
+
break
|
| 1085 |
+
|
| 1086 |
+
containerValue[pos] = component
|
| 1087 |
+
|
| 1088 |
+
else:
|
| 1089 |
+
stream = asSeekableStream(asn1Object.getComponentByPosition(idx).asOctets())
|
| 1090 |
+
for component in decodeFun(stream, asn1Spec=openType,
|
| 1091 |
+
**dict(options, allowEoo=True)):
|
| 1092 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 1093 |
+
yield component
|
| 1094 |
+
|
| 1095 |
+
if component is eoo.endOfOctets:
|
| 1096 |
+
break
|
| 1097 |
+
|
| 1098 |
+
asn1Object.setComponentByPosition(idx, component)
|
| 1099 |
+
|
| 1100 |
+
else:
|
| 1101 |
+
inconsistency = asn1Object.isInconsistent
|
| 1102 |
+
if inconsistency:
|
| 1103 |
+
raise error.PyAsn1Error(
|
| 1104 |
+
f"ASN.1 object {asn1Object.__class__.__name__} is inconsistent")
|
| 1105 |
+
|
| 1106 |
+
else:
|
| 1107 |
+
componentType = asn1Spec.componentType
|
| 1108 |
+
|
| 1109 |
+
if LOG:
|
| 1110 |
+
LOG('decoding type %r chosen by given `asn1Spec`' % componentType)
|
| 1111 |
+
|
| 1112 |
+
idx = 0
|
| 1113 |
+
|
| 1114 |
+
while True:
|
| 1115 |
+
|
| 1116 |
+
for component in decodeFun(
|
| 1117 |
+
substrate, componentType, allowEoo=True, **options):
|
| 1118 |
+
|
| 1119 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 1120 |
+
yield component
|
| 1121 |
+
|
| 1122 |
+
if component is eoo.endOfOctets:
|
| 1123 |
+
break
|
| 1124 |
+
|
| 1125 |
+
if component is eoo.endOfOctets:
|
| 1126 |
+
break
|
| 1127 |
+
|
| 1128 |
+
asn1Object.setComponentByPosition(
|
| 1129 |
+
idx, component,
|
| 1130 |
+
verifyConstraints=False,
|
| 1131 |
+
matchTags=False, matchConstraints=False
|
| 1132 |
+
)
|
| 1133 |
+
|
| 1134 |
+
idx += 1
|
| 1135 |
+
|
| 1136 |
+
yield asn1Object
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
class SequenceOrSequenceOfPayloadDecoder(ConstructedPayloadDecoderBase):
|
| 1140 |
+
protoRecordComponent = univ.Sequence()
|
| 1141 |
+
protoSequenceComponent = univ.SequenceOf()
|
| 1142 |
+
|
| 1143 |
+
|
| 1144 |
+
class SequencePayloadDecoder(SequenceOrSequenceOfPayloadDecoder):
|
| 1145 |
+
protoComponent = univ.Sequence()
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
class SequenceOfPayloadDecoder(SequenceOrSequenceOfPayloadDecoder):
|
| 1149 |
+
protoComponent = univ.SequenceOf()
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
class SetOrSetOfPayloadDecoder(ConstructedPayloadDecoderBase):
|
| 1153 |
+
protoRecordComponent = univ.Set()
|
| 1154 |
+
protoSequenceComponent = univ.SetOf()
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
class SetPayloadDecoder(SetOrSetOfPayloadDecoder):
|
| 1158 |
+
protoComponent = univ.Set()
|
| 1159 |
+
|
| 1160 |
+
|
| 1161 |
+
class SetOfPayloadDecoder(SetOrSetOfPayloadDecoder):
|
| 1162 |
+
protoComponent = univ.SetOf()
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
class ChoicePayloadDecoder(ConstructedPayloadDecoderBase):
|
| 1166 |
+
protoComponent = univ.Choice()
|
| 1167 |
+
|
| 1168 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 1169 |
+
tagSet=None, length=None, state=None,
|
| 1170 |
+
decodeFun=None, substrateFun=None,
|
| 1171 |
+
**options):
|
| 1172 |
+
if asn1Spec is None:
|
| 1173 |
+
asn1Object = self.protoComponent.clone(tagSet=tagSet)
|
| 1174 |
+
|
| 1175 |
+
else:
|
| 1176 |
+
asn1Object = asn1Spec.clone()
|
| 1177 |
+
|
| 1178 |
+
if substrateFun:
|
| 1179 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 1180 |
+
yield chunk
|
| 1181 |
+
|
| 1182 |
+
return
|
| 1183 |
+
|
| 1184 |
+
options = self._passAsn1Object(asn1Object, options)
|
| 1185 |
+
|
| 1186 |
+
if asn1Object.tagSet == tagSet:
|
| 1187 |
+
if LOG:
|
| 1188 |
+
LOG('decoding %s as explicitly tagged CHOICE' % (tagSet,))
|
| 1189 |
+
|
| 1190 |
+
for component in decodeFun(
|
| 1191 |
+
substrate, asn1Object.componentTagMap, **options):
|
| 1192 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 1193 |
+
yield component
|
| 1194 |
+
|
| 1195 |
+
else:
|
| 1196 |
+
if LOG:
|
| 1197 |
+
LOG('decoding %s as untagged CHOICE' % (tagSet,))
|
| 1198 |
+
|
| 1199 |
+
for component in decodeFun(
|
| 1200 |
+
substrate, asn1Object.componentTagMap, tagSet, length,
|
| 1201 |
+
state, **options):
|
| 1202 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 1203 |
+
yield component
|
| 1204 |
+
|
| 1205 |
+
effectiveTagSet = component.effectiveTagSet
|
| 1206 |
+
|
| 1207 |
+
if LOG:
|
| 1208 |
+
LOG('decoded component %s, effective tag set %s' % (component, effectiveTagSet))
|
| 1209 |
+
|
| 1210 |
+
asn1Object.setComponentByType(
|
| 1211 |
+
effectiveTagSet, component,
|
| 1212 |
+
verifyConstraints=False,
|
| 1213 |
+
matchTags=False, matchConstraints=False,
|
| 1214 |
+
innerFlag=False
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
+
yield asn1Object
|
| 1218 |
+
|
| 1219 |
+
def indefLenValueDecoder(self, substrate, asn1Spec,
|
| 1220 |
+
tagSet=None, length=None, state=None,
|
| 1221 |
+
decodeFun=None, substrateFun=None,
|
| 1222 |
+
**options):
|
| 1223 |
+
if asn1Spec is None:
|
| 1224 |
+
asn1Object = self.protoComponent.clone(tagSet=tagSet)
|
| 1225 |
+
|
| 1226 |
+
else:
|
| 1227 |
+
asn1Object = asn1Spec.clone()
|
| 1228 |
+
|
| 1229 |
+
if substrateFun:
|
| 1230 |
+
for chunk in substrateFun(asn1Object, substrate, length, options):
|
| 1231 |
+
yield chunk
|
| 1232 |
+
|
| 1233 |
+
return
|
| 1234 |
+
|
| 1235 |
+
options = self._passAsn1Object(asn1Object, options)
|
| 1236 |
+
|
| 1237 |
+
isTagged = asn1Object.tagSet == tagSet
|
| 1238 |
+
|
| 1239 |
+
if LOG:
|
| 1240 |
+
LOG('decoding %s as %stagged CHOICE' % (
|
| 1241 |
+
tagSet, isTagged and 'explicitly ' or 'un'))
|
| 1242 |
+
|
| 1243 |
+
while True:
|
| 1244 |
+
|
| 1245 |
+
if isTagged:
|
| 1246 |
+
iterator = decodeFun(
|
| 1247 |
+
substrate, asn1Object.componentType.tagMapUnique,
|
| 1248 |
+
**dict(options, allowEoo=True))
|
| 1249 |
+
|
| 1250 |
+
else:
|
| 1251 |
+
iterator = decodeFun(
|
| 1252 |
+
substrate, asn1Object.componentType.tagMapUnique,
|
| 1253 |
+
tagSet, length, state, **dict(options, allowEoo=True))
|
| 1254 |
+
|
| 1255 |
+
for component in iterator:
|
| 1256 |
+
|
| 1257 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 1258 |
+
yield component
|
| 1259 |
+
|
| 1260 |
+
if component is eoo.endOfOctets:
|
| 1261 |
+
break
|
| 1262 |
+
|
| 1263 |
+
effectiveTagSet = component.effectiveTagSet
|
| 1264 |
+
|
| 1265 |
+
if LOG:
|
| 1266 |
+
LOG('decoded component %s, effective tag set '
|
| 1267 |
+
'%s' % (component, effectiveTagSet))
|
| 1268 |
+
|
| 1269 |
+
asn1Object.setComponentByType(
|
| 1270 |
+
effectiveTagSet, component,
|
| 1271 |
+
verifyConstraints=False,
|
| 1272 |
+
matchTags=False, matchConstraints=False,
|
| 1273 |
+
innerFlag=False
|
| 1274 |
+
)
|
| 1275 |
+
|
| 1276 |
+
if not isTagged:
|
| 1277 |
+
break
|
| 1278 |
+
|
| 1279 |
+
if not isTagged or component is eoo.endOfOctets:
|
| 1280 |
+
break
|
| 1281 |
+
|
| 1282 |
+
yield asn1Object
|
| 1283 |
+
|
| 1284 |
+
|
| 1285 |
+
class AnyPayloadDecoder(AbstractSimplePayloadDecoder):
|
| 1286 |
+
protoComponent = univ.Any()
|
| 1287 |
+
|
| 1288 |
+
def valueDecoder(self, substrate, asn1Spec,
|
| 1289 |
+
tagSet=None, length=None, state=None,
|
| 1290 |
+
decodeFun=None, substrateFun=None,
|
| 1291 |
+
**options):
|
| 1292 |
+
if asn1Spec is None:
|
| 1293 |
+
isUntagged = True
|
| 1294 |
+
|
| 1295 |
+
elif asn1Spec.__class__ is tagmap.TagMap:
|
| 1296 |
+
isUntagged = tagSet not in asn1Spec.tagMap
|
| 1297 |
+
|
| 1298 |
+
else:
|
| 1299 |
+
isUntagged = tagSet != asn1Spec.tagSet
|
| 1300 |
+
|
| 1301 |
+
if isUntagged:
|
| 1302 |
+
fullPosition = substrate.markedPosition
|
| 1303 |
+
currentPosition = substrate.tell()
|
| 1304 |
+
|
| 1305 |
+
substrate.seek(fullPosition, os.SEEK_SET)
|
| 1306 |
+
length += currentPosition - fullPosition
|
| 1307 |
+
|
| 1308 |
+
if LOG:
|
| 1309 |
+
for chunk in peekIntoStream(substrate, length):
|
| 1310 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 1311 |
+
yield chunk
|
| 1312 |
+
LOG('decoding as untagged ANY, substrate '
|
| 1313 |
+
'%s' % debug.hexdump(chunk))
|
| 1314 |
+
|
| 1315 |
+
if substrateFun:
|
| 1316 |
+
for chunk in substrateFun(
|
| 1317 |
+
self._createComponent(asn1Spec, tagSet, noValue, **options),
|
| 1318 |
+
substrate, length, options):
|
| 1319 |
+
yield chunk
|
| 1320 |
+
|
| 1321 |
+
return
|
| 1322 |
+
|
| 1323 |
+
for chunk in readFromStream(substrate, length, options):
|
| 1324 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 1325 |
+
yield chunk
|
| 1326 |
+
|
| 1327 |
+
yield self._createComponent(asn1Spec, tagSet, chunk, **options)
|
| 1328 |
+
|
| 1329 |
+
def indefLenValueDecoder(self, substrate, asn1Spec,
|
| 1330 |
+
tagSet=None, length=None, state=None,
|
| 1331 |
+
decodeFun=None, substrateFun=None,
|
| 1332 |
+
**options):
|
| 1333 |
+
if asn1Spec is None:
|
| 1334 |
+
isTagged = False
|
| 1335 |
+
|
| 1336 |
+
elif asn1Spec.__class__ is tagmap.TagMap:
|
| 1337 |
+
isTagged = tagSet in asn1Spec.tagMap
|
| 1338 |
+
|
| 1339 |
+
else:
|
| 1340 |
+
isTagged = tagSet == asn1Spec.tagSet
|
| 1341 |
+
|
| 1342 |
+
if isTagged:
|
| 1343 |
+
# tagged Any type -- consume header substrate
|
| 1344 |
+
chunk = b''
|
| 1345 |
+
|
| 1346 |
+
if LOG:
|
| 1347 |
+
LOG('decoding as tagged ANY')
|
| 1348 |
+
|
| 1349 |
+
else:
|
| 1350 |
+
# TODO: Seems not to be tested
|
| 1351 |
+
fullPosition = substrate.markedPosition
|
| 1352 |
+
currentPosition = substrate.tell()
|
| 1353 |
+
|
| 1354 |
+
substrate.seek(fullPosition, os.SEEK_SET)
|
| 1355 |
+
for chunk in readFromStream(substrate, currentPosition - fullPosition, options):
|
| 1356 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 1357 |
+
yield chunk
|
| 1358 |
+
|
| 1359 |
+
if LOG:
|
| 1360 |
+
LOG('decoding as untagged ANY, header substrate %s' % debug.hexdump(chunk))
|
| 1361 |
+
|
| 1362 |
+
# Any components do not inherit initial tag
|
| 1363 |
+
asn1Spec = self.protoComponent
|
| 1364 |
+
|
| 1365 |
+
if substrateFun and substrateFun is not self.substrateCollector:
|
| 1366 |
+
asn1Object = self._createComponent(
|
| 1367 |
+
asn1Spec, tagSet, noValue, **options)
|
| 1368 |
+
|
| 1369 |
+
for chunk in substrateFun(
|
| 1370 |
+
asn1Object, chunk + substrate, length + len(chunk), options):
|
| 1371 |
+
yield chunk
|
| 1372 |
+
|
| 1373 |
+
return
|
| 1374 |
+
|
| 1375 |
+
if LOG:
|
| 1376 |
+
LOG('assembling constructed serialization')
|
| 1377 |
+
|
| 1378 |
+
# All inner fragments are of the same type, treat them as octet string
|
| 1379 |
+
substrateFun = self.substrateCollector
|
| 1380 |
+
|
| 1381 |
+
while True: # loop over fragments
|
| 1382 |
+
|
| 1383 |
+
for component in decodeFun(
|
| 1384 |
+
substrate, asn1Spec, substrateFun=substrateFun,
|
| 1385 |
+
allowEoo=True, **options):
|
| 1386 |
+
|
| 1387 |
+
if isinstance(component, SubstrateUnderrunError):
|
| 1388 |
+
yield component
|
| 1389 |
+
|
| 1390 |
+
if component is eoo.endOfOctets:
|
| 1391 |
+
break
|
| 1392 |
+
|
| 1393 |
+
if component is eoo.endOfOctets:
|
| 1394 |
+
break
|
| 1395 |
+
|
| 1396 |
+
chunk += component
|
| 1397 |
+
|
| 1398 |
+
if substrateFun:
|
| 1399 |
+
yield chunk # TODO: Weird
|
| 1400 |
+
|
| 1401 |
+
else:
|
| 1402 |
+
yield self._createComponent(asn1Spec, tagSet, chunk, **options)
|
| 1403 |
+
|
| 1404 |
+
|
| 1405 |
+
# character string types
|
| 1406 |
+
class UTF8StringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1407 |
+
protoComponent = char.UTF8String()
|
| 1408 |
+
|
| 1409 |
+
|
| 1410 |
+
class NumericStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1411 |
+
protoComponent = char.NumericString()
|
| 1412 |
+
|
| 1413 |
+
|
| 1414 |
+
class PrintableStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1415 |
+
protoComponent = char.PrintableString()
|
| 1416 |
+
|
| 1417 |
+
|
| 1418 |
+
class TeletexStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1419 |
+
protoComponent = char.TeletexString()
|
| 1420 |
+
|
| 1421 |
+
|
| 1422 |
+
class VideotexStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1423 |
+
protoComponent = char.VideotexString()
|
| 1424 |
+
|
| 1425 |
+
|
| 1426 |
+
class IA5StringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1427 |
+
protoComponent = char.IA5String()
|
| 1428 |
+
|
| 1429 |
+
|
| 1430 |
+
class GraphicStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1431 |
+
protoComponent = char.GraphicString()
|
| 1432 |
+
|
| 1433 |
+
|
| 1434 |
+
class VisibleStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1435 |
+
protoComponent = char.VisibleString()
|
| 1436 |
+
|
| 1437 |
+
|
| 1438 |
+
class GeneralStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1439 |
+
protoComponent = char.GeneralString()
|
| 1440 |
+
|
| 1441 |
+
|
| 1442 |
+
class UniversalStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1443 |
+
protoComponent = char.UniversalString()
|
| 1444 |
+
|
| 1445 |
+
|
| 1446 |
+
class BMPStringPayloadDecoder(OctetStringPayloadDecoder):
|
| 1447 |
+
protoComponent = char.BMPString()
|
| 1448 |
+
|
| 1449 |
+
|
| 1450 |
+
# "useful" types
|
| 1451 |
+
class ObjectDescriptorPayloadDecoder(OctetStringPayloadDecoder):
|
| 1452 |
+
protoComponent = useful.ObjectDescriptor()
|
| 1453 |
+
|
| 1454 |
+
|
| 1455 |
+
class GeneralizedTimePayloadDecoder(OctetStringPayloadDecoder):
|
| 1456 |
+
protoComponent = useful.GeneralizedTime()
|
| 1457 |
+
|
| 1458 |
+
|
| 1459 |
+
class UTCTimePayloadDecoder(OctetStringPayloadDecoder):
|
| 1460 |
+
protoComponent = useful.UTCTime()
|
| 1461 |
+
|
| 1462 |
+
|
| 1463 |
+
TAG_MAP = {
|
| 1464 |
+
univ.Integer.tagSet: IntegerPayloadDecoder(),
|
| 1465 |
+
univ.Boolean.tagSet: BooleanPayloadDecoder(),
|
| 1466 |
+
univ.BitString.tagSet: BitStringPayloadDecoder(),
|
| 1467 |
+
univ.OctetString.tagSet: OctetStringPayloadDecoder(),
|
| 1468 |
+
univ.Null.tagSet: NullPayloadDecoder(),
|
| 1469 |
+
univ.ObjectIdentifier.tagSet: ObjectIdentifierPayloadDecoder(),
|
| 1470 |
+
univ.RelativeOID.tagSet: RelativeOIDPayloadDecoder(),
|
| 1471 |
+
univ.Enumerated.tagSet: IntegerPayloadDecoder(),
|
| 1472 |
+
univ.Real.tagSet: RealPayloadDecoder(),
|
| 1473 |
+
univ.Sequence.tagSet: SequenceOrSequenceOfPayloadDecoder(), # conflicts with SequenceOf
|
| 1474 |
+
univ.Set.tagSet: SetOrSetOfPayloadDecoder(), # conflicts with SetOf
|
| 1475 |
+
univ.Choice.tagSet: ChoicePayloadDecoder(), # conflicts with Any
|
| 1476 |
+
# character string types
|
| 1477 |
+
char.UTF8String.tagSet: UTF8StringPayloadDecoder(),
|
| 1478 |
+
char.NumericString.tagSet: NumericStringPayloadDecoder(),
|
| 1479 |
+
char.PrintableString.tagSet: PrintableStringPayloadDecoder(),
|
| 1480 |
+
char.TeletexString.tagSet: TeletexStringPayloadDecoder(),
|
| 1481 |
+
char.VideotexString.tagSet: VideotexStringPayloadDecoder(),
|
| 1482 |
+
char.IA5String.tagSet: IA5StringPayloadDecoder(),
|
| 1483 |
+
char.GraphicString.tagSet: GraphicStringPayloadDecoder(),
|
| 1484 |
+
char.VisibleString.tagSet: VisibleStringPayloadDecoder(),
|
| 1485 |
+
char.GeneralString.tagSet: GeneralStringPayloadDecoder(),
|
| 1486 |
+
char.UniversalString.tagSet: UniversalStringPayloadDecoder(),
|
| 1487 |
+
char.BMPString.tagSet: BMPStringPayloadDecoder(),
|
| 1488 |
+
# useful types
|
| 1489 |
+
useful.ObjectDescriptor.tagSet: ObjectDescriptorPayloadDecoder(),
|
| 1490 |
+
useful.GeneralizedTime.tagSet: GeneralizedTimePayloadDecoder(),
|
| 1491 |
+
useful.UTCTime.tagSet: UTCTimePayloadDecoder()
|
| 1492 |
+
}
|
| 1493 |
+
|
| 1494 |
+
# Type-to-codec map for ambiguous ASN.1 types
|
| 1495 |
+
TYPE_MAP = {
|
| 1496 |
+
univ.Set.typeId: SetPayloadDecoder(),
|
| 1497 |
+
univ.SetOf.typeId: SetOfPayloadDecoder(),
|
| 1498 |
+
univ.Sequence.typeId: SequencePayloadDecoder(),
|
| 1499 |
+
univ.SequenceOf.typeId: SequenceOfPayloadDecoder(),
|
| 1500 |
+
univ.Choice.typeId: ChoicePayloadDecoder(),
|
| 1501 |
+
univ.Any.typeId: AnyPayloadDecoder()
|
| 1502 |
+
}
|
| 1503 |
+
|
| 1504 |
+
# Put in non-ambiguous types for faster codec lookup
|
| 1505 |
+
for typeDecoder in TAG_MAP.values():
|
| 1506 |
+
if typeDecoder.protoComponent is not None:
|
| 1507 |
+
typeId = typeDecoder.protoComponent.__class__.typeId
|
| 1508 |
+
if typeId is not None and typeId not in TYPE_MAP:
|
| 1509 |
+
TYPE_MAP[typeId] = typeDecoder
|
| 1510 |
+
|
| 1511 |
+
|
| 1512 |
+
(stDecodeTag,
|
| 1513 |
+
stDecodeLength,
|
| 1514 |
+
stGetValueDecoder,
|
| 1515 |
+
stGetValueDecoderByAsn1Spec,
|
| 1516 |
+
stGetValueDecoderByTag,
|
| 1517 |
+
stTryAsExplicitTag,
|
| 1518 |
+
stDecodeValue,
|
| 1519 |
+
stDumpRawValue,
|
| 1520 |
+
stErrorCondition,
|
| 1521 |
+
stStop) = [x for x in range(10)]
|
| 1522 |
+
|
| 1523 |
+
|
| 1524 |
+
EOO_SENTINEL = bytes((0, 0))
|
| 1525 |
+
|
| 1526 |
+
|
| 1527 |
+
class SingleItemDecoder(object):
|
| 1528 |
+
defaultErrorState = stErrorCondition
|
| 1529 |
+
#defaultErrorState = stDumpRawValue
|
| 1530 |
+
defaultRawDecoder = AnyPayloadDecoder()
|
| 1531 |
+
|
| 1532 |
+
supportIndefLength = True
|
| 1533 |
+
|
| 1534 |
+
TAG_MAP = TAG_MAP
|
| 1535 |
+
TYPE_MAP = TYPE_MAP
|
| 1536 |
+
|
| 1537 |
+
def __init__(self, tagMap=_MISSING, typeMap=_MISSING, **ignored):
|
| 1538 |
+
self._tagMap = tagMap if tagMap is not _MISSING else self.TAG_MAP
|
| 1539 |
+
self._typeMap = typeMap if typeMap is not _MISSING else self.TYPE_MAP
|
| 1540 |
+
|
| 1541 |
+
# Tag & TagSet objects caches
|
| 1542 |
+
self._tagCache = {}
|
| 1543 |
+
self._tagSetCache = {}
|
| 1544 |
+
|
| 1545 |
+
def __call__(self, substrate, asn1Spec=None,
|
| 1546 |
+
tagSet=None, length=None, state=stDecodeTag,
|
| 1547 |
+
decodeFun=None, substrateFun=None,
|
| 1548 |
+
**options):
|
| 1549 |
+
|
| 1550 |
+
allowEoo = options.pop('allowEoo', False)
|
| 1551 |
+
|
| 1552 |
+
if LOG:
|
| 1553 |
+
LOG('decoder called at scope %s with state %d, working with up '
|
| 1554 |
+
'to %s octets of substrate: '
|
| 1555 |
+
'%s' % (debug.scope, state, length, substrate))
|
| 1556 |
+
|
| 1557 |
+
# Look for end-of-octets sentinel
|
| 1558 |
+
if allowEoo and self.supportIndefLength:
|
| 1559 |
+
|
| 1560 |
+
for eoo_candidate in readFromStream(substrate, 2, options):
|
| 1561 |
+
if isinstance(eoo_candidate, SubstrateUnderrunError):
|
| 1562 |
+
yield eoo_candidate
|
| 1563 |
+
|
| 1564 |
+
if eoo_candidate == EOO_SENTINEL:
|
| 1565 |
+
if LOG:
|
| 1566 |
+
LOG('end-of-octets sentinel found')
|
| 1567 |
+
yield eoo.endOfOctets
|
| 1568 |
+
return
|
| 1569 |
+
|
| 1570 |
+
else:
|
| 1571 |
+
substrate.seek(-2, os.SEEK_CUR)
|
| 1572 |
+
|
| 1573 |
+
tagMap = self._tagMap
|
| 1574 |
+
typeMap = self._typeMap
|
| 1575 |
+
tagCache = self._tagCache
|
| 1576 |
+
tagSetCache = self._tagSetCache
|
| 1577 |
+
|
| 1578 |
+
value = noValue
|
| 1579 |
+
|
| 1580 |
+
substrate.markedPosition = substrate.tell()
|
| 1581 |
+
|
| 1582 |
+
while state is not stStop:
|
| 1583 |
+
|
| 1584 |
+
if state is stDecodeTag:
|
| 1585 |
+
# Decode tag
|
| 1586 |
+
isShortTag = True
|
| 1587 |
+
|
| 1588 |
+
for firstByte in readFromStream(substrate, 1, options):
|
| 1589 |
+
if isinstance(firstByte, SubstrateUnderrunError):
|
| 1590 |
+
yield firstByte
|
| 1591 |
+
|
| 1592 |
+
firstOctet = ord(firstByte)
|
| 1593 |
+
|
| 1594 |
+
try:
|
| 1595 |
+
lastTag = tagCache[firstOctet]
|
| 1596 |
+
|
| 1597 |
+
except KeyError:
|
| 1598 |
+
integerTag = firstOctet
|
| 1599 |
+
tagClass = integerTag & 0xC0
|
| 1600 |
+
tagFormat = integerTag & 0x20
|
| 1601 |
+
tagId = integerTag & 0x1F
|
| 1602 |
+
|
| 1603 |
+
if tagId == 0x1F:
|
| 1604 |
+
isShortTag = False
|
| 1605 |
+
lengthOctetIdx = 0
|
| 1606 |
+
tagId = 0
|
| 1607 |
+
|
| 1608 |
+
while True:
|
| 1609 |
+
for integerByte in readFromStream(substrate, 1, options):
|
| 1610 |
+
if isinstance(integerByte, SubstrateUnderrunError):
|
| 1611 |
+
yield integerByte
|
| 1612 |
+
|
| 1613 |
+
if not integerByte:
|
| 1614 |
+
raise error.SubstrateUnderrunError(
|
| 1615 |
+
'Short octet stream on long tag decoding'
|
| 1616 |
+
)
|
| 1617 |
+
|
| 1618 |
+
integerTag = ord(integerByte)
|
| 1619 |
+
lengthOctetIdx += 1
|
| 1620 |
+
tagId <<= 7
|
| 1621 |
+
tagId |= (integerTag & 0x7F)
|
| 1622 |
+
|
| 1623 |
+
if not integerTag & 0x80:
|
| 1624 |
+
break
|
| 1625 |
+
|
| 1626 |
+
lastTag = tag.Tag(
|
| 1627 |
+
tagClass=tagClass, tagFormat=tagFormat, tagId=tagId
|
| 1628 |
+
)
|
| 1629 |
+
|
| 1630 |
+
if isShortTag:
|
| 1631 |
+
# cache short tags
|
| 1632 |
+
tagCache[firstOctet] = lastTag
|
| 1633 |
+
|
| 1634 |
+
if tagSet is None:
|
| 1635 |
+
if isShortTag:
|
| 1636 |
+
try:
|
| 1637 |
+
tagSet = tagSetCache[firstOctet]
|
| 1638 |
+
|
| 1639 |
+
except KeyError:
|
| 1640 |
+
# base tag not recovered
|
| 1641 |
+
tagSet = tag.TagSet((), lastTag)
|
| 1642 |
+
tagSetCache[firstOctet] = tagSet
|
| 1643 |
+
else:
|
| 1644 |
+
tagSet = tag.TagSet((), lastTag)
|
| 1645 |
+
|
| 1646 |
+
else:
|
| 1647 |
+
tagSet = lastTag + tagSet
|
| 1648 |
+
|
| 1649 |
+
state = stDecodeLength
|
| 1650 |
+
|
| 1651 |
+
if LOG:
|
| 1652 |
+
LOG('tag decoded into %s, decoding length' % tagSet)
|
| 1653 |
+
|
| 1654 |
+
if state is stDecodeLength:
|
| 1655 |
+
# Decode length
|
| 1656 |
+
for firstOctet in readFromStream(substrate, 1, options):
|
| 1657 |
+
if isinstance(firstOctet, SubstrateUnderrunError):
|
| 1658 |
+
yield firstOctet
|
| 1659 |
+
|
| 1660 |
+
firstOctet = ord(firstOctet)
|
| 1661 |
+
|
| 1662 |
+
if firstOctet < 128:
|
| 1663 |
+
length = firstOctet
|
| 1664 |
+
|
| 1665 |
+
elif firstOctet > 128:
|
| 1666 |
+
size = firstOctet & 0x7F
|
| 1667 |
+
# encoded in size bytes
|
| 1668 |
+
for encodedLength in readFromStream(substrate, size, options):
|
| 1669 |
+
if isinstance(encodedLength, SubstrateUnderrunError):
|
| 1670 |
+
yield encodedLength
|
| 1671 |
+
encodedLength = list(encodedLength)
|
| 1672 |
+
# missing check on maximum size, which shouldn't be a
|
| 1673 |
+
# problem, we can handle more than is possible
|
| 1674 |
+
if len(encodedLength) != size:
|
| 1675 |
+
raise error.SubstrateUnderrunError(
|
| 1676 |
+
'%s<%s at %s' % (size, len(encodedLength), tagSet)
|
| 1677 |
+
)
|
| 1678 |
+
|
| 1679 |
+
length = 0
|
| 1680 |
+
for lengthOctet in encodedLength:
|
| 1681 |
+
length <<= 8
|
| 1682 |
+
length |= lengthOctet
|
| 1683 |
+
size += 1
|
| 1684 |
+
|
| 1685 |
+
else: # 128 means indefinite
|
| 1686 |
+
length = -1
|
| 1687 |
+
|
| 1688 |
+
if length == -1 and not self.supportIndefLength:
|
| 1689 |
+
raise error.PyAsn1Error('Indefinite length encoding not supported by this codec')
|
| 1690 |
+
|
| 1691 |
+
state = stGetValueDecoder
|
| 1692 |
+
|
| 1693 |
+
if LOG:
|
| 1694 |
+
LOG('value length decoded into %d' % length)
|
| 1695 |
+
|
| 1696 |
+
if state is stGetValueDecoder:
|
| 1697 |
+
if asn1Spec is None:
|
| 1698 |
+
state = stGetValueDecoderByTag
|
| 1699 |
+
|
| 1700 |
+
else:
|
| 1701 |
+
state = stGetValueDecoderByAsn1Spec
|
| 1702 |
+
#
|
| 1703 |
+
# There're two ways of creating subtypes in ASN.1 what influences
|
| 1704 |
+
# decoder operation. These methods are:
|
| 1705 |
+
# 1) Either base types used in or no IMPLICIT tagging has been
|
| 1706 |
+
# applied on subtyping.
|
| 1707 |
+
# 2) Subtype syntax drops base type information (by means of
|
| 1708 |
+
# IMPLICIT tagging.
|
| 1709 |
+
# The first case allows for complete tag recovery from substrate
|
| 1710 |
+
# while the second one requires original ASN.1 type spec for
|
| 1711 |
+
# decoding.
|
| 1712 |
+
#
|
| 1713 |
+
# In either case a set of tags (tagSet) is coming from substrate
|
| 1714 |
+
# in an incremental, tag-by-tag fashion (this is the case of
|
| 1715 |
+
# EXPLICIT tag which is most basic). Outermost tag comes first
|
| 1716 |
+
# from the wire.
|
| 1717 |
+
#
|
| 1718 |
+
if state is stGetValueDecoderByTag:
|
| 1719 |
+
try:
|
| 1720 |
+
concreteDecoder = tagMap[tagSet]
|
| 1721 |
+
|
| 1722 |
+
except KeyError:
|
| 1723 |
+
concreteDecoder = None
|
| 1724 |
+
|
| 1725 |
+
if concreteDecoder:
|
| 1726 |
+
state = stDecodeValue
|
| 1727 |
+
|
| 1728 |
+
else:
|
| 1729 |
+
try:
|
| 1730 |
+
concreteDecoder = tagMap[tagSet[:1]]
|
| 1731 |
+
|
| 1732 |
+
except KeyError:
|
| 1733 |
+
concreteDecoder = None
|
| 1734 |
+
|
| 1735 |
+
if concreteDecoder:
|
| 1736 |
+
state = stDecodeValue
|
| 1737 |
+
else:
|
| 1738 |
+
state = stTryAsExplicitTag
|
| 1739 |
+
|
| 1740 |
+
if LOG:
|
| 1741 |
+
LOG('codec %s chosen by a built-in type, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "<none>", state is stDecodeValue and 'value' or 'as explicit tag'))
|
| 1742 |
+
debug.scope.push(concreteDecoder is None and '?' or concreteDecoder.protoComponent.__class__.__name__)
|
| 1743 |
+
|
| 1744 |
+
if state is stGetValueDecoderByAsn1Spec:
|
| 1745 |
+
|
| 1746 |
+
if asn1Spec.__class__ is tagmap.TagMap:
|
| 1747 |
+
try:
|
| 1748 |
+
chosenSpec = asn1Spec[tagSet]
|
| 1749 |
+
|
| 1750 |
+
except KeyError:
|
| 1751 |
+
chosenSpec = None
|
| 1752 |
+
|
| 1753 |
+
if LOG:
|
| 1754 |
+
LOG('candidate ASN.1 spec is a map of:')
|
| 1755 |
+
|
| 1756 |
+
for firstOctet, v in asn1Spec.presentTypes.items():
|
| 1757 |
+
LOG(' %s -> %s' % (firstOctet, v.__class__.__name__))
|
| 1758 |
+
|
| 1759 |
+
if asn1Spec.skipTypes:
|
| 1760 |
+
LOG('but neither of: ')
|
| 1761 |
+
for firstOctet, v in asn1Spec.skipTypes.items():
|
| 1762 |
+
LOG(' %s -> %s' % (firstOctet, v.__class__.__name__))
|
| 1763 |
+
LOG('new candidate ASN.1 spec is %s, chosen by %s' % (chosenSpec is None and '<none>' or chosenSpec.prettyPrintType(), tagSet))
|
| 1764 |
+
|
| 1765 |
+
elif tagSet == asn1Spec.tagSet or tagSet in asn1Spec.tagMap:
|
| 1766 |
+
chosenSpec = asn1Spec
|
| 1767 |
+
if LOG:
|
| 1768 |
+
LOG('candidate ASN.1 spec is %s' % asn1Spec.__class__.__name__)
|
| 1769 |
+
|
| 1770 |
+
else:
|
| 1771 |
+
chosenSpec = None
|
| 1772 |
+
|
| 1773 |
+
if chosenSpec is not None:
|
| 1774 |
+
try:
|
| 1775 |
+
# ambiguous type or just faster codec lookup
|
| 1776 |
+
concreteDecoder = typeMap[chosenSpec.typeId]
|
| 1777 |
+
|
| 1778 |
+
if LOG:
|
| 1779 |
+
LOG('value decoder chosen for an ambiguous type by type ID %s' % (chosenSpec.typeId,))
|
| 1780 |
+
|
| 1781 |
+
except KeyError:
|
| 1782 |
+
# use base type for codec lookup to recover untagged types
|
| 1783 |
+
baseTagSet = tag.TagSet(chosenSpec.tagSet.baseTag, chosenSpec.tagSet.baseTag)
|
| 1784 |
+
try:
|
| 1785 |
+
# base type or tagged subtype
|
| 1786 |
+
concreteDecoder = tagMap[baseTagSet]
|
| 1787 |
+
|
| 1788 |
+
if LOG:
|
| 1789 |
+
LOG('value decoder chosen by base %s' % (baseTagSet,))
|
| 1790 |
+
|
| 1791 |
+
except KeyError:
|
| 1792 |
+
concreteDecoder = None
|
| 1793 |
+
|
| 1794 |
+
if concreteDecoder:
|
| 1795 |
+
asn1Spec = chosenSpec
|
| 1796 |
+
state = stDecodeValue
|
| 1797 |
+
|
| 1798 |
+
else:
|
| 1799 |
+
state = stTryAsExplicitTag
|
| 1800 |
+
|
| 1801 |
+
else:
|
| 1802 |
+
concreteDecoder = None
|
| 1803 |
+
state = stTryAsExplicitTag
|
| 1804 |
+
|
| 1805 |
+
if LOG:
|
| 1806 |
+
LOG('codec %s chosen by ASN.1 spec, decoding %s' % (state is stDecodeValue and concreteDecoder.__class__.__name__ or "<none>", state is stDecodeValue and 'value' or 'as explicit tag'))
|
| 1807 |
+
debug.scope.push(chosenSpec is None and '?' or chosenSpec.__class__.__name__)
|
| 1808 |
+
|
| 1809 |
+
if state is stDecodeValue:
|
| 1810 |
+
if not options.get('recursiveFlag', True) and not substrateFun: # deprecate this
|
| 1811 |
+
def substrateFun(asn1Object, _substrate, _length, _options):
|
| 1812 |
+
"""Legacy hack to keep the recursiveFlag=False option supported.
|
| 1813 |
+
|
| 1814 |
+
The decode(..., substrateFun=userCallback) option was introduced in 0.1.4 as a generalization
|
| 1815 |
+
of the old recursiveFlag=False option. Users should pass their callback instead of using
|
| 1816 |
+
recursiveFlag.
|
| 1817 |
+
"""
|
| 1818 |
+
yield asn1Object
|
| 1819 |
+
|
| 1820 |
+
original_position = substrate.tell()
|
| 1821 |
+
|
| 1822 |
+
if length == -1: # indef length
|
| 1823 |
+
for value in concreteDecoder.indefLenValueDecoder(
|
| 1824 |
+
substrate, asn1Spec,
|
| 1825 |
+
tagSet, length, stGetValueDecoder,
|
| 1826 |
+
self, substrateFun, **options):
|
| 1827 |
+
if isinstance(value, SubstrateUnderrunError):
|
| 1828 |
+
yield value
|
| 1829 |
+
|
| 1830 |
+
else:
|
| 1831 |
+
for value in concreteDecoder.valueDecoder(
|
| 1832 |
+
substrate, asn1Spec,
|
| 1833 |
+
tagSet, length, stGetValueDecoder,
|
| 1834 |
+
self, substrateFun, **options):
|
| 1835 |
+
if isinstance(value, SubstrateUnderrunError):
|
| 1836 |
+
yield value
|
| 1837 |
+
|
| 1838 |
+
bytesRead = substrate.tell() - original_position
|
| 1839 |
+
if not substrateFun and bytesRead != length:
|
| 1840 |
+
raise PyAsn1Error(
|
| 1841 |
+
"Read %s bytes instead of expected %s." % (bytesRead, length))
|
| 1842 |
+
elif substrateFun and bytesRead > length:
|
| 1843 |
+
# custom substrateFun may be used for partial decoding, reading less is expected there
|
| 1844 |
+
raise PyAsn1Error(
|
| 1845 |
+
"Read %s bytes are more than expected %s." % (bytesRead, length))
|
| 1846 |
+
|
| 1847 |
+
if LOG:
|
| 1848 |
+
LOG('codec %s yields type %s, value:\n%s\n...' % (
|
| 1849 |
+
concreteDecoder.__class__.__name__, value.__class__.__name__,
|
| 1850 |
+
isinstance(value, base.Asn1Item) and value.prettyPrint() or value))
|
| 1851 |
+
|
| 1852 |
+
state = stStop
|
| 1853 |
+
break
|
| 1854 |
+
|
| 1855 |
+
if state is stTryAsExplicitTag:
|
| 1856 |
+
if (tagSet and
|
| 1857 |
+
tagSet[0].tagFormat == tag.tagFormatConstructed and
|
| 1858 |
+
tagSet[0].tagClass != tag.tagClassUniversal):
|
| 1859 |
+
# Assume explicit tagging
|
| 1860 |
+
concreteDecoder = rawPayloadDecoder
|
| 1861 |
+
state = stDecodeValue
|
| 1862 |
+
|
| 1863 |
+
else:
|
| 1864 |
+
concreteDecoder = None
|
| 1865 |
+
state = self.defaultErrorState
|
| 1866 |
+
|
| 1867 |
+
if LOG:
|
| 1868 |
+
LOG('codec %s chosen, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "<none>", state is stDecodeValue and 'value' or 'as failure'))
|
| 1869 |
+
|
| 1870 |
+
if state is stDumpRawValue:
|
| 1871 |
+
concreteDecoder = self.defaultRawDecoder
|
| 1872 |
+
|
| 1873 |
+
if LOG:
|
| 1874 |
+
LOG('codec %s chosen, decoding value' % concreteDecoder.__class__.__name__)
|
| 1875 |
+
|
| 1876 |
+
state = stDecodeValue
|
| 1877 |
+
|
| 1878 |
+
if state is stErrorCondition:
|
| 1879 |
+
raise error.PyAsn1Error(
|
| 1880 |
+
'%s not in asn1Spec: %r' % (tagSet, asn1Spec)
|
| 1881 |
+
)
|
| 1882 |
+
|
| 1883 |
+
if LOG:
|
| 1884 |
+
debug.scope.pop()
|
| 1885 |
+
LOG('decoder left scope %s, call completed' % debug.scope)
|
| 1886 |
+
|
| 1887 |
+
yield value
|
| 1888 |
+
|
| 1889 |
+
|
| 1890 |
+
class StreamingDecoder(object):
|
| 1891 |
+
"""Create an iterator that turns BER/CER/DER byte stream into ASN.1 objects.
|
| 1892 |
+
|
| 1893 |
+
On each iteration, consume whatever BER/CER/DER serialization is
|
| 1894 |
+
available in the `substrate` stream-like object and turns it into
|
| 1895 |
+
one or more, possibly nested, ASN.1 objects.
|
| 1896 |
+
|
| 1897 |
+
Parameters
|
| 1898 |
+
----------
|
| 1899 |
+
substrate: :py:class:`file`, :py:class:`io.BytesIO`
|
| 1900 |
+
BER/CER/DER serialization in form of a byte stream
|
| 1901 |
+
|
| 1902 |
+
Keyword Args
|
| 1903 |
+
------------
|
| 1904 |
+
asn1Spec: :py:class:`~pyasn1.type.base.PyAsn1Item`
|
| 1905 |
+
A pyasn1 type object to act as a template guiding the decoder.
|
| 1906 |
+
Depending on the ASN.1 structure being decoded, `asn1Spec` may
|
| 1907 |
+
or may not be required. One of the reasons why `asn1Spec` may
|
| 1908 |
+
me required is that ASN.1 structure is encoded in the *IMPLICIT*
|
| 1909 |
+
tagging mode.
|
| 1910 |
+
|
| 1911 |
+
Yields
|
| 1912 |
+
------
|
| 1913 |
+
: :py:class:`~pyasn1.type.base.PyAsn1Item`, :py:class:`~pyasn1.error.SubstrateUnderrunError`
|
| 1914 |
+
Decoded ASN.1 object (possibly, nested) or
|
| 1915 |
+
:py:class:`~pyasn1.error.SubstrateUnderrunError` object indicating
|
| 1916 |
+
insufficient BER/CER/DER serialization on input to fully recover ASN.1
|
| 1917 |
+
objects from it.
|
| 1918 |
+
|
| 1919 |
+
In the latter case the caller is advised to ensure some more data in
|
| 1920 |
+
the input stream, then call the iterator again. The decoder will resume
|
| 1921 |
+
the decoding process using the newly arrived data.
|
| 1922 |
+
|
| 1923 |
+
The `context` property of :py:class:`~pyasn1.error.SubstrateUnderrunError`
|
| 1924 |
+
object might hold a reference to the partially populated ASN.1 object
|
| 1925 |
+
being reconstructed.
|
| 1926 |
+
|
| 1927 |
+
Raises
|
| 1928 |
+
------
|
| 1929 |
+
~pyasn1.error.PyAsn1Error, ~pyasn1.error.EndOfStreamError
|
| 1930 |
+
`PyAsn1Error` on deserialization error, `EndOfStreamError` on
|
| 1931 |
+
premature stream closure.
|
| 1932 |
+
|
| 1933 |
+
Examples
|
| 1934 |
+
--------
|
| 1935 |
+
Decode BER serialisation without ASN.1 schema
|
| 1936 |
+
|
| 1937 |
+
.. code-block:: pycon
|
| 1938 |
+
|
| 1939 |
+
>>> stream = io.BytesIO(
|
| 1940 |
+
... b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03')
|
| 1941 |
+
>>>
|
| 1942 |
+
>>> for asn1Object in StreamingDecoder(stream):
|
| 1943 |
+
... print(asn1Object)
|
| 1944 |
+
>>>
|
| 1945 |
+
SequenceOf:
|
| 1946 |
+
1 2 3
|
| 1947 |
+
|
| 1948 |
+
Decode BER serialisation with ASN.1 schema
|
| 1949 |
+
|
| 1950 |
+
.. code-block:: pycon
|
| 1951 |
+
|
| 1952 |
+
>>> stream = io.BytesIO(
|
| 1953 |
+
... b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03')
|
| 1954 |
+
>>>
|
| 1955 |
+
>>> schema = SequenceOf(componentType=Integer())
|
| 1956 |
+
>>>
|
| 1957 |
+
>>> decoder = StreamingDecoder(stream, asn1Spec=schema)
|
| 1958 |
+
>>> for asn1Object in decoder:
|
| 1959 |
+
... print(asn1Object)
|
| 1960 |
+
>>>
|
| 1961 |
+
SequenceOf:
|
| 1962 |
+
1 2 3
|
| 1963 |
+
"""
|
| 1964 |
+
|
| 1965 |
+
SINGLE_ITEM_DECODER = SingleItemDecoder
|
| 1966 |
+
|
| 1967 |
+
def __init__(self, substrate, asn1Spec=None, **options):
|
| 1968 |
+
self._singleItemDecoder = self.SINGLE_ITEM_DECODER(**options)
|
| 1969 |
+
self._substrate = asSeekableStream(substrate)
|
| 1970 |
+
self._asn1Spec = asn1Spec
|
| 1971 |
+
self._options = options
|
| 1972 |
+
|
| 1973 |
+
def __iter__(self):
|
| 1974 |
+
while True:
|
| 1975 |
+
for asn1Object in self._singleItemDecoder(
|
| 1976 |
+
self._substrate, self._asn1Spec, **self._options):
|
| 1977 |
+
yield asn1Object
|
| 1978 |
+
|
| 1979 |
+
for chunk in isEndOfStream(self._substrate):
|
| 1980 |
+
if isinstance(chunk, SubstrateUnderrunError):
|
| 1981 |
+
yield
|
| 1982 |
+
|
| 1983 |
+
break
|
| 1984 |
+
|
| 1985 |
+
if chunk:
|
| 1986 |
+
break
|
| 1987 |
+
|
| 1988 |
+
|
| 1989 |
+
class Decoder(object):
|
| 1990 |
+
"""Create a BER decoder object.
|
| 1991 |
+
|
| 1992 |
+
Parse BER/CER/DER octet-stream into one, possibly nested, ASN.1 object.
|
| 1993 |
+
"""
|
| 1994 |
+
STREAMING_DECODER = StreamingDecoder
|
| 1995 |
+
|
| 1996 |
+
@classmethod
|
| 1997 |
+
def __call__(cls, substrate, asn1Spec=None, **options):
|
| 1998 |
+
"""Turns BER/CER/DER octet stream into an ASN.1 object.
|
| 1999 |
+
|
| 2000 |
+
Takes BER/CER/DER octet-stream in form of :py:class:`bytes`
|
| 2001 |
+
and decode it into an ASN.1 object
|
| 2002 |
+
(e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) which
|
| 2003 |
+
may be a scalar or an arbitrary nested structure.
|
| 2004 |
+
|
| 2005 |
+
Parameters
|
| 2006 |
+
----------
|
| 2007 |
+
substrate: :py:class:`bytes`
|
| 2008 |
+
BER/CER/DER octet-stream to parse
|
| 2009 |
+
|
| 2010 |
+
Keyword Args
|
| 2011 |
+
------------
|
| 2012 |
+
asn1Spec: :py:class:`~pyasn1.type.base.PyAsn1Item`
|
| 2013 |
+
A pyasn1 type object (:py:class:`~pyasn1.type.base.PyAsn1Item`
|
| 2014 |
+
derivative) to act as a template guiding the decoder.
|
| 2015 |
+
Depending on the ASN.1 structure being decoded, `asn1Spec` may or
|
| 2016 |
+
may not be required. Most common reason for it to require is that
|
| 2017 |
+
ASN.1 structure is encoded in *IMPLICIT* tagging mode.
|
| 2018 |
+
|
| 2019 |
+
substrateFun: :py:class:`Union[
|
| 2020 |
+
Callable[[pyasn1.type.base.PyAsn1Item, bytes, int],
|
| 2021 |
+
Tuple[pyasn1.type.base.PyAsn1Item, bytes]],
|
| 2022 |
+
Callable[[pyasn1.type.base.PyAsn1Item, io.BytesIO, int, dict],
|
| 2023 |
+
Generator[Union[pyasn1.type.base.PyAsn1Item,
|
| 2024 |
+
pyasn1.error.SubstrateUnderrunError],
|
| 2025 |
+
None, None]]
|
| 2026 |
+
]`
|
| 2027 |
+
User callback meant to generalize special use cases like non-recursive or
|
| 2028 |
+
partial decoding. A 3-arg non-streaming variant is supported for backwards
|
| 2029 |
+
compatiblilty in addition to the newer 4-arg streaming variant.
|
| 2030 |
+
The callback will receive the uninitialized object recovered from substrate
|
| 2031 |
+
as 1st argument, the uninterpreted payload as 2nd argument, and the length
|
| 2032 |
+
of the uninterpreted payload as 3rd argument. The streaming variant will
|
| 2033 |
+
additionally receive the decode(..., **options) kwargs as 4th argument.
|
| 2034 |
+
The non-streaming variant shall return an object that will be propagated
|
| 2035 |
+
as decode() return value as 1st item, and the remainig payload for further
|
| 2036 |
+
decode passes as 2nd item.
|
| 2037 |
+
The streaming variant shall yield an object that will be propagated as
|
| 2038 |
+
decode() return value, and leave the remaining payload in the stream.
|
| 2039 |
+
|
| 2040 |
+
Returns
|
| 2041 |
+
-------
|
| 2042 |
+
: :py:class:`tuple`
|
| 2043 |
+
A tuple of :py:class:`~pyasn1.type.base.PyAsn1Item` object
|
| 2044 |
+
recovered from BER/CER/DER substrate and the unprocessed trailing
|
| 2045 |
+
portion of the `substrate` (may be empty)
|
| 2046 |
+
|
| 2047 |
+
Raises
|
| 2048 |
+
------
|
| 2049 |
+
: :py:class:`~pyasn1.error.PyAsn1Error`
|
| 2050 |
+
:py:class:`~pyasn1.error.SubstrateUnderrunError` on insufficient
|
| 2051 |
+
input or :py:class:`~pyasn1.error.PyAsn1Error` on decoding error.
|
| 2052 |
+
|
| 2053 |
+
Examples
|
| 2054 |
+
--------
|
| 2055 |
+
Decode BER/CER/DER serialisation without ASN.1 schema
|
| 2056 |
+
|
| 2057 |
+
.. code-block:: pycon
|
| 2058 |
+
|
| 2059 |
+
>>> s, unprocessed = decode(b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03')
|
| 2060 |
+
>>> str(s)
|
| 2061 |
+
SequenceOf:
|
| 2062 |
+
1 2 3
|
| 2063 |
+
|
| 2064 |
+
Decode BER/CER/DER serialisation with ASN.1 schema
|
| 2065 |
+
|
| 2066 |
+
.. code-block:: pycon
|
| 2067 |
+
|
| 2068 |
+
>>> seq = SequenceOf(componentType=Integer())
|
| 2069 |
+
>>> s, unprocessed = decode(
|
| 2070 |
+
b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03', asn1Spec=seq)
|
| 2071 |
+
>>> str(s)
|
| 2072 |
+
SequenceOf:
|
| 2073 |
+
1 2 3
|
| 2074 |
+
|
| 2075 |
+
"""
|
| 2076 |
+
substrate = asSeekableStream(substrate)
|
| 2077 |
+
|
| 2078 |
+
if "substrateFun" in options:
|
| 2079 |
+
origSubstrateFun = options["substrateFun"]
|
| 2080 |
+
|
| 2081 |
+
def substrateFunWrapper(asn1Object, substrate, length, options=None):
|
| 2082 |
+
"""Support both 0.4 and 0.5 style APIs.
|
| 2083 |
+
|
| 2084 |
+
substrateFun API has changed in 0.5 for use with streaming decoders. To stay backwards compatible,
|
| 2085 |
+
we first try if we received a streaming user callback. If that fails,we assume we've received a
|
| 2086 |
+
non-streaming v0.4 user callback and convert it for streaming on the fly
|
| 2087 |
+
"""
|
| 2088 |
+
try:
|
| 2089 |
+
substrate_gen = origSubstrateFun(asn1Object, substrate, length, options)
|
| 2090 |
+
except TypeError as _value:
|
| 2091 |
+
if _value.__traceback__.tb_next:
|
| 2092 |
+
# Traceback depth > 1 means TypeError from inside user provided function
|
| 2093 |
+
raise
|
| 2094 |
+
# invariant maintained at Decoder.__call__ entry
|
| 2095 |
+
assert isinstance(substrate, io.BytesIO) # nosec assert_used
|
| 2096 |
+
substrate_gen = Decoder._callSubstrateFunV4asV5(origSubstrateFun, asn1Object, substrate, length)
|
| 2097 |
+
for value in substrate_gen:
|
| 2098 |
+
yield value
|
| 2099 |
+
|
| 2100 |
+
options["substrateFun"] = substrateFunWrapper
|
| 2101 |
+
|
| 2102 |
+
streamingDecoder = cls.STREAMING_DECODER(
|
| 2103 |
+
substrate, asn1Spec, **options)
|
| 2104 |
+
|
| 2105 |
+
for asn1Object in streamingDecoder:
|
| 2106 |
+
if isinstance(asn1Object, SubstrateUnderrunError):
|
| 2107 |
+
raise error.SubstrateUnderrunError('Short substrate on input')
|
| 2108 |
+
|
| 2109 |
+
try:
|
| 2110 |
+
tail = next(readFromStream(substrate))
|
| 2111 |
+
|
| 2112 |
+
except error.EndOfStreamError:
|
| 2113 |
+
tail = b''
|
| 2114 |
+
|
| 2115 |
+
return asn1Object, tail
|
| 2116 |
+
|
| 2117 |
+
@staticmethod
|
| 2118 |
+
def _callSubstrateFunV4asV5(substrateFunV4, asn1Object, substrate, length):
|
| 2119 |
+
substrate_bytes = substrate.read()
|
| 2120 |
+
if length == -1:
|
| 2121 |
+
length = len(substrate_bytes)
|
| 2122 |
+
value, nextSubstrate = substrateFunV4(asn1Object, substrate_bytes, length)
|
| 2123 |
+
nbytes = substrate.write(nextSubstrate)
|
| 2124 |
+
substrate.truncate()
|
| 2125 |
+
substrate.seek(-nbytes, os.SEEK_CUR)
|
| 2126 |
+
yield value
|
| 2127 |
+
|
| 2128 |
+
#: Turns BER octet stream into an ASN.1 object.
|
| 2129 |
+
#:
|
| 2130 |
+
#: Takes BER octet-stream and decode it into an ASN.1 object
|
| 2131 |
+
#: (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) which
|
| 2132 |
+
#: may be a scalar or an arbitrary nested structure.
|
| 2133 |
+
#:
|
| 2134 |
+
#: Parameters
|
| 2135 |
+
#: ----------
|
| 2136 |
+
#: substrate: :py:class:`bytes`
|
| 2137 |
+
#: BER octet-stream
|
| 2138 |
+
#:
|
| 2139 |
+
#: Keyword Args
|
| 2140 |
+
#: ------------
|
| 2141 |
+
#: asn1Spec: any pyasn1 type object e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative
|
| 2142 |
+
#: A pyasn1 type object to act as a template guiding the decoder. Depending on the ASN.1 structure
|
| 2143 |
+
#: being decoded, *asn1Spec* may or may not be required. Most common reason for
|
| 2144 |
+
#: it to require is that ASN.1 structure is encoded in *IMPLICIT* tagging mode.
|
| 2145 |
+
#:
|
| 2146 |
+
#: Returns
|
| 2147 |
+
#: -------
|
| 2148 |
+
#: : :py:class:`tuple`
|
| 2149 |
+
#: A tuple of pyasn1 object recovered from BER substrate (:py:class:`~pyasn1.type.base.PyAsn1Item` derivative)
|
| 2150 |
+
#: and the unprocessed trailing portion of the *substrate* (may be empty)
|
| 2151 |
+
#:
|
| 2152 |
+
#: Raises
|
| 2153 |
+
#: ------
|
| 2154 |
+
#: ~pyasn1.error.PyAsn1Error, ~pyasn1.error.SubstrateUnderrunError
|
| 2155 |
+
#: On decoding errors
|
| 2156 |
+
#:
|
| 2157 |
+
#: Notes
|
| 2158 |
+
#: -----
|
| 2159 |
+
#: This function is deprecated. Please use :py:class:`Decoder` or
|
| 2160 |
+
#: :py:class:`StreamingDecoder` class instance.
|
| 2161 |
+
#:
|
| 2162 |
+
#: Examples
|
| 2163 |
+
#: --------
|
| 2164 |
+
#: Decode BER serialisation without ASN.1 schema
|
| 2165 |
+
#:
|
| 2166 |
+
#: .. code-block:: pycon
|
| 2167 |
+
#:
|
| 2168 |
+
#: >>> s, _ = decode(b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03')
|
| 2169 |
+
#: >>> str(s)
|
| 2170 |
+
#: SequenceOf:
|
| 2171 |
+
#: 1 2 3
|
| 2172 |
+
#:
|
| 2173 |
+
#: Decode BER serialisation with ASN.1 schema
|
| 2174 |
+
#:
|
| 2175 |
+
#: .. code-block:: pycon
|
| 2176 |
+
#:
|
| 2177 |
+
#: >>> seq = SequenceOf(componentType=Integer())
|
| 2178 |
+
#: >>> s, _ = decode(b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03', asn1Spec=seq)
|
| 2179 |
+
#: >>> str(s)
|
| 2180 |
+
#: SequenceOf:
|
| 2181 |
+
#: 1 2 3
|
| 2182 |
+
#:
|
| 2183 |
+
decode = Decoder()
|
| 2184 |
+
|
| 2185 |
+
def __getattr__(attr: str):
|
| 2186 |
+
if newAttr := {"tagMap": "TAG_MAP", "typeMap": "TYPE_MAP"}.get(attr):
|
| 2187 |
+
warnings.warn(f"{attr} is deprecated. Please use {newAttr} instead.", DeprecationWarning)
|
| 2188 |
+
return globals()[newAttr]
|
| 2189 |
+
raise AttributeError(attr)
|
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/encoder.py
ADDED
|
@@ -0,0 +1,954 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This file is part of pyasn1 software.
|
| 3 |
+
#
|
| 4 |
+
# Copyright (c) 2005-2020, Ilya Etingof <etingof@gmail.com>
|
| 5 |
+
# License: https://pyasn1.readthedocs.io/en/latest/license.html
|
| 6 |
+
#
|
| 7 |
+
import sys
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from pyasn1 import debug
|
| 11 |
+
from pyasn1 import error
|
| 12 |
+
from pyasn1.codec.ber import eoo
|
| 13 |
+
from pyasn1.compat import _MISSING
|
| 14 |
+
from pyasn1.compat.integer import to_bytes
|
| 15 |
+
from pyasn1.type import char
|
| 16 |
+
from pyasn1.type import tag
|
| 17 |
+
from pyasn1.type import univ
|
| 18 |
+
from pyasn1.type import useful
|
| 19 |
+
|
| 20 |
+
__all__ = ['Encoder', 'encode']
|
| 21 |
+
|
| 22 |
+
LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_ENCODER)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AbstractItemEncoder(object):
|
| 26 |
+
supportIndefLenMode = True
|
| 27 |
+
|
| 28 |
+
# An outcome of otherwise legit call `encodeFun(eoo.endOfOctets)`
|
| 29 |
+
eooIntegerSubstrate = (0, 0)
|
| 30 |
+
eooOctetsSubstrate = bytes(eooIntegerSubstrate)
|
| 31 |
+
|
| 32 |
+
# noinspection PyMethodMayBeStatic
|
| 33 |
+
def encodeTag(self, singleTag, isConstructed):
|
| 34 |
+
tagClass, tagFormat, tagId = singleTag
|
| 35 |
+
encodedTag = tagClass | tagFormat
|
| 36 |
+
if isConstructed:
|
| 37 |
+
encodedTag |= tag.tagFormatConstructed
|
| 38 |
+
|
| 39 |
+
if tagId < 31:
|
| 40 |
+
return encodedTag | tagId,
|
| 41 |
+
|
| 42 |
+
else:
|
| 43 |
+
substrate = tagId & 0x7f,
|
| 44 |
+
|
| 45 |
+
tagId >>= 7
|
| 46 |
+
|
| 47 |
+
while tagId:
|
| 48 |
+
substrate = (0x80 | (tagId & 0x7f),) + substrate
|
| 49 |
+
tagId >>= 7
|
| 50 |
+
|
| 51 |
+
return (encodedTag | 0x1F,) + substrate
|
| 52 |
+
|
| 53 |
+
def encodeLength(self, length, defMode):
|
| 54 |
+
if not defMode and self.supportIndefLenMode:
|
| 55 |
+
return (0x80,)
|
| 56 |
+
|
| 57 |
+
if length < 0x80:
|
| 58 |
+
return length,
|
| 59 |
+
|
| 60 |
+
else:
|
| 61 |
+
substrate = ()
|
| 62 |
+
while length:
|
| 63 |
+
substrate = (length & 0xff,) + substrate
|
| 64 |
+
length >>= 8
|
| 65 |
+
|
| 66 |
+
substrateLen = len(substrate)
|
| 67 |
+
|
| 68 |
+
if substrateLen > 126:
|
| 69 |
+
raise error.PyAsn1Error('Length octets overflow (%d)' % substrateLen)
|
| 70 |
+
|
| 71 |
+
return (0x80 | substrateLen,) + substrate
|
| 72 |
+
|
| 73 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 74 |
+
raise error.PyAsn1Error('Not implemented')
|
| 75 |
+
|
| 76 |
+
def encode(self, value, asn1Spec=None, encodeFun=None, **options):
|
| 77 |
+
|
| 78 |
+
if asn1Spec is None:
|
| 79 |
+
tagSet = value.tagSet
|
| 80 |
+
else:
|
| 81 |
+
tagSet = asn1Spec.tagSet
|
| 82 |
+
|
| 83 |
+
# untagged item?
|
| 84 |
+
if not tagSet:
|
| 85 |
+
substrate, isConstructed, isOctets = self.encodeValue(
|
| 86 |
+
value, asn1Spec, encodeFun, **options
|
| 87 |
+
)
|
| 88 |
+
return substrate
|
| 89 |
+
|
| 90 |
+
defMode = options.get('defMode', True)
|
| 91 |
+
|
| 92 |
+
substrate = b''
|
| 93 |
+
|
| 94 |
+
for idx, singleTag in enumerate(tagSet.superTags):
|
| 95 |
+
|
| 96 |
+
defModeOverride = defMode
|
| 97 |
+
|
| 98 |
+
# base tag?
|
| 99 |
+
if not idx:
|
| 100 |
+
try:
|
| 101 |
+
substrate, isConstructed, isOctets = self.encodeValue(
|
| 102 |
+
value, asn1Spec, encodeFun, **options
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
except error.PyAsn1Error as exc:
|
| 106 |
+
raise error.PyAsn1Error(
|
| 107 |
+
'Error encoding %r: %s' % (value, exc))
|
| 108 |
+
|
| 109 |
+
if LOG:
|
| 110 |
+
LOG('encoded %svalue %s into %s' % (
|
| 111 |
+
isConstructed and 'constructed ' or '', value, substrate
|
| 112 |
+
))
|
| 113 |
+
|
| 114 |
+
if not substrate and isConstructed and options.get('ifNotEmpty', False):
|
| 115 |
+
return substrate
|
| 116 |
+
|
| 117 |
+
if not isConstructed:
|
| 118 |
+
defModeOverride = True
|
| 119 |
+
|
| 120 |
+
if LOG:
|
| 121 |
+
LOG('overridden encoding mode into definitive for primitive type')
|
| 122 |
+
|
| 123 |
+
header = self.encodeTag(singleTag, isConstructed)
|
| 124 |
+
|
| 125 |
+
if LOG:
|
| 126 |
+
LOG('encoded %stag %s into %s' % (
|
| 127 |
+
isConstructed and 'constructed ' or '',
|
| 128 |
+
singleTag, debug.hexdump(bytes(header))))
|
| 129 |
+
|
| 130 |
+
header += self.encodeLength(len(substrate), defModeOverride)
|
| 131 |
+
|
| 132 |
+
if LOG:
|
| 133 |
+
LOG('encoded %s octets (tag + payload) into %s' % (
|
| 134 |
+
len(substrate), debug.hexdump(bytes(header))))
|
| 135 |
+
|
| 136 |
+
if isOctets:
|
| 137 |
+
substrate = bytes(header) + substrate
|
| 138 |
+
|
| 139 |
+
if not defModeOverride:
|
| 140 |
+
substrate += self.eooOctetsSubstrate
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
substrate = header + substrate
|
| 144 |
+
|
| 145 |
+
if not defModeOverride:
|
| 146 |
+
substrate += self.eooIntegerSubstrate
|
| 147 |
+
|
| 148 |
+
if not isOctets:
|
| 149 |
+
substrate = bytes(substrate)
|
| 150 |
+
|
| 151 |
+
return substrate
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class EndOfOctetsEncoder(AbstractItemEncoder):
|
| 155 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 156 |
+
return b'', False, True
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class BooleanEncoder(AbstractItemEncoder):
|
| 160 |
+
supportIndefLenMode = False
|
| 161 |
+
|
| 162 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 163 |
+
return value and (1,) or (0,), False, False
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class IntegerEncoder(AbstractItemEncoder):
|
| 167 |
+
supportIndefLenMode = False
|
| 168 |
+
supportCompactZero = False
|
| 169 |
+
|
| 170 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 171 |
+
if value == 0:
|
| 172 |
+
if LOG:
|
| 173 |
+
LOG('encoding %spayload for zero INTEGER' % (
|
| 174 |
+
self.supportCompactZero and 'no ' or ''
|
| 175 |
+
))
|
| 176 |
+
|
| 177 |
+
# de-facto way to encode zero
|
| 178 |
+
if self.supportCompactZero:
|
| 179 |
+
return (), False, False
|
| 180 |
+
else:
|
| 181 |
+
return (0,), False, False
|
| 182 |
+
|
| 183 |
+
return to_bytes(int(value), signed=True), False, True
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class BitStringEncoder(AbstractItemEncoder):
|
| 187 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 188 |
+
if asn1Spec is not None:
|
| 189 |
+
# TODO: try to avoid ASN.1 schema instantiation
|
| 190 |
+
value = asn1Spec.clone(value)
|
| 191 |
+
|
| 192 |
+
valueLength = len(value)
|
| 193 |
+
if valueLength % 8:
|
| 194 |
+
alignedValue = value << (8 - valueLength % 8)
|
| 195 |
+
else:
|
| 196 |
+
alignedValue = value
|
| 197 |
+
|
| 198 |
+
maxChunkSize = options.get('maxChunkSize', 0)
|
| 199 |
+
if not maxChunkSize or len(alignedValue) <= maxChunkSize * 8:
|
| 200 |
+
substrate = alignedValue.asOctets()
|
| 201 |
+
return bytes((len(substrate) * 8 - valueLength,)) + substrate, False, True
|
| 202 |
+
|
| 203 |
+
if LOG:
|
| 204 |
+
LOG('encoding into up to %s-octet chunks' % maxChunkSize)
|
| 205 |
+
|
| 206 |
+
baseTag = value.tagSet.baseTag
|
| 207 |
+
|
| 208 |
+
# strip off explicit tags
|
| 209 |
+
if baseTag:
|
| 210 |
+
tagSet = tag.TagSet(baseTag, baseTag)
|
| 211 |
+
|
| 212 |
+
else:
|
| 213 |
+
tagSet = tag.TagSet()
|
| 214 |
+
|
| 215 |
+
alignedValue = alignedValue.clone(tagSet=tagSet)
|
| 216 |
+
|
| 217 |
+
stop = 0
|
| 218 |
+
substrate = b''
|
| 219 |
+
while stop < valueLength:
|
| 220 |
+
start = stop
|
| 221 |
+
stop = min(start + maxChunkSize * 8, valueLength)
|
| 222 |
+
substrate += encodeFun(alignedValue[start:stop], asn1Spec, **options)
|
| 223 |
+
|
| 224 |
+
return substrate, True, True
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class OctetStringEncoder(AbstractItemEncoder):
|
| 228 |
+
|
| 229 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 230 |
+
|
| 231 |
+
if asn1Spec is None:
|
| 232 |
+
substrate = value.asOctets()
|
| 233 |
+
|
| 234 |
+
elif not isinstance(value, bytes):
|
| 235 |
+
substrate = asn1Spec.clone(value).asOctets()
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
substrate = value
|
| 239 |
+
|
| 240 |
+
maxChunkSize = options.get('maxChunkSize', 0)
|
| 241 |
+
|
| 242 |
+
if not maxChunkSize or len(substrate) <= maxChunkSize:
|
| 243 |
+
return substrate, False, True
|
| 244 |
+
|
| 245 |
+
if LOG:
|
| 246 |
+
LOG('encoding into up to %s-octet chunks' % maxChunkSize)
|
| 247 |
+
|
| 248 |
+
# strip off explicit tags for inner chunks
|
| 249 |
+
|
| 250 |
+
if asn1Spec is None:
|
| 251 |
+
baseTag = value.tagSet.baseTag
|
| 252 |
+
|
| 253 |
+
# strip off explicit tags
|
| 254 |
+
if baseTag:
|
| 255 |
+
tagSet = tag.TagSet(baseTag, baseTag)
|
| 256 |
+
|
| 257 |
+
else:
|
| 258 |
+
tagSet = tag.TagSet()
|
| 259 |
+
|
| 260 |
+
asn1Spec = value.clone(tagSet=tagSet)
|
| 261 |
+
|
| 262 |
+
elif not isinstance(value, bytes):
|
| 263 |
+
baseTag = asn1Spec.tagSet.baseTag
|
| 264 |
+
|
| 265 |
+
# strip off explicit tags
|
| 266 |
+
if baseTag:
|
| 267 |
+
tagSet = tag.TagSet(baseTag, baseTag)
|
| 268 |
+
|
| 269 |
+
else:
|
| 270 |
+
tagSet = tag.TagSet()
|
| 271 |
+
|
| 272 |
+
asn1Spec = asn1Spec.clone(tagSet=tagSet)
|
| 273 |
+
|
| 274 |
+
pos = 0
|
| 275 |
+
substrate = b''
|
| 276 |
+
|
| 277 |
+
while True:
|
| 278 |
+
chunk = value[pos:pos + maxChunkSize]
|
| 279 |
+
if not chunk:
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
substrate += encodeFun(chunk, asn1Spec, **options)
|
| 283 |
+
pos += maxChunkSize
|
| 284 |
+
|
| 285 |
+
return substrate, True, True
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class NullEncoder(AbstractItemEncoder):
|
| 289 |
+
supportIndefLenMode = False
|
| 290 |
+
|
| 291 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 292 |
+
return b'', False, True
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class ObjectIdentifierEncoder(AbstractItemEncoder):
|
| 296 |
+
supportIndefLenMode = False
|
| 297 |
+
|
| 298 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 299 |
+
if asn1Spec is not None:
|
| 300 |
+
value = asn1Spec.clone(value)
|
| 301 |
+
|
| 302 |
+
oid = value.asTuple()
|
| 303 |
+
|
| 304 |
+
# Build the first pair
|
| 305 |
+
try:
|
| 306 |
+
first = oid[0]
|
| 307 |
+
second = oid[1]
|
| 308 |
+
|
| 309 |
+
except IndexError:
|
| 310 |
+
raise error.PyAsn1Error('Short OID %s' % (value,))
|
| 311 |
+
|
| 312 |
+
if 0 <= second <= 39:
|
| 313 |
+
if first == 1:
|
| 314 |
+
oid = (second + 40,) + oid[2:]
|
| 315 |
+
elif first == 0:
|
| 316 |
+
oid = (second,) + oid[2:]
|
| 317 |
+
elif first == 2:
|
| 318 |
+
oid = (second + 80,) + oid[2:]
|
| 319 |
+
else:
|
| 320 |
+
raise error.PyAsn1Error('Impossible first/second arcs at %s' % (value,))
|
| 321 |
+
|
| 322 |
+
elif first == 2:
|
| 323 |
+
oid = (second + 80,) + oid[2:]
|
| 324 |
+
|
| 325 |
+
else:
|
| 326 |
+
raise error.PyAsn1Error('Impossible first/second arcs at %s' % (value,))
|
| 327 |
+
|
| 328 |
+
octets = ()
|
| 329 |
+
|
| 330 |
+
# Cycle through subIds
|
| 331 |
+
for subOid in oid:
|
| 332 |
+
if 0 <= subOid <= 127:
|
| 333 |
+
# Optimize for the common case
|
| 334 |
+
octets += (subOid,)
|
| 335 |
+
|
| 336 |
+
elif subOid > 127:
|
| 337 |
+
# Pack large Sub-Object IDs
|
| 338 |
+
res = (subOid & 0x7f,)
|
| 339 |
+
subOid >>= 7
|
| 340 |
+
|
| 341 |
+
while subOid:
|
| 342 |
+
res = (0x80 | (subOid & 0x7f),) + res
|
| 343 |
+
subOid >>= 7
|
| 344 |
+
|
| 345 |
+
# Add packed Sub-Object ID to resulted Object ID
|
| 346 |
+
octets += res
|
| 347 |
+
|
| 348 |
+
else:
|
| 349 |
+
raise error.PyAsn1Error('Negative OID arc %s at %s' % (subOid, value))
|
| 350 |
+
|
| 351 |
+
return octets, False, False
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class RelativeOIDEncoder(AbstractItemEncoder):
|
| 355 |
+
supportIndefLenMode = False
|
| 356 |
+
|
| 357 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 358 |
+
if asn1Spec is not None:
|
| 359 |
+
value = asn1Spec.clone(value)
|
| 360 |
+
|
| 361 |
+
octets = ()
|
| 362 |
+
|
| 363 |
+
# Cycle through subIds
|
| 364 |
+
for subOid in value.asTuple():
|
| 365 |
+
if 0 <= subOid <= 127:
|
| 366 |
+
# Optimize for the common case
|
| 367 |
+
octets += (subOid,)
|
| 368 |
+
|
| 369 |
+
elif subOid > 127:
|
| 370 |
+
# Pack large Sub-Object IDs
|
| 371 |
+
res = (subOid & 0x7f,)
|
| 372 |
+
subOid >>= 7
|
| 373 |
+
|
| 374 |
+
while subOid:
|
| 375 |
+
res = (0x80 | (subOid & 0x7f),) + res
|
| 376 |
+
subOid >>= 7
|
| 377 |
+
|
| 378 |
+
# Add packed Sub-Object ID to resulted RELATIVE-OID
|
| 379 |
+
octets += res
|
| 380 |
+
|
| 381 |
+
else:
|
| 382 |
+
raise error.PyAsn1Error('Negative RELATIVE-OID arc %s at %s' % (subOid, value))
|
| 383 |
+
|
| 384 |
+
return octets, False, False
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class RealEncoder(AbstractItemEncoder):
|
| 388 |
+
supportIndefLenMode = False
|
| 389 |
+
binEncBase = 2 # set to None to choose encoding base automatically
|
| 390 |
+
|
| 391 |
+
@staticmethod
|
| 392 |
+
def _dropFloatingPoint(m, encbase, e):
|
| 393 |
+
ms, es = 1, 1
|
| 394 |
+
if m < 0:
|
| 395 |
+
ms = -1 # mantissa sign
|
| 396 |
+
|
| 397 |
+
if e < 0:
|
| 398 |
+
es = -1 # exponent sign
|
| 399 |
+
|
| 400 |
+
m *= ms
|
| 401 |
+
|
| 402 |
+
if encbase == 8:
|
| 403 |
+
m *= 2 ** (abs(e) % 3 * es)
|
| 404 |
+
e = abs(e) // 3 * es
|
| 405 |
+
|
| 406 |
+
elif encbase == 16:
|
| 407 |
+
m *= 2 ** (abs(e) % 4 * es)
|
| 408 |
+
e = abs(e) // 4 * es
|
| 409 |
+
|
| 410 |
+
while True:
|
| 411 |
+
if int(m) != m:
|
| 412 |
+
m *= encbase
|
| 413 |
+
e -= 1
|
| 414 |
+
continue
|
| 415 |
+
break
|
| 416 |
+
|
| 417 |
+
return ms, int(m), encbase, e
|
| 418 |
+
|
| 419 |
+
def _chooseEncBase(self, value):
|
| 420 |
+
m, b, e = value
|
| 421 |
+
encBase = [2, 8, 16]
|
| 422 |
+
if value.binEncBase in encBase:
|
| 423 |
+
return self._dropFloatingPoint(m, value.binEncBase, e)
|
| 424 |
+
|
| 425 |
+
elif self.binEncBase in encBase:
|
| 426 |
+
return self._dropFloatingPoint(m, self.binEncBase, e)
|
| 427 |
+
|
| 428 |
+
# auto choosing base 2/8/16
|
| 429 |
+
mantissa = [m, m, m]
|
| 430 |
+
exponent = [e, e, e]
|
| 431 |
+
sign = 1
|
| 432 |
+
encbase = 2
|
| 433 |
+
e = float('inf')
|
| 434 |
+
|
| 435 |
+
for i in range(3):
|
| 436 |
+
(sign,
|
| 437 |
+
mantissa[i],
|
| 438 |
+
encBase[i],
|
| 439 |
+
exponent[i]) = self._dropFloatingPoint(mantissa[i], encBase[i], exponent[i])
|
| 440 |
+
|
| 441 |
+
if abs(exponent[i]) < abs(e) or (abs(exponent[i]) == abs(e) and mantissa[i] < m):
|
| 442 |
+
e = exponent[i]
|
| 443 |
+
m = int(mantissa[i])
|
| 444 |
+
encbase = encBase[i]
|
| 445 |
+
|
| 446 |
+
if LOG:
|
| 447 |
+
LOG('automatically chosen REAL encoding base %s, sign %s, mantissa %s, '
|
| 448 |
+
'exponent %s' % (encbase, sign, m, e))
|
| 449 |
+
|
| 450 |
+
return sign, m, encbase, e
|
| 451 |
+
|
| 452 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 453 |
+
if asn1Spec is not None:
|
| 454 |
+
value = asn1Spec.clone(value)
|
| 455 |
+
|
| 456 |
+
if value.isPlusInf:
|
| 457 |
+
return (0x40,), False, False
|
| 458 |
+
|
| 459 |
+
if value.isMinusInf:
|
| 460 |
+
return (0x41,), False, False
|
| 461 |
+
|
| 462 |
+
m, b, e = value
|
| 463 |
+
|
| 464 |
+
if not m:
|
| 465 |
+
return b'', False, True
|
| 466 |
+
|
| 467 |
+
if b == 10:
|
| 468 |
+
if LOG:
|
| 469 |
+
LOG('encoding REAL into character form')
|
| 470 |
+
|
| 471 |
+
return b'\x03%dE%s%d' % (m, e == 0 and b'+' or b'', e), False, True
|
| 472 |
+
|
| 473 |
+
elif b == 2:
|
| 474 |
+
fo = 0x80 # binary encoding
|
| 475 |
+
ms, m, encbase, e = self._chooseEncBase(value)
|
| 476 |
+
|
| 477 |
+
if ms < 0: # mantissa sign
|
| 478 |
+
fo |= 0x40 # sign bit
|
| 479 |
+
|
| 480 |
+
# exponent & mantissa normalization
|
| 481 |
+
if encbase == 2:
|
| 482 |
+
while m & 0x1 == 0:
|
| 483 |
+
m >>= 1
|
| 484 |
+
e += 1
|
| 485 |
+
|
| 486 |
+
elif encbase == 8:
|
| 487 |
+
while m & 0x7 == 0:
|
| 488 |
+
m >>= 3
|
| 489 |
+
e += 1
|
| 490 |
+
fo |= 0x10
|
| 491 |
+
|
| 492 |
+
else: # encbase = 16
|
| 493 |
+
while m & 0xf == 0:
|
| 494 |
+
m >>= 4
|
| 495 |
+
e += 1
|
| 496 |
+
fo |= 0x20
|
| 497 |
+
|
| 498 |
+
sf = 0 # scale factor
|
| 499 |
+
|
| 500 |
+
while m & 0x1 == 0:
|
| 501 |
+
m >>= 1
|
| 502 |
+
sf += 1
|
| 503 |
+
|
| 504 |
+
if sf > 3:
|
| 505 |
+
raise error.PyAsn1Error('Scale factor overflow') # bug if raised
|
| 506 |
+
|
| 507 |
+
fo |= sf << 2
|
| 508 |
+
eo = b''
|
| 509 |
+
if e == 0 or e == -1:
|
| 510 |
+
eo = bytes((e & 0xff,))
|
| 511 |
+
|
| 512 |
+
else:
|
| 513 |
+
while e not in (0, -1):
|
| 514 |
+
eo = bytes((e & 0xff,)) + eo
|
| 515 |
+
e >>= 8
|
| 516 |
+
|
| 517 |
+
if e == 0 and eo and eo[0] & 0x80:
|
| 518 |
+
eo = bytes((0,)) + eo
|
| 519 |
+
|
| 520 |
+
if e == -1 and eo and not (eo[0] & 0x80):
|
| 521 |
+
eo = bytes((0xff,)) + eo
|
| 522 |
+
|
| 523 |
+
n = len(eo)
|
| 524 |
+
if n > 0xff:
|
| 525 |
+
raise error.PyAsn1Error('Real exponent overflow')
|
| 526 |
+
|
| 527 |
+
if n == 1:
|
| 528 |
+
pass
|
| 529 |
+
|
| 530 |
+
elif n == 2:
|
| 531 |
+
fo |= 1
|
| 532 |
+
|
| 533 |
+
elif n == 3:
|
| 534 |
+
fo |= 2
|
| 535 |
+
|
| 536 |
+
else:
|
| 537 |
+
fo |= 3
|
| 538 |
+
eo = bytes((n & 0xff,)) + eo
|
| 539 |
+
|
| 540 |
+
po = b''
|
| 541 |
+
|
| 542 |
+
while m:
|
| 543 |
+
po = bytes((m & 0xff,)) + po
|
| 544 |
+
m >>= 8
|
| 545 |
+
|
| 546 |
+
substrate = bytes((fo,)) + eo + po
|
| 547 |
+
|
| 548 |
+
return substrate, False, True
|
| 549 |
+
|
| 550 |
+
else:
|
| 551 |
+
raise error.PyAsn1Error('Prohibited Real base %s' % b)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class SequenceEncoder(AbstractItemEncoder):
|
| 555 |
+
omitEmptyOptionals = False
|
| 556 |
+
|
| 557 |
+
# TODO: handling three flavors of input is too much -- split over codecs
|
| 558 |
+
|
| 559 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 560 |
+
|
| 561 |
+
substrate = b''
|
| 562 |
+
|
| 563 |
+
omitEmptyOptionals = options.get(
|
| 564 |
+
'omitEmptyOptionals', self.omitEmptyOptionals)
|
| 565 |
+
|
| 566 |
+
if LOG:
|
| 567 |
+
LOG('%sencoding empty OPTIONAL components' % (
|
| 568 |
+
omitEmptyOptionals and 'not ' or ''))
|
| 569 |
+
|
| 570 |
+
if asn1Spec is None:
|
| 571 |
+
# instance of ASN.1 schema
|
| 572 |
+
inconsistency = value.isInconsistent
|
| 573 |
+
if inconsistency:
|
| 574 |
+
raise error.PyAsn1Error(
|
| 575 |
+
f"ASN.1 object {value.__class__.__name__} is inconsistent")
|
| 576 |
+
|
| 577 |
+
namedTypes = value.componentType
|
| 578 |
+
|
| 579 |
+
for idx, component in enumerate(value.values()):
|
| 580 |
+
if namedTypes:
|
| 581 |
+
namedType = namedTypes[idx]
|
| 582 |
+
|
| 583 |
+
if namedType.isOptional and not component.isValue:
|
| 584 |
+
if LOG:
|
| 585 |
+
LOG('not encoding OPTIONAL component %r' % (namedType,))
|
| 586 |
+
continue
|
| 587 |
+
|
| 588 |
+
if namedType.isDefaulted and component == namedType.asn1Object:
|
| 589 |
+
if LOG:
|
| 590 |
+
LOG('not encoding DEFAULT component %r' % (namedType,))
|
| 591 |
+
continue
|
| 592 |
+
|
| 593 |
+
if omitEmptyOptionals:
|
| 594 |
+
options.update(ifNotEmpty=namedType.isOptional)
|
| 595 |
+
|
| 596 |
+
# wrap open type blob if needed
|
| 597 |
+
if namedTypes and namedType.openType:
|
| 598 |
+
|
| 599 |
+
wrapType = namedType.asn1Object
|
| 600 |
+
|
| 601 |
+
if wrapType.typeId in (
|
| 602 |
+
univ.SetOf.typeId, univ.SequenceOf.typeId):
|
| 603 |
+
|
| 604 |
+
substrate += encodeFun(
|
| 605 |
+
component, asn1Spec,
|
| 606 |
+
**dict(options, wrapType=wrapType.componentType))
|
| 607 |
+
|
| 608 |
+
else:
|
| 609 |
+
chunk = encodeFun(component, asn1Spec, **options)
|
| 610 |
+
|
| 611 |
+
if wrapType.isSameTypeWith(component):
|
| 612 |
+
substrate += chunk
|
| 613 |
+
|
| 614 |
+
else:
|
| 615 |
+
substrate += encodeFun(chunk, wrapType, **options)
|
| 616 |
+
|
| 617 |
+
if LOG:
|
| 618 |
+
LOG('wrapped with wrap type %r' % (wrapType,))
|
| 619 |
+
|
| 620 |
+
else:
|
| 621 |
+
substrate += encodeFun(component, asn1Spec, **options)
|
| 622 |
+
|
| 623 |
+
else:
|
| 624 |
+
# bare Python value + ASN.1 schema
|
| 625 |
+
for idx, namedType in enumerate(asn1Spec.componentType.namedTypes):
|
| 626 |
+
|
| 627 |
+
try:
|
| 628 |
+
component = value[namedType.name]
|
| 629 |
+
|
| 630 |
+
except KeyError:
|
| 631 |
+
raise error.PyAsn1Error('Component name "%s" not found in %r' % (
|
| 632 |
+
namedType.name, value))
|
| 633 |
+
|
| 634 |
+
if namedType.isOptional and namedType.name not in value:
|
| 635 |
+
if LOG:
|
| 636 |
+
LOG('not encoding OPTIONAL component %r' % (namedType,))
|
| 637 |
+
continue
|
| 638 |
+
|
| 639 |
+
if namedType.isDefaulted and component == namedType.asn1Object:
|
| 640 |
+
if LOG:
|
| 641 |
+
LOG('not encoding DEFAULT component %r' % (namedType,))
|
| 642 |
+
continue
|
| 643 |
+
|
| 644 |
+
if omitEmptyOptionals:
|
| 645 |
+
options.update(ifNotEmpty=namedType.isOptional)
|
| 646 |
+
|
| 647 |
+
componentSpec = namedType.asn1Object
|
| 648 |
+
|
| 649 |
+
# wrap open type blob if needed
|
| 650 |
+
if namedType.openType:
|
| 651 |
+
|
| 652 |
+
if componentSpec.typeId in (
|
| 653 |
+
univ.SetOf.typeId, univ.SequenceOf.typeId):
|
| 654 |
+
|
| 655 |
+
substrate += encodeFun(
|
| 656 |
+
component, componentSpec,
|
| 657 |
+
**dict(options, wrapType=componentSpec.componentType))
|
| 658 |
+
|
| 659 |
+
else:
|
| 660 |
+
chunk = encodeFun(component, componentSpec, **options)
|
| 661 |
+
|
| 662 |
+
if componentSpec.isSameTypeWith(component):
|
| 663 |
+
substrate += chunk
|
| 664 |
+
|
| 665 |
+
else:
|
| 666 |
+
substrate += encodeFun(chunk, componentSpec, **options)
|
| 667 |
+
|
| 668 |
+
if LOG:
|
| 669 |
+
LOG('wrapped with wrap type %r' % (componentSpec,))
|
| 670 |
+
|
| 671 |
+
else:
|
| 672 |
+
substrate += encodeFun(component, componentSpec, **options)
|
| 673 |
+
|
| 674 |
+
return substrate, True, True
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class SequenceOfEncoder(AbstractItemEncoder):
|
| 678 |
+
def _encodeComponents(self, value, asn1Spec, encodeFun, **options):
|
| 679 |
+
|
| 680 |
+
if asn1Spec is None:
|
| 681 |
+
inconsistency = value.isInconsistent
|
| 682 |
+
if inconsistency:
|
| 683 |
+
raise error.PyAsn1Error(
|
| 684 |
+
f"ASN.1 object {value.__class__.__name__} is inconsistent")
|
| 685 |
+
|
| 686 |
+
else:
|
| 687 |
+
asn1Spec = asn1Spec.componentType
|
| 688 |
+
|
| 689 |
+
chunks = []
|
| 690 |
+
|
| 691 |
+
wrapType = options.pop('wrapType', None)
|
| 692 |
+
|
| 693 |
+
for idx, component in enumerate(value):
|
| 694 |
+
chunk = encodeFun(component, asn1Spec, **options)
|
| 695 |
+
|
| 696 |
+
if (wrapType is not None and
|
| 697 |
+
not wrapType.isSameTypeWith(component)):
|
| 698 |
+
# wrap encoded value with wrapper container (e.g. ANY)
|
| 699 |
+
chunk = encodeFun(chunk, wrapType, **options)
|
| 700 |
+
|
| 701 |
+
if LOG:
|
| 702 |
+
LOG('wrapped with wrap type %r' % (wrapType,))
|
| 703 |
+
|
| 704 |
+
chunks.append(chunk)
|
| 705 |
+
|
| 706 |
+
return chunks
|
| 707 |
+
|
| 708 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 709 |
+
chunks = self._encodeComponents(
|
| 710 |
+
value, asn1Spec, encodeFun, **options)
|
| 711 |
+
|
| 712 |
+
return b''.join(chunks), True, True
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
class ChoiceEncoder(AbstractItemEncoder):
|
| 716 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 717 |
+
if asn1Spec is None:
|
| 718 |
+
component = value.getComponent()
|
| 719 |
+
else:
|
| 720 |
+
names = [namedType.name for namedType in asn1Spec.componentType.namedTypes
|
| 721 |
+
if namedType.name in value]
|
| 722 |
+
if len(names) != 1:
|
| 723 |
+
raise error.PyAsn1Error('%s components for Choice at %r' % (len(names) and 'Multiple ' or 'None ', value))
|
| 724 |
+
|
| 725 |
+
name = names[0]
|
| 726 |
+
|
| 727 |
+
component = value[name]
|
| 728 |
+
asn1Spec = asn1Spec[name]
|
| 729 |
+
|
| 730 |
+
return encodeFun(component, asn1Spec, **options), True, True
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
class AnyEncoder(OctetStringEncoder):
|
| 734 |
+
def encodeValue(self, value, asn1Spec, encodeFun, **options):
|
| 735 |
+
if asn1Spec is None:
|
| 736 |
+
value = value.asOctets()
|
| 737 |
+
elif not isinstance(value, bytes):
|
| 738 |
+
value = asn1Spec.clone(value).asOctets()
|
| 739 |
+
|
| 740 |
+
return value, not options.get('defMode', True), True
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
TAG_MAP = {
|
| 744 |
+
eoo.endOfOctets.tagSet: EndOfOctetsEncoder(),
|
| 745 |
+
univ.Boolean.tagSet: BooleanEncoder(),
|
| 746 |
+
univ.Integer.tagSet: IntegerEncoder(),
|
| 747 |
+
univ.BitString.tagSet: BitStringEncoder(),
|
| 748 |
+
univ.OctetString.tagSet: OctetStringEncoder(),
|
| 749 |
+
univ.Null.tagSet: NullEncoder(),
|
| 750 |
+
univ.ObjectIdentifier.tagSet: ObjectIdentifierEncoder(),
|
| 751 |
+
univ.RelativeOID.tagSet: RelativeOIDEncoder(),
|
| 752 |
+
univ.Enumerated.tagSet: IntegerEncoder(),
|
| 753 |
+
univ.Real.tagSet: RealEncoder(),
|
| 754 |
+
# Sequence & Set have same tags as SequenceOf & SetOf
|
| 755 |
+
univ.SequenceOf.tagSet: SequenceOfEncoder(),
|
| 756 |
+
univ.SetOf.tagSet: SequenceOfEncoder(),
|
| 757 |
+
univ.Choice.tagSet: ChoiceEncoder(),
|
| 758 |
+
# character string types
|
| 759 |
+
char.UTF8String.tagSet: OctetStringEncoder(),
|
| 760 |
+
char.NumericString.tagSet: OctetStringEncoder(),
|
| 761 |
+
char.PrintableString.tagSet: OctetStringEncoder(),
|
| 762 |
+
char.TeletexString.tagSet: OctetStringEncoder(),
|
| 763 |
+
char.VideotexString.tagSet: OctetStringEncoder(),
|
| 764 |
+
char.IA5String.tagSet: OctetStringEncoder(),
|
| 765 |
+
char.GraphicString.tagSet: OctetStringEncoder(),
|
| 766 |
+
char.VisibleString.tagSet: OctetStringEncoder(),
|
| 767 |
+
char.GeneralString.tagSet: OctetStringEncoder(),
|
| 768 |
+
char.UniversalString.tagSet: OctetStringEncoder(),
|
| 769 |
+
char.BMPString.tagSet: OctetStringEncoder(),
|
| 770 |
+
# useful types
|
| 771 |
+
useful.ObjectDescriptor.tagSet: OctetStringEncoder(),
|
| 772 |
+
useful.GeneralizedTime.tagSet: OctetStringEncoder(),
|
| 773 |
+
useful.UTCTime.tagSet: OctetStringEncoder()
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
# Put in ambiguous & non-ambiguous types for faster codec lookup
|
| 777 |
+
TYPE_MAP = {
|
| 778 |
+
univ.Boolean.typeId: BooleanEncoder(),
|
| 779 |
+
univ.Integer.typeId: IntegerEncoder(),
|
| 780 |
+
univ.BitString.typeId: BitStringEncoder(),
|
| 781 |
+
univ.OctetString.typeId: OctetStringEncoder(),
|
| 782 |
+
univ.Null.typeId: NullEncoder(),
|
| 783 |
+
univ.ObjectIdentifier.typeId: ObjectIdentifierEncoder(),
|
| 784 |
+
univ.RelativeOID.typeId: RelativeOIDEncoder(),
|
| 785 |
+
univ.Enumerated.typeId: IntegerEncoder(),
|
| 786 |
+
univ.Real.typeId: RealEncoder(),
|
| 787 |
+
# Sequence & Set have same tags as SequenceOf & SetOf
|
| 788 |
+
univ.Set.typeId: SequenceEncoder(),
|
| 789 |
+
univ.SetOf.typeId: SequenceOfEncoder(),
|
| 790 |
+
univ.Sequence.typeId: SequenceEncoder(),
|
| 791 |
+
univ.SequenceOf.typeId: SequenceOfEncoder(),
|
| 792 |
+
univ.Choice.typeId: ChoiceEncoder(),
|
| 793 |
+
univ.Any.typeId: AnyEncoder(),
|
| 794 |
+
# character string types
|
| 795 |
+
char.UTF8String.typeId: OctetStringEncoder(),
|
| 796 |
+
char.NumericString.typeId: OctetStringEncoder(),
|
| 797 |
+
char.PrintableString.typeId: OctetStringEncoder(),
|
| 798 |
+
char.TeletexString.typeId: OctetStringEncoder(),
|
| 799 |
+
char.VideotexString.typeId: OctetStringEncoder(),
|
| 800 |
+
char.IA5String.typeId: OctetStringEncoder(),
|
| 801 |
+
char.GraphicString.typeId: OctetStringEncoder(),
|
| 802 |
+
char.VisibleString.typeId: OctetStringEncoder(),
|
| 803 |
+
char.GeneralString.typeId: OctetStringEncoder(),
|
| 804 |
+
char.UniversalString.typeId: OctetStringEncoder(),
|
| 805 |
+
char.BMPString.typeId: OctetStringEncoder(),
|
| 806 |
+
# useful types
|
| 807 |
+
useful.ObjectDescriptor.typeId: OctetStringEncoder(),
|
| 808 |
+
useful.GeneralizedTime.typeId: OctetStringEncoder(),
|
| 809 |
+
useful.UTCTime.typeId: OctetStringEncoder()
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
class SingleItemEncoder(object):
|
| 814 |
+
fixedDefLengthMode = None
|
| 815 |
+
fixedChunkSize = None
|
| 816 |
+
|
| 817 |
+
TAG_MAP = TAG_MAP
|
| 818 |
+
TYPE_MAP = TYPE_MAP
|
| 819 |
+
|
| 820 |
+
def __init__(self, tagMap=_MISSING, typeMap=_MISSING, **ignored):
|
| 821 |
+
self._tagMap = tagMap if tagMap is not _MISSING else self.TAG_MAP
|
| 822 |
+
self._typeMap = typeMap if typeMap is not _MISSING else self.TYPE_MAP
|
| 823 |
+
|
| 824 |
+
def __call__(self, value, asn1Spec=None, **options):
|
| 825 |
+
try:
|
| 826 |
+
if asn1Spec is None:
|
| 827 |
+
typeId = value.typeId
|
| 828 |
+
else:
|
| 829 |
+
typeId = asn1Spec.typeId
|
| 830 |
+
|
| 831 |
+
except AttributeError:
|
| 832 |
+
raise error.PyAsn1Error('Value %r is not ASN.1 type instance '
|
| 833 |
+
'and "asn1Spec" not given' % (value,))
|
| 834 |
+
|
| 835 |
+
if LOG:
|
| 836 |
+
LOG('encoder called in %sdef mode, chunk size %s for type %s, '
|
| 837 |
+
'value:\n%s' % (not options.get('defMode', True) and 'in' or '',
|
| 838 |
+
options.get('maxChunkSize', 0),
|
| 839 |
+
asn1Spec is None and value.prettyPrintType() or
|
| 840 |
+
asn1Spec.prettyPrintType(), value))
|
| 841 |
+
|
| 842 |
+
if self.fixedDefLengthMode is not None:
|
| 843 |
+
options.update(defMode=self.fixedDefLengthMode)
|
| 844 |
+
|
| 845 |
+
if self.fixedChunkSize is not None:
|
| 846 |
+
options.update(maxChunkSize=self.fixedChunkSize)
|
| 847 |
+
|
| 848 |
+
try:
|
| 849 |
+
concreteEncoder = self._typeMap[typeId]
|
| 850 |
+
|
| 851 |
+
if LOG:
|
| 852 |
+
LOG('using value codec %s chosen by type ID '
|
| 853 |
+
'%s' % (concreteEncoder.__class__.__name__, typeId))
|
| 854 |
+
|
| 855 |
+
except KeyError:
|
| 856 |
+
if asn1Spec is None:
|
| 857 |
+
tagSet = value.tagSet
|
| 858 |
+
else:
|
| 859 |
+
tagSet = asn1Spec.tagSet
|
| 860 |
+
|
| 861 |
+
# use base type for codec lookup to recover untagged types
|
| 862 |
+
baseTagSet = tag.TagSet(tagSet.baseTag, tagSet.baseTag)
|
| 863 |
+
|
| 864 |
+
try:
|
| 865 |
+
concreteEncoder = self._tagMap[baseTagSet]
|
| 866 |
+
|
| 867 |
+
except KeyError:
|
| 868 |
+
raise error.PyAsn1Error('No encoder for %r (%s)' % (value, tagSet))
|
| 869 |
+
|
| 870 |
+
if LOG:
|
| 871 |
+
LOG('using value codec %s chosen by tagSet '
|
| 872 |
+
'%s' % (concreteEncoder.__class__.__name__, tagSet))
|
| 873 |
+
|
| 874 |
+
substrate = concreteEncoder.encode(value, asn1Spec, self, **options)
|
| 875 |
+
|
| 876 |
+
if LOG:
|
| 877 |
+
LOG('codec %s built %s octets of substrate: %s\nencoder '
|
| 878 |
+
'completed' % (concreteEncoder, len(substrate),
|
| 879 |
+
debug.hexdump(substrate)))
|
| 880 |
+
|
| 881 |
+
return substrate
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
class Encoder(object):
|
| 885 |
+
SINGLE_ITEM_ENCODER = SingleItemEncoder
|
| 886 |
+
|
| 887 |
+
def __init__(self, tagMap=_MISSING, typeMap=_MISSING, **options):
|
| 888 |
+
self._singleItemEncoder = self.SINGLE_ITEM_ENCODER(
|
| 889 |
+
tagMap=tagMap, typeMap=typeMap, **options
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
def __call__(self, pyObject, asn1Spec=None, **options):
|
| 893 |
+
return self._singleItemEncoder(
|
| 894 |
+
pyObject, asn1Spec=asn1Spec, **options)
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
#: Turns ASN.1 object into BER octet stream.
|
| 898 |
+
#:
|
| 899 |
+
#: Takes any ASN.1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative)
|
| 900 |
+
#: walks all its components recursively and produces a BER octet stream.
|
| 901 |
+
#:
|
| 902 |
+
#: Parameters
|
| 903 |
+
#: ----------
|
| 904 |
+
#: value: either a Python or pyasn1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative)
|
| 905 |
+
#: A Python or pyasn1 object to encode. If Python object is given, `asnSpec`
|
| 906 |
+
#: parameter is required to guide the encoding process.
|
| 907 |
+
#:
|
| 908 |
+
#: Keyword Args
|
| 909 |
+
#: ------------
|
| 910 |
+
#: asn1Spec:
|
| 911 |
+
#: Optional ASN.1 schema or value object e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative
|
| 912 |
+
#:
|
| 913 |
+
#: defMode: :py:class:`bool`
|
| 914 |
+
#: If :obj:`False`, produces indefinite length encoding
|
| 915 |
+
#:
|
| 916 |
+
#: maxChunkSize: :py:class:`int`
|
| 917 |
+
#: Maximum chunk size in chunked encoding mode (0 denotes unlimited chunk size)
|
| 918 |
+
#:
|
| 919 |
+
#: Returns
|
| 920 |
+
#: -------
|
| 921 |
+
#: : :py:class:`bytes`
|
| 922 |
+
#: Given ASN.1 object encoded into BER octetstream
|
| 923 |
+
#:
|
| 924 |
+
#: Raises
|
| 925 |
+
#: ------
|
| 926 |
+
#: ~pyasn1.error.PyAsn1Error
|
| 927 |
+
#: On encoding errors
|
| 928 |
+
#:
|
| 929 |
+
#: Examples
|
| 930 |
+
#: --------
|
| 931 |
+
#: Encode Python value into BER with ASN.1 schema
|
| 932 |
+
#:
|
| 933 |
+
#: .. code-block:: pycon
|
| 934 |
+
#:
|
| 935 |
+
#: >>> seq = SequenceOf(componentType=Integer())
|
| 936 |
+
#: >>> encode([1, 2, 3], asn1Spec=seq)
|
| 937 |
+
#: b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03'
|
| 938 |
+
#:
|
| 939 |
+
#: Encode ASN.1 value object into BER
|
| 940 |
+
#:
|
| 941 |
+
#: .. code-block:: pycon
|
| 942 |
+
#:
|
| 943 |
+
#: >>> seq = SequenceOf(componentType=Integer())
|
| 944 |
+
#: >>> seq.extend([1, 2, 3])
|
| 945 |
+
#: >>> encode(seq)
|
| 946 |
+
#: b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03'
|
| 947 |
+
#:
|
| 948 |
+
encode = Encoder()
|
| 949 |
+
|
| 950 |
+
def __getattr__(attr: str):
|
| 951 |
+
if newAttr := {"tagMap": "TAG_MAP", "typeMap": "TYPE_MAP"}.get(attr):
|
| 952 |
+
warnings.warn(f"{attr} is deprecated. Please use {newAttr} instead.", DeprecationWarning)
|
| 953 |
+
return globals()[newAttr]
|
| 954 |
+
raise AttributeError(attr)
|
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/eoo.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This file is part of pyasn1 software.
|
| 3 |
+
#
|
| 4 |
+
# Copyright (c) 2005-2020, Ilya Etingof <etingof@gmail.com>
|
| 5 |
+
# License: https://pyasn1.readthedocs.io/en/latest/license.html
|
| 6 |
+
#
|
| 7 |
+
from pyasn1.type import base
|
| 8 |
+
from pyasn1.type import tag
|
| 9 |
+
|
| 10 |
+
__all__ = ['endOfOctets']
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EndOfOctets(base.SimpleAsn1Type):
|
| 14 |
+
defaultValue = 0
|
| 15 |
+
tagSet = tag.initTagSet(
|
| 16 |
+
tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x00)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
_instance = None
|
| 20 |
+
|
| 21 |
+
def __new__(cls, *args, **kwargs):
|
| 22 |
+
if cls._instance is None:
|
| 23 |
+
cls._instance = object.__new__(cls, *args, **kwargs)
|
| 24 |
+
|
| 25 |
+
return cls._instance
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
endOfOctets = EndOfOctets()
|
.venv/lib/python3.11/site-packages/pyasn1/codec/native/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# This file is necessary to make this directory a package.
|