koichi12 commited on
Commit
bc06d6a
·
verified ·
1 Parent(s): dc52d38

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/click/__pycache__/_winconsole.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/click/__pycache__/formatting.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/click/__pycache__/parser.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/click/__pycache__/termui.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/click/_termui_impl.py +788 -0
  7. .venv/lib/python3.11/site-packages/click/decorators.py +562 -0
  8. .venv/lib/python3.11/site-packages/click/py.typed +0 -0
  9. .venv/lib/python3.11/site-packages/click/shell_completion.py +603 -0
  10. .venv/lib/python3.11/site-packages/click/termui.py +784 -0
  11. .venv/lib/python3.11/site-packages/click/utils.py +624 -0
  12. .venv/lib/python3.11/site-packages/httplib2/__init__.py +1799 -0
  13. .venv/lib/python3.11/site-packages/httplib2/__pycache__/__init__.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/httplib2/__pycache__/auth.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/httplib2/__pycache__/certs.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/httplib2/__pycache__/error.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/httplib2/__pycache__/iri2uri.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/httplib2/__pycache__/socks.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/httplib2/auth.py +69 -0
  20. .venv/lib/python3.11/site-packages/httplib2/cacerts.txt +0 -0
  21. .venv/lib/python3.11/site-packages/httplib2/certs.py +42 -0
  22. .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/INSTALLER +1 -0
  23. .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/LICENSE +21 -0
  24. .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/METADATA +252 -0
  25. .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/RECORD +45 -0
  26. .venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/WHEEL +4 -0
  27. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/__init__.py +0 -0
  28. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/__pycache__/__init__.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv.h +671 -0
  30. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend.h +60 -0
  31. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn.h +693 -0
  32. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_v9.h +693 -0
  33. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_graph_v9.h +909 -0
  34. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops_v9.h +1316 -0
  35. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_v9.h +68 -0
  36. .venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version.h +70 -0
  37. .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libavutil-734d06dd.so.57.28.100 +3 -0
  38. .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libssl-28bef1ac.so.1.1 +3 -0
  39. .venv/lib/python3.11/site-packages/pyasn1/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/pyasn1/__pycache__/debug.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/pyasn1/codec/__init__.py +1 -0
  42. .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__init__.py +1 -0
  43. .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/__init__.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/decoder.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/encoder.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/eoo.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/pyasn1/codec/ber/decoder.py +2189 -0
  48. .venv/lib/python3.11/site-packages/pyasn1/codec/ber/encoder.py +954 -0
  49. .venv/lib/python3.11/site-packages/pyasn1/codec/ber/eoo.py +28 -0
  50. .venv/lib/python3.11/site-packages/pyasn1/codec/native/__init__.py +1 -0
.gitattributes CHANGED
@@ -407,3 +407,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
407
  .venv/lib/python3.11/site-packages/aiohttp/_websocket/mask.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
408
  .venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
409
  .venv/lib/python3.11/site-packages/aiohttp/_websocket/reader_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
 
407
  .venv/lib/python3.11/site-packages/aiohttp/_websocket/mask.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
408
  .venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
409
  .venv/lib/python3.11/site-packages/aiohttp/_websocket/reader_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
410
+ .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libssl-28bef1ac.so.1.1 filter=lfs diff=lfs merge=lfs -text
411
+ .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libavutil-734d06dd.so.57.28.100 filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/click/__pycache__/_winconsole.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
.venv/lib/python3.11/site-packages/click/__pycache__/formatting.cpython-311.pyc ADDED
Binary file (15.7 kB). View file
 
.venv/lib/python3.11/site-packages/click/__pycache__/parser.cpython-311.pyc ADDED
Binary file (23.1 kB). View file
 
.venv/lib/python3.11/site-packages/click/__pycache__/termui.cpython-311.pyc ADDED
Binary file (34.5 kB). View file
 
.venv/lib/python3.11/site-packages/click/_termui_impl.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains implementations for the termui module. To keep the
3
+ import time of Click down, some infrequently used functionality is
4
+ placed in this module and only imported as needed.
5
+ """
6
+
7
+ import contextlib
8
+ import math
9
+ import os
10
+ import sys
11
+ import time
12
+ import typing as t
13
+ from gettext import gettext as _
14
+ from io import StringIO
15
+ from shutil import which
16
+ from types import TracebackType
17
+
18
+ from ._compat import _default_text_stdout
19
+ from ._compat import CYGWIN
20
+ from ._compat import get_best_encoding
21
+ from ._compat import isatty
22
+ from ._compat import open_stream
23
+ from ._compat import strip_ansi
24
+ from ._compat import term_len
25
+ from ._compat import WIN
26
+ from .exceptions import ClickException
27
+ from .utils import echo
28
+
29
+ V = t.TypeVar("V")
30
+
31
+ if os.name == "nt":
32
+ BEFORE_BAR = "\r"
33
+ AFTER_BAR = "\n"
34
+ else:
35
+ BEFORE_BAR = "\r\033[?25l"
36
+ AFTER_BAR = "\033[?25h\n"
37
+
38
+
39
+ class ProgressBar(t.Generic[V]):
40
+ def __init__(
41
+ self,
42
+ iterable: t.Optional[t.Iterable[V]],
43
+ length: t.Optional[int] = None,
44
+ fill_char: str = "#",
45
+ empty_char: str = " ",
46
+ bar_template: str = "%(bar)s",
47
+ info_sep: str = " ",
48
+ show_eta: bool = True,
49
+ show_percent: t.Optional[bool] = None,
50
+ show_pos: bool = False,
51
+ item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
52
+ label: t.Optional[str] = None,
53
+ file: t.Optional[t.TextIO] = None,
54
+ color: t.Optional[bool] = None,
55
+ update_min_steps: int = 1,
56
+ width: int = 30,
57
+ ) -> None:
58
+ self.fill_char = fill_char
59
+ self.empty_char = empty_char
60
+ self.bar_template = bar_template
61
+ self.info_sep = info_sep
62
+ self.show_eta = show_eta
63
+ self.show_percent = show_percent
64
+ self.show_pos = show_pos
65
+ self.item_show_func = item_show_func
66
+ self.label: str = label or ""
67
+
68
+ if file is None:
69
+ file = _default_text_stdout()
70
+
71
+ # There are no standard streams attached to write to. For example,
72
+ # pythonw on Windows.
73
+ if file is None:
74
+ file = StringIO()
75
+
76
+ self.file = file
77
+ self.color = color
78
+ self.update_min_steps = update_min_steps
79
+ self._completed_intervals = 0
80
+ self.width: int = width
81
+ self.autowidth: bool = width == 0
82
+
83
+ if length is None:
84
+ from operator import length_hint
85
+
86
+ length = length_hint(iterable, -1)
87
+
88
+ if length == -1:
89
+ length = None
90
+ if iterable is None:
91
+ if length is None:
92
+ raise TypeError("iterable or length is required")
93
+ iterable = t.cast(t.Iterable[V], range(length))
94
+ self.iter: t.Iterable[V] = iter(iterable)
95
+ self.length = length
96
+ self.pos = 0
97
+ self.avg: t.List[float] = []
98
+ self.last_eta: float
99
+ self.start: float
100
+ self.start = self.last_eta = time.time()
101
+ self.eta_known: bool = False
102
+ self.finished: bool = False
103
+ self.max_width: t.Optional[int] = None
104
+ self.entered: bool = False
105
+ self.current_item: t.Optional[V] = None
106
+ self.is_hidden: bool = not isatty(self.file)
107
+ self._last_line: t.Optional[str] = None
108
+
109
+ def __enter__(self) -> "ProgressBar[V]":
110
+ self.entered = True
111
+ self.render_progress()
112
+ return self
113
+
114
+ def __exit__(
115
+ self,
116
+ exc_type: t.Optional[t.Type[BaseException]],
117
+ exc_value: t.Optional[BaseException],
118
+ tb: t.Optional[TracebackType],
119
+ ) -> None:
120
+ self.render_finish()
121
+
122
+ def __iter__(self) -> t.Iterator[V]:
123
+ if not self.entered:
124
+ raise RuntimeError("You need to use progress bars in a with block.")
125
+ self.render_progress()
126
+ return self.generator()
127
+
128
+ def __next__(self) -> V:
129
+ # Iteration is defined in terms of a generator function,
130
+ # returned by iter(self); use that to define next(). This works
131
+ # because `self.iter` is an iterable consumed by that generator,
132
+ # so it is re-entry safe. Calling `next(self.generator())`
133
+ # twice works and does "what you want".
134
+ return next(iter(self))
135
+
136
+ def render_finish(self) -> None:
137
+ if self.is_hidden:
138
+ return
139
+ self.file.write(AFTER_BAR)
140
+ self.file.flush()
141
+
142
+ @property
143
+ def pct(self) -> float:
144
+ if self.finished:
145
+ return 1.0
146
+ return min(self.pos / (float(self.length or 1) or 1), 1.0)
147
+
148
+ @property
149
+ def time_per_iteration(self) -> float:
150
+ if not self.avg:
151
+ return 0.0
152
+ return sum(self.avg) / float(len(self.avg))
153
+
154
+ @property
155
+ def eta(self) -> float:
156
+ if self.length is not None and not self.finished:
157
+ return self.time_per_iteration * (self.length - self.pos)
158
+ return 0.0
159
+
160
+ def format_eta(self) -> str:
161
+ if self.eta_known:
162
+ t = int(self.eta)
163
+ seconds = t % 60
164
+ t //= 60
165
+ minutes = t % 60
166
+ t //= 60
167
+ hours = t % 24
168
+ t //= 24
169
+ if t > 0:
170
+ return f"{t}d {hours:02}:{minutes:02}:{seconds:02}"
171
+ else:
172
+ return f"{hours:02}:{minutes:02}:{seconds:02}"
173
+ return ""
174
+
175
+ def format_pos(self) -> str:
176
+ pos = str(self.pos)
177
+ if self.length is not None:
178
+ pos += f"/{self.length}"
179
+ return pos
180
+
181
+ def format_pct(self) -> str:
182
+ return f"{int(self.pct * 100): 4}%"[1:]
183
+
184
+ def format_bar(self) -> str:
185
+ if self.length is not None:
186
+ bar_length = int(self.pct * self.width)
187
+ bar = self.fill_char * bar_length
188
+ bar += self.empty_char * (self.width - bar_length)
189
+ elif self.finished:
190
+ bar = self.fill_char * self.width
191
+ else:
192
+ chars = list(self.empty_char * (self.width or 1))
193
+ if self.time_per_iteration != 0:
194
+ chars[
195
+ int(
196
+ (math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5)
197
+ * self.width
198
+ )
199
+ ] = self.fill_char
200
+ bar = "".join(chars)
201
+ return bar
202
+
203
+ def format_progress_line(self) -> str:
204
+ show_percent = self.show_percent
205
+
206
+ info_bits = []
207
+ if self.length is not None and show_percent is None:
208
+ show_percent = not self.show_pos
209
+
210
+ if self.show_pos:
211
+ info_bits.append(self.format_pos())
212
+ if show_percent:
213
+ info_bits.append(self.format_pct())
214
+ if self.show_eta and self.eta_known and not self.finished:
215
+ info_bits.append(self.format_eta())
216
+ if self.item_show_func is not None:
217
+ item_info = self.item_show_func(self.current_item)
218
+ if item_info is not None:
219
+ info_bits.append(item_info)
220
+
221
+ return (
222
+ self.bar_template
223
+ % {
224
+ "label": self.label,
225
+ "bar": self.format_bar(),
226
+ "info": self.info_sep.join(info_bits),
227
+ }
228
+ ).rstrip()
229
+
230
+ def render_progress(self) -> None:
231
+ import shutil
232
+
233
+ if self.is_hidden:
234
+ # Only output the label as it changes if the output is not a
235
+ # TTY. Use file=stderr if you expect to be piping stdout.
236
+ if self._last_line != self.label:
237
+ self._last_line = self.label
238
+ echo(self.label, file=self.file, color=self.color)
239
+
240
+ return
241
+
242
+ buf = []
243
+ # Update width in case the terminal has been resized
244
+ if self.autowidth:
245
+ old_width = self.width
246
+ self.width = 0
247
+ clutter_length = term_len(self.format_progress_line())
248
+ new_width = max(0, shutil.get_terminal_size().columns - clutter_length)
249
+ if new_width < old_width:
250
+ buf.append(BEFORE_BAR)
251
+ buf.append(" " * self.max_width) # type: ignore
252
+ self.max_width = new_width
253
+ self.width = new_width
254
+
255
+ clear_width = self.width
256
+ if self.max_width is not None:
257
+ clear_width = self.max_width
258
+
259
+ buf.append(BEFORE_BAR)
260
+ line = self.format_progress_line()
261
+ line_len = term_len(line)
262
+ if self.max_width is None or self.max_width < line_len:
263
+ self.max_width = line_len
264
+
265
+ buf.append(line)
266
+ buf.append(" " * (clear_width - line_len))
267
+ line = "".join(buf)
268
+ # Render the line only if it changed.
269
+
270
+ if line != self._last_line:
271
+ self._last_line = line
272
+ echo(line, file=self.file, color=self.color, nl=False)
273
+ self.file.flush()
274
+
275
+ def make_step(self, n_steps: int) -> None:
276
+ self.pos += n_steps
277
+ if self.length is not None and self.pos >= self.length:
278
+ self.finished = True
279
+
280
+ if (time.time() - self.last_eta) < 1.0:
281
+ return
282
+
283
+ self.last_eta = time.time()
284
+
285
+ # self.avg is a rolling list of length <= 7 of steps where steps are
286
+ # defined as time elapsed divided by the total progress through
287
+ # self.length.
288
+ if self.pos:
289
+ step = (time.time() - self.start) / self.pos
290
+ else:
291
+ step = time.time() - self.start
292
+
293
+ self.avg = self.avg[-6:] + [step]
294
+
295
+ self.eta_known = self.length is not None
296
+
297
+ def update(self, n_steps: int, current_item: t.Optional[V] = None) -> None:
298
+ """Update the progress bar by advancing a specified number of
299
+ steps, and optionally set the ``current_item`` for this new
300
+ position.
301
+
302
+ :param n_steps: Number of steps to advance.
303
+ :param current_item: Optional item to set as ``current_item``
304
+ for the updated position.
305
+
306
+ .. versionchanged:: 8.0
307
+ Added the ``current_item`` optional parameter.
308
+
309
+ .. versionchanged:: 8.0
310
+ Only render when the number of steps meets the
311
+ ``update_min_steps`` threshold.
312
+ """
313
+ if current_item is not None:
314
+ self.current_item = current_item
315
+
316
+ self._completed_intervals += n_steps
317
+
318
+ if self._completed_intervals >= self.update_min_steps:
319
+ self.make_step(self._completed_intervals)
320
+ self.render_progress()
321
+ self._completed_intervals = 0
322
+
323
+ def finish(self) -> None:
324
+ self.eta_known = False
325
+ self.current_item = None
326
+ self.finished = True
327
+
328
+ def generator(self) -> t.Iterator[V]:
329
+ """Return a generator which yields the items added to the bar
330
+ during construction, and updates the progress bar *after* the
331
+ yielded block returns.
332
+ """
333
+ # WARNING: the iterator interface for `ProgressBar` relies on
334
+ # this and only works because this is a simple generator which
335
+ # doesn't create or manage additional state. If this function
336
+ # changes, the impact should be evaluated both against
337
+ # `iter(bar)` and `next(bar)`. `next()` in particular may call
338
+ # `self.generator()` repeatedly, and this must remain safe in
339
+ # order for that interface to work.
340
+ if not self.entered:
341
+ raise RuntimeError("You need to use progress bars in a with block.")
342
+
343
+ if self.is_hidden:
344
+ yield from self.iter
345
+ else:
346
+ for rv in self.iter:
347
+ self.current_item = rv
348
+
349
+ # This allows show_item_func to be updated before the
350
+ # item is processed. Only trigger at the beginning of
351
+ # the update interval.
352
+ if self._completed_intervals == 0:
353
+ self.render_progress()
354
+
355
+ yield rv
356
+ self.update(1)
357
+
358
+ self.finish()
359
+ self.render_progress()
360
+
361
+
362
+ def pager(generator: t.Iterable[str], color: t.Optional[bool] = None) -> None:
363
+ """Decide what method to use for paging through text."""
364
+ stdout = _default_text_stdout()
365
+
366
+ # There are no standard streams attached to write to. For example,
367
+ # pythonw on Windows.
368
+ if stdout is None:
369
+ stdout = StringIO()
370
+
371
+ if not isatty(sys.stdin) or not isatty(stdout):
372
+ return _nullpager(stdout, generator, color)
373
+ pager_cmd = (os.environ.get("PAGER", None) or "").strip()
374
+ if pager_cmd:
375
+ if WIN:
376
+ if _tempfilepager(generator, pager_cmd, color):
377
+ return
378
+ elif _pipepager(generator, pager_cmd, color):
379
+ return
380
+ if os.environ.get("TERM") in ("dumb", "emacs"):
381
+ return _nullpager(stdout, generator, color)
382
+ if (WIN or sys.platform.startswith("os2")) and _tempfilepager(
383
+ generator, "more", color
384
+ ):
385
+ return
386
+ if _pipepager(generator, "less", color):
387
+ return
388
+
389
+ import tempfile
390
+
391
+ fd, filename = tempfile.mkstemp()
392
+ os.close(fd)
393
+ try:
394
+ if _pipepager(generator, "more", color):
395
+ return
396
+ return _nullpager(stdout, generator, color)
397
+ finally:
398
+ os.unlink(filename)
399
+
400
+
401
+ def _pipepager(generator: t.Iterable[str], cmd: str, color: t.Optional[bool]) -> bool:
402
+ """Page through text by feeding it to another program. Invoking a
403
+ pager through this might support colors.
404
+
405
+ Returns True if the command was found, False otherwise and thus another
406
+ pager should be attempted.
407
+ """
408
+ cmd_absolute = which(cmd)
409
+ if cmd_absolute is None:
410
+ return False
411
+
412
+ import subprocess
413
+
414
+ env = dict(os.environ)
415
+
416
+ # If we're piping to less we might support colors under the
417
+ # condition that
418
+ cmd_detail = cmd.rsplit("/", 1)[-1].split()
419
+ if color is None and cmd_detail[0] == "less":
420
+ less_flags = f"{os.environ.get('LESS', '')}{' '.join(cmd_detail[1:])}"
421
+ if not less_flags:
422
+ env["LESS"] = "-R"
423
+ color = True
424
+ elif "r" in less_flags or "R" in less_flags:
425
+ color = True
426
+
427
+ c = subprocess.Popen(
428
+ [cmd_absolute],
429
+ shell=True,
430
+ stdin=subprocess.PIPE,
431
+ env=env,
432
+ errors="replace",
433
+ text=True,
434
+ )
435
+ assert c.stdin is not None
436
+ try:
437
+ for text in generator:
438
+ if not color:
439
+ text = strip_ansi(text)
440
+
441
+ c.stdin.write(text)
442
+ except (OSError, KeyboardInterrupt):
443
+ pass
444
+ else:
445
+ c.stdin.close()
446
+
447
+ # Less doesn't respect ^C, but catches it for its own UI purposes (aborting
448
+ # search or other commands inside less).
449
+ #
450
+ # That means when the user hits ^C, the parent process (click) terminates,
451
+ # but less is still alive, paging the output and messing up the terminal.
452
+ #
453
+ # If the user wants to make the pager exit on ^C, they should set
454
+ # `LESS='-K'`. It's not our decision to make.
455
+ while True:
456
+ try:
457
+ c.wait()
458
+ except KeyboardInterrupt:
459
+ pass
460
+ else:
461
+ break
462
+
463
+ return True
464
+
465
+
466
+ def _tempfilepager(
467
+ generator: t.Iterable[str],
468
+ cmd: str,
469
+ color: t.Optional[bool],
470
+ ) -> bool:
471
+ """Page through text by invoking a program on a temporary file.
472
+
473
+ Returns True if the command was found, False otherwise and thus another
474
+ pager should be attempted.
475
+ """
476
+ # Which is necessary for Windows, it is also recommended in the Popen docs.
477
+ cmd_absolute = which(cmd)
478
+ if cmd_absolute is None:
479
+ return False
480
+
481
+ import subprocess
482
+ import tempfile
483
+
484
+ fd, filename = tempfile.mkstemp()
485
+ # TODO: This never terminates if the passed generator never terminates.
486
+ text = "".join(generator)
487
+ if not color:
488
+ text = strip_ansi(text)
489
+ encoding = get_best_encoding(sys.stdout)
490
+ with open_stream(filename, "wb")[0] as f:
491
+ f.write(text.encode(encoding))
492
+ try:
493
+ subprocess.call([cmd_absolute, filename])
494
+ except OSError:
495
+ # Command not found
496
+ pass
497
+ finally:
498
+ os.close(fd)
499
+ os.unlink(filename)
500
+
501
+ return True
502
+
503
+
504
+ def _nullpager(
505
+ stream: t.TextIO, generator: t.Iterable[str], color: t.Optional[bool]
506
+ ) -> None:
507
+ """Simply print unformatted text. This is the ultimate fallback."""
508
+ for text in generator:
509
+ if not color:
510
+ text = strip_ansi(text)
511
+ stream.write(text)
512
+
513
+
514
+ class Editor:
515
+ def __init__(
516
+ self,
517
+ editor: t.Optional[str] = None,
518
+ env: t.Optional[t.Mapping[str, str]] = None,
519
+ require_save: bool = True,
520
+ extension: str = ".txt",
521
+ ) -> None:
522
+ self.editor = editor
523
+ self.env = env
524
+ self.require_save = require_save
525
+ self.extension = extension
526
+
527
+ def get_editor(self) -> str:
528
+ if self.editor is not None:
529
+ return self.editor
530
+ for key in "VISUAL", "EDITOR":
531
+ rv = os.environ.get(key)
532
+ if rv:
533
+ return rv
534
+ if WIN:
535
+ return "notepad"
536
+ for editor in "sensible-editor", "vim", "nano":
537
+ if which(editor) is not None:
538
+ return editor
539
+ return "vi"
540
+
541
+ def edit_file(self, filename: str) -> None:
542
+ import subprocess
543
+
544
+ editor = self.get_editor()
545
+ environ: t.Optional[t.Dict[str, str]] = None
546
+
547
+ if self.env:
548
+ environ = os.environ.copy()
549
+ environ.update(self.env)
550
+
551
+ try:
552
+ c = subprocess.Popen(f'{editor} "{filename}"', env=environ, shell=True)
553
+ exit_code = c.wait()
554
+ if exit_code != 0:
555
+ raise ClickException(
556
+ _("{editor}: Editing failed").format(editor=editor)
557
+ )
558
+ except OSError as e:
559
+ raise ClickException(
560
+ _("{editor}: Editing failed: {e}").format(editor=editor, e=e)
561
+ ) from e
562
+
563
+ def edit(self, text: t.Optional[t.AnyStr]) -> t.Optional[t.AnyStr]:
564
+ import tempfile
565
+
566
+ if not text:
567
+ data = b""
568
+ elif isinstance(text, (bytes, bytearray)):
569
+ data = text
570
+ else:
571
+ if text and not text.endswith("\n"):
572
+ text += "\n"
573
+
574
+ if WIN:
575
+ data = text.replace("\n", "\r\n").encode("utf-8-sig")
576
+ else:
577
+ data = text.encode("utf-8")
578
+
579
+ fd, name = tempfile.mkstemp(prefix="editor-", suffix=self.extension)
580
+ f: t.BinaryIO
581
+
582
+ try:
583
+ with os.fdopen(fd, "wb") as f:
584
+ f.write(data)
585
+
586
+ # If the filesystem resolution is 1 second, like Mac OS
587
+ # 10.12 Extended, or 2 seconds, like FAT32, and the editor
588
+ # closes very fast, require_save can fail. Set the modified
589
+ # time to be 2 seconds in the past to work around this.
590
+ os.utime(name, (os.path.getatime(name), os.path.getmtime(name) - 2))
591
+ # Depending on the resolution, the exact value might not be
592
+ # recorded, so get the new recorded value.
593
+ timestamp = os.path.getmtime(name)
594
+
595
+ self.edit_file(name)
596
+
597
+ if self.require_save and os.path.getmtime(name) == timestamp:
598
+ return None
599
+
600
+ with open(name, "rb") as f:
601
+ rv = f.read()
602
+
603
+ if isinstance(text, (bytes, bytearray)):
604
+ return rv
605
+
606
+ return rv.decode("utf-8-sig").replace("\r\n", "\n") # type: ignore
607
+ finally:
608
+ os.unlink(name)
609
+
610
+
611
+ def open_url(url: str, wait: bool = False, locate: bool = False) -> int:
612
+ import subprocess
613
+
614
+ def _unquote_file(url: str) -> str:
615
+ from urllib.parse import unquote
616
+
617
+ if url.startswith("file://"):
618
+ url = unquote(url[7:])
619
+
620
+ return url
621
+
622
+ if sys.platform == "darwin":
623
+ args = ["open"]
624
+ if wait:
625
+ args.append("-W")
626
+ if locate:
627
+ args.append("-R")
628
+ args.append(_unquote_file(url))
629
+ null = open("/dev/null", "w")
630
+ try:
631
+ return subprocess.Popen(args, stderr=null).wait()
632
+ finally:
633
+ null.close()
634
+ elif WIN:
635
+ if locate:
636
+ url = _unquote_file(url)
637
+ args = ["explorer", f"/select,{url}"]
638
+ else:
639
+ args = ["start"]
640
+ if wait:
641
+ args.append("/WAIT")
642
+ args.append("")
643
+ args.append(url)
644
+ try:
645
+ return subprocess.call(args)
646
+ except OSError:
647
+ # Command not found
648
+ return 127
649
+ elif CYGWIN:
650
+ if locate:
651
+ url = _unquote_file(url)
652
+ args = ["cygstart", os.path.dirname(url)]
653
+ else:
654
+ args = ["cygstart"]
655
+ if wait:
656
+ args.append("-w")
657
+ args.append(url)
658
+ try:
659
+ return subprocess.call(args)
660
+ except OSError:
661
+ # Command not found
662
+ return 127
663
+
664
+ try:
665
+ if locate:
666
+ url = os.path.dirname(_unquote_file(url)) or "."
667
+ else:
668
+ url = _unquote_file(url)
669
+ c = subprocess.Popen(["xdg-open", url])
670
+ if wait:
671
+ return c.wait()
672
+ return 0
673
+ except OSError:
674
+ if url.startswith(("http://", "https://")) and not locate and not wait:
675
+ import webbrowser
676
+
677
+ webbrowser.open(url)
678
+ return 0
679
+ return 1
680
+
681
+
682
+ def _translate_ch_to_exc(ch: str) -> t.Optional[BaseException]:
683
+ if ch == "\x03":
684
+ raise KeyboardInterrupt()
685
+
686
+ if ch == "\x04" and not WIN: # Unix-like, Ctrl+D
687
+ raise EOFError()
688
+
689
+ if ch == "\x1a" and WIN: # Windows, Ctrl+Z
690
+ raise EOFError()
691
+
692
+ return None
693
+
694
+
695
+ if WIN:
696
+ import msvcrt
697
+
698
+ @contextlib.contextmanager
699
+ def raw_terminal() -> t.Iterator[int]:
700
+ yield -1
701
+
702
+ def getchar(echo: bool) -> str:
703
+ # The function `getch` will return a bytes object corresponding to
704
+ # the pressed character. Since Windows 10 build 1803, it will also
705
+ # return \x00 when called a second time after pressing a regular key.
706
+ #
707
+ # `getwch` does not share this probably-bugged behavior. Moreover, it
708
+ # returns a Unicode object by default, which is what we want.
709
+ #
710
+ # Either of these functions will return \x00 or \xe0 to indicate
711
+ # a special key, and you need to call the same function again to get
712
+ # the "rest" of the code. The fun part is that \u00e0 is
713
+ # "latin small letter a with grave", so if you type that on a French
714
+ # keyboard, you _also_ get a \xe0.
715
+ # E.g., consider the Up arrow. This returns \xe0 and then \x48. The
716
+ # resulting Unicode string reads as "a with grave" + "capital H".
717
+ # This is indistinguishable from when the user actually types
718
+ # "a with grave" and then "capital H".
719
+ #
720
+ # When \xe0 is returned, we assume it's part of a special-key sequence
721
+ # and call `getwch` again, but that means that when the user types
722
+ # the \u00e0 character, `getchar` doesn't return until a second
723
+ # character is typed.
724
+ # The alternative is returning immediately, but that would mess up
725
+ # cross-platform handling of arrow keys and others that start with
726
+ # \xe0. Another option is using `getch`, but then we can't reliably
727
+ # read non-ASCII characters, because return values of `getch` are
728
+ # limited to the current 8-bit codepage.
729
+ #
730
+ # Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
731
+ # is doing the right thing in more situations than with `getch`.
732
+ func: t.Callable[[], str]
733
+
734
+ if echo:
735
+ func = msvcrt.getwche # type: ignore
736
+ else:
737
+ func = msvcrt.getwch # type: ignore
738
+
739
+ rv = func()
740
+
741
+ if rv in ("\x00", "\xe0"):
742
+ # \x00 and \xe0 are control characters that indicate special key,
743
+ # see above.
744
+ rv += func()
745
+
746
+ _translate_ch_to_exc(rv)
747
+ return rv
748
+
749
+ else:
750
+ import termios
751
+ import tty
752
+
753
+ @contextlib.contextmanager
754
+ def raw_terminal() -> t.Iterator[int]:
755
+ f: t.Optional[t.TextIO]
756
+ fd: int
757
+
758
+ if not isatty(sys.stdin):
759
+ f = open("/dev/tty")
760
+ fd = f.fileno()
761
+ else:
762
+ fd = sys.stdin.fileno()
763
+ f = None
764
+
765
+ try:
766
+ old_settings = termios.tcgetattr(fd)
767
+
768
+ try:
769
+ tty.setraw(fd)
770
+ yield fd
771
+ finally:
772
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
773
+ sys.stdout.flush()
774
+
775
+ if f is not None:
776
+ f.close()
777
+ except termios.error:
778
+ pass
779
+
780
+ def getchar(echo: bool) -> str:
781
+ with raw_terminal() as fd:
782
+ ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace")
783
+
784
+ if echo and isatty(sys.stdout):
785
+ sys.stdout.write(ch)
786
+
787
+ _translate_ch_to_exc(ch)
788
+ return ch
.venv/lib/python3.11/site-packages/click/decorators.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import types
3
+ import typing as t
4
+ from functools import update_wrapper
5
+ from gettext import gettext as _
6
+
7
+ from .core import Argument
8
+ from .core import Command
9
+ from .core import Context
10
+ from .core import Group
11
+ from .core import Option
12
+ from .core import Parameter
13
+ from .globals import get_current_context
14
+ from .utils import echo
15
+
16
+ if t.TYPE_CHECKING:
17
+ import typing_extensions as te
18
+
19
+ P = te.ParamSpec("P")
20
+
21
+ R = t.TypeVar("R")
22
+ T = t.TypeVar("T")
23
+ _AnyCallable = t.Callable[..., t.Any]
24
+ FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, Command])
25
+
26
+
27
+ def pass_context(f: "t.Callable[te.Concatenate[Context, P], R]") -> "t.Callable[P, R]":
28
+ """Marks a callback as wanting to receive the current context
29
+ object as first argument.
30
+ """
31
+
32
+ def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R":
33
+ return f(get_current_context(), *args, **kwargs)
34
+
35
+ return update_wrapper(new_func, f)
36
+
37
+
38
+ def pass_obj(f: "t.Callable[te.Concatenate[t.Any, P], R]") -> "t.Callable[P, R]":
39
+ """Similar to :func:`pass_context`, but only pass the object on the
40
+ context onwards (:attr:`Context.obj`). This is useful if that object
41
+ represents the state of a nested system.
42
+ """
43
+
44
+ def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R":
45
+ return f(get_current_context().obj, *args, **kwargs)
46
+
47
+ return update_wrapper(new_func, f)
48
+
49
+
50
+ def make_pass_decorator(
51
+ object_type: t.Type[T], ensure: bool = False
52
+ ) -> t.Callable[["t.Callable[te.Concatenate[T, P], R]"], "t.Callable[P, R]"]:
53
+ """Given an object type this creates a decorator that will work
54
+ similar to :func:`pass_obj` but instead of passing the object of the
55
+ current context, it will find the innermost context of type
56
+ :func:`object_type`.
57
+
58
+ This generates a decorator that works roughly like this::
59
+
60
+ from functools import update_wrapper
61
+
62
+ def decorator(f):
63
+ @pass_context
64
+ def new_func(ctx, *args, **kwargs):
65
+ obj = ctx.find_object(object_type)
66
+ return ctx.invoke(f, obj, *args, **kwargs)
67
+ return update_wrapper(new_func, f)
68
+ return decorator
69
+
70
+ :param object_type: the type of the object to pass.
71
+ :param ensure: if set to `True`, a new object will be created and
72
+ remembered on the context if it's not there yet.
73
+ """
74
+
75
+ def decorator(f: "t.Callable[te.Concatenate[T, P], R]") -> "t.Callable[P, R]":
76
+ def new_func(*args: "P.args", **kwargs: "P.kwargs") -> "R":
77
+ ctx = get_current_context()
78
+
79
+ obj: t.Optional[T]
80
+ if ensure:
81
+ obj = ctx.ensure_object(object_type)
82
+ else:
83
+ obj = ctx.find_object(object_type)
84
+
85
+ if obj is None:
86
+ raise RuntimeError(
87
+ "Managed to invoke callback without a context"
88
+ f" object of type {object_type.__name__!r}"
89
+ " existing."
90
+ )
91
+
92
+ return ctx.invoke(f, obj, *args, **kwargs)
93
+
94
+ return update_wrapper(new_func, f)
95
+
96
+ return decorator
97
+
98
+
99
+ def pass_meta_key(
100
+ key: str, *, doc_description: t.Optional[str] = None
101
+ ) -> "t.Callable[[t.Callable[te.Concatenate[t.Any, P], R]], t.Callable[P, R]]":
102
+ """Create a decorator that passes a key from
103
+ :attr:`click.Context.meta` as the first argument to the decorated
104
+ function.
105
+
106
+ :param key: Key in ``Context.meta`` to pass.
107
+ :param doc_description: Description of the object being passed,
108
+ inserted into the decorator's docstring. Defaults to "the 'key'
109
+ key from Context.meta".
110
+
111
+ .. versionadded:: 8.0
112
+ """
113
+
114
+ def decorator(f: "t.Callable[te.Concatenate[t.Any, P], R]") -> "t.Callable[P, R]":
115
+ def new_func(*args: "P.args", **kwargs: "P.kwargs") -> R:
116
+ ctx = get_current_context()
117
+ obj = ctx.meta[key]
118
+ return ctx.invoke(f, obj, *args, **kwargs)
119
+
120
+ return update_wrapper(new_func, f)
121
+
122
+ if doc_description is None:
123
+ doc_description = f"the {key!r} key from :attr:`click.Context.meta`"
124
+
125
+ decorator.__doc__ = (
126
+ f"Decorator that passes {doc_description} as the first argument"
127
+ " to the decorated function."
128
+ )
129
+ return decorator
130
+
131
+
132
+ CmdType = t.TypeVar("CmdType", bound=Command)
133
+
134
+
135
+ # variant: no call, directly as decorator for a function.
136
+ @t.overload
137
+ def command(name: _AnyCallable) -> Command: ...
138
+
139
+
140
+ # variant: with positional name and with positional or keyword cls argument:
141
+ # @command(namearg, CommandCls, ...) or @command(namearg, cls=CommandCls, ...)
142
+ @t.overload
143
+ def command(
144
+ name: t.Optional[str],
145
+ cls: t.Type[CmdType],
146
+ **attrs: t.Any,
147
+ ) -> t.Callable[[_AnyCallable], CmdType]: ...
148
+
149
+
150
+ # variant: name omitted, cls _must_ be a keyword argument, @command(cls=CommandCls, ...)
151
+ @t.overload
152
+ def command(
153
+ name: None = None,
154
+ *,
155
+ cls: t.Type[CmdType],
156
+ **attrs: t.Any,
157
+ ) -> t.Callable[[_AnyCallable], CmdType]: ...
158
+
159
+
160
+ # variant: with optional string name, no cls argument provided.
161
+ @t.overload
162
+ def command(
163
+ name: t.Optional[str] = ..., cls: None = None, **attrs: t.Any
164
+ ) -> t.Callable[[_AnyCallable], Command]: ...
165
+
166
+
167
+ def command(
168
+ name: t.Union[t.Optional[str], _AnyCallable] = None,
169
+ cls: t.Optional[t.Type[CmdType]] = None,
170
+ **attrs: t.Any,
171
+ ) -> t.Union[Command, t.Callable[[_AnyCallable], t.Union[Command, CmdType]]]:
172
+ r"""Creates a new :class:`Command` and uses the decorated function as
173
+ callback. This will also automatically attach all decorated
174
+ :func:`option`\s and :func:`argument`\s as parameters to the command.
175
+
176
+ The name of the command defaults to the name of the function with
177
+ underscores replaced by dashes. If you want to change that, you can
178
+ pass the intended name as the first argument.
179
+
180
+ All keyword arguments are forwarded to the underlying command class.
181
+ For the ``params`` argument, any decorated params are appended to
182
+ the end of the list.
183
+
184
+ Once decorated the function turns into a :class:`Command` instance
185
+ that can be invoked as a command line utility or be attached to a
186
+ command :class:`Group`.
187
+
188
+ :param name: the name of the command. This defaults to the function
189
+ name with underscores replaced by dashes.
190
+ :param cls: the command class to instantiate. This defaults to
191
+ :class:`Command`.
192
+
193
+ .. versionchanged:: 8.1
194
+ This decorator can be applied without parentheses.
195
+
196
+ .. versionchanged:: 8.1
197
+ The ``params`` argument can be used. Decorated params are
198
+ appended to the end of the list.
199
+ """
200
+
201
+ func: t.Optional[t.Callable[[_AnyCallable], t.Any]] = None
202
+
203
+ if callable(name):
204
+ func = name
205
+ name = None
206
+ assert cls is None, "Use 'command(cls=cls)(callable)' to specify a class."
207
+ assert not attrs, "Use 'command(**kwargs)(callable)' to provide arguments."
208
+
209
+ if cls is None:
210
+ cls = t.cast(t.Type[CmdType], Command)
211
+
212
+ def decorator(f: _AnyCallable) -> CmdType:
213
+ if isinstance(f, Command):
214
+ raise TypeError("Attempted to convert a callback into a command twice.")
215
+
216
+ attr_params = attrs.pop("params", None)
217
+ params = attr_params if attr_params is not None else []
218
+
219
+ try:
220
+ decorator_params = f.__click_params__ # type: ignore
221
+ except AttributeError:
222
+ pass
223
+ else:
224
+ del f.__click_params__ # type: ignore
225
+ params.extend(reversed(decorator_params))
226
+
227
+ if attrs.get("help") is None:
228
+ attrs["help"] = f.__doc__
229
+
230
+ if t.TYPE_CHECKING:
231
+ assert cls is not None
232
+ assert not callable(name)
233
+
234
+ cmd = cls(
235
+ name=name or f.__name__.lower().replace("_", "-"),
236
+ callback=f,
237
+ params=params,
238
+ **attrs,
239
+ )
240
+ cmd.__doc__ = f.__doc__
241
+ return cmd
242
+
243
+ if func is not None:
244
+ return decorator(func)
245
+
246
+ return decorator
247
+
248
+
249
+ GrpType = t.TypeVar("GrpType", bound=Group)
250
+
251
+
252
+ # variant: no call, directly as decorator for a function.
253
+ @t.overload
254
+ def group(name: _AnyCallable) -> Group: ...
255
+
256
+
257
+ # variant: with positional name and with positional or keyword cls argument:
258
+ # @group(namearg, GroupCls, ...) or @group(namearg, cls=GroupCls, ...)
259
+ @t.overload
260
+ def group(
261
+ name: t.Optional[str],
262
+ cls: t.Type[GrpType],
263
+ **attrs: t.Any,
264
+ ) -> t.Callable[[_AnyCallable], GrpType]: ...
265
+
266
+
267
+ # variant: name omitted, cls _must_ be a keyword argument, @group(cmd=GroupCls, ...)
268
+ @t.overload
269
+ def group(
270
+ name: None = None,
271
+ *,
272
+ cls: t.Type[GrpType],
273
+ **attrs: t.Any,
274
+ ) -> t.Callable[[_AnyCallable], GrpType]: ...
275
+
276
+
277
+ # variant: with optional string name, no cls argument provided.
278
+ @t.overload
279
+ def group(
280
+ name: t.Optional[str] = ..., cls: None = None, **attrs: t.Any
281
+ ) -> t.Callable[[_AnyCallable], Group]: ...
282
+
283
+
284
+ def group(
285
+ name: t.Union[str, _AnyCallable, None] = None,
286
+ cls: t.Optional[t.Type[GrpType]] = None,
287
+ **attrs: t.Any,
288
+ ) -> t.Union[Group, t.Callable[[_AnyCallable], t.Union[Group, GrpType]]]:
289
+ """Creates a new :class:`Group` with a function as callback. This
290
+ works otherwise the same as :func:`command` just that the `cls`
291
+ parameter is set to :class:`Group`.
292
+
293
+ .. versionchanged:: 8.1
294
+ This decorator can be applied without parentheses.
295
+ """
296
+ if cls is None:
297
+ cls = t.cast(t.Type[GrpType], Group)
298
+
299
+ if callable(name):
300
+ return command(cls=cls, **attrs)(name)
301
+
302
+ return command(name, cls, **attrs)
303
+
304
+
305
+ def _param_memo(f: t.Callable[..., t.Any], param: Parameter) -> None:
306
+ if isinstance(f, Command):
307
+ f.params.append(param)
308
+ else:
309
+ if not hasattr(f, "__click_params__"):
310
+ f.__click_params__ = [] # type: ignore
311
+
312
+ f.__click_params__.append(param) # type: ignore
313
+
314
+
315
+ def argument(
316
+ *param_decls: str, cls: t.Optional[t.Type[Argument]] = None, **attrs: t.Any
317
+ ) -> t.Callable[[FC], FC]:
318
+ """Attaches an argument to the command. All positional arguments are
319
+ passed as parameter declarations to :class:`Argument`; all keyword
320
+ arguments are forwarded unchanged (except ``cls``).
321
+ This is equivalent to creating an :class:`Argument` instance manually
322
+ and attaching it to the :attr:`Command.params` list.
323
+
324
+ For the default argument class, refer to :class:`Argument` and
325
+ :class:`Parameter` for descriptions of parameters.
326
+
327
+ :param cls: the argument class to instantiate. This defaults to
328
+ :class:`Argument`.
329
+ :param param_decls: Passed as positional arguments to the constructor of
330
+ ``cls``.
331
+ :param attrs: Passed as keyword arguments to the constructor of ``cls``.
332
+ """
333
+ if cls is None:
334
+ cls = Argument
335
+
336
+ def decorator(f: FC) -> FC:
337
+ _param_memo(f, cls(param_decls, **attrs))
338
+ return f
339
+
340
+ return decorator
341
+
342
+
343
+ def option(
344
+ *param_decls: str, cls: t.Optional[t.Type[Option]] = None, **attrs: t.Any
345
+ ) -> t.Callable[[FC], FC]:
346
+ """Attaches an option to the command. All positional arguments are
347
+ passed as parameter declarations to :class:`Option`; all keyword
348
+ arguments are forwarded unchanged (except ``cls``).
349
+ This is equivalent to creating an :class:`Option` instance manually
350
+ and attaching it to the :attr:`Command.params` list.
351
+
352
+ For the default option class, refer to :class:`Option` and
353
+ :class:`Parameter` for descriptions of parameters.
354
+
355
+ :param cls: the option class to instantiate. This defaults to
356
+ :class:`Option`.
357
+ :param param_decls: Passed as positional arguments to the constructor of
358
+ ``cls``.
359
+ :param attrs: Passed as keyword arguments to the constructor of ``cls``.
360
+ """
361
+ if cls is None:
362
+ cls = Option
363
+
364
+ def decorator(f: FC) -> FC:
365
+ _param_memo(f, cls(param_decls, **attrs))
366
+ return f
367
+
368
+ return decorator
369
+
370
+
371
+ def confirmation_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
372
+ """Add a ``--yes`` option which shows a prompt before continuing if
373
+ not passed. If the prompt is declined, the program will exit.
374
+
375
+ :param param_decls: One or more option names. Defaults to the single
376
+ value ``"--yes"``.
377
+ :param kwargs: Extra arguments are passed to :func:`option`.
378
+ """
379
+
380
+ def callback(ctx: Context, param: Parameter, value: bool) -> None:
381
+ if not value:
382
+ ctx.abort()
383
+
384
+ if not param_decls:
385
+ param_decls = ("--yes",)
386
+
387
+ kwargs.setdefault("is_flag", True)
388
+ kwargs.setdefault("callback", callback)
389
+ kwargs.setdefault("expose_value", False)
390
+ kwargs.setdefault("prompt", "Do you want to continue?")
391
+ kwargs.setdefault("help", "Confirm the action without prompting.")
392
+ return option(*param_decls, **kwargs)
393
+
394
+
395
+ def password_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
396
+ """Add a ``--password`` option which prompts for a password, hiding
397
+ input and asking to enter the value again for confirmation.
398
+
399
+ :param param_decls: One or more option names. Defaults to the single
400
+ value ``"--password"``.
401
+ :param kwargs: Extra arguments are passed to :func:`option`.
402
+ """
403
+ if not param_decls:
404
+ param_decls = ("--password",)
405
+
406
+ kwargs.setdefault("prompt", True)
407
+ kwargs.setdefault("confirmation_prompt", True)
408
+ kwargs.setdefault("hide_input", True)
409
+ return option(*param_decls, **kwargs)
410
+
411
+
412
+ def version_option(
413
+ version: t.Optional[str] = None,
414
+ *param_decls: str,
415
+ package_name: t.Optional[str] = None,
416
+ prog_name: t.Optional[str] = None,
417
+ message: t.Optional[str] = None,
418
+ **kwargs: t.Any,
419
+ ) -> t.Callable[[FC], FC]:
420
+ """Add a ``--version`` option which immediately prints the version
421
+ number and exits the program.
422
+
423
+ If ``version`` is not provided, Click will try to detect it using
424
+ :func:`importlib.metadata.version` to get the version for the
425
+ ``package_name``. On Python < 3.8, the ``importlib_metadata``
426
+ backport must be installed.
427
+
428
+ If ``package_name`` is not provided, Click will try to detect it by
429
+ inspecting the stack frames. This will be used to detect the
430
+ version, so it must match the name of the installed package.
431
+
432
+ :param version: The version number to show. If not provided, Click
433
+ will try to detect it.
434
+ :param param_decls: One or more option names. Defaults to the single
435
+ value ``"--version"``.
436
+ :param package_name: The package name to detect the version from. If
437
+ not provided, Click will try to detect it.
438
+ :param prog_name: The name of the CLI to show in the message. If not
439
+ provided, it will be detected from the command.
440
+ :param message: The message to show. The values ``%(prog)s``,
441
+ ``%(package)s``, and ``%(version)s`` are available. Defaults to
442
+ ``"%(prog)s, version %(version)s"``.
443
+ :param kwargs: Extra arguments are passed to :func:`option`.
444
+ :raise RuntimeError: ``version`` could not be detected.
445
+
446
+ .. versionchanged:: 8.0
447
+ Add the ``package_name`` parameter, and the ``%(package)s``
448
+ value for messages.
449
+
450
+ .. versionchanged:: 8.0
451
+ Use :mod:`importlib.metadata` instead of ``pkg_resources``. The
452
+ version is detected based on the package name, not the entry
453
+ point name. The Python package name must match the installed
454
+ package name, or be passed with ``package_name=``.
455
+ """
456
+ if message is None:
457
+ message = _("%(prog)s, version %(version)s")
458
+
459
+ if version is None and package_name is None:
460
+ frame = inspect.currentframe()
461
+ f_back = frame.f_back if frame is not None else None
462
+ f_globals = f_back.f_globals if f_back is not None else None
463
+ # break reference cycle
464
+ # https://docs.python.org/3/library/inspect.html#the-interpreter-stack
465
+ del frame
466
+
467
+ if f_globals is not None:
468
+ package_name = f_globals.get("__name__")
469
+
470
+ if package_name == "__main__":
471
+ package_name = f_globals.get("__package__")
472
+
473
+ if package_name:
474
+ package_name = package_name.partition(".")[0]
475
+
476
+ def callback(ctx: Context, param: Parameter, value: bool) -> None:
477
+ if not value or ctx.resilient_parsing:
478
+ return
479
+
480
+ nonlocal prog_name
481
+ nonlocal version
482
+
483
+ if prog_name is None:
484
+ prog_name = ctx.find_root().info_name
485
+
486
+ if version is None and package_name is not None:
487
+ metadata: t.Optional[types.ModuleType]
488
+
489
+ try:
490
+ from importlib import metadata
491
+ except ImportError:
492
+ # Python < 3.8
493
+ import importlib_metadata as metadata # type: ignore
494
+
495
+ try:
496
+ version = metadata.version(package_name) # type: ignore
497
+ except metadata.PackageNotFoundError: # type: ignore
498
+ raise RuntimeError(
499
+ f"{package_name!r} is not installed. Try passing"
500
+ " 'package_name' instead."
501
+ ) from None
502
+
503
+ if version is None:
504
+ raise RuntimeError(
505
+ f"Could not determine the version for {package_name!r} automatically."
506
+ )
507
+
508
+ echo(
509
+ message % {"prog": prog_name, "package": package_name, "version": version},
510
+ color=ctx.color,
511
+ )
512
+ ctx.exit()
513
+
514
+ if not param_decls:
515
+ param_decls = ("--version",)
516
+
517
+ kwargs.setdefault("is_flag", True)
518
+ kwargs.setdefault("expose_value", False)
519
+ kwargs.setdefault("is_eager", True)
520
+ kwargs.setdefault("help", _("Show the version and exit."))
521
+ kwargs["callback"] = callback
522
+ return option(*param_decls, **kwargs)
523
+
524
+
525
+ class HelpOption(Option):
526
+ """Pre-configured ``--help`` option which immediately prints the help page
527
+ and exits the program.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ param_decls: t.Optional[t.Sequence[str]] = None,
533
+ **kwargs: t.Any,
534
+ ) -> None:
535
+ if not param_decls:
536
+ param_decls = ("--help",)
537
+
538
+ kwargs.setdefault("is_flag", True)
539
+ kwargs.setdefault("expose_value", False)
540
+ kwargs.setdefault("is_eager", True)
541
+ kwargs.setdefault("help", _("Show this message and exit."))
542
+ kwargs.setdefault("callback", self.show_help)
543
+
544
+ super().__init__(param_decls, **kwargs)
545
+
546
+ @staticmethod
547
+ def show_help(ctx: Context, param: Parameter, value: bool) -> None:
548
+ """Callback that print the help page on ``<stdout>`` and exits."""
549
+ if value and not ctx.resilient_parsing:
550
+ echo(ctx.get_help(), color=ctx.color)
551
+ ctx.exit()
552
+
553
+
554
+ def help_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
555
+ """Decorator for the pre-configured ``--help`` option defined above.
556
+
557
+ :param param_decls: One or more option names. Defaults to the single
558
+ value ``"--help"``.
559
+ :param kwargs: Extra arguments are passed to :func:`option`.
560
+ """
561
+ kwargs.setdefault("cls", HelpOption)
562
+ return option(*param_decls, **kwargs)
.venv/lib/python3.11/site-packages/click/py.typed ADDED
File without changes
.venv/lib/python3.11/site-packages/click/shell_completion.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import typing as t
4
+ from gettext import gettext as _
5
+
6
+ from .core import Argument
7
+ from .core import BaseCommand
8
+ from .core import Context
9
+ from .core import MultiCommand
10
+ from .core import Option
11
+ from .core import Parameter
12
+ from .core import ParameterSource
13
+ from .parser import split_arg_string
14
+ from .utils import echo
15
+
16
+
17
+ def shell_complete(
18
+ cli: BaseCommand,
19
+ ctx_args: t.MutableMapping[str, t.Any],
20
+ prog_name: str,
21
+ complete_var: str,
22
+ instruction: str,
23
+ ) -> int:
24
+ """Perform shell completion for the given CLI program.
25
+
26
+ :param cli: Command being called.
27
+ :param ctx_args: Extra arguments to pass to
28
+ ``cli.make_context``.
29
+ :param prog_name: Name of the executable in the shell.
30
+ :param complete_var: Name of the environment variable that holds
31
+ the completion instruction.
32
+ :param instruction: Value of ``complete_var`` with the completion
33
+ instruction and shell, in the form ``instruction_shell``.
34
+ :return: Status code to exit with.
35
+ """
36
+ shell, _, instruction = instruction.partition("_")
37
+ comp_cls = get_completion_class(shell)
38
+
39
+ if comp_cls is None:
40
+ return 1
41
+
42
+ comp = comp_cls(cli, ctx_args, prog_name, complete_var)
43
+
44
+ if instruction == "source":
45
+ echo(comp.source())
46
+ return 0
47
+
48
+ if instruction == "complete":
49
+ echo(comp.complete())
50
+ return 0
51
+
52
+ return 1
53
+
54
+
55
+ class CompletionItem:
56
+ """Represents a completion value and metadata about the value. The
57
+ default metadata is ``type`` to indicate special shell handling,
58
+ and ``help`` if a shell supports showing a help string next to the
59
+ value.
60
+
61
+ Arbitrary parameters can be passed when creating the object, and
62
+ accessed using ``item.attr``. If an attribute wasn't passed,
63
+ accessing it returns ``None``.
64
+
65
+ :param value: The completion suggestion.
66
+ :param type: Tells the shell script to provide special completion
67
+ support for the type. Click uses ``"dir"`` and ``"file"``.
68
+ :param help: String shown next to the value if supported.
69
+ :param kwargs: Arbitrary metadata. The built-in implementations
70
+ don't use this, but custom type completions paired with custom
71
+ shell support could use it.
72
+ """
73
+
74
+ __slots__ = ("value", "type", "help", "_info")
75
+
76
+ def __init__(
77
+ self,
78
+ value: t.Any,
79
+ type: str = "plain",
80
+ help: t.Optional[str] = None,
81
+ **kwargs: t.Any,
82
+ ) -> None:
83
+ self.value: t.Any = value
84
+ self.type: str = type
85
+ self.help: t.Optional[str] = help
86
+ self._info = kwargs
87
+
88
+ def __getattr__(self, name: str) -> t.Any:
89
+ return self._info.get(name)
90
+
91
+
92
+ # Only Bash >= 4.4 has the nosort option.
93
+ _SOURCE_BASH = """\
94
+ %(complete_func)s() {
95
+ local IFS=$'\\n'
96
+ local response
97
+
98
+ response=$(env COMP_WORDS="${COMP_WORDS[*]}" COMP_CWORD=$COMP_CWORD \
99
+ %(complete_var)s=bash_complete $1)
100
+
101
+ for completion in $response; do
102
+ IFS=',' read type value <<< "$completion"
103
+
104
+ if [[ $type == 'dir' ]]; then
105
+ COMPREPLY=()
106
+ compopt -o dirnames
107
+ elif [[ $type == 'file' ]]; then
108
+ COMPREPLY=()
109
+ compopt -o default
110
+ elif [[ $type == 'plain' ]]; then
111
+ COMPREPLY+=($value)
112
+ fi
113
+ done
114
+
115
+ return 0
116
+ }
117
+
118
+ %(complete_func)s_setup() {
119
+ complete -o nosort -F %(complete_func)s %(prog_name)s
120
+ }
121
+
122
+ %(complete_func)s_setup;
123
+ """
124
+
125
+ _SOURCE_ZSH = """\
126
+ #compdef %(prog_name)s
127
+
128
+ %(complete_func)s() {
129
+ local -a completions
130
+ local -a completions_with_descriptions
131
+ local -a response
132
+ (( ! $+commands[%(prog_name)s] )) && return 1
133
+
134
+ response=("${(@f)$(env COMP_WORDS="${words[*]}" COMP_CWORD=$((CURRENT-1)) \
135
+ %(complete_var)s=zsh_complete %(prog_name)s)}")
136
+
137
+ for type key descr in ${response}; do
138
+ if [[ "$type" == "plain" ]]; then
139
+ if [[ "$descr" == "_" ]]; then
140
+ completions+=("$key")
141
+ else
142
+ completions_with_descriptions+=("$key":"$descr")
143
+ fi
144
+ elif [[ "$type" == "dir" ]]; then
145
+ _path_files -/
146
+ elif [[ "$type" == "file" ]]; then
147
+ _path_files -f
148
+ fi
149
+ done
150
+
151
+ if [ -n "$completions_with_descriptions" ]; then
152
+ _describe -V unsorted completions_with_descriptions -U
153
+ fi
154
+
155
+ if [ -n "$completions" ]; then
156
+ compadd -U -V unsorted -a completions
157
+ fi
158
+ }
159
+
160
+ if [[ $zsh_eval_context[-1] == loadautofunc ]]; then
161
+ # autoload from fpath, call function directly
162
+ %(complete_func)s "$@"
163
+ else
164
+ # eval/source/. command, register function for later
165
+ compdef %(complete_func)s %(prog_name)s
166
+ fi
167
+ """
168
+
169
+ _SOURCE_FISH = """\
170
+ function %(complete_func)s;
171
+ set -l response (env %(complete_var)s=fish_complete COMP_WORDS=(commandline -cp) \
172
+ COMP_CWORD=(commandline -t) %(prog_name)s);
173
+
174
+ for completion in $response;
175
+ set -l metadata (string split "," $completion);
176
+
177
+ if test $metadata[1] = "dir";
178
+ __fish_complete_directories $metadata[2];
179
+ else if test $metadata[1] = "file";
180
+ __fish_complete_path $metadata[2];
181
+ else if test $metadata[1] = "plain";
182
+ echo $metadata[2];
183
+ end;
184
+ end;
185
+ end;
186
+
187
+ complete --no-files --command %(prog_name)s --arguments \
188
+ "(%(complete_func)s)";
189
+ """
190
+
191
+
192
+ class ShellComplete:
193
+ """Base class for providing shell completion support. A subclass for
194
+ a given shell will override attributes and methods to implement the
195
+ completion instructions (``source`` and ``complete``).
196
+
197
+ :param cli: Command being called.
198
+ :param prog_name: Name of the executable in the shell.
199
+ :param complete_var: Name of the environment variable that holds
200
+ the completion instruction.
201
+
202
+ .. versionadded:: 8.0
203
+ """
204
+
205
+ name: t.ClassVar[str]
206
+ """Name to register the shell as with :func:`add_completion_class`.
207
+ This is used in completion instructions (``{name}_source`` and
208
+ ``{name}_complete``).
209
+ """
210
+
211
+ source_template: t.ClassVar[str]
212
+ """Completion script template formatted by :meth:`source`. This must
213
+ be provided by subclasses.
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ cli: BaseCommand,
219
+ ctx_args: t.MutableMapping[str, t.Any],
220
+ prog_name: str,
221
+ complete_var: str,
222
+ ) -> None:
223
+ self.cli = cli
224
+ self.ctx_args = ctx_args
225
+ self.prog_name = prog_name
226
+ self.complete_var = complete_var
227
+
228
+ @property
229
+ def func_name(self) -> str:
230
+ """The name of the shell function defined by the completion
231
+ script.
232
+ """
233
+ safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), flags=re.ASCII)
234
+ return f"_{safe_name}_completion"
235
+
236
+ def source_vars(self) -> t.Dict[str, t.Any]:
237
+ """Vars for formatting :attr:`source_template`.
238
+
239
+ By default this provides ``complete_func``, ``complete_var``,
240
+ and ``prog_name``.
241
+ """
242
+ return {
243
+ "complete_func": self.func_name,
244
+ "complete_var": self.complete_var,
245
+ "prog_name": self.prog_name,
246
+ }
247
+
248
+ def source(self) -> str:
249
+ """Produce the shell script that defines the completion
250
+ function. By default this ``%``-style formats
251
+ :attr:`source_template` with the dict returned by
252
+ :meth:`source_vars`.
253
+ """
254
+ return self.source_template % self.source_vars()
255
+
256
+ def get_completion_args(self) -> t.Tuple[t.List[str], str]:
257
+ """Use the env vars defined by the shell script to return a
258
+ tuple of ``args, incomplete``. This must be implemented by
259
+ subclasses.
260
+ """
261
+ raise NotImplementedError
262
+
263
+ def get_completions(
264
+ self, args: t.List[str], incomplete: str
265
+ ) -> t.List[CompletionItem]:
266
+ """Determine the context and last complete command or parameter
267
+ from the complete args. Call that object's ``shell_complete``
268
+ method to get the completions for the incomplete value.
269
+
270
+ :param args: List of complete args before the incomplete value.
271
+ :param incomplete: Value being completed. May be empty.
272
+ """
273
+ ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args)
274
+ obj, incomplete = _resolve_incomplete(ctx, args, incomplete)
275
+ return obj.shell_complete(ctx, incomplete)
276
+
277
+ def format_completion(self, item: CompletionItem) -> str:
278
+ """Format a completion item into the form recognized by the
279
+ shell script. This must be implemented by subclasses.
280
+
281
+ :param item: Completion item to format.
282
+ """
283
+ raise NotImplementedError
284
+
285
+ def complete(self) -> str:
286
+ """Produce the completion data to send back to the shell.
287
+
288
+ By default this calls :meth:`get_completion_args`, gets the
289
+ completions, then calls :meth:`format_completion` for each
290
+ completion.
291
+ """
292
+ args, incomplete = self.get_completion_args()
293
+ completions = self.get_completions(args, incomplete)
294
+ out = [self.format_completion(item) for item in completions]
295
+ return "\n".join(out)
296
+
297
+
298
+ class BashComplete(ShellComplete):
299
+ """Shell completion for Bash."""
300
+
301
+ name = "bash"
302
+ source_template = _SOURCE_BASH
303
+
304
+ @staticmethod
305
+ def _check_version() -> None:
306
+ import shutil
307
+ import subprocess
308
+
309
+ bash_exe = shutil.which("bash")
310
+
311
+ if bash_exe is None:
312
+ match = None
313
+ else:
314
+ output = subprocess.run(
315
+ [bash_exe, "--norc", "-c", 'echo "${BASH_VERSION}"'],
316
+ stdout=subprocess.PIPE,
317
+ )
318
+ match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode())
319
+
320
+ if match is not None:
321
+ major, minor = match.groups()
322
+
323
+ if major < "4" or major == "4" and minor < "4":
324
+ echo(
325
+ _(
326
+ "Shell completion is not supported for Bash"
327
+ " versions older than 4.4."
328
+ ),
329
+ err=True,
330
+ )
331
+ else:
332
+ echo(
333
+ _("Couldn't detect Bash version, shell completion is not supported."),
334
+ err=True,
335
+ )
336
+
337
+ def source(self) -> str:
338
+ self._check_version()
339
+ return super().source()
340
+
341
+ def get_completion_args(self) -> t.Tuple[t.List[str], str]:
342
+ cwords = split_arg_string(os.environ["COMP_WORDS"])
343
+ cword = int(os.environ["COMP_CWORD"])
344
+ args = cwords[1:cword]
345
+
346
+ try:
347
+ incomplete = cwords[cword]
348
+ except IndexError:
349
+ incomplete = ""
350
+
351
+ return args, incomplete
352
+
353
+ def format_completion(self, item: CompletionItem) -> str:
354
+ return f"{item.type},{item.value}"
355
+
356
+
357
+ class ZshComplete(ShellComplete):
358
+ """Shell completion for Zsh."""
359
+
360
+ name = "zsh"
361
+ source_template = _SOURCE_ZSH
362
+
363
+ def get_completion_args(self) -> t.Tuple[t.List[str], str]:
364
+ cwords = split_arg_string(os.environ["COMP_WORDS"])
365
+ cword = int(os.environ["COMP_CWORD"])
366
+ args = cwords[1:cword]
367
+
368
+ try:
369
+ incomplete = cwords[cword]
370
+ except IndexError:
371
+ incomplete = ""
372
+
373
+ return args, incomplete
374
+
375
+ def format_completion(self, item: CompletionItem) -> str:
376
+ return f"{item.type}\n{item.value}\n{item.help if item.help else '_'}"
377
+
378
+
379
+ class FishComplete(ShellComplete):
380
+ """Shell completion for Fish."""
381
+
382
+ name = "fish"
383
+ source_template = _SOURCE_FISH
384
+
385
+ def get_completion_args(self) -> t.Tuple[t.List[str], str]:
386
+ cwords = split_arg_string(os.environ["COMP_WORDS"])
387
+ incomplete = os.environ["COMP_CWORD"]
388
+ args = cwords[1:]
389
+
390
+ # Fish stores the partial word in both COMP_WORDS and
391
+ # COMP_CWORD, remove it from complete args.
392
+ if incomplete and args and args[-1] == incomplete:
393
+ args.pop()
394
+
395
+ return args, incomplete
396
+
397
+ def format_completion(self, item: CompletionItem) -> str:
398
+ if item.help:
399
+ return f"{item.type},{item.value}\t{item.help}"
400
+
401
+ return f"{item.type},{item.value}"
402
+
403
+
404
+ ShellCompleteType = t.TypeVar("ShellCompleteType", bound=t.Type[ShellComplete])
405
+
406
+
407
+ _available_shells: t.Dict[str, t.Type[ShellComplete]] = {
408
+ "bash": BashComplete,
409
+ "fish": FishComplete,
410
+ "zsh": ZshComplete,
411
+ }
412
+
413
+
414
+ def add_completion_class(
415
+ cls: ShellCompleteType, name: t.Optional[str] = None
416
+ ) -> ShellCompleteType:
417
+ """Register a :class:`ShellComplete` subclass under the given name.
418
+ The name will be provided by the completion instruction environment
419
+ variable during completion.
420
+
421
+ :param cls: The completion class that will handle completion for the
422
+ shell.
423
+ :param name: Name to register the class under. Defaults to the
424
+ class's ``name`` attribute.
425
+ """
426
+ if name is None:
427
+ name = cls.name
428
+
429
+ _available_shells[name] = cls
430
+
431
+ return cls
432
+
433
+
434
+ def get_completion_class(shell: str) -> t.Optional[t.Type[ShellComplete]]:
435
+ """Look up a registered :class:`ShellComplete` subclass by the name
436
+ provided by the completion instruction environment variable. If the
437
+ name isn't registered, returns ``None``.
438
+
439
+ :param shell: Name the class is registered under.
440
+ """
441
+ return _available_shells.get(shell)
442
+
443
+
444
+ def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool:
445
+ """Determine if the given parameter is an argument that can still
446
+ accept values.
447
+
448
+ :param ctx: Invocation context for the command represented by the
449
+ parsed complete args.
450
+ :param param: Argument object being checked.
451
+ """
452
+ if not isinstance(param, Argument):
453
+ return False
454
+
455
+ assert param.name is not None
456
+ # Will be None if expose_value is False.
457
+ value = ctx.params.get(param.name)
458
+ return (
459
+ param.nargs == -1
460
+ or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE
461
+ or (
462
+ param.nargs > 1
463
+ and isinstance(value, (tuple, list))
464
+ and len(value) < param.nargs
465
+ )
466
+ )
467
+
468
+
469
+ def _start_of_option(ctx: Context, value: str) -> bool:
470
+ """Check if the value looks like the start of an option."""
471
+ if not value:
472
+ return False
473
+
474
+ c = value[0]
475
+ return c in ctx._opt_prefixes
476
+
477
+
478
+ def _is_incomplete_option(ctx: Context, args: t.List[str], param: Parameter) -> bool:
479
+ """Determine if the given parameter is an option that needs a value.
480
+
481
+ :param args: List of complete args before the incomplete value.
482
+ :param param: Option object being checked.
483
+ """
484
+ if not isinstance(param, Option):
485
+ return False
486
+
487
+ if param.is_flag or param.count:
488
+ return False
489
+
490
+ last_option = None
491
+
492
+ for index, arg in enumerate(reversed(args)):
493
+ if index + 1 > param.nargs:
494
+ break
495
+
496
+ if _start_of_option(ctx, arg):
497
+ last_option = arg
498
+
499
+ return last_option is not None and last_option in param.opts
500
+
501
+
502
+ def _resolve_context(
503
+ cli: BaseCommand,
504
+ ctx_args: t.MutableMapping[str, t.Any],
505
+ prog_name: str,
506
+ args: t.List[str],
507
+ ) -> Context:
508
+ """Produce the context hierarchy starting with the command and
509
+ traversing the complete arguments. This only follows the commands,
510
+ it doesn't trigger input prompts or callbacks.
511
+
512
+ :param cli: Command being called.
513
+ :param prog_name: Name of the executable in the shell.
514
+ :param args: List of complete args before the incomplete value.
515
+ """
516
+ ctx_args["resilient_parsing"] = True
517
+ ctx = cli.make_context(prog_name, args.copy(), **ctx_args)
518
+ args = ctx.protected_args + ctx.args
519
+
520
+ while args:
521
+ command = ctx.command
522
+
523
+ if isinstance(command, MultiCommand):
524
+ if not command.chain:
525
+ name, cmd, args = command.resolve_command(ctx, args)
526
+
527
+ if cmd is None:
528
+ return ctx
529
+
530
+ ctx = cmd.make_context(name, args, parent=ctx, resilient_parsing=True)
531
+ args = ctx.protected_args + ctx.args
532
+ else:
533
+ sub_ctx = ctx
534
+
535
+ while args:
536
+ name, cmd, args = command.resolve_command(ctx, args)
537
+
538
+ if cmd is None:
539
+ return ctx
540
+
541
+ sub_ctx = cmd.make_context(
542
+ name,
543
+ args,
544
+ parent=ctx,
545
+ allow_extra_args=True,
546
+ allow_interspersed_args=False,
547
+ resilient_parsing=True,
548
+ )
549
+ args = sub_ctx.args
550
+
551
+ ctx = sub_ctx
552
+ args = [*sub_ctx.protected_args, *sub_ctx.args]
553
+ else:
554
+ break
555
+
556
+ return ctx
557
+
558
+
559
+ def _resolve_incomplete(
560
+ ctx: Context, args: t.List[str], incomplete: str
561
+ ) -> t.Tuple[t.Union[BaseCommand, Parameter], str]:
562
+ """Find the Click object that will handle the completion of the
563
+ incomplete value. Return the object and the incomplete value.
564
+
565
+ :param ctx: Invocation context for the command represented by
566
+ the parsed complete args.
567
+ :param args: List of complete args before the incomplete value.
568
+ :param incomplete: Value being completed. May be empty.
569
+ """
570
+ # Different shells treat an "=" between a long option name and
571
+ # value differently. Might keep the value joined, return the "="
572
+ # as a separate item, or return the split name and value. Always
573
+ # split and discard the "=" to make completion easier.
574
+ if incomplete == "=":
575
+ incomplete = ""
576
+ elif "=" in incomplete and _start_of_option(ctx, incomplete):
577
+ name, _, incomplete = incomplete.partition("=")
578
+ args.append(name)
579
+
580
+ # The "--" marker tells Click to stop treating values as options
581
+ # even if they start with the option character. If it hasn't been
582
+ # given and the incomplete arg looks like an option, the current
583
+ # command will provide option name completions.
584
+ if "--" not in args and _start_of_option(ctx, incomplete):
585
+ return ctx.command, incomplete
586
+
587
+ params = ctx.command.get_params(ctx)
588
+
589
+ # If the last complete arg is an option name with an incomplete
590
+ # value, the option will provide value completions.
591
+ for param in params:
592
+ if _is_incomplete_option(ctx, args, param):
593
+ return param, incomplete
594
+
595
+ # It's not an option name or value. The first argument without a
596
+ # parsed value will provide value completions.
597
+ for param in params:
598
+ if _is_incomplete_argument(ctx, param):
599
+ return param, incomplete
600
+
601
+ # There were no unparsed arguments, the command may be a group that
602
+ # will provide command name completions.
603
+ return ctx.command, incomplete
.venv/lib/python3.11/site-packages/click/termui.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import io
3
+ import itertools
4
+ import sys
5
+ import typing as t
6
+ from gettext import gettext as _
7
+
8
+ from ._compat import isatty
9
+ from ._compat import strip_ansi
10
+ from .exceptions import Abort
11
+ from .exceptions import UsageError
12
+ from .globals import resolve_color_default
13
+ from .types import Choice
14
+ from .types import convert_type
15
+ from .types import ParamType
16
+ from .utils import echo
17
+ from .utils import LazyFile
18
+
19
+ if t.TYPE_CHECKING:
20
+ from ._termui_impl import ProgressBar
21
+
22
+ V = t.TypeVar("V")
23
+
24
+ # The prompt functions to use. The doc tools currently override these
25
+ # functions to customize how they work.
26
+ visible_prompt_func: t.Callable[[str], str] = input
27
+
28
+ _ansi_colors = {
29
+ "black": 30,
30
+ "red": 31,
31
+ "green": 32,
32
+ "yellow": 33,
33
+ "blue": 34,
34
+ "magenta": 35,
35
+ "cyan": 36,
36
+ "white": 37,
37
+ "reset": 39,
38
+ "bright_black": 90,
39
+ "bright_red": 91,
40
+ "bright_green": 92,
41
+ "bright_yellow": 93,
42
+ "bright_blue": 94,
43
+ "bright_magenta": 95,
44
+ "bright_cyan": 96,
45
+ "bright_white": 97,
46
+ }
47
+ _ansi_reset_all = "\033[0m"
48
+
49
+
50
+ def hidden_prompt_func(prompt: str) -> str:
51
+ import getpass
52
+
53
+ return getpass.getpass(prompt)
54
+
55
+
56
+ def _build_prompt(
57
+ text: str,
58
+ suffix: str,
59
+ show_default: bool = False,
60
+ default: t.Optional[t.Any] = None,
61
+ show_choices: bool = True,
62
+ type: t.Optional[ParamType] = None,
63
+ ) -> str:
64
+ prompt = text
65
+ if type is not None and show_choices and isinstance(type, Choice):
66
+ prompt += f" ({', '.join(map(str, type.choices))})"
67
+ if default is not None and show_default:
68
+ prompt = f"{prompt} [{_format_default(default)}]"
69
+ return f"{prompt}{suffix}"
70
+
71
+
72
+ def _format_default(default: t.Any) -> t.Any:
73
+ if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"):
74
+ return default.name
75
+
76
+ return default
77
+
78
+
79
+ def prompt(
80
+ text: str,
81
+ default: t.Optional[t.Any] = None,
82
+ hide_input: bool = False,
83
+ confirmation_prompt: t.Union[bool, str] = False,
84
+ type: t.Optional[t.Union[ParamType, t.Any]] = None,
85
+ value_proc: t.Optional[t.Callable[[str], t.Any]] = None,
86
+ prompt_suffix: str = ": ",
87
+ show_default: bool = True,
88
+ err: bool = False,
89
+ show_choices: bool = True,
90
+ ) -> t.Any:
91
+ """Prompts a user for input. This is a convenience function that can
92
+ be used to prompt a user for input later.
93
+
94
+ If the user aborts the input by sending an interrupt signal, this
95
+ function will catch it and raise a :exc:`Abort` exception.
96
+
97
+ :param text: the text to show for the prompt.
98
+ :param default: the default value to use if no input happens. If this
99
+ is not given it will prompt until it's aborted.
100
+ :param hide_input: if this is set to true then the input value will
101
+ be hidden.
102
+ :param confirmation_prompt: Prompt a second time to confirm the
103
+ value. Can be set to a string instead of ``True`` to customize
104
+ the message.
105
+ :param type: the type to use to check the value against.
106
+ :param value_proc: if this parameter is provided it's a function that
107
+ is invoked instead of the type conversion to
108
+ convert a value.
109
+ :param prompt_suffix: a suffix that should be added to the prompt.
110
+ :param show_default: shows or hides the default value in the prompt.
111
+ :param err: if set to true the file defaults to ``stderr`` instead of
112
+ ``stdout``, the same as with echo.
113
+ :param show_choices: Show or hide choices if the passed type is a Choice.
114
+ For example if type is a Choice of either day or week,
115
+ show_choices is true and text is "Group by" then the
116
+ prompt will be "Group by (day, week): ".
117
+
118
+ .. versionadded:: 8.0
119
+ ``confirmation_prompt`` can be a custom string.
120
+
121
+ .. versionadded:: 7.0
122
+ Added the ``show_choices`` parameter.
123
+
124
+ .. versionadded:: 6.0
125
+ Added unicode support for cmd.exe on Windows.
126
+
127
+ .. versionadded:: 4.0
128
+ Added the `err` parameter.
129
+
130
+ """
131
+
132
+ def prompt_func(text: str) -> str:
133
+ f = hidden_prompt_func if hide_input else visible_prompt_func
134
+ try:
135
+ # Write the prompt separately so that we get nice
136
+ # coloring through colorama on Windows
137
+ echo(text.rstrip(" "), nl=False, err=err)
138
+ # Echo a space to stdout to work around an issue where
139
+ # readline causes backspace to clear the whole line.
140
+ return f(" ")
141
+ except (KeyboardInterrupt, EOFError):
142
+ # getpass doesn't print a newline if the user aborts input with ^C.
143
+ # Allegedly this behavior is inherited from getpass(3).
144
+ # A doc bug has been filed at https://bugs.python.org/issue24711
145
+ if hide_input:
146
+ echo(None, err=err)
147
+ raise Abort() from None
148
+
149
+ if value_proc is None:
150
+ value_proc = convert_type(type, default)
151
+
152
+ prompt = _build_prompt(
153
+ text, prompt_suffix, show_default, default, show_choices, type
154
+ )
155
+
156
+ if confirmation_prompt:
157
+ if confirmation_prompt is True:
158
+ confirmation_prompt = _("Repeat for confirmation")
159
+
160
+ confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix)
161
+
162
+ while True:
163
+ while True:
164
+ value = prompt_func(prompt)
165
+ if value:
166
+ break
167
+ elif default is not None:
168
+ value = default
169
+ break
170
+ try:
171
+ result = value_proc(value)
172
+ except UsageError as e:
173
+ if hide_input:
174
+ echo(_("Error: The value you entered was invalid."), err=err)
175
+ else:
176
+ echo(_("Error: {e.message}").format(e=e), err=err)
177
+ continue
178
+ if not confirmation_prompt:
179
+ return result
180
+ while True:
181
+ value2 = prompt_func(confirmation_prompt)
182
+ is_empty = not value and not value2
183
+ if value2 or is_empty:
184
+ break
185
+ if value == value2:
186
+ return result
187
+ echo(_("Error: The two entered values do not match."), err=err)
188
+
189
+
190
+ def confirm(
191
+ text: str,
192
+ default: t.Optional[bool] = False,
193
+ abort: bool = False,
194
+ prompt_suffix: str = ": ",
195
+ show_default: bool = True,
196
+ err: bool = False,
197
+ ) -> bool:
198
+ """Prompts for confirmation (yes/no question).
199
+
200
+ If the user aborts the input by sending a interrupt signal this
201
+ function will catch it and raise a :exc:`Abort` exception.
202
+
203
+ :param text: the question to ask.
204
+ :param default: The default value to use when no input is given. If
205
+ ``None``, repeat until input is given.
206
+ :param abort: if this is set to `True` a negative answer aborts the
207
+ exception by raising :exc:`Abort`.
208
+ :param prompt_suffix: a suffix that should be added to the prompt.
209
+ :param show_default: shows or hides the default value in the prompt.
210
+ :param err: if set to true the file defaults to ``stderr`` instead of
211
+ ``stdout``, the same as with echo.
212
+
213
+ .. versionchanged:: 8.0
214
+ Repeat until input is given if ``default`` is ``None``.
215
+
216
+ .. versionadded:: 4.0
217
+ Added the ``err`` parameter.
218
+ """
219
+ prompt = _build_prompt(
220
+ text,
221
+ prompt_suffix,
222
+ show_default,
223
+ "y/n" if default is None else ("Y/n" if default else "y/N"),
224
+ )
225
+
226
+ while True:
227
+ try:
228
+ # Write the prompt separately so that we get nice
229
+ # coloring through colorama on Windows
230
+ echo(prompt.rstrip(" "), nl=False, err=err)
231
+ # Echo a space to stdout to work around an issue where
232
+ # readline causes backspace to clear the whole line.
233
+ value = visible_prompt_func(" ").lower().strip()
234
+ except (KeyboardInterrupt, EOFError):
235
+ raise Abort() from None
236
+ if value in ("y", "yes"):
237
+ rv = True
238
+ elif value in ("n", "no"):
239
+ rv = False
240
+ elif default is not None and value == "":
241
+ rv = default
242
+ else:
243
+ echo(_("Error: invalid input"), err=err)
244
+ continue
245
+ break
246
+ if abort and not rv:
247
+ raise Abort()
248
+ return rv
249
+
250
+
251
+ def echo_via_pager(
252
+ text_or_generator: t.Union[t.Iterable[str], t.Callable[[], t.Iterable[str]], str],
253
+ color: t.Optional[bool] = None,
254
+ ) -> None:
255
+ """This function takes a text and shows it via an environment specific
256
+ pager on stdout.
257
+
258
+ .. versionchanged:: 3.0
259
+ Added the `color` flag.
260
+
261
+ :param text_or_generator: the text to page, or alternatively, a
262
+ generator emitting the text to page.
263
+ :param color: controls if the pager supports ANSI colors or not. The
264
+ default is autodetection.
265
+ """
266
+ color = resolve_color_default(color)
267
+
268
+ if inspect.isgeneratorfunction(text_or_generator):
269
+ i = t.cast(t.Callable[[], t.Iterable[str]], text_or_generator)()
270
+ elif isinstance(text_or_generator, str):
271
+ i = [text_or_generator]
272
+ else:
273
+ i = iter(t.cast(t.Iterable[str], text_or_generator))
274
+
275
+ # convert every element of i to a text type if necessary
276
+ text_generator = (el if isinstance(el, str) else str(el) for el in i)
277
+
278
+ from ._termui_impl import pager
279
+
280
+ return pager(itertools.chain(text_generator, "\n"), color)
281
+
282
+
283
+ def progressbar(
284
+ iterable: t.Optional[t.Iterable[V]] = None,
285
+ length: t.Optional[int] = None,
286
+ label: t.Optional[str] = None,
287
+ show_eta: bool = True,
288
+ show_percent: t.Optional[bool] = None,
289
+ show_pos: bool = False,
290
+ item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
291
+ fill_char: str = "#",
292
+ empty_char: str = "-",
293
+ bar_template: str = "%(label)s [%(bar)s] %(info)s",
294
+ info_sep: str = " ",
295
+ width: int = 36,
296
+ file: t.Optional[t.TextIO] = None,
297
+ color: t.Optional[bool] = None,
298
+ update_min_steps: int = 1,
299
+ ) -> "ProgressBar[V]":
300
+ """This function creates an iterable context manager that can be used
301
+ to iterate over something while showing a progress bar. It will
302
+ either iterate over the `iterable` or `length` items (that are counted
303
+ up). While iteration happens, this function will print a rendered
304
+ progress bar to the given `file` (defaults to stdout) and will attempt
305
+ to calculate remaining time and more. By default, this progress bar
306
+ will not be rendered if the file is not a terminal.
307
+
308
+ The context manager creates the progress bar. When the context
309
+ manager is entered the progress bar is already created. With every
310
+ iteration over the progress bar, the iterable passed to the bar is
311
+ advanced and the bar is updated. When the context manager exits,
312
+ a newline is printed and the progress bar is finalized on screen.
313
+
314
+ Note: The progress bar is currently designed for use cases where the
315
+ total progress can be expected to take at least several seconds.
316
+ Because of this, the ProgressBar class object won't display
317
+ progress that is considered too fast, and progress where the time
318
+ between steps is less than a second.
319
+
320
+ No printing must happen or the progress bar will be unintentionally
321
+ destroyed.
322
+
323
+ Example usage::
324
+
325
+ with progressbar(items) as bar:
326
+ for item in bar:
327
+ do_something_with(item)
328
+
329
+ Alternatively, if no iterable is specified, one can manually update the
330
+ progress bar through the `update()` method instead of directly
331
+ iterating over the progress bar. The update method accepts the number
332
+ of steps to increment the bar with::
333
+
334
+ with progressbar(length=chunks.total_bytes) as bar:
335
+ for chunk in chunks:
336
+ process_chunk(chunk)
337
+ bar.update(chunks.bytes)
338
+
339
+ The ``update()`` method also takes an optional value specifying the
340
+ ``current_item`` at the new position. This is useful when used
341
+ together with ``item_show_func`` to customize the output for each
342
+ manual step::
343
+
344
+ with click.progressbar(
345
+ length=total_size,
346
+ label='Unzipping archive',
347
+ item_show_func=lambda a: a.filename
348
+ ) as bar:
349
+ for archive in zip_file:
350
+ archive.extract()
351
+ bar.update(archive.size, archive)
352
+
353
+ :param iterable: an iterable to iterate over. If not provided the length
354
+ is required.
355
+ :param length: the number of items to iterate over. By default the
356
+ progressbar will attempt to ask the iterator about its
357
+ length, which might or might not work. If an iterable is
358
+ also provided this parameter can be used to override the
359
+ length. If an iterable is not provided the progress bar
360
+ will iterate over a range of that length.
361
+ :param label: the label to show next to the progress bar.
362
+ :param show_eta: enables or disables the estimated time display. This is
363
+ automatically disabled if the length cannot be
364
+ determined.
365
+ :param show_percent: enables or disables the percentage display. The
366
+ default is `True` if the iterable has a length or
367
+ `False` if not.
368
+ :param show_pos: enables or disables the absolute position display. The
369
+ default is `False`.
370
+ :param item_show_func: A function called with the current item which
371
+ can return a string to show next to the progress bar. If the
372
+ function returns ``None`` nothing is shown. The current item can
373
+ be ``None``, such as when entering and exiting the bar.
374
+ :param fill_char: the character to use to show the filled part of the
375
+ progress bar.
376
+ :param empty_char: the character to use to show the non-filled part of
377
+ the progress bar.
378
+ :param bar_template: the format string to use as template for the bar.
379
+ The parameters in it are ``label`` for the label,
380
+ ``bar`` for the progress bar and ``info`` for the
381
+ info section.
382
+ :param info_sep: the separator between multiple info items (eta etc.)
383
+ :param width: the width of the progress bar in characters, 0 means full
384
+ terminal width
385
+ :param file: The file to write to. If this is not a terminal then
386
+ only the label is printed.
387
+ :param color: controls if the terminal supports ANSI colors or not. The
388
+ default is autodetection. This is only needed if ANSI
389
+ codes are included anywhere in the progress bar output
390
+ which is not the case by default.
391
+ :param update_min_steps: Render only when this many updates have
392
+ completed. This allows tuning for very fast iterators.
393
+
394
+ .. versionchanged:: 8.0
395
+ Output is shown even if execution time is less than 0.5 seconds.
396
+
397
+ .. versionchanged:: 8.0
398
+ ``item_show_func`` shows the current item, not the previous one.
399
+
400
+ .. versionchanged:: 8.0
401
+ Labels are echoed if the output is not a TTY. Reverts a change
402
+ in 7.0 that removed all output.
403
+
404
+ .. versionadded:: 8.0
405
+ Added the ``update_min_steps`` parameter.
406
+
407
+ .. versionchanged:: 4.0
408
+ Added the ``color`` parameter. Added the ``update`` method to
409
+ the object.
410
+
411
+ .. versionadded:: 2.0
412
+ """
413
+ from ._termui_impl import ProgressBar
414
+
415
+ color = resolve_color_default(color)
416
+ return ProgressBar(
417
+ iterable=iterable,
418
+ length=length,
419
+ show_eta=show_eta,
420
+ show_percent=show_percent,
421
+ show_pos=show_pos,
422
+ item_show_func=item_show_func,
423
+ fill_char=fill_char,
424
+ empty_char=empty_char,
425
+ bar_template=bar_template,
426
+ info_sep=info_sep,
427
+ file=file,
428
+ label=label,
429
+ width=width,
430
+ color=color,
431
+ update_min_steps=update_min_steps,
432
+ )
433
+
434
+
435
+ def clear() -> None:
436
+ """Clears the terminal screen. This will have the effect of clearing
437
+ the whole visible space of the terminal and moving the cursor to the
438
+ top left. This does not do anything if not connected to a terminal.
439
+
440
+ .. versionadded:: 2.0
441
+ """
442
+ if not isatty(sys.stdout):
443
+ return
444
+
445
+ # ANSI escape \033[2J clears the screen, \033[1;1H moves the cursor
446
+ echo("\033[2J\033[1;1H", nl=False)
447
+
448
+
449
+ def _interpret_color(
450
+ color: t.Union[int, t.Tuple[int, int, int], str], offset: int = 0
451
+ ) -> str:
452
+ if isinstance(color, int):
453
+ return f"{38 + offset};5;{color:d}"
454
+
455
+ if isinstance(color, (tuple, list)):
456
+ r, g, b = color
457
+ return f"{38 + offset};2;{r:d};{g:d};{b:d}"
458
+
459
+ return str(_ansi_colors[color] + offset)
460
+
461
+
462
+ def style(
463
+ text: t.Any,
464
+ fg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
465
+ bg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
466
+ bold: t.Optional[bool] = None,
467
+ dim: t.Optional[bool] = None,
468
+ underline: t.Optional[bool] = None,
469
+ overline: t.Optional[bool] = None,
470
+ italic: t.Optional[bool] = None,
471
+ blink: t.Optional[bool] = None,
472
+ reverse: t.Optional[bool] = None,
473
+ strikethrough: t.Optional[bool] = None,
474
+ reset: bool = True,
475
+ ) -> str:
476
+ """Styles a text with ANSI styles and returns the new string. By
477
+ default the styling is self contained which means that at the end
478
+ of the string a reset code is issued. This can be prevented by
479
+ passing ``reset=False``.
480
+
481
+ Examples::
482
+
483
+ click.echo(click.style('Hello World!', fg='green'))
484
+ click.echo(click.style('ATTENTION!', blink=True))
485
+ click.echo(click.style('Some things', reverse=True, fg='cyan'))
486
+ click.echo(click.style('More colors', fg=(255, 12, 128), bg=117))
487
+
488
+ Supported color names:
489
+
490
+ * ``black`` (might be a gray)
491
+ * ``red``
492
+ * ``green``
493
+ * ``yellow`` (might be an orange)
494
+ * ``blue``
495
+ * ``magenta``
496
+ * ``cyan``
497
+ * ``white`` (might be light gray)
498
+ * ``bright_black``
499
+ * ``bright_red``
500
+ * ``bright_green``
501
+ * ``bright_yellow``
502
+ * ``bright_blue``
503
+ * ``bright_magenta``
504
+ * ``bright_cyan``
505
+ * ``bright_white``
506
+ * ``reset`` (reset the color code only)
507
+
508
+ If the terminal supports it, color may also be specified as:
509
+
510
+ - An integer in the interval [0, 255]. The terminal must support
511
+ 8-bit/256-color mode.
512
+ - An RGB tuple of three integers in [0, 255]. The terminal must
513
+ support 24-bit/true-color mode.
514
+
515
+ See https://en.wikipedia.org/wiki/ANSI_color and
516
+ https://gist.github.com/XVilka/8346728 for more information.
517
+
518
+ :param text: the string to style with ansi codes.
519
+ :param fg: if provided this will become the foreground color.
520
+ :param bg: if provided this will become the background color.
521
+ :param bold: if provided this will enable or disable bold mode.
522
+ :param dim: if provided this will enable or disable dim mode. This is
523
+ badly supported.
524
+ :param underline: if provided this will enable or disable underline.
525
+ :param overline: if provided this will enable or disable overline.
526
+ :param italic: if provided this will enable or disable italic.
527
+ :param blink: if provided this will enable or disable blinking.
528
+ :param reverse: if provided this will enable or disable inverse
529
+ rendering (foreground becomes background and the
530
+ other way round).
531
+ :param strikethrough: if provided this will enable or disable
532
+ striking through text.
533
+ :param reset: by default a reset-all code is added at the end of the
534
+ string which means that styles do not carry over. This
535
+ can be disabled to compose styles.
536
+
537
+ .. versionchanged:: 8.0
538
+ A non-string ``message`` is converted to a string.
539
+
540
+ .. versionchanged:: 8.0
541
+ Added support for 256 and RGB color codes.
542
+
543
+ .. versionchanged:: 8.0
544
+ Added the ``strikethrough``, ``italic``, and ``overline``
545
+ parameters.
546
+
547
+ .. versionchanged:: 7.0
548
+ Added support for bright colors.
549
+
550
+ .. versionadded:: 2.0
551
+ """
552
+ if not isinstance(text, str):
553
+ text = str(text)
554
+
555
+ bits = []
556
+
557
+ if fg:
558
+ try:
559
+ bits.append(f"\033[{_interpret_color(fg)}m")
560
+ except KeyError:
561
+ raise TypeError(f"Unknown color {fg!r}") from None
562
+
563
+ if bg:
564
+ try:
565
+ bits.append(f"\033[{_interpret_color(bg, 10)}m")
566
+ except KeyError:
567
+ raise TypeError(f"Unknown color {bg!r}") from None
568
+
569
+ if bold is not None:
570
+ bits.append(f"\033[{1 if bold else 22}m")
571
+ if dim is not None:
572
+ bits.append(f"\033[{2 if dim else 22}m")
573
+ if underline is not None:
574
+ bits.append(f"\033[{4 if underline else 24}m")
575
+ if overline is not None:
576
+ bits.append(f"\033[{53 if overline else 55}m")
577
+ if italic is not None:
578
+ bits.append(f"\033[{3 if italic else 23}m")
579
+ if blink is not None:
580
+ bits.append(f"\033[{5 if blink else 25}m")
581
+ if reverse is not None:
582
+ bits.append(f"\033[{7 if reverse else 27}m")
583
+ if strikethrough is not None:
584
+ bits.append(f"\033[{9 if strikethrough else 29}m")
585
+ bits.append(text)
586
+ if reset:
587
+ bits.append(_ansi_reset_all)
588
+ return "".join(bits)
589
+
590
+
591
+ def unstyle(text: str) -> str:
592
+ """Removes ANSI styling information from a string. Usually it's not
593
+ necessary to use this function as Click's echo function will
594
+ automatically remove styling if necessary.
595
+
596
+ .. versionadded:: 2.0
597
+
598
+ :param text: the text to remove style information from.
599
+ """
600
+ return strip_ansi(text)
601
+
602
+
603
+ def secho(
604
+ message: t.Optional[t.Any] = None,
605
+ file: t.Optional[t.IO[t.AnyStr]] = None,
606
+ nl: bool = True,
607
+ err: bool = False,
608
+ color: t.Optional[bool] = None,
609
+ **styles: t.Any,
610
+ ) -> None:
611
+ """This function combines :func:`echo` and :func:`style` into one
612
+ call. As such the following two calls are the same::
613
+
614
+ click.secho('Hello World!', fg='green')
615
+ click.echo(click.style('Hello World!', fg='green'))
616
+
617
+ All keyword arguments are forwarded to the underlying functions
618
+ depending on which one they go with.
619
+
620
+ Non-string types will be converted to :class:`str`. However,
621
+ :class:`bytes` are passed directly to :meth:`echo` without applying
622
+ style. If you want to style bytes that represent text, call
623
+ :meth:`bytes.decode` first.
624
+
625
+ .. versionchanged:: 8.0
626
+ A non-string ``message`` is converted to a string. Bytes are
627
+ passed through without style applied.
628
+
629
+ .. versionadded:: 2.0
630
+ """
631
+ if message is not None and not isinstance(message, (bytes, bytearray)):
632
+ message = style(message, **styles)
633
+
634
+ return echo(message, file=file, nl=nl, err=err, color=color)
635
+
636
+
637
+ def edit(
638
+ text: t.Optional[t.AnyStr] = None,
639
+ editor: t.Optional[str] = None,
640
+ env: t.Optional[t.Mapping[str, str]] = None,
641
+ require_save: bool = True,
642
+ extension: str = ".txt",
643
+ filename: t.Optional[str] = None,
644
+ ) -> t.Optional[t.AnyStr]:
645
+ r"""Edits the given text in the defined editor. If an editor is given
646
+ (should be the full path to the executable but the regular operating
647
+ system search path is used for finding the executable) it overrides
648
+ the detected editor. Optionally, some environment variables can be
649
+ used. If the editor is closed without changes, `None` is returned. In
650
+ case a file is edited directly the return value is always `None` and
651
+ `require_save` and `extension` are ignored.
652
+
653
+ If the editor cannot be opened a :exc:`UsageError` is raised.
654
+
655
+ Note for Windows: to simplify cross-platform usage, the newlines are
656
+ automatically converted from POSIX to Windows and vice versa. As such,
657
+ the message here will have ``\n`` as newline markers.
658
+
659
+ :param text: the text to edit.
660
+ :param editor: optionally the editor to use. Defaults to automatic
661
+ detection.
662
+ :param env: environment variables to forward to the editor.
663
+ :param require_save: if this is true, then not saving in the editor
664
+ will make the return value become `None`.
665
+ :param extension: the extension to tell the editor about. This defaults
666
+ to `.txt` but changing this might change syntax
667
+ highlighting.
668
+ :param filename: if provided it will edit this file instead of the
669
+ provided text contents. It will not use a temporary
670
+ file as an indirection in that case.
671
+ """
672
+ from ._termui_impl import Editor
673
+
674
+ ed = Editor(editor=editor, env=env, require_save=require_save, extension=extension)
675
+
676
+ if filename is None:
677
+ return ed.edit(text)
678
+
679
+ ed.edit_file(filename)
680
+ return None
681
+
682
+
683
+ def launch(url: str, wait: bool = False, locate: bool = False) -> int:
684
+ """This function launches the given URL (or filename) in the default
685
+ viewer application for this file type. If this is an executable, it
686
+ might launch the executable in a new session. The return value is
687
+ the exit code of the launched application. Usually, ``0`` indicates
688
+ success.
689
+
690
+ Examples::
691
+
692
+ click.launch('https://click.palletsprojects.com/')
693
+ click.launch('/my/downloaded/file', locate=True)
694
+
695
+ .. versionadded:: 2.0
696
+
697
+ :param url: URL or filename of the thing to launch.
698
+ :param wait: Wait for the program to exit before returning. This
699
+ only works if the launched program blocks. In particular,
700
+ ``xdg-open`` on Linux does not block.
701
+ :param locate: if this is set to `True` then instead of launching the
702
+ application associated with the URL it will attempt to
703
+ launch a file manager with the file located. This
704
+ might have weird effects if the URL does not point to
705
+ the filesystem.
706
+ """
707
+ from ._termui_impl import open_url
708
+
709
+ return open_url(url, wait=wait, locate=locate)
710
+
711
+
712
+ # If this is provided, getchar() calls into this instead. This is used
713
+ # for unittesting purposes.
714
+ _getchar: t.Optional[t.Callable[[bool], str]] = None
715
+
716
+
717
+ def getchar(echo: bool = False) -> str:
718
+ """Fetches a single character from the terminal and returns it. This
719
+ will always return a unicode character and under certain rare
720
+ circumstances this might return more than one character. The
721
+ situations which more than one character is returned is when for
722
+ whatever reason multiple characters end up in the terminal buffer or
723
+ standard input was not actually a terminal.
724
+
725
+ Note that this will always read from the terminal, even if something
726
+ is piped into the standard input.
727
+
728
+ Note for Windows: in rare cases when typing non-ASCII characters, this
729
+ function might wait for a second character and then return both at once.
730
+ This is because certain Unicode characters look like special-key markers.
731
+
732
+ .. versionadded:: 2.0
733
+
734
+ :param echo: if set to `True`, the character read will also show up on
735
+ the terminal. The default is to not show it.
736
+ """
737
+ global _getchar
738
+
739
+ if _getchar is None:
740
+ from ._termui_impl import getchar as f
741
+
742
+ _getchar = f
743
+
744
+ return _getchar(echo)
745
+
746
+
747
+ def raw_terminal() -> t.ContextManager[int]:
748
+ from ._termui_impl import raw_terminal as f
749
+
750
+ return f()
751
+
752
+
753
+ def pause(info: t.Optional[str] = None, err: bool = False) -> None:
754
+ """This command stops execution and waits for the user to press any
755
+ key to continue. This is similar to the Windows batch "pause"
756
+ command. If the program is not run through a terminal, this command
757
+ will instead do nothing.
758
+
759
+ .. versionadded:: 2.0
760
+
761
+ .. versionadded:: 4.0
762
+ Added the `err` parameter.
763
+
764
+ :param info: The message to print before pausing. Defaults to
765
+ ``"Press any key to continue..."``.
766
+ :param err: if set to message goes to ``stderr`` instead of
767
+ ``stdout``, the same as with echo.
768
+ """
769
+ if not isatty(sys.stdin) or not isatty(sys.stdout):
770
+ return
771
+
772
+ if info is None:
773
+ info = _("Press any key to continue...")
774
+
775
+ try:
776
+ if info:
777
+ echo(info, nl=False, err=err)
778
+ try:
779
+ getchar()
780
+ except (KeyboardInterrupt, EOFError):
781
+ pass
782
+ finally:
783
+ if info:
784
+ echo(err=err)
.venv/lib/python3.11/site-packages/click/utils.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import typing as t
5
+ from functools import update_wrapper
6
+ from types import ModuleType
7
+ from types import TracebackType
8
+
9
+ from ._compat import _default_text_stderr
10
+ from ._compat import _default_text_stdout
11
+ from ._compat import _find_binary_writer
12
+ from ._compat import auto_wrap_for_ansi
13
+ from ._compat import binary_streams
14
+ from ._compat import open_stream
15
+ from ._compat import should_strip_ansi
16
+ from ._compat import strip_ansi
17
+ from ._compat import text_streams
18
+ from ._compat import WIN
19
+ from .globals import resolve_color_default
20
+
21
+ if t.TYPE_CHECKING:
22
+ import typing_extensions as te
23
+
24
+ P = te.ParamSpec("P")
25
+
26
+ R = t.TypeVar("R")
27
+
28
+
29
+ def _posixify(name: str) -> str:
30
+ return "-".join(name.split()).lower()
31
+
32
+
33
+ def safecall(func: "t.Callable[P, R]") -> "t.Callable[P, t.Optional[R]]":
34
+ """Wraps a function so that it swallows exceptions."""
35
+
36
+ def wrapper(*args: "P.args", **kwargs: "P.kwargs") -> t.Optional[R]:
37
+ try:
38
+ return func(*args, **kwargs)
39
+ except Exception:
40
+ pass
41
+ return None
42
+
43
+ return update_wrapper(wrapper, func)
44
+
45
+
46
+ def make_str(value: t.Any) -> str:
47
+ """Converts a value into a valid string."""
48
+ if isinstance(value, bytes):
49
+ try:
50
+ return value.decode(sys.getfilesystemencoding())
51
+ except UnicodeError:
52
+ return value.decode("utf-8", "replace")
53
+ return str(value)
54
+
55
+
56
+ def make_default_short_help(help: str, max_length: int = 45) -> str:
57
+ """Returns a condensed version of help string."""
58
+ # Consider only the first paragraph.
59
+ paragraph_end = help.find("\n\n")
60
+
61
+ if paragraph_end != -1:
62
+ help = help[:paragraph_end]
63
+
64
+ # Collapse newlines, tabs, and spaces.
65
+ words = help.split()
66
+
67
+ if not words:
68
+ return ""
69
+
70
+ # The first paragraph started with a "no rewrap" marker, ignore it.
71
+ if words[0] == "\b":
72
+ words = words[1:]
73
+
74
+ total_length = 0
75
+ last_index = len(words) - 1
76
+
77
+ for i, word in enumerate(words):
78
+ total_length += len(word) + (i > 0)
79
+
80
+ if total_length > max_length: # too long, truncate
81
+ break
82
+
83
+ if word[-1] == ".": # sentence end, truncate without "..."
84
+ return " ".join(words[: i + 1])
85
+
86
+ if total_length == max_length and i != last_index:
87
+ break # not at sentence end, truncate with "..."
88
+ else:
89
+ return " ".join(words) # no truncation needed
90
+
91
+ # Account for the length of the suffix.
92
+ total_length += len("...")
93
+
94
+ # remove words until the length is short enough
95
+ while i > 0:
96
+ total_length -= len(words[i]) + (i > 0)
97
+
98
+ if total_length <= max_length:
99
+ break
100
+
101
+ i -= 1
102
+
103
+ return " ".join(words[:i]) + "..."
104
+
105
+
106
+ class LazyFile:
107
+ """A lazy file works like a regular file but it does not fully open
108
+ the file but it does perform some basic checks early to see if the
109
+ filename parameter does make sense. This is useful for safely opening
110
+ files for writing.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ filename: t.Union[str, "os.PathLike[str]"],
116
+ mode: str = "r",
117
+ encoding: t.Optional[str] = None,
118
+ errors: t.Optional[str] = "strict",
119
+ atomic: bool = False,
120
+ ):
121
+ self.name: str = os.fspath(filename)
122
+ self.mode = mode
123
+ self.encoding = encoding
124
+ self.errors = errors
125
+ self.atomic = atomic
126
+ self._f: t.Optional[t.IO[t.Any]]
127
+ self.should_close: bool
128
+
129
+ if self.name == "-":
130
+ self._f, self.should_close = open_stream(filename, mode, encoding, errors)
131
+ else:
132
+ if "r" in mode:
133
+ # Open and close the file in case we're opening it for
134
+ # reading so that we can catch at least some errors in
135
+ # some cases early.
136
+ open(filename, mode).close()
137
+ self._f = None
138
+ self.should_close = True
139
+
140
+ def __getattr__(self, name: str) -> t.Any:
141
+ return getattr(self.open(), name)
142
+
143
+ def __repr__(self) -> str:
144
+ if self._f is not None:
145
+ return repr(self._f)
146
+ return f"<unopened file '{format_filename(self.name)}' {self.mode}>"
147
+
148
+ def open(self) -> t.IO[t.Any]:
149
+ """Opens the file if it's not yet open. This call might fail with
150
+ a :exc:`FileError`. Not handling this error will produce an error
151
+ that Click shows.
152
+ """
153
+ if self._f is not None:
154
+ return self._f
155
+ try:
156
+ rv, self.should_close = open_stream(
157
+ self.name, self.mode, self.encoding, self.errors, atomic=self.atomic
158
+ )
159
+ except OSError as e:
160
+ from .exceptions import FileError
161
+
162
+ raise FileError(self.name, hint=e.strerror) from e
163
+ self._f = rv
164
+ return rv
165
+
166
+ def close(self) -> None:
167
+ """Closes the underlying file, no matter what."""
168
+ if self._f is not None:
169
+ self._f.close()
170
+
171
+ def close_intelligently(self) -> None:
172
+ """This function only closes the file if it was opened by the lazy
173
+ file wrapper. For instance this will never close stdin.
174
+ """
175
+ if self.should_close:
176
+ self.close()
177
+
178
+ def __enter__(self) -> "LazyFile":
179
+ return self
180
+
181
+ def __exit__(
182
+ self,
183
+ exc_type: t.Optional[t.Type[BaseException]],
184
+ exc_value: t.Optional[BaseException],
185
+ tb: t.Optional[TracebackType],
186
+ ) -> None:
187
+ self.close_intelligently()
188
+
189
+ def __iter__(self) -> t.Iterator[t.AnyStr]:
190
+ self.open()
191
+ return iter(self._f) # type: ignore
192
+
193
+
194
+ class KeepOpenFile:
195
+ def __init__(self, file: t.IO[t.Any]) -> None:
196
+ self._file: t.IO[t.Any] = file
197
+
198
+ def __getattr__(self, name: str) -> t.Any:
199
+ return getattr(self._file, name)
200
+
201
+ def __enter__(self) -> "KeepOpenFile":
202
+ return self
203
+
204
+ def __exit__(
205
+ self,
206
+ exc_type: t.Optional[t.Type[BaseException]],
207
+ exc_value: t.Optional[BaseException],
208
+ tb: t.Optional[TracebackType],
209
+ ) -> None:
210
+ pass
211
+
212
+ def __repr__(self) -> str:
213
+ return repr(self._file)
214
+
215
+ def __iter__(self) -> t.Iterator[t.AnyStr]:
216
+ return iter(self._file)
217
+
218
+
219
+ def echo(
220
+ message: t.Optional[t.Any] = None,
221
+ file: t.Optional[t.IO[t.Any]] = None,
222
+ nl: bool = True,
223
+ err: bool = False,
224
+ color: t.Optional[bool] = None,
225
+ ) -> None:
226
+ """Print a message and newline to stdout or a file. This should be
227
+ used instead of :func:`print` because it provides better support
228
+ for different data, files, and environments.
229
+
230
+ Compared to :func:`print`, this does the following:
231
+
232
+ - Ensures that the output encoding is not misconfigured on Linux.
233
+ - Supports Unicode in the Windows console.
234
+ - Supports writing to binary outputs, and supports writing bytes
235
+ to text outputs.
236
+ - Supports colors and styles on Windows.
237
+ - Removes ANSI color and style codes if the output does not look
238
+ like an interactive terminal.
239
+ - Always flushes the output.
240
+
241
+ :param message: The string or bytes to output. Other objects are
242
+ converted to strings.
243
+ :param file: The file to write to. Defaults to ``stdout``.
244
+ :param err: Write to ``stderr`` instead of ``stdout``.
245
+ :param nl: Print a newline after the message. Enabled by default.
246
+ :param color: Force showing or hiding colors and other styles. By
247
+ default Click will remove color if the output does not look like
248
+ an interactive terminal.
249
+
250
+ .. versionchanged:: 6.0
251
+ Support Unicode output on the Windows console. Click does not
252
+ modify ``sys.stdout``, so ``sys.stdout.write()`` and ``print()``
253
+ will still not support Unicode.
254
+
255
+ .. versionchanged:: 4.0
256
+ Added the ``color`` parameter.
257
+
258
+ .. versionadded:: 3.0
259
+ Added the ``err`` parameter.
260
+
261
+ .. versionchanged:: 2.0
262
+ Support colors on Windows if colorama is installed.
263
+ """
264
+ if file is None:
265
+ if err:
266
+ file = _default_text_stderr()
267
+ else:
268
+ file = _default_text_stdout()
269
+
270
+ # There are no standard streams attached to write to. For example,
271
+ # pythonw on Windows.
272
+ if file is None:
273
+ return
274
+
275
+ # Convert non bytes/text into the native string type.
276
+ if message is not None and not isinstance(message, (str, bytes, bytearray)):
277
+ out: t.Optional[t.Union[str, bytes]] = str(message)
278
+ else:
279
+ out = message
280
+
281
+ if nl:
282
+ out = out or ""
283
+ if isinstance(out, str):
284
+ out += "\n"
285
+ else:
286
+ out += b"\n"
287
+
288
+ if not out:
289
+ file.flush()
290
+ return
291
+
292
+ # If there is a message and the value looks like bytes, we manually
293
+ # need to find the binary stream and write the message in there.
294
+ # This is done separately so that most stream types will work as you
295
+ # would expect. Eg: you can write to StringIO for other cases.
296
+ if isinstance(out, (bytes, bytearray)):
297
+ binary_file = _find_binary_writer(file)
298
+
299
+ if binary_file is not None:
300
+ file.flush()
301
+ binary_file.write(out)
302
+ binary_file.flush()
303
+ return
304
+
305
+ # ANSI style code support. For no message or bytes, nothing happens.
306
+ # When outputting to a file instead of a terminal, strip codes.
307
+ else:
308
+ color = resolve_color_default(color)
309
+
310
+ if should_strip_ansi(file, color):
311
+ out = strip_ansi(out)
312
+ elif WIN:
313
+ if auto_wrap_for_ansi is not None:
314
+ file = auto_wrap_for_ansi(file, color) # type: ignore
315
+ elif not color:
316
+ out = strip_ansi(out)
317
+
318
+ file.write(out) # type: ignore
319
+ file.flush()
320
+
321
+
322
+ def get_binary_stream(name: "te.Literal['stdin', 'stdout', 'stderr']") -> t.BinaryIO:
323
+ """Returns a system stream for byte processing.
324
+
325
+ :param name: the name of the stream to open. Valid names are ``'stdin'``,
326
+ ``'stdout'`` and ``'stderr'``
327
+ """
328
+ opener = binary_streams.get(name)
329
+ if opener is None:
330
+ raise TypeError(f"Unknown standard stream '{name}'")
331
+ return opener()
332
+
333
+
334
+ def get_text_stream(
335
+ name: "te.Literal['stdin', 'stdout', 'stderr']",
336
+ encoding: t.Optional[str] = None,
337
+ errors: t.Optional[str] = "strict",
338
+ ) -> t.TextIO:
339
+ """Returns a system stream for text processing. This usually returns
340
+ a wrapped stream around a binary stream returned from
341
+ :func:`get_binary_stream` but it also can take shortcuts for already
342
+ correctly configured streams.
343
+
344
+ :param name: the name of the stream to open. Valid names are ``'stdin'``,
345
+ ``'stdout'`` and ``'stderr'``
346
+ :param encoding: overrides the detected default encoding.
347
+ :param errors: overrides the default error mode.
348
+ """
349
+ opener = text_streams.get(name)
350
+ if opener is None:
351
+ raise TypeError(f"Unknown standard stream '{name}'")
352
+ return opener(encoding, errors)
353
+
354
+
355
+ def open_file(
356
+ filename: t.Union[str, "os.PathLike[str]"],
357
+ mode: str = "r",
358
+ encoding: t.Optional[str] = None,
359
+ errors: t.Optional[str] = "strict",
360
+ lazy: bool = False,
361
+ atomic: bool = False,
362
+ ) -> t.IO[t.Any]:
363
+ """Open a file, with extra behavior to handle ``'-'`` to indicate
364
+ a standard stream, lazy open on write, and atomic write. Similar to
365
+ the behavior of the :class:`~click.File` param type.
366
+
367
+ If ``'-'`` is given to open ``stdout`` or ``stdin``, the stream is
368
+ wrapped so that using it in a context manager will not close it.
369
+ This makes it possible to use the function without accidentally
370
+ closing a standard stream:
371
+
372
+ .. code-block:: python
373
+
374
+ with open_file(filename) as f:
375
+ ...
376
+
377
+ :param filename: The name or Path of the file to open, or ``'-'`` for
378
+ ``stdin``/``stdout``.
379
+ :param mode: The mode in which to open the file.
380
+ :param encoding: The encoding to decode or encode a file opened in
381
+ text mode.
382
+ :param errors: The error handling mode.
383
+ :param lazy: Wait to open the file until it is accessed. For read
384
+ mode, the file is temporarily opened to raise access errors
385
+ early, then closed until it is read again.
386
+ :param atomic: Write to a temporary file and replace the given file
387
+ on close.
388
+
389
+ .. versionadded:: 3.0
390
+ """
391
+ if lazy:
392
+ return t.cast(
393
+ t.IO[t.Any], LazyFile(filename, mode, encoding, errors, atomic=atomic)
394
+ )
395
+
396
+ f, should_close = open_stream(filename, mode, encoding, errors, atomic=atomic)
397
+
398
+ if not should_close:
399
+ f = t.cast(t.IO[t.Any], KeepOpenFile(f))
400
+
401
+ return f
402
+
403
+
404
+ def format_filename(
405
+ filename: "t.Union[str, bytes, os.PathLike[str], os.PathLike[bytes]]",
406
+ shorten: bool = False,
407
+ ) -> str:
408
+ """Format a filename as a string for display. Ensures the filename can be
409
+ displayed by replacing any invalid bytes or surrogate escapes in the name
410
+ with the replacement character ``�``.
411
+
412
+ Invalid bytes or surrogate escapes will raise an error when written to a
413
+ stream with ``errors="strict"``. This will typically happen with ``stdout``
414
+ when the locale is something like ``en_GB.UTF-8``.
415
+
416
+ Many scenarios *are* safe to write surrogates though, due to PEP 538 and
417
+ PEP 540, including:
418
+
419
+ - Writing to ``stderr``, which uses ``errors="backslashreplace"``.
420
+ - The system has ``LANG=C.UTF-8``, ``C``, or ``POSIX``. Python opens
421
+ stdout and stderr with ``errors="surrogateescape"``.
422
+ - None of ``LANG/LC_*`` are set. Python assumes ``LANG=C.UTF-8``.
423
+ - Python is started in UTF-8 mode with ``PYTHONUTF8=1`` or ``-X utf8``.
424
+ Python opens stdout and stderr with ``errors="surrogateescape"``.
425
+
426
+ :param filename: formats a filename for UI display. This will also convert
427
+ the filename into unicode without failing.
428
+ :param shorten: this optionally shortens the filename to strip of the
429
+ path that leads up to it.
430
+ """
431
+ if shorten:
432
+ filename = os.path.basename(filename)
433
+ else:
434
+ filename = os.fspath(filename)
435
+
436
+ if isinstance(filename, bytes):
437
+ filename = filename.decode(sys.getfilesystemencoding(), "replace")
438
+ else:
439
+ filename = filename.encode("utf-8", "surrogateescape").decode(
440
+ "utf-8", "replace"
441
+ )
442
+
443
+ return filename
444
+
445
+
446
+ def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str:
447
+ r"""Returns the config folder for the application. The default behavior
448
+ is to return whatever is most appropriate for the operating system.
449
+
450
+ To give you an idea, for an app called ``"Foo Bar"``, something like
451
+ the following folders could be returned:
452
+
453
+ Mac OS X:
454
+ ``~/Library/Application Support/Foo Bar``
455
+ Mac OS X (POSIX):
456
+ ``~/.foo-bar``
457
+ Unix:
458
+ ``~/.config/foo-bar``
459
+ Unix (POSIX):
460
+ ``~/.foo-bar``
461
+ Windows (roaming):
462
+ ``C:\Users\<user>\AppData\Roaming\Foo Bar``
463
+ Windows (not roaming):
464
+ ``C:\Users\<user>\AppData\Local\Foo Bar``
465
+
466
+ .. versionadded:: 2.0
467
+
468
+ :param app_name: the application name. This should be properly capitalized
469
+ and can contain whitespace.
470
+ :param roaming: controls if the folder should be roaming or not on Windows.
471
+ Has no effect otherwise.
472
+ :param force_posix: if this is set to `True` then on any POSIX system the
473
+ folder will be stored in the home folder with a leading
474
+ dot instead of the XDG config home or darwin's
475
+ application support folder.
476
+ """
477
+ if WIN:
478
+ key = "APPDATA" if roaming else "LOCALAPPDATA"
479
+ folder = os.environ.get(key)
480
+ if folder is None:
481
+ folder = os.path.expanduser("~")
482
+ return os.path.join(folder, app_name)
483
+ if force_posix:
484
+ return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}"))
485
+ if sys.platform == "darwin":
486
+ return os.path.join(
487
+ os.path.expanduser("~/Library/Application Support"), app_name
488
+ )
489
+ return os.path.join(
490
+ os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
491
+ _posixify(app_name),
492
+ )
493
+
494
+
495
+ class PacifyFlushWrapper:
496
+ """This wrapper is used to catch and suppress BrokenPipeErrors resulting
497
+ from ``.flush()`` being called on broken pipe during the shutdown/final-GC
498
+ of the Python interpreter. Notably ``.flush()`` is always called on
499
+ ``sys.stdout`` and ``sys.stderr``. So as to have minimal impact on any
500
+ other cleanup code, and the case where the underlying file is not a broken
501
+ pipe, all calls and attributes are proxied.
502
+ """
503
+
504
+ def __init__(self, wrapped: t.IO[t.Any]) -> None:
505
+ self.wrapped = wrapped
506
+
507
+ def flush(self) -> None:
508
+ try:
509
+ self.wrapped.flush()
510
+ except OSError as e:
511
+ import errno
512
+
513
+ if e.errno != errno.EPIPE:
514
+ raise
515
+
516
+ def __getattr__(self, attr: str) -> t.Any:
517
+ return getattr(self.wrapped, attr)
518
+
519
+
520
+ def _detect_program_name(
521
+ path: t.Optional[str] = None, _main: t.Optional[ModuleType] = None
522
+ ) -> str:
523
+ """Determine the command used to run the program, for use in help
524
+ text. If a file or entry point was executed, the file name is
525
+ returned. If ``python -m`` was used to execute a module or package,
526
+ ``python -m name`` is returned.
527
+
528
+ This doesn't try to be too precise, the goal is to give a concise
529
+ name for help text. Files are only shown as their name without the
530
+ path. ``python`` is only shown for modules, and the full path to
531
+ ``sys.executable`` is not shown.
532
+
533
+ :param path: The Python file being executed. Python puts this in
534
+ ``sys.argv[0]``, which is used by default.
535
+ :param _main: The ``__main__`` module. This should only be passed
536
+ during internal testing.
537
+
538
+ .. versionadded:: 8.0
539
+ Based on command args detection in the Werkzeug reloader.
540
+
541
+ :meta private:
542
+ """
543
+ if _main is None:
544
+ _main = sys.modules["__main__"]
545
+
546
+ if not path:
547
+ path = sys.argv[0]
548
+
549
+ # The value of __package__ indicates how Python was called. It may
550
+ # not exist if a setuptools script is installed as an egg. It may be
551
+ # set incorrectly for entry points created with pip on Windows.
552
+ # It is set to "" inside a Shiv or PEX zipapp.
553
+ if getattr(_main, "__package__", None) in {None, ""} or (
554
+ os.name == "nt"
555
+ and _main.__package__ == ""
556
+ and not os.path.exists(path)
557
+ and os.path.exists(f"{path}.exe")
558
+ ):
559
+ # Executed a file, like "python app.py".
560
+ return os.path.basename(path)
561
+
562
+ # Executed a module, like "python -m example".
563
+ # Rewritten by Python from "-m script" to "/path/to/script.py".
564
+ # Need to look at main module to determine how it was executed.
565
+ py_module = t.cast(str, _main.__package__)
566
+ name = os.path.splitext(os.path.basename(path))[0]
567
+
568
+ # A submodule like "example.cli".
569
+ if name != "__main__":
570
+ py_module = f"{py_module}.{name}"
571
+
572
+ return f"python -m {py_module.lstrip('.')}"
573
+
574
+
575
+ def _expand_args(
576
+ args: t.Iterable[str],
577
+ *,
578
+ user: bool = True,
579
+ env: bool = True,
580
+ glob_recursive: bool = True,
581
+ ) -> t.List[str]:
582
+ """Simulate Unix shell expansion with Python functions.
583
+
584
+ See :func:`glob.glob`, :func:`os.path.expanduser`, and
585
+ :func:`os.path.expandvars`.
586
+
587
+ This is intended for use on Windows, where the shell does not do any
588
+ expansion. It may not exactly match what a Unix shell would do.
589
+
590
+ :param args: List of command line arguments to expand.
591
+ :param user: Expand user home directory.
592
+ :param env: Expand environment variables.
593
+ :param glob_recursive: ``**`` matches directories recursively.
594
+
595
+ .. versionchanged:: 8.1
596
+ Invalid glob patterns are treated as empty expansions rather
597
+ than raising an error.
598
+
599
+ .. versionadded:: 8.0
600
+
601
+ :meta private:
602
+ """
603
+ from glob import glob
604
+
605
+ out = []
606
+
607
+ for arg in args:
608
+ if user:
609
+ arg = os.path.expanduser(arg)
610
+
611
+ if env:
612
+ arg = os.path.expandvars(arg)
613
+
614
+ try:
615
+ matches = glob(arg, recursive=glob_recursive)
616
+ except re.error:
617
+ matches = []
618
+
619
+ if not matches:
620
+ out.append(arg)
621
+ else:
622
+ out.extend(matches)
623
+
624
+ return out
.venv/lib/python3.11/site-packages/httplib2/__init__.py ADDED
@@ -0,0 +1,1799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Small, fast HTTP client library for Python."""
3
+
4
+ __author__ = "Joe Gregorio (joe@bitworking.org)"
5
+ __copyright__ = "Copyright 2006, Joe Gregorio"
6
+ __contributors__ = [
7
+ "Thomas Broyer (t.broyer@ltgt.net)",
8
+ "James Antill",
9
+ "Xavier Verges Farrero",
10
+ "Jonathan Feinberg",
11
+ "Blair Zajac",
12
+ "Sam Ruby",
13
+ "Louis Nyffenegger",
14
+ "Mark Pilgrim",
15
+ "Alex Yu",
16
+ "Lai Han",
17
+ ]
18
+ __license__ = "MIT"
19
+ __version__ = "0.22.0"
20
+
21
+ import base64
22
+ import calendar
23
+ import copy
24
+ import email
25
+ import email.feedparser
26
+ from email import header
27
+ import email.message
28
+ import email.utils
29
+ import errno
30
+ from gettext import gettext as _
31
+ import gzip
32
+ from hashlib import md5 as _md5
33
+ from hashlib import sha1 as _sha
34
+ import hmac
35
+ import http.client
36
+ import io
37
+ import os
38
+ import random
39
+ import re
40
+ import socket
41
+ import ssl
42
+ import sys
43
+ import time
44
+ import urllib.parse
45
+ import zlib
46
+
47
+ try:
48
+ import socks
49
+ except ImportError:
50
+ # TODO: remove this fallback and copypasted socksipy module upon py2/3 merge,
51
+ # idea is to have soft-dependency on any compatible module called socks
52
+ from . import socks
53
+ from . import auth
54
+ from .error import *
55
+ from .iri2uri import iri2uri
56
+
57
+
58
+ def has_timeout(timeout):
59
+ if hasattr(socket, "_GLOBAL_DEFAULT_TIMEOUT"):
60
+ return timeout is not None and timeout is not socket._GLOBAL_DEFAULT_TIMEOUT
61
+ return timeout is not None
62
+
63
+
64
+ __all__ = [
65
+ "debuglevel",
66
+ "FailedToDecompressContent",
67
+ "Http",
68
+ "HttpLib2Error",
69
+ "ProxyInfo",
70
+ "RedirectLimit",
71
+ "RedirectMissingLocation",
72
+ "Response",
73
+ "RETRIES",
74
+ "UnimplementedDigestAuthOptionError",
75
+ "UnimplementedHmacDigestAuthOptionError",
76
+ ]
77
+
78
+ # The httplib debug level, set to a non-zero value to get debug output
79
+ debuglevel = 0
80
+
81
+ # A request will be tried 'RETRIES' times if it fails at the socket/connection level.
82
+ RETRIES = 2
83
+
84
+
85
+ # Open Items:
86
+ # -----------
87
+
88
+ # Are we removing the cached content too soon on PUT (only delete on 200 Maybe?)
89
+
90
+ # Pluggable cache storage (supports storing the cache in
91
+ # flat files by default. We need a plug-in architecture
92
+ # that can support Berkeley DB and Squid)
93
+
94
+ # == Known Issues ==
95
+ # Does not handle a resource that uses conneg and Last-Modified but no ETag as a cache validator.
96
+ # Does not handle Cache-Control: max-stale
97
+ # Does not use Age: headers when calculating cache freshness.
98
+
99
+ # The number of redirections to follow before giving up.
100
+ # Note that only GET redirects are automatically followed.
101
+ # Will also honor 301 requests by saving that info and never
102
+ # requesting that URI again.
103
+ DEFAULT_MAX_REDIRECTS = 5
104
+
105
+ # Which headers are hop-by-hop headers by default
106
+ HOP_BY_HOP = [
107
+ "connection",
108
+ "keep-alive",
109
+ "proxy-authenticate",
110
+ "proxy-authorization",
111
+ "te",
112
+ "trailers",
113
+ "transfer-encoding",
114
+ "upgrade",
115
+ ]
116
+
117
+ # https://tools.ietf.org/html/rfc7231#section-8.1.3
118
+ SAFE_METHODS = ("GET", "HEAD", "OPTIONS", "TRACE")
119
+
120
+ # To change, assign to `Http().redirect_codes`
121
+ REDIRECT_CODES = frozenset((300, 301, 302, 303, 307, 308))
122
+
123
+
124
+ from httplib2 import certs
125
+
126
+ CA_CERTS = certs.where()
127
+
128
+ # PROTOCOL_TLS is python 3.5.3+. PROTOCOL_SSLv23 is deprecated.
129
+ # Both PROTOCOL_TLS and PROTOCOL_SSLv23 are equivalent and means:
130
+ # > Selects the highest protocol version that both the client and server support.
131
+ # > Despite the name, this option can select “TLS” protocols as well as “SSL”.
132
+ # source: https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_SSLv23
133
+
134
+ # PROTOCOL_TLS_CLIENT is python 3.10.0+. PROTOCOL_TLS is deprecated.
135
+ # > Auto-negotiate the highest protocol version that both the client and server support, and configure the context client-side connections.
136
+ # > The protocol enables CERT_REQUIRED and check_hostname by default.
137
+ # source: https://docs.python.org/3.10/library/ssl.html#ssl.PROTOCOL_TLS
138
+
139
+ DEFAULT_TLS_VERSION = getattr(ssl, "PROTOCOL_TLS_CLIENT", None) or getattr(ssl, "PROTOCOL_TLS", None) or getattr(ssl, "PROTOCOL_SSLv23")
140
+
141
+
142
+ def _build_ssl_context(
143
+ disable_ssl_certificate_validation,
144
+ ca_certs,
145
+ cert_file=None,
146
+ key_file=None,
147
+ maximum_version=None,
148
+ minimum_version=None,
149
+ key_password=None,
150
+ ):
151
+ if not hasattr(ssl, "SSLContext"):
152
+ raise RuntimeError("httplib2 requires Python 3.2+ for ssl.SSLContext")
153
+
154
+ context = ssl.SSLContext(DEFAULT_TLS_VERSION)
155
+ # check_hostname and verify_mode should be set in opposite order during disable
156
+ # https://bugs.python.org/issue31431
157
+ if disable_ssl_certificate_validation and hasattr(context, "check_hostname"):
158
+ context.check_hostname = not disable_ssl_certificate_validation
159
+ context.verify_mode = ssl.CERT_NONE if disable_ssl_certificate_validation else ssl.CERT_REQUIRED
160
+
161
+ # SSLContext.maximum_version and SSLContext.minimum_version are python 3.7+.
162
+ # source: https://docs.python.org/3/library/ssl.html#ssl.SSLContext.maximum_version
163
+ if maximum_version is not None:
164
+ if hasattr(context, "maximum_version"):
165
+ if isinstance(maximum_version, str):
166
+ maximum_version = getattr(ssl.TLSVersion, maximum_version)
167
+ context.maximum_version = maximum_version
168
+ else:
169
+ raise RuntimeError("setting tls_maximum_version requires Python 3.7 and OpenSSL 1.1 or newer")
170
+ if minimum_version is not None:
171
+ if hasattr(context, "minimum_version"):
172
+ if isinstance(minimum_version, str):
173
+ minimum_version = getattr(ssl.TLSVersion, minimum_version)
174
+ context.minimum_version = minimum_version
175
+ else:
176
+ raise RuntimeError("setting tls_minimum_version requires Python 3.7 and OpenSSL 1.1 or newer")
177
+ # check_hostname requires python 3.4+
178
+ # we will perform the equivalent in HTTPSConnectionWithTimeout.connect() by calling ssl.match_hostname
179
+ # if check_hostname is not supported.
180
+ if hasattr(context, "check_hostname"):
181
+ context.check_hostname = not disable_ssl_certificate_validation
182
+
183
+ context.load_verify_locations(ca_certs)
184
+
185
+ if cert_file:
186
+ context.load_cert_chain(cert_file, key_file, key_password)
187
+
188
+ return context
189
+
190
+
191
+ def _get_end2end_headers(response):
192
+ hopbyhop = list(HOP_BY_HOP)
193
+ hopbyhop.extend([x.strip() for x in response.get("connection", "").split(",")])
194
+ return [header for header in list(response.keys()) if header not in hopbyhop]
195
+
196
+
197
+ _missing = object()
198
+
199
+
200
+ def _errno_from_exception(e):
201
+ # TODO python 3.11+ cheap try: return e.errno except AttributeError: pass
202
+ errno = getattr(e, "errno", _missing)
203
+ if errno is not _missing:
204
+ return errno
205
+
206
+ # socket.error and common wrap in .args
207
+ args = getattr(e, "args", None)
208
+ if args:
209
+ return _errno_from_exception(args[0])
210
+
211
+ # pysocks.ProxyError wraps in .socket_err
212
+ # https://github.com/httplib2/httplib2/pull/202
213
+ socket_err = getattr(e, "socket_err", None)
214
+ if socket_err:
215
+ return _errno_from_exception(socket_err)
216
+
217
+ return None
218
+
219
+
220
+ URI = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?")
221
+
222
+
223
+ def parse_uri(uri):
224
+ """Parses a URI using the regex given in Appendix B of RFC 3986.
225
+
226
+ (scheme, authority, path, query, fragment) = parse_uri(uri)
227
+ """
228
+ groups = URI.match(uri).groups()
229
+ return (groups[1], groups[3], groups[4], groups[6], groups[8])
230
+
231
+
232
+ def urlnorm(uri):
233
+ (scheme, authority, path, query, fragment) = parse_uri(uri)
234
+ if not scheme or not authority:
235
+ raise RelativeURIError("Only absolute URIs are allowed. uri = %s" % uri)
236
+ authority = authority.lower()
237
+ scheme = scheme.lower()
238
+ if not path:
239
+ path = "/"
240
+ # Could do syntax based normalization of the URI before
241
+ # computing the digest. See Section 6.2.2 of Std 66.
242
+ request_uri = query and "?".join([path, query]) or path
243
+ scheme = scheme.lower()
244
+ defrag_uri = scheme + "://" + authority + request_uri
245
+ return scheme, authority, request_uri, defrag_uri
246
+
247
+
248
+ # Cache filename construction (original borrowed from Venus http://intertwingly.net/code/venus/)
249
+ re_url_scheme = re.compile(r"^\w+://")
250
+ re_unsafe = re.compile(r"[^\w\-_.()=!]+", re.ASCII)
251
+
252
+
253
+ def safename(filename):
254
+ """Return a filename suitable for the cache.
255
+ Strips dangerous and common characters to create a filename we
256
+ can use to store the cache in.
257
+ """
258
+ if isinstance(filename, bytes):
259
+ filename_bytes = filename
260
+ filename = filename.decode("utf-8")
261
+ else:
262
+ filename_bytes = filename.encode("utf-8")
263
+ filemd5 = _md5(filename_bytes).hexdigest()
264
+ filename = re_url_scheme.sub("", filename)
265
+ filename = re_unsafe.sub("", filename)
266
+
267
+ # limit length of filename (vital for Windows)
268
+ # https://github.com/httplib2/httplib2/pull/74
269
+ # C:\Users\ <username> \AppData\Local\Temp\ <safe_filename> , <md5>
270
+ # 9 chars + max 104 chars + 20 chars + x + 1 + 32 = max 259 chars
271
+ # Thus max safe filename x = 93 chars. Let it be 90 to make a round sum:
272
+ filename = filename[:90]
273
+
274
+ return ",".join((filename, filemd5))
275
+
276
+
277
+ NORMALIZE_SPACE = re.compile(r"(?:\r\n)?[ \t]+")
278
+
279
+
280
+ def _normalize_headers(headers):
281
+ return dict(
282
+ [
283
+ (_convert_byte_str(key).lower(), NORMALIZE_SPACE.sub(_convert_byte_str(value), " ").strip(),)
284
+ for (key, value) in headers.items()
285
+ ]
286
+ )
287
+
288
+
289
+ def _convert_byte_str(s):
290
+ if not isinstance(s, str):
291
+ return str(s, "utf-8")
292
+ return s
293
+
294
+
295
+ def _parse_cache_control(headers):
296
+ retval = {}
297
+ if "cache-control" in headers:
298
+ parts = headers["cache-control"].split(",")
299
+ parts_with_args = [
300
+ tuple([x.strip().lower() for x in part.split("=", 1)]) for part in parts if -1 != part.find("=")
301
+ ]
302
+ parts_wo_args = [(name.strip().lower(), 1) for name in parts if -1 == name.find("=")]
303
+ retval = dict(parts_with_args + parts_wo_args)
304
+ return retval
305
+
306
+
307
+ # Whether to use a strict mode to parse WWW-Authenticate headers
308
+ # Might lead to bad results in case of ill-formed header value,
309
+ # so disabled by default, falling back to relaxed parsing.
310
+ # Set to true to turn on, useful for testing servers.
311
+ USE_WWW_AUTH_STRICT_PARSING = 0
312
+
313
+
314
+ def _entry_disposition(response_headers, request_headers):
315
+ """Determine freshness from the Date, Expires and Cache-Control headers.
316
+
317
+ We don't handle the following:
318
+
319
+ 1. Cache-Control: max-stale
320
+ 2. Age: headers are not used in the calculations.
321
+
322
+ Not that this algorithm is simpler than you might think
323
+ because we are operating as a private (non-shared) cache.
324
+ This lets us ignore 's-maxage'. We can also ignore
325
+ 'proxy-invalidate' since we aren't a proxy.
326
+ We will never return a stale document as
327
+ fresh as a design decision, and thus the non-implementation
328
+ of 'max-stale'. This also lets us safely ignore 'must-revalidate'
329
+ since we operate as if every server has sent 'must-revalidate'.
330
+ Since we are private we get to ignore both 'public' and
331
+ 'private' parameters. We also ignore 'no-transform' since
332
+ we don't do any transformations.
333
+ The 'no-store' parameter is handled at a higher level.
334
+ So the only Cache-Control parameters we look at are:
335
+
336
+ no-cache
337
+ only-if-cached
338
+ max-age
339
+ min-fresh
340
+ """
341
+
342
+ retval = "STALE"
343
+ cc = _parse_cache_control(request_headers)
344
+ cc_response = _parse_cache_control(response_headers)
345
+
346
+ if "pragma" in request_headers and request_headers["pragma"].lower().find("no-cache") != -1:
347
+ retval = "TRANSPARENT"
348
+ if "cache-control" not in request_headers:
349
+ request_headers["cache-control"] = "no-cache"
350
+ elif "no-cache" in cc:
351
+ retval = "TRANSPARENT"
352
+ elif "no-cache" in cc_response:
353
+ retval = "STALE"
354
+ elif "only-if-cached" in cc:
355
+ retval = "FRESH"
356
+ elif "date" in response_headers:
357
+ date = calendar.timegm(email.utils.parsedate_tz(response_headers["date"]))
358
+ now = time.time()
359
+ current_age = max(0, now - date)
360
+ if "max-age" in cc_response:
361
+ try:
362
+ freshness_lifetime = int(cc_response["max-age"])
363
+ except ValueError:
364
+ freshness_lifetime = 0
365
+ elif "expires" in response_headers:
366
+ expires = email.utils.parsedate_tz(response_headers["expires"])
367
+ if None == expires:
368
+ freshness_lifetime = 0
369
+ else:
370
+ freshness_lifetime = max(0, calendar.timegm(expires) - date)
371
+ else:
372
+ freshness_lifetime = 0
373
+ if "max-age" in cc:
374
+ try:
375
+ freshness_lifetime = int(cc["max-age"])
376
+ except ValueError:
377
+ freshness_lifetime = 0
378
+ if "min-fresh" in cc:
379
+ try:
380
+ min_fresh = int(cc["min-fresh"])
381
+ except ValueError:
382
+ min_fresh = 0
383
+ current_age += min_fresh
384
+ if freshness_lifetime > current_age:
385
+ retval = "FRESH"
386
+ return retval
387
+
388
+
389
+ def _decompressContent(response, new_content):
390
+ content = new_content
391
+ try:
392
+ encoding = response.get("content-encoding", None)
393
+ if encoding in ["gzip", "deflate"]:
394
+ if encoding == "gzip":
395
+ content = gzip.GzipFile(fileobj=io.BytesIO(new_content)).read()
396
+ if encoding == "deflate":
397
+ try:
398
+ content = zlib.decompress(content, zlib.MAX_WBITS)
399
+ except (IOError, zlib.error):
400
+ content = zlib.decompress(content, -zlib.MAX_WBITS)
401
+ response["content-length"] = str(len(content))
402
+ # Record the historical presence of the encoding in a way the won't interfere.
403
+ response["-content-encoding"] = response["content-encoding"]
404
+ del response["content-encoding"]
405
+ except (IOError, zlib.error):
406
+ content = ""
407
+ raise FailedToDecompressContent(
408
+ _("Content purported to be compressed with %s but failed to decompress.") % response.get("content-encoding"),
409
+ response,
410
+ content,
411
+ )
412
+ return content
413
+
414
+
415
+ def _bind_write_headers(msg):
416
+ def _write_headers(self):
417
+ # Self refers to the Generator object.
418
+ for h, v in msg.items():
419
+ print("%s:" % h, end=" ", file=self._fp)
420
+ if isinstance(v, header.Header):
421
+ print(v.encode(maxlinelen=self._maxheaderlen), file=self._fp)
422
+ else:
423
+ # email.Header got lots of smarts, so use it.
424
+ headers = header.Header(v, maxlinelen=self._maxheaderlen, charset="utf-8", header_name=h)
425
+ print(headers.encode(), file=self._fp)
426
+ # A blank line always separates headers from body.
427
+ print(file=self._fp)
428
+
429
+ return _write_headers
430
+
431
+
432
+ def _updateCache(request_headers, response_headers, content, cache, cachekey):
433
+ if cachekey:
434
+ cc = _parse_cache_control(request_headers)
435
+ cc_response = _parse_cache_control(response_headers)
436
+ if "no-store" in cc or "no-store" in cc_response:
437
+ cache.delete(cachekey)
438
+ else:
439
+ info = email.message.Message()
440
+ for key, value in response_headers.items():
441
+ if key not in ["status", "content-encoding", "transfer-encoding"]:
442
+ info[key] = value
443
+
444
+ # Add annotations to the cache to indicate what headers
445
+ # are variant for this request.
446
+ vary = response_headers.get("vary", None)
447
+ if vary:
448
+ vary_headers = vary.lower().replace(" ", "").split(",")
449
+ for header in vary_headers:
450
+ key = "-varied-%s" % header
451
+ try:
452
+ info[key] = request_headers[header]
453
+ except KeyError:
454
+ pass
455
+
456
+ status = response_headers.status
457
+ if status == 304:
458
+ status = 200
459
+
460
+ status_header = "status: %d\r\n" % status
461
+
462
+ try:
463
+ header_str = info.as_string()
464
+ except UnicodeEncodeError:
465
+ setattr(info, "_write_headers", _bind_write_headers(info))
466
+ header_str = info.as_string()
467
+
468
+ header_str = re.sub("\r(?!\n)|(?<!\r)\n", "\r\n", header_str)
469
+ text = b"".join([status_header.encode("utf-8"), header_str.encode("utf-8"), content])
470
+
471
+ cache.set(cachekey, text)
472
+
473
+
474
+ def _cnonce():
475
+ dig = _md5(
476
+ ("%s:%s" % (time.ctime(), ["0123456789"[random.randrange(0, 9)] for i in range(20)])).encode("utf-8")
477
+ ).hexdigest()
478
+ return dig[:16]
479
+
480
+
481
+ def _wsse_username_token(cnonce, iso_now, password):
482
+ return (
483
+ base64.b64encode(_sha(("%s%s%s" % (cnonce, iso_now, password)).encode("utf-8")).digest()).strip().decode("utf-8")
484
+ )
485
+
486
+
487
+ # For credentials we need two things, first
488
+ # a pool of credential to try (not necesarily tied to BAsic, Digest, etc.)
489
+ # Then we also need a list of URIs that have already demanded authentication
490
+ # That list is tricky since sub-URIs can take the same auth, or the
491
+ # auth scheme may change as you descend the tree.
492
+ # So we also need each Auth instance to be able to tell us
493
+ # how close to the 'top' it is.
494
+
495
+
496
+ class Authentication(object):
497
+ def __init__(self, credentials, host, request_uri, headers, response, content, http):
498
+ (scheme, authority, path, query, fragment) = parse_uri(request_uri)
499
+ self.path = path
500
+ self.host = host
501
+ self.credentials = credentials
502
+ self.http = http
503
+
504
+ def depth(self, request_uri):
505
+ (scheme, authority, path, query, fragment) = parse_uri(request_uri)
506
+ return request_uri[len(self.path) :].count("/")
507
+
508
+ def inscope(self, host, request_uri):
509
+ # XXX Should we normalize the request_uri?
510
+ (scheme, authority, path, query, fragment) = parse_uri(request_uri)
511
+ return (host == self.host) and path.startswith(self.path)
512
+
513
+ def request(self, method, request_uri, headers, content):
514
+ """Modify the request headers to add the appropriate
515
+ Authorization header. Over-rise this in sub-classes."""
516
+ pass
517
+
518
+ def response(self, response, content):
519
+ """Gives us a chance to update with new nonces
520
+ or such returned from the last authorized response.
521
+ Over-rise this in sub-classes if necessary.
522
+
523
+ Return TRUE is the request is to be retried, for
524
+ example Digest may return stale=true.
525
+ """
526
+ return False
527
+
528
+ def __eq__(self, auth):
529
+ return False
530
+
531
+ def __ne__(self, auth):
532
+ return True
533
+
534
+ def __lt__(self, auth):
535
+ return True
536
+
537
+ def __gt__(self, auth):
538
+ return False
539
+
540
+ def __le__(self, auth):
541
+ return True
542
+
543
+ def __ge__(self, auth):
544
+ return False
545
+
546
+ def __bool__(self):
547
+ return True
548
+
549
+
550
+ class BasicAuthentication(Authentication):
551
+ def __init__(self, credentials, host, request_uri, headers, response, content, http):
552
+ Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
553
+
554
+ def request(self, method, request_uri, headers, content):
555
+ """Modify the request headers to add the appropriate
556
+ Authorization header."""
557
+ headers["authorization"] = "Basic " + base64.b64encode(
558
+ ("%s:%s" % self.credentials).encode("utf-8")
559
+ ).strip().decode("utf-8")
560
+
561
+
562
+ class DigestAuthentication(Authentication):
563
+ """Only do qop='auth' and MD5, since that
564
+ is all Apache currently implements"""
565
+
566
+ def __init__(self, credentials, host, request_uri, headers, response, content, http):
567
+ Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
568
+ self.challenge = auth._parse_www_authenticate(response, "www-authenticate")["digest"]
569
+ qop = self.challenge.get("qop", "auth")
570
+ self.challenge["qop"] = ("auth" in [x.strip() for x in qop.split()]) and "auth" or None
571
+ if self.challenge["qop"] is None:
572
+ raise UnimplementedDigestAuthOptionError(_("Unsupported value for qop: %s." % qop))
573
+ self.challenge["algorithm"] = self.challenge.get("algorithm", "MD5").upper()
574
+ if self.challenge["algorithm"] != "MD5":
575
+ raise UnimplementedDigestAuthOptionError(
576
+ _("Unsupported value for algorithm: %s." % self.challenge["algorithm"])
577
+ )
578
+ self.A1 = "".join([self.credentials[0], ":", self.challenge["realm"], ":", self.credentials[1],])
579
+ self.challenge["nc"] = 1
580
+
581
+ def request(self, method, request_uri, headers, content, cnonce=None):
582
+ """Modify the request headers"""
583
+ H = lambda x: _md5(x.encode("utf-8")).hexdigest()
584
+ KD = lambda s, d: H("%s:%s" % (s, d))
585
+ A2 = "".join([method, ":", request_uri])
586
+ self.challenge["cnonce"] = cnonce or _cnonce()
587
+ request_digest = '"%s"' % KD(
588
+ H(self.A1),
589
+ "%s:%s:%s:%s:%s"
590
+ % (
591
+ self.challenge["nonce"],
592
+ "%08x" % self.challenge["nc"],
593
+ self.challenge["cnonce"],
594
+ self.challenge["qop"],
595
+ H(A2),
596
+ ),
597
+ )
598
+ headers["authorization"] = (
599
+ 'Digest username="%s", realm="%s", nonce="%s", '
600
+ 'uri="%s", algorithm=%s, response=%s, qop=%s, '
601
+ 'nc=%08x, cnonce="%s"'
602
+ ) % (
603
+ self.credentials[0],
604
+ self.challenge["realm"],
605
+ self.challenge["nonce"],
606
+ request_uri,
607
+ self.challenge["algorithm"],
608
+ request_digest,
609
+ self.challenge["qop"],
610
+ self.challenge["nc"],
611
+ self.challenge["cnonce"],
612
+ )
613
+ if self.challenge.get("opaque"):
614
+ headers["authorization"] += ', opaque="%s"' % self.challenge["opaque"]
615
+ self.challenge["nc"] += 1
616
+
617
+ def response(self, response, content):
618
+ if "authentication-info" not in response:
619
+ challenge = auth._parse_www_authenticate(response, "www-authenticate").get("digest", {})
620
+ if "true" == challenge.get("stale"):
621
+ self.challenge["nonce"] = challenge["nonce"]
622
+ self.challenge["nc"] = 1
623
+ return True
624
+ else:
625
+ updated_challenge = auth._parse_authentication_info(response, "authentication-info")
626
+
627
+ if "nextnonce" in updated_challenge:
628
+ self.challenge["nonce"] = updated_challenge["nextnonce"]
629
+ self.challenge["nc"] = 1
630
+ return False
631
+
632
+
633
+ class HmacDigestAuthentication(Authentication):
634
+ """Adapted from Robert Sayre's code and DigestAuthentication above."""
635
+
636
+ __author__ = "Thomas Broyer (t.broyer@ltgt.net)"
637
+
638
+ def __init__(self, credentials, host, request_uri, headers, response, content, http):
639
+ Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
640
+ challenge = auth._parse_www_authenticate(response, "www-authenticate")
641
+ self.challenge = challenge["hmacdigest"]
642
+ # TODO: self.challenge['domain']
643
+ self.challenge["reason"] = self.challenge.get("reason", "unauthorized")
644
+ if self.challenge["reason"] not in ["unauthorized", "integrity"]:
645
+ self.challenge["reason"] = "unauthorized"
646
+ self.challenge["salt"] = self.challenge.get("salt", "")
647
+ if not self.challenge.get("snonce"):
648
+ raise UnimplementedHmacDigestAuthOptionError(
649
+ _("The challenge doesn't contain a server nonce, or this one is empty.")
650
+ )
651
+ self.challenge["algorithm"] = self.challenge.get("algorithm", "HMAC-SHA-1")
652
+ if self.challenge["algorithm"] not in ["HMAC-SHA-1", "HMAC-MD5"]:
653
+ raise UnimplementedHmacDigestAuthOptionError(
654
+ _("Unsupported value for algorithm: %s." % self.challenge["algorithm"])
655
+ )
656
+ self.challenge["pw-algorithm"] = self.challenge.get("pw-algorithm", "SHA-1")
657
+ if self.challenge["pw-algorithm"] not in ["SHA-1", "MD5"]:
658
+ raise UnimplementedHmacDigestAuthOptionError(
659
+ _("Unsupported value for pw-algorithm: %s." % self.challenge["pw-algorithm"])
660
+ )
661
+ if self.challenge["algorithm"] == "HMAC-MD5":
662
+ self.hashmod = _md5
663
+ else:
664
+ self.hashmod = _sha
665
+ if self.challenge["pw-algorithm"] == "MD5":
666
+ self.pwhashmod = _md5
667
+ else:
668
+ self.pwhashmod = _sha
669
+ self.key = "".join(
670
+ [
671
+ self.credentials[0],
672
+ ":",
673
+ self.pwhashmod.new("".join([self.credentials[1], self.challenge["salt"]])).hexdigest().lower(),
674
+ ":",
675
+ self.challenge["realm"],
676
+ ]
677
+ )
678
+ self.key = self.pwhashmod.new(self.key).hexdigest().lower()
679
+
680
+ def request(self, method, request_uri, headers, content):
681
+ """Modify the request headers"""
682
+ keys = _get_end2end_headers(headers)
683
+ keylist = "".join(["%s " % k for k in keys])
684
+ headers_val = "".join([headers[k] for k in keys])
685
+ created = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
686
+ cnonce = _cnonce()
687
+ request_digest = "%s:%s:%s:%s:%s" % (method, request_uri, cnonce, self.challenge["snonce"], headers_val,)
688
+ request_digest = hmac.new(self.key, request_digest, self.hashmod).hexdigest().lower()
689
+ headers["authorization"] = (
690
+ 'HMACDigest username="%s", realm="%s", snonce="%s",'
691
+ ' cnonce="%s", uri="%s", created="%s", '
692
+ 'response="%s", headers="%s"'
693
+ ) % (
694
+ self.credentials[0],
695
+ self.challenge["realm"],
696
+ self.challenge["snonce"],
697
+ cnonce,
698
+ request_uri,
699
+ created,
700
+ request_digest,
701
+ keylist,
702
+ )
703
+
704
+ def response(self, response, content):
705
+ challenge = auth._parse_www_authenticate(response, "www-authenticate").get("hmacdigest", {})
706
+ if challenge.get("reason") in ["integrity", "stale"]:
707
+ return True
708
+ return False
709
+
710
+
711
+ class WsseAuthentication(Authentication):
712
+ """This is thinly tested and should not be relied upon.
713
+ At this time there isn't any third party server to test against.
714
+ Blogger and TypePad implemented this algorithm at one point
715
+ but Blogger has since switched to Basic over HTTPS and
716
+ TypePad has implemented it wrong, by never issuing a 401
717
+ challenge but instead requiring your client to telepathically know that
718
+ their endpoint is expecting WSSE profile="UsernameToken"."""
719
+
720
+ def __init__(self, credentials, host, request_uri, headers, response, content, http):
721
+ Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
722
+
723
+ def request(self, method, request_uri, headers, content):
724
+ """Modify the request headers to add the appropriate
725
+ Authorization header."""
726
+ headers["authorization"] = 'WSSE profile="UsernameToken"'
727
+ iso_now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
728
+ cnonce = _cnonce()
729
+ password_digest = _wsse_username_token(cnonce, iso_now, self.credentials[1])
730
+ headers["X-WSSE"] = ('UsernameToken Username="%s", PasswordDigest="%s", ' 'Nonce="%s", Created="%s"') % (
731
+ self.credentials[0],
732
+ password_digest,
733
+ cnonce,
734
+ iso_now,
735
+ )
736
+
737
+
738
+ class GoogleLoginAuthentication(Authentication):
739
+ def __init__(self, credentials, host, request_uri, headers, response, content, http):
740
+ from urllib.parse import urlencode
741
+
742
+ Authentication.__init__(self, credentials, host, request_uri, headers, response, content, http)
743
+ challenge = auth._parse_www_authenticate(response, "www-authenticate")
744
+ service = challenge["googlelogin"].get("service", "xapi")
745
+ # Bloggger actually returns the service in the challenge
746
+ # For the rest we guess based on the URI
747
+ if service == "xapi" and request_uri.find("calendar") > 0:
748
+ service = "cl"
749
+ # No point in guessing Base or Spreadsheet
750
+ # elif request_uri.find("spreadsheets") > 0:
751
+ # service = "wise"
752
+
753
+ auth = dict(Email=credentials[0], Passwd=credentials[1], service=service, source=headers["user-agent"],)
754
+ resp, content = self.http.request(
755
+ "https://www.google.com/accounts/ClientLogin",
756
+ method="POST",
757
+ body=urlencode(auth),
758
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
759
+ )
760
+ lines = content.split("\n")
761
+ d = dict([tuple(line.split("=", 1)) for line in lines if line])
762
+ if resp.status == 403:
763
+ self.Auth = ""
764
+ else:
765
+ self.Auth = d["Auth"]
766
+
767
+ def request(self, method, request_uri, headers, content):
768
+ """Modify the request headers to add the appropriate
769
+ Authorization header."""
770
+ headers["authorization"] = "GoogleLogin Auth=" + self.Auth
771
+
772
+
773
+ AUTH_SCHEME_CLASSES = {
774
+ "basic": BasicAuthentication,
775
+ "wsse": WsseAuthentication,
776
+ "digest": DigestAuthentication,
777
+ "hmacdigest": HmacDigestAuthentication,
778
+ "googlelogin": GoogleLoginAuthentication,
779
+ }
780
+
781
+ AUTH_SCHEME_ORDER = ["hmacdigest", "googlelogin", "digest", "wsse", "basic"]
782
+
783
+
784
+ class FileCache(object):
785
+ """Uses a local directory as a store for cached files.
786
+ Not really safe to use if multiple threads or processes are going to
787
+ be running on the same cache.
788
+ """
789
+
790
+ def __init__(self, cache, safe=safename): # use safe=lambda x: md5.new(x).hexdigest() for the old behavior
791
+ self.cache = cache
792
+ self.safe = safe
793
+ if not os.path.exists(cache):
794
+ os.makedirs(self.cache)
795
+
796
+ def get(self, key):
797
+ retval = None
798
+ cacheFullPath = os.path.join(self.cache, self.safe(key))
799
+ try:
800
+ f = open(cacheFullPath, "rb")
801
+ retval = f.read()
802
+ f.close()
803
+ except IOError:
804
+ pass
805
+ return retval
806
+
807
+ def set(self, key, value):
808
+ cacheFullPath = os.path.join(self.cache, self.safe(key))
809
+ f = open(cacheFullPath, "wb")
810
+ f.write(value)
811
+ f.close()
812
+
813
+ def delete(self, key):
814
+ cacheFullPath = os.path.join(self.cache, self.safe(key))
815
+ if os.path.exists(cacheFullPath):
816
+ os.remove(cacheFullPath)
817
+
818
+
819
+ class Credentials(object):
820
+ def __init__(self):
821
+ self.credentials = []
822
+
823
+ def add(self, name, password, domain=""):
824
+ self.credentials.append((domain.lower(), name, password))
825
+
826
+ def clear(self):
827
+ self.credentials = []
828
+
829
+ def iter(self, domain):
830
+ for (cdomain, name, password) in self.credentials:
831
+ if cdomain == "" or domain == cdomain:
832
+ yield (name, password)
833
+
834
+
835
+ class KeyCerts(Credentials):
836
+ """Identical to Credentials except that
837
+ name/password are mapped to key/cert."""
838
+
839
+ def add(self, key, cert, domain, password):
840
+ self.credentials.append((domain.lower(), key, cert, password))
841
+
842
+ def iter(self, domain):
843
+ for (cdomain, key, cert, password) in self.credentials:
844
+ if cdomain == "" or domain == cdomain:
845
+ yield (key, cert, password)
846
+
847
+
848
+ class AllHosts(object):
849
+ pass
850
+
851
+
852
+ class ProxyInfo(object):
853
+ """Collect information required to use a proxy."""
854
+
855
+ bypass_hosts = ()
856
+
857
+ def __init__(
858
+ self, proxy_type, proxy_host, proxy_port, proxy_rdns=True, proxy_user=None, proxy_pass=None, proxy_headers=None,
859
+ ):
860
+ """Args:
861
+
862
+ proxy_type: The type of proxy server. This must be set to one of
863
+ socks.PROXY_TYPE_XXX constants. For example: p =
864
+ ProxyInfo(proxy_type=socks.PROXY_TYPE_HTTP, proxy_host='localhost',
865
+ proxy_port=8000)
866
+ proxy_host: The hostname or IP address of the proxy server.
867
+ proxy_port: The port that the proxy server is running on.
868
+ proxy_rdns: If True (default), DNS queries will not be performed
869
+ locally, and instead, handed to the proxy to resolve. This is useful
870
+ if the network does not allow resolution of non-local names. In
871
+ httplib2 0.9 and earlier, this defaulted to False.
872
+ proxy_user: The username used to authenticate with the proxy server.
873
+ proxy_pass: The password used to authenticate with the proxy server.
874
+ proxy_headers: Additional or modified headers for the proxy connect
875
+ request.
876
+ """
877
+ if isinstance(proxy_user, bytes):
878
+ proxy_user = proxy_user.decode()
879
+ if isinstance(proxy_pass, bytes):
880
+ proxy_pass = proxy_pass.decode()
881
+ (
882
+ self.proxy_type,
883
+ self.proxy_host,
884
+ self.proxy_port,
885
+ self.proxy_rdns,
886
+ self.proxy_user,
887
+ self.proxy_pass,
888
+ self.proxy_headers,
889
+ ) = (
890
+ proxy_type,
891
+ proxy_host,
892
+ proxy_port,
893
+ proxy_rdns,
894
+ proxy_user,
895
+ proxy_pass,
896
+ proxy_headers,
897
+ )
898
+
899
+ def astuple(self):
900
+ return (
901
+ self.proxy_type,
902
+ self.proxy_host,
903
+ self.proxy_port,
904
+ self.proxy_rdns,
905
+ self.proxy_user,
906
+ self.proxy_pass,
907
+ self.proxy_headers,
908
+ )
909
+
910
+ def isgood(self):
911
+ return socks and (self.proxy_host != None) and (self.proxy_port != None)
912
+
913
+ def applies_to(self, hostname):
914
+ return not self.bypass_host(hostname)
915
+
916
+ def bypass_host(self, hostname):
917
+ """Has this host been excluded from the proxy config"""
918
+ if self.bypass_hosts is AllHosts:
919
+ return True
920
+
921
+ hostname = "." + hostname.lstrip(".")
922
+ for skip_name in self.bypass_hosts:
923
+ # *.suffix
924
+ if skip_name.startswith(".") and hostname.endswith(skip_name):
925
+ return True
926
+ # exact match
927
+ if hostname == "." + skip_name:
928
+ return True
929
+ return False
930
+
931
+ def __repr__(self):
932
+ return (
933
+ "<ProxyInfo type={p.proxy_type} "
934
+ "host:port={p.proxy_host}:{p.proxy_port} rdns={p.proxy_rdns}"
935
+ + " user={p.proxy_user} headers={p.proxy_headers}>"
936
+ ).format(p=self)
937
+
938
+
939
+ def proxy_info_from_environment(method="http"):
940
+ """Read proxy info from the environment variables.
941
+ """
942
+ if method not in ("http", "https"):
943
+ return
944
+
945
+ env_var = method + "_proxy"
946
+ url = os.environ.get(env_var, os.environ.get(env_var.upper()))
947
+ if not url:
948
+ return
949
+ return proxy_info_from_url(url, method, noproxy=None)
950
+
951
+
952
+ def proxy_info_from_url(url, method="http", noproxy=None):
953
+ """Construct a ProxyInfo from a URL (such as http_proxy env var)
954
+ """
955
+ url = urllib.parse.urlparse(url)
956
+
957
+ proxy_type = 3 # socks.PROXY_TYPE_HTTP
958
+ pi = ProxyInfo(
959
+ proxy_type=proxy_type,
960
+ proxy_host=url.hostname,
961
+ proxy_port=url.port or dict(https=443, http=80)[method],
962
+ proxy_user=url.username or None,
963
+ proxy_pass=url.password or None,
964
+ proxy_headers=None,
965
+ )
966
+
967
+ bypass_hosts = []
968
+ # If not given an explicit noproxy value, respect values in env vars.
969
+ if noproxy is None:
970
+ noproxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY", ""))
971
+ # Special case: A single '*' character means all hosts should be bypassed.
972
+ if noproxy == "*":
973
+ bypass_hosts = AllHosts
974
+ elif noproxy.strip():
975
+ bypass_hosts = noproxy.split(",")
976
+ bypass_hosts = tuple(filter(bool, bypass_hosts)) # To exclude empty string.
977
+
978
+ pi.bypass_hosts = bypass_hosts
979
+ return pi
980
+
981
+
982
+ class HTTPConnectionWithTimeout(http.client.HTTPConnection):
983
+ """HTTPConnection subclass that supports timeouts
984
+
985
+ HTTPConnection subclass that supports timeouts
986
+
987
+ All timeouts are in seconds. If None is passed for timeout then
988
+ Python's default timeout for sockets will be used. See for example
989
+ the docs of socket.setdefaulttimeout():
990
+ http://docs.python.org/library/socket.html#socket.setdefaulttimeout
991
+ """
992
+
993
+ def __init__(self, host, port=None, timeout=None, proxy_info=None):
994
+ http.client.HTTPConnection.__init__(self, host, port=port, timeout=timeout)
995
+
996
+ self.proxy_info = proxy_info
997
+ if proxy_info and not isinstance(proxy_info, ProxyInfo):
998
+ self.proxy_info = proxy_info("http")
999
+
1000
+ def connect(self):
1001
+ """Connect to the host and port specified in __init__."""
1002
+ if self.proxy_info and socks is None:
1003
+ raise ProxiesUnavailableError("Proxy support missing but proxy use was requested!")
1004
+ if self.proxy_info and self.proxy_info.isgood() and self.proxy_info.applies_to(self.host):
1005
+ use_proxy = True
1006
+ (
1007
+ proxy_type,
1008
+ proxy_host,
1009
+ proxy_port,
1010
+ proxy_rdns,
1011
+ proxy_user,
1012
+ proxy_pass,
1013
+ proxy_headers,
1014
+ ) = self.proxy_info.astuple()
1015
+
1016
+ host = proxy_host
1017
+ port = proxy_port
1018
+ else:
1019
+ use_proxy = False
1020
+
1021
+ host = self.host
1022
+ port = self.port
1023
+ proxy_type = None
1024
+
1025
+ socket_err = None
1026
+
1027
+ for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
1028
+ af, socktype, proto, canonname, sa = res
1029
+ try:
1030
+ if use_proxy:
1031
+ self.sock = socks.socksocket(af, socktype, proto)
1032
+ self.sock.setproxy(
1033
+ proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass,
1034
+ )
1035
+ else:
1036
+ self.sock = socket.socket(af, socktype, proto)
1037
+ self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1038
+ if has_timeout(self.timeout):
1039
+ self.sock.settimeout(self.timeout)
1040
+ if self.debuglevel > 0:
1041
+ print("connect: ({0}, {1}) ************".format(self.host, self.port))
1042
+ if use_proxy:
1043
+ print(
1044
+ "proxy: {0} ************".format(
1045
+ str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers,))
1046
+ )
1047
+ )
1048
+
1049
+ self.sock.connect((self.host, self.port) + sa[2:])
1050
+ except socket.error as e:
1051
+ socket_err = e
1052
+ if self.debuglevel > 0:
1053
+ print("connect fail: ({0}, {1})".format(self.host, self.port))
1054
+ if use_proxy:
1055
+ print(
1056
+ "proxy: {0}".format(
1057
+ str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers,))
1058
+ )
1059
+ )
1060
+ if self.sock:
1061
+ self.sock.close()
1062
+ self.sock = None
1063
+ continue
1064
+ break
1065
+ if not self.sock:
1066
+ raise socket_err
1067
+
1068
+
1069
+ class HTTPSConnectionWithTimeout(http.client.HTTPSConnection):
1070
+ """This class allows communication via SSL.
1071
+
1072
+ All timeouts are in seconds. If None is passed for timeout then
1073
+ Python's default timeout for sockets will be used. See for example
1074
+ the docs of socket.setdefaulttimeout():
1075
+ http://docs.python.org/library/socket.html#socket.setdefaulttimeout
1076
+ """
1077
+
1078
+ def __init__(
1079
+ self,
1080
+ host,
1081
+ port=None,
1082
+ key_file=None,
1083
+ cert_file=None,
1084
+ timeout=None,
1085
+ proxy_info=None,
1086
+ ca_certs=None,
1087
+ disable_ssl_certificate_validation=False,
1088
+ tls_maximum_version=None,
1089
+ tls_minimum_version=None,
1090
+ key_password=None,
1091
+ ):
1092
+
1093
+ self.disable_ssl_certificate_validation = disable_ssl_certificate_validation
1094
+ self.ca_certs = ca_certs if ca_certs else CA_CERTS
1095
+
1096
+ self.proxy_info = proxy_info
1097
+ if proxy_info and not isinstance(proxy_info, ProxyInfo):
1098
+ self.proxy_info = proxy_info("https")
1099
+
1100
+ context = _build_ssl_context(
1101
+ self.disable_ssl_certificate_validation,
1102
+ self.ca_certs,
1103
+ cert_file,
1104
+ key_file,
1105
+ maximum_version=tls_maximum_version,
1106
+ minimum_version=tls_minimum_version,
1107
+ key_password=key_password,
1108
+ )
1109
+ super(HTTPSConnectionWithTimeout, self).__init__(
1110
+ host, port=port, timeout=timeout, context=context,
1111
+ )
1112
+ self.key_file = key_file
1113
+ self.cert_file = cert_file
1114
+ self.key_password = key_password
1115
+
1116
+ def connect(self):
1117
+ """Connect to a host on a given (SSL) port."""
1118
+ if self.proxy_info and self.proxy_info.isgood() and self.proxy_info.applies_to(self.host):
1119
+ use_proxy = True
1120
+ (
1121
+ proxy_type,
1122
+ proxy_host,
1123
+ proxy_port,
1124
+ proxy_rdns,
1125
+ proxy_user,
1126
+ proxy_pass,
1127
+ proxy_headers,
1128
+ ) = self.proxy_info.astuple()
1129
+
1130
+ host = proxy_host
1131
+ port = proxy_port
1132
+ else:
1133
+ use_proxy = False
1134
+
1135
+ host = self.host
1136
+ port = self.port
1137
+ proxy_type = None
1138
+ proxy_headers = None
1139
+
1140
+ socket_err = None
1141
+
1142
+ address_info = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
1143
+ for family, socktype, proto, canonname, sockaddr in address_info:
1144
+ try:
1145
+ if use_proxy:
1146
+ sock = socks.socksocket(family, socktype, proto)
1147
+
1148
+ sock.setproxy(
1149
+ proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass,
1150
+ )
1151
+ else:
1152
+ sock = socket.socket(family, socktype, proto)
1153
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1154
+ if has_timeout(self.timeout):
1155
+ sock.settimeout(self.timeout)
1156
+ sock.connect((self.host, self.port))
1157
+
1158
+ self.sock = self._context.wrap_socket(sock, server_hostname=self.host)
1159
+
1160
+ # Python 3.3 compatibility: emulate the check_hostname behavior
1161
+ if not hasattr(self._context, "check_hostname") and not self.disable_ssl_certificate_validation:
1162
+ try:
1163
+ ssl.match_hostname(self.sock.getpeercert(), self.host)
1164
+ except Exception:
1165
+ self.sock.shutdown(socket.SHUT_RDWR)
1166
+ self.sock.close()
1167
+ raise
1168
+
1169
+ if self.debuglevel > 0:
1170
+ print("connect: ({0}, {1})".format(self.host, self.port))
1171
+ if use_proxy:
1172
+ print(
1173
+ "proxy: {0}".format(
1174
+ str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers,))
1175
+ )
1176
+ )
1177
+ except (ssl.SSLError, ssl.CertificateError) as e:
1178
+ if sock:
1179
+ sock.close()
1180
+ if self.sock:
1181
+ self.sock.close()
1182
+ self.sock = None
1183
+ raise
1184
+ except (socket.timeout, socket.gaierror):
1185
+ raise
1186
+ except socket.error as e:
1187
+ socket_err = e
1188
+ if self.debuglevel > 0:
1189
+ print("connect fail: ({0}, {1})".format(self.host, self.port))
1190
+ if use_proxy:
1191
+ print(
1192
+ "proxy: {0}".format(
1193
+ str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers,))
1194
+ )
1195
+ )
1196
+ if self.sock:
1197
+ self.sock.close()
1198
+ self.sock = None
1199
+ continue
1200
+ break
1201
+ if not self.sock:
1202
+ raise socket_err
1203
+
1204
+
1205
+ SCHEME_TO_CONNECTION = {
1206
+ "http": HTTPConnectionWithTimeout,
1207
+ "https": HTTPSConnectionWithTimeout,
1208
+ }
1209
+
1210
+
1211
+ class Http(object):
1212
+ """An HTTP client that handles:
1213
+
1214
+ - all methods
1215
+ - caching
1216
+ - ETags
1217
+ - compression,
1218
+ - HTTPS
1219
+ - Basic
1220
+ - Digest
1221
+ - WSSE
1222
+
1223
+ and more.
1224
+ """
1225
+
1226
+ def __init__(
1227
+ self,
1228
+ cache=None,
1229
+ timeout=None,
1230
+ proxy_info=proxy_info_from_environment,
1231
+ ca_certs=None,
1232
+ disable_ssl_certificate_validation=False,
1233
+ tls_maximum_version=None,
1234
+ tls_minimum_version=None,
1235
+ ):
1236
+ """If 'cache' is a string then it is used as a directory name for
1237
+ a disk cache. Otherwise it must be an object that supports the
1238
+ same interface as FileCache.
1239
+
1240
+ All timeouts are in seconds. If None is passed for timeout
1241
+ then Python's default timeout for sockets will be used. See
1242
+ for example the docs of socket.setdefaulttimeout():
1243
+ http://docs.python.org/library/socket.html#socket.setdefaulttimeout
1244
+
1245
+ `proxy_info` may be:
1246
+ - a callable that takes the http scheme ('http' or 'https') and
1247
+ returns a ProxyInfo instance per request. By default, uses
1248
+ proxy_info_from_environment.
1249
+ - a ProxyInfo instance (static proxy config).
1250
+ - None (proxy disabled).
1251
+
1252
+ ca_certs is the path of a file containing root CA certificates for SSL
1253
+ server certificate validation. By default, a CA cert file bundled with
1254
+ httplib2 is used.
1255
+
1256
+ If disable_ssl_certificate_validation is true, SSL cert validation will
1257
+ not be performed.
1258
+
1259
+ tls_maximum_version / tls_minimum_version require Python 3.7+ /
1260
+ OpenSSL 1.1.0g+. A value of "TLSv1_3" requires OpenSSL 1.1.1+.
1261
+ """
1262
+ self.proxy_info = proxy_info
1263
+ self.ca_certs = ca_certs
1264
+ self.disable_ssl_certificate_validation = disable_ssl_certificate_validation
1265
+ self.tls_maximum_version = tls_maximum_version
1266
+ self.tls_minimum_version = tls_minimum_version
1267
+ # Map domain name to an httplib connection
1268
+ self.connections = {}
1269
+ # The location of the cache, for now a directory
1270
+ # where cached responses are held.
1271
+ if cache and isinstance(cache, str):
1272
+ self.cache = FileCache(cache)
1273
+ else:
1274
+ self.cache = cache
1275
+
1276
+ # Name/password
1277
+ self.credentials = Credentials()
1278
+
1279
+ # Key/cert
1280
+ self.certificates = KeyCerts()
1281
+
1282
+ # authorization objects
1283
+ self.authorizations = []
1284
+
1285
+ # If set to False then no redirects are followed, even safe ones.
1286
+ self.follow_redirects = True
1287
+
1288
+ self.redirect_codes = REDIRECT_CODES
1289
+
1290
+ # Which HTTP methods do we apply optimistic concurrency to, i.e.
1291
+ # which methods get an "if-match:" etag header added to them.
1292
+ self.optimistic_concurrency_methods = ["PUT", "PATCH"]
1293
+
1294
+ self.safe_methods = list(SAFE_METHODS)
1295
+
1296
+ # If 'follow_redirects' is True, and this is set to True then
1297
+ # all redirecs are followed, including unsafe ones.
1298
+ self.follow_all_redirects = False
1299
+
1300
+ self.ignore_etag = False
1301
+
1302
+ self.force_exception_to_status_code = False
1303
+
1304
+ self.timeout = timeout
1305
+
1306
+ # Keep Authorization: headers on a redirect.
1307
+ self.forward_authorization_headers = False
1308
+
1309
+ def close(self):
1310
+ """Close persistent connections, clear sensitive data.
1311
+ Not thread-safe, requires external synchronization against concurrent requests.
1312
+ """
1313
+ existing, self.connections = self.connections, {}
1314
+ for _, c in existing.items():
1315
+ c.close()
1316
+ self.certificates.clear()
1317
+ self.clear_credentials()
1318
+
1319
+ def __getstate__(self):
1320
+ state_dict = copy.copy(self.__dict__)
1321
+ # In case request is augmented by some foreign object such as
1322
+ # credentials which handle auth
1323
+ if "request" in state_dict:
1324
+ del state_dict["request"]
1325
+ if "connections" in state_dict:
1326
+ del state_dict["connections"]
1327
+ return state_dict
1328
+
1329
+ def __setstate__(self, state):
1330
+ self.__dict__.update(state)
1331
+ self.connections = {}
1332
+
1333
+ def _auth_from_challenge(self, host, request_uri, headers, response, content):
1334
+ """A generator that creates Authorization objects
1335
+ that can be applied to requests.
1336
+ """
1337
+ challenges = auth._parse_www_authenticate(response, "www-authenticate")
1338
+ for cred in self.credentials.iter(host):
1339
+ for scheme in AUTH_SCHEME_ORDER:
1340
+ if scheme in challenges:
1341
+ yield AUTH_SCHEME_CLASSES[scheme](cred, host, request_uri, headers, response, content, self)
1342
+
1343
+ def add_credentials(self, name, password, domain=""):
1344
+ """Add a name and password that will be used
1345
+ any time a request requires authentication."""
1346
+ self.credentials.add(name, password, domain)
1347
+
1348
+ def add_certificate(self, key, cert, domain, password=None):
1349
+ """Add a key and cert that will be used
1350
+ any time a request requires authentication."""
1351
+ self.certificates.add(key, cert, domain, password)
1352
+
1353
+ def clear_credentials(self):
1354
+ """Remove all the names and passwords
1355
+ that are used for authentication"""
1356
+ self.credentials.clear()
1357
+ self.authorizations = []
1358
+
1359
+ def _conn_request(self, conn, request_uri, method, body, headers):
1360
+ i = 0
1361
+ seen_bad_status_line = False
1362
+ while i < RETRIES:
1363
+ i += 1
1364
+ try:
1365
+ if conn.sock is None:
1366
+ conn.connect()
1367
+ conn.request(method, request_uri, body, headers)
1368
+ except socket.timeout:
1369
+ conn.close()
1370
+ raise
1371
+ except socket.gaierror:
1372
+ conn.close()
1373
+ raise ServerNotFoundError("Unable to find the server at %s" % conn.host)
1374
+ except socket.error as e:
1375
+ errno_ = _errno_from_exception(e)
1376
+ if errno_ in (errno.ENETUNREACH, errno.EADDRNOTAVAIL) and i < RETRIES:
1377
+ continue # retry on potentially transient errors
1378
+ raise
1379
+ except http.client.HTTPException:
1380
+ if conn.sock is None:
1381
+ if i < RETRIES - 1:
1382
+ conn.close()
1383
+ conn.connect()
1384
+ continue
1385
+ else:
1386
+ conn.close()
1387
+ raise
1388
+ if i < RETRIES - 1:
1389
+ conn.close()
1390
+ conn.connect()
1391
+ continue
1392
+ # Just because the server closed the connection doesn't apparently mean
1393
+ # that the server didn't send a response.
1394
+ pass
1395
+ try:
1396
+ response = conn.getresponse()
1397
+ except (http.client.BadStatusLine, http.client.ResponseNotReady):
1398
+ # If we get a BadStatusLine on the first try then that means
1399
+ # the connection just went stale, so retry regardless of the
1400
+ # number of RETRIES set.
1401
+ if not seen_bad_status_line and i == 1:
1402
+ i = 0
1403
+ seen_bad_status_line = True
1404
+ conn.close()
1405
+ conn.connect()
1406
+ continue
1407
+ else:
1408
+ conn.close()
1409
+ raise
1410
+ except socket.timeout:
1411
+ raise
1412
+ except (socket.error, http.client.HTTPException):
1413
+ conn.close()
1414
+ if i == 0:
1415
+ conn.close()
1416
+ conn.connect()
1417
+ continue
1418
+ else:
1419
+ raise
1420
+ else:
1421
+ content = b""
1422
+ if method == "HEAD":
1423
+ conn.close()
1424
+ else:
1425
+ content = response.read()
1426
+ response = Response(response)
1427
+ if method != "HEAD":
1428
+ content = _decompressContent(response, content)
1429
+
1430
+ break
1431
+ return (response, content)
1432
+
1433
+ def _request(
1434
+ self, conn, host, absolute_uri, request_uri, method, body, headers, redirections, cachekey,
1435
+ ):
1436
+ """Do the actual request using the connection object
1437
+ and also follow one level of redirects if necessary"""
1438
+
1439
+ auths = [(auth.depth(request_uri), auth) for auth in self.authorizations if auth.inscope(host, request_uri)]
1440
+ auth = auths and sorted(auths)[0][1] or None
1441
+ if auth:
1442
+ auth.request(method, request_uri, headers, body)
1443
+
1444
+ (response, content) = self._conn_request(conn, request_uri, method, body, headers)
1445
+
1446
+ if auth:
1447
+ if auth.response(response, body):
1448
+ auth.request(method, request_uri, headers, body)
1449
+ (response, content) = self._conn_request(conn, request_uri, method, body, headers)
1450
+ response._stale_digest = 1
1451
+
1452
+ if response.status == 401:
1453
+ for authorization in self._auth_from_challenge(host, request_uri, headers, response, content):
1454
+ authorization.request(method, request_uri, headers, body)
1455
+ (response, content) = self._conn_request(conn, request_uri, method, body, headers)
1456
+ if response.status != 401:
1457
+ self.authorizations.append(authorization)
1458
+ authorization.response(response, body)
1459
+ break
1460
+
1461
+ if self.follow_all_redirects or method in self.safe_methods or response.status in (303, 308):
1462
+ if self.follow_redirects and response.status in self.redirect_codes:
1463
+ # Pick out the location header and basically start from the beginning
1464
+ # remembering first to strip the ETag header and decrement our 'depth'
1465
+ if redirections:
1466
+ if "location" not in response and response.status != 300:
1467
+ raise RedirectMissingLocation(
1468
+ _("Redirected but the response is missing a Location: header."), response, content,
1469
+ )
1470
+ # Fix-up relative redirects (which violate an RFC 2616 MUST)
1471
+ if "location" in response:
1472
+ location = response["location"]
1473
+ (scheme, authority, path, query, fragment) = parse_uri(location)
1474
+ if authority == None:
1475
+ response["location"] = urllib.parse.urljoin(absolute_uri, location)
1476
+ if response.status == 308 or (response.status == 301 and (method in self.safe_methods)):
1477
+ response["-x-permanent-redirect-url"] = response["location"]
1478
+ if "content-location" not in response:
1479
+ response["content-location"] = absolute_uri
1480
+ _updateCache(headers, response, content, self.cache, cachekey)
1481
+ if "if-none-match" in headers:
1482
+ del headers["if-none-match"]
1483
+ if "if-modified-since" in headers:
1484
+ del headers["if-modified-since"]
1485
+ if "authorization" in headers and not self.forward_authorization_headers:
1486
+ del headers["authorization"]
1487
+ if "location" in response:
1488
+ location = response["location"]
1489
+ old_response = copy.deepcopy(response)
1490
+ if "content-location" not in old_response:
1491
+ old_response["content-location"] = absolute_uri
1492
+ redirect_method = method
1493
+ if response.status in [302, 303]:
1494
+ redirect_method = "GET"
1495
+ body = None
1496
+ (response, content) = self.request(
1497
+ location, method=redirect_method, body=body, headers=headers, redirections=redirections - 1,
1498
+ )
1499
+ response.previous = old_response
1500
+ else:
1501
+ raise RedirectLimit(
1502
+ "Redirected more times than redirection_limit allows.", response, content,
1503
+ )
1504
+ elif response.status in [200, 203] and method in self.safe_methods:
1505
+ # Don't cache 206's since we aren't going to handle byte range requests
1506
+ if "content-location" not in response:
1507
+ response["content-location"] = absolute_uri
1508
+ _updateCache(headers, response, content, self.cache, cachekey)
1509
+
1510
+ return (response, content)
1511
+
1512
+ def _normalize_headers(self, headers):
1513
+ return _normalize_headers(headers)
1514
+
1515
+ # Need to catch and rebrand some exceptions
1516
+ # Then need to optionally turn all exceptions into status codes
1517
+ # including all socket.* and httplib.* exceptions.
1518
+
1519
+ def request(
1520
+ self, uri, method="GET", body=None, headers=None, redirections=DEFAULT_MAX_REDIRECTS, connection_type=None,
1521
+ ):
1522
+ """ Performs a single HTTP request.
1523
+ The 'uri' is the URI of the HTTP resource and can begin
1524
+ with either 'http' or 'https'. The value of 'uri' must be an absolute URI.
1525
+
1526
+ The 'method' is the HTTP method to perform, such as GET, POST, DELETE, etc.
1527
+ There is no restriction on the methods allowed.
1528
+
1529
+ The 'body' is the entity body to be sent with the request. It is a string
1530
+ object.
1531
+
1532
+ Any extra headers that are to be sent with the request should be provided in the
1533
+ 'headers' dictionary.
1534
+
1535
+ The maximum number of redirect to follow before raising an
1536
+ exception is 'redirections. The default is 5.
1537
+
1538
+ The return value is a tuple of (response, content), the first
1539
+ being and instance of the 'Response' class, the second being
1540
+ a string that contains the response entity body.
1541
+ """
1542
+ conn_key = ""
1543
+
1544
+ try:
1545
+ if headers is None:
1546
+ headers = {}
1547
+ else:
1548
+ headers = self._normalize_headers(headers)
1549
+
1550
+ if "user-agent" not in headers:
1551
+ headers["user-agent"] = "Python-httplib2/%s (gzip)" % __version__
1552
+
1553
+ uri = iri2uri(uri)
1554
+ # Prevent CWE-75 space injection to manipulate request via part of uri.
1555
+ # Prevent CWE-93 CRLF injection to modify headers via part of uri.
1556
+ uri = uri.replace(" ", "%20").replace("\r", "%0D").replace("\n", "%0A")
1557
+
1558
+ (scheme, authority, request_uri, defrag_uri) = urlnorm(uri)
1559
+
1560
+ conn_key = scheme + ":" + authority
1561
+ conn = self.connections.get(conn_key)
1562
+ if conn is None:
1563
+ if not connection_type:
1564
+ connection_type = SCHEME_TO_CONNECTION[scheme]
1565
+ certs = list(self.certificates.iter(authority))
1566
+ if issubclass(connection_type, HTTPSConnectionWithTimeout):
1567
+ if certs:
1568
+ conn = self.connections[conn_key] = connection_type(
1569
+ authority,
1570
+ key_file=certs[0][0],
1571
+ cert_file=certs[0][1],
1572
+ timeout=self.timeout,
1573
+ proxy_info=self.proxy_info,
1574
+ ca_certs=self.ca_certs,
1575
+ disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
1576
+ tls_maximum_version=self.tls_maximum_version,
1577
+ tls_minimum_version=self.tls_minimum_version,
1578
+ key_password=certs[0][2],
1579
+ )
1580
+ else:
1581
+ conn = self.connections[conn_key] = connection_type(
1582
+ authority,
1583
+ timeout=self.timeout,
1584
+ proxy_info=self.proxy_info,
1585
+ ca_certs=self.ca_certs,
1586
+ disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
1587
+ tls_maximum_version=self.tls_maximum_version,
1588
+ tls_minimum_version=self.tls_minimum_version,
1589
+ )
1590
+ else:
1591
+ conn = self.connections[conn_key] = connection_type(
1592
+ authority, timeout=self.timeout, proxy_info=self.proxy_info
1593
+ )
1594
+ conn.set_debuglevel(debuglevel)
1595
+
1596
+ if "range" not in headers and "accept-encoding" not in headers:
1597
+ headers["accept-encoding"] = "gzip, deflate"
1598
+
1599
+ info = email.message.Message()
1600
+ cachekey = None
1601
+ cached_value = None
1602
+ if self.cache:
1603
+ cachekey = defrag_uri
1604
+ cached_value = self.cache.get(cachekey)
1605
+ if cached_value:
1606
+ try:
1607
+ info, content = cached_value.split(b"\r\n\r\n", 1)
1608
+ info = email.message_from_bytes(info)
1609
+ for k, v in info.items():
1610
+ if v.startswith("=?") and v.endswith("?="):
1611
+ info.replace_header(k, str(*email.header.decode_header(v)[0]))
1612
+ except (IndexError, ValueError):
1613
+ self.cache.delete(cachekey)
1614
+ cachekey = None
1615
+ cached_value = None
1616
+
1617
+ if (
1618
+ method in self.optimistic_concurrency_methods
1619
+ and self.cache
1620
+ and "etag" in info
1621
+ and not self.ignore_etag
1622
+ and "if-match" not in headers
1623
+ ):
1624
+ # http://www.w3.org/1999/04/Editing/
1625
+ headers["if-match"] = info["etag"]
1626
+
1627
+ # https://tools.ietf.org/html/rfc7234
1628
+ # A cache MUST invalidate the effective Request URI as well as [...] Location and Content-Location
1629
+ # when a non-error status code is received in response to an unsafe request method.
1630
+ if self.cache and cachekey and method not in self.safe_methods:
1631
+ self.cache.delete(cachekey)
1632
+
1633
+ # Check the vary header in the cache to see if this request
1634
+ # matches what varies in the cache.
1635
+ if method in self.safe_methods and "vary" in info:
1636
+ vary = info["vary"]
1637
+ vary_headers = vary.lower().replace(" ", "").split(",")
1638
+ for header in vary_headers:
1639
+ key = "-varied-%s" % header
1640
+ value = info[key]
1641
+ if headers.get(header, None) != value:
1642
+ cached_value = None
1643
+ break
1644
+
1645
+ if (
1646
+ self.cache
1647
+ and cached_value
1648
+ and (method in self.safe_methods or info["status"] == "308")
1649
+ and "range" not in headers
1650
+ ):
1651
+ redirect_method = method
1652
+ if info["status"] not in ("307", "308"):
1653
+ redirect_method = "GET"
1654
+ if "-x-permanent-redirect-url" in info:
1655
+ # Should cached permanent redirects be counted in our redirection count? For now, yes.
1656
+ if redirections <= 0:
1657
+ raise RedirectLimit(
1658
+ "Redirected more times than redirection_limit allows.", {}, "",
1659
+ )
1660
+ (response, new_content) = self.request(
1661
+ info["-x-permanent-redirect-url"],
1662
+ method=redirect_method,
1663
+ headers=headers,
1664
+ redirections=redirections - 1,
1665
+ )
1666
+ response.previous = Response(info)
1667
+ response.previous.fromcache = True
1668
+ else:
1669
+ # Determine our course of action:
1670
+ # Is the cached entry fresh or stale?
1671
+ # Has the client requested a non-cached response?
1672
+ #
1673
+ # There seems to be three possible answers:
1674
+ # 1. [FRESH] Return the cache entry w/o doing a GET
1675
+ # 2. [STALE] Do the GET (but add in cache validators if available)
1676
+ # 3. [TRANSPARENT] Do a GET w/o any cache validators (Cache-Control: no-cache) on the request
1677
+ entry_disposition = _entry_disposition(info, headers)
1678
+
1679
+ if entry_disposition == "FRESH":
1680
+ response = Response(info)
1681
+ response.fromcache = True
1682
+ return (response, content)
1683
+
1684
+ if entry_disposition == "STALE":
1685
+ if "etag" in info and not self.ignore_etag and not "if-none-match" in headers:
1686
+ headers["if-none-match"] = info["etag"]
1687
+ if "last-modified" in info and not "last-modified" in headers:
1688
+ headers["if-modified-since"] = info["last-modified"]
1689
+ elif entry_disposition == "TRANSPARENT":
1690
+ pass
1691
+
1692
+ (response, new_content) = self._request(
1693
+ conn, authority, uri, request_uri, method, body, headers, redirections, cachekey,
1694
+ )
1695
+
1696
+ if response.status == 304 and method == "GET":
1697
+ # Rewrite the cache entry with the new end-to-end headers
1698
+ # Take all headers that are in response
1699
+ # and overwrite their values in info.
1700
+ # unless they are hop-by-hop, or are listed in the connection header.
1701
+
1702
+ for key in _get_end2end_headers(response):
1703
+ info[key] = response[key]
1704
+ merged_response = Response(info)
1705
+ if hasattr(response, "_stale_digest"):
1706
+ merged_response._stale_digest = response._stale_digest
1707
+ _updateCache(headers, merged_response, content, self.cache, cachekey)
1708
+ response = merged_response
1709
+ response.status = 200
1710
+ response.fromcache = True
1711
+
1712
+ elif response.status == 200:
1713
+ content = new_content
1714
+ else:
1715
+ self.cache.delete(cachekey)
1716
+ content = new_content
1717
+ else:
1718
+ cc = _parse_cache_control(headers)
1719
+ if "only-if-cached" in cc:
1720
+ info["status"] = "504"
1721
+ response = Response(info)
1722
+ content = b""
1723
+ else:
1724
+ (response, content) = self._request(
1725
+ conn, authority, uri, request_uri, method, body, headers, redirections, cachekey,
1726
+ )
1727
+ except Exception as e:
1728
+ is_timeout = isinstance(e, socket.timeout)
1729
+ if is_timeout:
1730
+ conn = self.connections.pop(conn_key, None)
1731
+ if conn:
1732
+ conn.close()
1733
+
1734
+ if self.force_exception_to_status_code:
1735
+ if isinstance(e, HttpLib2ErrorWithResponse):
1736
+ response = e.response
1737
+ content = e.content
1738
+ response.status = 500
1739
+ response.reason = str(e)
1740
+ elif isinstance(e, socket.timeout):
1741
+ content = b"Request Timeout"
1742
+ response = Response({"content-type": "text/plain", "status": "408", "content-length": len(content),})
1743
+ response.reason = "Request Timeout"
1744
+ else:
1745
+ content = str(e).encode("utf-8")
1746
+ response = Response({"content-type": "text/plain", "status": "400", "content-length": len(content),})
1747
+ response.reason = "Bad Request"
1748
+ else:
1749
+ raise
1750
+
1751
+ return (response, content)
1752
+
1753
+
1754
+ class Response(dict):
1755
+ """An object more like email.message than httplib.HTTPResponse."""
1756
+
1757
+ """Is this response from our local cache"""
1758
+ fromcache = False
1759
+ """HTTP protocol version used by server.
1760
+
1761
+ 10 for HTTP/1.0, 11 for HTTP/1.1.
1762
+ """
1763
+ version = 11
1764
+
1765
+ "Status code returned by server. "
1766
+ status = 200
1767
+ """Reason phrase returned by server."""
1768
+ reason = "Ok"
1769
+
1770
+ previous = None
1771
+
1772
+ def __init__(self, info):
1773
+ # info is either an email.message or
1774
+ # an httplib.HTTPResponse object.
1775
+ if isinstance(info, http.client.HTTPResponse):
1776
+ for key, value in info.getheaders():
1777
+ key = key.lower()
1778
+ prev = self.get(key)
1779
+ if prev is not None:
1780
+ value = ", ".join((prev, value))
1781
+ self[key] = value
1782
+ self.status = info.status
1783
+ self["status"] = str(self.status)
1784
+ self.reason = info.reason
1785
+ self.version = info.version
1786
+ elif isinstance(info, email.message.Message):
1787
+ for key, value in list(info.items()):
1788
+ self[key.lower()] = value
1789
+ self.status = int(self["status"])
1790
+ else:
1791
+ for key, value in info.items():
1792
+ self[key.lower()] = value
1793
+ self.status = int(self.get("status", self.status))
1794
+
1795
+ def __getattr__(self, name):
1796
+ if name == "dict":
1797
+ return self
1798
+ else:
1799
+ raise AttributeError(name)
.venv/lib/python3.11/site-packages/httplib2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (80.4 kB). View file
 
.venv/lib/python3.11/site-packages/httplib2/__pycache__/auth.cpython-311.pyc ADDED
Binary file (4.36 kB). View file
 
.venv/lib/python3.11/site-packages/httplib2/__pycache__/certs.cpython-311.pyc ADDED
Binary file (1.7 kB). View file
 
.venv/lib/python3.11/site-packages/httplib2/__pycache__/error.cpython-311.pyc ADDED
Binary file (2.69 kB). View file
 
.venv/lib/python3.11/site-packages/httplib2/__pycache__/iri2uri.cpython-311.pyc ADDED
Binary file (4.88 kB). View file
 
.venv/lib/python3.11/site-packages/httplib2/__pycache__/socks.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
.venv/lib/python3.11/site-packages/httplib2/auth.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import re
3
+
4
+ import pyparsing as pp
5
+
6
+ from .error import *
7
+
8
+
9
+ try: # pyparsing>=3.0.0
10
+ downcaseTokens = pp.common.downcaseTokens
11
+ except AttributeError:
12
+ downcaseTokens = pp.downcaseTokens
13
+
14
+ UNQUOTE_PAIRS = re.compile(r"\\(.)")
15
+ unquote = lambda s, l, t: UNQUOTE_PAIRS.sub(r"\1", t[0][1:-1])
16
+
17
+ # https://tools.ietf.org/html/rfc7235#section-1.2
18
+ # https://tools.ietf.org/html/rfc7235#appendix-B
19
+ tchar = "!#$%&'*+-.^_`|~" + pp.nums + pp.alphas
20
+ token = pp.Word(tchar).setName("token")
21
+ token68 = pp.Combine(pp.Word("-._~+/" + pp.nums + pp.alphas) + pp.Optional(pp.Word("=").leaveWhitespace())).setName(
22
+ "token68"
23
+ )
24
+
25
+ quoted_string = pp.dblQuotedString.copy().setName("quoted-string").setParseAction(unquote)
26
+ auth_param_name = token.copy().setName("auth-param-name").addParseAction(downcaseTokens)
27
+ auth_param = auth_param_name + pp.Suppress("=") + (quoted_string | token)
28
+ params = pp.Dict(pp.delimitedList(pp.Group(auth_param)))
29
+
30
+ scheme = token("scheme")
31
+ challenge = scheme + (params("params") | token68("token"))
32
+
33
+ authentication_info = params.copy()
34
+ www_authenticate = pp.delimitedList(pp.Group(challenge))
35
+
36
+
37
+ def _parse_authentication_info(headers, headername="authentication-info"):
38
+ """https://tools.ietf.org/html/rfc7615
39
+ """
40
+ header = headers.get(headername, "").strip()
41
+ if not header:
42
+ return {}
43
+ try:
44
+ parsed = authentication_info.parseString(header)
45
+ except pp.ParseException as ex:
46
+ # print(ex.explain(ex))
47
+ raise MalformedHeader(headername)
48
+
49
+ return parsed.asDict()
50
+
51
+
52
+ def _parse_www_authenticate(headers, headername="www-authenticate"):
53
+ """Returns a dictionary of dictionaries, one dict per auth_scheme."""
54
+ header = headers.get(headername, "").strip()
55
+ if not header:
56
+ return {}
57
+ try:
58
+ parsed = www_authenticate.parseString(header)
59
+ except pp.ParseException as ex:
60
+ # print(ex.explain(ex))
61
+ raise MalformedHeader(headername)
62
+
63
+ retval = {
64
+ challenge["scheme"].lower(): challenge["params"].asDict()
65
+ if "params" in challenge
66
+ else {"token": challenge.get("token")}
67
+ for challenge in parsed
68
+ }
69
+ return retval
.venv/lib/python3.11/site-packages/httplib2/cacerts.txt ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/httplib2/certs.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for certificate management."""
2
+
3
+ import os
4
+
5
+ certifi_available = False
6
+ certifi_where = None
7
+ try:
8
+ from certifi import where as certifi_where
9
+ certifi_available = True
10
+ except ImportError:
11
+ pass
12
+
13
+ custom_ca_locater_available = False
14
+ custom_ca_locater_where = None
15
+ try:
16
+ from ca_certs_locater import get as custom_ca_locater_where
17
+ custom_ca_locater_available = True
18
+ except ImportError:
19
+ pass
20
+
21
+
22
+ BUILTIN_CA_CERTS = os.path.join(
23
+ os.path.dirname(os.path.abspath(__file__)), "cacerts.txt"
24
+ )
25
+
26
+
27
+ def where():
28
+ env = os.environ.get("HTTPLIB2_CA_CERTS")
29
+ if env is not None:
30
+ if os.path.isfile(env):
31
+ return env
32
+ else:
33
+ raise RuntimeError("Environment variable HTTPLIB2_CA_CERTS not a valid file")
34
+ if custom_ca_locater_available:
35
+ return custom_ca_locater_where()
36
+ if certifi_available:
37
+ return certifi_where()
38
+ return BUILTIN_CA_CERTS
39
+
40
+
41
+ if __name__ == "__main__":
42
+ print(where())
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Noam Gat
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/METADATA ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: lm-format-enforcer
3
+ Version: 0.10.9
4
+ Summary: Enforce the output format (JSON Schema, Regex etc) of a language model
5
+ Home-page: https://github.com/noamgat/lm-format-enforcer
6
+ License: MIT
7
+ Author: Noam Gat
8
+ Author-email: noamgat@gmail.com
9
+ Requires-Python: >=3.8,<4.0
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.8
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Programming Language :: Python :: 3.13
22
+ Classifier: Programming Language :: Python :: 3 :: Only
23
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
+ Requires-Dist: interegular (>=0.3.2)
25
+ Requires-Dist: packaging
26
+ Requires-Dist: pydantic (>=1.10.8)
27
+ Requires-Dist: pyyaml
28
+ Project-URL: Bug Tracker, https://github.com/noamgat/lm-format-enforcer/issues
29
+ Project-URL: Documentation, https://github.com/noamgat/lm-format-enforcer
30
+ Project-URL: Repository, https://github.com/noamgat/lm-format-enforcer
31
+ Description-Content-Type: text/markdown
32
+
33
+ # lm-format-enforcer
34
+
35
+ ![LMFE Logo](https://raw.githubusercontent.com/noamgat/lm-format-enforcer/main/docs/Logo.png)
36
+
37
+ **Enforce the output format (JSON Schema, Regex etc) of a language model**
38
+
39
+ <a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_llama2_enforcer.ipynb">
40
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
41
+ </a>
42
+
43
+ [![Code Coverage](https://codecov.io/gh/noamgat/lm-format-enforcer/graph/badge.svg?token=63U3S58VWS)](https://codecov.io/gh/noamgat/lm-format-enforcer)
44
+ ![Tests](https://github.com/noamgat/lm-format-enforcer/actions/workflows/run_tests.yml/badge.svg)
45
+
46
+
47
+ ![Solution at a glance](https://raw.githubusercontent.com/noamgat/lm-format-enforcer/main/docs/Intro.webp)
48
+
49
+
50
+ Language models are able to generate text, but when requiring a precise output format, they do not always perform as instructed.
51
+ Various prompt engineering techniques have been introduced to improve the robustness of the generated text, but they are not always sufficient.
52
+ This project solves the issues by filtering the tokens that the language model is allowed to generate at every timestep, thus ensuring that the output format is respected, while minimizing the limitations on the language model.
53
+
54
+ ## Installation
55
+ ```pip install lm-format-enforcer```
56
+
57
+ ## Basic Tutorial
58
+ ```python
59
+ # Requirements if running from Google Colab with a T4 GPU.
60
+ !pip install transformers torch lm-format-enforcer huggingface_hub optimum
61
+ !pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
62
+
63
+ from pydantic import BaseModel
64
+ from lmformatenforcer import JsonSchemaParser
65
+ from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
66
+ from transformers import pipeline
67
+
68
+ class AnswerFormat(BaseModel):
69
+ first_name: str
70
+ last_name: str
71
+ year_of_birth: int
72
+ num_seasons_in_nba: int
73
+
74
+ # Create a transformers pipeline
75
+ hf_pipeline = pipeline('text-generation', model='TheBloke/Llama-2-7b-Chat-GPTQ', device_map='auto')
76
+ prompt = f'Here is information about Michael Jordan in the following json schema: {AnswerFormat.schema_json()} :\n'
77
+
78
+ # Create a character level parser and build a transformers prefix function from it
79
+ parser = JsonSchemaParser(AnswerFormat.schema())
80
+ prefix_function = build_transformers_prefix_allowed_tokens_fn(hf_pipeline.tokenizer, parser)
81
+
82
+ # Call the pipeline with the prefix function
83
+ output_dict = hf_pipeline(prompt, prefix_allowed_tokens_fn=prefix_function)
84
+
85
+ # Extract the results
86
+ result = output_dict[0]['generated_text'][len(prompt):]
87
+ print(result)
88
+ # {'first_name': 'Michael', 'last_name': 'Jordan', 'year_of_birth': 1963, 'num_seasons_in_nba': 15}
89
+ ```
90
+
91
+ ## Capabilities / Advantages
92
+
93
+ - Works with any Python language model and tokenizer. Already supports [transformers](https://github.com/huggingface/transformers), [LangChain](https://python.langchain.com/docs/integrations/llms/lmformatenforcer_experimental), [LlamaIndex](https://docs.llamaindex.ai/en/latest/community/integrations/lmformatenforcer.html), [llama.cpp](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_llamacpppython_integration.ipynb), [vLLM](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_vllm_integration.ipynb), [Haystack](https://haystack.deepset.ai/integrations/lmformatenforcer), [NVIDIA TensorRT-LLM](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_trtllm_integration.ipynb) and [ExLlamaV2](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_exllamav2_integration.ipynb).
94
+ - Supports batched generation and beam searches - each input / beam can have different tokens filtered at every timestep
95
+ - Supports JSON Schema, JSON Mode (schemaless) and Regular Expression formats
96
+ - Supports both required and optional fields in JSON schemas
97
+ - Supports nested fields, arrays and dictionaries in JSON schemas
98
+ - Gives the language model freedom to control whitespacing and field ordering in JSON schemas, reducing hallucinations.
99
+ - Does not modify the high level loop of transformers API, so can be used in any scenario.
100
+
101
+
102
+ ## Comparison to other libraries
103
+
104
+ Capability | LM Format Enforcer | [Guidance](https://github.com/guidance-ai/guidance) | [Jsonformer](https://github.com/1rgs/jsonformer) | [Outlines](https://github.com/outlines-dev/outlines)
105
+ :------------ | :-------------| :-------------| :------------- | :----
106
+ Regular Expressions | ✅ | ✅ | ❌ | ✅
107
+ JSON Schema | ✅ | 🟡 ([Partial conversion is possible](https://github.com/guidance-ai/guidance/blob/main/notebooks/applications/jsonformer.ipynb)) | ✅ | ✅
108
+ Batched Generation | ✅ | ❌ | ❌ | ✅
109
+ Beam Search | ✅ | ❌ | ❌ | ✅
110
+ Integrates into existing pipelines | ✅ | ❌ | ❌ | ✅
111
+ Optional JSON Fields | ✅ | ❌ | ❌ | ❌
112
+ LLM Controls JSON field ordering and whitespace | ✅ | ❌ | ❌ | ❌
113
+ JSON Schema with recursive classes | ✅ | ❌ | ✅ | ❌
114
+ Visual model support | [✅](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_llama32_vision_enforcer.ipynb) | ✅ | ❌ | ❌
115
+
116
+ Spotted a mistake? Library updated with new capabilities? [Open an issue!](https://github.com/noamgat/lm-format-enforcer/issues)
117
+
118
+ ## Detailed example
119
+
120
+ We created a Google Colab Notebook which contains a full example of how to use this library to enforce the output format of llama2, including interpreting the intermediate results. The notebook can run on a free GPU-backed runtime in Colab.
121
+
122
+ <a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_llama2_enforcer.ipynb">
123
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
124
+ </a>
125
+
126
+ You can also [view the notebook in GitHub](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_llama2_enforcer.ipynb).
127
+
128
+ For the different ways to integrate with huggingface transformers, see the [unit tests](https://github.com/noamgat/lm-format-enforcer/blob/main/tests/test_transformerenforcer.py).
129
+
130
+ ## vLLM Server Integration
131
+
132
+ LM Format Enforcer is integrated into the [vLLM](https://github.com/vllm-project/vllm) inference server. vLLM includes an [OpenAI compatible server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html) with added capabilities that allow using LM Format Enforcer without writing custom inference code.
133
+
134
+ Use LM Format Enforcer with the vLLM OpenAI Server either by adding the [vLLM command line parameter](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server):
135
+
136
+ ```
137
+ python -m vllm.entrypoints.openai.api_server \
138
+ --model mistralai/Mistral-7B-Instruct-v0.2 \
139
+ --guided-decoding-backend lm-format-enforcer
140
+ ```
141
+
142
+ Or on a per-request basis, by adding the `guided_decoding_backend` parameter to the request together with the guided decoding parameters:
143
+
144
+ ```
145
+ completion = client.chat.completions.create(
146
+ model="mistralai/Mistral-7B-Instruct-v0.2",
147
+ messages=[
148
+ {"role": "user", "content": "Classify this sentiment: LMFE is wonderful!"}
149
+ ],
150
+ extra_body={
151
+ "guided_regex": "[Pp]ositive|[Nn]egative",
152
+ "guided_decoding_backend": "lm-format-enforcer"
153
+ }
154
+ )
155
+ ```
156
+ Json schema and choice decoding also supported via `guided_json` and `guided_choice` [extra parameters](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api).
157
+
158
+ ## How does it work?
159
+
160
+ The library works by combining a character level parser and a tokenizer prefix tree into a smart token filtering mechanism.
161
+
162
+ ![An example of the character level parser and tokenizer prefix tree in a certain timestep](https://raw.githubusercontent.com/noamgat/lm-format-enforcer/main/docs/Trees.drawio.svg?sanitize=true)
163
+
164
+ ### Character Level Parser
165
+
166
+ Parsing a string into any kind of formatter can be looked at as an implicit tree structure - at any moment in the parsing process, there is a set of allowed next characters, and if any of them are selected, there is a new set of allowed next characters, and so on.
167
+
168
+ ```CharacterLevelParser``` is an interface for parsing according to this implicit structure. ```add_character()``` and ```get_allowed_characters()``` can be seen as tree traversal methods.
169
+
170
+ There are several implementations of this interface:
171
+ - ```JsonSchemaParser``` - parses according to a json schema (or pure json output - `JsonSchemaParser(None) will result in any json object allowed`).
172
+ - ```StringParser``` - forces an exact string (used mainly for diagnostics)
173
+ - ```RegexParser``` - parses according to a regular expression. Note that this cannot use the built in python regex and uses a manually implemented one (via the [interegular](https://pypi.org/project/interegular/) library), so it doesn't cover 100% of the regex standard.
174
+ ### Tokenizer Prefix Tree
175
+
176
+ Given a tokenizer used by a certain language model, we can build a prefix tree of all the tokens that the language model can generate. This is done by generating all possible sequences of tokens, and adding them to the tree.
177
+ See ```TokenizerPrefixTree```
178
+
179
+ ### Combining the two
180
+
181
+ Given a character level parser and a tokenizer prefix tree, we can elegantly and efficiently filter the tokens that the language model is allowed to generate at the next timestep:
182
+ We only traverse the characters that are in BOTH the character level parsing node and the tokenizer prefix tree node. This allows us to find all of the tokens (including complex subword tokens such as ```","``` which are critical in JSON parsing).
183
+ We do this recursively on both trees and return all of the allowed tokens. When the language model generates a token, we advance the character level parser according to the new characters, ready to filter the next timestep.
184
+
185
+ ### How is this approach different? Why is it good?
186
+
187
+ This is not the first library to enforce the output format of a language model. However, other similar libraries (such as Guidance, JsonFormer and Outlines) enforce an exact output format. This means that the language model is not allowed to control whitespacing, field optionality and field ordering (in the JSON usecase). While this seems inconsequencial to humans, it means that the language model may not be generating the JSON formats that it "wants to" generate, and could put its internal states in a suboptimal value, reducing the quality of the output in later timesteps.
188
+
189
+ This forces language model users to know the details of the language model they are using (for example - were JSONs minified before pretraining?) and modify the libraries to generate the precise format.
190
+
191
+ We avoid this problem by scanning potential next tokens and allowing any token sequence that will be parsed into the output format. This means that the language model can control all of these aspects, and output the token sequence that matches its' style in the most natural way, without requiring the developer to know the details.
192
+
193
+
194
+ ## Diagnostics - Will I always get good results?
195
+
196
+ Using this library guarantees that the output will match the format, but it does not guarantee that the output will be semantically correct. Forcing the language model to conform to a certain output may lead to increased hallucinations. Guiding the model via prompt engineering is still likely to improve results.
197
+
198
+ In order to help you understand the aggressiveness caused by the format enforcement, if you pass ```output_scores=True``` and ```return_dict_in_generate=True``` in the ```kwargs``` to ```generate_enforced()``` (these are existing optional parameters in the ```transformers``` library), you will also get a token-by-token dataframe showing which token was selected, its score, and what was the token that would have been chosen if the format enforcement was not applied. If you see that the format enforcer forced the language model to select tokens with very low weights, it is a likely contributor to the poor results. Try modifying the prompt to guide the language model to not force the format enforcer to be so aggressive.
199
+
200
+ Example using the regular expression format ``` Michael Jordan was Born in (\d)+.```
201
+
202
+ idx | generated_token | generated_token_idx | generated_score | leading_token | leading_token_idx | leading_score
203
+ :------------ | :-------------| :-------------| :------------- | :------------ | :-------------| :-------------
204
+ 0 | ▁ | 29871 | 1.000000 | ▁ | 29871 | 1.000000
205
+ 1 | Michael | 24083 | 0.000027 | ▁Sure | 18585 | 0.959473
206
+ 2 | ▁Jordan | 18284 | 1.000000 | ▁Jordan | 18284 | 1.000000
207
+ 3 | ▁was | 471 | 1.000000 | ▁was | 471 | 1.000000
208
+ 4 | ▁Born | 19298 | 0.000008 | ▁born | 6345 | 1.000000
209
+ 5 | ▁in | 297 | 0.994629 | ▁in | 297 | 0.994629
210
+ 6 | ▁ | 29871 | 0.982422 | ▁ | 29871 | 0.982422
211
+ 7 | 1 | 29896 | 1.000000 | 1 | 29896 | 1.000000
212
+ 8 | 9 | 29929 | 1.000000 | 9 | 29929 | 1.000000
213
+ 9 | 6 | 29953 | 1.000000 | 6 | 29953 | 1.000000
214
+ 10 | 3 | 29941 | 1.000000 | 3 | 29941 | 1.000000
215
+ 11 | . | 29889 | 0.999512 | . | 29889 | 0.999512
216
+ 12 | ```</s>``` | 2 | 0.981445 | ```</s>``` | 2 | 0.981445
217
+
218
+
219
+ You can see that the model "wanted" to start the answer using ```Sure```, but the format enforcer forced it to use ```Michael``` - there was a big gap in token 1. Afterwards, almost all of the leading scores are all within the allowed token set, meaning the model likely did not hallucinate due to the token forcing. The only exception was timestep 4 - " Born" was forced while the LLM wanted to choose "born". This is a hint for the prompt engineer, to change the prompt to use a lowercase b instead.
220
+
221
+
222
+ ## Configuration options
223
+
224
+ LM Format Enforcer makes use of several heuristics to avoid edge cases that may happen with LLM's generating structure outputs.
225
+ There are two ways to control these heuristics:
226
+
227
+ ### Option 1: via Environment Variables
228
+
229
+ There are several environment variables that can be set, that affect the operation of the library. This method is useful when you don't want to modify the code, for example when using the library through the vLLM OpenAI server.
230
+
231
+ - `LMFE_MAX_CONSECUTIVE_WHITESPACES` - How many consecutive whitespaces are allowed when parsing JsonSchemaObjects. Default: 12.
232
+ - `LMFE_STRICT_JSON_FIELD_ORDER` - Should the JsonSchemaParser force the properties to appear in the same order as they appear in the 'required' list of the JsonSchema? (Note: this is consistent with the order of declaration in Pydantic models). Default: False.
233
+ - `LMFE_MAX_JSON_ARRAY_LENGTH` - What is the maximal JSON array length, if not specified by the schema. Helps LLM Avoid infinite loops. Default: 20.
234
+
235
+ ### Option 2: via the CharacterLevelParserConfig class
236
+ When using the library through code, any `CharacterLevelParser` (`JsonSchemaParser`, `RegexParser` etc) constructor receives an optional `CharacterLevelParserConfig` object.
237
+
238
+ Therefore, to configure the heuristics of a single parser, instantiate a `CharacterLevelParserConfig` object, modify its values and pass it to the `CharacterLevelParser`'s constructor.
239
+
240
+
241
+
242
+ ## Known issues and limitations
243
+
244
+ - LM Format Enforcer requires a python API to process the output logits of the language model. This means that until the APIs are extended, it can not be used with OpenAI ChatGPT and similar API based solutions.
245
+ - Regular expression syntax is not 100% supported. See [interegular](https://pypi.org/project/interegular/) for more details.
246
+ - LM Format Enforcer Regex Parser can only generate characters that exist in the tokenizer vocabulary. This may be solved in a later version, see [the issue on GitHub](https://github.com/noamgat/lm-format-enforcer/issues/13).
247
+
248
+
249
+ ## Contributers and contributing
250
+
251
+ See [CONTRIBUTORS.md](https://github.com/noamgat/lm-format-enforcer/blob/main/CONTRIBUTORS.md) for a list of contributers.
252
+
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/RECORD ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lm_format_enforcer-0.10.9.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ lm_format_enforcer-0.10.9.dist-info/LICENSE,sha256=0cAjc_naVKu0D7n6XKbBkTbRDadqEFT_HuJSgiGuG_4,1065
3
+ lm_format_enforcer-0.10.9.dist-info/METADATA,sha256=m23Ia1CRbDDrL7N-gnortFDjA-w9YAy_2HP6EmbH5Mg,17143
4
+ lm_format_enforcer-0.10.9.dist-info/RECORD,,
5
+ lm_format_enforcer-0.10.9.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
6
+ lmformatenforcer/__init__.py,sha256=VQlyXyCK1HC4an39hFEdX9hae5vod1wY9DjHTVnJUCM,850
7
+ lmformatenforcer/__pycache__/__init__.cpython-311.pyc,,
8
+ lmformatenforcer/__pycache__/analyzer.cpython-311.pyc,,
9
+ lmformatenforcer/__pycache__/characterlevelparser.cpython-311.pyc,,
10
+ lmformatenforcer/__pycache__/consts.cpython-311.pyc,,
11
+ lmformatenforcer/__pycache__/exceptions.cpython-311.pyc,,
12
+ lmformatenforcer/__pycache__/jsonschemaparser.cpython-311.pyc,,
13
+ lmformatenforcer/__pycache__/regexparser.cpython-311.pyc,,
14
+ lmformatenforcer/__pycache__/tokenenforcer.cpython-311.pyc,,
15
+ lmformatenforcer/__pycache__/tokenizerprefixtree.cpython-311.pyc,,
16
+ lmformatenforcer/analyzer.py,sha256=imn5kKVaY833GY1D7qfY-A-7hJ0oQz9E3tcXIww5uTY,3893
17
+ lmformatenforcer/characterlevelparser.py,sha256=f3MrMEC_Qfoo1uZjqtDiAiBsRWipMRc1MFDlRB-nd-Y,8557
18
+ lmformatenforcer/consts.py,sha256=g30hYNKwlgc564aX6ZTYbGPQsv5iHiXEPx0eSNyZbSw,1094
19
+ lmformatenforcer/exceptions.py,sha256=oJuEhaGwaawtgGULhTSL-mnVVNiXuaBC_UNlh6MEfX8,104
20
+ lmformatenforcer/external/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ lmformatenforcer/external/__pycache__/__init__.cpython-311.pyc,,
22
+ lmformatenforcer/external/__pycache__/jsonschemaobject.cpython-311.pyc,,
23
+ lmformatenforcer/external/__pycache__/jsonschemaobjectutil.cpython-311.pyc,,
24
+ lmformatenforcer/external/jsonschemaobject.py,sha256=XT1C-T8MtGzudm-hG5U1isRofLkPPoA5cA1ZjqYr7cs,10716
25
+ lmformatenforcer/external/jsonschemaobjectutil.py,sha256=3W0IVrUDS6k0ydp-Nhd1ymQrDi4Fe3eO8naX-o20zl0,7044
26
+ lmformatenforcer/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ lmformatenforcer/integrations/__pycache__/__init__.cpython-311.pyc,,
28
+ lmformatenforcer/integrations/__pycache__/exllamav2.cpython-311.pyc,,
29
+ lmformatenforcer/integrations/__pycache__/haystackv1.cpython-311.pyc,,
30
+ lmformatenforcer/integrations/__pycache__/haystackv2.cpython-311.pyc,,
31
+ lmformatenforcer/integrations/__pycache__/llamacpp.cpython-311.pyc,,
32
+ lmformatenforcer/integrations/__pycache__/transformers.cpython-311.pyc,,
33
+ lmformatenforcer/integrations/__pycache__/trtllm.cpython-311.pyc,,
34
+ lmformatenforcer/integrations/__pycache__/vllm.cpython-311.pyc,,
35
+ lmformatenforcer/integrations/exllamav2.py,sha256=usZWTriLVp4aCA34Y8niehZTrOuY3-Lpig_q6De4Kk8,2595
36
+ lmformatenforcer/integrations/haystackv1.py,sha256=WZ43iebe8Hag3J3ndAPfRgfxd-JWxhFVoUNvjA485ag,2820
37
+ lmformatenforcer/integrations/haystackv2.py,sha256=YLkgadglK4d0Mf2-gJh4UlJiChZouIPjKxk2hYqAq7o,3510
38
+ lmformatenforcer/integrations/llamacpp.py,sha256=W8MckPo2Jtf0_As3Aa4kCaKVQRsgPIVrKyDLX7EWvNQ,3572
39
+ lmformatenforcer/integrations/transformers.py,sha256=PjglkWNJYq_3lMrLOxqXSD-6vb2WTEWCMx9Czl7DYag,7034
40
+ lmformatenforcer/integrations/trtllm.py,sha256=zbk14D7qjrEHUTXYJRWO1Ql4-_VbdtUbwFfqj1-1BLU,3869
41
+ lmformatenforcer/integrations/vllm.py,sha256=mBnjXlIuVLanwiqikLi6N2ZWBgqbEH5m8sjL5arPDBM,2987
42
+ lmformatenforcer/jsonschemaparser.py,sha256=qLt8-Pu3_ZyTsvhYMq5oPx_kxnW19GPaTyWhOdJ_ctI,34589
43
+ lmformatenforcer/regexparser.py,sha256=YoPSowclmDe6M4YbBabf3im2wAyUqzpKkQXAEMYMnEQ,3888
44
+ lmformatenforcer/tokenenforcer.py,sha256=AnvFQWdMHZ8gC5wgBNH5AURqQtJ44cTrHGjGIKvWpAI,10240
45
+ lmformatenforcer/tokenizerprefixtree.py,sha256=EYtWFfzXVY49OMAKm01MvA3SDSGfZl-mwc3WJwMJOUw,6813
.venv/lib/python3.11/site-packages/lm_format_enforcer-0.10.9.dist-info/WHEEL ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 1.9.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv.h ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /* cudnn_adv : cuDNN's advanced and experimental features.
51
+
52
+ */
53
+
54
+ #if !defined(CUDNN_ADV_H_)
55
+ #define CUDNN_ADV_H_
56
+
57
+ #include <stdint.h>
58
+
59
+ #include "cudnn_version.h"
60
+ #include "cudnn_ops.h"
61
+
62
+ /* These version numbers are autogenerated, do not edit manually. */
63
+ #define CUDNN_ADV_MAJOR 9
64
+ #define CUDNN_ADV_MINOR 1
65
+ #define CUDNN_ADV_PATCH 0
66
+
67
+ #if (CUDNN_ADV_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_MINOR != CUDNN_MINOR) || (CUDNN_ADV_PATCH != CUDNN_PATCHLEVEL)
68
+ #error Version mismatch in cuDNN ADV INFER!!!
69
+ #endif
70
+
71
+ #if defined(__cplusplus)
72
+ extern "C" {
73
+ #endif
74
+
75
+ /* BASIC RNN API */
76
+
77
+ typedef enum {
78
+ CUDNN_RNN_ALGO_STANDARD = 0,
79
+ CUDNN_RNN_ALGO_PERSIST_STATIC = 1,
80
+ CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2,
81
+ CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H = 3,
82
+ CUDNN_RNN_ALGO_COUNT = 4,
83
+ } cudnnRNNAlgo_t;
84
+
85
+ typedef enum {
86
+ CUDNN_FWD_MODE_INFERENCE = 0,
87
+ CUDNN_FWD_MODE_TRAINING = 1,
88
+ } cudnnForwardMode_t;
89
+
90
+ typedef enum {
91
+ CUDNN_RNN_RELU = 0, /* basic RNN cell type with ReLu activation */
92
+ CUDNN_RNN_TANH = 1, /* basic RNN cell type with tanh activation */
93
+ CUDNN_LSTM = 2, /* LSTM with optional recurrent projection and clipping */
94
+ CUDNN_GRU = 3, /* Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1); */
95
+ } cudnnRNNMode_t;
96
+
97
+ typedef enum {
98
+ CUDNN_RNN_NO_BIAS = 0, /* rnn cell formulas do not use biases */
99
+ CUDNN_RNN_SINGLE_INP_BIAS = 1, /* rnn cell formulas use one input bias in input GEMM */
100
+ CUDNN_RNN_DOUBLE_BIAS = 2, /* default, rnn cell formulas use two bias vectors */
101
+ CUDNN_RNN_SINGLE_REC_BIAS = 3 /* rnn cell formulas use one recurrent bias in recurrent GEMM */
102
+ } cudnnRNNBiasMode_t;
103
+
104
+ typedef enum {
105
+ CUDNN_UNIDIRECTIONAL = 0, /* single direction network */
106
+ CUDNN_BIDIRECTIONAL = 1, /* output concatination at each layer */
107
+ } cudnnDirectionMode_t;
108
+
109
+ typedef enum {
110
+ CUDNN_LINEAR_INPUT = 0, /* adjustable weight matrix in first layer input GEMM */
111
+ CUDNN_SKIP_INPUT = 1, /* fixed identity matrix in the first layer input GEMM */
112
+ } cudnnRNNInputMode_t;
113
+
114
+ typedef enum {
115
+ CUDNN_RNN_CLIP_NONE = 0, /* disables LSTM cell clipping */
116
+ CUDNN_RNN_CLIP_MINMAX = 1, /* enables LSTM cell clipping */
117
+ } cudnnRNNClipMode_t;
118
+
119
+ typedef enum {
120
+ CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED = 0, /* padded, outer stride from one time-step to the next */
121
+ CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED = 1, /* sequence length sorted and packed as in basic RNN api */
122
+ CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED = 2, /* padded, outer stride from one batch to the next */
123
+ } cudnnRNNDataLayout_t;
124
+
125
+ /* For auxFlags in cudnnSetRNNDescriptor_v8() */
126
+ #define CUDNN_RNN_PADDED_IO_DISABLED 0
127
+ #define CUDNN_RNN_PADDED_IO_ENABLED (1U << 0)
128
+
129
+ struct cudnnRNNStruct;
130
+ typedef struct cudnnRNNStruct *cudnnRNNDescriptor_t;
131
+
132
+ struct cudnnRNNDataStruct;
133
+ typedef struct cudnnRNNDataStruct *cudnnRNNDataDescriptor_t;
134
+
135
+ cudnnStatus_t CUDNNWINAPI
136
+ cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc);
137
+
138
+ cudnnStatus_t CUDNNWINAPI
139
+ cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
140
+
141
+ /*
142
+ * mathPrec in cudnnSetRNNDescriptor_v8() specifies compute precision.
143
+ * Compute precision is further modified by mathType that sets the
144
+ * preferred option for using NVIDIA Tensor Cores. dataType specify
145
+ * input/output data type and weight/bias type.
146
+ */
147
+
148
+ cudnnStatus_t CUDNNWINAPI
149
+ cudnnSetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
150
+ cudnnRNNAlgo_t algo,
151
+ cudnnRNNMode_t cellMode,
152
+ cudnnRNNBiasMode_t biasMode,
153
+ cudnnDirectionMode_t dirMode,
154
+ cudnnRNNInputMode_t inputMode,
155
+ cudnnDataType_t dataType,
156
+ cudnnDataType_t mathPrec,
157
+ cudnnMathType_t mathType,
158
+ int32_t inputSize,
159
+ int32_t hiddenSize,
160
+ int32_t projSize,
161
+ int32_t numLayers,
162
+ cudnnDropoutDescriptor_t dropoutDesc,
163
+ uint32_t auxFlags);
164
+
165
+ cudnnStatus_t CUDNNWINAPI
166
+ cudnnGetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
167
+ cudnnRNNAlgo_t *algo,
168
+ cudnnRNNMode_t *cellMode,
169
+ cudnnRNNBiasMode_t *biasMode,
170
+ cudnnDirectionMode_t *dirMode,
171
+ cudnnRNNInputMode_t *inputMode,
172
+ cudnnDataType_t *dataType,
173
+ cudnnDataType_t *mathPrec,
174
+ cudnnMathType_t *mathType,
175
+ int32_t *inputSize,
176
+ int32_t *hiddenSize,
177
+ int32_t *projSize,
178
+ int32_t *numLayers,
179
+ cudnnDropoutDescriptor_t *dropoutDesc,
180
+ uint32_t *auxFlags);
181
+
182
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
183
+ cudnnRNNSetClip_v8(cudnnRNNDescriptor_t rnnDesc,
184
+ cudnnRNNClipMode_t clipMode,
185
+ cudnnNanPropagation_t clipNanOpt,
186
+ double lclip,
187
+ double rclip);
188
+
189
+ cudnnStatus_t CUDNNWINAPI
190
+ cudnnRNNSetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t clipMode, double lclip, double rclip);
191
+
192
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
193
+ cudnnRNNGetClip_v8(cudnnRNNDescriptor_t rnnDesc,
194
+ cudnnRNNClipMode_t *clipMode,
195
+ cudnnNanPropagation_t *clipNanOpt,
196
+ double *lclip,
197
+ double *rclip);
198
+
199
+ cudnnStatus_t CUDNNWINAPI
200
+ cudnnRNNGetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t *clipMode, double *lclip, double *rclip);
201
+
202
+ cudnnStatus_t CUDNNWINAPI
203
+ cudnnBuildRNNDynamic(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int miniBatch);
204
+
205
+ cudnnStatus_t CUDNNWINAPI
206
+ cudnnGetRNNTempSpaceSizes(cudnnHandle_t handle,
207
+ cudnnRNNDescriptor_t rnnDesc,
208
+ cudnnForwardMode_t fwdMode,
209
+ cudnnRNNDataDescriptor_t xDesc,
210
+ size_t *workSpaceSize,
211
+ size_t *reserveSpaceSize);
212
+
213
+ cudnnStatus_t CUDNNWINAPI
214
+ cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, size_t *weightSpaceSize);
215
+
216
+ cudnnStatus_t CUDNNWINAPI
217
+ cudnnGetRNNWeightParams(cudnnHandle_t handle,
218
+ cudnnRNNDescriptor_t rnnDesc,
219
+ int32_t pseudoLayer,
220
+ size_t weightSpaceSize,
221
+ const void *weightSpace,
222
+ int32_t linLayerID,
223
+ cudnnTensorDescriptor_t mDesc,
224
+ void **mAddr,
225
+ cudnnTensorDescriptor_t bDesc,
226
+ void **bAddr);
227
+
228
+ cudnnStatus_t CUDNNWINAPI
229
+ cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc);
230
+
231
+ cudnnStatus_t CUDNNWINAPI
232
+ cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc);
233
+
234
+ cudnnStatus_t CUDNNWINAPI
235
+ cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
236
+ cudnnDataType_t dataType,
237
+ cudnnRNNDataLayout_t layout,
238
+ int maxSeqLength,
239
+ int batchSize,
240
+ int vectorSize,
241
+ const int seqLengthArray[], /* length of each sequence in the batch */
242
+ void *paddingFill); /* symbol for filling padding position in output */
243
+
244
+ cudnnStatus_t CUDNNWINAPI
245
+ cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
246
+ cudnnDataType_t *dataType,
247
+ cudnnRNNDataLayout_t *layout,
248
+ int *maxSeqLength,
249
+ int *batchSize,
250
+ int *vectorSize,
251
+ int arrayLengthRequested,
252
+ int seqLengthArray[],
253
+ void *paddingFill);
254
+
255
+ cudnnStatus_t CUDNNWINAPI
256
+ cudnnRNNForward(cudnnHandle_t handle,
257
+ cudnnRNNDescriptor_t rnnDesc,
258
+ cudnnForwardMode_t fwdMode,
259
+ const int32_t devSeqLengths[],
260
+ cudnnRNNDataDescriptor_t xDesc,
261
+ const void *x,
262
+ cudnnRNNDataDescriptor_t yDesc,
263
+ void *y,
264
+ cudnnTensorDescriptor_t hDesc,
265
+ const void *hx,
266
+ void *hy,
267
+ cudnnTensorDescriptor_t cDesc,
268
+ const void *cx,
269
+ void *cy,
270
+ size_t weightSpaceSize,
271
+ const void *weightSpace,
272
+ size_t workSpaceSize,
273
+ void *workSpace,
274
+ size_t reserveSpaceSize,
275
+ void *reserveSpace);
276
+
277
+ /* Sequence data descriptor */
278
+
279
+ typedef enum {
280
+ CUDNN_SEQDATA_TIME_DIM = 0, /* index in time */
281
+ CUDNN_SEQDATA_BATCH_DIM = 1, /* index in batch */
282
+ CUDNN_SEQDATA_BEAM_DIM = 2, /* index in beam */
283
+ CUDNN_SEQDATA_VECT_DIM = 3 /* index in vector */
284
+ } cudnnSeqDataAxis_t;
285
+
286
+ struct cudnnSeqDataStruct;
287
+ typedef struct cudnnSeqDataStruct *cudnnSeqDataDescriptor_t CUDNN_DEPRECATED;
288
+
289
+ #define CUDNN_SEQDATA_DIM_COUNT 4 /* dimension count */
290
+
291
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
292
+ cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc);
293
+
294
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
295
+ cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc);
296
+
297
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
298
+ cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc,
299
+ cudnnDataType_t dataType,
300
+ int nbDims,
301
+ const int dimA[],
302
+ const cudnnSeqDataAxis_t axes[],
303
+ size_t seqLengthArraySize,
304
+ const int seqLengthArray[],
305
+ void *paddingFill);
306
+
307
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
308
+ cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc,
309
+ cudnnDataType_t *dataType,
310
+ int *nbDims,
311
+ int nbDimsRequested,
312
+ int dimA[],
313
+ cudnnSeqDataAxis_t axes[],
314
+ size_t *seqLengthArraySize,
315
+ size_t seqLengthSizeRequested,
316
+ int seqLengthArray[],
317
+ void *paddingFill);
318
+
319
+ /* Multihead Attention */
320
+
321
+ /*
322
+ * Multi-head attention options passed via 'attnMode' in cudnnSetAttnDescriptor().
323
+ * Use the bitwise OR operator to combine several settings listed below. Additional
324
+ * minor options can be added here w/o changing or introducing new API functions.
325
+ */
326
+ #define CUDNN_ATTN_QUERYMAP_ALL_TO_ONE 0 /* multiple Q-s map to a single (K,V) set when beam size > 1 */
327
+ #define CUDNN_ATTN_QUERYMAP_ONE_TO_ONE (1U << 0) /* multiple Q-s map to multiple (K,V) sets when beam size > 1 */
328
+ #define CUDNN_ATTN_DISABLE_PROJ_BIASES 0 /* no biases in attention input and output projections */
329
+ #define CUDNN_ATTN_ENABLE_PROJ_BIASES (1U << 1) /* use biases in attention input and output projections */
330
+
331
+ struct cudnnAttnStruct;
332
+ typedef struct cudnnAttnStruct *cudnnAttnDescriptor_t CUDNN_DEPRECATED;
333
+
334
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
335
+ cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc);
336
+
337
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
338
+ cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc);
339
+
340
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
341
+ cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
342
+ unsigned attnMode,
343
+ int nHeads,
344
+ double smScaler,
345
+ cudnnDataType_t dataType,
346
+ cudnnDataType_t computePrec,
347
+ cudnnMathType_t mathType,
348
+ cudnnDropoutDescriptor_t attnDropoutDesc,
349
+ cudnnDropoutDescriptor_t postDropoutDesc,
350
+ int qSize,
351
+ int kSize,
352
+ int vSize,
353
+ int qProjSize,
354
+ int kProjSize,
355
+ int vProjSize,
356
+ int oProjSize,
357
+ int qoMaxSeqLength,
358
+ int kvMaxSeqLength,
359
+ int maxBatchSize,
360
+ int maxBeamSize);
361
+
362
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
363
+ cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
364
+ unsigned *attnMode,
365
+ int *nHeads,
366
+ double *smScaler,
367
+ cudnnDataType_t *dataType,
368
+ cudnnDataType_t *computePrec,
369
+ cudnnMathType_t *mathType,
370
+ cudnnDropoutDescriptor_t *attnDropoutDesc,
371
+ cudnnDropoutDescriptor_t *postDropoutDesc,
372
+ int *qSize,
373
+ int *kSize,
374
+ int *vSize,
375
+ int *qProjSize,
376
+ int *kProjSize,
377
+ int *vProjSize,
378
+ int *oProjSize,
379
+ int *qoMaxSeqLength,
380
+ int *kvMaxSeqLength,
381
+ int *maxBatchSize,
382
+ int *maxBeamSize);
383
+
384
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
385
+ cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle,
386
+ const cudnnAttnDescriptor_t attnDesc,
387
+ size_t *weightSizeInBytes,
388
+ size_t *workSpaceSizeInBytes,
389
+ size_t *reserveSpaceSizeInBytes);
390
+
391
+ typedef enum {
392
+ CUDNN_MH_ATTN_Q_WEIGHTS = 0, /* input projection weights for 'queries' */
393
+ CUDNN_MH_ATTN_K_WEIGHTS = 1, /* input projection weights for 'keys' */
394
+ CUDNN_MH_ATTN_V_WEIGHTS = 2, /* input projection weights for 'values' */
395
+ CUDNN_MH_ATTN_O_WEIGHTS = 3, /* output projection weights */
396
+ CUDNN_MH_ATTN_Q_BIASES = 4, /* input projection bias tensor for 'queries' */
397
+ CUDNN_MH_ATTN_K_BIASES = 5, /* input projection bias for 'keys' */
398
+ CUDNN_MH_ATTN_V_BIASES = 6, /* input projection bias for 'values' */
399
+ CUDNN_MH_ATTN_O_BIASES = 7, /* output projection biases */
400
+ } cudnnMultiHeadAttnWeightKind_t;
401
+
402
+ #define CUDNN_ATTN_WKIND_COUNT 8 /* Number of attention weight/bias tensors */
403
+
404
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
405
+ cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle,
406
+ const cudnnAttnDescriptor_t attnDesc,
407
+ cudnnMultiHeadAttnWeightKind_t wKind,
408
+ size_t weightSizeInBytes,
409
+ const void *weights,
410
+ cudnnTensorDescriptor_t wDesc,
411
+ void **wAddr);
412
+
413
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
414
+ cudnnMultiHeadAttnForward(cudnnHandle_t handle,
415
+ const cudnnAttnDescriptor_t attnDesc,
416
+ int currIdx,
417
+ const int loWinIdx[],
418
+ const int hiWinIdx[],
419
+ const int devSeqLengthsQO[],
420
+ const int devSeqLengthsKV[],
421
+ const cudnnSeqDataDescriptor_t qDesc,
422
+ const void *queries,
423
+ const void *residuals,
424
+ const cudnnSeqDataDescriptor_t kDesc,
425
+ const void *keys,
426
+ const cudnnSeqDataDescriptor_t vDesc,
427
+ const void *values,
428
+ const cudnnSeqDataDescriptor_t oDesc,
429
+ void *out,
430
+ size_t weightSizeInBytes,
431
+ const void *weights,
432
+ size_t workSpaceSizeInBytes,
433
+ void *workSpace,
434
+ size_t reserveSpaceSizeInBytes,
435
+ void *reserveSpace);
436
+
437
+ /*
438
+ * \brief Cross-library version checker.
439
+ * This function is implemented differently in each sub-library. Each sublib
440
+ * checks whether its own version matches that of its dependencies.
441
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
442
+ * CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
443
+ */
444
+ cudnnStatus_t CUDNNWINAPI
445
+ cudnnAdvVersionCheck(void);
446
+
447
+ typedef enum {
448
+ CUDNN_WGRAD_MODE_ADD = 0, /* add partial gradients to wgrad output buffers */
449
+ CUDNN_WGRAD_MODE_SET = 1, /* write partial gradients to wgrad output buffers */
450
+ } cudnnWgradMode_t;
451
+
452
+ cudnnStatus_t CUDNNWINAPI
453
+ cudnnRNNBackwardData_v8(cudnnHandle_t handle,
454
+ cudnnRNNDescriptor_t rnnDesc,
455
+ const int32_t devSeqLengths[],
456
+ cudnnRNNDataDescriptor_t yDesc,
457
+ const void *y,
458
+ const void *dy,
459
+ cudnnRNNDataDescriptor_t xDesc,
460
+ void *dx,
461
+ cudnnTensorDescriptor_t hDesc,
462
+ const void *hx,
463
+ const void *dhy,
464
+ void *dhx,
465
+ cudnnTensorDescriptor_t cDesc,
466
+ const void *cx,
467
+ const void *dcy,
468
+ void *dcx,
469
+ size_t weightSpaceSize,
470
+ const void *weightSpace,
471
+ size_t workSpaceSize,
472
+ void *workSpace,
473
+ size_t reserveSpaceSize,
474
+ void *reserveSpace);
475
+
476
+ cudnnStatus_t CUDNNWINAPI
477
+ cudnnRNNBackwardWeights_v8(cudnnHandle_t handle,
478
+ cudnnRNNDescriptor_t rnnDesc,
479
+ cudnnWgradMode_t addGrad,
480
+ const int32_t devSeqLengths[],
481
+ cudnnRNNDataDescriptor_t xDesc,
482
+ const void *x,
483
+ cudnnTensorDescriptor_t hDesc,
484
+ const void *hx,
485
+ cudnnRNNDataDescriptor_t yDesc,
486
+ const void *y,
487
+ size_t weightSpaceSize,
488
+ void *dweightSpace,
489
+ size_t workSpaceSize,
490
+ void *workSpace,
491
+ size_t reserveSpaceSize,
492
+ void *reserveSpace);
493
+
494
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
495
+ cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle,
496
+ const cudnnAttnDescriptor_t attnDesc,
497
+ const int loWinIdx[],
498
+ const int hiWinIdx[],
499
+ const int devSeqLengthsDQDO[],
500
+ const int devSeqLengthsDKDV[],
501
+ const cudnnSeqDataDescriptor_t doDesc,
502
+ const void *dout,
503
+ const cudnnSeqDataDescriptor_t dqDesc,
504
+ void *dqueries,
505
+ const void *queries,
506
+ const cudnnSeqDataDescriptor_t dkDesc,
507
+ void *dkeys,
508
+ const void *keys,
509
+ const cudnnSeqDataDescriptor_t dvDesc,
510
+ void *dvalues,
511
+ const void *values,
512
+ size_t weightSizeInBytes,
513
+ const void *weights,
514
+ size_t workSpaceSizeInBytes,
515
+ void *workSpace,
516
+ size_t reserveSpaceSizeInBytes,
517
+ void *reserveSpace);
518
+
519
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
520
+ cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle,
521
+ const cudnnAttnDescriptor_t attnDesc,
522
+ cudnnWgradMode_t addGrad,
523
+ const cudnnSeqDataDescriptor_t qDesc,
524
+ const void *queries,
525
+ const cudnnSeqDataDescriptor_t kDesc,
526
+ const void *keys,
527
+ const cudnnSeqDataDescriptor_t vDesc,
528
+ const void *values,
529
+ const cudnnSeqDataDescriptor_t doDesc,
530
+ const void *dout,
531
+ size_t weightSizeInBytes,
532
+ const void *weights,
533
+ void *dweights,
534
+ size_t workSpaceSizeInBytes,
535
+ void *workSpace,
536
+ size_t reserveSpaceSizeInBytes,
537
+ void *reserveSpace);
538
+
539
+ /*
540
+ * CTC (Connectionist Temporal Classification) loss descriptor create/destory/set/get functions
541
+ */
542
+ /* Input normalization mode for loss function */
543
+ typedef enum {
544
+ CUDNN_LOSS_NORMALIZATION_NONE = 0,
545
+ CUDNN_LOSS_NORMALIZATION_SOFTMAX = 1,
546
+ } cudnnLossNormalizationMode_t;
547
+
548
+ cudnnStatus_t CUDNNWINAPI
549
+ cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc);
550
+
551
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
552
+ cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType);
553
+
554
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
555
+ cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
556
+ cudnnDataType_t compType,
557
+ cudnnLossNormalizationMode_t normMode,
558
+ cudnnNanPropagation_t gradMode);
559
+
560
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
561
+ cudnnSetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
562
+ cudnnDataType_t compType,
563
+ cudnnLossNormalizationMode_t normMode,
564
+ cudnnNanPropagation_t gradMode,
565
+ int maxLabelLength);
566
+
567
+ cudnnStatus_t CUDNNWINAPI
568
+ cudnnSetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
569
+ cudnnDataType_t compType,
570
+ cudnnLossNormalizationMode_t normMode,
571
+ cudnnCTCGradMode_t ctcGradMode,
572
+ int maxLabelLength);
573
+
574
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
575
+ cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType);
576
+
577
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
578
+ cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
579
+ cudnnDataType_t *compType,
580
+ cudnnLossNormalizationMode_t *normMode,
581
+ cudnnNanPropagation_t *gradMode);
582
+
583
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
584
+ cudnnGetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
585
+ cudnnDataType_t *compType,
586
+ cudnnLossNormalizationMode_t *normMode,
587
+ cudnnNanPropagation_t *gradMode,
588
+ int *maxLabelLength);
589
+
590
+ cudnnStatus_t CUDNNWINAPI
591
+ cudnnGetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
592
+ cudnnDataType_t *compType,
593
+ cudnnLossNormalizationMode_t *normMode,
594
+ cudnnCTCGradMode_t *ctcGradMode,
595
+ int *maxLabelLength);
596
+
597
+ cudnnStatus_t CUDNNWINAPI
598
+ cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc);
599
+
600
+ /* return the ctc costs and gradients, given the probabilities and labels */
601
+ cudnnStatus_t CUDNNWINAPI
602
+ cudnnCTCLoss(
603
+ cudnnHandle_t handle,
604
+ const cudnnTensorDescriptor_t
605
+ probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
606
+ mini batch size, A is the alphabet size) */
607
+ const void *probs, /* probabilities after softmax, in GPU memory */
608
+ const int hostLabels[], /* labels, in CPU memory */
609
+ const int hostLabelLengths[], /* the length of each label, in CPU memory */
610
+ const int hostInputLengths[], /* the lengths of timing steps in each batch, in CPU memory */
611
+ void *costs, /* the returned costs of CTC, in GPU memory */
612
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
613
+ void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
614
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
615
+ cudnnCTCLossDescriptor_t ctcLossDesc,
616
+ void *workspace, /* pointer to the workspace, in GPU memory */
617
+ size_t workSpaceSizeInBytes); /* size of the workspace */
618
+
619
+ /* return the ctc costs and gradients, given the probabilities and labels */
620
+ cudnnStatus_t CUDNNWINAPI
621
+ cudnnCTCLoss_v8(
622
+ cudnnHandle_t handle,
623
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
624
+ cudnnCTCLossDescriptor_t ctcLossDesc,
625
+ const cudnnTensorDescriptor_t
626
+ probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
627
+ mini batch size, A is the alphabet size) */
628
+ const void *probs, /* probabilities after softmax, in GPU memory */
629
+ const int labels[], /* labels, in GPU memory */
630
+ const int labelLengths[], /* the length of each label, in GPU memory */
631
+ const int inputLengths[], /* the lengths of timing steps in each batch, in GPU memory */
632
+ void *costs, /* the returned costs of CTC, in GPU memory */
633
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
634
+ void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
635
+ size_t workSpaceSizeInBytes, /* size of the workspace */
636
+ void *workspace); /* pointer to the workspace, in GPU memory */
637
+
638
+ /* return the workspace size needed for ctc */
639
+ cudnnStatus_t CUDNNWINAPI
640
+ cudnnGetCTCLossWorkspaceSize(
641
+ cudnnHandle_t handle,
642
+ const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
643
+ timing steps, N is the mini batch size, A is the alphabet size) */
644
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
645
+ dimensions are T,N,A. To compute costs
646
+ only, set it to NULL */
647
+ const int *labels, /* labels, in CPU memory */
648
+ const int *labelLengths, /* the length of each label, in CPU memory */
649
+ const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */
650
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
651
+ cudnnCTCLossDescriptor_t ctcLossDesc,
652
+ size_t *sizeInBytes); /* pointer to the returned workspace size */
653
+
654
+ /* return the workspace size needed for ctc */
655
+ cudnnStatus_t CUDNNWINAPI
656
+ cudnnGetCTCLossWorkspaceSize_v8(
657
+ cudnnHandle_t handle,
658
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
659
+ cudnnCTCLossDescriptor_t ctcLossDesc,
660
+ const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
661
+ timing steps, N is the mini batch size, A is the alphabet size) */
662
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
663
+ dimensions are T,N,A. To compute costs
664
+ only, set it to NULL */
665
+ size_t *sizeInBytes); /* pointer to the returned workspace size */
666
+
667
+ #if defined(__cplusplus)
668
+ }
669
+ #endif
670
+
671
+ #endif /* CUDNN_ADV_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend.h ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #ifndef _CUDNN_BACKEND_H_
51
+ #define _CUDNN_BACKEND_H_
52
+
53
+ /*
54
+ * The content of this header has been moved into cudnn_graph.h.
55
+ * This header is kept for the backward compatibility purpose.
56
+ */
57
+
58
+ #include "cudnn_graph.h"
59
+
60
+ #endif /* _CUDNN_BACKEND_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn.h ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_cnn : cuDNN's basic definitions and CNN functions.
52
+ */
53
+
54
+ #if !defined(CUDNN_CNN_H_)
55
+ #define CUDNN_CNN_H_
56
+
57
+ #pragma once
58
+ #include <stdint.h>
59
+
60
+ #include "cudnn_version.h"
61
+ #include "cudnn_ops.h"
62
+
63
+ /* These version numbers are autogenerated, do not edit manually. */
64
+ #define CUDNN_CNN_MAJOR 9
65
+ #define CUDNN_CNN_MINOR 1
66
+ #define CUDNN_CNN_PATCH 0
67
+
68
+ #if (CUDNN_CNN_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_MINOR != CUDNN_MINOR) || (CUDNN_CNN_PATCH != CUDNN_PATCHLEVEL)
69
+ #error Version mismatch in cuDNN CNN INFER!!!
70
+ #endif
71
+
72
+ #if defined(__cplusplus)
73
+ extern "C" {
74
+ #endif
75
+
76
+ typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t CUDNN_DEPRECATED;
77
+
78
+ typedef struct cudnnConvolutionFwdAlgoPerfStruct {
79
+ cudnnConvolutionFwdAlgo_t algo;
80
+ cudnnStatus_t status;
81
+ float time;
82
+ size_t memory;
83
+ cudnnDeterminism_t determinism;
84
+ cudnnMathType_t mathType;
85
+ int reserved[3];
86
+ } cudnnConvolutionFwdAlgoPerf_t CUDNN_DEPRECATED;
87
+
88
+ /* Create an instance of convolution descriptor */
89
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
90
+ cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc);
91
+
92
+ /* Destroy an instance of convolution descriptor */
93
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
94
+ cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc);
95
+
96
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
97
+ cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType);
98
+
99
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
100
+ cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType);
101
+
102
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
103
+ cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount);
104
+
105
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
106
+ cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount);
107
+
108
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
109
+ cudnnSetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType);
110
+
111
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
112
+ cudnnGetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType);
113
+
114
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
115
+ cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc,
116
+ int pad_h, /* zero-padding height */
117
+ int pad_w, /* zero-padding width */
118
+ int u, /* vertical filter stride */
119
+ int v, /* horizontal filter stride */
120
+ int dilation_h, /* filter dilation in the vertical dimension */
121
+ int dilation_w, /* filter dilation in the horizontal dimension */
122
+ cudnnConvolutionMode_t mode,
123
+ cudnnDataType_t computeType);
124
+
125
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
126
+ cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc,
127
+ int *pad_h, /* zero-padding height */
128
+ int *pad_w, /* zero-padding width */
129
+ int *u, /* vertical filter stride */
130
+ int *v, /* horizontal filter stride */
131
+ int *dilation_h, /* filter dilation in the vertical dimension */
132
+ int *dilation_w, /* filter dilation in the horizontal dimension */
133
+ cudnnConvolutionMode_t *mode,
134
+ cudnnDataType_t *computeType);
135
+
136
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
137
+ cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc,
138
+ int arrayLength, /* nbDims-2 size */
139
+ const int padA[],
140
+ const int filterStrideA[],
141
+ const int dilationA[],
142
+ cudnnConvolutionMode_t mode,
143
+ cudnnDataType_t computeType); /* convolution data type */
144
+
145
+ /* Helper function to return the dimensions of the output tensor given a convolution descriptor */
146
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
147
+ cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc,
148
+ int arrayLengthRequested,
149
+ int *arrayLength,
150
+ int padA[],
151
+ int strideA[],
152
+ int dilationA[],
153
+ cudnnConvolutionMode_t *mode,
154
+ cudnnDataType_t *computeType); /* convolution data type */
155
+
156
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
157
+ cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
158
+ const cudnnTensorDescriptor_t inputTensorDesc,
159
+ const cudnnFilterDescriptor_t filterDesc,
160
+ int *n,
161
+ int *c,
162
+ int *h,
163
+ int *w);
164
+
165
+ /* Helper function to return the dimensions of the output tensor given a convolution descriptor */
166
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
167
+ cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
168
+ const cudnnTensorDescriptor_t inputTensorDesc,
169
+ const cudnnFilterDescriptor_t filterDesc,
170
+ int nbDims,
171
+ int tensorOuputDimA[]);
172
+
173
+ /* helper function to provide the convolution forward algo that fit best the requirement */
174
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
175
+ cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count);
176
+
177
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
178
+ cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle,
179
+ const cudnnTensorDescriptor_t srcDesc,
180
+ const cudnnFilterDescriptor_t filterDesc,
181
+ const cudnnConvolutionDescriptor_t convDesc,
182
+ const cudnnTensorDescriptor_t destDesc,
183
+ const int requestedAlgoCount,
184
+ int *returnedAlgoCount,
185
+ cudnnConvolutionFwdAlgoPerf_t *perfResults);
186
+
187
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
188
+ cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle,
189
+ const cudnnTensorDescriptor_t xDesc,
190
+ const cudnnFilterDescriptor_t wDesc,
191
+ const cudnnConvolutionDescriptor_t convDesc,
192
+ const cudnnTensorDescriptor_t yDesc,
193
+ const int requestedAlgoCount,
194
+ int *returnedAlgoCount,
195
+ cudnnConvolutionFwdAlgoPerf_t *perfResults);
196
+
197
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
198
+ cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle,
199
+ const cudnnTensorDescriptor_t xDesc,
200
+ const void *x,
201
+ const cudnnFilterDescriptor_t wDesc,
202
+ const void *w,
203
+ const cudnnConvolutionDescriptor_t convDesc,
204
+ const cudnnTensorDescriptor_t yDesc,
205
+ void *y,
206
+ const int requestedAlgoCount,
207
+ int *returnedAlgoCount,
208
+ cudnnConvolutionFwdAlgoPerf_t *perfResults,
209
+ void *workSpace,
210
+ size_t workSpaceSizeInBytes);
211
+
212
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
213
+ cudnnIm2Col(cudnnHandle_t handle,
214
+ const cudnnTensorDescriptor_t xDesc,
215
+ const void *x,
216
+ const cudnnFilterDescriptor_t wDesc,
217
+ const cudnnConvolutionDescriptor_t convDesc,
218
+ void *colBuffer);
219
+
220
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
221
+ cudnnReorderFilterAndBias(cudnnHandle_t handle,
222
+ const cudnnFilterDescriptor_t filterDesc,
223
+ cudnnReorderType_t reorderType,
224
+ const void *filterData,
225
+ void *reorderedFilterData,
226
+ int reorderBias,
227
+ const void *biasData,
228
+ void *reorderedBiasData);
229
+
230
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
231
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
232
+ cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle,
233
+ const cudnnTensorDescriptor_t xDesc,
234
+ const cudnnFilterDescriptor_t wDesc,
235
+ const cudnnConvolutionDescriptor_t convDesc,
236
+ const cudnnTensorDescriptor_t yDesc,
237
+ cudnnConvolutionFwdAlgo_t algo,
238
+ size_t *sizeInBytes);
239
+
240
+ /* Convolution functions: All of the form "output = alpha * Op(inputs) + beta * output" */
241
+
242
+ /* Function to perform the forward pass for batch convolution */
243
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
244
+ cudnnConvolutionForward(cudnnHandle_t handle,
245
+ const void *alpha,
246
+ const cudnnTensorDescriptor_t xDesc,
247
+ const void *x,
248
+ const cudnnFilterDescriptor_t wDesc,
249
+ const void *w,
250
+ const cudnnConvolutionDescriptor_t convDesc,
251
+ cudnnConvolutionFwdAlgo_t algo,
252
+ void *workSpace,
253
+ size_t workSpaceSizeInBytes,
254
+ const void *beta,
255
+ const cudnnTensorDescriptor_t yDesc,
256
+ void *y);
257
+
258
+ /* Fused conv/bias/activation operation : y = Act( alpha1 * conv(x) + alpha2 * z + bias ) */
259
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
260
+ cudnnConvolutionBiasActivationForward(cudnnHandle_t handle,
261
+ const void *alpha1,
262
+ const cudnnTensorDescriptor_t xDesc,
263
+ const void *x,
264
+ const cudnnFilterDescriptor_t wDesc,
265
+ const void *w,
266
+ const cudnnConvolutionDescriptor_t convDesc,
267
+ cudnnConvolutionFwdAlgo_t algo,
268
+ void *workSpace,
269
+ size_t workSpaceSizeInBytes,
270
+ const void *alpha2,
271
+ const cudnnTensorDescriptor_t zDesc,
272
+ const void *z,
273
+ const cudnnTensorDescriptor_t biasDesc,
274
+ const void *bias,
275
+ const cudnnActivationDescriptor_t activationDesc,
276
+ const cudnnTensorDescriptor_t yDesc,
277
+ void *y);
278
+
279
+ /* helper function to provide the convolution backward data algo that fit best the requirement */
280
+
281
+ typedef struct cudnnConvolutionBwdDataAlgoPerfStruct {
282
+ cudnnConvolutionBwdDataAlgo_t algo;
283
+ cudnnStatus_t status;
284
+ float time;
285
+ size_t memory;
286
+ cudnnDeterminism_t determinism;
287
+ cudnnMathType_t mathType;
288
+ int reserved[3];
289
+ } cudnnConvolutionBwdDataAlgoPerf_t CUDNN_DEPRECATED;
290
+
291
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
292
+ cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count);
293
+
294
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
295
+ cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
296
+ const cudnnFilterDescriptor_t wDesc,
297
+ const cudnnTensorDescriptor_t dyDesc,
298
+ const cudnnConvolutionDescriptor_t convDesc,
299
+ const cudnnTensorDescriptor_t dxDesc,
300
+ const int requestedAlgoCount,
301
+ int *returnedAlgoCount,
302
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
303
+
304
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
305
+ cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle,
306
+ const cudnnFilterDescriptor_t wDesc,
307
+ const void *w,
308
+ const cudnnTensorDescriptor_t dyDesc,
309
+ const void *dy,
310
+ const cudnnConvolutionDescriptor_t convDesc,
311
+ const cudnnTensorDescriptor_t dxDesc,
312
+ void *dx,
313
+ const int requestedAlgoCount,
314
+ int *returnedAlgoCount,
315
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults,
316
+ void *workSpace,
317
+ size_t workSpaceSizeInBytes);
318
+
319
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
320
+ cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle,
321
+ const cudnnFilterDescriptor_t filterDesc,
322
+ const cudnnTensorDescriptor_t diffDesc,
323
+ const cudnnConvolutionDescriptor_t convDesc,
324
+ const cudnnTensorDescriptor_t gradDesc,
325
+ const int requestedAlgoCount,
326
+ int *returnedAlgoCount,
327
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
328
+
329
+ /*
330
+ * convolution algorithm (which requires potentially some workspace)
331
+ */
332
+
333
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
334
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
335
+ cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle,
336
+ const cudnnFilterDescriptor_t wDesc,
337
+ const cudnnTensorDescriptor_t dyDesc,
338
+ const cudnnConvolutionDescriptor_t convDesc,
339
+ const cudnnTensorDescriptor_t dxDesc,
340
+ cudnnConvolutionBwdDataAlgo_t algo,
341
+ size_t *sizeInBytes);
342
+
343
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
344
+ cudnnConvolutionBackwardData(cudnnHandle_t handle,
345
+ const void *alpha,
346
+ const cudnnFilterDescriptor_t wDesc,
347
+ const void *w,
348
+ const cudnnTensorDescriptor_t dyDesc,
349
+ const void *dy,
350
+ const cudnnConvolutionDescriptor_t convDesc,
351
+ cudnnConvolutionBwdDataAlgo_t algo,
352
+ void *workSpace,
353
+ size_t workSpaceSizeInBytes,
354
+ const void *beta,
355
+ const cudnnTensorDescriptor_t dxDesc,
356
+ void *dx);
357
+
358
+ /* Helper function to calculate folding descriptors for dgrad */
359
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
360
+ cudnnGetFoldedConvBackwardDataDescriptors(const cudnnHandle_t handle,
361
+ const cudnnFilterDescriptor_t filterDesc,
362
+ const cudnnTensorDescriptor_t diffDesc,
363
+ const cudnnConvolutionDescriptor_t convDesc,
364
+ const cudnnTensorDescriptor_t gradDesc,
365
+ const cudnnTensorFormat_t transformFormat,
366
+ cudnnFilterDescriptor_t foldedFilterDesc,
367
+ cudnnTensorDescriptor_t paddedDiffDesc,
368
+ cudnnConvolutionDescriptor_t foldedConvDesc,
369
+ cudnnTensorDescriptor_t foldedGradDesc,
370
+ cudnnTensorTransformDescriptor_t filterFoldTransDesc,
371
+ cudnnTensorTransformDescriptor_t diffPadTransDesc,
372
+ cudnnTensorTransformDescriptor_t gradFoldTransDesc,
373
+ cudnnTensorTransformDescriptor_t gradUnfoldTransDesc);
374
+
375
+ /* cudnnFusedOps... */
376
+ struct cudnnFusedOpsConstParamStruct;
377
+ typedef struct cudnnFusedOpsConstParamStruct *cudnnFusedOpsConstParamPack_t CUDNN_DEPRECATED;
378
+
379
+ struct cudnnFusedOpsVariantParamStruct;
380
+ typedef struct cudnnFusedOpsVariantParamStruct *cudnnFusedOpsVariantParamPack_t CUDNN_DEPRECATED;
381
+
382
+ struct cudnnFusedOpsPlanStruct;
383
+ typedef struct cudnnFusedOpsPlanStruct *cudnnFusedOpsPlan_t CUDNN_DEPRECATED;
384
+
385
+ typedef enum {
386
+ /* each op in [ ] can be disabled by passing NULL ptr */
387
+ /* [per channel scale], [per channel bias], [activation], convolution, [generate BN stats] */
388
+ CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS = 0,
389
+ /* [per channel scale], [per channel bias], [activation], convolutionBackwardWeights */
390
+ CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD = 1,
391
+ /* utility for BN training in BN-conv fusion */
392
+ /* computes the equivalent scale and bias from ySum ySqSum and learned scale, bias */
393
+ /* optionally update running stats and generate saved stats */
394
+ CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING = 2,
395
+ /* utility for BN inference in BN-conv fusion */
396
+ /* computes the equivalent scale and bias from learned running stats and learned scale, bias */
397
+ CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE = 3,
398
+ /* reserved for future use: convolution, [per channel scale], [per channel bias], [residual add], [activation] */
399
+ CUDNN_FUSED_CONV_SCALE_BIAS_ADD_ACTIVATION = 4,
400
+ /* reserved for future use: [per channel scale], [per channel bias], [residual add], activation, bitmask */
401
+ CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK = 5,
402
+ /* reserved for future use */
403
+ CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM = 6,
404
+ } cudnnFusedOps_t CUDNN_DEPRECATED;
405
+
406
+ typedef enum {
407
+ /* set XDESC: pass previously initialized cudnnTensorDescriptor_t */
408
+ /* get XDESC: pass previously created cudnnTensorDescriptor_t */
409
+ CUDNN_PARAM_XDESC = 0,
410
+ /* set/get XDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
411
+ CUDNN_PARAM_XDATA_PLACEHOLDER = 1,
412
+ /* set/get BN_MODE: pass cudnnBatchNormMode_t* */
413
+ CUDNN_PARAM_BN_MODE = 2,
414
+ /* set CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
415
+ /* get CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
416
+ CUDNN_PARAM_BN_EQSCALEBIAS_DESC = 3,
417
+ /* set/get BN_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
418
+ CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER = 4,
419
+ /* set/get BN_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
420
+ CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER = 5,
421
+ /* set ACTIVATION_DESC: pass previously initialized cudnnActivationDescriptor_t */
422
+ /* get ACTIVATION_DESC: pass previously created cudnnActivationDescriptor_t */
423
+ CUDNN_PARAM_ACTIVATION_DESC = 6,
424
+ /* set CONV_DESC: pass previously initialized cudnnConvolutionDescriptor_t */
425
+ /* get CONV_DESC: pass previously created cudnnConvolutionDescriptor_t */
426
+ CUDNN_PARAM_CONV_DESC = 7,
427
+ /* set WDESC: pass previously initialized cudnnFilterDescriptor_t */
428
+ /* get WDESC: pass previously created cudnnFilterDescriptor_t */
429
+ CUDNN_PARAM_WDESC = 8,
430
+ /* set/get WDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
431
+ CUDNN_PARAM_WDATA_PLACEHOLDER = 9,
432
+ /* set DWDESC: pass previously initialized cudnnFilterDescriptor_t */
433
+ /* get DWDESC: pass previously created cudnnFilterDescriptor_t */
434
+ CUDNN_PARAM_DWDESC = 10,
435
+ /* set/get DWDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
436
+ CUDNN_PARAM_DWDATA_PLACEHOLDER = 11,
437
+ /* set YDESC: pass previously initialized cudnnTensorDescriptor_t */
438
+ /* get YDESC: pass previously created cudnnTensorDescriptor_t */
439
+ CUDNN_PARAM_YDESC = 12,
440
+ /* set/get YDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
441
+ CUDNN_PARAM_YDATA_PLACEHOLDER = 13,
442
+ /* set DYDESC: pass previously initialized cudnnTensorDescriptor_t */
443
+ /* get DYDESC: pass previously created cudnnTensorDescriptor_t */
444
+ CUDNN_PARAM_DYDESC = 14,
445
+ /* set/get DYDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
446
+ CUDNN_PARAM_DYDATA_PLACEHOLDER = 15,
447
+ /* set YSTATS_DESC: pass previously initialized cudnnTensorDescriptor_t */
448
+ /* get YSTATS_DESC: pass previously created cudnnTensorDescriptor_t */
449
+ CUDNN_PARAM_YSTATS_DESC = 16,
450
+ /* set/get YSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
451
+ CUDNN_PARAM_YSUM_PLACEHOLDER = 17,
452
+ /* set/get YSQSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
453
+ CUDNN_PARAM_YSQSUM_PLACEHOLDER = 18,
454
+ /* set CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously initialized cudnnTensorDescriptor_t */
455
+ /* get CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously created cudnnTensorDescriptor_t */
456
+ CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC = 19,
457
+ /* set/get CUDNN_PARAM_BN_SCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
458
+ CUDNN_PARAM_BN_SCALE_PLACEHOLDER = 20,
459
+ /* set/get CUDNN_PARAM_BN_BIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
460
+ CUDNN_PARAM_BN_BIAS_PLACEHOLDER = 21,
461
+ /* set/get CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
462
+ CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER = 22,
463
+ /* set/get CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
464
+ CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER = 23,
465
+ /* set/get CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
466
+ CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER = 24,
467
+ /* set/get CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
468
+ CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER = 25,
469
+
470
+ /* set ZDESC: pass previously initialized cudnnTensorDescriptor_t */
471
+ /* get ZDESC: pass previously created cudnnTensorDescriptor_t */
472
+ CUDNN_PARAM_ZDESC = 26,
473
+ /* set/get ZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
474
+ CUDNN_PARAM_ZDATA_PLACEHOLDER = 27,
475
+ /* set BN_Z_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
476
+ /* get BN_Z_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
477
+ CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC = 28,
478
+ /* set/get BN_Z_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
479
+ CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER = 29,
480
+ /* set/get BN_Z_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
481
+ CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER = 30,
482
+
483
+ /* set ACTIVATION_BITMASK_DESC: pass previously initialized cudnnTensorDescriptor_t */
484
+ /* get ACTIVATION_BITMASK_DESC: pass previously created cudnnTensorDescriptor_t */
485
+ CUDNN_PARAM_ACTIVATION_BITMASK_DESC = 31,
486
+ /* set/get ACTIVATION_BITMASK_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
487
+ CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER = 32,
488
+
489
+ /* set DXDESC: pass previously initialized cudnnTensorDescriptor_t */
490
+ /* get DXDESC: pass previously created cudnnTensorDescriptor_t */
491
+ CUDNN_PARAM_DXDESC = 33,
492
+ /* set/get DXDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
493
+ CUDNN_PARAM_DXDATA_PLACEHOLDER = 34,
494
+ /* set DZDESC: pass previously initialized cudnnTensorDescriptor_t */
495
+ /* get DZDESC: pass previously created cudnnTensorDescriptor_t */
496
+ CUDNN_PARAM_DZDESC = 35,
497
+ /* set/get DZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
498
+ CUDNN_PARAM_DZDATA_PLACEHOLDER = 36,
499
+ /* set/get CUDNN_PARAM_BN_DSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
500
+ CUDNN_PARAM_BN_DSCALE_PLACEHOLDER = 37,
501
+ /* set/get CUDNN_PARAM_BN_DBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
502
+ CUDNN_PARAM_BN_DBIAS_PLACEHOLDER = 38,
503
+ } cudnnFusedOpsConstParamLabel_t CUDNN_DEPRECATED;
504
+
505
+ typedef enum {
506
+ CUDNN_PTR_NULL = 0,
507
+ CUDNN_PTR_ELEM_ALIGNED = 1,
508
+ CUDNN_PTR_16B_ALIGNED = 2,
509
+ } cudnnFusedOpsPointerPlaceHolder_t CUDNN_DEPRECATED;
510
+
511
+ typedef enum {
512
+ /* set: pass void* pointing to dev memory */
513
+ /* get: pass void** pointing to host memory */
514
+ CUDNN_PTR_XDATA = 0,
515
+ CUDNN_PTR_BN_EQSCALE = 1,
516
+ CUDNN_PTR_BN_EQBIAS = 2,
517
+ CUDNN_PTR_WDATA = 3,
518
+ CUDNN_PTR_DWDATA = 4,
519
+ CUDNN_PTR_YDATA = 5,
520
+ CUDNN_PTR_DYDATA = 6,
521
+ CUDNN_PTR_YSUM = 7,
522
+ CUDNN_PTR_YSQSUM = 8,
523
+ CUDNN_PTR_WORKSPACE = 9,
524
+ CUDNN_PTR_BN_SCALE = 10,
525
+ CUDNN_PTR_BN_BIAS = 11,
526
+ CUDNN_PTR_BN_SAVED_MEAN = 12,
527
+ CUDNN_PTR_BN_SAVED_INVSTD = 13,
528
+ CUDNN_PTR_BN_RUNNING_MEAN = 14,
529
+ CUDNN_PTR_BN_RUNNING_VAR = 15,
530
+ CUDNN_PTR_ZDATA = 16,
531
+ CUDNN_PTR_BN_Z_EQSCALE = 17,
532
+ CUDNN_PTR_BN_Z_EQBIAS = 18,
533
+ CUDNN_PTR_ACTIVATION_BITMASK = 19,
534
+ CUDNN_PTR_DXDATA = 20,
535
+ CUDNN_PTR_DZDATA = 21,
536
+ CUDNN_PTR_BN_DSCALE = 22,
537
+ CUDNN_PTR_BN_DBIAS = 23,
538
+
539
+ /* set/get: pass size_t* pointing to host memory */
540
+ CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES = 100,
541
+ /* set/get: pass int64_t* pointing to host memory */
542
+ CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT = 101,
543
+ /* set/get: pass double* pointing to host memory */
544
+ CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR = 102,
545
+ /* set/get: pass double* pointing to host memory */
546
+ CUDNN_SCALAR_DOUBLE_BN_EPSILON = 103,
547
+ } cudnnFusedOpsVariantParamLabel_t CUDNN_DEPRECATED;
548
+
549
+ cudnnStatus_t CUDNNWINAPI
550
+ cudnnCnnVersionCheck(void);
551
+
552
+ /* helper function to provide the convolution backward filter algo that fit best the requirement */
553
+
554
+ typedef struct cudnnConvolutionBwdFilterAlgoPerfStruct {
555
+ cudnnConvolutionBwdFilterAlgo_t algo;
556
+ cudnnStatus_t status;
557
+ float time;
558
+ size_t memory;
559
+ cudnnDeterminism_t determinism;
560
+ cudnnMathType_t mathType;
561
+ int reserved[3];
562
+ } cudnnConvolutionBwdFilterAlgoPerf_t CUDNN_DEPRECATED;
563
+
564
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
565
+ cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count);
566
+
567
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
568
+ cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
569
+ const cudnnTensorDescriptor_t xDesc,
570
+ const cudnnTensorDescriptor_t dyDesc,
571
+ const cudnnConvolutionDescriptor_t convDesc,
572
+ const cudnnFilterDescriptor_t dwDesc,
573
+ const int requestedAlgoCount,
574
+ int *returnedAlgoCount,
575
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
576
+
577
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
578
+ cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
579
+ const cudnnTensorDescriptor_t xDesc,
580
+ const void *x,
581
+ const cudnnTensorDescriptor_t dyDesc,
582
+ const void *y,
583
+ const cudnnConvolutionDescriptor_t convDesc,
584
+ const cudnnFilterDescriptor_t dwDesc,
585
+ void *dw,
586
+ const int requestedAlgoCount,
587
+ int *returnedAlgoCount,
588
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
589
+ void *workSpace,
590
+ size_t workSpaceSizeInBytes);
591
+
592
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
593
+ cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
594
+ const cudnnTensorDescriptor_t srcDesc,
595
+ const cudnnTensorDescriptor_t diffDesc,
596
+ const cudnnConvolutionDescriptor_t convDesc,
597
+ const cudnnFilterDescriptor_t gradDesc,
598
+ const int requestedAlgoCount,
599
+ int *returnedAlgoCount,
600
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
601
+
602
+ /*
603
+ * convolution algorithm (which requires potentially some workspace)
604
+ */
605
+
606
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
607
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
608
+ cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle,
609
+ const cudnnTensorDescriptor_t xDesc,
610
+ const cudnnTensorDescriptor_t dyDesc,
611
+ const cudnnConvolutionDescriptor_t convDesc,
612
+ const cudnnFilterDescriptor_t gradDesc,
613
+ cudnnConvolutionBwdFilterAlgo_t algo,
614
+ size_t *sizeInBytes);
615
+
616
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
617
+ cudnnConvolutionBackwardFilter(cudnnHandle_t handle,
618
+ const void *alpha,
619
+ const cudnnTensorDescriptor_t xDesc,
620
+ const void *x,
621
+ const cudnnTensorDescriptor_t dyDesc,
622
+ const void *dy,
623
+ const cudnnConvolutionDescriptor_t convDesc,
624
+ cudnnConvolutionBwdFilterAlgo_t algo,
625
+ void *workSpace,
626
+ size_t workSpaceSizeInBytes,
627
+ const void *beta,
628
+ const cudnnFilterDescriptor_t dwDesc,
629
+ void *dw);
630
+
631
+ /* Function to compute the bias gradient for batch convolution */
632
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
633
+ cudnnConvolutionBackwardBias(cudnnHandle_t handle,
634
+ const void *alpha,
635
+ const cudnnTensorDescriptor_t dyDesc,
636
+ const void *dy,
637
+ const void *beta,
638
+ const cudnnTensorDescriptor_t dbDesc,
639
+ void *db);
640
+
641
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
642
+ cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops);
643
+
644
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
645
+ cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack);
646
+
647
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
648
+ cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack,
649
+ cudnnFusedOpsConstParamLabel_t paramLabel,
650
+ const void *param);
651
+
652
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
653
+ cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack,
654
+ cudnnFusedOpsConstParamLabel_t paramLabel,
655
+ void *param,
656
+ int *isNULL);
657
+
658
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
659
+ cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops);
660
+
661
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
662
+ cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack);
663
+
664
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
665
+ cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack,
666
+ cudnnFusedOpsVariantParamLabel_t paramLabel,
667
+ void *ptr);
668
+
669
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
670
+ cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack,
671
+ cudnnFusedOpsVariantParamLabel_t paramLabel,
672
+ void *ptr);
673
+
674
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
675
+ cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops);
676
+
677
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
678
+ cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan);
679
+
680
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
681
+ cudnnMakeFusedOpsPlan(cudnnHandle_t handle,
682
+ cudnnFusedOpsPlan_t plan,
683
+ const cudnnFusedOpsConstParamPack_t constPack,
684
+ size_t *workspaceSizeInBytes);
685
+
686
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
687
+ cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack);
688
+
689
+ #if defined(__cplusplus)
690
+ }
691
+ #endif
692
+
693
+ #endif /* CUDNN_CNN_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_v9.h ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_cnn : cuDNN's basic definitions and CNN functions.
52
+ */
53
+
54
+ #if !defined(CUDNN_CNN_H_)
55
+ #define CUDNN_CNN_H_
56
+
57
+ #pragma once
58
+ #include <stdint.h>
59
+
60
+ #include "cudnn_version.h"
61
+ #include "cudnn_ops.h"
62
+
63
+ /* These version numbers are autogenerated, do not edit manually. */
64
+ #define CUDNN_CNN_MAJOR 9
65
+ #define CUDNN_CNN_MINOR 1
66
+ #define CUDNN_CNN_PATCH 0
67
+
68
+ #if (CUDNN_CNN_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_MINOR != CUDNN_MINOR) || (CUDNN_CNN_PATCH != CUDNN_PATCHLEVEL)
69
+ #error Version mismatch in cuDNN CNN INFER!!!
70
+ #endif
71
+
72
+ #if defined(__cplusplus)
73
+ extern "C" {
74
+ #endif
75
+
76
+ typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t CUDNN_DEPRECATED;
77
+
78
+ typedef struct cudnnConvolutionFwdAlgoPerfStruct {
79
+ cudnnConvolutionFwdAlgo_t algo;
80
+ cudnnStatus_t status;
81
+ float time;
82
+ size_t memory;
83
+ cudnnDeterminism_t determinism;
84
+ cudnnMathType_t mathType;
85
+ int reserved[3];
86
+ } cudnnConvolutionFwdAlgoPerf_t CUDNN_DEPRECATED;
87
+
88
+ /* Create an instance of convolution descriptor */
89
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
90
+ cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc);
91
+
92
+ /* Destroy an instance of convolution descriptor */
93
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
94
+ cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc);
95
+
96
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
97
+ cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType);
98
+
99
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
100
+ cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType);
101
+
102
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
103
+ cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount);
104
+
105
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
106
+ cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount);
107
+
108
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
109
+ cudnnSetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType);
110
+
111
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
112
+ cudnnGetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType);
113
+
114
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
115
+ cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc,
116
+ int pad_h, /* zero-padding height */
117
+ int pad_w, /* zero-padding width */
118
+ int u, /* vertical filter stride */
119
+ int v, /* horizontal filter stride */
120
+ int dilation_h, /* filter dilation in the vertical dimension */
121
+ int dilation_w, /* filter dilation in the horizontal dimension */
122
+ cudnnConvolutionMode_t mode,
123
+ cudnnDataType_t computeType);
124
+
125
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
126
+ cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc,
127
+ int *pad_h, /* zero-padding height */
128
+ int *pad_w, /* zero-padding width */
129
+ int *u, /* vertical filter stride */
130
+ int *v, /* horizontal filter stride */
131
+ int *dilation_h, /* filter dilation in the vertical dimension */
132
+ int *dilation_w, /* filter dilation in the horizontal dimension */
133
+ cudnnConvolutionMode_t *mode,
134
+ cudnnDataType_t *computeType);
135
+
136
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
137
+ cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc,
138
+ int arrayLength, /* nbDims-2 size */
139
+ const int padA[],
140
+ const int filterStrideA[],
141
+ const int dilationA[],
142
+ cudnnConvolutionMode_t mode,
143
+ cudnnDataType_t computeType); /* convolution data type */
144
+
145
+ /* Helper function to return the dimensions of the output tensor given a convolution descriptor */
146
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
147
+ cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc,
148
+ int arrayLengthRequested,
149
+ int *arrayLength,
150
+ int padA[],
151
+ int strideA[],
152
+ int dilationA[],
153
+ cudnnConvolutionMode_t *mode,
154
+ cudnnDataType_t *computeType); /* convolution data type */
155
+
156
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
157
+ cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
158
+ const cudnnTensorDescriptor_t inputTensorDesc,
159
+ const cudnnFilterDescriptor_t filterDesc,
160
+ int *n,
161
+ int *c,
162
+ int *h,
163
+ int *w);
164
+
165
+ /* Helper function to return the dimensions of the output tensor given a convolution descriptor */
166
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
167
+ cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
168
+ const cudnnTensorDescriptor_t inputTensorDesc,
169
+ const cudnnFilterDescriptor_t filterDesc,
170
+ int nbDims,
171
+ int tensorOuputDimA[]);
172
+
173
+ /* helper function to provide the convolution forward algo that fit best the requirement */
174
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
175
+ cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count);
176
+
177
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
178
+ cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle,
179
+ const cudnnTensorDescriptor_t srcDesc,
180
+ const cudnnFilterDescriptor_t filterDesc,
181
+ const cudnnConvolutionDescriptor_t convDesc,
182
+ const cudnnTensorDescriptor_t destDesc,
183
+ const int requestedAlgoCount,
184
+ int *returnedAlgoCount,
185
+ cudnnConvolutionFwdAlgoPerf_t *perfResults);
186
+
187
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
188
+ cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle,
189
+ const cudnnTensorDescriptor_t xDesc,
190
+ const cudnnFilterDescriptor_t wDesc,
191
+ const cudnnConvolutionDescriptor_t convDesc,
192
+ const cudnnTensorDescriptor_t yDesc,
193
+ const int requestedAlgoCount,
194
+ int *returnedAlgoCount,
195
+ cudnnConvolutionFwdAlgoPerf_t *perfResults);
196
+
197
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
198
+ cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle,
199
+ const cudnnTensorDescriptor_t xDesc,
200
+ const void *x,
201
+ const cudnnFilterDescriptor_t wDesc,
202
+ const void *w,
203
+ const cudnnConvolutionDescriptor_t convDesc,
204
+ const cudnnTensorDescriptor_t yDesc,
205
+ void *y,
206
+ const int requestedAlgoCount,
207
+ int *returnedAlgoCount,
208
+ cudnnConvolutionFwdAlgoPerf_t *perfResults,
209
+ void *workSpace,
210
+ size_t workSpaceSizeInBytes);
211
+
212
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
213
+ cudnnIm2Col(cudnnHandle_t handle,
214
+ const cudnnTensorDescriptor_t xDesc,
215
+ const void *x,
216
+ const cudnnFilterDescriptor_t wDesc,
217
+ const cudnnConvolutionDescriptor_t convDesc,
218
+ void *colBuffer);
219
+
220
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
221
+ cudnnReorderFilterAndBias(cudnnHandle_t handle,
222
+ const cudnnFilterDescriptor_t filterDesc,
223
+ cudnnReorderType_t reorderType,
224
+ const void *filterData,
225
+ void *reorderedFilterData,
226
+ int reorderBias,
227
+ const void *biasData,
228
+ void *reorderedBiasData);
229
+
230
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
231
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
232
+ cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle,
233
+ const cudnnTensorDescriptor_t xDesc,
234
+ const cudnnFilterDescriptor_t wDesc,
235
+ const cudnnConvolutionDescriptor_t convDesc,
236
+ const cudnnTensorDescriptor_t yDesc,
237
+ cudnnConvolutionFwdAlgo_t algo,
238
+ size_t *sizeInBytes);
239
+
240
+ /* Convolution functions: All of the form "output = alpha * Op(inputs) + beta * output" */
241
+
242
+ /* Function to perform the forward pass for batch convolution */
243
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
244
+ cudnnConvolutionForward(cudnnHandle_t handle,
245
+ const void *alpha,
246
+ const cudnnTensorDescriptor_t xDesc,
247
+ const void *x,
248
+ const cudnnFilterDescriptor_t wDesc,
249
+ const void *w,
250
+ const cudnnConvolutionDescriptor_t convDesc,
251
+ cudnnConvolutionFwdAlgo_t algo,
252
+ void *workSpace,
253
+ size_t workSpaceSizeInBytes,
254
+ const void *beta,
255
+ const cudnnTensorDescriptor_t yDesc,
256
+ void *y);
257
+
258
+ /* Fused conv/bias/activation operation : y = Act( alpha1 * conv(x) + alpha2 * z + bias ) */
259
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
260
+ cudnnConvolutionBiasActivationForward(cudnnHandle_t handle,
261
+ const void *alpha1,
262
+ const cudnnTensorDescriptor_t xDesc,
263
+ const void *x,
264
+ const cudnnFilterDescriptor_t wDesc,
265
+ const void *w,
266
+ const cudnnConvolutionDescriptor_t convDesc,
267
+ cudnnConvolutionFwdAlgo_t algo,
268
+ void *workSpace,
269
+ size_t workSpaceSizeInBytes,
270
+ const void *alpha2,
271
+ const cudnnTensorDescriptor_t zDesc,
272
+ const void *z,
273
+ const cudnnTensorDescriptor_t biasDesc,
274
+ const void *bias,
275
+ const cudnnActivationDescriptor_t activationDesc,
276
+ const cudnnTensorDescriptor_t yDesc,
277
+ void *y);
278
+
279
+ /* helper function to provide the convolution backward data algo that fit best the requirement */
280
+
281
+ typedef struct cudnnConvolutionBwdDataAlgoPerfStruct {
282
+ cudnnConvolutionBwdDataAlgo_t algo;
283
+ cudnnStatus_t status;
284
+ float time;
285
+ size_t memory;
286
+ cudnnDeterminism_t determinism;
287
+ cudnnMathType_t mathType;
288
+ int reserved[3];
289
+ } cudnnConvolutionBwdDataAlgoPerf_t CUDNN_DEPRECATED;
290
+
291
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
292
+ cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count);
293
+
294
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
295
+ cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
296
+ const cudnnFilterDescriptor_t wDesc,
297
+ const cudnnTensorDescriptor_t dyDesc,
298
+ const cudnnConvolutionDescriptor_t convDesc,
299
+ const cudnnTensorDescriptor_t dxDesc,
300
+ const int requestedAlgoCount,
301
+ int *returnedAlgoCount,
302
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
303
+
304
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
305
+ cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle,
306
+ const cudnnFilterDescriptor_t wDesc,
307
+ const void *w,
308
+ const cudnnTensorDescriptor_t dyDesc,
309
+ const void *dy,
310
+ const cudnnConvolutionDescriptor_t convDesc,
311
+ const cudnnTensorDescriptor_t dxDesc,
312
+ void *dx,
313
+ const int requestedAlgoCount,
314
+ int *returnedAlgoCount,
315
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults,
316
+ void *workSpace,
317
+ size_t workSpaceSizeInBytes);
318
+
319
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
320
+ cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle,
321
+ const cudnnFilterDescriptor_t filterDesc,
322
+ const cudnnTensorDescriptor_t diffDesc,
323
+ const cudnnConvolutionDescriptor_t convDesc,
324
+ const cudnnTensorDescriptor_t gradDesc,
325
+ const int requestedAlgoCount,
326
+ int *returnedAlgoCount,
327
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
328
+
329
+ /*
330
+ * convolution algorithm (which requires potentially some workspace)
331
+ */
332
+
333
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
334
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
335
+ cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle,
336
+ const cudnnFilterDescriptor_t wDesc,
337
+ const cudnnTensorDescriptor_t dyDesc,
338
+ const cudnnConvolutionDescriptor_t convDesc,
339
+ const cudnnTensorDescriptor_t dxDesc,
340
+ cudnnConvolutionBwdDataAlgo_t algo,
341
+ size_t *sizeInBytes);
342
+
343
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
344
+ cudnnConvolutionBackwardData(cudnnHandle_t handle,
345
+ const void *alpha,
346
+ const cudnnFilterDescriptor_t wDesc,
347
+ const void *w,
348
+ const cudnnTensorDescriptor_t dyDesc,
349
+ const void *dy,
350
+ const cudnnConvolutionDescriptor_t convDesc,
351
+ cudnnConvolutionBwdDataAlgo_t algo,
352
+ void *workSpace,
353
+ size_t workSpaceSizeInBytes,
354
+ const void *beta,
355
+ const cudnnTensorDescriptor_t dxDesc,
356
+ void *dx);
357
+
358
+ /* Helper function to calculate folding descriptors for dgrad */
359
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
360
+ cudnnGetFoldedConvBackwardDataDescriptors(const cudnnHandle_t handle,
361
+ const cudnnFilterDescriptor_t filterDesc,
362
+ const cudnnTensorDescriptor_t diffDesc,
363
+ const cudnnConvolutionDescriptor_t convDesc,
364
+ const cudnnTensorDescriptor_t gradDesc,
365
+ const cudnnTensorFormat_t transformFormat,
366
+ cudnnFilterDescriptor_t foldedFilterDesc,
367
+ cudnnTensorDescriptor_t paddedDiffDesc,
368
+ cudnnConvolutionDescriptor_t foldedConvDesc,
369
+ cudnnTensorDescriptor_t foldedGradDesc,
370
+ cudnnTensorTransformDescriptor_t filterFoldTransDesc,
371
+ cudnnTensorTransformDescriptor_t diffPadTransDesc,
372
+ cudnnTensorTransformDescriptor_t gradFoldTransDesc,
373
+ cudnnTensorTransformDescriptor_t gradUnfoldTransDesc);
374
+
375
+ /* cudnnFusedOps... */
376
+ struct cudnnFusedOpsConstParamStruct;
377
+ typedef struct cudnnFusedOpsConstParamStruct *cudnnFusedOpsConstParamPack_t CUDNN_DEPRECATED;
378
+
379
+ struct cudnnFusedOpsVariantParamStruct;
380
+ typedef struct cudnnFusedOpsVariantParamStruct *cudnnFusedOpsVariantParamPack_t CUDNN_DEPRECATED;
381
+
382
+ struct cudnnFusedOpsPlanStruct;
383
+ typedef struct cudnnFusedOpsPlanStruct *cudnnFusedOpsPlan_t CUDNN_DEPRECATED;
384
+
385
+ typedef enum {
386
+ /* each op in [ ] can be disabled by passing NULL ptr */
387
+ /* [per channel scale], [per channel bias], [activation], convolution, [generate BN stats] */
388
+ CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS = 0,
389
+ /* [per channel scale], [per channel bias], [activation], convolutionBackwardWeights */
390
+ CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD = 1,
391
+ /* utility for BN training in BN-conv fusion */
392
+ /* computes the equivalent scale and bias from ySum ySqSum and learned scale, bias */
393
+ /* optionally update running stats and generate saved stats */
394
+ CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING = 2,
395
+ /* utility for BN inference in BN-conv fusion */
396
+ /* computes the equivalent scale and bias from learned running stats and learned scale, bias */
397
+ CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE = 3,
398
+ /* reserved for future use: convolution, [per channel scale], [per channel bias], [residual add], [activation] */
399
+ CUDNN_FUSED_CONV_SCALE_BIAS_ADD_ACTIVATION = 4,
400
+ /* reserved for future use: [per channel scale], [per channel bias], [residual add], activation, bitmask */
401
+ CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK = 5,
402
+ /* reserved for future use */
403
+ CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM = 6,
404
+ } cudnnFusedOps_t CUDNN_DEPRECATED;
405
+
406
+ typedef enum {
407
+ /* set XDESC: pass previously initialized cudnnTensorDescriptor_t */
408
+ /* get XDESC: pass previously created cudnnTensorDescriptor_t */
409
+ CUDNN_PARAM_XDESC = 0,
410
+ /* set/get XDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
411
+ CUDNN_PARAM_XDATA_PLACEHOLDER = 1,
412
+ /* set/get BN_MODE: pass cudnnBatchNormMode_t* */
413
+ CUDNN_PARAM_BN_MODE = 2,
414
+ /* set CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
415
+ /* get CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
416
+ CUDNN_PARAM_BN_EQSCALEBIAS_DESC = 3,
417
+ /* set/get BN_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
418
+ CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER = 4,
419
+ /* set/get BN_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
420
+ CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER = 5,
421
+ /* set ACTIVATION_DESC: pass previously initialized cudnnActivationDescriptor_t */
422
+ /* get ACTIVATION_DESC: pass previously created cudnnActivationDescriptor_t */
423
+ CUDNN_PARAM_ACTIVATION_DESC = 6,
424
+ /* set CONV_DESC: pass previously initialized cudnnConvolutionDescriptor_t */
425
+ /* get CONV_DESC: pass previously created cudnnConvolutionDescriptor_t */
426
+ CUDNN_PARAM_CONV_DESC = 7,
427
+ /* set WDESC: pass previously initialized cudnnFilterDescriptor_t */
428
+ /* get WDESC: pass previously created cudnnFilterDescriptor_t */
429
+ CUDNN_PARAM_WDESC = 8,
430
+ /* set/get WDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
431
+ CUDNN_PARAM_WDATA_PLACEHOLDER = 9,
432
+ /* set DWDESC: pass previously initialized cudnnFilterDescriptor_t */
433
+ /* get DWDESC: pass previously created cudnnFilterDescriptor_t */
434
+ CUDNN_PARAM_DWDESC = 10,
435
+ /* set/get DWDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
436
+ CUDNN_PARAM_DWDATA_PLACEHOLDER = 11,
437
+ /* set YDESC: pass previously initialized cudnnTensorDescriptor_t */
438
+ /* get YDESC: pass previously created cudnnTensorDescriptor_t */
439
+ CUDNN_PARAM_YDESC = 12,
440
+ /* set/get YDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
441
+ CUDNN_PARAM_YDATA_PLACEHOLDER = 13,
442
+ /* set DYDESC: pass previously initialized cudnnTensorDescriptor_t */
443
+ /* get DYDESC: pass previously created cudnnTensorDescriptor_t */
444
+ CUDNN_PARAM_DYDESC = 14,
445
+ /* set/get DYDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
446
+ CUDNN_PARAM_DYDATA_PLACEHOLDER = 15,
447
+ /* set YSTATS_DESC: pass previously initialized cudnnTensorDescriptor_t */
448
+ /* get YSTATS_DESC: pass previously created cudnnTensorDescriptor_t */
449
+ CUDNN_PARAM_YSTATS_DESC = 16,
450
+ /* set/get YSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
451
+ CUDNN_PARAM_YSUM_PLACEHOLDER = 17,
452
+ /* set/get YSQSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
453
+ CUDNN_PARAM_YSQSUM_PLACEHOLDER = 18,
454
+ /* set CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously initialized cudnnTensorDescriptor_t */
455
+ /* get CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously created cudnnTensorDescriptor_t */
456
+ CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC = 19,
457
+ /* set/get CUDNN_PARAM_BN_SCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
458
+ CUDNN_PARAM_BN_SCALE_PLACEHOLDER = 20,
459
+ /* set/get CUDNN_PARAM_BN_BIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
460
+ CUDNN_PARAM_BN_BIAS_PLACEHOLDER = 21,
461
+ /* set/get CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
462
+ CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER = 22,
463
+ /* set/get CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
464
+ CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER = 23,
465
+ /* set/get CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
466
+ CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER = 24,
467
+ /* set/get CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
468
+ CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER = 25,
469
+
470
+ /* set ZDESC: pass previously initialized cudnnTensorDescriptor_t */
471
+ /* get ZDESC: pass previously created cudnnTensorDescriptor_t */
472
+ CUDNN_PARAM_ZDESC = 26,
473
+ /* set/get ZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
474
+ CUDNN_PARAM_ZDATA_PLACEHOLDER = 27,
475
+ /* set BN_Z_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
476
+ /* get BN_Z_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
477
+ CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC = 28,
478
+ /* set/get BN_Z_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
479
+ CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER = 29,
480
+ /* set/get BN_Z_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
481
+ CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER = 30,
482
+
483
+ /* set ACTIVATION_BITMASK_DESC: pass previously initialized cudnnTensorDescriptor_t */
484
+ /* get ACTIVATION_BITMASK_DESC: pass previously created cudnnTensorDescriptor_t */
485
+ CUDNN_PARAM_ACTIVATION_BITMASK_DESC = 31,
486
+ /* set/get ACTIVATION_BITMASK_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
487
+ CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER = 32,
488
+
489
+ /* set DXDESC: pass previously initialized cudnnTensorDescriptor_t */
490
+ /* get DXDESC: pass previously created cudnnTensorDescriptor_t */
491
+ CUDNN_PARAM_DXDESC = 33,
492
+ /* set/get DXDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
493
+ CUDNN_PARAM_DXDATA_PLACEHOLDER = 34,
494
+ /* set DZDESC: pass previously initialized cudnnTensorDescriptor_t */
495
+ /* get DZDESC: pass previously created cudnnTensorDescriptor_t */
496
+ CUDNN_PARAM_DZDESC = 35,
497
+ /* set/get DZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
498
+ CUDNN_PARAM_DZDATA_PLACEHOLDER = 36,
499
+ /* set/get CUDNN_PARAM_BN_DSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
500
+ CUDNN_PARAM_BN_DSCALE_PLACEHOLDER = 37,
501
+ /* set/get CUDNN_PARAM_BN_DBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
502
+ CUDNN_PARAM_BN_DBIAS_PLACEHOLDER = 38,
503
+ } cudnnFusedOpsConstParamLabel_t CUDNN_DEPRECATED;
504
+
505
+ typedef enum {
506
+ CUDNN_PTR_NULL = 0,
507
+ CUDNN_PTR_ELEM_ALIGNED = 1,
508
+ CUDNN_PTR_16B_ALIGNED = 2,
509
+ } cudnnFusedOpsPointerPlaceHolder_t CUDNN_DEPRECATED;
510
+
511
+ typedef enum {
512
+ /* set: pass void* pointing to dev memory */
513
+ /* get: pass void** pointing to host memory */
514
+ CUDNN_PTR_XDATA = 0,
515
+ CUDNN_PTR_BN_EQSCALE = 1,
516
+ CUDNN_PTR_BN_EQBIAS = 2,
517
+ CUDNN_PTR_WDATA = 3,
518
+ CUDNN_PTR_DWDATA = 4,
519
+ CUDNN_PTR_YDATA = 5,
520
+ CUDNN_PTR_DYDATA = 6,
521
+ CUDNN_PTR_YSUM = 7,
522
+ CUDNN_PTR_YSQSUM = 8,
523
+ CUDNN_PTR_WORKSPACE = 9,
524
+ CUDNN_PTR_BN_SCALE = 10,
525
+ CUDNN_PTR_BN_BIAS = 11,
526
+ CUDNN_PTR_BN_SAVED_MEAN = 12,
527
+ CUDNN_PTR_BN_SAVED_INVSTD = 13,
528
+ CUDNN_PTR_BN_RUNNING_MEAN = 14,
529
+ CUDNN_PTR_BN_RUNNING_VAR = 15,
530
+ CUDNN_PTR_ZDATA = 16,
531
+ CUDNN_PTR_BN_Z_EQSCALE = 17,
532
+ CUDNN_PTR_BN_Z_EQBIAS = 18,
533
+ CUDNN_PTR_ACTIVATION_BITMASK = 19,
534
+ CUDNN_PTR_DXDATA = 20,
535
+ CUDNN_PTR_DZDATA = 21,
536
+ CUDNN_PTR_BN_DSCALE = 22,
537
+ CUDNN_PTR_BN_DBIAS = 23,
538
+
539
+ /* set/get: pass size_t* pointing to host memory */
540
+ CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES = 100,
541
+ /* set/get: pass int64_t* pointing to host memory */
542
+ CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT = 101,
543
+ /* set/get: pass double* pointing to host memory */
544
+ CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR = 102,
545
+ /* set/get: pass double* pointing to host memory */
546
+ CUDNN_SCALAR_DOUBLE_BN_EPSILON = 103,
547
+ } cudnnFusedOpsVariantParamLabel_t CUDNN_DEPRECATED;
548
+
549
+ cudnnStatus_t CUDNNWINAPI
550
+ cudnnCnnVersionCheck(void);
551
+
552
+ /* helper function to provide the convolution backward filter algo that fit best the requirement */
553
+
554
+ typedef struct cudnnConvolutionBwdFilterAlgoPerfStruct {
555
+ cudnnConvolutionBwdFilterAlgo_t algo;
556
+ cudnnStatus_t status;
557
+ float time;
558
+ size_t memory;
559
+ cudnnDeterminism_t determinism;
560
+ cudnnMathType_t mathType;
561
+ int reserved[3];
562
+ } cudnnConvolutionBwdFilterAlgoPerf_t CUDNN_DEPRECATED;
563
+
564
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
565
+ cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count);
566
+
567
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
568
+ cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
569
+ const cudnnTensorDescriptor_t xDesc,
570
+ const cudnnTensorDescriptor_t dyDesc,
571
+ const cudnnConvolutionDescriptor_t convDesc,
572
+ const cudnnFilterDescriptor_t dwDesc,
573
+ const int requestedAlgoCount,
574
+ int *returnedAlgoCount,
575
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
576
+
577
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
578
+ cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
579
+ const cudnnTensorDescriptor_t xDesc,
580
+ const void *x,
581
+ const cudnnTensorDescriptor_t dyDesc,
582
+ const void *y,
583
+ const cudnnConvolutionDescriptor_t convDesc,
584
+ const cudnnFilterDescriptor_t dwDesc,
585
+ void *dw,
586
+ const int requestedAlgoCount,
587
+ int *returnedAlgoCount,
588
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
589
+ void *workSpace,
590
+ size_t workSpaceSizeInBytes);
591
+
592
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
593
+ cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
594
+ const cudnnTensorDescriptor_t srcDesc,
595
+ const cudnnTensorDescriptor_t diffDesc,
596
+ const cudnnConvolutionDescriptor_t convDesc,
597
+ const cudnnFilterDescriptor_t gradDesc,
598
+ const int requestedAlgoCount,
599
+ int *returnedAlgoCount,
600
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
601
+
602
+ /*
603
+ * convolution algorithm (which requires potentially some workspace)
604
+ */
605
+
606
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
607
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
608
+ cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle,
609
+ const cudnnTensorDescriptor_t xDesc,
610
+ const cudnnTensorDescriptor_t dyDesc,
611
+ const cudnnConvolutionDescriptor_t convDesc,
612
+ const cudnnFilterDescriptor_t gradDesc,
613
+ cudnnConvolutionBwdFilterAlgo_t algo,
614
+ size_t *sizeInBytes);
615
+
616
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
617
+ cudnnConvolutionBackwardFilter(cudnnHandle_t handle,
618
+ const void *alpha,
619
+ const cudnnTensorDescriptor_t xDesc,
620
+ const void *x,
621
+ const cudnnTensorDescriptor_t dyDesc,
622
+ const void *dy,
623
+ const cudnnConvolutionDescriptor_t convDesc,
624
+ cudnnConvolutionBwdFilterAlgo_t algo,
625
+ void *workSpace,
626
+ size_t workSpaceSizeInBytes,
627
+ const void *beta,
628
+ const cudnnFilterDescriptor_t dwDesc,
629
+ void *dw);
630
+
631
+ /* Function to compute the bias gradient for batch convolution */
632
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
633
+ cudnnConvolutionBackwardBias(cudnnHandle_t handle,
634
+ const void *alpha,
635
+ const cudnnTensorDescriptor_t dyDesc,
636
+ const void *dy,
637
+ const void *beta,
638
+ const cudnnTensorDescriptor_t dbDesc,
639
+ void *db);
640
+
641
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
642
+ cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops);
643
+
644
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
645
+ cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack);
646
+
647
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
648
+ cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack,
649
+ cudnnFusedOpsConstParamLabel_t paramLabel,
650
+ const void *param);
651
+
652
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
653
+ cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack,
654
+ cudnnFusedOpsConstParamLabel_t paramLabel,
655
+ void *param,
656
+ int *isNULL);
657
+
658
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
659
+ cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops);
660
+
661
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
662
+ cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack);
663
+
664
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
665
+ cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack,
666
+ cudnnFusedOpsVariantParamLabel_t paramLabel,
667
+ void *ptr);
668
+
669
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
670
+ cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack,
671
+ cudnnFusedOpsVariantParamLabel_t paramLabel,
672
+ void *ptr);
673
+
674
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
675
+ cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops);
676
+
677
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
678
+ cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan);
679
+
680
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
681
+ cudnnMakeFusedOpsPlan(cudnnHandle_t handle,
682
+ cudnnFusedOpsPlan_t plan,
683
+ const cudnnFusedOpsConstParamPack_t constPack,
684
+ size_t *workspaceSizeInBytes);
685
+
686
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
687
+ cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack);
688
+
689
+ #if defined(__cplusplus)
690
+ }
691
+ #endif
692
+
693
+ #endif /* CUDNN_CNN_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_graph_v9.h ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_graph : cuDNN's basic definitions operations.
52
+ */
53
+
54
+ #if !defined(CUDNN_GRAPH_H_)
55
+ #define CUDNN_GRAPH_H_
56
+
57
+ #include <cuda_runtime_api.h>
58
+ #include <library_types.h>
59
+
60
+ #include <stdint.h>
61
+
62
+ #include "cudnn_version.h"
63
+
64
+ /* These version numbers are autogenerated, do not edit manually. */
65
+ #define CUDNN_GRAPH_MAJOR 9
66
+ #define CUDNN_GRAPH_MINOR 1
67
+ #define CUDNN_GRAPH_PATCH 0
68
+
69
+ #if (CUDNN_GRAPH_MAJOR != CUDNN_MAJOR) || (CUDNN_GRAPH_MINOR != CUDNN_MINOR) || (CUDNN_GRAPH_PATCH != CUDNN_PATCHLEVEL)
70
+ #error Version mismatch in cuDNN GRAPH!!!
71
+ #endif
72
+
73
+ #ifndef CUDNNWINAPI
74
+ #ifdef _WIN32
75
+ #define CUDNNWINAPI __stdcall
76
+ #else
77
+ #define CUDNNWINAPI
78
+ #endif
79
+ #endif
80
+
81
+ /* Warnings for deprecated API-s are enabled using the CUDNN_WARN_DEPRECATED macro */
82
+ #if defined(CUDNN_WARN_DEPRECATED) && (defined(__GNUC__) || defined(__clang__))
83
+ /* GCC, Intel C/C++, Cray C/C++, CLANG, IBM XL C/C++ little endian */
84
+ #define CUDNN_DEPRECATED __attribute__((deprecated))
85
+ #define CUDNN_DEPRECATED_ENUM __attribute__((deprecated))
86
+ #elif defined(CUDNN_WARN_DEPRECATED) && defined(_MSC_VER)
87
+ /* Microsoft Visual C++ */
88
+ #define CUDNN_DEPRECATED __declspec(deprecated)
89
+ #define CUDNN_DEPRECATED_ENUM __declspec(deprecated)
90
+ #elif defined(CUDNN_WARN_DEPRECATED) && (__cplusplus >= 201402L)
91
+ /* C++14 compilers */
92
+ #define CUDNN_DEPRECATED [[deprecated]]
93
+ #define CUDNN_DEPRECATED_ENUM [[deprecated]]
94
+ #else
95
+ /* No support for the deprecated attribute */
96
+ #define CUDNN_DEPRECATED
97
+ #define CUDNN_DEPRECATED_ENUM
98
+ #endif
99
+
100
+ #if defined(__cplusplus)
101
+ extern "C" {
102
+ #endif
103
+
104
+ struct cudnnContext;
105
+ typedef struct cudnnContext *cudnnHandle_t;
106
+
107
+ size_t CUDNNWINAPI
108
+ cudnnGetVersion(void);
109
+
110
+ size_t CUDNNWINAPI
111
+ cudnnGetMaxDeviceVersion(void);
112
+
113
+ /* Returns CUDA Runtime version statically linked against cudnn */
114
+ size_t CUDNNWINAPI
115
+ cudnnGetCudartVersion(void);
116
+
117
+ /*
118
+ * CUDNN return codes
119
+ */
120
+ typedef enum {
121
+ CUDNN_STATUS_SUCCESS = 0,
122
+
123
+ /* Uncategorized errors */
124
+ CUDNN_STATUS_NOT_INITIALIZED = 1001,
125
+ CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH = 1002,
126
+ CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH = 1003,
127
+ CUDNN_STATUS_DEPRECATED = 1004,
128
+ CUDNN_STATUS_LICENSE_ERROR = 1005,
129
+ CUDNN_STATUS_RUNTIME_IN_PROGRESS = 1006,
130
+ CUDNN_STATUS_RUNTIME_FP_OVERFLOW = 1007,
131
+
132
+ CUDNN_STATUS_BAD_PARAM = 2000,
133
+ CUDNN_STATUS_BAD_PARAM_NULL_POINTER = 2002,
134
+ CUDNN_STATUS_BAD_PARAM_MISALIGNED_POINTER = 2003,
135
+ CUDNN_STATUS_BAD_PARAM_NOT_FINALIZED = 2004,
136
+ CUDNN_STATUS_BAD_PARAM_OUT_OF_BOUND = 2005,
137
+ CUDNN_STATUS_BAD_PARAM_SIZE_INSUFFICIENT = 2006,
138
+ CUDNN_STATUS_BAD_PARAM_STREAM_MISMATCH = 2007,
139
+ CUDNN_STATUS_BAD_PARAM_SHAPE_MISMATCH = 2008,
140
+ CUDNN_STATUS_BAD_PARAM_DUPLICATED_ENTRIES = 2009,
141
+ CUDNN_STATUS_BAD_PARAM_ATTRIBUTE_TYPE = 2010,
142
+
143
+ CUDNN_STATUS_NOT_SUPPORTED = 3000,
144
+ CUDNN_STATUS_NOT_SUPPORTED_GRAPH_PATTERN = 3001,
145
+ CUDNN_STATUS_NOT_SUPPORTED_SHAPE = 3002,
146
+ CUDNN_STATUS_NOT_SUPPORTED_DATA_TYPE = 3003,
147
+ CUDNN_STATUS_NOT_SUPPORTED_LAYOUT = 3004,
148
+ CUDNN_STATUS_NOT_SUPPORTED_INCOMPATIBLE_CUDA_DRIVER = 3005,
149
+ CUDNN_STATUS_NOT_SUPPORTED_INCOMPATIBLE_CUDART = 3006,
150
+ CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH = 3007,
151
+ CUDNN_STATUS_NOT_SUPPORTED_RUNTIME_PREREQUISITE_MISSING = 3008,
152
+ CUDNN_STATUS_NOT_SUPPORTED_SUBLIBRARY_UNAVAILABLE = 3009,
153
+ CUDNN_STATUS_NOT_SUPPORTED_SHARED_MEMORY_INSUFFICIENT = 3010,
154
+ CUDNN_STATUS_NOT_SUPPORTED_PADDING = 3011,
155
+ CUDNN_STATUS_NOT_SUPPORTED_BAD_LAUNCH_PARAM = 3012,
156
+
157
+ CUDNN_STATUS_INTERNAL_ERROR = 4000,
158
+ CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED = 4001,
159
+ CUDNN_STATUS_INTERNAL_ERROR_UNEXPECTED_VALUE = 4002,
160
+ CUDNN_STATUS_INTERNAL_ERROR_HOST_ALLOCATION_FAILED = 4003,
161
+ CUDNN_STATUS_INTERNAL_ERROR_DEVICE_ALLOCATION_FAILED = 4004,
162
+ CUDNN_STATUS_INTERNAL_ERROR_BAD_LAUNCH_PARAM = 4005,
163
+ CUDNN_STATUS_INTERNAL_ERROR_TEXTURE_CREATION_FAILED = 4006,
164
+
165
+ CUDNN_STATUS_EXECUTION_FAILED = 5000,
166
+ CUDNN_STATUS_EXECUTION_FAILED_CUDA_DRIVER = 5001,
167
+ CUDNN_STATUS_EXECUTION_FAILED_CUBLAS = 5002,
168
+ CUDNN_STATUS_EXECUTION_FAILED_CUDART = 5003,
169
+ CUDNN_STATUS_EXECUTION_FAILED_CURAND = 5004,
170
+
171
+ CUDNN_STATUS_ALLOC_FAILED CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_INTERNAL_ERROR_HOST_ALLOCATION_FAILED,
172
+ CUDNN_STATUS_INVALID_VALUE CUDNN_DEPRECATED_ENUM = 2001 /* please transition to CUDNN_STATUS_BAD_PARAM instead */,
173
+ CUDNN_STATUS_ARCH_MISMATCH CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH,
174
+ CUDNN_STATUS_MAPPING_ERROR CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_INTERNAL_ERROR_TEXTURE_CREATION_FAILED,
175
+ CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING CUDNN_DEPRECATED_ENUM =
176
+ CUDNN_STATUS_NOT_SUPPORTED_RUNTIME_PREREQUISITE_MISSING,
177
+ CUDNN_STATUS_VERSION_MISMATCH CUDNN_DEPRECATED_ENUM = CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH,
178
+ } cudnnStatus_t;
179
+
180
+ #define CUDNN_STATUS_FULL_ERROR_CODE(category, specific_err) ((cudnnStatus_t)(0 + (category) + (specific_err)))
181
+ #define CUDNN_STATUS_CATEGORY(full_error_code) ((full_error_code) / 1000 * 1000)
182
+ #define CUDNN_STATUS_SPECIFIC_ERROR(full_error_code) ((full_error_code) % 1000)
183
+
184
+ /* human-readable error messages */
185
+ const char *CUDNNWINAPI
186
+ cudnnGetErrorString(cudnnStatus_t status);
187
+
188
+ void CUDNNWINAPI
189
+ cudnnGetLastErrorString(char *message, size_t max_size);
190
+
191
+ /* Forward definition in this version only */
192
+ typedef struct cudnnRuntimeTag_t cudnnRuntimeTag_t CUDNN_DEPRECATED;
193
+
194
+ typedef enum {
195
+ CUDNN_ERRQUERY_RAWCODE = 0,
196
+ CUDNN_ERRQUERY_NONBLOCKING = 1,
197
+ CUDNN_ERRQUERY_BLOCKING = 2,
198
+ } cudnnErrQueryMode_t;
199
+
200
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
201
+ cudnnQueryRuntimeError(cudnnHandle_t handle, cudnnStatus_t *rstatus, cudnnErrQueryMode_t mode, cudnnRuntimeTag_t *tag);
202
+
203
+ cudnnStatus_t CUDNNWINAPI
204
+ cudnnGetProperty(libraryPropertyType type, int *value);
205
+
206
+ cudnnStatus_t CUDNNWINAPI
207
+ cudnnCreate(cudnnHandle_t *handle);
208
+ cudnnStatus_t CUDNNWINAPI
209
+ cudnnDestroy(cudnnHandle_t handle);
210
+ cudnnStatus_t CUDNNWINAPI
211
+ cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
212
+ cudnnStatus_t CUDNNWINAPI
213
+ cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId);
214
+ /*
215
+ * CUDNN data type
216
+ */
217
+ typedef enum {
218
+ CUDNN_DATA_FLOAT = 0,
219
+ CUDNN_DATA_DOUBLE = 1,
220
+ CUDNN_DATA_HALF = 2,
221
+ CUDNN_DATA_INT8 = 3,
222
+ CUDNN_DATA_INT32 = 4,
223
+ CUDNN_DATA_INT8x4 CUDNN_DEPRECATED_ENUM = 5,
224
+ CUDNN_DATA_UINT8 = 6,
225
+ CUDNN_DATA_UINT8x4 CUDNN_DEPRECATED_ENUM = 7,
226
+ CUDNN_DATA_INT8x32 CUDNN_DEPRECATED_ENUM = 8,
227
+ CUDNN_DATA_BFLOAT16 = 9,
228
+ CUDNN_DATA_INT64 = 10,
229
+ CUDNN_DATA_BOOLEAN = 11,
230
+ CUDNN_DATA_FP8_E4M3 = 12,
231
+ CUDNN_DATA_FP8_E5M2 = 13,
232
+ CUDNN_DATA_FAST_FLOAT_FOR_FP8 = 14,
233
+ } cudnnDataType_t;
234
+
235
+ /*
236
+ * CUDNN math type
237
+ */
238
+ typedef enum {
239
+ CUDNN_DEFAULT_MATH = 0,
240
+ CUDNN_TENSOR_OP_MATH = 1,
241
+ CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION = 2,
242
+ CUDNN_FMA_MATH = 3,
243
+ } cudnnMathType_t;
244
+
245
+ /*
246
+ * CUDNN propagate Nan
247
+ */
248
+ typedef enum {
249
+ CUDNN_NOT_PROPAGATE_NAN CUDNN_DEPRECATED_ENUM = 0,
250
+ CUDNN_PROPAGATE_NAN CUDNN_DEPRECATED_ENUM = 1,
251
+ } cudnnNanPropagation_t;
252
+
253
+ /*
254
+ * Behavior for OOB samples. OOB samples are samples where L+R > T is encountered during the gradient calculation. If
255
+ * gradMode is set to CUDNN_CTC_SKIP_OOB_GRADIENTS, then the CTC loss function does not write to the gradient buffer for
256
+ * that sample. Instead, the current values, even not finite, are retained. If gradMode is set to
257
+ * CUDNN_CTC_ZERO_OOB_GRADIENTS, then the gradient for that sample is set to zero. This guarantees a finite gradient.
258
+ */
259
+ typedef enum {
260
+ CUDNN_CTC_ZERO_OOB_GRADIENTS = 0,
261
+ CUDNN_CTC_SKIP_OOB_GRADIENTS = 1,
262
+ } cudnnCTCGradMode_t;
263
+
264
+ typedef enum {
265
+ CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */
266
+ CUDNN_TENSOR_NHWC = 1, /* feature maps interleaved ( cStride = 1 )*/
267
+ CUDNN_TENSOR_NCHW_VECT_C = 2, /* each image point is vector of element of C, vector length in data type */
268
+ } cudnnTensorFormat_t;
269
+
270
+ /*
271
+ * CUDNN ReduceTensor op type
272
+ */
273
+ typedef enum {
274
+ CUDNN_REDUCE_TENSOR_ADD = 0,
275
+ CUDNN_REDUCE_TENSOR_MUL = 1,
276
+ CUDNN_REDUCE_TENSOR_MIN = 2,
277
+ CUDNN_REDUCE_TENSOR_MAX = 3,
278
+ CUDNN_REDUCE_TENSOR_AMAX = 4,
279
+ CUDNN_REDUCE_TENSOR_AVG = 5,
280
+ CUDNN_REDUCE_TENSOR_NORM1 = 6,
281
+ CUDNN_REDUCE_TENSOR_NORM2 = 7,
282
+ CUDNN_REDUCE_TENSOR_MUL_NO_ZEROS = 8,
283
+ } cudnnReduceTensorOp_t;
284
+
285
+ /*
286
+ * activation mode
287
+ */
288
+ typedef enum {
289
+ CUDNN_ACTIVATION_SIGMOID = 0,
290
+ CUDNN_ACTIVATION_RELU = 1,
291
+ CUDNN_ACTIVATION_TANH = 2,
292
+ CUDNN_ACTIVATION_CLIPPED_RELU = 3,
293
+ CUDNN_ACTIVATION_ELU = 4,
294
+ CUDNN_ACTIVATION_IDENTITY = 5,
295
+ CUDNN_ACTIVATION_SWISH = 6
296
+ } cudnnActivationMode_t CUDNN_DEPRECATED;
297
+
298
+ typedef enum {
299
+ CUDNN_SEV_FATAL = 0,
300
+ CUDNN_SEV_ERROR = 1,
301
+ CUDNN_SEV_WARNING = 2,
302
+ CUDNN_SEV_INFO = 3,
303
+ } cudnnSeverity_t;
304
+
305
+ /* Message masks to be used with cudnnSetCallback() */
306
+ #define CUDNN_SEV_ERROR_EN (1U << CUDNN_SEV_ERROR)
307
+ #define CUDNN_SEV_WARNING_EN (1U << CUDNN_SEV_WARNING)
308
+ #define CUDNN_SEV_INFO_EN (1U << CUDNN_SEV_INFO)
309
+
310
+ /* struct containing useful informaiton for each API call */
311
+ typedef struct cudnnDebugStruct {
312
+ unsigned cudnn_version;
313
+ cudnnStatus_t cudnnStatus;
314
+ unsigned time_sec; /* epoch time in seconds */
315
+ unsigned time_usec; /* microseconds part of epoch time */
316
+ unsigned time_delta; /* time since start in seconds */
317
+ cudnnHandle_t handle; /* cudnn handle */
318
+ cudaStream_t stream; /* cuda stream ID */
319
+ unsigned long long pid; /* process ID */
320
+ unsigned long long tid; /* thread ID */
321
+ int cudaDeviceId; /* CUDA device ID */
322
+ int reserved[15]; /* reserved for future use */
323
+ } cudnnDebug_t;
324
+
325
+ typedef void (*cudnnCallback_t)(cudnnSeverity_t sev, void *udata, const cudnnDebug_t *dbg, const char *msg);
326
+
327
+ cudnnStatus_t CUDNNWINAPI
328
+ cudnnSetCallback(unsigned mask, void *udata, cudnnCallback_t fptr);
329
+
330
+ cudnnStatus_t CUDNNWINAPI
331
+ cudnnGetCallback(unsigned *mask, void **udata, cudnnCallback_t *fptr);
332
+
333
+ /*
334
+ * \brief Cross-library version checker.
335
+ * This function is implemented differently in each sub-library. Each sublib
336
+ * checks whether its own version matches that of its dependencies.
337
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
338
+ * CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
339
+ */
340
+ cudnnStatus_t CUDNNWINAPI
341
+ cudnnGraphVersionCheck(void);
342
+
343
+ /* Maximum supported number of tensor dimensions */
344
+ #define CUDNN_DIM_MAX 8
345
+
346
+ /*
347
+ * convolution mode
348
+ */
349
+ typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t;
350
+
351
+ /*
352
+ * CUDNN Reorder
353
+ */
354
+ typedef enum {
355
+ CUDNN_DEFAULT_REORDER = 0,
356
+ CUDNN_NO_REORDER = 1,
357
+ } cudnnReorderType_t CUDNN_DEPRECATED;
358
+
359
+ typedef void *cudnnBackendDescriptor_t;
360
+
361
+ typedef struct cudnnFractionStruct {
362
+ int64_t numerator;
363
+ int64_t denominator;
364
+ } cudnnFraction_t;
365
+
366
+ typedef enum {
367
+ CUDNN_POINTWISE_ADD = 0,
368
+ CUDNN_POINTWISE_ADD_SQUARE = 5,
369
+ CUDNN_POINTWISE_DIV = 6,
370
+ CUDNN_POINTWISE_MAX = 3,
371
+ CUDNN_POINTWISE_MIN = 2,
372
+ CUDNN_POINTWISE_MOD = 7,
373
+ CUDNN_POINTWISE_MUL = 1,
374
+ CUDNN_POINTWISE_POW = 8,
375
+ CUDNN_POINTWISE_SUB = 9,
376
+
377
+ CUDNN_POINTWISE_ABS = 10,
378
+ CUDNN_POINTWISE_CEIL = 11,
379
+ CUDNN_POINTWISE_COS = 12,
380
+ CUDNN_POINTWISE_EXP = 13,
381
+ CUDNN_POINTWISE_FLOOR = 14,
382
+ CUDNN_POINTWISE_LOG = 15,
383
+ CUDNN_POINTWISE_NEG = 16,
384
+ CUDNN_POINTWISE_RSQRT = 17,
385
+ CUDNN_POINTWISE_SIN = 18,
386
+ CUDNN_POINTWISE_SQRT = 4,
387
+ CUDNN_POINTWISE_TAN = 19,
388
+ CUDNN_POINTWISE_ERF = 20,
389
+ CUDNN_POINTWISE_IDENTITY = 21,
390
+ CUDNN_POINTWISE_RECIPROCAL = 22,
391
+ CUDNN_POINTWISE_ATAN2 = 23,
392
+
393
+ CUDNN_POINTWISE_RELU_FWD = 100,
394
+ CUDNN_POINTWISE_TANH_FWD = 101,
395
+ CUDNN_POINTWISE_SIGMOID_FWD = 102,
396
+ CUDNN_POINTWISE_ELU_FWD = 103,
397
+ CUDNN_POINTWISE_GELU_FWD = 104,
398
+ CUDNN_POINTWISE_SOFTPLUS_FWD = 105,
399
+ CUDNN_POINTWISE_SWISH_FWD = 106,
400
+ CUDNN_POINTWISE_GELU_APPROX_TANH_FWD = 107,
401
+
402
+ CUDNN_POINTWISE_RELU_BWD = 200,
403
+ CUDNN_POINTWISE_TANH_BWD = 201,
404
+ CUDNN_POINTWISE_SIGMOID_BWD = 202,
405
+ CUDNN_POINTWISE_ELU_BWD = 203,
406
+ CUDNN_POINTWISE_GELU_BWD = 204,
407
+ CUDNN_POINTWISE_SOFTPLUS_BWD = 205,
408
+ CUDNN_POINTWISE_SWISH_BWD = 206,
409
+ CUDNN_POINTWISE_GELU_APPROX_TANH_BWD = 207,
410
+
411
+ CUDNN_POINTWISE_CMP_EQ = 300,
412
+ CUDNN_POINTWISE_CMP_NEQ = 301,
413
+ CUDNN_POINTWISE_CMP_GT = 302,
414
+ CUDNN_POINTWISE_CMP_GE = 303,
415
+ CUDNN_POINTWISE_CMP_LT = 304,
416
+ CUDNN_POINTWISE_CMP_LE = 305,
417
+
418
+ CUDNN_POINTWISE_LOGICAL_AND = 400,
419
+ CUDNN_POINTWISE_LOGICAL_OR = 401,
420
+ CUDNN_POINTWISE_LOGICAL_NOT = 402,
421
+
422
+ CUDNN_POINTWISE_GEN_INDEX = 501,
423
+
424
+ CUDNN_POINTWISE_BINARY_SELECT = 601,
425
+ } cudnnPointwiseMode_t;
426
+
427
+ typedef enum {
428
+ CUDNN_RESAMPLE_NEAREST = 0,
429
+ CUDNN_RESAMPLE_BILINEAR = 1,
430
+ CUDNN_RESAMPLE_AVGPOOL = 2,
431
+ CUDNN_RESAMPLE_AVGPOOL_INCLUDE_PADDING = 2,
432
+ CUDNN_RESAMPLE_AVGPOOL_EXCLUDE_PADDING = 4,
433
+ CUDNN_RESAMPLE_MAXPOOL = 3,
434
+ } cudnnResampleMode_t;
435
+
436
+ typedef enum {
437
+ CUDNN_SIGNAL_SET = 0,
438
+ CUDNN_SIGNAL_WAIT = 1,
439
+ } cudnnSignalMode_t;
440
+
441
+ typedef enum {
442
+ CUDNN_GENSTATS_SUM_SQSUM = 0,
443
+ } cudnnGenStatsMode_t;
444
+
445
+ typedef enum {
446
+ CUDNN_BN_FINALIZE_STATISTICS_TRAINING = 0,
447
+ CUDNN_BN_FINALIZE_STATISTICS_INFERENCE = 1,
448
+ } cudnnBnFinalizeStatsMode_t;
449
+
450
+ typedef enum {
451
+ CUDNN_RNG_DISTRIBUTION_BERNOULLI,
452
+ CUDNN_RNG_DISTRIBUTION_UNIFORM,
453
+ CUDNN_RNG_DISTRIBUTION_NORMAL,
454
+ } cudnnRngDistribution_t;
455
+
456
+ typedef enum {
457
+ CUDNN_ATTR_POINTWISE_MODE = 0,
458
+ CUDNN_ATTR_POINTWISE_MATH_PREC = 1,
459
+ CUDNN_ATTR_POINTWISE_NAN_PROPAGATION CUDNN_DEPRECATED_ENUM = 2,
460
+ CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3,
461
+ CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4,
462
+ CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE = 5,
463
+ CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6,
464
+ CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7,
465
+ CUDNN_ATTR_POINTWISE_SWISH_BETA = 8,
466
+ CUDNN_ATTR_POINTWISE_AXIS = 9,
467
+
468
+ CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100,
469
+ CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101,
470
+ CUDNN_ATTR_CONVOLUTION_DILATIONS = 102,
471
+ CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES = 103,
472
+ CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104,
473
+ CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105,
474
+ CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106,
475
+
476
+ CUDNN_ATTR_ENGINEHEUR_MODE = 200,
477
+ CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH = 201,
478
+ CUDNN_ATTR_ENGINEHEUR_RESULTS = 202,
479
+ CUDNN_ATTR_ENGINEHEUR_SM_COUNT_TARGET = 203,
480
+
481
+ CUDNN_ATTR_ENGINECFG_ENGINE = 300,
482
+ CUDNN_ATTR_ENGINECFG_INTERMEDIATE_INFO = 301,
483
+ CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302,
484
+
485
+ CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400,
486
+ CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401,
487
+ CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402,
488
+ CUDNN_ATTR_EXECUTION_PLAN_COMPUTED_INTERMEDIATE_UIDS = 403,
489
+ CUDNN_ATTR_EXECUTION_PLAN_RUN_ONLY_INTERMEDIATE_UIDS = 404,
490
+ CUDNN_ATTR_EXECUTION_PLAN_JSON_REPRESENTATION = 405,
491
+
492
+ CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500,
493
+ CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501,
494
+ CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502,
495
+ CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_ATTRIBUTES = 503,
496
+
497
+ CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600,
498
+ CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE = 601,
499
+
500
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700,
501
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701,
502
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702,
503
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703,
504
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704,
505
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705,
506
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706,
507
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707,
508
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708,
509
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709,
510
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710,
511
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711,
512
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712,
513
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713,
514
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC = 714,
515
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715,
516
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716,
517
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717,
518
+
519
+ CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR = 750,
520
+ CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751,
521
+ CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752,
522
+ CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753,
523
+ CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754,
524
+ CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755,
525
+ CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756,
526
+ CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757,
527
+ CUDNN_ATTR_OPERATION_POINTWISE_TDESC = 758,
528
+
529
+ CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770,
530
+ CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC = 771,
531
+ CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772,
532
+ CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773,
533
+ CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC = 774,
534
+
535
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780,
536
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781,
537
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782,
538
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783,
539
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784,
540
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785,
541
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786,
542
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787,
543
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC = 788,
544
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789,
545
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790,
546
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791,
547
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792,
548
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793,
549
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794,
550
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795,
551
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796,
552
+
553
+ CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800,
554
+ CUDNN_ATTR_OPERATIONGRAPH_OPS = 801,
555
+ CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT = 802,
556
+
557
+ CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900,
558
+ CUDNN_ATTR_TENSOR_DATA_TYPE = 901,
559
+ CUDNN_ATTR_TENSOR_DIMENSIONS = 902,
560
+ CUDNN_ATTR_TENSOR_STRIDES = 903,
561
+ CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904,
562
+ CUDNN_ATTR_TENSOR_VECTORIZED_DIMENSION = 905,
563
+ CUDNN_ATTR_TENSOR_UNIQUE_ID = 906,
564
+ CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907,
565
+ CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908,
566
+ CUDNN_ATTR_TENSOR_REORDERING_MODE = 909,
567
+ CUDNN_ATTR_TENSOR_RAGGED_OFFSET_DESC = 913,
568
+
569
+ CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000,
570
+ CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS = 1001,
571
+ CUDNN_ATTR_VARIANT_PACK_INTERMEDIATES = 1002,
572
+ CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003,
573
+
574
+ CUDNN_ATTR_LAYOUT_INFO_TENSOR_UID = 1100,
575
+ CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101,
576
+
577
+ CUDNN_ATTR_KNOB_INFO_TYPE = 1200,
578
+ CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE = 1201,
579
+ CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE = 1202,
580
+ CUDNN_ATTR_KNOB_INFO_STRIDE = 1203,
581
+
582
+ CUDNN_ATTR_ENGINE_OPERATION_GRAPH = 1300,
583
+ CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301,
584
+ CUDNN_ATTR_ENGINE_KNOB_INFO = 1302,
585
+ CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303,
586
+ CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304,
587
+ CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305,
588
+ CUDNN_ATTR_ENGINE_SM_COUNT_TARGET = 1306,
589
+
590
+ CUDNN_ATTR_MATMUL_COMP_TYPE = 1500,
591
+ CUDNN_ATTR_MATMUL_PADDING_VALUE = 1503,
592
+
593
+ CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520,
594
+ CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521,
595
+ CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522,
596
+ CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523,
597
+ CUDNN_ATTR_OPERATION_MATMUL_IRREGULARLY_STRIDED_BATCH_COUNT CUDNN_DEPRECATED_ENUM = 1524,
598
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_M_OVERRIDE_DESC = 1525,
599
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_N_OVERRIDE_DESC = 1526,
600
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_K_OVERRIDE_DESC = 1527,
601
+
602
+ CUDNN_ATTR_REDUCTION_OPERATOR = 1600,
603
+ CUDNN_ATTR_REDUCTION_COMP_TYPE = 1601,
604
+
605
+ CUDNN_ATTR_OPERATION_REDUCTION_XDESC = 1610,
606
+ CUDNN_ATTR_OPERATION_REDUCTION_YDESC = 1611,
607
+ CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612,
608
+
609
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620,
610
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621,
611
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622,
612
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623,
613
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624,
614
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625,
615
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626,
616
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627,
617
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC = 1628,
618
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629,
619
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630,
620
+
621
+ CUDNN_ATTR_RESAMPLE_MODE = 1700,
622
+ CUDNN_ATTR_RESAMPLE_COMP_TYPE = 1701,
623
+ CUDNN_ATTR_RESAMPLE_SPATIAL_DIMS = 1702,
624
+ CUDNN_ATTR_RESAMPLE_POST_PADDINGS = 1703,
625
+ CUDNN_ATTR_RESAMPLE_PRE_PADDINGS = 1704,
626
+ CUDNN_ATTR_RESAMPLE_STRIDES = 1705,
627
+ CUDNN_ATTR_RESAMPLE_WINDOW_DIMS = 1706,
628
+ CUDNN_ATTR_RESAMPLE_NAN_PROPAGATION = 1707,
629
+ CUDNN_ATTR_RESAMPLE_PADDING_MODE = 1708,
630
+
631
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_XDESC = 1710,
632
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_YDESC = 1711,
633
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_IDXDESC = 1712,
634
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_ALPHA CUDNN_DEPRECATED_ENUM = 1713,
635
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_BETA CUDNN_DEPRECATED_ENUM = 1714,
636
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_DESC = 1716,
637
+
638
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DXDESC = 1720,
639
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DYDESC = 1721,
640
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_IDXDESC = 1722,
641
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_ALPHA CUDNN_DEPRECATED_ENUM = 1723,
642
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_BETA CUDNN_DEPRECATED_ENUM = 1724,
643
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DESC = 1725,
644
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_XDESC = 1726,
645
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_YDESC = 1727,
646
+
647
+ CUDNN_ATTR_OPERATION_CONCAT_AXIS = 1800,
648
+ CUDNN_ATTR_OPERATION_CONCAT_INPUT_DESCS = 1801,
649
+ CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX = 1802,
650
+ CUDNN_ATTR_OPERATION_CONCAT_OUTPUT_DESC = 1803,
651
+
652
+ CUDNN_ATTR_OPERATION_SIGNAL_MODE = 1900,
653
+ CUDNN_ATTR_OPERATION_SIGNAL_FLAGDESC = 1901,
654
+ CUDNN_ATTR_OPERATION_SIGNAL_VALUE = 1902,
655
+ CUDNN_ATTR_OPERATION_SIGNAL_XDESC = 1903,
656
+ CUDNN_ATTR_OPERATION_SIGNAL_YDESC = 1904,
657
+
658
+ CUDNN_ATTR_OPERATION_NORM_FWD_MODE = 2000,
659
+ CUDNN_ATTR_OPERATION_NORM_FWD_PHASE = 2001,
660
+ CUDNN_ATTR_OPERATION_NORM_FWD_XDESC = 2002,
661
+ CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC = 2003,
662
+ CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC = 2004,
663
+ CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC = 2005,
664
+ CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC = 2006,
665
+ CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC = 2007,
666
+ CUDNN_ATTR_OPERATION_NORM_FWD_EXP_AVG_FACTOR_DESC = 2008,
667
+ CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_MEAN_DESC = 2009,
668
+ CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_VAR_DESC = 2010,
669
+ CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_MEAN_DESC = 2011,
670
+ CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_VAR_DESC = 2012,
671
+ CUDNN_ATTR_OPERATION_NORM_FWD_YDESC = 2013,
672
+ CUDNN_ATTR_OPERATION_NORM_FWD_PEER_STAT_DESCS = 2014,
673
+
674
+ CUDNN_ATTR_OPERATION_NORM_BWD_MODE = 2100,
675
+ CUDNN_ATTR_OPERATION_NORM_BWD_XDESC = 2101,
676
+ CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC = 2102,
677
+ CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC = 2103,
678
+ CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC = 2104,
679
+ CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC = 2105,
680
+ CUDNN_ATTR_OPERATION_NORM_BWD_EPSILON_DESC = 2106,
681
+ CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC = 2107,
682
+ CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC = 2108,
683
+ CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC = 2109,
684
+ CUDNN_ATTR_OPERATION_NORM_BWD_PEER_STAT_DESCS = 2110,
685
+
686
+ CUDNN_ATTR_OPERATION_RESHAPE_XDESC = 2200,
687
+ CUDNN_ATTR_OPERATION_RESHAPE_YDESC = 2201,
688
+
689
+ CUDNN_ATTR_RNG_DISTRIBUTION = 2300,
690
+ CUDNN_ATTR_RNG_NORMAL_DIST_MEAN = 2301,
691
+ CUDNN_ATTR_RNG_NORMAL_DIST_STANDARD_DEVIATION = 2302,
692
+ CUDNN_ATTR_RNG_UNIFORM_DIST_MAXIMUM = 2303,
693
+ CUDNN_ATTR_RNG_UNIFORM_DIST_MINIMUM = 2304,
694
+ CUDNN_ATTR_RNG_BERNOULLI_DIST_PROBABILITY = 2305,
695
+
696
+ CUDNN_ATTR_OPERATION_RNG_YDESC = 2310,
697
+ CUDNN_ATTR_OPERATION_RNG_SEED = 2311,
698
+ CUDNN_ATTR_OPERATION_RNG_DESC = 2312,
699
+ CUDNN_ATTR_OPERATION_RNG_OFFSET_DESC = 2313,
700
+ } cudnnBackendAttributeName_t;
701
+
702
+ typedef enum {
703
+ CUDNN_TYPE_HANDLE = 0,
704
+ CUDNN_TYPE_DATA_TYPE,
705
+ CUDNN_TYPE_BOOLEAN,
706
+ CUDNN_TYPE_INT64,
707
+ CUDNN_TYPE_FLOAT,
708
+ CUDNN_TYPE_DOUBLE,
709
+ CUDNN_TYPE_VOID_PTR,
710
+ CUDNN_TYPE_CONVOLUTION_MODE,
711
+ CUDNN_TYPE_HEUR_MODE,
712
+ CUDNN_TYPE_KNOB_TYPE,
713
+ CUDNN_TYPE_NAN_PROPOGATION CUDNN_DEPRECATED_ENUM,
714
+ CUDNN_TYPE_NUMERICAL_NOTE,
715
+ CUDNN_TYPE_LAYOUT_TYPE,
716
+ CUDNN_TYPE_ATTRIB_NAME,
717
+ CUDNN_TYPE_POINTWISE_MODE,
718
+ CUDNN_TYPE_BACKEND_DESCRIPTOR,
719
+ CUDNN_TYPE_GENSTATS_MODE,
720
+ CUDNN_TYPE_BN_FINALIZE_STATS_MODE,
721
+ CUDNN_TYPE_REDUCTION_OPERATOR_TYPE,
722
+ CUDNN_TYPE_BEHAVIOR_NOTE,
723
+ CUDNN_TYPE_TENSOR_REORDERING_MODE,
724
+ CUDNN_TYPE_RESAMPLE_MODE,
725
+ CUDNN_TYPE_PADDING_MODE,
726
+ CUDNN_TYPE_INT32,
727
+ CUDNN_TYPE_CHAR,
728
+ CUDNN_TYPE_SIGNAL_MODE,
729
+ CUDNN_TYPE_FRACTION,
730
+ CUDNN_TYPE_NORM_MODE,
731
+ CUDNN_TYPE_NORM_FWD_PHASE,
732
+ CUDNN_TYPE_RNG_DISTRIBUTION
733
+ } cudnnBackendAttributeType_t;
734
+
735
+ typedef enum {
736
+ CUDNN_BACKEND_POINTWISE_DESCRIPTOR = 0,
737
+ CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR,
738
+ CUDNN_BACKEND_ENGINE_DESCRIPTOR,
739
+ CUDNN_BACKEND_ENGINECFG_DESCRIPTOR,
740
+ CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR,
741
+ CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR,
742
+ CUDNN_BACKEND_INTERMEDIATE_INFO_DESCRIPTOR,
743
+ CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR,
744
+ CUDNN_BACKEND_KNOB_INFO_DESCRIPTOR,
745
+ CUDNN_BACKEND_LAYOUT_INFO_DESCRIPTOR,
746
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
747
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
748
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
749
+ CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR,
750
+ CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR,
751
+ CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR,
752
+ CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR,
753
+ CUDNN_BACKEND_TENSOR_DESCRIPTOR,
754
+ CUDNN_BACKEND_MATMUL_DESCRIPTOR,
755
+ CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR,
756
+ CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR,
757
+ CUDNN_BACKEND_REDUCTION_DESCRIPTOR,
758
+ CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR,
759
+ CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR,
760
+ CUDNN_BACKEND_RESAMPLE_DESCRIPTOR,
761
+ CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR,
762
+ CUDNN_BACKEND_OPERATION_RESAMPLE_BWD_DESCRIPTOR,
763
+ CUDNN_BACKEND_OPERATION_CONCAT_DESCRIPTOR,
764
+ CUDNN_BACKEND_OPERATION_SIGNAL_DESCRIPTOR,
765
+ CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR,
766
+ CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR,
767
+ CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR,
768
+ CUDNN_BACKEND_RNG_DESCRIPTOR,
769
+ CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR,
770
+ } cudnnBackendDescriptorType_t;
771
+
772
+ typedef enum {
773
+ CUDNN_NUMERICAL_NOTE_TENSOR_CORE = 0,
774
+ CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS,
775
+ CUDNN_NUMERICAL_NOTE_REDUCED_PRECISION_REDUCTION,
776
+ CUDNN_NUMERICAL_NOTE_FFT,
777
+ CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC,
778
+ CUDNN_NUMERICAL_NOTE_WINOGRAD,
779
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_4x4,
780
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_6x6,
781
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_13x13,
782
+ CUDNN_NUMERICAL_NOTE_STRICT_NAN_PROP,
783
+ CUDNN_NUMERICAL_NOTE_TYPE_COUNT,
784
+ } cudnnBackendNumericalNote_t;
785
+
786
+ typedef enum {
787
+ CUDNN_BEHAVIOR_NOTE_RUNTIME_COMPILATION = 0,
788
+ CUDNN_BEHAVIOR_NOTE_REQUIRES_FILTER_INT8x32_REORDER = 1,
789
+ CUDNN_BEHAVIOR_NOTE_REQUIRES_BIAS_INT8x32_REORDER = 2,
790
+ CUDNN_BEHAVIOR_NOTE_TYPE_COUNT,
791
+ } cudnnBackendBehaviorNote_t;
792
+
793
+ typedef enum {
794
+ CUDNN_KNOB_TYPE_SPLIT_K CUDNN_DEPRECATED_ENUM = 0,
795
+ CUDNN_KNOB_TYPE_SWIZZLE = 1,
796
+ CUDNN_KNOB_TYPE_TILE_SIZE = 2,
797
+ CUDNN_KNOB_TYPE_USE_TEX CUDNN_DEPRECATED_ENUM = 3,
798
+ CUDNN_KNOB_TYPE_EDGE = 4,
799
+ CUDNN_KNOB_TYPE_KBLOCK CUDNN_DEPRECATED_ENUM = 5,
800
+ CUDNN_KNOB_TYPE_LDGA CUDNN_DEPRECATED_ENUM = 6,
801
+ CUDNN_KNOB_TYPE_LDGB CUDNN_DEPRECATED_ENUM = 7,
802
+ CUDNN_KNOB_TYPE_CHUNK_K CUDNN_DEPRECATED_ENUM = 8,
803
+ CUDNN_KNOB_TYPE_SPLIT_H CUDNN_DEPRECATED_ENUM = 9,
804
+ CUDNN_KNOB_TYPE_WINO_TILE CUDNN_DEPRECATED_ENUM = 10,
805
+ CUDNN_KNOB_TYPE_MULTIPLY = 11,
806
+ CUDNN_KNOB_TYPE_SPLIT_K_BUF = 12,
807
+ CUDNN_KNOB_TYPE_TILEK = 13,
808
+ CUDNN_KNOB_TYPE_STAGES = 14,
809
+ CUDNN_KNOB_TYPE_REDUCTION_MODE = 15,
810
+ CUDNN_KNOB_TYPE_CTA_SPLIT_K_MODE CUDNN_DEPRECATED_ENUM = 16,
811
+ CUDNN_KNOB_TYPE_SPLIT_K_SLC = 17,
812
+ CUDNN_KNOB_TYPE_IDX_MODE CUDNN_DEPRECATED_ENUM = 18,
813
+ CUDNN_KNOB_TYPE_SLICED CUDNN_DEPRECATED_ENUM = 19,
814
+ CUDNN_KNOB_TYPE_SPLIT_RS CUDNN_DEPRECATED_ENUM = 20,
815
+ CUDNN_KNOB_TYPE_SINGLEBUFFER CUDNN_DEPRECATED_ENUM = 21,
816
+ CUDNN_KNOB_TYPE_LDGC CUDNN_DEPRECATED_ENUM = 22,
817
+ CUDNN_KNOB_TYPE_SPECFILT = 23,
818
+ CUDNN_KNOB_TYPE_KERNEL_CFG = 24,
819
+ CUDNN_KNOB_TYPE_WORKSPACE = 25,
820
+ CUDNN_KNOB_TYPE_TILE_CGA CUDNN_DEPRECATED_ENUM = 26,
821
+ CUDNN_KNOB_TYPE_TILE_CGA_M = 27,
822
+ CUDNN_KNOB_TYPE_TILE_CGA_N = 28,
823
+ CUDNN_KNOB_TYPE_BLOCK_SIZE = 29,
824
+ CUDNN_KNOB_TYPE_OCCUPANCY = 30,
825
+ CUDNN_KNOB_TYPE_ARRAY_SIZE_PER_THREAD = 31,
826
+ CUDNN_KNOB_TYPE_NUM_C_PER_BLOCK CUDNN_DEPRECATED_ENUM = 32,
827
+ CUDNN_KNOB_TYPE_SPLIT_COLS = 33,
828
+ CUDNN_KNOB_TYPE_TILE_ROWS = 34,
829
+ CUDNN_KNOB_TYPE_TILE_COLS = 35,
830
+ CUDNN_KNOB_TYPE_LOAD_SIZE = 36,
831
+ CUDNN_KNOB_TYPE_COUNTS,
832
+ } cudnnBackendKnobType_t;
833
+
834
+ typedef enum {
835
+ CUDNN_LAYOUT_TYPE_PREFERRED_NCHW = 0,
836
+ CUDNN_LAYOUT_TYPE_PREFERRED_NHWC = 1,
837
+ CUDNN_LAYOUT_TYPE_PREFERRED_PAD4CK = 2,
838
+ CUDNN_LAYOUT_TYPE_PREFERRED_PAD8CK = 3,
839
+ CUDNN_LAYOUT_TYPE_COUNT = 4,
840
+ } cudnnBackendLayoutType_t;
841
+
842
+ typedef enum {
843
+ CUDNN_HEUR_MODE_INSTANT = 0,
844
+ CUDNN_HEUR_MODE_B = 1,
845
+ CUDNN_HEUR_MODE_FALLBACK = 2,
846
+ CUDNN_HEUR_MODE_A = 3,
847
+ CUDNN_HEUR_MODES_COUNT = 4,
848
+ } cudnnBackendHeurMode_t;
849
+
850
+ typedef enum {
851
+ CUDNN_TENSOR_REORDERING_NONE = 0,
852
+ CUDNN_TENSOR_REORDERING_INT8x32 = 1,
853
+ CUDNN_TENSOR_REORDERING_F16x16 = 2,
854
+ } cudnnBackendTensorReordering_t;
855
+
856
+ typedef enum {
857
+ CUDNN_ZERO_PAD = 0,
858
+ CUDNN_NEG_INF_PAD = 1,
859
+ CUDNN_EDGE_VAL_PAD = 2,
860
+ } cudnnPaddingMode_t;
861
+
862
+ typedef enum {
863
+ CUDNN_LAYER_NORM = 0,
864
+ CUDNN_INSTANCE_NORM = 1,
865
+ CUDNN_BATCH_NORM = 2,
866
+ CUDNN_GROUP_NORM = 3,
867
+ CUDNN_RMS_NORM = 4,
868
+ } cudnnBackendNormMode_t;
869
+
870
+ typedef enum {
871
+ CUDNN_NORM_FWD_INFERENCE = 0,
872
+ CUDNN_NORM_FWD_TRAINING = 1,
873
+ } cudnnBackendNormFwdPhase_t;
874
+
875
+ cudnnStatus_t CUDNNWINAPI
876
+ cudnnBackendCreateDescriptor(cudnnBackendDescriptorType_t descriptorType, cudnnBackendDescriptor_t *descriptor);
877
+
878
+ cudnnStatus_t CUDNNWINAPI
879
+ cudnnBackendDestroyDescriptor(cudnnBackendDescriptor_t descriptor);
880
+
881
+ cudnnStatus_t CUDNNWINAPI
882
+ cudnnBackendInitialize(cudnnBackendDescriptor_t descriptor);
883
+
884
+ cudnnStatus_t CUDNNWINAPI
885
+ cudnnBackendFinalize(cudnnBackendDescriptor_t descriptor);
886
+
887
+ cudnnStatus_t CUDNNWINAPI
888
+ cudnnBackendSetAttribute(cudnnBackendDescriptor_t descriptor,
889
+ cudnnBackendAttributeName_t attributeName,
890
+ cudnnBackendAttributeType_t attributeType,
891
+ int64_t elementCount,
892
+ const void *arrayOfElements);
893
+
894
+ cudnnStatus_t CUDNNWINAPI
895
+ cudnnBackendGetAttribute(cudnnBackendDescriptor_t const descriptor,
896
+ cudnnBackendAttributeName_t attributeName,
897
+ cudnnBackendAttributeType_t attributeType,
898
+ int64_t requestedElementCount,
899
+ int64_t *elementCount,
900
+ void *arrayOfElements);
901
+
902
+ cudnnStatus_t CUDNNWINAPI
903
+ cudnnBackendExecute(cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBackendDescriptor_t variantPack);
904
+
905
+ #if defined(__cplusplus)
906
+ }
907
+ #endif
908
+
909
+ #endif /* CUDNN_GRAPH_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops_v9.h ADDED
@@ -0,0 +1,1316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_ops : cuDNN's basic definitions and basic operations.
52
+ */
53
+
54
+ #if !defined(CUDNN_OPS_H_)
55
+ #define CUDNN_OPS_H_
56
+
57
+ #include <stdint.h>
58
+
59
+ #include "cudnn_version.h"
60
+ #include "cudnn_graph.h"
61
+
62
+ /* These version numbers are autogenerated, do not edit manually. */
63
+ #define CUDNN_OPS_MAJOR 9
64
+ #define CUDNN_OPS_MINOR 1
65
+ #define CUDNN_OPS_PATCH 0
66
+
67
+ #if (CUDNN_OPS_MAJOR != CUDNN_MAJOR) || (CUDNN_OPS_MINOR != CUDNN_MINOR) || (CUDNN_OPS_PATCH != CUDNN_PATCHLEVEL)
68
+ #error Version mismatch in cuDNN OPS INFER!!!
69
+ #endif
70
+
71
+ #if defined(__cplusplus)
72
+ extern "C" {
73
+ #endif
74
+
75
+ /* Data structures to represent Image/Filter and the Neural Network Layer */
76
+ typedef struct cudnnTensorStruct *cudnnTensorDescriptor_t;
77
+ typedef struct cudnnPoolingStruct *cudnnPoolingDescriptor_t CUDNN_DEPRECATED;
78
+ typedef struct cudnnFilterStruct *cudnnFilterDescriptor_t CUDNN_DEPRECATED;
79
+ typedef struct cudnnLRNStruct *cudnnLRNDescriptor_t;
80
+ typedef struct cudnnActivationStruct *cudnnActivationDescriptor_t CUDNN_DEPRECATED;
81
+ typedef struct cudnnSpatialTransformerStruct *cudnnSpatialTransformerDescriptor_t;
82
+ typedef struct cudnnOpTensorStruct *cudnnOpTensorDescriptor_t CUDNN_DEPRECATED;
83
+ typedef struct cudnnReduceTensorStruct *cudnnReduceTensorDescriptor_t CUDNN_DEPRECATED;
84
+ typedef struct cudnnCTCLossStruct *cudnnCTCLossDescriptor_t;
85
+ typedef struct cudnnTensorTransformStruct *cudnnTensorTransformDescriptor_t CUDNN_DEPRECATED;
86
+ /*
87
+ * CUDNN Determinism
88
+ */
89
+ typedef enum {
90
+ CUDNN_NON_DETERMINISTIC = 0,
91
+ CUDNN_DETERMINISTIC = 1,
92
+ } cudnnDeterminism_t;
93
+
94
+ /* Create an instance of a generic Tensor descriptor */
95
+ cudnnStatus_t CUDNNWINAPI
96
+ cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc);
97
+
98
+ cudnnStatus_t CUDNNWINAPI
99
+ cudnnSetTensor4dDescriptor(cudnnTensorDescriptor_t tensorDesc,
100
+ cudnnTensorFormat_t format,
101
+ cudnnDataType_t dataType, /* image data type */
102
+ int n, /* number of inputs (batch size) */
103
+ int c, /* number of input feature maps */
104
+ int h, /* height of input section */
105
+ int w); /* width of input section */
106
+
107
+ cudnnStatus_t CUDNNWINAPI
108
+ cudnnSetTensor4dDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
109
+ cudnnDataType_t dataType, /* image data type */
110
+ int n, /* number of inputs (batch size) */
111
+ int c, /* number of input feature maps */
112
+ int h, /* height of input section */
113
+ int w, /* width of input section */
114
+ int nStride,
115
+ int cStride,
116
+ int hStride,
117
+ int wStride);
118
+
119
+ cudnnStatus_t CUDNNWINAPI
120
+ cudnnGetTensor4dDescriptor(const cudnnTensorDescriptor_t tensorDesc,
121
+ cudnnDataType_t *dataType, /* image data type */
122
+ int *n, /* number of inputs (batch size) */
123
+ int *c, /* number of input feature maps */
124
+ int *h, /* height of input section */
125
+ int *w, /* width of input section */
126
+ int *nStride,
127
+ int *cStride,
128
+ int *hStride,
129
+ int *wStride);
130
+
131
+ cudnnStatus_t CUDNNWINAPI
132
+ cudnnSetTensorNdDescriptor(cudnnTensorDescriptor_t tensorDesc,
133
+ cudnnDataType_t dataType,
134
+ int nbDims,
135
+ const int dimA[],
136
+ const int strideA[]);
137
+
138
+ cudnnStatus_t CUDNNWINAPI
139
+ cudnnSetTensorNdDescriptorEx(cudnnTensorDescriptor_t tensorDesc,
140
+ cudnnTensorFormat_t format,
141
+ cudnnDataType_t dataType,
142
+ int nbDims,
143
+ const int dimA[]);
144
+
145
+ cudnnStatus_t CUDNNWINAPI
146
+ cudnnGetTensorNdDescriptor(const cudnnTensorDescriptor_t tensorDesc,
147
+ int nbDimsRequested,
148
+ cudnnDataType_t *dataType,
149
+ int *nbDims,
150
+ int dimA[],
151
+ int strideA[]);
152
+
153
+ cudnnStatus_t CUDNNWINAPI
154
+ cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size);
155
+
156
+ /* PixelOffset( n, c, h, w ) = n *input_stride + c * feature_stride + h * h_stride + w * w_stride
157
+
158
+ 1)Example of all images in row major order one batch of features after the other (with an optional padding on row)
159
+ input_stride : c x h x h_stride
160
+ feature_stride : h x h_stride
161
+ h_stride : >= w ( h_stride = w if no padding)
162
+ w_stride : 1
163
+
164
+
165
+ 2)Example of all images in row major with features maps interleaved
166
+ input_stride : c x h x h_stride
167
+ feature_stride : 1
168
+ h_stride : w x c
169
+ w_stride : c
170
+
171
+ 3)Example of all images in column major order one batch of features after the other (with optional padding on column)
172
+ input_stride : c x w x w_stride
173
+ feature_stride : w x w_stride
174
+ h_stride : 1
175
+ w_stride : >= h
176
+
177
+ */
178
+
179
+ /* Destroy an instance of Tensor4d descriptor */
180
+ cudnnStatus_t CUDNNWINAPI
181
+ cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc);
182
+
183
+ /* Fold/unfold transforms */
184
+ typedef enum {
185
+ CUDNN_TRANSFORM_FOLD = 0U,
186
+ CUDNN_TRANSFORM_UNFOLD = 1U,
187
+ } cudnnFoldingDirection_t;
188
+
189
+ /** Create a destination descriptor for cudnnTransformTensor */
190
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
191
+ cudnnInitTransformDest(const cudnnTensorTransformDescriptor_t transformDesc,
192
+ const cudnnTensorDescriptor_t srcDesc,
193
+ cudnnTensorDescriptor_t destDesc,
194
+ size_t *destSizeInBytes);
195
+
196
+ /** Create an empty tensor transform descriptor */
197
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
198
+ cudnnCreateTensorTransformDescriptor(cudnnTensorTransformDescriptor_t *transformDesc);
199
+
200
+ /** Initialize a previously created tensor transform descriptor. */
201
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
202
+ cudnnSetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
203
+ const uint32_t nbDims,
204
+ const cudnnTensorFormat_t destFormat,
205
+ const int32_t padBeforeA[],
206
+ const int32_t padAfterA[],
207
+ const uint32_t foldA[],
208
+ const cudnnFoldingDirection_t direction);
209
+
210
+ /**
211
+ * Retrieves the values stored in a previously initialized tensor transform
212
+ * descriptor.
213
+ */
214
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
215
+ cudnnGetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc,
216
+ uint32_t nbDimsRequested,
217
+ cudnnTensorFormat_t *destFormat,
218
+ int32_t padBeforeA[],
219
+ int32_t padAfterA[],
220
+ uint32_t foldA[],
221
+ cudnnFoldingDirection_t *direction);
222
+
223
+ /**
224
+ * Destroys a previously created tensor transform descriptor.
225
+ */
226
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
227
+ cudnnDestroyTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc);
228
+
229
+ /* Tensor layout conversion helper (y = alpha * x + beta * y) */
230
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
231
+ cudnnTransformTensor(cudnnHandle_t handle,
232
+ const void *alpha,
233
+ const cudnnTensorDescriptor_t xDesc,
234
+ const void *x,
235
+ const void *beta,
236
+ const cudnnTensorDescriptor_t yDesc,
237
+ void *y);
238
+
239
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
240
+ cudnnTransformTensorEx(cudnnHandle_t handle,
241
+ const cudnnTensorTransformDescriptor_t transDesc,
242
+ const void *alpha,
243
+ const cudnnTensorDescriptor_t srcDesc,
244
+ const void *srcData,
245
+ const void *beta,
246
+ const cudnnTensorDescriptor_t destDesc,
247
+ void *destData);
248
+
249
+ /* Tensor Bias addition : C = alpha * A + beta * C */
250
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
251
+ cudnnAddTensor(cudnnHandle_t handle,
252
+ const void *alpha,
253
+ const cudnnTensorDescriptor_t aDesc,
254
+ const void *A,
255
+ const void *beta,
256
+ const cudnnTensorDescriptor_t cDesc,
257
+ void *C);
258
+
259
+ /*
260
+ * CUDNN OpTensor op type
261
+ */
262
+ typedef enum {
263
+ CUDNN_OP_TENSOR_ADD = 0,
264
+ CUDNN_OP_TENSOR_MUL = 1,
265
+ CUDNN_OP_TENSOR_MIN = 2,
266
+ CUDNN_OP_TENSOR_MAX = 3,
267
+ CUDNN_OP_TENSOR_SQRT = 4,
268
+ CUDNN_OP_TENSOR_NOT = 5,
269
+ } cudnnOpTensorOp_t;
270
+
271
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
272
+ cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc);
273
+
274
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
275
+ cudnnSetOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc,
276
+ cudnnOpTensorOp_t opTensorOp,
277
+ cudnnDataType_t opTensorCompType,
278
+ cudnnNanPropagation_t opTensorNanOpt);
279
+
280
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
281
+ cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc,
282
+ cudnnOpTensorOp_t *opTensorOp,
283
+ cudnnDataType_t *opTensorCompType,
284
+ cudnnNanPropagation_t *opTensorNanOpt);
285
+
286
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
287
+ cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc);
288
+
289
+ /* Tensor operation : C = op( alpha1 * A, alpha2 * B ) + beta * C */
290
+ /* B tensor is ignored for CUDNN_OP_TENSOR_SQRT, CUDNN_OP_TENSOR_NOT. */
291
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
292
+ cudnnOpTensor(cudnnHandle_t handle,
293
+ const cudnnOpTensorDescriptor_t opTensorDesc,
294
+ const void *alpha1,
295
+ const cudnnTensorDescriptor_t aDesc,
296
+ const void *A,
297
+ const void *alpha2,
298
+ const cudnnTensorDescriptor_t bDesc,
299
+ const void *B,
300
+ const void *beta,
301
+ const cudnnTensorDescriptor_t cDesc,
302
+ void *C);
303
+
304
+ /*
305
+ * CUDNN ReduceTensor indices type
306
+ */
307
+ typedef enum {
308
+ CUDNN_REDUCE_TENSOR_NO_INDICES = 0,
309
+ CUDNN_REDUCE_TENSOR_FLATTENED_INDICES = 1,
310
+ } cudnnReduceTensorIndices_t CUDNN_DEPRECATED;
311
+
312
+ /*
313
+ * CUDNN tensor indices type size (all unsigned)
314
+ * Currently not supported, default is 32 bit unsigned.
315
+ */
316
+ typedef enum {
317
+ CUDNN_32BIT_INDICES = 0,
318
+ CUDNN_64BIT_INDICES = 1,
319
+ CUDNN_16BIT_INDICES = 2,
320
+ CUDNN_8BIT_INDICES = 3,
321
+ } cudnnIndicesType_t CUDNN_DEPRECATED;
322
+
323
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
324
+ cudnnCreateReduceTensorDescriptor(cudnnReduceTensorDescriptor_t *reduceTensorDesc);
325
+
326
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
327
+ cudnnSetReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc,
328
+ cudnnReduceTensorOp_t reduceTensorOp,
329
+ cudnnDataType_t reduceTensorCompType,
330
+ cudnnNanPropagation_t reduceTensorNanOpt,
331
+ cudnnReduceTensorIndices_t reduceTensorIndices,
332
+ cudnnIndicesType_t reduceTensorIndicesType);
333
+
334
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
335
+ cudnnGetReduceTensorDescriptor(const cudnnReduceTensorDescriptor_t reduceTensorDesc,
336
+ cudnnReduceTensorOp_t *reduceTensorOp,
337
+ cudnnDataType_t *reduceTensorCompType,
338
+ cudnnNanPropagation_t *reduceTensorNanOpt,
339
+ cudnnReduceTensorIndices_t *reduceTensorIndices,
340
+ cudnnIndicesType_t *reduceTensorIndicesType);
341
+
342
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
343
+ cudnnDestroyReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc);
344
+
345
+ /* Helper function to return the minimum size of the index space to be passed to the reduction given the input and
346
+ * output tensors */
347
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
348
+ cudnnGetReductionIndicesSize(cudnnHandle_t handle,
349
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
350
+ const cudnnTensorDescriptor_t aDesc,
351
+ const cudnnTensorDescriptor_t cDesc,
352
+ size_t *sizeInBytes);
353
+
354
+ /* Helper function to return the minimum size of the workspace to be passed to the reduction given the input and output
355
+ * tensors */
356
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
357
+ cudnnGetReductionWorkspaceSize(cudnnHandle_t handle,
358
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
359
+ const cudnnTensorDescriptor_t aDesc,
360
+ const cudnnTensorDescriptor_t cDesc,
361
+ size_t *sizeInBytes);
362
+
363
+ /* Tensor operation : C = reduce op( alpha * A ) + beta * C */
364
+ /* The NaN propagation enum applies to only the min and max reduce ops; the other reduce ops propagate NaN as usual. */
365
+ /* The indices space is ignored for reduce ops other than min or max. */
366
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
367
+ cudnnReduceTensor(cudnnHandle_t handle,
368
+ const cudnnReduceTensorDescriptor_t reduceTensorDesc,
369
+ void *indices,
370
+ size_t indicesSizeInBytes,
371
+ void *workspace,
372
+ size_t workspaceSizeInBytes,
373
+ const void *alpha,
374
+ const cudnnTensorDescriptor_t aDesc,
375
+ const void *A,
376
+ const void *beta,
377
+ const cudnnTensorDescriptor_t cDesc,
378
+ void *C);
379
+
380
+ /* Set all values of a tensor to a given value : y[i] = value[0] */
381
+ cudnnStatus_t CUDNNWINAPI
382
+ cudnnSetTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *valuePtr);
383
+
384
+ /* Scale all values of a tensor by a given factor : y[i] = alpha * y[i] */
385
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
386
+ cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *alpha);
387
+
388
+ /* Create an instance of FilterStruct */
389
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
390
+ cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc);
391
+
392
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
393
+ cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc,
394
+ cudnnDataType_t dataType, /* image data type */
395
+ cudnnTensorFormat_t format,
396
+ int k, /* number of output feature maps */
397
+ int c, /* number of input feature maps */
398
+ int h, /* height of each input filter */
399
+ int w); /* width of each input filter */
400
+
401
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
402
+ cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc,
403
+ cudnnDataType_t *dataType, /* image data type */
404
+ cudnnTensorFormat_t *format,
405
+ int *k, /* number of output feature maps */
406
+ int *c, /* number of input feature maps */
407
+ int *h, /* height of each input filter */
408
+ int *w); /* width of each input filter */
409
+
410
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
411
+ cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc,
412
+ cudnnDataType_t dataType, /* image data type */
413
+ cudnnTensorFormat_t format,
414
+ int nbDims,
415
+ const int filterDimA[]);
416
+
417
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
418
+ cudnnGetFilterNdDescriptor(const cudnnFilterDescriptor_t filterDesc,
419
+ int nbDimsRequested,
420
+ cudnnDataType_t *dataType, /* image data type */
421
+ cudnnTensorFormat_t *format,
422
+ int *nbDims,
423
+ int filterDimA[]);
424
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
425
+ cudnnGetFilterSizeInBytes(const cudnnFilterDescriptor_t filterDesc, size_t *size);
426
+
427
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
428
+ cudnnTransformFilter(cudnnHandle_t handle,
429
+ const cudnnTensorTransformDescriptor_t transDesc,
430
+ const void *alpha,
431
+ const cudnnFilterDescriptor_t srcDesc,
432
+ const void *srcData,
433
+ const void *beta,
434
+ const cudnnFilterDescriptor_t destDesc,
435
+ void *destData);
436
+
437
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
438
+ cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc);
439
+
440
+ /*
441
+ * softmax algorithm
442
+ */
443
+ typedef enum {
444
+ CUDNN_SOFTMAX_FAST = 0, /* straightforward implementation */
445
+ CUDNN_SOFTMAX_ACCURATE = 1, /* subtract max from every point to avoid overflow */
446
+ CUDNN_SOFTMAX_LOG = 2
447
+ } cudnnSoftmaxAlgorithm_t;
448
+
449
+ typedef enum {
450
+ CUDNN_SOFTMAX_MODE_INSTANCE = 0, /* compute the softmax over all C, H, W for each N */
451
+ CUDNN_SOFTMAX_MODE_CHANNEL = 1 /* compute the softmax over all C for each H, W, N */
452
+ } cudnnSoftmaxMode_t;
453
+
454
+ /* Softmax functions: All of the form "output = alpha * Op(inputs) + beta * output" */
455
+
456
+ /* Function to perform forward softmax */
457
+ cudnnStatus_t CUDNNWINAPI
458
+ cudnnSoftmaxForward(cudnnHandle_t handle,
459
+ cudnnSoftmaxAlgorithm_t algo,
460
+ cudnnSoftmaxMode_t mode,
461
+ const void *alpha,
462
+ const cudnnTensorDescriptor_t xDesc,
463
+ const void *x,
464
+ const void *beta,
465
+ const cudnnTensorDescriptor_t yDesc,
466
+ void *y);
467
+
468
+ /*
469
+ * pooling mode
470
+ */
471
+ typedef enum {
472
+ CUDNN_POOLING_MAX = 0,
473
+ CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values */
474
+ CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values */
475
+ CUDNN_POOLING_MAX_DETERMINISTIC = 3
476
+ } cudnnPoolingMode_t CUDNN_DEPRECATED;
477
+
478
+ /* Create an instance of pooling descriptor */
479
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
480
+ cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
481
+
482
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
483
+ cudnnSetPooling2dDescriptor(cudnnPoolingDescriptor_t poolingDesc,
484
+ cudnnPoolingMode_t mode,
485
+ cudnnNanPropagation_t maxpoolingNanOpt,
486
+ int windowHeight,
487
+ int windowWidth,
488
+ int verticalPadding,
489
+ int horizontalPadding,
490
+ int verticalStride,
491
+ int horizontalStride);
492
+
493
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
494
+ cudnnGetPooling2dDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
495
+ cudnnPoolingMode_t *mode,
496
+ cudnnNanPropagation_t *maxpoolingNanOpt,
497
+ int *windowHeight,
498
+ int *windowWidth,
499
+ int *verticalPadding,
500
+ int *horizontalPadding,
501
+ int *verticalStride,
502
+ int *horizontalStride);
503
+
504
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
505
+ cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc,
506
+ const cudnnPoolingMode_t mode,
507
+ const cudnnNanPropagation_t maxpoolingNanOpt,
508
+ int nbDims,
509
+ const int windowDimA[],
510
+ const int paddingA[],
511
+ const int strideA[]);
512
+
513
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
514
+ cudnnGetPoolingNdDescriptor(const cudnnPoolingDescriptor_t poolingDesc,
515
+ int nbDimsRequested,
516
+ cudnnPoolingMode_t *mode,
517
+ cudnnNanPropagation_t *maxpoolingNanOpt,
518
+ int *nbDims,
519
+ int windowDimA[],
520
+ int paddingA[],
521
+ int strideA[]);
522
+
523
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
524
+ cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
525
+ const cudnnTensorDescriptor_t inputTensorDesc,
526
+ int nbDims,
527
+ int outputTensorDimA[]);
528
+
529
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
530
+ cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc,
531
+ const cudnnTensorDescriptor_t inputTensorDesc,
532
+ int *n,
533
+ int *c,
534
+ int *h,
535
+ int *w);
536
+
537
+ /* Destroy an instance of pooling descriptor */
538
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
539
+ cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc);
540
+
541
+ /* Pooling functions: All of the form "output = alpha * Op(inputs) + beta * output" */
542
+
543
+ /* Function to perform forward pooling */
544
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
545
+ cudnnPoolingForward(cudnnHandle_t handle,
546
+ const cudnnPoolingDescriptor_t poolingDesc,
547
+ const void *alpha,
548
+ const cudnnTensorDescriptor_t xDesc,
549
+ const void *x,
550
+ const void *beta,
551
+ const cudnnTensorDescriptor_t yDesc,
552
+ void *y);
553
+
554
+ /* Activation functions: All of the form "output = alpha * Op(inputs) + beta * output" */
555
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
556
+ cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc);
557
+
558
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
559
+ cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc,
560
+ cudnnActivationMode_t mode,
561
+ cudnnNanPropagation_t reluNanOpt,
562
+ double coef); /* ceiling for clipped RELU, alpha for ELU */
563
+
564
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
565
+ cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc,
566
+ cudnnActivationMode_t *mode,
567
+ cudnnNanPropagation_t *reluNanOpt,
568
+ double *coef); /* ceiling for clipped RELU, alpha for ELU */
569
+
570
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
571
+ cudnnSetActivationDescriptorSwishBeta(cudnnActivationDescriptor_t activationDesc, double swish_beta);
572
+
573
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
574
+ cudnnGetActivationDescriptorSwishBeta(cudnnActivationDescriptor_t activationDesc, double *swish_beta);
575
+
576
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
577
+ cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc);
578
+
579
+ /* Function to perform forward activation */
580
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
581
+ cudnnActivationForward(cudnnHandle_t handle,
582
+ cudnnActivationDescriptor_t activationDesc,
583
+ const void *alpha,
584
+ const cudnnTensorDescriptor_t xDesc,
585
+ const void *x,
586
+ const void *beta,
587
+ const cudnnTensorDescriptor_t yDesc,
588
+ void *y);
589
+
590
+ /*
591
+ * Create an instance of LRN (Local Response Normalization) descriptor
592
+ * Uses lrnN=5, lrnAlpha=1e-4, lrnBeta=0.75, lrnK=2.0 as defaults from Krizhevsky'12 ImageNet paper
593
+ */
594
+ cudnnStatus_t CUDNNWINAPI
595
+ cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc);
596
+
597
+ #define CUDNN_LRN_MIN_N 1 /* minimum allowed lrnN */
598
+ #define CUDNN_LRN_MAX_N 16 /* maximum allowed lrnN */
599
+ #define CUDNN_LRN_MIN_K 1e-5 /* minimum allowed lrnK */
600
+ #define CUDNN_LRN_MIN_BETA 0.01 /* minimum allowed lrnBeta */
601
+
602
+ /* LRN layer mode */
603
+ typedef enum {
604
+ CUDNN_LRN_CROSS_CHANNEL_DIM1 = 0, /* Normalize across tensor's dimA[1] dimension */
605
+ } cudnnLRNMode_t;
606
+
607
+ /*
608
+ * Uses a window [center-lookBehind, center+lookAhead], where
609
+ * lookBehind = floor( (lrnN-1)/2 ), lookAhead = lrnN-lookBehind-1.
610
+ * Values of double parameters cast to tensor data type.
611
+ */
612
+ cudnnStatus_t CUDNNWINAPI
613
+ cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned lrnN, double lrnAlpha, double lrnBeta, double lrnK);
614
+ /*
615
+ * Retrieve the settings currently stored in an LRN layer descriptor
616
+ * Any of the provided pointers can be NULL (no corresponding value will be returned)
617
+ */
618
+ cudnnStatus_t CUDNNWINAPI
619
+ cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrnAlpha, double *lrnBeta, double *lrnK);
620
+
621
+ /* Destroy an instance of LRN descriptor */
622
+ cudnnStatus_t CUDNNWINAPI
623
+ cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc);
624
+
625
+ /* LRN functions: output = alpha * normalize(x) + beta * old_y */
626
+
627
+ /* LRN cross-channel forward computation. Double parameters cast to tensor data type */
628
+ cudnnStatus_t CUDNNWINAPI
629
+ cudnnLRNCrossChannelForward(cudnnHandle_t handle,
630
+ cudnnLRNDescriptor_t normDesc,
631
+ cudnnLRNMode_t lrnMode,
632
+ const void *alpha,
633
+ const cudnnTensorDescriptor_t xDesc,
634
+ const void *x,
635
+ const void *beta,
636
+ const cudnnTensorDescriptor_t yDesc,
637
+ void *y);
638
+
639
+ typedef enum {
640
+ CUDNN_DIVNORM_PRECOMPUTED_MEANS = 0,
641
+ } cudnnDivNormMode_t;
642
+
643
+ /* LCN/divisive normalization functions: y = alpha * normalize(x) + beta * y */
644
+ cudnnStatus_t CUDNNWINAPI
645
+ cudnnDivisiveNormalizationForward(cudnnHandle_t handle,
646
+ cudnnLRNDescriptor_t normDesc,
647
+ cudnnDivNormMode_t mode,
648
+ const void *alpha,
649
+ const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */
650
+ const void *x,
651
+ const void *means, /* if NULL, means are assumed to be zero */
652
+ void *temp,
653
+ void *temp2,
654
+ const void *beta,
655
+ const cudnnTensorDescriptor_t yDesc,
656
+ void *y);
657
+
658
+ typedef enum {
659
+ /* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
660
+ CUDNN_BATCHNORM_PER_ACTIVATION = 0,
661
+
662
+ /* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
663
+ CUDNN_BATCHNORM_SPATIAL = 1,
664
+
665
+ /*
666
+ * bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors).
667
+ * May be faster than CUDNN_BATCHNORM_SPATIAL but imposes some limits on the range of values
668
+ */
669
+ CUDNN_BATCHNORM_SPATIAL_PERSISTENT = 2,
670
+ } cudnnBatchNormMode_t CUDNN_DEPRECATED;
671
+
672
+ #define CUDNN_BN_MIN_EPSILON 0.0 /* Minimum epsilon allowed to be used in the Batch Normalization formula */
673
+
674
+ /*
675
+ * Derives a tensor descriptor from layer data descriptor for BatchNormalization
676
+ * scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
677
+ * bnScaleBiasMeanVarDesc and bnScaleBiasDiffDesc in Batch Normalization forward and backward functions.
678
+ */
679
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
680
+ cudnnDeriveBNTensorDescriptor(cudnnTensorDescriptor_t derivedBnDesc,
681
+ const cudnnTensorDescriptor_t xDesc,
682
+ cudnnBatchNormMode_t mode);
683
+
684
+ typedef enum {
685
+ CUDNN_BATCHNORM_OPS_BN = 0, /* do batch normalization only */
686
+ CUDNN_BATCHNORM_OPS_BN_ACTIVATION = 1, /* do batchNorm, then activation */
687
+ CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION = 2, /* do batchNorm, then elemWiseAdd, then activation */
688
+ } cudnnBatchNormOps_t CUDNN_DEPRECATED;
689
+
690
+ /*
691
+ * Performs Batch Normalization during Inference:
692
+ * y[i] = bnScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + bnBias[k]
693
+ * with bnScale, bnBias, runningMean, runningInvVariance tensors indexed
694
+ * according to spatial or per-activation mode. Refer to cudnnBatchNormalizationForwardTraining
695
+ * above for notes on function arguments.
696
+ */
697
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
698
+ cudnnBatchNormalizationForwardInference(cudnnHandle_t handle,
699
+ cudnnBatchNormMode_t mode,
700
+ const void *alpha, /* alpha[0] = result blend factor */
701
+ const void *beta, /* beta[0] = dest layer blend factor */
702
+ const cudnnTensorDescriptor_t xDesc,
703
+ const void *x, /* NxCxHxW */
704
+ const cudnnTensorDescriptor_t yDesc,
705
+ void *y, /* NxCxHxW */
706
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
707
+ const void *bnScale,
708
+ const void *bnBias,
709
+ const void *estimatedMean,
710
+ const void *estimatedVariance,
711
+ double epsilon);
712
+
713
+ typedef enum {
714
+ /* bnScale, bnBias tensor dims are 1xCxHxWx.. (one value per CHW...-slice, normalized over N slice) */
715
+ CUDNN_NORM_PER_ACTIVATION = 0,
716
+
717
+ /* bnScale, bnBias tensor dims are 1xCx1x1 (one value per C-dim normalized over Nx1xHxW subtensors) */
718
+ CUDNN_NORM_PER_CHANNEL = 1,
719
+ } cudnnNormMode_t CUDNN_DEPRECATED;
720
+
721
+ typedef enum { CUDNN_NORM_ALGO_STANDARD = 0, CUDNN_NORM_ALGO_PERSIST = 1 } cudnnNormAlgo_t CUDNN_DEPRECATED;
722
+
723
+ /*
724
+ * Derives a tensor descriptor from layer data descriptor for Normalization
725
+ * scale, invVariance, bnBias, bnScale tensors. Use this tensor desc for
726
+ * normScaleBiasMeanVarDesc and normScaleBiasDiffDesc in Normalization forward and backward functions.
727
+ */
728
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
729
+ cudnnDeriveNormTensorDescriptor(cudnnTensorDescriptor_t derivedNormScaleBiasDesc,
730
+ cudnnTensorDescriptor_t derivedNormMeanVarDesc,
731
+ const cudnnTensorDescriptor_t xDesc,
732
+ cudnnNormMode_t mode,
733
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
734
+
735
+ typedef enum {
736
+ CUDNN_NORM_OPS_NORM = 0, /* do normalization only */
737
+ CUDNN_NORM_OPS_NORM_ACTIVATION = 1, /* do Norm, then activation */
738
+ CUDNN_NORM_OPS_NORM_ADD_ACTIVATION = 2, /* do Norm, then elemWiseAdd, then activation */
739
+ } cudnnNormOps_t CUDNN_DEPRECATED;
740
+
741
+ /*
742
+ * Performs Normalization during Inference:
743
+ * y[i] = normScale[k]*(x[i]-estimatedMean[k])/sqrt(epsilon+estimatedVariance[k]) + normBias[k]
744
+ * with normScale, normBias, runningMean, runningInvVariance tensors indexed
745
+ * according to per-channel or per-activation mode. Refer to cudnnNormalizationForwardTraining
746
+ * above for notes on function arguments.
747
+ */
748
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
749
+ cudnnNormalizationForwardInference(cudnnHandle_t handle,
750
+ cudnnNormMode_t mode,
751
+ cudnnNormOps_t normOps,
752
+ cudnnNormAlgo_t algo,
753
+ const void *alpha, /* alpha[0] = result blend factor */
754
+ const void *beta, /* beta[0] = dest layer blend factor */
755
+ const cudnnTensorDescriptor_t xDesc,
756
+ const void *x, /* NxCxHxW */
757
+ const cudnnTensorDescriptor_t normScaleBiasDesc,
758
+ const void *normScale,
759
+ const void *normBias,
760
+ const cudnnTensorDescriptor_t normMeanVarDesc,
761
+ const void *estimatedMean,
762
+ const void *estimatedVariance,
763
+ const cudnnTensorDescriptor_t zDesc,
764
+ const void *z,
765
+ cudnnActivationDescriptor_t activationDesc,
766
+ const cudnnTensorDescriptor_t yDesc,
767
+ void *y, /* NxCxHxW */
768
+ double epsilon,
769
+ int groupCnt); /* Place hold for future work*/
770
+
771
+ /* APIs for spatial transformer network*/
772
+ typedef enum {
773
+ CUDNN_SAMPLER_BILINEAR = 0,
774
+ } cudnnSamplerType_t;
775
+
776
+ cudnnStatus_t CUDNNWINAPI
777
+ cudnnCreateSpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t *stDesc);
778
+
779
+ cudnnStatus_t CUDNNWINAPI
780
+ cudnnSetSpatialTransformerNdDescriptor(cudnnSpatialTransformerDescriptor_t stDesc,
781
+ cudnnSamplerType_t samplerType,
782
+ cudnnDataType_t dataType,
783
+ const int nbDims,
784
+ const int dimA[]);
785
+
786
+ cudnnStatus_t CUDNNWINAPI
787
+ cudnnDestroySpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t stDesc);
788
+
789
+ cudnnStatus_t CUDNNWINAPI
790
+ cudnnSpatialTfGridGeneratorForward(cudnnHandle_t handle,
791
+ const cudnnSpatialTransformerDescriptor_t stDesc,
792
+ const void *theta,
793
+ void *grid);
794
+
795
+ cudnnStatus_t CUDNNWINAPI
796
+ cudnnSpatialTfSamplerForward(cudnnHandle_t handle,
797
+ cudnnSpatialTransformerDescriptor_t stDesc,
798
+ const void *alpha,
799
+ const cudnnTensorDescriptor_t xDesc,
800
+ const void *x,
801
+ const void *grid,
802
+ const void *beta,
803
+ cudnnTensorDescriptor_t yDesc,
804
+ void *y);
805
+
806
+ typedef struct cudnnDropoutStruct *cudnnDropoutDescriptor_t;
807
+
808
+ cudnnStatus_t CUDNNWINAPI
809
+ cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc);
810
+
811
+ cudnnStatus_t CUDNNWINAPI
812
+ cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc);
813
+
814
+ /*helper function to determine size of the states to be passed to cudnnSetDropoutDescriptor */
815
+ cudnnStatus_t CUDNNWINAPI
816
+ cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t *sizeInBytes);
817
+
818
+ /*helper function to determine size of the reserve space to be passed to dropout forward/backward calls */
819
+ cudnnStatus_t CUDNNWINAPI
820
+ cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes);
821
+
822
+ cudnnStatus_t CUDNNWINAPI
823
+ cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
824
+ cudnnHandle_t handle,
825
+ float dropout,
826
+ void *states,
827
+ size_t stateSizeInBytes,
828
+ unsigned long long seed);
829
+
830
+ /* Restores the dropout descriptor to a previously saved-off state */
831
+ cudnnStatus_t CUDNNWINAPI
832
+ cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
833
+ cudnnHandle_t handle,
834
+ float dropout,
835
+ void *states,
836
+ size_t stateSizeInBytes,
837
+ unsigned long long seed);
838
+
839
+ cudnnStatus_t CUDNNWINAPI
840
+ cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc,
841
+ cudnnHandle_t handle,
842
+ float *dropout,
843
+ void **states,
844
+ unsigned long long *seed);
845
+
846
+ cudnnStatus_t CUDNNWINAPI
847
+ cudnnDropoutForward(cudnnHandle_t handle,
848
+ const cudnnDropoutDescriptor_t dropoutDesc,
849
+ const cudnnTensorDescriptor_t xdesc,
850
+ const void *x,
851
+ const cudnnTensorDescriptor_t ydesc,
852
+ void *y,
853
+ void *reserveSpace,
854
+ size_t reserveSpaceSizeInBytes);
855
+
856
+ /* TODO: move these enums out to the appropriate submodule */
857
+ typedef enum {
858
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0,
859
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1,
860
+ CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2,
861
+ CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3,
862
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4,
863
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5,
864
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6,
865
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7,
866
+ CUDNN_CONVOLUTION_FWD_ALGO_COUNT = 8
867
+ } cudnnConvolutionFwdAlgo_t;
868
+
869
+ typedef enum {
870
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, /* non-deterministic */
871
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
872
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2,
873
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3 = 3, /* non-deterministic */
874
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD = 4, /* not implemented */
875
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED = 5,
876
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING = 6,
877
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT = 7
878
+ } cudnnConvolutionBwdFilterAlgo_t;
879
+
880
+ typedef enum {
881
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, /* non-deterministic */
882
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
883
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
884
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING = 3,
885
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD = 4,
886
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED = 5,
887
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT = 6
888
+ } cudnnConvolutionBwdDataAlgo_t;
889
+
890
+ typedef enum { CUDNN_CTC_LOSS_ALGO_DETERMINISTIC = 0, CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC = 1 } cudnnCTCLossAlgo_t;
891
+
892
+ /*
893
+ * \brief Cross-library version checker.
894
+ * This function is implemented differently in each sub-library. Each sublib
895
+ * checks whether its own version matches that of its dependencies.
896
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
897
+ * CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
898
+ */
899
+ cudnnStatus_t CUDNNWINAPI
900
+ cudnnOpsVersionCheck(void);
901
+
902
+ /* Function to perform backward softmax */
903
+ cudnnStatus_t CUDNNWINAPI
904
+ cudnnSoftmaxBackward(cudnnHandle_t handle,
905
+ cudnnSoftmaxAlgorithm_t algo,
906
+ cudnnSoftmaxMode_t mode,
907
+ const void *alpha,
908
+ const cudnnTensorDescriptor_t yDesc,
909
+ const void *y,
910
+ const cudnnTensorDescriptor_t dyDesc,
911
+ const void *dy,
912
+ const void *beta,
913
+ const cudnnTensorDescriptor_t dxDesc,
914
+ void *dx);
915
+
916
+ /* Function to perform backward pooling */
917
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
918
+ cudnnPoolingBackward(cudnnHandle_t handle,
919
+ const cudnnPoolingDescriptor_t poolingDesc,
920
+ const void *alpha,
921
+ const cudnnTensorDescriptor_t yDesc,
922
+ const void *y,
923
+ const cudnnTensorDescriptor_t dyDesc,
924
+ const void *dy,
925
+ const cudnnTensorDescriptor_t xDesc,
926
+ const void *x,
927
+ const void *beta,
928
+ const cudnnTensorDescriptor_t dxDesc,
929
+ void *dx);
930
+
931
+ /* Function to perform backward activation */
932
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
933
+ cudnnActivationBackward(cudnnHandle_t handle,
934
+ cudnnActivationDescriptor_t activationDesc,
935
+ const void *alpha,
936
+ const cudnnTensorDescriptor_t yDesc,
937
+ const void *y,
938
+ const cudnnTensorDescriptor_t dyDesc,
939
+ const void *dy,
940
+ const cudnnTensorDescriptor_t xDesc,
941
+ const void *x,
942
+ const void *beta,
943
+ const cudnnTensorDescriptor_t dxDesc,
944
+ void *dx);
945
+
946
+ /* LRN cross-channel backward computation. Double parameters cast to tensor data type */
947
+ cudnnStatus_t CUDNNWINAPI
948
+ cudnnLRNCrossChannelBackward(cudnnHandle_t handle,
949
+ cudnnLRNDescriptor_t normDesc,
950
+ cudnnLRNMode_t lrnMode,
951
+ const void *alpha,
952
+ const cudnnTensorDescriptor_t yDesc,
953
+ const void *y,
954
+ const cudnnTensorDescriptor_t dyDesc,
955
+ const void *dy,
956
+ const cudnnTensorDescriptor_t xDesc,
957
+ const void *x,
958
+ const void *beta,
959
+ const cudnnTensorDescriptor_t dxDesc,
960
+ void *dx);
961
+
962
+ cudnnStatus_t CUDNNWINAPI
963
+ cudnnDivisiveNormalizationBackward(cudnnHandle_t handle,
964
+ cudnnLRNDescriptor_t normDesc,
965
+ cudnnDivNormMode_t mode,
966
+ const void *alpha,
967
+ const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */
968
+ const void *x,
969
+ const void *means, /* if NULL, means are assumed to be zero */
970
+ const void *dy,
971
+ void *temp,
972
+ void *temp2,
973
+ const void *beta,
974
+ const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */
975
+ void *dx, /* output x differential */
976
+ void *dMeans); /* output means differential, can be NULL */
977
+
978
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
979
+ cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(cudnnHandle_t handle,
980
+ cudnnBatchNormMode_t mode,
981
+ cudnnBatchNormOps_t bnOps,
982
+ const cudnnTensorDescriptor_t xDesc,
983
+ const cudnnTensorDescriptor_t zDesc,
984
+ const cudnnTensorDescriptor_t yDesc,
985
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
986
+ const cudnnActivationDescriptor_t activationDesc,
987
+ size_t *sizeInBytes);
988
+
989
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
990
+ cudnnGetBatchNormalizationBackwardExWorkspaceSize(cudnnHandle_t handle,
991
+ cudnnBatchNormMode_t mode,
992
+ cudnnBatchNormOps_t bnOps,
993
+ const cudnnTensorDescriptor_t xDesc,
994
+ const cudnnTensorDescriptor_t yDesc,
995
+ const cudnnTensorDescriptor_t dyDesc,
996
+ const cudnnTensorDescriptor_t dzDesc,
997
+ const cudnnTensorDescriptor_t dxDesc,
998
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
999
+ const cudnnActivationDescriptor_t activationDesc,
1000
+ size_t *sizeInBytes);
1001
+
1002
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1003
+ cudnnGetBatchNormalizationTrainingExReserveSpaceSize(cudnnHandle_t handle,
1004
+ cudnnBatchNormMode_t mode,
1005
+ cudnnBatchNormOps_t bnOps,
1006
+ const cudnnActivationDescriptor_t activationDesc,
1007
+ const cudnnTensorDescriptor_t xDesc,
1008
+ size_t *sizeInBytes);
1009
+
1010
+ /* Computes y = BN(x). Also accumulates moving averages of mean and inverse variances */
1011
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1012
+ cudnnBatchNormalizationForwardTraining(
1013
+ cudnnHandle_t handle,
1014
+ cudnnBatchNormMode_t mode,
1015
+
1016
+ const void *alpha, /* alpha[0] = result blend factor */
1017
+ const void *beta, /* beta[0] = dest layer blend factor */
1018
+
1019
+ const cudnnTensorDescriptor_t xDesc,
1020
+ const void *x, /* NxCxHxW */
1021
+ const cudnnTensorDescriptor_t yDesc,
1022
+ void *y, /* NxCxHxW */
1023
+
1024
+ /* Shared desc for the next 6 tensors in the argument list.
1025
+ Data type to be set as follows:
1026
+ type = (typeOf(x) == double) ? double : float
1027
+ Dimensions for this descriptor depend on normalization mode
1028
+ - Spatial Normalization : tensors are expected to have dims 1xCx1x1
1029
+ (normalization is performed across NxHxW)
1030
+ - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW
1031
+ (normalization is performed across N) */
1032
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
1033
+
1034
+ /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */
1035
+ const void *bnScale,
1036
+ const void *bnBias,
1037
+
1038
+ /* MUST use factor=1 in the very first call of a complete training cycle.
1039
+ Use a factor=1/(1+n) at N-th call to the function to get
1040
+ Cumulative Moving Average (CMA) behavior
1041
+ CMA[n] = (x[1]+...+x[n])/n
1042
+ Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) =
1043
+ ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) =
1044
+ CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */
1045
+ double exponentialAverageFactor,
1046
+
1047
+ /* Used in Training phase only.
1048
+ runningMean = newMean*factor + runningMean*(1-factor) */
1049
+ void *resultRunningMean,
1050
+ /* Output in training mode, input in inference. Is the moving average
1051
+ of variance[x] (factor is applied in the same way as for runningMean) */
1052
+ void *resultRunningVariance,
1053
+
1054
+ /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
1055
+ double epsilon,
1056
+
1057
+ /* Optionally save intermediate results from the forward pass here
1058
+ - can be reused to speed up backward pass. NULL if unused */
1059
+ void *resultSaveMean,
1060
+ void *resultSaveInvVariance);
1061
+
1062
+ /* Computes y = relu(BN(x) + z). Also accumulates moving averages of mean and inverse variances */
1063
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1064
+ cudnnBatchNormalizationForwardTrainingEx(
1065
+ cudnnHandle_t handle,
1066
+ cudnnBatchNormMode_t mode,
1067
+ cudnnBatchNormOps_t bnOps,
1068
+
1069
+ const void *alpha, /* alpha[0] = result blend factor */
1070
+ const void *beta, /* beta[0] = dest layer blend factor */
1071
+
1072
+ const cudnnTensorDescriptor_t xDesc,
1073
+ const void *xData,
1074
+ const cudnnTensorDescriptor_t zDesc,
1075
+ const void *zData,
1076
+ const cudnnTensorDescriptor_t yDesc,
1077
+ void *yData,
1078
+
1079
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
1080
+ const void *bnScale,
1081
+ const void *bnBias,
1082
+
1083
+ double exponentialAverageFactor,
1084
+ void *resultRunningMean,
1085
+ void *resultRunningVariance,
1086
+
1087
+ /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
1088
+ double epsilon,
1089
+
1090
+ /* Optionally save intermediate results from the forward pass here
1091
+ - can be reused to speed up backward pass. NULL if unused */
1092
+ void *resultSaveMean,
1093
+ void *resultSaveInvVariance,
1094
+
1095
+ cudnnActivationDescriptor_t activationDesc,
1096
+ void *workspace,
1097
+ size_t workSpaceSizeInBytes,
1098
+ void *reserveSpace,
1099
+ size_t reserveSpaceSizeInBytes);
1100
+
1101
+ /* Performs backward pass of Batch Normalization layer. Returns x gradient,
1102
+ * bnScale gradient and bnBias gradient */
1103
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1104
+ cudnnBatchNormalizationBackward(cudnnHandle_t handle,
1105
+ cudnnBatchNormMode_t mode,
1106
+ const void *alphaDataDiff,
1107
+ const void *betaDataDiff,
1108
+ const void *alphaParamDiff,
1109
+ const void *betaParamDiff,
1110
+ const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */
1111
+ const void *x,
1112
+ const cudnnTensorDescriptor_t dyDesc,
1113
+ const void *dy,
1114
+ const cudnnTensorDescriptor_t dxDesc,
1115
+ void *dx,
1116
+ /* Shared tensor desc for the 4 tensors below */
1117
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
1118
+ const void *bnScale, /* bnBias doesn't affect backpropagation */
1119
+ /* scale and bias diff are not backpropagated below this layer */
1120
+ void *dBnScaleResult,
1121
+ void *dBnBiasResult,
1122
+ /* Same epsilon as forward pass */
1123
+ double epsilon,
1124
+
1125
+ /* Optionally cached intermediate results from
1126
+ forward pass */
1127
+ const void *savedMean,
1128
+ const void *savedInvVariance);
1129
+
1130
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1131
+ cudnnBatchNormalizationBackwardEx(cudnnHandle_t handle,
1132
+ cudnnBatchNormMode_t mode,
1133
+ cudnnBatchNormOps_t bnOps,
1134
+
1135
+ const void *alphaDataDiff,
1136
+ const void *betaDataDiff,
1137
+ const void *alphaParamDiff,
1138
+ const void *betaParamDiff,
1139
+ const cudnnTensorDescriptor_t xDesc,
1140
+ const void *xData,
1141
+ const cudnnTensorDescriptor_t yDesc,
1142
+ const void *yData,
1143
+ const cudnnTensorDescriptor_t dyDesc,
1144
+ const void *dyData,
1145
+ const cudnnTensorDescriptor_t dzDesc,
1146
+ void *dzData,
1147
+ const cudnnTensorDescriptor_t dxDesc,
1148
+ void *dxData,
1149
+
1150
+ /* Shared tensor desc for the 4 tensors below */
1151
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
1152
+ const void *bnScaleData,
1153
+ const void *bnBiasData, /* needed if there is activation */
1154
+ void *dBnScaleData,
1155
+ void *dBnBiasData,
1156
+ double epsilon, /* Same epsilon as forward pass */
1157
+
1158
+ /* Optionally cached intermediate results from
1159
+ forward pass */
1160
+ const void *savedMean,
1161
+ const void *savedInvVariance,
1162
+ cudnnActivationDescriptor_t activationDesc,
1163
+ void *workSpace,
1164
+ size_t workSpaceSizeInBytes,
1165
+ void *reserveSpace,
1166
+ size_t reserveSpaceSizeInBytes);
1167
+
1168
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1169
+ cudnnGetNormalizationForwardTrainingWorkspaceSize(cudnnHandle_t handle,
1170
+ cudnnNormMode_t mode,
1171
+ cudnnNormOps_t normOps,
1172
+ cudnnNormAlgo_t algo,
1173
+ const cudnnTensorDescriptor_t xDesc,
1174
+ const cudnnTensorDescriptor_t zDesc,
1175
+ const cudnnTensorDescriptor_t yDesc,
1176
+ const cudnnTensorDescriptor_t normScaleBiasDesc,
1177
+ const cudnnActivationDescriptor_t activationDesc,
1178
+ const cudnnTensorDescriptor_t normMeanVarDesc,
1179
+ size_t *sizeInBytes,
1180
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1181
+
1182
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1183
+ cudnnGetNormalizationBackwardWorkspaceSize(cudnnHandle_t handle,
1184
+ cudnnNormMode_t mode,
1185
+ cudnnNormOps_t normOps,
1186
+ cudnnNormAlgo_t algo,
1187
+ const cudnnTensorDescriptor_t xDesc,
1188
+ const cudnnTensorDescriptor_t yDesc,
1189
+ const cudnnTensorDescriptor_t dyDesc,
1190
+ const cudnnTensorDescriptor_t dzDesc,
1191
+ const cudnnTensorDescriptor_t dxDesc,
1192
+ const cudnnTensorDescriptor_t dNormScaleBiasDesc,
1193
+ const cudnnActivationDescriptor_t activationDesc,
1194
+ const cudnnTensorDescriptor_t normMeanVarDesc,
1195
+ size_t *sizeInBytes,
1196
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1197
+
1198
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1199
+ cudnnGetNormalizationTrainingReserveSpaceSize(cudnnHandle_t handle,
1200
+ cudnnNormMode_t mode,
1201
+ cudnnNormOps_t normOps,
1202
+ cudnnNormAlgo_t algo,
1203
+ const cudnnActivationDescriptor_t activationDesc,
1204
+ const cudnnTensorDescriptor_t xDesc,
1205
+ size_t *sizeInBytes,
1206
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1207
+
1208
+ /* Computes y = relu(Norm(x) + z). Also accumulates moving averages of mean and inverse variances */
1209
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1210
+ cudnnNormalizationForwardTraining(cudnnHandle_t handle,
1211
+ cudnnNormMode_t mode,
1212
+ cudnnNormOps_t normOps,
1213
+ cudnnNormAlgo_t algo,
1214
+ const void *alpha, /* alpha[0] = result blend factor */
1215
+ const void *beta, /* beta[0] = dest layer blend factor */
1216
+ const cudnnTensorDescriptor_t xDesc,
1217
+ const void *xData,
1218
+ const cudnnTensorDescriptor_t normScaleBiasDesc,
1219
+ const void *normScale,
1220
+ const void *normBias,
1221
+ double exponentialAverageFactor,
1222
+ const cudnnTensorDescriptor_t normMeanVarDesc,
1223
+ void *resultRunningMean,
1224
+ void *resultRunningVariance,
1225
+ /* Has to be >= 0. Should be the same in forward and backward functions. */
1226
+ double epsilon,
1227
+ /* Optionally save intermediate results from the forward pass here
1228
+ - can be reused to speed up backward pass. NULL if unused */
1229
+ void *resultSaveMean,
1230
+ void *resultSaveInvVariance,
1231
+ cudnnActivationDescriptor_t activationDesc,
1232
+ const cudnnTensorDescriptor_t zDesc,
1233
+ const void *zData,
1234
+ const cudnnTensorDescriptor_t yDesc,
1235
+ void *yData,
1236
+ void *workspace,
1237
+ size_t workSpaceSizeInBytes,
1238
+ void *reserveSpace,
1239
+ size_t reserveSpaceSizeInBytes,
1240
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1241
+
1242
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
1243
+ cudnnNormalizationBackward(cudnnHandle_t handle,
1244
+ cudnnNormMode_t mode,
1245
+ cudnnNormOps_t normOps,
1246
+ cudnnNormAlgo_t algo,
1247
+ const void *alphaDataDiff,
1248
+ const void *betaDataDiff,
1249
+ const void *alphaParamDiff,
1250
+ const void *betaParamDiff,
1251
+ const cudnnTensorDescriptor_t xDesc,
1252
+ const void *xData,
1253
+ const cudnnTensorDescriptor_t yDesc,
1254
+ const void *yData,
1255
+ const cudnnTensorDescriptor_t dyDesc,
1256
+ const void *dyData,
1257
+ const cudnnTensorDescriptor_t dzDesc,
1258
+ void *dzData,
1259
+ const cudnnTensorDescriptor_t dxDesc,
1260
+ void *dxData,
1261
+ /* Shared tensor desc for the 4 tensors below */
1262
+ const cudnnTensorDescriptor_t dNormScaleBiasDesc,
1263
+ const void *normScaleData,
1264
+ const void *normBiasData, /* needed if there is activation */
1265
+ void *dNormScaleData,
1266
+ void *dNormBiasData,
1267
+ double epsilon, /* Same epsilon as forward pass */
1268
+ const cudnnTensorDescriptor_t normMeanVarDesc,
1269
+ /* Optionally cached intermediate results from
1270
+ forward pass */
1271
+ const void *savedMean,
1272
+ const void *savedInvVariance,
1273
+ cudnnActivationDescriptor_t activationDesc,
1274
+ void *workSpace,
1275
+ size_t workSpaceSizeInBytes,
1276
+ void *reserveSpace,
1277
+ size_t reserveSpaceSizeInBytes,
1278
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
1279
+
1280
+ cudnnStatus_t CUDNNWINAPI
1281
+ cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle,
1282
+ const cudnnSpatialTransformerDescriptor_t stDesc,
1283
+ const void *dgrid,
1284
+ void *dtheta);
1285
+
1286
+ cudnnStatus_t CUDNNWINAPI
1287
+ cudnnSpatialTfSamplerBackward(cudnnHandle_t handle,
1288
+ cudnnSpatialTransformerDescriptor_t stDesc,
1289
+ const void *alpha,
1290
+ const cudnnTensorDescriptor_t xDesc,
1291
+ const void *x,
1292
+ const void *beta,
1293
+ const cudnnTensorDescriptor_t dxDesc,
1294
+ void *dx,
1295
+ const void *alphaDgrid,
1296
+ const cudnnTensorDescriptor_t dyDesc,
1297
+ const void *dy,
1298
+ const void *grid,
1299
+ const void *betaDgrid,
1300
+ void *dgrid);
1301
+
1302
+ cudnnStatus_t CUDNNWINAPI
1303
+ cudnnDropoutBackward(cudnnHandle_t handle,
1304
+ const cudnnDropoutDescriptor_t dropoutDesc,
1305
+ const cudnnTensorDescriptor_t dydesc,
1306
+ const void *dy,
1307
+ const cudnnTensorDescriptor_t dxdesc,
1308
+ void *dx,
1309
+ void *reserveSpace,
1310
+ size_t reserveSpaceSizeInBytes);
1311
+
1312
+ #if defined(__cplusplus)
1313
+ }
1314
+ #endif
1315
+
1316
+ #endif /* CUDNN_OPS_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_v9.h ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /* cudnn : Neural Networks Library */
51
+
52
+ #if !defined(CUDNN_H_)
53
+ #define CUDNN_H_
54
+ #if defined(__cplusplus)
55
+ extern "C" {
56
+ #endif
57
+
58
+ #include <cuda_runtime_api.h>
59
+ #include "cudnn_version.h"
60
+ #include "cudnn_graph.h"
61
+ #include "cudnn_ops.h"
62
+ #include "cudnn_adv.h"
63
+ #include "cudnn_cnn.h"
64
+
65
+ #if defined(__cplusplus)
66
+ }
67
+ #endif
68
+ #endif /* CUDNN_H_ */
.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /**
51
+ * \file: The master cuDNN version file.
52
+ */
53
+
54
+ #ifndef CUDNN_VERSION_H_
55
+ #define CUDNN_VERSION_H_
56
+
57
+ #define CUDNN_MAJOR 9
58
+ #define CUDNN_MINOR 1
59
+ #define CUDNN_PATCHLEVEL 0
60
+
61
+ #define CUDNN_VERSION (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
62
+
63
+ /* cannot use constexpr here since this is a C-only file */
64
+ /* Below is the max SM version this cuDNN library is aware of and supports natively */
65
+
66
+ #define CUDNN_MAX_SM_MAJOR_NUMBER 9
67
+ #define CUDNN_MAX_SM_MINOR_NUMBER 0
68
+ #define CUDNN_MAX_DEVICE_VERSION (CUDNN_MAX_SM_MAJOR_NUMBER * 100 + CUDNN_MAX_SM_MINOR_NUMBER * 10)
69
+
70
+ #endif /* CUDNN_VERSION_H */
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libavutil-734d06dd.so.57.28.100 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b207c6d08f72e7bc047e3d17fac1a9f23c61059ffd30cba8040489e8ad79ae33
3
+ size 844673
.venv/lib/python3.11/site-packages/opencv_python_headless.libs/libssl-28bef1ac.so.1.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf90306b659880ae21114a928d1a3421c66715f16e49a824fb9ee55113d2b767
3
+ size 736177
.venv/lib/python3.11/site-packages/pyasn1/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
.venv/lib/python3.11/site-packages/pyasn1/__pycache__/debug.cpython-311.pyc ADDED
Binary file (6.88 kB). View file
 
.venv/lib/python3.11/site-packages/pyasn1/codec/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file is necessary to make this directory a package.
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file is necessary to make this directory a package.
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (189 Bytes). View file
 
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/decoder.cpython-311.pyc ADDED
Binary file (79.4 kB). View file
 
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/encoder.cpython-311.pyc ADDED
Binary file (34.5 kB). View file
 
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/__pycache__/eoo.cpython-311.pyc ADDED
Binary file (1.15 kB). View file
 
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/decoder.py ADDED
@@ -0,0 +1,2189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is part of pyasn1 software.
3
+ #
4
+ # Copyright (c) 2005-2020, Ilya Etingof <etingof@gmail.com>
5
+ # License: https://pyasn1.readthedocs.io/en/latest/license.html
6
+ #
7
+ import io
8
+ import os
9
+ import sys
10
+ import warnings
11
+
12
+ from pyasn1 import debug
13
+ from pyasn1 import error
14
+ from pyasn1.codec.ber import eoo
15
+ from pyasn1.codec.streaming import asSeekableStream
16
+ from pyasn1.codec.streaming import isEndOfStream
17
+ from pyasn1.codec.streaming import peekIntoStream
18
+ from pyasn1.codec.streaming import readFromStream
19
+ from pyasn1.compat import _MISSING
20
+ from pyasn1.error import PyAsn1Error
21
+ from pyasn1.type import base
22
+ from pyasn1.type import char
23
+ from pyasn1.type import tag
24
+ from pyasn1.type import tagmap
25
+ from pyasn1.type import univ
26
+ from pyasn1.type import useful
27
+
28
+ __all__ = ['StreamingDecoder', 'Decoder', 'decode']
29
+
30
+ LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_DECODER)
31
+
32
+ noValue = base.noValue
33
+
34
+ SubstrateUnderrunError = error.SubstrateUnderrunError
35
+
36
+
37
+ class AbstractPayloadDecoder(object):
38
+ protoComponent = None
39
+
40
+ def valueDecoder(self, substrate, asn1Spec,
41
+ tagSet=None, length=None, state=None,
42
+ decodeFun=None, substrateFun=None,
43
+ **options):
44
+ """Decode value with fixed byte length.
45
+
46
+ The decoder is allowed to consume as many bytes as necessary.
47
+ """
48
+ raise error.PyAsn1Error('SingleItemDecoder not implemented for %s' % (tagSet,)) # TODO: Seems more like an NotImplementedError?
49
+
50
+ def indefLenValueDecoder(self, substrate, asn1Spec,
51
+ tagSet=None, length=None, state=None,
52
+ decodeFun=None, substrateFun=None,
53
+ **options):
54
+ """Decode value with undefined length.
55
+
56
+ The decoder is allowed to consume as many bytes as necessary.
57
+ """
58
+ raise error.PyAsn1Error('Indefinite length mode decoder not implemented for %s' % (tagSet,)) # TODO: Seems more like an NotImplementedError?
59
+
60
+ @staticmethod
61
+ def _passAsn1Object(asn1Object, options):
62
+ if 'asn1Object' not in options:
63
+ options['asn1Object'] = asn1Object
64
+
65
+ return options
66
+
67
+
68
+ class AbstractSimplePayloadDecoder(AbstractPayloadDecoder):
69
+ @staticmethod
70
+ def substrateCollector(asn1Object, substrate, length, options):
71
+ for chunk in readFromStream(substrate, length, options):
72
+ yield chunk
73
+
74
+ def _createComponent(self, asn1Spec, tagSet, value, **options):
75
+ if options.get('native'):
76
+ return value
77
+ elif asn1Spec is None:
78
+ return self.protoComponent.clone(value, tagSet=tagSet)
79
+ elif value is noValue:
80
+ return asn1Spec
81
+ else:
82
+ return asn1Spec.clone(value)
83
+
84
+
85
+ class RawPayloadDecoder(AbstractSimplePayloadDecoder):
86
+ protoComponent = univ.Any('')
87
+
88
+ def valueDecoder(self, substrate, asn1Spec,
89
+ tagSet=None, length=None, state=None,
90
+ decodeFun=None, substrateFun=None,
91
+ **options):
92
+ if substrateFun:
93
+ asn1Object = self._createComponent(asn1Spec, tagSet, '', **options)
94
+
95
+ for chunk in substrateFun(asn1Object, substrate, length, options):
96
+ yield chunk
97
+
98
+ return
99
+
100
+ for value in decodeFun(substrate, asn1Spec, tagSet, length, **options):
101
+ yield value
102
+
103
+ def indefLenValueDecoder(self, substrate, asn1Spec,
104
+ tagSet=None, length=None, state=None,
105
+ decodeFun=None, substrateFun=None,
106
+ **options):
107
+ if substrateFun:
108
+ asn1Object = self._createComponent(asn1Spec, tagSet, '', **options)
109
+
110
+ for chunk in substrateFun(asn1Object, substrate, length, options):
111
+ yield chunk
112
+
113
+ return
114
+
115
+ while True:
116
+ for value in decodeFun(
117
+ substrate, asn1Spec, tagSet, length,
118
+ allowEoo=True, **options):
119
+
120
+ if value is eoo.endOfOctets:
121
+ return
122
+
123
+ yield value
124
+
125
+
126
+ rawPayloadDecoder = RawPayloadDecoder()
127
+
128
+
129
+ class IntegerPayloadDecoder(AbstractSimplePayloadDecoder):
130
+ protoComponent = univ.Integer(0)
131
+
132
+ def valueDecoder(self, substrate, asn1Spec,
133
+ tagSet=None, length=None, state=None,
134
+ decodeFun=None, substrateFun=None,
135
+ **options):
136
+
137
+ if tagSet[0].tagFormat != tag.tagFormatSimple:
138
+ raise error.PyAsn1Error('Simple tag format expected')
139
+
140
+ for chunk in readFromStream(substrate, length, options):
141
+ if isinstance(chunk, SubstrateUnderrunError):
142
+ yield chunk
143
+
144
+ if chunk:
145
+ value = int.from_bytes(bytes(chunk), 'big', signed=True)
146
+
147
+ else:
148
+ value = 0
149
+
150
+ yield self._createComponent(asn1Spec, tagSet, value, **options)
151
+
152
+
153
+ class BooleanPayloadDecoder(IntegerPayloadDecoder):
154
+ protoComponent = univ.Boolean(0)
155
+
156
+ def _createComponent(self, asn1Spec, tagSet, value, **options):
157
+ return IntegerPayloadDecoder._createComponent(
158
+ self, asn1Spec, tagSet, value and 1 or 0, **options)
159
+
160
+
161
+ class BitStringPayloadDecoder(AbstractSimplePayloadDecoder):
162
+ protoComponent = univ.BitString(())
163
+ supportConstructedForm = True
164
+
165
+ def valueDecoder(self, substrate, asn1Spec,
166
+ tagSet=None, length=None, state=None,
167
+ decodeFun=None, substrateFun=None,
168
+ **options):
169
+
170
+ if substrateFun:
171
+ asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options)
172
+
173
+ for chunk in substrateFun(asn1Object, substrate, length, options):
174
+ yield chunk
175
+
176
+ return
177
+
178
+ if not length:
179
+ raise error.PyAsn1Error('Empty BIT STRING substrate')
180
+
181
+ for chunk in isEndOfStream(substrate):
182
+ if isinstance(chunk, SubstrateUnderrunError):
183
+ yield chunk
184
+
185
+ if chunk:
186
+ raise error.PyAsn1Error('Empty BIT STRING substrate')
187
+
188
+ if tagSet[0].tagFormat == tag.tagFormatSimple: # XXX what tag to check?
189
+
190
+ for trailingBits in readFromStream(substrate, 1, options):
191
+ if isinstance(trailingBits, SubstrateUnderrunError):
192
+ yield trailingBits
193
+
194
+ trailingBits = ord(trailingBits)
195
+ if trailingBits > 7:
196
+ raise error.PyAsn1Error(
197
+ 'Trailing bits overflow %s' % trailingBits
198
+ )
199
+
200
+ for chunk in readFromStream(substrate, length - 1, options):
201
+ if isinstance(chunk, SubstrateUnderrunError):
202
+ yield chunk
203
+
204
+ value = self.protoComponent.fromOctetString(
205
+ chunk, internalFormat=True, padding=trailingBits)
206
+
207
+ yield self._createComponent(asn1Spec, tagSet, value, **options)
208
+
209
+ return
210
+
211
+ if not self.supportConstructedForm:
212
+ raise error.PyAsn1Error('Constructed encoding form prohibited '
213
+ 'at %s' % self.__class__.__name__)
214
+
215
+ if LOG:
216
+ LOG('assembling constructed serialization')
217
+
218
+ # All inner fragments are of the same type, treat them as octet string
219
+ substrateFun = self.substrateCollector
220
+
221
+ bitString = self.protoComponent.fromOctetString(b'', internalFormat=True)
222
+
223
+ current_position = substrate.tell()
224
+
225
+ while substrate.tell() - current_position < length:
226
+ for component in decodeFun(
227
+ substrate, self.protoComponent, substrateFun=substrateFun,
228
+ **options):
229
+ if isinstance(component, SubstrateUnderrunError):
230
+ yield component
231
+
232
+ trailingBits = component[0]
233
+ if trailingBits > 7:
234
+ raise error.PyAsn1Error(
235
+ 'Trailing bits overflow %s' % trailingBits
236
+ )
237
+
238
+ bitString = self.protoComponent.fromOctetString(
239
+ component[1:], internalFormat=True,
240
+ prepend=bitString, padding=trailingBits
241
+ )
242
+
243
+ yield self._createComponent(asn1Spec, tagSet, bitString, **options)
244
+
245
+ def indefLenValueDecoder(self, substrate, asn1Spec,
246
+ tagSet=None, length=None, state=None,
247
+ decodeFun=None, substrateFun=None,
248
+ **options):
249
+
250
+ if substrateFun:
251
+ asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options)
252
+
253
+ for chunk in substrateFun(asn1Object, substrate, length, options):
254
+ yield chunk
255
+
256
+ return
257
+
258
+ # All inner fragments are of the same type, treat them as octet string
259
+ substrateFun = self.substrateCollector
260
+
261
+ bitString = self.protoComponent.fromOctetString(b'', internalFormat=True)
262
+
263
+ while True: # loop over fragments
264
+
265
+ for component in decodeFun(
266
+ substrate, self.protoComponent, substrateFun=substrateFun,
267
+ allowEoo=True, **options):
268
+
269
+ if component is eoo.endOfOctets:
270
+ break
271
+
272
+ if isinstance(component, SubstrateUnderrunError):
273
+ yield component
274
+
275
+ if component is eoo.endOfOctets:
276
+ break
277
+
278
+ trailingBits = component[0]
279
+ if trailingBits > 7:
280
+ raise error.PyAsn1Error(
281
+ 'Trailing bits overflow %s' % trailingBits
282
+ )
283
+
284
+ bitString = self.protoComponent.fromOctetString(
285
+ component[1:], internalFormat=True,
286
+ prepend=bitString, padding=trailingBits
287
+ )
288
+
289
+ yield self._createComponent(asn1Spec, tagSet, bitString, **options)
290
+
291
+
292
+ class OctetStringPayloadDecoder(AbstractSimplePayloadDecoder):
293
+ protoComponent = univ.OctetString('')
294
+ supportConstructedForm = True
295
+
296
+ def valueDecoder(self, substrate, asn1Spec,
297
+ tagSet=None, length=None, state=None,
298
+ decodeFun=None, substrateFun=None,
299
+ **options):
300
+ if substrateFun:
301
+ asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options)
302
+
303
+ for chunk in substrateFun(asn1Object, substrate, length, options):
304
+ yield chunk
305
+
306
+ return
307
+
308
+ if tagSet[0].tagFormat == tag.tagFormatSimple: # XXX what tag to check?
309
+ for chunk in readFromStream(substrate, length, options):
310
+ if isinstance(chunk, SubstrateUnderrunError):
311
+ yield chunk
312
+
313
+ yield self._createComponent(asn1Spec, tagSet, chunk, **options)
314
+
315
+ return
316
+
317
+ if not self.supportConstructedForm:
318
+ raise error.PyAsn1Error('Constructed encoding form prohibited at %s' % self.__class__.__name__)
319
+
320
+ if LOG:
321
+ LOG('assembling constructed serialization')
322
+
323
+ # All inner fragments are of the same type, treat them as octet string
324
+ substrateFun = self.substrateCollector
325
+
326
+ header = b''
327
+
328
+ original_position = substrate.tell()
329
+ # head = popSubstream(substrate, length)
330
+ while substrate.tell() - original_position < length:
331
+ for component in decodeFun(
332
+ substrate, self.protoComponent, substrateFun=substrateFun,
333
+ **options):
334
+ if isinstance(component, SubstrateUnderrunError):
335
+ yield component
336
+
337
+ header += component
338
+
339
+ yield self._createComponent(asn1Spec, tagSet, header, **options)
340
+
341
+ def indefLenValueDecoder(self, substrate, asn1Spec,
342
+ tagSet=None, length=None, state=None,
343
+ decodeFun=None, substrateFun=None,
344
+ **options):
345
+ if substrateFun and substrateFun is not self.substrateCollector:
346
+ asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options)
347
+
348
+ for chunk in substrateFun(asn1Object, substrate, length, options):
349
+ yield chunk
350
+
351
+ return
352
+
353
+ # All inner fragments are of the same type, treat them as octet string
354
+ substrateFun = self.substrateCollector
355
+
356
+ header = b''
357
+
358
+ while True: # loop over fragments
359
+
360
+ for component in decodeFun(
361
+ substrate, self.protoComponent, substrateFun=substrateFun,
362
+ allowEoo=True, **options):
363
+
364
+ if isinstance(component, SubstrateUnderrunError):
365
+ yield component
366
+
367
+ if component is eoo.endOfOctets:
368
+ break
369
+
370
+ if component is eoo.endOfOctets:
371
+ break
372
+
373
+ header += component
374
+
375
+ yield self._createComponent(asn1Spec, tagSet, header, **options)
376
+
377
+
378
+ class NullPayloadDecoder(AbstractSimplePayloadDecoder):
379
+ protoComponent = univ.Null('')
380
+
381
+ def valueDecoder(self, substrate, asn1Spec,
382
+ tagSet=None, length=None, state=None,
383
+ decodeFun=None, substrateFun=None,
384
+ **options):
385
+
386
+ if tagSet[0].tagFormat != tag.tagFormatSimple:
387
+ raise error.PyAsn1Error('Simple tag format expected')
388
+
389
+ for chunk in readFromStream(substrate, length, options):
390
+ if isinstance(chunk, SubstrateUnderrunError):
391
+ yield chunk
392
+
393
+ component = self._createComponent(asn1Spec, tagSet, '', **options)
394
+
395
+ if chunk:
396
+ raise error.PyAsn1Error('Unexpected %d-octet substrate for Null' % length)
397
+
398
+ yield component
399
+
400
+
401
+ class ObjectIdentifierPayloadDecoder(AbstractSimplePayloadDecoder):
402
+ protoComponent = univ.ObjectIdentifier(())
403
+
404
+ def valueDecoder(self, substrate, asn1Spec,
405
+ tagSet=None, length=None, state=None,
406
+ decodeFun=None, substrateFun=None,
407
+ **options):
408
+ if tagSet[0].tagFormat != tag.tagFormatSimple:
409
+ raise error.PyAsn1Error('Simple tag format expected')
410
+
411
+ for chunk in readFromStream(substrate, length, options):
412
+ if isinstance(chunk, SubstrateUnderrunError):
413
+ yield chunk
414
+
415
+ if not chunk:
416
+ raise error.PyAsn1Error('Empty substrate')
417
+
418
+ oid = ()
419
+ index = 0
420
+ substrateLen = len(chunk)
421
+ while index < substrateLen:
422
+ subId = chunk[index]
423
+ index += 1
424
+ if subId < 128:
425
+ oid += (subId,)
426
+ elif subId > 128:
427
+ # Construct subid from a number of octets
428
+ nextSubId = subId
429
+ subId = 0
430
+ while nextSubId >= 128:
431
+ subId = (subId << 7) + (nextSubId & 0x7F)
432
+ if index >= substrateLen:
433
+ raise error.SubstrateUnderrunError(
434
+ 'Short substrate for sub-OID past %s' % (oid,)
435
+ )
436
+ nextSubId = chunk[index]
437
+ index += 1
438
+ oid += ((subId << 7) + nextSubId,)
439
+ elif subId == 128:
440
+ # ASN.1 spec forbids leading zeros (0x80) in OID
441
+ # encoding, tolerating it opens a vulnerability. See
442
+ # https://www.esat.kuleuven.be/cosic/publications/article-1432.pdf
443
+ # page 7
444
+ raise error.PyAsn1Error('Invalid octet 0x80 in OID encoding')
445
+
446
+ # Decode two leading arcs
447
+ if 0 <= oid[0] <= 39:
448
+ oid = (0,) + oid
449
+ elif 40 <= oid[0] <= 79:
450
+ oid = (1, oid[0] - 40) + oid[1:]
451
+ elif oid[0] >= 80:
452
+ oid = (2, oid[0] - 80) + oid[1:]
453
+ else:
454
+ raise error.PyAsn1Error('Malformed first OID octet: %s' % chunk[0])
455
+
456
+ yield self._createComponent(asn1Spec, tagSet, oid, **options)
457
+
458
+
459
+ class RelativeOIDPayloadDecoder(AbstractSimplePayloadDecoder):
460
+ protoComponent = univ.RelativeOID(())
461
+
462
+ def valueDecoder(self, substrate, asn1Spec,
463
+ tagSet=None, length=None, state=None,
464
+ decodeFun=None, substrateFun=None,
465
+ **options):
466
+ if tagSet[0].tagFormat != tag.tagFormatSimple:
467
+ raise error.PyAsn1Error('Simple tag format expected')
468
+
469
+ for chunk in readFromStream(substrate, length, options):
470
+ if isinstance(chunk, SubstrateUnderrunError):
471
+ yield chunk
472
+
473
+ if not chunk:
474
+ raise error.PyAsn1Error('Empty substrate')
475
+
476
+ reloid = ()
477
+ index = 0
478
+ substrateLen = len(chunk)
479
+ while index < substrateLen:
480
+ subId = chunk[index]
481
+ index += 1
482
+ if subId < 128:
483
+ reloid += (subId,)
484
+ elif subId > 128:
485
+ # Construct subid from a number of octets
486
+ nextSubId = subId
487
+ subId = 0
488
+ while nextSubId >= 128:
489
+ subId = (subId << 7) + (nextSubId & 0x7F)
490
+ if index >= substrateLen:
491
+ raise error.SubstrateUnderrunError(
492
+ 'Short substrate for sub-OID past %s' % (reloid,)
493
+ )
494
+ nextSubId = chunk[index]
495
+ index += 1
496
+ reloid += ((subId << 7) + nextSubId,)
497
+ elif subId == 128:
498
+ # ASN.1 spec forbids leading zeros (0x80) in OID
499
+ # encoding, tolerating it opens a vulnerability. See
500
+ # https://www.esat.kuleuven.be/cosic/publications/article-1432.pdf
501
+ # page 7
502
+ raise error.PyAsn1Error('Invalid octet 0x80 in RELATIVE-OID encoding')
503
+
504
+ yield self._createComponent(asn1Spec, tagSet, reloid, **options)
505
+
506
+
507
+ class RealPayloadDecoder(AbstractSimplePayloadDecoder):
508
+ protoComponent = univ.Real()
509
+
510
+ def valueDecoder(self, substrate, asn1Spec,
511
+ tagSet=None, length=None, state=None,
512
+ decodeFun=None, substrateFun=None,
513
+ **options):
514
+ if tagSet[0].tagFormat != tag.tagFormatSimple:
515
+ raise error.PyAsn1Error('Simple tag format expected')
516
+
517
+ for chunk in readFromStream(substrate, length, options):
518
+ if isinstance(chunk, SubstrateUnderrunError):
519
+ yield chunk
520
+
521
+ if not chunk:
522
+ yield self._createComponent(asn1Spec, tagSet, 0.0, **options)
523
+ return
524
+
525
+ fo = chunk[0]
526
+ chunk = chunk[1:]
527
+ if fo & 0x80: # binary encoding
528
+ if not chunk:
529
+ raise error.PyAsn1Error("Incomplete floating-point value")
530
+
531
+ if LOG:
532
+ LOG('decoding binary encoded REAL')
533
+
534
+ n = (fo & 0x03) + 1
535
+
536
+ if n == 4:
537
+ n = chunk[0]
538
+ chunk = chunk[1:]
539
+
540
+ eo, chunk = chunk[:n], chunk[n:]
541
+
542
+ if not eo or not chunk:
543
+ raise error.PyAsn1Error('Real exponent screwed')
544
+
545
+ e = eo[0] & 0x80 and -1 or 0
546
+
547
+ while eo: # exponent
548
+ e <<= 8
549
+ e |= eo[0]
550
+ eo = eo[1:]
551
+
552
+ b = fo >> 4 & 0x03 # base bits
553
+
554
+ if b > 2:
555
+ raise error.PyAsn1Error('Illegal Real base')
556
+
557
+ if b == 1: # encbase = 8
558
+ e *= 3
559
+
560
+ elif b == 2: # encbase = 16
561
+ e *= 4
562
+ p = 0
563
+
564
+ while chunk: # value
565
+ p <<= 8
566
+ p |= chunk[0]
567
+ chunk = chunk[1:]
568
+
569
+ if fo & 0x40: # sign bit
570
+ p = -p
571
+
572
+ sf = fo >> 2 & 0x03 # scale bits
573
+ p *= 2 ** sf
574
+ value = (p, 2, e)
575
+
576
+ elif fo & 0x40: # infinite value
577
+ if LOG:
578
+ LOG('decoding infinite REAL')
579
+
580
+ value = fo & 0x01 and '-inf' or 'inf'
581
+
582
+ elif fo & 0xc0 == 0: # character encoding
583
+ if not chunk:
584
+ raise error.PyAsn1Error("Incomplete floating-point value")
585
+
586
+ if LOG:
587
+ LOG('decoding character encoded REAL')
588
+
589
+ try:
590
+ if fo & 0x3 == 0x1: # NR1
591
+ value = (int(chunk), 10, 0)
592
+
593
+ elif fo & 0x3 == 0x2: # NR2
594
+ value = float(chunk)
595
+
596
+ elif fo & 0x3 == 0x3: # NR3
597
+ value = float(chunk)
598
+
599
+ else:
600
+ raise error.SubstrateUnderrunError(
601
+ 'Unknown NR (tag %s)' % fo
602
+ )
603
+
604
+ except ValueError:
605
+ raise error.SubstrateUnderrunError(
606
+ 'Bad character Real syntax'
607
+ )
608
+
609
+ else:
610
+ raise error.SubstrateUnderrunError(
611
+ 'Unknown encoding (tag %s)' % fo
612
+ )
613
+
614
+ yield self._createComponent(asn1Spec, tagSet, value, **options)
615
+
616
+
617
+ class AbstractConstructedPayloadDecoder(AbstractPayloadDecoder):
618
+ protoComponent = None
619
+
620
+
621
+ class ConstructedPayloadDecoderBase(AbstractConstructedPayloadDecoder):
622
+ protoRecordComponent = None
623
+ protoSequenceComponent = None
624
+
625
+ def _getComponentTagMap(self, asn1Object, idx):
626
+ raise NotImplementedError
627
+
628
+ def _getComponentPositionByType(self, asn1Object, tagSet, idx):
629
+ raise NotImplementedError
630
+
631
+ def _decodeComponentsSchemaless(
632
+ self, substrate, tagSet=None, decodeFun=None,
633
+ length=None, **options):
634
+
635
+ asn1Object = None
636
+
637
+ components = []
638
+ componentTypes = set()
639
+
640
+ original_position = substrate.tell()
641
+
642
+ while length == -1 or substrate.tell() < original_position + length:
643
+ for component in decodeFun(substrate, **options):
644
+ if isinstance(component, SubstrateUnderrunError):
645
+ yield component
646
+
647
+ if length == -1 and component is eoo.endOfOctets:
648
+ break
649
+
650
+ components.append(component)
651
+ componentTypes.add(component.tagSet)
652
+
653
+ # Now we have to guess is it SEQUENCE/SET or SEQUENCE OF/SET OF
654
+ # The heuristics is:
655
+ # * 1+ components of different types -> likely SEQUENCE/SET
656
+ # * otherwise -> likely SEQUENCE OF/SET OF
657
+ if len(componentTypes) > 1:
658
+ protoComponent = self.protoRecordComponent
659
+
660
+ else:
661
+ protoComponent = self.protoSequenceComponent
662
+
663
+ asn1Object = protoComponent.clone(
664
+ # construct tagSet from base tag from prototype ASN.1 object
665
+ # and additional tags recovered from the substrate
666
+ tagSet=tag.TagSet(protoComponent.tagSet.baseTag, *tagSet.superTags)
667
+ )
668
+
669
+ if LOG:
670
+ LOG('guessed %r container type (pass `asn1Spec` to guide the '
671
+ 'decoder)' % asn1Object)
672
+
673
+ for idx, component in enumerate(components):
674
+ asn1Object.setComponentByPosition(
675
+ idx, component,
676
+ verifyConstraints=False,
677
+ matchTags=False, matchConstraints=False
678
+ )
679
+
680
+ yield asn1Object
681
+
682
+ def valueDecoder(self, substrate, asn1Spec,
683
+ tagSet=None, length=None, state=None,
684
+ decodeFun=None, substrateFun=None,
685
+ **options):
686
+ if tagSet[0].tagFormat != tag.tagFormatConstructed:
687
+ raise error.PyAsn1Error('Constructed tag format expected')
688
+
689
+ original_position = substrate.tell()
690
+
691
+ if substrateFun:
692
+ if asn1Spec is not None:
693
+ asn1Object = asn1Spec.clone()
694
+
695
+ elif self.protoComponent is not None:
696
+ asn1Object = self.protoComponent.clone(tagSet=tagSet)
697
+
698
+ else:
699
+ asn1Object = self.protoRecordComponent, self.protoSequenceComponent
700
+
701
+ for chunk in substrateFun(asn1Object, substrate, length, options):
702
+ yield chunk
703
+
704
+ return
705
+
706
+ if asn1Spec is None:
707
+ for asn1Object in self._decodeComponentsSchemaless(
708
+ substrate, tagSet=tagSet, decodeFun=decodeFun,
709
+ length=length, **options):
710
+ if isinstance(asn1Object, SubstrateUnderrunError):
711
+ yield asn1Object
712
+
713
+ if substrate.tell() < original_position + length:
714
+ if LOG:
715
+ for trailing in readFromStream(substrate, context=options):
716
+ if isinstance(trailing, SubstrateUnderrunError):
717
+ yield trailing
718
+
719
+ LOG('Unused trailing %d octets encountered: %s' % (
720
+ len(trailing), debug.hexdump(trailing)))
721
+
722
+ yield asn1Object
723
+
724
+ return
725
+
726
+ asn1Object = asn1Spec.clone()
727
+ asn1Object.clear()
728
+
729
+ options = self._passAsn1Object(asn1Object, options)
730
+
731
+ if asn1Spec.typeId in (univ.Sequence.typeId, univ.Set.typeId):
732
+
733
+ namedTypes = asn1Spec.componentType
734
+
735
+ isSetType = asn1Spec.typeId == univ.Set.typeId
736
+ isDeterministic = not isSetType and not namedTypes.hasOptionalOrDefault
737
+
738
+ if LOG:
739
+ LOG('decoding %sdeterministic %s type %r chosen by type ID' % (
740
+ not isDeterministic and 'non-' or '', isSetType and 'SET' or '',
741
+ asn1Spec))
742
+
743
+ seenIndices = set()
744
+ idx = 0
745
+ while substrate.tell() - original_position < length:
746
+ if not namedTypes:
747
+ componentType = None
748
+
749
+ elif isSetType:
750
+ componentType = namedTypes.tagMapUnique
751
+
752
+ else:
753
+ try:
754
+ if isDeterministic:
755
+ componentType = namedTypes[idx].asn1Object
756
+
757
+ elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted:
758
+ componentType = namedTypes.getTagMapNearPosition(idx)
759
+
760
+ else:
761
+ componentType = namedTypes[idx].asn1Object
762
+
763
+ except IndexError:
764
+ raise error.PyAsn1Error(
765
+ 'Excessive components decoded at %r' % (asn1Spec,)
766
+ )
767
+
768
+ for component in decodeFun(substrate, componentType, **options):
769
+ if isinstance(component, SubstrateUnderrunError):
770
+ yield component
771
+
772
+ if not isDeterministic and namedTypes:
773
+ if isSetType:
774
+ idx = namedTypes.getPositionByType(component.effectiveTagSet)
775
+
776
+ elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted:
777
+ idx = namedTypes.getPositionNearType(component.effectiveTagSet, idx)
778
+
779
+ asn1Object.setComponentByPosition(
780
+ idx, component,
781
+ verifyConstraints=False,
782
+ matchTags=False, matchConstraints=False
783
+ )
784
+
785
+ seenIndices.add(idx)
786
+ idx += 1
787
+
788
+ if LOG:
789
+ LOG('seen component indices %s' % seenIndices)
790
+
791
+ if namedTypes:
792
+ if not namedTypes.requiredComponents.issubset(seenIndices):
793
+ raise error.PyAsn1Error(
794
+ 'ASN.1 object %s has uninitialized '
795
+ 'components' % asn1Object.__class__.__name__)
796
+
797
+ if namedTypes.hasOpenTypes:
798
+
799
+ openTypes = options.get('openTypes', {})
800
+
801
+ if LOG:
802
+ LOG('user-specified open types map:')
803
+
804
+ for k, v in openTypes.items():
805
+ LOG('%s -> %r' % (k, v))
806
+
807
+ if openTypes or options.get('decodeOpenTypes', False):
808
+
809
+ for idx, namedType in enumerate(namedTypes.namedTypes):
810
+ if not namedType.openType:
811
+ continue
812
+
813
+ if namedType.isOptional and not asn1Object.getComponentByPosition(idx).isValue:
814
+ continue
815
+
816
+ governingValue = asn1Object.getComponentByName(
817
+ namedType.openType.name
818
+ )
819
+
820
+ try:
821
+ openType = openTypes[governingValue]
822
+
823
+ except KeyError:
824
+
825
+ if LOG:
826
+ LOG('default open types map of component '
827
+ '"%s.%s" governed by component "%s.%s"'
828
+ ':' % (asn1Object.__class__.__name__,
829
+ namedType.name,
830
+ asn1Object.__class__.__name__,
831
+ namedType.openType.name))
832
+
833
+ for k, v in namedType.openType.items():
834
+ LOG('%s -> %r' % (k, v))
835
+
836
+ try:
837
+ openType = namedType.openType[governingValue]
838
+
839
+ except KeyError:
840
+ if LOG:
841
+ LOG('failed to resolve open type by governing '
842
+ 'value %r' % (governingValue,))
843
+ continue
844
+
845
+ if LOG:
846
+ LOG('resolved open type %r by governing '
847
+ 'value %r' % (openType, governingValue))
848
+
849
+ containerValue = asn1Object.getComponentByPosition(idx)
850
+
851
+ if containerValue.typeId in (
852
+ univ.SetOf.typeId, univ.SequenceOf.typeId):
853
+
854
+ for pos, containerElement in enumerate(
855
+ containerValue):
856
+
857
+ stream = asSeekableStream(containerValue[pos].asOctets())
858
+
859
+ for component in decodeFun(stream, asn1Spec=openType, **options):
860
+ if isinstance(component, SubstrateUnderrunError):
861
+ yield component
862
+
863
+ containerValue[pos] = component
864
+
865
+ else:
866
+ stream = asSeekableStream(asn1Object.getComponentByPosition(idx).asOctets())
867
+
868
+ for component in decodeFun(stream, asn1Spec=openType, **options):
869
+ if isinstance(component, SubstrateUnderrunError):
870
+ yield component
871
+
872
+ asn1Object.setComponentByPosition(idx, component)
873
+
874
+ else:
875
+ inconsistency = asn1Object.isInconsistent
876
+ if inconsistency:
877
+ raise error.PyAsn1Error(
878
+ f"ASN.1 object {asn1Object.__class__.__name__} is inconsistent")
879
+
880
+ else:
881
+ componentType = asn1Spec.componentType
882
+
883
+ if LOG:
884
+ LOG('decoding type %r chosen by given `asn1Spec`' % componentType)
885
+
886
+ idx = 0
887
+
888
+ while substrate.tell() - original_position < length:
889
+ for component in decodeFun(substrate, componentType, **options):
890
+ if isinstance(component, SubstrateUnderrunError):
891
+ yield component
892
+
893
+ asn1Object.setComponentByPosition(
894
+ idx, component,
895
+ verifyConstraints=False,
896
+ matchTags=False, matchConstraints=False
897
+ )
898
+
899
+ idx += 1
900
+
901
+ yield asn1Object
902
+
903
+ def indefLenValueDecoder(self, substrate, asn1Spec,
904
+ tagSet=None, length=None, state=None,
905
+ decodeFun=None, substrateFun=None,
906
+ **options):
907
+ if tagSet[0].tagFormat != tag.tagFormatConstructed:
908
+ raise error.PyAsn1Error('Constructed tag format expected')
909
+
910
+ if substrateFun is not None:
911
+ if asn1Spec is not None:
912
+ asn1Object = asn1Spec.clone()
913
+
914
+ elif self.protoComponent is not None:
915
+ asn1Object = self.protoComponent.clone(tagSet=tagSet)
916
+
917
+ else:
918
+ asn1Object = self.protoRecordComponent, self.protoSequenceComponent
919
+
920
+ for chunk in substrateFun(asn1Object, substrate, length, options):
921
+ yield chunk
922
+
923
+ return
924
+
925
+ if asn1Spec is None:
926
+ for asn1Object in self._decodeComponentsSchemaless(
927
+ substrate, tagSet=tagSet, decodeFun=decodeFun,
928
+ length=length, **dict(options, allowEoo=True)):
929
+ if isinstance(asn1Object, SubstrateUnderrunError):
930
+ yield asn1Object
931
+
932
+ yield asn1Object
933
+
934
+ return
935
+
936
+ asn1Object = asn1Spec.clone()
937
+ asn1Object.clear()
938
+
939
+ options = self._passAsn1Object(asn1Object, options)
940
+
941
+ if asn1Spec.typeId in (univ.Sequence.typeId, univ.Set.typeId):
942
+
943
+ namedTypes = asn1Object.componentType
944
+
945
+ isSetType = asn1Object.typeId == univ.Set.typeId
946
+ isDeterministic = not isSetType and not namedTypes.hasOptionalOrDefault
947
+
948
+ if LOG:
949
+ LOG('decoding %sdeterministic %s type %r chosen by type ID' % (
950
+ not isDeterministic and 'non-' or '', isSetType and 'SET' or '',
951
+ asn1Spec))
952
+
953
+ seenIndices = set()
954
+
955
+ idx = 0
956
+
957
+ while True: # loop over components
958
+ if len(namedTypes) <= idx:
959
+ asn1Spec = None
960
+
961
+ elif isSetType:
962
+ asn1Spec = namedTypes.tagMapUnique
963
+
964
+ else:
965
+ try:
966
+ if isDeterministic:
967
+ asn1Spec = namedTypes[idx].asn1Object
968
+
969
+ elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted:
970
+ asn1Spec = namedTypes.getTagMapNearPosition(idx)
971
+
972
+ else:
973
+ asn1Spec = namedTypes[idx].asn1Object
974
+
975
+ except IndexError:
976
+ raise error.PyAsn1Error(
977
+ 'Excessive components decoded at %r' % (asn1Object,)
978
+ )
979
+
980
+ for component in decodeFun(substrate, asn1Spec, allowEoo=True, **options):
981
+
982
+ if isinstance(component, SubstrateUnderrunError):
983
+ yield component
984
+
985
+ if component is eoo.endOfOctets:
986
+ break
987
+
988
+ if component is eoo.endOfOctets:
989
+ break
990
+
991
+ if not isDeterministic and namedTypes:
992
+ if isSetType:
993
+ idx = namedTypes.getPositionByType(component.effectiveTagSet)
994
+
995
+ elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted:
996
+ idx = namedTypes.getPositionNearType(component.effectiveTagSet, idx)
997
+
998
+ asn1Object.setComponentByPosition(
999
+ idx, component,
1000
+ verifyConstraints=False,
1001
+ matchTags=False, matchConstraints=False
1002
+ )
1003
+
1004
+ seenIndices.add(idx)
1005
+ idx += 1
1006
+
1007
+ if LOG:
1008
+ LOG('seen component indices %s' % seenIndices)
1009
+
1010
+ if namedTypes:
1011
+ if not namedTypes.requiredComponents.issubset(seenIndices):
1012
+ raise error.PyAsn1Error(
1013
+ 'ASN.1 object %s has uninitialized '
1014
+ 'components' % asn1Object.__class__.__name__)
1015
+
1016
+ if namedTypes.hasOpenTypes:
1017
+
1018
+ openTypes = options.get('openTypes', {})
1019
+
1020
+ if LOG:
1021
+ LOG('user-specified open types map:')
1022
+
1023
+ for k, v in openTypes.items():
1024
+ LOG('%s -> %r' % (k, v))
1025
+
1026
+ if openTypes or options.get('decodeOpenTypes', False):
1027
+
1028
+ for idx, namedType in enumerate(namedTypes.namedTypes):
1029
+ if not namedType.openType:
1030
+ continue
1031
+
1032
+ if namedType.isOptional and not asn1Object.getComponentByPosition(idx).isValue:
1033
+ continue
1034
+
1035
+ governingValue = asn1Object.getComponentByName(
1036
+ namedType.openType.name
1037
+ )
1038
+
1039
+ try:
1040
+ openType = openTypes[governingValue]
1041
+
1042
+ except KeyError:
1043
+
1044
+ if LOG:
1045
+ LOG('default open types map of component '
1046
+ '"%s.%s" governed by component "%s.%s"'
1047
+ ':' % (asn1Object.__class__.__name__,
1048
+ namedType.name,
1049
+ asn1Object.__class__.__name__,
1050
+ namedType.openType.name))
1051
+
1052
+ for k, v in namedType.openType.items():
1053
+ LOG('%s -> %r' % (k, v))
1054
+
1055
+ try:
1056
+ openType = namedType.openType[governingValue]
1057
+
1058
+ except KeyError:
1059
+ if LOG:
1060
+ LOG('failed to resolve open type by governing '
1061
+ 'value %r' % (governingValue,))
1062
+ continue
1063
+
1064
+ if LOG:
1065
+ LOG('resolved open type %r by governing '
1066
+ 'value %r' % (openType, governingValue))
1067
+
1068
+ containerValue = asn1Object.getComponentByPosition(idx)
1069
+
1070
+ if containerValue.typeId in (
1071
+ univ.SetOf.typeId, univ.SequenceOf.typeId):
1072
+
1073
+ for pos, containerElement in enumerate(
1074
+ containerValue):
1075
+
1076
+ stream = asSeekableStream(containerValue[pos].asOctets())
1077
+
1078
+ for component in decodeFun(stream, asn1Spec=openType,
1079
+ **dict(options, allowEoo=True)):
1080
+ if isinstance(component, SubstrateUnderrunError):
1081
+ yield component
1082
+
1083
+ if component is eoo.endOfOctets:
1084
+ break
1085
+
1086
+ containerValue[pos] = component
1087
+
1088
+ else:
1089
+ stream = asSeekableStream(asn1Object.getComponentByPosition(idx).asOctets())
1090
+ for component in decodeFun(stream, asn1Spec=openType,
1091
+ **dict(options, allowEoo=True)):
1092
+ if isinstance(component, SubstrateUnderrunError):
1093
+ yield component
1094
+
1095
+ if component is eoo.endOfOctets:
1096
+ break
1097
+
1098
+ asn1Object.setComponentByPosition(idx, component)
1099
+
1100
+ else:
1101
+ inconsistency = asn1Object.isInconsistent
1102
+ if inconsistency:
1103
+ raise error.PyAsn1Error(
1104
+ f"ASN.1 object {asn1Object.__class__.__name__} is inconsistent")
1105
+
1106
+ else:
1107
+ componentType = asn1Spec.componentType
1108
+
1109
+ if LOG:
1110
+ LOG('decoding type %r chosen by given `asn1Spec`' % componentType)
1111
+
1112
+ idx = 0
1113
+
1114
+ while True:
1115
+
1116
+ for component in decodeFun(
1117
+ substrate, componentType, allowEoo=True, **options):
1118
+
1119
+ if isinstance(component, SubstrateUnderrunError):
1120
+ yield component
1121
+
1122
+ if component is eoo.endOfOctets:
1123
+ break
1124
+
1125
+ if component is eoo.endOfOctets:
1126
+ break
1127
+
1128
+ asn1Object.setComponentByPosition(
1129
+ idx, component,
1130
+ verifyConstraints=False,
1131
+ matchTags=False, matchConstraints=False
1132
+ )
1133
+
1134
+ idx += 1
1135
+
1136
+ yield asn1Object
1137
+
1138
+
1139
+ class SequenceOrSequenceOfPayloadDecoder(ConstructedPayloadDecoderBase):
1140
+ protoRecordComponent = univ.Sequence()
1141
+ protoSequenceComponent = univ.SequenceOf()
1142
+
1143
+
1144
+ class SequencePayloadDecoder(SequenceOrSequenceOfPayloadDecoder):
1145
+ protoComponent = univ.Sequence()
1146
+
1147
+
1148
+ class SequenceOfPayloadDecoder(SequenceOrSequenceOfPayloadDecoder):
1149
+ protoComponent = univ.SequenceOf()
1150
+
1151
+
1152
+ class SetOrSetOfPayloadDecoder(ConstructedPayloadDecoderBase):
1153
+ protoRecordComponent = univ.Set()
1154
+ protoSequenceComponent = univ.SetOf()
1155
+
1156
+
1157
+ class SetPayloadDecoder(SetOrSetOfPayloadDecoder):
1158
+ protoComponent = univ.Set()
1159
+
1160
+
1161
+ class SetOfPayloadDecoder(SetOrSetOfPayloadDecoder):
1162
+ protoComponent = univ.SetOf()
1163
+
1164
+
1165
+ class ChoicePayloadDecoder(ConstructedPayloadDecoderBase):
1166
+ protoComponent = univ.Choice()
1167
+
1168
+ def valueDecoder(self, substrate, asn1Spec,
1169
+ tagSet=None, length=None, state=None,
1170
+ decodeFun=None, substrateFun=None,
1171
+ **options):
1172
+ if asn1Spec is None:
1173
+ asn1Object = self.protoComponent.clone(tagSet=tagSet)
1174
+
1175
+ else:
1176
+ asn1Object = asn1Spec.clone()
1177
+
1178
+ if substrateFun:
1179
+ for chunk in substrateFun(asn1Object, substrate, length, options):
1180
+ yield chunk
1181
+
1182
+ return
1183
+
1184
+ options = self._passAsn1Object(asn1Object, options)
1185
+
1186
+ if asn1Object.tagSet == tagSet:
1187
+ if LOG:
1188
+ LOG('decoding %s as explicitly tagged CHOICE' % (tagSet,))
1189
+
1190
+ for component in decodeFun(
1191
+ substrate, asn1Object.componentTagMap, **options):
1192
+ if isinstance(component, SubstrateUnderrunError):
1193
+ yield component
1194
+
1195
+ else:
1196
+ if LOG:
1197
+ LOG('decoding %s as untagged CHOICE' % (tagSet,))
1198
+
1199
+ for component in decodeFun(
1200
+ substrate, asn1Object.componentTagMap, tagSet, length,
1201
+ state, **options):
1202
+ if isinstance(component, SubstrateUnderrunError):
1203
+ yield component
1204
+
1205
+ effectiveTagSet = component.effectiveTagSet
1206
+
1207
+ if LOG:
1208
+ LOG('decoded component %s, effective tag set %s' % (component, effectiveTagSet))
1209
+
1210
+ asn1Object.setComponentByType(
1211
+ effectiveTagSet, component,
1212
+ verifyConstraints=False,
1213
+ matchTags=False, matchConstraints=False,
1214
+ innerFlag=False
1215
+ )
1216
+
1217
+ yield asn1Object
1218
+
1219
+ def indefLenValueDecoder(self, substrate, asn1Spec,
1220
+ tagSet=None, length=None, state=None,
1221
+ decodeFun=None, substrateFun=None,
1222
+ **options):
1223
+ if asn1Spec is None:
1224
+ asn1Object = self.protoComponent.clone(tagSet=tagSet)
1225
+
1226
+ else:
1227
+ asn1Object = asn1Spec.clone()
1228
+
1229
+ if substrateFun:
1230
+ for chunk in substrateFun(asn1Object, substrate, length, options):
1231
+ yield chunk
1232
+
1233
+ return
1234
+
1235
+ options = self._passAsn1Object(asn1Object, options)
1236
+
1237
+ isTagged = asn1Object.tagSet == tagSet
1238
+
1239
+ if LOG:
1240
+ LOG('decoding %s as %stagged CHOICE' % (
1241
+ tagSet, isTagged and 'explicitly ' or 'un'))
1242
+
1243
+ while True:
1244
+
1245
+ if isTagged:
1246
+ iterator = decodeFun(
1247
+ substrate, asn1Object.componentType.tagMapUnique,
1248
+ **dict(options, allowEoo=True))
1249
+
1250
+ else:
1251
+ iterator = decodeFun(
1252
+ substrate, asn1Object.componentType.tagMapUnique,
1253
+ tagSet, length, state, **dict(options, allowEoo=True))
1254
+
1255
+ for component in iterator:
1256
+
1257
+ if isinstance(component, SubstrateUnderrunError):
1258
+ yield component
1259
+
1260
+ if component is eoo.endOfOctets:
1261
+ break
1262
+
1263
+ effectiveTagSet = component.effectiveTagSet
1264
+
1265
+ if LOG:
1266
+ LOG('decoded component %s, effective tag set '
1267
+ '%s' % (component, effectiveTagSet))
1268
+
1269
+ asn1Object.setComponentByType(
1270
+ effectiveTagSet, component,
1271
+ verifyConstraints=False,
1272
+ matchTags=False, matchConstraints=False,
1273
+ innerFlag=False
1274
+ )
1275
+
1276
+ if not isTagged:
1277
+ break
1278
+
1279
+ if not isTagged or component is eoo.endOfOctets:
1280
+ break
1281
+
1282
+ yield asn1Object
1283
+
1284
+
1285
+ class AnyPayloadDecoder(AbstractSimplePayloadDecoder):
1286
+ protoComponent = univ.Any()
1287
+
1288
+ def valueDecoder(self, substrate, asn1Spec,
1289
+ tagSet=None, length=None, state=None,
1290
+ decodeFun=None, substrateFun=None,
1291
+ **options):
1292
+ if asn1Spec is None:
1293
+ isUntagged = True
1294
+
1295
+ elif asn1Spec.__class__ is tagmap.TagMap:
1296
+ isUntagged = tagSet not in asn1Spec.tagMap
1297
+
1298
+ else:
1299
+ isUntagged = tagSet != asn1Spec.tagSet
1300
+
1301
+ if isUntagged:
1302
+ fullPosition = substrate.markedPosition
1303
+ currentPosition = substrate.tell()
1304
+
1305
+ substrate.seek(fullPosition, os.SEEK_SET)
1306
+ length += currentPosition - fullPosition
1307
+
1308
+ if LOG:
1309
+ for chunk in peekIntoStream(substrate, length):
1310
+ if isinstance(chunk, SubstrateUnderrunError):
1311
+ yield chunk
1312
+ LOG('decoding as untagged ANY, substrate '
1313
+ '%s' % debug.hexdump(chunk))
1314
+
1315
+ if substrateFun:
1316
+ for chunk in substrateFun(
1317
+ self._createComponent(asn1Spec, tagSet, noValue, **options),
1318
+ substrate, length, options):
1319
+ yield chunk
1320
+
1321
+ return
1322
+
1323
+ for chunk in readFromStream(substrate, length, options):
1324
+ if isinstance(chunk, SubstrateUnderrunError):
1325
+ yield chunk
1326
+
1327
+ yield self._createComponent(asn1Spec, tagSet, chunk, **options)
1328
+
1329
+ def indefLenValueDecoder(self, substrate, asn1Spec,
1330
+ tagSet=None, length=None, state=None,
1331
+ decodeFun=None, substrateFun=None,
1332
+ **options):
1333
+ if asn1Spec is None:
1334
+ isTagged = False
1335
+
1336
+ elif asn1Spec.__class__ is tagmap.TagMap:
1337
+ isTagged = tagSet in asn1Spec.tagMap
1338
+
1339
+ else:
1340
+ isTagged = tagSet == asn1Spec.tagSet
1341
+
1342
+ if isTagged:
1343
+ # tagged Any type -- consume header substrate
1344
+ chunk = b''
1345
+
1346
+ if LOG:
1347
+ LOG('decoding as tagged ANY')
1348
+
1349
+ else:
1350
+ # TODO: Seems not to be tested
1351
+ fullPosition = substrate.markedPosition
1352
+ currentPosition = substrate.tell()
1353
+
1354
+ substrate.seek(fullPosition, os.SEEK_SET)
1355
+ for chunk in readFromStream(substrate, currentPosition - fullPosition, options):
1356
+ if isinstance(chunk, SubstrateUnderrunError):
1357
+ yield chunk
1358
+
1359
+ if LOG:
1360
+ LOG('decoding as untagged ANY, header substrate %s' % debug.hexdump(chunk))
1361
+
1362
+ # Any components do not inherit initial tag
1363
+ asn1Spec = self.protoComponent
1364
+
1365
+ if substrateFun and substrateFun is not self.substrateCollector:
1366
+ asn1Object = self._createComponent(
1367
+ asn1Spec, tagSet, noValue, **options)
1368
+
1369
+ for chunk in substrateFun(
1370
+ asn1Object, chunk + substrate, length + len(chunk), options):
1371
+ yield chunk
1372
+
1373
+ return
1374
+
1375
+ if LOG:
1376
+ LOG('assembling constructed serialization')
1377
+
1378
+ # All inner fragments are of the same type, treat them as octet string
1379
+ substrateFun = self.substrateCollector
1380
+
1381
+ while True: # loop over fragments
1382
+
1383
+ for component in decodeFun(
1384
+ substrate, asn1Spec, substrateFun=substrateFun,
1385
+ allowEoo=True, **options):
1386
+
1387
+ if isinstance(component, SubstrateUnderrunError):
1388
+ yield component
1389
+
1390
+ if component is eoo.endOfOctets:
1391
+ break
1392
+
1393
+ if component is eoo.endOfOctets:
1394
+ break
1395
+
1396
+ chunk += component
1397
+
1398
+ if substrateFun:
1399
+ yield chunk # TODO: Weird
1400
+
1401
+ else:
1402
+ yield self._createComponent(asn1Spec, tagSet, chunk, **options)
1403
+
1404
+
1405
+ # character string types
1406
+ class UTF8StringPayloadDecoder(OctetStringPayloadDecoder):
1407
+ protoComponent = char.UTF8String()
1408
+
1409
+
1410
+ class NumericStringPayloadDecoder(OctetStringPayloadDecoder):
1411
+ protoComponent = char.NumericString()
1412
+
1413
+
1414
+ class PrintableStringPayloadDecoder(OctetStringPayloadDecoder):
1415
+ protoComponent = char.PrintableString()
1416
+
1417
+
1418
+ class TeletexStringPayloadDecoder(OctetStringPayloadDecoder):
1419
+ protoComponent = char.TeletexString()
1420
+
1421
+
1422
+ class VideotexStringPayloadDecoder(OctetStringPayloadDecoder):
1423
+ protoComponent = char.VideotexString()
1424
+
1425
+
1426
+ class IA5StringPayloadDecoder(OctetStringPayloadDecoder):
1427
+ protoComponent = char.IA5String()
1428
+
1429
+
1430
+ class GraphicStringPayloadDecoder(OctetStringPayloadDecoder):
1431
+ protoComponent = char.GraphicString()
1432
+
1433
+
1434
+ class VisibleStringPayloadDecoder(OctetStringPayloadDecoder):
1435
+ protoComponent = char.VisibleString()
1436
+
1437
+
1438
+ class GeneralStringPayloadDecoder(OctetStringPayloadDecoder):
1439
+ protoComponent = char.GeneralString()
1440
+
1441
+
1442
+ class UniversalStringPayloadDecoder(OctetStringPayloadDecoder):
1443
+ protoComponent = char.UniversalString()
1444
+
1445
+
1446
+ class BMPStringPayloadDecoder(OctetStringPayloadDecoder):
1447
+ protoComponent = char.BMPString()
1448
+
1449
+
1450
+ # "useful" types
1451
+ class ObjectDescriptorPayloadDecoder(OctetStringPayloadDecoder):
1452
+ protoComponent = useful.ObjectDescriptor()
1453
+
1454
+
1455
+ class GeneralizedTimePayloadDecoder(OctetStringPayloadDecoder):
1456
+ protoComponent = useful.GeneralizedTime()
1457
+
1458
+
1459
+ class UTCTimePayloadDecoder(OctetStringPayloadDecoder):
1460
+ protoComponent = useful.UTCTime()
1461
+
1462
+
1463
+ TAG_MAP = {
1464
+ univ.Integer.tagSet: IntegerPayloadDecoder(),
1465
+ univ.Boolean.tagSet: BooleanPayloadDecoder(),
1466
+ univ.BitString.tagSet: BitStringPayloadDecoder(),
1467
+ univ.OctetString.tagSet: OctetStringPayloadDecoder(),
1468
+ univ.Null.tagSet: NullPayloadDecoder(),
1469
+ univ.ObjectIdentifier.tagSet: ObjectIdentifierPayloadDecoder(),
1470
+ univ.RelativeOID.tagSet: RelativeOIDPayloadDecoder(),
1471
+ univ.Enumerated.tagSet: IntegerPayloadDecoder(),
1472
+ univ.Real.tagSet: RealPayloadDecoder(),
1473
+ univ.Sequence.tagSet: SequenceOrSequenceOfPayloadDecoder(), # conflicts with SequenceOf
1474
+ univ.Set.tagSet: SetOrSetOfPayloadDecoder(), # conflicts with SetOf
1475
+ univ.Choice.tagSet: ChoicePayloadDecoder(), # conflicts with Any
1476
+ # character string types
1477
+ char.UTF8String.tagSet: UTF8StringPayloadDecoder(),
1478
+ char.NumericString.tagSet: NumericStringPayloadDecoder(),
1479
+ char.PrintableString.tagSet: PrintableStringPayloadDecoder(),
1480
+ char.TeletexString.tagSet: TeletexStringPayloadDecoder(),
1481
+ char.VideotexString.tagSet: VideotexStringPayloadDecoder(),
1482
+ char.IA5String.tagSet: IA5StringPayloadDecoder(),
1483
+ char.GraphicString.tagSet: GraphicStringPayloadDecoder(),
1484
+ char.VisibleString.tagSet: VisibleStringPayloadDecoder(),
1485
+ char.GeneralString.tagSet: GeneralStringPayloadDecoder(),
1486
+ char.UniversalString.tagSet: UniversalStringPayloadDecoder(),
1487
+ char.BMPString.tagSet: BMPStringPayloadDecoder(),
1488
+ # useful types
1489
+ useful.ObjectDescriptor.tagSet: ObjectDescriptorPayloadDecoder(),
1490
+ useful.GeneralizedTime.tagSet: GeneralizedTimePayloadDecoder(),
1491
+ useful.UTCTime.tagSet: UTCTimePayloadDecoder()
1492
+ }
1493
+
1494
+ # Type-to-codec map for ambiguous ASN.1 types
1495
+ TYPE_MAP = {
1496
+ univ.Set.typeId: SetPayloadDecoder(),
1497
+ univ.SetOf.typeId: SetOfPayloadDecoder(),
1498
+ univ.Sequence.typeId: SequencePayloadDecoder(),
1499
+ univ.SequenceOf.typeId: SequenceOfPayloadDecoder(),
1500
+ univ.Choice.typeId: ChoicePayloadDecoder(),
1501
+ univ.Any.typeId: AnyPayloadDecoder()
1502
+ }
1503
+
1504
+ # Put in non-ambiguous types for faster codec lookup
1505
+ for typeDecoder in TAG_MAP.values():
1506
+ if typeDecoder.protoComponent is not None:
1507
+ typeId = typeDecoder.protoComponent.__class__.typeId
1508
+ if typeId is not None and typeId not in TYPE_MAP:
1509
+ TYPE_MAP[typeId] = typeDecoder
1510
+
1511
+
1512
+ (stDecodeTag,
1513
+ stDecodeLength,
1514
+ stGetValueDecoder,
1515
+ stGetValueDecoderByAsn1Spec,
1516
+ stGetValueDecoderByTag,
1517
+ stTryAsExplicitTag,
1518
+ stDecodeValue,
1519
+ stDumpRawValue,
1520
+ stErrorCondition,
1521
+ stStop) = [x for x in range(10)]
1522
+
1523
+
1524
+ EOO_SENTINEL = bytes((0, 0))
1525
+
1526
+
1527
+ class SingleItemDecoder(object):
1528
+ defaultErrorState = stErrorCondition
1529
+ #defaultErrorState = stDumpRawValue
1530
+ defaultRawDecoder = AnyPayloadDecoder()
1531
+
1532
+ supportIndefLength = True
1533
+
1534
+ TAG_MAP = TAG_MAP
1535
+ TYPE_MAP = TYPE_MAP
1536
+
1537
+ def __init__(self, tagMap=_MISSING, typeMap=_MISSING, **ignored):
1538
+ self._tagMap = tagMap if tagMap is not _MISSING else self.TAG_MAP
1539
+ self._typeMap = typeMap if typeMap is not _MISSING else self.TYPE_MAP
1540
+
1541
+ # Tag & TagSet objects caches
1542
+ self._tagCache = {}
1543
+ self._tagSetCache = {}
1544
+
1545
+ def __call__(self, substrate, asn1Spec=None,
1546
+ tagSet=None, length=None, state=stDecodeTag,
1547
+ decodeFun=None, substrateFun=None,
1548
+ **options):
1549
+
1550
+ allowEoo = options.pop('allowEoo', False)
1551
+
1552
+ if LOG:
1553
+ LOG('decoder called at scope %s with state %d, working with up '
1554
+ 'to %s octets of substrate: '
1555
+ '%s' % (debug.scope, state, length, substrate))
1556
+
1557
+ # Look for end-of-octets sentinel
1558
+ if allowEoo and self.supportIndefLength:
1559
+
1560
+ for eoo_candidate in readFromStream(substrate, 2, options):
1561
+ if isinstance(eoo_candidate, SubstrateUnderrunError):
1562
+ yield eoo_candidate
1563
+
1564
+ if eoo_candidate == EOO_SENTINEL:
1565
+ if LOG:
1566
+ LOG('end-of-octets sentinel found')
1567
+ yield eoo.endOfOctets
1568
+ return
1569
+
1570
+ else:
1571
+ substrate.seek(-2, os.SEEK_CUR)
1572
+
1573
+ tagMap = self._tagMap
1574
+ typeMap = self._typeMap
1575
+ tagCache = self._tagCache
1576
+ tagSetCache = self._tagSetCache
1577
+
1578
+ value = noValue
1579
+
1580
+ substrate.markedPosition = substrate.tell()
1581
+
1582
+ while state is not stStop:
1583
+
1584
+ if state is stDecodeTag:
1585
+ # Decode tag
1586
+ isShortTag = True
1587
+
1588
+ for firstByte in readFromStream(substrate, 1, options):
1589
+ if isinstance(firstByte, SubstrateUnderrunError):
1590
+ yield firstByte
1591
+
1592
+ firstOctet = ord(firstByte)
1593
+
1594
+ try:
1595
+ lastTag = tagCache[firstOctet]
1596
+
1597
+ except KeyError:
1598
+ integerTag = firstOctet
1599
+ tagClass = integerTag & 0xC0
1600
+ tagFormat = integerTag & 0x20
1601
+ tagId = integerTag & 0x1F
1602
+
1603
+ if tagId == 0x1F:
1604
+ isShortTag = False
1605
+ lengthOctetIdx = 0
1606
+ tagId = 0
1607
+
1608
+ while True:
1609
+ for integerByte in readFromStream(substrate, 1, options):
1610
+ if isinstance(integerByte, SubstrateUnderrunError):
1611
+ yield integerByte
1612
+
1613
+ if not integerByte:
1614
+ raise error.SubstrateUnderrunError(
1615
+ 'Short octet stream on long tag decoding'
1616
+ )
1617
+
1618
+ integerTag = ord(integerByte)
1619
+ lengthOctetIdx += 1
1620
+ tagId <<= 7
1621
+ tagId |= (integerTag & 0x7F)
1622
+
1623
+ if not integerTag & 0x80:
1624
+ break
1625
+
1626
+ lastTag = tag.Tag(
1627
+ tagClass=tagClass, tagFormat=tagFormat, tagId=tagId
1628
+ )
1629
+
1630
+ if isShortTag:
1631
+ # cache short tags
1632
+ tagCache[firstOctet] = lastTag
1633
+
1634
+ if tagSet is None:
1635
+ if isShortTag:
1636
+ try:
1637
+ tagSet = tagSetCache[firstOctet]
1638
+
1639
+ except KeyError:
1640
+ # base tag not recovered
1641
+ tagSet = tag.TagSet((), lastTag)
1642
+ tagSetCache[firstOctet] = tagSet
1643
+ else:
1644
+ tagSet = tag.TagSet((), lastTag)
1645
+
1646
+ else:
1647
+ tagSet = lastTag + tagSet
1648
+
1649
+ state = stDecodeLength
1650
+
1651
+ if LOG:
1652
+ LOG('tag decoded into %s, decoding length' % tagSet)
1653
+
1654
+ if state is stDecodeLength:
1655
+ # Decode length
1656
+ for firstOctet in readFromStream(substrate, 1, options):
1657
+ if isinstance(firstOctet, SubstrateUnderrunError):
1658
+ yield firstOctet
1659
+
1660
+ firstOctet = ord(firstOctet)
1661
+
1662
+ if firstOctet < 128:
1663
+ length = firstOctet
1664
+
1665
+ elif firstOctet > 128:
1666
+ size = firstOctet & 0x7F
1667
+ # encoded in size bytes
1668
+ for encodedLength in readFromStream(substrate, size, options):
1669
+ if isinstance(encodedLength, SubstrateUnderrunError):
1670
+ yield encodedLength
1671
+ encodedLength = list(encodedLength)
1672
+ # missing check on maximum size, which shouldn't be a
1673
+ # problem, we can handle more than is possible
1674
+ if len(encodedLength) != size:
1675
+ raise error.SubstrateUnderrunError(
1676
+ '%s<%s at %s' % (size, len(encodedLength), tagSet)
1677
+ )
1678
+
1679
+ length = 0
1680
+ for lengthOctet in encodedLength:
1681
+ length <<= 8
1682
+ length |= lengthOctet
1683
+ size += 1
1684
+
1685
+ else: # 128 means indefinite
1686
+ length = -1
1687
+
1688
+ if length == -1 and not self.supportIndefLength:
1689
+ raise error.PyAsn1Error('Indefinite length encoding not supported by this codec')
1690
+
1691
+ state = stGetValueDecoder
1692
+
1693
+ if LOG:
1694
+ LOG('value length decoded into %d' % length)
1695
+
1696
+ if state is stGetValueDecoder:
1697
+ if asn1Spec is None:
1698
+ state = stGetValueDecoderByTag
1699
+
1700
+ else:
1701
+ state = stGetValueDecoderByAsn1Spec
1702
+ #
1703
+ # There're two ways of creating subtypes in ASN.1 what influences
1704
+ # decoder operation. These methods are:
1705
+ # 1) Either base types used in or no IMPLICIT tagging has been
1706
+ # applied on subtyping.
1707
+ # 2) Subtype syntax drops base type information (by means of
1708
+ # IMPLICIT tagging.
1709
+ # The first case allows for complete tag recovery from substrate
1710
+ # while the second one requires original ASN.1 type spec for
1711
+ # decoding.
1712
+ #
1713
+ # In either case a set of tags (tagSet) is coming from substrate
1714
+ # in an incremental, tag-by-tag fashion (this is the case of
1715
+ # EXPLICIT tag which is most basic). Outermost tag comes first
1716
+ # from the wire.
1717
+ #
1718
+ if state is stGetValueDecoderByTag:
1719
+ try:
1720
+ concreteDecoder = tagMap[tagSet]
1721
+
1722
+ except KeyError:
1723
+ concreteDecoder = None
1724
+
1725
+ if concreteDecoder:
1726
+ state = stDecodeValue
1727
+
1728
+ else:
1729
+ try:
1730
+ concreteDecoder = tagMap[tagSet[:1]]
1731
+
1732
+ except KeyError:
1733
+ concreteDecoder = None
1734
+
1735
+ if concreteDecoder:
1736
+ state = stDecodeValue
1737
+ else:
1738
+ state = stTryAsExplicitTag
1739
+
1740
+ if LOG:
1741
+ LOG('codec %s chosen by a built-in type, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "<none>", state is stDecodeValue and 'value' or 'as explicit tag'))
1742
+ debug.scope.push(concreteDecoder is None and '?' or concreteDecoder.protoComponent.__class__.__name__)
1743
+
1744
+ if state is stGetValueDecoderByAsn1Spec:
1745
+
1746
+ if asn1Spec.__class__ is tagmap.TagMap:
1747
+ try:
1748
+ chosenSpec = asn1Spec[tagSet]
1749
+
1750
+ except KeyError:
1751
+ chosenSpec = None
1752
+
1753
+ if LOG:
1754
+ LOG('candidate ASN.1 spec is a map of:')
1755
+
1756
+ for firstOctet, v in asn1Spec.presentTypes.items():
1757
+ LOG(' %s -> %s' % (firstOctet, v.__class__.__name__))
1758
+
1759
+ if asn1Spec.skipTypes:
1760
+ LOG('but neither of: ')
1761
+ for firstOctet, v in asn1Spec.skipTypes.items():
1762
+ LOG(' %s -> %s' % (firstOctet, v.__class__.__name__))
1763
+ LOG('new candidate ASN.1 spec is %s, chosen by %s' % (chosenSpec is None and '<none>' or chosenSpec.prettyPrintType(), tagSet))
1764
+
1765
+ elif tagSet == asn1Spec.tagSet or tagSet in asn1Spec.tagMap:
1766
+ chosenSpec = asn1Spec
1767
+ if LOG:
1768
+ LOG('candidate ASN.1 spec is %s' % asn1Spec.__class__.__name__)
1769
+
1770
+ else:
1771
+ chosenSpec = None
1772
+
1773
+ if chosenSpec is not None:
1774
+ try:
1775
+ # ambiguous type or just faster codec lookup
1776
+ concreteDecoder = typeMap[chosenSpec.typeId]
1777
+
1778
+ if LOG:
1779
+ LOG('value decoder chosen for an ambiguous type by type ID %s' % (chosenSpec.typeId,))
1780
+
1781
+ except KeyError:
1782
+ # use base type for codec lookup to recover untagged types
1783
+ baseTagSet = tag.TagSet(chosenSpec.tagSet.baseTag, chosenSpec.tagSet.baseTag)
1784
+ try:
1785
+ # base type or tagged subtype
1786
+ concreteDecoder = tagMap[baseTagSet]
1787
+
1788
+ if LOG:
1789
+ LOG('value decoder chosen by base %s' % (baseTagSet,))
1790
+
1791
+ except KeyError:
1792
+ concreteDecoder = None
1793
+
1794
+ if concreteDecoder:
1795
+ asn1Spec = chosenSpec
1796
+ state = stDecodeValue
1797
+
1798
+ else:
1799
+ state = stTryAsExplicitTag
1800
+
1801
+ else:
1802
+ concreteDecoder = None
1803
+ state = stTryAsExplicitTag
1804
+
1805
+ if LOG:
1806
+ LOG('codec %s chosen by ASN.1 spec, decoding %s' % (state is stDecodeValue and concreteDecoder.__class__.__name__ or "<none>", state is stDecodeValue and 'value' or 'as explicit tag'))
1807
+ debug.scope.push(chosenSpec is None and '?' or chosenSpec.__class__.__name__)
1808
+
1809
+ if state is stDecodeValue:
1810
+ if not options.get('recursiveFlag', True) and not substrateFun: # deprecate this
1811
+ def substrateFun(asn1Object, _substrate, _length, _options):
1812
+ """Legacy hack to keep the recursiveFlag=False option supported.
1813
+
1814
+ The decode(..., substrateFun=userCallback) option was introduced in 0.1.4 as a generalization
1815
+ of the old recursiveFlag=False option. Users should pass their callback instead of using
1816
+ recursiveFlag.
1817
+ """
1818
+ yield asn1Object
1819
+
1820
+ original_position = substrate.tell()
1821
+
1822
+ if length == -1: # indef length
1823
+ for value in concreteDecoder.indefLenValueDecoder(
1824
+ substrate, asn1Spec,
1825
+ tagSet, length, stGetValueDecoder,
1826
+ self, substrateFun, **options):
1827
+ if isinstance(value, SubstrateUnderrunError):
1828
+ yield value
1829
+
1830
+ else:
1831
+ for value in concreteDecoder.valueDecoder(
1832
+ substrate, asn1Spec,
1833
+ tagSet, length, stGetValueDecoder,
1834
+ self, substrateFun, **options):
1835
+ if isinstance(value, SubstrateUnderrunError):
1836
+ yield value
1837
+
1838
+ bytesRead = substrate.tell() - original_position
1839
+ if not substrateFun and bytesRead != length:
1840
+ raise PyAsn1Error(
1841
+ "Read %s bytes instead of expected %s." % (bytesRead, length))
1842
+ elif substrateFun and bytesRead > length:
1843
+ # custom substrateFun may be used for partial decoding, reading less is expected there
1844
+ raise PyAsn1Error(
1845
+ "Read %s bytes are more than expected %s." % (bytesRead, length))
1846
+
1847
+ if LOG:
1848
+ LOG('codec %s yields type %s, value:\n%s\n...' % (
1849
+ concreteDecoder.__class__.__name__, value.__class__.__name__,
1850
+ isinstance(value, base.Asn1Item) and value.prettyPrint() or value))
1851
+
1852
+ state = stStop
1853
+ break
1854
+
1855
+ if state is stTryAsExplicitTag:
1856
+ if (tagSet and
1857
+ tagSet[0].tagFormat == tag.tagFormatConstructed and
1858
+ tagSet[0].tagClass != tag.tagClassUniversal):
1859
+ # Assume explicit tagging
1860
+ concreteDecoder = rawPayloadDecoder
1861
+ state = stDecodeValue
1862
+
1863
+ else:
1864
+ concreteDecoder = None
1865
+ state = self.defaultErrorState
1866
+
1867
+ if LOG:
1868
+ LOG('codec %s chosen, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "<none>", state is stDecodeValue and 'value' or 'as failure'))
1869
+
1870
+ if state is stDumpRawValue:
1871
+ concreteDecoder = self.defaultRawDecoder
1872
+
1873
+ if LOG:
1874
+ LOG('codec %s chosen, decoding value' % concreteDecoder.__class__.__name__)
1875
+
1876
+ state = stDecodeValue
1877
+
1878
+ if state is stErrorCondition:
1879
+ raise error.PyAsn1Error(
1880
+ '%s not in asn1Spec: %r' % (tagSet, asn1Spec)
1881
+ )
1882
+
1883
+ if LOG:
1884
+ debug.scope.pop()
1885
+ LOG('decoder left scope %s, call completed' % debug.scope)
1886
+
1887
+ yield value
1888
+
1889
+
1890
+ class StreamingDecoder(object):
1891
+ """Create an iterator that turns BER/CER/DER byte stream into ASN.1 objects.
1892
+
1893
+ On each iteration, consume whatever BER/CER/DER serialization is
1894
+ available in the `substrate` stream-like object and turns it into
1895
+ one or more, possibly nested, ASN.1 objects.
1896
+
1897
+ Parameters
1898
+ ----------
1899
+ substrate: :py:class:`file`, :py:class:`io.BytesIO`
1900
+ BER/CER/DER serialization in form of a byte stream
1901
+
1902
+ Keyword Args
1903
+ ------------
1904
+ asn1Spec: :py:class:`~pyasn1.type.base.PyAsn1Item`
1905
+ A pyasn1 type object to act as a template guiding the decoder.
1906
+ Depending on the ASN.1 structure being decoded, `asn1Spec` may
1907
+ or may not be required. One of the reasons why `asn1Spec` may
1908
+ me required is that ASN.1 structure is encoded in the *IMPLICIT*
1909
+ tagging mode.
1910
+
1911
+ Yields
1912
+ ------
1913
+ : :py:class:`~pyasn1.type.base.PyAsn1Item`, :py:class:`~pyasn1.error.SubstrateUnderrunError`
1914
+ Decoded ASN.1 object (possibly, nested) or
1915
+ :py:class:`~pyasn1.error.SubstrateUnderrunError` object indicating
1916
+ insufficient BER/CER/DER serialization on input to fully recover ASN.1
1917
+ objects from it.
1918
+
1919
+ In the latter case the caller is advised to ensure some more data in
1920
+ the input stream, then call the iterator again. The decoder will resume
1921
+ the decoding process using the newly arrived data.
1922
+
1923
+ The `context` property of :py:class:`~pyasn1.error.SubstrateUnderrunError`
1924
+ object might hold a reference to the partially populated ASN.1 object
1925
+ being reconstructed.
1926
+
1927
+ Raises
1928
+ ------
1929
+ ~pyasn1.error.PyAsn1Error, ~pyasn1.error.EndOfStreamError
1930
+ `PyAsn1Error` on deserialization error, `EndOfStreamError` on
1931
+ premature stream closure.
1932
+
1933
+ Examples
1934
+ --------
1935
+ Decode BER serialisation without ASN.1 schema
1936
+
1937
+ .. code-block:: pycon
1938
+
1939
+ >>> stream = io.BytesIO(
1940
+ ... b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03')
1941
+ >>>
1942
+ >>> for asn1Object in StreamingDecoder(stream):
1943
+ ... print(asn1Object)
1944
+ >>>
1945
+ SequenceOf:
1946
+ 1 2 3
1947
+
1948
+ Decode BER serialisation with ASN.1 schema
1949
+
1950
+ .. code-block:: pycon
1951
+
1952
+ >>> stream = io.BytesIO(
1953
+ ... b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03')
1954
+ >>>
1955
+ >>> schema = SequenceOf(componentType=Integer())
1956
+ >>>
1957
+ >>> decoder = StreamingDecoder(stream, asn1Spec=schema)
1958
+ >>> for asn1Object in decoder:
1959
+ ... print(asn1Object)
1960
+ >>>
1961
+ SequenceOf:
1962
+ 1 2 3
1963
+ """
1964
+
1965
+ SINGLE_ITEM_DECODER = SingleItemDecoder
1966
+
1967
+ def __init__(self, substrate, asn1Spec=None, **options):
1968
+ self._singleItemDecoder = self.SINGLE_ITEM_DECODER(**options)
1969
+ self._substrate = asSeekableStream(substrate)
1970
+ self._asn1Spec = asn1Spec
1971
+ self._options = options
1972
+
1973
+ def __iter__(self):
1974
+ while True:
1975
+ for asn1Object in self._singleItemDecoder(
1976
+ self._substrate, self._asn1Spec, **self._options):
1977
+ yield asn1Object
1978
+
1979
+ for chunk in isEndOfStream(self._substrate):
1980
+ if isinstance(chunk, SubstrateUnderrunError):
1981
+ yield
1982
+
1983
+ break
1984
+
1985
+ if chunk:
1986
+ break
1987
+
1988
+
1989
+ class Decoder(object):
1990
+ """Create a BER decoder object.
1991
+
1992
+ Parse BER/CER/DER octet-stream into one, possibly nested, ASN.1 object.
1993
+ """
1994
+ STREAMING_DECODER = StreamingDecoder
1995
+
1996
+ @classmethod
1997
+ def __call__(cls, substrate, asn1Spec=None, **options):
1998
+ """Turns BER/CER/DER octet stream into an ASN.1 object.
1999
+
2000
+ Takes BER/CER/DER octet-stream in form of :py:class:`bytes`
2001
+ and decode it into an ASN.1 object
2002
+ (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) which
2003
+ may be a scalar or an arbitrary nested structure.
2004
+
2005
+ Parameters
2006
+ ----------
2007
+ substrate: :py:class:`bytes`
2008
+ BER/CER/DER octet-stream to parse
2009
+
2010
+ Keyword Args
2011
+ ------------
2012
+ asn1Spec: :py:class:`~pyasn1.type.base.PyAsn1Item`
2013
+ A pyasn1 type object (:py:class:`~pyasn1.type.base.PyAsn1Item`
2014
+ derivative) to act as a template guiding the decoder.
2015
+ Depending on the ASN.1 structure being decoded, `asn1Spec` may or
2016
+ may not be required. Most common reason for it to require is that
2017
+ ASN.1 structure is encoded in *IMPLICIT* tagging mode.
2018
+
2019
+ substrateFun: :py:class:`Union[
2020
+ Callable[[pyasn1.type.base.PyAsn1Item, bytes, int],
2021
+ Tuple[pyasn1.type.base.PyAsn1Item, bytes]],
2022
+ Callable[[pyasn1.type.base.PyAsn1Item, io.BytesIO, int, dict],
2023
+ Generator[Union[pyasn1.type.base.PyAsn1Item,
2024
+ pyasn1.error.SubstrateUnderrunError],
2025
+ None, None]]
2026
+ ]`
2027
+ User callback meant to generalize special use cases like non-recursive or
2028
+ partial decoding. A 3-arg non-streaming variant is supported for backwards
2029
+ compatiblilty in addition to the newer 4-arg streaming variant.
2030
+ The callback will receive the uninitialized object recovered from substrate
2031
+ as 1st argument, the uninterpreted payload as 2nd argument, and the length
2032
+ of the uninterpreted payload as 3rd argument. The streaming variant will
2033
+ additionally receive the decode(..., **options) kwargs as 4th argument.
2034
+ The non-streaming variant shall return an object that will be propagated
2035
+ as decode() return value as 1st item, and the remainig payload for further
2036
+ decode passes as 2nd item.
2037
+ The streaming variant shall yield an object that will be propagated as
2038
+ decode() return value, and leave the remaining payload in the stream.
2039
+
2040
+ Returns
2041
+ -------
2042
+ : :py:class:`tuple`
2043
+ A tuple of :py:class:`~pyasn1.type.base.PyAsn1Item` object
2044
+ recovered from BER/CER/DER substrate and the unprocessed trailing
2045
+ portion of the `substrate` (may be empty)
2046
+
2047
+ Raises
2048
+ ------
2049
+ : :py:class:`~pyasn1.error.PyAsn1Error`
2050
+ :py:class:`~pyasn1.error.SubstrateUnderrunError` on insufficient
2051
+ input or :py:class:`~pyasn1.error.PyAsn1Error` on decoding error.
2052
+
2053
+ Examples
2054
+ --------
2055
+ Decode BER/CER/DER serialisation without ASN.1 schema
2056
+
2057
+ .. code-block:: pycon
2058
+
2059
+ >>> s, unprocessed = decode(b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03')
2060
+ >>> str(s)
2061
+ SequenceOf:
2062
+ 1 2 3
2063
+
2064
+ Decode BER/CER/DER serialisation with ASN.1 schema
2065
+
2066
+ .. code-block:: pycon
2067
+
2068
+ >>> seq = SequenceOf(componentType=Integer())
2069
+ >>> s, unprocessed = decode(
2070
+ b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03', asn1Spec=seq)
2071
+ >>> str(s)
2072
+ SequenceOf:
2073
+ 1 2 3
2074
+
2075
+ """
2076
+ substrate = asSeekableStream(substrate)
2077
+
2078
+ if "substrateFun" in options:
2079
+ origSubstrateFun = options["substrateFun"]
2080
+
2081
+ def substrateFunWrapper(asn1Object, substrate, length, options=None):
2082
+ """Support both 0.4 and 0.5 style APIs.
2083
+
2084
+ substrateFun API has changed in 0.5 for use with streaming decoders. To stay backwards compatible,
2085
+ we first try if we received a streaming user callback. If that fails,we assume we've received a
2086
+ non-streaming v0.4 user callback and convert it for streaming on the fly
2087
+ """
2088
+ try:
2089
+ substrate_gen = origSubstrateFun(asn1Object, substrate, length, options)
2090
+ except TypeError as _value:
2091
+ if _value.__traceback__.tb_next:
2092
+ # Traceback depth > 1 means TypeError from inside user provided function
2093
+ raise
2094
+ # invariant maintained at Decoder.__call__ entry
2095
+ assert isinstance(substrate, io.BytesIO) # nosec assert_used
2096
+ substrate_gen = Decoder._callSubstrateFunV4asV5(origSubstrateFun, asn1Object, substrate, length)
2097
+ for value in substrate_gen:
2098
+ yield value
2099
+
2100
+ options["substrateFun"] = substrateFunWrapper
2101
+
2102
+ streamingDecoder = cls.STREAMING_DECODER(
2103
+ substrate, asn1Spec, **options)
2104
+
2105
+ for asn1Object in streamingDecoder:
2106
+ if isinstance(asn1Object, SubstrateUnderrunError):
2107
+ raise error.SubstrateUnderrunError('Short substrate on input')
2108
+
2109
+ try:
2110
+ tail = next(readFromStream(substrate))
2111
+
2112
+ except error.EndOfStreamError:
2113
+ tail = b''
2114
+
2115
+ return asn1Object, tail
2116
+
2117
+ @staticmethod
2118
+ def _callSubstrateFunV4asV5(substrateFunV4, asn1Object, substrate, length):
2119
+ substrate_bytes = substrate.read()
2120
+ if length == -1:
2121
+ length = len(substrate_bytes)
2122
+ value, nextSubstrate = substrateFunV4(asn1Object, substrate_bytes, length)
2123
+ nbytes = substrate.write(nextSubstrate)
2124
+ substrate.truncate()
2125
+ substrate.seek(-nbytes, os.SEEK_CUR)
2126
+ yield value
2127
+
2128
+ #: Turns BER octet stream into an ASN.1 object.
2129
+ #:
2130
+ #: Takes BER octet-stream and decode it into an ASN.1 object
2131
+ #: (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) which
2132
+ #: may be a scalar or an arbitrary nested structure.
2133
+ #:
2134
+ #: Parameters
2135
+ #: ----------
2136
+ #: substrate: :py:class:`bytes`
2137
+ #: BER octet-stream
2138
+ #:
2139
+ #: Keyword Args
2140
+ #: ------------
2141
+ #: asn1Spec: any pyasn1 type object e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative
2142
+ #: A pyasn1 type object to act as a template guiding the decoder. Depending on the ASN.1 structure
2143
+ #: being decoded, *asn1Spec* may or may not be required. Most common reason for
2144
+ #: it to require is that ASN.1 structure is encoded in *IMPLICIT* tagging mode.
2145
+ #:
2146
+ #: Returns
2147
+ #: -------
2148
+ #: : :py:class:`tuple`
2149
+ #: A tuple of pyasn1 object recovered from BER substrate (:py:class:`~pyasn1.type.base.PyAsn1Item` derivative)
2150
+ #: and the unprocessed trailing portion of the *substrate* (may be empty)
2151
+ #:
2152
+ #: Raises
2153
+ #: ------
2154
+ #: ~pyasn1.error.PyAsn1Error, ~pyasn1.error.SubstrateUnderrunError
2155
+ #: On decoding errors
2156
+ #:
2157
+ #: Notes
2158
+ #: -----
2159
+ #: This function is deprecated. Please use :py:class:`Decoder` or
2160
+ #: :py:class:`StreamingDecoder` class instance.
2161
+ #:
2162
+ #: Examples
2163
+ #: --------
2164
+ #: Decode BER serialisation without ASN.1 schema
2165
+ #:
2166
+ #: .. code-block:: pycon
2167
+ #:
2168
+ #: >>> s, _ = decode(b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03')
2169
+ #: >>> str(s)
2170
+ #: SequenceOf:
2171
+ #: 1 2 3
2172
+ #:
2173
+ #: Decode BER serialisation with ASN.1 schema
2174
+ #:
2175
+ #: .. code-block:: pycon
2176
+ #:
2177
+ #: >>> seq = SequenceOf(componentType=Integer())
2178
+ #: >>> s, _ = decode(b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03', asn1Spec=seq)
2179
+ #: >>> str(s)
2180
+ #: SequenceOf:
2181
+ #: 1 2 3
2182
+ #:
2183
+ decode = Decoder()
2184
+
2185
+ def __getattr__(attr: str):
2186
+ if newAttr := {"tagMap": "TAG_MAP", "typeMap": "TYPE_MAP"}.get(attr):
2187
+ warnings.warn(f"{attr} is deprecated. Please use {newAttr} instead.", DeprecationWarning)
2188
+ return globals()[newAttr]
2189
+ raise AttributeError(attr)
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/encoder.py ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is part of pyasn1 software.
3
+ #
4
+ # Copyright (c) 2005-2020, Ilya Etingof <etingof@gmail.com>
5
+ # License: https://pyasn1.readthedocs.io/en/latest/license.html
6
+ #
7
+ import sys
8
+ import warnings
9
+
10
+ from pyasn1 import debug
11
+ from pyasn1 import error
12
+ from pyasn1.codec.ber import eoo
13
+ from pyasn1.compat import _MISSING
14
+ from pyasn1.compat.integer import to_bytes
15
+ from pyasn1.type import char
16
+ from pyasn1.type import tag
17
+ from pyasn1.type import univ
18
+ from pyasn1.type import useful
19
+
20
+ __all__ = ['Encoder', 'encode']
21
+
22
+ LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_ENCODER)
23
+
24
+
25
+ class AbstractItemEncoder(object):
26
+ supportIndefLenMode = True
27
+
28
+ # An outcome of otherwise legit call `encodeFun(eoo.endOfOctets)`
29
+ eooIntegerSubstrate = (0, 0)
30
+ eooOctetsSubstrate = bytes(eooIntegerSubstrate)
31
+
32
+ # noinspection PyMethodMayBeStatic
33
+ def encodeTag(self, singleTag, isConstructed):
34
+ tagClass, tagFormat, tagId = singleTag
35
+ encodedTag = tagClass | tagFormat
36
+ if isConstructed:
37
+ encodedTag |= tag.tagFormatConstructed
38
+
39
+ if tagId < 31:
40
+ return encodedTag | tagId,
41
+
42
+ else:
43
+ substrate = tagId & 0x7f,
44
+
45
+ tagId >>= 7
46
+
47
+ while tagId:
48
+ substrate = (0x80 | (tagId & 0x7f),) + substrate
49
+ tagId >>= 7
50
+
51
+ return (encodedTag | 0x1F,) + substrate
52
+
53
+ def encodeLength(self, length, defMode):
54
+ if not defMode and self.supportIndefLenMode:
55
+ return (0x80,)
56
+
57
+ if length < 0x80:
58
+ return length,
59
+
60
+ else:
61
+ substrate = ()
62
+ while length:
63
+ substrate = (length & 0xff,) + substrate
64
+ length >>= 8
65
+
66
+ substrateLen = len(substrate)
67
+
68
+ if substrateLen > 126:
69
+ raise error.PyAsn1Error('Length octets overflow (%d)' % substrateLen)
70
+
71
+ return (0x80 | substrateLen,) + substrate
72
+
73
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
74
+ raise error.PyAsn1Error('Not implemented')
75
+
76
+ def encode(self, value, asn1Spec=None, encodeFun=None, **options):
77
+
78
+ if asn1Spec is None:
79
+ tagSet = value.tagSet
80
+ else:
81
+ tagSet = asn1Spec.tagSet
82
+
83
+ # untagged item?
84
+ if not tagSet:
85
+ substrate, isConstructed, isOctets = self.encodeValue(
86
+ value, asn1Spec, encodeFun, **options
87
+ )
88
+ return substrate
89
+
90
+ defMode = options.get('defMode', True)
91
+
92
+ substrate = b''
93
+
94
+ for idx, singleTag in enumerate(tagSet.superTags):
95
+
96
+ defModeOverride = defMode
97
+
98
+ # base tag?
99
+ if not idx:
100
+ try:
101
+ substrate, isConstructed, isOctets = self.encodeValue(
102
+ value, asn1Spec, encodeFun, **options
103
+ )
104
+
105
+ except error.PyAsn1Error as exc:
106
+ raise error.PyAsn1Error(
107
+ 'Error encoding %r: %s' % (value, exc))
108
+
109
+ if LOG:
110
+ LOG('encoded %svalue %s into %s' % (
111
+ isConstructed and 'constructed ' or '', value, substrate
112
+ ))
113
+
114
+ if not substrate and isConstructed and options.get('ifNotEmpty', False):
115
+ return substrate
116
+
117
+ if not isConstructed:
118
+ defModeOverride = True
119
+
120
+ if LOG:
121
+ LOG('overridden encoding mode into definitive for primitive type')
122
+
123
+ header = self.encodeTag(singleTag, isConstructed)
124
+
125
+ if LOG:
126
+ LOG('encoded %stag %s into %s' % (
127
+ isConstructed and 'constructed ' or '',
128
+ singleTag, debug.hexdump(bytes(header))))
129
+
130
+ header += self.encodeLength(len(substrate), defModeOverride)
131
+
132
+ if LOG:
133
+ LOG('encoded %s octets (tag + payload) into %s' % (
134
+ len(substrate), debug.hexdump(bytes(header))))
135
+
136
+ if isOctets:
137
+ substrate = bytes(header) + substrate
138
+
139
+ if not defModeOverride:
140
+ substrate += self.eooOctetsSubstrate
141
+
142
+ else:
143
+ substrate = header + substrate
144
+
145
+ if not defModeOverride:
146
+ substrate += self.eooIntegerSubstrate
147
+
148
+ if not isOctets:
149
+ substrate = bytes(substrate)
150
+
151
+ return substrate
152
+
153
+
154
+ class EndOfOctetsEncoder(AbstractItemEncoder):
155
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
156
+ return b'', False, True
157
+
158
+
159
+ class BooleanEncoder(AbstractItemEncoder):
160
+ supportIndefLenMode = False
161
+
162
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
163
+ return value and (1,) or (0,), False, False
164
+
165
+
166
+ class IntegerEncoder(AbstractItemEncoder):
167
+ supportIndefLenMode = False
168
+ supportCompactZero = False
169
+
170
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
171
+ if value == 0:
172
+ if LOG:
173
+ LOG('encoding %spayload for zero INTEGER' % (
174
+ self.supportCompactZero and 'no ' or ''
175
+ ))
176
+
177
+ # de-facto way to encode zero
178
+ if self.supportCompactZero:
179
+ return (), False, False
180
+ else:
181
+ return (0,), False, False
182
+
183
+ return to_bytes(int(value), signed=True), False, True
184
+
185
+
186
+ class BitStringEncoder(AbstractItemEncoder):
187
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
188
+ if asn1Spec is not None:
189
+ # TODO: try to avoid ASN.1 schema instantiation
190
+ value = asn1Spec.clone(value)
191
+
192
+ valueLength = len(value)
193
+ if valueLength % 8:
194
+ alignedValue = value << (8 - valueLength % 8)
195
+ else:
196
+ alignedValue = value
197
+
198
+ maxChunkSize = options.get('maxChunkSize', 0)
199
+ if not maxChunkSize or len(alignedValue) <= maxChunkSize * 8:
200
+ substrate = alignedValue.asOctets()
201
+ return bytes((len(substrate) * 8 - valueLength,)) + substrate, False, True
202
+
203
+ if LOG:
204
+ LOG('encoding into up to %s-octet chunks' % maxChunkSize)
205
+
206
+ baseTag = value.tagSet.baseTag
207
+
208
+ # strip off explicit tags
209
+ if baseTag:
210
+ tagSet = tag.TagSet(baseTag, baseTag)
211
+
212
+ else:
213
+ tagSet = tag.TagSet()
214
+
215
+ alignedValue = alignedValue.clone(tagSet=tagSet)
216
+
217
+ stop = 0
218
+ substrate = b''
219
+ while stop < valueLength:
220
+ start = stop
221
+ stop = min(start + maxChunkSize * 8, valueLength)
222
+ substrate += encodeFun(alignedValue[start:stop], asn1Spec, **options)
223
+
224
+ return substrate, True, True
225
+
226
+
227
+ class OctetStringEncoder(AbstractItemEncoder):
228
+
229
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
230
+
231
+ if asn1Spec is None:
232
+ substrate = value.asOctets()
233
+
234
+ elif not isinstance(value, bytes):
235
+ substrate = asn1Spec.clone(value).asOctets()
236
+
237
+ else:
238
+ substrate = value
239
+
240
+ maxChunkSize = options.get('maxChunkSize', 0)
241
+
242
+ if not maxChunkSize or len(substrate) <= maxChunkSize:
243
+ return substrate, False, True
244
+
245
+ if LOG:
246
+ LOG('encoding into up to %s-octet chunks' % maxChunkSize)
247
+
248
+ # strip off explicit tags for inner chunks
249
+
250
+ if asn1Spec is None:
251
+ baseTag = value.tagSet.baseTag
252
+
253
+ # strip off explicit tags
254
+ if baseTag:
255
+ tagSet = tag.TagSet(baseTag, baseTag)
256
+
257
+ else:
258
+ tagSet = tag.TagSet()
259
+
260
+ asn1Spec = value.clone(tagSet=tagSet)
261
+
262
+ elif not isinstance(value, bytes):
263
+ baseTag = asn1Spec.tagSet.baseTag
264
+
265
+ # strip off explicit tags
266
+ if baseTag:
267
+ tagSet = tag.TagSet(baseTag, baseTag)
268
+
269
+ else:
270
+ tagSet = tag.TagSet()
271
+
272
+ asn1Spec = asn1Spec.clone(tagSet=tagSet)
273
+
274
+ pos = 0
275
+ substrate = b''
276
+
277
+ while True:
278
+ chunk = value[pos:pos + maxChunkSize]
279
+ if not chunk:
280
+ break
281
+
282
+ substrate += encodeFun(chunk, asn1Spec, **options)
283
+ pos += maxChunkSize
284
+
285
+ return substrate, True, True
286
+
287
+
288
+ class NullEncoder(AbstractItemEncoder):
289
+ supportIndefLenMode = False
290
+
291
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
292
+ return b'', False, True
293
+
294
+
295
+ class ObjectIdentifierEncoder(AbstractItemEncoder):
296
+ supportIndefLenMode = False
297
+
298
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
299
+ if asn1Spec is not None:
300
+ value = asn1Spec.clone(value)
301
+
302
+ oid = value.asTuple()
303
+
304
+ # Build the first pair
305
+ try:
306
+ first = oid[0]
307
+ second = oid[1]
308
+
309
+ except IndexError:
310
+ raise error.PyAsn1Error('Short OID %s' % (value,))
311
+
312
+ if 0 <= second <= 39:
313
+ if first == 1:
314
+ oid = (second + 40,) + oid[2:]
315
+ elif first == 0:
316
+ oid = (second,) + oid[2:]
317
+ elif first == 2:
318
+ oid = (second + 80,) + oid[2:]
319
+ else:
320
+ raise error.PyAsn1Error('Impossible first/second arcs at %s' % (value,))
321
+
322
+ elif first == 2:
323
+ oid = (second + 80,) + oid[2:]
324
+
325
+ else:
326
+ raise error.PyAsn1Error('Impossible first/second arcs at %s' % (value,))
327
+
328
+ octets = ()
329
+
330
+ # Cycle through subIds
331
+ for subOid in oid:
332
+ if 0 <= subOid <= 127:
333
+ # Optimize for the common case
334
+ octets += (subOid,)
335
+
336
+ elif subOid > 127:
337
+ # Pack large Sub-Object IDs
338
+ res = (subOid & 0x7f,)
339
+ subOid >>= 7
340
+
341
+ while subOid:
342
+ res = (0x80 | (subOid & 0x7f),) + res
343
+ subOid >>= 7
344
+
345
+ # Add packed Sub-Object ID to resulted Object ID
346
+ octets += res
347
+
348
+ else:
349
+ raise error.PyAsn1Error('Negative OID arc %s at %s' % (subOid, value))
350
+
351
+ return octets, False, False
352
+
353
+
354
+ class RelativeOIDEncoder(AbstractItemEncoder):
355
+ supportIndefLenMode = False
356
+
357
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
358
+ if asn1Spec is not None:
359
+ value = asn1Spec.clone(value)
360
+
361
+ octets = ()
362
+
363
+ # Cycle through subIds
364
+ for subOid in value.asTuple():
365
+ if 0 <= subOid <= 127:
366
+ # Optimize for the common case
367
+ octets += (subOid,)
368
+
369
+ elif subOid > 127:
370
+ # Pack large Sub-Object IDs
371
+ res = (subOid & 0x7f,)
372
+ subOid >>= 7
373
+
374
+ while subOid:
375
+ res = (0x80 | (subOid & 0x7f),) + res
376
+ subOid >>= 7
377
+
378
+ # Add packed Sub-Object ID to resulted RELATIVE-OID
379
+ octets += res
380
+
381
+ else:
382
+ raise error.PyAsn1Error('Negative RELATIVE-OID arc %s at %s' % (subOid, value))
383
+
384
+ return octets, False, False
385
+
386
+
387
+ class RealEncoder(AbstractItemEncoder):
388
+ supportIndefLenMode = False
389
+ binEncBase = 2 # set to None to choose encoding base automatically
390
+
391
+ @staticmethod
392
+ def _dropFloatingPoint(m, encbase, e):
393
+ ms, es = 1, 1
394
+ if m < 0:
395
+ ms = -1 # mantissa sign
396
+
397
+ if e < 0:
398
+ es = -1 # exponent sign
399
+
400
+ m *= ms
401
+
402
+ if encbase == 8:
403
+ m *= 2 ** (abs(e) % 3 * es)
404
+ e = abs(e) // 3 * es
405
+
406
+ elif encbase == 16:
407
+ m *= 2 ** (abs(e) % 4 * es)
408
+ e = abs(e) // 4 * es
409
+
410
+ while True:
411
+ if int(m) != m:
412
+ m *= encbase
413
+ e -= 1
414
+ continue
415
+ break
416
+
417
+ return ms, int(m), encbase, e
418
+
419
+ def _chooseEncBase(self, value):
420
+ m, b, e = value
421
+ encBase = [2, 8, 16]
422
+ if value.binEncBase in encBase:
423
+ return self._dropFloatingPoint(m, value.binEncBase, e)
424
+
425
+ elif self.binEncBase in encBase:
426
+ return self._dropFloatingPoint(m, self.binEncBase, e)
427
+
428
+ # auto choosing base 2/8/16
429
+ mantissa = [m, m, m]
430
+ exponent = [e, e, e]
431
+ sign = 1
432
+ encbase = 2
433
+ e = float('inf')
434
+
435
+ for i in range(3):
436
+ (sign,
437
+ mantissa[i],
438
+ encBase[i],
439
+ exponent[i]) = self._dropFloatingPoint(mantissa[i], encBase[i], exponent[i])
440
+
441
+ if abs(exponent[i]) < abs(e) or (abs(exponent[i]) == abs(e) and mantissa[i] < m):
442
+ e = exponent[i]
443
+ m = int(mantissa[i])
444
+ encbase = encBase[i]
445
+
446
+ if LOG:
447
+ LOG('automatically chosen REAL encoding base %s, sign %s, mantissa %s, '
448
+ 'exponent %s' % (encbase, sign, m, e))
449
+
450
+ return sign, m, encbase, e
451
+
452
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
453
+ if asn1Spec is not None:
454
+ value = asn1Spec.clone(value)
455
+
456
+ if value.isPlusInf:
457
+ return (0x40,), False, False
458
+
459
+ if value.isMinusInf:
460
+ return (0x41,), False, False
461
+
462
+ m, b, e = value
463
+
464
+ if not m:
465
+ return b'', False, True
466
+
467
+ if b == 10:
468
+ if LOG:
469
+ LOG('encoding REAL into character form')
470
+
471
+ return b'\x03%dE%s%d' % (m, e == 0 and b'+' or b'', e), False, True
472
+
473
+ elif b == 2:
474
+ fo = 0x80 # binary encoding
475
+ ms, m, encbase, e = self._chooseEncBase(value)
476
+
477
+ if ms < 0: # mantissa sign
478
+ fo |= 0x40 # sign bit
479
+
480
+ # exponent & mantissa normalization
481
+ if encbase == 2:
482
+ while m & 0x1 == 0:
483
+ m >>= 1
484
+ e += 1
485
+
486
+ elif encbase == 8:
487
+ while m & 0x7 == 0:
488
+ m >>= 3
489
+ e += 1
490
+ fo |= 0x10
491
+
492
+ else: # encbase = 16
493
+ while m & 0xf == 0:
494
+ m >>= 4
495
+ e += 1
496
+ fo |= 0x20
497
+
498
+ sf = 0 # scale factor
499
+
500
+ while m & 0x1 == 0:
501
+ m >>= 1
502
+ sf += 1
503
+
504
+ if sf > 3:
505
+ raise error.PyAsn1Error('Scale factor overflow') # bug if raised
506
+
507
+ fo |= sf << 2
508
+ eo = b''
509
+ if e == 0 or e == -1:
510
+ eo = bytes((e & 0xff,))
511
+
512
+ else:
513
+ while e not in (0, -1):
514
+ eo = bytes((e & 0xff,)) + eo
515
+ e >>= 8
516
+
517
+ if e == 0 and eo and eo[0] & 0x80:
518
+ eo = bytes((0,)) + eo
519
+
520
+ if e == -1 and eo and not (eo[0] & 0x80):
521
+ eo = bytes((0xff,)) + eo
522
+
523
+ n = len(eo)
524
+ if n > 0xff:
525
+ raise error.PyAsn1Error('Real exponent overflow')
526
+
527
+ if n == 1:
528
+ pass
529
+
530
+ elif n == 2:
531
+ fo |= 1
532
+
533
+ elif n == 3:
534
+ fo |= 2
535
+
536
+ else:
537
+ fo |= 3
538
+ eo = bytes((n & 0xff,)) + eo
539
+
540
+ po = b''
541
+
542
+ while m:
543
+ po = bytes((m & 0xff,)) + po
544
+ m >>= 8
545
+
546
+ substrate = bytes((fo,)) + eo + po
547
+
548
+ return substrate, False, True
549
+
550
+ else:
551
+ raise error.PyAsn1Error('Prohibited Real base %s' % b)
552
+
553
+
554
+ class SequenceEncoder(AbstractItemEncoder):
555
+ omitEmptyOptionals = False
556
+
557
+ # TODO: handling three flavors of input is too much -- split over codecs
558
+
559
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
560
+
561
+ substrate = b''
562
+
563
+ omitEmptyOptionals = options.get(
564
+ 'omitEmptyOptionals', self.omitEmptyOptionals)
565
+
566
+ if LOG:
567
+ LOG('%sencoding empty OPTIONAL components' % (
568
+ omitEmptyOptionals and 'not ' or ''))
569
+
570
+ if asn1Spec is None:
571
+ # instance of ASN.1 schema
572
+ inconsistency = value.isInconsistent
573
+ if inconsistency:
574
+ raise error.PyAsn1Error(
575
+ f"ASN.1 object {value.__class__.__name__} is inconsistent")
576
+
577
+ namedTypes = value.componentType
578
+
579
+ for idx, component in enumerate(value.values()):
580
+ if namedTypes:
581
+ namedType = namedTypes[idx]
582
+
583
+ if namedType.isOptional and not component.isValue:
584
+ if LOG:
585
+ LOG('not encoding OPTIONAL component %r' % (namedType,))
586
+ continue
587
+
588
+ if namedType.isDefaulted and component == namedType.asn1Object:
589
+ if LOG:
590
+ LOG('not encoding DEFAULT component %r' % (namedType,))
591
+ continue
592
+
593
+ if omitEmptyOptionals:
594
+ options.update(ifNotEmpty=namedType.isOptional)
595
+
596
+ # wrap open type blob if needed
597
+ if namedTypes and namedType.openType:
598
+
599
+ wrapType = namedType.asn1Object
600
+
601
+ if wrapType.typeId in (
602
+ univ.SetOf.typeId, univ.SequenceOf.typeId):
603
+
604
+ substrate += encodeFun(
605
+ component, asn1Spec,
606
+ **dict(options, wrapType=wrapType.componentType))
607
+
608
+ else:
609
+ chunk = encodeFun(component, asn1Spec, **options)
610
+
611
+ if wrapType.isSameTypeWith(component):
612
+ substrate += chunk
613
+
614
+ else:
615
+ substrate += encodeFun(chunk, wrapType, **options)
616
+
617
+ if LOG:
618
+ LOG('wrapped with wrap type %r' % (wrapType,))
619
+
620
+ else:
621
+ substrate += encodeFun(component, asn1Spec, **options)
622
+
623
+ else:
624
+ # bare Python value + ASN.1 schema
625
+ for idx, namedType in enumerate(asn1Spec.componentType.namedTypes):
626
+
627
+ try:
628
+ component = value[namedType.name]
629
+
630
+ except KeyError:
631
+ raise error.PyAsn1Error('Component name "%s" not found in %r' % (
632
+ namedType.name, value))
633
+
634
+ if namedType.isOptional and namedType.name not in value:
635
+ if LOG:
636
+ LOG('not encoding OPTIONAL component %r' % (namedType,))
637
+ continue
638
+
639
+ if namedType.isDefaulted and component == namedType.asn1Object:
640
+ if LOG:
641
+ LOG('not encoding DEFAULT component %r' % (namedType,))
642
+ continue
643
+
644
+ if omitEmptyOptionals:
645
+ options.update(ifNotEmpty=namedType.isOptional)
646
+
647
+ componentSpec = namedType.asn1Object
648
+
649
+ # wrap open type blob if needed
650
+ if namedType.openType:
651
+
652
+ if componentSpec.typeId in (
653
+ univ.SetOf.typeId, univ.SequenceOf.typeId):
654
+
655
+ substrate += encodeFun(
656
+ component, componentSpec,
657
+ **dict(options, wrapType=componentSpec.componentType))
658
+
659
+ else:
660
+ chunk = encodeFun(component, componentSpec, **options)
661
+
662
+ if componentSpec.isSameTypeWith(component):
663
+ substrate += chunk
664
+
665
+ else:
666
+ substrate += encodeFun(chunk, componentSpec, **options)
667
+
668
+ if LOG:
669
+ LOG('wrapped with wrap type %r' % (componentSpec,))
670
+
671
+ else:
672
+ substrate += encodeFun(component, componentSpec, **options)
673
+
674
+ return substrate, True, True
675
+
676
+
677
+ class SequenceOfEncoder(AbstractItemEncoder):
678
+ def _encodeComponents(self, value, asn1Spec, encodeFun, **options):
679
+
680
+ if asn1Spec is None:
681
+ inconsistency = value.isInconsistent
682
+ if inconsistency:
683
+ raise error.PyAsn1Error(
684
+ f"ASN.1 object {value.__class__.__name__} is inconsistent")
685
+
686
+ else:
687
+ asn1Spec = asn1Spec.componentType
688
+
689
+ chunks = []
690
+
691
+ wrapType = options.pop('wrapType', None)
692
+
693
+ for idx, component in enumerate(value):
694
+ chunk = encodeFun(component, asn1Spec, **options)
695
+
696
+ if (wrapType is not None and
697
+ not wrapType.isSameTypeWith(component)):
698
+ # wrap encoded value with wrapper container (e.g. ANY)
699
+ chunk = encodeFun(chunk, wrapType, **options)
700
+
701
+ if LOG:
702
+ LOG('wrapped with wrap type %r' % (wrapType,))
703
+
704
+ chunks.append(chunk)
705
+
706
+ return chunks
707
+
708
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
709
+ chunks = self._encodeComponents(
710
+ value, asn1Spec, encodeFun, **options)
711
+
712
+ return b''.join(chunks), True, True
713
+
714
+
715
+ class ChoiceEncoder(AbstractItemEncoder):
716
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
717
+ if asn1Spec is None:
718
+ component = value.getComponent()
719
+ else:
720
+ names = [namedType.name for namedType in asn1Spec.componentType.namedTypes
721
+ if namedType.name in value]
722
+ if len(names) != 1:
723
+ raise error.PyAsn1Error('%s components for Choice at %r' % (len(names) and 'Multiple ' or 'None ', value))
724
+
725
+ name = names[0]
726
+
727
+ component = value[name]
728
+ asn1Spec = asn1Spec[name]
729
+
730
+ return encodeFun(component, asn1Spec, **options), True, True
731
+
732
+
733
+ class AnyEncoder(OctetStringEncoder):
734
+ def encodeValue(self, value, asn1Spec, encodeFun, **options):
735
+ if asn1Spec is None:
736
+ value = value.asOctets()
737
+ elif not isinstance(value, bytes):
738
+ value = asn1Spec.clone(value).asOctets()
739
+
740
+ return value, not options.get('defMode', True), True
741
+
742
+
743
+ TAG_MAP = {
744
+ eoo.endOfOctets.tagSet: EndOfOctetsEncoder(),
745
+ univ.Boolean.tagSet: BooleanEncoder(),
746
+ univ.Integer.tagSet: IntegerEncoder(),
747
+ univ.BitString.tagSet: BitStringEncoder(),
748
+ univ.OctetString.tagSet: OctetStringEncoder(),
749
+ univ.Null.tagSet: NullEncoder(),
750
+ univ.ObjectIdentifier.tagSet: ObjectIdentifierEncoder(),
751
+ univ.RelativeOID.tagSet: RelativeOIDEncoder(),
752
+ univ.Enumerated.tagSet: IntegerEncoder(),
753
+ univ.Real.tagSet: RealEncoder(),
754
+ # Sequence & Set have same tags as SequenceOf & SetOf
755
+ univ.SequenceOf.tagSet: SequenceOfEncoder(),
756
+ univ.SetOf.tagSet: SequenceOfEncoder(),
757
+ univ.Choice.tagSet: ChoiceEncoder(),
758
+ # character string types
759
+ char.UTF8String.tagSet: OctetStringEncoder(),
760
+ char.NumericString.tagSet: OctetStringEncoder(),
761
+ char.PrintableString.tagSet: OctetStringEncoder(),
762
+ char.TeletexString.tagSet: OctetStringEncoder(),
763
+ char.VideotexString.tagSet: OctetStringEncoder(),
764
+ char.IA5String.tagSet: OctetStringEncoder(),
765
+ char.GraphicString.tagSet: OctetStringEncoder(),
766
+ char.VisibleString.tagSet: OctetStringEncoder(),
767
+ char.GeneralString.tagSet: OctetStringEncoder(),
768
+ char.UniversalString.tagSet: OctetStringEncoder(),
769
+ char.BMPString.tagSet: OctetStringEncoder(),
770
+ # useful types
771
+ useful.ObjectDescriptor.tagSet: OctetStringEncoder(),
772
+ useful.GeneralizedTime.tagSet: OctetStringEncoder(),
773
+ useful.UTCTime.tagSet: OctetStringEncoder()
774
+ }
775
+
776
+ # Put in ambiguous & non-ambiguous types for faster codec lookup
777
+ TYPE_MAP = {
778
+ univ.Boolean.typeId: BooleanEncoder(),
779
+ univ.Integer.typeId: IntegerEncoder(),
780
+ univ.BitString.typeId: BitStringEncoder(),
781
+ univ.OctetString.typeId: OctetStringEncoder(),
782
+ univ.Null.typeId: NullEncoder(),
783
+ univ.ObjectIdentifier.typeId: ObjectIdentifierEncoder(),
784
+ univ.RelativeOID.typeId: RelativeOIDEncoder(),
785
+ univ.Enumerated.typeId: IntegerEncoder(),
786
+ univ.Real.typeId: RealEncoder(),
787
+ # Sequence & Set have same tags as SequenceOf & SetOf
788
+ univ.Set.typeId: SequenceEncoder(),
789
+ univ.SetOf.typeId: SequenceOfEncoder(),
790
+ univ.Sequence.typeId: SequenceEncoder(),
791
+ univ.SequenceOf.typeId: SequenceOfEncoder(),
792
+ univ.Choice.typeId: ChoiceEncoder(),
793
+ univ.Any.typeId: AnyEncoder(),
794
+ # character string types
795
+ char.UTF8String.typeId: OctetStringEncoder(),
796
+ char.NumericString.typeId: OctetStringEncoder(),
797
+ char.PrintableString.typeId: OctetStringEncoder(),
798
+ char.TeletexString.typeId: OctetStringEncoder(),
799
+ char.VideotexString.typeId: OctetStringEncoder(),
800
+ char.IA5String.typeId: OctetStringEncoder(),
801
+ char.GraphicString.typeId: OctetStringEncoder(),
802
+ char.VisibleString.typeId: OctetStringEncoder(),
803
+ char.GeneralString.typeId: OctetStringEncoder(),
804
+ char.UniversalString.typeId: OctetStringEncoder(),
805
+ char.BMPString.typeId: OctetStringEncoder(),
806
+ # useful types
807
+ useful.ObjectDescriptor.typeId: OctetStringEncoder(),
808
+ useful.GeneralizedTime.typeId: OctetStringEncoder(),
809
+ useful.UTCTime.typeId: OctetStringEncoder()
810
+ }
811
+
812
+
813
+ class SingleItemEncoder(object):
814
+ fixedDefLengthMode = None
815
+ fixedChunkSize = None
816
+
817
+ TAG_MAP = TAG_MAP
818
+ TYPE_MAP = TYPE_MAP
819
+
820
+ def __init__(self, tagMap=_MISSING, typeMap=_MISSING, **ignored):
821
+ self._tagMap = tagMap if tagMap is not _MISSING else self.TAG_MAP
822
+ self._typeMap = typeMap if typeMap is not _MISSING else self.TYPE_MAP
823
+
824
+ def __call__(self, value, asn1Spec=None, **options):
825
+ try:
826
+ if asn1Spec is None:
827
+ typeId = value.typeId
828
+ else:
829
+ typeId = asn1Spec.typeId
830
+
831
+ except AttributeError:
832
+ raise error.PyAsn1Error('Value %r is not ASN.1 type instance '
833
+ 'and "asn1Spec" not given' % (value,))
834
+
835
+ if LOG:
836
+ LOG('encoder called in %sdef mode, chunk size %s for type %s, '
837
+ 'value:\n%s' % (not options.get('defMode', True) and 'in' or '',
838
+ options.get('maxChunkSize', 0),
839
+ asn1Spec is None and value.prettyPrintType() or
840
+ asn1Spec.prettyPrintType(), value))
841
+
842
+ if self.fixedDefLengthMode is not None:
843
+ options.update(defMode=self.fixedDefLengthMode)
844
+
845
+ if self.fixedChunkSize is not None:
846
+ options.update(maxChunkSize=self.fixedChunkSize)
847
+
848
+ try:
849
+ concreteEncoder = self._typeMap[typeId]
850
+
851
+ if LOG:
852
+ LOG('using value codec %s chosen by type ID '
853
+ '%s' % (concreteEncoder.__class__.__name__, typeId))
854
+
855
+ except KeyError:
856
+ if asn1Spec is None:
857
+ tagSet = value.tagSet
858
+ else:
859
+ tagSet = asn1Spec.tagSet
860
+
861
+ # use base type for codec lookup to recover untagged types
862
+ baseTagSet = tag.TagSet(tagSet.baseTag, tagSet.baseTag)
863
+
864
+ try:
865
+ concreteEncoder = self._tagMap[baseTagSet]
866
+
867
+ except KeyError:
868
+ raise error.PyAsn1Error('No encoder for %r (%s)' % (value, tagSet))
869
+
870
+ if LOG:
871
+ LOG('using value codec %s chosen by tagSet '
872
+ '%s' % (concreteEncoder.__class__.__name__, tagSet))
873
+
874
+ substrate = concreteEncoder.encode(value, asn1Spec, self, **options)
875
+
876
+ if LOG:
877
+ LOG('codec %s built %s octets of substrate: %s\nencoder '
878
+ 'completed' % (concreteEncoder, len(substrate),
879
+ debug.hexdump(substrate)))
880
+
881
+ return substrate
882
+
883
+
884
+ class Encoder(object):
885
+ SINGLE_ITEM_ENCODER = SingleItemEncoder
886
+
887
+ def __init__(self, tagMap=_MISSING, typeMap=_MISSING, **options):
888
+ self._singleItemEncoder = self.SINGLE_ITEM_ENCODER(
889
+ tagMap=tagMap, typeMap=typeMap, **options
890
+ )
891
+
892
+ def __call__(self, pyObject, asn1Spec=None, **options):
893
+ return self._singleItemEncoder(
894
+ pyObject, asn1Spec=asn1Spec, **options)
895
+
896
+
897
+ #: Turns ASN.1 object into BER octet stream.
898
+ #:
899
+ #: Takes any ASN.1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative)
900
+ #: walks all its components recursively and produces a BER octet stream.
901
+ #:
902
+ #: Parameters
903
+ #: ----------
904
+ #: value: either a Python or pyasn1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative)
905
+ #: A Python or pyasn1 object to encode. If Python object is given, `asnSpec`
906
+ #: parameter is required to guide the encoding process.
907
+ #:
908
+ #: Keyword Args
909
+ #: ------------
910
+ #: asn1Spec:
911
+ #: Optional ASN.1 schema or value object e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative
912
+ #:
913
+ #: defMode: :py:class:`bool`
914
+ #: If :obj:`False`, produces indefinite length encoding
915
+ #:
916
+ #: maxChunkSize: :py:class:`int`
917
+ #: Maximum chunk size in chunked encoding mode (0 denotes unlimited chunk size)
918
+ #:
919
+ #: Returns
920
+ #: -------
921
+ #: : :py:class:`bytes`
922
+ #: Given ASN.1 object encoded into BER octetstream
923
+ #:
924
+ #: Raises
925
+ #: ------
926
+ #: ~pyasn1.error.PyAsn1Error
927
+ #: On encoding errors
928
+ #:
929
+ #: Examples
930
+ #: --------
931
+ #: Encode Python value into BER with ASN.1 schema
932
+ #:
933
+ #: .. code-block:: pycon
934
+ #:
935
+ #: >>> seq = SequenceOf(componentType=Integer())
936
+ #: >>> encode([1, 2, 3], asn1Spec=seq)
937
+ #: b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03'
938
+ #:
939
+ #: Encode ASN.1 value object into BER
940
+ #:
941
+ #: .. code-block:: pycon
942
+ #:
943
+ #: >>> seq = SequenceOf(componentType=Integer())
944
+ #: >>> seq.extend([1, 2, 3])
945
+ #: >>> encode(seq)
946
+ #: b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03'
947
+ #:
948
+ encode = Encoder()
949
+
950
+ def __getattr__(attr: str):
951
+ if newAttr := {"tagMap": "TAG_MAP", "typeMap": "TYPE_MAP"}.get(attr):
952
+ warnings.warn(f"{attr} is deprecated. Please use {newAttr} instead.", DeprecationWarning)
953
+ return globals()[newAttr]
954
+ raise AttributeError(attr)
.venv/lib/python3.11/site-packages/pyasn1/codec/ber/eoo.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is part of pyasn1 software.
3
+ #
4
+ # Copyright (c) 2005-2020, Ilya Etingof <etingof@gmail.com>
5
+ # License: https://pyasn1.readthedocs.io/en/latest/license.html
6
+ #
7
+ from pyasn1.type import base
8
+ from pyasn1.type import tag
9
+
10
+ __all__ = ['endOfOctets']
11
+
12
+
13
+ class EndOfOctets(base.SimpleAsn1Type):
14
+ defaultValue = 0
15
+ tagSet = tag.initTagSet(
16
+ tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x00)
17
+ )
18
+
19
+ _instance = None
20
+
21
+ def __new__(cls, *args, **kwargs):
22
+ if cls._instance is None:
23
+ cls._instance = object.__new__(cls, *args, **kwargs)
24
+
25
+ return cls._instance
26
+
27
+
28
+ endOfOctets = EndOfOctets()
.venv/lib/python3.11/site-packages/pyasn1/codec/native/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file is necessary to make this directory a package.