koichi12 commited on
Commit
dbf954e
·
verified ·
1 Parent(s): 2e7ec00

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/_src/__pycache__/__init__.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/__pycache__/magic_trace.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/delayed_mul_tensor.py +77 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/magic_trace.py +42 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/wrap_type.py +71 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__init__.py +3 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/__init__.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/_parsing.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/rearrange.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/rearrange.py +207 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/__pycache__/extrapolation.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/__pycache__/quadrature.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/approximation.py +246 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/__init__.py +2 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/__pycache__/calculus.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/calculus.py +531 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/eigen.py +877 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/linalg.py +790 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/matrices.py +1005 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h +1853 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h +273 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/__init__.py +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/include/cupti_pcsampling.h +923 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_infer_v8.h +571 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version.h +70 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/__pycache__/__init__.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/__init__.py +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/cudalibxt.h +97 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/cufftXt.h +269 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/__init__.py +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/__init__.py +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/include/__init__.py +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia_nvtx_cu11-11.8.86.dist-info/METADATA +35 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia_nvtx_cu11-11.8.86.dist-info/WHEEL +5 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/__pycache__/_elffile.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/_parser.py +354 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/markers.py +331 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/metadata.py +863 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/py.typed +0 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/specifiers.py +1020 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/tags.py +617 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/version.py +582 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/cachecontrol/__init__.py +28 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/cachecontrol/__pycache__/__init__.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/_src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (215 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/__pycache__/magic_trace.cpython-311.pyc ADDED
Binary file (2.48 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/delayed_mul_tensor.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import torch
7
+
8
+ from . import _Tensor, Tensor
9
+ from .reference import _dims, _enable_layers, llist, ltuple
10
+
11
+
12
+ class DelayedMulTensor(_Tensor):
13
+ def __init__(self, lhs, rhs):
14
+ self._lhs, self._rhs = lhs, rhs
15
+ self._data = None
16
+ self._levels_data = None
17
+ self._has_device = lhs._has_device or rhs._has_device
18
+ self._batchtensor_data = None
19
+ self._tensor_data = None
20
+
21
+ @property
22
+ def _levels(self):
23
+ if self._levels_data is None:
24
+ levels = llist(self._lhs._levels)
25
+ for l in self._rhs._levels:
26
+ if l not in levels:
27
+ levels.append(l)
28
+ self._levels_data = ltuple(levels)
29
+ return self._levels_data
30
+
31
+ @property
32
+ def _batchtensor(self):
33
+ if self._batchtensor_data is None:
34
+ with _enable_layers(self._levels):
35
+ print("bt multiply fallback")
36
+ self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
37
+ return self._batchtensor_data
38
+
39
+ @property
40
+ def _tensor(self):
41
+ if self._tensor_data is None:
42
+ self._tensor_data = Tensor.from_batched(
43
+ self._batchtensor, self._has_device
44
+ )._tensor
45
+ return self._tensor_data
46
+
47
+ @property
48
+ def ndim(self):
49
+ return self._batchtensor.ndim
50
+
51
+ @property
52
+ def dims(self):
53
+ return ltuple(super().dims)
54
+
55
+ def sum(self, dim):
56
+ dims = _dims(dim, 0, False, False)
57
+ n = ord("a")
58
+ all_levels = self._levels
59
+
60
+ def to_char(d):
61
+ return chr(n + all_levels.index(d))
62
+
63
+ plhs, levelslhs = self._lhs._tensor, self._lhs._levels
64
+ prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
65
+ new_dims = tuple(d for d in self.dims if d not in dims)
66
+ new_levels = [l for l in self._levels if l not in dims]
67
+ fmt = "".join(
68
+ [
69
+ *(to_char(d) for d in levelslhs),
70
+ ",",
71
+ *(to_char(d) for d in levelsrhs),
72
+ "->",
73
+ *(to_char(d) for d in new_levels),
74
+ ]
75
+ )
76
+ result_data = torch.einsum(fmt, (plhs, prhs))
77
+ return Tensor.from_positional(result_data, new_levels, True)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/magic_trace.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import os
7
+ import signal
8
+ import subprocess
9
+ from contextlib import contextmanager
10
+
11
+
12
+ @contextmanager
13
+ def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
14
+ pid = os.getpid()
15
+ if not os.path.exists(magic_trace_cache):
16
+ print(f"Downloading magic_trace to: {magic_trace_cache}")
17
+ subprocess.run(
18
+ [
19
+ "wget",
20
+ "-O",
21
+ magic_trace_cache,
22
+ "-q",
23
+ "https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
24
+ ]
25
+ )
26
+ subprocess.run(["chmod", "+x", magic_trace_cache])
27
+ args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
28
+ p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
29
+ while True:
30
+ x = p.stderr.readline()
31
+ print(x)
32
+ if "Attached" in x:
33
+ break
34
+ try:
35
+ yield
36
+ finally:
37
+ p.send_signal(signal.SIGINT)
38
+ r = p.wait()
39
+ print(p.stderr.read())
40
+ p.stderr.close()
41
+ if r != 0:
42
+ raise ValueError(f"magic_trace exited abnormally: {r}")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/wrap_type.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from types import (
8
+ BuiltinMethodType,
9
+ FunctionType,
10
+ GetSetDescriptorType,
11
+ MethodDescriptorType,
12
+ WrapperDescriptorType,
13
+ )
14
+
15
+ from functorch._C import dim as _C
16
+
17
+ _wrap_method = _C._wrap_method
18
+
19
+ FUNC_TYPES = (
20
+ FunctionType,
21
+ MethodDescriptorType,
22
+ BuiltinMethodType,
23
+ WrapperDescriptorType,
24
+ )
25
+ PROPERTY_TYPES = (GetSetDescriptorType, property)
26
+
27
+
28
+ def _py_wrap_method(orig, __torch_function__):
29
+ def impl(*args, **kwargs):
30
+ return __torch_function__(orig, None, args, kwargs)
31
+
32
+ return impl
33
+
34
+
35
+ def wrap_type(use_c, to_patch, pattern, __torch_function__):
36
+ if use_c:
37
+ wrap_method = _wrap_method
38
+ else:
39
+ wrap_method = _py_wrap_method
40
+
41
+ all = {}
42
+ for t in reversed(pattern.mro()[:-1]): # skip object
43
+ all.update(t.__dict__)
44
+
45
+ def wrap_attr(orig):
46
+ return property(wrap_method(orig.__get__, __torch_function__))
47
+
48
+ for name, obj in all.items():
49
+ if name in (
50
+ "__dict__",
51
+ "__new__",
52
+ "__init__",
53
+ "__repr__",
54
+ "__weakref__",
55
+ "__doc__",
56
+ "__module__",
57
+ "__dir__",
58
+ ):
59
+ continue
60
+
61
+ # skip things that have been overloaded
62
+ # things that come from object like `__eq__` still need to be patched, however.
63
+ if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
64
+ object, name, None
65
+ ):
66
+ continue
67
+
68
+ if isinstance(obj, FUNC_TYPES):
69
+ setattr(to_patch, name, wrap_method(obj, __torch_function__))
70
+ elif isinstance(obj, PROPERTY_TYPES):
71
+ setattr(to_patch, name, wrap_attr(obj))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .rearrange import rearrange
2
+
3
+ __all__ = ["rearrange"]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (288 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/_parsing.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/rearrange.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/rearrange.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from typing import Callable, Dict, List, Sequence, Tuple, Union
5
+
6
+ import torch
7
+
8
+ from functorch._C import dim as _C
9
+ from ._parsing import (
10
+ _ellipsis,
11
+ AnonymousAxis,
12
+ comma_separate,
13
+ parse_pattern,
14
+ validate_rearrange_expressions,
15
+ )
16
+
17
+ __all__ = ["rearrange"]
18
+
19
+ dims = _C.dims
20
+
21
+
22
+ @functools.lru_cache(256)
23
+ def _create_rearrange_callable(
24
+ tensor_ndim: int, pattern: str, **axes_lengths: int
25
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
26
+ r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
27
+
28
+ Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
29
+ specified axes lengths, this function can be memoized.
30
+
31
+ Args:
32
+ tensor_ndim (int): the number of dimensions in the tensor to rearrange
33
+ pattern (str): the `einops`-style rearrangement pattern
34
+ axes_lengths (int): any additional length specifications for dimensions
35
+
36
+ Returns:
37
+ Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
38
+ """
39
+ left, right = parse_pattern(pattern, axes_lengths)
40
+ validate_rearrange_expressions(left, right, axes_lengths)
41
+
42
+ n_anon_dims = sum(not dim for dim in left.composition)
43
+ if left.has_ellipsis:
44
+ n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
45
+ n_named_dims = len(left.identifiers) - 1
46
+
47
+ if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
48
+ raise ValueError(
49
+ f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
50
+ f"dimensions in the tensor ({tensor_ndim})"
51
+ )
52
+ else:
53
+ n_ellipsis_dims = 0
54
+ n_named_dims = len(left.identifiers)
55
+
56
+ if (pattern_ndim := len(left.composition)) != tensor_ndim:
57
+ raise ValueError(
58
+ f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
59
+ f"the tensor ({tensor_ndim})"
60
+ )
61
+ n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
62
+
63
+ if n_dims == 0:
64
+ # an identity rearrangement on a 0-dimension tensor
65
+ return lambda tensor: tensor
66
+
67
+ first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
68
+ identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
69
+ anon_axes: List[AnonymousAxis] = []
70
+
71
+ # map the left-hand side identifiers to strings representing first class dims
72
+ dims_i = 0
73
+ for dimension in left.composition:
74
+ if isinstance(dimension, list):
75
+ for identifier in dimension:
76
+ # non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
77
+ assert isinstance(identifier, str)
78
+ identifier_dim_map[identifier] = (first_class_dims[dims_i],)
79
+ dims_i += 1
80
+ if not dimension:
81
+ # unitary anonymous axis
82
+ anon_axis = AnonymousAxis("1")
83
+ identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
84
+ anon_axes.append(anon_axis)
85
+ dimension.append(anon_axis)
86
+ dims_i += 1
87
+ elif dimension == _ellipsis:
88
+ identifier = _ellipsis
89
+ identifier_dim_map[identifier] = tuple(
90
+ first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
91
+ )
92
+ dims_i += n_ellipsis_dims
93
+ else:
94
+ raise ValueError(f"Unexpected dimension: {dimension}")
95
+
96
+ def composition_to_dims(
97
+ composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
98
+ ) -> List[Union[str, Tuple[str, ...]]]:
99
+ """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
100
+ class dims."""
101
+ dim_composition: List[Union[str, Tuple[str, ...]]] = []
102
+ for dimension in composition:
103
+ if isinstance(dimension, list):
104
+ dim_composition.append(
105
+ tuple(
106
+ dim
107
+ for identifier in dimension
108
+ for dim in identifier_dim_map[identifier]
109
+ )
110
+ )
111
+ elif dimension == _ellipsis:
112
+ dim_composition.extend(identifier_dim_map[_ellipsis])
113
+ else:
114
+ raise ValueError(f"Unexpected dimension: {dimension}")
115
+ return dim_composition
116
+
117
+ left_dims = composition_to_dims(left.composition)
118
+ right_dims = composition_to_dims(right.composition)
119
+ anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
120
+ specified_lengths = tuple(
121
+ (identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
122
+ )
123
+
124
+ custom_rearrange_callable_name = "do_rearrange"
125
+ custom_rearrange_callable_code = (
126
+ (
127
+ f"def {custom_rearrange_callable_name}(tensor):\n"
128
+ f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
129
+ )
130
+ + (
131
+ "".join(
132
+ f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
133
+ )
134
+ if specified_lengths
135
+ else ""
136
+ )
137
+ + f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
138
+ + (
139
+ f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
140
+ if anon_dims
141
+ else " return tensor\n"
142
+ )
143
+ )
144
+
145
+ exec(custom_rearrange_callable_code)
146
+ return locals()[custom_rearrange_callable_name]
147
+
148
+
149
+ def rearrange(
150
+ tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
151
+ pattern: str,
152
+ **axes_lengths: int,
153
+ ) -> torch.Tensor:
154
+ r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
155
+ tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
156
+ stack, concatenate and other operations.
157
+
158
+ See: https://einops.rocks/api/rearrange/
159
+
160
+ Args:
161
+ tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
162
+ pattern (str): the rearrangement pattern
163
+ axes_lengths (int): any additional length specifications for dimensions
164
+
165
+ Returns:
166
+ Tensor: the rearranged tensor
167
+
168
+ Examples:
169
+ >>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
170
+ >>> images = torch.randn((32, 30, 40, 3))
171
+
172
+ >>> # stack along first (batch) axis, output is a single array
173
+ >>> rearrange(images, 'b h w c -> b h w c').shape
174
+ torch.Size([32, 30, 40, 3])
175
+
176
+ >>> # concatenate images along height (vertical axis), 960 = 32 * 30
177
+ >>> rearrange(images, 'b h w c -> (b h) w c').shape
178
+ torch.Size([960, 40, 3])
179
+
180
+ >>> # concatenated images along horizontal axis, 1280 = 32 * 40
181
+ >>> rearrange(images, 'b h w c -> h (b w) c').shape
182
+ torch.Size([30, 1280, 3])
183
+
184
+ >>> # reordered axes to "b c h w" format for deep learning
185
+ >>> rearrange(images, 'b h w c -> b c h w').shape
186
+ torch.Size([32, 3, 30, 40])
187
+
188
+ >>> # flattened each image into a vector, 3600 = 30 * 40 * 3
189
+ >>> rearrange(images, 'b h w c -> b (c h w)').shape
190
+ torch.Size([32, 3600])
191
+
192
+ >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
193
+ >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
194
+ torch.Size([128, 15, 20, 3])
195
+
196
+ >>> # space-to-depth operation
197
+ >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
198
+ torch.Size([32, 15, 20, 12])
199
+ """
200
+ if not isinstance(tensor, torch.Tensor):
201
+ tensor = torch.stack(tensor)
202
+
203
+ rearrange_callable = _create_rearrange_callable(
204
+ tensor.ndim, pattern, **axes_lengths
205
+ )
206
+
207
+ return rearrange_callable(tensor)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/__pycache__/extrapolation.cpython-311.pyc ADDED
Binary file (89.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/__pycache__/quadrature.cpython-311.pyc ADDED
Binary file (50.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/approximation.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..libmp.backend import xrange
2
+ from .calculus import defun
3
+
4
+ #----------------------------------------------------------------------------#
5
+ # Approximation methods #
6
+ #----------------------------------------------------------------------------#
7
+
8
+ # The Chebyshev approximation formula is given at:
9
+ # http://mathworld.wolfram.com/ChebyshevApproximationFormula.html
10
+
11
+ # The only major changes in the following code is that we return the
12
+ # expanded polynomial coefficients instead of Chebyshev coefficients,
13
+ # and that we automatically transform [a,b] -> [-1,1] and back
14
+ # for convenience.
15
+
16
+ # Coefficient in Chebyshev approximation
17
+ def chebcoeff(ctx,f,a,b,j,N):
18
+ s = ctx.mpf(0)
19
+ h = ctx.mpf(0.5)
20
+ for k in range(1, N+1):
21
+ t = ctx.cospi((k-h)/N)
22
+ s += f(t*(b-a)*h + (b+a)*h) * ctx.cospi(j*(k-h)/N)
23
+ return 2*s/N
24
+
25
+ # Generate Chebyshev polynomials T_n(ax+b) in expanded form
26
+ def chebT(ctx, a=1, b=0):
27
+ Tb = [1]
28
+ yield Tb
29
+ Ta = [b, a]
30
+ while 1:
31
+ yield Ta
32
+ # Recurrence: T[n+1](ax+b) = 2*(ax+b)*T[n](ax+b) - T[n-1](ax+b)
33
+ Tmp = [0] + [2*a*t for t in Ta]
34
+ for i, c in enumerate(Ta): Tmp[i] += 2*b*c
35
+ for i, c in enumerate(Tb): Tmp[i] -= c
36
+ Ta, Tb = Tmp, Ta
37
+
38
+ @defun
39
+ def chebyfit(ctx, f, interval, N, error=False):
40
+ r"""
41
+ Computes a polynomial of degree `N-1` that approximates the
42
+ given function `f` on the interval `[a, b]`. With ``error=True``,
43
+ :func:`~mpmath.chebyfit` also returns an accurate estimate of the
44
+ maximum absolute error; that is, the maximum value of
45
+ `|f(x) - P(x)|` for `x \in [a, b]`.
46
+
47
+ :func:`~mpmath.chebyfit` uses the Chebyshev approximation formula,
48
+ which gives a nearly optimal solution: that is, the maximum
49
+ error of the approximating polynomial is very close to
50
+ the smallest possible for any polynomial of the same degree.
51
+
52
+ Chebyshev approximation is very useful if one needs repeated
53
+ evaluation of an expensive function, such as function defined
54
+ implicitly by an integral or a differential equation. (For
55
+ example, it could be used to turn a slow mpmath function
56
+ into a fast machine-precision version of the same.)
57
+
58
+ **Examples**
59
+
60
+ Here we use :func:`~mpmath.chebyfit` to generate a low-degree approximation
61
+ of `f(x) = \cos(x)`, valid on the interval `[1, 2]`::
62
+
63
+ >>> from mpmath import *
64
+ >>> mp.dps = 15; mp.pretty = True
65
+ >>> poly, err = chebyfit(cos, [1, 2], 5, error=True)
66
+ >>> nprint(poly)
67
+ [0.00291682, 0.146166, -0.732491, 0.174141, 0.949553]
68
+ >>> nprint(err, 12)
69
+ 1.61351758081e-5
70
+
71
+ The polynomial can be evaluated using ``polyval``::
72
+
73
+ >>> nprint(polyval(poly, 1.6), 12)
74
+ -0.0291858904138
75
+ >>> nprint(cos(1.6), 12)
76
+ -0.0291995223013
77
+
78
+ Sampling the true error at 1000 points shows that the error
79
+ estimate generated by ``chebyfit`` is remarkably good::
80
+
81
+ >>> error = lambda x: abs(cos(x) - polyval(poly, x))
82
+ >>> nprint(max([error(1+n/1000.) for n in range(1000)]), 12)
83
+ 1.61349954245e-5
84
+
85
+ **Choice of degree**
86
+
87
+ The degree `N` can be set arbitrarily high, to obtain an
88
+ arbitrarily good approximation. As a rule of thumb, an
89
+ `N`-term Chebyshev approximation is good to `N/(b-a)` decimal
90
+ places on a unit interval (although this depends on how
91
+ well-behaved `f` is). The cost grows accordingly: ``chebyfit``
92
+ evaluates the function `(N^2)/2` times to compute the
93
+ coefficients and an additional `N` times to estimate the error.
94
+
95
+ **Possible issues**
96
+
97
+ One should be careful to use a sufficiently high working
98
+ precision both when calling ``chebyfit`` and when evaluating
99
+ the resulting polynomial, as the polynomial is sometimes
100
+ ill-conditioned. It is for example difficult to reach
101
+ 15-digit accuracy when evaluating the polynomial using
102
+ machine precision floats, no matter the theoretical
103
+ accuracy of the polynomial. (The option to return the
104
+ coefficients in Chebyshev form should be made available
105
+ in the future.)
106
+
107
+ It is important to note the Chebyshev approximation works
108
+ poorly if `f` is not smooth. A function containing singularities,
109
+ rapid oscillation, etc can be approximated more effectively by
110
+ multiplying it by a weight function that cancels out the
111
+ nonsmooth features, or by dividing the interval into several
112
+ segments.
113
+ """
114
+ a, b = ctx._as_points(interval)
115
+ orig = ctx.prec
116
+ try:
117
+ ctx.prec = orig + int(N**0.5) + 20
118
+ c = [chebcoeff(ctx,f,a,b,k,N) for k in range(N)]
119
+ d = [ctx.zero] * N
120
+ d[0] = -c[0]/2
121
+ h = ctx.mpf(0.5)
122
+ T = chebT(ctx, ctx.mpf(2)/(b-a), ctx.mpf(-1)*(b+a)/(b-a))
123
+ for (k, Tk) in zip(range(N), T):
124
+ for i in range(len(Tk)):
125
+ d[i] += c[k]*Tk[i]
126
+ d = d[::-1]
127
+ # Estimate maximum error
128
+ err = ctx.zero
129
+ for k in range(N):
130
+ x = ctx.cos(ctx.pi*k/N) * (b-a)*h + (b+a)*h
131
+ err = max(err, abs(f(x) - ctx.polyval(d, x)))
132
+ finally:
133
+ ctx.prec = orig
134
+ if error:
135
+ return d, +err
136
+ else:
137
+ return d
138
+
139
+ @defun
140
+ def fourier(ctx, f, interval, N):
141
+ r"""
142
+ Computes the Fourier series of degree `N` of the given function
143
+ on the interval `[a, b]`. More precisely, :func:`~mpmath.fourier` returns
144
+ two lists `(c, s)` of coefficients (the cosine series and sine
145
+ series, respectively), such that
146
+
147
+ .. math ::
148
+
149
+ f(x) \sim \sum_{k=0}^N
150
+ c_k \cos(k m x) + s_k \sin(k m x)
151
+
152
+ where `m = 2 \pi / (b-a)`.
153
+
154
+ Note that many texts define the first coefficient as `2 c_0` instead
155
+ of `c_0`. The easiest way to evaluate the computed series correctly
156
+ is to pass it to :func:`~mpmath.fourierval`.
157
+
158
+ **Examples**
159
+
160
+ The function `f(x) = x` has a simple Fourier series on the standard
161
+ interval `[-\pi, \pi]`. The cosine coefficients are all zero (because
162
+ the function has odd symmetry), and the sine coefficients are
163
+ rational numbers::
164
+
165
+ >>> from mpmath import *
166
+ >>> mp.dps = 15; mp.pretty = True
167
+ >>> c, s = fourier(lambda x: x, [-pi, pi], 5)
168
+ >>> nprint(c)
169
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
170
+ >>> nprint(s)
171
+ [0.0, 2.0, -1.0, 0.666667, -0.5, 0.4]
172
+
173
+ This computes a Fourier series of a nonsymmetric function on
174
+ a nonstandard interval::
175
+
176
+ >>> I = [-1, 1.5]
177
+ >>> f = lambda x: x**2 - 4*x + 1
178
+ >>> cs = fourier(f, I, 4)
179
+ >>> nprint(cs[0])
180
+ [0.583333, 1.12479, -1.27552, 0.904708, -0.441296]
181
+ >>> nprint(cs[1])
182
+ [0.0, -2.6255, 0.580905, 0.219974, -0.540057]
183
+
184
+ It is instructive to plot a function along with its truncated
185
+ Fourier series::
186
+
187
+ >>> plot([f, lambda x: fourierval(cs, I, x)], I) #doctest: +SKIP
188
+
189
+ Fourier series generally converge slowly (and may not converge
190
+ pointwise). For example, if `f(x) = \cosh(x)`, a 10-term Fourier
191
+ series gives an `L^2` error corresponding to 2-digit accuracy::
192
+
193
+ >>> I = [-1, 1]
194
+ >>> cs = fourier(cosh, I, 9)
195
+ >>> g = lambda x: (cosh(x) - fourierval(cs, I, x))**2
196
+ >>> nprint(sqrt(quad(g, I)))
197
+ 0.00467963
198
+
199
+ :func:`~mpmath.fourier` uses numerical quadrature. For nonsmooth functions,
200
+ the accuracy (and speed) can be improved by including all singular
201
+ points in the interval specification::
202
+
203
+ >>> nprint(fourier(abs, [-1, 1], 0), 10)
204
+ ([0.5000441648], [0.0])
205
+ >>> nprint(fourier(abs, [-1, 0, 1], 0), 10)
206
+ ([0.5], [0.0])
207
+
208
+ """
209
+ interval = ctx._as_points(interval)
210
+ a = interval[0]
211
+ b = interval[-1]
212
+ L = b-a
213
+ cos_series = []
214
+ sin_series = []
215
+ cutoff = ctx.eps*10
216
+ for n in xrange(N+1):
217
+ m = 2*n*ctx.pi/L
218
+ an = 2*ctx.quadgl(lambda t: f(t)*ctx.cos(m*t), interval)/L
219
+ bn = 2*ctx.quadgl(lambda t: f(t)*ctx.sin(m*t), interval)/L
220
+ if n == 0:
221
+ an /= 2
222
+ if abs(an) < cutoff: an = ctx.zero
223
+ if abs(bn) < cutoff: bn = ctx.zero
224
+ cos_series.append(an)
225
+ sin_series.append(bn)
226
+ return cos_series, sin_series
227
+
228
+ @defun
229
+ def fourierval(ctx, series, interval, x):
230
+ """
231
+ Evaluates a Fourier series (in the format computed by
232
+ by :func:`~mpmath.fourier` for the given interval) at the point `x`.
233
+
234
+ The series should be a pair `(c, s)` where `c` is the
235
+ cosine series and `s` is the sine series. The two lists
236
+ need not have the same length.
237
+ """
238
+ cs, ss = series
239
+ ab = ctx._as_points(interval)
240
+ a = interval[0]
241
+ b = interval[-1]
242
+ m = 2*ctx.pi/(ab[-1]-ab[0])
243
+ s = ctx.zero
244
+ s += ctx.fsum(cs[n]*ctx.cos(m*n*x) for n in xrange(len(cs)) if cs[n])
245
+ s += ctx.fsum(ss[n]*ctx.sin(m*n*x) for n in xrange(len(ss)) if ss[n])
246
+ return s
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import eigen # to set methods
2
+ from . import eigen_symmetric # to set methods
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/__pycache__/calculus.cpython-311.pyc ADDED
Binary file (22.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/calculus.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..libmp.backend import xrange
2
+
3
+ # TODO: should use diagonalization-based algorithms
4
+
5
+ class MatrixCalculusMethods(object):
6
+
7
+ def _exp_pade(ctx, a):
8
+ """
9
+ Exponential of a matrix using Pade approximants.
10
+
11
+ See G. H. Golub, C. F. van Loan 'Matrix Computations',
12
+ third Ed., page 572
13
+
14
+ TODO:
15
+ - find a good estimate for q
16
+ - reduce the number of matrix multiplications to improve
17
+ performance
18
+ """
19
+ def eps_pade(p):
20
+ return ctx.mpf(2)**(3-2*p) * \
21
+ ctx.factorial(p)**2/(ctx.factorial(2*p)**2 * (2*p + 1))
22
+ q = 4
23
+ extraq = 8
24
+ while 1:
25
+ if eps_pade(q) < ctx.eps:
26
+ break
27
+ q += 1
28
+ q += extraq
29
+ j = int(max(1, ctx.mag(ctx.mnorm(a,'inf'))))
30
+ extra = q
31
+ prec = ctx.prec
32
+ ctx.dps += extra + 3
33
+ try:
34
+ a = a/2**j
35
+ na = a.rows
36
+ den = ctx.eye(na)
37
+ num = ctx.eye(na)
38
+ x = ctx.eye(na)
39
+ c = ctx.mpf(1)
40
+ for k in range(1, q+1):
41
+ c *= ctx.mpf(q - k + 1)/((2*q - k + 1) * k)
42
+ x = a*x
43
+ cx = c*x
44
+ num += cx
45
+ den += (-1)**k * cx
46
+ f = ctx.lu_solve_mat(den, num)
47
+ for k in range(j):
48
+ f = f*f
49
+ finally:
50
+ ctx.prec = prec
51
+ return f*1
52
+
53
+ def expm(ctx, A, method='taylor'):
54
+ r"""
55
+ Computes the matrix exponential of a square matrix `A`, which is defined
56
+ by the power series
57
+
58
+ .. math ::
59
+
60
+ \exp(A) = I + A + \frac{A^2}{2!} + \frac{A^3}{3!} + \ldots
61
+
62
+ With method='taylor', the matrix exponential is computed
63
+ using the Taylor series. With method='pade', Pade approximants
64
+ are used instead.
65
+
66
+ **Examples**
67
+
68
+ Basic examples::
69
+
70
+ >>> from mpmath import *
71
+ >>> mp.dps = 15; mp.pretty = True
72
+ >>> expm(zeros(3))
73
+ [1.0 0.0 0.0]
74
+ [0.0 1.0 0.0]
75
+ [0.0 0.0 1.0]
76
+ >>> expm(eye(3))
77
+ [2.71828182845905 0.0 0.0]
78
+ [ 0.0 2.71828182845905 0.0]
79
+ [ 0.0 0.0 2.71828182845905]
80
+ >>> expm([[1,1,0],[1,0,1],[0,1,0]])
81
+ [ 3.86814500615414 2.26812870852145 0.841130841230196]
82
+ [ 2.26812870852145 2.44114713886289 1.42699786729125]
83
+ [0.841130841230196 1.42699786729125 1.6000162976327]
84
+ >>> expm([[1,1,0],[1,0,1],[0,1,0]], method='pade')
85
+ [ 3.86814500615414 2.26812870852145 0.841130841230196]
86
+ [ 2.26812870852145 2.44114713886289 1.42699786729125]
87
+ [0.841130841230196 1.42699786729125 1.6000162976327]
88
+ >>> expm([[1+j, 0], [1+j,1]])
89
+ [(1.46869393991589 + 2.28735528717884j) 0.0]
90
+ [ (1.03776739863568 + 3.536943175722j) (2.71828182845905 + 0.0j)]
91
+
92
+ Matrices with large entries are allowed::
93
+
94
+ >>> expm(matrix([[1,2],[2,3]])**25)
95
+ [5.65024064048415e+2050488462815550 9.14228140091932e+2050488462815550]
96
+ [9.14228140091932e+2050488462815550 1.47925220414035e+2050488462815551]
97
+
98
+ The identity `\exp(A+B) = \exp(A) \exp(B)` does not hold for
99
+ noncommuting matrices::
100
+
101
+ >>> A = hilbert(3)
102
+ >>> B = A + eye(3)
103
+ >>> chop(mnorm(A*B - B*A))
104
+ 0.0
105
+ >>> chop(mnorm(expm(A+B) - expm(A)*expm(B)))
106
+ 0.0
107
+ >>> B = A + ones(3)
108
+ >>> mnorm(A*B - B*A)
109
+ 1.8
110
+ >>> mnorm(expm(A+B) - expm(A)*expm(B))
111
+ 42.0927851137247
112
+
113
+ """
114
+ if method == 'pade':
115
+ prec = ctx.prec
116
+ try:
117
+ A = ctx.matrix(A)
118
+ ctx.prec += 2*A.rows
119
+ res = ctx._exp_pade(A)
120
+ finally:
121
+ ctx.prec = prec
122
+ return res
123
+ A = ctx.matrix(A)
124
+ prec = ctx.prec
125
+ j = int(max(1, ctx.mag(ctx.mnorm(A,'inf'))))
126
+ j += int(0.5*prec**0.5)
127
+ try:
128
+ ctx.prec += 10 + 2*j
129
+ tol = +ctx.eps
130
+ A = A/2**j
131
+ T = A
132
+ Y = A**0 + A
133
+ k = 2
134
+ while 1:
135
+ T *= A * (1/ctx.mpf(k))
136
+ if ctx.mnorm(T, 'inf') < tol:
137
+ break
138
+ Y += T
139
+ k += 1
140
+ for k in xrange(j):
141
+ Y = Y*Y
142
+ finally:
143
+ ctx.prec = prec
144
+ Y *= 1
145
+ return Y
146
+
147
+ def cosm(ctx, A):
148
+ r"""
149
+ Gives the cosine of a square matrix `A`, defined in analogy
150
+ with the matrix exponential.
151
+
152
+ Examples::
153
+
154
+ >>> from mpmath import *
155
+ >>> mp.dps = 15; mp.pretty = True
156
+ >>> X = eye(3)
157
+ >>> cosm(X)
158
+ [0.54030230586814 0.0 0.0]
159
+ [ 0.0 0.54030230586814 0.0]
160
+ [ 0.0 0.0 0.54030230586814]
161
+ >>> X = hilbert(3)
162
+ >>> cosm(X)
163
+ [ 0.424403834569555 -0.316643413047167 -0.221474945949293]
164
+ [-0.316643413047167 0.820646708837824 -0.127183694770039]
165
+ [-0.221474945949293 -0.127183694770039 0.909236687217541]
166
+ >>> X = matrix([[1+j,-2],[0,-j]])
167
+ >>> cosm(X)
168
+ [(0.833730025131149 - 0.988897705762865j) (1.07485840848393 - 0.17192140544213j)]
169
+ [ 0.0 (1.54308063481524 + 0.0j)]
170
+ """
171
+ B = 0.5 * (ctx.expm(A*ctx.j) + ctx.expm(A*(-ctx.j)))
172
+ if not sum(A.apply(ctx.im).apply(abs)):
173
+ B = B.apply(ctx.re)
174
+ return B
175
+
176
+ def sinm(ctx, A):
177
+ r"""
178
+ Gives the sine of a square matrix `A`, defined in analogy
179
+ with the matrix exponential.
180
+
181
+ Examples::
182
+
183
+ >>> from mpmath import *
184
+ >>> mp.dps = 15; mp.pretty = True
185
+ >>> X = eye(3)
186
+ >>> sinm(X)
187
+ [0.841470984807897 0.0 0.0]
188
+ [ 0.0 0.841470984807897 0.0]
189
+ [ 0.0 0.0 0.841470984807897]
190
+ >>> X = hilbert(3)
191
+ >>> sinm(X)
192
+ [0.711608512150994 0.339783913247439 0.220742837314741]
193
+ [0.339783913247439 0.244113865695532 0.187231271174372]
194
+ [0.220742837314741 0.187231271174372 0.155816730769635]
195
+ >>> X = matrix([[1+j,-2],[0,-j]])
196
+ >>> sinm(X)
197
+ [(1.29845758141598 + 0.634963914784736j) (-1.96751511930922 + 0.314700021761367j)]
198
+ [ 0.0 (0.0 - 1.1752011936438j)]
199
+ """
200
+ B = (-0.5j) * (ctx.expm(A*ctx.j) - ctx.expm(A*(-ctx.j)))
201
+ if not sum(A.apply(ctx.im).apply(abs)):
202
+ B = B.apply(ctx.re)
203
+ return B
204
+
205
+ def _sqrtm_rot(ctx, A, _may_rotate):
206
+ # If the iteration fails to converge, cheat by performing
207
+ # a rotation by a complex number
208
+ u = ctx.j**0.3
209
+ return ctx.sqrtm(u*A, _may_rotate) / ctx.sqrt(u)
210
+
211
+ def sqrtm(ctx, A, _may_rotate=2):
212
+ r"""
213
+ Computes a square root of the square matrix `A`, i.e. returns
214
+ a matrix `B = A^{1/2}` such that `B^2 = A`. The square root
215
+ of a matrix, if it exists, is not unique.
216
+
217
+ **Examples**
218
+
219
+ Square roots of some simple matrices::
220
+
221
+ >>> from mpmath import *
222
+ >>> mp.dps = 15; mp.pretty = True
223
+ >>> sqrtm([[1,0], [0,1]])
224
+ [1.0 0.0]
225
+ [0.0 1.0]
226
+ >>> sqrtm([[0,0], [0,0]])
227
+ [0.0 0.0]
228
+ [0.0 0.0]
229
+ >>> sqrtm([[2,0],[0,1]])
230
+ [1.4142135623731 0.0]
231
+ [ 0.0 1.0]
232
+ >>> sqrtm([[1,1],[1,0]])
233
+ [ (0.920442065259926 - 0.21728689675164j) (0.568864481005783 + 0.351577584254143j)]
234
+ [(0.568864481005783 + 0.351577584254143j) (0.351577584254143 - 0.568864481005783j)]
235
+ >>> sqrtm([[1,0],[0,1]])
236
+ [1.0 0.0]
237
+ [0.0 1.0]
238
+ >>> sqrtm([[-1,0],[0,1]])
239
+ [(0.0 - 1.0j) 0.0]
240
+ [ 0.0 (1.0 + 0.0j)]
241
+ >>> sqrtm([[j,0],[0,j]])
242
+ [(0.707106781186547 + 0.707106781186547j) 0.0]
243
+ [ 0.0 (0.707106781186547 + 0.707106781186547j)]
244
+
245
+ A square root of a rotation matrix, giving the corresponding
246
+ half-angle rotation matrix::
247
+
248
+ >>> t1 = 0.75
249
+ >>> t2 = t1 * 0.5
250
+ >>> A1 = matrix([[cos(t1), -sin(t1)], [sin(t1), cos(t1)]])
251
+ >>> A2 = matrix([[cos(t2), -sin(t2)], [sin(t2), cos(t2)]])
252
+ >>> sqrtm(A1)
253
+ [0.930507621912314 -0.366272529086048]
254
+ [0.366272529086048 0.930507621912314]
255
+ >>> A2
256
+ [0.930507621912314 -0.366272529086048]
257
+ [0.366272529086048 0.930507621912314]
258
+
259
+ The identity `(A^2)^{1/2} = A` does not necessarily hold::
260
+
261
+ >>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
262
+ >>> sqrtm(A**2)
263
+ [ 4.0 1.0 4.0]
264
+ [ 7.0 8.0 9.0]
265
+ [10.0 2.0 11.0]
266
+ >>> sqrtm(A)**2
267
+ [ 4.0 1.0 4.0]
268
+ [ 7.0 8.0 9.0]
269
+ [10.0 2.0 11.0]
270
+ >>> A = matrix([[-4,1,4],[7,-8,9],[10,2,11]])
271
+ >>> sqrtm(A**2)
272
+ [ 7.43715112194995 -0.324127569985474 1.8481718827526]
273
+ [-0.251549715716942 9.32699765900402 2.48221180985147]
274
+ [ 4.11609388833616 0.775751877098258 13.017955697342]
275
+ >>> chop(sqrtm(A)**2)
276
+ [-4.0 1.0 4.0]
277
+ [ 7.0 -8.0 9.0]
278
+ [10.0 2.0 11.0]
279
+
280
+ For some matrices, a square root does not exist::
281
+
282
+ >>> sqrtm([[0,1], [0,0]])
283
+ Traceback (most recent call last):
284
+ ...
285
+ ZeroDivisionError: matrix is numerically singular
286
+
287
+ Two examples from the documentation for Matlab's ``sqrtm``::
288
+
289
+ >>> mp.dps = 15; mp.pretty = True
290
+ >>> sqrtm([[7,10],[15,22]])
291
+ [1.56669890360128 1.74077655955698]
292
+ [2.61116483933547 4.17786374293675]
293
+ >>>
294
+ >>> X = matrix(\
295
+ ... [[5,-4,1,0,0],
296
+ ... [-4,6,-4,1,0],
297
+ ... [1,-4,6,-4,1],
298
+ ... [0,1,-4,6,-4],
299
+ ... [0,0,1,-4,5]])
300
+ >>> Y = matrix(\
301
+ ... [[2,-1,-0,-0,-0],
302
+ ... [-1,2,-1,0,-0],
303
+ ... [0,-1,2,-1,0],
304
+ ... [-0,0,-1,2,-1],
305
+ ... [-0,-0,-0,-1,2]])
306
+ >>> mnorm(sqrtm(X) - Y)
307
+ 4.53155328326114e-19
308
+
309
+ """
310
+ A = ctx.matrix(A)
311
+ # Trivial
312
+ if A*0 == A:
313
+ return A
314
+ prec = ctx.prec
315
+ if _may_rotate:
316
+ d = ctx.det(A)
317
+ if abs(ctx.im(d)) < 16*ctx.eps and ctx.re(d) < 0:
318
+ return ctx._sqrtm_rot(A, _may_rotate-1)
319
+ try:
320
+ ctx.prec += 10
321
+ tol = ctx.eps * 128
322
+ Y = A
323
+ Z = I = A**0
324
+ k = 0
325
+ # Denman-Beavers iteration
326
+ while 1:
327
+ Yprev = Y
328
+ try:
329
+ Y, Z = 0.5*(Y+ctx.inverse(Z)), 0.5*(Z+ctx.inverse(Y))
330
+ except ZeroDivisionError:
331
+ if _may_rotate:
332
+ Y = ctx._sqrtm_rot(A, _may_rotate-1)
333
+ break
334
+ else:
335
+ raise
336
+ mag1 = ctx.mnorm(Y-Yprev, 'inf')
337
+ mag2 = ctx.mnorm(Y, 'inf')
338
+ if mag1 <= mag2*tol:
339
+ break
340
+ if _may_rotate and k > 6 and not mag1 < mag2 * 0.001:
341
+ return ctx._sqrtm_rot(A, _may_rotate-1)
342
+ k += 1
343
+ if k > ctx.prec:
344
+ raise ctx.NoConvergence
345
+ finally:
346
+ ctx.prec = prec
347
+ Y *= 1
348
+ return Y
349
+
350
+ def logm(ctx, A):
351
+ r"""
352
+ Computes a logarithm of the square matrix `A`, i.e. returns
353
+ a matrix `B = \log(A)` such that `\exp(B) = A`. The logarithm
354
+ of a matrix, if it exists, is not unique.
355
+
356
+ **Examples**
357
+
358
+ Logarithms of some simple matrices::
359
+
360
+ >>> from mpmath import *
361
+ >>> mp.dps = 15; mp.pretty = True
362
+ >>> X = eye(3)
363
+ >>> logm(X)
364
+ [0.0 0.0 0.0]
365
+ [0.0 0.0 0.0]
366
+ [0.0 0.0 0.0]
367
+ >>> logm(2*X)
368
+ [0.693147180559945 0.0 0.0]
369
+ [ 0.0 0.693147180559945 0.0]
370
+ [ 0.0 0.0 0.693147180559945]
371
+ >>> logm(expm(X))
372
+ [1.0 0.0 0.0]
373
+ [0.0 1.0 0.0]
374
+ [0.0 0.0 1.0]
375
+
376
+ A logarithm of a complex matrix::
377
+
378
+ >>> X = matrix([[2+j, 1, 3], [1-j, 1-2*j, 1], [-4, -5, j]])
379
+ >>> B = logm(X)
380
+ >>> nprint(B)
381
+ [ (0.808757 + 0.107759j) (2.20752 + 0.202762j) (1.07376 - 0.773874j)]
382
+ [ (0.905709 - 0.107795j) (0.0287395 - 0.824993j) (0.111619 + 0.514272j)]
383
+ [(-0.930151 + 0.399512j) (-2.06266 - 0.674397j) (0.791552 + 0.519839j)]
384
+ >>> chop(expm(B))
385
+ [(2.0 + 1.0j) 1.0 3.0]
386
+ [(1.0 - 1.0j) (1.0 - 2.0j) 1.0]
387
+ [ -4.0 -5.0 (0.0 + 1.0j)]
388
+
389
+ A matrix `X` close to the identity matrix, for which
390
+ `\log(\exp(X)) = \exp(\log(X)) = X` holds::
391
+
392
+ >>> X = eye(3) + hilbert(3)/4
393
+ >>> X
394
+ [ 1.25 0.125 0.0833333333333333]
395
+ [ 0.125 1.08333333333333 0.0625]
396
+ [0.0833333333333333 0.0625 1.05]
397
+ >>> logm(expm(X))
398
+ [ 1.25 0.125 0.0833333333333333]
399
+ [ 0.125 1.08333333333333 0.0625]
400
+ [0.0833333333333333 0.0625 1.05]
401
+ >>> expm(logm(X))
402
+ [ 1.25 0.125 0.0833333333333333]
403
+ [ 0.125 1.08333333333333 0.0625]
404
+ [0.0833333333333333 0.0625 1.05]
405
+
406
+ A logarithm of a rotation matrix, giving back the angle of
407
+ the rotation::
408
+
409
+ >>> t = 3.7
410
+ >>> A = matrix([[cos(t),sin(t)],[-sin(t),cos(t)]])
411
+ >>> chop(logm(A))
412
+ [ 0.0 -2.58318530717959]
413
+ [2.58318530717959 0.0]
414
+ >>> (2*pi-t)
415
+ 2.58318530717959
416
+
417
+ For some matrices, a logarithm does not exist::
418
+
419
+ >>> logm([[1,0], [0,0]])
420
+ Traceback (most recent call last):
421
+ ...
422
+ ZeroDivisionError: matrix is numerically singular
423
+
424
+ Logarithm of a matrix with large entries::
425
+
426
+ >>> logm(hilbert(3) * 10**20).apply(re)
427
+ [ 45.5597513593433 1.27721006042799 0.317662687717978]
428
+ [ 1.27721006042799 42.5222778973542 2.24003708791604]
429
+ [0.317662687717978 2.24003708791604 42.395212822267]
430
+
431
+ """
432
+ A = ctx.matrix(A)
433
+ prec = ctx.prec
434
+ try:
435
+ ctx.prec += 10
436
+ tol = ctx.eps * 128
437
+ I = A**0
438
+ B = A
439
+ n = 0
440
+ while 1:
441
+ B = ctx.sqrtm(B)
442
+ n += 1
443
+ if ctx.mnorm(B-I, 'inf') < 0.125:
444
+ break
445
+ T = X = B-I
446
+ L = X*0
447
+ k = 1
448
+ while 1:
449
+ if k & 1:
450
+ L += T / k
451
+ else:
452
+ L -= T / k
453
+ T *= X
454
+ if ctx.mnorm(T, 'inf') < tol:
455
+ break
456
+ k += 1
457
+ if k > ctx.prec:
458
+ raise ctx.NoConvergence
459
+ finally:
460
+ ctx.prec = prec
461
+ L *= 2**n
462
+ return L
463
+
464
+ def powm(ctx, A, r):
465
+ r"""
466
+ Computes `A^r = \exp(A \log r)` for a matrix `A` and complex
467
+ number `r`.
468
+
469
+ **Examples**
470
+
471
+ Powers and inverse powers of a matrix::
472
+
473
+ >>> from mpmath import *
474
+ >>> mp.dps = 15; mp.pretty = True
475
+ >>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
476
+ >>> powm(A, 2)
477
+ [ 63.0 20.0 69.0]
478
+ [174.0 89.0 199.0]
479
+ [164.0 48.0 179.0]
480
+ >>> chop(powm(powm(A, 4), 1/4.))
481
+ [ 4.0 1.0 4.0]
482
+ [ 7.0 8.0 9.0]
483
+ [10.0 2.0 11.0]
484
+ >>> powm(extraprec(20)(powm)(A, -4), -1/4.)
485
+ [ 4.0 1.0 4.0]
486
+ [ 7.0 8.0 9.0]
487
+ [10.0 2.0 11.0]
488
+ >>> chop(powm(powm(A, 1+0.5j), 1/(1+0.5j)))
489
+ [ 4.0 1.0 4.0]
490
+ [ 7.0 8.0 9.0]
491
+ [10.0 2.0 11.0]
492
+ >>> powm(extraprec(5)(powm)(A, -1.5), -1/(1.5))
493
+ [ 4.0 1.0 4.0]
494
+ [ 7.0 8.0 9.0]
495
+ [10.0 2.0 11.0]
496
+
497
+ A Fibonacci-generating matrix::
498
+
499
+ >>> powm([[1,1],[1,0]], 10)
500
+ [89.0 55.0]
501
+ [55.0 34.0]
502
+ >>> fib(10)
503
+ 55.0
504
+ >>> powm([[1,1],[1,0]], 6.5)
505
+ [(16.5166626964253 - 0.0121089837381789j) (10.2078589271083 + 0.0195927472575932j)]
506
+ [(10.2078589271083 + 0.0195927472575932j) (6.30880376931698 - 0.0317017309957721j)]
507
+ >>> (phi**6.5 - (1-phi)**6.5)/sqrt(5)
508
+ (10.2078589271083 - 0.0195927472575932j)
509
+ >>> powm([[1,1],[1,0]], 6.2)
510
+ [ (14.3076953002666 - 0.008222855781077j) (8.81733464837593 + 0.0133048601383712j)]
511
+ [(8.81733464837593 + 0.0133048601383712j) (5.49036065189071 - 0.0215277159194482j)]
512
+ >>> (phi**6.2 - (1-phi)**6.2)/sqrt(5)
513
+ (8.81733464837593 - 0.0133048601383712j)
514
+
515
+ """
516
+ A = ctx.matrix(A)
517
+ r = ctx.convert(r)
518
+ prec = ctx.prec
519
+ try:
520
+ ctx.prec += 10
521
+ if ctx.isint(r):
522
+ v = A ** int(r)
523
+ elif ctx.isint(r*2):
524
+ y = int(r*2)
525
+ v = ctx.sqrtm(A) ** y
526
+ else:
527
+ v = ctx.expm(r*ctx.logm(A))
528
+ finally:
529
+ ctx.prec = prec
530
+ v *= 1
531
+ return v
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/eigen.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ ##################################################################################################
5
+ # module for the eigenvalue problem
6
+ # Copyright 2013 Timo Hartmann (thartmann15 at gmail.com)
7
+ #
8
+ # todo:
9
+ # - implement balancing
10
+ # - agressive early deflation
11
+ #
12
+ ##################################################################################################
13
+
14
+ """
15
+ The eigenvalue problem
16
+ ----------------------
17
+
18
+ This file contains routines for the eigenvalue problem.
19
+
20
+ high level routines:
21
+
22
+ hessenberg : reduction of a real or complex square matrix to upper Hessenberg form
23
+ schur : reduction of a real or complex square matrix to upper Schur form
24
+ eig : eigenvalues and eigenvectors of a real or complex square matrix
25
+
26
+ low level routines:
27
+
28
+ hessenberg_reduce_0 : reduction of a real or complex square matrix to upper Hessenberg form
29
+ hessenberg_reduce_1 : auxiliary routine to hessenberg_reduce_0
30
+ qr_step : a single implicitly shifted QR step for an upper Hessenberg matrix
31
+ hessenberg_qr : Schur decomposition of an upper Hessenberg matrix
32
+ eig_tr_r : right eigenvectors of an upper triangular matrix
33
+ eig_tr_l : left eigenvectors of an upper triangular matrix
34
+ """
35
+
36
+ from ..libmp.backend import xrange
37
+
38
+ class Eigen(object):
39
+ pass
40
+
41
+ def defun(f):
42
+ setattr(Eigen, f.__name__, f)
43
+ return f
44
+
45
+ def hessenberg_reduce_0(ctx, A, T):
46
+ """
47
+ This routine computes the (upper) Hessenberg decomposition of a square matrix A.
48
+ Given A, an unitary matrix Q is calculated such that
49
+
50
+ Q' A Q = H and Q' Q = Q Q' = 1
51
+
52
+ where H is an upper Hessenberg matrix, meaning that it only contains zeros
53
+ below the first subdiagonal. Here ' denotes the hermitian transpose (i.e.
54
+ transposition and conjugation).
55
+
56
+ parameters:
57
+ A (input/output) On input, A contains the square matrix A of
58
+ dimension (n,n). On output, A contains a compressed representation
59
+ of Q and H.
60
+ T (output) An array of length n containing the first elements of
61
+ the Householder reflectors.
62
+ """
63
+
64
+ # internally we work with householder reflections from the right.
65
+ # let u be a row vector (i.e. u[i]=A[i,:i]). then
66
+ # Q is build up by reflectors of the type (1-v'v) where v is a suitable
67
+ # modification of u. these reflectors are applyed to A from the right.
68
+ # because we work with reflectors from the right we have to start with
69
+ # the bottom row of A and work then upwards (this corresponds to
70
+ # some kind of RQ decomposition).
71
+ # the first part of the vectors v (i.e. A[i,:(i-1)]) are stored as row vectors
72
+ # in the lower left part of A (excluding the diagonal and subdiagonal).
73
+ # the last entry of v is stored in T.
74
+ # the upper right part of A (including diagonal and subdiagonal) becomes H.
75
+
76
+
77
+ n = A.rows
78
+ if n <= 2: return
79
+
80
+ for i in xrange(n-1, 1, -1):
81
+
82
+ # scale the vector
83
+
84
+ scale = 0
85
+ for k in xrange(0, i):
86
+ scale += abs(ctx.re(A[i,k])) + abs(ctx.im(A[i,k]))
87
+
88
+ scale_inv = 0
89
+ if scale != 0:
90
+ scale_inv = 1 / scale
91
+
92
+ if scale == 0 or ctx.isinf(scale_inv):
93
+ # sadly there are floating point numbers not equal to zero whose reciprocal is infinity
94
+ T[i] = 0
95
+ A[i,i-1] = 0
96
+ continue
97
+
98
+ # calculate parameters for housholder transformation
99
+
100
+ H = 0
101
+ for k in xrange(0, i):
102
+ A[i,k] *= scale_inv
103
+ rr = ctx.re(A[i,k])
104
+ ii = ctx.im(A[i,k])
105
+ H += rr * rr + ii * ii
106
+
107
+ F = A[i,i-1]
108
+ f = abs(F)
109
+ G = ctx.sqrt(H)
110
+ A[i,i-1] = - G * scale
111
+
112
+ if f == 0:
113
+ T[i] = G
114
+ else:
115
+ ff = F / f
116
+ T[i] = F + G * ff
117
+ A[i,i-1] *= ff
118
+
119
+ H += G * f
120
+ H = 1 / ctx.sqrt(H)
121
+
122
+ T[i] *= H
123
+ for k in xrange(0, i - 1):
124
+ A[i,k] *= H
125
+
126
+ for j in xrange(0, i):
127
+ # apply housholder transformation (from right)
128
+
129
+ G = ctx.conj(T[i]) * A[j,i-1]
130
+ for k in xrange(0, i-1):
131
+ G += ctx.conj(A[i,k]) * A[j,k]
132
+
133
+ A[j,i-1] -= G * T[i]
134
+ for k in xrange(0, i-1):
135
+ A[j,k] -= G * A[i,k]
136
+
137
+ for j in xrange(0, n):
138
+ # apply housholder transformation (from left)
139
+
140
+ G = T[i] * A[i-1,j]
141
+ for k in xrange(0, i-1):
142
+ G += A[i,k] * A[k,j]
143
+
144
+ A[i-1,j] -= G * ctx.conj(T[i])
145
+ for k in xrange(0, i-1):
146
+ A[k,j] -= G * ctx.conj(A[i,k])
147
+
148
+
149
+
150
+ def hessenberg_reduce_1(ctx, A, T):
151
+ """
152
+ This routine forms the unitary matrix Q described in hessenberg_reduce_0.
153
+
154
+ parameters:
155
+ A (input/output) On input, A is the same matrix as delivered by
156
+ hessenberg_reduce_0. On output, A is set to Q.
157
+
158
+ T (input) On input, T is the same array as delivered by hessenberg_reduce_0.
159
+ """
160
+
161
+ n = A.rows
162
+
163
+ if n == 1:
164
+ A[0,0] = 1
165
+ return
166
+
167
+ A[0,0] = A[1,1] = 1
168
+ A[0,1] = A[1,0] = 0
169
+
170
+ for i in xrange(2, n):
171
+ if T[i] != 0:
172
+
173
+ for j in xrange(0, i):
174
+ G = T[i] * A[i-1,j]
175
+ for k in xrange(0, i-1):
176
+ G += A[i,k] * A[k,j]
177
+
178
+ A[i-1,j] -= G * ctx.conj(T[i])
179
+ for k in xrange(0, i-1):
180
+ A[k,j] -= G * ctx.conj(A[i,k])
181
+
182
+ A[i,i] = 1
183
+ for j in xrange(0, i):
184
+ A[j,i] = A[i,j] = 0
185
+
186
+
187
+
188
+ @defun
189
+ def hessenberg(ctx, A, overwrite_a = False):
190
+ """
191
+ This routine computes the Hessenberg decomposition of a square matrix A.
192
+ Given A, an unitary matrix Q is determined such that
193
+
194
+ Q' A Q = H and Q' Q = Q Q' = 1
195
+
196
+ where H is an upper right Hessenberg matrix. Here ' denotes the hermitian
197
+ transpose (i.e. transposition and conjugation).
198
+
199
+ input:
200
+ A : a real or complex square matrix
201
+ overwrite_a : if true, allows modification of A which may improve
202
+ performance. if false, A is not modified.
203
+
204
+ output:
205
+ Q : an unitary matrix
206
+ H : an upper right Hessenberg matrix
207
+
208
+ example:
209
+ >>> from mpmath import mp
210
+ >>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
211
+ >>> Q, H = mp.hessenberg(A)
212
+ >>> mp.nprint(H, 3) # doctest:+SKIP
213
+ [ 3.15 2.23 4.44]
214
+ [-0.769 4.85 3.05]
215
+ [ 0.0 3.61 7.0]
216
+ >>> print(mp.chop(A - Q * H * Q.transpose_conj()))
217
+ [0.0 0.0 0.0]
218
+ [0.0 0.0 0.0]
219
+ [0.0 0.0 0.0]
220
+
221
+ return value: (Q, H)
222
+ """
223
+
224
+ n = A.rows
225
+
226
+ if n == 1:
227
+ return (ctx.matrix([[1]]), A)
228
+
229
+ if not overwrite_a:
230
+ A = A.copy()
231
+
232
+ T = ctx.matrix(n, 1)
233
+
234
+ hessenberg_reduce_0(ctx, A, T)
235
+ Q = A.copy()
236
+ hessenberg_reduce_1(ctx, Q, T)
237
+
238
+ for x in xrange(n):
239
+ for y in xrange(x+2, n):
240
+ A[y,x] = 0
241
+
242
+ return Q, A
243
+
244
+
245
+ ###########################################################################
246
+
247
+
248
+ def qr_step(ctx, n0, n1, A, Q, shift):
249
+ """
250
+ This subroutine executes a single implicitly shifted QR step applied to an
251
+ upper Hessenberg matrix A. Given A and shift as input, first an QR
252
+ decomposition is calculated:
253
+
254
+ Q R = A - shift * 1 .
255
+
256
+ The output is then following matrix:
257
+
258
+ R Q + shift * 1
259
+
260
+ parameters:
261
+ n0, n1 (input) Two integers which specify the submatrix A[n0:n1,n0:n1]
262
+ on which this subroutine operators. The subdiagonal elements
263
+ to the left and below this submatrix must be deflated (i.e. zero).
264
+ following restriction is imposed: n1>=n0+2
265
+ A (input/output) On input, A is an upper Hessenberg matrix.
266
+ On output, A is replaced by "R Q + shift * 1"
267
+ Q (input/output) The parameter Q is multiplied by the unitary matrix
268
+ Q arising from the QR decomposition. Q can also be false, in which
269
+ case the unitary matrix Q is not computated.
270
+ shift (input) a complex number specifying the shift. idealy close to an
271
+ eigenvalue of the bottemmost part of the submatrix A[n0:n1,n0:n1].
272
+
273
+ references:
274
+ Stoer, Bulirsch - Introduction to Numerical Analysis.
275
+ Kresser : Numerical Methods for General and Structured Eigenvalue Problems
276
+ """
277
+
278
+ # implicitly shifted and bulge chasing is explained at p.398/399 in "Stoer, Bulirsch - Introduction to Numerical Analysis"
279
+ # for bulge chasing see also "Watkins - The Matrix Eigenvalue Problem" sec.4.5,p.173
280
+
281
+ # the Givens rotation we used is determined as follows: let c,s be two complex
282
+ # numbers. then we have following relation:
283
+ #
284
+ # v = sqrt(|c|^2 + |s|^2)
285
+ #
286
+ # 1/v [ c~ s~] [c] = [v]
287
+ # [-s c ] [s] [0]
288
+ #
289
+ # the matrix on the left is our Givens rotation.
290
+
291
+ n = A.rows
292
+
293
+ # first step
294
+
295
+ # calculate givens rotation
296
+ c = A[n0 ,n0] - shift
297
+ s = A[n0+1,n0]
298
+
299
+ v = ctx.hypot(ctx.hypot(ctx.re(c), ctx.im(c)), ctx.hypot(ctx.re(s), ctx.im(s)))
300
+
301
+ if v == 0:
302
+ v = 1
303
+ c = 1
304
+ s = 0
305
+ else:
306
+ c /= v
307
+ s /= v
308
+
309
+ cc = ctx.conj(c)
310
+ cs = ctx.conj(s)
311
+
312
+ for k in xrange(n0, n):
313
+ # apply givens rotation from the left
314
+ x = A[n0 ,k]
315
+ y = A[n0+1,k]
316
+ A[n0 ,k] = cc * x + cs * y
317
+ A[n0+1,k] = c * y - s * x
318
+
319
+ for k in xrange(min(n1, n0+3)):
320
+ # apply givens rotation from the right
321
+ x = A[k,n0 ]
322
+ y = A[k,n0+1]
323
+ A[k,n0 ] = c * x + s * y
324
+ A[k,n0+1] = cc * y - cs * x
325
+
326
+ if not isinstance(Q, bool):
327
+ for k in xrange(n):
328
+ # eigenvectors
329
+ x = Q[k,n0 ]
330
+ y = Q[k,n0+1]
331
+ Q[k,n0 ] = c * x + s * y
332
+ Q[k,n0+1] = cc * y - cs * x
333
+
334
+ # chase the bulge
335
+
336
+ for j in xrange(n0, n1 - 2):
337
+ # calculate givens rotation
338
+
339
+ c = A[j+1,j]
340
+ s = A[j+2,j]
341
+
342
+ v = ctx.hypot(ctx.hypot(ctx.re(c), ctx.im(c)), ctx.hypot(ctx.re(s), ctx.im(s)))
343
+
344
+ if v == 0:
345
+ A[j+1,j] = 0
346
+ v = 1
347
+ c = 1
348
+ s = 0
349
+ else:
350
+ A[j+1,j] = v
351
+ c /= v
352
+ s /= v
353
+
354
+ A[j+2,j] = 0
355
+
356
+ cc = ctx.conj(c)
357
+ cs = ctx.conj(s)
358
+
359
+ for k in xrange(j+1, n):
360
+ # apply givens rotation from the left
361
+ x = A[j+1,k]
362
+ y = A[j+2,k]
363
+ A[j+1,k] = cc * x + cs * y
364
+ A[j+2,k] = c * y - s * x
365
+
366
+ for k in xrange(0, min(n1, j+4)):
367
+ # apply givens rotation from the right
368
+ x = A[k,j+1]
369
+ y = A[k,j+2]
370
+ A[k,j+1] = c * x + s * y
371
+ A[k,j+2] = cc * y - cs * x
372
+
373
+ if not isinstance(Q, bool):
374
+ for k in xrange(0, n):
375
+ # eigenvectors
376
+ x = Q[k,j+1]
377
+ y = Q[k,j+2]
378
+ Q[k,j+1] = c * x + s * y
379
+ Q[k,j+2] = cc * y - cs * x
380
+
381
+
382
+
383
+ def hessenberg_qr(ctx, A, Q):
384
+ """
385
+ This routine computes the Schur decomposition of an upper Hessenberg matrix A.
386
+ Given A, an unitary matrix Q is determined such that
387
+
388
+ Q' A Q = R and Q' Q = Q Q' = 1
389
+
390
+ where R is an upper right triangular matrix. Here ' denotes the hermitian
391
+ transpose (i.e. transposition and conjugation).
392
+
393
+ parameters:
394
+ A (input/output) On input, A contains an upper Hessenberg matrix.
395
+ On output, A is replace by the upper right triangluar matrix R.
396
+
397
+ Q (input/output) The parameter Q is multiplied by the unitary
398
+ matrix Q arising from the Schur decomposition. Q can also be
399
+ false, in which case the unitary matrix Q is not computated.
400
+ """
401
+
402
+ n = A.rows
403
+
404
+ norm = 0
405
+ for x in xrange(n):
406
+ for y in xrange(min(x+2, n)):
407
+ norm += ctx.re(A[y,x]) ** 2 + ctx.im(A[y,x]) ** 2
408
+ norm = ctx.sqrt(norm) / n
409
+
410
+ if norm == 0:
411
+ return
412
+
413
+ n0 = 0
414
+ n1 = n
415
+
416
+ eps = ctx.eps / (100 * n)
417
+ maxits = ctx.dps * 4
418
+
419
+ its = totalits = 0
420
+
421
+ while 1:
422
+ # kressner p.32 algo 3
423
+ # the active submatrix is A[n0:n1,n0:n1]
424
+
425
+ k = n0
426
+
427
+ while k + 1 < n1:
428
+ s = abs(ctx.re(A[k,k])) + abs(ctx.im(A[k,k])) + abs(ctx.re(A[k+1,k+1])) + abs(ctx.im(A[k+1,k+1]))
429
+ if s < eps * norm:
430
+ s = norm
431
+ if abs(A[k+1,k]) < eps * s:
432
+ break
433
+ k += 1
434
+
435
+ if k + 1 < n1:
436
+ # deflation found at position (k+1, k)
437
+
438
+ A[k+1,k] = 0
439
+ n0 = k + 1
440
+
441
+ its = 0
442
+
443
+ if n0 + 1 >= n1:
444
+ # block of size at most two has converged
445
+ n0 = 0
446
+ n1 = k + 1
447
+ if n1 < 2:
448
+ # QR algorithm has converged
449
+ return
450
+ else:
451
+ if (its % 30) == 10:
452
+ # exceptional shift
453
+ shift = A[n1-1,n1-2]
454
+ elif (its % 30) == 20:
455
+ # exceptional shift
456
+ shift = abs(A[n1-1,n1-2])
457
+ elif (its % 30) == 29:
458
+ # exceptional shift
459
+ shift = norm
460
+ else:
461
+ # A = [ a b ] det(x-A)=x*x-x*tr(A)+det(A)
462
+ # [ c d ]
463
+ #
464
+ # eigenvalues bad: (tr(A)+sqrt((tr(A))**2-4*det(A)))/2
465
+ # bad because of cancellation if |c| is small and |a-d| is small, too.
466
+ #
467
+ # eigenvalues good: (a+d+sqrt((a-d)**2+4*b*c))/2
468
+
469
+ t = A[n1-2,n1-2] + A[n1-1,n1-1]
470
+ s = (A[n1-1,n1-1] - A[n1-2,n1-2]) ** 2 + 4 * A[n1-1,n1-2] * A[n1-2,n1-1]
471
+ if ctx.re(s) > 0:
472
+ s = ctx.sqrt(s)
473
+ else:
474
+ s = ctx.sqrt(-s) * 1j
475
+ a = (t + s) / 2
476
+ b = (t - s) / 2
477
+ if abs(A[n1-1,n1-1] - a) > abs(A[n1-1,n1-1] - b):
478
+ shift = b
479
+ else:
480
+ shift = a
481
+
482
+ its += 1
483
+ totalits += 1
484
+
485
+ qr_step(ctx, n0, n1, A, Q, shift)
486
+
487
+ if its > maxits:
488
+ raise RuntimeError("qr: failed to converge after %d steps" % its)
489
+
490
+
491
+ @defun
492
+ def schur(ctx, A, overwrite_a = False):
493
+ """
494
+ This routine computes the Schur decomposition of a square matrix A.
495
+ Given A, an unitary matrix Q is determined such that
496
+
497
+ Q' A Q = R and Q' Q = Q Q' = 1
498
+
499
+ where R is an upper right triangular matrix. Here ' denotes the
500
+ hermitian transpose (i.e. transposition and conjugation).
501
+
502
+ input:
503
+ A : a real or complex square matrix
504
+ overwrite_a : if true, allows modification of A which may improve
505
+ performance. if false, A is not modified.
506
+
507
+ output:
508
+ Q : an unitary matrix
509
+ R : an upper right triangular matrix
510
+
511
+ return value: (Q, R)
512
+
513
+ example:
514
+ >>> from mpmath import mp
515
+ >>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
516
+ >>> Q, R = mp.schur(A)
517
+ >>> mp.nprint(R, 3) # doctest:+SKIP
518
+ [2.0 0.417 -2.53]
519
+ [0.0 4.0 -4.74]
520
+ [0.0 0.0 9.0]
521
+ >>> print(mp.chop(A - Q * R * Q.transpose_conj()))
522
+ [0.0 0.0 0.0]
523
+ [0.0 0.0 0.0]
524
+ [0.0 0.0 0.0]
525
+
526
+ warning: The Schur decomposition is not unique.
527
+ """
528
+
529
+ n = A.rows
530
+
531
+ if n == 1:
532
+ return (ctx.matrix([[1]]), A)
533
+
534
+ if not overwrite_a:
535
+ A = A.copy()
536
+
537
+ T = ctx.matrix(n, 1)
538
+
539
+ hessenberg_reduce_0(ctx, A, T)
540
+ Q = A.copy()
541
+ hessenberg_reduce_1(ctx, Q, T)
542
+
543
+ for x in xrange(n):
544
+ for y in xrange(x + 2, n):
545
+ A[y,x] = 0
546
+
547
+ hessenberg_qr(ctx, A, Q)
548
+
549
+ return Q, A
550
+
551
+
552
+ def eig_tr_r(ctx, A):
553
+ """
554
+ This routine calculates the right eigenvectors of an upper right triangular matrix.
555
+
556
+ input:
557
+ A an upper right triangular matrix
558
+
559
+ output:
560
+ ER a matrix whose columns form the right eigenvectors of A
561
+
562
+ return value: ER
563
+ """
564
+
565
+ # this subroutine is inspired by the lapack routines ctrevc.f,clatrs.f
566
+
567
+ n = A.rows
568
+
569
+ ER = ctx.eye(n)
570
+
571
+ eps = ctx.eps
572
+
573
+ unfl = ctx.ldexp(ctx.one, -ctx.prec * 30)
574
+ # since mpmath effectively has no limits on the exponent, we simply scale doubles up
575
+ # original double has prec*20
576
+
577
+ smlnum = unfl * (n / eps)
578
+ simin = 1 / ctx.sqrt(eps)
579
+
580
+ rmax = 1
581
+
582
+ for i in xrange(1, n):
583
+ s = A[i,i]
584
+
585
+ smin = max(eps * abs(s), smlnum)
586
+
587
+ for j in xrange(i - 1, -1, -1):
588
+
589
+ r = 0
590
+ for k in xrange(j + 1, i + 1):
591
+ r += A[j,k] * ER[k,i]
592
+
593
+ t = A[j,j] - s
594
+ if abs(t) < smin:
595
+ t = smin
596
+
597
+ r = -r / t
598
+ ER[j,i] = r
599
+
600
+ rmax = max(rmax, abs(r))
601
+ if rmax > simin:
602
+ for k in xrange(j, i+1):
603
+ ER[k,i] /= rmax
604
+ rmax = 1
605
+
606
+ if rmax != 1:
607
+ for k in xrange(0, i + 1):
608
+ ER[k,i] /= rmax
609
+
610
+ return ER
611
+
612
+ def eig_tr_l(ctx, A):
613
+ """
614
+ This routine calculates the left eigenvectors of an upper right triangular matrix.
615
+
616
+ input:
617
+ A an upper right triangular matrix
618
+
619
+ output:
620
+ EL a matrix whose rows form the left eigenvectors of A
621
+
622
+ return value: EL
623
+ """
624
+
625
+ n = A.rows
626
+
627
+ EL = ctx.eye(n)
628
+
629
+ eps = ctx.eps
630
+
631
+ unfl = ctx.ldexp(ctx.one, -ctx.prec * 30)
632
+ # since mpmath effectively has no limits on the exponent, we simply scale doubles up
633
+ # original double has prec*20
634
+
635
+ smlnum = unfl * (n / eps)
636
+ simin = 1 / ctx.sqrt(eps)
637
+
638
+ rmax = 1
639
+
640
+ for i in xrange(0, n - 1):
641
+ s = A[i,i]
642
+
643
+ smin = max(eps * abs(s), smlnum)
644
+
645
+ for j in xrange(i + 1, n):
646
+
647
+ r = 0
648
+ for k in xrange(i, j):
649
+ r += EL[i,k] * A[k,j]
650
+
651
+ t = A[j,j] - s
652
+ if abs(t) < smin:
653
+ t = smin
654
+
655
+ r = -r / t
656
+ EL[i,j] = r
657
+
658
+ rmax = max(rmax, abs(r))
659
+ if rmax > simin:
660
+ for k in xrange(i, j + 1):
661
+ EL[i,k] /= rmax
662
+ rmax = 1
663
+
664
+ if rmax != 1:
665
+ for k in xrange(i, n):
666
+ EL[i,k] /= rmax
667
+
668
+ return EL
669
+
670
+ @defun
671
+ def eig(ctx, A, left = False, right = True, overwrite_a = False):
672
+ """
673
+ This routine computes the eigenvalues and optionally the left and right
674
+ eigenvectors of a square matrix A. Given A, a vector E and matrices ER
675
+ and EL are calculated such that
676
+
677
+ A ER[:,i] = E[i] ER[:,i]
678
+ EL[i,:] A = EL[i,:] E[i]
679
+
680
+ E contains the eigenvalues of A. The columns of ER contain the right eigenvectors
681
+ of A whereas the rows of EL contain the left eigenvectors.
682
+
683
+
684
+ input:
685
+ A : a real or complex square matrix of shape (n, n)
686
+ left : if true, the left eigenvectors are calculated.
687
+ right : if true, the right eigenvectors are calculated.
688
+ overwrite_a : if true, allows modification of A which may improve
689
+ performance. if false, A is not modified.
690
+
691
+ output:
692
+ E : a list of length n containing the eigenvalues of A.
693
+ ER : a matrix whose columns contain the right eigenvectors of A.
694
+ EL : a matrix whose rows contain the left eigenvectors of A.
695
+
696
+ return values:
697
+ E if left and right are both false.
698
+ (E, ER) if right is true and left is false.
699
+ (E, EL) if left is true and right is false.
700
+ (E, EL, ER) if left and right are true.
701
+
702
+
703
+ examples:
704
+ >>> from mpmath import mp
705
+ >>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
706
+ >>> E, ER = mp.eig(A)
707
+ >>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
708
+ [0.0]
709
+ [0.0]
710
+ [0.0]
711
+
712
+ >>> E, EL, ER = mp.eig(A,left = True, right = True)
713
+ >>> E, EL, ER = mp.eig_sort(E, EL, ER)
714
+ >>> mp.nprint(E)
715
+ [2.0, 4.0, 9.0]
716
+ >>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
717
+ [0.0]
718
+ [0.0]
719
+ [0.0]
720
+ >>> print(mp.chop( EL[0,:] * A - EL[0,:] * E[0]))
721
+ [0.0 0.0 0.0]
722
+
723
+ warning:
724
+ - If there are multiple eigenvalues, the eigenvectors do not necessarily
725
+ span the whole vectorspace, i.e. ER and EL may have not full rank.
726
+ Furthermore in that case the eigenvectors are numerical ill-conditioned.
727
+ - In the general case the eigenvalues have no natural order.
728
+
729
+ see also:
730
+ - eigh (or eigsy, eighe) for the symmetric eigenvalue problem.
731
+ - eig_sort for sorting of eigenvalues and eigenvectors
732
+ """
733
+
734
+ n = A.rows
735
+
736
+ if n == 1:
737
+ if left and (not right):
738
+ return ([A[0]], ctx.matrix([[1]]))
739
+
740
+ if right and (not left):
741
+ return ([A[0]], ctx.matrix([[1]]))
742
+
743
+ return ([A[0]], ctx.matrix([[1]]), ctx.matrix([[1]]))
744
+
745
+ if not overwrite_a:
746
+ A = A.copy()
747
+
748
+ T = ctx.zeros(n, 1)
749
+
750
+ hessenberg_reduce_0(ctx, A, T)
751
+
752
+ if left or right:
753
+ Q = A.copy()
754
+ hessenberg_reduce_1(ctx, Q, T)
755
+ else:
756
+ Q = False
757
+
758
+ for x in xrange(n):
759
+ for y in xrange(x + 2, n):
760
+ A[y,x] = 0
761
+
762
+ hessenberg_qr(ctx, A, Q)
763
+
764
+ E = [0 for i in xrange(n)]
765
+ for i in xrange(n):
766
+ E[i] = A[i,i]
767
+
768
+ if not (left or right):
769
+ return E
770
+
771
+ if left:
772
+ EL = eig_tr_l(ctx, A)
773
+ EL = EL * Q.transpose_conj()
774
+
775
+ if right:
776
+ ER = eig_tr_r(ctx, A)
777
+ ER = Q * ER
778
+
779
+ if left and (not right):
780
+ return (E, EL)
781
+
782
+ if right and (not left):
783
+ return (E, ER)
784
+
785
+ return (E, EL, ER)
786
+
787
+ @defun
788
+ def eig_sort(ctx, E, EL = False, ER = False, f = "real"):
789
+ """
790
+ This routine sorts the eigenvalues and eigenvectors delivered by ``eig``.
791
+
792
+ parameters:
793
+ E : the eigenvalues as delivered by eig
794
+ EL : the left eigenvectors as delivered by eig, or false
795
+ ER : the right eigenvectors as delivered by eig, or false
796
+ f : either a string ("real" sort by increasing real part, "imag" sort by
797
+ increasing imag part, "abs" sort by absolute value) or a function
798
+ mapping complexs to the reals, i.e. ``f = lambda x: -mp.re(x) ``
799
+ would sort the eigenvalues by decreasing real part.
800
+
801
+ return values:
802
+ E if EL and ER are both false.
803
+ (E, ER) if ER is not false and left is false.
804
+ (E, EL) if EL is not false and right is false.
805
+ (E, EL, ER) if EL and ER are not false.
806
+
807
+ example:
808
+ >>> from mpmath import mp
809
+ >>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
810
+ >>> E, EL, ER = mp.eig(A,left = True, right = True)
811
+ >>> E, EL, ER = mp.eig_sort(E, EL, ER)
812
+ >>> mp.nprint(E)
813
+ [2.0, 4.0, 9.0]
814
+ >>> E, EL, ER = mp.eig_sort(E, EL, ER,f = lambda x: -mp.re(x))
815
+ >>> mp.nprint(E)
816
+ [9.0, 4.0, 2.0]
817
+ >>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
818
+ [0.0]
819
+ [0.0]
820
+ [0.0]
821
+ >>> print(mp.chop( EL[0,:] * A - EL[0,:] * E[0]))
822
+ [0.0 0.0 0.0]
823
+ """
824
+
825
+ if isinstance(f, str):
826
+ if f == "real":
827
+ f = ctx.re
828
+ elif f == "imag":
829
+ f = ctx.im
830
+ elif f == "abs":
831
+ f = abs
832
+ else:
833
+ raise RuntimeError("unknown function %s" % f)
834
+
835
+ n = len(E)
836
+
837
+ # Sort eigenvalues (bubble-sort)
838
+
839
+ for i in xrange(n):
840
+ imax = i
841
+ s = f(E[i]) # s is the current maximal element
842
+
843
+ for j in xrange(i + 1, n):
844
+ c = f(E[j])
845
+ if c < s:
846
+ s = c
847
+ imax = j
848
+
849
+ if imax != i:
850
+ # swap eigenvalues
851
+
852
+ z = E[i]
853
+ E[i] = E[imax]
854
+ E[imax] = z
855
+
856
+ if not isinstance(EL, bool):
857
+ for j in xrange(n):
858
+ z = EL[i,j]
859
+ EL[i,j] = EL[imax,j]
860
+ EL[imax,j] = z
861
+
862
+ if not isinstance(ER, bool):
863
+ for j in xrange(n):
864
+ z = ER[j,i]
865
+ ER[j,i] = ER[j,imax]
866
+ ER[j,imax] = z
867
+
868
+ if isinstance(EL, bool) and isinstance(ER, bool):
869
+ return E
870
+
871
+ if isinstance(EL, bool) and not(isinstance(ER, bool)):
872
+ return (E, ER)
873
+
874
+ if isinstance(ER, bool) and not(isinstance(EL, bool)):
875
+ return (E, EL)
876
+
877
+ return (E, EL, ER)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/linalg.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear algebra
3
+ --------------
4
+
5
+ Linear equations
6
+ ................
7
+
8
+ Basic linear algebra is implemented; you can for example solve the linear
9
+ equation system::
10
+
11
+ x + 2*y = -10
12
+ 3*x + 4*y = 10
13
+
14
+ using ``lu_solve``::
15
+
16
+ >>> from mpmath import *
17
+ >>> mp.pretty = False
18
+ >>> A = matrix([[1, 2], [3, 4]])
19
+ >>> b = matrix([-10, 10])
20
+ >>> x = lu_solve(A, b)
21
+ >>> x
22
+ matrix(
23
+ [['30.0'],
24
+ ['-20.0']])
25
+
26
+ If you don't trust the result, use ``residual`` to calculate the residual ||A*x-b||::
27
+
28
+ >>> residual(A, x, b)
29
+ matrix(
30
+ [['3.46944695195361e-18'],
31
+ ['3.46944695195361e-18']])
32
+ >>> str(eps)
33
+ '2.22044604925031e-16'
34
+
35
+ As you can see, the solution is quite accurate. The error is caused by the
36
+ inaccuracy of the internal floating point arithmetic. Though, it's even smaller
37
+ than the current machine epsilon, which basically means you can trust the
38
+ result.
39
+
40
+ If you need more speed, use NumPy, or ``fp.lu_solve`` for a floating-point computation.
41
+
42
+ >>> fp.lu_solve(A, b) # doctest: +ELLIPSIS
43
+ matrix(...)
44
+
45
+ ``lu_solve`` accepts overdetermined systems. It is usually not possible to solve
46
+ such systems, so the residual is minimized instead. Internally this is done
47
+ using Cholesky decomposition to compute a least squares approximation. This means
48
+ that that ``lu_solve`` will square the errors. If you can't afford this, use
49
+ ``qr_solve`` instead. It is twice as slow but more accurate, and it calculates
50
+ the residual automatically.
51
+
52
+
53
+ Matrix factorization
54
+ ....................
55
+
56
+ The function ``lu`` computes an explicit LU factorization of a matrix::
57
+
58
+ >>> P, L, U = lu(matrix([[0,2,3],[4,5,6],[7,8,9]]))
59
+ >>> print(P)
60
+ [0.0 0.0 1.0]
61
+ [1.0 0.0 0.0]
62
+ [0.0 1.0 0.0]
63
+ >>> print(L)
64
+ [ 1.0 0.0 0.0]
65
+ [ 0.0 1.0 0.0]
66
+ [0.571428571428571 0.214285714285714 1.0]
67
+ >>> print(U)
68
+ [7.0 8.0 9.0]
69
+ [0.0 2.0 3.0]
70
+ [0.0 0.0 0.214285714285714]
71
+ >>> print(P.T*L*U)
72
+ [0.0 2.0 3.0]
73
+ [4.0 5.0 6.0]
74
+ [7.0 8.0 9.0]
75
+
76
+ Interval matrices
77
+ -----------------
78
+
79
+ Matrices may contain interval elements. This allows one to perform
80
+ basic linear algebra operations such as matrix multiplication
81
+ and equation solving with rigorous error bounds::
82
+
83
+ >>> a = iv.matrix([['0.1','0.3','1.0'],
84
+ ... ['7.1','5.5','4.8'],
85
+ ... ['3.2','4.4','5.6']])
86
+ >>>
87
+ >>> b = iv.matrix(['4','0.6','0.5'])
88
+ >>> c = iv.lu_solve(a, b)
89
+ >>> print(c)
90
+ [ [5.2582327113062568605927528666, 5.25823271130625686059275702219]]
91
+ [[-13.1550493962678375411635581388, -13.1550493962678375411635540152]]
92
+ [ [7.42069154774972557628979076189, 7.42069154774972557628979190734]]
93
+ >>> print(a*c)
94
+ [ [3.99999999999999999999999844904, 4.00000000000000000000000155096]]
95
+ [[0.599999999999999999999968898009, 0.600000000000000000000031763736]]
96
+ [[0.499999999999999999999979320485, 0.500000000000000000000020679515]]
97
+ """
98
+
99
+ # TODO:
100
+ # *implement high-level qr()
101
+ # *test unitvector
102
+ # *iterative solving
103
+
104
+ from copy import copy
105
+
106
+ from ..libmp.backend import xrange
107
+
108
+ class LinearAlgebraMethods(object):
109
+
110
+ def LU_decomp(ctx, A, overwrite=False, use_cache=True):
111
+ """
112
+ LU-factorization of a n*n matrix using the Gauss algorithm.
113
+ Returns L and U in one matrix and the pivot indices.
114
+
115
+ Use overwrite to specify whether A will be overwritten with L and U.
116
+ """
117
+ if not A.rows == A.cols:
118
+ raise ValueError('need n*n matrix')
119
+ # get from cache if possible
120
+ if use_cache and isinstance(A, ctx.matrix) and A._LU:
121
+ return A._LU
122
+ if not overwrite:
123
+ orig = A
124
+ A = A.copy()
125
+ tol = ctx.absmin(ctx.mnorm(A,1) * ctx.eps) # each pivot element has to be bigger
126
+ n = A.rows
127
+ p = [None]*(n - 1)
128
+ for j in xrange(n - 1):
129
+ # pivoting, choose max(abs(reciprocal row sum)*abs(pivot element))
130
+ biggest = 0
131
+ for k in xrange(j, n):
132
+ s = ctx.fsum([ctx.absmin(A[k,l]) for l in xrange(j, n)])
133
+ if ctx.absmin(s) <= tol:
134
+ raise ZeroDivisionError('matrix is numerically singular')
135
+ current = 1/s * ctx.absmin(A[k,j])
136
+ if current > biggest: # TODO: what if equal?
137
+ biggest = current
138
+ p[j] = k
139
+ # swap rows according to p
140
+ ctx.swap_row(A, j, p[j])
141
+ if ctx.absmin(A[j,j]) <= tol:
142
+ raise ZeroDivisionError('matrix is numerically singular')
143
+ # calculate elimination factors and add rows
144
+ for i in xrange(j + 1, n):
145
+ A[i,j] /= A[j,j]
146
+ for k in xrange(j + 1, n):
147
+ A[i,k] -= A[i,j]*A[j,k]
148
+ if ctx.absmin(A[n - 1,n - 1]) <= tol:
149
+ raise ZeroDivisionError('matrix is numerically singular')
150
+ # cache decomposition
151
+ if not overwrite and isinstance(orig, ctx.matrix):
152
+ orig._LU = (A, p)
153
+ return A, p
154
+
155
+ def L_solve(ctx, L, b, p=None):
156
+ """
157
+ Solve the lower part of a LU factorized matrix for y.
158
+ """
159
+ if L.rows != L.cols:
160
+ raise RuntimeError("need n*n matrix")
161
+ n = L.rows
162
+ if len(b) != n:
163
+ raise ValueError("Value should be equal to n")
164
+ b = copy(b)
165
+ if p: # swap b according to p
166
+ for k in xrange(0, len(p)):
167
+ ctx.swap_row(b, k, p[k])
168
+ # solve
169
+ for i in xrange(1, n):
170
+ for j in xrange(i):
171
+ b[i] -= L[i,j] * b[j]
172
+ return b
173
+
174
+ def U_solve(ctx, U, y):
175
+ """
176
+ Solve the upper part of a LU factorized matrix for x.
177
+ """
178
+ if U.rows != U.cols:
179
+ raise RuntimeError("need n*n matrix")
180
+ n = U.rows
181
+ if len(y) != n:
182
+ raise ValueError("Value should be equal to n")
183
+ x = copy(y)
184
+ for i in xrange(n - 1, -1, -1):
185
+ for j in xrange(i + 1, n):
186
+ x[i] -= U[i,j] * x[j]
187
+ x[i] /= U[i,i]
188
+ return x
189
+
190
+ def lu_solve(ctx, A, b, **kwargs):
191
+ """
192
+ Ax = b => x
193
+
194
+ Solve a determined or overdetermined linear equations system.
195
+ Fast LU decomposition is used, which is less accurate than QR decomposition
196
+ (especially for overdetermined systems), but it's twice as efficient.
197
+ Use qr_solve if you want more precision or have to solve a very ill-
198
+ conditioned system.
199
+
200
+ If you specify real=True, it does not check for overdeterminded complex
201
+ systems.
202
+ """
203
+ prec = ctx.prec
204
+ try:
205
+ ctx.prec += 10
206
+ # do not overwrite A nor b
207
+ A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
208
+ if A.rows < A.cols:
209
+ raise ValueError('cannot solve underdetermined system')
210
+ if A.rows > A.cols:
211
+ # use least-squares method if overdetermined
212
+ # (this increases errors)
213
+ AH = A.H
214
+ A = AH * A
215
+ b = AH * b
216
+ if (kwargs.get('real', False) or
217
+ not sum(type(i) is ctx.mpc for i in A)):
218
+ # TODO: necessary to check also b?
219
+ x = ctx.cholesky_solve(A, b)
220
+ else:
221
+ x = ctx.lu_solve(A, b)
222
+ else:
223
+ # LU factorization
224
+ A, p = ctx.LU_decomp(A)
225
+ b = ctx.L_solve(A, b, p)
226
+ x = ctx.U_solve(A, b)
227
+ finally:
228
+ ctx.prec = prec
229
+ return x
230
+
231
+ def improve_solution(ctx, A, x, b, maxsteps=1):
232
+ """
233
+ Improve a solution to a linear equation system iteratively.
234
+
235
+ This re-uses the LU decomposition and is thus cheap.
236
+ Usually 3 up to 4 iterations are giving the maximal improvement.
237
+ """
238
+ if A.rows != A.cols:
239
+ raise RuntimeError("need n*n matrix") # TODO: really?
240
+ for _ in xrange(maxsteps):
241
+ r = ctx.residual(A, x, b)
242
+ if ctx.norm(r, 2) < 10*ctx.eps:
243
+ break
244
+ # this uses cached LU decomposition and is thus cheap
245
+ dx = ctx.lu_solve(A, -r)
246
+ x += dx
247
+ return x
248
+
249
+ def lu(ctx, A):
250
+ """
251
+ A -> P, L, U
252
+
253
+ LU factorisation of a square matrix A. L is the lower, U the upper part.
254
+ P is the permutation matrix indicating the row swaps.
255
+
256
+ P*A = L*U
257
+
258
+ If you need efficiency, use the low-level method LU_decomp instead, it's
259
+ much more memory efficient.
260
+ """
261
+ # get factorization
262
+ A, p = ctx.LU_decomp(A)
263
+ n = A.rows
264
+ L = ctx.matrix(n)
265
+ U = ctx.matrix(n)
266
+ for i in xrange(n):
267
+ for j in xrange(n):
268
+ if i > j:
269
+ L[i,j] = A[i,j]
270
+ elif i == j:
271
+ L[i,j] = 1
272
+ U[i,j] = A[i,j]
273
+ else:
274
+ U[i,j] = A[i,j]
275
+ # calculate permutation matrix
276
+ P = ctx.eye(n)
277
+ for k in xrange(len(p)):
278
+ ctx.swap_row(P, k, p[k])
279
+ return P, L, U
280
+
281
+ def unitvector(ctx, n, i):
282
+ """
283
+ Return the i-th n-dimensional unit vector.
284
+ """
285
+ assert 0 < i <= n, 'this unit vector does not exist'
286
+ return [ctx.zero]*(i-1) + [ctx.one] + [ctx.zero]*(n-i)
287
+
288
+ def inverse(ctx, A, **kwargs):
289
+ """
290
+ Calculate the inverse of a matrix.
291
+
292
+ If you want to solve an equation system Ax = b, it's recommended to use
293
+ solve(A, b) instead, it's about 3 times more efficient.
294
+ """
295
+ prec = ctx.prec
296
+ try:
297
+ ctx.prec += 10
298
+ # do not overwrite A
299
+ A = ctx.matrix(A, **kwargs).copy()
300
+ n = A.rows
301
+ # get LU factorisation
302
+ A, p = ctx.LU_decomp(A)
303
+ cols = []
304
+ # calculate unit vectors and solve corresponding system to get columns
305
+ for i in xrange(1, n + 1):
306
+ e = ctx.unitvector(n, i)
307
+ y = ctx.L_solve(A, e, p)
308
+ cols.append(ctx.U_solve(A, y))
309
+ # convert columns to matrix
310
+ inv = []
311
+ for i in xrange(n):
312
+ row = []
313
+ for j in xrange(n):
314
+ row.append(cols[j][i])
315
+ inv.append(row)
316
+ result = ctx.matrix(inv, **kwargs)
317
+ finally:
318
+ ctx.prec = prec
319
+ return result
320
+
321
+ def householder(ctx, A):
322
+ """
323
+ (A|b) -> H, p, x, res
324
+
325
+ (A|b) is the coefficient matrix with left hand side of an optionally
326
+ overdetermined linear equation system.
327
+ H and p contain all information about the transformation matrices.
328
+ x is the solution, res the residual.
329
+ """
330
+ if not isinstance(A, ctx.matrix):
331
+ raise TypeError("A should be a type of ctx.matrix")
332
+ m = A.rows
333
+ n = A.cols
334
+ if m < n - 1:
335
+ raise RuntimeError("Columns should not be less than rows")
336
+ # calculate Householder matrix
337
+ p = []
338
+ for j in xrange(0, n - 1):
339
+ s = ctx.fsum(abs(A[i,j])**2 for i in xrange(j, m))
340
+ if not abs(s) > ctx.eps:
341
+ raise ValueError('matrix is numerically singular')
342
+ p.append(-ctx.sign(ctx.re(A[j,j])) * ctx.sqrt(s))
343
+ kappa = ctx.one / (s - p[j] * A[j,j])
344
+ A[j,j] -= p[j]
345
+ for k in xrange(j+1, n):
346
+ y = ctx.fsum(ctx.conj(A[i,j]) * A[i,k] for i in xrange(j, m)) * kappa
347
+ for i in xrange(j, m):
348
+ A[i,k] -= A[i,j] * y
349
+ # solve Rx = c1
350
+ x = [A[i,n - 1] for i in xrange(n - 1)]
351
+ for i in xrange(n - 2, -1, -1):
352
+ x[i] -= ctx.fsum(A[i,j] * x[j] for j in xrange(i + 1, n - 1))
353
+ x[i] /= p[i]
354
+ # calculate residual
355
+ if not m == n - 1:
356
+ r = [A[m-1-i, n-1] for i in xrange(m - n + 1)]
357
+ else:
358
+ # determined system, residual should be 0
359
+ r = [0]*m # maybe a bad idea, changing r[i] will change all elements
360
+ return A, p, x, r
361
+
362
+ #def qr(ctx, A):
363
+ # """
364
+ # A -> Q, R
365
+ #
366
+ # QR factorisation of a square matrix A using Householder decomposition.
367
+ # Q is orthogonal, this leads to very few numerical errors.
368
+ #
369
+ # A = Q*R
370
+ # """
371
+ # H, p, x, res = householder(A)
372
+ # TODO: implement this
373
+
374
+ def residual(ctx, A, x, b, **kwargs):
375
+ """
376
+ Calculate the residual of a solution to a linear equation system.
377
+
378
+ r = A*x - b for A*x = b
379
+ """
380
+ oldprec = ctx.prec
381
+ try:
382
+ ctx.prec *= 2
383
+ A, x, b = ctx.matrix(A, **kwargs), ctx.matrix(x, **kwargs), ctx.matrix(b, **kwargs)
384
+ return A*x - b
385
+ finally:
386
+ ctx.prec = oldprec
387
+
388
+ def qr_solve(ctx, A, b, norm=None, **kwargs):
389
+ """
390
+ Ax = b => x, ||Ax - b||
391
+
392
+ Solve a determined or overdetermined linear equations system and
393
+ calculate the norm of the residual (error).
394
+ QR decomposition using Householder factorization is applied, which gives very
395
+ accurate results even for ill-conditioned matrices. qr_solve is twice as
396
+ efficient.
397
+ """
398
+ if norm is None:
399
+ norm = ctx.norm
400
+ prec = ctx.prec
401
+ try:
402
+ ctx.prec += 10
403
+ # do not overwrite A nor b
404
+ A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
405
+ if A.rows < A.cols:
406
+ raise ValueError('cannot solve underdetermined system')
407
+ H, p, x, r = ctx.householder(ctx.extend(A, b))
408
+ res = ctx.norm(r)
409
+ # calculate residual "manually" for determined systems
410
+ if res == 0:
411
+ res = ctx.norm(ctx.residual(A, x, b))
412
+ return ctx.matrix(x, **kwargs), res
413
+ finally:
414
+ ctx.prec = prec
415
+
416
+ def cholesky(ctx, A, tol=None):
417
+ r"""
418
+ Cholesky decomposition of a symmetric positive-definite matrix `A`.
419
+ Returns a lower triangular matrix `L` such that `A = L \times L^T`.
420
+ More generally, for a complex Hermitian positive-definite matrix,
421
+ a Cholesky decomposition satisfying `A = L \times L^H` is returned.
422
+
423
+ The Cholesky decomposition can be used to solve linear equation
424
+ systems twice as efficiently as LU decomposition, or to
425
+ test whether `A` is positive-definite.
426
+
427
+ The optional parameter ``tol`` determines the tolerance for
428
+ verifying positive-definiteness.
429
+
430
+ **Examples**
431
+
432
+ Cholesky decomposition of a positive-definite symmetric matrix::
433
+
434
+ >>> from mpmath import *
435
+ >>> mp.dps = 25; mp.pretty = True
436
+ >>> A = eye(3) + hilbert(3)
437
+ >>> nprint(A)
438
+ [ 2.0 0.5 0.333333]
439
+ [ 0.5 1.33333 0.25]
440
+ [0.333333 0.25 1.2]
441
+ >>> L = cholesky(A)
442
+ >>> nprint(L)
443
+ [ 1.41421 0.0 0.0]
444
+ [0.353553 1.09924 0.0]
445
+ [0.235702 0.15162 1.05899]
446
+ >>> chop(A - L*L.T)
447
+ [0.0 0.0 0.0]
448
+ [0.0 0.0 0.0]
449
+ [0.0 0.0 0.0]
450
+
451
+ Cholesky decomposition of a Hermitian matrix::
452
+
453
+ >>> A = eye(3) + matrix([[0,0.25j,-0.5j],[-0.25j,0,0],[0.5j,0,0]])
454
+ >>> L = cholesky(A)
455
+ >>> nprint(L)
456
+ [ 1.0 0.0 0.0]
457
+ [(0.0 - 0.25j) (0.968246 + 0.0j) 0.0]
458
+ [ (0.0 + 0.5j) (0.129099 + 0.0j) (0.856349 + 0.0j)]
459
+ >>> chop(A - L*L.H)
460
+ [0.0 0.0 0.0]
461
+ [0.0 0.0 0.0]
462
+ [0.0 0.0 0.0]
463
+
464
+ Attempted Cholesky decomposition of a matrix that is not positive
465
+ definite::
466
+
467
+ >>> A = -eye(3) + hilbert(3)
468
+ >>> L = cholesky(A)
469
+ Traceback (most recent call last):
470
+ ...
471
+ ValueError: matrix is not positive-definite
472
+
473
+ **References**
474
+
475
+ 1. [Wikipedia]_ http://en.wikipedia.org/wiki/Cholesky_decomposition
476
+
477
+ """
478
+ if not isinstance(A, ctx.matrix):
479
+ raise RuntimeError("A should be a type of ctx.matrix")
480
+ if not A.rows == A.cols:
481
+ raise ValueError('need n*n matrix')
482
+ if tol is None:
483
+ tol = +ctx.eps
484
+ n = A.rows
485
+ L = ctx.matrix(n)
486
+ for j in xrange(n):
487
+ c = ctx.re(A[j,j])
488
+ if abs(c-A[j,j]) > tol:
489
+ raise ValueError('matrix is not Hermitian')
490
+ s = c - ctx.fsum((L[j,k] for k in xrange(j)),
491
+ absolute=True, squared=True)
492
+ if s < tol:
493
+ raise ValueError('matrix is not positive-definite')
494
+ L[j,j] = ctx.sqrt(s)
495
+ for i in xrange(j, n):
496
+ it1 = (L[i,k] for k in xrange(j))
497
+ it2 = (L[j,k] for k in xrange(j))
498
+ t = ctx.fdot(it1, it2, conjugate=True)
499
+ L[i,j] = (A[i,j] - t) / L[j,j]
500
+ return L
501
+
502
+ def cholesky_solve(ctx, A, b, **kwargs):
503
+ """
504
+ Ax = b => x
505
+
506
+ Solve a symmetric positive-definite linear equation system.
507
+ This is twice as efficient as lu_solve.
508
+
509
+ Typical use cases:
510
+ * A.T*A
511
+ * Hessian matrix
512
+ * differential equations
513
+ """
514
+ prec = ctx.prec
515
+ try:
516
+ ctx.prec += 10
517
+ # do not overwrite A nor b
518
+ A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
519
+ if A.rows != A.cols:
520
+ raise ValueError('can only solve determined system')
521
+ # Cholesky factorization
522
+ L = ctx.cholesky(A)
523
+ # solve
524
+ n = L.rows
525
+ if len(b) != n:
526
+ raise ValueError("Value should be equal to n")
527
+ for i in xrange(n):
528
+ b[i] -= ctx.fsum(L[i,j] * b[j] for j in xrange(i))
529
+ b[i] /= L[i,i]
530
+ x = ctx.U_solve(L.T, b)
531
+ return x
532
+ finally:
533
+ ctx.prec = prec
534
+
535
+ def det(ctx, A):
536
+ """
537
+ Calculate the determinant of a matrix.
538
+ """
539
+ prec = ctx.prec
540
+ try:
541
+ # do not overwrite A
542
+ A = ctx.matrix(A).copy()
543
+ # use LU factorization to calculate determinant
544
+ try:
545
+ R, p = ctx.LU_decomp(A)
546
+ except ZeroDivisionError:
547
+ return 0
548
+ z = 1
549
+ for i, e in enumerate(p):
550
+ if i != e:
551
+ z *= -1
552
+ for i in xrange(A.rows):
553
+ z *= R[i,i]
554
+ return z
555
+ finally:
556
+ ctx.prec = prec
557
+
558
+ def cond(ctx, A, norm=None):
559
+ """
560
+ Calculate the condition number of a matrix using a specified matrix norm.
561
+
562
+ The condition number estimates the sensitivity of a matrix to errors.
563
+ Example: small input errors for ill-conditioned coefficient matrices
564
+ alter the solution of the system dramatically.
565
+
566
+ For ill-conditioned matrices it's recommended to use qr_solve() instead
567
+ of lu_solve(). This does not help with input errors however, it just avoids
568
+ to add additional errors.
569
+
570
+ Definition: cond(A) = ||A|| * ||A**-1||
571
+ """
572
+ if norm is None:
573
+ norm = lambda x: ctx.mnorm(x,1)
574
+ return norm(A) * norm(ctx.inverse(A))
575
+
576
+ def lu_solve_mat(ctx, a, b):
577
+ """Solve a * x = b where a and b are matrices."""
578
+ r = ctx.matrix(a.rows, b.cols)
579
+ for i in range(b.cols):
580
+ c = ctx.lu_solve(a, b.column(i))
581
+ for j in range(len(c)):
582
+ r[j, i] = c[j]
583
+ return r
584
+
585
+ def qr(ctx, A, mode = 'full', edps = 10):
586
+ """
587
+ Compute a QR factorization $A = QR$ where
588
+ A is an m x n matrix of real or complex numbers where m >= n
589
+
590
+ mode has following meanings:
591
+ (1) mode = 'raw' returns two matrixes (A, tau) in the
592
+ internal format used by LAPACK
593
+ (2) mode = 'skinny' returns the leading n columns of Q
594
+ and n rows of R
595
+ (3) Any other value returns the leading m columns of Q
596
+ and m rows of R
597
+
598
+ edps is the increase in mp precision used for calculations
599
+
600
+ **Examples**
601
+
602
+ >>> from mpmath import *
603
+ >>> mp.dps = 15
604
+ >>> mp.pretty = True
605
+ >>> A = matrix([[1, 2], [3, 4], [1, 1]])
606
+ >>> Q, R = qr(A)
607
+ >>> Q
608
+ [-0.301511344577764 0.861640436855329 0.408248290463863]
609
+ [-0.904534033733291 -0.123091490979333 -0.408248290463863]
610
+ [-0.301511344577764 -0.492365963917331 0.816496580927726]
611
+ >>> R
612
+ [-3.3166247903554 -4.52267016866645]
613
+ [ 0.0 0.738548945875996]
614
+ [ 0.0 0.0]
615
+ >>> Q * R
616
+ [1.0 2.0]
617
+ [3.0 4.0]
618
+ [1.0 1.0]
619
+ >>> chop(Q.T * Q)
620
+ [1.0 0.0 0.0]
621
+ [0.0 1.0 0.0]
622
+ [0.0 0.0 1.0]
623
+ >>> B = matrix([[1+0j, 2-3j], [3+j, 4+5j]])
624
+ >>> Q, R = qr(B)
625
+ >>> nprint(Q)
626
+ [ (-0.301511 + 0.0j) (0.0695795 - 0.95092j)]
627
+ [(-0.904534 - 0.301511j) (-0.115966 + 0.278318j)]
628
+ >>> nprint(R)
629
+ [(-3.31662 + 0.0j) (-5.72872 - 2.41209j)]
630
+ [ 0.0 (3.91965 + 0.0j)]
631
+ >>> Q * R
632
+ [(1.0 + 0.0j) (2.0 - 3.0j)]
633
+ [(3.0 + 1.0j) (4.0 + 5.0j)]
634
+ >>> chop(Q.T * Q.conjugate())
635
+ [1.0 0.0]
636
+ [0.0 1.0]
637
+
638
+ """
639
+
640
+ # check values before continuing
641
+ assert isinstance(A, ctx.matrix)
642
+ m = A.rows
643
+ n = A.cols
644
+ assert n >= 0
645
+ assert m >= n
646
+ assert edps >= 0
647
+
648
+ # check for complex data type
649
+ cmplx = any(type(x) is ctx.mpc for x in A)
650
+
651
+ # temporarily increase the precision and initialize
652
+ with ctx.extradps(edps):
653
+ tau = ctx.matrix(n,1)
654
+ A = A.copy()
655
+
656
+ # ---------------
657
+ # FACTOR MATRIX A
658
+ # ---------------
659
+ if cmplx:
660
+ one = ctx.mpc('1.0', '0.0')
661
+ zero = ctx.mpc('0.0', '0.0')
662
+ rzero = ctx.mpf('0.0')
663
+
664
+ # main loop to factor A (complex)
665
+ for j in xrange(0, n):
666
+ alpha = A[j,j]
667
+ alphr = ctx.re(alpha)
668
+ alphi = ctx.im(alpha)
669
+
670
+ if (m-j) >= 2:
671
+ xnorm = ctx.fsum( A[i,j]*ctx.conj(A[i,j]) for i in xrange(j+1, m) )
672
+ xnorm = ctx.re( ctx.sqrt(xnorm) )
673
+ else:
674
+ xnorm = rzero
675
+
676
+ if (xnorm == rzero) and (alphi == rzero):
677
+ tau[j] = zero
678
+ continue
679
+
680
+ if alphr < rzero:
681
+ beta = ctx.sqrt(alphr**2 + alphi**2 + xnorm**2)
682
+ else:
683
+ beta = -ctx.sqrt(alphr**2 + alphi**2 + xnorm**2)
684
+
685
+ tau[j] = ctx.mpc( (beta - alphr) / beta, -alphi / beta )
686
+ t = -ctx.conj(tau[j])
687
+ za = one / (alpha - beta)
688
+
689
+ for i in xrange(j+1, m):
690
+ A[i,j] *= za
691
+
692
+ A[j,j] = one
693
+ for k in xrange(j+1, n):
694
+ y = ctx.fsum(A[i,j] * ctx.conj(A[i,k]) for i in xrange(j, m))
695
+ temp = t * ctx.conj(y)
696
+ for i in xrange(j, m):
697
+ A[i,k] += A[i,j] * temp
698
+
699
+ A[j,j] = ctx.mpc(beta, '0.0')
700
+ else:
701
+ one = ctx.mpf('1.0')
702
+ zero = ctx.mpf('0.0')
703
+
704
+ # main loop to factor A (real)
705
+ for j in xrange(0, n):
706
+ alpha = A[j,j]
707
+
708
+ if (m-j) > 2:
709
+ xnorm = ctx.fsum( (A[i,j])**2 for i in xrange(j+1, m) )
710
+ xnorm = ctx.sqrt(xnorm)
711
+ elif (m-j) == 2:
712
+ xnorm = abs( A[m-1,j] )
713
+ else:
714
+ xnorm = zero
715
+
716
+ if xnorm == zero:
717
+ tau[j] = zero
718
+ continue
719
+
720
+ if alpha < zero:
721
+ beta = ctx.sqrt(alpha**2 + xnorm**2)
722
+ else:
723
+ beta = -ctx.sqrt(alpha**2 + xnorm**2)
724
+
725
+ tau[j] = (beta - alpha) / beta
726
+ t = -tau[j]
727
+ da = one / (alpha - beta)
728
+
729
+ for i in xrange(j+1, m):
730
+ A[i,j] *= da
731
+
732
+ A[j,j] = one
733
+ for k in xrange(j+1, n):
734
+ y = ctx.fsum( A[i,j] * A[i,k] for i in xrange(j, m) )
735
+ temp = t * y
736
+ for i in xrange(j,m):
737
+ A[i,k] += A[i,j] * temp
738
+
739
+ A[j,j] = beta
740
+
741
+ # return factorization in same internal format as LAPACK
742
+ if (mode == 'raw') or (mode == 'RAW'):
743
+ return A, tau
744
+
745
+ # ----------------------------------
746
+ # FORM Q USING BACKWARD ACCUMULATION
747
+ # ----------------------------------
748
+
749
+ # form R before the values are overwritten
750
+ R = A.copy()
751
+ for j in xrange(0, n):
752
+ for i in xrange(j+1, m):
753
+ R[i,j] = zero
754
+
755
+ # set the value of p (number of columns of Q to return)
756
+ p = m
757
+ if (mode == 'skinny') or (mode == 'SKINNY'):
758
+ p = n
759
+
760
+ # add columns to A if needed and initialize
761
+ A.cols += (p-n)
762
+ for j in xrange(0, p):
763
+ A[j,j] = one
764
+ for i in xrange(0, j):
765
+ A[i,j] = zero
766
+
767
+ # main loop to form Q
768
+ for j in xrange(n-1, -1, -1):
769
+ t = -tau[j]
770
+ A[j,j] += t
771
+
772
+ for k in xrange(j+1, p):
773
+ if cmplx:
774
+ y = ctx.fsum(A[i,j] * ctx.conj(A[i,k]) for i in xrange(j+1, m))
775
+ temp = t * ctx.conj(y)
776
+ else:
777
+ y = ctx.fsum(A[i,j] * A[i,k] for i in xrange(j+1, m))
778
+ temp = t * y
779
+ A[j,k] = temp
780
+ for i in xrange(j+1, m):
781
+ A[i,k] += A[i,j] * temp
782
+
783
+ for i in xrange(j+1, m):
784
+ A[i, j] *= t
785
+
786
+ return A, R[0:p,0:n]
787
+
788
+ # ------------------
789
+ # END OF FUNCTION QR
790
+ # ------------------
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/matrices.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..libmp.backend import xrange
2
+ import warnings
3
+
4
+ # TODO: interpret list as vectors (for multiplication)
5
+
6
+ rowsep = '\n'
7
+ colsep = ' '
8
+
9
+ class _matrix(object):
10
+ """
11
+ Numerical matrix.
12
+
13
+ Specify the dimensions or the data as a nested list.
14
+ Elements default to zero.
15
+ Use a flat list to create a column vector easily.
16
+
17
+ The datatype of the context (mpf for mp, mpi for iv, and float for fp) is used to store the data.
18
+
19
+ Creating matrices
20
+ -----------------
21
+
22
+ Matrices in mpmath are implemented using dictionaries. Only non-zero values
23
+ are stored, so it is cheap to represent sparse matrices.
24
+
25
+ The most basic way to create one is to use the ``matrix`` class directly.
26
+ You can create an empty matrix specifying the dimensions:
27
+
28
+ >>> from mpmath import *
29
+ >>> mp.dps = 15
30
+ >>> matrix(2)
31
+ matrix(
32
+ [['0.0', '0.0'],
33
+ ['0.0', '0.0']])
34
+ >>> matrix(2, 3)
35
+ matrix(
36
+ [['0.0', '0.0', '0.0'],
37
+ ['0.0', '0.0', '0.0']])
38
+
39
+ Calling ``matrix`` with one dimension will create a square matrix.
40
+
41
+ To access the dimensions of a matrix, use the ``rows`` or ``cols`` keyword:
42
+
43
+ >>> A = matrix(3, 2)
44
+ >>> A
45
+ matrix(
46
+ [['0.0', '0.0'],
47
+ ['0.0', '0.0'],
48
+ ['0.0', '0.0']])
49
+ >>> A.rows
50
+ 3
51
+ >>> A.cols
52
+ 2
53
+
54
+ You can also change the dimension of an existing matrix. This will set the
55
+ new elements to 0. If the new dimension is smaller than before, the
56
+ concerning elements are discarded:
57
+
58
+ >>> A.rows = 2
59
+ >>> A
60
+ matrix(
61
+ [['0.0', '0.0'],
62
+ ['0.0', '0.0']])
63
+
64
+ Internally ``mpmathify`` is used every time an element is set. This
65
+ is done using the syntax A[row,column], counting from 0:
66
+
67
+ >>> A = matrix(2)
68
+ >>> A[1,1] = 1 + 1j
69
+ >>> A
70
+ matrix(
71
+ [['0.0', '0.0'],
72
+ ['0.0', mpc(real='1.0', imag='1.0')]])
73
+
74
+ A more comfortable way to create a matrix lets you use nested lists:
75
+
76
+ >>> matrix([[1, 2], [3, 4]])
77
+ matrix(
78
+ [['1.0', '2.0'],
79
+ ['3.0', '4.0']])
80
+
81
+ Convenient advanced functions are available for creating various standard
82
+ matrices, see ``zeros``, ``ones``, ``diag``, ``eye``, ``randmatrix`` and
83
+ ``hilbert``.
84
+
85
+ Vectors
86
+ .......
87
+
88
+ Vectors may also be represented by the ``matrix`` class (with rows = 1 or cols = 1).
89
+ For vectors there are some things which make life easier. A column vector can
90
+ be created using a flat list, a row vectors using an almost flat nested list::
91
+
92
+ >>> matrix([1, 2, 3])
93
+ matrix(
94
+ [['1.0'],
95
+ ['2.0'],
96
+ ['3.0']])
97
+ >>> matrix([[1, 2, 3]])
98
+ matrix(
99
+ [['1.0', '2.0', '3.0']])
100
+
101
+ Optionally vectors can be accessed like lists, using only a single index::
102
+
103
+ >>> x = matrix([1, 2, 3])
104
+ >>> x[1]
105
+ mpf('2.0')
106
+ >>> x[1,0]
107
+ mpf('2.0')
108
+
109
+ Other
110
+ .....
111
+
112
+ Like you probably expected, matrices can be printed::
113
+
114
+ >>> print randmatrix(3) # doctest:+SKIP
115
+ [ 0.782963853573023 0.802057689719883 0.427895717335467]
116
+ [0.0541876859348597 0.708243266653103 0.615134039977379]
117
+ [ 0.856151514955773 0.544759264818486 0.686210904770947]
118
+
119
+ Use ``nstr`` or ``nprint`` to specify the number of digits to print::
120
+
121
+ >>> nprint(randmatrix(5), 3) # doctest:+SKIP
122
+ [2.07e-1 1.66e-1 5.06e-1 1.89e-1 8.29e-1]
123
+ [6.62e-1 6.55e-1 4.47e-1 4.82e-1 2.06e-2]
124
+ [4.33e-1 7.75e-1 6.93e-2 2.86e-1 5.71e-1]
125
+ [1.01e-1 2.53e-1 6.13e-1 3.32e-1 2.59e-1]
126
+ [1.56e-1 7.27e-2 6.05e-1 6.67e-2 2.79e-1]
127
+
128
+ As matrices are mutable, you will need to copy them sometimes::
129
+
130
+ >>> A = matrix(2)
131
+ >>> A
132
+ matrix(
133
+ [['0.0', '0.0'],
134
+ ['0.0', '0.0']])
135
+ >>> B = A.copy()
136
+ >>> B[0,0] = 1
137
+ >>> B
138
+ matrix(
139
+ [['1.0', '0.0'],
140
+ ['0.0', '0.0']])
141
+ >>> A
142
+ matrix(
143
+ [['0.0', '0.0'],
144
+ ['0.0', '0.0']])
145
+
146
+ Finally, it is possible to convert a matrix to a nested list. This is very useful,
147
+ as most Python libraries involving matrices or arrays (namely NumPy or SymPy)
148
+ support this format::
149
+
150
+ >>> B.tolist()
151
+ [[mpf('1.0'), mpf('0.0')], [mpf('0.0'), mpf('0.0')]]
152
+
153
+
154
+ Matrix operations
155
+ -----------------
156
+
157
+ You can add and subtract matrices of compatible dimensions::
158
+
159
+ >>> A = matrix([[1, 2], [3, 4]])
160
+ >>> B = matrix([[-2, 4], [5, 9]])
161
+ >>> A + B
162
+ matrix(
163
+ [['-1.0', '6.0'],
164
+ ['8.0', '13.0']])
165
+ >>> A - B
166
+ matrix(
167
+ [['3.0', '-2.0'],
168
+ ['-2.0', '-5.0']])
169
+ >>> A + ones(3) # doctest:+ELLIPSIS
170
+ Traceback (most recent call last):
171
+ ...
172
+ ValueError: incompatible dimensions for addition
173
+
174
+ It is possible to multiply or add matrices and scalars. In the latter case the
175
+ operation will be done element-wise::
176
+
177
+ >>> A * 2
178
+ matrix(
179
+ [['2.0', '4.0'],
180
+ ['6.0', '8.0']])
181
+ >>> A / 4
182
+ matrix(
183
+ [['0.25', '0.5'],
184
+ ['0.75', '1.0']])
185
+ >>> A - 1
186
+ matrix(
187
+ [['0.0', '1.0'],
188
+ ['2.0', '3.0']])
189
+
190
+ Of course you can perform matrix multiplication, if the dimensions are
191
+ compatible, using ``@`` (for Python >= 3.5) or ``*``. For clarity, ``@`` is
192
+ recommended (`PEP 465 <https://www.python.org/dev/peps/pep-0465/>`), because
193
+ the meaning of ``*`` is different in many other Python libraries such as NumPy.
194
+
195
+ >>> A @ B # doctest:+SKIP
196
+ matrix(
197
+ [['8.0', '22.0'],
198
+ ['14.0', '48.0']])
199
+ >>> A * B # same as A @ B
200
+ matrix(
201
+ [['8.0', '22.0'],
202
+ ['14.0', '48.0']])
203
+ >>> matrix([[1, 2, 3]]) * matrix([[-6], [7], [-2]])
204
+ matrix(
205
+ [['2.0']])
206
+
207
+ ..
208
+ COMMENT: TODO: the above "doctest:+SKIP" may be removed as soon as we
209
+ have dropped support for Python 3.5 and below.
210
+
211
+ You can raise powers of square matrices::
212
+
213
+ >>> A**2
214
+ matrix(
215
+ [['7.0', '10.0'],
216
+ ['15.0', '22.0']])
217
+
218
+ Negative powers will calculate the inverse::
219
+
220
+ >>> A**-1
221
+ matrix(
222
+ [['-2.0', '1.0'],
223
+ ['1.5', '-0.5']])
224
+ >>> A * A**-1
225
+ matrix(
226
+ [['1.0', '1.0842021724855e-19'],
227
+ ['-2.16840434497101e-19', '1.0']])
228
+
229
+
230
+
231
+ Matrix transposition is straightforward::
232
+
233
+ >>> A = ones(2, 3)
234
+ >>> A
235
+ matrix(
236
+ [['1.0', '1.0', '1.0'],
237
+ ['1.0', '1.0', '1.0']])
238
+ >>> A.T
239
+ matrix(
240
+ [['1.0', '1.0'],
241
+ ['1.0', '1.0'],
242
+ ['1.0', '1.0']])
243
+
244
+ Norms
245
+ .....
246
+
247
+ Sometimes you need to know how "large" a matrix or vector is. Due to their
248
+ multidimensional nature it's not possible to compare them, but there are
249
+ several functions to map a matrix or a vector to a positive real number, the
250
+ so called norms.
251
+
252
+ For vectors the p-norm is intended, usually the 1-, the 2- and the oo-norm are
253
+ used.
254
+
255
+ >>> x = matrix([-10, 2, 100])
256
+ >>> norm(x, 1)
257
+ mpf('112.0')
258
+ >>> norm(x, 2)
259
+ mpf('100.5186549850325')
260
+ >>> norm(x, inf)
261
+ mpf('100.0')
262
+
263
+ Please note that the 2-norm is the most used one, though it is more expensive
264
+ to calculate than the 1- or oo-norm.
265
+
266
+ It is possible to generalize some vector norms to matrix norm::
267
+
268
+ >>> A = matrix([[1, -1000], [100, 50]])
269
+ >>> mnorm(A, 1)
270
+ mpf('1050.0')
271
+ >>> mnorm(A, inf)
272
+ mpf('1001.0')
273
+ >>> mnorm(A, 'F')
274
+ mpf('1006.2310867787777')
275
+
276
+ The last norm (the "Frobenius-norm") is an approximation for the 2-norm, which
277
+ is hard to calculate and not available. The Frobenius-norm lacks some
278
+ mathematical properties you might expect from a norm.
279
+ """
280
+
281
+ def __init__(self, *args, **kwargs):
282
+ self.__data = {}
283
+ # LU decompostion cache, this is useful when solving the same system
284
+ # multiple times, when calculating the inverse and when calculating the
285
+ # determinant
286
+ self._LU = None
287
+ if "force_type" in kwargs:
288
+ warnings.warn("The force_type argument was removed, it did not work"
289
+ " properly anyway. If you want to force floating-point or"
290
+ " interval computations, use the respective methods from `fp`"
291
+ " or `mp` instead, e.g., `fp.matrix()` or `iv.matrix()`."
292
+ " If you want to truncate values to integer, use .apply(int) instead.")
293
+ if isinstance(args[0], (list, tuple)):
294
+ if isinstance(args[0][0], (list, tuple)):
295
+ # interpret nested list as matrix
296
+ A = args[0]
297
+ self.__rows = len(A)
298
+ self.__cols = len(A[0])
299
+ for i, row in enumerate(A):
300
+ for j, a in enumerate(row):
301
+ # note: this will call __setitem__ which will call self.ctx.convert() to convert the datatype.
302
+ self[i, j] = a
303
+ else:
304
+ # interpret list as row vector
305
+ v = args[0]
306
+ self.__rows = len(v)
307
+ self.__cols = 1
308
+ for i, e in enumerate(v):
309
+ self[i, 0] = e
310
+ elif isinstance(args[0], int):
311
+ # create empty matrix of given dimensions
312
+ if len(args) == 1:
313
+ self.__rows = self.__cols = args[0]
314
+ else:
315
+ if not isinstance(args[1], int):
316
+ raise TypeError("expected int")
317
+ self.__rows = args[0]
318
+ self.__cols = args[1]
319
+ elif isinstance(args[0], _matrix):
320
+ A = args[0]
321
+ self.__rows = A._matrix__rows
322
+ self.__cols = A._matrix__cols
323
+ for i in xrange(A.__rows):
324
+ for j in xrange(A.__cols):
325
+ self[i, j] = A[i, j]
326
+ elif hasattr(args[0], 'tolist'):
327
+ A = self.ctx.matrix(args[0].tolist())
328
+ self.__data = A._matrix__data
329
+ self.__rows = A._matrix__rows
330
+ self.__cols = A._matrix__cols
331
+ else:
332
+ raise TypeError('could not interpret given arguments')
333
+
334
+ def apply(self, f):
335
+ """
336
+ Return a copy of self with the function `f` applied elementwise.
337
+ """
338
+ new = self.ctx.matrix(self.__rows, self.__cols)
339
+ for i in xrange(self.__rows):
340
+ for j in xrange(self.__cols):
341
+ new[i,j] = f(self[i,j])
342
+ return new
343
+
344
+ def __nstr__(self, n=None, **kwargs):
345
+ # Build table of string representations of the elements
346
+ res = []
347
+ # Track per-column max lengths for pretty alignment
348
+ maxlen = [0] * self.cols
349
+ for i in range(self.rows):
350
+ res.append([])
351
+ for j in range(self.cols):
352
+ if n:
353
+ string = self.ctx.nstr(self[i,j], n, **kwargs)
354
+ else:
355
+ string = str(self[i,j])
356
+ res[-1].append(string)
357
+ maxlen[j] = max(len(string), maxlen[j])
358
+ # Patch strings together
359
+ for i, row in enumerate(res):
360
+ for j, elem in enumerate(row):
361
+ # Pad each element up to maxlen so the columns line up
362
+ row[j] = elem.rjust(maxlen[j])
363
+ res[i] = "[" + colsep.join(row) + "]"
364
+ return rowsep.join(res)
365
+
366
+ def __str__(self):
367
+ return self.__nstr__()
368
+
369
+ def _toliststr(self, avoid_type=False):
370
+ """
371
+ Create a list string from a matrix.
372
+
373
+ If avoid_type: avoid multiple 'mpf's.
374
+ """
375
+ # XXX: should be something like self.ctx._types
376
+ typ = self.ctx.mpf
377
+ s = '['
378
+ for i in xrange(self.__rows):
379
+ s += '['
380
+ for j in xrange(self.__cols):
381
+ if not avoid_type or not isinstance(self[i,j], typ):
382
+ a = repr(self[i,j])
383
+ else:
384
+ a = "'" + str(self[i,j]) + "'"
385
+ s += a + ', '
386
+ s = s[:-2]
387
+ s += '],\n '
388
+ s = s[:-3]
389
+ s += ']'
390
+ return s
391
+
392
+ def tolist(self):
393
+ """
394
+ Convert the matrix to a nested list.
395
+ """
396
+ return [[self[i,j] for j in range(self.__cols)] for i in range(self.__rows)]
397
+
398
+ def __repr__(self):
399
+ if self.ctx.pretty:
400
+ return self.__str__()
401
+ s = 'matrix(\n'
402
+ s += self._toliststr(avoid_type=True) + ')'
403
+ return s
404
+
405
+ def __get_element(self, key):
406
+ '''
407
+ Fast extraction of the i,j element from the matrix
408
+ This function is for private use only because is unsafe:
409
+ 1. Does not check on the value of key it expects key to be a integer tuple (i,j)
410
+ 2. Does not check bounds
411
+ '''
412
+ if key in self.__data:
413
+ return self.__data[key]
414
+ else:
415
+ return self.ctx.zero
416
+
417
+ def __set_element(self, key, value):
418
+ '''
419
+ Fast assignment of the i,j element in the matrix
420
+ This function is unsafe:
421
+ 1. Does not check on the value of key it expects key to be a integer tuple (i,j)
422
+ 2. Does not check bounds
423
+ 3. Does not check the value type
424
+ 4. Does not reset the LU cache
425
+ '''
426
+ if value: # only store non-zeros
427
+ self.__data[key] = value
428
+ elif key in self.__data:
429
+ del self.__data[key]
430
+
431
+
432
+ def __getitem__(self, key):
433
+ '''
434
+ Getitem function for mp matrix class with slice index enabled
435
+ it allows the following assingments
436
+ scalar to a slice of the matrix
437
+ B = A[:,2:6]
438
+ '''
439
+ # Convert vector to matrix indexing
440
+ if isinstance(key, int) or isinstance(key,slice):
441
+ # only sufficent for vectors
442
+ if self.__rows == 1:
443
+ key = (0, key)
444
+ elif self.__cols == 1:
445
+ key = (key, 0)
446
+ else:
447
+ raise IndexError('insufficient indices for matrix')
448
+
449
+ if isinstance(key[0],slice) or isinstance(key[1],slice):
450
+
451
+ #Rows
452
+ if isinstance(key[0],slice):
453
+ #Check bounds
454
+ if (key[0].start is None or key[0].start >= 0) and \
455
+ (key[0].stop is None or key[0].stop <= self.__rows+1):
456
+ # Generate indices
457
+ rows = xrange(*key[0].indices(self.__rows))
458
+ else:
459
+ raise IndexError('Row index out of bounds')
460
+ else:
461
+ # Single row
462
+ rows = [key[0]]
463
+
464
+ # Columns
465
+ if isinstance(key[1],slice):
466
+ # Check bounds
467
+ if (key[1].start is None or key[1].start >= 0) and \
468
+ (key[1].stop is None or key[1].stop <= self.__cols+1):
469
+ # Generate indices
470
+ columns = xrange(*key[1].indices(self.__cols))
471
+ else:
472
+ raise IndexError('Column index out of bounds')
473
+
474
+ else:
475
+ # Single column
476
+ columns = [key[1]]
477
+
478
+ # Create matrix slice
479
+ m = self.ctx.matrix(len(rows),len(columns))
480
+
481
+ # Assign elements to the output matrix
482
+ for i,x in enumerate(rows):
483
+ for j,y in enumerate(columns):
484
+ m.__set_element((i,j),self.__get_element((x,y)))
485
+
486
+ return m
487
+
488
+ else:
489
+ # single element extraction
490
+ if key[0] >= self.__rows or key[1] >= self.__cols:
491
+ raise IndexError('matrix index out of range')
492
+ if key in self.__data:
493
+ return self.__data[key]
494
+ else:
495
+ return self.ctx.zero
496
+
497
+ def __setitem__(self, key, value):
498
+ # setitem function for mp matrix class with slice index enabled
499
+ # it allows the following assingments
500
+ # scalar to a slice of the matrix
501
+ # A[:,2:6] = 2.5
502
+ # submatrix to matrix (the value matrix should be the same size as the slice size)
503
+ # A[3,:] = B where A is n x m and B is n x 1
504
+ # Convert vector to matrix indexing
505
+ if isinstance(key, int) or isinstance(key,slice):
506
+ # only sufficent for vectors
507
+ if self.__rows == 1:
508
+ key = (0, key)
509
+ elif self.__cols == 1:
510
+ key = (key, 0)
511
+ else:
512
+ raise IndexError('insufficient indices for matrix')
513
+ # Slice indexing
514
+ if isinstance(key[0],slice) or isinstance(key[1],slice):
515
+ # Rows
516
+ if isinstance(key[0],slice):
517
+ # Check bounds
518
+ if (key[0].start is None or key[0].start >= 0) and \
519
+ (key[0].stop is None or key[0].stop <= self.__rows+1):
520
+ # generate row indices
521
+ rows = xrange(*key[0].indices(self.__rows))
522
+ else:
523
+ raise IndexError('Row index out of bounds')
524
+ else:
525
+ # Single row
526
+ rows = [key[0]]
527
+ # Columns
528
+ if isinstance(key[1],slice):
529
+ # Check bounds
530
+ if (key[1].start is None or key[1].start >= 0) and \
531
+ (key[1].stop is None or key[1].stop <= self.__cols+1):
532
+ # Generate column indices
533
+ columns = xrange(*key[1].indices(self.__cols))
534
+ else:
535
+ raise IndexError('Column index out of bounds')
536
+ else:
537
+ # Single column
538
+ columns = [key[1]]
539
+ # Assign slice with a scalar
540
+ if isinstance(value,self.ctx.matrix):
541
+ # Assign elements to matrix if input and output dimensions match
542
+ if len(rows) == value.rows and len(columns) == value.cols:
543
+ for i,x in enumerate(rows):
544
+ for j,y in enumerate(columns):
545
+ self.__set_element((x,y), value.__get_element((i,j)))
546
+ else:
547
+ raise ValueError('Dimensions do not match')
548
+ else:
549
+ # Assign slice with scalars
550
+ value = self.ctx.convert(value)
551
+ for i in rows:
552
+ for j in columns:
553
+ self.__set_element((i,j), value)
554
+ else:
555
+ # Single element assingment
556
+ # Check bounds
557
+ if key[0] >= self.__rows or key[1] >= self.__cols:
558
+ raise IndexError('matrix index out of range')
559
+ # Convert and store value
560
+ value = self.ctx.convert(value)
561
+ if value: # only store non-zeros
562
+ self.__data[key] = value
563
+ elif key in self.__data:
564
+ del self.__data[key]
565
+
566
+ if self._LU:
567
+ self._LU = None
568
+ return
569
+
570
+ def __iter__(self):
571
+ for i in xrange(self.__rows):
572
+ for j in xrange(self.__cols):
573
+ yield self[i,j]
574
+
575
+ def __mul__(self, other):
576
+ if isinstance(other, self.ctx.matrix):
577
+ # dot multiplication
578
+ if self.__cols != other.__rows:
579
+ raise ValueError('dimensions not compatible for multiplication')
580
+ new = self.ctx.matrix(self.__rows, other.__cols)
581
+ self_zero = self.ctx.zero
582
+ self_get = self.__data.get
583
+ other_zero = other.ctx.zero
584
+ other_get = other.__data.get
585
+ for i in xrange(self.__rows):
586
+ for j in xrange(other.__cols):
587
+ new[i, j] = self.ctx.fdot((self_get((i,k), self_zero), other_get((k,j), other_zero))
588
+ for k in xrange(other.__rows))
589
+ return new
590
+ else:
591
+ # try scalar multiplication
592
+ new = self.ctx.matrix(self.__rows, self.__cols)
593
+ for i in xrange(self.__rows):
594
+ for j in xrange(self.__cols):
595
+ new[i, j] = other * self[i, j]
596
+ return new
597
+
598
+ def __matmul__(self, other):
599
+ return self.__mul__(other)
600
+
601
+ def __rmul__(self, other):
602
+ # assume other is scalar and thus commutative
603
+ if isinstance(other, self.ctx.matrix):
604
+ raise TypeError("other should not be type of ctx.matrix")
605
+ return self.__mul__(other)
606
+
607
+ def __pow__(self, other):
608
+ # avoid cyclic import problems
609
+ #from linalg import inverse
610
+ if not isinstance(other, int):
611
+ raise ValueError('only integer exponents are supported')
612
+ if not self.__rows == self.__cols:
613
+ raise ValueError('only powers of square matrices are defined')
614
+ n = other
615
+ if n == 0:
616
+ return self.ctx.eye(self.__rows)
617
+ if n < 0:
618
+ n = -n
619
+ neg = True
620
+ else:
621
+ neg = False
622
+ i = n
623
+ y = 1
624
+ z = self.copy()
625
+ while i != 0:
626
+ if i % 2 == 1:
627
+ y = y * z
628
+ z = z*z
629
+ i = i // 2
630
+ if neg:
631
+ y = self.ctx.inverse(y)
632
+ return y
633
+
634
+ def __div__(self, other):
635
+ # assume other is scalar and do element-wise divison
636
+ assert not isinstance(other, self.ctx.matrix)
637
+ new = self.ctx.matrix(self.__rows, self.__cols)
638
+ for i in xrange(self.__rows):
639
+ for j in xrange(self.__cols):
640
+ new[i,j] = self[i,j] / other
641
+ return new
642
+
643
+ __truediv__ = __div__
644
+
645
+ def __add__(self, other):
646
+ if isinstance(other, self.ctx.matrix):
647
+ if not (self.__rows == other.__rows and self.__cols == other.__cols):
648
+ raise ValueError('incompatible dimensions for addition')
649
+ new = self.ctx.matrix(self.__rows, self.__cols)
650
+ for i in xrange(self.__rows):
651
+ for j in xrange(self.__cols):
652
+ new[i,j] = self[i,j] + other[i,j]
653
+ return new
654
+ else:
655
+ # assume other is scalar and add element-wise
656
+ new = self.ctx.matrix(self.__rows, self.__cols)
657
+ for i in xrange(self.__rows):
658
+ for j in xrange(self.__cols):
659
+ new[i,j] += self[i,j] + other
660
+ return new
661
+
662
+ def __radd__(self, other):
663
+ return self.__add__(other)
664
+
665
+ def __sub__(self, other):
666
+ if isinstance(other, self.ctx.matrix) and not (self.__rows == other.__rows
667
+ and self.__cols == other.__cols):
668
+ raise ValueError('incompatible dimensions for subtraction')
669
+ return self.__add__(other * (-1))
670
+
671
+ def __pos__(self):
672
+ """
673
+ +M returns a copy of M, rounded to current working precision.
674
+ """
675
+ return (+1) * self
676
+
677
+ def __neg__(self):
678
+ return (-1) * self
679
+
680
+ def __rsub__(self, other):
681
+ return -self + other
682
+
683
+ def __eq__(self, other):
684
+ return self.__rows == other.__rows and self.__cols == other.__cols \
685
+ and self.__data == other.__data
686
+
687
+ def __len__(self):
688
+ if self.rows == 1:
689
+ return self.cols
690
+ elif self.cols == 1:
691
+ return self.rows
692
+ else:
693
+ return self.rows # do it like numpy
694
+
695
+ def __getrows(self):
696
+ return self.__rows
697
+
698
+ def __setrows(self, value):
699
+ for key in self.__data.copy():
700
+ if key[0] >= value:
701
+ del self.__data[key]
702
+ self.__rows = value
703
+
704
+ rows = property(__getrows, __setrows, doc='number of rows')
705
+
706
+ def __getcols(self):
707
+ return self.__cols
708
+
709
+ def __setcols(self, value):
710
+ for key in self.__data.copy():
711
+ if key[1] >= value:
712
+ del self.__data[key]
713
+ self.__cols = value
714
+
715
+ cols = property(__getcols, __setcols, doc='number of columns')
716
+
717
+ def transpose(self):
718
+ new = self.ctx.matrix(self.__cols, self.__rows)
719
+ for i in xrange(self.__rows):
720
+ for j in xrange(self.__cols):
721
+ new[j,i] = self[i,j]
722
+ return new
723
+
724
+ T = property(transpose)
725
+
726
+ def conjugate(self):
727
+ return self.apply(self.ctx.conj)
728
+
729
+ def transpose_conj(self):
730
+ return self.conjugate().transpose()
731
+
732
+ H = property(transpose_conj)
733
+
734
+ def copy(self):
735
+ new = self.ctx.matrix(self.__rows, self.__cols)
736
+ new.__data = self.__data.copy()
737
+ return new
738
+
739
+ __copy__ = copy
740
+
741
+ def column(self, n):
742
+ m = self.ctx.matrix(self.rows, 1)
743
+ for i in range(self.rows):
744
+ m[i] = self[i,n]
745
+ return m
746
+
747
+ class MatrixMethods(object):
748
+
749
+ def __init__(ctx):
750
+ # XXX: subclass
751
+ ctx.matrix = type('matrix', (_matrix,), {})
752
+ ctx.matrix.ctx = ctx
753
+ ctx.matrix.convert = ctx.convert
754
+
755
+ def eye(ctx, n, **kwargs):
756
+ """
757
+ Create square identity matrix n x n.
758
+ """
759
+ A = ctx.matrix(n, **kwargs)
760
+ for i in xrange(n):
761
+ A[i,i] = 1
762
+ return A
763
+
764
+ def diag(ctx, diagonal, **kwargs):
765
+ """
766
+ Create square diagonal matrix using given list.
767
+
768
+ Example:
769
+ >>> from mpmath import diag, mp
770
+ >>> mp.pretty = False
771
+ >>> diag([1, 2, 3])
772
+ matrix(
773
+ [['1.0', '0.0', '0.0'],
774
+ ['0.0', '2.0', '0.0'],
775
+ ['0.0', '0.0', '3.0']])
776
+ """
777
+ A = ctx.matrix(len(diagonal), **kwargs)
778
+ for i in xrange(len(diagonal)):
779
+ A[i,i] = diagonal[i]
780
+ return A
781
+
782
+ def zeros(ctx, *args, **kwargs):
783
+ """
784
+ Create matrix m x n filled with zeros.
785
+ One given dimension will create square matrix n x n.
786
+
787
+ Example:
788
+ >>> from mpmath import zeros, mp
789
+ >>> mp.pretty = False
790
+ >>> zeros(2)
791
+ matrix(
792
+ [['0.0', '0.0'],
793
+ ['0.0', '0.0']])
794
+ """
795
+ if len(args) == 1:
796
+ m = n = args[0]
797
+ elif len(args) == 2:
798
+ m = args[0]
799
+ n = args[1]
800
+ else:
801
+ raise TypeError('zeros expected at most 2 arguments, got %i' % len(args))
802
+ A = ctx.matrix(m, n, **kwargs)
803
+ for i in xrange(m):
804
+ for j in xrange(n):
805
+ A[i,j] = 0
806
+ return A
807
+
808
+ def ones(ctx, *args, **kwargs):
809
+ """
810
+ Create matrix m x n filled with ones.
811
+ One given dimension will create square matrix n x n.
812
+
813
+ Example:
814
+ >>> from mpmath import ones, mp
815
+ >>> mp.pretty = False
816
+ >>> ones(2)
817
+ matrix(
818
+ [['1.0', '1.0'],
819
+ ['1.0', '1.0']])
820
+ """
821
+ if len(args) == 1:
822
+ m = n = args[0]
823
+ elif len(args) == 2:
824
+ m = args[0]
825
+ n = args[1]
826
+ else:
827
+ raise TypeError('ones expected at most 2 arguments, got %i' % len(args))
828
+ A = ctx.matrix(m, n, **kwargs)
829
+ for i in xrange(m):
830
+ for j in xrange(n):
831
+ A[i,j] = 1
832
+ return A
833
+
834
+ def hilbert(ctx, m, n=None):
835
+ """
836
+ Create (pseudo) hilbert matrix m x n.
837
+ One given dimension will create hilbert matrix n x n.
838
+
839
+ The matrix is very ill-conditioned and symmetric, positive definite if
840
+ square.
841
+ """
842
+ if n is None:
843
+ n = m
844
+ A = ctx.matrix(m, n)
845
+ for i in xrange(m):
846
+ for j in xrange(n):
847
+ A[i,j] = ctx.one / (i + j + 1)
848
+ return A
849
+
850
+ def randmatrix(ctx, m, n=None, min=0, max=1, **kwargs):
851
+ """
852
+ Create a random m x n matrix.
853
+
854
+ All values are >= min and <max.
855
+ n defaults to m.
856
+
857
+ Example:
858
+ >>> from mpmath import randmatrix
859
+ >>> randmatrix(2) # doctest:+SKIP
860
+ matrix(
861
+ [['0.53491598236191806', '0.57195669543302752'],
862
+ ['0.85589992269513615', '0.82444367501382143']])
863
+ """
864
+ if not n:
865
+ n = m
866
+ A = ctx.matrix(m, n, **kwargs)
867
+ for i in xrange(m):
868
+ for j in xrange(n):
869
+ A[i,j] = ctx.rand() * (max - min) + min
870
+ return A
871
+
872
+ def swap_row(ctx, A, i, j):
873
+ """
874
+ Swap row i with row j.
875
+ """
876
+ if i == j:
877
+ return
878
+ if isinstance(A, ctx.matrix):
879
+ for k in xrange(A.cols):
880
+ A[i,k], A[j,k] = A[j,k], A[i,k]
881
+ elif isinstance(A, list):
882
+ A[i], A[j] = A[j], A[i]
883
+ else:
884
+ raise TypeError('could not interpret type')
885
+
886
+ def extend(ctx, A, b):
887
+ """
888
+ Extend matrix A with column b and return result.
889
+ """
890
+ if not isinstance(A, ctx.matrix):
891
+ raise TypeError("A should be a type of ctx.matrix")
892
+ if A.rows != len(b):
893
+ raise ValueError("Value should be equal to len(b)")
894
+ A = A.copy()
895
+ A.cols += 1
896
+ for i in xrange(A.rows):
897
+ A[i, A.cols-1] = b[i]
898
+ return A
899
+
900
+ def norm(ctx, x, p=2):
901
+ r"""
902
+ Gives the entrywise `p`-norm of an iterable *x*, i.e. the vector norm
903
+ `\left(\sum_k |x_k|^p\right)^{1/p}`, for any given `1 \le p \le \infty`.
904
+
905
+ Special cases:
906
+
907
+ If *x* is not iterable, this just returns ``absmax(x)``.
908
+
909
+ ``p=1`` gives the sum of absolute values.
910
+
911
+ ``p=2`` is the standard Euclidean vector norm.
912
+
913
+ ``p=inf`` gives the magnitude of the largest element.
914
+
915
+ For *x* a matrix, ``p=2`` is the Frobenius norm.
916
+ For operator matrix norms, use :func:`~mpmath.mnorm` instead.
917
+
918
+ You can use the string 'inf' as well as float('inf') or mpf('inf')
919
+ to specify the infinity norm.
920
+
921
+ **Examples**
922
+
923
+ >>> from mpmath import *
924
+ >>> mp.dps = 15; mp.pretty = False
925
+ >>> x = matrix([-10, 2, 100])
926
+ >>> norm(x, 1)
927
+ mpf('112.0')
928
+ >>> norm(x, 2)
929
+ mpf('100.5186549850325')
930
+ >>> norm(x, inf)
931
+ mpf('100.0')
932
+
933
+ """
934
+ try:
935
+ iter(x)
936
+ except TypeError:
937
+ return ctx.absmax(x)
938
+ if type(p) is not int:
939
+ p = ctx.convert(p)
940
+ if p == ctx.inf:
941
+ return max(ctx.absmax(i) for i in x)
942
+ elif p == 1:
943
+ return ctx.fsum(x, absolute=1)
944
+ elif p == 2:
945
+ return ctx.sqrt(ctx.fsum(x, absolute=1, squared=1))
946
+ elif p > 1:
947
+ return ctx.nthroot(ctx.fsum(abs(i)**p for i in x), p)
948
+ else:
949
+ raise ValueError('p has to be >= 1')
950
+
951
+ def mnorm(ctx, A, p=1):
952
+ r"""
953
+ Gives the matrix (operator) `p`-norm of A. Currently ``p=1`` and ``p=inf``
954
+ are supported:
955
+
956
+ ``p=1`` gives the 1-norm (maximal column sum)
957
+
958
+ ``p=inf`` gives the `\infty`-norm (maximal row sum).
959
+ You can use the string 'inf' as well as float('inf') or mpf('inf')
960
+
961
+ ``p=2`` (not implemented) for a square matrix is the usual spectral
962
+ matrix norm, i.e. the largest singular value.
963
+
964
+ ``p='f'`` (or 'F', 'fro', 'Frobenius, 'frobenius') gives the
965
+ Frobenius norm, which is the elementwise 2-norm. The Frobenius norm is an
966
+ approximation of the spectral norm and satisfies
967
+
968
+ .. math ::
969
+
970
+ \frac{1}{\sqrt{\mathrm{rank}(A)}} \|A\|_F \le \|A\|_2 \le \|A\|_F
971
+
972
+ The Frobenius norm lacks some mathematical properties that might
973
+ be expected of a norm.
974
+
975
+ For general elementwise `p`-norms, use :func:`~mpmath.norm` instead.
976
+
977
+ **Examples**
978
+
979
+ >>> from mpmath import *
980
+ >>> mp.dps = 15; mp.pretty = False
981
+ >>> A = matrix([[1, -1000], [100, 50]])
982
+ >>> mnorm(A, 1)
983
+ mpf('1050.0')
984
+ >>> mnorm(A, inf)
985
+ mpf('1001.0')
986
+ >>> mnorm(A, 'F')
987
+ mpf('1006.2310867787777')
988
+
989
+ """
990
+ A = ctx.matrix(A)
991
+ if type(p) is not int:
992
+ if type(p) is str and 'frobenius'.startswith(p.lower()):
993
+ return ctx.norm(A, 2)
994
+ p = ctx.convert(p)
995
+ m, n = A.rows, A.cols
996
+ if p == 1:
997
+ return max(ctx.fsum((A[i,j] for i in xrange(m)), absolute=1) for j in xrange(n))
998
+ elif p == ctx.inf:
999
+ return max(ctx.fsum((A[i,j] for j in xrange(n)), absolute=1) for i in xrange(m))
1000
+ else:
1001
+ raise NotImplementedError("matrix p-norm for arbitrary p")
1002
+
1003
+ if __name__ == '__main__':
1004
+ import doctest
1005
+ doctest.testmod()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (214 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (222 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h ADDED
@@ -0,0 +1,1853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+ #pragma once
50
+
51
+ #ifndef CUBLASAPI
52
+ #ifdef __CUDACC__
53
+ #define CUBLASAPI __host__ __device__
54
+ #else
55
+ #define CUBLASAPI
56
+ #endif
57
+ #endif
58
+
59
+ #include <cublas_api.h>
60
+
61
+ #include <stdint.h>
62
+ #include <stddef.h>
63
+ #include <stdio.h>
64
+
65
+ #if defined(__cplusplus)
66
+ extern "C" {
67
+ #endif /* __cplusplus */
68
+
69
+ /** Opaque structure holding CUBLASLT context
70
+ */
71
+ typedef struct cublasLtContext* cublasLtHandle_t;
72
+
73
+ cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t* lightHandle);
74
+
75
+ cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle);
76
+
77
+ const char* CUBLASWINAPI cublasLtGetStatusName(cublasStatus_t status);
78
+
79
+ const char* CUBLASWINAPI cublasLtGetStatusString(cublasStatus_t status);
80
+
81
+ size_t CUBLASWINAPI cublasLtGetVersion(void);
82
+
83
+ size_t CUBLASWINAPI cublasLtGetCudartVersion(void);
84
+
85
+ cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type, int* value);
86
+
87
+ cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheGetCapacity(size_t* capacity);
88
+ cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheSetCapacity(size_t capacity);
89
+
90
+ /** Semi-opaque descriptor for matrix memory layout
91
+ */
92
+ typedef struct {
93
+ uint64_t data[8];
94
+ } cublasLtMatrixLayoutOpaque_t;
95
+
96
+ /** Opaque descriptor for matrix memory layout
97
+ */
98
+ typedef cublasLtMatrixLayoutOpaque_t* cublasLtMatrixLayout_t;
99
+
100
+ /** Semi-opaque algorithm descriptor (to avoid complicated alloc/free schemes)
101
+ *
102
+ * This structure can be trivially serialized and later restored for use with the same version of cuBLAS library to save
103
+ * on selecting the right configuration again.
104
+ */
105
+ typedef struct {
106
+ uint64_t data[8];
107
+ } cublasLtMatmulAlgo_t;
108
+
109
+ /** Semi-opaque descriptor for cublasLtMatmul() operation details
110
+ */
111
+ typedef struct {
112
+ uint64_t data[23];
113
+ } cublasLtMatmulDescOpaque_t;
114
+
115
+ /** Opaque descriptor for cublasLtMatmul() operation details
116
+ */
117
+ typedef cublasLtMatmulDescOpaque_t* cublasLtMatmulDesc_t;
118
+
119
+ /** Semi-opaque descriptor for cublasLtMatrixTransform() operation details
120
+ */
121
+ typedef struct {
122
+ uint64_t data[8];
123
+ } cublasLtMatrixTransformDescOpaque_t;
124
+
125
+ /** Opaque descriptor for cublasLtMatrixTransform() operation details
126
+ */
127
+ typedef cublasLtMatrixTransformDescOpaque_t* cublasLtMatrixTransformDesc_t;
128
+
129
+ /** Semi-opaque descriptor for cublasLtMatmulPreference() operation details
130
+ */
131
+ typedef struct {
132
+ uint64_t data[10];
133
+ } cublasLtMatmulPreferenceOpaque_t;
134
+
135
+ /** Opaque descriptor for cublasLtMatmulAlgoGetHeuristic() configuration
136
+ */
137
+ typedef cublasLtMatmulPreferenceOpaque_t* cublasLtMatmulPreference_t;
138
+
139
+ /** Tile size (in C/D matrix Rows x Cols)
140
+ *
141
+ * General order of tile IDs is sorted by size first and by first dimension second.
142
+ */
143
+ typedef enum {
144
+ CUBLASLT_MATMUL_TILE_UNDEFINED = 0,
145
+ CUBLASLT_MATMUL_TILE_8x8 = 1,
146
+ CUBLASLT_MATMUL_TILE_8x16 = 2,
147
+ CUBLASLT_MATMUL_TILE_16x8 = 3,
148
+ CUBLASLT_MATMUL_TILE_8x32 = 4,
149
+ CUBLASLT_MATMUL_TILE_16x16 = 5,
150
+ CUBLASLT_MATMUL_TILE_32x8 = 6,
151
+ CUBLASLT_MATMUL_TILE_8x64 = 7,
152
+ CUBLASLT_MATMUL_TILE_16x32 = 8,
153
+ CUBLASLT_MATMUL_TILE_32x16 = 9,
154
+ CUBLASLT_MATMUL_TILE_64x8 = 10,
155
+ CUBLASLT_MATMUL_TILE_32x32 = 11,
156
+ CUBLASLT_MATMUL_TILE_32x64 = 12,
157
+ CUBLASLT_MATMUL_TILE_64x32 = 13,
158
+ CUBLASLT_MATMUL_TILE_32x128 = 14,
159
+ CUBLASLT_MATMUL_TILE_64x64 = 15,
160
+ CUBLASLT_MATMUL_TILE_128x32 = 16,
161
+ CUBLASLT_MATMUL_TILE_64x128 = 17,
162
+ CUBLASLT_MATMUL_TILE_128x64 = 18,
163
+ CUBLASLT_MATMUL_TILE_64x256 = 19,
164
+ CUBLASLT_MATMUL_TILE_128x128 = 20,
165
+ CUBLASLT_MATMUL_TILE_256x64 = 21,
166
+ CUBLASLT_MATMUL_TILE_64x512 = 22,
167
+ CUBLASLT_MATMUL_TILE_128x256 = 23,
168
+ CUBLASLT_MATMUL_TILE_256x128 = 24,
169
+ CUBLASLT_MATMUL_TILE_512x64 = 25,
170
+ CUBLASLT_MATMUL_TILE_64x96 = 26,
171
+ CUBLASLT_MATMUL_TILE_96x64 = 27,
172
+ CUBLASLT_MATMUL_TILE_96x128 = 28,
173
+ CUBLASLT_MATMUL_TILE_128x160 = 29,
174
+ CUBLASLT_MATMUL_TILE_160x128 = 30,
175
+ CUBLASLT_MATMUL_TILE_192x128 = 31,
176
+ CUBLASLT_MATMUL_TILE_128x192 = 32,
177
+ CUBLASLT_MATMUL_TILE_128x96 = 33,
178
+ CUBLASLT_MATMUL_TILE_END
179
+ } cublasLtMatmulTile_t;
180
+
181
+ /** Size and number of stages in which elements are read into shared memory
182
+ *
183
+ * General order of stages IDs is sorted by stage size first and by number of stages second.
184
+ */
185
+ typedef enum {
186
+ CUBLASLT_MATMUL_STAGES_UNDEFINED = 0,
187
+ CUBLASLT_MATMUL_STAGES_16x1 = 1,
188
+ CUBLASLT_MATMUL_STAGES_16x2 = 2,
189
+ CUBLASLT_MATMUL_STAGES_16x3 = 3,
190
+ CUBLASLT_MATMUL_STAGES_16x4 = 4,
191
+ CUBLASLT_MATMUL_STAGES_16x5 = 5,
192
+ CUBLASLT_MATMUL_STAGES_16x6 = 6,
193
+ CUBLASLT_MATMUL_STAGES_32x1 = 7,
194
+ CUBLASLT_MATMUL_STAGES_32x2 = 8,
195
+ CUBLASLT_MATMUL_STAGES_32x3 = 9,
196
+ CUBLASLT_MATMUL_STAGES_32x4 = 10,
197
+ CUBLASLT_MATMUL_STAGES_32x5 = 11,
198
+ CUBLASLT_MATMUL_STAGES_32x6 = 12,
199
+ CUBLASLT_MATMUL_STAGES_64x1 = 13,
200
+ CUBLASLT_MATMUL_STAGES_64x2 = 14,
201
+ CUBLASLT_MATMUL_STAGES_64x3 = 15,
202
+ CUBLASLT_MATMUL_STAGES_64x4 = 16,
203
+ CUBLASLT_MATMUL_STAGES_64x5 = 17,
204
+ CUBLASLT_MATMUL_STAGES_64x6 = 18,
205
+ CUBLASLT_MATMUL_STAGES_128x1 = 19,
206
+ CUBLASLT_MATMUL_STAGES_128x2 = 20,
207
+ CUBLASLT_MATMUL_STAGES_128x3 = 21,
208
+ CUBLASLT_MATMUL_STAGES_128x4 = 22,
209
+ CUBLASLT_MATMUL_STAGES_128x5 = 23,
210
+ CUBLASLT_MATMUL_STAGES_128x6 = 24,
211
+ CUBLASLT_MATMUL_STAGES_32x10 = 25,
212
+ CUBLASLT_MATMUL_STAGES_8x4 = 26,
213
+ CUBLASLT_MATMUL_STAGES_16x10 = 27,
214
+ CUBLASLT_MATMUL_STAGES_8x5 = 28,
215
+ CUBLASLT_MATMUL_STAGES_16x80 = 29,
216
+ CUBLASLT_MATMUL_STAGES_64x80 = 30,
217
+ CUBLASLT_MATMUL_STAGES_8x3 = 31,
218
+ CUBLASLT_MATMUL_STAGES_8xAUTO = 32,
219
+ CUBLASLT_MATMUL_STAGES_16xAUTO = 33,
220
+ CUBLASLT_MATMUL_STAGES_32xAUTO = 34,
221
+ CUBLASLT_MATMUL_STAGES_64xAUTO = 35,
222
+ CUBLASLT_MATMUL_STAGES_128xAUTO = 36,
223
+ CUBLASLT_MATMUL_STAGES_END
224
+ } cublasLtMatmulStages_t;
225
+
226
+ /** Thread Block Cluster size
227
+ *
228
+ * Typically dimensioned similar to cublasLtMatmulTile_t, with the third coordinate unused at this time.
229
+ */
230
+ typedef enum {
231
+ /** Let library pick cluster shape automatically */
232
+ CUBLASLT_CLUSTER_SHAPE_AUTO = 0,
233
+ CUBLASLT_CLUSTER_SHAPE_1x1x1 = 2,
234
+ CUBLASLT_CLUSTER_SHAPE_2x1x1 = 3,
235
+ CUBLASLT_CLUSTER_SHAPE_4x1x1 = 4,
236
+ CUBLASLT_CLUSTER_SHAPE_1x2x1 = 5,
237
+ CUBLASLT_CLUSTER_SHAPE_2x2x1 = 6,
238
+ CUBLASLT_CLUSTER_SHAPE_4x2x1 = 7,
239
+ CUBLASLT_CLUSTER_SHAPE_1x4x1 = 8,
240
+ CUBLASLT_CLUSTER_SHAPE_2x4x1 = 9,
241
+ CUBLASLT_CLUSTER_SHAPE_4x4x1 = 10,
242
+ CUBLASLT_CLUSTER_SHAPE_8x1x1 = 11,
243
+ CUBLASLT_CLUSTER_SHAPE_1x8x1 = 12,
244
+ CUBLASLT_CLUSTER_SHAPE_8x2x1 = 13,
245
+ CUBLASLT_CLUSTER_SHAPE_2x8x1 = 14,
246
+ CUBLASLT_CLUSTER_SHAPE_16x1x1 = 15,
247
+ CUBLASLT_CLUSTER_SHAPE_1x16x1 = 16,
248
+ CUBLASLT_CLUSTER_SHAPE_3x1x1 = 17,
249
+ CUBLASLT_CLUSTER_SHAPE_5x1x1 = 18,
250
+ CUBLASLT_CLUSTER_SHAPE_6x1x1 = 19,
251
+ CUBLASLT_CLUSTER_SHAPE_7x1x1 = 20,
252
+ CUBLASLT_CLUSTER_SHAPE_9x1x1 = 21,
253
+ CUBLASLT_CLUSTER_SHAPE_10x1x1 = 22,
254
+ CUBLASLT_CLUSTER_SHAPE_11x1x1 = 23,
255
+ CUBLASLT_CLUSTER_SHAPE_12x1x1 = 24,
256
+ CUBLASLT_CLUSTER_SHAPE_13x1x1 = 25,
257
+ CUBLASLT_CLUSTER_SHAPE_14x1x1 = 26,
258
+ CUBLASLT_CLUSTER_SHAPE_15x1x1 = 27,
259
+ CUBLASLT_CLUSTER_SHAPE_3x2x1 = 28,
260
+ CUBLASLT_CLUSTER_SHAPE_5x2x1 = 29,
261
+ CUBLASLT_CLUSTER_SHAPE_6x2x1 = 30,
262
+ CUBLASLT_CLUSTER_SHAPE_7x2x1 = 31,
263
+ CUBLASLT_CLUSTER_SHAPE_1x3x1 = 32,
264
+ CUBLASLT_CLUSTER_SHAPE_2x3x1 = 33,
265
+ CUBLASLT_CLUSTER_SHAPE_3x3x1 = 34,
266
+ CUBLASLT_CLUSTER_SHAPE_4x3x1 = 35,
267
+ CUBLASLT_CLUSTER_SHAPE_5x3x1 = 36,
268
+ CUBLASLT_CLUSTER_SHAPE_3x4x1 = 37,
269
+ CUBLASLT_CLUSTER_SHAPE_1x5x1 = 38,
270
+ CUBLASLT_CLUSTER_SHAPE_2x5x1 = 39,
271
+ CUBLASLT_CLUSTER_SHAPE_3x5x1 = 40,
272
+ CUBLASLT_CLUSTER_SHAPE_1x6x1 = 41,
273
+ CUBLASLT_CLUSTER_SHAPE_2x6x1 = 42,
274
+ CUBLASLT_CLUSTER_SHAPE_1x7x1 = 43,
275
+ CUBLASLT_CLUSTER_SHAPE_2x7x1 = 44,
276
+ CUBLASLT_CLUSTER_SHAPE_1x9x1 = 45,
277
+ CUBLASLT_CLUSTER_SHAPE_1x10x1 = 46,
278
+ CUBLASLT_CLUSTER_SHAPE_1x11x1 = 47,
279
+ CUBLASLT_CLUSTER_SHAPE_1x12x1 = 48,
280
+ CUBLASLT_CLUSTER_SHAPE_1x13x1 = 49,
281
+ CUBLASLT_CLUSTER_SHAPE_1x14x1 = 50,
282
+ CUBLASLT_CLUSTER_SHAPE_1x15x1 = 51,
283
+ CUBLASLT_CLUSTER_SHAPE_END
284
+ } cublasLtClusterShape_t;
285
+
286
+ /** Inner size of the kernel
287
+ *
288
+ * Represents various aspects of internal kernel design, that don't impact CUDA grid size but may have other more subtle
289
+ * effects.
290
+ *
291
+ */
292
+ typedef enum {
293
+ CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED = 0,
294
+ CUBLASLT_MATMUL_INNER_SHAPE_MMA884 = 1,
295
+ CUBLASLT_MATMUL_INNER_SHAPE_MMA1684 = 2,
296
+ CUBLASLT_MATMUL_INNER_SHAPE_MMA1688 = 3,
297
+ CUBLASLT_MATMUL_INNER_SHAPE_MMA16816 = 4,
298
+ CUBLASLT_MATMUL_INNER_SHAPE_END
299
+ } cublasLtMatmulInnerShape_t;
300
+
301
+ /** Pointer mode to use for alpha/beta */
302
+ typedef enum {
303
+ /** matches CUBLAS_POINTER_MODE_HOST, pointer targets a single value host memory */
304
+ CUBLASLT_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST,
305
+ /** matches CUBLAS_POINTER_MODE_DEVICE, pointer targets a single value device memory */
306
+ CUBLASLT_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE,
307
+ /** pointer targets an array in device memory */
308
+ CUBLASLT_POINTER_MODE_DEVICE_VECTOR = 2,
309
+ /** alpha pointer targets an array in device memory, beta is zero. Note:
310
+ CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is not supported, must be 0. */
311
+ CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO = 3,
312
+ /** alpha pointer targets an array in device memory, beta is a single value in host memory. */
313
+ CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST = 4,
314
+ } cublasLtPointerMode_t;
315
+
316
+ /** Mask to define and query pointer mode capability */
317
+ typedef enum {
318
+ /** no initial filtering is performed when querying pointer mode capabilities, will use gemm pointer mode defined in
319
+ operation description **/
320
+ CUBLASLT_POINTER_MODE_MASK_NO_FILTERING = 0,
321
+ /** see CUBLASLT_POINTER_MODE_HOST */
322
+ CUBLASLT_POINTER_MODE_MASK_HOST = 1,
323
+ /** see CUBLASLT_POINTER_MODE_DEVICE */
324
+ CUBLASLT_POINTER_MODE_MASK_DEVICE = 2,
325
+ /** see CUBLASLT_POINTER_MODE_DEVICE_VECTOR */
326
+ CUBLASLT_POINTER_MODE_MASK_DEVICE_VECTOR = 4,
327
+ /** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO */
328
+ CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_ZERO = 8,
329
+ /** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST */
330
+ CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_HOST = 16,
331
+ } cublasLtPointerModeMask_t;
332
+
333
+ /** Implementation details that may affect numerical behavior of algorithms. */
334
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA (0x01ull << 0)
335
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA (0x02ull << 0)
336
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_IMMA (0x04ull << 0)
337
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_DMMA (0x08ull << 0)
338
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK (0xfeull << 0)
339
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK (0xffull << 0)
340
+
341
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_16F (0x01ull << 8)
342
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32F (0x02ull << 8)
343
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_64F (0x04ull << 8)
344
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32I (0x08ull << 8)
345
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_TYPE_MASK (0xffull << 8)
346
+
347
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16F (0x01ull << 16)
348
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16BF (0x02ull << 16)
349
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_TF32 (0x04ull << 16)
350
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_32F (0x08ull << 16)
351
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_64F (0x10ull << 16)
352
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8I (0x20ull << 16)
353
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E4M3 (0x40ull << 16)
354
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E5M2 (0x80ull << 16)
355
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK (0xffull << 16)
356
+
357
+ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN (0x01ull << 32)
358
+ typedef uint64_t cublasLtNumericalImplFlags_t;
359
+
360
+ /** Execute matrix multiplication (D = alpha * op(A) * op(B) + beta * C).
361
+ *
362
+ * \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized
363
+ * \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g.
364
+ * when workspaceSizeInBytes is less than workspace required by configured
365
+ * algo
366
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured
367
+ * operation
368
+ * \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device
369
+ * \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device
370
+ * \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully
371
+ */
372
+ cublasStatus_t CUBLASWINAPI cublasLtMatmul(cublasLtHandle_t lightHandle,
373
+ cublasLtMatmulDesc_t computeDesc,
374
+ const void* alpha, /* host or device pointer */
375
+ const void* A,
376
+ cublasLtMatrixLayout_t Adesc,
377
+ const void* B,
378
+ cublasLtMatrixLayout_t Bdesc,
379
+ const void* beta, /* host or device pointer */
380
+ const void* C,
381
+ cublasLtMatrixLayout_t Cdesc,
382
+ void* D,
383
+ cublasLtMatrixLayout_t Ddesc,
384
+ const cublasLtMatmulAlgo_t* algo,
385
+ void* workspace,
386
+ size_t workspaceSizeInBytes,
387
+ cudaStream_t stream);
388
+
389
+ /** Matrix layout conversion helper (C = alpha * op(A) + beta * op(B))
390
+ *
391
+ * Can be used to change memory order of data or to scale and shift the values.
392
+ *
393
+ * \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized
394
+ * \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g.
395
+ * when A is not NULL, but Adesc is NULL
396
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured
397
+ * operation
398
+ * \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device
399
+ * \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device
400
+ * \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully
401
+ */
402
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(cublasLtHandle_t lightHandle,
403
+ cublasLtMatrixTransformDesc_t transformDesc,
404
+ const void* alpha, /* host or device pointer */
405
+ const void* A,
406
+ cublasLtMatrixLayout_t Adesc,
407
+ const void* beta, /* host or device pointer */
408
+ const void* B,
409
+ cublasLtMatrixLayout_t Bdesc,
410
+ void* C,
411
+ cublasLtMatrixLayout_t Cdesc,
412
+ cudaStream_t stream);
413
+
414
+ /* ---------------------------------------------------------------------------------------*/
415
+ /* Helper functions for cublasLtMatrixLayout_t */
416
+ /* ---------------------------------------------------------------------------------------*/
417
+
418
+ /** Enum for data ordering */
419
+ typedef enum {
420
+ /** Column-major
421
+ *
422
+ * Leading dimension is the stride (in elements) to the beginning of next column in memory.
423
+ */
424
+ CUBLASLT_ORDER_COL = 0,
425
+ /** Row major
426
+ *
427
+ * Leading dimension is the stride (in elements) to the beginning of next row in memory.
428
+ */
429
+ CUBLASLT_ORDER_ROW = 1,
430
+ /** Column-major ordered tiles of 32 columns.
431
+ *
432
+ * Leading dimension is the stride (in elements) to the beginning of next group of 32-columns. E.g. if matrix has 33
433
+ * columns and 2 rows, ld must be at least (32) * 2 = 64.
434
+ */
435
+ CUBLASLT_ORDER_COL32 = 2,
436
+ /** Column-major ordered tiles of composite tiles with total 32 columns and 8 rows, tile composed of interleaved
437
+ * inner tiles of 4 columns within 4 even or odd rows in an alternating pattern.
438
+ *
439
+ * Leading dimension is the stride (in elements) to the beginning of the first 32 column x 8 row tile for the next
440
+ * 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32 * 8) * 1 = 256.
441
+ */
442
+ CUBLASLT_ORDER_COL4_4R2_8C = 3,
443
+ /** Column-major ordered tiles of composite tiles with total 32 columns ands 32 rows.
444
+ * Element offset within the tile is calculated as (((row%8)/2*4+row/8)*2+row%2)*32+col.
445
+ *
446
+ * Leading dimension is the stride (in elements) to the beginning of the first 32 column x 32 row tile for the next
447
+ * 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32*32)*1 = 1024.
448
+ */
449
+ CUBLASLT_ORDER_COL32_2R_4R4 = 4,
450
+
451
+ } cublasLtOrder_t;
452
+
453
+ /** Attributes of memory layout */
454
+ typedef enum {
455
+ /** Data type, see cudaDataType.
456
+ *
457
+ * uint32_t
458
+ */
459
+ CUBLASLT_MATRIX_LAYOUT_TYPE = 0,
460
+
461
+ /** Memory order of the data, see cublasLtOrder_t.
462
+ *
463
+ * int32_t, default: CUBLASLT_ORDER_COL
464
+ */
465
+ CUBLASLT_MATRIX_LAYOUT_ORDER = 1,
466
+
467
+ /** Number of rows.
468
+ *
469
+ * Usually only values that can be expressed as int32_t are supported.
470
+ *
471
+ * uint64_t
472
+ */
473
+ CUBLASLT_MATRIX_LAYOUT_ROWS = 2,
474
+
475
+ /** Number of columns.
476
+ *
477
+ * Usually only values that can be expressed as int32_t are supported.
478
+ *
479
+ * uint64_t
480
+ */
481
+ CUBLASLT_MATRIX_LAYOUT_COLS = 3,
482
+
483
+ /** Matrix leading dimension.
484
+ *
485
+ * For CUBLASLT_ORDER_COL this is stride (in elements) of matrix column, for more details and documentation for
486
+ * other memory orders see documentation for cublasLtOrder_t values.
487
+ *
488
+ * Currently only non-negative values are supported, must be large enough so that matrix memory locations are not
489
+ * overlapping (e.g. greater or equal to CUBLASLT_MATRIX_LAYOUT_ROWS in case of CUBLASLT_ORDER_COL).
490
+ *
491
+ * int64_t;
492
+ */
493
+ CUBLASLT_MATRIX_LAYOUT_LD = 4,
494
+
495
+ /** Number of matmul operations to perform in the batch.
496
+ *
497
+ * See also CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT
498
+ *
499
+ * int32_t, default: 1
500
+ */
501
+ CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5,
502
+
503
+ /** Stride (in elements) to the next matrix for strided batch operation.
504
+ *
505
+ * When matrix type is planar-complex (CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET != 0), batch stride
506
+ * is interpreted by cublasLtMatmul() in number of real valued sub-elements. E.g. for data of type CUDA_C_16F,
507
+ * offset of 1024B is encoded as a stride of value 512 (since each element of the real and imaginary matrices
508
+ * is a 2B (16bit) floating point type).
509
+ *
510
+ * NOTE: A bug in cublasLtMatrixTransform() causes it to interpret the batch stride for a planar-complex matrix
511
+ * as if it was specified in number of complex elements. Therefore an offset of 1024B must be encoded as stride
512
+ * value 256 when calling cublasLtMatrixTransform() (each complex element is 4B with real and imaginary values 2B
513
+ * each). This behavior is expected to be corrected in the next major cuBLAS version.
514
+ *
515
+ * int64_t, default: 0
516
+ */
517
+ CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6,
518
+
519
+ /** Stride (in bytes) to the imaginary plane for planar complex layout.
520
+ *
521
+ * int64_t, default: 0 - 0 means that layout is regular (real and imaginary parts of complex numbers are interleaved
522
+ * in memory in each element)
523
+ */
524
+ CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET = 7,
525
+ } cublasLtMatrixLayoutAttribute_t;
526
+
527
+ /** Internal. Do not use directly.
528
+ */
529
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( //
530
+ cublasLtMatrixLayout_t matLayout,
531
+ size_t size,
532
+ cudaDataType type,
533
+ uint64_t rows,
534
+ uint64_t cols,
535
+ int64_t ld);
536
+
537
+ /** Initialize matrix layout descriptor in pre-allocated space.
538
+ *
539
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
540
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
541
+ */
542
+ static inline cublasStatus_t cublasLtMatrixLayoutInit(
543
+ cublasLtMatrixLayout_t matLayout, cudaDataType type, uint64_t rows, uint64_t cols, int64_t ld) {
544
+ return cublasLtMatrixLayoutInit_internal(matLayout, sizeof(*matLayout), type, rows, cols, ld);
545
+ }
546
+
547
+ /** Create new matrix layout descriptor.
548
+ *
549
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
550
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
551
+ */
552
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( //
553
+ cublasLtMatrixLayout_t* matLayout,
554
+ cudaDataType type,
555
+ uint64_t rows,
556
+ uint64_t cols,
557
+ int64_t ld);
558
+
559
+ /** Destroy matrix layout descriptor.
560
+ *
561
+ * \retval CUBLAS_STATUS_SUCCESS if operation was successful
562
+ */
563
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout);
564
+
565
+ /** Set matrix layout descriptor attribute.
566
+ *
567
+ * \param[in] matLayout The descriptor
568
+ * \param[in] attr The attribute
569
+ * \param[in] buf memory address containing the new value
570
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
571
+ *
572
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
573
+ * selected attribute
574
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
575
+ */
576
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( //
577
+ cublasLtMatrixLayout_t matLayout,
578
+ cublasLtMatrixLayoutAttribute_t attr,
579
+ const void* buf,
580
+ size_t sizeInBytes);
581
+
582
+ /** Get matrix layout descriptor attribute.
583
+ *
584
+ * \param[in] matLayout The descriptor
585
+ * \param[in] attr The attribute
586
+ * \param[out] buf memory address containing the new value
587
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
588
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
589
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
590
+ *
591
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
592
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
593
+ * selected attribute
594
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
595
+ */
596
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( //
597
+ cublasLtMatrixLayout_t matLayout,
598
+ cublasLtMatrixLayoutAttribute_t attr,
599
+ void* buf,
600
+ size_t sizeInBytes,
601
+ size_t* sizeWritten);
602
+
603
+ /* ---------------------------------------------------------------------------------------*/
604
+ /* Helper functions for cublasLtMatmulDesc_t */
605
+ /* ---------------------------------------------------------------------------------------*/
606
+
607
+ /** Matmul descriptor attributes to define details of the operation. */
608
+ typedef enum {
609
+ /** Compute type, see cudaDataType. Defines data type used for multiply and accumulate operations and the
610
+ * accumulator during matrix multiplication.
611
+ *
612
+ * int32_t
613
+ */
614
+ CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0,
615
+
616
+ /** Scale type, see cudaDataType. Defines data type of alpha and beta. Accumulator and value from matrix C are
617
+ * typically converted to scale type before final scaling. Value is then converted from scale type to type of matrix
618
+ * D before being stored in memory.
619
+ *
620
+ * int32_t, default: same as CUBLASLT_MATMUL_DESC_COMPUTE_TYPE
621
+ */
622
+ CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1,
623
+
624
+ /** Pointer mode of alpha and beta, see cublasLtPointerMode_t. When CUBLASLT_POINTER_MODE_DEVICE_VECTOR is in use,
625
+ * alpha/beta vector lenghts must match number of output matrix rows.
626
+ *
627
+ * int32_t, default: CUBLASLT_POINTER_MODE_HOST
628
+ */
629
+ CUBLASLT_MATMUL_DESC_POINTER_MODE = 2,
630
+
631
+ /** Transform of matrix A, see cublasOperation_t.
632
+ *
633
+ * int32_t, default: CUBLAS_OP_N
634
+ */
635
+ CUBLASLT_MATMUL_DESC_TRANSA = 3,
636
+
637
+ /** Transform of matrix B, see cublasOperation_t.
638
+ *
639
+ * int32_t, default: CUBLAS_OP_N
640
+ */
641
+ CUBLASLT_MATMUL_DESC_TRANSB = 4,
642
+
643
+ /** Transform of matrix C, see cublasOperation_t.
644
+ *
645
+ * Currently only CUBLAS_OP_N is supported.
646
+ *
647
+ * int32_t, default: CUBLAS_OP_N
648
+ */
649
+ CUBLASLT_MATMUL_DESC_TRANSC = 5,
650
+
651
+ /** Matrix fill mode, see cublasFillMode_t.
652
+ *
653
+ * int32_t, default: CUBLAS_FILL_MODE_FULL
654
+ */
655
+ CUBLASLT_MATMUL_DESC_FILL_MODE = 6,
656
+
657
+ /** Epilogue function, see cublasLtEpilogue_t.
658
+ *
659
+ * uint32_t, default: CUBLASLT_EPILOGUE_DEFAULT
660
+ */
661
+ CUBLASLT_MATMUL_DESC_EPILOGUE = 7,
662
+
663
+ /** Bias or bias gradient vector pointer in the device memory.
664
+ *
665
+ * Bias case. See CUBLASLT_EPILOGUE_BIAS.
666
+ * For bias data type see CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE.
667
+ *
668
+ * Bias vector length must match matrix D rows count.
669
+ *
670
+ * Bias gradient case. See CUBLASLT_EPILOGUE_DRELU_BGRAD and CUBLASLT_EPILOGUE_DGELU_BGRAD.
671
+ * Bias gradient vector elements are the same type as the output elements
672
+ * (Ctype) with the exception of IMMA kernels (see above).
673
+ *
674
+ * Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic()
675
+ * depend on its value to determine expected pointer alignment.
676
+ *
677
+ * Bias case: const void *, default: NULL
678
+ * Bias gradient case: void *, default: NULL
679
+ */
680
+ CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8,
681
+
682
+ /** Batch stride for bias or bias gradient vector.
683
+ *
684
+ * Used together with CUBLASLT_MATMUL_DESC_BIAS_POINTER when matrix D's CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1.
685
+ *
686
+ * int64_t, default: 0
687
+ */
688
+ CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE = 10,
689
+
690
+ /** Pointer for epilogue auxiliary buffer.
691
+ *
692
+ * - Output vector for ReLu bit-mask in forward pass when CUBLASLT_EPILOGUE_RELU_AUX
693
+ * or CUBLASLT_EPILOGUE_RELU_AUX_BIAS epilogue is used.
694
+ * - Input vector for ReLu bit-mask in backward pass when
695
+ * CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is used.
696
+ *
697
+ * - Output of GELU input matrix in forward pass when
698
+ * CUBLASLT_EPILOGUE_GELU_AUX_BIAS epilogue is used.
699
+ * - Input of GELU input matrix for backward pass when
700
+ * CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue is used.
701
+ *
702
+ * For aux data type see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE.
703
+ *
704
+ * Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic()
705
+ * depend on its value to determine expected pointer alignment.
706
+ *
707
+ * Requires setting CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD attribute.
708
+ *
709
+ * Forward pass: void *, default: NULL
710
+ * Backward pass: const void *, default: NULL
711
+ */
712
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER = 11,
713
+
714
+ /** Leading dimension for epilogue auxiliary buffer.
715
+ *
716
+ * - ReLu bit-mask matrix leading dimension in elements (i.e. bits)
717
+ * when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is
718
+ * used. Must be divisible by 128 and be no less than the number of rows in the output matrix.
719
+ *
720
+ * - GELU input matrix leading dimension in elements
721
+ * when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used.
722
+ * Must be divisible by 8 and be no less than the number of rows in the output matrix.
723
+ *
724
+ * int64_t, default: 0
725
+ */
726
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD = 12,
727
+
728
+ /** Batch stride for epilogue auxiliary buffer.
729
+ *
730
+ * - ReLu bit-mask matrix batch stride in elements (i.e. bits)
731
+ * when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is
732
+ * used. Must be divisible by 128.
733
+ *
734
+ * - GELU input matrix batch stride in elements
735
+ * when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used.
736
+ * Must be divisible by 8.
737
+ *
738
+ * int64_t, default: 0
739
+ */
740
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE = 13,
741
+
742
+ /** Batch stride for alpha vector.
743
+ *
744
+ * Used together with CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST when matrix D's
745
+ * CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1. If CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO is set then
746
+ * CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE must be set to 0 as this mode doesnt supported batched alpha vector.
747
+ *
748
+ * int64_t, default: 0
749
+ */
750
+ CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE = 14,
751
+
752
+ /** Number of SMs to target for parallel execution. Optimizes heuristics for execution on a different number of SMs
753
+ * when user expects a concurrent stream to be using some of the device resources.
754
+ *
755
+ * int32_t, default: 0 - use the number reported by the device.
756
+ */
757
+ CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 15,
758
+
759
+ /** Device pointer to the scale factor value that converts data in matrix A to the compute data type range.
760
+ *
761
+ * The scaling factor value must have the same type as the compute type.
762
+ *
763
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1.
764
+ *
765
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
766
+ * will return CUBLAS_INVALID_VALUE.
767
+ *
768
+ * const void *, default: NULL
769
+ */
770
+ CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 17,
771
+
772
+ /** Device pointer to the scale factor value to convert data in matrix B to compute data type range.
773
+ *
774
+ * The scaling factor value must have the same type as the compute type.
775
+ *
776
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1.
777
+ *
778
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
779
+ * will return CUBLAS_INVALID_VALUE.
780
+ *
781
+ * const void *, default: NULL
782
+ */
783
+ CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 18,
784
+
785
+ /** Device pointer to the scale factor value to convert data in matrix C to compute data type range.
786
+ *
787
+ * The scaling factor value must have the same type as the compute type.
788
+ *
789
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1.
790
+ *
791
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
792
+ * will return CUBLAS_INVALID_VALUE.
793
+ *
794
+ * const void *, default: NULL
795
+ */
796
+ CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 19,
797
+
798
+ /** Device pointer to the scale factor value to convert data in matrix D to compute data type range.
799
+ *
800
+ * The scaling factor value must have the same type as the compute type.
801
+ *
802
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1.
803
+ *
804
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
805
+ * will return CUBLAS_INVALID_VALUE.
806
+ *
807
+ * const void *, default: NULL
808
+ */
809
+ CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 20,
810
+
811
+ /** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the
812
+ * output matrix.
813
+ *
814
+ * The computed value has the same type as the compute type.
815
+ *
816
+ * If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix
817
+ * data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
818
+ *
819
+ * void *, default: NULL
820
+ */
821
+ CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 21,
822
+
823
+ /** Type of the data to be stored to the memory pointed to by CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
824
+ *
825
+ * If unset, the data type defaults to the type of elements of the output matrix with some exceptions, see details
826
+ * below.
827
+ *
828
+ * ReLu uses a bit-mask.
829
+ *
830
+ * GELU input matrix elements type is the same as the type of elements of
831
+ * the output matrix with some exceptions, see details below.
832
+ *
833
+ * For fp8 kernels with output type CUDA_R_8F_E4M3 the aux data type can be CUDA_R_8F_E4M3 or CUDA_R_16F with some
834
+ * restrictions. See https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulDescAttributes_t for more details.
835
+ *
836
+ * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
837
+ * will return CUBLAS_INVALID_VALUE.
838
+ *
839
+ * int32_t based on cudaDataType, default: -1
840
+ */
841
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE = 22,
842
+
843
+ /** Device pointer to the scaling factor value to convert results from compute type data range to storage
844
+ * data range in the auxiliary matrix that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
845
+ *
846
+ * The scaling factor value must have the same type as the compute type.
847
+ *
848
+ * If not specified, or set to NULL, the scaling factor is assumed to be 1. If set for an unsupported matrix data,
849
+ * scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
850
+ *
851
+ * void *, default: NULL
852
+ */
853
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER = 23,
854
+
855
+ /** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the
856
+ * buffer that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
857
+ *
858
+ * The computed value has the same type as the compute type.
859
+ *
860
+ * If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix
861
+ * data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
862
+ *
863
+ * void *, default: NULL
864
+ */
865
+ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER = 24,
866
+
867
+ /** Flag for managing fp8 fast accumulation mode.
868
+ * When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results
869
+ * will not periodically be promoted to a higher precision.
870
+ *
871
+ * int8_t, default: 0 - fast accumulation mode is disabled.
872
+ */
873
+ CUBLASLT_MATMUL_DESC_FAST_ACCUM = 25,
874
+
875
+ /** Type of bias or bias gradient vector in the device memory.
876
+ *
877
+ * Bias case: see CUBLASLT_EPILOGUE_BIAS.
878
+ *
879
+ * Bias vector elements are the same type as the elements of output matrix (Dtype) with the following exceptions:
880
+ * - IMMA kernels with computeType=CUDA_R_32I and Ctype=CUDA_R_8I where the bias vector elements
881
+ * are the same type as alpha, beta (CUBLASLT_MATMUL_DESC_SCALE_TYPE=CUDA_R_32F)
882
+ * - fp8 kernels with an output type of CUDA_R_32F, CUDA_R_8F_E4M3 or CUDA_R_8F_E5M2, See
883
+ * https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul for details.
884
+ *
885
+ * int32_t based on cudaDataType, default: -1
886
+ */
887
+ CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 26,
888
+ } cublasLtMatmulDescAttributes_t;
889
+
890
+ /** Internal. Do not use directly.
891
+ */
892
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( //
893
+ cublasLtMatmulDesc_t matmulDesc,
894
+ size_t size,
895
+ cublasComputeType_t computeType,
896
+ cudaDataType_t scaleType);
897
+
898
+ /** Initialize matmul operation descriptor in pre-allocated space.
899
+ *
900
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
901
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was initialized successfully
902
+ */
903
+ static inline cublasStatus_t cublasLtMatmulDescInit( //
904
+ cublasLtMatmulDesc_t matmulDesc,
905
+ cublasComputeType_t computeType,
906
+ cudaDataType_t scaleType) {
907
+ return cublasLtMatmulDescInit_internal(matmulDesc, sizeof(*matmulDesc), computeType, scaleType);
908
+ }
909
+
910
+ /** Create new matmul operation descriptor.
911
+ *
912
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
913
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
914
+ */
915
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate(cublasLtMatmulDesc_t* matmulDesc,
916
+ cublasComputeType_t computeType,
917
+ cudaDataType_t scaleType);
918
+
919
+ /** Destroy matmul operation descriptor.
920
+ *
921
+ * \retval CUBLAS_STATUS_SUCCESS if operation was successful
922
+ */
923
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc);
924
+
925
+ /** Set matmul operation descriptor attribute.
926
+ *
927
+ * \param[in] matmulDesc The descriptor
928
+ * \param[in] attr The attribute
929
+ * \param[in] buf memory address containing the new value
930
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
931
+ *
932
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
933
+ * selected attribute
934
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
935
+ */
936
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( //
937
+ cublasLtMatmulDesc_t matmulDesc,
938
+ cublasLtMatmulDescAttributes_t attr,
939
+ const void* buf,
940
+ size_t sizeInBytes);
941
+
942
+ /** Get matmul operation descriptor attribute.
943
+ *
944
+ * \param[in] matmulDesc The descriptor
945
+ * \param[in] attr The attribute
946
+ * \param[out] buf memory address containing the new value
947
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
948
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
949
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
950
+ *
951
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
952
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
953
+ * selected attribute
954
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
955
+ */
956
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( //
957
+ cublasLtMatmulDesc_t matmulDesc,
958
+ cublasLtMatmulDescAttributes_t attr,
959
+ void* buf,
960
+ size_t sizeInBytes,
961
+ size_t* sizeWritten);
962
+
963
+ /* ---------------------------------------------------------------------------------------*/
964
+ /* Helper functions for cublasLtMatrixTransformDesc_t */
965
+ /* ---------------------------------------------------------------------------------------*/
966
+
967
+ /** Matrix transform descriptor attributes to define details of the operation.
968
+ */
969
+ typedef enum {
970
+ /** Scale type, see cudaDataType. Inputs are converted to scale type for scaling and summation and results are then
971
+ * converted to output type to store in memory.
972
+ *
973
+ * int32_t
974
+ */
975
+ CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE,
976
+
977
+ /** Pointer mode of alpha and beta, see cublasLtPointerMode_t.
978
+ *
979
+ * int32_t, default: CUBLASLT_POINTER_MODE_HOST
980
+ */
981
+ CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE,
982
+
983
+ /** Transform of matrix A, see cublasOperation_t.
984
+ *
985
+ * int32_t, default: CUBLAS_OP_N
986
+ */
987
+ CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA,
988
+
989
+ /** Transform of matrix B, see cublasOperation_t.
990
+ *
991
+ * int32_t, default: CUBLAS_OP_N
992
+ */
993
+ CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSB,
994
+ } cublasLtMatrixTransformDescAttributes_t;
995
+
996
+ /** Internal. Do not use directly.
997
+ */
998
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc,
999
+ size_t size,
1000
+ cudaDataType scaleType);
1001
+
1002
+ /** Initialize matrix transform operation descriptor in pre-allocated space.
1003
+ *
1004
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
1005
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
1006
+ */
1007
+ static inline cublasStatus_t cublasLtMatrixTransformDescInit(cublasLtMatrixTransformDesc_t transformDesc,
1008
+ cudaDataType scaleType) {
1009
+ return cublasLtMatrixTransformDescInit_internal(transformDesc, sizeof(*transformDesc), scaleType);
1010
+ }
1011
+
1012
+ /** Create new matrix transform operation descriptor.
1013
+ *
1014
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
1015
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
1016
+ */
1017
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t* transformDesc,
1018
+ cudaDataType scaleType);
1019
+
1020
+ /** Destroy matrix transform operation descriptor.
1021
+ *
1022
+ * \retval CUBLAS_STATUS_SUCCESS if operation was successful
1023
+ */
1024
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc);
1025
+
1026
+ /** Set matrix transform operation descriptor attribute.
1027
+ *
1028
+ * \param[in] transformDesc The descriptor
1029
+ * \param[in] attr The attribute
1030
+ * \param[in] buf memory address containing the new value
1031
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1032
+ *
1033
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
1034
+ * selected attribute
1035
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
1036
+ */
1037
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( //
1038
+ cublasLtMatrixTransformDesc_t transformDesc,
1039
+ cublasLtMatrixTransformDescAttributes_t attr,
1040
+ const void* buf,
1041
+ size_t sizeInBytes);
1042
+
1043
+ /** Get matrix transform operation descriptor attribute.
1044
+ *
1045
+ * \param[in] transformDesc The descriptor
1046
+ * \param[in] attr The attribute
1047
+ * \param[out] buf memory address containing the new value
1048
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1049
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number
1050
+ * of bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
1051
+ *
1052
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
1053
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
1054
+ * selected attribute
1055
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
1056
+ */
1057
+ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( //
1058
+ cublasLtMatrixTransformDesc_t transformDesc,
1059
+ cublasLtMatrixTransformDescAttributes_t attr,
1060
+ void* buf,
1061
+ size_t sizeInBytes,
1062
+ size_t* sizeWritten);
1063
+
1064
+ /** For computation with complex numbers, this enum allows to apply the Gauss Complexity reduction algorithm
1065
+ */
1066
+ typedef enum {
1067
+ CUBLASLT_3M_MODE_DISALLOWED = 0,
1068
+ CUBLASLT_3M_MODE_ALLOWED = 1,
1069
+ } cublasLt3mMode_t;
1070
+
1071
+ /** Reduction scheme for portions of the dot-product calculated in parallel (a. k. a. "split - K").
1072
+ */
1073
+ typedef enum {
1074
+ /** No reduction scheme, dot-product shall be performed in one sequence.
1075
+ */
1076
+ CUBLASLT_REDUCTION_SCHEME_NONE = 0,
1077
+
1078
+ /** Reduction is performed "in place" - using the output buffer (and output data type) and counters (in workspace) to
1079
+ * guarantee the sequentiality.
1080
+ */
1081
+ CUBLASLT_REDUCTION_SCHEME_INPLACE = 1,
1082
+
1083
+ /** Intermediate results are stored in compute type in the workspace and reduced in a separate step.
1084
+ */
1085
+ CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE = 2,
1086
+
1087
+ /** Intermediate results are stored in output type in the workspace and reduced in a separate step.
1088
+ */
1089
+ CUBLASLT_REDUCTION_SCHEME_OUTPUT_TYPE = 4,
1090
+
1091
+ CUBLASLT_REDUCTION_SCHEME_MASK = 0x7,
1092
+ } cublasLtReductionScheme_t;
1093
+
1094
+ /** Postprocessing options for the epilogue
1095
+ */
1096
+ typedef enum {
1097
+ /** No special postprocessing, just scale and quantize results if necessary.
1098
+ */
1099
+ CUBLASLT_EPILOGUE_DEFAULT = 1,
1100
+
1101
+ /** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)).
1102
+ */
1103
+ CUBLASLT_EPILOGUE_RELU = 2,
1104
+
1105
+ /** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)).
1106
+ *
1107
+ * This epilogue mode produces an extra output, a ReLu bit-mask matrix,
1108
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1109
+ */
1110
+ CUBLASLT_EPILOGUE_RELU_AUX = (CUBLASLT_EPILOGUE_RELU | 128),
1111
+
1112
+ /** Bias, apply (broadcasted) Bias from bias vector. Bias vector length must match matrix D rows, it must be packed
1113
+ * (stride between vector elements is 1). Bias vector is broadcasted to all columns and added before applying final
1114
+ * postprocessing.
1115
+ */
1116
+ CUBLASLT_EPILOGUE_BIAS = 4,
1117
+
1118
+ /** ReLu and Bias, apply Bias and then ReLu transform
1119
+ */
1120
+ CUBLASLT_EPILOGUE_RELU_BIAS = (CUBLASLT_EPILOGUE_RELU | CUBLASLT_EPILOGUE_BIAS),
1121
+
1122
+ /** ReLu and Bias, apply Bias and then ReLu transform
1123
+ *
1124
+ * This epilogue mode produces an extra output, a ReLu bit-mask matrix,
1125
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1126
+ */
1127
+ CUBLASLT_EPILOGUE_RELU_AUX_BIAS = (CUBLASLT_EPILOGUE_RELU_AUX | CUBLASLT_EPILOGUE_BIAS),
1128
+
1129
+ /* ReLu gradient. Apply ReLu gradient to matmul output. Store ReLu gradient in the output matrix.
1130
+ *
1131
+ * This epilogue mode requires an extra input,
1132
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1133
+ */
1134
+ CUBLASLT_EPILOGUE_DRELU = 8 | 128,
1135
+
1136
+ /* ReLu and Bias gradients. Apply independently ReLu and Bias gradient to
1137
+ * matmul output. Store ReLu gradient in the output matrix, and Bias gradient
1138
+ * in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
1139
+ *
1140
+ * This epilogue mode requires an extra input,
1141
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1142
+ */
1143
+ CUBLASLT_EPILOGUE_DRELU_BGRAD = CUBLASLT_EPILOGUE_DRELU | 16,
1144
+
1145
+ /** GELU, apply GELU point-wise transform to the results (x:=GELU(x)).
1146
+ */
1147
+ CUBLASLT_EPILOGUE_GELU = 32,
1148
+
1149
+ /** GELU, apply GELU point-wise transform to the results (x:=GELU(x)).
1150
+ *
1151
+ * This epilogue mode outputs GELU input as a separate matrix (useful for training).
1152
+ * See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1153
+ */
1154
+ CUBLASLT_EPILOGUE_GELU_AUX = (CUBLASLT_EPILOGUE_GELU | 128),
1155
+
1156
+ /** GELU and Bias, apply Bias and then GELU transform
1157
+ */
1158
+ CUBLASLT_EPILOGUE_GELU_BIAS = (CUBLASLT_EPILOGUE_GELU | CUBLASLT_EPILOGUE_BIAS),
1159
+
1160
+ /** GELU and Bias, apply Bias and then GELU transform
1161
+ *
1162
+ * This epilogue mode outputs GELU input as a separate matrix (useful for training).
1163
+ * See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1164
+ */
1165
+ CUBLASLT_EPILOGUE_GELU_AUX_BIAS = (CUBLASLT_EPILOGUE_GELU_AUX | CUBLASLT_EPILOGUE_BIAS),
1166
+
1167
+ /* GELU gradient. Apply GELU gradient to matmul output. Store GELU gradient in the output matrix.
1168
+ *
1169
+ * This epilogue mode requires an extra input,
1170
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1171
+ */
1172
+ CUBLASLT_EPILOGUE_DGELU = 64 | 128,
1173
+
1174
+ /* GELU and Bias gradients. Apply independently GELU and Bias gradient to
1175
+ * matmul output. Store GELU gradient in the output matrix, and Bias gradient
1176
+ * in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
1177
+ *
1178
+ * This epilogue mode requires an extra input,
1179
+ * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
1180
+ */
1181
+ CUBLASLT_EPILOGUE_DGELU_BGRAD = CUBLASLT_EPILOGUE_DGELU | 16,
1182
+
1183
+ /** Bias gradient based on the input matrix A.
1184
+ *
1185
+ * The bias size corresponds to the number of rows of the matrix D.
1186
+ * The reduction happens over the GEMM's "k" dimension.
1187
+ *
1188
+ * Stores Bias gradient in the auxiliary output
1189
+ * (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
1190
+ */
1191
+ CUBLASLT_EPILOGUE_BGRADA = 256,
1192
+
1193
+ /** Bias gradient based on the input matrix B.
1194
+ *
1195
+ * The bias size corresponds to the number of columns of the matrix D.
1196
+ * The reduction happens over the GEMM's "k" dimension.
1197
+ *
1198
+ * Stores Bias gradient in the auxiliary output
1199
+ * (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
1200
+ */
1201
+ CUBLASLT_EPILOGUE_BGRADB = 512,
1202
+ } cublasLtEpilogue_t;
1203
+
1204
+ /** Matmul heuristic search mode
1205
+ */
1206
+ typedef enum {
1207
+ /** ask heuristics for best algo for given usecase
1208
+ */
1209
+ CUBLASLT_SEARCH_BEST_FIT = 0,
1210
+ /** only try to find best config for preconfigured algo id
1211
+ */
1212
+ CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID = 1,
1213
+ /** reserved for future use
1214
+ */
1215
+ CUBLASLT_SEARCH_RESERVED_02 = 2,
1216
+ /** reserved for future use
1217
+ */
1218
+ CUBLASLT_SEARCH_RESERVED_03 = 3,
1219
+ /** reserved for future use
1220
+ */
1221
+ CUBLASLT_SEARCH_RESERVED_04 = 4,
1222
+ /** reserved for future use
1223
+ */
1224
+ CUBLASLT_SEARCH_RESERVED_05 = 5,
1225
+ } cublasLtMatmulSearch_t;
1226
+
1227
+ /** Algo search preference to fine tune the heuristic function. */
1228
+ typedef enum {
1229
+ /** Search mode, see cublasLtMatmulSearch_t.
1230
+ *
1231
+ * uint32_t, default: CUBLASLT_SEARCH_BEST_FIT
1232
+ */
1233
+ CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0,
1234
+
1235
+ /** Maximum allowed workspace size in bytes.
1236
+ *
1237
+ * uint64_t, default: 0 - no workspace allowed
1238
+ */
1239
+ CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1,
1240
+
1241
+ /** Math mode mask, see cublasMath_t.
1242
+ *
1243
+ * Only algorithms with CUBLASLT_ALGO_CAP_MATHMODE_IMPL that is not masked out by this attribute are allowed.
1244
+ *
1245
+ * uint32_t, default: 1 (allows both default and tensor op math)
1246
+ * DEPRECATED, will be removed in a future release, see cublasLtNumericalImplFlags_t for replacement
1247
+ */
1248
+ CUBLASLT_MATMUL_PREF_MATH_MODE_MASK = 2,
1249
+
1250
+ /** Reduction scheme mask, see cublasLtReductionScheme_t. Filters heuristic result to only include algo configs that
1251
+ * use one of the required modes.
1252
+ *
1253
+ * E.g. mask value of 0x03 will allow only INPLACE and COMPUTE_TYPE reduction schemes.
1254
+ *
1255
+ * uint32_t, default: CUBLASLT_REDUCTION_SCHEME_MASK (allows all reduction schemes)
1256
+ */
1257
+ CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3,
1258
+
1259
+ /** Gaussian mode mask, see cublasLt3mMode_t.
1260
+ *
1261
+ * Only algorithms with CUBLASLT_ALGO_CAP_GAUSSIAN_IMPL that is not masked out by this attribute are allowed.
1262
+ *
1263
+ * uint32_t, default: CUBLASLT_3M_MODE_ALLOWED (allows both gaussian and non-gaussian algorithms)
1264
+ * DEPRECATED, will be removed in a future release, see cublasLtNumericalImplFlags_t for replacement
1265
+ */
1266
+ CUBLASLT_MATMUL_PREF_GAUSSIAN_MODE_MASK = 4,
1267
+
1268
+ /** Minimum buffer alignment for matrix A (in bytes).
1269
+ *
1270
+ * Selecting a smaller value will exclude algorithms that can not work with matrix A that is not as strictly aligned
1271
+ * as they need.
1272
+ *
1273
+ * uint32_t, default: 256
1274
+ */
1275
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5,
1276
+
1277
+ /** Minimum buffer alignment for matrix B (in bytes).
1278
+ *
1279
+ * Selecting a smaller value will exclude algorithms that can not work with matrix B that is not as strictly aligned
1280
+ * as they need.
1281
+ *
1282
+ * uint32_t, default: 256
1283
+ */
1284
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6,
1285
+
1286
+ /** Minimum buffer alignment for matrix C (in bytes).
1287
+ *
1288
+ * Selecting a smaller value will exclude algorithms that can not work with matrix C that is not as strictly aligned
1289
+ * as they need.
1290
+ *
1291
+ * uint32_t, default: 256
1292
+ */
1293
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7,
1294
+
1295
+ /** Minimum buffer alignment for matrix D (in bytes).
1296
+ *
1297
+ * Selecting a smaller value will exclude algorithms that can not work with matrix D that is not as strictly aligned
1298
+ * as they need.
1299
+ *
1300
+ * uint32_t, default: 256
1301
+ */
1302
+ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8,
1303
+
1304
+ /** Maximum wave count.
1305
+ *
1306
+ * See cublasLtMatmulHeuristicResult_t::wavesCount.
1307
+ *
1308
+ * Selecting a non-zero value will exclude algorithms that report device utilization higher than specified.
1309
+ *
1310
+ * float, default: 0.0f
1311
+ */
1312
+ CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9,
1313
+
1314
+ /** Pointer mode mask, see cublasLtPointerModeMask_t. Filters heuristic result to only include algorithms that support
1315
+ * all required modes.
1316
+ *
1317
+ * uint32_t, default: (CUBLASLT_POINTER_MODE_MASK_HOST | CUBLASLT_POINTER_MODE_MASK_DEVICE) (only allows algorithms
1318
+ * that support both regular host and device pointers)
1319
+ */
1320
+ CUBLASLT_MATMUL_PREF_POINTER_MODE_MASK = 10,
1321
+
1322
+ /** Epilogue selector mask, see cublasLtEpilogue_t. Filters heuristic result to only include algorithms that support
1323
+ * all required operations.
1324
+ *
1325
+ * uint32_t, default: CUBLASLT_EPILOGUE_DEFAULT (only allows algorithms that support default epilogue)
1326
+ */
1327
+ CUBLASLT_MATMUL_PREF_EPILOGUE_MASK = 11,
1328
+
1329
+ /** Numerical implementation details mask, see cublasLtNumericalImplFlags_t. Filters heuristic result to only include
1330
+ * algorithms that use the allowed implementations.
1331
+ *
1332
+ * uint64_t, default: uint64_t(-1) (allow everything)
1333
+ */
1334
+ CUBLASLT_MATMUL_PREF_IMPL_MASK = 12,
1335
+
1336
+ /** Number of SMs to target for parallel execution. Optimizes heuristics for execution on a different number of SMs
1337
+ * when user expects a concurrent stream to be using some of the device resources.
1338
+ *
1339
+ * Overrides the SM count target set in the matrix multiplication descriptor (see cublasLtMatmulDescAttributes_t).
1340
+ *
1341
+ * int32_t, default: 0 - use the number reported by the device.
1342
+ * DEPRECATED, will be removed in a future release, see cublasLtMatmulDescAttributes_t for replacement
1343
+ */
1344
+ CUBLASLT_MATMUL_PREF_SM_COUNT_TARGET = 13,
1345
+ } cublasLtMatmulPreferenceAttributes_t;
1346
+
1347
+ /** Internal. Do not use directly.
1348
+ */
1349
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size);
1350
+
1351
+ /** Initialize matmul heuristic search preference descriptor in pre-allocated space.
1352
+ *
1353
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
1354
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
1355
+ */
1356
+ static inline cublasStatus_t cublasLtMatmulPreferenceInit(cublasLtMatmulPreference_t pref) {
1357
+ return cublasLtMatmulPreferenceInit_internal(pref, sizeof(*pref));
1358
+ }
1359
+
1360
+ /** Create new matmul heuristic search preference descriptor.
1361
+ *
1362
+ * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
1363
+ * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
1364
+ */
1365
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t* pref);
1366
+
1367
+ /** Destroy matmul heuristic search preference descriptor.
1368
+ *
1369
+ * \retval CUBLAS_STATUS_SUCCESS if operation was successful
1370
+ */
1371
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref);
1372
+
1373
+ /** Set matmul heuristic search preference descriptor attribute.
1374
+ *
1375
+ * \param[in] pref The descriptor
1376
+ * \param[in] attr The attribute
1377
+ * \param[in] buf memory address containing the new value
1378
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1379
+ *
1380
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
1381
+ * selected attribute
1382
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
1383
+ */
1384
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( //
1385
+ cublasLtMatmulPreference_t pref,
1386
+ cublasLtMatmulPreferenceAttributes_t attr,
1387
+ const void* buf,
1388
+ size_t sizeInBytes);
1389
+
1390
+ /** Get matmul heuristic search preference descriptor attribute.
1391
+ *
1392
+ * \param[in] pref The descriptor
1393
+ * \param[in] attr The attribute
1394
+ * \param[out] buf memory address containing the new value
1395
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1396
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
1397
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
1398
+ *
1399
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
1400
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
1401
+ * selected attribute
1402
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
1403
+ */
1404
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( //
1405
+ cublasLtMatmulPreference_t pref,
1406
+ cublasLtMatmulPreferenceAttributes_t attr,
1407
+ void* buf,
1408
+ size_t sizeInBytes,
1409
+ size_t* sizeWritten);
1410
+
1411
+ /** Results structure used by cublasLtMatmulGetAlgo.
1412
+ *
1413
+ * Holds returned configured algo descriptor and its runtime properties.
1414
+ */
1415
+ typedef struct {
1416
+ /** Matmul algorithm descriptor.
1417
+ *
1418
+ * Must be initialized with cublasLtMatmulAlgoInit() if preferences' CUBLASLT_MATMUL_PERF_SEARCH_MODE is set to
1419
+ * CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID
1420
+ */
1421
+ cublasLtMatmulAlgo_t algo;
1422
+
1423
+ /** Actual size of workspace memory required.
1424
+ */
1425
+ size_t workspaceSize;
1426
+
1427
+ /** Result status, other fields are only valid if after call to cublasLtMatmulAlgoGetHeuristic() this member is set to
1428
+ * CUBLAS_STATUS_SUCCESS.
1429
+ */
1430
+ cublasStatus_t state;
1431
+
1432
+ /** Waves count - a device utilization metric.
1433
+ *
1434
+ * wavesCount value of 1.0f suggests that when kernel is launched it will fully occupy the GPU.
1435
+ */
1436
+ float wavesCount;
1437
+
1438
+ int reserved[4];
1439
+ } cublasLtMatmulHeuristicResult_t;
1440
+
1441
+ /** Query cublasLt heuristic for algorithm appropriate for given use case.
1442
+ *
1443
+ * \param[in] lightHandle Pointer to the allocated cuBLASLt handle for the cuBLASLt
1444
+ * context. See cublasLtHandle_t.
1445
+ * \param[in] operationDesc Handle to the matrix multiplication descriptor.
1446
+ * \param[in] Adesc Handle to the layout descriptors for matrix A.
1447
+ * \param[in] Bdesc Handle to the layout descriptors for matrix B.
1448
+ * \param[in] Cdesc Handle to the layout descriptors for matrix C.
1449
+ * \param[in] Ddesc Handle to the layout descriptors for matrix D.
1450
+ * \param[in] preference Pointer to the structure holding the heuristic search
1451
+ * preferences descriptor. See cublasLtMatrixLayout_t.
1452
+ * \param[in] requestedAlgoCount Size of heuristicResultsArray (in elements) and requested
1453
+ * maximum number of algorithms to return.
1454
+ * \param[in, out] heuristicResultsArray Output algorithms and associated runtime characteristics,
1455
+ * ordered in increasing estimated compute time.
1456
+ * \param[out] returnAlgoCount The number of heuristicResultsArray elements written.
1457
+ *
1458
+ * \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero
1459
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if no heuristic function available for current configuration
1460
+ * \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect
1461
+ * heuristicResultsArray[0 to (returnAlgoCount - 1)].state
1462
+ * for detail status of results
1463
+ */
1464
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic(cublasLtHandle_t lightHandle,
1465
+ cublasLtMatmulDesc_t operationDesc,
1466
+ cublasLtMatrixLayout_t Adesc,
1467
+ cublasLtMatrixLayout_t Bdesc,
1468
+ cublasLtMatrixLayout_t Cdesc,
1469
+ cublasLtMatrixLayout_t Ddesc,
1470
+ cublasLtMatmulPreference_t preference,
1471
+ int requestedAlgoCount,
1472
+ cublasLtMatmulHeuristicResult_t heuristicResultsArray[],
1473
+ int* returnAlgoCount);
1474
+
1475
+ /* ---------------------------------------------------------------------------------------*/
1476
+ /* Lower level API to be able to implement own Heuristic and Find routines */
1477
+ /* ---------------------------------------------------------------------------------------*/
1478
+
1479
+ /** Routine to get all algo IDs that can potentially run
1480
+ *
1481
+ * \param[in] int requestedAlgoCount requested number of algos (must be less or equal to size of algoIdsA
1482
+ * (in elements)) \param[out] algoIdsA array to write algoIds to \param[out] returnAlgoCount number of algoIds
1483
+ * actually written
1484
+ *
1485
+ * \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero
1486
+ * \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect returnAlgoCount to get actual number of IDs
1487
+ * available
1488
+ */
1489
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds(cublasLtHandle_t lightHandle,
1490
+ cublasComputeType_t computeType,
1491
+ cudaDataType_t scaleType,
1492
+ cudaDataType_t Atype,
1493
+ cudaDataType_t Btype,
1494
+ cudaDataType_t Ctype,
1495
+ cudaDataType_t Dtype,
1496
+ int requestedAlgoCount,
1497
+ int algoIdsArray[],
1498
+ int* returnAlgoCount);
1499
+
1500
+ /** Initialize algo structure
1501
+ *
1502
+ * \retval CUBLAS_STATUS_INVALID_VALUE if algo is NULL or algoId is outside of recognized range
1503
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if algoId is not supported for given combination of data types
1504
+ * \retval CUBLAS_STATUS_SUCCESS if the structure was successfully initialized
1505
+ */
1506
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit(cublasLtHandle_t lightHandle,
1507
+ cublasComputeType_t computeType,
1508
+ cudaDataType_t scaleType,
1509
+ cudaDataType_t Atype,
1510
+ cudaDataType_t Btype,
1511
+ cudaDataType_t Ctype,
1512
+ cudaDataType_t Dtype,
1513
+ int algoId,
1514
+ cublasLtMatmulAlgo_t* algo);
1515
+
1516
+ /** Check configured algo descriptor for correctness and support on current device.
1517
+ *
1518
+ * Result includes required workspace size and calculated wave count.
1519
+ *
1520
+ * CUBLAS_STATUS_SUCCESS doesn't fully guarantee algo will run (will fail if e.g. buffers are not correctly aligned);
1521
+ * but if cublasLtMatmulAlgoCheck fails, the algo will not run.
1522
+ *
1523
+ * \param[in] algo algo configuration to check
1524
+ * \param[out] result result structure to report algo runtime characteristics; algo field is never updated
1525
+ *
1526
+ * \retval CUBLAS_STATUS_INVALID_VALUE if matrix layout descriptors or operation descriptor don't match algo
1527
+ * descriptor
1528
+ * \retval CUBLAS_STATUS_NOT_SUPPORTED if algo configuration or data type combination is not currently supported on
1529
+ * given device
1530
+ * \retval CUBLAS_STATUS_ARCH_MISMATCH if algo configuration cannot be run using the selected device
1531
+ * \retval CUBLAS_STATUS_SUCCESS if check was successful
1532
+ */
1533
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( //
1534
+ cublasLtHandle_t lightHandle,
1535
+ cublasLtMatmulDesc_t operationDesc,
1536
+ cublasLtMatrixLayout_t Adesc,
1537
+ cublasLtMatrixLayout_t Bdesc,
1538
+ cublasLtMatrixLayout_t Cdesc,
1539
+ cublasLtMatrixLayout_t Ddesc,
1540
+ const cublasLtMatmulAlgo_t* algo, ///< may point to result->algo
1541
+ cublasLtMatmulHeuristicResult_t* result);
1542
+
1543
+ /** Capabilities Attributes that can be retrieved from an initialized Algo structure
1544
+ */
1545
+ typedef enum {
1546
+ /** support for split K, see CUBLASLT_ALGO_CONFIG_SPLITK_NUM
1547
+ *
1548
+ * int32_t, 0 means no support, supported otherwise
1549
+ */
1550
+ CUBLASLT_ALGO_CAP_SPLITK_SUPPORT = 0,
1551
+ /** reduction scheme mask, see cublasLtReductionScheme_t; shows supported reduction schemes, if reduction scheme is
1552
+ * not masked out it is supported.
1553
+ *
1554
+ * e.g. int isReductionSchemeComputeTypeSupported ? (reductionSchemeMask & CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE) ==
1555
+ * CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE ? 1 : 0;
1556
+ *
1557
+ * uint32_t
1558
+ */
1559
+ CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK = 1,
1560
+ /** support for cta swizzling, see CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
1561
+ *
1562
+ * uint32_t, 0 means no support, 1 means supported value of 1, other values are reserved
1563
+ */
1564
+ CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT = 2,
1565
+ /** support strided batch
1566
+ *
1567
+ * int32_t, 0 means no support, supported otherwise
1568
+ */
1569
+ CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT = 3,
1570
+ /** support results out of place (D != C in D = alpha.A.B + beta.C)
1571
+ *
1572
+ * int32_t, 0 means no support, supported otherwise
1573
+ */
1574
+ CUBLASLT_ALGO_CAP_OUT_OF_PLACE_RESULT_SUPPORT = 4,
1575
+ /** syrk/herk support (on top of regular gemm)
1576
+ *
1577
+ * int32_t, 0 means no support, supported otherwise
1578
+ */
1579
+ CUBLASLT_ALGO_CAP_UPLO_SUPPORT = 5,
1580
+ /** tile ids possible to use, see cublasLtMatmulTile_t; if no tile ids are supported use
1581
+ * CUBLASLT_MATMUL_TILE_UNDEFINED
1582
+ *
1583
+ * use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count
1584
+ *
1585
+ * array of uint32_t
1586
+ */
1587
+ CUBLASLT_ALGO_CAP_TILE_IDS = 6,
1588
+ /** custom option range is from 0 to CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX (inclusive), see
1589
+ * CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
1590
+ *
1591
+ * int32_t
1592
+ */
1593
+ CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX = 7,
1594
+ /** whether algorithm is using regular compute or tensor operations
1595
+ *
1596
+ * int32_t 0 means regular compute, 1 means tensor operations;
1597
+ * DEPRECATED
1598
+ */
1599
+ CUBLASLT_ALGO_CAP_MATHMODE_IMPL = 8,
1600
+ /** whether algorithm implements gaussian optimization of complex matrix multiplication, see cublasMath_t
1601
+ *
1602
+ * int32_t 0 means regular compute, 1 means gaussian;
1603
+ * DEPRECATED
1604
+ */
1605
+ CUBLASLT_ALGO_CAP_GAUSSIAN_IMPL = 9,
1606
+ /** whether algorithm supports custom (not COL or ROW memory order), see cublasLtOrder_t
1607
+ *
1608
+ * int32_t 0 means only COL and ROW memory order is allowed, non-zero means that algo might have different
1609
+ * requirements;
1610
+ */
1611
+ CUBLASLT_ALGO_CAP_CUSTOM_MEMORY_ORDER = 10,
1612
+
1613
+ /** bitmask enumerating pointer modes algorithm supports
1614
+ *
1615
+ * uint32_t, see cublasLtPointerModeMask_t
1616
+ */
1617
+ CUBLASLT_ALGO_CAP_POINTER_MODE_MASK = 11,
1618
+
1619
+ /** bitmask enumerating kinds of postprocessing algorithm supports in the epilogue
1620
+ *
1621
+ * uint32_t, see cublasLtEpilogue_t
1622
+ */
1623
+ CUBLASLT_ALGO_CAP_EPILOGUE_MASK = 12,
1624
+ /** stages ids possible to use, see cublasLtMatmulStages_t; if no stages ids are supported use
1625
+ * CUBLASLT_MATMUL_STAGES_UNDEFINED
1626
+ *
1627
+ * use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count
1628
+ *
1629
+ * array of uint32_t
1630
+ */
1631
+ CUBLASLT_ALGO_CAP_STAGES_IDS = 13,
1632
+ /** support for nagative ld for all of the matrices
1633
+ *
1634
+ * int32_t 0 means no support, supported otherwise
1635
+ */
1636
+ CUBLASLT_ALGO_CAP_LD_NEGATIVE = 14,
1637
+ /** details about algorithm's implementation that affect it's numerical behavior
1638
+ *
1639
+ * uint64_t, see cublasLtNumericalImplFlags_t
1640
+ */
1641
+ CUBLASLT_ALGO_CAP_NUMERICAL_IMPL_FLAGS = 15,
1642
+ /** minimum alignment required for A matrix in bytes
1643
+ * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
1644
+ *
1645
+ * uint32_t
1646
+ */
1647
+ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_A_BYTES = 16,
1648
+ /** minimum alignment required for B matrix in bytes
1649
+ * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
1650
+ *
1651
+ * uint32_t
1652
+ */
1653
+ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_B_BYTES = 17,
1654
+ /** minimum alignment required for C matrix in bytes
1655
+ * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
1656
+ *
1657
+ * uint32_t
1658
+ */
1659
+ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_C_BYTES = 18,
1660
+ /** minimum alignment required for D matrix in bytes
1661
+ * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
1662
+ *
1663
+ * uint32_t
1664
+ */
1665
+ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_D_BYTES = 19,
1666
+ } cublasLtMatmulAlgoCapAttributes_t;
1667
+
1668
+ /** Get algo capability attribute.
1669
+ *
1670
+ * E.g. to get list of supported Tile IDs:
1671
+ * cublasLtMatmulTile_t tiles[CUBLASLT_MATMUL_TILE_END];
1672
+ * size_t num_tiles, size_written;
1673
+ * if (cublasLtMatmulAlgoCapGetAttribute(algo, CUBLASLT_ALGO_CAP_TILE_IDS, tiles, sizeof(tiles), size_written) ==
1674
+ * CUBLAS_STATUS_SUCCESS) { num_tiles = size_written / sizeof(tiles[0]);
1675
+ * }
1676
+ *
1677
+ * \param[in] algo The algo descriptor
1678
+ * \param[in] attr The attribute
1679
+ * \param[out] buf memory address containing the new value
1680
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1681
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
1682
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
1683
+ *
1684
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
1685
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
1686
+ * selected attribute
1687
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
1688
+ */
1689
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute(const cublasLtMatmulAlgo_t* algo,
1690
+ cublasLtMatmulAlgoCapAttributes_t attr,
1691
+ void* buf,
1692
+ size_t sizeInBytes,
1693
+ size_t* sizeWritten);
1694
+
1695
+ /** Algo Configuration Attributes that can be set according to the Algo capabilities
1696
+ */
1697
+ typedef enum {
1698
+ /** algorithm index, see cublasLtMatmulAlgoGetIds()
1699
+ *
1700
+ * readonly, set by cublasLtMatmulAlgoInit()
1701
+ * int32_t
1702
+ */
1703
+ CUBLASLT_ALGO_CONFIG_ID = 0,
1704
+ /** tile id, see cublasLtMatmulTile_t
1705
+ *
1706
+ * uint32_t, default: CUBLASLT_MATMUL_TILE_UNDEFINED
1707
+ */
1708
+ CUBLASLT_ALGO_CONFIG_TILE_ID = 1,
1709
+ /** Number of K splits. If the number of K splits is greater than one, SPLITK_NUM parts
1710
+ * of matrix multiplication will be computed in parallel. The results will be accumulated
1711
+ * according to CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
1712
+ *
1713
+ * int32_t, default: 1
1714
+ */
1715
+ CUBLASLT_ALGO_CONFIG_SPLITK_NUM = 2,
1716
+ /** reduction scheme, see cublasLtReductionScheme_t
1717
+ *
1718
+ * uint32_t, default: CUBLASLT_REDUCTION_SCHEME_NONE
1719
+ */
1720
+ CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME = 3,
1721
+ /** cta swizzling, change mapping from CUDA grid coordinates to parts of the matrices
1722
+ *
1723
+ * possible values: 0, 1, other values reserved
1724
+ *
1725
+ * uint32_t, default: 0
1726
+ */
1727
+ CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING = 4,
1728
+ /** custom option, each algorithm can support some custom options that don't fit description of the other config
1729
+ * attributes, see CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX to get accepted range for any specific case
1730
+ *
1731
+ * uint32_t, default: 0
1732
+ */
1733
+ CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION = 5,
1734
+ /** stages id, see cublasLtMatmulStages_t
1735
+ *
1736
+ * uint32_t, default: CUBLASLT_MATMUL_STAGES_UNDEFINED
1737
+ */
1738
+ CUBLASLT_ALGO_CONFIG_STAGES_ID = 6,
1739
+ /** inner shape id, see cublasLtMatmulInnerShape_t
1740
+ *
1741
+ * uint16_t, default: 0 (CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED)
1742
+ */
1743
+ CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID = 7,
1744
+ /** Thread Block Cluster shape id, see cublasLtClusterShape_t. Defines cluster size to use.
1745
+ *
1746
+ * uint16_t, default: 0 (CUBLASLT_CLUSTER_SHAPE_AUTO)
1747
+ */
1748
+ CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID = 8,
1749
+ } cublasLtMatmulAlgoConfigAttributes_t;
1750
+
1751
+ /** Set algo configuration attribute.
1752
+ *
1753
+ * \param[in] algo The algo descriptor
1754
+ * \param[in] attr The attribute
1755
+ * \param[in] buf memory address containing the new value
1756
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1757
+ *
1758
+ * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
1759
+ * selected attribute
1760
+ * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
1761
+ */
1762
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute(cublasLtMatmulAlgo_t* algo,
1763
+ cublasLtMatmulAlgoConfigAttributes_t attr,
1764
+ const void* buf,
1765
+ size_t sizeInBytes);
1766
+
1767
+ /** Get algo configuration attribute.
1768
+ *
1769
+ * \param[in] algo The algo descriptor
1770
+ * \param[in] attr The attribute
1771
+ * \param[out] buf memory address containing the new value
1772
+ * \param[in] sizeInBytes size of buf buffer for verification (in bytes)
1773
+ * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
1774
+ * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
1775
+ *
1776
+ * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
1777
+ * and buf is NULL or sizeInBytes doesn't match size of internal storage for
1778
+ * selected attribute
1779
+ * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
1780
+ */
1781
+ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute(const cublasLtMatmulAlgo_t* algo,
1782
+ cublasLtMatmulAlgoConfigAttributes_t attr,
1783
+ void* buf,
1784
+ size_t sizeInBytes,
1785
+ size_t* sizeWritten);
1786
+
1787
+ /** Experimental: Logger callback type.
1788
+ */
1789
+ typedef void (*cublasLtLoggerCallback_t)(int logLevel, const char* functionName, const char* message);
1790
+
1791
+ /** Experimental: Logger callback setter.
1792
+ *
1793
+ * \param[in] callback a user defined callback function to be called by the logger
1794
+ *
1795
+ * \retval CUBLAS_STATUS_SUCCESS if callback was set successfully
1796
+ */
1797
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetCallback(cublasLtLoggerCallback_t callback);
1798
+
1799
+ /** Experimental: Log file setter.
1800
+ *
1801
+ * \param[in] file an open file with write permissions
1802
+ *
1803
+ * \retval CUBLAS_STATUS_SUCCESS if log file was set successfully
1804
+ */
1805
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetFile(FILE* file);
1806
+
1807
+ /** Experimental: Open log file.
1808
+ *
1809
+ * \param[in] logFile log file path. if the log file does not exist, it will be created
1810
+ *
1811
+ * \retval CUBLAS_STATUS_SUCCESS if log file was created successfully
1812
+ */
1813
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerOpenFile(const char* logFile);
1814
+
1815
+ /** Experimental: Log level setter.
1816
+ *
1817
+ * \param[in] level log level, should be one of the following:
1818
+ * 0. Off
1819
+ * 1. Errors
1820
+ * 2. Performance Trace
1821
+ * 3. Performance Hints
1822
+ * 4. Heuristics Trace
1823
+ * 5. API Trace
1824
+ *
1825
+ * \retval CUBLAS_STATUS_INVALID_VALUE if log level is not one of the above levels
1826
+ *
1827
+ * \retval CUBLAS_STATUS_SUCCESS if log level was set successfully
1828
+ */
1829
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetLevel(int level);
1830
+
1831
+ /** Experimental: Log mask setter.
1832
+ *
1833
+ * \param[in] mask log mask, should be a combination of the following masks:
1834
+ * 0. Off
1835
+ * 1. Errors
1836
+ * 2. Performance Trace
1837
+ * 4. Performance Hints
1838
+ * 8. Heuristics Trace
1839
+ * 16. API Trace
1840
+ *
1841
+ * \retval CUBLAS_STATUS_SUCCESS if log mask was set successfully
1842
+ */
1843
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetMask(int mask);
1844
+
1845
+ /** Experimental: Disable logging for the entire session.
1846
+ *
1847
+ * \retval CUBLAS_STATUS_SUCCESS if disabled logging
1848
+ */
1849
+ cublasStatus_t CUBLASWINAPI cublasLtLoggerForceDisable();
1850
+
1851
+ #if defined(__cplusplus)
1852
+ }
1853
+ #endif /* __cplusplus */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * This is the public header file for the new CUBLAS library API, it mapped the generic
52
+ * Cublas name functions to the actual _v2 implementations.
53
+ */
54
+
55
+ #if !defined(CUBLAS_V2_H_)
56
+ #define CUBLAS_V2_H_
57
+
58
+ #undef CUBLASAPI
59
+ #ifdef __CUDACC__
60
+ #define CUBLASAPI __host__ __device__
61
+ #else
62
+ #define CUBLASAPI
63
+ #endif
64
+
65
+ #include "cublas_api.h"
66
+
67
+ #define cublasCreate cublasCreate_v2
68
+ #define cublasDestroy cublasDestroy_v2
69
+ #define cublasGetVersion cublasGetVersion_v2
70
+ #define cublasSetWorkspace cublasSetWorkspace_v2
71
+ #define cublasSetStream cublasSetStream_v2
72
+ #define cublasGetStream cublasGetStream_v2
73
+ #define cublasGetPointerMode cublasGetPointerMode_v2
74
+ #define cublasSetPointerMode cublasSetPointerMode_v2
75
+
76
+ /* Blas3 Routines */
77
+
78
+ #define cublasSnrm2 cublasSnrm2_v2
79
+ #define cublasDnrm2 cublasDnrm2_v2
80
+ #define cublasScnrm2 cublasScnrm2_v2
81
+ #define cublasDznrm2 cublasDznrm2_v2
82
+
83
+ #define cublasSdot cublasSdot_v2
84
+ #define cublasDdot cublasDdot_v2
85
+ #define cublasCdotu cublasCdotu_v2
86
+ #define cublasCdotc cublasCdotc_v2
87
+ #define cublasZdotu cublasZdotu_v2
88
+ #define cublasZdotc cublasZdotc_v2
89
+
90
+ #define cublasSscal cublasSscal_v2
91
+ #define cublasDscal cublasDscal_v2
92
+ #define cublasCscal cublasCscal_v2
93
+ #define cublasCsscal cublasCsscal_v2
94
+ #define cublasZscal cublasZscal_v2
95
+ #define cublasZdscal cublasZdscal_v2
96
+
97
+ #define cublasSaxpy cublasSaxpy_v2
98
+ #define cublasDaxpy cublasDaxpy_v2
99
+ #define cublasCaxpy cublasCaxpy_v2
100
+ #define cublasZaxpy cublasZaxpy_v2
101
+
102
+ #define cublasScopy cublasScopy_v2
103
+ #define cublasDcopy cublasDcopy_v2
104
+ #define cublasCcopy cublasCcopy_v2
105
+ #define cublasZcopy cublasZcopy_v2
106
+
107
+ #define cublasSswap cublasSswap_v2
108
+ #define cublasDswap cublasDswap_v2
109
+ #define cublasCswap cublasCswap_v2
110
+ #define cublasZswap cublasZswap_v2
111
+
112
+ #define cublasIsamax cublasIsamax_v2
113
+ #define cublasIdamax cublasIdamax_v2
114
+ #define cublasIcamax cublasIcamax_v2
115
+ #define cublasIzamax cublasIzamax_v2
116
+
117
+ #define cublasIsamin cublasIsamin_v2
118
+ #define cublasIdamin cublasIdamin_v2
119
+ #define cublasIcamin cublasIcamin_v2
120
+ #define cublasIzamin cublasIzamin_v2
121
+
122
+ #define cublasSasum cublasSasum_v2
123
+ #define cublasDasum cublasDasum_v2
124
+ #define cublasScasum cublasScasum_v2
125
+ #define cublasDzasum cublasDzasum_v2
126
+
127
+ #define cublasSrot cublasSrot_v2
128
+ #define cublasDrot cublasDrot_v2
129
+ #define cublasCrot cublasCrot_v2
130
+ #define cublasCsrot cublasCsrot_v2
131
+ #define cublasZrot cublasZrot_v2
132
+ #define cublasZdrot cublasZdrot_v2
133
+
134
+ #define cublasSrotg cublasSrotg_v2
135
+ #define cublasDrotg cublasDrotg_v2
136
+ #define cublasCrotg cublasCrotg_v2
137
+ #define cublasZrotg cublasZrotg_v2
138
+
139
+ #define cublasSrotm cublasSrotm_v2
140
+ #define cublasDrotm cublasDrotm_v2
141
+
142
+ #define cublasSrotmg cublasSrotmg_v2
143
+ #define cublasDrotmg cublasDrotmg_v2
144
+
145
+ /* Blas2 Routines */
146
+
147
+ #define cublasSgemv cublasSgemv_v2
148
+ #define cublasDgemv cublasDgemv_v2
149
+ #define cublasCgemv cublasCgemv_v2
150
+ #define cublasZgemv cublasZgemv_v2
151
+
152
+ #define cublasSgbmv cublasSgbmv_v2
153
+ #define cublasDgbmv cublasDgbmv_v2
154
+ #define cublasCgbmv cublasCgbmv_v2
155
+ #define cublasZgbmv cublasZgbmv_v2
156
+
157
+ #define cublasStrmv cublasStrmv_v2
158
+ #define cublasDtrmv cublasDtrmv_v2
159
+ #define cublasCtrmv cublasCtrmv_v2
160
+ #define cublasZtrmv cublasZtrmv_v2
161
+
162
+ #define cublasStbmv cublasStbmv_v2
163
+ #define cublasDtbmv cublasDtbmv_v2
164
+ #define cublasCtbmv cublasCtbmv_v2
165
+ #define cublasZtbmv cublasZtbmv_v2
166
+
167
+ #define cublasStpmv cublasStpmv_v2
168
+ #define cublasDtpmv cublasDtpmv_v2
169
+ #define cublasCtpmv cublasCtpmv_v2
170
+ #define cublasZtpmv cublasZtpmv_v2
171
+
172
+ #define cublasStrsv cublasStrsv_v2
173
+ #define cublasDtrsv cublasDtrsv_v2
174
+ #define cublasCtrsv cublasCtrsv_v2
175
+ #define cublasZtrsv cublasZtrsv_v2
176
+
177
+ #define cublasStpsv cublasStpsv_v2
178
+ #define cublasDtpsv cublasDtpsv_v2
179
+ #define cublasCtpsv cublasCtpsv_v2
180
+ #define cublasZtpsv cublasZtpsv_v2
181
+
182
+ #define cublasStbsv cublasStbsv_v2
183
+ #define cublasDtbsv cublasDtbsv_v2
184
+ #define cublasCtbsv cublasCtbsv_v2
185
+ #define cublasZtbsv cublasZtbsv_v2
186
+
187
+ #define cublasSsymv cublasSsymv_v2
188
+ #define cublasDsymv cublasDsymv_v2
189
+ #define cublasCsymv cublasCsymv_v2
190
+ #define cublasZsymv cublasZsymv_v2
191
+ #define cublasChemv cublasChemv_v2
192
+ #define cublasZhemv cublasZhemv_v2
193
+
194
+ #define cublasSsbmv cublasSsbmv_v2
195
+ #define cublasDsbmv cublasDsbmv_v2
196
+ #define cublasChbmv cublasChbmv_v2
197
+ #define cublasZhbmv cublasZhbmv_v2
198
+
199
+ #define cublasSspmv cublasSspmv_v2
200
+ #define cublasDspmv cublasDspmv_v2
201
+ #define cublasChpmv cublasChpmv_v2
202
+ #define cublasZhpmv cublasZhpmv_v2
203
+
204
+ #define cublasSger cublasSger_v2
205
+ #define cublasDger cublasDger_v2
206
+ #define cublasCgeru cublasCgeru_v2
207
+ #define cublasCgerc cublasCgerc_v2
208
+ #define cublasZgeru cublasZgeru_v2
209
+ #define cublasZgerc cublasZgerc_v2
210
+
211
+ #define cublasSsyr cublasSsyr_v2
212
+ #define cublasDsyr cublasDsyr_v2
213
+ #define cublasCsyr cublasCsyr_v2
214
+ #define cublasZsyr cublasZsyr_v2
215
+ #define cublasCher cublasCher_v2
216
+ #define cublasZher cublasZher_v2
217
+
218
+ #define cublasSspr cublasSspr_v2
219
+ #define cublasDspr cublasDspr_v2
220
+ #define cublasChpr cublasChpr_v2
221
+ #define cublasZhpr cublasZhpr_v2
222
+
223
+ #define cublasSsyr2 cublasSsyr2_v2
224
+ #define cublasDsyr2 cublasDsyr2_v2
225
+ #define cublasCsyr2 cublasCsyr2_v2
226
+ #define cublasZsyr2 cublasZsyr2_v2
227
+ #define cublasCher2 cublasCher2_v2
228
+ #define cublasZher2 cublasZher2_v2
229
+
230
+ #define cublasSspr2 cublasSspr2_v2
231
+ #define cublasDspr2 cublasDspr2_v2
232
+ #define cublasChpr2 cublasChpr2_v2
233
+ #define cublasZhpr2 cublasZhpr2_v2
234
+
235
+ /* Blas3 Routines */
236
+
237
+ #define cublasSgemm cublasSgemm_v2
238
+ #define cublasDgemm cublasDgemm_v2
239
+ #define cublasCgemm cublasCgemm_v2
240
+ #define cublasZgemm cublasZgemm_v2
241
+
242
+ #define cublasSsyrk cublasSsyrk_v2
243
+ #define cublasDsyrk cublasDsyrk_v2
244
+ #define cublasCsyrk cublasCsyrk_v2
245
+ #define cublasZsyrk cublasZsyrk_v2
246
+ #define cublasCherk cublasCherk_v2
247
+ #define cublasZherk cublasZherk_v2
248
+
249
+ #define cublasSsyr2k cublasSsyr2k_v2
250
+ #define cublasDsyr2k cublasDsyr2k_v2
251
+ #define cublasCsyr2k cublasCsyr2k_v2
252
+ #define cublasZsyr2k cublasZsyr2k_v2
253
+ #define cublasCher2k cublasCher2k_v2
254
+ #define cublasZher2k cublasZher2k_v2
255
+
256
+ #define cublasSsymm cublasSsymm_v2
257
+ #define cublasDsymm cublasDsymm_v2
258
+ #define cublasCsymm cublasCsymm_v2
259
+ #define cublasZsymm cublasZsymm_v2
260
+ #define cublasChemm cublasChemm_v2
261
+ #define cublasZhemm cublasZhemm_v2
262
+
263
+ #define cublasStrsm cublasStrsm_v2
264
+ #define cublasDtrsm cublasDtrsm_v2
265
+ #define cublasCtrsm cublasCtrsm_v2
266
+ #define cublasZtrsm cublasZtrsm_v2
267
+
268
+ #define cublasStrmm cublasStrmm_v2
269
+ #define cublasDtrmm cublasDtrmm_v2
270
+ #define cublasCtrmm cublasCtrmm_v2
271
+ #define cublasZtrmm cublasZtrmm_v2
272
+
273
+ #endif /* !defined(CUBLAS_V2_H_) */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/include/cupti_pcsampling.h ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2020-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #if !defined(_CUPTI_PCSAMPLING_H_)
51
+ #define _CUPTI_PCSAMPLING_H_
52
+
53
+ #include <cuda.h>
54
+ #include <stdint.h>
55
+ #include <stddef.h>
56
+ #include "cupti_result.h"
57
+
58
+ #ifndef CUPTIAPI
59
+ #ifdef _WIN32
60
+ #define CUPTIAPI __stdcall
61
+ #else
62
+ #define CUPTIAPI
63
+ #endif
64
+ #endif
65
+
66
+ #define ACTIVITY_RECORD_ALIGNMENT 8
67
+ #if defined(_WIN32) // Windows 32- and 64-bit
68
+ #define START_PACKED_ALIGNMENT __pragma(pack(push,1)) // exact fit - no padding
69
+ #define PACKED_ALIGNMENT __declspec(align(ACTIVITY_RECORD_ALIGNMENT))
70
+ #define END_PACKED_ALIGNMENT __pragma(pack(pop))
71
+ #elif defined(__GNUC__) // GCC
72
+ #define START_PACKED_ALIGNMENT
73
+ #define PACKED_ALIGNMENT __attribute__ ((__packed__)) __attribute__ ((aligned (ACTIVITY_RECORD_ALIGNMENT)))
74
+ #define END_PACKED_ALIGNMENT
75
+ #else // all other compilers
76
+ #define START_PACKED_ALIGNMENT
77
+ #define PACKED_ALIGNMENT
78
+ #define END_PACKED_ALIGNMENT
79
+ #endif
80
+
81
+ #if defined(__cplusplus)
82
+ extern "C" {
83
+ #endif
84
+
85
+ #if defined(__GNUC__) && defined(CUPTI_LIB)
86
+ #pragma GCC visibility push(default)
87
+ #endif
88
+
89
+ /**
90
+ * \defgroup CUPTI_PCSAMPLING_API CUPTI PC Sampling API
91
+ * Functions, types, and enums that implement the CUPTI PC Sampling API.
92
+ * @{
93
+ */
94
+
95
+ #ifndef CUPTI_PCSAMPLING_STRUCT_SIZE
96
+ #define CUPTI_PCSAMPLING_STRUCT_SIZE(type_, lastfield_) (offsetof(type_, lastfield_) + sizeof(((type_*)0)->lastfield_))
97
+ #endif
98
+
99
+ #ifndef CUPTI_STALL_REASON_STRING_SIZE
100
+ #define CUPTI_STALL_REASON_STRING_SIZE 128
101
+ #endif
102
+
103
+ /**
104
+ * \brief PC Sampling collection mode
105
+ */
106
+ typedef enum
107
+ {
108
+ /**
109
+ * INVALID Value
110
+ */
111
+ CUPTI_PC_SAMPLING_COLLECTION_MODE_INVALID = 0,
112
+ /**
113
+ * Continuous mode. Kernels are not serialized in this mode.
114
+ */
115
+ CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS = 1,
116
+ /**
117
+ * Serialized mode. Kernels are serialized in this mode.
118
+ */
119
+ CUPTI_PC_SAMPLING_COLLECTION_MODE_KERNEL_SERIALIZED = 2,
120
+ } CUpti_PCSamplingCollectionMode;
121
+
122
+ /**
123
+ * \brief PC Sampling stall reasons
124
+ */
125
+ typedef struct PACKED_ALIGNMENT
126
+ {
127
+ /**
128
+ * [r] Collected stall reason index
129
+ */
130
+ uint32_t pcSamplingStallReasonIndex;
131
+ /**
132
+ * [r] Number of times the PC was sampled with the stallReason.
133
+ */
134
+ uint32_t samples;
135
+ } CUpti_PCSamplingStallReason;
136
+
137
+ /**
138
+ * \brief PC Sampling data
139
+ */
140
+ typedef struct PACKED_ALIGNMENT
141
+ {
142
+ /**
143
+ * [w] Size of the data structure.
144
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
145
+ * available in the structure. Used to preserve backward compatibility.
146
+ */
147
+ size_t size;
148
+ /**
149
+ * [r] Unique cubin id
150
+ */
151
+ uint64_t cubinCrc;
152
+ /**
153
+ * [r] PC offset
154
+ */
155
+ uint64_t pcOffset;
156
+ /**
157
+ * The function's unique symbol index in the module.
158
+ */
159
+ uint32_t functionIndex;
160
+ /**
161
+ * Padding
162
+ */
163
+ uint32_t pad;
164
+ /**
165
+ * [r] The function name. This name string might be shared across all the records
166
+ * including records from activity APIs representing the same function, and so it should not be
167
+ * modified or freed until post processing of all the records is done. Once done, it is user’s responsibility to
168
+ * free the memory using free() function.
169
+ */
170
+ char* functionName;
171
+ /**
172
+ * [r] Collected stall reason count
173
+ */
174
+ size_t stallReasonCount;
175
+ /**
176
+ * [r] Stall reason id
177
+ * Total samples
178
+ */
179
+ CUpti_PCSamplingStallReason *stallReason;
180
+ } CUpti_PCSamplingPCData;
181
+
182
+ /**
183
+ * \brief PC Sampling output data format
184
+ */
185
+ typedef enum
186
+ {
187
+ CUPTI_PC_SAMPLING_OUTPUT_DATA_FORMAT_INVALID = 0,
188
+ /**
189
+ * HW buffer data will be parsed during collection of data
190
+ */
191
+ CUPTI_PC_SAMPLING_OUTPUT_DATA_FORMAT_PARSED = 1,
192
+ } CUpti_PCSamplingOutputDataFormat;
193
+
194
+ /**
195
+ * \brief Collected PC Sampling data
196
+ *
197
+ */
198
+ typedef struct PACKED_ALIGNMENT
199
+ {
200
+ /**
201
+ * [w] Size of the data structure.
202
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
203
+ * available in the structure. Used to preserve backward compatibility.
204
+ */
205
+ size_t size;
206
+ /**
207
+ * [w] Number of PCs to be collected
208
+ */
209
+ size_t collectNumPcs;
210
+ /**
211
+ * [r] Number of samples collected across all PCs.
212
+ * It includes samples for user modules, samples for non-user kernels and dropped samples.
213
+ * It includes counts for all non selected stall reasons.
214
+ * CUPTI does not provide PC records for non-user kernels.
215
+ * CUPTI does not provide PC records for instructions for which all selected stall reason metrics counts are zero.
216
+ */
217
+ uint64_t totalSamples;
218
+ /**
219
+ * [r] Number of samples that were dropped by hardware due to backpressure/overflow.
220
+ */
221
+ uint64_t droppedSamples;
222
+ /**
223
+ * [r] Number of PCs collected
224
+ */
225
+ size_t totalNumPcs;
226
+ /**
227
+ * [r] Number of PCs available for collection
228
+ */
229
+ size_t remainingNumPcs;
230
+ /**
231
+ * [r] Unique identifier for each range.
232
+ * Data collected across multiple ranges in multiple buffers can be identified using range id.
233
+ */
234
+ uint64_t rangeId;
235
+ /**
236
+ * [r] Profiled PC data
237
+ * This data struct should have enough memory to collect number of PCs mentioned in \brief collectNumPcs
238
+ */
239
+ CUpti_PCSamplingPCData *pPcData;
240
+ /**
241
+ * [r] Number of samples collected across all non user kernels PCs.
242
+ * It includes samples for non-user kernels.
243
+ * It includes counts for all non selected stall reasons as well.
244
+ * CUPTI does not provide PC records for non-user kernels.
245
+ */
246
+ uint64_t nonUsrKernelsTotalSamples;
247
+ } CUpti_PCSamplingData;
248
+
249
+ /**
250
+ * \brief PC Sampling configuration attributes
251
+ *
252
+ * PC Sampling configuration attribute types. These attributes can be read
253
+ * using \ref cuptiPCSamplingGetConfigurationAttribute and can be written
254
+ * using \ref cuptiPCSamplingSetConfigurationAttribute. Attributes marked
255
+ * [r] can only be read using \ref cuptiPCSamplingGetConfigurationAttribute
256
+ * [w] can only be written using \ref cuptiPCSamplingSetConfigurationAttribute
257
+ * [rw] can be read using \ref cuptiPCSamplingGetConfigurationAttribute and
258
+ * written using \ref cuptiPCSamplingSetConfigurationAttribute
259
+ */
260
+ typedef enum
261
+ {
262
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_INVALID = 0,
263
+ /**
264
+ * [rw] Sampling period for PC Sampling.
265
+ * DEFAULT - CUPTI defined value based on number of SMs
266
+ * Valid values for the sampling
267
+ * periods are between 5 to 31 both inclusive. This will set the
268
+ * sampling period to (2^samplingPeriod) cycles.
269
+ * For e.g. for sampling period = 5 to 31, cycles = 32, 64, 128,..., 2^31
270
+ * Value is a uint32_t
271
+ */
272
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_PERIOD = 1,
273
+ /**
274
+ * [w] Number of stall reasons to collect.
275
+ * DEFAULT - All stall reasons will be collected
276
+ * Value is a size_t
277
+ * [w] Stall reasons to collect
278
+ * DEFAULT - All stall reasons will be collected
279
+ * Input value should be a pointer pointing to array of stall reason indexes
280
+ * containing all the stall reason indexes to collect.
281
+ */
282
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON = 2,
283
+ /**
284
+ * [rw] Size of SW buffer for raw PC counter data downloaded from HW buffer
285
+ * DEFAULT - 1 MB, which can accommodate approximately 5500 PCs
286
+ * with all stall reasons
287
+ * Approximately it takes 16 Bytes (and some fixed size memory)
288
+ * to accommodate one PC with one stall reason
289
+ * For e.g. 1 PC with 1 stall reason = 32 Bytes
290
+ * 1 PC with 2 stall reason = 48 Bytes
291
+ * 1 PC with 4 stall reason = 96 Bytes
292
+ * Value is a size_t
293
+ */
294
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SCRATCH_BUFFER_SIZE = 3,
295
+ /**
296
+ * [rw] Size of HW buffer in bytes
297
+ * DEFAULT - 512 MB
298
+ * If sampling period is too less, HW buffer can overflow
299
+ * and drop PC data
300
+ * Value is a size_t
301
+ */
302
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_HARDWARE_BUFFER_SIZE = 4,
303
+ /**
304
+ * [rw] PC Sampling collection mode
305
+ * DEFAULT - CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS
306
+ * Input value should be of type \ref CUpti_PCSamplingCollectionMode.
307
+ */
308
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_COLLECTION_MODE = 5,
309
+ /**
310
+ * [rw] Control over PC Sampling data collection range
311
+ * Default - 0
312
+ * 1 - Allows user to start and stop PC Sampling using APIs -
313
+ * \ref cuptiPCSamplingStart() - Start PC Sampling
314
+ * \ref cuptiPCSamplingStop() - Stop PC Sampling
315
+ * Value is a uint32_t
316
+ */
317
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL = 6,
318
+ /**
319
+ * [w] Value for output data format
320
+ * Default - CUPTI_PC_SAMPLING_OUTPUT_DATA_FORMAT_PARSED
321
+ * Input value should be of type \ref CUpti_PCSamplingOutputDataFormat.
322
+ */
323
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_OUTPUT_DATA_FORMAT = 7,
324
+ /**
325
+ * [w] Data buffer to hold collected PC Sampling data PARSED_DATA
326
+ * Default - none.
327
+ * Buffer type is void * which can point to PARSED_DATA
328
+ * Refer \ref CUpti_PCSamplingData for buffer format for PARSED_DATA
329
+ */
330
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_DATA_BUFFER = 8,
331
+ CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_FORCE_INT = 0x7fffffff,
332
+ } CUpti_PCSamplingConfigurationAttributeType;
333
+
334
+ /**
335
+ * \brief PC sampling configuration information structure
336
+ *
337
+ * This structure provides \ref CUpti_PCSamplingConfigurationAttributeType which can be configured
338
+ * or queried for PC sampling configuration
339
+ */
340
+ typedef struct
341
+ {
342
+ /**
343
+ * Refer \ref CUpti_PCSamplingConfigurationAttributeType for all supported attribute types
344
+ */
345
+ CUpti_PCSamplingConfigurationAttributeType attributeType;
346
+ /*
347
+ * Configure or query status for \p attributeType
348
+ * CUPTI_SUCCESS for valid \p attributeType and \p attributeData
349
+ * CUPTI_ERROR_INVALID_OPERATION if \p attributeData is not valid
350
+ * CUPTI_ERROR_INVALID_PARAMETER if \p attributeType is not valid
351
+ */
352
+ CUptiResult attributeStatus;
353
+ union
354
+ {
355
+ /**
356
+ * Invalid Value
357
+ */
358
+ struct
359
+ {
360
+ uint64_t data[3];
361
+ } invalidData;
362
+ /**
363
+ * Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_PERIOD
364
+ */
365
+ struct
366
+ {
367
+ uint32_t samplingPeriod;
368
+ } samplingPeriodData;
369
+ /**
370
+ * Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON
371
+ */
372
+ struct
373
+ {
374
+ size_t stallReasonCount;
375
+ uint32_t *pStallReasonIndex;
376
+ } stallReasonData;
377
+ /**
378
+ * Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SCRATCH_BUFFER_SIZE
379
+ */
380
+ struct
381
+ {
382
+ size_t scratchBufferSize;
383
+ } scratchBufferSizeData;
384
+ /**
385
+ * Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_HARDWARE_BUFFER_SIZE
386
+ */
387
+ struct
388
+ {
389
+ size_t hardwareBufferSize;
390
+ } hardwareBufferSizeData;
391
+ /**
392
+ * Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_COLLECTION_MODE
393
+ */
394
+ struct
395
+ {
396
+ CUpti_PCSamplingCollectionMode collectionMode;
397
+ } collectionModeData;
398
+ /**
399
+ * Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL
400
+ */
401
+ struct
402
+ {
403
+ uint32_t enableStartStopControl;
404
+ } enableStartStopControlData;
405
+ /**
406
+ * Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_OUTPUT_DATA_FORMAT
407
+ */
408
+ struct
409
+ {
410
+ CUpti_PCSamplingOutputDataFormat outputDataFormat;
411
+ } outputDataFormatData;
412
+ /**
413
+ * Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_DATA_BUFFER
414
+ */
415
+ struct
416
+ {
417
+ void *samplingDataBuffer;
418
+ } samplingDataBufferData;
419
+ } attributeData;
420
+ } CUpti_PCSamplingConfigurationInfo;
421
+
422
+ /**
423
+ * \brief PC sampling configuration structure
424
+ *
425
+ * This structure configures PC sampling using \ref cuptiPCSamplingSetConfigurationAttribute
426
+ * and queries PC sampling default configuration using \ref cuptiPCSamplingGetConfigurationAttribute
427
+ */
428
+ typedef struct
429
+ {
430
+ /**
431
+ * [w] Size of the data structure i.e. CUpti_PCSamplingConfigurationInfoParamsSize
432
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
433
+ * available in the structure. Used to preserve backward compatibility.
434
+ */
435
+ size_t size;
436
+ /**
437
+ * [w] Assign to NULL
438
+ */
439
+ void* pPriv;
440
+ /**
441
+ * [w] CUcontext
442
+ */
443
+ CUcontext ctx;
444
+ /**
445
+ * [w] Number of attributes to configure using \ref cuptiPCSamplingSetConfigurationAttribute or query
446
+ * using \ref cuptiPCSamplingGetConfigurationAttribute
447
+ */
448
+ size_t numAttributes;
449
+ /**
450
+ * Refer \ref CUpti_PCSamplingConfigurationInfo
451
+ */
452
+ CUpti_PCSamplingConfigurationInfo *pPCSamplingConfigurationInfo;
453
+ } CUpti_PCSamplingConfigurationInfoParams;
454
+ #define CUpti_PCSamplingConfigurationInfoParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingConfigurationInfoParams,pPCSamplingConfigurationInfo)
455
+
456
+ /**
457
+ * \brief Write PC Sampling configuration attribute.
458
+ *
459
+ * \param pParams A pointer to \ref CUpti_PCSamplingConfigurationInfoParams
460
+ * containing PC sampling configuration.
461
+ *
462
+ * \retval CUPTI_SUCCESS
463
+ * \retval CUPTI_ERROR_INVALID_OPERATION if this API is called with
464
+ * some invalid \p attrib.
465
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if attribute \p value is not valid
466
+ * or any \p pParams is not valid
467
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
468
+ * does not support the API
469
+ */
470
+ CUptiResult CUPTIAPI cuptiPCSamplingSetConfigurationAttribute(CUpti_PCSamplingConfigurationInfoParams *pParams);
471
+
472
+ /**
473
+ * \brief Read PC Sampling configuration attribute.
474
+ *
475
+ * \param pParams A pointer to \ref CUpti_PCSamplingConfigurationInfoParams
476
+ * containing PC sampling configuration.
477
+ *
478
+ * \retval CUPTI_SUCCESS
479
+ * \retval CUPTI_ERROR_INVALID_OPERATION if this API is called with
480
+ * some invalid attribute.
481
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if \p attrib is not valid
482
+ * or any \p pParams is not valid
483
+ * \retval CUPTI_ERROR_PARAMETER_SIZE_NOT_SUFFICIENT indicates that
484
+ * the \p value buffer is too small to hold the attribute value
485
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
486
+ * does not support the API
487
+ */
488
+ CUptiResult CUPTIAPI cuptiPCSamplingGetConfigurationAttribute(CUpti_PCSamplingConfigurationInfoParams *pParams);
489
+
490
+ /**
491
+ * \brief Params for cuptiPCSamplingEnable
492
+ */
493
+ typedef struct
494
+ {
495
+ /**
496
+ * [w] Size of the data structure i.e. CUpti_PCSamplingGetDataParamsSize
497
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
498
+ * available in the structure. Used to preserve backward compatibility.
499
+ */
500
+ size_t size;
501
+ /**
502
+ * [w] Assign to NULL
503
+ */
504
+ void* pPriv;
505
+ /**
506
+ * [w] CUcontext
507
+ */
508
+ CUcontext ctx;
509
+ /**
510
+ * \param pcSamplingData Data buffer to hold collected PC Sampling data PARSED_DATA
511
+ * Buffer type is void * which can point to PARSED_DATA
512
+ * Refer \ref CUpti_PCSamplingData for buffer format for PARSED_DATA
513
+ */
514
+ void *pcSamplingData;
515
+ } CUpti_PCSamplingGetDataParams;
516
+ #define CUpti_PCSamplingGetDataParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingGetDataParams, pcSamplingData)
517
+ /**
518
+ * \brief Flush GPU PC sampling data periodically.
519
+ *
520
+ * Flushing of GPU PC Sampling data is required at following point to maintain uniqueness of PCs:
521
+ * For \brief CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS, after every module load-unload-load
522
+ * For \brief CUPTI_PC_SAMPLING_COLLECTION_MODE_KERNEL_SERIALIZED, after every kernel ends
523
+ * If configuration option \brief CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL
524
+ * is enabled, then after every range end i.e. \brief cuptiPCSamplingStop()
525
+ *
526
+ * If application is profiled in \brief CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS, with disabled
527
+ * \brief CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL, and there is no module unload,
528
+ * user can collect data in two ways:
529
+ * Use \brief cuptiPCSamplingGetData() API periodically
530
+ * Use \brief cuptiPCSamplingDisable() on application exit and read GPU PC sampling data from sampling
531
+ * data buffer passed during configuration.
532
+ * Note: In case, \brief cuptiPCSamplingGetData() API is not called periodically, then sampling data buffer
533
+ * passed during configuration should be large enough to hold all PCs data.
534
+ * \brief cuptiPCSamplingGetData() API never does device synchronization.
535
+ * It is possible that when the API is called there is some unconsumed data from the HW buffer. In this case
536
+ * CUPTI provides only the data available with it at that moment.
537
+ *
538
+ * \param Refer \ref CUpti_PCSamplingGetDataParams
539
+ *
540
+ * \retval CUPTI_SUCCESS
541
+ * \retval CUPTI_ERROR_INVALID_OPERATION if this API is called without
542
+ * enabling PC sampling.
543
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
544
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
545
+ * does not support the API
546
+ */
547
+ CUptiResult CUPTIAPI cuptiPCSamplingGetData(CUpti_PCSamplingGetDataParams *pParams);
548
+
549
+ /**
550
+ * \brief Params for cuptiPCSamplingEnable
551
+ */
552
+ typedef struct
553
+ {
554
+ /**
555
+ * [w] Size of the data structure i.e. CUpti_PCSamplingEnableParamsSize
556
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
557
+ * available in the structure. Used to preserve backward compatibility.
558
+ */
559
+ size_t size;
560
+ /**
561
+ * [w] Assign to NULL
562
+ */
563
+ void* pPriv;
564
+ /**
565
+ * [w] CUcontext
566
+ */
567
+ CUcontext ctx;
568
+ } CUpti_PCSamplingEnableParams;
569
+ #define CUpti_PCSamplingEnableParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingEnableParams, ctx)
570
+
571
+ /**
572
+ * \brief Enable PC sampling.
573
+ *
574
+ * \param Refer \ref CUpti_PCSamplingEnableParams
575
+ *
576
+ * \retval CUPTI_SUCCESS
577
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
578
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
579
+ * does not support the API
580
+ */
581
+ CUptiResult CUPTIAPI cuptiPCSamplingEnable(CUpti_PCSamplingEnableParams *pParams);
582
+
583
+ /**
584
+ * \brief Params for cuptiPCSamplingDisable
585
+ */
586
+ typedef struct
587
+ {
588
+ /**
589
+ * [w] Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
590
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
591
+ * available in the structure. Used to preserve backward compatibility.
592
+ */
593
+ size_t size;
594
+ /**
595
+ * [w] Assign to NULL
596
+ */
597
+ void* pPriv;
598
+ /**
599
+ * [w] CUcontext
600
+ */
601
+ CUcontext ctx;
602
+ } CUpti_PCSamplingDisableParams;
603
+ #define CUpti_PCSamplingDisableParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingDisableParams, ctx)
604
+
605
+ /**
606
+ * \brief Disable PC sampling.
607
+ *
608
+ * For application which doesn't destroy the CUDA context explicitly,
609
+ * this API does the PC Sampling tear-down, joins threads and copies PC records in the buffer provided
610
+ * during the PC sampling configuration. PC records which can't be accommodated in the buffer are discarded.
611
+ *
612
+ * \param Refer \ref CUpti_PCSamplingDisableParams
613
+ *
614
+ * \retval CUPTI_SUCCESS
615
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
616
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
617
+ * does not support the API
618
+ */
619
+ CUptiResult CUPTIAPI cuptiPCSamplingDisable(CUpti_PCSamplingDisableParams *pParams);
620
+
621
+ /**
622
+ * \brief Params for cuptiPCSamplingStart
623
+ */
624
+ typedef struct
625
+ {
626
+ /**
627
+ * [w] Size of the data structure i.e. CUpti_PCSamplingStartParamsSize
628
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
629
+ * available in the structure. Used to preserve backward compatibility.
630
+ */
631
+ size_t size;
632
+ /**
633
+ * [w] Assign to NULL
634
+ */
635
+ void* pPriv;
636
+ /**
637
+ * [w] CUcontext
638
+ */
639
+ CUcontext ctx;
640
+ } CUpti_PCSamplingStartParams;
641
+ #define CUpti_PCSamplingStartParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingStartParams, ctx)
642
+
643
+ /**
644
+ * \brief Start PC sampling.
645
+ *
646
+ * User can collect PC Sampling data for user-defined range specified by Start/Stop APIs.
647
+ * This API can be used to mark starting of range. Set configuration option
648
+ * \brief CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL to use this API.
649
+ *
650
+ * \param Refer \ref CUpti_PCSamplingStartParams
651
+ *
652
+ * \retval CUPTI_SUCCESS
653
+ * \retval CUPTI_ERROR_INVALID_OPERATION if this API is called with
654
+ * incorrect PC Sampling configuration.
655
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
656
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
657
+ * does not support the API
658
+ */
659
+ CUptiResult CUPTIAPI cuptiPCSamplingStart(CUpti_PCSamplingStartParams *pParams);
660
+
661
+ /**
662
+ * \brief Params for cuptiPCSamplingStop
663
+ */
664
+ typedef struct
665
+ {
666
+ /**
667
+ * [w] Size of the data structure i.e. CUpti_PCSamplingStopParamsSize
668
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
669
+ * available in the structure. Used to preserve backward compatibility.
670
+ */
671
+ size_t size;
672
+ /**
673
+ * [w] Assign to NULL
674
+ */
675
+ void* pPriv;
676
+ /**
677
+ * [w] CUcontext
678
+ */
679
+ CUcontext ctx;
680
+ } CUpti_PCSamplingStopParams;
681
+ #define CUpti_PCSamplingStopParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingStopParams, ctx)
682
+
683
+ /**
684
+ * \brief Stop PC sampling.
685
+ *
686
+ * User can collect PC Sampling data for user-defined range specified by Start/Stop APIs.
687
+ * This API can be used to mark end of range. Set configuration option
688
+ * \brief CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL to use this API.
689
+ *
690
+ * \param Refer \ref CUpti_PCSamplingStopParams
691
+ *
692
+ * \retval CUPTI_SUCCESS
693
+ * \retval CUPTI_ERROR_INVALID_OPERATION if this API is called with
694
+ * incorrect PC Sampling configuration.
695
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
696
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
697
+ * does not support the API
698
+ */
699
+ CUptiResult CUPTIAPI cuptiPCSamplingStop(CUpti_PCSamplingStopParams *pParams);
700
+
701
+ /**
702
+ * \brief Params for cuptiPCSamplingGetNumStallReasons
703
+ */
704
+ typedef struct
705
+ {
706
+ /**
707
+ * [w] Size of the data structure i.e. CUpti_PCSamplingGetNumStallReasonsParamsSize
708
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
709
+ * available in the structure. Used to preserve backward compatibility.
710
+ */
711
+ size_t size;
712
+ /**
713
+ * [w] Assign to NULL
714
+ */
715
+ void* pPriv;
716
+ /**
717
+ * [w] CUcontext
718
+ */
719
+ CUcontext ctx;
720
+ /**
721
+ * [r] Number of stall reasons
722
+ */
723
+ size_t *numStallReasons;
724
+ } CUpti_PCSamplingGetNumStallReasonsParams;
725
+ #define CUpti_PCSamplingGetNumStallReasonsParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingGetNumStallReasonsParams, numStallReasons)
726
+
727
+ /**
728
+ * \brief Get PC sampling stall reason count.
729
+ *
730
+ * \param Refer \ref CUpti_PCSamplingGetNumStallReasonsParams
731
+ *
732
+ * \retval CUPTI_SUCCESS
733
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
734
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
735
+ * does not support the API
736
+ */
737
+ CUptiResult CUPTIAPI cuptiPCSamplingGetNumStallReasons(CUpti_PCSamplingGetNumStallReasonsParams *pParams);
738
+
739
+ /**
740
+ * \brief Params for cuptiPCSamplingGetStallReasons
741
+ */
742
+ typedef struct
743
+ {
744
+ /**
745
+ * [w] Size of the data structure i.e. CUpti_PCSamplingGetStallReasonsParamsSize
746
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
747
+ * available in the structure. Used to preserve backward compatibility.
748
+ */
749
+ size_t size;
750
+ /**
751
+ * [w] Assign to NULL
752
+ */
753
+ void* pPriv;
754
+ /**
755
+ * [w] CUcontext
756
+ */
757
+ CUcontext ctx;
758
+ /**
759
+ * [w] Number of stall reasons
760
+ */
761
+ size_t numStallReasons;
762
+ /**
763
+ * [r] Stall reason index
764
+ */
765
+ uint32_t *stallReasonIndex;
766
+ /**
767
+ * [r] Stall reasons name
768
+ */
769
+ char **stallReasons;
770
+ } CUpti_PCSamplingGetStallReasonsParams;
771
+ #define CUpti_PCSamplingGetStallReasonsParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingGetStallReasonsParams, stallReasons)
772
+
773
+ /**
774
+ * \brief Get PC sampling stall reasons.
775
+ *
776
+ * \param Refer \ref CUpti_PCSamplingGetStallReasonsParams
777
+ *
778
+ * \retval CUPTI_SUCCESS
779
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
780
+ * \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
781
+ * does not support the API
782
+ */
783
+ CUptiResult CUPTIAPI cuptiPCSamplingGetStallReasons(CUpti_PCSamplingGetStallReasonsParams *pParams);
784
+
785
+ /**
786
+ * \brief Params for cuptiGetSassToSourceCorrelation
787
+ */
788
+ typedef struct {
789
+ /**
790
+ * [w] Size of the data structure i.e. CUpti_GetSassToSourceCorrelationParamsSize
791
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
792
+ * available in the structure. Used to preserve backward compatibility.
793
+ */
794
+ size_t size;
795
+ /**
796
+ * [w] Pointer to cubin binary where function belongs.
797
+ */
798
+ const void* cubin;
799
+ /**
800
+ * [w] Function name to which PC belongs.
801
+ */
802
+ const char *functionName;
803
+ /**
804
+ * [w] Size of cubin binary.
805
+ */
806
+ size_t cubinSize;
807
+ /**
808
+ * [r] Line number in the source code.
809
+ */
810
+ uint32_t lineNumber;
811
+ /**
812
+ * [w] PC offset
813
+ */
814
+ uint64_t pcOffset;
815
+ /**
816
+ * [r] Path for the source file.
817
+ */
818
+ char *fileName;
819
+ /**
820
+ * [r] Path for the directory of source file.
821
+ */
822
+ char *dirName;
823
+ } CUpti_GetSassToSourceCorrelationParams;
824
+ #define CUpti_GetSassToSourceCorrelationParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_GetSassToSourceCorrelationParams, dirName)
825
+
826
+ /**
827
+ * \brief SASS to Source correlation.
828
+ *
829
+ * \param Refer \ref CUpti_GetSassToSourceCorrelationParams
830
+ *
831
+ * It is expected from user to free allocated memory for fileName and dirName after use.
832
+ *
833
+ * \retval CUPTI_SUCCESS
834
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if either of the parameters cubin or functionName
835
+ * is NULL or cubinSize is zero or size field is not set correctly.
836
+ * \retval CUPTI_ERROR_INVALID_MODULE provided cubin is invalid.
837
+ * \retval CUPTI_ERROR_UNKNOWN an internal error occurred.
838
+ * This error code is also used for cases when the function is not present in the module.
839
+ * A better error code will be returned in the future release.
840
+ */
841
+ CUptiResult CUPTIAPI cuptiGetSassToSourceCorrelation(CUpti_GetSassToSourceCorrelationParams *pParams);
842
+
843
+ /**
844
+ * \brief Params for cuptiGetCubinCrc
845
+ */
846
+ typedef struct {
847
+ /**
848
+ * [w] Size of configuration structure.
849
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
850
+ * available in the structure. Used to preserve backward compatibility.
851
+ */
852
+ size_t size;
853
+ /**
854
+ * [w] Size of cubin binary.
855
+ */
856
+ size_t cubinSize;
857
+ /**
858
+ * [w] Pointer to cubin binary
859
+ */
860
+ const void* cubin;
861
+ /**
862
+ * [r] Computed CRC will be stored in it.
863
+ */
864
+ uint64_t cubinCrc;
865
+ } CUpti_GetCubinCrcParams;
866
+ #define CUpti_GetCubinCrcParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_GetCubinCrcParams, cubinCrc)
867
+
868
+ /**
869
+ * \brief Get the CRC of cubin.
870
+ *
871
+ * This function returns the CRC of provided cubin binary.
872
+ *
873
+ * \param Refer \ref CUpti_GetCubinCrcParams
874
+ *
875
+ * \retval CUPTI_SUCCESS
876
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if parameter cubin is NULL or
877
+ * provided cubinSize is zero or size field is not set.
878
+ */
879
+ CUptiResult CUPTIAPI cuptiGetCubinCrc(CUpti_GetCubinCrcParams *pParams);
880
+
881
+ /**
882
+ * \brief Function type for callback used by CUPTI to request crc of
883
+ * loaded module.
884
+ *
885
+ * This callback function ask for crc of provided module in function.
886
+ * The provided crc will be stored in PC sampling records i.e. in the field 'cubinCrc' of the PC sampling
887
+ * struct CUpti_PCSamplingPCData. The CRC is uses during the offline source correlation to uniquely identify the module.
888
+ *
889
+ * \param cubin The pointer to cubin binary
890
+ * \param cubinSize The size of cubin binary.
891
+ * \param cubinCrc Returns the computed crc of cubin.
892
+ */
893
+ typedef void (CUPTIAPI *CUpti_ComputeCrcCallbackFunc)(
894
+ const void* cubin,
895
+ size_t cubinSize,
896
+ uint64_t *cubinCrc);
897
+
898
+ /**
899
+ * \brief Register callback function with CUPTI to use
900
+ * your own algorithm to compute cubin crc.
901
+ *
902
+ * This function registers a callback function and it gets called
903
+ * from CUPTI when a CUDA module is loaded.
904
+ *
905
+ * \param funcComputeCubinCrc callback is invoked when a CUDA module
906
+ * is loaded.
907
+ *
908
+ * \retval CUPTI_SUCCESS
909
+ * \retval CUPTI_ERROR_INVALID_PARAMETER if \p funcComputeCubinCrc is NULL.
910
+ */
911
+ CUptiResult CUPTIAPI cuptiRegisterComputeCrcCallback(CUpti_ComputeCrcCallbackFunc funcComputeCubinCrc);
912
+
913
+ /** @} */ /* END CUPTI_PCSAMPLING_API */
914
+
915
+ #if defined(__GNUC__) && defined(CUPTI_LIB)
916
+ #pragma GCC visibility pop
917
+ #endif
918
+
919
+ #if defined(__cplusplus)
920
+ }
921
+ #endif
922
+
923
+ #endif /*_CUPTI_PCSAMPLING_H_*/
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (220 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_infer_v8.h ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_cnn_infer : cuDNN's basic definitions and inference CNN functions.
52
+ */
53
+
54
+ #if !defined(CUDNN_CNN_INFER_H_)
55
+ #define CUDNN_CNN_INFER_H_
56
+
57
+ #pragma once
58
+ #include <cuda_runtime.h>
59
+ #include <stdint.h>
60
+
61
+ #include "cudnn_version.h"
62
+ #include "cudnn_ops_infer.h"
63
+
64
+ /* These version numbers are autogenerated, do not edit manually. */
65
+ #define CUDNN_CNN_INFER_MAJOR 8
66
+ #define CUDNN_CNN_INFER_MINOR 7
67
+ #define CUDNN_CNN_INFER_PATCH 0
68
+
69
+ #if (CUDNN_CNN_INFER_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_INFER_MINOR != CUDNN_MINOR) || \
70
+ (CUDNN_CNN_INFER_PATCH != CUDNN_PATCHLEVEL)
71
+ #error Version mismatch in cuDNN CNN INFER!!!
72
+ #endif
73
+
74
+ #if defined(__cplusplus)
75
+ extern "C" {
76
+ #endif
77
+
78
+ typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t;
79
+
80
+ /*
81
+ * convolution mode
82
+ */
83
+ typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t;
84
+
85
+ /*
86
+ * CUDNN Reorder
87
+ */
88
+ typedef enum {
89
+ CUDNN_DEFAULT_REORDER = 0,
90
+ CUDNN_NO_REORDER = 1,
91
+ } cudnnReorderType_t;
92
+
93
+ typedef struct cudnnConvolutionFwdAlgoPerfStruct {
94
+ cudnnConvolutionFwdAlgo_t algo;
95
+ cudnnStatus_t status;
96
+ float time;
97
+ size_t memory;
98
+ cudnnDeterminism_t determinism;
99
+ cudnnMathType_t mathType;
100
+ int reserved[3];
101
+ } cudnnConvolutionFwdAlgoPerf_t;
102
+
103
+ /* Create an instance of convolution descriptor */
104
+ cudnnStatus_t CUDNNWINAPI
105
+ cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc);
106
+
107
+ /* Destroy an instance of convolution descriptor */
108
+ cudnnStatus_t CUDNNWINAPI
109
+ cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc);
110
+
111
+ cudnnStatus_t CUDNNWINAPI
112
+ cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType);
113
+
114
+ cudnnStatus_t CUDNNWINAPI
115
+ cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType);
116
+
117
+ cudnnStatus_t CUDNNWINAPI
118
+ cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount);
119
+
120
+ cudnnStatus_t CUDNNWINAPI
121
+ cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount);
122
+
123
+ cudnnStatus_t CUDNNWINAPI
124
+ cudnnSetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType);
125
+
126
+ cudnnStatus_t CUDNNWINAPI
127
+ cudnnGetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType);
128
+
129
+ cudnnStatus_t CUDNNWINAPI
130
+ cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc,
131
+ int pad_h, /* zero-padding height */
132
+ int pad_w, /* zero-padding width */
133
+ int u, /* vertical filter stride */
134
+ int v, /* horizontal filter stride */
135
+ int dilation_h, /* filter dilation in the vertical dimension */
136
+ int dilation_w, /* filter dilation in the horizontal dimension */
137
+ cudnnConvolutionMode_t mode,
138
+ cudnnDataType_t computeType);
139
+
140
+ cudnnStatus_t CUDNNWINAPI
141
+ cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc,
142
+ int *pad_h, /* zero-padding height */
143
+ int *pad_w, /* zero-padding width */
144
+ int *u, /* vertical filter stride */
145
+ int *v, /* horizontal filter stride */
146
+ int *dilation_h, /* filter dilation in the vertical dimension */
147
+ int *dilation_w, /* filter dilation in the horizontal dimension */
148
+ cudnnConvolutionMode_t *mode,
149
+ cudnnDataType_t *computeType);
150
+
151
+ cudnnStatus_t CUDNNWINAPI
152
+ cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc,
153
+ int arrayLength, /* nbDims-2 size */
154
+ const int padA[],
155
+ const int filterStrideA[],
156
+ const int dilationA[],
157
+ cudnnConvolutionMode_t mode,
158
+ cudnnDataType_t computeType); /* convolution data type */
159
+
160
+ /* Helper function to return the dimensions of the output tensor given a convolution descriptor */
161
+ cudnnStatus_t CUDNNWINAPI
162
+ cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc,
163
+ int arrayLengthRequested,
164
+ int *arrayLength,
165
+ int padA[],
166
+ int strideA[],
167
+ int dilationA[],
168
+ cudnnConvolutionMode_t *mode,
169
+ cudnnDataType_t *computeType); /* convolution data type */
170
+
171
+ cudnnStatus_t CUDNNWINAPI
172
+ cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
173
+ const cudnnTensorDescriptor_t inputTensorDesc,
174
+ const cudnnFilterDescriptor_t filterDesc,
175
+ int *n,
176
+ int *c,
177
+ int *h,
178
+ int *w);
179
+
180
+ /* Helper function to return the dimensions of the output tensor given a convolution descriptor */
181
+ cudnnStatus_t CUDNNWINAPI
182
+ cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
183
+ const cudnnTensorDescriptor_t inputTensorDesc,
184
+ const cudnnFilterDescriptor_t filterDesc,
185
+ int nbDims,
186
+ int tensorOuputDimA[]);
187
+
188
+ /* helper function to provide the convolution forward algo that fit best the requirement */
189
+ cudnnStatus_t CUDNNWINAPI
190
+ cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count);
191
+
192
+ cudnnStatus_t CUDNNWINAPI
193
+ cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle,
194
+ const cudnnTensorDescriptor_t srcDesc,
195
+ const cudnnFilterDescriptor_t filterDesc,
196
+ const cudnnConvolutionDescriptor_t convDesc,
197
+ const cudnnTensorDescriptor_t destDesc,
198
+ const int requestedAlgoCount,
199
+ int *returnedAlgoCount,
200
+ cudnnConvolutionFwdAlgoPerf_t *perfResults);
201
+
202
+ cudnnStatus_t CUDNNWINAPI
203
+ cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle,
204
+ const cudnnTensorDescriptor_t xDesc,
205
+ const cudnnFilterDescriptor_t wDesc,
206
+ const cudnnConvolutionDescriptor_t convDesc,
207
+ const cudnnTensorDescriptor_t yDesc,
208
+ const int requestedAlgoCount,
209
+ int *returnedAlgoCount,
210
+ cudnnConvolutionFwdAlgoPerf_t *perfResults);
211
+
212
+ cudnnStatus_t CUDNNWINAPI
213
+ cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle,
214
+ const cudnnTensorDescriptor_t xDesc,
215
+ const void *x,
216
+ const cudnnFilterDescriptor_t wDesc,
217
+ const void *w,
218
+ const cudnnConvolutionDescriptor_t convDesc,
219
+ const cudnnTensorDescriptor_t yDesc,
220
+ void *y,
221
+ const int requestedAlgoCount,
222
+ int *returnedAlgoCount,
223
+ cudnnConvolutionFwdAlgoPerf_t *perfResults,
224
+ void *workSpace,
225
+ size_t workSpaceSizeInBytes);
226
+
227
+ cudnnStatus_t CUDNNWINAPI
228
+ cudnnIm2Col(cudnnHandle_t handle,
229
+ const cudnnTensorDescriptor_t xDesc,
230
+ const void *x,
231
+ const cudnnFilterDescriptor_t wDesc,
232
+ const cudnnConvolutionDescriptor_t convDesc,
233
+ void *colBuffer);
234
+
235
+ cudnnStatus_t CUDNNWINAPI
236
+ cudnnReorderFilterAndBias(cudnnHandle_t handle,
237
+ const cudnnFilterDescriptor_t filterDesc,
238
+ cudnnReorderType_t reorderType,
239
+ const void *filterData,
240
+ void *reorderedFilterData,
241
+ int reorderBias,
242
+ const void *biasData,
243
+ void *reorderedBiasData);
244
+
245
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
246
+ cudnnStatus_t CUDNNWINAPI
247
+ cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle,
248
+ const cudnnTensorDescriptor_t xDesc,
249
+ const cudnnFilterDescriptor_t wDesc,
250
+ const cudnnConvolutionDescriptor_t convDesc,
251
+ const cudnnTensorDescriptor_t yDesc,
252
+ cudnnConvolutionFwdAlgo_t algo,
253
+ size_t *sizeInBytes);
254
+
255
+ /* Convolution functions: All of the form "output = alpha * Op(inputs) + beta * output" */
256
+
257
+ /* Function to perform the forward pass for batch convolution */
258
+ cudnnStatus_t CUDNNWINAPI
259
+ cudnnConvolutionForward(cudnnHandle_t handle,
260
+ const void *alpha,
261
+ const cudnnTensorDescriptor_t xDesc,
262
+ const void *x,
263
+ const cudnnFilterDescriptor_t wDesc,
264
+ const void *w,
265
+ const cudnnConvolutionDescriptor_t convDesc,
266
+ cudnnConvolutionFwdAlgo_t algo,
267
+ void *workSpace,
268
+ size_t workSpaceSizeInBytes,
269
+ const void *beta,
270
+ const cudnnTensorDescriptor_t yDesc,
271
+ void *y);
272
+
273
+ /* Fused conv/bias/activation operation : y = Act( alpha1 * conv(x) + alpha2 * z + bias ) */
274
+ cudnnStatus_t CUDNNWINAPI
275
+ cudnnConvolutionBiasActivationForward(cudnnHandle_t handle,
276
+ const void *alpha1,
277
+ const cudnnTensorDescriptor_t xDesc,
278
+ const void *x,
279
+ const cudnnFilterDescriptor_t wDesc,
280
+ const void *w,
281
+ const cudnnConvolutionDescriptor_t convDesc,
282
+ cudnnConvolutionFwdAlgo_t algo,
283
+ void *workSpace,
284
+ size_t workSpaceSizeInBytes,
285
+ const void *alpha2,
286
+ const cudnnTensorDescriptor_t zDesc,
287
+ const void *z,
288
+ const cudnnTensorDescriptor_t biasDesc,
289
+ const void *bias,
290
+ const cudnnActivationDescriptor_t activationDesc,
291
+ const cudnnTensorDescriptor_t yDesc,
292
+ void *y);
293
+
294
+ /* helper function to provide the convolution backward data algo that fit best the requirement */
295
+
296
+ typedef struct cudnnConvolutionBwdDataAlgoPerfStruct {
297
+ cudnnConvolutionBwdDataAlgo_t algo;
298
+ cudnnStatus_t status;
299
+ float time;
300
+ size_t memory;
301
+ cudnnDeterminism_t determinism;
302
+ cudnnMathType_t mathType;
303
+ int reserved[3];
304
+ } cudnnConvolutionBwdDataAlgoPerf_t;
305
+
306
+ cudnnStatus_t CUDNNWINAPI
307
+ cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count);
308
+
309
+ cudnnStatus_t CUDNNWINAPI
310
+ cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
311
+ const cudnnFilterDescriptor_t wDesc,
312
+ const cudnnTensorDescriptor_t dyDesc,
313
+ const cudnnConvolutionDescriptor_t convDesc,
314
+ const cudnnTensorDescriptor_t dxDesc,
315
+ const int requestedAlgoCount,
316
+ int *returnedAlgoCount,
317
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
318
+
319
+ cudnnStatus_t CUDNNWINAPI
320
+ cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle,
321
+ const cudnnFilterDescriptor_t wDesc,
322
+ const void *w,
323
+ const cudnnTensorDescriptor_t dyDesc,
324
+ const void *dy,
325
+ const cudnnConvolutionDescriptor_t convDesc,
326
+ const cudnnTensorDescriptor_t dxDesc,
327
+ void *dx,
328
+ const int requestedAlgoCount,
329
+ int *returnedAlgoCount,
330
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults,
331
+ void *workSpace,
332
+ size_t workSpaceSizeInBytes);
333
+
334
+ cudnnStatus_t CUDNNWINAPI
335
+ cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle,
336
+ const cudnnFilterDescriptor_t filterDesc,
337
+ const cudnnTensorDescriptor_t diffDesc,
338
+ const cudnnConvolutionDescriptor_t convDesc,
339
+ const cudnnTensorDescriptor_t gradDesc,
340
+ const int requestedAlgoCount,
341
+ int *returnedAlgoCount,
342
+ cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
343
+
344
+ /*
345
+ * convolution algorithm (which requires potentially some workspace)
346
+ */
347
+
348
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
349
+ cudnnStatus_t CUDNNWINAPI
350
+ cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle,
351
+ const cudnnFilterDescriptor_t wDesc,
352
+ const cudnnTensorDescriptor_t dyDesc,
353
+ const cudnnConvolutionDescriptor_t convDesc,
354
+ const cudnnTensorDescriptor_t dxDesc,
355
+ cudnnConvolutionBwdDataAlgo_t algo,
356
+ size_t *sizeInBytes);
357
+
358
+ cudnnStatus_t CUDNNWINAPI
359
+ cudnnConvolutionBackwardData(cudnnHandle_t handle,
360
+ const void *alpha,
361
+ const cudnnFilterDescriptor_t wDesc,
362
+ const void *w,
363
+ const cudnnTensorDescriptor_t dyDesc,
364
+ const void *dy,
365
+ const cudnnConvolutionDescriptor_t convDesc,
366
+ cudnnConvolutionBwdDataAlgo_t algo,
367
+ void *workSpace,
368
+ size_t workSpaceSizeInBytes,
369
+ const void *beta,
370
+ const cudnnTensorDescriptor_t dxDesc,
371
+ void *dx);
372
+
373
+ /* Helper function to calculate folding descriptors for dgrad */
374
+ cudnnStatus_t CUDNNWINAPI
375
+ cudnnGetFoldedConvBackwardDataDescriptors(const cudnnHandle_t handle,
376
+ const cudnnFilterDescriptor_t filterDesc,
377
+ const cudnnTensorDescriptor_t diffDesc,
378
+ const cudnnConvolutionDescriptor_t convDesc,
379
+ const cudnnTensorDescriptor_t gradDesc,
380
+ const cudnnTensorFormat_t transformFormat,
381
+ cudnnFilterDescriptor_t foldedFilterDesc,
382
+ cudnnTensorDescriptor_t paddedDiffDesc,
383
+ cudnnConvolutionDescriptor_t foldedConvDesc,
384
+ cudnnTensorDescriptor_t foldedGradDesc,
385
+ cudnnTensorTransformDescriptor_t filterFoldTransDesc,
386
+ cudnnTensorTransformDescriptor_t diffPadTransDesc,
387
+ cudnnTensorTransformDescriptor_t gradFoldTransDesc,
388
+ cudnnTensorTransformDescriptor_t gradUnfoldTransDesc);
389
+
390
+ /* cudnnFusedOps... */
391
+ struct cudnnFusedOpsConstParamStruct;
392
+ typedef struct cudnnFusedOpsConstParamStruct *cudnnFusedOpsConstParamPack_t;
393
+
394
+ struct cudnnFusedOpsVariantParamStruct;
395
+ typedef struct cudnnFusedOpsVariantParamStruct *cudnnFusedOpsVariantParamPack_t;
396
+
397
+ struct cudnnFusedOpsPlanStruct;
398
+ typedef struct cudnnFusedOpsPlanStruct *cudnnFusedOpsPlan_t;
399
+
400
+ typedef enum {
401
+ /* each op in [ ] can be disabled by passing NULL ptr */
402
+ /* [per channel scale], [per channel bias], [activation], convolution, [generate BN stats] */
403
+ CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS = 0,
404
+ /* [per channel scale], [per channel bias], [activation], convolutionBackwardWeights */
405
+ CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD = 1,
406
+ /* utility for BN training in BN-conv fusion */
407
+ /* computes the equivalent scale and bias from ySum ySqSum and learned scale, bias */
408
+ /* optionally update running stats and generate saved stats */
409
+ CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING = 2,
410
+ /* utility for BN inference in BN-conv fusion */
411
+ /* computes the equivalent scale and bias from learned running stats and learned scale, bias */
412
+ CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE = 3,
413
+ /* reserved for future use: convolution, [per channel scale], [per channel bias], [residual add], [activation] */
414
+ CUDNN_FUSED_CONV_SCALE_BIAS_ADD_ACTIVATION = 4,
415
+ /* reserved for future use: [per channel scale], [per channel bias], [residual add], activation, bitmask */
416
+ CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK = 5,
417
+ /* reserved for future use */
418
+ CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM = 6,
419
+ } cudnnFusedOps_t;
420
+
421
+ typedef enum {
422
+ /* set XDESC: pass previously initialized cudnnTensorDescriptor_t */
423
+ /* get XDESC: pass previously created cudnnTensorDescriptor_t */
424
+ CUDNN_PARAM_XDESC = 0,
425
+ /* set/get XDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
426
+ CUDNN_PARAM_XDATA_PLACEHOLDER = 1,
427
+ /* set/get BN_MODE: pass cudnnBatchNormMode_t* */
428
+ CUDNN_PARAM_BN_MODE = 2,
429
+ /* set CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
430
+ /* get CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
431
+ CUDNN_PARAM_BN_EQSCALEBIAS_DESC = 3,
432
+ /* set/get BN_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
433
+ CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER = 4,
434
+ /* set/get BN_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
435
+ CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER = 5,
436
+ /* set ACTIVATION_DESC: pass previously initialized cudnnActivationDescriptor_t */
437
+ /* get ACTIVATION_DESC: pass previously created cudnnActivationDescriptor_t */
438
+ CUDNN_PARAM_ACTIVATION_DESC = 6,
439
+ /* set CONV_DESC: pass previously initialized cudnnConvolutionDescriptor_t */
440
+ /* get CONV_DESC: pass previously created cudnnConvolutionDescriptor_t */
441
+ CUDNN_PARAM_CONV_DESC = 7,
442
+ /* set WDESC: pass previously initialized cudnnFilterDescriptor_t */
443
+ /* get WDESC: pass previously created cudnnFilterDescriptor_t */
444
+ CUDNN_PARAM_WDESC = 8,
445
+ /* set/get WDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
446
+ CUDNN_PARAM_WDATA_PLACEHOLDER = 9,
447
+ /* set DWDESC: pass previously initialized cudnnFilterDescriptor_t */
448
+ /* get DWDESC: pass previously created cudnnFilterDescriptor_t */
449
+ CUDNN_PARAM_DWDESC = 10,
450
+ /* set/get DWDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
451
+ CUDNN_PARAM_DWDATA_PLACEHOLDER = 11,
452
+ /* set YDESC: pass previously initialized cudnnTensorDescriptor_t */
453
+ /* get YDESC: pass previously created cudnnTensorDescriptor_t */
454
+ CUDNN_PARAM_YDESC = 12,
455
+ /* set/get YDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
456
+ CUDNN_PARAM_YDATA_PLACEHOLDER = 13,
457
+ /* set DYDESC: pass previously initialized cudnnTensorDescriptor_t */
458
+ /* get DYDESC: pass previously created cudnnTensorDescriptor_t */
459
+ CUDNN_PARAM_DYDESC = 14,
460
+ /* set/get DYDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
461
+ CUDNN_PARAM_DYDATA_PLACEHOLDER = 15,
462
+ /* set YSTATS_DESC: pass previously initialized cudnnTensorDescriptor_t */
463
+ /* get YSTATS_DESC: pass previously created cudnnTensorDescriptor_t */
464
+ CUDNN_PARAM_YSTATS_DESC = 16,
465
+ /* set/get YSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
466
+ CUDNN_PARAM_YSUM_PLACEHOLDER = 17,
467
+ /* set/get YSQSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
468
+ CUDNN_PARAM_YSQSUM_PLACEHOLDER = 18,
469
+ /* set CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously initialized cudnnTensorDescriptor_t */
470
+ /* get CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously created cudnnTensorDescriptor_t */
471
+ CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC = 19,
472
+ /* set/get CUDNN_PARAM_BN_SCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
473
+ CUDNN_PARAM_BN_SCALE_PLACEHOLDER = 20,
474
+ /* set/get CUDNN_PARAM_BN_BIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
475
+ CUDNN_PARAM_BN_BIAS_PLACEHOLDER = 21,
476
+ /* set/get CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
477
+ CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER = 22,
478
+ /* set/get CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
479
+ CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER = 23,
480
+ /* set/get CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
481
+ CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER = 24,
482
+ /* set/get CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
483
+ CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER = 25,
484
+
485
+ /* set ZDESC: pass previously initialized cudnnTensorDescriptor_t */
486
+ /* get ZDESC: pass previously created cudnnTensorDescriptor_t */
487
+ CUDNN_PARAM_ZDESC = 26,
488
+ /* set/get ZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
489
+ CUDNN_PARAM_ZDATA_PLACEHOLDER = 27,
490
+ /* set BN_Z_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
491
+ /* get BN_Z_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
492
+ CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC = 28,
493
+ /* set/get BN_Z_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
494
+ CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER = 29,
495
+ /* set/get BN_Z_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
496
+ CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER = 30,
497
+
498
+ /* set ACTIVATION_BITMASK_DESC: pass previously initialized cudnnTensorDescriptor_t */
499
+ /* get ACTIVATION_BITMASK_DESC: pass previously created cudnnTensorDescriptor_t */
500
+ CUDNN_PARAM_ACTIVATION_BITMASK_DESC = 31,
501
+ /* set/get ACTIVATION_BITMASK_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
502
+ CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER = 32,
503
+
504
+ /* set DXDESC: pass previously initialized cudnnTensorDescriptor_t */
505
+ /* get DXDESC: pass previously created cudnnTensorDescriptor_t */
506
+ CUDNN_PARAM_DXDESC = 33,
507
+ /* set/get DXDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
508
+ CUDNN_PARAM_DXDATA_PLACEHOLDER = 34,
509
+ /* set DZDESC: pass previously initialized cudnnTensorDescriptor_t */
510
+ /* get DZDESC: pass previously created cudnnTensorDescriptor_t */
511
+ CUDNN_PARAM_DZDESC = 35,
512
+ /* set/get DZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
513
+ CUDNN_PARAM_DZDATA_PLACEHOLDER = 36,
514
+ /* set/get CUDNN_PARAM_BN_DSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
515
+ CUDNN_PARAM_BN_DSCALE_PLACEHOLDER = 37,
516
+ /* set/get CUDNN_PARAM_BN_DBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
517
+ CUDNN_PARAM_BN_DBIAS_PLACEHOLDER = 38,
518
+ } cudnnFusedOpsConstParamLabel_t;
519
+
520
+ typedef enum {
521
+ CUDNN_PTR_NULL = 0,
522
+ CUDNN_PTR_ELEM_ALIGNED = 1,
523
+ CUDNN_PTR_16B_ALIGNED = 2,
524
+ } cudnnFusedOpsPointerPlaceHolder_t;
525
+
526
+ typedef enum {
527
+ /* set: pass void* pointing to dev memory */
528
+ /* get: pass void** pointing to host memory */
529
+ CUDNN_PTR_XDATA = 0,
530
+ CUDNN_PTR_BN_EQSCALE = 1,
531
+ CUDNN_PTR_BN_EQBIAS = 2,
532
+ CUDNN_PTR_WDATA = 3,
533
+ CUDNN_PTR_DWDATA = 4,
534
+ CUDNN_PTR_YDATA = 5,
535
+ CUDNN_PTR_DYDATA = 6,
536
+ CUDNN_PTR_YSUM = 7,
537
+ CUDNN_PTR_YSQSUM = 8,
538
+ CUDNN_PTR_WORKSPACE = 9,
539
+ CUDNN_PTR_BN_SCALE = 10,
540
+ CUDNN_PTR_BN_BIAS = 11,
541
+ CUDNN_PTR_BN_SAVED_MEAN = 12,
542
+ CUDNN_PTR_BN_SAVED_INVSTD = 13,
543
+ CUDNN_PTR_BN_RUNNING_MEAN = 14,
544
+ CUDNN_PTR_BN_RUNNING_VAR = 15,
545
+ CUDNN_PTR_ZDATA = 16,
546
+ CUDNN_PTR_BN_Z_EQSCALE = 17,
547
+ CUDNN_PTR_BN_Z_EQBIAS = 18,
548
+ CUDNN_PTR_ACTIVATION_BITMASK = 19,
549
+ CUDNN_PTR_DXDATA = 20,
550
+ CUDNN_PTR_DZDATA = 21,
551
+ CUDNN_PTR_BN_DSCALE = 22,
552
+ CUDNN_PTR_BN_DBIAS = 23,
553
+
554
+ /* set/get: pass size_t* pointing to host memory */
555
+ CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES = 100,
556
+ /* set/get: pass int64_t* pointing to host memory */
557
+ CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT = 101,
558
+ /* set/get: pass double* pointing to host memory */
559
+ CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR = 102,
560
+ /* set/get: pass double* pointing to host memory */
561
+ CUDNN_SCALAR_DOUBLE_BN_EPSILON = 103,
562
+ } cudnnFusedOpsVariantParamLabel_t;
563
+
564
+ cudnnStatus_t CUDNNWINAPI
565
+ cudnnCnnInferVersionCheck(void);
566
+
567
+ #if defined(__cplusplus)
568
+ }
569
+ #endif
570
+
571
+ #endif /* CUDNN_CNN_INFER_H_ */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /**
51
+ * \file: The master cuDNN version file.
52
+ */
53
+
54
+ #ifndef CUDNN_VERSION_H_
55
+ #define CUDNN_VERSION_H_
56
+
57
+ #define CUDNN_MAJOR 8
58
+ #define CUDNN_MINOR 7
59
+ #define CUDNN_PATCHLEVEL 0
60
+
61
+ #define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
62
+
63
+ /* cannot use constexpr here since this is a C-only file */
64
+ /* Below is the max SM version this cuDNN library is aware of and supports natively */
65
+
66
+ #define CUDNN_MAX_SM_MAJOR_NUMBER 9
67
+ #define CUDNN_MAX_SM_MINOR_NUMBER 0
68
+ #define CUDNN_MAX_DEVICE_VERSION (CUDNN_MAX_SM_MAJOR_NUMBER * 100) + (CUDNN_MAX_SM_MINOR_NUMBER * 10)
69
+
70
+ #endif /* CUDNN_VERSION_H */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (213 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/cudalibxt.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2013,2014 NVIDIA Corporation. All rights reserved.
2
+ *
3
+ * NOTICE TO LICENSEE:
4
+ *
5
+ * The source code and/or documentation ("Licensed Deliverables") are
6
+ * subject to NVIDIA intellectual property rights under U.S. and
7
+ * international Copyright laws.
8
+ *
9
+ * The Licensed Deliverables contained herein are PROPRIETARY and
10
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
11
+ * conditions of a form of NVIDIA software license agreement by and
12
+ * between NVIDIA and Licensee ("License Agreement") or electronically
13
+ * accepted by Licensee. Notwithstanding any terms or conditions to
14
+ * the contrary in the License Agreement, reproduction or disclosure
15
+ * of the Licensed Deliverables to any third party without the express
16
+ * written consent of NVIDIA is prohibited.
17
+ *
18
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
19
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
20
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
21
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
22
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
23
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
24
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
25
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
26
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
27
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
28
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
29
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
30
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
31
+ * OF THESE LICENSED DELIVERABLES.
32
+ *
33
+ * U.S. Government End Users. These Licensed Deliverables are a
34
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
35
+ * 1995), consisting of "commercial computer software" and "commercial
36
+ * computer software documentation" as such terms are used in 48
37
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
38
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
39
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
40
+ * U.S. Government End Users acquire the Licensed Deliverables with
41
+ * only those rights set forth herein.
42
+ *
43
+ * Any use of the Licensed Deliverables in individual and commercial
44
+ * software must include, in the user documentation and internal
45
+ * comments to the code, the above Disclaimer and U.S. Government End
46
+ * Users Notice.
47
+ */
48
+
49
+ /*!
50
+ * \file cudalibxt.h
51
+ * \brief Public header file for the NVIDIA library multi-GPU support structures
52
+ */
53
+
54
+ #ifndef _CUDA_LIB_XT_H_
55
+ #define _CUDA_LIB_XT_H_
56
+ #include <cuda_runtime.h>
57
+
58
+ #define CUDA_XT_DESCRIPTOR_VERSION 0x01000000 // This is added to CUDART_VERSION
59
+
60
+ enum cudaXtCopyType_t {
61
+ LIB_XT_COPY_HOST_TO_DEVICE,
62
+ LIB_XT_COPY_DEVICE_TO_HOST,
63
+ LIB_XT_COPY_DEVICE_TO_DEVICE
64
+ } ;
65
+ typedef enum cudaXtCopyType_t cudaLibXtCopyType;
66
+
67
+ enum libFormat_t {
68
+ LIB_FORMAT_CUFFT = 0x0,
69
+ LIB_FORMAT_UNDEFINED = 0x1
70
+ };
71
+
72
+ typedef enum libFormat_t libFormat;
73
+
74
+ #define MAX_CUDA_DESCRIPTOR_GPUS 64
75
+
76
+ struct cudaXtDesc_t{
77
+ int version; //descriptor version
78
+ int nGPUs; //number of GPUs
79
+ int GPUs[MAX_CUDA_DESCRIPTOR_GPUS]; //array of device IDs
80
+ void *data[MAX_CUDA_DESCRIPTOR_GPUS]; //array of pointers to data, one per GPU
81
+ size_t size[MAX_CUDA_DESCRIPTOR_GPUS]; //array of data sizes, one per GPU
82
+ void *cudaXtState; //opaque CUDA utility structure
83
+ };
84
+ typedef struct cudaXtDesc_t cudaXtDesc;
85
+
86
+ struct cudaLibXtDesc_t{
87
+ int version; //descriptor version
88
+ cudaXtDesc *descriptor; //multi-GPU memory descriptor
89
+ libFormat library; //which library recognizes the format
90
+ int subFormat; //library specific enumerator of sub formats
91
+ void *libDescriptor; //library specific descriptor e.g. FFT transform plan object
92
+ };
93
+ typedef struct cudaLibXtDesc_t cudaLibXtDesc;
94
+
95
+
96
+ #endif
97
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/cufftXt.h ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /* Copyright 2005-2021 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * The source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * The Licensed Deliverables contained herein are PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*!
51
+ * \file cufftXt.h
52
+ * \brief Public header file for the NVIDIA CUDA FFT library (CUFFT)
53
+ */
54
+
55
+ #ifndef _CUFFTXT_H_
56
+ #define _CUFFTXT_H_
57
+ #include "cudalibxt.h"
58
+ #include "cufft.h"
59
+
60
+
61
+ #ifndef CUFFTAPI
62
+ #ifdef _WIN32
63
+ #define CUFFTAPI __stdcall
64
+ #else
65
+ #define CUFFTAPI
66
+ #endif
67
+ #endif
68
+
69
+ #ifdef __cplusplus
70
+ extern "C" {
71
+ #endif
72
+
73
+ //
74
+ // cufftXtSubFormat identifies the data layout of
75
+ // a memory descriptor owned by cufft.
76
+ // note that multi GPU cufft does not yet support out-of-place transforms
77
+ //
78
+
79
+ typedef enum cufftXtSubFormat_t {
80
+ CUFFT_XT_FORMAT_INPUT = 0x00, //by default input is in linear order across GPUs
81
+ CUFFT_XT_FORMAT_OUTPUT = 0x01, //by default output is in scrambled order depending on transform
82
+ CUFFT_XT_FORMAT_INPLACE = 0x02, //by default inplace is input order, which is linear across GPUs
83
+ CUFFT_XT_FORMAT_INPLACE_SHUFFLED = 0x03, //shuffled output order after execution of the transform
84
+ CUFFT_XT_FORMAT_1D_INPUT_SHUFFLED = 0x04, //shuffled input order prior to execution of 1D transforms
85
+ CUFFT_XT_FORMAT_DISTRIBUTED_INPUT = 0x05,
86
+ CUFFT_XT_FORMAT_DISTRIBUTED_OUTPUT = 0x06,
87
+ CUFFT_FORMAT_UNDEFINED = 0x07
88
+ } cufftXtSubFormat;
89
+
90
+ //
91
+ // cufftXtCopyType specifies the type of copy for cufftXtMemcpy
92
+ //
93
+ typedef enum cufftXtCopyType_t {
94
+ CUFFT_COPY_HOST_TO_DEVICE = 0x00,
95
+ CUFFT_COPY_DEVICE_TO_HOST = 0x01,
96
+ CUFFT_COPY_DEVICE_TO_DEVICE = 0x02,
97
+ CUFFT_COPY_UNDEFINED = 0x03
98
+ } cufftXtCopyType;
99
+
100
+ //
101
+ // cufftXtQueryType specifies the type of query for cufftXtQueryPlan
102
+ //
103
+ typedef enum cufftXtQueryType_t {
104
+ CUFFT_QUERY_1D_FACTORS = 0x00,
105
+ CUFFT_QUERY_UNDEFINED = 0x01
106
+ } cufftXtQueryType;
107
+
108
+ typedef struct cufftXt1dFactors_t {
109
+ long long int size;
110
+ long long int stringCount;
111
+ long long int stringLength;
112
+ long long int substringLength;
113
+ long long int factor1;
114
+ long long int factor2;
115
+ long long int stringMask;
116
+ long long int substringMask;
117
+ long long int factor1Mask;
118
+ long long int factor2Mask;
119
+ int stringShift;
120
+ int substringShift;
121
+ int factor1Shift;
122
+ int factor2Shift;
123
+ } cufftXt1dFactors;
124
+
125
+ //
126
+ // cufftXtWorkAreaPolicy specifies policy for cufftXtSetWorkAreaPolicy
127
+ //
128
+ typedef enum cufftXtWorkAreaPolicy_t {
129
+ CUFFT_WORKAREA_MINIMAL = 0, /* maximum reduction */
130
+ CUFFT_WORKAREA_USER = 1, /* use workSize parameter as limit */
131
+ CUFFT_WORKAREA_PERFORMANCE = 2, /* default - 1x overhead or more, maximum performance */
132
+ } cufftXtWorkAreaPolicy;
133
+
134
+ // multi-GPU routines
135
+ cufftResult CUFFTAPI cufftXtSetGPUs(cufftHandle handle, int nGPUs, int *whichGPUs);
136
+
137
+ cufftResult CUFFTAPI cufftXtMalloc(cufftHandle plan,
138
+ cudaLibXtDesc ** descriptor,
139
+ cufftXtSubFormat format);
140
+
141
+ cufftResult CUFFTAPI cufftXtMemcpy(cufftHandle plan,
142
+ void *dstPointer,
143
+ void *srcPointer,
144
+ cufftXtCopyType type);
145
+
146
+ cufftResult CUFFTAPI cufftXtFree(cudaLibXtDesc *descriptor);
147
+
148
+ cufftResult CUFFTAPI cufftXtSetWorkArea(cufftHandle plan, void **workArea);
149
+
150
+ cufftResult CUFFTAPI cufftXtExecDescriptorC2C(cufftHandle plan,
151
+ cudaLibXtDesc *input,
152
+ cudaLibXtDesc *output,
153
+ int direction);
154
+
155
+ cufftResult CUFFTAPI cufftXtExecDescriptorR2C(cufftHandle plan,
156
+ cudaLibXtDesc *input,
157
+ cudaLibXtDesc *output);
158
+
159
+ cufftResult CUFFTAPI cufftXtExecDescriptorC2R(cufftHandle plan,
160
+ cudaLibXtDesc *input,
161
+ cudaLibXtDesc *output);
162
+
163
+ cufftResult CUFFTAPI cufftXtExecDescriptorZ2Z(cufftHandle plan,
164
+ cudaLibXtDesc *input,
165
+ cudaLibXtDesc *output,
166
+ int direction);
167
+
168
+ cufftResult CUFFTAPI cufftXtExecDescriptorD2Z(cufftHandle plan,
169
+ cudaLibXtDesc *input,
170
+ cudaLibXtDesc *output);
171
+
172
+ cufftResult CUFFTAPI cufftXtExecDescriptorZ2D(cufftHandle plan,
173
+ cudaLibXtDesc *input,
174
+ cudaLibXtDesc *output);
175
+
176
+ // Utility functions
177
+
178
+ cufftResult CUFFTAPI cufftXtQueryPlan(cufftHandle plan, void *queryStruct, cufftXtQueryType queryType);
179
+
180
+
181
+ // callbacks
182
+
183
+
184
+ typedef enum cufftXtCallbackType_t {
185
+ CUFFT_CB_LD_COMPLEX = 0x0,
186
+ CUFFT_CB_LD_COMPLEX_DOUBLE = 0x1,
187
+ CUFFT_CB_LD_REAL = 0x2,
188
+ CUFFT_CB_LD_REAL_DOUBLE = 0x3,
189
+ CUFFT_CB_ST_COMPLEX = 0x4,
190
+ CUFFT_CB_ST_COMPLEX_DOUBLE = 0x5,
191
+ CUFFT_CB_ST_REAL = 0x6,
192
+ CUFFT_CB_ST_REAL_DOUBLE = 0x7,
193
+ CUFFT_CB_UNDEFINED = 0x8
194
+
195
+ } cufftXtCallbackType;
196
+
197
+ typedef cufftComplex (*cufftCallbackLoadC)(void *dataIn, size_t offset, void *callerInfo, void *sharedPointer);
198
+ typedef cufftDoubleComplex (*cufftCallbackLoadZ)(void *dataIn, size_t offset, void *callerInfo, void *sharedPointer);
199
+ typedef cufftReal (*cufftCallbackLoadR)(void *dataIn, size_t offset, void *callerInfo, void *sharedPointer);
200
+ typedef cufftDoubleReal(*cufftCallbackLoadD)(void *dataIn, size_t offset, void *callerInfo, void *sharedPointer);
201
+
202
+ typedef void (*cufftCallbackStoreC)(void *dataOut, size_t offset, cufftComplex element, void *callerInfo, void *sharedPointer);
203
+ typedef void (*cufftCallbackStoreZ)(void *dataOut, size_t offset, cufftDoubleComplex element, void *callerInfo, void *sharedPointer);
204
+ typedef void (*cufftCallbackStoreR)(void *dataOut, size_t offset, cufftReal element, void *callerInfo, void *sharedPointer);
205
+ typedef void (*cufftCallbackStoreD)(void *dataOut, size_t offset, cufftDoubleReal element, void *callerInfo, void *sharedPointer);
206
+
207
+
208
+ cufftResult CUFFTAPI cufftXtSetCallback(cufftHandle plan, void **callback_routine, cufftXtCallbackType cbType, void **caller_info);
209
+ cufftResult CUFFTAPI cufftXtClearCallback(cufftHandle plan, cufftXtCallbackType cbType);
210
+ cufftResult CUFFTAPI cufftXtSetCallbackSharedSize(cufftHandle plan, cufftXtCallbackType cbType, size_t sharedSize);
211
+
212
+ cufftResult CUFFTAPI cufftXtMakePlanMany(cufftHandle plan,
213
+ int rank,
214
+ long long int *n,
215
+ long long int *inembed,
216
+ long long int istride,
217
+ long long int idist,
218
+ cudaDataType inputtype,
219
+ long long int *onembed,
220
+ long long int ostride,
221
+ long long int odist,
222
+ cudaDataType outputtype,
223
+ long long int batch,
224
+ size_t *workSize,
225
+ cudaDataType executiontype);
226
+
227
+ cufftResult CUFFTAPI cufftXtGetSizeMany(cufftHandle plan,
228
+ int rank,
229
+ long long int *n,
230
+ long long int *inembed,
231
+ long long int istride,
232
+ long long int idist,
233
+ cudaDataType inputtype,
234
+ long long int *onembed,
235
+ long long int ostride,
236
+ long long int odist,
237
+ cudaDataType outputtype,
238
+ long long int batch,
239
+ size_t *workSize,
240
+ cudaDataType executiontype);
241
+
242
+
243
+ cufftResult CUFFTAPI cufftXtExec(cufftHandle plan,
244
+ void *input,
245
+ void *output,
246
+ int direction);
247
+
248
+ cufftResult CUFFTAPI cufftXtExecDescriptor(cufftHandle plan,
249
+ cudaLibXtDesc *input,
250
+ cudaLibXtDesc *output,
251
+ int direction);
252
+
253
+ cufftResult CUFFTAPI cufftXtSetWorkAreaPolicy(cufftHandle plan, cufftXtWorkAreaPolicy policy, size_t *workSize);
254
+
255
+ typedef struct cufftBox3d_t {
256
+ size_t lower[3];
257
+ size_t upper[3];
258
+ size_t strides[3];
259
+ } cufftBox3d;
260
+
261
+ cufftResult CUFFTAPI cufftXtSetDistribution(cufftHandle plan,
262
+ const cufftBox3d *box_in,
263
+ const cufftBox3d *box_out);
264
+
265
+ #ifdef __cplusplus
266
+ }
267
+ #endif
268
+
269
+ #endif
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/include/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia_nvtx_cu11-11.8.86.dist-info/METADATA ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: nvidia-nvtx-cu11
3
+ Version: 11.8.86
4
+ Summary: NVIDIA Tools Extension
5
+ Home-page: https://developer.nvidia.com/cuda-zone
6
+ Author: Nvidia CUDA Installer Team
7
+ Author-email: cuda_installer@nvidia.com
8
+ License: NVIDIA Proprietary Software
9
+ Keywords: cuda,nvidia,runtime,machine learning,deep learning
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: Other/Proprietary License
15
+ Classifier: Natural Language :: English
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.5
18
+ Classifier: Programming Language :: Python :: 3.6
19
+ Classifier: Programming Language :: Python :: 3.7
20
+ Classifier: Programming Language :: Python :: 3.8
21
+ Classifier: Programming Language :: Python :: 3.9
22
+ Classifier: Programming Language :: Python :: 3.10
23
+ Classifier: Programming Language :: Python :: 3.11
24
+ Classifier: Programming Language :: Python :: 3 :: Only
25
+ Classifier: Topic :: Scientific/Engineering
26
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
27
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
28
+ Classifier: Topic :: Software Development
29
+ Classifier: Topic :: Software Development :: Libraries
30
+ Classifier: Operating System :: Microsoft :: Windows
31
+ Classifier: Operating System :: POSIX :: Linux
32
+ Requires-Python: >=3
33
+ License-File: License.txt
34
+
35
+ A C-based API for annotating events, code ranges, and resources in your applications. Applications which integrate NVTX can use the Visual Profiler to capture and visualize these events and ranges.
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia_nvtx_cu11-11.8.86.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.37.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-manylinux1_x86_64
5
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/__pycache__/_elffile.cpython-311.pyc ADDED
Binary file (5.53 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/_parser.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Handwritten parser of dependency specifiers.
2
+
3
+ The docstring for each __parse_* function contains EBNF-inspired grammar representing
4
+ the implementation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import ast
10
+ from typing import NamedTuple, Sequence, Tuple, Union
11
+
12
+ from ._tokenizer import DEFAULT_RULES, Tokenizer
13
+
14
+
15
+ class Node:
16
+ def __init__(self, value: str) -> None:
17
+ self.value = value
18
+
19
+ def __str__(self) -> str:
20
+ return self.value
21
+
22
+ def __repr__(self) -> str:
23
+ return f"<{self.__class__.__name__}('{self}')>"
24
+
25
+ def serialize(self) -> str:
26
+ raise NotImplementedError
27
+
28
+
29
+ class Variable(Node):
30
+ def serialize(self) -> str:
31
+ return str(self)
32
+
33
+
34
+ class Value(Node):
35
+ def serialize(self) -> str:
36
+ return f'"{self}"'
37
+
38
+
39
+ class Op(Node):
40
+ def serialize(self) -> str:
41
+ return str(self)
42
+
43
+
44
+ MarkerVar = Union[Variable, Value]
45
+ MarkerItem = Tuple[MarkerVar, Op, MarkerVar]
46
+ MarkerAtom = Union[MarkerItem, Sequence["MarkerAtom"]]
47
+ MarkerList = Sequence[Union["MarkerList", MarkerAtom, str]]
48
+
49
+
50
+ class ParsedRequirement(NamedTuple):
51
+ name: str
52
+ url: str
53
+ extras: list[str]
54
+ specifier: str
55
+ marker: MarkerList | None
56
+
57
+
58
+ # --------------------------------------------------------------------------------------
59
+ # Recursive descent parser for dependency specifier
60
+ # --------------------------------------------------------------------------------------
61
+ def parse_requirement(source: str) -> ParsedRequirement:
62
+ return _parse_requirement(Tokenizer(source, rules=DEFAULT_RULES))
63
+
64
+
65
+ def _parse_requirement(tokenizer: Tokenizer) -> ParsedRequirement:
66
+ """
67
+ requirement = WS? IDENTIFIER WS? extras WS? requirement_details
68
+ """
69
+ tokenizer.consume("WS")
70
+
71
+ name_token = tokenizer.expect(
72
+ "IDENTIFIER", expected="package name at the start of dependency specifier"
73
+ )
74
+ name = name_token.text
75
+ tokenizer.consume("WS")
76
+
77
+ extras = _parse_extras(tokenizer)
78
+ tokenizer.consume("WS")
79
+
80
+ url, specifier, marker = _parse_requirement_details(tokenizer)
81
+ tokenizer.expect("END", expected="end of dependency specifier")
82
+
83
+ return ParsedRequirement(name, url, extras, specifier, marker)
84
+
85
+
86
+ def _parse_requirement_details(
87
+ tokenizer: Tokenizer,
88
+ ) -> tuple[str, str, MarkerList | None]:
89
+ """
90
+ requirement_details = AT URL (WS requirement_marker?)?
91
+ | specifier WS? (requirement_marker)?
92
+ """
93
+
94
+ specifier = ""
95
+ url = ""
96
+ marker = None
97
+
98
+ if tokenizer.check("AT"):
99
+ tokenizer.read()
100
+ tokenizer.consume("WS")
101
+
102
+ url_start = tokenizer.position
103
+ url = tokenizer.expect("URL", expected="URL after @").text
104
+ if tokenizer.check("END", peek=True):
105
+ return (url, specifier, marker)
106
+
107
+ tokenizer.expect("WS", expected="whitespace after URL")
108
+
109
+ # The input might end after whitespace.
110
+ if tokenizer.check("END", peek=True):
111
+ return (url, specifier, marker)
112
+
113
+ marker = _parse_requirement_marker(
114
+ tokenizer, span_start=url_start, after="URL and whitespace"
115
+ )
116
+ else:
117
+ specifier_start = tokenizer.position
118
+ specifier = _parse_specifier(tokenizer)
119
+ tokenizer.consume("WS")
120
+
121
+ if tokenizer.check("END", peek=True):
122
+ return (url, specifier, marker)
123
+
124
+ marker = _parse_requirement_marker(
125
+ tokenizer,
126
+ span_start=specifier_start,
127
+ after=(
128
+ "version specifier"
129
+ if specifier
130
+ else "name and no valid version specifier"
131
+ ),
132
+ )
133
+
134
+ return (url, specifier, marker)
135
+
136
+
137
+ def _parse_requirement_marker(
138
+ tokenizer: Tokenizer, *, span_start: int, after: str
139
+ ) -> MarkerList:
140
+ """
141
+ requirement_marker = SEMICOLON marker WS?
142
+ """
143
+
144
+ if not tokenizer.check("SEMICOLON"):
145
+ tokenizer.raise_syntax_error(
146
+ f"Expected end or semicolon (after {after})",
147
+ span_start=span_start,
148
+ )
149
+ tokenizer.read()
150
+
151
+ marker = _parse_marker(tokenizer)
152
+ tokenizer.consume("WS")
153
+
154
+ return marker
155
+
156
+
157
+ def _parse_extras(tokenizer: Tokenizer) -> list[str]:
158
+ """
159
+ extras = (LEFT_BRACKET wsp* extras_list? wsp* RIGHT_BRACKET)?
160
+ """
161
+ if not tokenizer.check("LEFT_BRACKET", peek=True):
162
+ return []
163
+
164
+ with tokenizer.enclosing_tokens(
165
+ "LEFT_BRACKET",
166
+ "RIGHT_BRACKET",
167
+ around="extras",
168
+ ):
169
+ tokenizer.consume("WS")
170
+ extras = _parse_extras_list(tokenizer)
171
+ tokenizer.consume("WS")
172
+
173
+ return extras
174
+
175
+
176
+ def _parse_extras_list(tokenizer: Tokenizer) -> list[str]:
177
+ """
178
+ extras_list = identifier (wsp* ',' wsp* identifier)*
179
+ """
180
+ extras: list[str] = []
181
+
182
+ if not tokenizer.check("IDENTIFIER"):
183
+ return extras
184
+
185
+ extras.append(tokenizer.read().text)
186
+
187
+ while True:
188
+ tokenizer.consume("WS")
189
+ if tokenizer.check("IDENTIFIER", peek=True):
190
+ tokenizer.raise_syntax_error("Expected comma between extra names")
191
+ elif not tokenizer.check("COMMA"):
192
+ break
193
+
194
+ tokenizer.read()
195
+ tokenizer.consume("WS")
196
+
197
+ extra_token = tokenizer.expect("IDENTIFIER", expected="extra name after comma")
198
+ extras.append(extra_token.text)
199
+
200
+ return extras
201
+
202
+
203
+ def _parse_specifier(tokenizer: Tokenizer) -> str:
204
+ """
205
+ specifier = LEFT_PARENTHESIS WS? version_many WS? RIGHT_PARENTHESIS
206
+ | WS? version_many WS?
207
+ """
208
+ with tokenizer.enclosing_tokens(
209
+ "LEFT_PARENTHESIS",
210
+ "RIGHT_PARENTHESIS",
211
+ around="version specifier",
212
+ ):
213
+ tokenizer.consume("WS")
214
+ parsed_specifiers = _parse_version_many(tokenizer)
215
+ tokenizer.consume("WS")
216
+
217
+ return parsed_specifiers
218
+
219
+
220
+ def _parse_version_many(tokenizer: Tokenizer) -> str:
221
+ """
222
+ version_many = (SPECIFIER (WS? COMMA WS? SPECIFIER)*)?
223
+ """
224
+ parsed_specifiers = ""
225
+ while tokenizer.check("SPECIFIER"):
226
+ span_start = tokenizer.position
227
+ parsed_specifiers += tokenizer.read().text
228
+ if tokenizer.check("VERSION_PREFIX_TRAIL", peek=True):
229
+ tokenizer.raise_syntax_error(
230
+ ".* suffix can only be used with `==` or `!=` operators",
231
+ span_start=span_start,
232
+ span_end=tokenizer.position + 1,
233
+ )
234
+ if tokenizer.check("VERSION_LOCAL_LABEL_TRAIL", peek=True):
235
+ tokenizer.raise_syntax_error(
236
+ "Local version label can only be used with `==` or `!=` operators",
237
+ span_start=span_start,
238
+ span_end=tokenizer.position,
239
+ )
240
+ tokenizer.consume("WS")
241
+ if not tokenizer.check("COMMA"):
242
+ break
243
+ parsed_specifiers += tokenizer.read().text
244
+ tokenizer.consume("WS")
245
+
246
+ return parsed_specifiers
247
+
248
+
249
+ # --------------------------------------------------------------------------------------
250
+ # Recursive descent parser for marker expression
251
+ # --------------------------------------------------------------------------------------
252
+ def parse_marker(source: str) -> MarkerList:
253
+ return _parse_full_marker(Tokenizer(source, rules=DEFAULT_RULES))
254
+
255
+
256
+ def _parse_full_marker(tokenizer: Tokenizer) -> MarkerList:
257
+ retval = _parse_marker(tokenizer)
258
+ tokenizer.expect("END", expected="end of marker expression")
259
+ return retval
260
+
261
+
262
+ def _parse_marker(tokenizer: Tokenizer) -> MarkerList:
263
+ """
264
+ marker = marker_atom (BOOLOP marker_atom)+
265
+ """
266
+ expression = [_parse_marker_atom(tokenizer)]
267
+ while tokenizer.check("BOOLOP"):
268
+ token = tokenizer.read()
269
+ expr_right = _parse_marker_atom(tokenizer)
270
+ expression.extend((token.text, expr_right))
271
+ return expression
272
+
273
+
274
+ def _parse_marker_atom(tokenizer: Tokenizer) -> MarkerAtom:
275
+ """
276
+ marker_atom = WS? LEFT_PARENTHESIS WS? marker WS? RIGHT_PARENTHESIS WS?
277
+ | WS? marker_item WS?
278
+ """
279
+
280
+ tokenizer.consume("WS")
281
+ if tokenizer.check("LEFT_PARENTHESIS", peek=True):
282
+ with tokenizer.enclosing_tokens(
283
+ "LEFT_PARENTHESIS",
284
+ "RIGHT_PARENTHESIS",
285
+ around="marker expression",
286
+ ):
287
+ tokenizer.consume("WS")
288
+ marker: MarkerAtom = _parse_marker(tokenizer)
289
+ tokenizer.consume("WS")
290
+ else:
291
+ marker = _parse_marker_item(tokenizer)
292
+ tokenizer.consume("WS")
293
+ return marker
294
+
295
+
296
+ def _parse_marker_item(tokenizer: Tokenizer) -> MarkerItem:
297
+ """
298
+ marker_item = WS? marker_var WS? marker_op WS? marker_var WS?
299
+ """
300
+ tokenizer.consume("WS")
301
+ marker_var_left = _parse_marker_var(tokenizer)
302
+ tokenizer.consume("WS")
303
+ marker_op = _parse_marker_op(tokenizer)
304
+ tokenizer.consume("WS")
305
+ marker_var_right = _parse_marker_var(tokenizer)
306
+ tokenizer.consume("WS")
307
+ return (marker_var_left, marker_op, marker_var_right)
308
+
309
+
310
+ def _parse_marker_var(tokenizer: Tokenizer) -> MarkerVar:
311
+ """
312
+ marker_var = VARIABLE | QUOTED_STRING
313
+ """
314
+ if tokenizer.check("VARIABLE"):
315
+ return process_env_var(tokenizer.read().text.replace(".", "_"))
316
+ elif tokenizer.check("QUOTED_STRING"):
317
+ return process_python_str(tokenizer.read().text)
318
+ else:
319
+ tokenizer.raise_syntax_error(
320
+ message="Expected a marker variable or quoted string"
321
+ )
322
+
323
+
324
+ def process_env_var(env_var: str) -> Variable:
325
+ if env_var in ("platform_python_implementation", "python_implementation"):
326
+ return Variable("platform_python_implementation")
327
+ else:
328
+ return Variable(env_var)
329
+
330
+
331
+ def process_python_str(python_str: str) -> Value:
332
+ value = ast.literal_eval(python_str)
333
+ return Value(str(value))
334
+
335
+
336
+ def _parse_marker_op(tokenizer: Tokenizer) -> Op:
337
+ """
338
+ marker_op = IN | NOT IN | OP
339
+ """
340
+ if tokenizer.check("IN"):
341
+ tokenizer.read()
342
+ return Op("in")
343
+ elif tokenizer.check("NOT"):
344
+ tokenizer.read()
345
+ tokenizer.expect("WS", expected="whitespace after 'not'")
346
+ tokenizer.expect("IN", expected="'in' after 'not'")
347
+ return Op("not in")
348
+ elif tokenizer.check("OP"):
349
+ return Op(tokenizer.read().text)
350
+ else:
351
+ return tokenizer.raise_syntax_error(
352
+ "Expected marker operator, one of "
353
+ "<=, <, !=, ==, >=, >, ~=, ===, in, not in"
354
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/markers.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is dual licensed under the terms of the Apache License, Version
2
+ # 2.0, and the BSD License. See the LICENSE file in the root of this repository
3
+ # for complete details.
4
+
5
+ from __future__ import annotations
6
+
7
+ import operator
8
+ import os
9
+ import platform
10
+ import sys
11
+ from typing import Any, Callable, TypedDict, cast
12
+
13
+ from ._parser import MarkerAtom, MarkerList, Op, Value, Variable
14
+ from ._parser import parse_marker as _parse_marker
15
+ from ._tokenizer import ParserSyntaxError
16
+ from .specifiers import InvalidSpecifier, Specifier
17
+ from .utils import canonicalize_name
18
+
19
+ __all__ = [
20
+ "InvalidMarker",
21
+ "Marker",
22
+ "UndefinedComparison",
23
+ "UndefinedEnvironmentName",
24
+ "default_environment",
25
+ ]
26
+
27
+ Operator = Callable[[str, str], bool]
28
+
29
+
30
+ class InvalidMarker(ValueError):
31
+ """
32
+ An invalid marker was found, users should refer to PEP 508.
33
+ """
34
+
35
+
36
+ class UndefinedComparison(ValueError):
37
+ """
38
+ An invalid operation was attempted on a value that doesn't support it.
39
+ """
40
+
41
+
42
+ class UndefinedEnvironmentName(ValueError):
43
+ """
44
+ A name was attempted to be used that does not exist inside of the
45
+ environment.
46
+ """
47
+
48
+
49
+ class Environment(TypedDict):
50
+ implementation_name: str
51
+ """The implementation's identifier, e.g. ``'cpython'``."""
52
+
53
+ implementation_version: str
54
+ """
55
+ The implementation's version, e.g. ``'3.13.0a2'`` for CPython 3.13.0a2, or
56
+ ``'7.3.13'`` for PyPy3.10 v7.3.13.
57
+ """
58
+
59
+ os_name: str
60
+ """
61
+ The value of :py:data:`os.name`. The name of the operating system dependent module
62
+ imported, e.g. ``'posix'``.
63
+ """
64
+
65
+ platform_machine: str
66
+ """
67
+ Returns the machine type, e.g. ``'i386'``.
68
+
69
+ An empty string if the value cannot be determined.
70
+ """
71
+
72
+ platform_release: str
73
+ """
74
+ The system's release, e.g. ``'2.2.0'`` or ``'NT'``.
75
+
76
+ An empty string if the value cannot be determined.
77
+ """
78
+
79
+ platform_system: str
80
+ """
81
+ The system/OS name, e.g. ``'Linux'``, ``'Windows'`` or ``'Java'``.
82
+
83
+ An empty string if the value cannot be determined.
84
+ """
85
+
86
+ platform_version: str
87
+ """
88
+ The system's release version, e.g. ``'#3 on degas'``.
89
+
90
+ An empty string if the value cannot be determined.
91
+ """
92
+
93
+ python_full_version: str
94
+ """
95
+ The Python version as string ``'major.minor.patchlevel'``.
96
+
97
+ Note that unlike the Python :py:data:`sys.version`, this value will always include
98
+ the patchlevel (it defaults to 0).
99
+ """
100
+
101
+ platform_python_implementation: str
102
+ """
103
+ A string identifying the Python implementation, e.g. ``'CPython'``.
104
+ """
105
+
106
+ python_version: str
107
+ """The Python version as string ``'major.minor'``."""
108
+
109
+ sys_platform: str
110
+ """
111
+ This string contains a platform identifier that can be used to append
112
+ platform-specific components to :py:data:`sys.path`, for instance.
113
+
114
+ For Unix systems, except on Linux and AIX, this is the lowercased OS name as
115
+ returned by ``uname -s`` with the first part of the version as returned by
116
+ ``uname -r`` appended, e.g. ``'sunos5'`` or ``'freebsd8'``, at the time when Python
117
+ was built.
118
+ """
119
+
120
+
121
+ def _normalize_extra_values(results: Any) -> Any:
122
+ """
123
+ Normalize extra values.
124
+ """
125
+ if isinstance(results[0], tuple):
126
+ lhs, op, rhs = results[0]
127
+ if isinstance(lhs, Variable) and lhs.value == "extra":
128
+ normalized_extra = canonicalize_name(rhs.value)
129
+ rhs = Value(normalized_extra)
130
+ elif isinstance(rhs, Variable) and rhs.value == "extra":
131
+ normalized_extra = canonicalize_name(lhs.value)
132
+ lhs = Value(normalized_extra)
133
+ results[0] = lhs, op, rhs
134
+ return results
135
+
136
+
137
+ def _format_marker(
138
+ marker: list[str] | MarkerAtom | str, first: bool | None = True
139
+ ) -> str:
140
+ assert isinstance(marker, (list, tuple, str))
141
+
142
+ # Sometimes we have a structure like [[...]] which is a single item list
143
+ # where the single item is itself it's own list. In that case we want skip
144
+ # the rest of this function so that we don't get extraneous () on the
145
+ # outside.
146
+ if (
147
+ isinstance(marker, list)
148
+ and len(marker) == 1
149
+ and isinstance(marker[0], (list, tuple))
150
+ ):
151
+ return _format_marker(marker[0])
152
+
153
+ if isinstance(marker, list):
154
+ inner = (_format_marker(m, first=False) for m in marker)
155
+ if first:
156
+ return " ".join(inner)
157
+ else:
158
+ return "(" + " ".join(inner) + ")"
159
+ elif isinstance(marker, tuple):
160
+ return " ".join([m.serialize() for m in marker])
161
+ else:
162
+ return marker
163
+
164
+
165
+ _operators: dict[str, Operator] = {
166
+ "in": lambda lhs, rhs: lhs in rhs,
167
+ "not in": lambda lhs, rhs: lhs not in rhs,
168
+ "<": operator.lt,
169
+ "<=": operator.le,
170
+ "==": operator.eq,
171
+ "!=": operator.ne,
172
+ ">=": operator.ge,
173
+ ">": operator.gt,
174
+ }
175
+
176
+
177
+ def _eval_op(lhs: str, op: Op, rhs: str) -> bool:
178
+ try:
179
+ spec = Specifier("".join([op.serialize(), rhs]))
180
+ except InvalidSpecifier:
181
+ pass
182
+ else:
183
+ return spec.contains(lhs, prereleases=True)
184
+
185
+ oper: Operator | None = _operators.get(op.serialize())
186
+ if oper is None:
187
+ raise UndefinedComparison(f"Undefined {op!r} on {lhs!r} and {rhs!r}.")
188
+
189
+ return oper(lhs, rhs)
190
+
191
+
192
+ def _normalize(*values: str, key: str) -> tuple[str, ...]:
193
+ # PEP 685 – Comparison of extra names for optional distribution dependencies
194
+ # https://peps.python.org/pep-0685/
195
+ # > When comparing extra names, tools MUST normalize the names being
196
+ # > compared using the semantics outlined in PEP 503 for names
197
+ if key == "extra":
198
+ return tuple(canonicalize_name(v) for v in values)
199
+
200
+ # other environment markers don't have such standards
201
+ return values
202
+
203
+
204
+ def _evaluate_markers(markers: MarkerList, environment: dict[str, str]) -> bool:
205
+ groups: list[list[bool]] = [[]]
206
+
207
+ for marker in markers:
208
+ assert isinstance(marker, (list, tuple, str))
209
+
210
+ if isinstance(marker, list):
211
+ groups[-1].append(_evaluate_markers(marker, environment))
212
+ elif isinstance(marker, tuple):
213
+ lhs, op, rhs = marker
214
+
215
+ if isinstance(lhs, Variable):
216
+ environment_key = lhs.value
217
+ lhs_value = environment[environment_key]
218
+ rhs_value = rhs.value
219
+ else:
220
+ lhs_value = lhs.value
221
+ environment_key = rhs.value
222
+ rhs_value = environment[environment_key]
223
+
224
+ lhs_value, rhs_value = _normalize(lhs_value, rhs_value, key=environment_key)
225
+ groups[-1].append(_eval_op(lhs_value, op, rhs_value))
226
+ else:
227
+ assert marker in ["and", "or"]
228
+ if marker == "or":
229
+ groups.append([])
230
+
231
+ return any(all(item) for item in groups)
232
+
233
+
234
+ def format_full_version(info: sys._version_info) -> str:
235
+ version = f"{info.major}.{info.minor}.{info.micro}"
236
+ kind = info.releaselevel
237
+ if kind != "final":
238
+ version += kind[0] + str(info.serial)
239
+ return version
240
+
241
+
242
+ def default_environment() -> Environment:
243
+ iver = format_full_version(sys.implementation.version)
244
+ implementation_name = sys.implementation.name
245
+ return {
246
+ "implementation_name": implementation_name,
247
+ "implementation_version": iver,
248
+ "os_name": os.name,
249
+ "platform_machine": platform.machine(),
250
+ "platform_release": platform.release(),
251
+ "platform_system": platform.system(),
252
+ "platform_version": platform.version(),
253
+ "python_full_version": platform.python_version(),
254
+ "platform_python_implementation": platform.python_implementation(),
255
+ "python_version": ".".join(platform.python_version_tuple()[:2]),
256
+ "sys_platform": sys.platform,
257
+ }
258
+
259
+
260
+ class Marker:
261
+ def __init__(self, marker: str) -> None:
262
+ # Note: We create a Marker object without calling this constructor in
263
+ # packaging.requirements.Requirement. If any additional logic is
264
+ # added here, make sure to mirror/adapt Requirement.
265
+ try:
266
+ self._markers = _normalize_extra_values(_parse_marker(marker))
267
+ # The attribute `_markers` can be described in terms of a recursive type:
268
+ # MarkerList = List[Union[Tuple[Node, ...], str, MarkerList]]
269
+ #
270
+ # For example, the following expression:
271
+ # python_version > "3.6" or (python_version == "3.6" and os_name == "unix")
272
+ #
273
+ # is parsed into:
274
+ # [
275
+ # (<Variable('python_version')>, <Op('>')>, <Value('3.6')>),
276
+ # 'and',
277
+ # [
278
+ # (<Variable('python_version')>, <Op('==')>, <Value('3.6')>),
279
+ # 'or',
280
+ # (<Variable('os_name')>, <Op('==')>, <Value('unix')>)
281
+ # ]
282
+ # ]
283
+ except ParserSyntaxError as e:
284
+ raise InvalidMarker(str(e)) from e
285
+
286
+ def __str__(self) -> str:
287
+ return _format_marker(self._markers)
288
+
289
+ def __repr__(self) -> str:
290
+ return f"<Marker('{self}')>"
291
+
292
+ def __hash__(self) -> int:
293
+ return hash((self.__class__.__name__, str(self)))
294
+
295
+ def __eq__(self, other: Any) -> bool:
296
+ if not isinstance(other, Marker):
297
+ return NotImplemented
298
+
299
+ return str(self) == str(other)
300
+
301
+ def evaluate(self, environment: dict[str, str] | None = None) -> bool:
302
+ """Evaluate a marker.
303
+
304
+ Return the boolean from evaluating the given marker against the
305
+ environment. environment is an optional argument to override all or
306
+ part of the determined environment.
307
+
308
+ The environment is determined from the current Python process.
309
+ """
310
+ current_environment = cast("dict[str, str]", default_environment())
311
+ current_environment["extra"] = ""
312
+ if environment is not None:
313
+ current_environment.update(environment)
314
+ # The API used to allow setting extra to None. We need to handle this
315
+ # case for backwards compatibility.
316
+ if current_environment["extra"] is None:
317
+ current_environment["extra"] = ""
318
+
319
+ return _evaluate_markers(
320
+ self._markers, _repair_python_full_version(current_environment)
321
+ )
322
+
323
+
324
+ def _repair_python_full_version(env: dict[str, str]) -> dict[str, str]:
325
+ """
326
+ Work around platform.python_version() returning something that is not PEP 440
327
+ compliant for non-tagged Python builds.
328
+ """
329
+ if env["python_full_version"].endswith("+"):
330
+ env["python_full_version"] += "local"
331
+ return env
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/metadata.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import email.feedparser
4
+ import email.header
5
+ import email.message
6
+ import email.parser
7
+ import email.policy
8
+ import pathlib
9
+ import sys
10
+ import typing
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Generic,
15
+ Literal,
16
+ TypedDict,
17
+ cast,
18
+ )
19
+
20
+ from . import licenses, requirements, specifiers, utils
21
+ from . import version as version_module
22
+ from .licenses import NormalizedLicenseExpression
23
+
24
+ T = typing.TypeVar("T")
25
+
26
+
27
+ if sys.version_info >= (3, 11): # pragma: no cover
28
+ ExceptionGroup = ExceptionGroup
29
+ else: # pragma: no cover
30
+
31
+ class ExceptionGroup(Exception):
32
+ """A minimal implementation of :external:exc:`ExceptionGroup` from Python 3.11.
33
+
34
+ If :external:exc:`ExceptionGroup` is already defined by Python itself,
35
+ that version is used instead.
36
+ """
37
+
38
+ message: str
39
+ exceptions: list[Exception]
40
+
41
+ def __init__(self, message: str, exceptions: list[Exception]) -> None:
42
+ self.message = message
43
+ self.exceptions = exceptions
44
+
45
+ def __repr__(self) -> str:
46
+ return f"{self.__class__.__name__}({self.message!r}, {self.exceptions!r})"
47
+
48
+
49
+ class InvalidMetadata(ValueError):
50
+ """A metadata field contains invalid data."""
51
+
52
+ field: str
53
+ """The name of the field that contains invalid data."""
54
+
55
+ def __init__(self, field: str, message: str) -> None:
56
+ self.field = field
57
+ super().__init__(message)
58
+
59
+
60
+ # The RawMetadata class attempts to make as few assumptions about the underlying
61
+ # serialization formats as possible. The idea is that as long as a serialization
62
+ # formats offer some very basic primitives in *some* way then we can support
63
+ # serializing to and from that format.
64
+ class RawMetadata(TypedDict, total=False):
65
+ """A dictionary of raw core metadata.
66
+
67
+ Each field in core metadata maps to a key of this dictionary (when data is
68
+ provided). The key is lower-case and underscores are used instead of dashes
69
+ compared to the equivalent core metadata field. Any core metadata field that
70
+ can be specified multiple times or can hold multiple values in a single
71
+ field have a key with a plural name. See :class:`Metadata` whose attributes
72
+ match the keys of this dictionary.
73
+
74
+ Core metadata fields that can be specified multiple times are stored as a
75
+ list or dict depending on which is appropriate for the field. Any fields
76
+ which hold multiple values in a single field are stored as a list.
77
+
78
+ """
79
+
80
+ # Metadata 1.0 - PEP 241
81
+ metadata_version: str
82
+ name: str
83
+ version: str
84
+ platforms: list[str]
85
+ summary: str
86
+ description: str
87
+ keywords: list[str]
88
+ home_page: str
89
+ author: str
90
+ author_email: str
91
+ license: str
92
+
93
+ # Metadata 1.1 - PEP 314
94
+ supported_platforms: list[str]
95
+ download_url: str
96
+ classifiers: list[str]
97
+ requires: list[str]
98
+ provides: list[str]
99
+ obsoletes: list[str]
100
+
101
+ # Metadata 1.2 - PEP 345
102
+ maintainer: str
103
+ maintainer_email: str
104
+ requires_dist: list[str]
105
+ provides_dist: list[str]
106
+ obsoletes_dist: list[str]
107
+ requires_python: str
108
+ requires_external: list[str]
109
+ project_urls: dict[str, str]
110
+
111
+ # Metadata 2.0
112
+ # PEP 426 attempted to completely revamp the metadata format
113
+ # but got stuck without ever being able to build consensus on
114
+ # it and ultimately ended up withdrawn.
115
+ #
116
+ # However, a number of tools had started emitting METADATA with
117
+ # `2.0` Metadata-Version, so for historical reasons, this version
118
+ # was skipped.
119
+
120
+ # Metadata 2.1 - PEP 566
121
+ description_content_type: str
122
+ provides_extra: list[str]
123
+
124
+ # Metadata 2.2 - PEP 643
125
+ dynamic: list[str]
126
+
127
+ # Metadata 2.3 - PEP 685
128
+ # No new fields were added in PEP 685, just some edge case were
129
+ # tightened up to provide better interoptability.
130
+
131
+ # Metadata 2.4 - PEP 639
132
+ license_expression: str
133
+ license_files: list[str]
134
+
135
+
136
+ _STRING_FIELDS = {
137
+ "author",
138
+ "author_email",
139
+ "description",
140
+ "description_content_type",
141
+ "download_url",
142
+ "home_page",
143
+ "license",
144
+ "license_expression",
145
+ "maintainer",
146
+ "maintainer_email",
147
+ "metadata_version",
148
+ "name",
149
+ "requires_python",
150
+ "summary",
151
+ "version",
152
+ }
153
+
154
+ _LIST_FIELDS = {
155
+ "classifiers",
156
+ "dynamic",
157
+ "license_files",
158
+ "obsoletes",
159
+ "obsoletes_dist",
160
+ "platforms",
161
+ "provides",
162
+ "provides_dist",
163
+ "provides_extra",
164
+ "requires",
165
+ "requires_dist",
166
+ "requires_external",
167
+ "supported_platforms",
168
+ }
169
+
170
+ _DICT_FIELDS = {
171
+ "project_urls",
172
+ }
173
+
174
+
175
+ def _parse_keywords(data: str) -> list[str]:
176
+ """Split a string of comma-separated keywords into a list of keywords."""
177
+ return [k.strip() for k in data.split(",")]
178
+
179
+
180
+ def _parse_project_urls(data: list[str]) -> dict[str, str]:
181
+ """Parse a list of label/URL string pairings separated by a comma."""
182
+ urls = {}
183
+ for pair in data:
184
+ # Our logic is slightly tricky here as we want to try and do
185
+ # *something* reasonable with malformed data.
186
+ #
187
+ # The main thing that we have to worry about, is data that does
188
+ # not have a ',' at all to split the label from the Value. There
189
+ # isn't a singular right answer here, and we will fail validation
190
+ # later on (if the caller is validating) so it doesn't *really*
191
+ # matter, but since the missing value has to be an empty str
192
+ # and our return value is dict[str, str], if we let the key
193
+ # be the missing value, then they'd have multiple '' values that
194
+ # overwrite each other in a accumulating dict.
195
+ #
196
+ # The other potentional issue is that it's possible to have the
197
+ # same label multiple times in the metadata, with no solid "right"
198
+ # answer with what to do in that case. As such, we'll do the only
199
+ # thing we can, which is treat the field as unparseable and add it
200
+ # to our list of unparsed fields.
201
+ parts = [p.strip() for p in pair.split(",", 1)]
202
+ parts.extend([""] * (max(0, 2 - len(parts)))) # Ensure 2 items
203
+
204
+ # TODO: The spec doesn't say anything about if the keys should be
205
+ # considered case sensitive or not... logically they should
206
+ # be case-preserving and case-insensitive, but doing that
207
+ # would open up more cases where we might have duplicate
208
+ # entries.
209
+ label, url = parts
210
+ if label in urls:
211
+ # The label already exists in our set of urls, so this field
212
+ # is unparseable, and we can just add the whole thing to our
213
+ # unparseable data and stop processing it.
214
+ raise KeyError("duplicate labels in project urls")
215
+ urls[label] = url
216
+
217
+ return urls
218
+
219
+
220
+ def _get_payload(msg: email.message.Message, source: bytes | str) -> str:
221
+ """Get the body of the message."""
222
+ # If our source is a str, then our caller has managed encodings for us,
223
+ # and we don't need to deal with it.
224
+ if isinstance(source, str):
225
+ payload = msg.get_payload()
226
+ assert isinstance(payload, str)
227
+ return payload
228
+ # If our source is a bytes, then we're managing the encoding and we need
229
+ # to deal with it.
230
+ else:
231
+ bpayload = msg.get_payload(decode=True)
232
+ assert isinstance(bpayload, bytes)
233
+ try:
234
+ return bpayload.decode("utf8", "strict")
235
+ except UnicodeDecodeError as exc:
236
+ raise ValueError("payload in an invalid encoding") from exc
237
+
238
+
239
+ # The various parse_FORMAT functions here are intended to be as lenient as
240
+ # possible in their parsing, while still returning a correctly typed
241
+ # RawMetadata.
242
+ #
243
+ # To aid in this, we also generally want to do as little touching of the
244
+ # data as possible, except where there are possibly some historic holdovers
245
+ # that make valid data awkward to work with.
246
+ #
247
+ # While this is a lower level, intermediate format than our ``Metadata``
248
+ # class, some light touch ups can make a massive difference in usability.
249
+
250
+ # Map METADATA fields to RawMetadata.
251
+ _EMAIL_TO_RAW_MAPPING = {
252
+ "author": "author",
253
+ "author-email": "author_email",
254
+ "classifier": "classifiers",
255
+ "description": "description",
256
+ "description-content-type": "description_content_type",
257
+ "download-url": "download_url",
258
+ "dynamic": "dynamic",
259
+ "home-page": "home_page",
260
+ "keywords": "keywords",
261
+ "license": "license",
262
+ "license-expression": "license_expression",
263
+ "license-file": "license_files",
264
+ "maintainer": "maintainer",
265
+ "maintainer-email": "maintainer_email",
266
+ "metadata-version": "metadata_version",
267
+ "name": "name",
268
+ "obsoletes": "obsoletes",
269
+ "obsoletes-dist": "obsoletes_dist",
270
+ "platform": "platforms",
271
+ "project-url": "project_urls",
272
+ "provides": "provides",
273
+ "provides-dist": "provides_dist",
274
+ "provides-extra": "provides_extra",
275
+ "requires": "requires",
276
+ "requires-dist": "requires_dist",
277
+ "requires-external": "requires_external",
278
+ "requires-python": "requires_python",
279
+ "summary": "summary",
280
+ "supported-platform": "supported_platforms",
281
+ "version": "version",
282
+ }
283
+ _RAW_TO_EMAIL_MAPPING = {raw: email for email, raw in _EMAIL_TO_RAW_MAPPING.items()}
284
+
285
+
286
+ def parse_email(data: bytes | str) -> tuple[RawMetadata, dict[str, list[str]]]:
287
+ """Parse a distribution's metadata stored as email headers (e.g. from ``METADATA``).
288
+
289
+ This function returns a two-item tuple of dicts. The first dict is of
290
+ recognized fields from the core metadata specification. Fields that can be
291
+ parsed and translated into Python's built-in types are converted
292
+ appropriately. All other fields are left as-is. Fields that are allowed to
293
+ appear multiple times are stored as lists.
294
+
295
+ The second dict contains all other fields from the metadata. This includes
296
+ any unrecognized fields. It also includes any fields which are expected to
297
+ be parsed into a built-in type but were not formatted appropriately. Finally,
298
+ any fields that are expected to appear only once but are repeated are
299
+ included in this dict.
300
+
301
+ """
302
+ raw: dict[str, str | list[str] | dict[str, str]] = {}
303
+ unparsed: dict[str, list[str]] = {}
304
+
305
+ if isinstance(data, str):
306
+ parsed = email.parser.Parser(policy=email.policy.compat32).parsestr(data)
307
+ else:
308
+ parsed = email.parser.BytesParser(policy=email.policy.compat32).parsebytes(data)
309
+
310
+ # We have to wrap parsed.keys() in a set, because in the case of multiple
311
+ # values for a key (a list), the key will appear multiple times in the
312
+ # list of keys, but we're avoiding that by using get_all().
313
+ for name in frozenset(parsed.keys()):
314
+ # Header names in RFC are case insensitive, so we'll normalize to all
315
+ # lower case to make comparisons easier.
316
+ name = name.lower()
317
+
318
+ # We use get_all() here, even for fields that aren't multiple use,
319
+ # because otherwise someone could have e.g. two Name fields, and we
320
+ # would just silently ignore it rather than doing something about it.
321
+ headers = parsed.get_all(name) or []
322
+
323
+ # The way the email module works when parsing bytes is that it
324
+ # unconditionally decodes the bytes as ascii using the surrogateescape
325
+ # handler. When you pull that data back out (such as with get_all() ),
326
+ # it looks to see if the str has any surrogate escapes, and if it does
327
+ # it wraps it in a Header object instead of returning the string.
328
+ #
329
+ # As such, we'll look for those Header objects, and fix up the encoding.
330
+ value = []
331
+ # Flag if we have run into any issues processing the headers, thus
332
+ # signalling that the data belongs in 'unparsed'.
333
+ valid_encoding = True
334
+ for h in headers:
335
+ # It's unclear if this can return more types than just a Header or
336
+ # a str, so we'll just assert here to make sure.
337
+ assert isinstance(h, (email.header.Header, str))
338
+
339
+ # If it's a header object, we need to do our little dance to get
340
+ # the real data out of it. In cases where there is invalid data
341
+ # we're going to end up with mojibake, but there's no obvious, good
342
+ # way around that without reimplementing parts of the Header object
343
+ # ourselves.
344
+ #
345
+ # That should be fine since, if mojibacked happens, this key is
346
+ # going into the unparsed dict anyways.
347
+ if isinstance(h, email.header.Header):
348
+ # The Header object stores it's data as chunks, and each chunk
349
+ # can be independently encoded, so we'll need to check each
350
+ # of them.
351
+ chunks: list[tuple[bytes, str | None]] = []
352
+ for bin, encoding in email.header.decode_header(h):
353
+ try:
354
+ bin.decode("utf8", "strict")
355
+ except UnicodeDecodeError:
356
+ # Enable mojibake.
357
+ encoding = "latin1"
358
+ valid_encoding = False
359
+ else:
360
+ encoding = "utf8"
361
+ chunks.append((bin, encoding))
362
+
363
+ # Turn our chunks back into a Header object, then let that
364
+ # Header object do the right thing to turn them into a
365
+ # string for us.
366
+ value.append(str(email.header.make_header(chunks)))
367
+ # This is already a string, so just add it.
368
+ else:
369
+ value.append(h)
370
+
371
+ # We've processed all of our values to get them into a list of str,
372
+ # but we may have mojibake data, in which case this is an unparsed
373
+ # field.
374
+ if not valid_encoding:
375
+ unparsed[name] = value
376
+ continue
377
+
378
+ raw_name = _EMAIL_TO_RAW_MAPPING.get(name)
379
+ if raw_name is None:
380
+ # This is a bit of a weird situation, we've encountered a key that
381
+ # we don't know what it means, so we don't know whether it's meant
382
+ # to be a list or not.
383
+ #
384
+ # Since we can't really tell one way or another, we'll just leave it
385
+ # as a list, even though it may be a single item list, because that's
386
+ # what makes the most sense for email headers.
387
+ unparsed[name] = value
388
+ continue
389
+
390
+ # If this is one of our string fields, then we'll check to see if our
391
+ # value is a list of a single item. If it is then we'll assume that
392
+ # it was emitted as a single string, and unwrap the str from inside
393
+ # the list.
394
+ #
395
+ # If it's any other kind of data, then we haven't the faintest clue
396
+ # what we should parse it as, and we have to just add it to our list
397
+ # of unparsed stuff.
398
+ if raw_name in _STRING_FIELDS and len(value) == 1:
399
+ raw[raw_name] = value[0]
400
+ # If this is one of our list of string fields, then we can just assign
401
+ # the value, since email *only* has strings, and our get_all() call
402
+ # above ensures that this is a list.
403
+ elif raw_name in _LIST_FIELDS:
404
+ raw[raw_name] = value
405
+ # Special Case: Keywords
406
+ # The keywords field is implemented in the metadata spec as a str,
407
+ # but it conceptually is a list of strings, and is serialized using
408
+ # ", ".join(keywords), so we'll do some light data massaging to turn
409
+ # this into what it logically is.
410
+ elif raw_name == "keywords" and len(value) == 1:
411
+ raw[raw_name] = _parse_keywords(value[0])
412
+ # Special Case: Project-URL
413
+ # The project urls is implemented in the metadata spec as a list of
414
+ # specially-formatted strings that represent a key and a value, which
415
+ # is fundamentally a mapping, however the email format doesn't support
416
+ # mappings in a sane way, so it was crammed into a list of strings
417
+ # instead.
418
+ #
419
+ # We will do a little light data massaging to turn this into a map as
420
+ # it logically should be.
421
+ elif raw_name == "project_urls":
422
+ try:
423
+ raw[raw_name] = _parse_project_urls(value)
424
+ except KeyError:
425
+ unparsed[name] = value
426
+ # Nothing that we've done has managed to parse this, so it'll just
427
+ # throw it in our unparseable data and move on.
428
+ else:
429
+ unparsed[name] = value
430
+
431
+ # We need to support getting the Description from the message payload in
432
+ # addition to getting it from the the headers. This does mean, though, there
433
+ # is the possibility of it being set both ways, in which case we put both
434
+ # in 'unparsed' since we don't know which is right.
435
+ try:
436
+ payload = _get_payload(parsed, data)
437
+ except ValueError:
438
+ unparsed.setdefault("description", []).append(
439
+ parsed.get_payload(decode=isinstance(data, bytes)) # type: ignore[call-overload]
440
+ )
441
+ else:
442
+ if payload:
443
+ # Check to see if we've already got a description, if so then both
444
+ # it, and this body move to unparseable.
445
+ if "description" in raw:
446
+ description_header = cast(str, raw.pop("description"))
447
+ unparsed.setdefault("description", []).extend(
448
+ [description_header, payload]
449
+ )
450
+ elif "description" in unparsed:
451
+ unparsed["description"].append(payload)
452
+ else:
453
+ raw["description"] = payload
454
+
455
+ # We need to cast our `raw` to a metadata, because a TypedDict only support
456
+ # literal key names, but we're computing our key names on purpose, but the
457
+ # way this function is implemented, our `TypedDict` can only have valid key
458
+ # names.
459
+ return cast(RawMetadata, raw), unparsed
460
+
461
+
462
+ _NOT_FOUND = object()
463
+
464
+
465
+ # Keep the two values in sync.
466
+ _VALID_METADATA_VERSIONS = ["1.0", "1.1", "1.2", "2.1", "2.2", "2.3", "2.4"]
467
+ _MetadataVersion = Literal["1.0", "1.1", "1.2", "2.1", "2.2", "2.3", "2.4"]
468
+
469
+ _REQUIRED_ATTRS = frozenset(["metadata_version", "name", "version"])
470
+
471
+
472
+ class _Validator(Generic[T]):
473
+ """Validate a metadata field.
474
+
475
+ All _process_*() methods correspond to a core metadata field. The method is
476
+ called with the field's raw value. If the raw value is valid it is returned
477
+ in its "enriched" form (e.g. ``version.Version`` for the ``Version`` field).
478
+ If the raw value is invalid, :exc:`InvalidMetadata` is raised (with a cause
479
+ as appropriate).
480
+ """
481
+
482
+ name: str
483
+ raw_name: str
484
+ added: _MetadataVersion
485
+
486
+ def __init__(
487
+ self,
488
+ *,
489
+ added: _MetadataVersion = "1.0",
490
+ ) -> None:
491
+ self.added = added
492
+
493
+ def __set_name__(self, _owner: Metadata, name: str) -> None:
494
+ self.name = name
495
+ self.raw_name = _RAW_TO_EMAIL_MAPPING[name]
496
+
497
+ def __get__(self, instance: Metadata, _owner: type[Metadata]) -> T:
498
+ # With Python 3.8, the caching can be replaced with functools.cached_property().
499
+ # No need to check the cache as attribute lookup will resolve into the
500
+ # instance's __dict__ before __get__ is called.
501
+ cache = instance.__dict__
502
+ value = instance._raw.get(self.name)
503
+
504
+ # To make the _process_* methods easier, we'll check if the value is None
505
+ # and if this field is NOT a required attribute, and if both of those
506
+ # things are true, we'll skip the the converter. This will mean that the
507
+ # converters never have to deal with the None union.
508
+ if self.name in _REQUIRED_ATTRS or value is not None:
509
+ try:
510
+ converter: Callable[[Any], T] = getattr(self, f"_process_{self.name}")
511
+ except AttributeError:
512
+ pass
513
+ else:
514
+ value = converter(value)
515
+
516
+ cache[self.name] = value
517
+ try:
518
+ del instance._raw[self.name] # type: ignore[misc]
519
+ except KeyError:
520
+ pass
521
+
522
+ return cast(T, value)
523
+
524
+ def _invalid_metadata(
525
+ self, msg: str, cause: Exception | None = None
526
+ ) -> InvalidMetadata:
527
+ exc = InvalidMetadata(
528
+ self.raw_name, msg.format_map({"field": repr(self.raw_name)})
529
+ )
530
+ exc.__cause__ = cause
531
+ return exc
532
+
533
+ def _process_metadata_version(self, value: str) -> _MetadataVersion:
534
+ # Implicitly makes Metadata-Version required.
535
+ if value not in _VALID_METADATA_VERSIONS:
536
+ raise self._invalid_metadata(f"{value!r} is not a valid metadata version")
537
+ return cast(_MetadataVersion, value)
538
+
539
+ def _process_name(self, value: str) -> str:
540
+ if not value:
541
+ raise self._invalid_metadata("{field} is a required field")
542
+ # Validate the name as a side-effect.
543
+ try:
544
+ utils.canonicalize_name(value, validate=True)
545
+ except utils.InvalidName as exc:
546
+ raise self._invalid_metadata(
547
+ f"{value!r} is invalid for {{field}}", cause=exc
548
+ ) from exc
549
+ else:
550
+ return value
551
+
552
+ def _process_version(self, value: str) -> version_module.Version:
553
+ if not value:
554
+ raise self._invalid_metadata("{field} is a required field")
555
+ try:
556
+ return version_module.parse(value)
557
+ except version_module.InvalidVersion as exc:
558
+ raise self._invalid_metadata(
559
+ f"{value!r} is invalid for {{field}}", cause=exc
560
+ ) from exc
561
+
562
+ def _process_summary(self, value: str) -> str:
563
+ """Check the field contains no newlines."""
564
+ if "\n" in value:
565
+ raise self._invalid_metadata("{field} must be a single line")
566
+ return value
567
+
568
+ def _process_description_content_type(self, value: str) -> str:
569
+ content_types = {"text/plain", "text/x-rst", "text/markdown"}
570
+ message = email.message.EmailMessage()
571
+ message["content-type"] = value
572
+
573
+ content_type, parameters = (
574
+ # Defaults to `text/plain` if parsing failed.
575
+ message.get_content_type().lower(),
576
+ message["content-type"].params,
577
+ )
578
+ # Check if content-type is valid or defaulted to `text/plain` and thus was
579
+ # not parseable.
580
+ if content_type not in content_types or content_type not in value.lower():
581
+ raise self._invalid_metadata(
582
+ f"{{field}} must be one of {list(content_types)}, not {value!r}"
583
+ )
584
+
585
+ charset = parameters.get("charset", "UTF-8")
586
+ if charset != "UTF-8":
587
+ raise self._invalid_metadata(
588
+ f"{{field}} can only specify the UTF-8 charset, not {list(charset)}"
589
+ )
590
+
591
+ markdown_variants = {"GFM", "CommonMark"}
592
+ variant = parameters.get("variant", "GFM") # Use an acceptable default.
593
+ if content_type == "text/markdown" and variant not in markdown_variants:
594
+ raise self._invalid_metadata(
595
+ f"valid Markdown variants for {{field}} are {list(markdown_variants)}, "
596
+ f"not {variant!r}",
597
+ )
598
+ return value
599
+
600
+ def _process_dynamic(self, value: list[str]) -> list[str]:
601
+ for dynamic_field in map(str.lower, value):
602
+ if dynamic_field in {"name", "version", "metadata-version"}:
603
+ raise self._invalid_metadata(
604
+ f"{dynamic_field!r} is not allowed as a dynamic field"
605
+ )
606
+ elif dynamic_field not in _EMAIL_TO_RAW_MAPPING:
607
+ raise self._invalid_metadata(
608
+ f"{dynamic_field!r} is not a valid dynamic field"
609
+ )
610
+ return list(map(str.lower, value))
611
+
612
+ def _process_provides_extra(
613
+ self,
614
+ value: list[str],
615
+ ) -> list[utils.NormalizedName]:
616
+ normalized_names = []
617
+ try:
618
+ for name in value:
619
+ normalized_names.append(utils.canonicalize_name(name, validate=True))
620
+ except utils.InvalidName as exc:
621
+ raise self._invalid_metadata(
622
+ f"{name!r} is invalid for {{field}}", cause=exc
623
+ ) from exc
624
+ else:
625
+ return normalized_names
626
+
627
+ def _process_requires_python(self, value: str) -> specifiers.SpecifierSet:
628
+ try:
629
+ return specifiers.SpecifierSet(value)
630
+ except specifiers.InvalidSpecifier as exc:
631
+ raise self._invalid_metadata(
632
+ f"{value!r} is invalid for {{field}}", cause=exc
633
+ ) from exc
634
+
635
+ def _process_requires_dist(
636
+ self,
637
+ value: list[str],
638
+ ) -> list[requirements.Requirement]:
639
+ reqs = []
640
+ try:
641
+ for req in value:
642
+ reqs.append(requirements.Requirement(req))
643
+ except requirements.InvalidRequirement as exc:
644
+ raise self._invalid_metadata(
645
+ f"{req!r} is invalid for {{field}}", cause=exc
646
+ ) from exc
647
+ else:
648
+ return reqs
649
+
650
+ def _process_license_expression(
651
+ self, value: str
652
+ ) -> NormalizedLicenseExpression | None:
653
+ try:
654
+ return licenses.canonicalize_license_expression(value)
655
+ except ValueError as exc:
656
+ raise self._invalid_metadata(
657
+ f"{value!r} is invalid for {{field}}", cause=exc
658
+ ) from exc
659
+
660
+ def _process_license_files(self, value: list[str]) -> list[str]:
661
+ paths = []
662
+ for path in value:
663
+ if ".." in path:
664
+ raise self._invalid_metadata(
665
+ f"{path!r} is invalid for {{field}}, "
666
+ "parent directory indicators are not allowed"
667
+ )
668
+ if "*" in path:
669
+ raise self._invalid_metadata(
670
+ f"{path!r} is invalid for {{field}}, paths must be resolved"
671
+ )
672
+ if (
673
+ pathlib.PurePosixPath(path).is_absolute()
674
+ or pathlib.PureWindowsPath(path).is_absolute()
675
+ ):
676
+ raise self._invalid_metadata(
677
+ f"{path!r} is invalid for {{field}}, paths must be relative"
678
+ )
679
+ if pathlib.PureWindowsPath(path).as_posix() != path:
680
+ raise self._invalid_metadata(
681
+ f"{path!r} is invalid for {{field}}, "
682
+ "paths must use '/' delimiter"
683
+ )
684
+ paths.append(path)
685
+ return paths
686
+
687
+
688
+ class Metadata:
689
+ """Representation of distribution metadata.
690
+
691
+ Compared to :class:`RawMetadata`, this class provides objects representing
692
+ metadata fields instead of only using built-in types. Any invalid metadata
693
+ will cause :exc:`InvalidMetadata` to be raised (with a
694
+ :py:attr:`~BaseException.__cause__` attribute as appropriate).
695
+ """
696
+
697
+ _raw: RawMetadata
698
+
699
+ @classmethod
700
+ def from_raw(cls, data: RawMetadata, *, validate: bool = True) -> Metadata:
701
+ """Create an instance from :class:`RawMetadata`.
702
+
703
+ If *validate* is true, all metadata will be validated. All exceptions
704
+ related to validation will be gathered and raised as an :class:`ExceptionGroup`.
705
+ """
706
+ ins = cls()
707
+ ins._raw = data.copy() # Mutations occur due to caching enriched values.
708
+
709
+ if validate:
710
+ exceptions: list[Exception] = []
711
+ try:
712
+ metadata_version = ins.metadata_version
713
+ metadata_age = _VALID_METADATA_VERSIONS.index(metadata_version)
714
+ except InvalidMetadata as metadata_version_exc:
715
+ exceptions.append(metadata_version_exc)
716
+ metadata_version = None
717
+
718
+ # Make sure to check for the fields that are present, the required
719
+ # fields (so their absence can be reported).
720
+ fields_to_check = frozenset(ins._raw) | _REQUIRED_ATTRS
721
+ # Remove fields that have already been checked.
722
+ fields_to_check -= {"metadata_version"}
723
+
724
+ for key in fields_to_check:
725
+ try:
726
+ if metadata_version:
727
+ # Can't use getattr() as that triggers descriptor protocol which
728
+ # will fail due to no value for the instance argument.
729
+ try:
730
+ field_metadata_version = cls.__dict__[key].added
731
+ except KeyError:
732
+ exc = InvalidMetadata(key, f"unrecognized field: {key!r}")
733
+ exceptions.append(exc)
734
+ continue
735
+ field_age = _VALID_METADATA_VERSIONS.index(
736
+ field_metadata_version
737
+ )
738
+ if field_age > metadata_age:
739
+ field = _RAW_TO_EMAIL_MAPPING[key]
740
+ exc = InvalidMetadata(
741
+ field,
742
+ f"{field} introduced in metadata version "
743
+ f"{field_metadata_version}, not {metadata_version}",
744
+ )
745
+ exceptions.append(exc)
746
+ continue
747
+ getattr(ins, key)
748
+ except InvalidMetadata as exc:
749
+ exceptions.append(exc)
750
+
751
+ if exceptions:
752
+ raise ExceptionGroup("invalid metadata", exceptions)
753
+
754
+ return ins
755
+
756
+ @classmethod
757
+ def from_email(cls, data: bytes | str, *, validate: bool = True) -> Metadata:
758
+ """Parse metadata from email headers.
759
+
760
+ If *validate* is true, the metadata will be validated. All exceptions
761
+ related to validation will be gathered and raised as an :class:`ExceptionGroup`.
762
+ """
763
+ raw, unparsed = parse_email(data)
764
+
765
+ if validate:
766
+ exceptions: list[Exception] = []
767
+ for unparsed_key in unparsed:
768
+ if unparsed_key in _EMAIL_TO_RAW_MAPPING:
769
+ message = f"{unparsed_key!r} has invalid data"
770
+ else:
771
+ message = f"unrecognized field: {unparsed_key!r}"
772
+ exceptions.append(InvalidMetadata(unparsed_key, message))
773
+
774
+ if exceptions:
775
+ raise ExceptionGroup("unparsed", exceptions)
776
+
777
+ try:
778
+ return cls.from_raw(raw, validate=validate)
779
+ except ExceptionGroup as exc_group:
780
+ raise ExceptionGroup(
781
+ "invalid or unparsed metadata", exc_group.exceptions
782
+ ) from None
783
+
784
+ metadata_version: _Validator[_MetadataVersion] = _Validator()
785
+ """:external:ref:`core-metadata-metadata-version`
786
+ (required; validated to be a valid metadata version)"""
787
+ # `name` is not normalized/typed to NormalizedName so as to provide access to
788
+ # the original/raw name.
789
+ name: _Validator[str] = _Validator()
790
+ """:external:ref:`core-metadata-name`
791
+ (required; validated using :func:`~packaging.utils.canonicalize_name` and its
792
+ *validate* parameter)"""
793
+ version: _Validator[version_module.Version] = _Validator()
794
+ """:external:ref:`core-metadata-version` (required)"""
795
+ dynamic: _Validator[list[str] | None] = _Validator(
796
+ added="2.2",
797
+ )
798
+ """:external:ref:`core-metadata-dynamic`
799
+ (validated against core metadata field names and lowercased)"""
800
+ platforms: _Validator[list[str] | None] = _Validator()
801
+ """:external:ref:`core-metadata-platform`"""
802
+ supported_platforms: _Validator[list[str] | None] = _Validator(added="1.1")
803
+ """:external:ref:`core-metadata-supported-platform`"""
804
+ summary: _Validator[str | None] = _Validator()
805
+ """:external:ref:`core-metadata-summary` (validated to contain no newlines)"""
806
+ description: _Validator[str | None] = _Validator() # TODO 2.1: can be in body
807
+ """:external:ref:`core-metadata-description`"""
808
+ description_content_type: _Validator[str | None] = _Validator(added="2.1")
809
+ """:external:ref:`core-metadata-description-content-type` (validated)"""
810
+ keywords: _Validator[list[str] | None] = _Validator()
811
+ """:external:ref:`core-metadata-keywords`"""
812
+ home_page: _Validator[str | None] = _Validator()
813
+ """:external:ref:`core-metadata-home-page`"""
814
+ download_url: _Validator[str | None] = _Validator(added="1.1")
815
+ """:external:ref:`core-metadata-download-url`"""
816
+ author: _Validator[str | None] = _Validator()
817
+ """:external:ref:`core-metadata-author`"""
818
+ author_email: _Validator[str | None] = _Validator()
819
+ """:external:ref:`core-metadata-author-email`"""
820
+ maintainer: _Validator[str | None] = _Validator(added="1.2")
821
+ """:external:ref:`core-metadata-maintainer`"""
822
+ maintainer_email: _Validator[str | None] = _Validator(added="1.2")
823
+ """:external:ref:`core-metadata-maintainer-email`"""
824
+ license: _Validator[str | None] = _Validator()
825
+ """:external:ref:`core-metadata-license`"""
826
+ license_expression: _Validator[NormalizedLicenseExpression | None] = _Validator(
827
+ added="2.4"
828
+ )
829
+ """:external:ref:`core-metadata-license-expression`"""
830
+ license_files: _Validator[list[str] | None] = _Validator(added="2.4")
831
+ """:external:ref:`core-metadata-license-file`"""
832
+ classifiers: _Validator[list[str] | None] = _Validator(added="1.1")
833
+ """:external:ref:`core-metadata-classifier`"""
834
+ requires_dist: _Validator[list[requirements.Requirement] | None] = _Validator(
835
+ added="1.2"
836
+ )
837
+ """:external:ref:`core-metadata-requires-dist`"""
838
+ requires_python: _Validator[specifiers.SpecifierSet | None] = _Validator(
839
+ added="1.2"
840
+ )
841
+ """:external:ref:`core-metadata-requires-python`"""
842
+ # Because `Requires-External` allows for non-PEP 440 version specifiers, we
843
+ # don't do any processing on the values.
844
+ requires_external: _Validator[list[str] | None] = _Validator(added="1.2")
845
+ """:external:ref:`core-metadata-requires-external`"""
846
+ project_urls: _Validator[dict[str, str] | None] = _Validator(added="1.2")
847
+ """:external:ref:`core-metadata-project-url`"""
848
+ # PEP 685 lets us raise an error if an extra doesn't pass `Name` validation
849
+ # regardless of metadata version.
850
+ provides_extra: _Validator[list[utils.NormalizedName] | None] = _Validator(
851
+ added="2.1",
852
+ )
853
+ """:external:ref:`core-metadata-provides-extra`"""
854
+ provides_dist: _Validator[list[str] | None] = _Validator(added="1.2")
855
+ """:external:ref:`core-metadata-provides-dist`"""
856
+ obsoletes_dist: _Validator[list[str] | None] = _Validator(added="1.2")
857
+ """:external:ref:`core-metadata-obsoletes-dist`"""
858
+ requires: _Validator[list[str] | None] = _Validator(added="1.1")
859
+ """``Requires`` (deprecated)"""
860
+ provides: _Validator[list[str] | None] = _Validator(added="1.1")
861
+ """``Provides`` (deprecated)"""
862
+ obsoletes: _Validator[list[str] | None] = _Validator(added="1.1")
863
+ """``Obsoletes`` (deprecated)"""
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/py.typed ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/specifiers.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is dual licensed under the terms of the Apache License, Version
2
+ # 2.0, and the BSD License. See the LICENSE file in the root of this repository
3
+ # for complete details.
4
+ """
5
+ .. testsetup::
6
+
7
+ from packaging.specifiers import Specifier, SpecifierSet, InvalidSpecifier
8
+ from packaging.version import Version
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import abc
14
+ import itertools
15
+ import re
16
+ from typing import Callable, Iterable, Iterator, TypeVar, Union
17
+
18
+ from .utils import canonicalize_version
19
+ from .version import Version
20
+
21
+ UnparsedVersion = Union[Version, str]
22
+ UnparsedVersionVar = TypeVar("UnparsedVersionVar", bound=UnparsedVersion)
23
+ CallableOperator = Callable[[Version, str], bool]
24
+
25
+
26
+ def _coerce_version(version: UnparsedVersion) -> Version:
27
+ if not isinstance(version, Version):
28
+ version = Version(version)
29
+ return version
30
+
31
+
32
+ class InvalidSpecifier(ValueError):
33
+ """
34
+ Raised when attempting to create a :class:`Specifier` with a specifier
35
+ string that is invalid.
36
+
37
+ >>> Specifier("lolwat")
38
+ Traceback (most recent call last):
39
+ ...
40
+ packaging.specifiers.InvalidSpecifier: Invalid specifier: 'lolwat'
41
+ """
42
+
43
+
44
+ class BaseSpecifier(metaclass=abc.ABCMeta):
45
+ @abc.abstractmethod
46
+ def __str__(self) -> str:
47
+ """
48
+ Returns the str representation of this Specifier-like object. This
49
+ should be representative of the Specifier itself.
50
+ """
51
+
52
+ @abc.abstractmethod
53
+ def __hash__(self) -> int:
54
+ """
55
+ Returns a hash value for this Specifier-like object.
56
+ """
57
+
58
+ @abc.abstractmethod
59
+ def __eq__(self, other: object) -> bool:
60
+ """
61
+ Returns a boolean representing whether or not the two Specifier-like
62
+ objects are equal.
63
+
64
+ :param other: The other object to check against.
65
+ """
66
+
67
+ @property
68
+ @abc.abstractmethod
69
+ def prereleases(self) -> bool | None:
70
+ """Whether or not pre-releases as a whole are allowed.
71
+
72
+ This can be set to either ``True`` or ``False`` to explicitly enable or disable
73
+ prereleases or it can be set to ``None`` (the default) to use default semantics.
74
+ """
75
+
76
+ @prereleases.setter
77
+ def prereleases(self, value: bool) -> None:
78
+ """Setter for :attr:`prereleases`.
79
+
80
+ :param value: The value to set.
81
+ """
82
+
83
+ @abc.abstractmethod
84
+ def contains(self, item: str, prereleases: bool | None = None) -> bool:
85
+ """
86
+ Determines if the given item is contained within this specifier.
87
+ """
88
+
89
+ @abc.abstractmethod
90
+ def filter(
91
+ self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
92
+ ) -> Iterator[UnparsedVersionVar]:
93
+ """
94
+ Takes an iterable of items and filters them so that only items which
95
+ are contained within this specifier are allowed in it.
96
+ """
97
+
98
+
99
+ class Specifier(BaseSpecifier):
100
+ """This class abstracts handling of version specifiers.
101
+
102
+ .. tip::
103
+
104
+ It is generally not required to instantiate this manually. You should instead
105
+ prefer to work with :class:`SpecifierSet` instead, which can parse
106
+ comma-separated version specifiers (which is what package metadata contains).
107
+ """
108
+
109
+ _operator_regex_str = r"""
110
+ (?P<operator>(~=|==|!=|<=|>=|<|>|===))
111
+ """
112
+ _version_regex_str = r"""
113
+ (?P<version>
114
+ (?:
115
+ # The identity operators allow for an escape hatch that will
116
+ # do an exact string match of the version you wish to install.
117
+ # This will not be parsed by PEP 440 and we cannot determine
118
+ # any semantic meaning from it. This operator is discouraged
119
+ # but included entirely as an escape hatch.
120
+ (?<====) # Only match for the identity operator
121
+ \s*
122
+ [^\s;)]* # The arbitrary version can be just about anything,
123
+ # we match everything except for whitespace, a
124
+ # semi-colon for marker support, and a closing paren
125
+ # since versions can be enclosed in them.
126
+ )
127
+ |
128
+ (?:
129
+ # The (non)equality operators allow for wild card and local
130
+ # versions to be specified so we have to define these two
131
+ # operators separately to enable that.
132
+ (?<===|!=) # Only match for equals and not equals
133
+
134
+ \s*
135
+ v?
136
+ (?:[0-9]+!)? # epoch
137
+ [0-9]+(?:\.[0-9]+)* # release
138
+
139
+ # You cannot use a wild card and a pre-release, post-release, a dev or
140
+ # local version together so group them with a | and make them optional.
141
+ (?:
142
+ \.\* # Wild card syntax of .*
143
+ |
144
+ (?: # pre release
145
+ [-_\.]?
146
+ (alpha|beta|preview|pre|a|b|c|rc)
147
+ [-_\.]?
148
+ [0-9]*
149
+ )?
150
+ (?: # post release
151
+ (?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
152
+ )?
153
+ (?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
154
+ (?:\+[a-z0-9]+(?:[-_\.][a-z0-9]+)*)? # local
155
+ )?
156
+ )
157
+ |
158
+ (?:
159
+ # The compatible operator requires at least two digits in the
160
+ # release segment.
161
+ (?<=~=) # Only match for the compatible operator
162
+
163
+ \s*
164
+ v?
165
+ (?:[0-9]+!)? # epoch
166
+ [0-9]+(?:\.[0-9]+)+ # release (We have a + instead of a *)
167
+ (?: # pre release
168
+ [-_\.]?
169
+ (alpha|beta|preview|pre|a|b|c|rc)
170
+ [-_\.]?
171
+ [0-9]*
172
+ )?
173
+ (?: # post release
174
+ (?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
175
+ )?
176
+ (?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
177
+ )
178
+ |
179
+ (?:
180
+ # All other operators only allow a sub set of what the
181
+ # (non)equality operators do. Specifically they do not allow
182
+ # local versions to be specified nor do they allow the prefix
183
+ # matching wild cards.
184
+ (?<!==|!=|~=) # We have special cases for these
185
+ # operators so we want to make sure they
186
+ # don't match here.
187
+
188
+ \s*
189
+ v?
190
+ (?:[0-9]+!)? # epoch
191
+ [0-9]+(?:\.[0-9]+)* # release
192
+ (?: # pre release
193
+ [-_\.]?
194
+ (alpha|beta|preview|pre|a|b|c|rc)
195
+ [-_\.]?
196
+ [0-9]*
197
+ )?
198
+ (?: # post release
199
+ (?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
200
+ )?
201
+ (?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
202
+ )
203
+ )
204
+ """
205
+
206
+ _regex = re.compile(
207
+ r"^\s*" + _operator_regex_str + _version_regex_str + r"\s*$",
208
+ re.VERBOSE | re.IGNORECASE,
209
+ )
210
+
211
+ _operators = {
212
+ "~=": "compatible",
213
+ "==": "equal",
214
+ "!=": "not_equal",
215
+ "<=": "less_than_equal",
216
+ ">=": "greater_than_equal",
217
+ "<": "less_than",
218
+ ">": "greater_than",
219
+ "===": "arbitrary",
220
+ }
221
+
222
+ def __init__(self, spec: str = "", prereleases: bool | None = None) -> None:
223
+ """Initialize a Specifier instance.
224
+
225
+ :param spec:
226
+ The string representation of a specifier which will be parsed and
227
+ normalized before use.
228
+ :param prereleases:
229
+ This tells the specifier if it should accept prerelease versions if
230
+ applicable or not. The default of ``None`` will autodetect it from the
231
+ given specifiers.
232
+ :raises InvalidSpecifier:
233
+ If the given specifier is invalid (i.e. bad syntax).
234
+ """
235
+ match = self._regex.search(spec)
236
+ if not match:
237
+ raise InvalidSpecifier(f"Invalid specifier: {spec!r}")
238
+
239
+ self._spec: tuple[str, str] = (
240
+ match.group("operator").strip(),
241
+ match.group("version").strip(),
242
+ )
243
+
244
+ # Store whether or not this Specifier should accept prereleases
245
+ self._prereleases = prereleases
246
+
247
+ # https://github.com/python/mypy/pull/13475#pullrequestreview-1079784515
248
+ @property # type: ignore[override]
249
+ def prereleases(self) -> bool:
250
+ # If there is an explicit prereleases set for this, then we'll just
251
+ # blindly use that.
252
+ if self._prereleases is not None:
253
+ return self._prereleases
254
+
255
+ # Look at all of our specifiers and determine if they are inclusive
256
+ # operators, and if they are if they are including an explicit
257
+ # prerelease.
258
+ operator, version = self._spec
259
+ if operator in ["==", ">=", "<=", "~=", "===", ">", "<"]:
260
+ # The == specifier can include a trailing .*, if it does we
261
+ # want to remove before parsing.
262
+ if operator == "==" and version.endswith(".*"):
263
+ version = version[:-2]
264
+
265
+ # Parse the version, and if it is a pre-release than this
266
+ # specifier allows pre-releases.
267
+ if Version(version).is_prerelease:
268
+ return True
269
+
270
+ return False
271
+
272
+ @prereleases.setter
273
+ def prereleases(self, value: bool) -> None:
274
+ self._prereleases = value
275
+
276
+ @property
277
+ def operator(self) -> str:
278
+ """The operator of this specifier.
279
+
280
+ >>> Specifier("==1.2.3").operator
281
+ '=='
282
+ """
283
+ return self._spec[0]
284
+
285
+ @property
286
+ def version(self) -> str:
287
+ """The version of this specifier.
288
+
289
+ >>> Specifier("==1.2.3").version
290
+ '1.2.3'
291
+ """
292
+ return self._spec[1]
293
+
294
+ def __repr__(self) -> str:
295
+ """A representation of the Specifier that shows all internal state.
296
+
297
+ >>> Specifier('>=1.0.0')
298
+ <Specifier('>=1.0.0')>
299
+ >>> Specifier('>=1.0.0', prereleases=False)
300
+ <Specifier('>=1.0.0', prereleases=False)>
301
+ >>> Specifier('>=1.0.0', prereleases=True)
302
+ <Specifier('>=1.0.0', prereleases=True)>
303
+ """
304
+ pre = (
305
+ f", prereleases={self.prereleases!r}"
306
+ if self._prereleases is not None
307
+ else ""
308
+ )
309
+
310
+ return f"<{self.__class__.__name__}({str(self)!r}{pre})>"
311
+
312
+ def __str__(self) -> str:
313
+ """A string representation of the Specifier that can be round-tripped.
314
+
315
+ >>> str(Specifier('>=1.0.0'))
316
+ '>=1.0.0'
317
+ >>> str(Specifier('>=1.0.0', prereleases=False))
318
+ '>=1.0.0'
319
+ """
320
+ return "{}{}".format(*self._spec)
321
+
322
+ @property
323
+ def _canonical_spec(self) -> tuple[str, str]:
324
+ canonical_version = canonicalize_version(
325
+ self._spec[1],
326
+ strip_trailing_zero=(self._spec[0] != "~="),
327
+ )
328
+ return self._spec[0], canonical_version
329
+
330
+ def __hash__(self) -> int:
331
+ return hash(self._canonical_spec)
332
+
333
+ def __eq__(self, other: object) -> bool:
334
+ """Whether or not the two Specifier-like objects are equal.
335
+
336
+ :param other: The other object to check against.
337
+
338
+ The value of :attr:`prereleases` is ignored.
339
+
340
+ >>> Specifier("==1.2.3") == Specifier("== 1.2.3.0")
341
+ True
342
+ >>> (Specifier("==1.2.3", prereleases=False) ==
343
+ ... Specifier("==1.2.3", prereleases=True))
344
+ True
345
+ >>> Specifier("==1.2.3") == "==1.2.3"
346
+ True
347
+ >>> Specifier("==1.2.3") == Specifier("==1.2.4")
348
+ False
349
+ >>> Specifier("==1.2.3") == Specifier("~=1.2.3")
350
+ False
351
+ """
352
+ if isinstance(other, str):
353
+ try:
354
+ other = self.__class__(str(other))
355
+ except InvalidSpecifier:
356
+ return NotImplemented
357
+ elif not isinstance(other, self.__class__):
358
+ return NotImplemented
359
+
360
+ return self._canonical_spec == other._canonical_spec
361
+
362
+ def _get_operator(self, op: str) -> CallableOperator:
363
+ operator_callable: CallableOperator = getattr(
364
+ self, f"_compare_{self._operators[op]}"
365
+ )
366
+ return operator_callable
367
+
368
+ def _compare_compatible(self, prospective: Version, spec: str) -> bool:
369
+ # Compatible releases have an equivalent combination of >= and ==. That
370
+ # is that ~=2.2 is equivalent to >=2.2,==2.*. This allows us to
371
+ # implement this in terms of the other specifiers instead of
372
+ # implementing it ourselves. The only thing we need to do is construct
373
+ # the other specifiers.
374
+
375
+ # We want everything but the last item in the version, but we want to
376
+ # ignore suffix segments.
377
+ prefix = _version_join(
378
+ list(itertools.takewhile(_is_not_suffix, _version_split(spec)))[:-1]
379
+ )
380
+
381
+ # Add the prefix notation to the end of our string
382
+ prefix += ".*"
383
+
384
+ return self._get_operator(">=")(prospective, spec) and self._get_operator("==")(
385
+ prospective, prefix
386
+ )
387
+
388
+ def _compare_equal(self, prospective: Version, spec: str) -> bool:
389
+ # We need special logic to handle prefix matching
390
+ if spec.endswith(".*"):
391
+ # In the case of prefix matching we want to ignore local segment.
392
+ normalized_prospective = canonicalize_version(
393
+ prospective.public, strip_trailing_zero=False
394
+ )
395
+ # Get the normalized version string ignoring the trailing .*
396
+ normalized_spec = canonicalize_version(spec[:-2], strip_trailing_zero=False)
397
+ # Split the spec out by bangs and dots, and pretend that there is
398
+ # an implicit dot in between a release segment and a pre-release segment.
399
+ split_spec = _version_split(normalized_spec)
400
+
401
+ # Split the prospective version out by bangs and dots, and pretend
402
+ # that there is an implicit dot in between a release segment and
403
+ # a pre-release segment.
404
+ split_prospective = _version_split(normalized_prospective)
405
+
406
+ # 0-pad the prospective version before shortening it to get the correct
407
+ # shortened version.
408
+ padded_prospective, _ = _pad_version(split_prospective, split_spec)
409
+
410
+ # Shorten the prospective version to be the same length as the spec
411
+ # so that we can determine if the specifier is a prefix of the
412
+ # prospective version or not.
413
+ shortened_prospective = padded_prospective[: len(split_spec)]
414
+
415
+ return shortened_prospective == split_spec
416
+ else:
417
+ # Convert our spec string into a Version
418
+ spec_version = Version(spec)
419
+
420
+ # If the specifier does not have a local segment, then we want to
421
+ # act as if the prospective version also does not have a local
422
+ # segment.
423
+ if not spec_version.local:
424
+ prospective = Version(prospective.public)
425
+
426
+ return prospective == spec_version
427
+
428
+ def _compare_not_equal(self, prospective: Version, spec: str) -> bool:
429
+ return not self._compare_equal(prospective, spec)
430
+
431
+ def _compare_less_than_equal(self, prospective: Version, spec: str) -> bool:
432
+ # NB: Local version identifiers are NOT permitted in the version
433
+ # specifier, so local version labels can be universally removed from
434
+ # the prospective version.
435
+ return Version(prospective.public) <= Version(spec)
436
+
437
+ def _compare_greater_than_equal(self, prospective: Version, spec: str) -> bool:
438
+ # NB: Local version identifiers are NOT permitted in the version
439
+ # specifier, so local version labels can be universally removed from
440
+ # the prospective version.
441
+ return Version(prospective.public) >= Version(spec)
442
+
443
+ def _compare_less_than(self, prospective: Version, spec_str: str) -> bool:
444
+ # Convert our spec to a Version instance, since we'll want to work with
445
+ # it as a version.
446
+ spec = Version(spec_str)
447
+
448
+ # Check to see if the prospective version is less than the spec
449
+ # version. If it's not we can short circuit and just return False now
450
+ # instead of doing extra unneeded work.
451
+ if not prospective < spec:
452
+ return False
453
+
454
+ # This special case is here so that, unless the specifier itself
455
+ # includes is a pre-release version, that we do not accept pre-release
456
+ # versions for the version mentioned in the specifier (e.g. <3.1 should
457
+ # not match 3.1.dev0, but should match 3.0.dev0).
458
+ if not spec.is_prerelease and prospective.is_prerelease:
459
+ if Version(prospective.base_version) == Version(spec.base_version):
460
+ return False
461
+
462
+ # If we've gotten to here, it means that prospective version is both
463
+ # less than the spec version *and* it's not a pre-release of the same
464
+ # version in the spec.
465
+ return True
466
+
467
+ def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
468
+ # Convert our spec to a Version instance, since we'll want to work with
469
+ # it as a version.
470
+ spec = Version(spec_str)
471
+
472
+ # Check to see if the prospective version is greater than the spec
473
+ # version. If it's not we can short circuit and just return False now
474
+ # instead of doing extra unneeded work.
475
+ if not prospective > spec:
476
+ return False
477
+
478
+ # This special case is here so that, unless the specifier itself
479
+ # includes is a post-release version, that we do not accept
480
+ # post-release versions for the version mentioned in the specifier
481
+ # (e.g. >3.1 should not match 3.0.post0, but should match 3.2.post0).
482
+ if not spec.is_postrelease and prospective.is_postrelease:
483
+ if Version(prospective.base_version) == Version(spec.base_version):
484
+ return False
485
+
486
+ # Ensure that we do not allow a local version of the version mentioned
487
+ # in the specifier, which is technically greater than, to match.
488
+ if prospective.local is not None:
489
+ if Version(prospective.base_version) == Version(spec.base_version):
490
+ return False
491
+
492
+ # If we've gotten to here, it means that prospective version is both
493
+ # greater than the spec version *and* it's not a pre-release of the
494
+ # same version in the spec.
495
+ return True
496
+
497
+ def _compare_arbitrary(self, prospective: Version, spec: str) -> bool:
498
+ return str(prospective).lower() == str(spec).lower()
499
+
500
+ def __contains__(self, item: str | Version) -> bool:
501
+ """Return whether or not the item is contained in this specifier.
502
+
503
+ :param item: The item to check for.
504
+
505
+ This is used for the ``in`` operator and behaves the same as
506
+ :meth:`contains` with no ``prereleases`` argument passed.
507
+
508
+ >>> "1.2.3" in Specifier(">=1.2.3")
509
+ True
510
+ >>> Version("1.2.3") in Specifier(">=1.2.3")
511
+ True
512
+ >>> "1.0.0" in Specifier(">=1.2.3")
513
+ False
514
+ >>> "1.3.0a1" in Specifier(">=1.2.3")
515
+ False
516
+ >>> "1.3.0a1" in Specifier(">=1.2.3", prereleases=True)
517
+ True
518
+ """
519
+ return self.contains(item)
520
+
521
+ def contains(self, item: UnparsedVersion, prereleases: bool | None = None) -> bool:
522
+ """Return whether or not the item is contained in this specifier.
523
+
524
+ :param item:
525
+ The item to check for, which can be a version string or a
526
+ :class:`Version` instance.
527
+ :param prereleases:
528
+ Whether or not to match prereleases with this Specifier. If set to
529
+ ``None`` (the default), it uses :attr:`prereleases` to determine
530
+ whether or not prereleases are allowed.
531
+
532
+ >>> Specifier(">=1.2.3").contains("1.2.3")
533
+ True
534
+ >>> Specifier(">=1.2.3").contains(Version("1.2.3"))
535
+ True
536
+ >>> Specifier(">=1.2.3").contains("1.0.0")
537
+ False
538
+ >>> Specifier(">=1.2.3").contains("1.3.0a1")
539
+ False
540
+ >>> Specifier(">=1.2.3", prereleases=True).contains("1.3.0a1")
541
+ True
542
+ >>> Specifier(">=1.2.3").contains("1.3.0a1", prereleases=True)
543
+ True
544
+ """
545
+
546
+ # Determine if prereleases are to be allowed or not.
547
+ if prereleases is None:
548
+ prereleases = self.prereleases
549
+
550
+ # Normalize item to a Version, this allows us to have a shortcut for
551
+ # "2.0" in Specifier(">=2")
552
+ normalized_item = _coerce_version(item)
553
+
554
+ # Determine if we should be supporting prereleases in this specifier
555
+ # or not, if we do not support prereleases than we can short circuit
556
+ # logic if this version is a prereleases.
557
+ if normalized_item.is_prerelease and not prereleases:
558
+ return False
559
+
560
+ # Actually do the comparison to determine if this item is contained
561
+ # within this Specifier or not.
562
+ operator_callable: CallableOperator = self._get_operator(self.operator)
563
+ return operator_callable(normalized_item, self.version)
564
+
565
+ def filter(
566
+ self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
567
+ ) -> Iterator[UnparsedVersionVar]:
568
+ """Filter items in the given iterable, that match the specifier.
569
+
570
+ :param iterable:
571
+ An iterable that can contain version strings and :class:`Version` instances.
572
+ The items in the iterable will be filtered according to the specifier.
573
+ :param prereleases:
574
+ Whether or not to allow prereleases in the returned iterator. If set to
575
+ ``None`` (the default), it will be intelligently decide whether to allow
576
+ prereleases or not (based on the :attr:`prereleases` attribute, and
577
+ whether the only versions matching are prereleases).
578
+
579
+ This method is smarter than just ``filter(Specifier().contains, [...])``
580
+ because it implements the rule from :pep:`440` that a prerelease item
581
+ SHOULD be accepted if no other versions match the given specifier.
582
+
583
+ >>> list(Specifier(">=1.2.3").filter(["1.2", "1.3", "1.5a1"]))
584
+ ['1.3']
585
+ >>> list(Specifier(">=1.2.3").filter(["1.2", "1.2.3", "1.3", Version("1.4")]))
586
+ ['1.2.3', '1.3', <Version('1.4')>]
587
+ >>> list(Specifier(">=1.2.3").filter(["1.2", "1.5a1"]))
588
+ ['1.5a1']
589
+ >>> list(Specifier(">=1.2.3").filter(["1.3", "1.5a1"], prereleases=True))
590
+ ['1.3', '1.5a1']
591
+ >>> list(Specifier(">=1.2.3", prereleases=True).filter(["1.3", "1.5a1"]))
592
+ ['1.3', '1.5a1']
593
+ """
594
+
595
+ yielded = False
596
+ found_prereleases = []
597
+
598
+ kw = {"prereleases": prereleases if prereleases is not None else True}
599
+
600
+ # Attempt to iterate over all the values in the iterable and if any of
601
+ # them match, yield them.
602
+ for version in iterable:
603
+ parsed_version = _coerce_version(version)
604
+
605
+ if self.contains(parsed_version, **kw):
606
+ # If our version is a prerelease, and we were not set to allow
607
+ # prereleases, then we'll store it for later in case nothing
608
+ # else matches this specifier.
609
+ if parsed_version.is_prerelease and not (
610
+ prereleases or self.prereleases
611
+ ):
612
+ found_prereleases.append(version)
613
+ # Either this is not a prerelease, or we should have been
614
+ # accepting prereleases from the beginning.
615
+ else:
616
+ yielded = True
617
+ yield version
618
+
619
+ # Now that we've iterated over everything, determine if we've yielded
620
+ # any values, and if we have not and we have any prereleases stored up
621
+ # then we will go ahead and yield the prereleases.
622
+ if not yielded and found_prereleases:
623
+ for version in found_prereleases:
624
+ yield version
625
+
626
+
627
+ _prefix_regex = re.compile(r"^([0-9]+)((?:a|b|c|rc)[0-9]+)$")
628
+
629
+
630
+ def _version_split(version: str) -> list[str]:
631
+ """Split version into components.
632
+
633
+ The split components are intended for version comparison. The logic does
634
+ not attempt to retain the original version string, so joining the
635
+ components back with :func:`_version_join` may not produce the original
636
+ version string.
637
+ """
638
+ result: list[str] = []
639
+
640
+ epoch, _, rest = version.rpartition("!")
641
+ result.append(epoch or "0")
642
+
643
+ for item in rest.split("."):
644
+ match = _prefix_regex.search(item)
645
+ if match:
646
+ result.extend(match.groups())
647
+ else:
648
+ result.append(item)
649
+ return result
650
+
651
+
652
+ def _version_join(components: list[str]) -> str:
653
+ """Join split version components into a version string.
654
+
655
+ This function assumes the input came from :func:`_version_split`, where the
656
+ first component must be the epoch (either empty or numeric), and all other
657
+ components numeric.
658
+ """
659
+ epoch, *rest = components
660
+ return f"{epoch}!{'.'.join(rest)}"
661
+
662
+
663
+ def _is_not_suffix(segment: str) -> bool:
664
+ return not any(
665
+ segment.startswith(prefix) for prefix in ("dev", "a", "b", "rc", "post")
666
+ )
667
+
668
+
669
+ def _pad_version(left: list[str], right: list[str]) -> tuple[list[str], list[str]]:
670
+ left_split, right_split = [], []
671
+
672
+ # Get the release segment of our versions
673
+ left_split.append(list(itertools.takewhile(lambda x: x.isdigit(), left)))
674
+ right_split.append(list(itertools.takewhile(lambda x: x.isdigit(), right)))
675
+
676
+ # Get the rest of our versions
677
+ left_split.append(left[len(left_split[0]) :])
678
+ right_split.append(right[len(right_split[0]) :])
679
+
680
+ # Insert our padding
681
+ left_split.insert(1, ["0"] * max(0, len(right_split[0]) - len(left_split[0])))
682
+ right_split.insert(1, ["0"] * max(0, len(left_split[0]) - len(right_split[0])))
683
+
684
+ return (
685
+ list(itertools.chain.from_iterable(left_split)),
686
+ list(itertools.chain.from_iterable(right_split)),
687
+ )
688
+
689
+
690
+ class SpecifierSet(BaseSpecifier):
691
+ """This class abstracts handling of a set of version specifiers.
692
+
693
+ It can be passed a single specifier (``>=3.0``), a comma-separated list of
694
+ specifiers (``>=3.0,!=3.1``), or no specifier at all.
695
+ """
696
+
697
+ def __init__(
698
+ self,
699
+ specifiers: str | Iterable[Specifier] = "",
700
+ prereleases: bool | None = None,
701
+ ) -> None:
702
+ """Initialize a SpecifierSet instance.
703
+
704
+ :param specifiers:
705
+ The string representation of a specifier or a comma-separated list of
706
+ specifiers which will be parsed and normalized before use.
707
+ May also be an iterable of ``Specifier`` instances, which will be used
708
+ as is.
709
+ :param prereleases:
710
+ This tells the SpecifierSet if it should accept prerelease versions if
711
+ applicable or not. The default of ``None`` will autodetect it from the
712
+ given specifiers.
713
+
714
+ :raises InvalidSpecifier:
715
+ If the given ``specifiers`` are not parseable than this exception will be
716
+ raised.
717
+ """
718
+
719
+ if isinstance(specifiers, str):
720
+ # Split on `,` to break each individual specifier into its own item, and
721
+ # strip each item to remove leading/trailing whitespace.
722
+ split_specifiers = [s.strip() for s in specifiers.split(",") if s.strip()]
723
+
724
+ # Make each individual specifier a Specifier and save in a frozen set
725
+ # for later.
726
+ self._specs = frozenset(map(Specifier, split_specifiers))
727
+ else:
728
+ # Save the supplied specifiers in a frozen set.
729
+ self._specs = frozenset(specifiers)
730
+
731
+ # Store our prereleases value so we can use it later to determine if
732
+ # we accept prereleases or not.
733
+ self._prereleases = prereleases
734
+
735
+ @property
736
+ def prereleases(self) -> bool | None:
737
+ # If we have been given an explicit prerelease modifier, then we'll
738
+ # pass that through here.
739
+ if self._prereleases is not None:
740
+ return self._prereleases
741
+
742
+ # If we don't have any specifiers, and we don't have a forced value,
743
+ # then we'll just return None since we don't know if this should have
744
+ # pre-releases or not.
745
+ if not self._specs:
746
+ return None
747
+
748
+ # Otherwise we'll see if any of the given specifiers accept
749
+ # prereleases, if any of them do we'll return True, otherwise False.
750
+ return any(s.prereleases for s in self._specs)
751
+
752
+ @prereleases.setter
753
+ def prereleases(self, value: bool) -> None:
754
+ self._prereleases = value
755
+
756
+ def __repr__(self) -> str:
757
+ """A representation of the specifier set that shows all internal state.
758
+
759
+ Note that the ordering of the individual specifiers within the set may not
760
+ match the input string.
761
+
762
+ >>> SpecifierSet('>=1.0.0,!=2.0.0')
763
+ <SpecifierSet('!=2.0.0,>=1.0.0')>
764
+ >>> SpecifierSet('>=1.0.0,!=2.0.0', prereleases=False)
765
+ <SpecifierSet('!=2.0.0,>=1.0.0', prereleases=False)>
766
+ >>> SpecifierSet('>=1.0.0,!=2.0.0', prereleases=True)
767
+ <SpecifierSet('!=2.0.0,>=1.0.0', prereleases=True)>
768
+ """
769
+ pre = (
770
+ f", prereleases={self.prereleases!r}"
771
+ if self._prereleases is not None
772
+ else ""
773
+ )
774
+
775
+ return f"<SpecifierSet({str(self)!r}{pre})>"
776
+
777
+ def __str__(self) -> str:
778
+ """A string representation of the specifier set that can be round-tripped.
779
+
780
+ Note that the ordering of the individual specifiers within the set may not
781
+ match the input string.
782
+
783
+ >>> str(SpecifierSet(">=1.0.0,!=1.0.1"))
784
+ '!=1.0.1,>=1.0.0'
785
+ >>> str(SpecifierSet(">=1.0.0,!=1.0.1", prereleases=False))
786
+ '!=1.0.1,>=1.0.0'
787
+ """
788
+ return ",".join(sorted(str(s) for s in self._specs))
789
+
790
+ def __hash__(self) -> int:
791
+ return hash(self._specs)
792
+
793
+ def __and__(self, other: SpecifierSet | str) -> SpecifierSet:
794
+ """Return a SpecifierSet which is a combination of the two sets.
795
+
796
+ :param other: The other object to combine with.
797
+
798
+ >>> SpecifierSet(">=1.0.0,!=1.0.1") & '<=2.0.0,!=2.0.1'
799
+ <SpecifierSet('!=1.0.1,!=2.0.1,<=2.0.0,>=1.0.0')>
800
+ >>> SpecifierSet(">=1.0.0,!=1.0.1") & SpecifierSet('<=2.0.0,!=2.0.1')
801
+ <SpecifierSet('!=1.0.1,!=2.0.1,<=2.0.0,>=1.0.0')>
802
+ """
803
+ if isinstance(other, str):
804
+ other = SpecifierSet(other)
805
+ elif not isinstance(other, SpecifierSet):
806
+ return NotImplemented
807
+
808
+ specifier = SpecifierSet()
809
+ specifier._specs = frozenset(self._specs | other._specs)
810
+
811
+ if self._prereleases is None and other._prereleases is not None:
812
+ specifier._prereleases = other._prereleases
813
+ elif self._prereleases is not None and other._prereleases is None:
814
+ specifier._prereleases = self._prereleases
815
+ elif self._prereleases == other._prereleases:
816
+ specifier._prereleases = self._prereleases
817
+ else:
818
+ raise ValueError(
819
+ "Cannot combine SpecifierSets with True and False prerelease "
820
+ "overrides."
821
+ )
822
+
823
+ return specifier
824
+
825
+ def __eq__(self, other: object) -> bool:
826
+ """Whether or not the two SpecifierSet-like objects are equal.
827
+
828
+ :param other: The other object to check against.
829
+
830
+ The value of :attr:`prereleases` is ignored.
831
+
832
+ >>> SpecifierSet(">=1.0.0,!=1.0.1") == SpecifierSet(">=1.0.0,!=1.0.1")
833
+ True
834
+ >>> (SpecifierSet(">=1.0.0,!=1.0.1", prereleases=False) ==
835
+ ... SpecifierSet(">=1.0.0,!=1.0.1", prereleases=True))
836
+ True
837
+ >>> SpecifierSet(">=1.0.0,!=1.0.1") == ">=1.0.0,!=1.0.1"
838
+ True
839
+ >>> SpecifierSet(">=1.0.0,!=1.0.1") == SpecifierSet(">=1.0.0")
840
+ False
841
+ >>> SpecifierSet(">=1.0.0,!=1.0.1") == SpecifierSet(">=1.0.0,!=1.0.2")
842
+ False
843
+ """
844
+ if isinstance(other, (str, Specifier)):
845
+ other = SpecifierSet(str(other))
846
+ elif not isinstance(other, SpecifierSet):
847
+ return NotImplemented
848
+
849
+ return self._specs == other._specs
850
+
851
+ def __len__(self) -> int:
852
+ """Returns the number of specifiers in this specifier set."""
853
+ return len(self._specs)
854
+
855
+ def __iter__(self) -> Iterator[Specifier]:
856
+ """
857
+ Returns an iterator over all the underlying :class:`Specifier` instances
858
+ in this specifier set.
859
+
860
+ >>> sorted(SpecifierSet(">=1.0.0,!=1.0.1"), key=str)
861
+ [<Specifier('!=1.0.1')>, <Specifier('>=1.0.0')>]
862
+ """
863
+ return iter(self._specs)
864
+
865
+ def __contains__(self, item: UnparsedVersion) -> bool:
866
+ """Return whether or not the item is contained in this specifier.
867
+
868
+ :param item: The item to check for.
869
+
870
+ This is used for the ``in`` operator and behaves the same as
871
+ :meth:`contains` with no ``prereleases`` argument passed.
872
+
873
+ >>> "1.2.3" in SpecifierSet(">=1.0.0,!=1.0.1")
874
+ True
875
+ >>> Version("1.2.3") in SpecifierSet(">=1.0.0,!=1.0.1")
876
+ True
877
+ >>> "1.0.1" in SpecifierSet(">=1.0.0,!=1.0.1")
878
+ False
879
+ >>> "1.3.0a1" in SpecifierSet(">=1.0.0,!=1.0.1")
880
+ False
881
+ >>> "1.3.0a1" in SpecifierSet(">=1.0.0,!=1.0.1", prereleases=True)
882
+ True
883
+ """
884
+ return self.contains(item)
885
+
886
+ def contains(
887
+ self,
888
+ item: UnparsedVersion,
889
+ prereleases: bool | None = None,
890
+ installed: bool | None = None,
891
+ ) -> bool:
892
+ """Return whether or not the item is contained in this SpecifierSet.
893
+
894
+ :param item:
895
+ The item to check for, which can be a version string or a
896
+ :class:`Version` instance.
897
+ :param prereleases:
898
+ Whether or not to match prereleases with this SpecifierSet. If set to
899
+ ``None`` (the default), it uses :attr:`prereleases` to determine
900
+ whether or not prereleases are allowed.
901
+
902
+ >>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.2.3")
903
+ True
904
+ >>> SpecifierSet(">=1.0.0,!=1.0.1").contains(Version("1.2.3"))
905
+ True
906
+ >>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.0.1")
907
+ False
908
+ >>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.3.0a1")
909
+ False
910
+ >>> SpecifierSet(">=1.0.0,!=1.0.1", prereleases=True).contains("1.3.0a1")
911
+ True
912
+ >>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.3.0a1", prereleases=True)
913
+ True
914
+ """
915
+ # Ensure that our item is a Version instance.
916
+ if not isinstance(item, Version):
917
+ item = Version(item)
918
+
919
+ # Determine if we're forcing a prerelease or not, if we're not forcing
920
+ # one for this particular filter call, then we'll use whatever the
921
+ # SpecifierSet thinks for whether or not we should support prereleases.
922
+ if prereleases is None:
923
+ prereleases = self.prereleases
924
+
925
+ # We can determine if we're going to allow pre-releases by looking to
926
+ # see if any of the underlying items supports them. If none of them do
927
+ # and this item is a pre-release then we do not allow it and we can
928
+ # short circuit that here.
929
+ # Note: This means that 1.0.dev1 would not be contained in something
930
+ # like >=1.0.devabc however it would be in >=1.0.debabc,>0.0.dev0
931
+ if not prereleases and item.is_prerelease:
932
+ return False
933
+
934
+ if installed and item.is_prerelease:
935
+ item = Version(item.base_version)
936
+
937
+ # We simply dispatch to the underlying specs here to make sure that the
938
+ # given version is contained within all of them.
939
+ # Note: This use of all() here means that an empty set of specifiers
940
+ # will always return True, this is an explicit design decision.
941
+ return all(s.contains(item, prereleases=prereleases) for s in self._specs)
942
+
943
+ def filter(
944
+ self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
945
+ ) -> Iterator[UnparsedVersionVar]:
946
+ """Filter items in the given iterable, that match the specifiers in this set.
947
+
948
+ :param iterable:
949
+ An iterable that can contain version strings and :class:`Version` instances.
950
+ The items in the iterable will be filtered according to the specifier.
951
+ :param prereleases:
952
+ Whether or not to allow prereleases in the returned iterator. If set to
953
+ ``None`` (the default), it will be intelligently decide whether to allow
954
+ prereleases or not (based on the :attr:`prereleases` attribute, and
955
+ whether the only versions matching are prereleases).
956
+
957
+ This method is smarter than just ``filter(SpecifierSet(...).contains, [...])``
958
+ because it implements the rule from :pep:`440` that a prerelease item
959
+ SHOULD be accepted if no other versions match the given specifier.
960
+
961
+ >>> list(SpecifierSet(">=1.2.3").filter(["1.2", "1.3", "1.5a1"]))
962
+ ['1.3']
963
+ >>> list(SpecifierSet(">=1.2.3").filter(["1.2", "1.3", Version("1.4")]))
964
+ ['1.3', <Version('1.4')>]
965
+ >>> list(SpecifierSet(">=1.2.3").filter(["1.2", "1.5a1"]))
966
+ []
967
+ >>> list(SpecifierSet(">=1.2.3").filter(["1.3", "1.5a1"], prereleases=True))
968
+ ['1.3', '1.5a1']
969
+ >>> list(SpecifierSet(">=1.2.3", prereleases=True).filter(["1.3", "1.5a1"]))
970
+ ['1.3', '1.5a1']
971
+
972
+ An "empty" SpecifierSet will filter items based on the presence of prerelease
973
+ versions in the set.
974
+
975
+ >>> list(SpecifierSet("").filter(["1.3", "1.5a1"]))
976
+ ['1.3']
977
+ >>> list(SpecifierSet("").filter(["1.5a1"]))
978
+ ['1.5a1']
979
+ >>> list(SpecifierSet("", prereleases=True).filter(["1.3", "1.5a1"]))
980
+ ['1.3', '1.5a1']
981
+ >>> list(SpecifierSet("").filter(["1.3", "1.5a1"], prereleases=True))
982
+ ['1.3', '1.5a1']
983
+ """
984
+ # Determine if we're forcing a prerelease or not, if we're not forcing
985
+ # one for this particular filter call, then we'll use whatever the
986
+ # SpecifierSet thinks for whether or not we should support prereleases.
987
+ if prereleases is None:
988
+ prereleases = self.prereleases
989
+
990
+ # If we have any specifiers, then we want to wrap our iterable in the
991
+ # filter method for each one, this will act as a logical AND amongst
992
+ # each specifier.
993
+ if self._specs:
994
+ for spec in self._specs:
995
+ iterable = spec.filter(iterable, prereleases=bool(prereleases))
996
+ return iter(iterable)
997
+ # If we do not have any specifiers, then we need to have a rough filter
998
+ # which will filter out any pre-releases, unless there are no final
999
+ # releases.
1000
+ else:
1001
+ filtered: list[UnparsedVersionVar] = []
1002
+ found_prereleases: list[UnparsedVersionVar] = []
1003
+
1004
+ for item in iterable:
1005
+ parsed_version = _coerce_version(item)
1006
+
1007
+ # Store any item which is a pre-release for later unless we've
1008
+ # already found a final version or we are accepting prereleases
1009
+ if parsed_version.is_prerelease and not prereleases:
1010
+ if not filtered:
1011
+ found_prereleases.append(item)
1012
+ else:
1013
+ filtered.append(item)
1014
+
1015
+ # If we've found no items except for pre-releases, then we'll go
1016
+ # ahead and use the pre-releases
1017
+ if not filtered and found_prereleases and prereleases is None:
1018
+ return iter(found_prereleases)
1019
+
1020
+ return iter(filtered)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/tags.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is dual licensed under the terms of the Apache License, Version
2
+ # 2.0, and the BSD License. See the LICENSE file in the root of this repository
3
+ # for complete details.
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ import platform
9
+ import re
10
+ import struct
11
+ import subprocess
12
+ import sys
13
+ import sysconfig
14
+ from importlib.machinery import EXTENSION_SUFFIXES
15
+ from typing import (
16
+ Iterable,
17
+ Iterator,
18
+ Sequence,
19
+ Tuple,
20
+ cast,
21
+ )
22
+
23
+ from . import _manylinux, _musllinux
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ PythonVersion = Sequence[int]
28
+ AppleVersion = Tuple[int, int]
29
+
30
+ INTERPRETER_SHORT_NAMES: dict[str, str] = {
31
+ "python": "py", # Generic.
32
+ "cpython": "cp",
33
+ "pypy": "pp",
34
+ "ironpython": "ip",
35
+ "jython": "jy",
36
+ }
37
+
38
+
39
+ _32_BIT_INTERPRETER = struct.calcsize("P") == 4
40
+
41
+
42
+ class Tag:
43
+ """
44
+ A representation of the tag triple for a wheel.
45
+
46
+ Instances are considered immutable and thus are hashable. Equality checking
47
+ is also supported.
48
+ """
49
+
50
+ __slots__ = ["_abi", "_hash", "_interpreter", "_platform"]
51
+
52
+ def __init__(self, interpreter: str, abi: str, platform: str) -> None:
53
+ self._interpreter = interpreter.lower()
54
+ self._abi = abi.lower()
55
+ self._platform = platform.lower()
56
+ # The __hash__ of every single element in a Set[Tag] will be evaluated each time
57
+ # that a set calls its `.disjoint()` method, which may be called hundreds of
58
+ # times when scanning a page of links for packages with tags matching that
59
+ # Set[Tag]. Pre-computing the value here produces significant speedups for
60
+ # downstream consumers.
61
+ self._hash = hash((self._interpreter, self._abi, self._platform))
62
+
63
+ @property
64
+ def interpreter(self) -> str:
65
+ return self._interpreter
66
+
67
+ @property
68
+ def abi(self) -> str:
69
+ return self._abi
70
+
71
+ @property
72
+ def platform(self) -> str:
73
+ return self._platform
74
+
75
+ def __eq__(self, other: object) -> bool:
76
+ if not isinstance(other, Tag):
77
+ return NotImplemented
78
+
79
+ return (
80
+ (self._hash == other._hash) # Short-circuit ASAP for perf reasons.
81
+ and (self._platform == other._platform)
82
+ and (self._abi == other._abi)
83
+ and (self._interpreter == other._interpreter)
84
+ )
85
+
86
+ def __hash__(self) -> int:
87
+ return self._hash
88
+
89
+ def __str__(self) -> str:
90
+ return f"{self._interpreter}-{self._abi}-{self._platform}"
91
+
92
+ def __repr__(self) -> str:
93
+ return f"<{self} @ {id(self)}>"
94
+
95
+
96
+ def parse_tag(tag: str) -> frozenset[Tag]:
97
+ """
98
+ Parses the provided tag (e.g. `py3-none-any`) into a frozenset of Tag instances.
99
+
100
+ Returning a set is required due to the possibility that the tag is a
101
+ compressed tag set.
102
+ """
103
+ tags = set()
104
+ interpreters, abis, platforms = tag.split("-")
105
+ for interpreter in interpreters.split("."):
106
+ for abi in abis.split("."):
107
+ for platform_ in platforms.split("."):
108
+ tags.add(Tag(interpreter, abi, platform_))
109
+ return frozenset(tags)
110
+
111
+
112
+ def _get_config_var(name: str, warn: bool = False) -> int | str | None:
113
+ value: int | str | None = sysconfig.get_config_var(name)
114
+ if value is None and warn:
115
+ logger.debug(
116
+ "Config variable '%s' is unset, Python ABI tag may be incorrect", name
117
+ )
118
+ return value
119
+
120
+
121
+ def _normalize_string(string: str) -> str:
122
+ return string.replace(".", "_").replace("-", "_").replace(" ", "_")
123
+
124
+
125
+ def _is_threaded_cpython(abis: list[str]) -> bool:
126
+ """
127
+ Determine if the ABI corresponds to a threaded (`--disable-gil`) build.
128
+
129
+ The threaded builds are indicated by a "t" in the abiflags.
130
+ """
131
+ if len(abis) == 0:
132
+ return False
133
+ # expect e.g., cp313
134
+ m = re.match(r"cp\d+(.*)", abis[0])
135
+ if not m:
136
+ return False
137
+ abiflags = m.group(1)
138
+ return "t" in abiflags
139
+
140
+
141
+ def _abi3_applies(python_version: PythonVersion, threading: bool) -> bool:
142
+ """
143
+ Determine if the Python version supports abi3.
144
+
145
+ PEP 384 was first implemented in Python 3.2. The threaded (`--disable-gil`)
146
+ builds do not support abi3.
147
+ """
148
+ return len(python_version) > 1 and tuple(python_version) >= (3, 2) and not threading
149
+
150
+
151
+ def _cpython_abis(py_version: PythonVersion, warn: bool = False) -> list[str]:
152
+ py_version = tuple(py_version) # To allow for version comparison.
153
+ abis = []
154
+ version = _version_nodot(py_version[:2])
155
+ threading = debug = pymalloc = ucs4 = ""
156
+ with_debug = _get_config_var("Py_DEBUG", warn)
157
+ has_refcount = hasattr(sys, "gettotalrefcount")
158
+ # Windows doesn't set Py_DEBUG, so checking for support of debug-compiled
159
+ # extension modules is the best option.
160
+ # https://github.com/pypa/pip/issues/3383#issuecomment-173267692
161
+ has_ext = "_d.pyd" in EXTENSION_SUFFIXES
162
+ if with_debug or (with_debug is None and (has_refcount or has_ext)):
163
+ debug = "d"
164
+ if py_version >= (3, 13) and _get_config_var("Py_GIL_DISABLED", warn):
165
+ threading = "t"
166
+ if py_version < (3, 8):
167
+ with_pymalloc = _get_config_var("WITH_PYMALLOC", warn)
168
+ if with_pymalloc or with_pymalloc is None:
169
+ pymalloc = "m"
170
+ if py_version < (3, 3):
171
+ unicode_size = _get_config_var("Py_UNICODE_SIZE", warn)
172
+ if unicode_size == 4 or (
173
+ unicode_size is None and sys.maxunicode == 0x10FFFF
174
+ ):
175
+ ucs4 = "u"
176
+ elif debug:
177
+ # Debug builds can also load "normal" extension modules.
178
+ # We can also assume no UCS-4 or pymalloc requirement.
179
+ abis.append(f"cp{version}{threading}")
180
+ abis.insert(0, f"cp{version}{threading}{debug}{pymalloc}{ucs4}")
181
+ return abis
182
+
183
+
184
+ def cpython_tags(
185
+ python_version: PythonVersion | None = None,
186
+ abis: Iterable[str] | None = None,
187
+ platforms: Iterable[str] | None = None,
188
+ *,
189
+ warn: bool = False,
190
+ ) -> Iterator[Tag]:
191
+ """
192
+ Yields the tags for a CPython interpreter.
193
+
194
+ The tags consist of:
195
+ - cp<python_version>-<abi>-<platform>
196
+ - cp<python_version>-abi3-<platform>
197
+ - cp<python_version>-none-<platform>
198
+ - cp<less than python_version>-abi3-<platform> # Older Python versions down to 3.2.
199
+
200
+ If python_version only specifies a major version then user-provided ABIs and
201
+ the 'none' ABItag will be used.
202
+
203
+ If 'abi3' or 'none' are specified in 'abis' then they will be yielded at
204
+ their normal position and not at the beginning.
205
+ """
206
+ if not python_version:
207
+ python_version = sys.version_info[:2]
208
+
209
+ interpreter = f"cp{_version_nodot(python_version[:2])}"
210
+
211
+ if abis is None:
212
+ if len(python_version) > 1:
213
+ abis = _cpython_abis(python_version, warn)
214
+ else:
215
+ abis = []
216
+ abis = list(abis)
217
+ # 'abi3' and 'none' are explicitly handled later.
218
+ for explicit_abi in ("abi3", "none"):
219
+ try:
220
+ abis.remove(explicit_abi)
221
+ except ValueError:
222
+ pass
223
+
224
+ platforms = list(platforms or platform_tags())
225
+ for abi in abis:
226
+ for platform_ in platforms:
227
+ yield Tag(interpreter, abi, platform_)
228
+
229
+ threading = _is_threaded_cpython(abis)
230
+ use_abi3 = _abi3_applies(python_version, threading)
231
+ if use_abi3:
232
+ yield from (Tag(interpreter, "abi3", platform_) for platform_ in platforms)
233
+ yield from (Tag(interpreter, "none", platform_) for platform_ in platforms)
234
+
235
+ if use_abi3:
236
+ for minor_version in range(python_version[1] - 1, 1, -1):
237
+ for platform_ in platforms:
238
+ version = _version_nodot((python_version[0], minor_version))
239
+ interpreter = f"cp{version}"
240
+ yield Tag(interpreter, "abi3", platform_)
241
+
242
+
243
+ def _generic_abi() -> list[str]:
244
+ """
245
+ Return the ABI tag based on EXT_SUFFIX.
246
+ """
247
+ # The following are examples of `EXT_SUFFIX`.
248
+ # We want to keep the parts which are related to the ABI and remove the
249
+ # parts which are related to the platform:
250
+ # - linux: '.cpython-310-x86_64-linux-gnu.so' => cp310
251
+ # - mac: '.cpython-310-darwin.so' => cp310
252
+ # - win: '.cp310-win_amd64.pyd' => cp310
253
+ # - win: '.pyd' => cp37 (uses _cpython_abis())
254
+ # - pypy: '.pypy38-pp73-x86_64-linux-gnu.so' => pypy38_pp73
255
+ # - graalpy: '.graalpy-38-native-x86_64-darwin.dylib'
256
+ # => graalpy_38_native
257
+
258
+ ext_suffix = _get_config_var("EXT_SUFFIX", warn=True)
259
+ if not isinstance(ext_suffix, str) or ext_suffix[0] != ".":
260
+ raise SystemError("invalid sysconfig.get_config_var('EXT_SUFFIX')")
261
+ parts = ext_suffix.split(".")
262
+ if len(parts) < 3:
263
+ # CPython3.7 and earlier uses ".pyd" on Windows.
264
+ return _cpython_abis(sys.version_info[:2])
265
+ soabi = parts[1]
266
+ if soabi.startswith("cpython"):
267
+ # non-windows
268
+ abi = "cp" + soabi.split("-")[1]
269
+ elif soabi.startswith("cp"):
270
+ # windows
271
+ abi = soabi.split("-")[0]
272
+ elif soabi.startswith("pypy"):
273
+ abi = "-".join(soabi.split("-")[:2])
274
+ elif soabi.startswith("graalpy"):
275
+ abi = "-".join(soabi.split("-")[:3])
276
+ elif soabi:
277
+ # pyston, ironpython, others?
278
+ abi = soabi
279
+ else:
280
+ return []
281
+ return [_normalize_string(abi)]
282
+
283
+
284
+ def generic_tags(
285
+ interpreter: str | None = None,
286
+ abis: Iterable[str] | None = None,
287
+ platforms: Iterable[str] | None = None,
288
+ *,
289
+ warn: bool = False,
290
+ ) -> Iterator[Tag]:
291
+ """
292
+ Yields the tags for a generic interpreter.
293
+
294
+ The tags consist of:
295
+ - <interpreter>-<abi>-<platform>
296
+
297
+ The "none" ABI will be added if it was not explicitly provided.
298
+ """
299
+ if not interpreter:
300
+ interp_name = interpreter_name()
301
+ interp_version = interpreter_version(warn=warn)
302
+ interpreter = "".join([interp_name, interp_version])
303
+ if abis is None:
304
+ abis = _generic_abi()
305
+ else:
306
+ abis = list(abis)
307
+ platforms = list(platforms or platform_tags())
308
+ if "none" not in abis:
309
+ abis.append("none")
310
+ for abi in abis:
311
+ for platform_ in platforms:
312
+ yield Tag(interpreter, abi, platform_)
313
+
314
+
315
+ def _py_interpreter_range(py_version: PythonVersion) -> Iterator[str]:
316
+ """
317
+ Yields Python versions in descending order.
318
+
319
+ After the latest version, the major-only version will be yielded, and then
320
+ all previous versions of that major version.
321
+ """
322
+ if len(py_version) > 1:
323
+ yield f"py{_version_nodot(py_version[:2])}"
324
+ yield f"py{py_version[0]}"
325
+ if len(py_version) > 1:
326
+ for minor in range(py_version[1] - 1, -1, -1):
327
+ yield f"py{_version_nodot((py_version[0], minor))}"
328
+
329
+
330
+ def compatible_tags(
331
+ python_version: PythonVersion | None = None,
332
+ interpreter: str | None = None,
333
+ platforms: Iterable[str] | None = None,
334
+ ) -> Iterator[Tag]:
335
+ """
336
+ Yields the sequence of tags that are compatible with a specific version of Python.
337
+
338
+ The tags consist of:
339
+ - py*-none-<platform>
340
+ - <interpreter>-none-any # ... if `interpreter` is provided.
341
+ - py*-none-any
342
+ """
343
+ if not python_version:
344
+ python_version = sys.version_info[:2]
345
+ platforms = list(platforms or platform_tags())
346
+ for version in _py_interpreter_range(python_version):
347
+ for platform_ in platforms:
348
+ yield Tag(version, "none", platform_)
349
+ if interpreter:
350
+ yield Tag(interpreter, "none", "any")
351
+ for version in _py_interpreter_range(python_version):
352
+ yield Tag(version, "none", "any")
353
+
354
+
355
+ def _mac_arch(arch: str, is_32bit: bool = _32_BIT_INTERPRETER) -> str:
356
+ if not is_32bit:
357
+ return arch
358
+
359
+ if arch.startswith("ppc"):
360
+ return "ppc"
361
+
362
+ return "i386"
363
+
364
+
365
+ def _mac_binary_formats(version: AppleVersion, cpu_arch: str) -> list[str]:
366
+ formats = [cpu_arch]
367
+ if cpu_arch == "x86_64":
368
+ if version < (10, 4):
369
+ return []
370
+ formats.extend(["intel", "fat64", "fat32"])
371
+
372
+ elif cpu_arch == "i386":
373
+ if version < (10, 4):
374
+ return []
375
+ formats.extend(["intel", "fat32", "fat"])
376
+
377
+ elif cpu_arch == "ppc64":
378
+ # TODO: Need to care about 32-bit PPC for ppc64 through 10.2?
379
+ if version > (10, 5) or version < (10, 4):
380
+ return []
381
+ formats.append("fat64")
382
+
383
+ elif cpu_arch == "ppc":
384
+ if version > (10, 6):
385
+ return []
386
+ formats.extend(["fat32", "fat"])
387
+
388
+ if cpu_arch in {"arm64", "x86_64"}:
389
+ formats.append("universal2")
390
+
391
+ if cpu_arch in {"x86_64", "i386", "ppc64", "ppc", "intel"}:
392
+ formats.append("universal")
393
+
394
+ return formats
395
+
396
+
397
+ def mac_platforms(
398
+ version: AppleVersion | None = None, arch: str | None = None
399
+ ) -> Iterator[str]:
400
+ """
401
+ Yields the platform tags for a macOS system.
402
+
403
+ The `version` parameter is a two-item tuple specifying the macOS version to
404
+ generate platform tags for. The `arch` parameter is the CPU architecture to
405
+ generate platform tags for. Both parameters default to the appropriate value
406
+ for the current system.
407
+ """
408
+ version_str, _, cpu_arch = platform.mac_ver()
409
+ if version is None:
410
+ version = cast("AppleVersion", tuple(map(int, version_str.split(".")[:2])))
411
+ if version == (10, 16):
412
+ # When built against an older macOS SDK, Python will report macOS 10.16
413
+ # instead of the real version.
414
+ version_str = subprocess.run(
415
+ [
416
+ sys.executable,
417
+ "-sS",
418
+ "-c",
419
+ "import platform; print(platform.mac_ver()[0])",
420
+ ],
421
+ check=True,
422
+ env={"SYSTEM_VERSION_COMPAT": "0"},
423
+ stdout=subprocess.PIPE,
424
+ text=True,
425
+ ).stdout
426
+ version = cast("AppleVersion", tuple(map(int, version_str.split(".")[:2])))
427
+ else:
428
+ version = version
429
+ if arch is None:
430
+ arch = _mac_arch(cpu_arch)
431
+ else:
432
+ arch = arch
433
+
434
+ if (10, 0) <= version and version < (11, 0):
435
+ # Prior to Mac OS 11, each yearly release of Mac OS bumped the
436
+ # "minor" version number. The major version was always 10.
437
+ major_version = 10
438
+ for minor_version in range(version[1], -1, -1):
439
+ compat_version = major_version, minor_version
440
+ binary_formats = _mac_binary_formats(compat_version, arch)
441
+ for binary_format in binary_formats:
442
+ yield f"macosx_{major_version}_{minor_version}_{binary_format}"
443
+
444
+ if version >= (11, 0):
445
+ # Starting with Mac OS 11, each yearly release bumps the major version
446
+ # number. The minor versions are now the midyear updates.
447
+ minor_version = 0
448
+ for major_version in range(version[0], 10, -1):
449
+ compat_version = major_version, minor_version
450
+ binary_formats = _mac_binary_formats(compat_version, arch)
451
+ for binary_format in binary_formats:
452
+ yield f"macosx_{major_version}_{minor_version}_{binary_format}"
453
+
454
+ if version >= (11, 0):
455
+ # Mac OS 11 on x86_64 is compatible with binaries from previous releases.
456
+ # Arm64 support was introduced in 11.0, so no Arm binaries from previous
457
+ # releases exist.
458
+ #
459
+ # However, the "universal2" binary format can have a
460
+ # macOS version earlier than 11.0 when the x86_64 part of the binary supports
461
+ # that version of macOS.
462
+ major_version = 10
463
+ if arch == "x86_64":
464
+ for minor_version in range(16, 3, -1):
465
+ compat_version = major_version, minor_version
466
+ binary_formats = _mac_binary_formats(compat_version, arch)
467
+ for binary_format in binary_formats:
468
+ yield f"macosx_{major_version}_{minor_version}_{binary_format}"
469
+ else:
470
+ for minor_version in range(16, 3, -1):
471
+ compat_version = major_version, minor_version
472
+ binary_format = "universal2"
473
+ yield f"macosx_{major_version}_{minor_version}_{binary_format}"
474
+
475
+
476
+ def ios_platforms(
477
+ version: AppleVersion | None = None, multiarch: str | None = None
478
+ ) -> Iterator[str]:
479
+ """
480
+ Yields the platform tags for an iOS system.
481
+
482
+ :param version: A two-item tuple specifying the iOS version to generate
483
+ platform tags for. Defaults to the current iOS version.
484
+ :param multiarch: The CPU architecture+ABI to generate platform tags for -
485
+ (the value used by `sys.implementation._multiarch` e.g.,
486
+ `arm64_iphoneos` or `x84_64_iphonesimulator`). Defaults to the current
487
+ multiarch value.
488
+ """
489
+ if version is None:
490
+ # if iOS is the current platform, ios_ver *must* be defined. However,
491
+ # it won't exist for CPython versions before 3.13, which causes a mypy
492
+ # error.
493
+ _, release, _, _ = platform.ios_ver() # type: ignore[attr-defined, unused-ignore]
494
+ version = cast("AppleVersion", tuple(map(int, release.split(".")[:2])))
495
+
496
+ if multiarch is None:
497
+ multiarch = sys.implementation._multiarch
498
+ multiarch = multiarch.replace("-", "_")
499
+
500
+ ios_platform_template = "ios_{major}_{minor}_{multiarch}"
501
+
502
+ # Consider any iOS major.minor version from the version requested, down to
503
+ # 12.0. 12.0 is the first iOS version that is known to have enough features
504
+ # to support CPython. Consider every possible minor release up to X.9. There
505
+ # highest the minor has ever gone is 8 (14.8 and 15.8) but having some extra
506
+ # candidates that won't ever match doesn't really hurt, and it saves us from
507
+ # having to keep an explicit list of known iOS versions in the code. Return
508
+ # the results descending order of version number.
509
+
510
+ # If the requested major version is less than 12, there won't be any matches.
511
+ if version[0] < 12:
512
+ return
513
+
514
+ # Consider the actual X.Y version that was requested.
515
+ yield ios_platform_template.format(
516
+ major=version[0], minor=version[1], multiarch=multiarch
517
+ )
518
+
519
+ # Consider every minor version from X.0 to the minor version prior to the
520
+ # version requested by the platform.
521
+ for minor in range(version[1] - 1, -1, -1):
522
+ yield ios_platform_template.format(
523
+ major=version[0], minor=minor, multiarch=multiarch
524
+ )
525
+
526
+ for major in range(version[0] - 1, 11, -1):
527
+ for minor in range(9, -1, -1):
528
+ yield ios_platform_template.format(
529
+ major=major, minor=minor, multiarch=multiarch
530
+ )
531
+
532
+
533
+ def _linux_platforms(is_32bit: bool = _32_BIT_INTERPRETER) -> Iterator[str]:
534
+ linux = _normalize_string(sysconfig.get_platform())
535
+ if not linux.startswith("linux_"):
536
+ # we should never be here, just yield the sysconfig one and return
537
+ yield linux
538
+ return
539
+ if is_32bit:
540
+ if linux == "linux_x86_64":
541
+ linux = "linux_i686"
542
+ elif linux == "linux_aarch64":
543
+ linux = "linux_armv8l"
544
+ _, arch = linux.split("_", 1)
545
+ archs = {"armv8l": ["armv8l", "armv7l"]}.get(arch, [arch])
546
+ yield from _manylinux.platform_tags(archs)
547
+ yield from _musllinux.platform_tags(archs)
548
+ for arch in archs:
549
+ yield f"linux_{arch}"
550
+
551
+
552
+ def _generic_platforms() -> Iterator[str]:
553
+ yield _normalize_string(sysconfig.get_platform())
554
+
555
+
556
+ def platform_tags() -> Iterator[str]:
557
+ """
558
+ Provides the platform tags for this installation.
559
+ """
560
+ if platform.system() == "Darwin":
561
+ return mac_platforms()
562
+ elif platform.system() == "iOS":
563
+ return ios_platforms()
564
+ elif platform.system() == "Linux":
565
+ return _linux_platforms()
566
+ else:
567
+ return _generic_platforms()
568
+
569
+
570
+ def interpreter_name() -> str:
571
+ """
572
+ Returns the name of the running interpreter.
573
+
574
+ Some implementations have a reserved, two-letter abbreviation which will
575
+ be returned when appropriate.
576
+ """
577
+ name = sys.implementation.name
578
+ return INTERPRETER_SHORT_NAMES.get(name) or name
579
+
580
+
581
+ def interpreter_version(*, warn: bool = False) -> str:
582
+ """
583
+ Returns the version of the running interpreter.
584
+ """
585
+ version = _get_config_var("py_version_nodot", warn=warn)
586
+ if version:
587
+ version = str(version)
588
+ else:
589
+ version = _version_nodot(sys.version_info[:2])
590
+ return version
591
+
592
+
593
+ def _version_nodot(version: PythonVersion) -> str:
594
+ return "".join(map(str, version))
595
+
596
+
597
+ def sys_tags(*, warn: bool = False) -> Iterator[Tag]:
598
+ """
599
+ Returns the sequence of tag triples for the running interpreter.
600
+
601
+ The order of the sequence corresponds to priority order for the
602
+ interpreter, from most to least important.
603
+ """
604
+
605
+ interp_name = interpreter_name()
606
+ if interp_name == "cp":
607
+ yield from cpython_tags(warn=warn)
608
+ else:
609
+ yield from generic_tags()
610
+
611
+ if interp_name == "pp":
612
+ interp = "pp3"
613
+ elif interp_name == "cp":
614
+ interp = "cp" + interpreter_version(warn=warn)
615
+ else:
616
+ interp = None
617
+ yield from compatible_tags(interpreter=interp)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/version.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is dual licensed under the terms of the Apache License, Version
2
+ # 2.0, and the BSD License. See the LICENSE file in the root of this repository
3
+ # for complete details.
4
+ """
5
+ .. testsetup::
6
+
7
+ from packaging.version import parse, Version
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import itertools
13
+ import re
14
+ from typing import Any, Callable, NamedTuple, SupportsInt, Tuple, Union
15
+
16
+ from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType
17
+
18
+ __all__ = ["VERSION_PATTERN", "InvalidVersion", "Version", "parse"]
19
+
20
+ LocalType = Tuple[Union[int, str], ...]
21
+
22
+ CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]]
23
+ CmpLocalType = Union[
24
+ NegativeInfinityType,
25
+ Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...],
26
+ ]
27
+ CmpKey = Tuple[
28
+ int,
29
+ Tuple[int, ...],
30
+ CmpPrePostDevType,
31
+ CmpPrePostDevType,
32
+ CmpPrePostDevType,
33
+ CmpLocalType,
34
+ ]
35
+ VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool]
36
+
37
+
38
+ class _Version(NamedTuple):
39
+ epoch: int
40
+ release: tuple[int, ...]
41
+ dev: tuple[str, int] | None
42
+ pre: tuple[str, int] | None
43
+ post: tuple[str, int] | None
44
+ local: LocalType | None
45
+
46
+
47
+ def parse(version: str) -> Version:
48
+ """Parse the given version string.
49
+
50
+ >>> parse('1.0.dev1')
51
+ <Version('1.0.dev1')>
52
+
53
+ :param version: The version string to parse.
54
+ :raises InvalidVersion: When the version string is not a valid version.
55
+ """
56
+ return Version(version)
57
+
58
+
59
+ class InvalidVersion(ValueError):
60
+ """Raised when a version string is not a valid version.
61
+
62
+ >>> Version("invalid")
63
+ Traceback (most recent call last):
64
+ ...
65
+ packaging.version.InvalidVersion: Invalid version: 'invalid'
66
+ """
67
+
68
+
69
+ class _BaseVersion:
70
+ _key: tuple[Any, ...]
71
+
72
+ def __hash__(self) -> int:
73
+ return hash(self._key)
74
+
75
+ # Please keep the duplicated `isinstance` check
76
+ # in the six comparisons hereunder
77
+ # unless you find a way to avoid adding overhead function calls.
78
+ def __lt__(self, other: _BaseVersion) -> bool:
79
+ if not isinstance(other, _BaseVersion):
80
+ return NotImplemented
81
+
82
+ return self._key < other._key
83
+
84
+ def __le__(self, other: _BaseVersion) -> bool:
85
+ if not isinstance(other, _BaseVersion):
86
+ return NotImplemented
87
+
88
+ return self._key <= other._key
89
+
90
+ def __eq__(self, other: object) -> bool:
91
+ if not isinstance(other, _BaseVersion):
92
+ return NotImplemented
93
+
94
+ return self._key == other._key
95
+
96
+ def __ge__(self, other: _BaseVersion) -> bool:
97
+ if not isinstance(other, _BaseVersion):
98
+ return NotImplemented
99
+
100
+ return self._key >= other._key
101
+
102
+ def __gt__(self, other: _BaseVersion) -> bool:
103
+ if not isinstance(other, _BaseVersion):
104
+ return NotImplemented
105
+
106
+ return self._key > other._key
107
+
108
+ def __ne__(self, other: object) -> bool:
109
+ if not isinstance(other, _BaseVersion):
110
+ return NotImplemented
111
+
112
+ return self._key != other._key
113
+
114
+
115
+ # Deliberately not anchored to the start and end of the string, to make it
116
+ # easier for 3rd party code to reuse
117
+ _VERSION_PATTERN = r"""
118
+ v?
119
+ (?:
120
+ (?:(?P<epoch>[0-9]+)!)? # epoch
121
+ (?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
122
+ (?P<pre> # pre-release
123
+ [-_\.]?
124
+ (?P<pre_l>alpha|a|beta|b|preview|pre|c|rc)
125
+ [-_\.]?
126
+ (?P<pre_n>[0-9]+)?
127
+ )?
128
+ (?P<post> # post release
129
+ (?:-(?P<post_n1>[0-9]+))
130
+ |
131
+ (?:
132
+ [-_\.]?
133
+ (?P<post_l>post|rev|r)
134
+ [-_\.]?
135
+ (?P<post_n2>[0-9]+)?
136
+ )
137
+ )?
138
+ (?P<dev> # dev release
139
+ [-_\.]?
140
+ (?P<dev_l>dev)
141
+ [-_\.]?
142
+ (?P<dev_n>[0-9]+)?
143
+ )?
144
+ )
145
+ (?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
146
+ """
147
+
148
+ VERSION_PATTERN = _VERSION_PATTERN
149
+ """
150
+ A string containing the regular expression used to match a valid version.
151
+
152
+ The pattern is not anchored at either end, and is intended for embedding in larger
153
+ expressions (for example, matching a version number as part of a file name). The
154
+ regular expression should be compiled with the ``re.VERBOSE`` and ``re.IGNORECASE``
155
+ flags set.
156
+
157
+ :meta hide-value:
158
+ """
159
+
160
+
161
+ class Version(_BaseVersion):
162
+ """This class abstracts handling of a project's versions.
163
+
164
+ A :class:`Version` instance is comparison aware and can be compared and
165
+ sorted using the standard Python interfaces.
166
+
167
+ >>> v1 = Version("1.0a5")
168
+ >>> v2 = Version("1.0")
169
+ >>> v1
170
+ <Version('1.0a5')>
171
+ >>> v2
172
+ <Version('1.0')>
173
+ >>> v1 < v2
174
+ True
175
+ >>> v1 == v2
176
+ False
177
+ >>> v1 > v2
178
+ False
179
+ >>> v1 >= v2
180
+ False
181
+ >>> v1 <= v2
182
+ True
183
+ """
184
+
185
+ _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
186
+ _key: CmpKey
187
+
188
+ def __init__(self, version: str) -> None:
189
+ """Initialize a Version object.
190
+
191
+ :param version:
192
+ The string representation of a version which will be parsed and normalized
193
+ before use.
194
+ :raises InvalidVersion:
195
+ If the ``version`` does not conform to PEP 440 in any way then this
196
+ exception will be raised.
197
+ """
198
+
199
+ # Validate the version and parse it into pieces
200
+ match = self._regex.search(version)
201
+ if not match:
202
+ raise InvalidVersion(f"Invalid version: {version!r}")
203
+
204
+ # Store the parsed out pieces of the version
205
+ self._version = _Version(
206
+ epoch=int(match.group("epoch")) if match.group("epoch") else 0,
207
+ release=tuple(int(i) for i in match.group("release").split(".")),
208
+ pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
209
+ post=_parse_letter_version(
210
+ match.group("post_l"), match.group("post_n1") or match.group("post_n2")
211
+ ),
212
+ dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
213
+ local=_parse_local_version(match.group("local")),
214
+ )
215
+
216
+ # Generate a key which will be used for sorting
217
+ self._key = _cmpkey(
218
+ self._version.epoch,
219
+ self._version.release,
220
+ self._version.pre,
221
+ self._version.post,
222
+ self._version.dev,
223
+ self._version.local,
224
+ )
225
+
226
+ def __repr__(self) -> str:
227
+ """A representation of the Version that shows all internal state.
228
+
229
+ >>> Version('1.0.0')
230
+ <Version('1.0.0')>
231
+ """
232
+ return f"<Version('{self}')>"
233
+
234
+ def __str__(self) -> str:
235
+ """A string representation of the version that can be round-tripped.
236
+
237
+ >>> str(Version("1.0a5"))
238
+ '1.0a5'
239
+ """
240
+ parts = []
241
+
242
+ # Epoch
243
+ if self.epoch != 0:
244
+ parts.append(f"{self.epoch}!")
245
+
246
+ # Release segment
247
+ parts.append(".".join(str(x) for x in self.release))
248
+
249
+ # Pre-release
250
+ if self.pre is not None:
251
+ parts.append("".join(str(x) for x in self.pre))
252
+
253
+ # Post-release
254
+ if self.post is not None:
255
+ parts.append(f".post{self.post}")
256
+
257
+ # Development release
258
+ if self.dev is not None:
259
+ parts.append(f".dev{self.dev}")
260
+
261
+ # Local version segment
262
+ if self.local is not None:
263
+ parts.append(f"+{self.local}")
264
+
265
+ return "".join(parts)
266
+
267
+ @property
268
+ def epoch(self) -> int:
269
+ """The epoch of the version.
270
+
271
+ >>> Version("2.0.0").epoch
272
+ 0
273
+ >>> Version("1!2.0.0").epoch
274
+ 1
275
+ """
276
+ return self._version.epoch
277
+
278
+ @property
279
+ def release(self) -> tuple[int, ...]:
280
+ """The components of the "release" segment of the version.
281
+
282
+ >>> Version("1.2.3").release
283
+ (1, 2, 3)
284
+ >>> Version("2.0.0").release
285
+ (2, 0, 0)
286
+ >>> Version("1!2.0.0.post0").release
287
+ (2, 0, 0)
288
+
289
+ Includes trailing zeroes but not the epoch or any pre-release / development /
290
+ post-release suffixes.
291
+ """
292
+ return self._version.release
293
+
294
+ @property
295
+ def pre(self) -> tuple[str, int] | None:
296
+ """The pre-release segment of the version.
297
+
298
+ >>> print(Version("1.2.3").pre)
299
+ None
300
+ >>> Version("1.2.3a1").pre
301
+ ('a', 1)
302
+ >>> Version("1.2.3b1").pre
303
+ ('b', 1)
304
+ >>> Version("1.2.3rc1").pre
305
+ ('rc', 1)
306
+ """
307
+ return self._version.pre
308
+
309
+ @property
310
+ def post(self) -> int | None:
311
+ """The post-release number of the version.
312
+
313
+ >>> print(Version("1.2.3").post)
314
+ None
315
+ >>> Version("1.2.3.post1").post
316
+ 1
317
+ """
318
+ return self._version.post[1] if self._version.post else None
319
+
320
+ @property
321
+ def dev(self) -> int | None:
322
+ """The development number of the version.
323
+
324
+ >>> print(Version("1.2.3").dev)
325
+ None
326
+ >>> Version("1.2.3.dev1").dev
327
+ 1
328
+ """
329
+ return self._version.dev[1] if self._version.dev else None
330
+
331
+ @property
332
+ def local(self) -> str | None:
333
+ """The local version segment of the version.
334
+
335
+ >>> print(Version("1.2.3").local)
336
+ None
337
+ >>> Version("1.2.3+abc").local
338
+ 'abc'
339
+ """
340
+ if self._version.local:
341
+ return ".".join(str(x) for x in self._version.local)
342
+ else:
343
+ return None
344
+
345
+ @property
346
+ def public(self) -> str:
347
+ """The public portion of the version.
348
+
349
+ >>> Version("1.2.3").public
350
+ '1.2.3'
351
+ >>> Version("1.2.3+abc").public
352
+ '1.2.3'
353
+ >>> Version("1!1.2.3dev1+abc").public
354
+ '1!1.2.3.dev1'
355
+ """
356
+ return str(self).split("+", 1)[0]
357
+
358
+ @property
359
+ def base_version(self) -> str:
360
+ """The "base version" of the version.
361
+
362
+ >>> Version("1.2.3").base_version
363
+ '1.2.3'
364
+ >>> Version("1.2.3+abc").base_version
365
+ '1.2.3'
366
+ >>> Version("1!1.2.3dev1+abc").base_version
367
+ '1!1.2.3'
368
+
369
+ The "base version" is the public version of the project without any pre or post
370
+ release markers.
371
+ """
372
+ parts = []
373
+
374
+ # Epoch
375
+ if self.epoch != 0:
376
+ parts.append(f"{self.epoch}!")
377
+
378
+ # Release segment
379
+ parts.append(".".join(str(x) for x in self.release))
380
+
381
+ return "".join(parts)
382
+
383
+ @property
384
+ def is_prerelease(self) -> bool:
385
+ """Whether this version is a pre-release.
386
+
387
+ >>> Version("1.2.3").is_prerelease
388
+ False
389
+ >>> Version("1.2.3a1").is_prerelease
390
+ True
391
+ >>> Version("1.2.3b1").is_prerelease
392
+ True
393
+ >>> Version("1.2.3rc1").is_prerelease
394
+ True
395
+ >>> Version("1.2.3dev1").is_prerelease
396
+ True
397
+ """
398
+ return self.dev is not None or self.pre is not None
399
+
400
+ @property
401
+ def is_postrelease(self) -> bool:
402
+ """Whether this version is a post-release.
403
+
404
+ >>> Version("1.2.3").is_postrelease
405
+ False
406
+ >>> Version("1.2.3.post1").is_postrelease
407
+ True
408
+ """
409
+ return self.post is not None
410
+
411
+ @property
412
+ def is_devrelease(self) -> bool:
413
+ """Whether this version is a development release.
414
+
415
+ >>> Version("1.2.3").is_devrelease
416
+ False
417
+ >>> Version("1.2.3.dev1").is_devrelease
418
+ True
419
+ """
420
+ return self.dev is not None
421
+
422
+ @property
423
+ def major(self) -> int:
424
+ """The first item of :attr:`release` or ``0`` if unavailable.
425
+
426
+ >>> Version("1.2.3").major
427
+ 1
428
+ """
429
+ return self.release[0] if len(self.release) >= 1 else 0
430
+
431
+ @property
432
+ def minor(self) -> int:
433
+ """The second item of :attr:`release` or ``0`` if unavailable.
434
+
435
+ >>> Version("1.2.3").minor
436
+ 2
437
+ >>> Version("1").minor
438
+ 0
439
+ """
440
+ return self.release[1] if len(self.release) >= 2 else 0
441
+
442
+ @property
443
+ def micro(self) -> int:
444
+ """The third item of :attr:`release` or ``0`` if unavailable.
445
+
446
+ >>> Version("1.2.3").micro
447
+ 3
448
+ >>> Version("1").micro
449
+ 0
450
+ """
451
+ return self.release[2] if len(self.release) >= 3 else 0
452
+
453
+
454
+ class _TrimmedRelease(Version):
455
+ @property
456
+ def release(self) -> tuple[int, ...]:
457
+ """
458
+ Release segment without any trailing zeros.
459
+
460
+ >>> _TrimmedRelease('1.0.0').release
461
+ (1,)
462
+ >>> _TrimmedRelease('0.0').release
463
+ (0,)
464
+ """
465
+ rel = super().release
466
+ nonzeros = (index for index, val in enumerate(rel) if val)
467
+ last_nonzero = max(nonzeros, default=0)
468
+ return rel[: last_nonzero + 1]
469
+
470
+
471
+ def _parse_letter_version(
472
+ letter: str | None, number: str | bytes | SupportsInt | None
473
+ ) -> tuple[str, int] | None:
474
+ if letter:
475
+ # We consider there to be an implicit 0 in a pre-release if there is
476
+ # not a numeral associated with it.
477
+ if number is None:
478
+ number = 0
479
+
480
+ # We normalize any letters to their lower case form
481
+ letter = letter.lower()
482
+
483
+ # We consider some words to be alternate spellings of other words and
484
+ # in those cases we want to normalize the spellings to our preferred
485
+ # spelling.
486
+ if letter == "alpha":
487
+ letter = "a"
488
+ elif letter == "beta":
489
+ letter = "b"
490
+ elif letter in ["c", "pre", "preview"]:
491
+ letter = "rc"
492
+ elif letter in ["rev", "r"]:
493
+ letter = "post"
494
+
495
+ return letter, int(number)
496
+
497
+ assert not letter
498
+ if number:
499
+ # We assume if we are given a number, but we are not given a letter
500
+ # then this is using the implicit post release syntax (e.g. 1.0-1)
501
+ letter = "post"
502
+
503
+ return letter, int(number)
504
+
505
+ return None
506
+
507
+
508
+ _local_version_separators = re.compile(r"[\._-]")
509
+
510
+
511
+ def _parse_local_version(local: str | None) -> LocalType | None:
512
+ """
513
+ Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
514
+ """
515
+ if local is not None:
516
+ return tuple(
517
+ part.lower() if not part.isdigit() else int(part)
518
+ for part in _local_version_separators.split(local)
519
+ )
520
+ return None
521
+
522
+
523
+ def _cmpkey(
524
+ epoch: int,
525
+ release: tuple[int, ...],
526
+ pre: tuple[str, int] | None,
527
+ post: tuple[str, int] | None,
528
+ dev: tuple[str, int] | None,
529
+ local: LocalType | None,
530
+ ) -> CmpKey:
531
+ # When we compare a release version, we want to compare it with all of the
532
+ # trailing zeros removed. So we'll use a reverse the list, drop all the now
533
+ # leading zeros until we come to something non zero, then take the rest
534
+ # re-reverse it back into the correct order and make it a tuple and use
535
+ # that for our sorting key.
536
+ _release = tuple(
537
+ reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
538
+ )
539
+
540
+ # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
541
+ # We'll do this by abusing the pre segment, but we _only_ want to do this
542
+ # if there is not a pre or a post segment. If we have one of those then
543
+ # the normal sorting rules will handle this case correctly.
544
+ if pre is None and post is None and dev is not None:
545
+ _pre: CmpPrePostDevType = NegativeInfinity
546
+ # Versions without a pre-release (except as noted above) should sort after
547
+ # those with one.
548
+ elif pre is None:
549
+ _pre = Infinity
550
+ else:
551
+ _pre = pre
552
+
553
+ # Versions without a post segment should sort before those with one.
554
+ if post is None:
555
+ _post: CmpPrePostDevType = NegativeInfinity
556
+
557
+ else:
558
+ _post = post
559
+
560
+ # Versions without a development segment should sort after those with one.
561
+ if dev is None:
562
+ _dev: CmpPrePostDevType = Infinity
563
+
564
+ else:
565
+ _dev = dev
566
+
567
+ if local is None:
568
+ # Versions without a local segment should sort before those with one.
569
+ _local: CmpLocalType = NegativeInfinity
570
+ else:
571
+ # Versions with a local segment need that segment parsed to implement
572
+ # the sorting rules in PEP440.
573
+ # - Alpha numeric segments sort before numeric segments
574
+ # - Alpha numeric segments sort lexicographically
575
+ # - Numeric segments sort numerically
576
+ # - Shorter versions sort before longer versions when the prefixes
577
+ # match exactly
578
+ _local = tuple(
579
+ (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
580
+ )
581
+
582
+ return epoch, _release, _pre, _post, _dev, _local
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/cachecontrol/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2015 Eric Larson
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """CacheControl import Interface.
6
+
7
+ Make it easy to import from cachecontrol without long namespaces.
8
+ """
9
+ __author__ = "Eric Larson"
10
+ __email__ = "eric@ionrock.org"
11
+ __version__ = "0.14.0"
12
+
13
+ from pip._vendor.cachecontrol.adapter import CacheControlAdapter
14
+ from pip._vendor.cachecontrol.controller import CacheController
15
+ from pip._vendor.cachecontrol.wrapper import CacheControl
16
+
17
+ __all__ = [
18
+ "__author__",
19
+ "__email__",
20
+ "__version__",
21
+ "CacheControlAdapter",
22
+ "CacheController",
23
+ "CacheControl",
24
+ ]
25
+
26
+ import logging
27
+
28
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/cachecontrol/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (999 Bytes). View file