koichi12 commited on
Commit
211e5eb
·
verified ·
1 Parent(s): dbf954e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec-2024.2.0.dist-info/LICENSE +29 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec-2024.2.0.dist-info/METADATA +167 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_mpmath.py +7 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_str.py +14 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_visualization.py +32 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/__init__.py +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv_infer_v8.h +658 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv_train_v8.h +540 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend_v8.h +600 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_train.h +219 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_train_v8.h +219 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops_train_v8.h +501 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_v8.h +78 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version_v8.h +70 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/__init__.py +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/__init__.py +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/__init__.py +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/curand_kernel.h +1665 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/curand_mtgp32_host.h +516 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/curand_poisson.h +751 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/curand_uniform.h +498 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverSp.h +923 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/locators.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/markers.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/metadata.cpython-311.pyc +0 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/resources.cpython-311.pyc +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/scripts.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/util.cpython-311.pyc +0 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/version.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/wheel.cpython-311.pyc +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distro/__pycache__/__init__.cpython-311.pyc +0 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distro/py.typed +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/_emoji_codes.py +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/_wrap.py +93 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/constrain.py +37 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/file_proxy.py +57 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/highlighter.py +232 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/json.py +139 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/layout.py +442 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/progress_bar.py +223 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/syntax.py +958 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/tree.py +249 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pybind11-2.13.6.dist-info/INSTALLER +1 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pybind11-2.13.6.dist-info/METADATA +220 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/__init__.py +36 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/__init__.py +19 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py +180 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec-2024.2.0.dist-info/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2018, Martin Durant
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ * Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
tuning-competition-baseline/.venv/lib/python3.11/site-packages/fsspec-2024.2.0.dist-info/METADATA ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: fsspec
3
+ Version: 2024.2.0
4
+ Summary: File-system specification
5
+ Home-page: https://github.com/fsspec/filesystem_spec
6
+ Maintainer: Martin Durant
7
+ Maintainer-email: mdurant@anaconda.com
8
+ License: BSD
9
+ Project-URL: Changelog, https://filesystem-spec.readthedocs.io/en/latest/changelog.html
10
+ Project-URL: Documentation, https://filesystem-spec.readthedocs.io/en/latest/
11
+ Keywords: file
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: License :: OSI Approved :: BSD License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3.8
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Requires-Python: >=3.8
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Provides-Extra: abfs
24
+ Requires-Dist: adlfs ; extra == 'abfs'
25
+ Provides-Extra: adl
26
+ Requires-Dist: adlfs ; extra == 'adl'
27
+ Provides-Extra: arrow
28
+ Requires-Dist: pyarrow >=1 ; extra == 'arrow'
29
+ Provides-Extra: dask
30
+ Requires-Dist: dask ; extra == 'dask'
31
+ Requires-Dist: distributed ; extra == 'dask'
32
+ Provides-Extra: devel
33
+ Requires-Dist: pytest ; extra == 'devel'
34
+ Requires-Dist: pytest-cov ; extra == 'devel'
35
+ Provides-Extra: dropbox
36
+ Requires-Dist: dropboxdrivefs ; extra == 'dropbox'
37
+ Requires-Dist: requests ; extra == 'dropbox'
38
+ Requires-Dist: dropbox ; extra == 'dropbox'
39
+ Provides-Extra: entrypoints
40
+ Provides-Extra: full
41
+ Requires-Dist: adlfs ; extra == 'full'
42
+ Requires-Dist: aiohttp !=4.0.0a0,!=4.0.0a1 ; extra == 'full'
43
+ Requires-Dist: dask ; extra == 'full'
44
+ Requires-Dist: distributed ; extra == 'full'
45
+ Requires-Dist: dropbox ; extra == 'full'
46
+ Requires-Dist: dropboxdrivefs ; extra == 'full'
47
+ Requires-Dist: fusepy ; extra == 'full'
48
+ Requires-Dist: gcsfs ; extra == 'full'
49
+ Requires-Dist: libarchive-c ; extra == 'full'
50
+ Requires-Dist: ocifs ; extra == 'full'
51
+ Requires-Dist: panel ; extra == 'full'
52
+ Requires-Dist: paramiko ; extra == 'full'
53
+ Requires-Dist: pyarrow >=1 ; extra == 'full'
54
+ Requires-Dist: pygit2 ; extra == 'full'
55
+ Requires-Dist: requests ; extra == 'full'
56
+ Requires-Dist: s3fs ; extra == 'full'
57
+ Requires-Dist: smbprotocol ; extra == 'full'
58
+ Requires-Dist: tqdm ; extra == 'full'
59
+ Provides-Extra: fuse
60
+ Requires-Dist: fusepy ; extra == 'fuse'
61
+ Provides-Extra: gcs
62
+ Requires-Dist: gcsfs ; extra == 'gcs'
63
+ Provides-Extra: git
64
+ Requires-Dist: pygit2 ; extra == 'git'
65
+ Provides-Extra: github
66
+ Requires-Dist: requests ; extra == 'github'
67
+ Provides-Extra: gs
68
+ Requires-Dist: gcsfs ; extra == 'gs'
69
+ Provides-Extra: gui
70
+ Requires-Dist: panel ; extra == 'gui'
71
+ Provides-Extra: hdfs
72
+ Requires-Dist: pyarrow >=1 ; extra == 'hdfs'
73
+ Provides-Extra: http
74
+ Requires-Dist: aiohttp !=4.0.0a0,!=4.0.0a1 ; extra == 'http'
75
+ Provides-Extra: libarchive
76
+ Requires-Dist: libarchive-c ; extra == 'libarchive'
77
+ Provides-Extra: oci
78
+ Requires-Dist: ocifs ; extra == 'oci'
79
+ Provides-Extra: s3
80
+ Requires-Dist: s3fs ; extra == 's3'
81
+ Provides-Extra: sftp
82
+ Requires-Dist: paramiko ; extra == 'sftp'
83
+ Provides-Extra: smb
84
+ Requires-Dist: smbprotocol ; extra == 'smb'
85
+ Provides-Extra: ssh
86
+ Requires-Dist: paramiko ; extra == 'ssh'
87
+ Provides-Extra: tqdm
88
+ Requires-Dist: tqdm ; extra == 'tqdm'
89
+
90
+ # filesystem_spec
91
+
92
+ [![PyPI version](https://badge.fury.io/py/fsspec.svg)](https://pypi.python.org/pypi/fsspec/)
93
+ [![Anaconda-Server Badge](https://anaconda.org/conda-forge/fsspec/badges/version.svg)](https://anaconda.org/conda-forge/fsspec)
94
+ ![Build](https://github.com/fsspec/filesystem_spec/workflows/CI/badge.svg)
95
+ [![Docs](https://readthedocs.org/projects/filesystem-spec/badge/?version=latest)](https://filesystem-spec.readthedocs.io/en/latest/?badge=latest)
96
+ [![PyPi downloads](https://img.shields.io/pypi/dm/fsspec?label=pypi%20downloads&style=flat)](https://pepy.tech/project/fsspec)
97
+
98
+ A specification for pythonic filesystems.
99
+
100
+ ## Install
101
+
102
+ ```bash
103
+ pip install fsspec
104
+ ```
105
+
106
+ would install the base fsspec. Various optionally supported features might require specification of custom
107
+ extra require, e.g. `pip install fsspec[ssh]` will install dependencies for `ssh` backends support.
108
+ Use `pip install fsspec[full]` for installation of all known extra dependencies.
109
+
110
+ Up-to-date package also provided through conda-forge distribution:
111
+
112
+ ```bash
113
+ conda install -c conda-forge fsspec
114
+ ```
115
+
116
+
117
+ ## Purpose
118
+
119
+ To produce a template or specification for a file-system interface, that specific implementations should follow,
120
+ so that applications making use of them can rely on a common behaviour and not have to worry about the specific
121
+ internal implementation decisions with any given backend. Many such implementations are included in this package,
122
+ or in sister projects such as `s3fs` and `gcsfs`.
123
+
124
+ In addition, if this is well-designed, then additional functionality, such as a key-value store or FUSE
125
+ mounting of the file-system implementation may be available for all implementations "for free".
126
+
127
+ ## Documentation
128
+
129
+ Please refer to [RTD](https://filesystem-spec.readthedocs.io/en/latest/?badge=latest)
130
+
131
+ ## Develop
132
+
133
+ fsspec uses GitHub Actions for CI. Environment files can be found
134
+ in the "ci/" directory. Note that the main environment is called "py38",
135
+ but it is expected that the version of python installed be adjustable at
136
+ CI runtime. For local use, pick a version suitable for you.
137
+
138
+ ### Testing
139
+
140
+ Tests can be run in the dev environment, if activated, via ``pytest fsspec``.
141
+
142
+ The full fsspec suite requires a system-level docker, docker-compose, and fuse
143
+ installation. If only making changes to one backend implementation, it is
144
+ not generally necessary to run all tests locally.
145
+
146
+ It is expected that contributors ensure that any change to fsspec does not
147
+ cause issues or regressions for either other fsspec-related packages such
148
+ as gcsfs and s3fs, nor for downstream users of fsspec. The "downstream" CI
149
+ run and corresponding environment file run a set of tests from the dask
150
+ test suite, and very minimal tests against pandas and zarr from the
151
+ test_downstream.py module in this repo.
152
+
153
+ ### Code Formatting
154
+
155
+ fsspec uses [Black](https://black.readthedocs.io/en/stable) to ensure
156
+ a consistent code format throughout the project.
157
+ Run ``black fsspec`` from the root of the filesystem_spec repository to
158
+ auto-format your code. Additionally, many editors have plugins that will apply
159
+ ``black`` as you edit files. ``black`` is included in the ``tox`` environments.
160
+
161
+ Optionally, you may wish to setup [pre-commit hooks](https://pre-commit.com) to
162
+ automatically run ``black`` when you make a git commit.
163
+ Run ``pre-commit install --install-hooks`` from the root of the
164
+ filesystem_spec repository to setup pre-commit hooks. ``black`` will now be run
165
+ before you commit, reformatting any changed files. You can format without
166
+ committing via ``pre-commit run`` or skip these checks with ``git commit
167
+ --no-verify``.
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_mpmath.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from mpmath.libmp import *
2
+ from mpmath import *
3
+
4
+ def test_newstyle_classes():
5
+ for cls in [mp, fp, iv, mpf, mpc]:
6
+ for s in cls.__class__.__mro__:
7
+ assert isinstance(s, type)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_str.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mpmath import nstr, matrix, inf
2
+
3
+ def test_nstr():
4
+ m = matrix([[0.75, 0.190940654, -0.0299195971],
5
+ [0.190940654, 0.65625, 0.205663228],
6
+ [-0.0299195971, 0.205663228, 0.64453125e-20]])
7
+ assert nstr(m, 4, min_fixed=-inf) == \
8
+ '''[ 0.75 0.1909 -0.02992]
9
+ [ 0.1909 0.6563 0.2057]
10
+ [-0.02992 0.2057 0.000000000000000000006445]'''
11
+ assert nstr(m, 4) == \
12
+ '''[ 0.75 0.1909 -0.02992]
13
+ [ 0.1909 0.6563 0.2057]
14
+ [-0.02992 0.2057 6.445e-21]'''
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/test_visualization.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Limited tests of the visualization module. Right now it just makes
3
+ sure that passing custom Axes works.
4
+
5
+ """
6
+
7
+ from mpmath import mp, fp
8
+
9
+ def test_axes():
10
+ try:
11
+ import matplotlib
12
+ version = matplotlib.__version__.split("-")[0]
13
+ version = version.split(".")[:2]
14
+ if [int(_) for _ in version] < [0,99]:
15
+ raise ImportError
16
+ import pylab
17
+ except ImportError:
18
+ print("\nSkipping test (pylab not available or too old version)\n")
19
+ return
20
+ fig = pylab.figure()
21
+ axes = fig.add_subplot(111)
22
+ for ctx in [mp, fp]:
23
+ ctx.plot(lambda x: x**2, [0, 3], axes=axes)
24
+ assert axes.get_xlabel() == 'x'
25
+ assert axes.get_ylabel() == 'f(x)'
26
+
27
+ fig = pylab.figure()
28
+ axes = fig.add_subplot(111)
29
+ for ctx in [mp, fp]:
30
+ ctx.cplot(lambda z: z, [-2, 2], [-10, 10], axes=axes)
31
+ assert axes.get_xlabel() == 'Re(z)'
32
+ assert axes.get_ylabel() == 'Im(z)'
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv_infer_v8.h ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /* cudnn_adv_infer : cuDNN's advanced and experimental features.
51
+
52
+ */
53
+
54
+ #if !defined(CUDNN_ADV_INFER_H_)
55
+ #define CUDNN_ADV_INFER_H_
56
+
57
+ #include <cuda_runtime.h>
58
+ #include <stdint.h>
59
+
60
+ #include "cudnn_version.h"
61
+ #include "cudnn_ops_infer.h"
62
+
63
+ /* These version numbers are autogenerated, do not edit manually. */
64
+ #define CUDNN_ADV_INFER_MAJOR 8
65
+ #define CUDNN_ADV_INFER_MINOR 7
66
+ #define CUDNN_ADV_INFER_PATCH 0
67
+
68
+ #if (CUDNN_ADV_INFER_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_INFER_MINOR != CUDNN_MINOR) || \
69
+ (CUDNN_ADV_INFER_PATCH != CUDNN_PATCHLEVEL)
70
+ #error Version mismatch in cuDNN ADV INFER!!!
71
+ #endif
72
+
73
+ #if defined(__cplusplus)
74
+ extern "C" {
75
+ #endif
76
+
77
+ /* BASIC RNN API */
78
+
79
+ typedef enum {
80
+ CUDNN_FWD_MODE_INFERENCE = 0,
81
+ CUDNN_FWD_MODE_TRAINING = 1,
82
+ } cudnnForwardMode_t;
83
+
84
+ typedef enum {
85
+ CUDNN_RNN_RELU = 0, /* basic RNN cell type with ReLu activation */
86
+ CUDNN_RNN_TANH = 1, /* basic RNN cell type with tanh activation */
87
+ CUDNN_LSTM = 2, /* LSTM with optional recurrent projection and clipping */
88
+ CUDNN_GRU = 3, /* Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1); */
89
+ } cudnnRNNMode_t;
90
+
91
+ typedef enum {
92
+ CUDNN_RNN_NO_BIAS = 0, /* rnn cell formulas do not use biases */
93
+ CUDNN_RNN_SINGLE_INP_BIAS = 1, /* rnn cell formulas use one input bias in input GEMM */
94
+ CUDNN_RNN_DOUBLE_BIAS = 2, /* default, rnn cell formulas use two bias vectors */
95
+ CUDNN_RNN_SINGLE_REC_BIAS = 3 /* rnn cell formulas use one recurrent bias in recurrent GEMM */
96
+ } cudnnRNNBiasMode_t;
97
+
98
+ typedef enum {
99
+ CUDNN_UNIDIRECTIONAL = 0, /* single direction network */
100
+ CUDNN_BIDIRECTIONAL = 1, /* output concatination at each layer */
101
+ } cudnnDirectionMode_t;
102
+
103
+ typedef enum {
104
+ CUDNN_LINEAR_INPUT = 0, /* adjustable weight matrix in first layer input GEMM */
105
+ CUDNN_SKIP_INPUT = 1, /* fixed identity matrix in the first layer input GEMM */
106
+ } cudnnRNNInputMode_t;
107
+
108
+ typedef enum {
109
+ CUDNN_RNN_CLIP_NONE = 0, /* disables LSTM cell clipping */
110
+ CUDNN_RNN_CLIP_MINMAX = 1, /* enables LSTM cell clipping */
111
+ } cudnnRNNClipMode_t;
112
+
113
+ typedef enum {
114
+ CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED = 0, /* padded, outer stride from one time-step to the next */
115
+ CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED = 1, /* sequence length sorted and packed as in basic RNN api */
116
+ CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED = 2, /* padded, outer stride from one batch to the next */
117
+ } cudnnRNNDataLayout_t;
118
+
119
+ /* Legacy type for backward compatibility */
120
+ typedef unsigned cudnnRNNPaddingMode_t;
121
+
122
+ /* For auxFlags in cudnnSetRNNDescriptor_v8() and cudnnSetRNNPaddingMode() */
123
+ #define CUDNN_RNN_PADDED_IO_DISABLED 0
124
+ #define CUDNN_RNN_PADDED_IO_ENABLED (1U << 0)
125
+
126
+ struct cudnnRNNStruct;
127
+ typedef struct cudnnRNNStruct *cudnnRNNDescriptor_t;
128
+
129
+ struct cudnnPersistentRNNPlan;
130
+ typedef struct cudnnPersistentRNNPlan *cudnnPersistentRNNPlan_t;
131
+
132
+ struct cudnnRNNDataStruct;
133
+ typedef struct cudnnRNNDataStruct *cudnnRNNDataDescriptor_t;
134
+
135
+ cudnnStatus_t CUDNNWINAPI
136
+ cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc);
137
+
138
+ cudnnStatus_t CUDNNWINAPI
139
+ cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
140
+
141
+ cudnnStatus_t CUDNNWINAPI
142
+ cudnnSetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
143
+ cudnnRNNAlgo_t algo,
144
+ cudnnRNNMode_t cellMode,
145
+ cudnnRNNBiasMode_t biasMode,
146
+ cudnnDirectionMode_t dirMode,
147
+ cudnnRNNInputMode_t inputMode,
148
+ cudnnDataType_t dataType,
149
+ cudnnDataType_t mathPrec,
150
+ cudnnMathType_t mathType,
151
+ int32_t inputSize,
152
+ int32_t hiddenSize,
153
+ int32_t projSize,
154
+ int32_t numLayers,
155
+ cudnnDropoutDescriptor_t dropoutDesc,
156
+ uint32_t auxFlags);
157
+
158
+ cudnnStatus_t CUDNNWINAPI
159
+ cudnnGetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
160
+ cudnnRNNAlgo_t *algo,
161
+ cudnnRNNMode_t *cellMode,
162
+ cudnnRNNBiasMode_t *biasMode,
163
+ cudnnDirectionMode_t *dirMode,
164
+ cudnnRNNInputMode_t *inputMode,
165
+ cudnnDataType_t *dataType,
166
+ cudnnDataType_t *mathPrec,
167
+ cudnnMathType_t *mathType,
168
+ int32_t *inputSize,
169
+ int32_t *hiddenSize,
170
+ int32_t *projSize,
171
+ int32_t *numLayers,
172
+ cudnnDropoutDescriptor_t *dropoutDesc,
173
+ uint32_t *auxFlags);
174
+
175
+ /*
176
+ * mathPrec in cudnnSetRNNDescriptor_v6() specifies compute precision
177
+ * compute precision is further modified by cudnnSetRNNMatrixMathType()
178
+ * dataType in cudnnGetRNNParamsSize() and wDesc specify weight storage
179
+ * dropout is between RNN layers, not between recurrent steps
180
+ */
181
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
182
+ cudnnSetRNNDescriptor_v6(cudnnHandle_t handle,
183
+ cudnnRNNDescriptor_t rnnDesc,
184
+ const int hiddenSize,
185
+ const int numLayers,
186
+ cudnnDropoutDescriptor_t dropoutDesc,
187
+ cudnnRNNInputMode_t inputMode,
188
+ cudnnDirectionMode_t direction,
189
+ cudnnRNNMode_t cellMode,
190
+ cudnnRNNAlgo_t algo,
191
+ cudnnDataType_t mathPrec);
192
+
193
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
194
+ cudnnGetRNNDescriptor_v6(cudnnHandle_t handle,
195
+ cudnnRNNDescriptor_t rnnDesc,
196
+ int *hiddenSize,
197
+ int *numLayers,
198
+ cudnnDropoutDescriptor_t *dropoutDesc,
199
+ cudnnRNNInputMode_t *inputMode,
200
+ cudnnDirectionMode_t *direction,
201
+ cudnnRNNMode_t *cellMode,
202
+ cudnnRNNAlgo_t *algo,
203
+ cudnnDataType_t *mathPrec);
204
+
205
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
206
+ cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t mType);
207
+
208
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
209
+ cudnnGetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType);
210
+
211
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
212
+ cudnnSetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNBiasMode_t biasMode);
213
+
214
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
215
+ cudnnGetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNBiasMode_t *biasMode);
216
+
217
+ cudnnStatus_t CUDNNWINAPI
218
+ cudnnRNNSetClip_v8(cudnnRNNDescriptor_t rnnDesc,
219
+ cudnnRNNClipMode_t clipMode,
220
+ cudnnNanPropagation_t clipNanOpt,
221
+ double lclip,
222
+ double rclip);
223
+
224
+ cudnnStatus_t CUDNNWINAPI
225
+ cudnnRNNGetClip_v8(cudnnRNNDescriptor_t rnnDesc,
226
+ cudnnRNNClipMode_t *clipMode,
227
+ cudnnNanPropagation_t *clipNanOpt,
228
+ double *lclip,
229
+ double *rclip);
230
+
231
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
232
+ cudnnRNNSetClip(cudnnHandle_t handle,
233
+ cudnnRNNDescriptor_t rnnDesc,
234
+ cudnnRNNClipMode_t clipMode,
235
+ cudnnNanPropagation_t clipNanOpt,
236
+ double lclip,
237
+ double rclip);
238
+
239
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
240
+ cudnnRNNGetClip(cudnnHandle_t handle,
241
+ cudnnRNNDescriptor_t rnnDesc,
242
+ cudnnRNNClipMode_t *clipMode,
243
+ cudnnNanPropagation_t *clipNanOpt,
244
+ double *lclip,
245
+ double *rclip);
246
+
247
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
248
+ cudnnSetRNNProjectionLayers(cudnnHandle_t handle,
249
+ cudnnRNNDescriptor_t rnnDesc,
250
+ const int recProjSize,
251
+ const int outProjSize);
252
+
253
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
254
+ cudnnGetRNNProjectionLayers(cudnnHandle_t handle,
255
+ const cudnnRNNDescriptor_t rnnDesc,
256
+ int *recProjSize,
257
+ int *outProjSize);
258
+
259
+ /* Expensive. Creates the plan for the specific settings. */
260
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
261
+ cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
262
+ const int minibatch,
263
+ const cudnnDataType_t dataType,
264
+ cudnnPersistentRNNPlan_t *plan);
265
+
266
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
267
+ cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan);
268
+
269
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
270
+ cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan);
271
+
272
+ cudnnStatus_t CUDNNWINAPI
273
+ cudnnBuildRNNDynamic(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int miniBatch);
274
+
275
+ /* dataType in weight descriptors and input descriptors is used to describe storage */
276
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
277
+ cudnnGetRNNWorkspaceSize(cudnnHandle_t handle,
278
+ const cudnnRNNDescriptor_t rnnDesc,
279
+ const int seqLength,
280
+ const cudnnTensorDescriptor_t *xDesc,
281
+ size_t *sizeInBytes);
282
+
283
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
284
+ cudnnGetRNNTrainingReserveSize(cudnnHandle_t handle,
285
+ const cudnnRNNDescriptor_t rnnDesc,
286
+ const int seqLength,
287
+ const cudnnTensorDescriptor_t *xDesc,
288
+ size_t *sizeInBytes);
289
+
290
+ cudnnStatus_t CUDNNWINAPI
291
+ cudnnGetRNNTempSpaceSizes(cudnnHandle_t handle,
292
+ cudnnRNNDescriptor_t rnnDesc,
293
+ cudnnForwardMode_t fMode,
294
+ cudnnRNNDataDescriptor_t xDesc,
295
+ size_t *workSpaceSize,
296
+ size_t *reserveSpaceSize);
297
+
298
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
299
+ cudnnGetRNNParamsSize(cudnnHandle_t handle,
300
+ const cudnnRNNDescriptor_t rnnDesc,
301
+ const cudnnTensorDescriptor_t xDesc,
302
+ size_t *sizeInBytes,
303
+ cudnnDataType_t dataType);
304
+
305
+ cudnnStatus_t CUDNNWINAPI
306
+ cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, size_t *weightSpaceSize);
307
+
308
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
309
+ cudnnGetRNNLinLayerMatrixParams(cudnnHandle_t handle,
310
+ const cudnnRNNDescriptor_t rnnDesc,
311
+ const int pseudoLayer,
312
+ const cudnnTensorDescriptor_t xDesc,
313
+ const cudnnFilterDescriptor_t wDesc,
314
+ const void *w,
315
+ const int linLayerID,
316
+ cudnnFilterDescriptor_t linLayerMatDesc,
317
+ void **linLayerMat);
318
+
319
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
320
+ cudnnGetRNNLinLayerBiasParams(cudnnHandle_t handle,
321
+ const cudnnRNNDescriptor_t rnnDesc,
322
+ const int pseudoLayer,
323
+ const cudnnTensorDescriptor_t xDesc,
324
+ const cudnnFilterDescriptor_t wDesc,
325
+ const void *w,
326
+ const int linLayerID,
327
+ cudnnFilterDescriptor_t linLayerBiasDesc,
328
+ void **linLayerBias);
329
+
330
+ cudnnStatus_t CUDNNWINAPI
331
+ cudnnGetRNNWeightParams(cudnnHandle_t handle,
332
+ cudnnRNNDescriptor_t rnnDesc,
333
+ int32_t pseudoLayer,
334
+ size_t weightSpaceSize,
335
+ const void *weightSpace,
336
+ int32_t linLayerID,
337
+ cudnnTensorDescriptor_t mDesc,
338
+ void **mAddr,
339
+ cudnnTensorDescriptor_t bDesc,
340
+ void **bAddr);
341
+
342
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
343
+ cudnnRNNForwardInference(cudnnHandle_t handle,
344
+ const cudnnRNNDescriptor_t rnnDesc,
345
+ const int seqLength,
346
+ const cudnnTensorDescriptor_t *xDesc,
347
+ const void *x,
348
+ const cudnnTensorDescriptor_t hxDesc,
349
+ const void *hx,
350
+ const cudnnTensorDescriptor_t cxDesc,
351
+ const void *cx,
352
+ const cudnnFilterDescriptor_t wDesc,
353
+ const void *w,
354
+ const cudnnTensorDescriptor_t *yDesc,
355
+ void *y,
356
+ const cudnnTensorDescriptor_t hyDesc,
357
+ void *hy,
358
+ const cudnnTensorDescriptor_t cyDesc,
359
+ void *cy,
360
+ void *workSpace,
361
+ size_t workSpaceSizeInBytes);
362
+
363
+ /* RNN EX API */
364
+
365
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
366
+ cudnnSetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, unsigned paddingMode);
367
+
368
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
369
+ cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, unsigned *paddingMode);
370
+
371
+ cudnnStatus_t CUDNNWINAPI
372
+ cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc);
373
+
374
+ cudnnStatus_t CUDNNWINAPI
375
+ cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc);
376
+
377
+ cudnnStatus_t CUDNNWINAPI
378
+ cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
379
+ cudnnDataType_t dataType,
380
+ cudnnRNNDataLayout_t layout,
381
+ int maxSeqLength,
382
+ int batchSize,
383
+ int vectorSize,
384
+ const int seqLengthArray[], /* length of each sequence in the batch */
385
+ void *paddingFill); /* symbol for filling padding position in output */
386
+
387
+ cudnnStatus_t CUDNNWINAPI
388
+ cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
389
+ cudnnDataType_t *dataType,
390
+ cudnnRNNDataLayout_t *layout,
391
+ int *maxSeqLength,
392
+ int *batchSize,
393
+ int *vectorSize,
394
+ int arrayLengthRequested,
395
+ int seqLengthArray[],
396
+ void *paddingFill);
397
+
398
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
399
+ cudnnRNNForwardInferenceEx(cudnnHandle_t handle,
400
+ const cudnnRNNDescriptor_t rnnDesc,
401
+ const cudnnRNNDataDescriptor_t xDesc,
402
+ const void *x,
403
+ const cudnnTensorDescriptor_t hxDesc,
404
+ const void *hx,
405
+ const cudnnTensorDescriptor_t cxDesc,
406
+ const void *cx,
407
+ const cudnnFilterDescriptor_t wDesc,
408
+ const void *w,
409
+ const cudnnRNNDataDescriptor_t yDesc,
410
+ void *y,
411
+ const cudnnTensorDescriptor_t hyDesc,
412
+ void *hy,
413
+ const cudnnTensorDescriptor_t cyDesc,
414
+ void *cy,
415
+ const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */
416
+ const void *keys, /* reserved, should pass NULL */
417
+ const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */
418
+ void *cAttn, /* reserved, should pass NULL */
419
+ const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */
420
+ void *iAttn, /* reserved, should pass NULL */
421
+ const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */
422
+ void *queries, /* reserved, should pass NULL */
423
+ void *workSpace,
424
+ size_t workSpaceSizeInBytes);
425
+
426
+ cudnnStatus_t CUDNNWINAPI
427
+ cudnnRNNForward(cudnnHandle_t handle,
428
+ cudnnRNNDescriptor_t rnnDesc,
429
+ cudnnForwardMode_t fwdMode,
430
+ const int32_t devSeqLengths[],
431
+ cudnnRNNDataDescriptor_t xDesc,
432
+ const void *x,
433
+ cudnnRNNDataDescriptor_t yDesc,
434
+ void *y,
435
+ cudnnTensorDescriptor_t hDesc,
436
+ const void *hx,
437
+ void *hy,
438
+ cudnnTensorDescriptor_t cDesc,
439
+ const void *cx,
440
+ void *cy,
441
+ size_t weightSpaceSize,
442
+ const void *weightSpace,
443
+ size_t workSpaceSize,
444
+ void *workSpace,
445
+ size_t reserveSpaceSize,
446
+ void *reserveSpace);
447
+
448
+ /* RNN FIND API */
449
+
450
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
451
+ cudnnSetRNNAlgorithmDescriptor(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, cudnnAlgorithmDescriptor_t algoDesc);
452
+
453
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
454
+ cudnnGetRNNForwardInferenceAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count);
455
+
456
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
457
+ cudnnFindRNNForwardInferenceAlgorithmEx(cudnnHandle_t handle,
458
+ const cudnnRNNDescriptor_t rnnDesc,
459
+ const int seqLength,
460
+ const cudnnTensorDescriptor_t *xDesc,
461
+ const void *x,
462
+ const cudnnTensorDescriptor_t hxDesc,
463
+ const void *hx,
464
+ const cudnnTensorDescriptor_t cxDesc,
465
+ const void *cx,
466
+ const cudnnFilterDescriptor_t wDesc,
467
+ const void *w,
468
+ const cudnnTensorDescriptor_t *yDesc,
469
+ void *y,
470
+ const cudnnTensorDescriptor_t hyDesc,
471
+ void *hy,
472
+ const cudnnTensorDescriptor_t cyDesc,
473
+ void *cy,
474
+ const float findIntensity,
475
+ const int requestedAlgoCount,
476
+ int *returnedAlgoCount,
477
+ cudnnAlgorithmPerformance_t *perfResults,
478
+ void *workspace,
479
+ size_t workSpaceSizeInBytes);
480
+
481
+ /* Sequence data descriptor */
482
+
483
+ typedef enum {
484
+ CUDNN_SEQDATA_TIME_DIM = 0, /* index in time */
485
+ CUDNN_SEQDATA_BATCH_DIM = 1, /* index in batch */
486
+ CUDNN_SEQDATA_BEAM_DIM = 2, /* index in beam */
487
+ CUDNN_SEQDATA_VECT_DIM = 3 /* index in vector */
488
+ } cudnnSeqDataAxis_t;
489
+
490
+ struct cudnnSeqDataStruct;
491
+ typedef struct cudnnSeqDataStruct *cudnnSeqDataDescriptor_t;
492
+
493
+ #define CUDNN_SEQDATA_DIM_COUNT 4 /* dimension count */
494
+
495
+ cudnnStatus_t CUDNNWINAPI
496
+ cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc);
497
+
498
+ cudnnStatus_t CUDNNWINAPI
499
+ cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc);
500
+
501
+ cudnnStatus_t CUDNNWINAPI
502
+ cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc,
503
+ cudnnDataType_t dataType,
504
+ int nbDims,
505
+ const int dimA[],
506
+ const cudnnSeqDataAxis_t axes[],
507
+ size_t seqLengthArraySize,
508
+ const int seqLengthArray[],
509
+ void *paddingFill);
510
+
511
+ cudnnStatus_t CUDNNWINAPI
512
+ cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc,
513
+ cudnnDataType_t *dataType,
514
+ int *nbDims,
515
+ int nbDimsRequested,
516
+ int dimA[],
517
+ cudnnSeqDataAxis_t axes[],
518
+ size_t *seqLengthArraySize,
519
+ size_t seqLengthSizeRequested,
520
+ int seqLengthArray[],
521
+ void *paddingFill);
522
+
523
+ /* Multihead Attention */
524
+
525
+ /* Legacy type for backward compatibility */
526
+ typedef unsigned cudnnAttnQueryMap_t;
527
+
528
+ /*
529
+ * Multi-head attention options passed via 'attnMode' in cudnnSetAttnDescriptor().
530
+ * Use the bitwise OR operator to combine several settings listed below. Additional
531
+ * minor options can be added here w/o changing or introducing new API functions.
532
+ */
533
+ #define CUDNN_ATTN_QUERYMAP_ALL_TO_ONE 0 /* multiple Q-s map to a single (K,V) set when beam size > 1 */
534
+ #define CUDNN_ATTN_QUERYMAP_ONE_TO_ONE (1U << 0) /* multiple Q-s map to multiple (K,V) sets when beam size > 1 */
535
+ #define CUDNN_ATTN_DISABLE_PROJ_BIASES 0 /* no biases in attention input and output projections */
536
+ #define CUDNN_ATTN_ENABLE_PROJ_BIASES (1U << 1) /* use biases in attention input and output projections */
537
+
538
+ struct cudnnAttnStruct;
539
+ typedef struct cudnnAttnStruct *cudnnAttnDescriptor_t;
540
+
541
+ cudnnStatus_t CUDNNWINAPI
542
+ cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc);
543
+
544
+ cudnnStatus_t CUDNNWINAPI
545
+ cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc);
546
+
547
+ cudnnStatus_t CUDNNWINAPI
548
+ cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
549
+ unsigned attnMode,
550
+ int nHeads,
551
+ double smScaler,
552
+ cudnnDataType_t dataType,
553
+ cudnnDataType_t computePrec,
554
+ cudnnMathType_t mathType,
555
+ cudnnDropoutDescriptor_t attnDropoutDesc,
556
+ cudnnDropoutDescriptor_t postDropoutDesc,
557
+ int qSize,
558
+ int kSize,
559
+ int vSize,
560
+ int qProjSize,
561
+ int kProjSize,
562
+ int vProjSize,
563
+ int oProjSize,
564
+ int qoMaxSeqLength,
565
+ int kvMaxSeqLength,
566
+ int maxBatchSize,
567
+ int maxBeamSize);
568
+
569
+ cudnnStatus_t CUDNNWINAPI
570
+ cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
571
+ unsigned *attnMode,
572
+ int *nHeads,
573
+ double *smScaler,
574
+ cudnnDataType_t *dataType,
575
+ cudnnDataType_t *computePrec,
576
+ cudnnMathType_t *mathType,
577
+ cudnnDropoutDescriptor_t *attnDropoutDesc,
578
+ cudnnDropoutDescriptor_t *postDropoutDesc,
579
+ int *qSize,
580
+ int *kSize,
581
+ int *vSize,
582
+ int *qProjSize,
583
+ int *kProjSize,
584
+ int *vProjSize,
585
+ int *oProjSize,
586
+ int *qoMaxSeqLength,
587
+ int *kvMaxSeqLength,
588
+ int *maxBatchSize,
589
+ int *maxBeamSize);
590
+
591
+ cudnnStatus_t CUDNNWINAPI
592
+ cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle,
593
+ const cudnnAttnDescriptor_t attnDesc,
594
+ size_t *weightSizeInBytes,
595
+ size_t *workSpaceSizeInBytes,
596
+ size_t *reserveSpaceSizeInBytes);
597
+
598
+ typedef enum {
599
+ CUDNN_MH_ATTN_Q_WEIGHTS = 0, /* input projection weights for 'queries' */
600
+ CUDNN_MH_ATTN_K_WEIGHTS = 1, /* input projection weights for 'keys' */
601
+ CUDNN_MH_ATTN_V_WEIGHTS = 2, /* input projection weights for 'values' */
602
+ CUDNN_MH_ATTN_O_WEIGHTS = 3, /* output projection weights */
603
+ CUDNN_MH_ATTN_Q_BIASES = 4, /* input projection bias tensor for 'queries' */
604
+ CUDNN_MH_ATTN_K_BIASES = 5, /* input projection bias for 'keys' */
605
+ CUDNN_MH_ATTN_V_BIASES = 6, /* input projection bias for 'values' */
606
+ CUDNN_MH_ATTN_O_BIASES = 7, /* output projection biases */
607
+ } cudnnMultiHeadAttnWeightKind_t;
608
+
609
+ #define CUDNN_ATTN_WKIND_COUNT 8 /* Number of attention weight/bias tensors */
610
+
611
+ cudnnStatus_t CUDNNWINAPI
612
+ cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle,
613
+ const cudnnAttnDescriptor_t attnDesc,
614
+ cudnnMultiHeadAttnWeightKind_t wKind,
615
+ size_t weightSizeInBytes,
616
+ const void *weights,
617
+ cudnnTensorDescriptor_t wDesc,
618
+ void **wAddr);
619
+
620
+ cudnnStatus_t CUDNNWINAPI
621
+ cudnnMultiHeadAttnForward(cudnnHandle_t handle,
622
+ const cudnnAttnDescriptor_t attnDesc,
623
+ int currIdx,
624
+ const int loWinIdx[],
625
+ const int hiWinIdx[],
626
+ const int devSeqLengthsQO[],
627
+ const int devSeqLengthsKV[],
628
+ const cudnnSeqDataDescriptor_t qDesc,
629
+ const void *queries,
630
+ const void *residuals,
631
+ const cudnnSeqDataDescriptor_t kDesc,
632
+ const void *keys,
633
+ const cudnnSeqDataDescriptor_t vDesc,
634
+ const void *values,
635
+ const cudnnSeqDataDescriptor_t oDesc,
636
+ void *out,
637
+ size_t weightSizeInBytes,
638
+ const void *weights,
639
+ size_t workSpaceSizeInBytes,
640
+ void *workSpace,
641
+ size_t reserveSpaceSizeInBytes,
642
+ void *reserveSpace);
643
+
644
+ /*
645
+ * \brief Cross-library version checker.
646
+ * This function is implemented differently in each sub-library. Each sublib
647
+ * checks whether its own version matches that of its dependencies.
648
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
649
+ * CUDNN_STATUS_VERSION_MISMATCH if the versions are inconsistent.
650
+ */
651
+ cudnnStatus_t CUDNNWINAPI
652
+ cudnnAdvInferVersionCheck(void);
653
+
654
+ #if defined(__cplusplus)
655
+ }
656
+ #endif
657
+
658
+ #endif /* CUDNN_ADV_INFER_H_ */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_adv_train_v8.h ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /* cudnn_adv_train : cuDNN's advanced and experimental features.
51
+
52
+ */
53
+
54
+ #if !defined(CUDNN_ADV_TRAIN_H_)
55
+ #define CUDNN_ADV_TRAIN_H_
56
+
57
+ #include <cuda_runtime.h>
58
+ #include <stdint.h>
59
+
60
+ #include "cudnn_version.h"
61
+ #include "cudnn_ops_infer.h"
62
+ #include "cudnn_ops_train.h"
63
+ #include "cudnn_adv_infer.h"
64
+
65
+ /* These version numbers are autogenerated, do not edit manually. */
66
+ #define CUDNN_ADV_TRAIN_MAJOR 8
67
+ #define CUDNN_ADV_TRAIN_MINOR 7
68
+ #define CUDNN_ADV_TRAIN_PATCH 0
69
+
70
+ #if (CUDNN_ADV_TRAIN_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_TRAIN_MINOR != CUDNN_MINOR) || \
71
+ (CUDNN_ADV_TRAIN_PATCH != CUDNN_PATCHLEVEL)
72
+ #error Version mismatch in cuDNN ADV TRAIN!!!
73
+ #endif
74
+
75
+ #if defined(__cplusplus)
76
+ extern "C" {
77
+ #endif
78
+
79
+ typedef enum {
80
+ CUDNN_WGRAD_MODE_ADD = 0, /* add partial gradients to wgrad output buffers */
81
+ CUDNN_WGRAD_MODE_SET = 1, /* write partial gradients to wgrad output buffers */
82
+ } cudnnWgradMode_t;
83
+
84
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
85
+ cudnnRNNForwardTraining(cudnnHandle_t handle,
86
+ const cudnnRNNDescriptor_t rnnDesc,
87
+ const int seqLength,
88
+ const cudnnTensorDescriptor_t *xDesc,
89
+ const void *x,
90
+ const cudnnTensorDescriptor_t hxDesc,
91
+ const void *hx,
92
+ const cudnnTensorDescriptor_t cxDesc,
93
+ const void *cx,
94
+ const cudnnFilterDescriptor_t wDesc,
95
+ const void *w,
96
+ const cudnnTensorDescriptor_t *yDesc,
97
+ void *y,
98
+ const cudnnTensorDescriptor_t hyDesc,
99
+ void *hy,
100
+ const cudnnTensorDescriptor_t cyDesc,
101
+ void *cy,
102
+ void *workSpace,
103
+ size_t workSpaceSizeInBytes,
104
+ void *reserveSpace,
105
+ size_t reserveSpaceSizeInBytes);
106
+
107
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
108
+ cudnnRNNBackwardData(cudnnHandle_t handle,
109
+ const cudnnRNNDescriptor_t rnnDesc,
110
+ const int seqLength,
111
+ const cudnnTensorDescriptor_t *yDesc,
112
+ const void *y,
113
+ const cudnnTensorDescriptor_t *dyDesc,
114
+ const void *dy,
115
+ const cudnnTensorDescriptor_t dhyDesc,
116
+ const void *dhy,
117
+ const cudnnTensorDescriptor_t dcyDesc,
118
+ const void *dcy,
119
+ const cudnnFilterDescriptor_t wDesc,
120
+ const void *w,
121
+ const cudnnTensorDescriptor_t hxDesc,
122
+ const void *hx,
123
+ const cudnnTensorDescriptor_t cxDesc,
124
+ const void *cx,
125
+ const cudnnTensorDescriptor_t *dxDesc,
126
+ void *dx,
127
+ const cudnnTensorDescriptor_t dhxDesc,
128
+ void *dhx,
129
+ const cudnnTensorDescriptor_t dcxDesc,
130
+ void *dcx,
131
+ void *workSpace,
132
+ size_t workSpaceSizeInBytes,
133
+ void *reserveSpace,
134
+ size_t reserveSpaceSizeInBytes);
135
+
136
+ cudnnStatus_t CUDNNWINAPI
137
+ cudnnRNNBackwardData_v8(cudnnHandle_t handle,
138
+ cudnnRNNDescriptor_t rnnDesc,
139
+ const int32_t devSeqLengths[],
140
+ cudnnRNNDataDescriptor_t yDesc,
141
+ const void *y,
142
+ const void *dy,
143
+ cudnnRNNDataDescriptor_t xDesc,
144
+ void *dx,
145
+ cudnnTensorDescriptor_t hDesc,
146
+ const void *hx,
147
+ const void *dhy,
148
+ void *dhx,
149
+ cudnnTensorDescriptor_t cDesc,
150
+ const void *cx,
151
+ const void *dcy,
152
+ void *dcx,
153
+ size_t weightSpaceSize,
154
+ const void *weightSpace,
155
+ size_t workSpaceSize,
156
+ void *workSpace,
157
+ size_t reserveSpaceSize,
158
+ void *reserveSpace);
159
+
160
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
161
+ cudnnRNNBackwardWeights(cudnnHandle_t handle,
162
+ const cudnnRNNDescriptor_t rnnDesc,
163
+ const int seqLength,
164
+ const cudnnTensorDescriptor_t *xDesc,
165
+ const void *x,
166
+ const cudnnTensorDescriptor_t hxDesc,
167
+ const void *hx,
168
+ const cudnnTensorDescriptor_t *yDesc,
169
+ const void *y,
170
+ const void *workSpace,
171
+ size_t workSpaceSizeInBytes,
172
+ const cudnnFilterDescriptor_t dwDesc,
173
+ void *dw,
174
+ const void *reserveSpace,
175
+ size_t reserveSpaceSizeInBytes);
176
+
177
+ cudnnStatus_t CUDNNWINAPI
178
+ cudnnRNNBackwardWeights_v8(cudnnHandle_t handle,
179
+ cudnnRNNDescriptor_t rnnDesc,
180
+ cudnnWgradMode_t addGrad,
181
+ const int32_t devSeqLengths[],
182
+ cudnnRNNDataDescriptor_t xDesc,
183
+ const void *x,
184
+ cudnnTensorDescriptor_t hDesc,
185
+ const void *hx,
186
+ cudnnRNNDataDescriptor_t yDesc,
187
+ const void *y,
188
+ size_t weightSpaceSize,
189
+ void *dweightSpace,
190
+ size_t workSpaceSize,
191
+ void *workSpace,
192
+ size_t reserveSpaceSize,
193
+ void *reserveSpace);
194
+
195
+ /* RNN EX API */
196
+
197
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
198
+ cudnnRNNForwardTrainingEx(cudnnHandle_t handle,
199
+ const cudnnRNNDescriptor_t rnnDesc,
200
+ const cudnnRNNDataDescriptor_t xDesc,
201
+ const void *x,
202
+ const cudnnTensorDescriptor_t hxDesc,
203
+ const void *hx,
204
+ const cudnnTensorDescriptor_t cxDesc,
205
+ const void *cx,
206
+ const cudnnFilterDescriptor_t wDesc,
207
+ const void *w,
208
+ const cudnnRNNDataDescriptor_t yDesc,
209
+ void *y,
210
+ const cudnnTensorDescriptor_t hyDesc,
211
+ void *hy,
212
+ const cudnnTensorDescriptor_t cyDesc,
213
+ void *cy,
214
+ const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */
215
+ const void *keys, /* reserved, should pass NULL */
216
+ const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */
217
+ void *cAttn, /* reserved, should pass NULL */
218
+ const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */
219
+ void *iAttn, /* reserved, should pass NULL */
220
+ const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */
221
+ void *queries, /* reserved, should pass NULL */
222
+ void *workSpace,
223
+ size_t workSpaceSizeInBytes,
224
+ void *reserveSpace,
225
+ size_t reserveSpaceSizeInBytes);
226
+
227
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
228
+ cudnnRNNBackwardDataEx(cudnnHandle_t handle,
229
+ const cudnnRNNDescriptor_t rnnDesc,
230
+ const cudnnRNNDataDescriptor_t yDesc,
231
+ const void *y,
232
+ const cudnnRNNDataDescriptor_t dyDesc,
233
+ const void *dy,
234
+ const cudnnRNNDataDescriptor_t dcDesc, /* reserved, should pass NULL */
235
+ const void *dcAttn, /* reserved, should pass NULL */
236
+ const cudnnTensorDescriptor_t dhyDesc,
237
+ const void *dhy,
238
+ const cudnnTensorDescriptor_t dcyDesc,
239
+ const void *dcy,
240
+ const cudnnFilterDescriptor_t wDesc,
241
+ const void *w,
242
+ const cudnnTensorDescriptor_t hxDesc,
243
+ const void *hx,
244
+ const cudnnTensorDescriptor_t cxDesc,
245
+ const void *cx,
246
+ const cudnnRNNDataDescriptor_t dxDesc,
247
+ void *dx,
248
+ const cudnnTensorDescriptor_t dhxDesc,
249
+ void *dhx,
250
+ const cudnnTensorDescriptor_t dcxDesc,
251
+ void *dcx,
252
+ const cudnnRNNDataDescriptor_t dkDesc, /* reserved, should pass NULL */
253
+ void *dkeys, /* reserved, should pass NULL */
254
+ void *workSpace,
255
+ size_t workSpaceSizeInBytes,
256
+ void *reserveSpace,
257
+ size_t reserveSpaceSizeInBytes);
258
+
259
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
260
+ cudnnRNNBackwardWeightsEx(cudnnHandle_t handle,
261
+ const cudnnRNNDescriptor_t rnnDesc,
262
+ const cudnnRNNDataDescriptor_t xDesc,
263
+ const void *x,
264
+ const cudnnTensorDescriptor_t hxDesc,
265
+ const void *hx,
266
+ const cudnnRNNDataDescriptor_t yDesc,
267
+ const void *y,
268
+ void *workSpace,
269
+ size_t workSpaceSizeInBytes,
270
+ const cudnnFilterDescriptor_t dwDesc,
271
+ void *dw,
272
+ void *reserveSpace,
273
+ size_t reserveSpaceSizeInBytes);
274
+
275
+ /* RNN FIND API */
276
+
277
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
278
+ cudnnGetRNNForwardTrainingAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count);
279
+
280
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
281
+ cudnnFindRNNForwardTrainingAlgorithmEx(cudnnHandle_t handle,
282
+ const cudnnRNNDescriptor_t rnnDesc,
283
+ const int seqLength,
284
+ const cudnnTensorDescriptor_t *xDesc,
285
+ const void *x,
286
+ const cudnnTensorDescriptor_t hxDesc,
287
+ const void *hx,
288
+ const cudnnTensorDescriptor_t cxDesc,
289
+ const void *cx,
290
+ const cudnnFilterDescriptor_t wDesc,
291
+ const void *w,
292
+ const cudnnTensorDescriptor_t *yDesc,
293
+ void *y,
294
+ const cudnnTensorDescriptor_t hyDesc,
295
+ void *hy,
296
+ const cudnnTensorDescriptor_t cyDesc,
297
+ void *cy,
298
+ const float findIntensity,
299
+ const int requestedAlgoCount,
300
+ int *returnedAlgoCount,
301
+ cudnnAlgorithmPerformance_t *perfResults,
302
+ void *workspace,
303
+ size_t workSpaceSizeInBytes,
304
+ void *reserveSpace,
305
+ size_t reserveSpaceSizeInBytes);
306
+
307
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
308
+ cudnnGetRNNBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count);
309
+
310
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
311
+ cudnnFindRNNBackwardDataAlgorithmEx(cudnnHandle_t handle,
312
+ const cudnnRNNDescriptor_t rnnDesc,
313
+ const int seqLength,
314
+ const cudnnTensorDescriptor_t *yDesc,
315
+ const void *y,
316
+ const cudnnTensorDescriptor_t *dyDesc,
317
+ const void *dy,
318
+ const cudnnTensorDescriptor_t dhyDesc,
319
+ const void *dhy,
320
+ const cudnnTensorDescriptor_t dcyDesc,
321
+ const void *dcy,
322
+ const cudnnFilterDescriptor_t wDesc,
323
+ const void *w,
324
+ const cudnnTensorDescriptor_t hxDesc,
325
+ const void *hx,
326
+ const cudnnTensorDescriptor_t cxDesc,
327
+ const void *cx,
328
+ const cudnnTensorDescriptor_t *dxDesc,
329
+ void *dx,
330
+ const cudnnTensorDescriptor_t dhxDesc,
331
+ void *dhx,
332
+ const cudnnTensorDescriptor_t dcxDesc,
333
+ void *dcx,
334
+ const float findIntensity,
335
+ const int requestedAlgoCount,
336
+ int *returnedAlgoCount,
337
+ cudnnAlgorithmPerformance_t *perfResults,
338
+ void *workspace,
339
+ size_t workSpaceSizeInBytes,
340
+ void *reserveSpace,
341
+ size_t reserveSpaceSizeInBytes);
342
+
343
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
344
+ cudnnGetRNNBackwardWeightsAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count);
345
+
346
+ CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
347
+ cudnnFindRNNBackwardWeightsAlgorithmEx(cudnnHandle_t handle,
348
+ const cudnnRNNDescriptor_t rnnDesc,
349
+ const int seqLength,
350
+ const cudnnTensorDescriptor_t *xDesc,
351
+ const void *x,
352
+ const cudnnTensorDescriptor_t hxDesc,
353
+ const void *hx,
354
+ const cudnnTensorDescriptor_t *yDesc,
355
+ const void *y,
356
+ const float findIntensity,
357
+ const int requestedAlgoCount,
358
+ int *returnedAlgoCount,
359
+ cudnnAlgorithmPerformance_t *perfResults,
360
+ const void *workspace,
361
+ size_t workSpaceSizeInBytes,
362
+ const cudnnFilterDescriptor_t dwDesc,
363
+ void *dw,
364
+ const void *reserveSpace,
365
+ size_t reserveSpaceSizeInBytes);
366
+
367
+ cudnnStatus_t CUDNNWINAPI
368
+ cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle,
369
+ const cudnnAttnDescriptor_t attnDesc,
370
+ const int loWinIdx[],
371
+ const int hiWinIdx[],
372
+ const int devSeqLengthsDQDO[],
373
+ const int devSeqLengthsDKDV[],
374
+ const cudnnSeqDataDescriptor_t doDesc,
375
+ const void *dout,
376
+ const cudnnSeqDataDescriptor_t dqDesc,
377
+ void *dqueries,
378
+ const void *queries,
379
+ const cudnnSeqDataDescriptor_t dkDesc,
380
+ void *dkeys,
381
+ const void *keys,
382
+ const cudnnSeqDataDescriptor_t dvDesc,
383
+ void *dvalues,
384
+ const void *values,
385
+ size_t weightSizeInBytes,
386
+ const void *weights,
387
+ size_t workSpaceSizeInBytes,
388
+ void *workSpace,
389
+ size_t reserveSpaceSizeInBytes,
390
+ void *reserveSpace);
391
+
392
+ cudnnStatus_t CUDNNWINAPI
393
+ cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle,
394
+ const cudnnAttnDescriptor_t attnDesc,
395
+ cudnnWgradMode_t addGrad,
396
+ const cudnnSeqDataDescriptor_t qDesc,
397
+ const void *queries,
398
+ const cudnnSeqDataDescriptor_t kDesc,
399
+ const void *keys,
400
+ const cudnnSeqDataDescriptor_t vDesc,
401
+ const void *values,
402
+ const cudnnSeqDataDescriptor_t doDesc,
403
+ const void *dout,
404
+ size_t weightSizeInBytes,
405
+ const void *weights,
406
+ void *dweights,
407
+ size_t workSpaceSizeInBytes,
408
+ void *workSpace,
409
+ size_t reserveSpaceSizeInBytes,
410
+ void *reserveSpace);
411
+
412
+ /*
413
+ * CTC (Connectionist Temporal Classification) loss descriptor create/destory/set/get functions
414
+ */
415
+ /* Input normalization mode for loss function */
416
+ typedef enum {
417
+ CUDNN_LOSS_NORMALIZATION_NONE = 0,
418
+ CUDNN_LOSS_NORMALIZATION_SOFTMAX = 1,
419
+ } cudnnLossNormalizationMode_t;
420
+
421
+ cudnnStatus_t CUDNNWINAPI
422
+ cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc);
423
+
424
+ cudnnStatus_t CUDNNWINAPI
425
+ cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType);
426
+
427
+ cudnnStatus_t CUDNNWINAPI
428
+ cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
429
+ cudnnDataType_t compType,
430
+ cudnnLossNormalizationMode_t normMode,
431
+ cudnnNanPropagation_t gradMode);
432
+
433
+ cudnnStatus_t CUDNNWINAPI
434
+ cudnnSetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
435
+ cudnnDataType_t compType,
436
+ cudnnLossNormalizationMode_t normMode,
437
+ cudnnNanPropagation_t gradMode,
438
+ int maxLabelLength);
439
+
440
+ cudnnStatus_t CUDNNWINAPI
441
+ cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType);
442
+
443
+ cudnnStatus_t CUDNNWINAPI
444
+ cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
445
+ cudnnDataType_t *compType,
446
+ cudnnLossNormalizationMode_t *normMode,
447
+ cudnnNanPropagation_t *gradMode);
448
+
449
+ cudnnStatus_t CUDNNWINAPI
450
+ cudnnGetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
451
+ cudnnDataType_t *compType,
452
+ cudnnLossNormalizationMode_t *normMode,
453
+ cudnnNanPropagation_t *gradMode,
454
+ int *maxLabelLength);
455
+
456
+ cudnnStatus_t CUDNNWINAPI
457
+ cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc);
458
+
459
+ /* return the ctc costs and gradients, given the probabilities and labels */
460
+ cudnnStatus_t CUDNNWINAPI
461
+ cudnnCTCLoss(
462
+ cudnnHandle_t handle,
463
+ const cudnnTensorDescriptor_t
464
+ probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
465
+ mini batch size, A is the alphabet size) */
466
+ const void *probs, /* probabilities after softmax, in GPU memory */
467
+ const int hostLabels[], /* labels, in CPU memory */
468
+ const int hostLabelLengths[], /* the length of each label, in CPU memory */
469
+ const int hostInputLengths[], /* the lengths of timing steps in each batch, in CPU memory */
470
+ void *costs, /* the returned costs of CTC, in GPU memory */
471
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
472
+ void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
473
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
474
+ cudnnCTCLossDescriptor_t ctcLossDesc,
475
+ void *workspace, /* pointer to the workspace, in GPU memory */
476
+ size_t workSpaceSizeInBytes); /* size of the workspace */
477
+
478
+ /* return the ctc costs and gradients, given the probabilities and labels */
479
+ cudnnStatus_t CUDNNWINAPI
480
+ cudnnCTCLoss_v8(
481
+ cudnnHandle_t handle,
482
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
483
+ cudnnCTCLossDescriptor_t ctcLossDesc,
484
+ const cudnnTensorDescriptor_t
485
+ probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the
486
+ mini batch size, A is the alphabet size) */
487
+ const void *probs, /* probabilities after softmax, in GPU memory */
488
+ const int labels[], /* labels, in GPU memory */
489
+ const int labelLengths[], /* the length of each label, in GPU memory */
490
+ const int inputLengths[], /* the lengths of timing steps in each batch, in GPU memory */
491
+ void *costs, /* the returned costs of CTC, in GPU memory */
492
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
493
+ void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
494
+ size_t workSpaceSizeInBytes, /* size of the workspace */
495
+ void *workspace); /* pointer to the workspace, in GPU memory */
496
+
497
+ /* return the workspace size needed for ctc */
498
+ cudnnStatus_t CUDNNWINAPI
499
+ cudnnGetCTCLossWorkspaceSize(
500
+ cudnnHandle_t handle,
501
+ const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
502
+ timing steps, N is the mini batch size, A is the alphabet size) */
503
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
504
+ dimensions are T,N,A. To compute costs
505
+ only, set it to NULL */
506
+ const int *labels, /* labels, in CPU memory */
507
+ const int *labelLengths, /* the length of each label, in CPU memory */
508
+ const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */
509
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
510
+ cudnnCTCLossDescriptor_t ctcLossDesc,
511
+ size_t *sizeInBytes); /* pointer to the returned workspace size */
512
+
513
+ /* return the workspace size needed for ctc */
514
+ cudnnStatus_t CUDNNWINAPI
515
+ cudnnGetCTCLossWorkspaceSize_v8(
516
+ cudnnHandle_t handle,
517
+ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
518
+ cudnnCTCLossDescriptor_t ctcLossDesc,
519
+ const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
520
+ timing steps, N is the mini batch size, A is the alphabet size) */
521
+ const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
522
+ dimensions are T,N,A. To compute costs
523
+ only, set it to NULL */
524
+ size_t *sizeInBytes); /* pointer to the returned workspace size */
525
+
526
+ /*
527
+ * \brief Cross-library version checker.
528
+ * This function is implemented differently in each sub-library. Each sublib
529
+ * checks whether its own version matches that of its dependencies.
530
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
531
+ * CUDNN_STATUS_VERSION_MISMATCH if the versions are inconsistent.
532
+ */
533
+ cudnnStatus_t CUDNNWINAPI
534
+ cudnnAdvTrainVersionCheck(void);
535
+
536
+ #if defined(__cplusplus)
537
+ }
538
+ #endif
539
+
540
+ #endif /* CUDNN_ADV_TRAIN_H_ */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_backend_v8.h ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #ifndef _CUDNN_BACKEND_H_
51
+ #define _CUDNN_BACKEND_H_
52
+
53
+ /*
54
+ * The content in this header file is under development to be included in cudnn.h in the future
55
+ * Production code should have all include of this header file remove.
56
+ */
57
+
58
+ #include "cudnn_ops_infer.h"
59
+ #include "cudnn_cnn_infer.h"
60
+
61
+ /* NOTE: definition in extern "C" to be copied later to public header */
62
+ #if defined(__cplusplus)
63
+ extern "C" {
64
+ #endif
65
+
66
+ typedef void *cudnnBackendDescriptor_t;
67
+
68
+ typedef struct cudnnFractionStruct {
69
+ int64_t numerator;
70
+ int64_t denominator;
71
+ } cudnnFraction_t;
72
+
73
+ typedef enum {
74
+ CUDNN_POINTWISE_ADD = 0,
75
+ CUDNN_POINTWISE_ADD_SQUARE = 5,
76
+ CUDNN_POINTWISE_DIV = 6,
77
+ CUDNN_POINTWISE_MAX = 3,
78
+ CUDNN_POINTWISE_MIN = 2,
79
+ CUDNN_POINTWISE_MOD = 7,
80
+ CUDNN_POINTWISE_MUL = 1,
81
+ CUDNN_POINTWISE_POW = 8,
82
+ CUDNN_POINTWISE_SUB = 9,
83
+
84
+ CUDNN_POINTWISE_ABS = 10,
85
+ CUDNN_POINTWISE_CEIL = 11,
86
+ CUDNN_POINTWISE_COS = 12,
87
+ CUDNN_POINTWISE_EXP = 13,
88
+ CUDNN_POINTWISE_FLOOR = 14,
89
+ CUDNN_POINTWISE_LOG = 15,
90
+ CUDNN_POINTWISE_NEG = 16,
91
+ CUDNN_POINTWISE_RSQRT = 17,
92
+ CUDNN_POINTWISE_SIN = 18,
93
+ CUDNN_POINTWISE_SQRT = 4,
94
+ CUDNN_POINTWISE_TAN = 19,
95
+ CUDNN_POINTWISE_ERF = 20,
96
+ CUDNN_POINTWISE_IDENTITY = 21,
97
+
98
+ CUDNN_POINTWISE_RELU_FWD = 100,
99
+ CUDNN_POINTWISE_TANH_FWD = 101,
100
+ CUDNN_POINTWISE_SIGMOID_FWD = 102,
101
+ CUDNN_POINTWISE_ELU_FWD = 103,
102
+ CUDNN_POINTWISE_GELU_FWD = 104,
103
+ CUDNN_POINTWISE_SOFTPLUS_FWD = 105,
104
+ CUDNN_POINTWISE_SWISH_FWD = 106,
105
+ CUDNN_POINTWISE_GELU_APPROX_TANH_FWD = 107,
106
+
107
+ CUDNN_POINTWISE_RELU_BWD = 200,
108
+ CUDNN_POINTWISE_TANH_BWD = 201,
109
+ CUDNN_POINTWISE_SIGMOID_BWD = 202,
110
+ CUDNN_POINTWISE_ELU_BWD = 203,
111
+ CUDNN_POINTWISE_GELU_BWD = 204,
112
+ CUDNN_POINTWISE_SOFTPLUS_BWD = 205,
113
+ CUDNN_POINTWISE_SWISH_BWD = 206,
114
+ CUDNN_POINTWISE_GELU_APPROX_TANH_BWD = 207,
115
+
116
+ CUDNN_POINTWISE_CMP_EQ = 300,
117
+ CUDNN_POINTWISE_CMP_NEQ = 301,
118
+ CUDNN_POINTWISE_CMP_GT = 302,
119
+ CUDNN_POINTWISE_CMP_GE = 303,
120
+ CUDNN_POINTWISE_CMP_LT = 304,
121
+ CUDNN_POINTWISE_CMP_LE = 305,
122
+
123
+ CUDNN_POINTWISE_LOGICAL_AND = 400,
124
+ CUDNN_POINTWISE_LOGICAL_OR = 401,
125
+ CUDNN_POINTWISE_LOGICAL_NOT = 402,
126
+
127
+ CUDNN_POINTWISE_GEN_INDEX = 501,
128
+
129
+ CUDNN_POINTWISE_BINARY_SELECT = 601,
130
+ } cudnnPointwiseMode_t;
131
+
132
+ typedef enum {
133
+ CUDNN_RESAMPLE_NEAREST = 0,
134
+ CUDNN_RESAMPLE_BILINEAR = 1,
135
+ CUDNN_RESAMPLE_AVGPOOL = 2,
136
+ CUDNN_RESAMPLE_AVGPOOL_INCLUDE_PADDING = 2,
137
+ CUDNN_RESAMPLE_AVGPOOL_EXCLUDE_PADDING = 4,
138
+ CUDNN_RESAMPLE_MAXPOOL = 3,
139
+ } cudnnResampleMode_t;
140
+
141
+ typedef enum {
142
+ CUDNN_SIGNAL_SET = 0,
143
+ CUDNN_SIGNAL_WAIT = 1,
144
+ } cudnnSignalMode_t;
145
+
146
+ typedef enum {
147
+ CUDNN_GENSTATS_SUM_SQSUM = 0,
148
+ } cudnnGenStatsMode_t;
149
+
150
+ typedef enum {
151
+ CUDNN_BN_FINALIZE_STATISTICS_TRAINING = 0,
152
+ CUDNN_BN_FINALIZE_STATISTICS_INFERENCE = 1,
153
+ } cudnnBnFinalizeStatsMode_t;
154
+
155
+ typedef enum {
156
+ CUDNN_RNG_DISTRIBUTION_BERNOULLI,
157
+ CUDNN_RNG_DISTRIBUTION_UNIFORM,
158
+ CUDNN_RNG_DISTRIBUTION_NORMAL,
159
+ } cudnnRngDistribution_t;
160
+
161
+ typedef enum {
162
+ CUDNN_ATTR_POINTWISE_MODE = 0,
163
+ CUDNN_ATTR_POINTWISE_MATH_PREC = 1,
164
+ CUDNN_ATTR_POINTWISE_NAN_PROPAGATION = 2,
165
+ CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3,
166
+ CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4,
167
+ CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE = 5,
168
+ CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6,
169
+ CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7,
170
+ CUDNN_ATTR_POINTWISE_SWISH_BETA = 8,
171
+ CUDNN_ATTR_POINTWISE_AXIS = 9,
172
+
173
+ CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100,
174
+ CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101,
175
+ CUDNN_ATTR_CONVOLUTION_DILATIONS = 102,
176
+ CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES = 103,
177
+ CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104,
178
+ CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105,
179
+ CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106,
180
+
181
+ CUDNN_ATTR_ENGINEHEUR_MODE = 200,
182
+ CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH = 201,
183
+ CUDNN_ATTR_ENGINEHEUR_RESULTS = 202,
184
+
185
+ CUDNN_ATTR_ENGINECFG_ENGINE = 300,
186
+ CUDNN_ATTR_ENGINECFG_INTERMEDIATE_INFO = 301,
187
+ CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302,
188
+
189
+ CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400,
190
+ CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401,
191
+ CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402,
192
+ CUDNN_ATTR_EXECUTION_PLAN_COMPUTED_INTERMEDIATE_UIDS = 403,
193
+ CUDNN_ATTR_EXECUTION_PLAN_RUN_ONLY_INTERMEDIATE_UIDS = 404,
194
+ CUDNN_ATTR_EXECUTION_PLAN_JSON_REPRESENTATION = 405,
195
+
196
+ CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500,
197
+ CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501,
198
+ CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502,
199
+ CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_ATTRIBUTES = 503,
200
+
201
+ CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600,
202
+ CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE = 601,
203
+
204
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700,
205
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701,
206
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702,
207
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703,
208
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704,
209
+ CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705,
210
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706,
211
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707,
212
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708,
213
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709,
214
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710,
215
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711,
216
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712,
217
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713,
218
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC = 714,
219
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715,
220
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716,
221
+ CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717,
222
+
223
+ CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR = 750,
224
+ CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751,
225
+ CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752,
226
+ CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753,
227
+ CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754,
228
+ CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755,
229
+ CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756,
230
+ CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757,
231
+ CUDNN_ATTR_OPERATION_POINTWISE_TDESC = 758,
232
+
233
+ CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770,
234
+ CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC = 771,
235
+ CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772,
236
+ CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773,
237
+ CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC = 774,
238
+
239
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780,
240
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781,
241
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782,
242
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783,
243
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784,
244
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785,
245
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786,
246
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787,
247
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC = 788,
248
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789,
249
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790,
250
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791,
251
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792,
252
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793,
253
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794,
254
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795,
255
+ CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796,
256
+
257
+ CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800,
258
+ CUDNN_ATTR_OPERATIONGRAPH_OPS = 801,
259
+ CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT = 802,
260
+
261
+ CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900,
262
+ CUDNN_ATTR_TENSOR_DATA_TYPE = 901,
263
+ CUDNN_ATTR_TENSOR_DIMENSIONS = 902,
264
+ CUDNN_ATTR_TENSOR_STRIDES = 903,
265
+ CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904,
266
+ CUDNN_ATTR_TENSOR_VECTORIZED_DIMENSION = 905,
267
+ CUDNN_ATTR_TENSOR_UNIQUE_ID = 906,
268
+ CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907,
269
+ CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908,
270
+ CUDNN_ATTR_TENSOR_REORDERING_MODE = 909,
271
+
272
+ CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000,
273
+ CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS = 1001,
274
+ CUDNN_ATTR_VARIANT_PACK_INTERMEDIATES = 1002,
275
+ CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003,
276
+
277
+ CUDNN_ATTR_LAYOUT_INFO_TENSOR_UID = 1100,
278
+ CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101,
279
+
280
+ CUDNN_ATTR_KNOB_INFO_TYPE = 1200,
281
+ CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE = 1201,
282
+ CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE = 1202,
283
+ CUDNN_ATTR_KNOB_INFO_STRIDE = 1203,
284
+
285
+ CUDNN_ATTR_ENGINE_OPERATION_GRAPH = 1300,
286
+ CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301,
287
+ CUDNN_ATTR_ENGINE_KNOB_INFO = 1302,
288
+ CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303,
289
+ CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304,
290
+ CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305,
291
+
292
+ CUDNN_ATTR_MATMUL_COMP_TYPE = 1500,
293
+
294
+ CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520,
295
+ CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521,
296
+ CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522,
297
+ CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523,
298
+ CUDNN_ATTR_OPERATION_MATMUL_IRREGULARLY_STRIDED_BATCH_COUNT = 1524,
299
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_M_OVERRIDE_DESC = 1525,
300
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_N_OVERRIDE_DESC = 1526,
301
+ CUDNN_ATTR_OPERATION_MATMUL_GEMM_K_OVERRIDE_DESC = 1527,
302
+
303
+ CUDNN_ATTR_REDUCTION_OPERATOR = 1600,
304
+ CUDNN_ATTR_REDUCTION_COMP_TYPE = 1601,
305
+
306
+ CUDNN_ATTR_OPERATION_REDUCTION_XDESC = 1610,
307
+ CUDNN_ATTR_OPERATION_REDUCTION_YDESC = 1611,
308
+ CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612,
309
+
310
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620,
311
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621,
312
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622,
313
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623,
314
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624,
315
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625,
316
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626,
317
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627,
318
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC = 1628,
319
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629,
320
+ CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630,
321
+
322
+ CUDNN_ATTR_RESAMPLE_MODE = 1700,
323
+ CUDNN_ATTR_RESAMPLE_COMP_TYPE = 1701,
324
+ CUDNN_ATTR_RESAMPLE_SPATIAL_DIMS = 1702,
325
+ CUDNN_ATTR_RESAMPLE_POST_PADDINGS = 1703,
326
+ CUDNN_ATTR_RESAMPLE_PRE_PADDINGS = 1704,
327
+ CUDNN_ATTR_RESAMPLE_STRIDES = 1705,
328
+ CUDNN_ATTR_RESAMPLE_WINDOW_DIMS = 1706,
329
+ CUDNN_ATTR_RESAMPLE_NAN_PROPAGATION = 1707,
330
+ CUDNN_ATTR_RESAMPLE_PADDING_MODE = 1708,
331
+
332
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_XDESC = 1710,
333
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_YDESC = 1711,
334
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_IDXDESC = 1712,
335
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_ALPHA = 1713,
336
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_BETA = 1714,
337
+ CUDNN_ATTR_OPERATION_RESAMPLE_FWD_DESC = 1716,
338
+
339
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DXDESC = 1720,
340
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DYDESC = 1721,
341
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_IDXDESC = 1722,
342
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_ALPHA = 1723,
343
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_BETA = 1724,
344
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_DESC = 1725,
345
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_XDESC = 1726,
346
+ CUDNN_ATTR_OPERATION_RESAMPLE_BWD_YDESC = 1727,
347
+
348
+ CUDNN_ATTR_OPERATION_CONCAT_AXIS = 1800,
349
+ CUDNN_ATTR_OPERATION_CONCAT_INPUT_DESCS = 1801,
350
+ CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX = 1802,
351
+ CUDNN_ATTR_OPERATION_CONCAT_OUTPUT_DESC = 1803,
352
+
353
+ CUDNN_ATTR_OPERATION_SIGNAL_MODE = 1900,
354
+ CUDNN_ATTR_OPERATION_SIGNAL_FLAGDESC = 1901,
355
+ CUDNN_ATTR_OPERATION_SIGNAL_VALUE = 1902,
356
+ CUDNN_ATTR_OPERATION_SIGNAL_XDESC = 1903,
357
+ CUDNN_ATTR_OPERATION_SIGNAL_YDESC = 1904,
358
+
359
+ CUDNN_ATTR_OPERATION_NORM_FWD_MODE = 2000,
360
+ CUDNN_ATTR_OPERATION_NORM_FWD_PHASE = 2001,
361
+ CUDNN_ATTR_OPERATION_NORM_FWD_XDESC = 2002,
362
+ CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC = 2003,
363
+ CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC = 2004,
364
+ CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC = 2005,
365
+ CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC = 2006,
366
+ CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC = 2007,
367
+ CUDNN_ATTR_OPERATION_NORM_FWD_EXP_AVG_FACTOR_DESC = 2008,
368
+ CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_MEAN_DESC = 2009,
369
+ CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_VAR_DESC = 2010,
370
+ CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_MEAN_DESC = 2011,
371
+ CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_VAR_DESC = 2012,
372
+ CUDNN_ATTR_OPERATION_NORM_FWD_YDESC = 2013,
373
+ CUDNN_ATTR_OPERATION_NORM_FWD_PEER_STAT_DESCS = 2014,
374
+
375
+ CUDNN_ATTR_OPERATION_NORM_BWD_MODE = 2100,
376
+ CUDNN_ATTR_OPERATION_NORM_BWD_XDESC = 2101,
377
+ CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC = 2102,
378
+ CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC = 2103,
379
+ CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC = 2104,
380
+ CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC = 2105,
381
+ CUDNN_ATTR_OPERATION_NORM_BWD_EPSILON_DESC = 2106,
382
+ CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC = 2107,
383
+ CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC = 2108,
384
+ CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC = 2109,
385
+ CUDNN_ATTR_OPERATION_NORM_BWD_PEER_STAT_DESCS = 2110,
386
+
387
+ CUDNN_ATTR_OPERATION_RESHAPE_XDESC = 2200,
388
+ CUDNN_ATTR_OPERATION_RESHAPE_YDESC = 2201,
389
+
390
+ CUDNN_ATTR_RNG_DISTRIBUTION = 2300,
391
+ CUDNN_ATTR_RNG_NORMAL_DIST_MEAN = 2301,
392
+ CUDNN_ATTR_RNG_NORMAL_DIST_STANDARD_DEVIATION = 2302,
393
+ CUDNN_ATTR_RNG_UNIFORM_DIST_MAXIMUM = 2303,
394
+ CUDNN_ATTR_RNG_UNIFORM_DIST_MINIMUM = 2304,
395
+ CUDNN_ATTR_RNG_BERNOULLI_DIST_PROBABILITY = 2305,
396
+
397
+ CUDNN_ATTR_OPERATION_RNG_YDESC = 2310,
398
+ CUDNN_ATTR_OPERATION_RNG_SEED = 2311,
399
+ CUDNN_ATTR_OPERATION_RNG_DESC = 2312,
400
+
401
+ } cudnnBackendAttributeName_t;
402
+
403
+ typedef enum {
404
+ CUDNN_TYPE_HANDLE = 0,
405
+ CUDNN_TYPE_DATA_TYPE,
406
+ CUDNN_TYPE_BOOLEAN,
407
+ CUDNN_TYPE_INT64,
408
+ CUDNN_TYPE_FLOAT,
409
+ CUDNN_TYPE_DOUBLE,
410
+ CUDNN_TYPE_VOID_PTR,
411
+ CUDNN_TYPE_CONVOLUTION_MODE,
412
+ CUDNN_TYPE_HEUR_MODE,
413
+ CUDNN_TYPE_KNOB_TYPE,
414
+ CUDNN_TYPE_NAN_PROPOGATION,
415
+ CUDNN_TYPE_NUMERICAL_NOTE,
416
+ CUDNN_TYPE_LAYOUT_TYPE,
417
+ CUDNN_TYPE_ATTRIB_NAME,
418
+ CUDNN_TYPE_POINTWISE_MODE,
419
+ CUDNN_TYPE_BACKEND_DESCRIPTOR,
420
+ CUDNN_TYPE_GENSTATS_MODE,
421
+ CUDNN_TYPE_BN_FINALIZE_STATS_MODE,
422
+ CUDNN_TYPE_REDUCTION_OPERATOR_TYPE,
423
+ CUDNN_TYPE_BEHAVIOR_NOTE,
424
+ CUDNN_TYPE_TENSOR_REORDERING_MODE,
425
+ CUDNN_TYPE_RESAMPLE_MODE,
426
+ CUDNN_TYPE_PADDING_MODE,
427
+ CUDNN_TYPE_INT32,
428
+ CUDNN_TYPE_CHAR,
429
+ CUDNN_TYPE_SIGNAL_MODE,
430
+ CUDNN_TYPE_FRACTION,
431
+ CUDNN_TYPE_NORM_MODE,
432
+ CUDNN_TYPE_NORM_FWD_PHASE,
433
+ CUDNN_TYPE_RNG_DISTRIBUTION
434
+ } cudnnBackendAttributeType_t;
435
+
436
+ typedef enum {
437
+ CUDNN_BACKEND_POINTWISE_DESCRIPTOR = 0,
438
+ CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR,
439
+ CUDNN_BACKEND_ENGINE_DESCRIPTOR,
440
+ CUDNN_BACKEND_ENGINECFG_DESCRIPTOR,
441
+ CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR,
442
+ CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR,
443
+ CUDNN_BACKEND_INTERMEDIATE_INFO_DESCRIPTOR,
444
+ CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR,
445
+ CUDNN_BACKEND_KNOB_INFO_DESCRIPTOR,
446
+ CUDNN_BACKEND_LAYOUT_INFO_DESCRIPTOR,
447
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
448
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
449
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
450
+ CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR,
451
+ CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR,
452
+ CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR,
453
+ CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR,
454
+ CUDNN_BACKEND_TENSOR_DESCRIPTOR,
455
+ CUDNN_BACKEND_MATMUL_DESCRIPTOR,
456
+ CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR,
457
+ CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR,
458
+ CUDNN_BACKEND_REDUCTION_DESCRIPTOR,
459
+ CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR,
460
+ CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR,
461
+ CUDNN_BACKEND_RESAMPLE_DESCRIPTOR,
462
+ CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR,
463
+ CUDNN_BACKEND_OPERATION_RESAMPLE_BWD_DESCRIPTOR,
464
+ CUDNN_BACKEND_OPERATION_CONCAT_DESCRIPTOR,
465
+ CUDNN_BACKEND_OPERATION_SIGNAL_DESCRIPTOR,
466
+ CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR,
467
+ CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR,
468
+ CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR,
469
+ CUDNN_BACKEND_RNG_DESCRIPTOR,
470
+ CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR
471
+ } cudnnBackendDescriptorType_t;
472
+
473
+ typedef enum {
474
+ CUDNN_NUMERICAL_NOTE_TENSOR_CORE = 0,
475
+ CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS,
476
+ CUDNN_NUMERICAL_NOTE_REDUCED_PRECISION_REDUCTION,
477
+ CUDNN_NUMERICAL_NOTE_FFT,
478
+ CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC,
479
+ CUDNN_NUMERICAL_NOTE_WINOGRAD,
480
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_4x4,
481
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_6x6,
482
+ CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_13x13,
483
+ CUDNN_NUMERICAL_NOTE_TYPE_COUNT,
484
+ } cudnnBackendNumericalNote_t;
485
+
486
+ typedef enum {
487
+ CUDNN_BEHAVIOR_NOTE_RUNTIME_COMPILATION = 0,
488
+ CUDNN_BEHAVIOR_NOTE_REQUIRES_FILTER_INT8x32_REORDER = 1,
489
+ CUDNN_BEHAVIOR_NOTE_REQUIRES_BIAS_INT8x32_REORDER = 2,
490
+ CUDNN_BEHAVIOR_NOTE_TYPE_COUNT,
491
+ } cudnnBackendBehaviorNote_t;
492
+
493
+ typedef enum {
494
+ CUDNN_KNOB_TYPE_SPLIT_K = 0,
495
+ CUDNN_KNOB_TYPE_SWIZZLE = 1,
496
+ CUDNN_KNOB_TYPE_TILE_SIZE = 2,
497
+ CUDNN_KNOB_TYPE_USE_TEX = 3,
498
+ CUDNN_KNOB_TYPE_EDGE = 4,
499
+ CUDNN_KNOB_TYPE_KBLOCK = 5,
500
+ CUDNN_KNOB_TYPE_LDGA = 6,
501
+ CUDNN_KNOB_TYPE_LDGB = 7,
502
+ CUDNN_KNOB_TYPE_CHUNK_K = 8,
503
+ CUDNN_KNOB_TYPE_SPLIT_H = 9,
504
+ CUDNN_KNOB_TYPE_WINO_TILE = 10,
505
+ CUDNN_KNOB_TYPE_MULTIPLY = 11,
506
+ CUDNN_KNOB_TYPE_SPLIT_K_BUF = 12,
507
+ CUDNN_KNOB_TYPE_TILEK = 13,
508
+ CUDNN_KNOB_TYPE_STAGES = 14,
509
+ CUDNN_KNOB_TYPE_REDUCTION_MODE = 15,
510
+ CUDNN_KNOB_TYPE_CTA_SPLIT_K_MODE = 16,
511
+ CUDNN_KNOB_TYPE_SPLIT_K_SLC = 17,
512
+ CUDNN_KNOB_TYPE_IDX_MODE = 18,
513
+ CUDNN_KNOB_TYPE_SLICED = 19,
514
+ CUDNN_KNOB_TYPE_SPLIT_RS = 20,
515
+ CUDNN_KNOB_TYPE_SINGLEBUFFER = 21,
516
+ CUDNN_KNOB_TYPE_LDGC = 22,
517
+ CUDNN_KNOB_TYPE_SPECFILT = 23,
518
+ CUDNN_KNOB_TYPE_KERNEL_CFG = 24,
519
+ CUDNN_KNOB_TYPE_WORKSPACE = 25,
520
+ CUDNN_KNOB_TYPE_TILE_CGA = 26,
521
+ CUDNN_KNOB_TYPE_TILE_CGA_M = 27,
522
+ CUDNN_KNOB_TYPE_TILE_CGA_N = 28,
523
+
524
+ CUDNN_KNOB_TYPE_COUNTS,
525
+ } cudnnBackendKnobType_t;
526
+
527
+ typedef enum {
528
+ CUDNN_LAYOUT_TYPE_PREFERRED_NCHW = 0,
529
+ CUDNN_LAYOUT_TYPE_PREFERRED_NHWC = 1,
530
+ CUDNN_LAYOUT_TYPE_PREFERRED_PAD4CK = 2,
531
+ CUDNN_LAYOUT_TYPE_PREFERRED_PAD8CK = 3,
532
+ CUDNN_LAYOUT_TYPE_COUNT = 4,
533
+ } cudnnBackendLayoutType_t;
534
+
535
+ typedef enum {
536
+ CUDNN_HEUR_MODE_INSTANT = 0,
537
+ CUDNN_HEUR_MODE_B = 1,
538
+ CUDNN_HEUR_MODE_FALLBACK = 2,
539
+ CUDNN_HEUR_MODE_A = 3,
540
+ CUDNN_HEUR_MODES_COUNT = 4,
541
+ } cudnnBackendHeurMode_t;
542
+
543
+ typedef enum {
544
+ CUDNN_TENSOR_REORDERING_NONE = 0,
545
+ CUDNN_TENSOR_REORDERING_INT8x32 = 1,
546
+ } cudnnBackendTensorReordering_t;
547
+
548
+ typedef enum {
549
+ CUDNN_ZERO_PAD = 0,
550
+ CUDNN_NEG_INF_PAD = 1,
551
+ CUDNN_EDGE_VAL_PAD = 2,
552
+ } cudnnPaddingMode_t;
553
+
554
+ typedef enum {
555
+ CUDNN_LAYER_NORM = 0,
556
+ CUDNN_INSTANCE_NORM = 1,
557
+ CUDNN_BATCH_NORM = 2,
558
+ CUDNN_GROUP_NORM = 3,
559
+ } cudnnBackendNormMode_t;
560
+
561
+ typedef enum {
562
+ CUDNN_NORM_FWD_INFERENCE = 0,
563
+ CUDNN_NORM_FWD_TRAINING = 1,
564
+ } cudnnBackendNormFwdPhase_t;
565
+
566
+ cudnnStatus_t CUDNNWINAPI
567
+ cudnnBackendCreateDescriptor(cudnnBackendDescriptorType_t descriptorType, cudnnBackendDescriptor_t *descriptor);
568
+
569
+ cudnnStatus_t CUDNNWINAPI
570
+ cudnnBackendDestroyDescriptor(cudnnBackendDescriptor_t descriptor);
571
+
572
+ cudnnStatus_t CUDNNWINAPI
573
+ cudnnBackendInitialize(cudnnBackendDescriptor_t descriptor);
574
+
575
+ cudnnStatus_t CUDNNWINAPI
576
+ cudnnBackendFinalize(cudnnBackendDescriptor_t descriptor);
577
+
578
+ cudnnStatus_t CUDNNWINAPI
579
+ cudnnBackendSetAttribute(cudnnBackendDescriptor_t descriptor,
580
+ cudnnBackendAttributeName_t attributeName,
581
+ cudnnBackendAttributeType_t attributeType,
582
+ int64_t elementCount,
583
+ const void *arrayOfElements);
584
+
585
+ cudnnStatus_t CUDNNWINAPI
586
+ cudnnBackendGetAttribute(cudnnBackendDescriptor_t const descriptor,
587
+ cudnnBackendAttributeName_t attributeName,
588
+ cudnnBackendAttributeType_t attributeType,
589
+ int64_t requestedElementCount,
590
+ int64_t *elementCount,
591
+ void *arrayOfElements);
592
+
593
+ cudnnStatus_t CUDNNWINAPI
594
+ cudnnBackendExecute(cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBackendDescriptor_t variantPack);
595
+
596
+ #if defined(__cplusplus)
597
+ }
598
+ #endif
599
+
600
+ #endif /* _CUDNN_BACKEND_H_ */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_train.h ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_cnn_train : cuDNN's basic definitions and inference CNN functions.
52
+ */
53
+
54
+ #pragma once
55
+ #include <cuda_runtime.h>
56
+ #include <stdint.h>
57
+
58
+ #include "cudnn_version.h"
59
+ #include "cudnn_ops_infer.h"
60
+ #include "cudnn_ops_train.h"
61
+ #include "cudnn_cnn_infer.h"
62
+
63
+ /* These version numbers are autogenerated, do not edit manually. */
64
+ #define CUDNN_CNN_TRAIN_MAJOR 8
65
+ #define CUDNN_CNN_TRAIN_MINOR 7
66
+ #define CUDNN_CNN_TRAIN_PATCH 0
67
+
68
+ #if (CUDNN_CNN_TRAIN_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_TRAIN_MINOR != CUDNN_MINOR) || \
69
+ (CUDNN_CNN_TRAIN_PATCH != CUDNN_PATCHLEVEL)
70
+ #error Version mismatch in cuDNN CNN INFER!!!
71
+ #endif
72
+
73
+ #if defined(__cplusplus)
74
+ extern "C" {
75
+ #endif
76
+
77
+ /* helper function to provide the convolution backward filter algo that fit best the requirement */
78
+
79
+ typedef struct cudnnConvolutionBwdFilterAlgoPerfStruct {
80
+ cudnnConvolutionBwdFilterAlgo_t algo;
81
+ cudnnStatus_t status;
82
+ float time;
83
+ size_t memory;
84
+ cudnnDeterminism_t determinism;
85
+ cudnnMathType_t mathType;
86
+ int reserved[3];
87
+ } cudnnConvolutionBwdFilterAlgoPerf_t;
88
+
89
+ cudnnStatus_t CUDNNWINAPI
90
+ cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count);
91
+
92
+ cudnnStatus_t CUDNNWINAPI
93
+ cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
94
+ const cudnnTensorDescriptor_t xDesc,
95
+ const cudnnTensorDescriptor_t dyDesc,
96
+ const cudnnConvolutionDescriptor_t convDesc,
97
+ const cudnnFilterDescriptor_t dwDesc,
98
+ const int requestedAlgoCount,
99
+ int *returnedAlgoCount,
100
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
101
+
102
+ cudnnStatus_t CUDNNWINAPI
103
+ cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
104
+ const cudnnTensorDescriptor_t xDesc,
105
+ const void *x,
106
+ const cudnnTensorDescriptor_t dyDesc,
107
+ const void *y,
108
+ const cudnnConvolutionDescriptor_t convDesc,
109
+ const cudnnFilterDescriptor_t dwDesc,
110
+ void *dw,
111
+ const int requestedAlgoCount,
112
+ int *returnedAlgoCount,
113
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
114
+ void *workSpace,
115
+ size_t workSpaceSizeInBytes);
116
+
117
+ cudnnStatus_t CUDNNWINAPI
118
+ cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
119
+ const cudnnTensorDescriptor_t srcDesc,
120
+ const cudnnTensorDescriptor_t diffDesc,
121
+ const cudnnConvolutionDescriptor_t convDesc,
122
+ const cudnnFilterDescriptor_t gradDesc,
123
+ const int requestedAlgoCount,
124
+ int *returnedAlgoCount,
125
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
126
+
127
+ /*
128
+ * convolution algorithm (which requires potentially some workspace)
129
+ */
130
+
131
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
132
+ cudnnStatus_t CUDNNWINAPI
133
+ cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle,
134
+ const cudnnTensorDescriptor_t xDesc,
135
+ const cudnnTensorDescriptor_t dyDesc,
136
+ const cudnnConvolutionDescriptor_t convDesc,
137
+ const cudnnFilterDescriptor_t gradDesc,
138
+ cudnnConvolutionBwdFilterAlgo_t algo,
139
+ size_t *sizeInBytes);
140
+
141
+ cudnnStatus_t CUDNNWINAPI
142
+ cudnnConvolutionBackwardFilter(cudnnHandle_t handle,
143
+ const void *alpha,
144
+ const cudnnTensorDescriptor_t xDesc,
145
+ const void *x,
146
+ const cudnnTensorDescriptor_t dyDesc,
147
+ const void *dy,
148
+ const cudnnConvolutionDescriptor_t convDesc,
149
+ cudnnConvolutionBwdFilterAlgo_t algo,
150
+ void *workSpace,
151
+ size_t workSpaceSizeInBytes,
152
+ const void *beta,
153
+ const cudnnFilterDescriptor_t dwDesc,
154
+ void *dw);
155
+
156
+ /* Function to compute the bias gradient for batch convolution */
157
+ cudnnStatus_t CUDNNWINAPI
158
+ cudnnConvolutionBackwardBias(cudnnHandle_t handle,
159
+ const void *alpha,
160
+ const cudnnTensorDescriptor_t dyDesc,
161
+ const void *dy,
162
+ const void *beta,
163
+ const cudnnTensorDescriptor_t dbDesc,
164
+ void *db);
165
+
166
+ cudnnStatus_t CUDNNWINAPI
167
+ cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops);
168
+
169
+ cudnnStatus_t CUDNNWINAPI
170
+ cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack);
171
+
172
+ cudnnStatus_t CUDNNWINAPI
173
+ cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack,
174
+ cudnnFusedOpsConstParamLabel_t paramLabel,
175
+ const void *param);
176
+
177
+ cudnnStatus_t CUDNNWINAPI
178
+ cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack,
179
+ cudnnFusedOpsConstParamLabel_t paramLabel,
180
+ void *param,
181
+ int *isNULL);
182
+
183
+ cudnnStatus_t CUDNNWINAPI
184
+ cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops);
185
+
186
+ cudnnStatus_t CUDNNWINAPI
187
+ cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack);
188
+
189
+ cudnnStatus_t CUDNNWINAPI
190
+ cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack,
191
+ cudnnFusedOpsVariantParamLabel_t paramLabel,
192
+ void *ptr);
193
+
194
+ cudnnStatus_t CUDNNWINAPI
195
+ cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack,
196
+ cudnnFusedOpsVariantParamLabel_t paramLabel,
197
+ void *ptr);
198
+
199
+ cudnnStatus_t CUDNNWINAPI
200
+ cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops);
201
+
202
+ cudnnStatus_t CUDNNWINAPI
203
+ cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan);
204
+
205
+ cudnnStatus_t CUDNNWINAPI
206
+ cudnnMakeFusedOpsPlan(cudnnHandle_t handle,
207
+ cudnnFusedOpsPlan_t plan,
208
+ const cudnnFusedOpsConstParamPack_t constPack,
209
+ size_t *workspaceSizeInBytes);
210
+
211
+ cudnnStatus_t CUDNNWINAPI
212
+ cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack);
213
+
214
+ cudnnStatus_t CUDNNWINAPI
215
+ cudnnCnnTrainVersionCheck(void);
216
+
217
+ #if defined(__cplusplus)
218
+ }
219
+ #endif
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_train_v8.h ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_cnn_train : cuDNN's basic definitions and inference CNN functions.
52
+ */
53
+
54
+ #pragma once
55
+ #include <cuda_runtime.h>
56
+ #include <stdint.h>
57
+
58
+ #include "cudnn_version.h"
59
+ #include "cudnn_ops_infer.h"
60
+ #include "cudnn_ops_train.h"
61
+ #include "cudnn_cnn_infer.h"
62
+
63
+ /* These version numbers are autogenerated, do not edit manually. */
64
+ #define CUDNN_CNN_TRAIN_MAJOR 8
65
+ #define CUDNN_CNN_TRAIN_MINOR 7
66
+ #define CUDNN_CNN_TRAIN_PATCH 0
67
+
68
+ #if (CUDNN_CNN_TRAIN_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_TRAIN_MINOR != CUDNN_MINOR) || \
69
+ (CUDNN_CNN_TRAIN_PATCH != CUDNN_PATCHLEVEL)
70
+ #error Version mismatch in cuDNN CNN INFER!!!
71
+ #endif
72
+
73
+ #if defined(__cplusplus)
74
+ extern "C" {
75
+ #endif
76
+
77
+ /* helper function to provide the convolution backward filter algo that fit best the requirement */
78
+
79
+ typedef struct cudnnConvolutionBwdFilterAlgoPerfStruct {
80
+ cudnnConvolutionBwdFilterAlgo_t algo;
81
+ cudnnStatus_t status;
82
+ float time;
83
+ size_t memory;
84
+ cudnnDeterminism_t determinism;
85
+ cudnnMathType_t mathType;
86
+ int reserved[3];
87
+ } cudnnConvolutionBwdFilterAlgoPerf_t;
88
+
89
+ cudnnStatus_t CUDNNWINAPI
90
+ cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count);
91
+
92
+ cudnnStatus_t CUDNNWINAPI
93
+ cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle,
94
+ const cudnnTensorDescriptor_t xDesc,
95
+ const cudnnTensorDescriptor_t dyDesc,
96
+ const cudnnConvolutionDescriptor_t convDesc,
97
+ const cudnnFilterDescriptor_t dwDesc,
98
+ const int requestedAlgoCount,
99
+ int *returnedAlgoCount,
100
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
101
+
102
+ cudnnStatus_t CUDNNWINAPI
103
+ cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle,
104
+ const cudnnTensorDescriptor_t xDesc,
105
+ const void *x,
106
+ const cudnnTensorDescriptor_t dyDesc,
107
+ const void *y,
108
+ const cudnnConvolutionDescriptor_t convDesc,
109
+ const cudnnFilterDescriptor_t dwDesc,
110
+ void *dw,
111
+ const int requestedAlgoCount,
112
+ int *returnedAlgoCount,
113
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults,
114
+ void *workSpace,
115
+ size_t workSpaceSizeInBytes);
116
+
117
+ cudnnStatus_t CUDNNWINAPI
118
+ cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle,
119
+ const cudnnTensorDescriptor_t srcDesc,
120
+ const cudnnTensorDescriptor_t diffDesc,
121
+ const cudnnConvolutionDescriptor_t convDesc,
122
+ const cudnnFilterDescriptor_t gradDesc,
123
+ const int requestedAlgoCount,
124
+ int *returnedAlgoCount,
125
+ cudnnConvolutionBwdFilterAlgoPerf_t *perfResults);
126
+
127
+ /*
128
+ * convolution algorithm (which requires potentially some workspace)
129
+ */
130
+
131
+ /* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
132
+ cudnnStatus_t CUDNNWINAPI
133
+ cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle,
134
+ const cudnnTensorDescriptor_t xDesc,
135
+ const cudnnTensorDescriptor_t dyDesc,
136
+ const cudnnConvolutionDescriptor_t convDesc,
137
+ const cudnnFilterDescriptor_t gradDesc,
138
+ cudnnConvolutionBwdFilterAlgo_t algo,
139
+ size_t *sizeInBytes);
140
+
141
+ cudnnStatus_t CUDNNWINAPI
142
+ cudnnConvolutionBackwardFilter(cudnnHandle_t handle,
143
+ const void *alpha,
144
+ const cudnnTensorDescriptor_t xDesc,
145
+ const void *x,
146
+ const cudnnTensorDescriptor_t dyDesc,
147
+ const void *dy,
148
+ const cudnnConvolutionDescriptor_t convDesc,
149
+ cudnnConvolutionBwdFilterAlgo_t algo,
150
+ void *workSpace,
151
+ size_t workSpaceSizeInBytes,
152
+ const void *beta,
153
+ const cudnnFilterDescriptor_t dwDesc,
154
+ void *dw);
155
+
156
+ /* Function to compute the bias gradient for batch convolution */
157
+ cudnnStatus_t CUDNNWINAPI
158
+ cudnnConvolutionBackwardBias(cudnnHandle_t handle,
159
+ const void *alpha,
160
+ const cudnnTensorDescriptor_t dyDesc,
161
+ const void *dy,
162
+ const void *beta,
163
+ const cudnnTensorDescriptor_t dbDesc,
164
+ void *db);
165
+
166
+ cudnnStatus_t CUDNNWINAPI
167
+ cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops);
168
+
169
+ cudnnStatus_t CUDNNWINAPI
170
+ cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack);
171
+
172
+ cudnnStatus_t CUDNNWINAPI
173
+ cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack,
174
+ cudnnFusedOpsConstParamLabel_t paramLabel,
175
+ const void *param);
176
+
177
+ cudnnStatus_t CUDNNWINAPI
178
+ cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack,
179
+ cudnnFusedOpsConstParamLabel_t paramLabel,
180
+ void *param,
181
+ int *isNULL);
182
+
183
+ cudnnStatus_t CUDNNWINAPI
184
+ cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops);
185
+
186
+ cudnnStatus_t CUDNNWINAPI
187
+ cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack);
188
+
189
+ cudnnStatus_t CUDNNWINAPI
190
+ cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack,
191
+ cudnnFusedOpsVariantParamLabel_t paramLabel,
192
+ void *ptr);
193
+
194
+ cudnnStatus_t CUDNNWINAPI
195
+ cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack,
196
+ cudnnFusedOpsVariantParamLabel_t paramLabel,
197
+ void *ptr);
198
+
199
+ cudnnStatus_t CUDNNWINAPI
200
+ cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops);
201
+
202
+ cudnnStatus_t CUDNNWINAPI
203
+ cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan);
204
+
205
+ cudnnStatus_t CUDNNWINAPI
206
+ cudnnMakeFusedOpsPlan(cudnnHandle_t handle,
207
+ cudnnFusedOpsPlan_t plan,
208
+ const cudnnFusedOpsConstParamPack_t constPack,
209
+ size_t *workspaceSizeInBytes);
210
+
211
+ cudnnStatus_t CUDNNWINAPI
212
+ cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack);
213
+
214
+ cudnnStatus_t CUDNNWINAPI
215
+ cudnnCnnTrainVersionCheck(void);
216
+
217
+ #if defined(__cplusplus)
218
+ }
219
+ #endif
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_ops_train_v8.h ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * cudnn_ops_train : cuDNN's basic training operations and algorithms.
52
+ */
53
+
54
+ #if !defined(CUDNN_OPS_TRAIN_H_)
55
+ #define CUDNN_OPS_TRAIN_H_
56
+
57
+ #include <cuda_runtime.h>
58
+ #include <stdint.h>
59
+
60
+ #include "cudnn_version.h"
61
+ #include "cudnn_ops_infer.h"
62
+
63
+ /* These version numbers are autogenerated, do not edit manually. */
64
+ #define CUDNN_OPS_TRAIN_MAJOR 8
65
+ #define CUDNN_OPS_TRAIN_MINOR 7
66
+ #define CUDNN_OPS_TRAIN_PATCH 0
67
+
68
+ #if (CUDNN_OPS_TRAIN_MAJOR != CUDNN_MAJOR) || (CUDNN_OPS_TRAIN_MINOR != CUDNN_MINOR) || \
69
+ (CUDNN_OPS_TRAIN_PATCH != CUDNN_PATCHLEVEL)
70
+ #error Version mismatch in cuDNN OPS TRAIN!!!
71
+ #endif
72
+
73
+ #if defined(__cplusplus)
74
+ extern "C" {
75
+ #endif
76
+
77
+ /* Function to perform backward softmax */
78
+ cudnnStatus_t CUDNNWINAPI
79
+ cudnnSoftmaxBackward(cudnnHandle_t handle,
80
+ cudnnSoftmaxAlgorithm_t algo,
81
+ cudnnSoftmaxMode_t mode,
82
+ const void *alpha,
83
+ const cudnnTensorDescriptor_t yDesc,
84
+ const void *y,
85
+ const cudnnTensorDescriptor_t dyDesc,
86
+ const void *dy,
87
+ const void *beta,
88
+ const cudnnTensorDescriptor_t dxDesc,
89
+ void *dx);
90
+
91
+ /* Function to perform backward pooling */
92
+ cudnnStatus_t CUDNNWINAPI
93
+ cudnnPoolingBackward(cudnnHandle_t handle,
94
+ const cudnnPoolingDescriptor_t poolingDesc,
95
+ const void *alpha,
96
+ const cudnnTensorDescriptor_t yDesc,
97
+ const void *y,
98
+ const cudnnTensorDescriptor_t dyDesc,
99
+ const void *dy,
100
+ const cudnnTensorDescriptor_t xDesc,
101
+ const void *x,
102
+ const void *beta,
103
+ const cudnnTensorDescriptor_t dxDesc,
104
+ void *dx);
105
+
106
+ /* Function to perform backward activation */
107
+ cudnnStatus_t CUDNNWINAPI
108
+ cudnnActivationBackward(cudnnHandle_t handle,
109
+ cudnnActivationDescriptor_t activationDesc,
110
+ const void *alpha,
111
+ const cudnnTensorDescriptor_t yDesc,
112
+ const void *y,
113
+ const cudnnTensorDescriptor_t dyDesc,
114
+ const void *dy,
115
+ const cudnnTensorDescriptor_t xDesc,
116
+ const void *x,
117
+ const void *beta,
118
+ const cudnnTensorDescriptor_t dxDesc,
119
+ void *dx);
120
+
121
+ /* LRN cross-channel backward computation. Double parameters cast to tensor data type */
122
+ cudnnStatus_t CUDNNWINAPI
123
+ cudnnLRNCrossChannelBackward(cudnnHandle_t handle,
124
+ cudnnLRNDescriptor_t normDesc,
125
+ cudnnLRNMode_t lrnMode,
126
+ const void *alpha,
127
+ const cudnnTensorDescriptor_t yDesc,
128
+ const void *y,
129
+ const cudnnTensorDescriptor_t dyDesc,
130
+ const void *dy,
131
+ const cudnnTensorDescriptor_t xDesc,
132
+ const void *x,
133
+ const void *beta,
134
+ const cudnnTensorDescriptor_t dxDesc,
135
+ void *dx);
136
+
137
+ cudnnStatus_t CUDNNWINAPI
138
+ cudnnDivisiveNormalizationBackward(cudnnHandle_t handle,
139
+ cudnnLRNDescriptor_t normDesc,
140
+ cudnnDivNormMode_t mode,
141
+ const void *alpha,
142
+ const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */
143
+ const void *x,
144
+ const void *means, /* if NULL, means are assumed to be zero */
145
+ const void *dy,
146
+ void *temp,
147
+ void *temp2,
148
+ const void *beta,
149
+ const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */
150
+ void *dx, /* output x differential */
151
+ void *dMeans); /* output means differential, can be NULL */
152
+
153
+ cudnnStatus_t CUDNNWINAPI
154
+ cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(cudnnHandle_t handle,
155
+ cudnnBatchNormMode_t mode,
156
+ cudnnBatchNormOps_t bnOps,
157
+ const cudnnTensorDescriptor_t xDesc,
158
+ const cudnnTensorDescriptor_t zDesc,
159
+ const cudnnTensorDescriptor_t yDesc,
160
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
161
+ const cudnnActivationDescriptor_t activationDesc,
162
+ size_t *sizeInBytes);
163
+
164
+ cudnnStatus_t CUDNNWINAPI
165
+ cudnnGetBatchNormalizationBackwardExWorkspaceSize(cudnnHandle_t handle,
166
+ cudnnBatchNormMode_t mode,
167
+ cudnnBatchNormOps_t bnOps,
168
+ const cudnnTensorDescriptor_t xDesc,
169
+ const cudnnTensorDescriptor_t yDesc,
170
+ const cudnnTensorDescriptor_t dyDesc,
171
+ const cudnnTensorDescriptor_t dzDesc,
172
+ const cudnnTensorDescriptor_t dxDesc,
173
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
174
+ const cudnnActivationDescriptor_t activationDesc,
175
+ size_t *sizeInBytes);
176
+
177
+ cudnnStatus_t CUDNNWINAPI
178
+ cudnnGetBatchNormalizationTrainingExReserveSpaceSize(cudnnHandle_t handle,
179
+ cudnnBatchNormMode_t mode,
180
+ cudnnBatchNormOps_t bnOps,
181
+ const cudnnActivationDescriptor_t activationDesc,
182
+ const cudnnTensorDescriptor_t xDesc,
183
+ size_t *sizeInBytes);
184
+
185
+ /* Computes y = BN(x). Also accumulates moving averages of mean and inverse variances */
186
+ cudnnStatus_t CUDNNWINAPI
187
+ cudnnBatchNormalizationForwardTraining(
188
+ cudnnHandle_t handle,
189
+ cudnnBatchNormMode_t mode,
190
+
191
+ const void *alpha, /* alpha[0] = result blend factor */
192
+ const void *beta, /* beta[0] = dest layer blend factor */
193
+
194
+ const cudnnTensorDescriptor_t xDesc,
195
+ const void *x, /* NxCxHxW */
196
+ const cudnnTensorDescriptor_t yDesc,
197
+ void *y, /* NxCxHxW */
198
+
199
+ /* Shared desc for the next 6 tensors in the argument list.
200
+ Data type to be set as follows:
201
+ type = (typeOf(x) == double) ? double : float
202
+ Dimensions for this descriptor depend on normalization mode
203
+ - Spatial Normalization : tensors are expected to have dims 1xCx1x1
204
+ (normalization is performed across NxHxW)
205
+ - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW
206
+ (normalization is performed across N) */
207
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
208
+
209
+ /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */
210
+ const void *bnScale,
211
+ const void *bnBias,
212
+
213
+ /* MUST use factor=1 in the very first call of a complete training cycle.
214
+ Use a factor=1/(1+n) at N-th call to the function to get
215
+ Cumulative Moving Average (CMA) behavior
216
+ CMA[n] = (x[1]+...+x[n])/n
217
+ Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) =
218
+ ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) =
219
+ CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */
220
+ double exponentialAverageFactor,
221
+
222
+ /* Used in Training phase only.
223
+ runningMean = newMean*factor + runningMean*(1-factor) */
224
+ void *resultRunningMean,
225
+ /* Output in training mode, input in inference. Is the moving average
226
+ of variance[x] (factor is applied in the same way as for runningMean) */
227
+ void *resultRunningVariance,
228
+
229
+ /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
230
+ double epsilon,
231
+
232
+ /* Optionally save intermediate results from the forward pass here
233
+ - can be reused to speed up backward pass. NULL if unused */
234
+ void *resultSaveMean,
235
+ void *resultSaveInvVariance);
236
+
237
+ /* Computes y = relu(BN(x) + z). Also accumulates moving averages of mean and inverse variances */
238
+ cudnnStatus_t CUDNNWINAPI
239
+ cudnnBatchNormalizationForwardTrainingEx(
240
+ cudnnHandle_t handle,
241
+ cudnnBatchNormMode_t mode,
242
+ cudnnBatchNormOps_t bnOps,
243
+
244
+ const void *alpha, /* alpha[0] = result blend factor */
245
+ const void *beta, /* beta[0] = dest layer blend factor */
246
+
247
+ const cudnnTensorDescriptor_t xDesc,
248
+ const void *xData,
249
+ const cudnnTensorDescriptor_t zDesc,
250
+ const void *zData,
251
+ const cudnnTensorDescriptor_t yDesc,
252
+ void *yData,
253
+
254
+ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc,
255
+ const void *bnScale,
256
+ const void *bnBias,
257
+
258
+ double exponentialAverageFactor,
259
+ void *resultRunningMean,
260
+ void *resultRunningVariance,
261
+
262
+ /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */
263
+ double epsilon,
264
+
265
+ /* Optionally save intermediate results from the forward pass here
266
+ - can be reused to speed up backward pass. NULL if unused */
267
+ void *resultSaveMean,
268
+ void *resultSaveInvVariance,
269
+
270
+ cudnnActivationDescriptor_t activationDesc,
271
+ void *workspace,
272
+ size_t workSpaceSizeInBytes,
273
+ void *reserveSpace,
274
+ size_t reserveSpaceSizeInBytes);
275
+
276
+ /* Performs backward pass of Batch Normalization layer. Returns x gradient,
277
+ * bnScale gradient and bnBias gradient */
278
+ cudnnStatus_t CUDNNWINAPI
279
+ cudnnBatchNormalizationBackward(cudnnHandle_t handle,
280
+ cudnnBatchNormMode_t mode,
281
+ const void *alphaDataDiff,
282
+ const void *betaDataDiff,
283
+ const void *alphaParamDiff,
284
+ const void *betaParamDiff,
285
+ const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */
286
+ const void *x,
287
+ const cudnnTensorDescriptor_t dyDesc,
288
+ const void *dy,
289
+ const cudnnTensorDescriptor_t dxDesc,
290
+ void *dx,
291
+ /* Shared tensor desc for the 4 tensors below */
292
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
293
+ const void *bnScale, /* bnBias doesn't affect backpropagation */
294
+ /* scale and bias diff are not backpropagated below this layer */
295
+ void *dBnScaleResult,
296
+ void *dBnBiasResult,
297
+ /* Same epsilon as forward pass */
298
+ double epsilon,
299
+
300
+ /* Optionally cached intermediate results from
301
+ forward pass */
302
+ const void *savedMean,
303
+ const void *savedInvVariance);
304
+
305
+ cudnnStatus_t CUDNNWINAPI
306
+ cudnnBatchNormalizationBackwardEx(cudnnHandle_t handle,
307
+ cudnnBatchNormMode_t mode,
308
+ cudnnBatchNormOps_t bnOps,
309
+
310
+ const void *alphaDataDiff,
311
+ const void *betaDataDiff,
312
+ const void *alphaParamDiff,
313
+ const void *betaParamDiff,
314
+ const cudnnTensorDescriptor_t xDesc,
315
+ const void *xData,
316
+ const cudnnTensorDescriptor_t yDesc,
317
+ const void *yData,
318
+ const cudnnTensorDescriptor_t dyDesc,
319
+ const void *dyData,
320
+ const cudnnTensorDescriptor_t dzDesc,
321
+ void *dzData,
322
+ const cudnnTensorDescriptor_t dxDesc,
323
+ void *dxData,
324
+
325
+ /* Shared tensor desc for the 4 tensors below */
326
+ const cudnnTensorDescriptor_t dBnScaleBiasDesc,
327
+ const void *bnScaleData,
328
+ const void *bnBiasData, /* needed if there is activation */
329
+ void *dBnScaleData,
330
+ void *dBnBiasData,
331
+ double epsilon, /* Same epsilon as forward pass */
332
+
333
+ /* Optionally cached intermediate results from
334
+ forward pass */
335
+ const void *savedMean,
336
+ const void *savedInvVariance,
337
+ cudnnActivationDescriptor_t activationDesc,
338
+ void *workSpace,
339
+ size_t workSpaceSizeInBytes,
340
+ void *reserveSpace,
341
+ size_t reserveSpaceSizeInBytes);
342
+
343
+ cudnnStatus_t CUDNNWINAPI
344
+ cudnnGetNormalizationForwardTrainingWorkspaceSize(cudnnHandle_t handle,
345
+ cudnnNormMode_t mode,
346
+ cudnnNormOps_t normOps,
347
+ cudnnNormAlgo_t algo,
348
+ const cudnnTensorDescriptor_t xDesc,
349
+ const cudnnTensorDescriptor_t zDesc,
350
+ const cudnnTensorDescriptor_t yDesc,
351
+ const cudnnTensorDescriptor_t normScaleBiasDesc,
352
+ const cudnnActivationDescriptor_t activationDesc,
353
+ const cudnnTensorDescriptor_t normMeanVarDesc,
354
+ size_t *sizeInBytes,
355
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
356
+
357
+ cudnnStatus_t CUDNNWINAPI
358
+ cudnnGetNormalizationBackwardWorkspaceSize(cudnnHandle_t handle,
359
+ cudnnNormMode_t mode,
360
+ cudnnNormOps_t normOps,
361
+ cudnnNormAlgo_t algo,
362
+ const cudnnTensorDescriptor_t xDesc,
363
+ const cudnnTensorDescriptor_t yDesc,
364
+ const cudnnTensorDescriptor_t dyDesc,
365
+ const cudnnTensorDescriptor_t dzDesc,
366
+ const cudnnTensorDescriptor_t dxDesc,
367
+ const cudnnTensorDescriptor_t dNormScaleBiasDesc,
368
+ const cudnnActivationDescriptor_t activationDesc,
369
+ const cudnnTensorDescriptor_t normMeanVarDesc,
370
+ size_t *sizeInBytes,
371
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
372
+
373
+ cudnnStatus_t CUDNNWINAPI
374
+ cudnnGetNormalizationTrainingReserveSpaceSize(cudnnHandle_t handle,
375
+ cudnnNormMode_t mode,
376
+ cudnnNormOps_t normOps,
377
+ cudnnNormAlgo_t algo,
378
+ const cudnnActivationDescriptor_t activationDesc,
379
+ const cudnnTensorDescriptor_t xDesc,
380
+ size_t *sizeInBytes,
381
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
382
+
383
+ /* Computes y = relu(Norm(x) + z). Also accumulates moving averages of mean and inverse variances */
384
+ cudnnStatus_t CUDNNWINAPI
385
+ cudnnNormalizationForwardTraining(cudnnHandle_t handle,
386
+ cudnnNormMode_t mode,
387
+ cudnnNormOps_t normOps,
388
+ cudnnNormAlgo_t algo,
389
+ const void *alpha, /* alpha[0] = result blend factor */
390
+ const void *beta, /* beta[0] = dest layer blend factor */
391
+ const cudnnTensorDescriptor_t xDesc,
392
+ const void *xData,
393
+ const cudnnTensorDescriptor_t normScaleBiasDesc,
394
+ const void *normScale,
395
+ const void *normBias,
396
+ double exponentialAverageFactor,
397
+ const cudnnTensorDescriptor_t normMeanVarDesc,
398
+ void *resultRunningMean,
399
+ void *resultRunningVariance,
400
+ /* Has to be >= 0. Should be the same in forward and backward functions. */
401
+ double epsilon,
402
+ /* Optionally save intermediate results from the forward pass here
403
+ - can be reused to speed up backward pass. NULL if unused */
404
+ void *resultSaveMean,
405
+ void *resultSaveInvVariance,
406
+ cudnnActivationDescriptor_t activationDesc,
407
+ const cudnnTensorDescriptor_t zDesc,
408
+ const void *zData,
409
+ const cudnnTensorDescriptor_t yDesc,
410
+ void *yData,
411
+ void *workspace,
412
+ size_t workSpaceSizeInBytes,
413
+ void *reserveSpace,
414
+ size_t reserveSpaceSizeInBytes,
415
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
416
+
417
+ cudnnStatus_t CUDNNWINAPI
418
+ cudnnNormalizationBackward(cudnnHandle_t handle,
419
+ cudnnNormMode_t mode,
420
+ cudnnNormOps_t normOps,
421
+ cudnnNormAlgo_t algo,
422
+ const void *alphaDataDiff,
423
+ const void *betaDataDiff,
424
+ const void *alphaParamDiff,
425
+ const void *betaParamDiff,
426
+ const cudnnTensorDescriptor_t xDesc,
427
+ const void *xData,
428
+ const cudnnTensorDescriptor_t yDesc,
429
+ const void *yData,
430
+ const cudnnTensorDescriptor_t dyDesc,
431
+ const void *dyData,
432
+ const cudnnTensorDescriptor_t dzDesc,
433
+ void *dzData,
434
+ const cudnnTensorDescriptor_t dxDesc,
435
+ void *dxData,
436
+ /* Shared tensor desc for the 4 tensors below */
437
+ const cudnnTensorDescriptor_t dNormScaleBiasDesc,
438
+ const void *normScaleData,
439
+ const void *normBiasData, /* needed if there is activation */
440
+ void *dNormScaleData,
441
+ void *dNormBiasData,
442
+ double epsilon, /* Same epsilon as forward pass */
443
+ const cudnnTensorDescriptor_t normMeanVarDesc,
444
+ /* Optionally cached intermediate results from
445
+ forward pass */
446
+ const void *savedMean,
447
+ const void *savedInvVariance,
448
+ cudnnActivationDescriptor_t activationDesc,
449
+ void *workSpace,
450
+ size_t workSpaceSizeInBytes,
451
+ void *reserveSpace,
452
+ size_t reserveSpaceSizeInBytes,
453
+ int groupCnt); /* Place hold for future work, should be set to 1 now*/
454
+
455
+ cudnnStatus_t CUDNNWINAPI
456
+ cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle,
457
+ const cudnnSpatialTransformerDescriptor_t stDesc,
458
+ const void *dgrid,
459
+ void *dtheta);
460
+
461
+ cudnnStatus_t CUDNNWINAPI
462
+ cudnnSpatialTfSamplerBackward(cudnnHandle_t handle,
463
+ cudnnSpatialTransformerDescriptor_t stDesc,
464
+ const void *alpha,
465
+ const cudnnTensorDescriptor_t xDesc,
466
+ const void *x,
467
+ const void *beta,
468
+ const cudnnTensorDescriptor_t dxDesc,
469
+ void *dx,
470
+ const void *alphaDgrid,
471
+ const cudnnTensorDescriptor_t dyDesc,
472
+ const void *dy,
473
+ const void *grid,
474
+ const void *betaDgrid,
475
+ void *dgrid);
476
+
477
+ cudnnStatus_t CUDNNWINAPI
478
+ cudnnDropoutBackward(cudnnHandle_t handle,
479
+ const cudnnDropoutDescriptor_t dropoutDesc,
480
+ const cudnnTensorDescriptor_t dydesc,
481
+ const void *dy,
482
+ const cudnnTensorDescriptor_t dxdesc,
483
+ void *dx,
484
+ void *reserveSpace,
485
+ size_t reserveSpaceSizeInBytes);
486
+
487
+ /*
488
+ * \brief Cross-library version checker.
489
+ * This function is implemented differently in each sub-library. Each sublib
490
+ * checks whether its own version matches that of its dependencies.
491
+ * \returns CUDNN_STATUS_SUCCESS if the version check passes,
492
+ * CUDNN_STATUS_VERSION_MISMATCH if the versions are inconsistent.
493
+ */
494
+ cudnnStatus_t CUDNNWINAPI
495
+ cudnnOpsTrainVersionCheck(void);
496
+
497
+ #if defined(__cplusplus)
498
+ }
499
+ #endif
500
+
501
+ #endif /* CUDNN_OPS_TRAIN_H_ */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_v8.h ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /* cudnn : Neural Networks Library
51
+
52
+ */
53
+
54
+ #if !defined(CUDNN_H_)
55
+ #define CUDNN_H_
56
+
57
+ #include <cuda_runtime.h>
58
+ #include <stdint.h>
59
+
60
+ #include "cudnn_version.h"
61
+ #include "cudnn_ops_infer.h"
62
+ #include "cudnn_ops_train.h"
63
+ #include "cudnn_adv_infer.h"
64
+ #include "cudnn_adv_train.h"
65
+ #include "cudnn_cnn_infer.h"
66
+ #include "cudnn_cnn_train.h"
67
+
68
+ #include "cudnn_backend.h"
69
+
70
+ #if defined(__cplusplus)
71
+ extern "C" {
72
+ #endif
73
+
74
+ #if defined(__cplusplus)
75
+ }
76
+ #endif
77
+
78
+ #endif /* CUDNN_H_ */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version_v8.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /**
51
+ * \file: The master cuDNN version file.
52
+ */
53
+
54
+ #ifndef CUDNN_VERSION_H_
55
+ #define CUDNN_VERSION_H_
56
+
57
+ #define CUDNN_MAJOR 8
58
+ #define CUDNN_MINOR 7
59
+ #define CUDNN_PATCHLEVEL 0
60
+
61
+ #define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
62
+
63
+ /* cannot use constexpr here since this is a C-only file */
64
+ /* Below is the max SM version this cuDNN library is aware of and supports natively */
65
+
66
+ #define CUDNN_MAX_SM_MAJOR_NUMBER 9
67
+ #define CUDNN_MAX_SM_MINOR_NUMBER 0
68
+ #define CUDNN_MAX_DEVICE_VERSION (CUDNN_MAX_SM_MAJOR_NUMBER * 100) + (CUDNN_MAX_SM_MINOR_NUMBER * 10)
69
+
70
+ #endif /* CUDNN_VERSION_H */
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/curand_kernel.h ADDED
@@ -0,0 +1,1665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /* Copyright 2010-2014 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * The source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * The Licensed Deliverables contained herein are PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+
51
+ #if !defined(CURAND_KERNEL_H_)
52
+ #define CURAND_KERNEL_H_
53
+
54
+ /**
55
+ * \defgroup DEVICE Device API
56
+ *
57
+ * @{
58
+ */
59
+
60
+ #if !defined(QUALIFIERS)
61
+ #define QUALIFIERS static __forceinline__ __device__
62
+ #endif
63
+
64
+
65
+ #ifdef __CUDACC_RTC__
66
+ #define CURAND_DETAIL_USE_CUDA_STL
67
+ #endif
68
+
69
+ #if __cplusplus >= 201103L
70
+ # ifdef CURAND_DETAIL_USE_CUDA_STL
71
+ # define CURAND_STD cuda::std
72
+ # include <cuda/std/type_traits>
73
+ # else
74
+ # define CURAND_STD std
75
+ # include <type_traits>
76
+ # endif // CURAND_DETAIL_USE_CUDA_STL
77
+ #else
78
+ // To support C++03 compilation
79
+ # define CURAND_STD curand_detail
80
+ namespace curand_detail {
81
+ template<bool B, class T = void>
82
+ struct enable_if {};
83
+
84
+ template<class T>
85
+ struct enable_if<true, T> { typedef T type; };
86
+
87
+ template<class T, class U>
88
+ struct is_same { static const bool value = false; };
89
+
90
+ template<class T>
91
+ struct is_same<T, T> { static const bool value = true; };
92
+ } // namespace curand_detail
93
+ #endif // __cplusplus >= 201103L
94
+
95
+ #ifndef __CUDACC_RTC__
96
+ #include <math.h>
97
+ #endif // __CUDACC_RTC__
98
+
99
+ #include "curand.h"
100
+ #include "curand_discrete.h"
101
+ #include "curand_precalc.h"
102
+ #include "curand_mrg32k3a.h"
103
+ #include "curand_mtgp32_kernel.h"
104
+ #include "curand_philox4x32_x.h"
105
+ #include "curand_globals.h"
106
+
107
+ /* Test RNG */
108
+ /* This generator uses the formula:
109
+ x_n = x_(n-1) + 1 mod 2^32
110
+ x_0 = (unsigned int)seed * 3
111
+ Subsequences are spaced 31337 steps apart.
112
+ */
113
+ struct curandStateTest {
114
+ unsigned int v;
115
+ };
116
+
117
+ /** \cond UNHIDE_TYPEDEFS */
118
+ typedef struct curandStateTest curandStateTest_t;
119
+ /** \endcond */
120
+
121
+ /* XORSHIFT FAMILY RNGs */
122
+ /* These generators are a family proposed by Marsaglia. They keep state
123
+ in 32 bit chunks, then use repeated shift and xor operations to scramble
124
+ the bits. The following generators are a combination of a simple Weyl
125
+ generator with an N variable XORSHIFT generator.
126
+ */
127
+
128
+ /* XORSHIFT RNG */
129
+ /* This generator uses the xorwow formula of
130
+ www.jstatsoft.org/v08/i14/paper page 5
131
+ Has period 2^192 - 2^32.
132
+ */
133
+ /**
134
+ * CURAND XORWOW state
135
+ */
136
+ struct curandStateXORWOW;
137
+
138
+ /*
139
+ * Implementation details not in reference documentation */
140
+ struct curandStateXORWOW {
141
+ unsigned int d, v[5];
142
+ int boxmuller_flag;
143
+ int boxmuller_flag_double;
144
+ float boxmuller_extra;
145
+ double boxmuller_extra_double;
146
+ };
147
+
148
+ /*
149
+ * CURAND XORWOW state
150
+ */
151
+ /** \cond UNHIDE_TYPEDEFS */
152
+ typedef struct curandStateXORWOW curandStateXORWOW_t;
153
+
154
+ #define EXTRA_FLAG_NORMAL 0x00000001
155
+ #define EXTRA_FLAG_LOG_NORMAL 0x00000002
156
+ /** \endcond */
157
+
158
+ /* Combined Multiple Recursive Generators */
159
+ /* These generators are a family proposed by L'Ecuyer. They keep state
160
+ in sets of doubles, then use repeated modular arithmetic multiply operations
161
+ to scramble the bits in each set, and combine the result.
162
+ */
163
+
164
+ /* MRG32k3a RNG */
165
+ /* This generator uses the MRG32k3A formula of
166
+ http://www.iro.umontreal.ca/~lecuyer/myftp/streams00/c++/streams4.pdf
167
+ Has period 2^191.
168
+ */
169
+
170
+ /* moduli for the recursions */
171
+ /** \cond UNHIDE_DEFINES */
172
+ #define MRG32K3A_MOD1 4294967087.
173
+ #define MRG32K3A_MOD2 4294944443.
174
+
175
+ /* Constants used in generation */
176
+
177
+ #define MRG32K3A_A12 1403580.
178
+ #define MRG32K3A_A13N 810728.
179
+ #define MRG32K3A_A21 527612.
180
+ #define MRG32K3A_A23N 1370589.
181
+ #define MRG32K3A_NORM (2.3283065498378288e-10)
182
+ //
183
+ // #define MRG32K3A_BITS_NORM ((double)((POW32_DOUBLE-1.0)/MOD1))
184
+ // above constant, used verbatim, rounds differently on some host systems.
185
+ #define MRG32K3A_BITS_NORM 1.000000048662
186
+
187
+ /** \endcond */
188
+
189
+
190
+
191
+
192
+ /**
193
+ * CURAND MRG32K3A state
194
+ */
195
+ struct curandStateMRG32k3a;
196
+
197
+ /* Implementation details not in reference documentation */
198
+ struct curandStateMRG32k3a {
199
+ unsigned int s1[3];
200
+ unsigned int s2[3];
201
+ int boxmuller_flag;
202
+ int boxmuller_flag_double;
203
+ float boxmuller_extra;
204
+ double boxmuller_extra_double;
205
+ };
206
+
207
+ /*
208
+ * CURAND MRG32K3A state
209
+ */
210
+ /** \cond UNHIDE_TYPEDEFS */
211
+ typedef struct curandStateMRG32k3a curandStateMRG32k3a_t;
212
+ /** \endcond */
213
+
214
+ /* SOBOL QRNG */
215
+ /**
216
+ * CURAND Sobol32 state
217
+ */
218
+ struct curandStateSobol32;
219
+
220
+ /* Implementation details not in reference documentation */
221
+ struct curandStateSobol32 {
222
+ unsigned int i, x, c;
223
+ unsigned int direction_vectors[32];
224
+ };
225
+
226
+ /*
227
+ * CURAND Sobol32 state
228
+ */
229
+ /** \cond UNHIDE_TYPEDEFS */
230
+ typedef struct curandStateSobol32 curandStateSobol32_t;
231
+ /** \endcond */
232
+
233
+ /**
234
+ * CURAND Scrambled Sobol32 state
235
+ */
236
+ struct curandStateScrambledSobol32;
237
+
238
+ /* Implementation details not in reference documentation */
239
+ struct curandStateScrambledSobol32 {
240
+ unsigned int i, x, c;
241
+ unsigned int direction_vectors[32];
242
+ };
243
+
244
+ /*
245
+ * CURAND Scrambled Sobol32 state
246
+ */
247
+ /** \cond UNHIDE_TYPEDEFS */
248
+ typedef struct curandStateScrambledSobol32 curandStateScrambledSobol32_t;
249
+ /** \endcond */
250
+
251
+ /**
252
+ * CURAND Sobol64 state
253
+ */
254
+ struct curandStateSobol64;
255
+
256
+ /* Implementation details not in reference documentation */
257
+ struct curandStateSobol64 {
258
+ unsigned long long i, x, c;
259
+ unsigned long long direction_vectors[64];
260
+ };
261
+
262
+ /*
263
+ * CURAND Sobol64 state
264
+ */
265
+ /** \cond UNHIDE_TYPEDEFS */
266
+ typedef struct curandStateSobol64 curandStateSobol64_t;
267
+ /** \endcond */
268
+
269
+ /**
270
+ * CURAND Scrambled Sobol64 state
271
+ */
272
+ struct curandStateScrambledSobol64;
273
+
274
+ /* Implementation details not in reference documentation */
275
+ struct curandStateScrambledSobol64 {
276
+ unsigned long long i, x, c;
277
+ unsigned long long direction_vectors[64];
278
+ };
279
+
280
+ /*
281
+ * CURAND Scrambled Sobol64 state
282
+ */
283
+ /** \cond UNHIDE_TYPEDEFS */
284
+ typedef struct curandStateScrambledSobol64 curandStateScrambledSobol64_t;
285
+ /** \endcond */
286
+
287
+ /*
288
+ * Default RNG
289
+ */
290
+ /** \cond UNHIDE_TYPEDEFS */
291
+ typedef struct curandStateXORWOW curandState_t;
292
+ typedef struct curandStateXORWOW curandState;
293
+ /** \endcond */
294
+
295
+ /****************************************************************************/
296
+ /* Utility functions needed by RNGs */
297
+ /****************************************************************************/
298
+ /** \cond UNHIDE_UTILITIES */
299
+ /*
300
+ multiply vector by matrix, store in result
301
+ matrix is n x n, measured in 32 bit units
302
+ matrix is stored in row major order
303
+ vector and result cannot be same pointer
304
+ */
305
+ template<int N>
306
+ QUALIFIERS void __curand_matvec_inplace(unsigned int *vector, unsigned int *matrix)
307
+ {
308
+ unsigned int result[N] = { 0 };
309
+ for(int i = 0; i < N; i++) {
310
+ #ifdef __CUDA_ARCH__
311
+ #pragma unroll 16
312
+ #endif
313
+ for(int j = 0; j < 32; j++) {
314
+ if(vector[i] & (1 << j)) {
315
+ for(int k = 0; k < N; k++) {
316
+ result[k] ^= matrix[N * (i * 32 + j) + k];
317
+ }
318
+ }
319
+ }
320
+ }
321
+ for(int i = 0; i < N; i++) {
322
+ vector[i] = result[i];
323
+ }
324
+ }
325
+
326
+ QUALIFIERS void __curand_matvec(unsigned int *vector, unsigned int *matrix,
327
+ unsigned int *result, int n)
328
+ {
329
+ for(int i = 0; i < n; i++) {
330
+ result[i] = 0;
331
+ }
332
+ for(int i = 0; i < n; i++) {
333
+ for(int j = 0; j < 32; j++) {
334
+ if(vector[i] & (1 << j)) {
335
+ for(int k = 0; k < n; k++) {
336
+ result[k] ^= matrix[n * (i * 32 + j) + k];
337
+ }
338
+ }
339
+ }
340
+ }
341
+ }
342
+
343
+ /* generate identity matrix */
344
+ QUALIFIERS void __curand_matidentity(unsigned int *matrix, int n)
345
+ {
346
+ int r;
347
+ for(int i = 0; i < n * 32; i++) {
348
+ for(int j = 0; j < n; j++) {
349
+ r = i & 31;
350
+ if(i / 32 == j) {
351
+ matrix[i * n + j] = (1 << r);
352
+ } else {
353
+ matrix[i * n + j] = 0;
354
+ }
355
+ }
356
+ }
357
+ }
358
+
359
+ /* multiply matrixA by matrixB, store back in matrixA
360
+ matrixA and matrixB must not be same matrix */
361
+ QUALIFIERS void __curand_matmat(unsigned int *matrixA, unsigned int *matrixB, int n)
362
+ {
363
+ unsigned int result[MAX_XOR_N];
364
+ for(int i = 0; i < n * 32; i++) {
365
+ __curand_matvec(matrixA + i * n, matrixB, result, n);
366
+ for(int j = 0; j < n; j++) {
367
+ matrixA[i * n + j] = result[j];
368
+ }
369
+ }
370
+ }
371
+
372
+ /* copy vectorA to vector */
373
+ QUALIFIERS void __curand_veccopy(unsigned int *vector, unsigned int *vectorA, int n)
374
+ {
375
+ for(int i = 0; i < n; i++) {
376
+ vector[i] = vectorA[i];
377
+ }
378
+ }
379
+
380
+ /* copy matrixA to matrix */
381
+ QUALIFIERS void __curand_matcopy(unsigned int *matrix, unsigned int *matrixA, int n)
382
+ {
383
+ for(int i = 0; i < n * n * 32; i++) {
384
+ matrix[i] = matrixA[i];
385
+ }
386
+ }
387
+
388
+ /* compute matrixA to power p, store result in matrix */
389
+ QUALIFIERS void __curand_matpow(unsigned int *matrix, unsigned int *matrixA,
390
+ unsigned long long p, int n)
391
+ {
392
+ unsigned int matrixR[MAX_XOR_N * MAX_XOR_N * 32];
393
+ unsigned int matrixS[MAX_XOR_N * MAX_XOR_N * 32];
394
+ __curand_matidentity(matrix, n);
395
+ __curand_matcopy(matrixR, matrixA, n);
396
+ while(p) {
397
+ if(p & 1) {
398
+ __curand_matmat(matrix, matrixR, n);
399
+ }
400
+ __curand_matcopy(matrixS, matrixR, n);
401
+ __curand_matmat(matrixR, matrixS, n);
402
+ p >>= 1;
403
+ }
404
+ }
405
+
406
+ /****************************************************************************/
407
+ /* Utility functions needed by MRG32k3a RNG */
408
+ /* Matrix operations modulo some integer less than 2**32, done in */
409
+ /* double precision floating point, with care not to overflow 53 bits */
410
+ /****************************************************************************/
411
+
412
+ /* return i mod m. */
413
+ /* assumes i and m are integers represented accurately in doubles */
414
+
415
+ QUALIFIERS double curand_MRGmod(double i, double m)
416
+ {
417
+ double quo;
418
+ double rem;
419
+ quo = floor(i/m);
420
+ rem = i - (quo*m);
421
+ if (rem < 0.0) rem += m;
422
+ return rem;
423
+ }
424
+
425
+ /* Multiplication modulo m. Inputs i and j less than 2**32 */
426
+ /* Ensure intermediate results do not exceed 2**53 */
427
+
428
+ QUALIFIERS double curand_MRGmodMul(double i, double j, double m)
429
+ {
430
+ double tempHi;
431
+ double tempLo;
432
+
433
+ tempHi = floor(i/131072.0);
434
+ tempLo = i - (tempHi*131072.0);
435
+ tempLo = curand_MRGmod( curand_MRGmod( (tempHi * j), m) * 131072.0 + curand_MRGmod(tempLo * j, m),m);
436
+
437
+ if (tempLo < 0.0) tempLo += m;
438
+ return tempLo;
439
+ }
440
+
441
+ /* multiply 3 by 3 matrices of doubles, modulo m */
442
+
443
+ QUALIFIERS void curand_MRGmatMul3x3(unsigned int i1[][3],unsigned int i2[][3],unsigned int o[][3],double m)
444
+ {
445
+ int i,j;
446
+ double temp[3][3];
447
+ for (i=0; i<3; i++){
448
+ for (j=0; j<3; j++){
449
+ temp[i][j] = ( curand_MRGmodMul(i1[i][0], i2[0][j], m) +
450
+ curand_MRGmodMul(i1[i][1], i2[1][j], m) +
451
+ curand_MRGmodMul(i1[i][2], i2[2][j], m));
452
+ temp[i][j] = curand_MRGmod( temp[i][j], m );
453
+ }
454
+ }
455
+ for (i=0; i<3; i++){
456
+ for (j=0; j<3; j++){
457
+ o[i][j] = (unsigned int)temp[i][j];
458
+ }
459
+ }
460
+ }
461
+
462
+ /* multiply 3 by 3 matrix times 3 by 1 vector of doubles, modulo m */
463
+
464
+ QUALIFIERS void curand_MRGmatVecMul3x3( unsigned int i[][3], unsigned int v[], double m)
465
+ {
466
+ int k;
467
+ double t[3];
468
+ for (k = 0; k < 3; k++) {
469
+ t[k] = ( curand_MRGmodMul(i[k][0], v[0], m) +
470
+ curand_MRGmodMul(i[k][1], v[1], m) +
471
+ curand_MRGmodMul(i[k][2], v[2], m) );
472
+ t[k] = curand_MRGmod( t[k], m );
473
+ }
474
+ for (k = 0; k < 3; k++) {
475
+ v[k] = (unsigned int)t[k];
476
+ }
477
+
478
+ }
479
+
480
+ /* raise a 3 by 3 matrix of doubles to a 64 bit integer power pow, modulo m */
481
+ /* input is index zero of an array of 3 by 3 matrices m, */
482
+ /* each m = m[0]**(2**index) */
483
+
484
+ QUALIFIERS void curand_MRGmatPow3x3( unsigned int in[][3][3], unsigned int o[][3], double m, unsigned long long pow )
485
+ {
486
+ int i,j;
487
+ for ( i = 0; i < 3; i++ ) {
488
+ for ( j = 0; j < 3; j++ ) {
489
+ o[i][j] = 0;
490
+ if ( i == j ) o[i][j] = 1;
491
+ }
492
+ }
493
+ i = 0;
494
+ curand_MRGmatVecMul3x3(o,o[0],m);
495
+ while (pow) {
496
+ if ( pow & 1ll ) {
497
+ curand_MRGmatMul3x3(in[i], o, o, m);
498
+ }
499
+ i++;
500
+ pow >>= 1;
501
+ }
502
+ }
503
+
504
+ /* raise a 3 by 3 matrix of doubles to the power */
505
+ /* 2 to the power (pow modulo 191), modulo m */
506
+
507
+ QUALIFIERS void curnand_MRGmatPow2Pow3x3( double in[][3], double o[][3], double m, unsigned long pow )
508
+ {
509
+ unsigned int temp[3][3];
510
+ int i,j;
511
+ pow = pow % 191;
512
+ for ( i = 0; i < 3; i++ ) {
513
+ for ( j = 0; j < 3; j++ ) {
514
+ temp[i][j] = (unsigned int)in[i][j];
515
+ }
516
+ }
517
+ while (pow) {
518
+ curand_MRGmatMul3x3(temp, temp, temp, m);
519
+ pow--;
520
+ }
521
+ for ( i = 0; i < 3; i++ ) {
522
+ for ( j = 0; j < 3; j++ ) {
523
+ o[i][j] = temp[i][j];
524
+ }
525
+ }
526
+ }
527
+
528
+ /** \endcond */
529
+
530
+ /****************************************************************************/
531
+ /* Kernel implementations of RNGs */
532
+ /****************************************************************************/
533
+
534
+ /* Test RNG */
535
+
536
+ QUALIFIERS void curand_init(unsigned long long seed,
537
+ unsigned long long subsequence,
538
+ unsigned long long offset,
539
+ curandStateTest_t *state)
540
+ {
541
+ state->v = (unsigned int)(seed * 3) + (unsigned int)(subsequence * 31337) + \
542
+ (unsigned int)offset;
543
+ }
544
+
545
+
546
+ QUALIFIERS unsigned int curand(curandStateTest_t *state)
547
+ {
548
+ unsigned int r = state->v++;
549
+ return r;
550
+ }
551
+
552
+ QUALIFIERS void skipahead(unsigned long long n, curandStateTest_t *state)
553
+ {
554
+ state->v += (unsigned int)n;
555
+ }
556
+
557
+ /* XORWOW RNG */
558
+
559
+ template <typename T, int n>
560
+ QUALIFIERS void __curand_generate_skipahead_matrix_xor(unsigned int matrix[])
561
+ {
562
+ T state;
563
+ // Generate matrix that advances one step
564
+ // matrix has n * n * 32 32-bit elements
565
+ // solve for matrix by stepping single bit states
566
+ for(int i = 0; i < 32 * n; i++) {
567
+ state.d = 0;
568
+ for(int j = 0; j < n; j++) {
569
+ state.v[j] = 0;
570
+ }
571
+ state.v[i / 32] = (1 << (i & 31));
572
+ curand(&state);
573
+ for(int j = 0; j < n; j++) {
574
+ matrix[i * n + j] = state.v[j];
575
+ }
576
+ }
577
+ }
578
+
579
+ template <typename T, int n>
580
+ QUALIFIERS void _skipahead_scratch(unsigned long long x, T *state, unsigned int *scratch)
581
+ {
582
+ // unsigned int matrix[n * n * 32];
583
+ unsigned int *matrix = scratch;
584
+ // unsigned int matrixA[n * n * 32];
585
+ unsigned int *matrixA = scratch + (n * n * 32);
586
+ // unsigned int vector[n];
587
+ unsigned int *vector = scratch + (n * n * 32) + (n * n * 32);
588
+ // unsigned int result[n];
589
+ unsigned int *result = scratch + (n * n * 32) + (n * n * 32) + n;
590
+ unsigned long long p = x;
591
+ for(int i = 0; i < n; i++) {
592
+ vector[i] = state->v[i];
593
+ }
594
+ int matrix_num = 0;
595
+ while(p && (matrix_num < PRECALC_NUM_MATRICES - 1)) {
596
+ for(unsigned int t = 0; t < (p & PRECALC_BLOCK_MASK); t++) {
597
+ #ifdef __CUDA_ARCH__
598
+ __curand_matvec(vector, precalc_xorwow_offset_matrix[matrix_num], result, n);
599
+ #else
600
+ __curand_matvec(vector, precalc_xorwow_offset_matrix_host[matrix_num], result, n);
601
+ #endif
602
+ __curand_veccopy(vector, result, n);
603
+ }
604
+ p >>= PRECALC_BLOCK_SIZE;
605
+ matrix_num++;
606
+ }
607
+ if(p) {
608
+ #ifdef __CUDA_ARCH__
609
+ __curand_matcopy(matrix, precalc_xorwow_offset_matrix[PRECALC_NUM_MATRICES - 1], n);
610
+ __curand_matcopy(matrixA, precalc_xorwow_offset_matrix[PRECALC_NUM_MATRICES - 1], n);
611
+ #else
612
+ __curand_matcopy(matrix, precalc_xorwow_offset_matrix_host[PRECALC_NUM_MATRICES - 1], n);
613
+ __curand_matcopy(matrixA, precalc_xorwow_offset_matrix_host[PRECALC_NUM_MATRICES - 1], n);
614
+ #endif
615
+ }
616
+ while(p) {
617
+ for(unsigned int t = 0; t < (p & SKIPAHEAD_MASK); t++) {
618
+ __curand_matvec(vector, matrixA, result, n);
619
+ __curand_veccopy(vector, result, n);
620
+ }
621
+ p >>= SKIPAHEAD_BLOCKSIZE;
622
+ if(p) {
623
+ for(int i = 0; i < SKIPAHEAD_BLOCKSIZE; i++) {
624
+ __curand_matmat(matrix, matrixA, n);
625
+ __curand_matcopy(matrixA, matrix, n);
626
+ }
627
+ }
628
+ }
629
+ for(int i = 0; i < n; i++) {
630
+ state->v[i] = vector[i];
631
+ }
632
+ state->d += 362437 * (unsigned int)x;
633
+ }
634
+
635
+ template <typename T, int n>
636
+ QUALIFIERS void _skipahead_sequence_scratch(unsigned long long x, T *state, unsigned int *scratch)
637
+ {
638
+ // unsigned int matrix[n * n * 32];
639
+ unsigned int *matrix = scratch;
640
+ // unsigned int matrixA[n * n * 32];
641
+ unsigned int *matrixA = scratch + (n * n * 32);
642
+ // unsigned int vector[n];
643
+ unsigned int *vector = scratch + (n * n * 32) + (n * n * 32);
644
+ // unsigned int result[n];
645
+ unsigned int *result = scratch + (n * n * 32) + (n * n * 32) + n;
646
+ unsigned long long p = x;
647
+ for(int i = 0; i < n; i++) {
648
+ vector[i] = state->v[i];
649
+ }
650
+ int matrix_num = 0;
651
+ while(p && matrix_num < PRECALC_NUM_MATRICES - 1) {
652
+ for(unsigned int t = 0; t < (p & PRECALC_BLOCK_MASK); t++) {
653
+ #ifdef __CUDA_ARCH__
654
+ __curand_matvec(vector, precalc_xorwow_matrix[matrix_num], result, n);
655
+ #else
656
+ __curand_matvec(vector, precalc_xorwow_matrix_host[matrix_num], result, n);
657
+ #endif
658
+ __curand_veccopy(vector, result, n);
659
+ }
660
+ p >>= PRECALC_BLOCK_SIZE;
661
+ matrix_num++;
662
+ }
663
+ if(p) {
664
+ #ifdef __CUDA_ARCH__
665
+ __curand_matcopy(matrix, precalc_xorwow_matrix[PRECALC_NUM_MATRICES - 1], n);
666
+ __curand_matcopy(matrixA, precalc_xorwow_matrix[PRECALC_NUM_MATRICES - 1], n);
667
+ #else
668
+ __curand_matcopy(matrix, precalc_xorwow_matrix_host[PRECALC_NUM_MATRICES - 1], n);
669
+ __curand_matcopy(matrixA, precalc_xorwow_matrix_host[PRECALC_NUM_MATRICES - 1], n);
670
+ #endif
671
+ }
672
+ while(p) {
673
+ for(unsigned int t = 0; t < (p & SKIPAHEAD_MASK); t++) {
674
+ __curand_matvec(vector, matrixA, result, n);
675
+ __curand_veccopy(vector, result, n);
676
+ }
677
+ p >>= SKIPAHEAD_BLOCKSIZE;
678
+ if(p) {
679
+ for(int i = 0; i < SKIPAHEAD_BLOCKSIZE; i++) {
680
+ __curand_matmat(matrix, matrixA, n);
681
+ __curand_matcopy(matrixA, matrix, n);
682
+ }
683
+ }
684
+ }
685
+ for(int i = 0; i < n; i++) {
686
+ state->v[i] = vector[i];
687
+ }
688
+ /* No update of state->d needed, guaranteed to be a multiple of 2^32 */
689
+ }
690
+
691
+ template <typename T, int N>
692
+ QUALIFIERS void _skipahead_inplace(const unsigned long long x, T *state)
693
+ {
694
+ unsigned long long p = x;
695
+ int matrix_num = 0;
696
+ while(p) {
697
+ for(unsigned int t = 0; t < (p & PRECALC_BLOCK_MASK); t++) {
698
+ #ifdef __CUDA_ARCH__
699
+ __curand_matvec_inplace<N>(state->v, precalc_xorwow_offset_matrix[matrix_num]);
700
+ #else
701
+ __curand_matvec_inplace<N>(state->v, precalc_xorwow_offset_matrix_host[matrix_num]);
702
+ #endif
703
+ }
704
+ p >>= PRECALC_BLOCK_SIZE;
705
+ matrix_num++;
706
+ }
707
+ state->d += 362437 * (unsigned int)x;
708
+ }
709
+
710
+ template <typename T, int N>
711
+ QUALIFIERS void _skipahead_sequence_inplace(unsigned long long x, T *state)
712
+ {
713
+ int matrix_num = 0;
714
+ while(x) {
715
+ for(unsigned int t = 0; t < (x & PRECALC_BLOCK_MASK); t++) {
716
+ #ifdef __CUDA_ARCH__
717
+ __curand_matvec_inplace<N>(state->v, precalc_xorwow_matrix[matrix_num]);
718
+ #else
719
+ __curand_matvec_inplace<N>(state->v, precalc_xorwow_matrix_host[matrix_num]);
720
+ #endif
721
+ }
722
+ x >>= PRECALC_BLOCK_SIZE;
723
+ matrix_num++;
724
+ }
725
+ /* No update of state->d needed, guaranteed to be a multiple of 2^32 */
726
+ }
727
+
728
+ /**
729
+ * \brief Update XORWOW state to skip \p n elements.
730
+ *
731
+ * Update the XORWOW state in \p state to skip ahead \p n elements.
732
+ *
733
+ * All values of \p n are valid. Large values require more computation and so
734
+ * will take more time to complete.
735
+ *
736
+ * \param n - Number of elements to skip
737
+ * \param state - Pointer to state to update
738
+ */
739
+ QUALIFIERS void skipahead(unsigned long long n, curandStateXORWOW_t *state)
740
+ {
741
+ _skipahead_inplace<curandStateXORWOW_t, 5>(n, state);
742
+ }
743
+
744
+ /**
745
+ * \brief Update XORWOW state to skip ahead \p n subsequences.
746
+ *
747
+ * Update the XORWOW state in \p state to skip ahead \p n subsequences. Each
748
+ * subsequence is \xmlonly<ph outputclass="xmlonly">2<sup>67</sup></ph>\endxmlonly elements long, so this means the function will skip ahead
749
+ * \xmlonly<ph outputclass="xmlonly">2<sup>67</sup></ph>\endxmlonly * n elements.
750
+ *
751
+ * All values of \p n are valid. Large values require more computation and so
752
+ * will take more time to complete.
753
+ *
754
+ * \param n - Number of subsequences to skip
755
+ * \param state - Pointer to state to update
756
+ */
757
+ QUALIFIERS void skipahead_sequence(unsigned long long n, curandStateXORWOW_t *state)
758
+ {
759
+ _skipahead_sequence_inplace<curandStateXORWOW_t, 5>(n, state);
760
+ }
761
+
762
+ QUALIFIERS void _curand_init_scratch(unsigned long long seed,
763
+ unsigned long long subsequence,
764
+ unsigned long long offset,
765
+ curandStateXORWOW_t *state,
766
+ unsigned int *scratch)
767
+ {
768
+ // Break up seed, apply salt
769
+ // Constants are arbitrary nonzero values
770
+ unsigned int s0 = ((unsigned int)seed) ^ 0xaad26b49UL;
771
+ unsigned int s1 = (unsigned int)(seed >> 32) ^ 0xf7dcefddUL;
772
+ // Simple multiplication to mix up bits
773
+ // Constants are arbitrary odd values
774
+ unsigned int t0 = 1099087573UL * s0;
775
+ unsigned int t1 = 2591861531UL * s1;
776
+ state->d = 6615241 + t1 + t0;
777
+ state->v[0] = 123456789UL + t0;
778
+ state->v[1] = 362436069UL ^ t0;
779
+ state->v[2] = 521288629UL + t1;
780
+ state->v[3] = 88675123UL ^ t1;
781
+ state->v[4] = 5783321UL + t0;
782
+ _skipahead_sequence_scratch<curandStateXORWOW_t, 5>(subsequence, state, scratch);
783
+ _skipahead_scratch<curandStateXORWOW_t, 5>(offset, state, scratch);
784
+ state->boxmuller_flag = 0;
785
+ state->boxmuller_flag_double = 0;
786
+ state->boxmuller_extra = 0.f;
787
+ state->boxmuller_extra_double = 0.;
788
+ }
789
+
790
+ QUALIFIERS void _curand_init_inplace(unsigned long long seed,
791
+ unsigned long long subsequence,
792
+ unsigned long long offset,
793
+ curandStateXORWOW_t *state)
794
+ {
795
+ // Break up seed, apply salt
796
+ // Constants are arbitrary nonzero values
797
+ unsigned int s0 = ((unsigned int)seed) ^ 0xaad26b49UL;
798
+ unsigned int s1 = (unsigned int)(seed >> 32) ^ 0xf7dcefddUL;
799
+ // Simple multiplication to mix up bits
800
+ // Constants are arbitrary odd values
801
+ unsigned int t0 = 1099087573UL * s0;
802
+ unsigned int t1 = 2591861531UL * s1;
803
+ state->d = 6615241 + t1 + t0;
804
+ state->v[0] = 123456789UL + t0;
805
+ state->v[1] = 362436069UL ^ t0;
806
+ state->v[2] = 521288629UL + t1;
807
+ state->v[3] = 88675123UL ^ t1;
808
+ state->v[4] = 5783321UL + t0;
809
+ _skipahead_sequence_inplace<curandStateXORWOW_t, 5>(subsequence, state);
810
+ _skipahead_inplace<curandStateXORWOW_t, 5>(offset, state);
811
+ state->boxmuller_flag = 0;
812
+ state->boxmuller_flag_double = 0;
813
+ state->boxmuller_extra = 0.f;
814
+ state->boxmuller_extra_double = 0.;
815
+ }
816
+
817
+ /**
818
+ * \brief Initialize XORWOW state.
819
+ *
820
+ * Initialize XORWOW state in \p state with the given \p seed, \p subsequence,
821
+ * and \p offset.
822
+ *
823
+ * All input values of \p seed, \p subsequence, and \p offset are legal. Large
824
+ * values for \p subsequence and \p offset require more computation and so will
825
+ * take more time to complete.
826
+ *
827
+ * A value of 0 for \p seed sets the state to the values of the original
828
+ * published version of the \p xorwow algorithm.
829
+ *
830
+ * \param seed - Arbitrary bits to use as a seed
831
+ * \param subsequence - Subsequence to start at
832
+ * \param offset - Absolute offset into sequence
833
+ * \param state - Pointer to state to initialize
834
+ */
835
+ QUALIFIERS void curand_init(unsigned long long seed,
836
+ unsigned long long subsequence,
837
+ unsigned long long offset,
838
+ curandStateXORWOW_t *state)
839
+ {
840
+ _curand_init_inplace(seed, subsequence, offset, state);
841
+ }
842
+
843
+ /**
844
+ * \brief Return 32-bits of pseudorandomness from an XORWOW generator.
845
+ *
846
+ * Return 32-bits of pseudorandomness from the XORWOW generator in \p state,
847
+ * increment position of generator by one.
848
+ *
849
+ * \param state - Pointer to state to update
850
+ *
851
+ * \return 32-bits of pseudorandomness as an unsigned int, all bits valid to use.
852
+ */
853
+ QUALIFIERS unsigned int curand(curandStateXORWOW_t *state)
854
+ {
855
+ unsigned int t;
856
+ t = (state->v[0] ^ (state->v[0] >> 2));
857
+ state->v[0] = state->v[1];
858
+ state->v[1] = state->v[2];
859
+ state->v[2] = state->v[3];
860
+ state->v[3] = state->v[4];
861
+ state->v[4] = (state->v[4] ^ (state->v[4] <<4)) ^ (t ^ (t << 1));
862
+ state->d += 362437;
863
+ return state->v[4] + state->d;
864
+ }
865
+
866
+
867
+ /**
868
+ * \brief Return 32-bits of pseudorandomness from an Philox4_32_10 generator.
869
+ *
870
+ * Return 32-bits of pseudorandomness from the Philox4_32_10 generator in \p state,
871
+ * increment position of generator by one.
872
+ *
873
+ * \param state - Pointer to state to update
874
+ *
875
+ * \return 32-bits of pseudorandomness as an unsigned int, all bits valid to use.
876
+ */
877
+
878
+ QUALIFIERS unsigned int curand(curandStatePhilox4_32_10_t *state)
879
+ {
880
+ // Maintain the invariant: output[STATE] is always "good" and
881
+ // is the next value to be returned by curand.
882
+ unsigned int ret;
883
+ switch(state->STATE++){
884
+ default:
885
+ ret = state->output.x;
886
+ break;
887
+ case 1:
888
+ ret = state->output.y;
889
+ break;
890
+ case 2:
891
+ ret = state->output.z;
892
+ break;
893
+ case 3:
894
+ ret = state->output.w;
895
+ break;
896
+ }
897
+ if(state->STATE == 4){
898
+ Philox_State_Incr(state);
899
+ state->output = curand_Philox4x32_10(state->ctr,state->key);
900
+ state->STATE = 0;
901
+ }
902
+ return ret;
903
+ }
904
+
905
+ /**
906
+ * \brief Return tuple of 4 32-bit pseudorandoms from a Philox4_32_10 generator.
907
+ *
908
+ * Return 128 bits of pseudorandomness from the Philox4_32_10 generator in \p state,
909
+ * increment position of generator by four.
910
+ *
911
+ * \param state - Pointer to state to update
912
+ *
913
+ * \return 128-bits of pseudorandomness as a uint4, all bits valid to use.
914
+ */
915
+
916
+ QUALIFIERS uint4 curand4(curandStatePhilox4_32_10_t *state)
917
+ {
918
+ uint4 r;
919
+
920
+ uint4 tmp = state->output;
921
+ Philox_State_Incr(state);
922
+ state->output= curand_Philox4x32_10(state->ctr,state->key);
923
+ switch(state->STATE){
924
+ case 0:
925
+ return tmp;
926
+ case 1:
927
+ r.x = tmp.y;
928
+ r.y = tmp.z;
929
+ r.z = tmp.w;
930
+ r.w = state->output.x;
931
+ break;
932
+ case 2:
933
+ r.x = tmp.z;
934
+ r.y = tmp.w;
935
+ r.z = state->output.x;
936
+ r.w = state->output.y;
937
+ break;
938
+ case 3:
939
+ r.x = tmp.w;
940
+ r.y = state->output.x;
941
+ r.z = state->output.y;
942
+ r.w = state->output.z;
943
+ break;
944
+ default:
945
+ // NOT possible but needed to avoid compiler warnings
946
+ return tmp;
947
+ }
948
+ return r;
949
+ }
950
+
951
+ /**
952
+ * \brief Update Philox4_32_10 state to skip \p n elements.
953
+ *
954
+ * Update the Philox4_32_10 state in \p state to skip ahead \p n elements.
955
+ *
956
+ * All values of \p n are valid.
957
+ *
958
+ * \param n - Number of elements to skip
959
+ * \param state - Pointer to state to update
960
+ */
961
+ QUALIFIERS void skipahead(unsigned long long n, curandStatePhilox4_32_10_t *state)
962
+ {
963
+ state->STATE += (n & 3);
964
+ n /= 4;
965
+ if( state->STATE > 3 ){
966
+ n += 1;
967
+ state->STATE -= 4;
968
+ }
969
+ Philox_State_Incr(state, n);
970
+ state->output = curand_Philox4x32_10(state->ctr,state->key);
971
+ }
972
+
973
+ /**
974
+ * \brief Update Philox4_32_10 state to skip ahead \p n subsequences.
975
+ *
976
+ * Update the Philox4_32_10 state in \p state to skip ahead \p n subsequences. Each
977
+ * subsequence is \xmlonly<ph outputclass="xmlonly">2<sup>66</sup></ph>\endxmlonly elements long, so this means the function will skip ahead
978
+ * \xmlonly<ph outputclass="xmlonly">2<sup>66</sup></ph>\endxmlonly * n elements.
979
+ *
980
+ * All values of \p n are valid.
981
+ *
982
+ * \param n - Number of subsequences to skip
983
+ * \param state - Pointer to state to update
984
+ */
985
+ QUALIFIERS void skipahead_sequence(unsigned long long n, curandStatePhilox4_32_10_t *state)
986
+ {
987
+ Philox_State_Incr_hi(state, n);
988
+ state->output = curand_Philox4x32_10(state->ctr,state->key);
989
+ }
990
+
991
+ /**
992
+ * \brief Initialize Philox4_32_10 state.
993
+ *
994
+ * Initialize Philox4_32_10 state in \p state with the given \p seed, p\ subsequence,
995
+ * and \p offset.
996
+ *
997
+ * All input values for \p seed, \p subseqence and \p offset are legal. Each of the
998
+ * \xmlonly<ph outputclass="xmlonly">2<sup>64</sup></ph>\endxmlonly possible
999
+ * values of seed selects an independent sequence of length
1000
+ * \xmlonly<ph outputclass="xmlonly">2<sup>130</sup></ph>\endxmlonly.
1001
+ * The first
1002
+ * \xmlonly<ph outputclass="xmlonly">2<sup>66</sup> * subsequence + offset</ph>\endxmlonly.
1003
+ * values of the sequence are skipped.
1004
+ * I.e., subsequences are of length
1005
+ * \xmlonly<ph outputclass="xmlonly">2<sup>66</sup></ph>\endxmlonly.
1006
+ *
1007
+ * \param seed - Arbitrary bits to use as a seed
1008
+ * \param subsequence - Subsequence to start at
1009
+ * \param offset - Absolute offset into subsequence
1010
+ * \param state - Pointer to state to initialize
1011
+ */
1012
+ QUALIFIERS void curand_init(unsigned long long seed,
1013
+ unsigned long long subsequence,
1014
+ unsigned long long offset,
1015
+ curandStatePhilox4_32_10_t *state)
1016
+ {
1017
+ state->ctr = make_uint4(0, 0, 0, 0);
1018
+ state->key.x = (unsigned int)seed;
1019
+ state->key.y = (unsigned int)(seed>>32);
1020
+ state->STATE = 0;
1021
+ state->boxmuller_flag = 0;
1022
+ state->boxmuller_flag_double = 0;
1023
+ state->boxmuller_extra = 0.f;
1024
+ state->boxmuller_extra_double = 0.;
1025
+ skipahead_sequence(subsequence, state);
1026
+ skipahead(offset, state);
1027
+ }
1028
+
1029
+
1030
+ /* MRG32k3a RNG */
1031
+
1032
+ /* Base generator for MRG32k3a */
1033
+ #if __CUDA_ARCH__ > 600
1034
+ QUALIFIERS unsigned long long __curand_umad(unsigned int a, unsigned int b, unsigned long long c)
1035
+ {
1036
+ unsigned long long r;
1037
+ asm("mad.wide.u32 %0, %1, %2, %3;"
1038
+ : "=l"(r) : "r"(a), "r"(b), "l"(c));
1039
+ return r;
1040
+ }
1041
+ QUALIFIERS unsigned long long __curand_umul(unsigned int a, unsigned int b)
1042
+ {
1043
+ unsigned long long r;
1044
+ asm("mul.wide.u32 %0, %1, %2;"
1045
+ : "=l"(r) : "r"(a), "r"(b));
1046
+ return r;
1047
+ }
1048
+
1049
+ QUALIFIERS double curand_MRG32k3a (curandStateMRG32k3a_t *state)
1050
+ {
1051
+ const unsigned int m1 = 4294967087u;
1052
+ const unsigned int m2 = 4294944443u;
1053
+ const unsigned int m1c = 209u;
1054
+ const unsigned int m2c = 22853u;
1055
+ const unsigned int a12 = 1403580u;
1056
+ const unsigned int a13n = 810728u;
1057
+ const unsigned int a21 = 527612u;
1058
+ const unsigned int a23n = 1370589u;
1059
+
1060
+ unsigned long long p1, p2;
1061
+ const unsigned long long p3 = __curand_umul(a13n, m1 - state->s1[0]);
1062
+ p1 = __curand_umad(a12, state->s1[1], p3);
1063
+
1064
+ // Putting addition inside and changing umul to umad
1065
+ // slowed this function down on GV100
1066
+ p1 = __curand_umul(p1 >> 32, m1c) + (p1 & 0xffffffff);
1067
+ if (p1 >= m1) p1 -= m1;
1068
+
1069
+ state->s1[0] = state->s1[1]; state->s1[1] = state->s1[2]; state->s1[2] = p1;
1070
+ const unsigned long long p4 = __curand_umul(a23n, m2 - state->s2[0]);
1071
+ p2 = __curand_umad(a21, state->s2[2], p4);
1072
+
1073
+ // Putting addition inside and changing umul to umad
1074
+ // slowed this function down on GV100
1075
+ p2 = __curand_umul(p2 >> 32, m2c) + (p2 & 0xffffffff);
1076
+ p2 = __curand_umul(p2 >> 32, m2c) + (p2 & 0xffffffff);
1077
+ if (p2 >= m2) p2 -= m2;
1078
+
1079
+ state->s2[0] = state->s2[1]; state->s2[1] = state->s2[2]; state->s2[2] = p2;
1080
+
1081
+ const unsigned int p5 = (unsigned int)p1 - (unsigned int)p2;
1082
+ if(p1 <= p2) return p5 + m1;
1083
+ return p5;
1084
+ }
1085
+ #elif __CUDA_ARCH__ > 0
1086
+ /* nj's implementation */
1087
+ QUALIFIERS double curand_MRG32k3a (curandStateMRG32k3a_t *state)
1088
+ {
1089
+ const double m1 = 4294967087.;
1090
+ const double m2 = 4294944443.;
1091
+ const double a12 = 1403580.;
1092
+ const double a13n = 810728.;
1093
+ const double a21 = 527612.;
1094
+ const double a23n = 1370589.;
1095
+
1096
+ const double rh1 = 2.3283065498378290e-010; /* (1.0 / m1)__hi */
1097
+ const double rl1 = -1.7354913086174288e-026; /* (1.0 / m1)__lo */
1098
+ const double rh2 = 2.3283188252407387e-010; /* (1.0 / m2)__hi */
1099
+ const double rl2 = 2.4081018096503646e-026; /* (1.0 / m2)__lo */
1100
+
1101
+ double q, p1, p2;
1102
+ p1 = a12 * state->s1[1] - a13n * state->s1[0];
1103
+ q = trunc (fma (p1, rh1, p1 * rl1));
1104
+ p1 -= q * m1;
1105
+ if (p1 < 0.0) p1 += m1;
1106
+ state->s1[0] = state->s1[1]; state->s1[1] = state->s1[2]; state->s1[2] = (unsigned int)p1;
1107
+ p2 = a21 * state->s2[2] - a23n * state->s2[0];
1108
+ q = trunc (fma (p2, rh2, p2 * rl2));
1109
+ p2 -= q * m2;
1110
+ if (p2 < 0.0) p2 += m2;
1111
+ state->s2[0] = state->s2[1]; state->s2[1] = state->s2[2]; state->s2[2] = (unsigned int)p2;
1112
+ if (p1 <= p2) return (p1 - p2 + m1);
1113
+ else return (p1 - p2);
1114
+ }
1115
+ /* end nj's implementation */
1116
+ #else
1117
+ QUALIFIERS double curand_MRG32k3a(curandStateMRG32k3a_t *state)
1118
+ {
1119
+ double p1,p2,r;
1120
+ p1 = (MRG32K3A_A12 * state->s1[1]) - (MRG32K3A_A13N * state->s1[0]);
1121
+ p1 = curand_MRGmod(p1, MRG32K3A_MOD1);
1122
+ if (p1 < 0.0) p1 += MRG32K3A_MOD1;
1123
+ state->s1[0] = state->s1[1];
1124
+ state->s1[1] = state->s1[2];
1125
+ state->s1[2] = (unsigned int)p1;
1126
+ p2 = (MRG32K3A_A21 * state->s2[2]) - (MRG32K3A_A23N * state->s2[0]);
1127
+ p2 = curand_MRGmod(p2, MRG32K3A_MOD2);
1128
+ if (p2 < 0) p2 += MRG32K3A_MOD2;
1129
+ state->s2[0] = state->s2[1];
1130
+ state->s2[1] = state->s2[2];
1131
+ state->s2[2] = (unsigned int)p2;
1132
+ r = p1 - p2;
1133
+ if (r <= 0) r += MRG32K3A_MOD1;
1134
+ return r;
1135
+ }
1136
+ #endif
1137
+
1138
+
1139
+ /**
1140
+ * \brief Return 32-bits of pseudorandomness from an MRG32k3a generator.
1141
+ *
1142
+ * Return 32-bits of pseudorandomness from the MRG32k3a generator in \p state,
1143
+ * increment position of generator by one.
1144
+ *
1145
+ * \param state - Pointer to state to update
1146
+ *
1147
+ * \return 32-bits of pseudorandomness as an unsigned int, all bits valid to use.
1148
+ */
1149
+ QUALIFIERS unsigned int curand(curandStateMRG32k3a_t *state)
1150
+ {
1151
+ double dRet;
1152
+ dRet = (double)curand_MRG32k3a(state)*(double)MRG32K3A_BITS_NORM;
1153
+ return (unsigned int)dRet;
1154
+ }
1155
+
1156
+
1157
+
1158
+ /**
1159
+ * \brief Update MRG32k3a state to skip \p n elements.
1160
+ *
1161
+ * Update the MRG32k3a state in \p state to skip ahead \p n elements.
1162
+ *
1163
+ * All values of \p n are valid. Large values require more computation and so
1164
+ * will take more time to complete.
1165
+ *
1166
+ * \param n - Number of elements to skip
1167
+ * \param state - Pointer to state to update
1168
+ */
1169
+ QUALIFIERS void skipahead(unsigned long long n, curandStateMRG32k3a_t *state)
1170
+ {
1171
+ unsigned int t[3][3];
1172
+ #ifdef __CUDA_ARCH__
1173
+ curand_MRGmatPow3x3( mrg32k3aM1, t, MRG32K3A_MOD1, n);
1174
+ curand_MRGmatVecMul3x3( t, state->s1, MRG32K3A_MOD1);
1175
+ curand_MRGmatPow3x3(mrg32k3aM2, t, MRG32K3A_MOD2, n);
1176
+ curand_MRGmatVecMul3x3( t, state->s2, MRG32K3A_MOD2);
1177
+ #else
1178
+ curand_MRGmatPow3x3( mrg32k3aM1Host, t, MRG32K3A_MOD1, n);
1179
+ curand_MRGmatVecMul3x3( t, state->s1, MRG32K3A_MOD1);
1180
+ curand_MRGmatPow3x3(mrg32k3aM2Host, t, MRG32K3A_MOD2, n);
1181
+ curand_MRGmatVecMul3x3( t, state->s2, MRG32K3A_MOD2);
1182
+ #endif
1183
+ }
1184
+
1185
+ /**
1186
+ * \brief Update MRG32k3a state to skip ahead \p n subsequences.
1187
+ *
1188
+ * Update the MRG32k3a state in \p state to skip ahead \p n subsequences. Each
1189
+ * subsequence is \xmlonly<ph outputclass="xmlonly">2<sup>127</sup></ph>\endxmlonly
1190
+ *
1191
+ * \xmlonly<ph outputclass="xmlonly">2<sup>76</sup></ph>\endxmlonly elements long, so this means the function will skip ahead
1192
+ * \xmlonly<ph outputclass="xmlonly">2<sup>67</sup></ph>\endxmlonly * n elements.
1193
+ *
1194
+ * Valid values of \p n are 0 to \xmlonly<ph outputclass="xmlonly">2<sup>51</sup></ph>\endxmlonly. Note \p n will be masked to 51 bits
1195
+ *
1196
+ * \param n - Number of subsequences to skip
1197
+ * \param state - Pointer to state to update
1198
+ */
1199
+ QUALIFIERS void skipahead_subsequence(unsigned long long n, curandStateMRG32k3a_t *state)
1200
+ {
1201
+ unsigned int t[3][3];
1202
+ #ifdef __CUDA_ARCH__
1203
+ curand_MRGmatPow3x3( mrg32k3aM1SubSeq, t, MRG32K3A_MOD1, n);
1204
+ curand_MRGmatVecMul3x3( t, state->s1, MRG32K3A_MOD1);
1205
+ curand_MRGmatPow3x3( mrg32k3aM2SubSeq, t, MRG32K3A_MOD2, n);
1206
+ curand_MRGmatVecMul3x3( t, state->s2, MRG32K3A_MOD2);
1207
+ #else
1208
+ curand_MRGmatPow3x3( mrg32k3aM1SubSeqHost, t, MRG32K3A_MOD1, n);
1209
+ curand_MRGmatVecMul3x3( t, state->s1, MRG32K3A_MOD1);
1210
+ curand_MRGmatPow3x3( mrg32k3aM2SubSeqHost, t, MRG32K3A_MOD2, n);
1211
+ curand_MRGmatVecMul3x3( t, state->s2, MRG32K3A_MOD2);
1212
+ #endif
1213
+ }
1214
+
1215
+ /**
1216
+ * \brief Update MRG32k3a state to skip ahead \p n sequences.
1217
+ *
1218
+ * Update the MRG32k3a state in \p state to skip ahead \p n sequences. Each
1219
+ * sequence is \xmlonly<ph outputclass="xmlonly">2<sup>127</sup></ph>\endxmlonly elements long, so this means the function will skip ahead
1220
+ * \xmlonly<ph outputclass="xmlonly">2<sup>127</sup></ph>\endxmlonly * n elements.
1221
+ *
1222
+ * All values of \p n are valid. Large values require more computation and so
1223
+ * will take more time to complete.
1224
+ *
1225
+ * \param n - Number of sequences to skip
1226
+ * \param state - Pointer to state to update
1227
+ */
1228
+ QUALIFIERS void skipahead_sequence(unsigned long long n, curandStateMRG32k3a_t *state)
1229
+ {
1230
+ unsigned int t[3][3];
1231
+ #ifdef __CUDA_ARCH__
1232
+ curand_MRGmatPow3x3( mrg32k3aM1Seq, t, MRG32K3A_MOD1, n);
1233
+ curand_MRGmatVecMul3x3( t, state->s1, MRG32K3A_MOD1);
1234
+ curand_MRGmatPow3x3( mrg32k3aM2Seq, t, MRG32K3A_MOD2, n);
1235
+ curand_MRGmatVecMul3x3( t, state->s2, MRG32K3A_MOD2);
1236
+ #else
1237
+ curand_MRGmatPow3x3( mrg32k3aM1SeqHost, t, MRG32K3A_MOD1, n);
1238
+ curand_MRGmatVecMul3x3( t, state->s1, MRG32K3A_MOD1);
1239
+ curand_MRGmatPow3x3( mrg32k3aM2SeqHost, t, MRG32K3A_MOD2, n);
1240
+ curand_MRGmatVecMul3x3( t, state->s2, MRG32K3A_MOD2);
1241
+ #endif
1242
+ }
1243
+
1244
+
1245
+ /**
1246
+ * \brief Initialize MRG32k3a state.
1247
+ *
1248
+ * Initialize MRG32k3a state in \p state with the given \p seed, \p subsequence,
1249
+ * and \p offset.
1250
+ *
1251
+ * All input values of \p seed, \p subsequence, and \p offset are legal.
1252
+ * \p subsequence will be truncated to 51 bits to avoid running into the next sequence
1253
+ *
1254
+ * A value of 0 for \p seed sets the state to the values of the original
1255
+ * published version of the \p MRG32k3a algorithm.
1256
+ *
1257
+ * \param seed - Arbitrary bits to use as a seed
1258
+ * \param subsequence - Subsequence to start at
1259
+ * \param offset - Absolute offset into sequence
1260
+ * \param state - Pointer to state to initialize
1261
+ */
1262
+ QUALIFIERS void curand_init(unsigned long long seed,
1263
+ unsigned long long subsequence,
1264
+ unsigned long long offset,
1265
+ curandStateMRG32k3a_t *state)
1266
+ {
1267
+ int i;
1268
+ for ( i=0; i<3; i++ ) {
1269
+ state->s1[i] = 12345u;
1270
+ state->s2[i] = 12345u;
1271
+ }
1272
+ if (seed != 0ull) {
1273
+ unsigned int x1 = ((unsigned int)seed) ^ 0x55555555UL;
1274
+ unsigned int x2 = (unsigned int)((seed >> 32) ^ 0xAAAAAAAAUL);
1275
+ state->s1[0] = (unsigned int)curand_MRGmodMul(x1, state->s1[0], MRG32K3A_MOD1);
1276
+ state->s1[1] = (unsigned int)curand_MRGmodMul(x2, state->s1[1], MRG32K3A_MOD1);
1277
+ state->s1[2] = (unsigned int)curand_MRGmodMul(x1, state->s1[2], MRG32K3A_MOD1);
1278
+ state->s2[0] = (unsigned int)curand_MRGmodMul(x2, state->s2[0], MRG32K3A_MOD2);
1279
+ state->s2[1] = (unsigned int)curand_MRGmodMul(x1, state->s2[1], MRG32K3A_MOD2);
1280
+ state->s2[2] = (unsigned int)curand_MRGmodMul(x2, state->s2[2], MRG32K3A_MOD2);
1281
+ }
1282
+ skipahead_subsequence( subsequence, state );
1283
+ skipahead( offset, state );
1284
+ state->boxmuller_flag = 0;
1285
+ state->boxmuller_flag_double = 0;
1286
+ state->boxmuller_extra = 0.f;
1287
+ state->boxmuller_extra_double = 0.;
1288
+ }
1289
+
1290
+ /**
1291
+ * \brief Update Sobol32 state to skip \p n elements.
1292
+ *
1293
+ * Update the Sobol32 state in \p state to skip ahead \p n elements.
1294
+ *
1295
+ * All values of \p n are valid.
1296
+ *
1297
+ * \param n - Number of elements to skip
1298
+ * \param state - Pointer to state to update
1299
+ */
1300
+ template <typename T>
1301
+ QUALIFIERS
1302
+ typename CURAND_STD::enable_if<CURAND_STD::is_same<curandStateSobol32_t*, T>::value || CURAND_STD::is_same<curandStateScrambledSobol32_t*, T>::value>::type
1303
+ skipahead(unsigned int n, T state)
1304
+ {
1305
+ unsigned int i_gray;
1306
+ state->x = state->c;
1307
+ state->i += n;
1308
+ /* Convert state->i to gray code */
1309
+ i_gray = state->i ^ (state->i >> 1);
1310
+ for(unsigned int k = 0; k < 32; k++) {
1311
+ if(i_gray & (1 << k)) {
1312
+ state->x ^= state->direction_vectors[k];
1313
+ }
1314
+ }
1315
+ return;
1316
+ }
1317
+
1318
+ /**
1319
+ * \brief Update Sobol64 state to skip \p n elements.
1320
+ *
1321
+ * Update the Sobol64 state in \p state to skip ahead \p n elements.
1322
+ *
1323
+ * All values of \p n are valid.
1324
+ *
1325
+ * \param n - Number of elements to skip
1326
+ * \param state - Pointer to state to update
1327
+ */
1328
+ template <typename T>
1329
+ QUALIFIERS
1330
+ typename CURAND_STD::enable_if<CURAND_STD::is_same<curandStateSobol64_t*, T>::value || CURAND_STD::is_same<curandStateScrambledSobol64_t*, T>::value>::type
1331
+ skipahead(unsigned long long n, T state)
1332
+ {
1333
+ unsigned long long i_gray;
1334
+ state->x = state->c;
1335
+ state->i += n;
1336
+ /* Convert state->i to gray code */
1337
+ i_gray = state->i ^ (state->i >> 1);
1338
+ for(unsigned k = 0; k < 64; k++) {
1339
+ if(i_gray & (1ULL << k)) {
1340
+ state->x ^= state->direction_vectors[k];
1341
+ }
1342
+ }
1343
+ return;
1344
+ }
1345
+
1346
+ /**
1347
+ * \brief Initialize Sobol32 state.
1348
+ *
1349
+ * Initialize Sobol32 state in \p state with the given \p direction \p vectors and
1350
+ * \p offset.
1351
+ *
1352
+ * The direction vector is a device pointer to an array of 32 unsigned ints.
1353
+ * All input values of \p offset are legal.
1354
+ *
1355
+ * \param direction_vectors - Pointer to array of 32 unsigned ints representing the
1356
+ * direction vectors for the desired dimension
1357
+ * \param offset - Absolute offset into sequence
1358
+ * \param state - Pointer to state to initialize
1359
+ */
1360
+ QUALIFIERS void curand_init(curandDirectionVectors32_t direction_vectors,
1361
+ unsigned int offset,
1362
+ curandStateSobol32_t *state)
1363
+ {
1364
+ state->i = 0;
1365
+ state->c = 0;
1366
+ for(int i = 0; i < 32; i++) {
1367
+ state->direction_vectors[i] = direction_vectors[i];
1368
+ }
1369
+ state->x = 0;
1370
+ skipahead<curandStateSobol32_t *>(offset, state);
1371
+ }
1372
+ /**
1373
+ * \brief Initialize Scrambled Sobol32 state.
1374
+ *
1375
+ * Initialize Sobol32 state in \p state with the given \p direction \p vectors and
1376
+ * \p offset.
1377
+ *
1378
+ * The direction vector is a device pointer to an array of 32 unsigned ints.
1379
+ * All input values of \p offset are legal.
1380
+ *
1381
+ * \param direction_vectors - Pointer to array of 32 unsigned ints representing the
1382
+ direction vectors for the desired dimension
1383
+ * \param scramble_c Scramble constant
1384
+ * \param offset - Absolute offset into sequence
1385
+ * \param state - Pointer to state to initialize
1386
+ */
1387
+ QUALIFIERS void curand_init(curandDirectionVectors32_t direction_vectors,
1388
+ unsigned int scramble_c,
1389
+ unsigned int offset,
1390
+ curandStateScrambledSobol32_t *state)
1391
+ {
1392
+ state->i = 0;
1393
+ state->c = scramble_c;
1394
+ for(int i = 0; i < 32; i++) {
1395
+ state->direction_vectors[i] = direction_vectors[i];
1396
+ }
1397
+ state->x = state->c;
1398
+ skipahead<curandStateScrambledSobol32_t *>(offset, state);
1399
+ }
1400
+
1401
+ QUALIFIERS int __curand_find_trailing_zero(unsigned int x)
1402
+ {
1403
+ #if __CUDA_ARCH__ > 0
1404
+ int y = __ffs(~x);
1405
+ if(y)
1406
+ return y - 1;
1407
+ return 31;
1408
+ #else
1409
+ int i = 1;
1410
+ while(x & 1) {
1411
+ i++;
1412
+ x >>= 1;
1413
+ }
1414
+ i = i - 1;
1415
+ return i == 32 ? 31 : i;
1416
+ #endif
1417
+ }
1418
+
1419
+ QUALIFIERS int __curand_find_trailing_zero(unsigned long long x)
1420
+ {
1421
+ #if __CUDA_ARCH__ > 0
1422
+ int y = __ffsll(~x);
1423
+ if(y)
1424
+ return y - 1;
1425
+ return 63;
1426
+ #else
1427
+ int i = 1;
1428
+ while(x & 1) {
1429
+ i++;
1430
+ x >>= 1;
1431
+ }
1432
+ i = i - 1;
1433
+ return i == 64 ? 63 : i;
1434
+ #endif
1435
+ }
1436
+
1437
+ /**
1438
+ * \brief Initialize Sobol64 state.
1439
+ *
1440
+ * Initialize Sobol64 state in \p state with the given \p direction \p vectors and
1441
+ * \p offset.
1442
+ *
1443
+ * The direction vector is a device pointer to an array of 64 unsigned long longs.
1444
+ * All input values of \p offset are legal.
1445
+ *
1446
+ * \param direction_vectors - Pointer to array of 64 unsigned long longs representing the
1447
+ direction vectors for the desired dimension
1448
+ * \param offset - Absolute offset into sequence
1449
+ * \param state - Pointer to state to initialize
1450
+ */
1451
+ QUALIFIERS void curand_init(curandDirectionVectors64_t direction_vectors,
1452
+ unsigned long long offset,
1453
+ curandStateSobol64_t *state)
1454
+ {
1455
+ state->i = 0;
1456
+ state->c = 0;
1457
+ for(int i = 0; i < 64; i++) {
1458
+ state->direction_vectors[i] = direction_vectors[i];
1459
+ }
1460
+ state->x = 0;
1461
+ skipahead<curandStateSobol64_t *>(offset, state);
1462
+ }
1463
+
1464
+ /**
1465
+ * \brief Initialize Scrambled Sobol64 state.
1466
+ *
1467
+ * Initialize Sobol64 state in \p state with the given \p direction \p vectors and
1468
+ * \p offset.
1469
+ *
1470
+ * The direction vector is a device pointer to an array of 64 unsigned long longs.
1471
+ * All input values of \p offset are legal.
1472
+ *
1473
+ * \param direction_vectors - Pointer to array of 64 unsigned long longs representing the
1474
+ direction vectors for the desired dimension
1475
+ * \param scramble_c Scramble constant
1476
+ * \param offset - Absolute offset into sequence
1477
+ * \param state - Pointer to state to initialize
1478
+ */
1479
+ QUALIFIERS void curand_init(curandDirectionVectors64_t direction_vectors,
1480
+ unsigned long long scramble_c,
1481
+ unsigned long long offset,
1482
+ curandStateScrambledSobol64_t *state)
1483
+ {
1484
+ state->i = 0;
1485
+ state->c = scramble_c;
1486
+ for(int i = 0; i < 64; i++) {
1487
+ state->direction_vectors[i] = direction_vectors[i];
1488
+ }
1489
+ state->x = state->c;
1490
+ skipahead<curandStateScrambledSobol64_t *>(offset, state);
1491
+ }
1492
+
1493
+ /**
1494
+ * \brief Return 32-bits of quasirandomness from a Sobol32 generator.
1495
+ *
1496
+ * Return 32-bits of quasirandomness from the Sobol32 generator in \p state,
1497
+ * increment position of generator by one.
1498
+ *
1499
+ * \param state - Pointer to state to update
1500
+ *
1501
+ * \return 32-bits of quasirandomness as an unsigned int, all bits valid to use.
1502
+ */
1503
+
1504
+ QUALIFIERS unsigned int curand(curandStateSobol32_t * state)
1505
+ {
1506
+ /* Moving from i to i+1 element in gray code is flipping one bit,
1507
+ the trailing zero bit of i
1508
+ */
1509
+ unsigned int res = state->x;
1510
+ state->x ^= state->direction_vectors[__curand_find_trailing_zero(state->i)];
1511
+ state->i ++;
1512
+ return res;
1513
+ }
1514
+
1515
+ /**
1516
+ * \brief Return 32-bits of quasirandomness from a scrambled Sobol32 generator.
1517
+ *
1518
+ * Return 32-bits of quasirandomness from the scrambled Sobol32 generator in \p state,
1519
+ * increment position of generator by one.
1520
+ *
1521
+ * \param state - Pointer to state to update
1522
+ *
1523
+ * \return 32-bits of quasirandomness as an unsigned int, all bits valid to use.
1524
+ */
1525
+
1526
+ QUALIFIERS unsigned int curand(curandStateScrambledSobol32_t * state)
1527
+ {
1528
+ /* Moving from i to i+1 element in gray code is flipping one bit,
1529
+ the trailing zero bit of i
1530
+ */
1531
+ unsigned int res = state->x;
1532
+ state->x ^= state->direction_vectors[__curand_find_trailing_zero(state->i)];
1533
+ state->i ++;
1534
+ return res;
1535
+ }
1536
+
1537
+ /**
1538
+ * \brief Return 64-bits of quasirandomness from a Sobol64 generator.
1539
+ *
1540
+ * Return 64-bits of quasirandomness from the Sobol64 generator in \p state,
1541
+ * increment position of generator by one.
1542
+ *
1543
+ * \param state - Pointer to state to update
1544
+ *
1545
+ * \return 64-bits of quasirandomness as an unsigned long long, all bits valid to use.
1546
+ */
1547
+
1548
+ QUALIFIERS unsigned long long curand(curandStateSobol64_t * state)
1549
+ {
1550
+ /* Moving from i to i+1 element in gray code is flipping one bit,
1551
+ the trailing zero bit of i
1552
+ */
1553
+ unsigned long long res = state->x;
1554
+ state->x ^= state->direction_vectors[__curand_find_trailing_zero(state->i)];
1555
+ state->i ++;
1556
+ return res;
1557
+ }
1558
+
1559
+ /**
1560
+ * \brief Return 64-bits of quasirandomness from a scrambled Sobol64 generator.
1561
+ *
1562
+ * Return 64-bits of quasirandomness from the scrambled Sobol32 generator in \p state,
1563
+ * increment position of generator by one.
1564
+ *
1565
+ * \param state - Pointer to state to update
1566
+ *
1567
+ * \return 64-bits of quasirandomness as an unsigned long long, all bits valid to use.
1568
+ */
1569
+
1570
+ QUALIFIERS unsigned long long curand(curandStateScrambledSobol64_t * state)
1571
+ {
1572
+ /* Moving from i to i+1 element in gray code is flipping one bit,
1573
+ the trailing zero bit of i
1574
+ */
1575
+ unsigned long long res = state->x;
1576
+ state->x ^= state->direction_vectors[__curand_find_trailing_zero(state->i)];
1577
+ state->i ++;
1578
+ return res;
1579
+ }
1580
+
1581
+ #include "curand_uniform.h"
1582
+ #include "curand_normal.h"
1583
+ #include "curand_lognormal.h"
1584
+ #include "curand_poisson.h"
1585
+ #include "curand_discrete2.h"
1586
+
1587
+ __device__ static inline unsigned int *__get_precalculated_matrix(int n)
1588
+ {
1589
+ if(n == 0) {
1590
+ return precalc_xorwow_matrix[n];
1591
+ }
1592
+ if(n == 2) {
1593
+ return precalc_xorwow_offset_matrix[n];
1594
+ }
1595
+ return precalc_xorwow_matrix[n];
1596
+ }
1597
+
1598
+ #ifndef __CUDACC_RTC__
1599
+ __host__ static inline unsigned int *__get_precalculated_matrix_host(int n)
1600
+ {
1601
+ if(n == 1) {
1602
+ return precalc_xorwow_matrix_host[n];
1603
+ }
1604
+ if(n == 3) {
1605
+ return precalc_xorwow_offset_matrix_host[n];
1606
+ }
1607
+ return precalc_xorwow_matrix_host[n];
1608
+ }
1609
+ #endif // #ifndef __CUDACC_RTC__
1610
+
1611
+ __device__ static inline unsigned int *__get_mrg32k3a_matrix(int n)
1612
+ {
1613
+ if(n == 0) {
1614
+ return mrg32k3aM1[n][0];
1615
+ }
1616
+ if(n == 2) {
1617
+ return mrg32k3aM2[n][0];
1618
+ }
1619
+ if(n == 4) {
1620
+ return mrg32k3aM1SubSeq[n][0];
1621
+ }
1622
+ if(n == 6) {
1623
+ return mrg32k3aM2SubSeq[n][0];
1624
+ }
1625
+ if(n == 8) {
1626
+ return mrg32k3aM1Seq[n][0];
1627
+ }
1628
+ if(n == 10) {
1629
+ return mrg32k3aM2Seq[n][0];
1630
+ }
1631
+ return mrg32k3aM1[n][0];
1632
+ }
1633
+
1634
+ #ifndef __CUDACC_RTC__
1635
+ __host__ static inline unsigned int *__get_mrg32k3a_matrix_host(int n)
1636
+ {
1637
+ if(n == 1) {
1638
+ return mrg32k3aM1Host[n][0];
1639
+ }
1640
+ if(n == 3) {
1641
+ return mrg32k3aM2Host[n][0];
1642
+ }
1643
+ if(n == 5) {
1644
+ return mrg32k3aM1SubSeqHost[n][0];
1645
+ }
1646
+ if(n == 7) {
1647
+ return mrg32k3aM2SubSeqHost[n][0];
1648
+ }
1649
+ if(n == 9) {
1650
+ return mrg32k3aM1SeqHost[n][0];
1651
+ }
1652
+ if(n == 11) {
1653
+ return mrg32k3aM2SeqHost[n][0];
1654
+ }
1655
+ return mrg32k3aM1Host[n][0];
1656
+ }
1657
+
1658
+ __host__ static inline double *__get__cr_lgamma_table_host(void) {
1659
+ return __cr_lgamma_table;
1660
+ }
1661
+ #endif // #ifndef __CUDACC_RTC__
1662
+
1663
+ /** @} */
1664
+
1665
+ #endif // !defined(CURAND_KERNEL_H_)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/curand_mtgp32_host.h ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2010-2014 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ /*
51
+ * curand_mtgp32_host.h
52
+ *
53
+ *
54
+ * MTGP32-11213
55
+ *
56
+ * Mersenne Twister RNG for the GPU
57
+ *
58
+ * The period of generated integers is 2<sup>11213</sup>-1.
59
+ *
60
+ * This code generates 32-bit unsigned integers, and
61
+ * single precision floating point numbers uniformly distributed
62
+ * in the range [1, 2). (float r; 1.0 <= r < 2.0)
63
+ */
64
+
65
+ /*
66
+ * Copyright (c) 2009, 2010 Mutsuo Saito, Makoto Matsumoto and Hiroshima
67
+ * University. All rights reserved.
68
+ * Copyright (c) 2011 Mutsuo Saito, Makoto Matsumoto, Hiroshima
69
+ * University and University of Tokyo. All rights reserved.
70
+ *
71
+ * Redistribution and use in source and binary forms, with or without
72
+ * modification, are permitted provided that the following conditions are
73
+ * met:
74
+ *
75
+ * * Redistributions of source code must retain the above copyright
76
+ * notice, this list of conditions and the following disclaimer.
77
+ * * Redistributions in binary form must reproduce the above
78
+ * copyright notice, this list of conditions and the following
79
+ * disclaimer in the documentation and/or other materials provided
80
+ * with the distribution.
81
+ * * Neither the name of the Hiroshima University nor the names of
82
+ * its contributors may be used to endorse or promote products
83
+ * derived from this software without specific prior written
84
+ * permission.
85
+ *
86
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
87
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
88
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
89
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
90
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
91
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
92
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
93
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
94
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
95
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
96
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
97
+ */
98
+ #if !defined CURAND_MTGP32_HOST_H
99
+ #define CURAND_MTGP32_HOST_H
100
+
101
+ #if !defined(QUALIFIERS)
102
+ #define QUALIFIERS static inline __device__
103
+ #endif
104
+
105
+ #include <cuda.h>
106
+ #include <stdlib.h>
107
+ #include <memory.h>
108
+ #include <string.h>
109
+ #include "curand.h"
110
+ #include "curand_mtgp32.h"
111
+ #include "curand_mtgp32dc_p_11213.h"
112
+
113
+
114
+ /**
115
+ * \addtogroup DEVICE Device API
116
+ *
117
+ * @{
118
+ */
119
+
120
+ static const unsigned int non_zero = 0x4d544750;
121
+
122
+ /*
123
+ * This function represents a function used in the initialization
124
+ * by mtgp32_init_by_array() and mtgp32_init_by_str().
125
+ * @param[in] x 32-bit integer
126
+ * @return 32-bit integer
127
+ */
128
+ static __forceinline__ unsigned int ini_func1(unsigned int x) {
129
+ return (x ^ (x >> 27)) * (1664525);
130
+ }
131
+
132
+ /*
133
+ * This function represents a function used in the initialization
134
+ * by mtgp32_init_by_array() and mtgp32_init_by_str().
135
+ * @param[in] x 32-bit integer
136
+ * @return 32-bit integer
137
+ */
138
+ static __forceinline__ unsigned int ini_func2(unsigned int x) {
139
+ return (x ^ (x >> 27)) * (1566083941);
140
+ }
141
+
142
+ /*
143
+ * This function initializes the internal state array with a 32-bit
144
+ * integer seed. The allocated memory should be freed by calling
145
+ * mtgp32_free(). \b para should be one of the elements in the
146
+ * parameter table (mtgp32-param-ref.c).
147
+ *
148
+ * This function is call by cuda program, because cuda program uses
149
+ * another structure and another allocation method.
150
+ *
151
+ * @param[out] array MTGP internal status vector.
152
+ * @param[in] para parameter structure
153
+ * @param[in] seed a 32-bit integer used as the seed.
154
+ */
155
+ static __forceinline__ __host__
156
+ void mtgp32_init_state(unsigned int state[],
157
+ const mtgp32_params_fast_t *para, unsigned int seed) {
158
+ int i;
159
+ int size = para->mexp / 32 + 1;
160
+ unsigned int hidden_seed;
161
+ unsigned int tmp;
162
+ hidden_seed = para->tbl[4] ^ (para->tbl[8] << 16);
163
+ tmp = hidden_seed;
164
+ tmp += tmp >> 16;
165
+ tmp += tmp >> 8;
166
+ memset(state, tmp & 0xff, sizeof(unsigned int) * size);
167
+ state[0] = seed;
168
+ state[1] = hidden_seed;
169
+ for (i = 1; i < size; i++) {
170
+ state[i] ^= (1812433253) * (state[i - 1] ^ (state[i - 1] >> 30)) + i;
171
+ }
172
+ }
173
+
174
+ /*
175
+ * This function initializes the internal state array
176
+ * with a 32-bit integer array. \b para should be one of the elements in
177
+ * the parameter table (mtgp32-param-ref.c).
178
+ *
179
+ * @param[out] mtgp32 MTGP structure.
180
+ * @param[in] para parameter structure
181
+ * @param[in] array a 32-bit integer array used as a seed.
182
+ * @param[in] length length of the array.
183
+ * @return CURAND_STATUS_SUCCESS
184
+ */
185
+ static __forceinline__ __host__
186
+ int mtgp32_init_by_array(unsigned int state[],
187
+ const mtgp32_params_fast_t *para,
188
+ unsigned int *array, int length) {
189
+ int i, j, count;
190
+ unsigned int r;
191
+ int lag;
192
+ int mid;
193
+ int size = para->mexp / 32 + 1;
194
+ unsigned int hidden_seed;
195
+ unsigned int tmp;
196
+
197
+ if (size >= 623) {
198
+ lag = 11;
199
+ } else if (size >= 68) {
200
+ lag = 7;
201
+ } else if (size >= 39) {
202
+ lag = 5;
203
+ } else {
204
+ lag = 3;
205
+ }
206
+ mid = (size - lag) / 2;
207
+
208
+ hidden_seed = para->tbl[4] ^ (para->tbl[8] << 16);
209
+ tmp = hidden_seed;
210
+ tmp += tmp >> 16;
211
+ tmp += tmp >> 8;
212
+ memset(state, tmp & 0xff, sizeof(unsigned int) * size);
213
+ state[0] = hidden_seed;
214
+
215
+ if (length + 1 > size) {
216
+ count = length + 1;
217
+ } else {
218
+ count = size;
219
+ }
220
+ r = ini_func1(state[0] ^ state[mid] ^ state[size - 1]);
221
+ state[mid] += r;
222
+ r += length;
223
+ state[(mid + lag) % size] += r;
224
+ state[0] = r;
225
+ i = 1;
226
+ count--;
227
+ for (i = 1, j = 0; (j < count) && (j < length); j++) {
228
+ r = ini_func1(state[i] ^ state[(i + mid) % size]
229
+ ^ state[(i + size - 1) % size]);
230
+ state[(i + mid) % size] += r;
231
+ r += array[j] + i;
232
+ state[(i + mid + lag) % size] += r;
233
+ state[i] = r;
234
+ i = (i + 1) % size;
235
+ }
236
+ for (; j < count; j++) {
237
+ r = ini_func1(state[i] ^ state[(i + mid) % size]
238
+ ^ state[(i + size - 1) % size]);
239
+ state[(i + mid) % size] += r;
240
+ r += i;
241
+ state[(i + mid + lag) % size] += r;
242
+ state[i] = r;
243
+ i = (i + 1) % size;
244
+ }
245
+ for (j = 0; j < size; j++) {
246
+ r = ini_func2(state[i] + state[(i + mid) % size]
247
+ + state[(i + size - 1) % size]);
248
+ state[(i + mid) % size] ^= r;
249
+ r -= i;
250
+ state[(i + mid + lag) % size] ^= r;
251
+ state[i] = r;
252
+ i = (i + 1) % size;
253
+ }
254
+ if (state[size - 1] == 0) {
255
+ state[size - 1] = non_zero;
256
+ }
257
+ return 0;
258
+ }
259
+
260
+ /*
261
+ * This function initializes the internal state array
262
+ * with a character array. \b para should be one of the elements in
263
+ * the parameter table (mtgp32-param-ref.c).
264
+ * This is the same algorithm with mtgp32_init_by_array(), but hope to
265
+ * be more useful.
266
+ *
267
+ * @param[out] mtgp32 MTGP structure.
268
+ * @param[in] para parameter structure
269
+ * @param[in] array a character array used as a seed. (terminated by zero.)
270
+ * @return memory allocation result. if 0 then O.K.
271
+ */
272
+ static __forceinline__ __host__
273
+ int mtgp32_init_by_str(unsigned int state[],
274
+ const mtgp32_params_fast_t *para, unsigned char *array) {
275
+ int i, j, count;
276
+ unsigned int r;
277
+ int lag;
278
+ int mid;
279
+ int size = para->mexp / 32 + 1;
280
+ int length = (unsigned int)strlen((char *)array);
281
+ unsigned int hidden_seed;
282
+ unsigned int tmp;
283
+
284
+ if (size >= 623) {
285
+ lag = 11;
286
+ } else if (size >= 68) {
287
+ lag = 7;
288
+ } else if (size >= 39) {
289
+ lag = 5;
290
+ } else {
291
+ lag = 3;
292
+ }
293
+ mid = (size - lag) / 2;
294
+
295
+ hidden_seed = para->tbl[4] ^ (para->tbl[8] << 16);
296
+ tmp = hidden_seed;
297
+ tmp += tmp >> 16;
298
+ tmp += tmp >> 8;
299
+ memset(state, tmp & 0xff, sizeof(unsigned int) * size);
300
+ state[0] = hidden_seed;
301
+
302
+ if (length + 1 > size) {
303
+ count = length + 1;
304
+ } else {
305
+ count = size;
306
+ }
307
+ r = ini_func1(state[0] ^ state[mid] ^ state[size - 1]);
308
+ state[mid] += r;
309
+ r += length;
310
+ state[(mid + lag) % size] += r;
311
+ state[0] = r;
312
+ i = 1;
313
+ count--;
314
+ for (i = 1, j = 0; (j < count) && (j < length); j++) {
315
+ r = ini_func1(state[i] ^ state[(i + mid) % size]
316
+ ^ state[(i + size - 1) % size]);
317
+ state[(i + mid) % size] += r;
318
+ r += array[j] + i;
319
+ state[(i + mid + lag) % size] += r;
320
+ state[i] = r;
321
+ i = (i + 1) % size;
322
+ }
323
+ for (; j < count; j++) {
324
+ r = ini_func1(state[i] ^ state[(i + mid) % size]
325
+ ^ state[(i + size - 1) % size]);
326
+ state[(i + mid) % size] += r;
327
+ r += i;
328
+ state[(i + mid + lag) % size] += r;
329
+ state[i] = r;
330
+ i = (i + 1) % size;
331
+ }
332
+ for (j = 0; j < size; j++) {
333
+ r = ini_func2(state[i] + state[(i + mid) % size]
334
+ + state[(i + size - 1) % size]);
335
+ state[(i + mid) % size] ^= r;
336
+ r -= i;
337
+ state[(i + mid + lag) % size] ^= r;
338
+ state[i] = r;
339
+ i = (i + 1) % size;
340
+ }
341
+ if (state[size - 1] == 0) {
342
+ state[size - 1] = non_zero;
343
+ }
344
+ return 0;
345
+ }
346
+
347
+ template<typename ParamsType>
348
+ static __forceinline__ __host__
349
+ curandStatus_t curandMakeMTGP32ConstantsImpl(const mtgp32_params_fast_t params[], ParamsType * p, const int block_num)
350
+ {
351
+ const int size1 = sizeof(unsigned int) * block_num;
352
+ const int size2 = sizeof(unsigned int) * block_num * TBL_SIZE;
353
+ unsigned int *h_pos_tbl;
354
+ unsigned int *h_sh1_tbl;
355
+ unsigned int *h_sh2_tbl;
356
+ unsigned int *h_param_tbl;
357
+ unsigned int *h_temper_tbl;
358
+ unsigned int *h_single_temper_tbl;
359
+ unsigned int *h_mask;
360
+ curandStatus_t status = CURAND_STATUS_SUCCESS;
361
+
362
+ h_pos_tbl = (unsigned int *)malloc(size1);
363
+ h_sh1_tbl = (unsigned int *)malloc(size1);
364
+ h_sh2_tbl = (unsigned int *)malloc(size1);
365
+ h_param_tbl = (unsigned int *)malloc(size2);
366
+ h_temper_tbl = (unsigned int *)malloc(size2);
367
+ h_single_temper_tbl = (unsigned int *)malloc(size2);
368
+ h_mask = (unsigned int *)malloc(sizeof(unsigned int));
369
+ if (h_pos_tbl == NULL
370
+ || h_sh1_tbl == NULL
371
+ || h_sh2_tbl == NULL
372
+ || h_param_tbl == NULL
373
+ || h_temper_tbl == NULL
374
+ || h_single_temper_tbl == NULL
375
+ || h_mask == NULL) {
376
+ if (h_pos_tbl != NULL) free(h_pos_tbl);
377
+ if (h_sh1_tbl != NULL) free(h_sh1_tbl);
378
+ if (h_sh2_tbl != NULL) free(h_sh2_tbl);
379
+ if (h_param_tbl != NULL) free(h_param_tbl);
380
+ if (h_temper_tbl != NULL) free(h_temper_tbl);
381
+ if (h_single_temper_tbl != NULL) free(h_single_temper_tbl);
382
+ if (h_mask != NULL) free(h_mask);
383
+ status = CURAND_STATUS_ALLOCATION_FAILED;
384
+ } else {
385
+
386
+ h_mask[0] = params[0].mask;
387
+ for (int i = 0; i < block_num; i++) {
388
+ h_pos_tbl[i] = params[i].pos;
389
+ h_sh1_tbl[i] = params[i].sh1;
390
+ h_sh2_tbl[i] = params[i].sh2;
391
+ for (int j = 0; j < TBL_SIZE; j++) {
392
+ h_param_tbl[i * TBL_SIZE + j] = params[i].tbl[j];
393
+ h_temper_tbl[i * TBL_SIZE + j] = params[i].tmp_tbl[j];
394
+ h_single_temper_tbl[i * TBL_SIZE + j] = params[i].flt_tmp_tbl[j];
395
+ }
396
+ }
397
+ if (cudaMemcpy( p->pos_tbl,
398
+ h_pos_tbl, size1, cudaMemcpyHostToDevice) != cudaSuccess)
399
+ {
400
+ status = CURAND_STATUS_INITIALIZATION_FAILED;
401
+ } else
402
+ if (cudaMemcpy( p->sh1_tbl,
403
+ h_sh1_tbl, size1, cudaMemcpyHostToDevice) != cudaSuccess)
404
+ {
405
+ status = CURAND_STATUS_INITIALIZATION_FAILED;
406
+ } else
407
+ if (cudaMemcpy( p->sh2_tbl,
408
+ h_sh2_tbl, size1, cudaMemcpyHostToDevice) != cudaSuccess)
409
+ {
410
+ status = CURAND_STATUS_INITIALIZATION_FAILED;
411
+ } else
412
+ if (cudaMemcpy( p->param_tbl,
413
+ h_param_tbl, size2, cudaMemcpyHostToDevice) != cudaSuccess)
414
+ {
415
+ status = CURAND_STATUS_INITIALIZATION_FAILED;
416
+ } else
417
+ if (cudaMemcpy( p->temper_tbl,
418
+ h_temper_tbl, size2, cudaMemcpyHostToDevice) != cudaSuccess)
419
+ {
420
+ status = CURAND_STATUS_INITIALIZATION_FAILED;
421
+ } else
422
+ if (cudaMemcpy( p->single_temper_tbl,
423
+ h_single_temper_tbl, size2, cudaMemcpyHostToDevice) != cudaSuccess)
424
+ {
425
+ status = CURAND_STATUS_INITIALIZATION_FAILED;
426
+ } else
427
+ if (cudaMemcpy( p->mask,
428
+ h_mask, sizeof(unsigned int), cudaMemcpyHostToDevice) != cudaSuccess)
429
+ {
430
+ status = CURAND_STATUS_INITIALIZATION_FAILED;
431
+ }
432
+ }
433
+ if (h_pos_tbl != NULL) free(h_pos_tbl);
434
+ if (h_sh1_tbl != NULL) free(h_sh1_tbl);
435
+ if (h_sh2_tbl != NULL) free(h_sh2_tbl);
436
+ if (h_param_tbl != NULL) free(h_param_tbl);
437
+ if (h_temper_tbl != NULL) free(h_temper_tbl);
438
+ if (h_single_temper_tbl != NULL)free(h_single_temper_tbl);
439
+ if (h_mask != NULL) free(h_mask);
440
+ return status;
441
+ }
442
+
443
+ /**
444
+ * \brief Set up constant parameters for the mtgp32 generator
445
+ *
446
+ * This host-side helper function re-organizes CURAND_NUM_MTGP32_PARAMS sets of
447
+ * generator parameters for use by kernel functions and copies the
448
+ * result to the specified location in device memory.
449
+ *
450
+ * \param params - Pointer to an array of type mtgp32_params_fast_t in host memory
451
+ * \param p - pointer to a structure of type mtgp32_kernel_params_t in device memory.
452
+ *
453
+ * \return
454
+ * - CURAND_STATUS_ALLOCATION_FAILED if host memory could not be allocated
455
+ * - CURAND_STATUS_INITIALIZATION_FAILED if the copy to device memory failed
456
+ * - CURAND_STATUS_SUCCESS otherwise
457
+ */
458
+ static __forceinline__ __host__
459
+ curandStatus_t curandMakeMTGP32Constants(const mtgp32_params_fast_t params[], mtgp32_kernel_params_t * p)
460
+ {
461
+ return curandMakeMTGP32ConstantsImpl(params, p, CURAND_NUM_MTGP32_PARAMS);
462
+ }
463
+
464
+ /**
465
+ * \brief Set up initial states for the mtgp32 generator
466
+ *
467
+ * This host-side helper function initializes a number of states (one parameter set per state) for
468
+ * an mtgp32 generator. To accomplish this it allocates a state array in host memory,
469
+ * initializes that array, and copies the result to device memory.
470
+ *
471
+ * \param s - pointer to an array of states in device memory
472
+ * \param params - Pointer to an array of type mtgp32_params_fast_t in host memory
473
+ * \param k - pointer to a structure of type mtgp32_kernel_params_t in device memory
474
+ * \param n - number of parameter sets/states to initialize
475
+ * \param seed - seed value
476
+ *
477
+ * \return
478
+ * - CURAND_STATUS_ALLOCATION_FAILED if host memory state could not be allocated
479
+ * - CURAND_STATUS_INITIALIZATION_FAILED if the copy to device memory failed
480
+ * - CURAND_STATUS_SUCCESS otherwise
481
+ */
482
+ static __forceinline__ __host__
483
+ curandStatus_t CURANDAPI curandMakeMTGP32KernelState(curandStateMtgp32_t *s,
484
+ mtgp32_params_fast_t params[],
485
+ mtgp32_kernel_params_t *k,
486
+ int n,
487
+ unsigned long long seed)
488
+ {
489
+ int i;
490
+ curandStatus_t status = CURAND_STATUS_SUCCESS;
491
+ curandStateMtgp32_t *h_status =(curandStateMtgp32_t *) malloc(sizeof(curandStateMtgp32_t) * n);
492
+ if (h_status == NULL) {
493
+ status = CURAND_STATUS_ALLOCATION_FAILED;
494
+ } else {
495
+ seed = seed ^ (seed >> 32);
496
+ for (i = 0; i < n; i++) {
497
+ mtgp32_init_state(&(h_status[i].s[0]), &params[i],(unsigned int)seed + i + 1);
498
+ h_status[i].offset = 0;
499
+ h_status[i].pIdx = i;
500
+ h_status[i].k = k;
501
+ }
502
+ if (cudaMemcpy(s, h_status,
503
+ sizeof(curandStateMtgp32_t) * n,
504
+ cudaMemcpyHostToDevice) != cudaSuccess) {
505
+ status = CURAND_STATUS_INITIALIZATION_FAILED;
506
+ }
507
+ }
508
+ free(h_status);
509
+ return status;
510
+ }
511
+
512
+ /** @} */
513
+
514
+ #endif
515
+
516
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/curand_poisson.h ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /* Copyright 2010-2014 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * The source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * The Licensed Deliverables contained herein are PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+
51
+ #if !defined(CURAND_POISSON_H_)
52
+ #define CURAND_POISSON_H_
53
+
54
+ /**
55
+ * \defgroup DEVICE Device API
56
+ *
57
+ * @{
58
+ */
59
+
60
+ #ifndef __CUDACC_RTC__
61
+ #include <math.h>
62
+ #endif // __CUDACC_RTC__
63
+
64
+ #include "curand_mrg32k3a.h"
65
+ #include "curand_mtgp32_kernel.h"
66
+ #include "curand_philox4x32_x.h"
67
+
68
+ #define CR_CUDART_PI 3.1415926535897931e+0
69
+ #define CR_CUDART_TWO_TO_52 4503599627370496.0
70
+
71
+
72
+ QUALIFIERS float __cr_rsqrt(float a)
73
+ {
74
+ #ifdef __CUDA_ARCH__
75
+ asm ("rsqrt.approx.f32.ftz %0, %1;" : "=f"(a) : "f"(a));
76
+ #else
77
+ a = 1.0f / sqrtf (a);
78
+ #endif
79
+ return a;
80
+ }
81
+
82
+ QUALIFIERS float __cr_exp (float a)
83
+ {
84
+ #ifdef __CUDA_ARCH__
85
+ a = a * 1.4426950408889634074;
86
+ asm ("ex2.approx.f32.ftz %0, %1;" : "=f"(a) : "f"(a));
87
+ #else
88
+ a = expf (a);
89
+ #endif
90
+ return a;
91
+ }
92
+
93
+ QUALIFIERS float __cr_log (float a)
94
+ {
95
+ #ifdef __CUDA_ARCH__
96
+ asm ("lg2.approx.f32.ftz %0, %1;" : "=f"(a) : "f"(a));
97
+ a = a * 0.69314718055994530942;
98
+ #else
99
+ a = logf (a);
100
+ #endif
101
+ return a;
102
+ }
103
+
104
+ QUALIFIERS float __cr_rcp (float a)
105
+ {
106
+ #ifdef __CUDA_ARCH__
107
+ asm ("rcp.approx.f32.ftz %0, %1;" : "=f"(a) : "f"(a));
108
+ #else
109
+ a = 1.0f / a;
110
+ #endif
111
+ return a;
112
+ }
113
+
114
+ /* Computes regularized gamma function: gammainc(a,x)/gamma(a) */
115
+ QUALIFIERS float __cr_pgammainc (float a, float x)
116
+ {
117
+ float t, alpha, beta;
118
+
119
+ /* First level parametrization constants */
120
+ float ma1 = 1.43248035075540910f,
121
+ ma2 = 0.12400979329415655f,
122
+ ma3 = 0.00025361074907033f,
123
+ mb1 = 0.21096734870196546f,
124
+ mb2 = 1.97381164089999420f,
125
+ mb3 = 0.94201734077887530f;
126
+
127
+ /* Second level parametrization constants (depends only on a) */
128
+
129
+ alpha = __cr_rsqrt (a - ma2);
130
+ alpha = ma1 * alpha + ma3;
131
+ beta = __cr_rsqrt (a - mb2);
132
+ beta = mb1 * beta + mb3;
133
+
134
+ /* Final approximation (depends on a and x) */
135
+
136
+ t = a - x;
137
+ t = alpha * t - beta;
138
+ t = 1.0f + __cr_exp (t);
139
+ t = t * t;
140
+ t = __cr_rcp (t);
141
+
142
+ /* Negative a,x or a,x=NAN requires special handling */
143
+ //t = !(x > 0 && a >= 0) ? 0.0 : t;
144
+
145
+ return t;
146
+ }
147
+
148
+ /* Computes inverse of pgammainc */
149
+ QUALIFIERS float __cr_pgammaincinv (float a, float y)
150
+ {
151
+ float t, alpha, beta;
152
+
153
+ /* First level parametrization constants */
154
+
155
+ float ma1 = 1.43248035075540910f,
156
+ ma2 = 0.12400979329415655f,
157
+ ma3 = 0.00025361074907033f,
158
+ mb1 = 0.21096734870196546f,
159
+ mb2 = 1.97381164089999420f,
160
+ mb3 = 0.94201734077887530f;
161
+
162
+ /* Second level parametrization constants (depends only on a) */
163
+
164
+ alpha = __cr_rsqrt (a - ma2);
165
+ alpha = ma1 * alpha + ma3;
166
+ beta = __cr_rsqrt (a - mb2);
167
+ beta = mb1 * beta + mb3;
168
+
169
+ /* Final approximation (depends on a and y) */
170
+
171
+ t = __cr_rsqrt (y) - 1.0f;
172
+ t = __cr_log (t);
173
+ t = beta + t;
174
+ t = - t * __cr_rcp (alpha) + a;
175
+ /* Negative a,x or a,x=NAN requires special handling */
176
+ //t = !(y > 0 && a >= 0) ? 0.0 : t;
177
+ return t;
178
+ }
179
+
180
+ #if defined(__CUDACC_RDC__) && (__cplusplus >= 201703L) && defined(__cpp_inline_variables)
181
+ inline __constant__ double __cr_lgamma_table [] = {
182
+ #else
183
+ static __constant__ double __cr_lgamma_table [] = {
184
+ #endif
185
+ 0.000000000000000000e-1,
186
+ 0.000000000000000000e-1,
187
+ 6.931471805599453094e-1,
188
+ 1.791759469228055001e0,
189
+ 3.178053830347945620e0,
190
+ 4.787491742782045994e0,
191
+ 6.579251212010100995e0,
192
+ 8.525161361065414300e0,
193
+ 1.060460290274525023e1
194
+ };
195
+
196
+
197
+ QUALIFIERS double __cr_lgamma_integer(int a)
198
+ {
199
+ double s;
200
+ double t;
201
+ double fa = fabs((float)a);
202
+ double sum;
203
+
204
+ if (a > 8) {
205
+ /* Stirling approximation; coefficients from Hart et al, "Computer
206
+ * Approximations", Wiley 1968. Approximation 5404.
207
+ */
208
+ s = 1.0 / fa;
209
+ t = s * s;
210
+ sum = -0.1633436431e-2;
211
+ sum = sum * t + 0.83645878922e-3;
212
+ sum = sum * t - 0.5951896861197e-3;
213
+ sum = sum * t + 0.793650576493454e-3;
214
+ sum = sum * t - 0.277777777735865004e-2;
215
+ sum = sum * t + 0.833333333333331018375e-1;
216
+ sum = sum * s + 0.918938533204672;
217
+ s = 0.5 * log (fa);
218
+ t = fa - 0.5;
219
+ s = s * t;
220
+ t = s - fa;
221
+ s = s + sum;
222
+ t = t + s;
223
+ return t;
224
+ } else {
225
+ #ifdef __CUDA_ARCH__
226
+ return __cr_lgamma_table [(int) fa-1];
227
+ #else
228
+ switch(a) {
229
+ case 1: return 0.000000000000000000e-1;
230
+ case 2: return 0.000000000000000000e-1;
231
+ case 3: return 6.931471805599453094e-1;
232
+ case 4: return 1.791759469228055001e0;
233
+ case 5: return 3.178053830347945620e0;
234
+ case 6: return 4.787491742782045994e0;
235
+ case 7: return 6.579251212010100995e0;
236
+ case 8: return 8.525161361065414300e0;
237
+ default: return 1.060460290274525023e1;
238
+ }
239
+ #endif
240
+ }
241
+ }
242
+
243
+ #define KNUTH_FLOAT_CONST 60.0
244
+ template <typename T>
245
+ // Donald E. Knuth Seminumerical Algorithms. The Art of Computer Programming, Volume 2
246
+ QUALIFIERS unsigned int curand_poisson_knuth(T *state, float lambda)
247
+ {
248
+ unsigned int k = 0;
249
+ float p = expf(lambda);
250
+ do{
251
+ k++;
252
+ p *= curand_uniform(state);
253
+ }while (p > 1.0);
254
+ return k-1;
255
+ }
256
+
257
+ template <typename T>
258
+ // Donald E. Knuth Seminumerical Algorithms. The Art of Computer Programming, Volume 2
259
+ QUALIFIERS uint4 curand_poisson_knuth4(T *state, float lambda)
260
+ {
261
+ uint4 k = {0,0,0,0};
262
+ float exp_lambda = expf(lambda);
263
+ float4 p={ exp_lambda,exp_lambda,exp_lambda,exp_lambda };
264
+ do{
265
+ k.x++;
266
+ p.x *= curand_uniform(state);
267
+ }while (p.x > 1.0);
268
+ do{
269
+ k.y++;
270
+ p.y *= curand_uniform(state);
271
+ }while (p.y > 1.0);
272
+ do{
273
+ k.z++;
274
+ p.z *= curand_uniform(state);
275
+ }while (p.z > 1.0);
276
+ do{
277
+ k.w++;
278
+ p.w *= curand_uniform(state);
279
+ }while (p.w > 1.0);
280
+
281
+ k.x--;
282
+ k.y--;
283
+ k.z--;
284
+ k.w--;
285
+ return k;
286
+ }
287
+
288
+ template <typename T>
289
+ // Marsaglia, Tsang, Wang Journal of Statistical Software, square histogram.
290
+ QUALIFIERS unsigned int _curand_M2_double(T x, curandDistributionM2Shift_t distributionM2)
291
+ {
292
+ double u = _curand_uniform_double(x);
293
+ int j = (int) floor(distributionM2->length*u);
294
+
295
+
296
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)
297
+ double histogramVj = __ldg( &(distributionM2->histogram->V[j]));
298
+ unsigned int histogramKj = __ldg( &(distributionM2->histogram->K[j]));
299
+ #else
300
+ double histogramVj = distributionM2->histogram->V[j];
301
+ unsigned int histogramKj = distributionM2->histogram->K[j];
302
+ #endif
303
+ //if (u < distributionM2->histogram->V[j]) return distributionM2->shift + j;
304
+ //return distributionM2->shift + distributionM2->histogram->K[j];
305
+ if (u < histogramVj) return distributionM2->shift + j;
306
+ return distributionM2->shift + histogramKj;
307
+ }
308
+
309
+ template <typename T>
310
+ // Marsaglia, Tsang, Wang Journal of Statistical Software, square histogram.
311
+ QUALIFIERS uint4 _curand_M2_double4(T x, curandDistributionM2Shift_t distributionM2)
312
+ {
313
+ double4 u;
314
+ uint4 result = {0,0,0,0};
315
+ int4 flag = {1,1,1,1};
316
+
317
+ u.x = _curand_uniform_double(x.x);
318
+ u.y = _curand_uniform_double(x.y);
319
+ u.z = _curand_uniform_double(x.z);
320
+ u.w = _curand_uniform_double(x.w);
321
+
322
+ int4 j;
323
+ j.x = (int) floor(distributionM2->length*u.x);
324
+ j.y = (int) floor(distributionM2->length*u.y);
325
+ j.z = (int) floor(distributionM2->length*u.z);
326
+ j.w = (int) floor(distributionM2->length*u.w);
327
+ // int result;
328
+
329
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)
330
+ double histogramVjx = __ldg( &(distributionM2->histogram->V[j.x]));
331
+ double histogramVjy = __ldg( &(distributionM2->histogram->V[j.y]));
332
+ double histogramVjz = __ldg( &(distributionM2->histogram->V[j.z]));
333
+ double histogramVjw = __ldg( &(distributionM2->histogram->V[j.w]));
334
+
335
+ unsigned int histogramKjx = __ldg( &(distributionM2->histogram->K[j.x]));
336
+ unsigned int histogramKjy = __ldg( &(distributionM2->histogram->K[j.y]));
337
+ unsigned int histogramKjz = __ldg( &(distributionM2->histogram->K[j.z]));
338
+ unsigned int histogramKjw = __ldg( &(distributionM2->histogram->K[j.w]));
339
+ #else
340
+ double histogramVjx = distributionM2->histogram->V[j.x];
341
+ double histogramVjy = distributionM2->histogram->V[j.y];
342
+ double histogramVjz = distributionM2->histogram->V[j.z];
343
+ double histogramVjw = distributionM2->histogram->V[j.w];
344
+
345
+ unsigned int histogramKjx = distributionM2->histogram->K[j.x];
346
+ unsigned int histogramKjy = distributionM2->histogram->K[j.y];
347
+ unsigned int histogramKjz = distributionM2->histogram->K[j.z];
348
+ unsigned int histogramKjw = distributionM2->histogram->K[j.w];
349
+ #endif
350
+
351
+ if (u.x < histogramVjx){ result.x = distributionM2->shift + j.x; flag.x = 0; }
352
+ if (u.y < histogramVjy){ result.y = distributionM2->shift + j.y; flag.y = 0; }
353
+ if (u.z < histogramVjz){ result.z = distributionM2->shift + j.z; flag.z = 0; }
354
+ if (u.w < histogramVjw){ result.w = distributionM2->shift + j.w; flag.w = 0; }
355
+ //return distributionM2->shift + distributionM2->histogram->K[j];
356
+
357
+ if(flag.x) result.x = distributionM2->shift + histogramKjx;
358
+ if(flag.y) result.y = distributionM2->shift + histogramKjy;
359
+ if(flag.z) result.z = distributionM2->shift + histogramKjz;
360
+ if(flag.w) result.w = distributionM2->shift + histogramKjw;
361
+
362
+ return result;
363
+ }
364
+
365
+ template <typename STATE>
366
+ QUALIFIERS unsigned int curand_M2_double(STATE *state, curandDistributionM2Shift_t distributionM2)
367
+ {
368
+ return _curand_M2_double(curand(state), distributionM2);
369
+ }
370
+
371
+ template <typename STATE>
372
+ QUALIFIERS uint4 curand_M2_double4(STATE *state, curandDistributionM2Shift_t distributionM2)
373
+ {
374
+ return _curand_M2_double4(curand4(state), distributionM2);
375
+ }
376
+
377
+
378
+ template <typename T>
379
+ QUALIFIERS unsigned int _curand_binary_search_double(T x, curandDistributionShift_t distribution)
380
+ {
381
+ double u = _curand_uniform_double(x);
382
+ int min = 0;
383
+ int max = distribution->length-1;
384
+ do{
385
+ int mid = (max + min)/2;
386
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)
387
+ double probability_mid = __ldg( &(distribution->probability[mid]));
388
+ #else
389
+ double probability_mid = distribution->probability[mid];
390
+ #endif
391
+ if (u <= probability_mid){
392
+ max = mid;
393
+ }else{
394
+ min = mid+1;
395
+ }
396
+ }while (min < max);
397
+ return distribution->shift + min;
398
+ }
399
+
400
+ template <typename STATE>
401
+ QUALIFIERS unsigned int curand_binary_search_double(STATE *state, curandDistributionShift_t distribution)
402
+ {
403
+ return _curand_binary_search_double(curand(state), distribution);
404
+ }
405
+
406
+ // Generates uniformly distributed double values in range (0.0; 1.0) from uniformly distributed
407
+ // unsigned int. We can't use standard _curand_uniform_double since it can generate 1.0.
408
+ // This is required only for _curand_poisson_ITR_double.
409
+ QUALIFIERS double _curand_uniform_double_excluding_one(unsigned int x)
410
+ {
411
+ return x * CURAND_2POW32_INV_DOUBLE + (CURAND_2POW32_INV_DOUBLE/2.0);
412
+ }
413
+
414
+ // Overload for unsigned long long.
415
+ // This is required only for _curand_poisson_ITR_double.
416
+ QUALIFIERS double _curand_uniform_double_excluding_one(unsigned long long x)
417
+ {
418
+ return (x >> 11) * CURAND_2POW53_INV_DOUBLE + (CURAND_2POW53_INV_DOUBLE/4.0);
419
+ }
420
+
421
+ #define MAGIC_DOUBLE_CONST 500.0
422
+ template <typename T>
423
+ //George S. Fishman Discrete-event simulation: modeling, programming, and analysis
424
+ QUALIFIERS unsigned int _curand_poisson_ITR_double(T x, double lambda)
425
+ {
426
+ double L,p = 1.0;
427
+ double q = 1.0;
428
+ unsigned int k = 0;
429
+ int pow=0;
430
+ // This algorithm requires u to be in (0;1) range, however, _curand_uniform_double
431
+ // returns a number in range (0;1]. If u is 1.0 the inner loop never ends. The
432
+ // following operation transforms the range from (0;1] to (0;1).
433
+ double u = _curand_uniform_double_excluding_one(x);
434
+ do{
435
+ if (lambda > (double)(pow+MAGIC_DOUBLE_CONST)){
436
+ L = exp(-MAGIC_DOUBLE_CONST);
437
+ }else{
438
+ L = exp((double)(pow - lambda));
439
+ }
440
+ p *= L;
441
+ q *= L;
442
+ pow += (int) MAGIC_DOUBLE_CONST;
443
+ while (u > q){
444
+ k++;
445
+ p *= ((double)lambda / (double) k);
446
+ q += p;
447
+ }
448
+ }while((double)pow < lambda);
449
+ return k;
450
+ }
451
+
452
+ template <typename T>
453
+ /* Rejection Method for Poisson distribution based on gammainc approximation */
454
+ QUALIFIERS unsigned int curand_poisson_gammainc(T state, float lambda){
455
+ float y, x, t, z,v;
456
+ float logl = __cr_log (lambda);
457
+ while (true) {
458
+ y = curand_uniform (state);
459
+ x = __cr_pgammaincinv (lambda, y);
460
+ x = floorf (x);
461
+ z = curand_uniform (state);
462
+ v = (__cr_pgammainc (lambda, x + 1.0f) - __cr_pgammainc (lambda, x)) * 1.3f;
463
+ z = z*v;
464
+ t = (float)__cr_exp (-lambda + x * logl - (float)__cr_lgamma_integer ((int)(1.0f + x)));
465
+ if ((z < t) && (v>=1e-20))
466
+ break;
467
+ }
468
+ return (unsigned int)x;
469
+ }
470
+
471
+ template <typename T>
472
+ /* Rejection Method for Poisson distribution based on gammainc approximation */
473
+ QUALIFIERS uint4 curand_poisson_gammainc4(T state, float lambda){
474
+ uint4 result;
475
+ float y, x, t, z,v;
476
+ float logl = __cr_log (lambda);
477
+ while (true) {
478
+ y = curand_uniform(state);
479
+ x = __cr_pgammaincinv (lambda, y);
480
+ x = floorf (x);
481
+ z = curand_uniform (state);
482
+ v = (__cr_pgammainc (lambda, x + 1.0f) - __cr_pgammainc (lambda, x)) * 1.3f;
483
+ z = z*v;
484
+ t = (float)__cr_exp (-lambda + x * logl - (float)__cr_lgamma_integer ((int)(1.0f + x)));
485
+ if ((z < t) && (v>=1e-20))
486
+ break;
487
+ }
488
+ result.x = (unsigned int)x;
489
+
490
+ while (true) {
491
+ y = curand_uniform(state);
492
+ x = __cr_pgammaincinv (lambda, y);
493
+ x = floorf (x);
494
+ z = curand_uniform (state);
495
+ v = (__cr_pgammainc (lambda, x + 1.0f) - __cr_pgammainc (lambda, x)) * 1.3f;
496
+ z = z*v;
497
+ t = (float)__cr_exp (-lambda + x * logl - (float)__cr_lgamma_integer ((int)(1.0f + x)));
498
+ if ((z < t) && (v>=1e-20))
499
+ break;
500
+ }
501
+ result.y = (unsigned int)x;
502
+
503
+ while (true) {
504
+ y = curand_uniform(state);
505
+ x = __cr_pgammaincinv (lambda, y);
506
+ x = floorf (x);
507
+ z = curand_uniform (state);
508
+ v = (__cr_pgammainc (lambda, x + 1.0f) - __cr_pgammainc (lambda, x)) * 1.3f;
509
+ z = z*v;
510
+ t = (float)__cr_exp (-lambda + x * logl - (float)__cr_lgamma_integer ((int)(1.0f + x)));
511
+ if ((z < t) && (v>=1e-20))
512
+ break;
513
+ }
514
+ result.z = (unsigned int)x;
515
+
516
+ while (true) {
517
+ y = curand_uniform(state);
518
+ x = __cr_pgammaincinv (lambda, y);
519
+ x = floorf (x);
520
+ z = curand_uniform (state);
521
+ v = (__cr_pgammainc (lambda, x + 1.0f) - __cr_pgammainc (lambda, x)) * 1.3f;
522
+ z = z*v;
523
+ t = (float)__cr_exp (-lambda + x * logl - (float)__cr_lgamma_integer ((int)(1.0f + x)));
524
+ if ((z < t) && (v>=1e-20))
525
+ break;
526
+ }
527
+ result.w = (unsigned int)x;
528
+
529
+ return result;
530
+ }
531
+ // Note below that the round to nearest integer, where needed,is done in line with code that
532
+ // assumes the range of values is < 2**32
533
+
534
+ template <typename T>
535
+ QUALIFIERS unsigned int _curand_poisson(T x, double lambda)
536
+ {
537
+ if (lambda < 1000)
538
+ return _curand_poisson_ITR_double(x, lambda);
539
+ return (unsigned int)((sqrt(lambda) * _curand_normal_icdf_double(x)) + lambda + 0.5); //Round to nearest
540
+ }
541
+
542
+ template <typename T>
543
+ QUALIFIERS unsigned int _curand_poisson_from_normal(T x, double lambda)
544
+ {
545
+ return (unsigned int)((sqrt(lambda) * _curand_normal_icdf(x)) + lambda + 0.5); //Round to nearest
546
+ }
547
+
548
+ template <typename STATE>
549
+ QUALIFIERS unsigned int curand_poisson_from_normal(STATE state, double lambda)
550
+ {
551
+ return (unsigned int)((sqrt(lambda) * curand_normal(state)) + lambda + 0.5); //Round to nearest
552
+ }
553
+
554
+ template <typename STATE>
555
+ QUALIFIERS uint4 curand_poisson_from_normal4(STATE state, double lambda)
556
+ {
557
+ uint4 result;
558
+ float4 _res;
559
+
560
+ _res = curand_normal4(state);
561
+
562
+ result.x = (unsigned int)((sqrt(lambda) * _res.x) + lambda + 0.5); //Round to nearest
563
+ result.y = (unsigned int)((sqrt(lambda) * _res.y) + lambda + 0.5); //Round to nearest
564
+ result.z = (unsigned int)((sqrt(lambda) * _res.z) + lambda + 0.5); //Round to nearest
565
+ result.w = (unsigned int)((sqrt(lambda) * _res.w) + lambda + 0.5); //Round to nearest
566
+ return result; //Round to nearest
567
+ }
568
+
569
+ /**
570
+ * \brief Return a Poisson-distributed unsigned int from a XORWOW generator.
571
+ *
572
+ * Return a single unsigned int from a Poisson
573
+ * distribution with lambda \p lambda from the XORWOW generator in \p state,
574
+ * increment the position of the generator by a variable amount, depending
575
+ * on the algorithm used.
576
+ *
577
+ * \param state - Pointer to state to update
578
+ * \param lambda - Lambda of the Poisson distribution
579
+ *
580
+ * \return Poisson-distributed unsigned int with lambda \p lambda
581
+ */
582
+ QUALIFIERS unsigned int curand_poisson(curandStateXORWOW_t *state, double lambda)
583
+ {
584
+ if (lambda < 64)
585
+ return curand_poisson_knuth(state, (float)lambda);
586
+ if (lambda > 4000)
587
+ return (unsigned int)((sqrt(lambda) * curand_normal_double(state)) + lambda + 0.5); //Round to nearest
588
+ return curand_poisson_gammainc(state, (float)lambda);
589
+ }
590
+
591
+ /**
592
+ * \brief Return a Poisson-distributed unsigned int from a Philox4_32_10 generator.
593
+ *
594
+ * Return a single unsigned int from a Poisson
595
+ * distribution with lambda \p lambda from the Philox4_32_10 generator in \p state,
596
+ * increment the position of the generator by a variable amount, depending
597
+ * on the algorithm used.
598
+ *
599
+ * \param state - Pointer to state to update
600
+ * \param lambda - Lambda of the Poisson distribution
601
+ *
602
+ * \return Poisson-distributed unsigned int with lambda \p lambda
603
+ */
604
+ QUALIFIERS unsigned int curand_poisson(curandStatePhilox4_32_10_t *state, double lambda)
605
+ {
606
+ if (lambda < 64)
607
+ return curand_poisson_knuth(state, (float)lambda);
608
+ if (lambda > 4000)
609
+ return (unsigned int)((sqrt(lambda) * curand_normal_double(state)) + lambda + 0.5); //Round to nearest
610
+ return curand_poisson_gammainc(state, (float)lambda);
611
+ }
612
+ /**
613
+ * \brief Return four Poisson-distributed unsigned ints from a Philox4_32_10 generator.
614
+ *
615
+ * Return a four unsigned ints from a Poisson
616
+ * distribution with lambda \p lambda from the Philox4_32_10 generator in \p state,
617
+ * increment the position of the generator by a variable amount, depending
618
+ * on the algorithm used.
619
+ *
620
+ * \param state - Pointer to state to update
621
+ * \param lambda - Lambda of the Poisson distribution
622
+ *
623
+ * \return Poisson-distributed unsigned int with lambda \p lambda
624
+ */
625
+ QUALIFIERS uint4 curand_poisson4(curandStatePhilox4_32_10_t *state, double lambda)
626
+ {
627
+ uint4 result;
628
+ double4 _res;
629
+ if (lambda < 64)
630
+ return curand_poisson_knuth4(state, (float)lambda);
631
+ if (lambda > 4000) {
632
+ _res = curand_normal4_double(state);
633
+ result.x = (unsigned int)((sqrt(lambda) * _res.x) + lambda + 0.5); //Round to nearest
634
+ result.y = (unsigned int)((sqrt(lambda) * _res.y) + lambda + 0.5); //Round to nearest
635
+ result.z = (unsigned int)((sqrt(lambda) * _res.z) + lambda + 0.5); //Round to nearest
636
+ result.w = (unsigned int)((sqrt(lambda) * _res.w) + lambda + 0.5); //Round to nearest
637
+ return result;
638
+ }
639
+ return curand_poisson_gammainc4(state, (float)lambda);
640
+ }
641
+
642
+
643
+
644
+ /**
645
+ * \brief Return a Poisson-distributed unsigned int from a MRG32k3A generator.
646
+ *
647
+ * Return a single unsigned int from a Poisson
648
+ * distribution with lambda \p lambda from the MRG32k3a generator in \p state,
649
+ * increment the position of the generator by a variable amount, depending
650
+ * on the algorithm used.
651
+ *
652
+ * \param state - Pointer to state to update
653
+ * \param lambda - Lambda of the Poisson distribution
654
+ *
655
+ * \return Poisson-distributed unsigned int with lambda \p lambda
656
+ */
657
+ QUALIFIERS unsigned int curand_poisson(curandStateMRG32k3a_t *state, double lambda)
658
+ {
659
+ if (lambda < 64)
660
+ return curand_poisson_knuth(state, (float)lambda);
661
+ if (lambda > 4000)
662
+ return (unsigned int)((sqrt(lambda) * curand_normal_double(state)) + lambda + 0.5); //Round to nearest
663
+ return curand_poisson_gammainc(state, (float)lambda);
664
+ }
665
+
666
+ /**
667
+ * \brief Return a Poisson-distributed unsigned int from a MTGP32 generator.
668
+ *
669
+ * Return a single int from a Poisson
670
+ * distribution with lambda \p lambda from the MTGP32 generator in \p state,
671
+ * increment the position of the generator by one.
672
+ *
673
+ * \param state - Pointer to state to update
674
+ * \param lambda - Lambda of the Poisson distribution
675
+ *
676
+ * \return Poisson-distributed unsigned int with lambda \p lambda
677
+ */
678
+ QUALIFIERS unsigned int curand_poisson(curandStateMtgp32_t *state, double lambda)
679
+ {
680
+ return _curand_poisson(curand(state), lambda);
681
+ }
682
+
683
+ /**
684
+ * \brief Return a Poisson-distributed unsigned int from a Sobol32 generator.
685
+ *
686
+ * Return a single unsigned int from a Poisson
687
+ * distribution with lambda \p lambda from the Sobol32 generator in \p state,
688
+ * increment the position of the generator by one.
689
+ *
690
+ * \param state - Pointer to state to update
691
+ * \param lambda - Lambda of the Poisson distribution
692
+ *
693
+ * \return Poisson-distributed unsigned int with lambda \p lambda
694
+ */
695
+
696
+ QUALIFIERS unsigned int curand_poisson(curandStateSobol32_t *state, double lambda)
697
+ {
698
+ return _curand_poisson(curand(state), lambda);
699
+ }
700
+
701
+ /**
702
+ * \brief Return a Poisson-distributed unsigned int from a scrambled Sobol32 generator.
703
+ *
704
+ * Return a single unsigned int from a Poisson
705
+ * distribution with lambda \p lambda from the scrambled Sobol32 generator in \p state,
706
+ * increment the position of the generator by one.
707
+ *
708
+ * \param state - Pointer to state to update
709
+ * \param lambda - Lambda of the Poisson distribution
710
+ *
711
+ * \return Poisson-distributed unsigned int with lambda \p lambda
712
+ */
713
+ QUALIFIERS unsigned int curand_poisson(curandStateScrambledSobol32_t *state, double lambda)
714
+ {
715
+ return _curand_poisson(curand(state), lambda);
716
+ }
717
+
718
+ /**
719
+ * \brief Return a Poisson-distributed unsigned int from a Sobol64 generator.
720
+ *
721
+ * Return a single unsigned int from a Poisson
722
+ * distribution with lambda \p lambda from the Sobol64 generator in \p state,
723
+ * increment position of generator by one.
724
+ *
725
+ * \param state - Pointer to state to update
726
+ * \param lambda - Lambda of the Poisson distribution
727
+ *
728
+ * \return Poisson-distributed unsigned int with lambda \p lambda
729
+ */
730
+ QUALIFIERS unsigned int curand_poisson(curandStateSobol64_t *state, double lambda)
731
+ {
732
+ return _curand_poisson(curand(state), lambda);
733
+ }
734
+
735
+ /**
736
+ * \brief Return a Poisson-distributed unsigned int from a scrambled Sobol64 generator.
737
+ *
738
+ * Return a single unsigned int from a Poisson
739
+ * distribution with lambda \p lambda from the scrambled Sobol64 generator in \p state,
740
+ * increment position of generator by one.
741
+ *
742
+ * \param state - Pointer to state to update
743
+ * \param lambda - Lambda of the Poisson distribution
744
+ *
745
+ * \return Poisson-distributed unsigned int with lambda \p lambda
746
+ */
747
+ QUALIFIERS unsigned int curand_poisson(curandStateScrambledSobol64_t *state, double lambda)
748
+ {
749
+ return _curand_poisson(curand(state), lambda);
750
+ }
751
+ #endif // !defined(CURAND_POISSON_H_)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/curand/include/curand_uniform.h ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /* Copyright 2010-2018 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * The source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * The Licensed Deliverables contained herein are PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and are being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+
51
+ #if !defined(CURAND_UNIFORM_H_)
52
+ #define CURAND_UNIFORM_H_
53
+
54
+ /**
55
+ * \defgroup DEVICE Device API
56
+ *
57
+ * @{
58
+ */
59
+
60
+ #ifndef __CUDACC_RTC__
61
+ #include <math.h>
62
+ #endif // __CUDACC_RTC__
63
+
64
+ #include "curand_mrg32k3a.h"
65
+ #include "curand_mtgp32_kernel.h"
66
+ #include "curand_philox4x32_x.h"
67
+
68
+
69
+ QUALIFIERS float _curand_uniform(unsigned int x)
70
+ {
71
+ return x * CURAND_2POW32_INV + (CURAND_2POW32_INV/2.0f);
72
+ }
73
+
74
+ QUALIFIERS float4 _curand_uniform4(uint4 x)
75
+ {
76
+ float4 y;
77
+ y.x = x.x * CURAND_2POW32_INV + (CURAND_2POW32_INV/2.0f);
78
+ y.y = x.y * CURAND_2POW32_INV + (CURAND_2POW32_INV/2.0f);
79
+ y.z = x.z * CURAND_2POW32_INV + (CURAND_2POW32_INV/2.0f);
80
+ y.w = x.w * CURAND_2POW32_INV + (CURAND_2POW32_INV/2.0f);
81
+ return y;
82
+ }
83
+
84
+ QUALIFIERS float _curand_uniform(unsigned long long x)
85
+ {
86
+ unsigned int t;
87
+ t = (unsigned int)(x >> 32);
88
+ return t * CURAND_2POW32_INV + (CURAND_2POW32_INV/2.0f);
89
+ }
90
+
91
+ QUALIFIERS double _curand_uniform_double(unsigned int x)
92
+ {
93
+ return x * CURAND_2POW32_INV_DOUBLE + CURAND_2POW32_INV_DOUBLE;
94
+ }
95
+
96
+ QUALIFIERS double _curand_uniform_double(unsigned long long x)
97
+ {
98
+ return (x >> 11) * CURAND_2POW53_INV_DOUBLE + (CURAND_2POW53_INV_DOUBLE/2.0);
99
+ }
100
+
101
+ QUALIFIERS double _curand_uniform_double_hq(unsigned int x, unsigned int y)
102
+ {
103
+ unsigned long long z = (unsigned long long)x ^
104
+ ((unsigned long long)y << (53 - 32));
105
+ return z * CURAND_2POW53_INV_DOUBLE + (CURAND_2POW53_INV_DOUBLE/2.0);
106
+ }
107
+
108
+ QUALIFIERS float curand_uniform(curandStateTest_t *state)
109
+ {
110
+ return _curand_uniform(curand(state));
111
+ }
112
+
113
+ QUALIFIERS double curand_uniform_double(curandStateTest_t *state)
114
+ {
115
+ return _curand_uniform_double(curand(state));
116
+ }
117
+
118
+ /**
119
+ * \brief Return a uniformly distributed float from an XORWOW generator.
120
+ *
121
+ * Return a uniformly distributed float between \p 0.0f and \p 1.0f
122
+ * from the XORWOW generator in \p state, increment position of generator.
123
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
124
+ * point outputs are never returned.
125
+ *
126
+ * The implementation may use any number of calls to \p curand() to
127
+ * get enough random bits to create the return value. The current
128
+ * implementation uses one call.
129
+ *
130
+ * \param state - Pointer to state to update
131
+ *
132
+ * \return uniformly distributed float between \p 0.0f and \p 1.0f
133
+ */
134
+ QUALIFIERS float curand_uniform(curandStateXORWOW_t *state)
135
+ {
136
+ return _curand_uniform(curand(state));
137
+ }
138
+
139
+ /**
140
+ * \brief Return a uniformly distributed double from an XORWOW generator.
141
+ *
142
+ * Return a uniformly distributed double between \p 0.0 and \p 1.0
143
+ * from the XORWOW generator in \p state, increment position of generator.
144
+ * Output range excludes \p 0.0 but includes \p 1.0. Denormalized floating
145
+ * point outputs are never returned.
146
+ *
147
+ * The implementation may use any number of calls to \p curand() to
148
+ * get enough random bits to create the return value. The current
149
+ * implementation uses exactly two calls.
150
+ *
151
+ * \param state - Pointer to state to update
152
+ *
153
+ * \return uniformly distributed double between \p 0.0 and \p 1.0
154
+ */
155
+ QUALIFIERS double curand_uniform_double(curandStateXORWOW_t *state)
156
+ {
157
+ unsigned int x, y;
158
+ x = curand(state);
159
+ y = curand(state);
160
+ return _curand_uniform_double_hq(x, y);
161
+ }
162
+ /**
163
+ * \brief Return a uniformly distributed float from an MRG32k3a generator.
164
+ *
165
+ * Return a uniformly distributed float between \p 0.0f and \p 1.0f
166
+ * from the MRG32k3a generator in \p state, increment position of generator.
167
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
168
+ * point outputs are never returned.
169
+ *
170
+ * The implementation returns up to 23 bits of mantissa, with the minimum
171
+ * return value \f$ 2^{-32} \f$
172
+ *
173
+ * \param state - Pointer to state to update
174
+ *
175
+ * \return uniformly distributed float between \p 0.0f and \p 1.0f
176
+ */
177
+ QUALIFIERS float curand_uniform(curandStateMRG32k3a_t *state)
178
+ {
179
+ return ((float)(curand_MRG32k3a(state)*MRG32K3A_NORM));
180
+ }
181
+
182
+ /**
183
+ * \brief Return a uniformly distributed double from an MRG32k3a generator.
184
+ *
185
+ * Return a uniformly distributed double between \p 0.0 and \p 1.0
186
+ * from the MRG32k3a generator in \p state, increment position of generator.
187
+ * Output range excludes \p 0.0 but includes \p 1.0. Denormalized floating
188
+ * point outputs are never returned.
189
+ *
190
+ * Note the implementation returns at most 32 random bits of mantissa as
191
+ * outlined in the seminal paper by L'Ecuyer.
192
+ *
193
+ * \param state - Pointer to state to update
194
+ *
195
+ * \return uniformly distributed double between \p 0.0 and \p 1.0
196
+ */
197
+ QUALIFIERS double curand_uniform_double(curandStateMRG32k3a_t *state)
198
+ {
199
+ return curand_MRG32k3a(state)*MRG32K3A_NORM;
200
+ }
201
+
202
+
203
+
204
+ /**
205
+ * \brief Return a uniformly distributed tuple of 2 doubles from an Philox4_32_10 generator.
206
+ *
207
+ * Return a uniformly distributed 2 doubles (double4) between \p 0.0 and \p 1.0
208
+ * from the Philox4_32_10 generator in \p state, increment position of generator by 4.
209
+ * Output range excludes \p 0.0 but includes \p 1.0. Denormalized floating
210
+ * point outputs are never returned.
211
+ *
212
+ * \param state - Pointer to state to update
213
+ *
214
+ * \return 2 uniformly distributed doubles between \p 0.0 and \p 1.0
215
+ */
216
+
217
+ QUALIFIERS double2 curand_uniform2_double(curandStatePhilox4_32_10_t *state)
218
+ {
219
+ uint4 _x;
220
+ double2 result;
221
+ _x = curand4(state);
222
+ result.x = _curand_uniform_double_hq(_x.x,_x.y);
223
+ result.y = _curand_uniform_double_hq(_x.z,_x.w);
224
+ return result;
225
+ }
226
+
227
+
228
+ // not a part of API
229
+ QUALIFIERS double4 curand_uniform4_double(curandStatePhilox4_32_10_t *state)
230
+ {
231
+ uint4 _x, _y;
232
+ double4 result;
233
+ _x = curand4(state);
234
+ _y = curand4(state);
235
+ result.x = _curand_uniform_double_hq(_x.x,_x.y);
236
+ result.y = _curand_uniform_double_hq(_x.z,_x.w);
237
+ result.z = _curand_uniform_double_hq(_y.x,_y.y);
238
+ result.w = _curand_uniform_double_hq(_y.z,_y.w);
239
+ return result;
240
+ }
241
+
242
+ /**
243
+ * \brief Return a uniformly distributed float from a Philox4_32_10 generator.
244
+ *
245
+ * Return a uniformly distributed float between \p 0.0f and \p 1.0f
246
+ * from the Philox4_32_10 generator in \p state, increment position of generator.
247
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
248
+ * point outputs are never returned.
249
+ *
250
+ * \param state - Pointer to state to update
251
+ *
252
+ * \return uniformly distributed float between \p 0.0 and \p 1.0
253
+ *
254
+ */
255
+ QUALIFIERS float curand_uniform(curandStatePhilox4_32_10_t *state)
256
+ {
257
+ return _curand_uniform(curand(state));
258
+ }
259
+
260
+ /**
261
+ * \brief Return a uniformly distributed tuple of 4 floats from a Philox4_32_10 generator.
262
+ *
263
+ * Return a uniformly distributed 4 floats between \p 0.0f and \p 1.0f
264
+ * from the Philox4_32_10 generator in \p state, increment position of generator by 4.
265
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
266
+ * point outputs are never returned.
267
+ *
268
+ * \param state - Pointer to state to update
269
+ *
270
+ * \return uniformly distributed float between \p 0.0 and \p 1.0
271
+ *
272
+ */
273
+ QUALIFIERS float4 curand_uniform4(curandStatePhilox4_32_10_t *state)
274
+ {
275
+ return _curand_uniform4(curand4(state));
276
+ }
277
+
278
+ /**
279
+ * \brief Return a uniformly distributed float from a MTGP32 generator.
280
+ *
281
+ * Return a uniformly distributed float between \p 0.0f and \p 1.0f
282
+ * from the MTGP32 generator in \p state, increment position of generator.
283
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
284
+ * point outputs are never returned.
285
+ *
286
+ * \param state - Pointer to state to update
287
+ *
288
+ * \return uniformly distributed float between \p 0.0f and \p 1.0f
289
+ */
290
+ QUALIFIERS float curand_uniform(curandStateMtgp32_t *state)
291
+ {
292
+ return _curand_uniform(curand(state));
293
+ }
294
+ /**
295
+ * \brief Return a uniformly distributed double from a MTGP32 generator.
296
+ *
297
+ * Return a uniformly distributed double between \p 0.0f and \p 1.0f
298
+ * from the MTGP32 generator in \p state, increment position of generator.
299
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
300
+ * point outputs are never returned.
301
+ *
302
+ * Note that the implementation uses only 32 random bits to generate a single double
303
+ * precision value.
304
+ *
305
+ * \param state - Pointer to state to update
306
+ *
307
+ * \return uniformly distributed double between \p 0.0f and \p 1.0f
308
+ */
309
+ QUALIFIERS double curand_uniform_double(curandStateMtgp32_t *state)
310
+ {
311
+ return _curand_uniform_double(curand(state));
312
+ }
313
+
314
+ /**
315
+ * \brief Return a uniformly distributed double from a Philox4_32_10 generator.
316
+ *
317
+ * Return a uniformly distributed double between \p 0.0f and \p 1.0f
318
+ * from the Philox4_32_10 generator in \p state, increment position of generator.
319
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
320
+ * point outputs are never returned.
321
+ *
322
+ * Note that the implementation uses only 32 random bits to generate a single double
323
+ * precision value.
324
+ *
325
+ * \p curand_uniform2_double() is recommended for higher quality uniformly distributed
326
+ * double precision values.
327
+ *
328
+ * \param state - Pointer to state to update
329
+ *
330
+ * \return uniformly distributed double between \p 0.0f and \p 1.0f
331
+ */
332
+
333
+ QUALIFIERS double curand_uniform_double(curandStatePhilox4_32_10_t *state)
334
+ {
335
+ return _curand_uniform_double(curand(state));
336
+ }
337
+
338
+
339
+ /**
340
+ * \brief Return a uniformly distributed float from a Sobol32 generator.
341
+ *
342
+ * Return a uniformly distributed float between \p 0.0f and \p 1.0f
343
+ * from the Sobol32 generator in \p state, increment position of generator.
344
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
345
+ * point outputs are never returned.
346
+ *
347
+ * The implementation is guaranteed to use a single call to \p curand().
348
+ *
349
+ * \param state - Pointer to state to update
350
+ *
351
+ * \return uniformly distributed float between \p 0.0f and \p 1.0f
352
+ */
353
+ QUALIFIERS float curand_uniform(curandStateSobol32_t *state)
354
+ {
355
+ return _curand_uniform(curand(state));
356
+ }
357
+
358
+ /**
359
+ * \brief Return a uniformly distributed double from a Sobol32 generator.
360
+ *
361
+ * Return a uniformly distributed double between \p 0.0 and \p 1.0
362
+ * from the Sobol32 generator in \p state, increment position of generator.
363
+ * Output range excludes \p 0.0 but includes \p 1.0. Denormalized floating
364
+ * point outputs are never returned.
365
+ *
366
+ * The implementation is guaranteed to use a single call to \p curand()
367
+ * to preserve the quasirandom properties of the sequence.
368
+ *
369
+ * Note that the implementation uses only 32 random bits to generate a single double
370
+ * precision value.
371
+ *
372
+ * \param state - Pointer to state to update
373
+ *
374
+ * \return uniformly distributed double between \p 0.0 and \p 1.0
375
+ */
376
+ QUALIFIERS double curand_uniform_double(curandStateSobol32_t *state)
377
+ {
378
+ return _curand_uniform_double(curand(state));
379
+ }
380
+ /**
381
+ * \brief Return a uniformly distributed float from a scrambled Sobol32 generator.
382
+ *
383
+ * Return a uniformly distributed float between \p 0.0f and \p 1.0f
384
+ * from the scrambled Sobol32 generator in \p state, increment position of generator.
385
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
386
+ * point outputs are never returned.
387
+ *
388
+ * The implementation is guaranteed to use a single call to \p curand().
389
+ *
390
+ * \param state - Pointer to state to update
391
+ *
392
+ * \return uniformly distributed float between \p 0.0f and \p 1.0f
393
+ */
394
+ QUALIFIERS float curand_uniform(curandStateScrambledSobol32_t *state)
395
+ {
396
+ return _curand_uniform(curand(state));
397
+ }
398
+
399
+ /**
400
+ * \brief Return a uniformly distributed double from a scrambled Sobol32 generator.
401
+ *
402
+ * Return a uniformly distributed double between \p 0.0 and \p 1.0
403
+ * from the scrambled Sobol32 generator in \p state, increment position of generator.
404
+ * Output range excludes \p 0.0 but includes \p 1.0. Denormalized floating
405
+ * point outputs are never returned.
406
+ *
407
+ * The implementation is guaranteed to use a single call to \p curand()
408
+ * to preserve the quasirandom properties of the sequence.
409
+ *
410
+ * Note that the implementation uses only 32 random bits to generate a single double
411
+ * precision value.
412
+ *
413
+ * \param state - Pointer to state to update
414
+ *
415
+ * \return uniformly distributed double between \p 0.0 and \p 1.0
416
+ */
417
+ QUALIFIERS double curand_uniform_double(curandStateScrambledSobol32_t *state)
418
+ {
419
+ return _curand_uniform_double(curand(state));
420
+ }
421
+ /**
422
+ * \brief Return a uniformly distributed float from a Sobol64 generator.
423
+ *
424
+ * Return a uniformly distributed float between \p 0.0f and \p 1.0f
425
+ * from the Sobol64 generator in \p state, increment position of generator.
426
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
427
+ * point outputs are never returned.
428
+ *
429
+ * The implementation is guaranteed to use a single call to \p curand().
430
+ *
431
+ * \param state - Pointer to state to update
432
+ *
433
+ * \return uniformly distributed float between \p 0.0f and \p 1.0f
434
+ */
435
+ QUALIFIERS float curand_uniform(curandStateSobol64_t *state)
436
+ {
437
+ return _curand_uniform(curand(state));
438
+ }
439
+
440
+ /**
441
+ * \brief Return a uniformly distributed double from a Sobol64 generator.
442
+ *
443
+ * Return a uniformly distributed double between \p 0.0 and \p 1.0
444
+ * from the Sobol64 generator in \p state, increment position of generator.
445
+ * Output range excludes \p 0.0 but includes \p 1.0. Denormalized floating
446
+ * point outputs are never returned.
447
+ *
448
+ * The implementation is guaranteed to use a single call to \p curand()
449
+ * to preserve the quasirandom properties of the sequence.
450
+ *
451
+ * \param state - Pointer to state to update
452
+ *
453
+ * \return uniformly distributed double between \p 0.0 and \p 1.0
454
+ */
455
+ QUALIFIERS double curand_uniform_double(curandStateSobol64_t *state)
456
+ {
457
+ return _curand_uniform_double(curand(state));
458
+ }
459
+ /**
460
+ * \brief Return a uniformly distributed float from a scrambled Sobol64 generator.
461
+ *
462
+ * Return a uniformly distributed float between \p 0.0f and \p 1.0f
463
+ * from the scrambled Sobol64 generator in \p state, increment position of generator.
464
+ * Output range excludes \p 0.0f but includes \p 1.0f. Denormalized floating
465
+ * point outputs are never returned.
466
+ *
467
+ * The implementation is guaranteed to use a single call to \p curand().
468
+ *
469
+ * \param state - Pointer to state to update
470
+ *
471
+ * \return uniformly distributed float between \p 0.0f and \p 1.0f
472
+ */
473
+ QUALIFIERS float curand_uniform(curandStateScrambledSobol64_t *state)
474
+ {
475
+ return _curand_uniform(curand(state));
476
+ }
477
+
478
+ /**
479
+ * \brief Return a uniformly distributed double from a scrambled Sobol64 generator.
480
+ *
481
+ * Return a uniformly distributed double between \p 0.0 and \p 1.0
482
+ * from the scrambled Sobol64 generator in \p state, increment position of generator.
483
+ * Output range excludes \p 0.0 but includes \p 1.0. Denormalized floating
484
+ * point outputs are never returned.
485
+ *
486
+ * The implementation is guaranteed to use a single call to \p curand()
487
+ * to preserve the quasirandom properties of the sequence.
488
+ *
489
+ * \param state - Pointer to state to update
490
+ *
491
+ * \return uniformly distributed double between \p 0.0 and \p 1.0
492
+ */
493
+ QUALIFIERS double curand_uniform_double(curandStateScrambledSobol64_t *state)
494
+ {
495
+ return _curand_uniform_double(curand(state));
496
+ }
497
+
498
+ #endif // !defined(CURAND_UNIFORM_H_)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cusolver/include/cusolverSp.h ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2014 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #if !defined(CUSOLVERSP_H_)
51
+ #define CUSOLVERSP_H_
52
+
53
+ #include "cusparse.h"
54
+ #include "cublas_v2.h"
55
+ #include "cusolver_common.h"
56
+
57
+ #if defined(__cplusplus)
58
+ extern "C" {
59
+ #endif /* __cplusplus */
60
+
61
+ struct cusolverSpContext;
62
+ typedef struct cusolverSpContext *cusolverSpHandle_t;
63
+
64
+ struct csrqrInfo;
65
+ typedef struct csrqrInfo *csrqrInfo_t;
66
+
67
+ cusolverStatus_t CUSOLVERAPI cusolverSpCreate(cusolverSpHandle_t *handle);
68
+ cusolverStatus_t CUSOLVERAPI cusolverSpDestroy(cusolverSpHandle_t handle);
69
+ cusolverStatus_t CUSOLVERAPI
70
+ cusolverSpSetStream(cusolverSpHandle_t handle, cudaStream_t streamId);
71
+ cusolverStatus_t CUSOLVERAPI
72
+ cusolverSpGetStream(cusolverSpHandle_t handle, cudaStream_t *streamId);
73
+
74
+ cusolverStatus_t CUSOLVERAPI cusolverSpXcsrissymHost(
75
+ cusolverSpHandle_t handle,
76
+ int m,
77
+ int nnzA,
78
+ const cusparseMatDescr_t descrA,
79
+ const int * csrRowPtrA,
80
+ const int * csrEndPtrA,
81
+ const int * csrColIndA,
82
+ int * issym);
83
+
84
+ /* -------- GPU linear solver by LU factorization
85
+ * solve A*x = b, A can be singular
86
+ * [ls] stands for linear solve
87
+ * [v] stands for vector
88
+ * [lu] stands for LU factorization
89
+ */
90
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrlsvluHost(
91
+ cusolverSpHandle_t handle,
92
+ int n,
93
+ int nnzA,
94
+ const cusparseMatDescr_t descrA,
95
+ const float * csrValA,
96
+ const int * csrRowPtrA,
97
+ const int * csrColIndA,
98
+ const float * b,
99
+ float tol,
100
+ int reorder,
101
+ float * x,
102
+ int * singularity);
103
+
104
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrlsvluHost(
105
+ cusolverSpHandle_t handle,
106
+ int n,
107
+ int nnzA,
108
+ const cusparseMatDescr_t descrA,
109
+ const double * csrValA,
110
+ const int * csrRowPtrA,
111
+ const int * csrColIndA,
112
+ const double * b,
113
+ double tol,
114
+ int reorder,
115
+ double * x,
116
+ int * singularity);
117
+
118
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrlsvluHost(
119
+ cusolverSpHandle_t handle,
120
+ int n,
121
+ int nnzA,
122
+ const cusparseMatDescr_t descrA,
123
+ const cuComplex * csrValA,
124
+ const int * csrRowPtrA,
125
+ const int * csrColIndA,
126
+ const cuComplex * b,
127
+ float tol,
128
+ int reorder,
129
+ cuComplex * x,
130
+ int * singularity);
131
+
132
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrlsvluHost(
133
+ cusolverSpHandle_t handle,
134
+ int n,
135
+ int nnzA,
136
+ const cusparseMatDescr_t descrA,
137
+ const cuDoubleComplex * csrValA,
138
+ const int * csrRowPtrA,
139
+ const int * csrColIndA,
140
+ const cuDoubleComplex * b,
141
+ double tol,
142
+ int reorder,
143
+ cuDoubleComplex * x,
144
+ int * singularity);
145
+
146
+ /* -------- GPU linear solver by QR factorization
147
+ * solve A*x = b, A can be singular
148
+ * [ls] stands for linear solve
149
+ * [v] stands for vector
150
+ * [qr] stands for QR factorization
151
+ */
152
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrlsvqr(
153
+ cusolverSpHandle_t handle,
154
+ int m,
155
+ int nnz,
156
+ const cusparseMatDescr_t descrA,
157
+ const float * csrVal,
158
+ const int * csrRowPtr,
159
+ const int * csrColInd,
160
+ const float * b,
161
+ float tol,
162
+ int reorder,
163
+ float * x,
164
+ int * singularity);
165
+
166
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrlsvqr(
167
+ cusolverSpHandle_t handle,
168
+ int m,
169
+ int nnz,
170
+ const cusparseMatDescr_t descrA,
171
+ const double * csrVal,
172
+ const int * csrRowPtr,
173
+ const int * csrColInd,
174
+ const double * b,
175
+ double tol,
176
+ int reorder,
177
+ double * x,
178
+ int * singularity);
179
+
180
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrlsvqr(
181
+ cusolverSpHandle_t handle,
182
+ int m,
183
+ int nnz,
184
+ const cusparseMatDescr_t descrA,
185
+ const cuComplex * csrVal,
186
+ const int * csrRowPtr,
187
+ const int * csrColInd,
188
+ const cuComplex * b,
189
+ float tol,
190
+ int reorder,
191
+ cuComplex * x,
192
+ int * singularity);
193
+
194
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrlsvqr(
195
+ cusolverSpHandle_t handle,
196
+ int m,
197
+ int nnz,
198
+ const cusparseMatDescr_t descrA,
199
+ const cuDoubleComplex * csrVal,
200
+ const int * csrRowPtr,
201
+ const int * csrColInd,
202
+ const cuDoubleComplex * b,
203
+ double tol,
204
+ int reorder,
205
+ cuDoubleComplex * x,
206
+ int * singularity);
207
+
208
+ /* -------- CPU linear solver by QR factorization
209
+ * solve A*x = b, A can be singular
210
+ * [ls] stands for linear solve
211
+ * [v] stands for vector
212
+ * [qr] stands for QR factorization
213
+ */
214
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrlsvqrHost(
215
+ cusolverSpHandle_t handle,
216
+ int m,
217
+ int nnz,
218
+ const cusparseMatDescr_t descrA,
219
+ const float * csrValA,
220
+ const int * csrRowPtrA,
221
+ const int * csrColIndA,
222
+ const float * b,
223
+ float tol,
224
+ int reorder,
225
+ float * x,
226
+ int * singularity);
227
+
228
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrlsvqrHost(
229
+ cusolverSpHandle_t handle,
230
+ int m,
231
+ int nnz,
232
+ const cusparseMatDescr_t descrA,
233
+ const double * csrValA,
234
+ const int * csrRowPtrA,
235
+ const int * csrColIndA,
236
+ const double * b,
237
+ double tol,
238
+ int reorder,
239
+ double * x,
240
+ int * singularity);
241
+
242
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrlsvqrHost(
243
+ cusolverSpHandle_t handle,
244
+ int m,
245
+ int nnz,
246
+ const cusparseMatDescr_t descrA,
247
+ const cuComplex * csrValA,
248
+ const int * csrRowPtrA,
249
+ const int * csrColIndA,
250
+ const cuComplex * b,
251
+ float tol,
252
+ int reorder,
253
+ cuComplex * x,
254
+ int * singularity);
255
+
256
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrlsvqrHost(
257
+ cusolverSpHandle_t handle,
258
+ int m,
259
+ int nnz,
260
+ const cusparseMatDescr_t descrA,
261
+ const cuDoubleComplex * csrValA,
262
+ const int * csrRowPtrA,
263
+ const int * csrColIndA,
264
+ const cuDoubleComplex * b,
265
+ double tol,
266
+ int reorder,
267
+ cuDoubleComplex * x,
268
+ int * singularity);
269
+
270
+ /* -------- CPU linear solver by Cholesky factorization
271
+ * solve A*x = b, A can be singular
272
+ * [ls] stands for linear solve
273
+ * [v] stands for vector
274
+ * [chol] stands for Cholesky factorization
275
+ *
276
+ * Only works for symmetric positive definite matrix.
277
+ * The upper part of A is ignored.
278
+ */
279
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrlsvcholHost(
280
+ cusolverSpHandle_t handle,
281
+ int m,
282
+ int nnz,
283
+ const cusparseMatDescr_t descrA,
284
+ const float * csrVal,
285
+ const int * csrRowPtr,
286
+ const int * csrColInd,
287
+ const float * b,
288
+ float tol,
289
+ int reorder,
290
+ float * x,
291
+ int * singularity);
292
+
293
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrlsvcholHost(
294
+ cusolverSpHandle_t handle,
295
+ int m,
296
+ int nnz,
297
+ const cusparseMatDescr_t descrA,
298
+ const double * csrVal,
299
+ const int * csrRowPtr,
300
+ const int * csrColInd,
301
+ const double * b,
302
+ double tol,
303
+ int reorder,
304
+ double * x,
305
+ int * singularity);
306
+
307
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrlsvcholHost(
308
+ cusolverSpHandle_t handle,
309
+ int m,
310
+ int nnz,
311
+ const cusparseMatDescr_t descrA,
312
+ const cuComplex * csrVal,
313
+ const int * csrRowPtr,
314
+ const int * csrColInd,
315
+ const cuComplex * b,
316
+ float tol,
317
+ int reorder,
318
+ cuComplex * x,
319
+ int * singularity);
320
+
321
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrlsvcholHost(
322
+ cusolverSpHandle_t handle,
323
+ int m,
324
+ int nnz,
325
+ const cusparseMatDescr_t descrA,
326
+ const cuDoubleComplex * csrVal,
327
+ const int * csrRowPtr,
328
+ const int * csrColInd,
329
+ const cuDoubleComplex * b,
330
+ double tol,
331
+ int reorder,
332
+ cuDoubleComplex * x,
333
+ int * singularity);
334
+
335
+ /* -------- GPU linear solver by Cholesky factorization
336
+ * solve A*x = b, A can be singular
337
+ * [ls] stands for linear solve
338
+ * [v] stands for vector
339
+ * [chol] stands for Cholesky factorization
340
+ *
341
+ * Only works for symmetric positive definite matrix.
342
+ * The upper part of A is ignored.
343
+ */
344
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrlsvchol(
345
+ cusolverSpHandle_t handle,
346
+ int m,
347
+ int nnz,
348
+ const cusparseMatDescr_t descrA,
349
+ const float * csrVal,
350
+ const int * csrRowPtr,
351
+ const int * csrColInd,
352
+ const float * b,
353
+ float tol,
354
+ int reorder,
355
+ // output
356
+ float *x,
357
+ int * singularity);
358
+
359
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrlsvchol(
360
+ cusolverSpHandle_t handle,
361
+ int m,
362
+ int nnz,
363
+ const cusparseMatDescr_t descrA,
364
+ const double * csrVal,
365
+ const int * csrRowPtr,
366
+ const int * csrColInd,
367
+ const double * b,
368
+ double tol,
369
+ int reorder,
370
+ // output
371
+ double *x,
372
+ int * singularity);
373
+
374
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrlsvchol(
375
+ cusolverSpHandle_t handle,
376
+ int m,
377
+ int nnz,
378
+ const cusparseMatDescr_t descrA,
379
+ const cuComplex * csrVal,
380
+ const int * csrRowPtr,
381
+ const int * csrColInd,
382
+ const cuComplex * b,
383
+ float tol,
384
+ int reorder,
385
+ // output
386
+ cuComplex *x,
387
+ int * singularity);
388
+
389
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrlsvchol(
390
+ cusolverSpHandle_t handle,
391
+ int m,
392
+ int nnz,
393
+ const cusparseMatDescr_t descrA,
394
+ const cuDoubleComplex * csrVal,
395
+ const int * csrRowPtr,
396
+ const int * csrColInd,
397
+ const cuDoubleComplex * b,
398
+ double tol,
399
+ int reorder,
400
+ // output
401
+ cuDoubleComplex *x,
402
+ int * singularity);
403
+
404
+ /* ----------- CPU least square solver by QR factorization
405
+ * solve min|b - A*x|
406
+ * [lsq] stands for least square
407
+ * [v] stands for vector
408
+ * [qr] stands for QR factorization
409
+ */
410
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrlsqvqrHost(
411
+ cusolverSpHandle_t handle,
412
+ int m,
413
+ int n,
414
+ int nnz,
415
+ const cusparseMatDescr_t descrA,
416
+ const float * csrValA,
417
+ const int * csrRowPtrA,
418
+ const int * csrColIndA,
419
+ const float * b,
420
+ float tol,
421
+ int * rankA,
422
+ float * x,
423
+ int * p,
424
+ float * min_norm);
425
+
426
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrlsqvqrHost(
427
+ cusolverSpHandle_t handle,
428
+ int m,
429
+ int n,
430
+ int nnz,
431
+ const cusparseMatDescr_t descrA,
432
+ const double * csrValA,
433
+ const int * csrRowPtrA,
434
+ const int * csrColIndA,
435
+ const double * b,
436
+ double tol,
437
+ int * rankA,
438
+ double * x,
439
+ int * p,
440
+ double * min_norm);
441
+
442
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrlsqvqrHost(
443
+ cusolverSpHandle_t handle,
444
+ int m,
445
+ int n,
446
+ int nnz,
447
+ const cusparseMatDescr_t descrA,
448
+ const cuComplex * csrValA,
449
+ const int * csrRowPtrA,
450
+ const int * csrColIndA,
451
+ const cuComplex * b,
452
+ float tol,
453
+ int * rankA,
454
+ cuComplex * x,
455
+ int * p,
456
+ float * min_norm);
457
+
458
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrlsqvqrHost(
459
+ cusolverSpHandle_t handle,
460
+ int m,
461
+ int n,
462
+ int nnz,
463
+ const cusparseMatDescr_t descrA,
464
+ const cuDoubleComplex * csrValA,
465
+ const int * csrRowPtrA,
466
+ const int * csrColIndA,
467
+ const cuDoubleComplex * b,
468
+ double tol,
469
+ int * rankA,
470
+ cuDoubleComplex * x,
471
+ int * p,
472
+ double * min_norm);
473
+
474
+ /* --------- CPU eigenvalue solver by shift inverse
475
+ * solve A*x = lambda * x
476
+ * where lambda is the eigenvalue nearest mu0.
477
+ * [eig] stands for eigenvalue solver
478
+ * [si] stands for shift-inverse
479
+ */
480
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsreigvsiHost(
481
+ cusolverSpHandle_t handle,
482
+ int m,
483
+ int nnz,
484
+ const cusparseMatDescr_t descrA,
485
+ const float * csrValA,
486
+ const int * csrRowPtrA,
487
+ const int * csrColIndA,
488
+ float mu0,
489
+ const float * x0,
490
+ int maxite,
491
+ float tol,
492
+ float * mu,
493
+ float * x);
494
+
495
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsreigvsiHost(
496
+ cusolverSpHandle_t handle,
497
+ int m,
498
+ int nnz,
499
+ const cusparseMatDescr_t descrA,
500
+ const double * csrValA,
501
+ const int * csrRowPtrA,
502
+ const int * csrColIndA,
503
+ double mu0,
504
+ const double * x0,
505
+ int maxite,
506
+ double tol,
507
+ double * mu,
508
+ double * x);
509
+
510
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsreigvsiHost(
511
+ cusolverSpHandle_t handle,
512
+ int m,
513
+ int nnz,
514
+ const cusparseMatDescr_t descrA,
515
+ const cuComplex * csrValA,
516
+ const int * csrRowPtrA,
517
+ const int * csrColIndA,
518
+ cuComplex mu0,
519
+ const cuComplex * x0,
520
+ int maxite,
521
+ float tol,
522
+ cuComplex * mu,
523
+ cuComplex * x);
524
+
525
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsreigvsiHost(
526
+ cusolverSpHandle_t handle,
527
+ int m,
528
+ int nnz,
529
+ const cusparseMatDescr_t descrA,
530
+ const cuDoubleComplex * csrValA,
531
+ const int * csrRowPtrA,
532
+ const int * csrColIndA,
533
+ cuDoubleComplex mu0,
534
+ const cuDoubleComplex * x0,
535
+ int maxite,
536
+ double tol,
537
+ cuDoubleComplex * mu,
538
+ cuDoubleComplex * x);
539
+
540
+ /* --------- GPU eigenvalue solver by shift inverse
541
+ * solve A*x = lambda * x
542
+ * where lambda is the eigenvalue nearest mu0.
543
+ * [eig] stands for eigenvalue solver
544
+ * [si] stands for shift-inverse
545
+ */
546
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsreigvsi(
547
+ cusolverSpHandle_t handle,
548
+ int m,
549
+ int nnz,
550
+ const cusparseMatDescr_t descrA,
551
+ const float * csrValA,
552
+ const int * csrRowPtrA,
553
+ const int * csrColIndA,
554
+ float mu0,
555
+ const float * x0,
556
+ int maxite,
557
+ float eps,
558
+ float * mu,
559
+ float * x);
560
+
561
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsreigvsi(
562
+ cusolverSpHandle_t handle,
563
+ int m,
564
+ int nnz,
565
+ const cusparseMatDescr_t descrA,
566
+ const double * csrValA,
567
+ const int * csrRowPtrA,
568
+ const int * csrColIndA,
569
+ double mu0,
570
+ const double * x0,
571
+ int maxite,
572
+ double eps,
573
+ double * mu,
574
+ double * x);
575
+
576
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsreigvsi(
577
+ cusolverSpHandle_t handle,
578
+ int m,
579
+ int nnz,
580
+ const cusparseMatDescr_t descrA,
581
+ const cuComplex * csrValA,
582
+ const int * csrRowPtrA,
583
+ const int * csrColIndA,
584
+ cuComplex mu0,
585
+ const cuComplex * x0,
586
+ int maxite,
587
+ float eps,
588
+ cuComplex * mu,
589
+ cuComplex * x);
590
+
591
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsreigvsi(
592
+ cusolverSpHandle_t handle,
593
+ int m,
594
+ int nnz,
595
+ const cusparseMatDescr_t descrA,
596
+ const cuDoubleComplex * csrValA,
597
+ const int * csrRowPtrA,
598
+ const int * csrColIndA,
599
+ cuDoubleComplex mu0,
600
+ const cuDoubleComplex * x0,
601
+ int maxite,
602
+ double eps,
603
+ cuDoubleComplex * mu,
604
+ cuDoubleComplex * x);
605
+
606
+ // ----------- enclosed eigenvalues
607
+
608
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsreigsHost(
609
+ cusolverSpHandle_t handle,
610
+ int m,
611
+ int nnz,
612
+ const cusparseMatDescr_t descrA,
613
+ const float * csrValA,
614
+ const int * csrRowPtrA,
615
+ const int * csrColIndA,
616
+ cuComplex left_bottom_corner,
617
+ cuComplex right_upper_corner,
618
+ int * num_eigs);
619
+
620
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsreigsHost(
621
+ cusolverSpHandle_t handle,
622
+ int m,
623
+ int nnz,
624
+ const cusparseMatDescr_t descrA,
625
+ const double * csrValA,
626
+ const int * csrRowPtrA,
627
+ const int * csrColIndA,
628
+ cuDoubleComplex left_bottom_corner,
629
+ cuDoubleComplex right_upper_corner,
630
+ int * num_eigs);
631
+
632
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsreigsHost(
633
+ cusolverSpHandle_t handle,
634
+ int m,
635
+ int nnz,
636
+ const cusparseMatDescr_t descrA,
637
+ const cuComplex * csrValA,
638
+ const int * csrRowPtrA,
639
+ const int * csrColIndA,
640
+ cuComplex left_bottom_corner,
641
+ cuComplex right_upper_corner,
642
+ int * num_eigs);
643
+
644
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsreigsHost(
645
+ cusolverSpHandle_t handle,
646
+ int m,
647
+ int nnz,
648
+ const cusparseMatDescr_t descrA,
649
+ const cuDoubleComplex * csrValA,
650
+ const int * csrRowPtrA,
651
+ const int * csrColIndA,
652
+ cuDoubleComplex left_bottom_corner,
653
+ cuDoubleComplex right_upper_corner,
654
+ int * num_eigs);
655
+
656
+ /* --------- CPU symrcm
657
+ * Symmetric reverse Cuthill McKee permutation
658
+ *
659
+ */
660
+ cusolverStatus_t CUSOLVERAPI cusolverSpXcsrsymrcmHost(
661
+ cusolverSpHandle_t handle,
662
+ int n,
663
+ int nnzA,
664
+ const cusparseMatDescr_t descrA,
665
+ const int * csrRowPtrA,
666
+ const int * csrColIndA,
667
+ int * p);
668
+
669
+ /* --------- CPU symmdq
670
+ * Symmetric minimum degree algorithm by quotient graph
671
+ *
672
+ */
673
+ cusolverStatus_t CUSOLVERAPI cusolverSpXcsrsymmdqHost(
674
+ cusolverSpHandle_t handle,
675
+ int n,
676
+ int nnzA,
677
+ const cusparseMatDescr_t descrA,
678
+ const int * csrRowPtrA,
679
+ const int * csrColIndA,
680
+ int * p);
681
+
682
+ /* --------- CPU symmdq
683
+ * Symmetric Approximate minimum degree algorithm by quotient graph
684
+ *
685
+ */
686
+ cusolverStatus_t CUSOLVERAPI cusolverSpXcsrsymamdHost(
687
+ cusolverSpHandle_t handle,
688
+ int n,
689
+ int nnzA,
690
+ const cusparseMatDescr_t descrA,
691
+ const int * csrRowPtrA,
692
+ const int * csrColIndA,
693
+ int * p);
694
+
695
+ /* --------- CPU metis
696
+ * symmetric reordering
697
+ */
698
+ cusolverStatus_t CUSOLVERAPI cusolverSpXcsrmetisndHost(
699
+ cusolverSpHandle_t handle,
700
+ int n,
701
+ int nnzA,
702
+ const cusparseMatDescr_t descrA,
703
+ const int * csrRowPtrA,
704
+ const int * csrColIndA,
705
+ const int64_t * options,
706
+ int * p);
707
+
708
+ /* --------- CPU zfd
709
+ * Zero free diagonal reordering
710
+ */
711
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrzfdHost(
712
+ cusolverSpHandle_t handle,
713
+ int n,
714
+ int nnz,
715
+ const cusparseMatDescr_t descrA,
716
+ const float * csrValA,
717
+ const int * csrRowPtrA,
718
+ const int * csrColIndA,
719
+ int * P,
720
+ int * numnz);
721
+
722
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrzfdHost(
723
+ cusolverSpHandle_t handle,
724
+ int n,
725
+ int nnz,
726
+ const cusparseMatDescr_t descrA,
727
+ const double * csrValA,
728
+ const int * csrRowPtrA,
729
+ const int * csrColIndA,
730
+ int * P,
731
+ int * numnz);
732
+
733
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrzfdHost(
734
+ cusolverSpHandle_t handle,
735
+ int n,
736
+ int nnz,
737
+ const cusparseMatDescr_t descrA,
738
+ const cuComplex * csrValA,
739
+ const int * csrRowPtrA,
740
+ const int * csrColIndA,
741
+ int * P,
742
+ int * numnz);
743
+
744
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrzfdHost(
745
+ cusolverSpHandle_t handle,
746
+ int n,
747
+ int nnz,
748
+ const cusparseMatDescr_t descrA,
749
+ const cuDoubleComplex * csrValA,
750
+ const int * csrRowPtrA,
751
+ const int * csrColIndA,
752
+ int * P,
753
+ int * numnz);
754
+
755
+ /* --------- CPU permuation
756
+ * P*A*Q^T
757
+ *
758
+ */
759
+ cusolverStatus_t CUSOLVERAPI cusolverSpXcsrperm_bufferSizeHost(
760
+ cusolverSpHandle_t handle,
761
+ int m,
762
+ int n,
763
+ int nnzA,
764
+ const cusparseMatDescr_t descrA,
765
+ const int * csrRowPtrA,
766
+ const int * csrColIndA,
767
+ const int * p,
768
+ const int * q,
769
+ size_t * bufferSizeInBytes);
770
+
771
+ cusolverStatus_t CUSOLVERAPI cusolverSpXcsrpermHost(
772
+ cusolverSpHandle_t handle,
773
+ int m,
774
+ int n,
775
+ int nnzA,
776
+ const cusparseMatDescr_t descrA,
777
+ int * csrRowPtrA,
778
+ int * csrColIndA,
779
+ const int * p,
780
+ const int * q,
781
+ int * map,
782
+ void * pBuffer);
783
+
784
+ /*
785
+ * Low-level API: Batched QR
786
+ *
787
+ */
788
+
789
+ cusolverStatus_t CUSOLVERAPI cusolverSpCreateCsrqrInfo(csrqrInfo_t *info);
790
+
791
+ cusolverStatus_t CUSOLVERAPI cusolverSpDestroyCsrqrInfo(csrqrInfo_t info);
792
+
793
+ cusolverStatus_t CUSOLVERAPI cusolverSpXcsrqrAnalysisBatched(
794
+ cusolverSpHandle_t handle,
795
+ int m,
796
+ int n,
797
+ int nnzA,
798
+ const cusparseMatDescr_t descrA,
799
+ const int * csrRowPtrA,
800
+ const int * csrColIndA,
801
+ csrqrInfo_t info);
802
+
803
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrqrBufferInfoBatched(
804
+ cusolverSpHandle_t handle,
805
+ int m,
806
+ int n,
807
+ int nnz,
808
+ const cusparseMatDescr_t descrA,
809
+ const float * csrVal,
810
+ const int * csrRowPtr,
811
+ const int * csrColInd,
812
+ int batchSize,
813
+ csrqrInfo_t info,
814
+ size_t * internalDataInBytes,
815
+ size_t * workspaceInBytes);
816
+
817
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrqrBufferInfoBatched(
818
+ cusolverSpHandle_t handle,
819
+ int m,
820
+ int n,
821
+ int nnz,
822
+ const cusparseMatDescr_t descrA,
823
+ const double * csrVal,
824
+ const int * csrRowPtr,
825
+ const int * csrColInd,
826
+ int batchSize,
827
+ csrqrInfo_t info,
828
+ size_t * internalDataInBytes,
829
+ size_t * workspaceInBytes);
830
+
831
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrqrBufferInfoBatched(
832
+ cusolverSpHandle_t handle,
833
+ int m,
834
+ int n,
835
+ int nnz,
836
+ const cusparseMatDescr_t descrA,
837
+ const cuComplex * csrVal,
838
+ const int * csrRowPtr,
839
+ const int * csrColInd,
840
+ int batchSize,
841
+ csrqrInfo_t info,
842
+ size_t * internalDataInBytes,
843
+ size_t * workspaceInBytes);
844
+
845
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrqrBufferInfoBatched(
846
+ cusolverSpHandle_t handle,
847
+ int m,
848
+ int n,
849
+ int nnz,
850
+ const cusparseMatDescr_t descrA,
851
+ const cuDoubleComplex * csrVal,
852
+ const int * csrRowPtr,
853
+ const int * csrColInd,
854
+ int batchSize,
855
+ csrqrInfo_t info,
856
+ size_t * internalDataInBytes,
857
+ size_t * workspaceInBytes);
858
+
859
+ cusolverStatus_t CUSOLVERAPI cusolverSpScsrqrsvBatched(
860
+ cusolverSpHandle_t handle,
861
+ int m,
862
+ int n,
863
+ int nnz,
864
+ const cusparseMatDescr_t descrA,
865
+ const float * csrValA,
866
+ const int * csrRowPtrA,
867
+ const int * csrColIndA,
868
+ const float * b,
869
+ float * x,
870
+ int batchSize,
871
+ csrqrInfo_t info,
872
+ void * pBuffer);
873
+
874
+ cusolverStatus_t CUSOLVERAPI cusolverSpDcsrqrsvBatched(
875
+ cusolverSpHandle_t handle,
876
+ int m,
877
+ int n,
878
+ int nnz,
879
+ const cusparseMatDescr_t descrA,
880
+ const double * csrValA,
881
+ const int * csrRowPtrA,
882
+ const int * csrColIndA,
883
+ const double * b,
884
+ double * x,
885
+ int batchSize,
886
+ csrqrInfo_t info,
887
+ void * pBuffer);
888
+
889
+ cusolverStatus_t CUSOLVERAPI cusolverSpCcsrqrsvBatched(
890
+ cusolverSpHandle_t handle,
891
+ int m,
892
+ int n,
893
+ int nnz,
894
+ const cusparseMatDescr_t descrA,
895
+ const cuComplex * csrValA,
896
+ const int * csrRowPtrA,
897
+ const int * csrColIndA,
898
+ const cuComplex * b,
899
+ cuComplex * x,
900
+ int batchSize,
901
+ csrqrInfo_t info,
902
+ void * pBuffer);
903
+
904
+ cusolverStatus_t CUSOLVERAPI cusolverSpZcsrqrsvBatched(
905
+ cusolverSpHandle_t handle,
906
+ int m,
907
+ int n,
908
+ int nnz,
909
+ const cusparseMatDescr_t descrA,
910
+ const cuDoubleComplex * csrValA,
911
+ const int * csrRowPtrA,
912
+ const int * csrColIndA,
913
+ const cuDoubleComplex * b,
914
+ cuDoubleComplex * x,
915
+ int batchSize,
916
+ csrqrInfo_t info,
917
+ void * pBuffer);
918
+
919
+ #if defined(__cplusplus)
920
+ }
921
+ #endif /* __cplusplus */
922
+
923
+ #endif // define CUSOLVERSP_H_
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/locators.cpython-311.pyc ADDED
Binary file (65.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/markers.cpython-311.pyc ADDED
Binary file (8.54 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/metadata.cpython-311.pyc ADDED
Binary file (47.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/resources.cpython-311.pyc ADDED
Binary file (19 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/scripts.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/util.cpython-311.pyc ADDED
Binary file (98.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/version.cpython-311.pyc ADDED
Binary file (34.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/__pycache__/wheel.cpython-311.pyc ADDED
Binary file (60.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distro/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.22 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distro/py.typed ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/_emoji_codes.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/_wrap.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import Iterable
5
+
6
+ from ._loop import loop_last
7
+ from .cells import cell_len, chop_cells
8
+
9
+ re_word = re.compile(r"\s*\S+\s*")
10
+
11
+
12
+ def words(text: str) -> Iterable[tuple[int, int, str]]:
13
+ """Yields each word from the text as a tuple
14
+ containing (start_index, end_index, word). A "word" in this context may
15
+ include the actual word and any whitespace to the right.
16
+ """
17
+ position = 0
18
+ word_match = re_word.match(text, position)
19
+ while word_match is not None:
20
+ start, end = word_match.span()
21
+ word = word_match.group(0)
22
+ yield start, end, word
23
+ word_match = re_word.match(text, end)
24
+
25
+
26
+ def divide_line(text: str, width: int, fold: bool = True) -> list[int]:
27
+ """Given a string of text, and a width (measured in cells), return a list
28
+ of cell offsets which the string should be split at in order for it to fit
29
+ within the given width.
30
+
31
+ Args:
32
+ text: The text to examine.
33
+ width: The available cell width.
34
+ fold: If True, words longer than `width` will be folded onto a new line.
35
+
36
+ Returns:
37
+ A list of indices to break the line at.
38
+ """
39
+ break_positions: list[int] = [] # offsets to insert the breaks at
40
+ append = break_positions.append
41
+ cell_offset = 0
42
+ _cell_len = cell_len
43
+
44
+ for start, _end, word in words(text):
45
+ word_length = _cell_len(word.rstrip())
46
+ remaining_space = width - cell_offset
47
+ word_fits_remaining_space = remaining_space >= word_length
48
+
49
+ if word_fits_remaining_space:
50
+ # Simplest case - the word fits within the remaining width for this line.
51
+ cell_offset += _cell_len(word)
52
+ else:
53
+ # Not enough space remaining for this word on the current line.
54
+ if word_length > width:
55
+ # The word doesn't fit on any line, so we can't simply
56
+ # place it on the next line...
57
+ if fold:
58
+ # Fold the word across multiple lines.
59
+ folded_word = chop_cells(word, width=width)
60
+ for last, line in loop_last(folded_word):
61
+ if start:
62
+ append(start)
63
+ if last:
64
+ cell_offset = _cell_len(line)
65
+ else:
66
+ start += len(line)
67
+ else:
68
+ # Folding isn't allowed, so crop the word.
69
+ if start:
70
+ append(start)
71
+ cell_offset = _cell_len(word)
72
+ elif cell_offset and start:
73
+ # The word doesn't fit within the remaining space on the current
74
+ # line, but it *can* fit on to the next (empty) line.
75
+ append(start)
76
+ cell_offset = _cell_len(word)
77
+
78
+ return break_positions
79
+
80
+
81
+ if __name__ == "__main__": # pragma: no cover
82
+ from .console import Console
83
+
84
+ console = Console(width=10)
85
+ console.print("12345 abcdefghijklmnopqrstuvwyxzABCDEFGHIJKLMNOPQRSTUVWXYZ 12345")
86
+ print(chop_cells("abcdefghijklmnopqrstuvwxyz", 10))
87
+
88
+ console = Console(width=20)
89
+ console.rule()
90
+ console.print("TextualはPythonの高速アプリケーション開発フレームワークです")
91
+
92
+ console.rule()
93
+ console.print("アプリケーションは1670万色を使用でき")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/constrain.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, TYPE_CHECKING
2
+
3
+ from .jupyter import JupyterMixin
4
+ from .measure import Measurement
5
+
6
+ if TYPE_CHECKING:
7
+ from .console import Console, ConsoleOptions, RenderableType, RenderResult
8
+
9
+
10
+ class Constrain(JupyterMixin):
11
+ """Constrain the width of a renderable to a given number of characters.
12
+
13
+ Args:
14
+ renderable (RenderableType): A renderable object.
15
+ width (int, optional): The maximum width (in characters) to render. Defaults to 80.
16
+ """
17
+
18
+ def __init__(self, renderable: "RenderableType", width: Optional[int] = 80) -> None:
19
+ self.renderable = renderable
20
+ self.width = width
21
+
22
+ def __rich_console__(
23
+ self, console: "Console", options: "ConsoleOptions"
24
+ ) -> "RenderResult":
25
+ if self.width is None:
26
+ yield self.renderable
27
+ else:
28
+ child_options = options.update_width(min(self.width, options.max_width))
29
+ yield from console.render(self.renderable, child_options)
30
+
31
+ def __rich_measure__(
32
+ self, console: "Console", options: "ConsoleOptions"
33
+ ) -> "Measurement":
34
+ if self.width is not None:
35
+ options = options.update_width(self.width)
36
+ measurement = Measurement.get(console, options, self.renderable)
37
+ return measurement
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/file_proxy.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import IO, TYPE_CHECKING, Any, List
3
+
4
+ from .ansi import AnsiDecoder
5
+ from .text import Text
6
+
7
+ if TYPE_CHECKING:
8
+ from .console import Console
9
+
10
+
11
+ class FileProxy(io.TextIOBase):
12
+ """Wraps a file (e.g. sys.stdout) and redirects writes to a console."""
13
+
14
+ def __init__(self, console: "Console", file: IO[str]) -> None:
15
+ self.__console = console
16
+ self.__file = file
17
+ self.__buffer: List[str] = []
18
+ self.__ansi_decoder = AnsiDecoder()
19
+
20
+ @property
21
+ def rich_proxied_file(self) -> IO[str]:
22
+ """Get proxied file."""
23
+ return self.__file
24
+
25
+ def __getattr__(self, name: str) -> Any:
26
+ return getattr(self.__file, name)
27
+
28
+ def write(self, text: str) -> int:
29
+ if not isinstance(text, str):
30
+ raise TypeError(f"write() argument must be str, not {type(text).__name__}")
31
+ buffer = self.__buffer
32
+ lines: List[str] = []
33
+ while text:
34
+ line, new_line, text = text.partition("\n")
35
+ if new_line:
36
+ lines.append("".join(buffer) + line)
37
+ buffer.clear()
38
+ else:
39
+ buffer.append(line)
40
+ break
41
+ if lines:
42
+ console = self.__console
43
+ with console:
44
+ output = Text("\n").join(
45
+ self.__ansi_decoder.decode_line(line) for line in lines
46
+ )
47
+ console.print(output)
48
+ return len(text)
49
+
50
+ def flush(self) -> None:
51
+ output = "".join(self.__buffer)
52
+ if output:
53
+ self.__console.print(output)
54
+ del self.__buffer[:]
55
+
56
+ def fileno(self) -> int:
57
+ return self.__file.fileno()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/highlighter.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from typing import List, Union
4
+
5
+ from .text import Span, Text
6
+
7
+
8
+ def _combine_regex(*regexes: str) -> str:
9
+ """Combine a number of regexes in to a single regex.
10
+
11
+ Returns:
12
+ str: New regex with all regexes ORed together.
13
+ """
14
+ return "|".join(regexes)
15
+
16
+
17
+ class Highlighter(ABC):
18
+ """Abstract base class for highlighters."""
19
+
20
+ def __call__(self, text: Union[str, Text]) -> Text:
21
+ """Highlight a str or Text instance.
22
+
23
+ Args:
24
+ text (Union[str, ~Text]): Text to highlight.
25
+
26
+ Raises:
27
+ TypeError: If not called with text or str.
28
+
29
+ Returns:
30
+ Text: A test instance with highlighting applied.
31
+ """
32
+ if isinstance(text, str):
33
+ highlight_text = Text(text)
34
+ elif isinstance(text, Text):
35
+ highlight_text = text.copy()
36
+ else:
37
+ raise TypeError(f"str or Text instance required, not {text!r}")
38
+ self.highlight(highlight_text)
39
+ return highlight_text
40
+
41
+ @abstractmethod
42
+ def highlight(self, text: Text) -> None:
43
+ """Apply highlighting in place to text.
44
+
45
+ Args:
46
+ text (~Text): A text object highlight.
47
+ """
48
+
49
+
50
+ class NullHighlighter(Highlighter):
51
+ """A highlighter object that doesn't highlight.
52
+
53
+ May be used to disable highlighting entirely.
54
+
55
+ """
56
+
57
+ def highlight(self, text: Text) -> None:
58
+ """Nothing to do"""
59
+
60
+
61
+ class RegexHighlighter(Highlighter):
62
+ """Applies highlighting from a list of regular expressions."""
63
+
64
+ highlights: List[str] = []
65
+ base_style: str = ""
66
+
67
+ def highlight(self, text: Text) -> None:
68
+ """Highlight :class:`rich.text.Text` using regular expressions.
69
+
70
+ Args:
71
+ text (~Text): Text to highlighted.
72
+
73
+ """
74
+
75
+ highlight_regex = text.highlight_regex
76
+ for re_highlight in self.highlights:
77
+ highlight_regex(re_highlight, style_prefix=self.base_style)
78
+
79
+
80
+ class ReprHighlighter(RegexHighlighter):
81
+ """Highlights the text typically produced from ``__repr__`` methods."""
82
+
83
+ base_style = "repr."
84
+ highlights = [
85
+ r"(?P<tag_start><)(?P<tag_name>[-\w.:|]*)(?P<tag_contents>[\w\W]*)(?P<tag_end>>)",
86
+ r'(?P<attrib_name>[\w_]{1,50})=(?P<attrib_value>"?[\w_]+"?)?',
87
+ r"(?P<brace>[][{}()])",
88
+ _combine_regex(
89
+ r"(?P<ipv4>[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})",
90
+ r"(?P<ipv6>([A-Fa-f0-9]{1,4}::?){1,7}[A-Fa-f0-9]{1,4})",
91
+ r"(?P<eui64>(?:[0-9A-Fa-f]{1,2}-){7}[0-9A-Fa-f]{1,2}|(?:[0-9A-Fa-f]{1,2}:){7}[0-9A-Fa-f]{1,2}|(?:[0-9A-Fa-f]{4}\.){3}[0-9A-Fa-f]{4})",
92
+ r"(?P<eui48>(?:[0-9A-Fa-f]{1,2}-){5}[0-9A-Fa-f]{1,2}|(?:[0-9A-Fa-f]{1,2}:){5}[0-9A-Fa-f]{1,2}|(?:[0-9A-Fa-f]{4}\.){2}[0-9A-Fa-f]{4})",
93
+ r"(?P<uuid>[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12})",
94
+ r"(?P<call>[\w.]*?)\(",
95
+ r"\b(?P<bool_true>True)\b|\b(?P<bool_false>False)\b|\b(?P<none>None)\b",
96
+ r"(?P<ellipsis>\.\.\.)",
97
+ r"(?P<number_complex>(?<!\w)(?:\-?[0-9]+\.?[0-9]*(?:e[-+]?\d+?)?)(?:[-+](?:[0-9]+\.?[0-9]*(?:e[-+]?\d+)?))?j)",
98
+ r"(?P<number>(?<!\w)\-?[0-9]+\.?[0-9]*(e[-+]?\d+?)?\b|0x[0-9a-fA-F]*)",
99
+ r"(?P<path>\B(/[-\w._+]+)*\/)(?P<filename>[-\w._+]*)?",
100
+ r"(?<![\\\w])(?P<str>b?'''.*?(?<!\\)'''|b?'.*?(?<!\\)'|b?\"\"\".*?(?<!\\)\"\"\"|b?\".*?(?<!\\)\")",
101
+ r"(?P<url>(file|https|http|ws|wss)://[-0-9a-zA-Z$_+!`(),.?/;:&=%#~]*)",
102
+ ),
103
+ ]
104
+
105
+
106
+ class JSONHighlighter(RegexHighlighter):
107
+ """Highlights JSON"""
108
+
109
+ # Captures the start and end of JSON strings, handling escaped quotes
110
+ JSON_STR = r"(?<![\\\w])(?P<str>b?\".*?(?<!\\)\")"
111
+ JSON_WHITESPACE = {" ", "\n", "\r", "\t"}
112
+
113
+ base_style = "json."
114
+ highlights = [
115
+ _combine_regex(
116
+ r"(?P<brace>[\{\[\(\)\]\}])",
117
+ r"\b(?P<bool_true>true)\b|\b(?P<bool_false>false)\b|\b(?P<null>null)\b",
118
+ r"(?P<number>(?<!\w)\-?[0-9]+\.?[0-9]*(e[\-\+]?\d+?)?\b|0x[0-9a-fA-F]*)",
119
+ JSON_STR,
120
+ ),
121
+ ]
122
+
123
+ def highlight(self, text: Text) -> None:
124
+ super().highlight(text)
125
+
126
+ # Additional work to handle highlighting JSON keys
127
+ plain = text.plain
128
+ append = text.spans.append
129
+ whitespace = self.JSON_WHITESPACE
130
+ for match in re.finditer(self.JSON_STR, plain):
131
+ start, end = match.span()
132
+ cursor = end
133
+ while cursor < len(plain):
134
+ char = plain[cursor]
135
+ cursor += 1
136
+ if char == ":":
137
+ append(Span(start, end, "json.key"))
138
+ elif char in whitespace:
139
+ continue
140
+ break
141
+
142
+
143
+ class ISO8601Highlighter(RegexHighlighter):
144
+ """Highlights the ISO8601 date time strings.
145
+ Regex reference: https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch04s07.html
146
+ """
147
+
148
+ base_style = "iso8601."
149
+ highlights = [
150
+ #
151
+ # Dates
152
+ #
153
+ # Calendar month (e.g. 2008-08). The hyphen is required
154
+ r"^(?P<year>[0-9]{4})-(?P<month>1[0-2]|0[1-9])$",
155
+ # Calendar date w/o hyphens (e.g. 20080830)
156
+ r"^(?P<date>(?P<year>[0-9]{4})(?P<month>1[0-2]|0[1-9])(?P<day>3[01]|0[1-9]|[12][0-9]))$",
157
+ # Ordinal date (e.g. 2008-243). The hyphen is optional
158
+ r"^(?P<date>(?P<year>[0-9]{4})-?(?P<day>36[0-6]|3[0-5][0-9]|[12][0-9]{2}|0[1-9][0-9]|00[1-9]))$",
159
+ #
160
+ # Weeks
161
+ #
162
+ # Week of the year (e.g., 2008-W35). The hyphen is optional
163
+ r"^(?P<date>(?P<year>[0-9]{4})-?W(?P<week>5[0-3]|[1-4][0-9]|0[1-9]))$",
164
+ # Week date (e.g., 2008-W35-6). The hyphens are optional
165
+ r"^(?P<date>(?P<year>[0-9]{4})-?W(?P<week>5[0-3]|[1-4][0-9]|0[1-9])-?(?P<day>[1-7]))$",
166
+ #
167
+ # Times
168
+ #
169
+ # Hours and minutes (e.g., 17:21). The colon is optional
170
+ r"^(?P<time>(?P<hour>2[0-3]|[01][0-9]):?(?P<minute>[0-5][0-9]))$",
171
+ # Hours, minutes, and seconds w/o colons (e.g., 172159)
172
+ r"^(?P<time>(?P<hour>2[0-3]|[01][0-9])(?P<minute>[0-5][0-9])(?P<second>[0-5][0-9]))$",
173
+ # Time zone designator (e.g., Z, +07 or +07:00). The colons and the minutes are optional
174
+ r"^(?P<timezone>(Z|[+-](?:2[0-3]|[01][0-9])(?::?(?:[0-5][0-9]))?))$",
175
+ # Hours, minutes, and seconds with time zone designator (e.g., 17:21:59+07:00).
176
+ # All the colons are optional. The minutes in the time zone designator are also optional
177
+ r"^(?P<time>(?P<hour>2[0-3]|[01][0-9])(?P<minute>[0-5][0-9])(?P<second>[0-5][0-9]))(?P<timezone>Z|[+-](?:2[0-3]|[01][0-9])(?::?(?:[0-5][0-9]))?)$",
178
+ #
179
+ # Date and Time
180
+ #
181
+ # Calendar date with hours, minutes, and seconds (e.g., 2008-08-30 17:21:59 or 20080830 172159).
182
+ # A space is required between the date and the time. The hyphens and colons are optional.
183
+ # This regex matches dates and times that specify some hyphens or colons but omit others.
184
+ # This does not follow ISO 8601
185
+ r"^(?P<date>(?P<year>[0-9]{4})(?P<hyphen>-)?(?P<month>1[0-2]|0[1-9])(?(hyphen)-)(?P<day>3[01]|0[1-9]|[12][0-9])) (?P<time>(?P<hour>2[0-3]|[01][0-9])(?(hyphen):)(?P<minute>[0-5][0-9])(?(hyphen):)(?P<second>[0-5][0-9]))$",
186
+ #
187
+ # XML Schema dates and times
188
+ #
189
+ # Date, with optional time zone (e.g., 2008-08-30 or 2008-08-30+07:00).
190
+ # Hyphens are required. This is the XML Schema 'date' type
191
+ r"^(?P<date>(?P<year>-?(?:[1-9][0-9]*)?[0-9]{4})-(?P<month>1[0-2]|0[1-9])-(?P<day>3[01]|0[1-9]|[12][0-9]))(?P<timezone>Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])?$",
192
+ # Time, with optional fractional seconds and time zone (e.g., 01:45:36 or 01:45:36.123+07:00).
193
+ # There is no limit on the number of digits for the fractional seconds. This is the XML Schema 'time' type
194
+ r"^(?P<time>(?P<hour>2[0-3]|[01][0-9]):(?P<minute>[0-5][0-9]):(?P<second>[0-5][0-9])(?P<frac>\.[0-9]+)?)(?P<timezone>Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])?$",
195
+ # Date and time, with optional fractional seconds and time zone (e.g., 2008-08-30T01:45:36 or 2008-08-30T01:45:36.123Z).
196
+ # This is the XML Schema 'dateTime' type
197
+ r"^(?P<date>(?P<year>-?(?:[1-9][0-9]*)?[0-9]{4})-(?P<month>1[0-2]|0[1-9])-(?P<day>3[01]|0[1-9]|[12][0-9]))T(?P<time>(?P<hour>2[0-3]|[01][0-9]):(?P<minute>[0-5][0-9]):(?P<second>[0-5][0-9])(?P<ms>\.[0-9]+)?)(?P<timezone>Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])?$",
198
+ ]
199
+
200
+
201
+ if __name__ == "__main__": # pragma: no cover
202
+ from .console import Console
203
+
204
+ console = Console()
205
+ console.print("[bold green]hello world![/bold green]")
206
+ console.print("'[bold green]hello world![/bold green]'")
207
+
208
+ console.print(" /foo")
209
+ console.print("/foo/")
210
+ console.print("/foo/bar")
211
+ console.print("foo/bar/baz")
212
+
213
+ console.print("/foo/bar/baz?foo=bar+egg&egg=baz")
214
+ console.print("/foo/bar/baz/")
215
+ console.print("/foo/bar/baz/egg")
216
+ console.print("/foo/bar/baz/egg.py")
217
+ console.print("/foo/bar/baz/egg.py word")
218
+ console.print(" /foo/bar/baz/egg.py word")
219
+ console.print("foo /foo/bar/baz/egg.py word")
220
+ console.print("foo /foo/bar/ba._++z/egg+.py word")
221
+ console.print("https://example.org?foo=bar#header")
222
+
223
+ console.print(1234567.34)
224
+ console.print(1 / 2)
225
+ console.print(-1 / 123123123123)
226
+
227
+ console.print(
228
+ "127.0.1.1 bar 192.168.1.4 2001:0db8:85a3:0000:0000:8a2e:0370:7334 foo"
229
+ )
230
+ import json
231
+
232
+ console.print_json(json.dumps(obj={"name": "apple", "count": 1}), indent=None)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/json.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from json import loads, dumps
3
+ from typing import Any, Callable, Optional, Union
4
+
5
+ from .text import Text
6
+ from .highlighter import JSONHighlighter, NullHighlighter
7
+
8
+
9
+ class JSON:
10
+ """A renderable which pretty prints JSON.
11
+
12
+ Args:
13
+ json (str): JSON encoded data.
14
+ indent (Union[None, int, str], optional): Number of characters to indent by. Defaults to 2.
15
+ highlight (bool, optional): Enable highlighting. Defaults to True.
16
+ skip_keys (bool, optional): Skip keys not of a basic type. Defaults to False.
17
+ ensure_ascii (bool, optional): Escape all non-ascii characters. Defaults to False.
18
+ check_circular (bool, optional): Check for circular references. Defaults to True.
19
+ allow_nan (bool, optional): Allow NaN and Infinity values. Defaults to True.
20
+ default (Callable, optional): A callable that converts values that can not be encoded
21
+ in to something that can be JSON encoded. Defaults to None.
22
+ sort_keys (bool, optional): Sort dictionary keys. Defaults to False.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ json: str,
28
+ indent: Union[None, int, str] = 2,
29
+ highlight: bool = True,
30
+ skip_keys: bool = False,
31
+ ensure_ascii: bool = False,
32
+ check_circular: bool = True,
33
+ allow_nan: bool = True,
34
+ default: Optional[Callable[[Any], Any]] = None,
35
+ sort_keys: bool = False,
36
+ ) -> None:
37
+ data = loads(json)
38
+ json = dumps(
39
+ data,
40
+ indent=indent,
41
+ skipkeys=skip_keys,
42
+ ensure_ascii=ensure_ascii,
43
+ check_circular=check_circular,
44
+ allow_nan=allow_nan,
45
+ default=default,
46
+ sort_keys=sort_keys,
47
+ )
48
+ highlighter = JSONHighlighter() if highlight else NullHighlighter()
49
+ self.text = highlighter(json)
50
+ self.text.no_wrap = True
51
+ self.text.overflow = None
52
+
53
+ @classmethod
54
+ def from_data(
55
+ cls,
56
+ data: Any,
57
+ indent: Union[None, int, str] = 2,
58
+ highlight: bool = True,
59
+ skip_keys: bool = False,
60
+ ensure_ascii: bool = False,
61
+ check_circular: bool = True,
62
+ allow_nan: bool = True,
63
+ default: Optional[Callable[[Any], Any]] = None,
64
+ sort_keys: bool = False,
65
+ ) -> "JSON":
66
+ """Encodes a JSON object from arbitrary data.
67
+
68
+ Args:
69
+ data (Any): An object that may be encoded in to JSON
70
+ indent (Union[None, int, str], optional): Number of characters to indent by. Defaults to 2.
71
+ highlight (bool, optional): Enable highlighting. Defaults to True.
72
+ default (Callable, optional): Optional callable which will be called for objects that cannot be serialized. Defaults to None.
73
+ skip_keys (bool, optional): Skip keys not of a basic type. Defaults to False.
74
+ ensure_ascii (bool, optional): Escape all non-ascii characters. Defaults to False.
75
+ check_circular (bool, optional): Check for circular references. Defaults to True.
76
+ allow_nan (bool, optional): Allow NaN and Infinity values. Defaults to True.
77
+ default (Callable, optional): A callable that converts values that can not be encoded
78
+ in to something that can be JSON encoded. Defaults to None.
79
+ sort_keys (bool, optional): Sort dictionary keys. Defaults to False.
80
+
81
+ Returns:
82
+ JSON: New JSON object from the given data.
83
+ """
84
+ json_instance: "JSON" = cls.__new__(cls)
85
+ json = dumps(
86
+ data,
87
+ indent=indent,
88
+ skipkeys=skip_keys,
89
+ ensure_ascii=ensure_ascii,
90
+ check_circular=check_circular,
91
+ allow_nan=allow_nan,
92
+ default=default,
93
+ sort_keys=sort_keys,
94
+ )
95
+ highlighter = JSONHighlighter() if highlight else NullHighlighter()
96
+ json_instance.text = highlighter(json)
97
+ json_instance.text.no_wrap = True
98
+ json_instance.text.overflow = None
99
+ return json_instance
100
+
101
+ def __rich__(self) -> Text:
102
+ return self.text
103
+
104
+
105
+ if __name__ == "__main__":
106
+ import argparse
107
+ import sys
108
+
109
+ parser = argparse.ArgumentParser(description="Pretty print json")
110
+ parser.add_argument(
111
+ "path",
112
+ metavar="PATH",
113
+ help="path to file, or - for stdin",
114
+ )
115
+ parser.add_argument(
116
+ "-i",
117
+ "--indent",
118
+ metavar="SPACES",
119
+ type=int,
120
+ help="Number of spaces in an indent",
121
+ default=2,
122
+ )
123
+ args = parser.parse_args()
124
+
125
+ from pip._vendor.rich.console import Console
126
+
127
+ console = Console()
128
+ error_console = Console(stderr=True)
129
+
130
+ try:
131
+ if args.path == "-":
132
+ json_data = sys.stdin.read()
133
+ else:
134
+ json_data = Path(args.path).read_text()
135
+ except Exception as error:
136
+ error_console.print(f"Unable to read {args.path!r}; {error}")
137
+ sys.exit(-1)
138
+
139
+ console.print(JSON(json_data, indent=args.indent), soft_wrap=True)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/layout.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from itertools import islice
3
+ from operator import itemgetter
4
+ from threading import RLock
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Dict,
8
+ Iterable,
9
+ List,
10
+ NamedTuple,
11
+ Optional,
12
+ Sequence,
13
+ Tuple,
14
+ Union,
15
+ )
16
+
17
+ from ._ratio import ratio_resolve
18
+ from .align import Align
19
+ from .console import Console, ConsoleOptions, RenderableType, RenderResult
20
+ from .highlighter import ReprHighlighter
21
+ from .panel import Panel
22
+ from .pretty import Pretty
23
+ from .region import Region
24
+ from .repr import Result, rich_repr
25
+ from .segment import Segment
26
+ from .style import StyleType
27
+
28
+ if TYPE_CHECKING:
29
+ from pip._vendor.rich.tree import Tree
30
+
31
+
32
+ class LayoutRender(NamedTuple):
33
+ """An individual layout render."""
34
+
35
+ region: Region
36
+ render: List[List[Segment]]
37
+
38
+
39
+ RegionMap = Dict["Layout", Region]
40
+ RenderMap = Dict["Layout", LayoutRender]
41
+
42
+
43
+ class LayoutError(Exception):
44
+ """Layout related error."""
45
+
46
+
47
+ class NoSplitter(LayoutError):
48
+ """Requested splitter does not exist."""
49
+
50
+
51
+ class _Placeholder:
52
+ """An internal renderable used as a Layout placeholder."""
53
+
54
+ highlighter = ReprHighlighter()
55
+
56
+ def __init__(self, layout: "Layout", style: StyleType = "") -> None:
57
+ self.layout = layout
58
+ self.style = style
59
+
60
+ def __rich_console__(
61
+ self, console: Console, options: ConsoleOptions
62
+ ) -> RenderResult:
63
+ width = options.max_width
64
+ height = options.height or options.size.height
65
+ layout = self.layout
66
+ title = (
67
+ f"{layout.name!r} ({width} x {height})"
68
+ if layout.name
69
+ else f"({width} x {height})"
70
+ )
71
+ yield Panel(
72
+ Align.center(Pretty(layout), vertical="middle"),
73
+ style=self.style,
74
+ title=self.highlighter(title),
75
+ border_style="blue",
76
+ height=height,
77
+ )
78
+
79
+
80
+ class Splitter(ABC):
81
+ """Base class for a splitter."""
82
+
83
+ name: str = ""
84
+
85
+ @abstractmethod
86
+ def get_tree_icon(self) -> str:
87
+ """Get the icon (emoji) used in layout.tree"""
88
+
89
+ @abstractmethod
90
+ def divide(
91
+ self, children: Sequence["Layout"], region: Region
92
+ ) -> Iterable[Tuple["Layout", Region]]:
93
+ """Divide a region amongst several child layouts.
94
+
95
+ Args:
96
+ children (Sequence(Layout)): A number of child layouts.
97
+ region (Region): A rectangular region to divide.
98
+ """
99
+
100
+
101
+ class RowSplitter(Splitter):
102
+ """Split a layout region in to rows."""
103
+
104
+ name = "row"
105
+
106
+ def get_tree_icon(self) -> str:
107
+ return "[layout.tree.row]⬌"
108
+
109
+ def divide(
110
+ self, children: Sequence["Layout"], region: Region
111
+ ) -> Iterable[Tuple["Layout", Region]]:
112
+ x, y, width, height = region
113
+ render_widths = ratio_resolve(width, children)
114
+ offset = 0
115
+ _Region = Region
116
+ for child, child_width in zip(children, render_widths):
117
+ yield child, _Region(x + offset, y, child_width, height)
118
+ offset += child_width
119
+
120
+
121
+ class ColumnSplitter(Splitter):
122
+ """Split a layout region in to columns."""
123
+
124
+ name = "column"
125
+
126
+ def get_tree_icon(self) -> str:
127
+ return "[layout.tree.column]⬍"
128
+
129
+ def divide(
130
+ self, children: Sequence["Layout"], region: Region
131
+ ) -> Iterable[Tuple["Layout", Region]]:
132
+ x, y, width, height = region
133
+ render_heights = ratio_resolve(height, children)
134
+ offset = 0
135
+ _Region = Region
136
+ for child, child_height in zip(children, render_heights):
137
+ yield child, _Region(x, y + offset, width, child_height)
138
+ offset += child_height
139
+
140
+
141
+ @rich_repr
142
+ class Layout:
143
+ """A renderable to divide a fixed height in to rows or columns.
144
+
145
+ Args:
146
+ renderable (RenderableType, optional): Renderable content, or None for placeholder. Defaults to None.
147
+ name (str, optional): Optional identifier for Layout. Defaults to None.
148
+ size (int, optional): Optional fixed size of layout. Defaults to None.
149
+ minimum_size (int, optional): Minimum size of layout. Defaults to 1.
150
+ ratio (int, optional): Optional ratio for flexible layout. Defaults to 1.
151
+ visible (bool, optional): Visibility of layout. Defaults to True.
152
+ """
153
+
154
+ splitters = {"row": RowSplitter, "column": ColumnSplitter}
155
+
156
+ def __init__(
157
+ self,
158
+ renderable: Optional[RenderableType] = None,
159
+ *,
160
+ name: Optional[str] = None,
161
+ size: Optional[int] = None,
162
+ minimum_size: int = 1,
163
+ ratio: int = 1,
164
+ visible: bool = True,
165
+ ) -> None:
166
+ self._renderable = renderable or _Placeholder(self)
167
+ self.size = size
168
+ self.minimum_size = minimum_size
169
+ self.ratio = ratio
170
+ self.name = name
171
+ self.visible = visible
172
+ self.splitter: Splitter = self.splitters["column"]()
173
+ self._children: List[Layout] = []
174
+ self._render_map: RenderMap = {}
175
+ self._lock = RLock()
176
+
177
+ def __rich_repr__(self) -> Result:
178
+ yield "name", self.name, None
179
+ yield "size", self.size, None
180
+ yield "minimum_size", self.minimum_size, 1
181
+ yield "ratio", self.ratio, 1
182
+
183
+ @property
184
+ def renderable(self) -> RenderableType:
185
+ """Layout renderable."""
186
+ return self if self._children else self._renderable
187
+
188
+ @property
189
+ def children(self) -> List["Layout"]:
190
+ """Gets (visible) layout children."""
191
+ return [child for child in self._children if child.visible]
192
+
193
+ @property
194
+ def map(self) -> RenderMap:
195
+ """Get a map of the last render."""
196
+ return self._render_map
197
+
198
+ def get(self, name: str) -> Optional["Layout"]:
199
+ """Get a named layout, or None if it doesn't exist.
200
+
201
+ Args:
202
+ name (str): Name of layout.
203
+
204
+ Returns:
205
+ Optional[Layout]: Layout instance or None if no layout was found.
206
+ """
207
+ if self.name == name:
208
+ return self
209
+ else:
210
+ for child in self._children:
211
+ named_layout = child.get(name)
212
+ if named_layout is not None:
213
+ return named_layout
214
+ return None
215
+
216
+ def __getitem__(self, name: str) -> "Layout":
217
+ layout = self.get(name)
218
+ if layout is None:
219
+ raise KeyError(f"No layout with name {name!r}")
220
+ return layout
221
+
222
+ @property
223
+ def tree(self) -> "Tree":
224
+ """Get a tree renderable to show layout structure."""
225
+ from pip._vendor.rich.styled import Styled
226
+ from pip._vendor.rich.table import Table
227
+ from pip._vendor.rich.tree import Tree
228
+
229
+ def summary(layout: "Layout") -> Table:
230
+ icon = layout.splitter.get_tree_icon()
231
+
232
+ table = Table.grid(padding=(0, 1, 0, 0))
233
+
234
+ text: RenderableType = (
235
+ Pretty(layout) if layout.visible else Styled(Pretty(layout), "dim")
236
+ )
237
+ table.add_row(icon, text)
238
+ _summary = table
239
+ return _summary
240
+
241
+ layout = self
242
+ tree = Tree(
243
+ summary(layout),
244
+ guide_style=f"layout.tree.{layout.splitter.name}",
245
+ highlight=True,
246
+ )
247
+
248
+ def recurse(tree: "Tree", layout: "Layout") -> None:
249
+ for child in layout._children:
250
+ recurse(
251
+ tree.add(
252
+ summary(child),
253
+ guide_style=f"layout.tree.{child.splitter.name}",
254
+ ),
255
+ child,
256
+ )
257
+
258
+ recurse(tree, self)
259
+ return tree
260
+
261
+ def split(
262
+ self,
263
+ *layouts: Union["Layout", RenderableType],
264
+ splitter: Union[Splitter, str] = "column",
265
+ ) -> None:
266
+ """Split the layout in to multiple sub-layouts.
267
+
268
+ Args:
269
+ *layouts (Layout): Positional arguments should be (sub) Layout instances.
270
+ splitter (Union[Splitter, str]): Splitter instance or name of splitter.
271
+ """
272
+ _layouts = [
273
+ layout if isinstance(layout, Layout) else Layout(layout)
274
+ for layout in layouts
275
+ ]
276
+ try:
277
+ self.splitter = (
278
+ splitter
279
+ if isinstance(splitter, Splitter)
280
+ else self.splitters[splitter]()
281
+ )
282
+ except KeyError:
283
+ raise NoSplitter(f"No splitter called {splitter!r}")
284
+ self._children[:] = _layouts
285
+
286
+ def add_split(self, *layouts: Union["Layout", RenderableType]) -> None:
287
+ """Add a new layout(s) to existing split.
288
+
289
+ Args:
290
+ *layouts (Union[Layout, RenderableType]): Positional arguments should be renderables or (sub) Layout instances.
291
+
292
+ """
293
+ _layouts = (
294
+ layout if isinstance(layout, Layout) else Layout(layout)
295
+ for layout in layouts
296
+ )
297
+ self._children.extend(_layouts)
298
+
299
+ def split_row(self, *layouts: Union["Layout", RenderableType]) -> None:
300
+ """Split the layout in to a row (layouts side by side).
301
+
302
+ Args:
303
+ *layouts (Layout): Positional arguments should be (sub) Layout instances.
304
+ """
305
+ self.split(*layouts, splitter="row")
306
+
307
+ def split_column(self, *layouts: Union["Layout", RenderableType]) -> None:
308
+ """Split the layout in to a column (layouts stacked on top of each other).
309
+
310
+ Args:
311
+ *layouts (Layout): Positional arguments should be (sub) Layout instances.
312
+ """
313
+ self.split(*layouts, splitter="column")
314
+
315
+ def unsplit(self) -> None:
316
+ """Reset splits to initial state."""
317
+ del self._children[:]
318
+
319
+ def update(self, renderable: RenderableType) -> None:
320
+ """Update renderable.
321
+
322
+ Args:
323
+ renderable (RenderableType): New renderable object.
324
+ """
325
+ with self._lock:
326
+ self._renderable = renderable
327
+
328
+ def refresh_screen(self, console: "Console", layout_name: str) -> None:
329
+ """Refresh a sub-layout.
330
+
331
+ Args:
332
+ console (Console): Console instance where Layout is to be rendered.
333
+ layout_name (str): Name of layout.
334
+ """
335
+ with self._lock:
336
+ layout = self[layout_name]
337
+ region, _lines = self._render_map[layout]
338
+ (x, y, width, height) = region
339
+ lines = console.render_lines(
340
+ layout, console.options.update_dimensions(width, height)
341
+ )
342
+ self._render_map[layout] = LayoutRender(region, lines)
343
+ console.update_screen_lines(lines, x, y)
344
+
345
+ def _make_region_map(self, width: int, height: int) -> RegionMap:
346
+ """Create a dict that maps layout on to Region."""
347
+ stack: List[Tuple[Layout, Region]] = [(self, Region(0, 0, width, height))]
348
+ push = stack.append
349
+ pop = stack.pop
350
+ layout_regions: List[Tuple[Layout, Region]] = []
351
+ append_layout_region = layout_regions.append
352
+ while stack:
353
+ append_layout_region(pop())
354
+ layout, region = layout_regions[-1]
355
+ children = layout.children
356
+ if children:
357
+ for child_and_region in layout.splitter.divide(children, region):
358
+ push(child_and_region)
359
+
360
+ region_map = {
361
+ layout: region
362
+ for layout, region in sorted(layout_regions, key=itemgetter(1))
363
+ }
364
+ return region_map
365
+
366
+ def render(self, console: Console, options: ConsoleOptions) -> RenderMap:
367
+ """Render the sub_layouts.
368
+
369
+ Args:
370
+ console (Console): Console instance.
371
+ options (ConsoleOptions): Console options.
372
+
373
+ Returns:
374
+ RenderMap: A dict that maps Layout on to a tuple of Region, lines
375
+ """
376
+ render_width = options.max_width
377
+ render_height = options.height or console.height
378
+ region_map = self._make_region_map(render_width, render_height)
379
+ layout_regions = [
380
+ (layout, region)
381
+ for layout, region in region_map.items()
382
+ if not layout.children
383
+ ]
384
+ render_map: Dict["Layout", "LayoutRender"] = {}
385
+ render_lines = console.render_lines
386
+ update_dimensions = options.update_dimensions
387
+
388
+ for layout, region in layout_regions:
389
+ lines = render_lines(
390
+ layout.renderable, update_dimensions(region.width, region.height)
391
+ )
392
+ render_map[layout] = LayoutRender(region, lines)
393
+ return render_map
394
+
395
+ def __rich_console__(
396
+ self, console: Console, options: ConsoleOptions
397
+ ) -> RenderResult:
398
+ with self._lock:
399
+ width = options.max_width or console.width
400
+ height = options.height or console.height
401
+ render_map = self.render(console, options.update_dimensions(width, height))
402
+ self._render_map = render_map
403
+ layout_lines: List[List[Segment]] = [[] for _ in range(height)]
404
+ _islice = islice
405
+ for region, lines in render_map.values():
406
+ _x, y, _layout_width, layout_height = region
407
+ for row, line in zip(
408
+ _islice(layout_lines, y, y + layout_height), lines
409
+ ):
410
+ row.extend(line)
411
+
412
+ new_line = Segment.line()
413
+ for layout_row in layout_lines:
414
+ yield from layout_row
415
+ yield new_line
416
+
417
+
418
+ if __name__ == "__main__":
419
+ from pip._vendor.rich.console import Console
420
+
421
+ console = Console()
422
+ layout = Layout()
423
+
424
+ layout.split_column(
425
+ Layout(name="header", size=3),
426
+ Layout(ratio=1, name="main"),
427
+ Layout(size=10, name="footer"),
428
+ )
429
+
430
+ layout["main"].split_row(Layout(name="side"), Layout(name="body", ratio=2))
431
+
432
+ layout["body"].split_row(Layout(name="content", ratio=2), Layout(name="s2"))
433
+
434
+ layout["s2"].split_column(
435
+ Layout(name="top"), Layout(name="middle"), Layout(name="bottom")
436
+ )
437
+
438
+ layout["side"].split_column(Layout(layout.tree, name="left1"), Layout(name="left2"))
439
+
440
+ layout["content"].update("foo")
441
+
442
+ console.print(layout)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/progress_bar.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import lru_cache
3
+ from time import monotonic
4
+ from typing import Iterable, List, Optional
5
+
6
+ from .color import Color, blend_rgb
7
+ from .color_triplet import ColorTriplet
8
+ from .console import Console, ConsoleOptions, RenderResult
9
+ from .jupyter import JupyterMixin
10
+ from .measure import Measurement
11
+ from .segment import Segment
12
+ from .style import Style, StyleType
13
+
14
+ # Number of characters before 'pulse' animation repeats
15
+ PULSE_SIZE = 20
16
+
17
+
18
+ class ProgressBar(JupyterMixin):
19
+ """Renders a (progress) bar. Used by rich.progress.
20
+
21
+ Args:
22
+ total (float, optional): Number of steps in the bar. Defaults to 100. Set to None to render a pulsing animation.
23
+ completed (float, optional): Number of steps completed. Defaults to 0.
24
+ width (int, optional): Width of the bar, or ``None`` for maximum width. Defaults to None.
25
+ pulse (bool, optional): Enable pulse effect. Defaults to False. Will pulse if a None total was passed.
26
+ style (StyleType, optional): Style for the bar background. Defaults to "bar.back".
27
+ complete_style (StyleType, optional): Style for the completed bar. Defaults to "bar.complete".
28
+ finished_style (StyleType, optional): Style for a finished bar. Defaults to "bar.finished".
29
+ pulse_style (StyleType, optional): Style for pulsing bars. Defaults to "bar.pulse".
30
+ animation_time (Optional[float], optional): Time in seconds to use for animation, or None to use system time.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ total: Optional[float] = 100.0,
36
+ completed: float = 0,
37
+ width: Optional[int] = None,
38
+ pulse: bool = False,
39
+ style: StyleType = "bar.back",
40
+ complete_style: StyleType = "bar.complete",
41
+ finished_style: StyleType = "bar.finished",
42
+ pulse_style: StyleType = "bar.pulse",
43
+ animation_time: Optional[float] = None,
44
+ ):
45
+ self.total = total
46
+ self.completed = completed
47
+ self.width = width
48
+ self.pulse = pulse
49
+ self.style = style
50
+ self.complete_style = complete_style
51
+ self.finished_style = finished_style
52
+ self.pulse_style = pulse_style
53
+ self.animation_time = animation_time
54
+
55
+ self._pulse_segments: Optional[List[Segment]] = None
56
+
57
+ def __repr__(self) -> str:
58
+ return f"<Bar {self.completed!r} of {self.total!r}>"
59
+
60
+ @property
61
+ def percentage_completed(self) -> Optional[float]:
62
+ """Calculate percentage complete."""
63
+ if self.total is None:
64
+ return None
65
+ completed = (self.completed / self.total) * 100.0
66
+ completed = min(100, max(0.0, completed))
67
+ return completed
68
+
69
+ @lru_cache(maxsize=16)
70
+ def _get_pulse_segments(
71
+ self,
72
+ fore_style: Style,
73
+ back_style: Style,
74
+ color_system: str,
75
+ no_color: bool,
76
+ ascii: bool = False,
77
+ ) -> List[Segment]:
78
+ """Get a list of segments to render a pulse animation.
79
+
80
+ Returns:
81
+ List[Segment]: A list of segments, one segment per character.
82
+ """
83
+ bar = "-" if ascii else "━"
84
+ segments: List[Segment] = []
85
+ if color_system not in ("standard", "eight_bit", "truecolor") or no_color:
86
+ segments += [Segment(bar, fore_style)] * (PULSE_SIZE // 2)
87
+ segments += [Segment(" " if no_color else bar, back_style)] * (
88
+ PULSE_SIZE - (PULSE_SIZE // 2)
89
+ )
90
+ return segments
91
+
92
+ append = segments.append
93
+ fore_color = (
94
+ fore_style.color.get_truecolor()
95
+ if fore_style.color
96
+ else ColorTriplet(255, 0, 255)
97
+ )
98
+ back_color = (
99
+ back_style.color.get_truecolor()
100
+ if back_style.color
101
+ else ColorTriplet(0, 0, 0)
102
+ )
103
+ cos = math.cos
104
+ pi = math.pi
105
+ _Segment = Segment
106
+ _Style = Style
107
+ from_triplet = Color.from_triplet
108
+
109
+ for index in range(PULSE_SIZE):
110
+ position = index / PULSE_SIZE
111
+ fade = 0.5 + cos((position * pi * 2)) / 2.0
112
+ color = blend_rgb(fore_color, back_color, cross_fade=fade)
113
+ append(_Segment(bar, _Style(color=from_triplet(color))))
114
+ return segments
115
+
116
+ def update(self, completed: float, total: Optional[float] = None) -> None:
117
+ """Update progress with new values.
118
+
119
+ Args:
120
+ completed (float): Number of steps completed.
121
+ total (float, optional): Total number of steps, or ``None`` to not change. Defaults to None.
122
+ """
123
+ self.completed = completed
124
+ self.total = total if total is not None else self.total
125
+
126
+ def _render_pulse(
127
+ self, console: Console, width: int, ascii: bool = False
128
+ ) -> Iterable[Segment]:
129
+ """Renders the pulse animation.
130
+
131
+ Args:
132
+ console (Console): Console instance.
133
+ width (int): Width in characters of pulse animation.
134
+
135
+ Returns:
136
+ RenderResult: [description]
137
+
138
+ Yields:
139
+ Iterator[Segment]: Segments to render pulse
140
+ """
141
+ fore_style = console.get_style(self.pulse_style, default="white")
142
+ back_style = console.get_style(self.style, default="black")
143
+
144
+ pulse_segments = self._get_pulse_segments(
145
+ fore_style, back_style, console.color_system, console.no_color, ascii=ascii
146
+ )
147
+ segment_count = len(pulse_segments)
148
+ current_time = (
149
+ monotonic() if self.animation_time is None else self.animation_time
150
+ )
151
+ segments = pulse_segments * (int(width / segment_count) + 2)
152
+ offset = int(-current_time * 15) % segment_count
153
+ segments = segments[offset : offset + width]
154
+ yield from segments
155
+
156
+ def __rich_console__(
157
+ self, console: Console, options: ConsoleOptions
158
+ ) -> RenderResult:
159
+ width = min(self.width or options.max_width, options.max_width)
160
+ ascii = options.legacy_windows or options.ascii_only
161
+ should_pulse = self.pulse or self.total is None
162
+ if should_pulse:
163
+ yield from self._render_pulse(console, width, ascii=ascii)
164
+ return
165
+
166
+ completed: Optional[float] = (
167
+ min(self.total, max(0, self.completed)) if self.total is not None else None
168
+ )
169
+
170
+ bar = "-" if ascii else "━"
171
+ half_bar_right = " " if ascii else "╸"
172
+ half_bar_left = " " if ascii else "╺"
173
+ complete_halves = (
174
+ int(width * 2 * completed / self.total)
175
+ if self.total and completed is not None
176
+ else width * 2
177
+ )
178
+ bar_count = complete_halves // 2
179
+ half_bar_count = complete_halves % 2
180
+ style = console.get_style(self.style)
181
+ is_finished = self.total is None or self.completed >= self.total
182
+ complete_style = console.get_style(
183
+ self.finished_style if is_finished else self.complete_style
184
+ )
185
+ _Segment = Segment
186
+ if bar_count:
187
+ yield _Segment(bar * bar_count, complete_style)
188
+ if half_bar_count:
189
+ yield _Segment(half_bar_right * half_bar_count, complete_style)
190
+
191
+ if not console.no_color:
192
+ remaining_bars = width - bar_count - half_bar_count
193
+ if remaining_bars and console.color_system is not None:
194
+ if not half_bar_count and bar_count:
195
+ yield _Segment(half_bar_left, style)
196
+ remaining_bars -= 1
197
+ if remaining_bars:
198
+ yield _Segment(bar * remaining_bars, style)
199
+
200
+ def __rich_measure__(
201
+ self, console: Console, options: ConsoleOptions
202
+ ) -> Measurement:
203
+ return (
204
+ Measurement(self.width, self.width)
205
+ if self.width is not None
206
+ else Measurement(4, options.max_width)
207
+ )
208
+
209
+
210
+ if __name__ == "__main__": # pragma: no cover
211
+ console = Console()
212
+ bar = ProgressBar(width=50, total=100)
213
+
214
+ import time
215
+
216
+ console.show_cursor(False)
217
+ for n in range(0, 101, 1):
218
+ bar.update(n)
219
+ console.print(bar)
220
+ console.file.write("\r")
221
+ time.sleep(0.05)
222
+ console.show_cursor(True)
223
+ console.print()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/syntax.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import platform
3
+ import re
4
+ import sys
5
+ import textwrap
6
+ from abc import ABC, abstractmethod
7
+ from pathlib import Path
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterable,
12
+ List,
13
+ NamedTuple,
14
+ Optional,
15
+ Sequence,
16
+ Set,
17
+ Tuple,
18
+ Type,
19
+ Union,
20
+ )
21
+
22
+ from pip._vendor.pygments.lexer import Lexer
23
+ from pip._vendor.pygments.lexers import get_lexer_by_name, guess_lexer_for_filename
24
+ from pip._vendor.pygments.style import Style as PygmentsStyle
25
+ from pip._vendor.pygments.styles import get_style_by_name
26
+ from pip._vendor.pygments.token import (
27
+ Comment,
28
+ Error,
29
+ Generic,
30
+ Keyword,
31
+ Name,
32
+ Number,
33
+ Operator,
34
+ String,
35
+ Token,
36
+ Whitespace,
37
+ )
38
+ from pip._vendor.pygments.util import ClassNotFound
39
+
40
+ from pip._vendor.rich.containers import Lines
41
+ from pip._vendor.rich.padding import Padding, PaddingDimensions
42
+
43
+ from ._loop import loop_first
44
+ from .cells import cell_len
45
+ from .color import Color, blend_rgb
46
+ from .console import Console, ConsoleOptions, JustifyMethod, RenderResult
47
+ from .jupyter import JupyterMixin
48
+ from .measure import Measurement
49
+ from .segment import Segment, Segments
50
+ from .style import Style, StyleType
51
+ from .text import Text
52
+
53
+ TokenType = Tuple[str, ...]
54
+
55
+ WINDOWS = platform.system() == "Windows"
56
+ DEFAULT_THEME = "monokai"
57
+
58
+ # The following styles are based on https://github.com/pygments/pygments/blob/master/pygments/formatters/terminal.py
59
+ # A few modifications were made
60
+
61
+ ANSI_LIGHT: Dict[TokenType, Style] = {
62
+ Token: Style(),
63
+ Whitespace: Style(color="white"),
64
+ Comment: Style(dim=True),
65
+ Comment.Preproc: Style(color="cyan"),
66
+ Keyword: Style(color="blue"),
67
+ Keyword.Type: Style(color="cyan"),
68
+ Operator.Word: Style(color="magenta"),
69
+ Name.Builtin: Style(color="cyan"),
70
+ Name.Function: Style(color="green"),
71
+ Name.Namespace: Style(color="cyan", underline=True),
72
+ Name.Class: Style(color="green", underline=True),
73
+ Name.Exception: Style(color="cyan"),
74
+ Name.Decorator: Style(color="magenta", bold=True),
75
+ Name.Variable: Style(color="red"),
76
+ Name.Constant: Style(color="red"),
77
+ Name.Attribute: Style(color="cyan"),
78
+ Name.Tag: Style(color="bright_blue"),
79
+ String: Style(color="yellow"),
80
+ Number: Style(color="blue"),
81
+ Generic.Deleted: Style(color="bright_red"),
82
+ Generic.Inserted: Style(color="green"),
83
+ Generic.Heading: Style(bold=True),
84
+ Generic.Subheading: Style(color="magenta", bold=True),
85
+ Generic.Prompt: Style(bold=True),
86
+ Generic.Error: Style(color="bright_red"),
87
+ Error: Style(color="red", underline=True),
88
+ }
89
+
90
+ ANSI_DARK: Dict[TokenType, Style] = {
91
+ Token: Style(),
92
+ Whitespace: Style(color="bright_black"),
93
+ Comment: Style(dim=True),
94
+ Comment.Preproc: Style(color="bright_cyan"),
95
+ Keyword: Style(color="bright_blue"),
96
+ Keyword.Type: Style(color="bright_cyan"),
97
+ Operator.Word: Style(color="bright_magenta"),
98
+ Name.Builtin: Style(color="bright_cyan"),
99
+ Name.Function: Style(color="bright_green"),
100
+ Name.Namespace: Style(color="bright_cyan", underline=True),
101
+ Name.Class: Style(color="bright_green", underline=True),
102
+ Name.Exception: Style(color="bright_cyan"),
103
+ Name.Decorator: Style(color="bright_magenta", bold=True),
104
+ Name.Variable: Style(color="bright_red"),
105
+ Name.Constant: Style(color="bright_red"),
106
+ Name.Attribute: Style(color="bright_cyan"),
107
+ Name.Tag: Style(color="bright_blue"),
108
+ String: Style(color="yellow"),
109
+ Number: Style(color="bright_blue"),
110
+ Generic.Deleted: Style(color="bright_red"),
111
+ Generic.Inserted: Style(color="bright_green"),
112
+ Generic.Heading: Style(bold=True),
113
+ Generic.Subheading: Style(color="bright_magenta", bold=True),
114
+ Generic.Prompt: Style(bold=True),
115
+ Generic.Error: Style(color="bright_red"),
116
+ Error: Style(color="red", underline=True),
117
+ }
118
+
119
+ RICH_SYNTAX_THEMES = {"ansi_light": ANSI_LIGHT, "ansi_dark": ANSI_DARK}
120
+ NUMBERS_COLUMN_DEFAULT_PADDING = 2
121
+
122
+
123
+ class SyntaxTheme(ABC):
124
+ """Base class for a syntax theme."""
125
+
126
+ @abstractmethod
127
+ def get_style_for_token(self, token_type: TokenType) -> Style:
128
+ """Get a style for a given Pygments token."""
129
+ raise NotImplementedError # pragma: no cover
130
+
131
+ @abstractmethod
132
+ def get_background_style(self) -> Style:
133
+ """Get the background color."""
134
+ raise NotImplementedError # pragma: no cover
135
+
136
+
137
+ class PygmentsSyntaxTheme(SyntaxTheme):
138
+ """Syntax theme that delegates to Pygments theme."""
139
+
140
+ def __init__(self, theme: Union[str, Type[PygmentsStyle]]) -> None:
141
+ self._style_cache: Dict[TokenType, Style] = {}
142
+ if isinstance(theme, str):
143
+ try:
144
+ self._pygments_style_class = get_style_by_name(theme)
145
+ except ClassNotFound:
146
+ self._pygments_style_class = get_style_by_name("default")
147
+ else:
148
+ self._pygments_style_class = theme
149
+
150
+ self._background_color = self._pygments_style_class.background_color
151
+ self._background_style = Style(bgcolor=self._background_color)
152
+
153
+ def get_style_for_token(self, token_type: TokenType) -> Style:
154
+ """Get a style from a Pygments class."""
155
+ try:
156
+ return self._style_cache[token_type]
157
+ except KeyError:
158
+ try:
159
+ pygments_style = self._pygments_style_class.style_for_token(token_type)
160
+ except KeyError:
161
+ style = Style.null()
162
+ else:
163
+ color = pygments_style["color"]
164
+ bgcolor = pygments_style["bgcolor"]
165
+ style = Style(
166
+ color="#" + color if color else "#000000",
167
+ bgcolor="#" + bgcolor if bgcolor else self._background_color,
168
+ bold=pygments_style["bold"],
169
+ italic=pygments_style["italic"],
170
+ underline=pygments_style["underline"],
171
+ )
172
+ self._style_cache[token_type] = style
173
+ return style
174
+
175
+ def get_background_style(self) -> Style:
176
+ return self._background_style
177
+
178
+
179
+ class ANSISyntaxTheme(SyntaxTheme):
180
+ """Syntax theme to use standard colors."""
181
+
182
+ def __init__(self, style_map: Dict[TokenType, Style]) -> None:
183
+ self.style_map = style_map
184
+ self._missing_style = Style.null()
185
+ self._background_style = Style.null()
186
+ self._style_cache: Dict[TokenType, Style] = {}
187
+
188
+ def get_style_for_token(self, token_type: TokenType) -> Style:
189
+ """Look up style in the style map."""
190
+ try:
191
+ return self._style_cache[token_type]
192
+ except KeyError:
193
+ # Styles form a hierarchy
194
+ # We need to go from most to least specific
195
+ # e.g. ("foo", "bar", "baz") to ("foo", "bar") to ("foo",)
196
+ get_style = self.style_map.get
197
+ token = tuple(token_type)
198
+ style = self._missing_style
199
+ while token:
200
+ _style = get_style(token)
201
+ if _style is not None:
202
+ style = _style
203
+ break
204
+ token = token[:-1]
205
+ self._style_cache[token_type] = style
206
+ return style
207
+
208
+ def get_background_style(self) -> Style:
209
+ return self._background_style
210
+
211
+
212
+ SyntaxPosition = Tuple[int, int]
213
+
214
+
215
+ class _SyntaxHighlightRange(NamedTuple):
216
+ """
217
+ A range to highlight in a Syntax object.
218
+ `start` and `end` are 2-integers tuples, where the first integer is the line number
219
+ (starting from 1) and the second integer is the column index (starting from 0).
220
+ """
221
+
222
+ style: StyleType
223
+ start: SyntaxPosition
224
+ end: SyntaxPosition
225
+
226
+
227
+ class Syntax(JupyterMixin):
228
+ """Construct a Syntax object to render syntax highlighted code.
229
+
230
+ Args:
231
+ code (str): Code to highlight.
232
+ lexer (Lexer | str): Lexer to use (see https://pygments.org/docs/lexers/)
233
+ theme (str, optional): Color theme, aka Pygments style (see https://pygments.org/docs/styles/#getting-a-list-of-available-styles). Defaults to "monokai".
234
+ dedent (bool, optional): Enable stripping of initial whitespace. Defaults to False.
235
+ line_numbers (bool, optional): Enable rendering of line numbers. Defaults to False.
236
+ start_line (int, optional): Starting number for line numbers. Defaults to 1.
237
+ line_range (Tuple[int | None, int | None], optional): If given should be a tuple of the start and end line to render.
238
+ A value of None in the tuple indicates the range is open in that direction.
239
+ highlight_lines (Set[int]): A set of line numbers to highlight.
240
+ code_width: Width of code to render (not including line numbers), or ``None`` to use all available width.
241
+ tab_size (int, optional): Size of tabs. Defaults to 4.
242
+ word_wrap (bool, optional): Enable word wrapping.
243
+ background_color (str, optional): Optional background color, or None to use theme color. Defaults to None.
244
+ indent_guides (bool, optional): Show indent guides. Defaults to False.
245
+ padding (PaddingDimensions): Padding to apply around the syntax. Defaults to 0 (no padding).
246
+ """
247
+
248
+ _pygments_style_class: Type[PygmentsStyle]
249
+ _theme: SyntaxTheme
250
+
251
+ @classmethod
252
+ def get_theme(cls, name: Union[str, SyntaxTheme]) -> SyntaxTheme:
253
+ """Get a syntax theme instance."""
254
+ if isinstance(name, SyntaxTheme):
255
+ return name
256
+ theme: SyntaxTheme
257
+ if name in RICH_SYNTAX_THEMES:
258
+ theme = ANSISyntaxTheme(RICH_SYNTAX_THEMES[name])
259
+ else:
260
+ theme = PygmentsSyntaxTheme(name)
261
+ return theme
262
+
263
+ def __init__(
264
+ self,
265
+ code: str,
266
+ lexer: Union[Lexer, str],
267
+ *,
268
+ theme: Union[str, SyntaxTheme] = DEFAULT_THEME,
269
+ dedent: bool = False,
270
+ line_numbers: bool = False,
271
+ start_line: int = 1,
272
+ line_range: Optional[Tuple[Optional[int], Optional[int]]] = None,
273
+ highlight_lines: Optional[Set[int]] = None,
274
+ code_width: Optional[int] = None,
275
+ tab_size: int = 4,
276
+ word_wrap: bool = False,
277
+ background_color: Optional[str] = None,
278
+ indent_guides: bool = False,
279
+ padding: PaddingDimensions = 0,
280
+ ) -> None:
281
+ self.code = code
282
+ self._lexer = lexer
283
+ self.dedent = dedent
284
+ self.line_numbers = line_numbers
285
+ self.start_line = start_line
286
+ self.line_range = line_range
287
+ self.highlight_lines = highlight_lines or set()
288
+ self.code_width = code_width
289
+ self.tab_size = tab_size
290
+ self.word_wrap = word_wrap
291
+ self.background_color = background_color
292
+ self.background_style = (
293
+ Style(bgcolor=background_color) if background_color else Style()
294
+ )
295
+ self.indent_guides = indent_guides
296
+ self.padding = padding
297
+
298
+ self._theme = self.get_theme(theme)
299
+ self._stylized_ranges: List[_SyntaxHighlightRange] = []
300
+
301
+ @classmethod
302
+ def from_path(
303
+ cls,
304
+ path: str,
305
+ encoding: str = "utf-8",
306
+ lexer: Optional[Union[Lexer, str]] = None,
307
+ theme: Union[str, SyntaxTheme] = DEFAULT_THEME,
308
+ dedent: bool = False,
309
+ line_numbers: bool = False,
310
+ line_range: Optional[Tuple[int, int]] = None,
311
+ start_line: int = 1,
312
+ highlight_lines: Optional[Set[int]] = None,
313
+ code_width: Optional[int] = None,
314
+ tab_size: int = 4,
315
+ word_wrap: bool = False,
316
+ background_color: Optional[str] = None,
317
+ indent_guides: bool = False,
318
+ padding: PaddingDimensions = 0,
319
+ ) -> "Syntax":
320
+ """Construct a Syntax object from a file.
321
+
322
+ Args:
323
+ path (str): Path to file to highlight.
324
+ encoding (str): Encoding of file.
325
+ lexer (str | Lexer, optional): Lexer to use. If None, lexer will be auto-detected from path/file content.
326
+ theme (str, optional): Color theme, aka Pygments style (see https://pygments.org/docs/styles/#getting-a-list-of-available-styles). Defaults to "emacs".
327
+ dedent (bool, optional): Enable stripping of initial whitespace. Defaults to True.
328
+ line_numbers (bool, optional): Enable rendering of line numbers. Defaults to False.
329
+ start_line (int, optional): Starting number for line numbers. Defaults to 1.
330
+ line_range (Tuple[int, int], optional): If given should be a tuple of the start and end line to render.
331
+ highlight_lines (Set[int]): A set of line numbers to highlight.
332
+ code_width: Width of code to render (not including line numbers), or ``None`` to use all available width.
333
+ tab_size (int, optional): Size of tabs. Defaults to 4.
334
+ word_wrap (bool, optional): Enable word wrapping of code.
335
+ background_color (str, optional): Optional background color, or None to use theme color. Defaults to None.
336
+ indent_guides (bool, optional): Show indent guides. Defaults to False.
337
+ padding (PaddingDimensions): Padding to apply around the syntax. Defaults to 0 (no padding).
338
+
339
+ Returns:
340
+ [Syntax]: A Syntax object that may be printed to the console
341
+ """
342
+ code = Path(path).read_text(encoding=encoding)
343
+
344
+ if not lexer:
345
+ lexer = cls.guess_lexer(path, code=code)
346
+
347
+ return cls(
348
+ code,
349
+ lexer,
350
+ theme=theme,
351
+ dedent=dedent,
352
+ line_numbers=line_numbers,
353
+ line_range=line_range,
354
+ start_line=start_line,
355
+ highlight_lines=highlight_lines,
356
+ code_width=code_width,
357
+ tab_size=tab_size,
358
+ word_wrap=word_wrap,
359
+ background_color=background_color,
360
+ indent_guides=indent_guides,
361
+ padding=padding,
362
+ )
363
+
364
+ @classmethod
365
+ def guess_lexer(cls, path: str, code: Optional[str] = None) -> str:
366
+ """Guess the alias of the Pygments lexer to use based on a path and an optional string of code.
367
+ If code is supplied, it will use a combination of the code and the filename to determine the
368
+ best lexer to use. For example, if the file is ``index.html`` and the file contains Django
369
+ templating syntax, then "html+django" will be returned. If the file is ``index.html``, and no
370
+ templating language is used, the "html" lexer will be used. If no string of code
371
+ is supplied, the lexer will be chosen based on the file extension..
372
+
373
+ Args:
374
+ path (AnyStr): The path to the file containing the code you wish to know the lexer for.
375
+ code (str, optional): Optional string of code that will be used as a fallback if no lexer
376
+ is found for the supplied path.
377
+
378
+ Returns:
379
+ str: The name of the Pygments lexer that best matches the supplied path/code.
380
+ """
381
+ lexer: Optional[Lexer] = None
382
+ lexer_name = "default"
383
+ if code:
384
+ try:
385
+ lexer = guess_lexer_for_filename(path, code)
386
+ except ClassNotFound:
387
+ pass
388
+
389
+ if not lexer:
390
+ try:
391
+ _, ext = os.path.splitext(path)
392
+ if ext:
393
+ extension = ext.lstrip(".").lower()
394
+ lexer = get_lexer_by_name(extension)
395
+ except ClassNotFound:
396
+ pass
397
+
398
+ if lexer:
399
+ if lexer.aliases:
400
+ lexer_name = lexer.aliases[0]
401
+ else:
402
+ lexer_name = lexer.name
403
+
404
+ return lexer_name
405
+
406
+ def _get_base_style(self) -> Style:
407
+ """Get the base style."""
408
+ default_style = self._theme.get_background_style() + self.background_style
409
+ return default_style
410
+
411
+ def _get_token_color(self, token_type: TokenType) -> Optional[Color]:
412
+ """Get a color (if any) for the given token.
413
+
414
+ Args:
415
+ token_type (TokenType): A token type tuple from Pygments.
416
+
417
+ Returns:
418
+ Optional[Color]: Color from theme, or None for no color.
419
+ """
420
+ style = self._theme.get_style_for_token(token_type)
421
+ return style.color
422
+
423
+ @property
424
+ def lexer(self) -> Optional[Lexer]:
425
+ """The lexer for this syntax, or None if no lexer was found.
426
+
427
+ Tries to find the lexer by name if a string was passed to the constructor.
428
+ """
429
+
430
+ if isinstance(self._lexer, Lexer):
431
+ return self._lexer
432
+ try:
433
+ return get_lexer_by_name(
434
+ self._lexer,
435
+ stripnl=False,
436
+ ensurenl=True,
437
+ tabsize=self.tab_size,
438
+ )
439
+ except ClassNotFound:
440
+ return None
441
+
442
+ @property
443
+ def default_lexer(self) -> Lexer:
444
+ """A Pygments Lexer to use if one is not specified or invalid."""
445
+ return get_lexer_by_name(
446
+ "text",
447
+ stripnl=False,
448
+ ensurenl=True,
449
+ tabsize=self.tab_size,
450
+ )
451
+
452
+ def highlight(
453
+ self,
454
+ code: str,
455
+ line_range: Optional[Tuple[Optional[int], Optional[int]]] = None,
456
+ ) -> Text:
457
+ """Highlight code and return a Text instance.
458
+
459
+ Args:
460
+ code (str): Code to highlight.
461
+ line_range(Tuple[int, int], optional): Optional line range to highlight.
462
+
463
+ Returns:
464
+ Text: A text instance containing highlighted syntax.
465
+ """
466
+
467
+ base_style = self._get_base_style()
468
+ justify: JustifyMethod = (
469
+ "default" if base_style.transparent_background else "left"
470
+ )
471
+
472
+ text = Text(
473
+ justify=justify,
474
+ style=base_style,
475
+ tab_size=self.tab_size,
476
+ no_wrap=not self.word_wrap,
477
+ )
478
+ _get_theme_style = self._theme.get_style_for_token
479
+
480
+ lexer = self.lexer or self.default_lexer
481
+
482
+ if lexer is None:
483
+ text.append(code)
484
+ else:
485
+ if line_range:
486
+ # More complicated path to only stylize a portion of the code
487
+ # This speeds up further operations as there are less spans to process
488
+ line_start, line_end = line_range
489
+
490
+ def line_tokenize() -> Iterable[Tuple[Any, str]]:
491
+ """Split tokens to one per line."""
492
+ assert lexer # required to make MyPy happy - we know lexer is not None at this point
493
+
494
+ for token_type, token in lexer.get_tokens(code):
495
+ while token:
496
+ line_token, new_line, token = token.partition("\n")
497
+ yield token_type, line_token + new_line
498
+
499
+ def tokens_to_spans() -> Iterable[Tuple[str, Optional[Style]]]:
500
+ """Convert tokens to spans."""
501
+ tokens = iter(line_tokenize())
502
+ line_no = 0
503
+ _line_start = line_start - 1 if line_start else 0
504
+
505
+ # Skip over tokens until line start
506
+ while line_no < _line_start:
507
+ try:
508
+ _token_type, token = next(tokens)
509
+ except StopIteration:
510
+ break
511
+ yield (token, None)
512
+ if token.endswith("\n"):
513
+ line_no += 1
514
+ # Generate spans until line end
515
+ for token_type, token in tokens:
516
+ yield (token, _get_theme_style(token_type))
517
+ if token.endswith("\n"):
518
+ line_no += 1
519
+ if line_end and line_no >= line_end:
520
+ break
521
+
522
+ text.append_tokens(tokens_to_spans())
523
+
524
+ else:
525
+ text.append_tokens(
526
+ (token, _get_theme_style(token_type))
527
+ for token_type, token in lexer.get_tokens(code)
528
+ )
529
+ if self.background_color is not None:
530
+ text.stylize(f"on {self.background_color}")
531
+
532
+ if self._stylized_ranges:
533
+ self._apply_stylized_ranges(text)
534
+
535
+ return text
536
+
537
+ def stylize_range(
538
+ self, style: StyleType, start: SyntaxPosition, end: SyntaxPosition
539
+ ) -> None:
540
+ """
541
+ Adds a custom style on a part of the code, that will be applied to the syntax display when it's rendered.
542
+ Line numbers are 1-based, while column indexes are 0-based.
543
+
544
+ Args:
545
+ style (StyleType): The style to apply.
546
+ start (Tuple[int, int]): The start of the range, in the form `[line number, column index]`.
547
+ end (Tuple[int, int]): The end of the range, in the form `[line number, column index]`.
548
+ """
549
+ self._stylized_ranges.append(_SyntaxHighlightRange(style, start, end))
550
+
551
+ def _get_line_numbers_color(self, blend: float = 0.3) -> Color:
552
+ background_style = self._theme.get_background_style() + self.background_style
553
+ background_color = background_style.bgcolor
554
+ if background_color is None or background_color.is_system_defined:
555
+ return Color.default()
556
+ foreground_color = self._get_token_color(Token.Text)
557
+ if foreground_color is None or foreground_color.is_system_defined:
558
+ return foreground_color or Color.default()
559
+ new_color = blend_rgb(
560
+ background_color.get_truecolor(),
561
+ foreground_color.get_truecolor(),
562
+ cross_fade=blend,
563
+ )
564
+ return Color.from_triplet(new_color)
565
+
566
+ @property
567
+ def _numbers_column_width(self) -> int:
568
+ """Get the number of characters used to render the numbers column."""
569
+ column_width = 0
570
+ if self.line_numbers:
571
+ column_width = (
572
+ len(str(self.start_line + self.code.count("\n")))
573
+ + NUMBERS_COLUMN_DEFAULT_PADDING
574
+ )
575
+ return column_width
576
+
577
+ def _get_number_styles(self, console: Console) -> Tuple[Style, Style, Style]:
578
+ """Get background, number, and highlight styles for line numbers."""
579
+ background_style = self._get_base_style()
580
+ if background_style.transparent_background:
581
+ return Style.null(), Style(dim=True), Style.null()
582
+ if console.color_system in ("256", "truecolor"):
583
+ number_style = Style.chain(
584
+ background_style,
585
+ self._theme.get_style_for_token(Token.Text),
586
+ Style(color=self._get_line_numbers_color()),
587
+ self.background_style,
588
+ )
589
+ highlight_number_style = Style.chain(
590
+ background_style,
591
+ self._theme.get_style_for_token(Token.Text),
592
+ Style(bold=True, color=self._get_line_numbers_color(0.9)),
593
+ self.background_style,
594
+ )
595
+ else:
596
+ number_style = background_style + Style(dim=True)
597
+ highlight_number_style = background_style + Style(dim=False)
598
+ return background_style, number_style, highlight_number_style
599
+
600
+ def __rich_measure__(
601
+ self, console: "Console", options: "ConsoleOptions"
602
+ ) -> "Measurement":
603
+ _, right, _, left = Padding.unpack(self.padding)
604
+ padding = left + right
605
+ if self.code_width is not None:
606
+ width = self.code_width + self._numbers_column_width + padding + 1
607
+ return Measurement(self._numbers_column_width, width)
608
+ lines = self.code.splitlines()
609
+ width = (
610
+ self._numbers_column_width
611
+ + padding
612
+ + (max(cell_len(line) for line in lines) if lines else 0)
613
+ )
614
+ if self.line_numbers:
615
+ width += 1
616
+ return Measurement(self._numbers_column_width, width)
617
+
618
+ def __rich_console__(
619
+ self, console: Console, options: ConsoleOptions
620
+ ) -> RenderResult:
621
+ segments = Segments(self._get_syntax(console, options))
622
+ if self.padding:
623
+ yield Padding(
624
+ segments, style=self._theme.get_background_style(), pad=self.padding
625
+ )
626
+ else:
627
+ yield segments
628
+
629
+ def _get_syntax(
630
+ self,
631
+ console: Console,
632
+ options: ConsoleOptions,
633
+ ) -> Iterable[Segment]:
634
+ """
635
+ Get the Segments for the Syntax object, excluding any vertical/horizontal padding
636
+ """
637
+ transparent_background = self._get_base_style().transparent_background
638
+ code_width = (
639
+ (
640
+ (options.max_width - self._numbers_column_width - 1)
641
+ if self.line_numbers
642
+ else options.max_width
643
+ )
644
+ if self.code_width is None
645
+ else self.code_width
646
+ )
647
+
648
+ ends_on_nl, processed_code = self._process_code(self.code)
649
+ text = self.highlight(processed_code, self.line_range)
650
+
651
+ if not self.line_numbers and not self.word_wrap and not self.line_range:
652
+ if not ends_on_nl:
653
+ text.remove_suffix("\n")
654
+ # Simple case of just rendering text
655
+ style = (
656
+ self._get_base_style()
657
+ + self._theme.get_style_for_token(Comment)
658
+ + Style(dim=True)
659
+ + self.background_style
660
+ )
661
+ if self.indent_guides and not options.ascii_only:
662
+ text = text.with_indent_guides(self.tab_size, style=style)
663
+ text.overflow = "crop"
664
+ if style.transparent_background:
665
+ yield from console.render(
666
+ text, options=options.update(width=code_width)
667
+ )
668
+ else:
669
+ syntax_lines = console.render_lines(
670
+ text,
671
+ options.update(width=code_width, height=None, justify="left"),
672
+ style=self.background_style,
673
+ pad=True,
674
+ new_lines=True,
675
+ )
676
+ for syntax_line in syntax_lines:
677
+ yield from syntax_line
678
+ return
679
+
680
+ start_line, end_line = self.line_range or (None, None)
681
+ line_offset = 0
682
+ if start_line:
683
+ line_offset = max(0, start_line - 1)
684
+ lines: Union[List[Text], Lines] = text.split("\n", allow_blank=ends_on_nl)
685
+ if self.line_range:
686
+ if line_offset > len(lines):
687
+ return
688
+ lines = lines[line_offset:end_line]
689
+
690
+ if self.indent_guides and not options.ascii_only:
691
+ style = (
692
+ self._get_base_style()
693
+ + self._theme.get_style_for_token(Comment)
694
+ + Style(dim=True)
695
+ + self.background_style
696
+ )
697
+ lines = (
698
+ Text("\n")
699
+ .join(lines)
700
+ .with_indent_guides(self.tab_size, style=style + Style(italic=False))
701
+ .split("\n", allow_blank=True)
702
+ )
703
+
704
+ numbers_column_width = self._numbers_column_width
705
+ render_options = options.update(width=code_width)
706
+
707
+ highlight_line = self.highlight_lines.__contains__
708
+ _Segment = Segment
709
+ new_line = _Segment("\n")
710
+
711
+ line_pointer = "> " if options.legacy_windows else "❱ "
712
+
713
+ (
714
+ background_style,
715
+ number_style,
716
+ highlight_number_style,
717
+ ) = self._get_number_styles(console)
718
+
719
+ for line_no, line in enumerate(lines, self.start_line + line_offset):
720
+ if self.word_wrap:
721
+ wrapped_lines = console.render_lines(
722
+ line,
723
+ render_options.update(height=None, justify="left"),
724
+ style=background_style,
725
+ pad=not transparent_background,
726
+ )
727
+ else:
728
+ segments = list(line.render(console, end=""))
729
+ if options.no_wrap:
730
+ wrapped_lines = [segments]
731
+ else:
732
+ wrapped_lines = [
733
+ _Segment.adjust_line_length(
734
+ segments,
735
+ render_options.max_width,
736
+ style=background_style,
737
+ pad=not transparent_background,
738
+ )
739
+ ]
740
+
741
+ if self.line_numbers:
742
+ wrapped_line_left_pad = _Segment(
743
+ " " * numbers_column_width + " ", background_style
744
+ )
745
+ for first, wrapped_line in loop_first(wrapped_lines):
746
+ if first:
747
+ line_column = str(line_no).rjust(numbers_column_width - 2) + " "
748
+ if highlight_line(line_no):
749
+ yield _Segment(line_pointer, Style(color="red"))
750
+ yield _Segment(line_column, highlight_number_style)
751
+ else:
752
+ yield _Segment(" ", highlight_number_style)
753
+ yield _Segment(line_column, number_style)
754
+ else:
755
+ yield wrapped_line_left_pad
756
+ yield from wrapped_line
757
+ yield new_line
758
+ else:
759
+ for wrapped_line in wrapped_lines:
760
+ yield from wrapped_line
761
+ yield new_line
762
+
763
+ def _apply_stylized_ranges(self, text: Text) -> None:
764
+ """
765
+ Apply stylized ranges to a text instance,
766
+ using the given code to determine the right portion to apply the style to.
767
+
768
+ Args:
769
+ text (Text): Text instance to apply the style to.
770
+ """
771
+ code = text.plain
772
+ newlines_offsets = [
773
+ # Let's add outer boundaries at each side of the list:
774
+ 0,
775
+ # N.B. using "\n" here is much faster than using metacharacters such as "^" or "\Z":
776
+ *[
777
+ match.start() + 1
778
+ for match in re.finditer("\n", code, flags=re.MULTILINE)
779
+ ],
780
+ len(code) + 1,
781
+ ]
782
+
783
+ for stylized_range in self._stylized_ranges:
784
+ start = _get_code_index_for_syntax_position(
785
+ newlines_offsets, stylized_range.start
786
+ )
787
+ end = _get_code_index_for_syntax_position(
788
+ newlines_offsets, stylized_range.end
789
+ )
790
+ if start is not None and end is not None:
791
+ text.stylize(stylized_range.style, start, end)
792
+
793
+ def _process_code(self, code: str) -> Tuple[bool, str]:
794
+ """
795
+ Applies various processing to a raw code string
796
+ (normalises it so it always ends with a line return, dedents it if necessary, etc.)
797
+
798
+ Args:
799
+ code (str): The raw code string to process
800
+
801
+ Returns:
802
+ Tuple[bool, str]: the boolean indicates whether the raw code ends with a line return,
803
+ while the string is the processed code.
804
+ """
805
+ ends_on_nl = code.endswith("\n")
806
+ processed_code = code if ends_on_nl else code + "\n"
807
+ processed_code = (
808
+ textwrap.dedent(processed_code) if self.dedent else processed_code
809
+ )
810
+ processed_code = processed_code.expandtabs(self.tab_size)
811
+ return ends_on_nl, processed_code
812
+
813
+
814
+ def _get_code_index_for_syntax_position(
815
+ newlines_offsets: Sequence[int], position: SyntaxPosition
816
+ ) -> Optional[int]:
817
+ """
818
+ Returns the index of the code string for the given positions.
819
+
820
+ Args:
821
+ newlines_offsets (Sequence[int]): The offset of each newline character found in the code snippet.
822
+ position (SyntaxPosition): The position to search for.
823
+
824
+ Returns:
825
+ Optional[int]: The index of the code string for this position, or `None`
826
+ if the given position's line number is out of range (if it's the column that is out of range
827
+ we silently clamp its value so that it reaches the end of the line)
828
+ """
829
+ lines_count = len(newlines_offsets)
830
+
831
+ line_number, column_index = position
832
+ if line_number > lines_count or len(newlines_offsets) < (line_number + 1):
833
+ return None # `line_number` is out of range
834
+ line_index = line_number - 1
835
+ line_length = newlines_offsets[line_index + 1] - newlines_offsets[line_index] - 1
836
+ # If `column_index` is out of range: let's silently clamp it:
837
+ column_index = min(line_length, column_index)
838
+ return newlines_offsets[line_index] + column_index
839
+
840
+
841
+ if __name__ == "__main__": # pragma: no cover
842
+ import argparse
843
+ import sys
844
+
845
+ parser = argparse.ArgumentParser(
846
+ description="Render syntax to the console with Rich"
847
+ )
848
+ parser.add_argument(
849
+ "path",
850
+ metavar="PATH",
851
+ help="path to file, or - for stdin",
852
+ )
853
+ parser.add_argument(
854
+ "-c",
855
+ "--force-color",
856
+ dest="force_color",
857
+ action="store_true",
858
+ default=None,
859
+ help="force color for non-terminals",
860
+ )
861
+ parser.add_argument(
862
+ "-i",
863
+ "--indent-guides",
864
+ dest="indent_guides",
865
+ action="store_true",
866
+ default=False,
867
+ help="display indent guides",
868
+ )
869
+ parser.add_argument(
870
+ "-l",
871
+ "--line-numbers",
872
+ dest="line_numbers",
873
+ action="store_true",
874
+ help="render line numbers",
875
+ )
876
+ parser.add_argument(
877
+ "-w",
878
+ "--width",
879
+ type=int,
880
+ dest="width",
881
+ default=None,
882
+ help="width of output (default will auto-detect)",
883
+ )
884
+ parser.add_argument(
885
+ "-r",
886
+ "--wrap",
887
+ dest="word_wrap",
888
+ action="store_true",
889
+ default=False,
890
+ help="word wrap long lines",
891
+ )
892
+ parser.add_argument(
893
+ "-s",
894
+ "--soft-wrap",
895
+ action="store_true",
896
+ dest="soft_wrap",
897
+ default=False,
898
+ help="enable soft wrapping mode",
899
+ )
900
+ parser.add_argument(
901
+ "-t", "--theme", dest="theme", default="monokai", help="pygments theme"
902
+ )
903
+ parser.add_argument(
904
+ "-b",
905
+ "--background-color",
906
+ dest="background_color",
907
+ default=None,
908
+ help="Override background color",
909
+ )
910
+ parser.add_argument(
911
+ "-x",
912
+ "--lexer",
913
+ default=None,
914
+ dest="lexer_name",
915
+ help="Lexer name",
916
+ )
917
+ parser.add_argument(
918
+ "-p", "--padding", type=int, default=0, dest="padding", help="Padding"
919
+ )
920
+ parser.add_argument(
921
+ "--highlight-line",
922
+ type=int,
923
+ default=None,
924
+ dest="highlight_line",
925
+ help="The line number (not index!) to highlight",
926
+ )
927
+ args = parser.parse_args()
928
+
929
+ from pip._vendor.rich.console import Console
930
+
931
+ console = Console(force_terminal=args.force_color, width=args.width)
932
+
933
+ if args.path == "-":
934
+ code = sys.stdin.read()
935
+ syntax = Syntax(
936
+ code=code,
937
+ lexer=args.lexer_name,
938
+ line_numbers=args.line_numbers,
939
+ word_wrap=args.word_wrap,
940
+ theme=args.theme,
941
+ background_color=args.background_color,
942
+ indent_guides=args.indent_guides,
943
+ padding=args.padding,
944
+ highlight_lines={args.highlight_line},
945
+ )
946
+ else:
947
+ syntax = Syntax.from_path(
948
+ args.path,
949
+ lexer=args.lexer_name,
950
+ line_numbers=args.line_numbers,
951
+ word_wrap=args.word_wrap,
952
+ theme=args.theme,
953
+ background_color=args.background_color,
954
+ indent_guides=args.indent_guides,
955
+ padding=args.padding,
956
+ highlight_lines={args.highlight_line},
957
+ )
958
+ console.print(syntax, soft_wrap=args.soft_wrap)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/tree.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator, List, Optional, Tuple
2
+
3
+ from ._loop import loop_first, loop_last
4
+ from .console import Console, ConsoleOptions, RenderableType, RenderResult
5
+ from .jupyter import JupyterMixin
6
+ from .measure import Measurement
7
+ from .segment import Segment
8
+ from .style import Style, StyleStack, StyleType
9
+ from .styled import Styled
10
+
11
+
12
+ class Tree(JupyterMixin):
13
+ """A renderable for a tree structure.
14
+
15
+ Args:
16
+ label (RenderableType): The renderable or str for the tree label.
17
+ style (StyleType, optional): Style of this tree. Defaults to "tree".
18
+ guide_style (StyleType, optional): Style of the guide lines. Defaults to "tree.line".
19
+ expanded (bool, optional): Also display children. Defaults to True.
20
+ highlight (bool, optional): Highlight renderable (if str). Defaults to False.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ label: RenderableType,
26
+ *,
27
+ style: StyleType = "tree",
28
+ guide_style: StyleType = "tree.line",
29
+ expanded: bool = True,
30
+ highlight: bool = False,
31
+ hide_root: bool = False,
32
+ ) -> None:
33
+ self.label = label
34
+ self.style = style
35
+ self.guide_style = guide_style
36
+ self.children: List[Tree] = []
37
+ self.expanded = expanded
38
+ self.highlight = highlight
39
+ self.hide_root = hide_root
40
+
41
+ def add(
42
+ self,
43
+ label: RenderableType,
44
+ *,
45
+ style: Optional[StyleType] = None,
46
+ guide_style: Optional[StyleType] = None,
47
+ expanded: bool = True,
48
+ highlight: Optional[bool] = False,
49
+ ) -> "Tree":
50
+ """Add a child tree.
51
+
52
+ Args:
53
+ label (RenderableType): The renderable or str for the tree label.
54
+ style (StyleType, optional): Style of this tree. Defaults to "tree".
55
+ guide_style (StyleType, optional): Style of the guide lines. Defaults to "tree.line".
56
+ expanded (bool, optional): Also display children. Defaults to True.
57
+ highlight (Optional[bool], optional): Highlight renderable (if str). Defaults to False.
58
+
59
+ Returns:
60
+ Tree: A new child Tree, which may be further modified.
61
+ """
62
+ node = Tree(
63
+ label,
64
+ style=self.style if style is None else style,
65
+ guide_style=self.guide_style if guide_style is None else guide_style,
66
+ expanded=expanded,
67
+ highlight=self.highlight if highlight is None else highlight,
68
+ )
69
+ self.children.append(node)
70
+ return node
71
+
72
+ def __rich_console__(
73
+ self, console: "Console", options: "ConsoleOptions"
74
+ ) -> "RenderResult":
75
+ stack: List[Iterator[Tuple[bool, Tree]]] = []
76
+ pop = stack.pop
77
+ push = stack.append
78
+ new_line = Segment.line()
79
+
80
+ get_style = console.get_style
81
+ null_style = Style.null()
82
+ guide_style = get_style(self.guide_style, default="") or null_style
83
+ SPACE, CONTINUE, FORK, END = range(4)
84
+
85
+ ASCII_GUIDES = (" ", "| ", "+-- ", "`-- ")
86
+ TREE_GUIDES = [
87
+ (" ", "│ ", "├── ", "└── "),
88
+ (" ", "┃ ", "┣━━ ", "┗━━ "),
89
+ (" ", "║ ", "╠══ ", "╚══ "),
90
+ ]
91
+ _Segment = Segment
92
+
93
+ def make_guide(index: int, style: Style) -> Segment:
94
+ """Make a Segment for a level of the guide lines."""
95
+ if options.ascii_only:
96
+ line = ASCII_GUIDES[index]
97
+ else:
98
+ guide = 1 if style.bold else (2 if style.underline2 else 0)
99
+ line = TREE_GUIDES[0 if options.legacy_windows else guide][index]
100
+ return _Segment(line, style)
101
+
102
+ levels: List[Segment] = [make_guide(CONTINUE, guide_style)]
103
+ push(iter(loop_last([self])))
104
+
105
+ guide_style_stack = StyleStack(get_style(self.guide_style))
106
+ style_stack = StyleStack(get_style(self.style))
107
+ remove_guide_styles = Style(bold=False, underline2=False)
108
+
109
+ depth = 0
110
+
111
+ while stack:
112
+ stack_node = pop()
113
+ try:
114
+ last, node = next(stack_node)
115
+ except StopIteration:
116
+ levels.pop()
117
+ if levels:
118
+ guide_style = levels[-1].style or null_style
119
+ levels[-1] = make_guide(FORK, guide_style)
120
+ guide_style_stack.pop()
121
+ style_stack.pop()
122
+ continue
123
+ push(stack_node)
124
+ if last:
125
+ levels[-1] = make_guide(END, levels[-1].style or null_style)
126
+
127
+ guide_style = guide_style_stack.current + get_style(node.guide_style)
128
+ style = style_stack.current + get_style(node.style)
129
+ prefix = levels[(2 if self.hide_root else 1) :]
130
+ renderable_lines = console.render_lines(
131
+ Styled(node.label, style),
132
+ options.update(
133
+ width=options.max_width
134
+ - sum(level.cell_length for level in prefix),
135
+ highlight=self.highlight,
136
+ height=None,
137
+ ),
138
+ pad=options.justify is not None,
139
+ )
140
+
141
+ if not (depth == 0 and self.hide_root):
142
+ for first, line in loop_first(renderable_lines):
143
+ if prefix:
144
+ yield from _Segment.apply_style(
145
+ prefix,
146
+ style.background_style,
147
+ post_style=remove_guide_styles,
148
+ )
149
+ yield from line
150
+ yield new_line
151
+ if first and prefix:
152
+ prefix[-1] = make_guide(
153
+ SPACE if last else CONTINUE, prefix[-1].style or null_style
154
+ )
155
+
156
+ if node.expanded and node.children:
157
+ levels[-1] = make_guide(
158
+ SPACE if last else CONTINUE, levels[-1].style or null_style
159
+ )
160
+ levels.append(
161
+ make_guide(END if len(node.children) == 1 else FORK, guide_style)
162
+ )
163
+ style_stack.push(get_style(node.style))
164
+ guide_style_stack.push(get_style(node.guide_style))
165
+ push(iter(loop_last(node.children)))
166
+ depth += 1
167
+
168
+ def __rich_measure__(
169
+ self, console: "Console", options: "ConsoleOptions"
170
+ ) -> "Measurement":
171
+ stack: List[Iterator[Tree]] = [iter([self])]
172
+ pop = stack.pop
173
+ push = stack.append
174
+ minimum = 0
175
+ maximum = 0
176
+ measure = Measurement.get
177
+ level = 0
178
+ while stack:
179
+ iter_tree = pop()
180
+ try:
181
+ tree = next(iter_tree)
182
+ except StopIteration:
183
+ level -= 1
184
+ continue
185
+ push(iter_tree)
186
+ min_measure, max_measure = measure(console, options, tree.label)
187
+ indent = level * 4
188
+ minimum = max(min_measure + indent, minimum)
189
+ maximum = max(max_measure + indent, maximum)
190
+ if tree.expanded and tree.children:
191
+ push(iter(tree.children))
192
+ level += 1
193
+ return Measurement(minimum, maximum)
194
+
195
+
196
+ if __name__ == "__main__": # pragma: no cover
197
+ from pip._vendor.rich.console import Group
198
+ from pip._vendor.rich.markdown import Markdown
199
+ from pip._vendor.rich.panel import Panel
200
+ from pip._vendor.rich.syntax import Syntax
201
+ from pip._vendor.rich.table import Table
202
+
203
+ table = Table(row_styles=["", "dim"])
204
+
205
+ table.add_column("Released", style="cyan", no_wrap=True)
206
+ table.add_column("Title", style="magenta")
207
+ table.add_column("Box Office", justify="right", style="green")
208
+
209
+ table.add_row("Dec 20, 2019", "Star Wars: The Rise of Skywalker", "$952,110,690")
210
+ table.add_row("May 25, 2018", "Solo: A Star Wars Story", "$393,151,347")
211
+ table.add_row("Dec 15, 2017", "Star Wars Ep. V111: The Last Jedi", "$1,332,539,889")
212
+ table.add_row("Dec 16, 2016", "Rogue One: A Star Wars Story", "$1,332,439,889")
213
+
214
+ code = """\
215
+ class Segment(NamedTuple):
216
+ text: str = ""
217
+ style: Optional[Style] = None
218
+ is_control: bool = False
219
+ """
220
+ syntax = Syntax(code, "python", theme="monokai", line_numbers=True)
221
+
222
+ markdown = Markdown(
223
+ """\
224
+ ### example.md
225
+ > Hello, World!
226
+ >
227
+ > Markdown _all_ the things
228
+ """
229
+ )
230
+
231
+ root = Tree("🌲 [b green]Rich Tree", highlight=True, hide_root=True)
232
+
233
+ node = root.add(":file_folder: Renderables", guide_style="red")
234
+ simple_node = node.add(":file_folder: [bold yellow]Atomic", guide_style="uu green")
235
+ simple_node.add(Group("📄 Syntax", syntax))
236
+ simple_node.add(Group("📄 Markdown", Panel(markdown, border_style="green")))
237
+
238
+ containers_node = node.add(
239
+ ":file_folder: [bold magenta]Containers", guide_style="bold magenta"
240
+ )
241
+ containers_node.expanded = True
242
+ panel = Panel.fit("Just a panel", border_style="red")
243
+ containers_node.add(Group("📄 Panels", panel))
244
+
245
+ containers_node.add(Group("📄 [b magenta]Table", table))
246
+
247
+ console = Console()
248
+
249
+ console.print(root)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pybind11-2.13.6.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pybind11-2.13.6.dist-info/METADATA ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: pybind11
3
+ Version: 2.13.6
4
+ Summary: Seamless operability between C++11 and Python
5
+ Home-page: https://github.com/pybind/pybind11
6
+ Download-URL: https://github.com/pybind/pybind11/tarball/v2.13.6
7
+ Author: Wenzel Jakob
8
+ Author-email: wenzel.jakob@epfl.ch
9
+ License: BSD
10
+ Project-URL: Documentation, https://pybind11.readthedocs.io/
11
+ Project-URL: Bug Tracker, https://github.com/pybind/pybind11/issues
12
+ Project-URL: Discussions, https://github.com/pybind/pybind11/discussions
13
+ Project-URL: Changelog, https://pybind11.readthedocs.io/en/latest/changelog.html
14
+ Project-URL: Chat, https://gitter.im/pybind/Lobby
15
+ Keywords: C++11,Python bindings
16
+ Classifier: Development Status :: 5 - Production/Stable
17
+ Classifier: Intended Audience :: Developers
18
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
19
+ Classifier: Topic :: Utilities
20
+ Classifier: Programming Language :: C++
21
+ Classifier: Programming Language :: Python :: 3 :: Only
22
+ Classifier: Programming Language :: Python :: 3.7
23
+ Classifier: Programming Language :: Python :: 3.8
24
+ Classifier: Programming Language :: Python :: 3.9
25
+ Classifier: Programming Language :: Python :: 3.10
26
+ Classifier: Programming Language :: Python :: 3.11
27
+ Classifier: Programming Language :: Python :: 3.12
28
+ Classifier: Programming Language :: Python :: 3.13
29
+ Classifier: License :: OSI Approved :: BSD License
30
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
31
+ Classifier: Programming Language :: Python :: Implementation :: CPython
32
+ Classifier: Programming Language :: C++
33
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
34
+ Requires-Python: >=3.7
35
+ Description-Content-Type: text/x-rst
36
+ License-File: LICENSE
37
+ Provides-Extra: global
38
+ Requires-Dist: pybind11-global==2.13.6; extra == "global"
39
+
40
+ .. figure:: https://github.com/pybind/pybind11/raw/master/docs/pybind11-logo.png
41
+ :alt: pybind11 logo
42
+
43
+ **pybind11 — Seamless operability between C++11 and Python**
44
+
45
+ |Latest Documentation Status| |Stable Documentation Status| |Gitter chat| |GitHub Discussions| |CI| |Build status|
46
+
47
+ |Repology| |PyPI package| |Conda-forge| |Python Versions|
48
+
49
+ `Setuptools example <https://github.com/pybind/python_example>`_
50
+ • `Scikit-build example <https://github.com/pybind/scikit_build_example>`_
51
+ • `CMake example <https://github.com/pybind/cmake_example>`_
52
+
53
+ .. start
54
+
55
+
56
+ **pybind11** is a lightweight header-only library that exposes C++ types
57
+ in Python and vice versa, mainly to create Python bindings of existing
58
+ C++ code. Its goals and syntax are similar to the excellent
59
+ `Boost.Python <http://www.boost.org/doc/libs/1_58_0/libs/python/doc/>`_
60
+ library by David Abrahams: to minimize boilerplate code in traditional
61
+ extension modules by inferring type information using compile-time
62
+ introspection.
63
+
64
+ The main issue with Boost.Python—and the reason for creating such a
65
+ similar project—is Boost. Boost is an enormously large and complex suite
66
+ of utility libraries that works with almost every C++ compiler in
67
+ existence. This compatibility has its cost: arcane template tricks and
68
+ workarounds are necessary to support the oldest and buggiest of compiler
69
+ specimens. Now that C++11-compatible compilers are widely available,
70
+ this heavy machinery has become an excessively large and unnecessary
71
+ dependency.
72
+
73
+ Think of this library as a tiny self-contained version of Boost.Python
74
+ with everything stripped away that isn't relevant for binding
75
+ generation. Without comments, the core header files only require ~4K
76
+ lines of code and depend on Python (3.7+, or PyPy) and the C++
77
+ standard library. This compact implementation was possible thanks to
78
+ some C++11 language features (specifically: tuples, lambda functions and
79
+ variadic templates). Since its creation, this library has grown beyond
80
+ Boost.Python in many ways, leading to dramatically simpler binding code in many
81
+ common situations.
82
+
83
+ Tutorial and reference documentation is provided at
84
+ `pybind11.readthedocs.io <https://pybind11.readthedocs.io/en/latest>`_.
85
+ A PDF version of the manual is available
86
+ `here <https://pybind11.readthedocs.io/_/downloads/en/latest/pdf/>`_.
87
+ And the source code is always available at
88
+ `github.com/pybind/pybind11 <https://github.com/pybind/pybind11>`_.
89
+
90
+
91
+ Core features
92
+ -------------
93
+
94
+
95
+ pybind11 can map the following core C++ features to Python:
96
+
97
+ - Functions accepting and returning custom data structures per value,
98
+ reference, or pointer
99
+ - Instance methods and static methods
100
+ - Overloaded functions
101
+ - Instance attributes and static attributes
102
+ - Arbitrary exception types
103
+ - Enumerations
104
+ - Callbacks
105
+ - Iterators and ranges
106
+ - Custom operators
107
+ - Single and multiple inheritance
108
+ - STL data structures
109
+ - Smart pointers with reference counting like ``std::shared_ptr``
110
+ - Internal references with correct reference counting
111
+ - C++ classes with virtual (and pure virtual) methods can be extended
112
+ in Python
113
+ - Integrated NumPy support (NumPy 2 requires pybind11 2.12+)
114
+
115
+ Goodies
116
+ -------
117
+
118
+ In addition to the core functionality, pybind11 provides some extra
119
+ goodies:
120
+
121
+ - Python 3.7+, and PyPy3 7.3 are supported with an implementation-agnostic
122
+ interface (pybind11 2.9 was the last version to support Python 2 and 3.5).
123
+
124
+ - It is possible to bind C++11 lambda functions with captured
125
+ variables. The lambda capture data is stored inside the resulting
126
+ Python function object.
127
+
128
+ - pybind11 uses C++11 move constructors and move assignment operators
129
+ whenever possible to efficiently transfer custom data types.
130
+
131
+ - It's easy to expose the internal storage of custom data types through
132
+ Pythons' buffer protocols. This is handy e.g. for fast conversion
133
+ between C++ matrix classes like Eigen and NumPy without expensive
134
+ copy operations.
135
+
136
+ - pybind11 can automatically vectorize functions so that they are
137
+ transparently applied to all entries of one or more NumPy array
138
+ arguments.
139
+
140
+ - Python's slice-based access and assignment operations can be
141
+ supported with just a few lines of code.
142
+
143
+ - Everything is contained in just a few header files; there is no need
144
+ to link against any additional libraries.
145
+
146
+ - Binaries are generally smaller by a factor of at least 2 compared to
147
+ equivalent bindings generated by Boost.Python. A recent pybind11
148
+ conversion of PyRosetta, an enormous Boost.Python binding project,
149
+ `reported <https://graylab.jhu.edu/Sergey/2016.RosettaCon/PyRosetta-4.pdf>`_
150
+ a binary size reduction of **5.4x** and compile time reduction by
151
+ **5.8x**.
152
+
153
+ - Function signatures are precomputed at compile time (using
154
+ ``constexpr``), leading to smaller binaries.
155
+
156
+ - With little extra effort, C++ types can be pickled and unpickled
157
+ similar to regular Python objects.
158
+
159
+ Supported compilers
160
+ -------------------
161
+
162
+ 1. Clang/LLVM 3.3 or newer (for Apple Xcode's clang, this is 5.0.0 or
163
+ newer)
164
+ 2. GCC 4.8 or newer
165
+ 3. Microsoft Visual Studio 2017 or newer
166
+ 4. Intel classic C++ compiler 18 or newer (ICC 20.2 tested in CI)
167
+ 5. Cygwin/GCC (previously tested on 2.5.1)
168
+ 6. NVCC (CUDA 11.0 tested in CI)
169
+ 7. NVIDIA PGI (20.9 tested in CI)
170
+
171
+ About
172
+ -----
173
+
174
+ This project was created by `Wenzel
175
+ Jakob <http://rgl.epfl.ch/people/wjakob>`_. Significant features and/or
176
+ improvements to the code were contributed by Jonas Adler, Lori A. Burns,
177
+ Sylvain Corlay, Eric Cousineau, Aaron Gokaslan, Ralf Grosse-Kunstleve, Trent Houliston, Axel
178
+ Huebl, @hulucc, Yannick Jadoul, Sergey Lyskov, Johan Mabille, Tomasz Miąsko,
179
+ Dean Moldovan, Ben Pritchard, Jason Rhinelander, Boris Schäling, Pim
180
+ Schellart, Henry Schreiner, Ivan Smirnov, Boris Staletic, and Patrick Stewart.
181
+
182
+ We thank Google for a generous financial contribution to the continuous
183
+ integration infrastructure used by this project.
184
+
185
+
186
+ Contributing
187
+ ~~~~~~~~~~~~
188
+
189
+ See the `contributing
190
+ guide <https://github.com/pybind/pybind11/blob/master/.github/CONTRIBUTING.md>`_
191
+ for information on building and contributing to pybind11.
192
+
193
+ License
194
+ ~~~~~~~
195
+
196
+ pybind11 is provided under a BSD-style license that can be found in the
197
+ `LICENSE <https://github.com/pybind/pybind11/blob/master/LICENSE>`_
198
+ file. By using, distributing, or contributing to this project, you agree
199
+ to the terms and conditions of this license.
200
+
201
+ .. |Latest Documentation Status| image:: https://readthedocs.org/projects/pybind11/badge?version=latest
202
+ :target: http://pybind11.readthedocs.org/en/latest
203
+ .. |Stable Documentation Status| image:: https://img.shields.io/badge/docs-stable-blue.svg
204
+ :target: http://pybind11.readthedocs.org/en/stable
205
+ .. |Gitter chat| image:: https://img.shields.io/gitter/room/gitterHQ/gitter.svg
206
+ :target: https://gitter.im/pybind/Lobby
207
+ .. |CI| image:: https://github.com/pybind/pybind11/workflows/CI/badge.svg
208
+ :target: https://github.com/pybind/pybind11/actions
209
+ .. |Build status| image:: https://ci.appveyor.com/api/projects/status/riaj54pn4h08xy40?svg=true
210
+ :target: https://ci.appveyor.com/project/wjakob/pybind11
211
+ .. |PyPI package| image:: https://img.shields.io/pypi/v/pybind11.svg
212
+ :target: https://pypi.org/project/pybind11/
213
+ .. |Conda-forge| image:: https://img.shields.io/conda/vn/conda-forge/pybind11.svg
214
+ :target: https://github.com/conda-forge/pybind11-feedstock
215
+ .. |Repology| image:: https://repology.org/badge/latest-versions/python:pybind11.svg
216
+ :target: https://repology.org/project/python:pybind11/versions
217
+ .. |Python Versions| image:: https://img.shields.io/pypi/pyversions/pybind11.svg
218
+ :target: https://pypi.org/project/pybind11/
219
+ .. |GitHub Discussions| image:: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github
220
+ :target: https://github.com/pybind/pybind11/discussions
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modules import * # noqa: F403
2
+ from .modules.fused import _FusedModule # noqa: F403
3
+
4
+ # # Subpackages
5
+ # from . import qat # noqa: F403
6
+ # from . import quantized # noqa: F403
7
+
8
+ __all__ = [
9
+ 'ConvBn1d',
10
+ 'ConvBn2d',
11
+ 'ConvBn3d',
12
+ 'ConvBnReLU1d',
13
+ 'ConvBnReLU2d',
14
+ 'ConvBnReLU3d',
15
+ 'ConvReLU1d',
16
+ 'ConvReLU2d',
17
+ 'ConvReLU3d',
18
+ 'LinearReLU',
19
+ 'BNReLU2d',
20
+ 'BNReLU3d',
21
+ 'LinearBn1d',
22
+ 'LinearLeakyReLU',
23
+ 'LinearTanh',
24
+ 'ConvAdd2d',
25
+ 'ConvAddReLU2d',
26
+ ]
27
+
28
+ # We are exposing all subpackages to the end-user.
29
+ # Because of possible inter-dependency, we want to avoid
30
+ # the cyclic imports, thus implementing lazy version
31
+ # as per https://peps.python.org/pep-0562/
32
+ def __getattr__(name):
33
+ if name in __all__:
34
+ import importlib
35
+ return importlib.import_module("." + name, __name__)
36
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc ADDED
Binary file (8.46 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-311.pyc ADDED
Binary file (3.44 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Variables
2
+ from ._mappings import get_dynamic_sparse_quantized_mapping
3
+ from ._mappings import get_static_sparse_quantized_mapping
4
+
5
+ # Sparsifier
6
+ from .sparsifier.base_sparsifier import BaseSparsifier
7
+ from .sparsifier.weight_norm_sparsifier import WeightNormSparsifier
8
+ from .sparsifier.nearly_diagonal_sparsifier import NearlyDiagonalSparsifier
9
+
10
+ # Scheduler
11
+ from .scheduler.base_scheduler import BaseScheduler
12
+ from .scheduler.lambda_scheduler import LambdaSL
13
+ from .scheduler.cubic_scheduler import CubicSL
14
+
15
+ # Parametrizations
16
+ from .sparsifier.utils import FakeSparsity
17
+ from .sparsifier.utils import module_to_fqn
18
+ from .sparsifier.utils import fqn_to_module
19
+ from .sparsifier.utils import get_arg_info_from_tensor_fqn
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (253 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ import weakref
3
+ import abc
4
+ import warnings
5
+
6
+ from ..data_sparsifier import BaseDataSparsifier
7
+
8
+ __all__ = ['BaseDataScheduler']
9
+
10
+
11
+ class BaseDataScheduler:
12
+ r"""
13
+ The BaseDataScheduler is the abstract scheduler class specifically for the
14
+ BaseDataSparsifier class. This class controls a specific hyperparameter of
15
+ the sparsifier class and varies it across the training process (or across time).
16
+
17
+ Args:
18
+ data_sparsifier (instance of BaseDataSparsifier)
19
+ Implemented class data sparsifier class wherein the update_mask is implemented
20
+ schedule_param (str)
21
+ A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied
22
+ last_epoch (int, default=-1)
23
+ This is specifically is passed when training needs to be resumed from a particular
24
+ point.
25
+ verbose (bool, default=False)
26
+ Verbosity of the BaseDataScheduler
27
+
28
+ The *get_hyperparam()* function needs to be implemented by the user.
29
+ """
30
+ def __init__(self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False):
31
+ # Attach sparsifier
32
+ if not isinstance(data_sparsifier, BaseDataSparsifier):
33
+ raise TypeError('{} is not an instance of torch.ao.pruning.BaseDataSparsifier'.format(
34
+ type(data_sparsifier).__name__))
35
+ self.data_sparsifier = data_sparsifier
36
+ self.schedule_param = schedule_param
37
+
38
+ # Initialize epoch and base hyper-params
39
+ self.base_param = {
40
+ name: config.get(schedule_param, None)
41
+ for name, config in self.data_sparsifier.data_groups.items()
42
+ }
43
+
44
+ self.last_epoch = last_epoch
45
+
46
+ # Following https://github.com/pytorch/pytorch/issues/20124
47
+ # We would like to ensure that `scheduler.step()` is called after
48
+ # `sparsifier.step()`
49
+ def with_counter(method):
50
+ if getattr(method, '_with_counter', False):
51
+ # `sparsifier.step()` has already been replaced, return.
52
+ return method
53
+
54
+ # Keep a weak reference to the sparsifier instance to prevent
55
+ # cyclic references.
56
+ instance_ref = weakref.ref(method.__self__)
57
+ # Get the unbound method for the same purpose.
58
+ func = method.__func__
59
+ cls = instance_ref().__class__
60
+ del method
61
+
62
+ @wraps(func)
63
+ def wrapper(*args, **kwargs):
64
+ instance = instance_ref()
65
+ instance._step_count += 1 # type: ignore[union-attr]
66
+ wrapped = func.__get__(instance, cls)
67
+ return wrapped(*args, **kwargs)
68
+
69
+ # Note that the returned function here is no longer a bound method,
70
+ # so attributes like `__func__` and `__self__` no longer exist.
71
+ wrapper._with_counter = True # type: ignore[attr-defined]
72
+ return wrapper
73
+
74
+ self.data_sparsifier.step = with_counter(self.data_sparsifier.step) # type: ignore[assignment]
75
+ self.data_sparsifier._step_count = 0 # type: ignore[attr-defined]
76
+ self._step_count: int = 0
77
+ self.verbose = verbose
78
+
79
+ # Housekeeping
80
+ self._get_sp_called_within_step: bool = False # sp -> schedule parameter
81
+ self.step()
82
+
83
+ @abc.abstractmethod
84
+ def get_schedule_param(self):
85
+ r"""
86
+ Abstract method that needs to be implemented by the child class.
87
+ The expected return type should is a dictionary of name to schedule_param value
88
+ The returned values will be updated in sparsifier when the scheduler step() function
89
+ is called.
90
+
91
+ Example:
92
+ >>> def get_schedule_param(self):
93
+ ... new_param = {}
94
+ ... for name in self.sparsifier.data_groups.keys():
95
+ ... new_param[name] = self.sparsifier.data_groups[name][self.schedule_param] * 0.5
96
+ ... return new_param
97
+
98
+ When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param]
99
+ would be halved
100
+ """
101
+ raise NotImplementedError
102
+
103
+ def __repr__(self):
104
+ format_string = self.__class__.__name__ + ' ('
105
+ format_string += '\n'
106
+ format_string += f'Data Sparsifier {self.data_sparsifier}\n'
107
+ format_string += f' {self.schedule_param}: {self.base_param}\n'
108
+ format_string += ')'
109
+ return format_string
110
+
111
+ def state_dict(self):
112
+ """Returns the state of the scheduler as a :class:`dict`.
113
+
114
+ It contains an entry for every variable in self.__dict__ which
115
+ is not the sparsifier.
116
+
117
+ Note:
118
+ The scheduler class does not track the state of the data_sparsifier.
119
+ Make sure to store the state of the sparsifier before storing the
120
+ state of the scheduler
121
+ """
122
+ return {key: value for key, value in self.__dict__.items() if key != 'data_sparsifier'}
123
+
124
+ def load_state_dict(self, state_dict):
125
+ """Loads the schedulers state.
126
+
127
+ Note:
128
+ Remember to restore the state of the data_sparsifier before the scheduler.
129
+
130
+ Args:
131
+ state_dict (dict): scheduler state. Should be an object returned
132
+ from a call to :meth:`state_dict`.
133
+ """
134
+ self.__dict__.update(state_dict)
135
+
136
+ def get_last_param(self):
137
+ return self._last_param
138
+
139
+ def step(self):
140
+ # Raise warning if trying to call scheduler step before the sparsifier.
141
+ # https://github.com/pytorch/pytorch/issues/20124
142
+ if self._step_count == 1:
143
+ if not hasattr(self.data_sparsifier.step, "_with_counter"):
144
+ warnings.warn("Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler "
145
+ "initialization. Please, make sure to call `data_sparsifier.step()` before "
146
+ "`scheduler.step()`.", UserWarning)
147
+
148
+ # Just check if there were two first scheduler.step() calls before sparsifier.step()
149
+ elif self.data_sparsifier._step_count < 1: # type: ignore[attr-defined]
150
+ warnings.warn("Detected call of `scheduler.step()` before `data_sparsifier.step()`. "
151
+ "You have to make sure you run the data_sparsifier.step() BEFORE any "
152
+ "calls to the scheduler.step().", UserWarning)
153
+ self._step_count += 1
154
+
155
+ class _enable_get_sp_call:
156
+
157
+ def __init__(self, o):
158
+ self.o = o
159
+
160
+ def __enter__(self):
161
+ self.o._get_sp_called_within_step = True
162
+ return self
163
+
164
+ def __exit__(self, type, value, traceback):
165
+ self.o._get_sp_called_within_step = False
166
+
167
+ with _enable_get_sp_call(self):
168
+ self.last_epoch += 1
169
+ updated_scheduler_params = self.get_schedule_param()
170
+
171
+ for name, param in updated_scheduler_params.items():
172
+ self.data_sparsifier.data_groups[name][self.schedule_param] = param
173
+ if self.verbose:
174
+ print(f"Adjusting {self.schedule_param} for group {name} to {param}")
175
+
176
+ self._last_param = {
177
+ name: config.get(self.schedule_param, None)
178
+ for name, config in self.data_sparsifier.data_groups.items()
179
+ }
180
+ self.data_sparsifier.enable_mask_update = True