Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/big_modeling.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/checkpointing.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/launchers.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/optimizer.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/ElementSoup.py +10 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/_difflib.py +2106 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/formfill.py +299 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/html5parser.py +260 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/__init__.pxd +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/config.pxd +3 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/relaxng.pxd +64 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/schematron.pxd +34 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/xpath.pxd +136 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/isoschematron/__init__.py +348 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/compiler.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/driver.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/__init__.py +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/compiler.py +495 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/driver.c +504 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/driver.py +877 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/__init__.py +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py +553 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.c +518 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py +764 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/cudaGL.h +608 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/cupti_pcsampling_util.h +402 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/driver_types.h +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/generated_cudaVDPAU_meta.h +46 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/nvperf_target.h +626 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/sm_32_atomic_functions.hpp +151 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/math.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/random.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/standard.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/target_info.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__init__.py +26 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__pycache__/libdevice.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__init__.py +16 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/gdc.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/libdevice.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/utils.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/gdc.py +42 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/libdevice.py +1629 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/utils.py +109 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/hip/__init__.py +5 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/hip/__pycache__/__init__.cpython-312.pyc +0 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/big_modeling.cpython-312.pyc
ADDED
|
Binary file (36.9 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/checkpointing.cpython-312.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/launchers.cpython-312.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/optimizer.cpython-312.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/ElementSoup.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__doc__ = """Legacy interface to the BeautifulSoup HTML parser.
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
__all__ = ["parse", "convert_tree"]
|
| 5 |
+
|
| 6 |
+
from .soupparser import convert_tree, parse as _parse
|
| 7 |
+
|
| 8 |
+
def parse(file, beautifulsoup=None, makeelement=None):
|
| 9 |
+
root = _parse(file, beautifulsoup=beautifulsoup, makeelement=makeelement)
|
| 10 |
+
return root.getroot()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/_difflib.py
ADDED
|
@@ -0,0 +1,2106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from CPython 3.14b2+.
|
| 2 |
+
# cython: infer_types=True
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Module difflib -- helpers for computing deltas between objects.
|
| 6 |
+
|
| 7 |
+
Function get_close_matches(word, possibilities, n=3, cutoff=0.6):
|
| 8 |
+
Use SequenceMatcher to return list of the best "good enough" matches.
|
| 9 |
+
|
| 10 |
+
Function context_diff(a, b):
|
| 11 |
+
For two lists of strings, return a delta in context diff format.
|
| 12 |
+
|
| 13 |
+
Function ndiff(a, b):
|
| 14 |
+
Return a delta: the difference between `a` and `b` (lists of strings).
|
| 15 |
+
|
| 16 |
+
Function restore(delta, which):
|
| 17 |
+
Return one of the two sequences that generated an ndiff delta.
|
| 18 |
+
|
| 19 |
+
Function unified_diff(a, b):
|
| 20 |
+
For two lists of strings, return a delta in unified diff format.
|
| 21 |
+
|
| 22 |
+
Class SequenceMatcher:
|
| 23 |
+
A flexible class for comparing pairs of sequences of any type.
|
| 24 |
+
|
| 25 |
+
Class Differ:
|
| 26 |
+
For producing human-readable deltas from sequences of lines of text.
|
| 27 |
+
|
| 28 |
+
Class HtmlDiff:
|
| 29 |
+
For producing HTML side by side comparison with change highlights.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
import cython
|
| 34 |
+
except ImportError:
|
| 35 |
+
class fake_cython:
|
| 36 |
+
compiled = False
|
| 37 |
+
def cfunc(self, func): return func
|
| 38 |
+
def declare(self, _, value): return value
|
| 39 |
+
def __getattr__(self, type_name): return "object"
|
| 40 |
+
|
| 41 |
+
cython = fake_cython()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
__all__ = ['get_close_matches', 'ndiff', 'restore', 'SequenceMatcher',
|
| 45 |
+
'Differ','IS_CHARACTER_JUNK', 'IS_LINE_JUNK', 'context_diff',
|
| 46 |
+
'unified_diff', 'diff_bytes', 'HtmlDiff', 'Match']
|
| 47 |
+
|
| 48 |
+
from heapq import nlargest as _nlargest
|
| 49 |
+
from collections import namedtuple as _namedtuple
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from types import GenericAlias
|
| 53 |
+
except ImportError:
|
| 54 |
+
GenericAlias = None
|
| 55 |
+
|
| 56 |
+
Match = _namedtuple('Match', 'a b size')
|
| 57 |
+
|
| 58 |
+
def _calculate_ratio(matches, length):
|
| 59 |
+
if length:
|
| 60 |
+
return 2.0 * matches / length
|
| 61 |
+
return 1.0
|
| 62 |
+
|
| 63 |
+
class SequenceMatcher:
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
SequenceMatcher is a flexible class for comparing pairs of sequences of
|
| 67 |
+
any type, so long as the sequence elements are hashable. The basic
|
| 68 |
+
algorithm predates, and is a little fancier than, an algorithm
|
| 69 |
+
published in the late 1980's by Ratcliff and Obershelp under the
|
| 70 |
+
hyperbolic name "gestalt pattern matching". The basic idea is to find
|
| 71 |
+
the longest contiguous matching subsequence that contains no "junk"
|
| 72 |
+
elements (R-O doesn't address junk). The same idea is then applied
|
| 73 |
+
recursively to the pieces of the sequences to the left and to the right
|
| 74 |
+
of the matching subsequence. This does not yield minimal edit
|
| 75 |
+
sequences, but does tend to yield matches that "look right" to people.
|
| 76 |
+
|
| 77 |
+
SequenceMatcher tries to compute a "human-friendly diff" between two
|
| 78 |
+
sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the
|
| 79 |
+
longest *contiguous* & junk-free matching subsequence. That's what
|
| 80 |
+
catches peoples' eyes. The Windows(tm) windiff has another interesting
|
| 81 |
+
notion, pairing up elements that appear uniquely in each sequence.
|
| 82 |
+
That, and the method here, appear to yield more intuitive difference
|
| 83 |
+
reports than does diff. This method appears to be the least vulnerable
|
| 84 |
+
to syncing up on blocks of "junk lines", though (like blank lines in
|
| 85 |
+
ordinary text files, or maybe "<P>" lines in HTML files). That may be
|
| 86 |
+
because this is the only method of the 3 that has a *concept* of
|
| 87 |
+
"junk" <wink>.
|
| 88 |
+
|
| 89 |
+
Example, comparing two strings, and considering blanks to be "junk":
|
| 90 |
+
|
| 91 |
+
>>> s = SequenceMatcher(lambda x: x == " ",
|
| 92 |
+
... "private Thread currentThread;",
|
| 93 |
+
... "private volatile Thread currentThread;")
|
| 94 |
+
>>>
|
| 95 |
+
|
| 96 |
+
.ratio() returns a float in [0, 1], measuring the "similarity" of the
|
| 97 |
+
sequences. As a rule of thumb, a .ratio() value over 0.6 means the
|
| 98 |
+
sequences are close matches:
|
| 99 |
+
|
| 100 |
+
>>> print(round(s.ratio(), 3))
|
| 101 |
+
0.866
|
| 102 |
+
>>>
|
| 103 |
+
|
| 104 |
+
If you're only interested in where the sequences match,
|
| 105 |
+
.get_matching_blocks() is handy:
|
| 106 |
+
|
| 107 |
+
>>> for block in s.get_matching_blocks():
|
| 108 |
+
... print("a[%d] and b[%d] match for %d elements" % block)
|
| 109 |
+
a[0] and b[0] match for 8 elements
|
| 110 |
+
a[8] and b[17] match for 21 elements
|
| 111 |
+
a[29] and b[38] match for 0 elements
|
| 112 |
+
|
| 113 |
+
Note that the last tuple returned by .get_matching_blocks() is always a
|
| 114 |
+
dummy, (len(a), len(b), 0), and this is the only case in which the last
|
| 115 |
+
tuple element (number of elements matched) is 0.
|
| 116 |
+
|
| 117 |
+
If you want to know how to change the first sequence into the second,
|
| 118 |
+
use .get_opcodes():
|
| 119 |
+
|
| 120 |
+
>>> for opcode in s.get_opcodes():
|
| 121 |
+
... print("%6s a[%d:%d] b[%d:%d]" % opcode)
|
| 122 |
+
equal a[0:8] b[0:8]
|
| 123 |
+
insert a[8:8] b[8:17]
|
| 124 |
+
equal a[8:29] b[17:38]
|
| 125 |
+
|
| 126 |
+
See the Differ class for a fancy human-friendly file differencer, which
|
| 127 |
+
uses SequenceMatcher both to compare sequences of lines, and to compare
|
| 128 |
+
sequences of characters within similar (near-matching) lines.
|
| 129 |
+
|
| 130 |
+
See also function get_close_matches() in this module, which shows how
|
| 131 |
+
simple code building on SequenceMatcher can be used to do useful work.
|
| 132 |
+
|
| 133 |
+
Timing: Basic R-O is cubic time worst case and quadratic time expected
|
| 134 |
+
case. SequenceMatcher is quadratic time for the worst case and has
|
| 135 |
+
expected-case behavior dependent in a complicated way on how many
|
| 136 |
+
elements the sequences have in common; best case time is linear.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, isjunk=None, a='', b='', autojunk=True):
|
| 140 |
+
"""Construct a SequenceMatcher.
|
| 141 |
+
|
| 142 |
+
Optional arg isjunk is None (the default), or a one-argument
|
| 143 |
+
function that takes a sequence element and returns true iff the
|
| 144 |
+
element is junk. None is equivalent to passing "lambda x: 0", i.e.
|
| 145 |
+
no elements are considered to be junk. For example, pass
|
| 146 |
+
lambda x: x in " \\t"
|
| 147 |
+
if you're comparing lines as sequences of characters, and don't
|
| 148 |
+
want to synch up on blanks or hard tabs.
|
| 149 |
+
|
| 150 |
+
Optional arg a is the first of two sequences to be compared. By
|
| 151 |
+
default, an empty string. The elements of a must be hashable. See
|
| 152 |
+
also .set_seqs() and .set_seq1().
|
| 153 |
+
|
| 154 |
+
Optional arg b is the second of two sequences to be compared. By
|
| 155 |
+
default, an empty string. The elements of b must be hashable. See
|
| 156 |
+
also .set_seqs() and .set_seq2().
|
| 157 |
+
|
| 158 |
+
Optional arg autojunk should be set to False to disable the
|
| 159 |
+
"automatic junk heuristic" that treats popular elements as junk
|
| 160 |
+
(see module documentation for more information).
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
# Members:
|
| 164 |
+
# a
|
| 165 |
+
# first sequence
|
| 166 |
+
# b
|
| 167 |
+
# second sequence; differences are computed as "what do
|
| 168 |
+
# we need to do to 'a' to change it into 'b'?"
|
| 169 |
+
# b2j
|
| 170 |
+
# for x in b, b2j[x] is a list of the indices (into b)
|
| 171 |
+
# at which x appears; junk and popular elements do not appear
|
| 172 |
+
# fullbcount
|
| 173 |
+
# for x in b, fullbcount[x] == the number of times x
|
| 174 |
+
# appears in b; only materialized if really needed (used
|
| 175 |
+
# only for computing quick_ratio())
|
| 176 |
+
# matching_blocks
|
| 177 |
+
# a list of (i, j, k) triples, where a[i:i+k] == b[j:j+k];
|
| 178 |
+
# ascending & non-overlapping in i and in j; terminated by
|
| 179 |
+
# a dummy (len(a), len(b), 0) sentinel
|
| 180 |
+
# opcodes
|
| 181 |
+
# a list of (tag, i1, i2, j1, j2) tuples, where tag is
|
| 182 |
+
# one of
|
| 183 |
+
# 'replace' a[i1:i2] should be replaced by b[j1:j2]
|
| 184 |
+
# 'delete' a[i1:i2] should be deleted
|
| 185 |
+
# 'insert' b[j1:j2] should be inserted
|
| 186 |
+
# 'equal' a[i1:i2] == b[j1:j2]
|
| 187 |
+
# isjunk
|
| 188 |
+
# a user-supplied function taking a sequence element and
|
| 189 |
+
# returning true iff the element is "junk" -- this has
|
| 190 |
+
# subtle but helpful effects on the algorithm, which I'll
|
| 191 |
+
# get around to writing up someday <0.9 wink>.
|
| 192 |
+
# DON'T USE! Only __chain_b uses this. Use "in self.bjunk".
|
| 193 |
+
# bjunk
|
| 194 |
+
# the items in b for which isjunk is True.
|
| 195 |
+
# bpopular
|
| 196 |
+
# nonjunk items in b treated as junk by the heuristic (if used).
|
| 197 |
+
|
| 198 |
+
self.isjunk = isjunk
|
| 199 |
+
self.a = self.b = None
|
| 200 |
+
self.autojunk = autojunk
|
| 201 |
+
self.set_seqs(a, b)
|
| 202 |
+
|
| 203 |
+
def set_seqs(self, a, b):
|
| 204 |
+
"""Set the two sequences to be compared.
|
| 205 |
+
|
| 206 |
+
>>> s = SequenceMatcher()
|
| 207 |
+
>>> s.set_seqs("abcd", "bcde")
|
| 208 |
+
>>> s.ratio()
|
| 209 |
+
0.75
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
self.set_seq1(a)
|
| 213 |
+
self.set_seq2(b)
|
| 214 |
+
|
| 215 |
+
def set_seq1(self, a):
|
| 216 |
+
"""Set the first sequence to be compared.
|
| 217 |
+
|
| 218 |
+
The second sequence to be compared is not changed.
|
| 219 |
+
|
| 220 |
+
>>> s = SequenceMatcher(None, "abcd", "bcde")
|
| 221 |
+
>>> s.ratio()
|
| 222 |
+
0.75
|
| 223 |
+
>>> s.set_seq1("bcde")
|
| 224 |
+
>>> s.ratio()
|
| 225 |
+
1.0
|
| 226 |
+
>>>
|
| 227 |
+
|
| 228 |
+
SequenceMatcher computes and caches detailed information about the
|
| 229 |
+
second sequence, so if you want to compare one sequence S against
|
| 230 |
+
many sequences, use .set_seq2(S) once and call .set_seq1(x)
|
| 231 |
+
repeatedly for each of the other sequences.
|
| 232 |
+
|
| 233 |
+
See also set_seqs() and set_seq2().
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
if a is self.a:
|
| 237 |
+
return
|
| 238 |
+
self.a = a
|
| 239 |
+
self.matching_blocks = self.opcodes = None
|
| 240 |
+
|
| 241 |
+
def set_seq2(self, b):
|
| 242 |
+
"""Set the second sequence to be compared.
|
| 243 |
+
|
| 244 |
+
The first sequence to be compared is not changed.
|
| 245 |
+
|
| 246 |
+
>>> s = SequenceMatcher(None, "abcd", "bcde")
|
| 247 |
+
>>> s.ratio()
|
| 248 |
+
0.75
|
| 249 |
+
>>> s.set_seq2("abcd")
|
| 250 |
+
>>> s.ratio()
|
| 251 |
+
1.0
|
| 252 |
+
>>>
|
| 253 |
+
|
| 254 |
+
SequenceMatcher computes and caches detailed information about the
|
| 255 |
+
second sequence, so if you want to compare one sequence S against
|
| 256 |
+
many sequences, use .set_seq2(S) once and call .set_seq1(x)
|
| 257 |
+
repeatedly for each of the other sequences.
|
| 258 |
+
|
| 259 |
+
See also set_seqs() and set_seq1().
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
if b is self.b:
|
| 263 |
+
return
|
| 264 |
+
self.b = b
|
| 265 |
+
self.matching_blocks = self.opcodes = None
|
| 266 |
+
self.fullbcount = None
|
| 267 |
+
self.__chain_b()
|
| 268 |
+
|
| 269 |
+
# For each element x in b, set b2j[x] to a list of the indices in
|
| 270 |
+
# b where x appears; the indices are in increasing order; note that
|
| 271 |
+
# the number of times x appears in b is len(b2j[x]) ...
|
| 272 |
+
# when self.isjunk is defined, junk elements don't show up in this
|
| 273 |
+
# map at all, which stops the central find_longest_match method
|
| 274 |
+
# from starting any matching block at a junk element ...
|
| 275 |
+
# b2j also does not contain entries for "popular" elements, meaning
|
| 276 |
+
# elements that account for more than 1 + 1% of the total elements, and
|
| 277 |
+
# when the sequence is reasonably large (>= 200 elements); this can
|
| 278 |
+
# be viewed as an adaptive notion of semi-junk, and yields an enormous
|
| 279 |
+
# speedup when, e.g., comparing program files with hundreds of
|
| 280 |
+
# instances of "return NULL;" ...
|
| 281 |
+
# note that this is only called when b changes; so for cross-product
|
| 282 |
+
# kinds of matches, it's best to call set_seq2 once, then set_seq1
|
| 283 |
+
# repeatedly
|
| 284 |
+
|
| 285 |
+
def __chain_b(self):
|
| 286 |
+
# Because isjunk is a user-defined (not C) function, and we test
|
| 287 |
+
# for junk a LOT, it's important to minimize the number of calls.
|
| 288 |
+
# Before the tricks described here, __chain_b was by far the most
|
| 289 |
+
# time-consuming routine in the whole module! If anyone sees
|
| 290 |
+
# Jim Roskind, thank him again for profile.py -- I never would
|
| 291 |
+
# have guessed that.
|
| 292 |
+
# The first trick is to build b2j ignoring the possibility
|
| 293 |
+
# of junk. I.e., we don't call isjunk at all yet. Throwing
|
| 294 |
+
# out the junk later is much cheaper than building b2j "right"
|
| 295 |
+
# from the start.
|
| 296 |
+
b = self.b
|
| 297 |
+
self.b2j = b2j = {}
|
| 298 |
+
|
| 299 |
+
for i, elt in enumerate(b):
|
| 300 |
+
indices = b2j.setdefault(elt, [])
|
| 301 |
+
indices.append(i)
|
| 302 |
+
|
| 303 |
+
# Purge junk elements
|
| 304 |
+
self.bjunk = junk = set()
|
| 305 |
+
isjunk = self.isjunk
|
| 306 |
+
if isjunk:
|
| 307 |
+
for elt in b2j.keys():
|
| 308 |
+
if isjunk(elt):
|
| 309 |
+
junk.add(elt)
|
| 310 |
+
for elt in junk: # separate loop avoids separate list of keys
|
| 311 |
+
del b2j[elt]
|
| 312 |
+
|
| 313 |
+
# Purge popular elements that are not junk
|
| 314 |
+
self.bpopular = popular = set()
|
| 315 |
+
n = len(b)
|
| 316 |
+
if self.autojunk and n >= 200:
|
| 317 |
+
ntest = n // 100 + 1
|
| 318 |
+
for elt, idxs in b2j.items():
|
| 319 |
+
if len(idxs) > ntest:
|
| 320 |
+
popular.add(elt)
|
| 321 |
+
for elt in popular: # ditto; as fast for 1% deletion
|
| 322 |
+
del b2j[elt]
|
| 323 |
+
|
| 324 |
+
def find_longest_match(self, alo=0, ahi_=None, blo=0, bhi_=None):
|
| 325 |
+
"""Find longest matching block in a[alo:ahi] and b[blo:bhi].
|
| 326 |
+
|
| 327 |
+
By default it will find the longest match in the entirety of a and b.
|
| 328 |
+
|
| 329 |
+
If isjunk is not defined:
|
| 330 |
+
|
| 331 |
+
Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where
|
| 332 |
+
alo <= i <= i+k <= ahi
|
| 333 |
+
blo <= j <= j+k <= bhi
|
| 334 |
+
and for all (i',j',k') meeting those conditions,
|
| 335 |
+
k >= k'
|
| 336 |
+
i <= i'
|
| 337 |
+
and if i == i', j <= j'
|
| 338 |
+
|
| 339 |
+
In other words, of all maximal matching blocks, return one that
|
| 340 |
+
starts earliest in a, and of all those maximal matching blocks that
|
| 341 |
+
start earliest in a, return the one that starts earliest in b.
|
| 342 |
+
|
| 343 |
+
>>> s = SequenceMatcher(None, " abcd", "abcd abcd")
|
| 344 |
+
>>> s.find_longest_match(0, 5, 0, 9)
|
| 345 |
+
Match(a=0, b=4, size=5)
|
| 346 |
+
|
| 347 |
+
If isjunk is defined, first the longest matching block is
|
| 348 |
+
determined as above, but with the additional restriction that no
|
| 349 |
+
junk element appears in the block. Then that block is extended as
|
| 350 |
+
far as possible by matching (only) junk elements on both sides. So
|
| 351 |
+
the resulting block never matches on junk except as identical junk
|
| 352 |
+
happens to be adjacent to an "interesting" match.
|
| 353 |
+
|
| 354 |
+
Here's the same example as before, but considering blanks to be
|
| 355 |
+
junk. That prevents " abcd" from matching the " abcd" at the tail
|
| 356 |
+
end of the second sequence directly. Instead only the "abcd" can
|
| 357 |
+
match, and matches the leftmost "abcd" in the second sequence:
|
| 358 |
+
|
| 359 |
+
>>> s = SequenceMatcher(lambda x: x==" ", " abcd", "abcd abcd")
|
| 360 |
+
>>> s.find_longest_match(0, 5, 0, 9)
|
| 361 |
+
Match(a=1, b=0, size=4)
|
| 362 |
+
|
| 363 |
+
If no blocks match, return (alo, blo, 0).
|
| 364 |
+
|
| 365 |
+
>>> s = SequenceMatcher(None, "ab", "c")
|
| 366 |
+
>>> s.find_longest_match(0, 2, 0, 1)
|
| 367 |
+
Match(a=0, b=0, size=0)
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
# CAUTION: stripping common prefix or suffix would be incorrect.
|
| 371 |
+
# E.g.,
|
| 372 |
+
# ab
|
| 373 |
+
# acab
|
| 374 |
+
# Longest matching block is "ab", but if common prefix is
|
| 375 |
+
# stripped, it's "a" (tied with "b"). UNIX(tm) diff does so
|
| 376 |
+
# strip, so ends up claiming that ab is changed to acab by
|
| 377 |
+
# inserting "ca" in the middle. That's minimal but unintuitive:
|
| 378 |
+
# "it's obvious" that someone inserted "ac" at the front.
|
| 379 |
+
# Windiff ends up at the same place as diff, but by pairing up
|
| 380 |
+
# the unique 'b's and then matching the first two 'a's.
|
| 381 |
+
|
| 382 |
+
bjunk: set = self.bjunk
|
| 383 |
+
a, b, b2j = self.a, self.b, self.b2j
|
| 384 |
+
ahi = len(a) if ahi_ is None else ahi_
|
| 385 |
+
bhi = len(b) if bhi_ is None else bhi_
|
| 386 |
+
besti, bestj, bestsize = alo, blo, 0
|
| 387 |
+
# find longest junk-free match
|
| 388 |
+
# during an iteration of the loop, j2len[j] = length of longest
|
| 389 |
+
# junk-free match ending with a[i-1] and b[j]
|
| 390 |
+
j2len = {}
|
| 391 |
+
nothing = []
|
| 392 |
+
for i in range(alo, ahi):
|
| 393 |
+
# look at all instances of a[i] in b; note that because
|
| 394 |
+
# b2j has no junk keys, the loop is skipped if a[i] is junk
|
| 395 |
+
newj2len = {}
|
| 396 |
+
for j in b2j.get(a[i], nothing):
|
| 397 |
+
# a[i] matches b[j]
|
| 398 |
+
if j < blo:
|
| 399 |
+
continue
|
| 400 |
+
if j >= bhi:
|
| 401 |
+
break
|
| 402 |
+
k = newj2len[j] = j2len.get(j-1, 0) + 1
|
| 403 |
+
if k > bestsize:
|
| 404 |
+
besti, bestj, bestsize = i-k+1, j-k+1, k
|
| 405 |
+
j2len = newj2len
|
| 406 |
+
|
| 407 |
+
# Extend the best by non-junk elements on each end. In particular,
|
| 408 |
+
# "popular" non-junk elements aren't in b2j, which greatly speeds
|
| 409 |
+
# the inner loop above, but also means "the best" match so far
|
| 410 |
+
# doesn't contain any junk *or* popular non-junk elements.
|
| 411 |
+
while besti > alo and bestj > blo and \
|
| 412 |
+
b[bestj-1] not in bjunk and \
|
| 413 |
+
a[besti-1] == b[bestj-1]:
|
| 414 |
+
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
|
| 415 |
+
while besti+bestsize < ahi and bestj+bestsize < bhi and \
|
| 416 |
+
b[bestj+bestsize] not in bjunk and \
|
| 417 |
+
a[besti+bestsize] == b[bestj+bestsize]:
|
| 418 |
+
bestsize += 1
|
| 419 |
+
|
| 420 |
+
# Now that we have a wholly interesting match (albeit possibly
|
| 421 |
+
# empty!), we may as well suck up the matching junk on each
|
| 422 |
+
# side of it too. Can't think of a good reason not to, and it
|
| 423 |
+
# saves post-processing the (possibly considerable) expense of
|
| 424 |
+
# figuring out what to do with it. In the case of an empty
|
| 425 |
+
# interesting match, this is clearly the right thing to do,
|
| 426 |
+
# because no other kind of match is possible in the regions.
|
| 427 |
+
while besti > alo and bestj > blo and \
|
| 428 |
+
b[bestj-1] in bjunk and \
|
| 429 |
+
a[besti-1] == b[bestj-1]:
|
| 430 |
+
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
|
| 431 |
+
while besti+bestsize < ahi and bestj+bestsize < bhi and \
|
| 432 |
+
b[bestj+bestsize] in bjunk and \
|
| 433 |
+
a[besti+bestsize] == b[bestj+bestsize]:
|
| 434 |
+
bestsize = bestsize + 1
|
| 435 |
+
|
| 436 |
+
return Match(besti, bestj, bestsize)
|
| 437 |
+
|
| 438 |
+
def get_matching_blocks(self):
|
| 439 |
+
"""Return list of triples describing matching subsequences.
|
| 440 |
+
|
| 441 |
+
Each triple is of the form (i, j, n), and means that
|
| 442 |
+
a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in
|
| 443 |
+
i and in j. New in Python 2.5, it's also guaranteed that if
|
| 444 |
+
(i, j, n) and (i', j', n') are adjacent triples in the list, and
|
| 445 |
+
the second is not the last triple in the list, then i+n != i' or
|
| 446 |
+
j+n != j'. IOW, adjacent triples never describe adjacent equal
|
| 447 |
+
blocks.
|
| 448 |
+
|
| 449 |
+
The last triple is a dummy, (len(a), len(b), 0), and is the only
|
| 450 |
+
triple with n==0.
|
| 451 |
+
|
| 452 |
+
>>> s = SequenceMatcher(None, "abxcd", "abcd")
|
| 453 |
+
>>> list(s.get_matching_blocks())
|
| 454 |
+
[Match(a=0, b=0, size=2), Match(a=3, b=2, size=2), Match(a=5, b=4, size=0)]
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
if self.matching_blocks is not None:
|
| 458 |
+
return self.matching_blocks
|
| 459 |
+
la, lb = len(self.a), len(self.b)
|
| 460 |
+
|
| 461 |
+
# This is most naturally expressed as a recursive algorithm, but
|
| 462 |
+
# at least one user bumped into extreme use cases that exceeded
|
| 463 |
+
# the recursion limit on their box. So, now we maintain a list
|
| 464 |
+
# ('queue`) of blocks we still need to look at, and append partial
|
| 465 |
+
# results to `matching_blocks` in a loop; the matches are sorted
|
| 466 |
+
# at the end.
|
| 467 |
+
queue = [(0, la, 0, lb)]
|
| 468 |
+
matching_blocks = []
|
| 469 |
+
while queue:
|
| 470 |
+
alo, ahi, blo, bhi = queue.pop()
|
| 471 |
+
i, j, k = x = self.find_longest_match(alo, ahi, blo, bhi)
|
| 472 |
+
# a[alo:i] vs b[blo:j] unknown
|
| 473 |
+
# a[i:i+k] same as b[j:j+k]
|
| 474 |
+
# a[i+k:ahi] vs b[j+k:bhi] unknown
|
| 475 |
+
if k: # if k is 0, there was no matching block
|
| 476 |
+
matching_blocks.append(x)
|
| 477 |
+
if alo < i and blo < j:
|
| 478 |
+
queue.append((alo, i, blo, j))
|
| 479 |
+
if i+k < ahi and j+k < bhi:
|
| 480 |
+
queue.append((i+k, ahi, j+k, bhi))
|
| 481 |
+
matching_blocks.sort()
|
| 482 |
+
|
| 483 |
+
# It's possible that we have adjacent equal blocks in the
|
| 484 |
+
# matching_blocks list now. Starting with 2.5, this code was added
|
| 485 |
+
# to collapse them.
|
| 486 |
+
i1 = j1 = k1 = 0
|
| 487 |
+
non_adjacent = []
|
| 488 |
+
for i2, j2, k2 in matching_blocks:
|
| 489 |
+
# Is this block adjacent to i1, j1, k1?
|
| 490 |
+
if i1 + k1 == i2 and j1 + k1 == j2:
|
| 491 |
+
# Yes, so collapse them -- this just increases the length of
|
| 492 |
+
# the first block by the length of the second, and the first
|
| 493 |
+
# block so lengthened remains the block to compare against.
|
| 494 |
+
k1 += k2
|
| 495 |
+
else:
|
| 496 |
+
# Not adjacent. Remember the first block (k1==0 means it's
|
| 497 |
+
# the dummy we started with), and make the second block the
|
| 498 |
+
# new block to compare against.
|
| 499 |
+
if k1:
|
| 500 |
+
non_adjacent.append((i1, j1, k1))
|
| 501 |
+
i1, j1, k1 = i2, j2, k2
|
| 502 |
+
if k1:
|
| 503 |
+
non_adjacent.append((i1, j1, k1))
|
| 504 |
+
|
| 505 |
+
non_adjacent.append( (la, lb, 0) )
|
| 506 |
+
self.matching_blocks = list(map(Match._make, non_adjacent))
|
| 507 |
+
return self.matching_blocks
|
| 508 |
+
|
| 509 |
+
def get_opcodes(self):
|
| 510 |
+
"""Return list of 5-tuples describing how to turn a into b.
|
| 511 |
+
|
| 512 |
+
Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple
|
| 513 |
+
has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the
|
| 514 |
+
tuple preceding it, and likewise for j1 == the previous j2.
|
| 515 |
+
|
| 516 |
+
The tags are strings, with these meanings:
|
| 517 |
+
|
| 518 |
+
'replace': a[i1:i2] should be replaced by b[j1:j2]
|
| 519 |
+
'delete': a[i1:i2] should be deleted.
|
| 520 |
+
Note that j1==j2 in this case.
|
| 521 |
+
'insert': b[j1:j2] should be inserted at a[i1:i1].
|
| 522 |
+
Note that i1==i2 in this case.
|
| 523 |
+
'equal': a[i1:i2] == b[j1:j2]
|
| 524 |
+
|
| 525 |
+
>>> a = "qabxcd"
|
| 526 |
+
>>> b = "abycdf"
|
| 527 |
+
>>> s = SequenceMatcher(None, a, b)
|
| 528 |
+
>>> for tag, i1, i2, j1, j2 in s.get_opcodes():
|
| 529 |
+
... print(("%7s a[%d:%d] (%s) b[%d:%d] (%s)" %
|
| 530 |
+
... (tag, i1, i2, a[i1:i2], j1, j2, b[j1:j2])))
|
| 531 |
+
delete a[0:1] (q) b[0:0] ()
|
| 532 |
+
equal a[1:3] (ab) b[0:2] (ab)
|
| 533 |
+
replace a[3:4] (x) b[2:3] (y)
|
| 534 |
+
equal a[4:6] (cd) b[3:5] (cd)
|
| 535 |
+
insert a[6:6] () b[5:6] (f)
|
| 536 |
+
"""
|
| 537 |
+
|
| 538 |
+
if self.opcodes is not None:
|
| 539 |
+
return self.opcodes
|
| 540 |
+
i = j = 0
|
| 541 |
+
self.opcodes = answer = []
|
| 542 |
+
for ai, bj, size in self.get_matching_blocks():
|
| 543 |
+
# invariant: we've pumped out correct diffs to change
|
| 544 |
+
# a[:i] into b[:j], and the next matching block is
|
| 545 |
+
# a[ai:ai+size] == b[bj:bj+size]. So we need to pump
|
| 546 |
+
# out a diff to change a[i:ai] into b[j:bj], pump out
|
| 547 |
+
# the matching block, and move (i,j) beyond the match
|
| 548 |
+
tag = ''
|
| 549 |
+
if i < ai and j < bj:
|
| 550 |
+
tag = 'replace'
|
| 551 |
+
elif i < ai:
|
| 552 |
+
tag = 'delete'
|
| 553 |
+
elif j < bj:
|
| 554 |
+
tag = 'insert'
|
| 555 |
+
if tag:
|
| 556 |
+
answer.append( (tag, i, ai, j, bj) )
|
| 557 |
+
i, j = ai+size, bj+size
|
| 558 |
+
# the list of matching blocks is terminated by a
|
| 559 |
+
# sentinel with size 0
|
| 560 |
+
if size:
|
| 561 |
+
answer.append( ('equal', ai, i, bj, j) )
|
| 562 |
+
return answer
|
| 563 |
+
|
| 564 |
+
def get_grouped_opcodes(self, n=3):
|
| 565 |
+
""" Isolate change clusters by eliminating ranges with no changes.
|
| 566 |
+
|
| 567 |
+
Return a generator of groups with up to n lines of context.
|
| 568 |
+
Each group is in the same format as returned by get_opcodes().
|
| 569 |
+
|
| 570 |
+
>>> from pprint import pprint
|
| 571 |
+
>>> a = list(map(str, range(1,40)))
|
| 572 |
+
>>> b = a[:]
|
| 573 |
+
>>> b[8:8] = ['i'] # Make an insertion
|
| 574 |
+
>>> b[20] += 'x' # Make a replacement
|
| 575 |
+
>>> b[23:28] = [] # Make a deletion
|
| 576 |
+
>>> b[30] += 'y' # Make another replacement
|
| 577 |
+
>>> pprint(list(SequenceMatcher(None,a,b).get_grouped_opcodes()))
|
| 578 |
+
[[('equal', 5, 8, 5, 8), ('insert', 8, 8, 8, 9), ('equal', 8, 11, 9, 12)],
|
| 579 |
+
[('equal', 16, 19, 17, 20),
|
| 580 |
+
('replace', 19, 20, 20, 21),
|
| 581 |
+
('equal', 20, 22, 21, 23),
|
| 582 |
+
('delete', 22, 27, 23, 23),
|
| 583 |
+
('equal', 27, 30, 23, 26)],
|
| 584 |
+
[('equal', 31, 34, 27, 30),
|
| 585 |
+
('replace', 34, 35, 30, 31),
|
| 586 |
+
('equal', 35, 38, 31, 34)]]
|
| 587 |
+
"""
|
| 588 |
+
|
| 589 |
+
codes = self.get_opcodes()
|
| 590 |
+
if not codes:
|
| 591 |
+
codes = [("equal", 0, 1, 0, 1)]
|
| 592 |
+
# Fixup leading and trailing groups if they show no changes.
|
| 593 |
+
if codes[0][0] == 'equal':
|
| 594 |
+
tag, i1, i2, j1, j2 = codes[0]
|
| 595 |
+
codes[0] = tag, max(i1, i2-n), i2, max(j1, j2-n), j2
|
| 596 |
+
if codes[-1][0] == 'equal':
|
| 597 |
+
tag, i1, i2, j1, j2 = codes[-1]
|
| 598 |
+
codes[-1] = tag, i1, min(i2, i1+n), j1, min(j2, j1+n)
|
| 599 |
+
|
| 600 |
+
nn = n + n
|
| 601 |
+
group = []
|
| 602 |
+
for tag, i1, i2, j1, j2 in codes:
|
| 603 |
+
# End the current group and start a new one whenever
|
| 604 |
+
# there is a large range with no changes.
|
| 605 |
+
if tag == 'equal' and i2-i1 > nn:
|
| 606 |
+
group.append((tag, i1, min(i2, i1+n), j1, min(j2, j1+n)))
|
| 607 |
+
yield group
|
| 608 |
+
group = []
|
| 609 |
+
i1, j1 = max(i1, i2-n), max(j1, j2-n)
|
| 610 |
+
group.append((tag, i1, i2, j1 ,j2))
|
| 611 |
+
if group and not (len(group)==1 and group[0][0] == 'equal'):
|
| 612 |
+
yield group
|
| 613 |
+
|
| 614 |
+
def ratio(self):
|
| 615 |
+
"""Return a measure of the sequences' similarity (float in [0,1]).
|
| 616 |
+
|
| 617 |
+
Where T is the total number of elements in both sequences, and
|
| 618 |
+
M is the number of matches, this is 2.0*M / T.
|
| 619 |
+
Note that this is 1 if the sequences are identical, and 0 if
|
| 620 |
+
they have nothing in common.
|
| 621 |
+
|
| 622 |
+
.ratio() is expensive to compute if you haven't already computed
|
| 623 |
+
.get_matching_blocks() or .get_opcodes(), in which case you may
|
| 624 |
+
want to try .quick_ratio() or .real_quick_ratio() first to get an
|
| 625 |
+
upper bound.
|
| 626 |
+
|
| 627 |
+
>>> s = SequenceMatcher(None, "abcd", "bcde")
|
| 628 |
+
>>> s.ratio()
|
| 629 |
+
0.75
|
| 630 |
+
>>> s.quick_ratio()
|
| 631 |
+
0.75
|
| 632 |
+
>>> s.real_quick_ratio()
|
| 633 |
+
1.0
|
| 634 |
+
"""
|
| 635 |
+
|
| 636 |
+
matches: cython.Py_ssize_t
|
| 637 |
+
matches = sum(triple[-1] for triple in self.get_matching_blocks())
|
| 638 |
+
return _calculate_ratio(matches, len(self.a) + len(self.b))
|
| 639 |
+
|
| 640 |
+
def quick_ratio(self):
|
| 641 |
+
"""Return an upper bound on ratio() relatively quickly.
|
| 642 |
+
|
| 643 |
+
This isn't defined beyond that it is an upper bound on .ratio(), and
|
| 644 |
+
is faster to compute.
|
| 645 |
+
"""
|
| 646 |
+
|
| 647 |
+
# viewing a and b as multisets, set matches to the cardinality
|
| 648 |
+
# of their intersection; this counts the number of matches
|
| 649 |
+
# without regard to order, so is clearly an upper bound
|
| 650 |
+
if self.fullbcount is None:
|
| 651 |
+
self.fullbcount = fullbcount = {}
|
| 652 |
+
for elt in self.b:
|
| 653 |
+
fullbcount[elt] = fullbcount.get(elt, 0) + 1
|
| 654 |
+
fullbcount = self.fullbcount
|
| 655 |
+
# avail[x] is the number of times x appears in 'b' less the
|
| 656 |
+
# number of times we've seen it in 'a' so far ... kinda
|
| 657 |
+
avail = {}
|
| 658 |
+
matches: cython.Py_ssize_t
|
| 659 |
+
matches = 0
|
| 660 |
+
for elt in self.a:
|
| 661 |
+
if elt in avail:
|
| 662 |
+
numb = avail[elt]
|
| 663 |
+
else:
|
| 664 |
+
numb = fullbcount.get(elt, 0)
|
| 665 |
+
avail[elt] = numb - 1
|
| 666 |
+
if numb > 0:
|
| 667 |
+
matches = matches + 1
|
| 668 |
+
return _calculate_ratio(matches, len(self.a) + len(self.b))
|
| 669 |
+
|
| 670 |
+
def real_quick_ratio(self):
|
| 671 |
+
"""Return an upper bound on ratio() very quickly.
|
| 672 |
+
|
| 673 |
+
This isn't defined beyond that it is an upper bound on .ratio(), and
|
| 674 |
+
is faster to compute than either .ratio() or .quick_ratio().
|
| 675 |
+
"""
|
| 676 |
+
|
| 677 |
+
la, lb = len(self.a), len(self.b)
|
| 678 |
+
# can't have more matches than the number of elements in the
|
| 679 |
+
# shorter sequence
|
| 680 |
+
return _calculate_ratio(min(la, lb), la + lb)
|
| 681 |
+
|
| 682 |
+
if GenericAlias is not None:
|
| 683 |
+
__class_getitem__ = classmethod(GenericAlias)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def get_close_matches(word, possibilities, n=3, cutoff=0.6):
|
| 687 |
+
"""Use SequenceMatcher to return list of the best "good enough" matches.
|
| 688 |
+
|
| 689 |
+
word is a sequence for which close matches are desired (typically a
|
| 690 |
+
string).
|
| 691 |
+
|
| 692 |
+
possibilities is a list of sequences against which to match word
|
| 693 |
+
(typically a list of strings).
|
| 694 |
+
|
| 695 |
+
Optional arg n (default 3) is the maximum number of close matches to
|
| 696 |
+
return. n must be > 0.
|
| 697 |
+
|
| 698 |
+
Optional arg cutoff (default 0.6) is a float in [0, 1]. Possibilities
|
| 699 |
+
that don't score at least that similar to word are ignored.
|
| 700 |
+
|
| 701 |
+
The best (no more than n) matches among the possibilities are returned
|
| 702 |
+
in a list, sorted by similarity score, most similar first.
|
| 703 |
+
|
| 704 |
+
>>> get_close_matches("appel", ["ape", "apple", "peach", "puppy"])
|
| 705 |
+
['apple', 'ape']
|
| 706 |
+
>>> import keyword as _keyword
|
| 707 |
+
>>> get_close_matches("wheel", _keyword.kwlist)
|
| 708 |
+
['while']
|
| 709 |
+
>>> get_close_matches("Apple", _keyword.kwlist)
|
| 710 |
+
[]
|
| 711 |
+
>>> get_close_matches("accept", _keyword.kwlist)
|
| 712 |
+
['except']
|
| 713 |
+
"""
|
| 714 |
+
|
| 715 |
+
if not n > 0:
|
| 716 |
+
raise ValueError("n must be > 0: %r" % (n,))
|
| 717 |
+
if not 0.0 <= cutoff <= 1.0:
|
| 718 |
+
raise ValueError("cutoff must be in [0.0, 1.0]: %r" % (cutoff,))
|
| 719 |
+
result = []
|
| 720 |
+
s = SequenceMatcher()
|
| 721 |
+
s.set_seq2(word)
|
| 722 |
+
for x in possibilities:
|
| 723 |
+
s.set_seq1(x)
|
| 724 |
+
if s.real_quick_ratio() >= cutoff and \
|
| 725 |
+
s.quick_ratio() >= cutoff and \
|
| 726 |
+
s.ratio() >= cutoff:
|
| 727 |
+
result.append((s.ratio(), x))
|
| 728 |
+
|
| 729 |
+
# Move the best scorers to head of list
|
| 730 |
+
result = _nlargest(n, result)
|
| 731 |
+
# Strip scores for the best n matches
|
| 732 |
+
return [x for score, x in result]
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def _keep_original_ws(s, tag_s):
|
| 736 |
+
"""Replace whitespace with the original whitespace characters in `s`"""
|
| 737 |
+
return ''.join(
|
| 738 |
+
c if tag_c == " " and c.isspace() else tag_c
|
| 739 |
+
for c, tag_c in zip(s, tag_s)
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class Differ:
|
| 745 |
+
r"""
|
| 746 |
+
Differ is a class for comparing sequences of lines of text, and
|
| 747 |
+
producing human-readable differences or deltas. Differ uses
|
| 748 |
+
SequenceMatcher both to compare sequences of lines, and to compare
|
| 749 |
+
sequences of characters within similar (near-matching) lines.
|
| 750 |
+
|
| 751 |
+
Each line of a Differ delta begins with a two-letter code:
|
| 752 |
+
|
| 753 |
+
'- ' line unique to sequence 1
|
| 754 |
+
'+ ' line unique to sequence 2
|
| 755 |
+
' ' line common to both sequences
|
| 756 |
+
'? ' line not present in either input sequence
|
| 757 |
+
|
| 758 |
+
Lines beginning with '? ' attempt to guide the eye to intraline
|
| 759 |
+
differences, and were not present in either input sequence. These lines
|
| 760 |
+
can be confusing if the sequences contain tab characters.
|
| 761 |
+
|
| 762 |
+
Note that Differ makes no claim to produce a *minimal* diff. To the
|
| 763 |
+
contrary, minimal diffs are often counter-intuitive, because they synch
|
| 764 |
+
up anywhere possible, sometimes accidental matches 100 pages apart.
|
| 765 |
+
Restricting synch points to contiguous matches preserves some notion of
|
| 766 |
+
locality, at the occasional cost of producing a longer diff.
|
| 767 |
+
|
| 768 |
+
Example: Comparing two texts.
|
| 769 |
+
|
| 770 |
+
First we set up the texts, sequences of individual single-line strings
|
| 771 |
+
ending with newlines (such sequences can also be obtained from the
|
| 772 |
+
`readlines()` method of file-like objects):
|
| 773 |
+
|
| 774 |
+
>>> text1 = ''' 1. Beautiful is better than ugly.
|
| 775 |
+
... 2. Explicit is better than implicit.
|
| 776 |
+
... 3. Simple is better than complex.
|
| 777 |
+
... 4. Complex is better than complicated.
|
| 778 |
+
... '''.splitlines(keepends=True)
|
| 779 |
+
>>> len(text1)
|
| 780 |
+
4
|
| 781 |
+
>>> text1[0][-1]
|
| 782 |
+
'\n'
|
| 783 |
+
>>> text2 = ''' 1. Beautiful is better than ugly.
|
| 784 |
+
... 3. Simple is better than complex.
|
| 785 |
+
... 4. Complicated is better than complex.
|
| 786 |
+
... 5. Flat is better than nested.
|
| 787 |
+
... '''.splitlines(keepends=True)
|
| 788 |
+
|
| 789 |
+
Next we instantiate a Differ object:
|
| 790 |
+
|
| 791 |
+
>>> d = Differ()
|
| 792 |
+
|
| 793 |
+
Note that when instantiating a Differ object we may pass functions to
|
| 794 |
+
filter out line and character 'junk'. See Differ.__init__ for details.
|
| 795 |
+
|
| 796 |
+
Finally, we compare the two:
|
| 797 |
+
|
| 798 |
+
>>> result = list(d.compare(text1, text2))
|
| 799 |
+
|
| 800 |
+
'result' is a list of strings, so let's pretty-print it:
|
| 801 |
+
|
| 802 |
+
>>> from pprint import pprint as _pprint
|
| 803 |
+
>>> _pprint(result)
|
| 804 |
+
[' 1. Beautiful is better than ugly.\n',
|
| 805 |
+
'- 2. Explicit is better than implicit.\n',
|
| 806 |
+
'- 3. Simple is better than complex.\n',
|
| 807 |
+
'+ 3. Simple is better than complex.\n',
|
| 808 |
+
'? ++\n',
|
| 809 |
+
'- 4. Complex is better than complicated.\n',
|
| 810 |
+
'? ^ ---- ^\n',
|
| 811 |
+
'+ 4. Complicated is better than complex.\n',
|
| 812 |
+
'? ++++ ^ ^\n',
|
| 813 |
+
'+ 5. Flat is better than nested.\n']
|
| 814 |
+
|
| 815 |
+
As a single multi-line string it looks like this:
|
| 816 |
+
|
| 817 |
+
>>> print(''.join(result), end="")
|
| 818 |
+
1. Beautiful is better than ugly.
|
| 819 |
+
- 2. Explicit is better than implicit.
|
| 820 |
+
- 3. Simple is better than complex.
|
| 821 |
+
+ 3. Simple is better than complex.
|
| 822 |
+
? ++
|
| 823 |
+
- 4. Complex is better than complicated.
|
| 824 |
+
? ^ ---- ^
|
| 825 |
+
+ 4. Complicated is better than complex.
|
| 826 |
+
? ++++ ^ ^
|
| 827 |
+
+ 5. Flat is better than nested.
|
| 828 |
+
"""
|
| 829 |
+
|
| 830 |
+
def __init__(self, linejunk=None, charjunk=None):
|
| 831 |
+
"""
|
| 832 |
+
Construct a text differencer, with optional filters.
|
| 833 |
+
|
| 834 |
+
The two optional keyword parameters are for filter functions:
|
| 835 |
+
|
| 836 |
+
- `linejunk`: A function that should accept a single string argument,
|
| 837 |
+
and return true iff the string is junk. The module-level function
|
| 838 |
+
`IS_LINE_JUNK` may be used to filter out lines without visible
|
| 839 |
+
characters, except for at most one splat ('#'). It is recommended
|
| 840 |
+
to leave linejunk None; the underlying SequenceMatcher class has
|
| 841 |
+
an adaptive notion of "noise" lines that's better than any static
|
| 842 |
+
definition the author has ever been able to craft.
|
| 843 |
+
|
| 844 |
+
- `charjunk`: A function that should accept a string of length 1. The
|
| 845 |
+
module-level function `IS_CHARACTER_JUNK` may be used to filter out
|
| 846 |
+
whitespace characters (a blank or tab; **note**: bad idea to include
|
| 847 |
+
newline in this!). Use of IS_CHARACTER_JUNK is recommended.
|
| 848 |
+
"""
|
| 849 |
+
|
| 850 |
+
self.linejunk = linejunk
|
| 851 |
+
self.charjunk = charjunk
|
| 852 |
+
|
| 853 |
+
def compare(self, a, b):
|
| 854 |
+
r"""
|
| 855 |
+
Compare two sequences of lines; generate the resulting delta.
|
| 856 |
+
|
| 857 |
+
Each sequence must contain individual single-line strings ending with
|
| 858 |
+
newlines. Such sequences can be obtained from the `readlines()` method
|
| 859 |
+
of file-like objects. The delta generated also consists of newline-
|
| 860 |
+
terminated strings, ready to be printed as-is via the writelines()
|
| 861 |
+
method of a file-like object.
|
| 862 |
+
|
| 863 |
+
Example:
|
| 864 |
+
|
| 865 |
+
>>> print(''.join(Differ().compare('one\ntwo\nthree\n'.splitlines(True),
|
| 866 |
+
... 'ore\ntree\nemu\n'.splitlines(True))),
|
| 867 |
+
... end="")
|
| 868 |
+
- one
|
| 869 |
+
? ^
|
| 870 |
+
+ ore
|
| 871 |
+
? ^
|
| 872 |
+
- two
|
| 873 |
+
- three
|
| 874 |
+
? -
|
| 875 |
+
+ tree
|
| 876 |
+
+ emu
|
| 877 |
+
"""
|
| 878 |
+
|
| 879 |
+
cruncher = SequenceMatcher(self.linejunk, a, b)
|
| 880 |
+
for tag, alo, ahi, blo, bhi in cruncher.get_opcodes():
|
| 881 |
+
if tag == 'replace':
|
| 882 |
+
g = self._fancy_replace(a, alo, ahi, b, blo, bhi)
|
| 883 |
+
elif tag == 'delete':
|
| 884 |
+
g = self._dump('-', a, alo, ahi)
|
| 885 |
+
elif tag == 'insert':
|
| 886 |
+
g = self._dump('+', b, blo, bhi)
|
| 887 |
+
elif tag == 'equal':
|
| 888 |
+
g = self._dump(' ', a, alo, ahi)
|
| 889 |
+
else:
|
| 890 |
+
raise ValueError('unknown tag %r' % (tag,))
|
| 891 |
+
|
| 892 |
+
yield from g
|
| 893 |
+
|
| 894 |
+
def _dump(self, tag, x, lo, hi):
|
| 895 |
+
"""Generate comparison results for a same-tagged range."""
|
| 896 |
+
for i in range(lo, hi):
|
| 897 |
+
yield '%s %s' % (tag, x[i])
|
| 898 |
+
|
| 899 |
+
def _plain_replace(self, a, alo, ahi, b, blo, bhi):
|
| 900 |
+
assert alo < ahi and blo < bhi
|
| 901 |
+
# dump the shorter block first -- reduces the burden on short-term
|
| 902 |
+
# memory if the blocks are of very different sizes
|
| 903 |
+
if bhi - blo < ahi - alo:
|
| 904 |
+
first = self._dump('+', b, blo, bhi)
|
| 905 |
+
second = self._dump('-', a, alo, ahi)
|
| 906 |
+
else:
|
| 907 |
+
first = self._dump('-', a, alo, ahi)
|
| 908 |
+
second = self._dump('+', b, blo, bhi)
|
| 909 |
+
|
| 910 |
+
for g in first, second:
|
| 911 |
+
yield from g
|
| 912 |
+
|
| 913 |
+
def _fancy_replace(self, a, alo, ahi, b, blo, bhi):
|
| 914 |
+
r"""
|
| 915 |
+
When replacing one block of lines with another, search the blocks
|
| 916 |
+
for *similar* lines; the best-matching pair (if any) is used as a
|
| 917 |
+
synch point, and intraline difference marking is done on the
|
| 918 |
+
similar pair. Lots of work, but often worth it.
|
| 919 |
+
|
| 920 |
+
Example:
|
| 921 |
+
|
| 922 |
+
>>> d = Differ()
|
| 923 |
+
>>> results = d._fancy_replace(['abcDefghiJkl\n'], 0, 1,
|
| 924 |
+
... ['abcdefGhijkl\n'], 0, 1)
|
| 925 |
+
>>> print(''.join(results), end="")
|
| 926 |
+
- abcDefghiJkl
|
| 927 |
+
? ^ ^ ^
|
| 928 |
+
+ abcdefGhijkl
|
| 929 |
+
? ^ ^ ^
|
| 930 |
+
"""
|
| 931 |
+
# Don't synch up unless the lines have a similarity score above
|
| 932 |
+
# cutoff. Previously only the smallest pair was handled here,
|
| 933 |
+
# and if there are many pairs with the best ratio, recursion
|
| 934 |
+
# could grow very deep, and runtime cubic. See:
|
| 935 |
+
# https://github.com/python/cpython/issues/119105
|
| 936 |
+
#
|
| 937 |
+
# Later, more pathological cases prompted removing recursion
|
| 938 |
+
# entirely.
|
| 939 |
+
cutoff = 0.74999
|
| 940 |
+
cruncher = SequenceMatcher(self.charjunk)
|
| 941 |
+
crqr = cruncher.real_quick_ratio
|
| 942 |
+
cqr = cruncher.quick_ratio
|
| 943 |
+
cr = cruncher.ratio
|
| 944 |
+
|
| 945 |
+
WINDOW = 10
|
| 946 |
+
best_i = best_j = None
|
| 947 |
+
dump_i, dump_j = alo, blo # smallest indices not yet resolved
|
| 948 |
+
for j in range(blo, bhi):
|
| 949 |
+
cruncher.set_seq2(b[j])
|
| 950 |
+
# Search the corresponding i's within WINDOW for rhe highest
|
| 951 |
+
# ratio greater than `cutoff`.
|
| 952 |
+
aequiv = alo + (j - blo)
|
| 953 |
+
arange = range(max(aequiv - WINDOW, dump_i),
|
| 954 |
+
min(aequiv + WINDOW + 1, ahi))
|
| 955 |
+
if not arange: # likely exit if `a` is shorter than `b`
|
| 956 |
+
break
|
| 957 |
+
best_ratio = cutoff
|
| 958 |
+
for i in arange:
|
| 959 |
+
cruncher.set_seq1(a[i])
|
| 960 |
+
# Ordering by cheapest to most expensive ratio is very
|
| 961 |
+
# valuable, most often getting out early.
|
| 962 |
+
if (crqr() > best_ratio
|
| 963 |
+
and cqr() > best_ratio
|
| 964 |
+
and cr() > best_ratio):
|
| 965 |
+
best_i, best_j, best_ratio = i, j, cr()
|
| 966 |
+
|
| 967 |
+
if best_i is None:
|
| 968 |
+
# found nothing to synch on yet - move to next j
|
| 969 |
+
continue
|
| 970 |
+
|
| 971 |
+
# pump out straight replace from before this synch pair
|
| 972 |
+
yield from self._fancy_helper(a, dump_i, best_i,
|
| 973 |
+
b, dump_j, best_j)
|
| 974 |
+
# do intraline marking on the synch pair
|
| 975 |
+
aelt, belt = a[best_i], b[best_j]
|
| 976 |
+
if aelt != belt:
|
| 977 |
+
# pump out a '-', '?', '+', '?' quad for the synched lines
|
| 978 |
+
atags = btags = ""
|
| 979 |
+
cruncher.set_seqs(aelt, belt)
|
| 980 |
+
for tag, ai1, ai2, bj1, bj2 in cruncher.get_opcodes():
|
| 981 |
+
la, lb = ai2 - ai1, bj2 - bj1
|
| 982 |
+
if tag == 'replace':
|
| 983 |
+
atags += '^' * la
|
| 984 |
+
btags += '^' * lb
|
| 985 |
+
elif tag == 'delete':
|
| 986 |
+
atags += '-' * la
|
| 987 |
+
elif tag == 'insert':
|
| 988 |
+
btags += '+' * lb
|
| 989 |
+
elif tag == 'equal':
|
| 990 |
+
atags += ' ' * la
|
| 991 |
+
btags += ' ' * lb
|
| 992 |
+
else:
|
| 993 |
+
raise ValueError('unknown tag %r' % (tag,))
|
| 994 |
+
yield from self._qformat(aelt, belt, atags, btags)
|
| 995 |
+
else:
|
| 996 |
+
# the synch pair is identical
|
| 997 |
+
yield ' ' + aelt
|
| 998 |
+
dump_i, dump_j = best_i + 1, best_j + 1
|
| 999 |
+
best_i = best_j = None
|
| 1000 |
+
|
| 1001 |
+
# pump out straight replace from after the last synch pair
|
| 1002 |
+
yield from self._fancy_helper(a, dump_i, ahi,
|
| 1003 |
+
b, dump_j, bhi)
|
| 1004 |
+
|
| 1005 |
+
def _fancy_helper(self, a, alo, ahi, b, blo, bhi):
|
| 1006 |
+
g = []
|
| 1007 |
+
if alo < ahi:
|
| 1008 |
+
if blo < bhi:
|
| 1009 |
+
g = self._plain_replace(a, alo, ahi, b, blo, bhi)
|
| 1010 |
+
else:
|
| 1011 |
+
g = self._dump('-', a, alo, ahi)
|
| 1012 |
+
elif blo < bhi:
|
| 1013 |
+
g = self._dump('+', b, blo, bhi)
|
| 1014 |
+
|
| 1015 |
+
yield from g
|
| 1016 |
+
|
| 1017 |
+
def _qformat(self, aline, bline, atags, btags):
|
| 1018 |
+
r"""
|
| 1019 |
+
Format "?" output and deal with tabs.
|
| 1020 |
+
|
| 1021 |
+
Example:
|
| 1022 |
+
|
| 1023 |
+
>>> d = Differ()
|
| 1024 |
+
>>> results = d._qformat('\tabcDefghiJkl\n', '\tabcdefGhijkl\n',
|
| 1025 |
+
... ' ^ ^ ^ ', ' ^ ^ ^ ')
|
| 1026 |
+
>>> for line in results: print(repr(line))
|
| 1027 |
+
...
|
| 1028 |
+
'- \tabcDefghiJkl\n'
|
| 1029 |
+
'? \t ^ ^ ^\n'
|
| 1030 |
+
'+ \tabcdefGhijkl\n'
|
| 1031 |
+
'? \t ^ ^ ^\n'
|
| 1032 |
+
"""
|
| 1033 |
+
atags = _keep_original_ws(aline, atags).rstrip()
|
| 1034 |
+
btags = _keep_original_ws(bline, btags).rstrip()
|
| 1035 |
+
|
| 1036 |
+
yield "- " + aline
|
| 1037 |
+
if atags:
|
| 1038 |
+
yield f"? {atags}\n"
|
| 1039 |
+
|
| 1040 |
+
yield "+ " + bline
|
| 1041 |
+
if btags:
|
| 1042 |
+
yield f"? {btags}\n"
|
| 1043 |
+
|
| 1044 |
+
# With respect to junk, an earlier version of ndiff simply refused to
|
| 1045 |
+
# *start* a match with a junk element. The result was cases like this:
|
| 1046 |
+
# before: private Thread currentThread;
|
| 1047 |
+
# after: private volatile Thread currentThread;
|
| 1048 |
+
# If you consider whitespace to be junk, the longest contiguous match
|
| 1049 |
+
# not starting with junk is "e Thread currentThread". So ndiff reported
|
| 1050 |
+
# that "e volatil" was inserted between the 't' and the 'e' in "private".
|
| 1051 |
+
# While an accurate view, to people that's absurd. The current version
|
| 1052 |
+
# looks for matching blocks that are entirely junk-free, then extends the
|
| 1053 |
+
# longest one of those as far as possible but only with matching junk.
|
| 1054 |
+
# So now "currentThread" is matched, then extended to suck up the
|
| 1055 |
+
# preceding blank; then "private" is matched, and extended to suck up the
|
| 1056 |
+
# following blank; then "Thread" is matched; and finally ndiff reports
|
| 1057 |
+
# that "volatile " was inserted before "Thread". The only quibble
|
| 1058 |
+
# remaining is that perhaps it was really the case that " volatile"
|
| 1059 |
+
# was inserted after "private". I can live with that <wink>.
|
| 1060 |
+
|
| 1061 |
+
def IS_LINE_JUNK(line, pat=None):
|
| 1062 |
+
r"""
|
| 1063 |
+
Return True for ignorable line: if `line` is blank or contains a single '#'.
|
| 1064 |
+
|
| 1065 |
+
Examples:
|
| 1066 |
+
|
| 1067 |
+
>>> IS_LINE_JUNK('\n')
|
| 1068 |
+
True
|
| 1069 |
+
>>> IS_LINE_JUNK(' # \n')
|
| 1070 |
+
True
|
| 1071 |
+
>>> IS_LINE_JUNK('hello\n')
|
| 1072 |
+
False
|
| 1073 |
+
"""
|
| 1074 |
+
|
| 1075 |
+
if pat is None:
|
| 1076 |
+
# Default: match '#' or the empty string
|
| 1077 |
+
return line.strip() in '#'
|
| 1078 |
+
# Previous versions used the undocumented parameter 'pat' as a
|
| 1079 |
+
# match function. Retain this behaviour for compatibility.
|
| 1080 |
+
return pat(line) is not None
|
| 1081 |
+
|
| 1082 |
+
def IS_CHARACTER_JUNK(ch, ws=" \t"):
|
| 1083 |
+
r"""
|
| 1084 |
+
Return True for ignorable character: iff `ch` is a space or tab.
|
| 1085 |
+
|
| 1086 |
+
Examples:
|
| 1087 |
+
|
| 1088 |
+
>>> IS_CHARACTER_JUNK(' ')
|
| 1089 |
+
True
|
| 1090 |
+
>>> IS_CHARACTER_JUNK('\t')
|
| 1091 |
+
True
|
| 1092 |
+
>>> IS_CHARACTER_JUNK('\n')
|
| 1093 |
+
False
|
| 1094 |
+
>>> IS_CHARACTER_JUNK('x')
|
| 1095 |
+
False
|
| 1096 |
+
"""
|
| 1097 |
+
|
| 1098 |
+
return ch in ws
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
########################################################################
|
| 1102 |
+
### Unified Diff
|
| 1103 |
+
########################################################################
|
| 1104 |
+
|
| 1105 |
+
def _format_range_unified(start, stop):
|
| 1106 |
+
'Convert range to the "ed" format'
|
| 1107 |
+
# Per the diff spec at http://www.unix.org/single_unix_specification/
|
| 1108 |
+
beginning = start + 1 # lines start numbering with one
|
| 1109 |
+
length = stop - start
|
| 1110 |
+
if length == 1:
|
| 1111 |
+
return '{}'.format(beginning)
|
| 1112 |
+
if not length:
|
| 1113 |
+
beginning -= 1 # empty ranges begin at line just before the range
|
| 1114 |
+
return '{},{}'.format(beginning, length)
|
| 1115 |
+
|
| 1116 |
+
def unified_diff(a, b, fromfile='', tofile='', fromfiledate='',
|
| 1117 |
+
tofiledate='', n=3, lineterm='\n'):
|
| 1118 |
+
r"""
|
| 1119 |
+
Compare two sequences of lines; generate the delta as a unified diff.
|
| 1120 |
+
|
| 1121 |
+
Unified diffs are a compact way of showing line changes and a few
|
| 1122 |
+
lines of context. The number of context lines is set by 'n' which
|
| 1123 |
+
defaults to three.
|
| 1124 |
+
|
| 1125 |
+
By default, the diff control lines (those with ---, +++, or @@) are
|
| 1126 |
+
created with a trailing newline. This is helpful so that inputs
|
| 1127 |
+
created from file.readlines() result in diffs that are suitable for
|
| 1128 |
+
file.writelines() since both the inputs and outputs have trailing
|
| 1129 |
+
newlines.
|
| 1130 |
+
|
| 1131 |
+
For inputs that do not have trailing newlines, set the lineterm
|
| 1132 |
+
argument to "" so that the output will be uniformly newline free.
|
| 1133 |
+
|
| 1134 |
+
The unidiff format normally has a header for filenames and modification
|
| 1135 |
+
times. Any or all of these may be specified using strings for
|
| 1136 |
+
'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'.
|
| 1137 |
+
The modification times are normally expressed in the ISO 8601 format.
|
| 1138 |
+
|
| 1139 |
+
Example:
|
| 1140 |
+
|
| 1141 |
+
>>> for line in unified_diff('one two three four'.split(),
|
| 1142 |
+
... 'zero one tree four'.split(), 'Original', 'Current',
|
| 1143 |
+
... '2005-01-26 23:30:50', '2010-04-02 10:20:52',
|
| 1144 |
+
... lineterm=''):
|
| 1145 |
+
... print(line) # doctest: +NORMALIZE_WHITESPACE
|
| 1146 |
+
--- Original 2005-01-26 23:30:50
|
| 1147 |
+
+++ Current 2010-04-02 10:20:52
|
| 1148 |
+
@@ -1,4 +1,4 @@
|
| 1149 |
+
+zero
|
| 1150 |
+
one
|
| 1151 |
+
-two
|
| 1152 |
+
-three
|
| 1153 |
+
+tree
|
| 1154 |
+
four
|
| 1155 |
+
"""
|
| 1156 |
+
|
| 1157 |
+
_check_types(a, b, fromfile, tofile, fromfiledate, tofiledate, lineterm)
|
| 1158 |
+
started = False
|
| 1159 |
+
for group in SequenceMatcher(None,a,b).get_grouped_opcodes(n):
|
| 1160 |
+
if not started:
|
| 1161 |
+
started = True
|
| 1162 |
+
fromdate = '\t{}'.format(fromfiledate) if fromfiledate else ''
|
| 1163 |
+
todate = '\t{}'.format(tofiledate) if tofiledate else ''
|
| 1164 |
+
yield '--- {}{}{}'.format(fromfile, fromdate, lineterm)
|
| 1165 |
+
yield '+++ {}{}{}'.format(tofile, todate, lineterm)
|
| 1166 |
+
|
| 1167 |
+
first, last = group[0], group[-1]
|
| 1168 |
+
file1_range = _format_range_unified(first[1], last[2])
|
| 1169 |
+
file2_range = _format_range_unified(first[3], last[4])
|
| 1170 |
+
yield '@@ -{} +{} @@{}'.format(file1_range, file2_range, lineterm)
|
| 1171 |
+
|
| 1172 |
+
for tag, i1, i2, j1, j2 in group:
|
| 1173 |
+
if tag == 'equal':
|
| 1174 |
+
for line in a[i1:i2]:
|
| 1175 |
+
yield ' ' + line
|
| 1176 |
+
continue
|
| 1177 |
+
if tag in {'replace', 'delete'}:
|
| 1178 |
+
for line in a[i1:i2]:
|
| 1179 |
+
yield '-' + line
|
| 1180 |
+
if tag in {'replace', 'insert'}:
|
| 1181 |
+
for line in b[j1:j2]:
|
| 1182 |
+
yield '+' + line
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
########################################################################
|
| 1186 |
+
### Context Diff
|
| 1187 |
+
########################################################################
|
| 1188 |
+
|
| 1189 |
+
def _format_range_context(start, stop):
|
| 1190 |
+
'Convert range to the "ed" format'
|
| 1191 |
+
# Per the diff spec at http://www.unix.org/single_unix_specification/
|
| 1192 |
+
beginning = start + 1 # lines start numbering with one
|
| 1193 |
+
length = stop - start
|
| 1194 |
+
if not length:
|
| 1195 |
+
beginning -= 1 # empty ranges begin at line just before the range
|
| 1196 |
+
if length <= 1:
|
| 1197 |
+
return '{}'.format(beginning)
|
| 1198 |
+
return '{},{}'.format(beginning, beginning + length - 1)
|
| 1199 |
+
|
| 1200 |
+
# See http://www.unix.org/single_unix_specification/
|
| 1201 |
+
def context_diff(a, b, fromfile='', tofile='',
|
| 1202 |
+
fromfiledate='', tofiledate='', n=3, lineterm='\n'):
|
| 1203 |
+
r"""
|
| 1204 |
+
Compare two sequences of lines; generate the delta as a context diff.
|
| 1205 |
+
|
| 1206 |
+
Context diffs are a compact way of showing line changes and a few
|
| 1207 |
+
lines of context. The number of context lines is set by 'n' which
|
| 1208 |
+
defaults to three.
|
| 1209 |
+
|
| 1210 |
+
By default, the diff control lines (those with *** or ---) are
|
| 1211 |
+
created with a trailing newline. This is helpful so that inputs
|
| 1212 |
+
created from file.readlines() result in diffs that are suitable for
|
| 1213 |
+
file.writelines() since both the inputs and outputs have trailing
|
| 1214 |
+
newlines.
|
| 1215 |
+
|
| 1216 |
+
For inputs that do not have trailing newlines, set the lineterm
|
| 1217 |
+
argument to "" so that the output will be uniformly newline free.
|
| 1218 |
+
|
| 1219 |
+
The context diff format normally has a header for filenames and
|
| 1220 |
+
modification times. Any or all of these may be specified using
|
| 1221 |
+
strings for 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'.
|
| 1222 |
+
The modification times are normally expressed in the ISO 8601 format.
|
| 1223 |
+
If not specified, the strings default to blanks.
|
| 1224 |
+
|
| 1225 |
+
Example:
|
| 1226 |
+
|
| 1227 |
+
>>> print(''.join(context_diff('one\ntwo\nthree\nfour\n'.splitlines(True),
|
| 1228 |
+
... 'zero\none\ntree\nfour\n'.splitlines(True), 'Original', 'Current')),
|
| 1229 |
+
... end="")
|
| 1230 |
+
*** Original
|
| 1231 |
+
--- Current
|
| 1232 |
+
***************
|
| 1233 |
+
*** 1,4 ****
|
| 1234 |
+
one
|
| 1235 |
+
! two
|
| 1236 |
+
! three
|
| 1237 |
+
four
|
| 1238 |
+
--- 1,4 ----
|
| 1239 |
+
+ zero
|
| 1240 |
+
one
|
| 1241 |
+
! tree
|
| 1242 |
+
four
|
| 1243 |
+
"""
|
| 1244 |
+
|
| 1245 |
+
_check_types(a, b, fromfile, tofile, fromfiledate, tofiledate, lineterm)
|
| 1246 |
+
prefix = dict(insert='+ ', delete='- ', replace='! ', equal=' ')
|
| 1247 |
+
started = False
|
| 1248 |
+
for group in SequenceMatcher(None,a,b).get_grouped_opcodes(n):
|
| 1249 |
+
if not started:
|
| 1250 |
+
started = True
|
| 1251 |
+
fromdate = '\t{}'.format(fromfiledate) if fromfiledate else ''
|
| 1252 |
+
todate = '\t{}'.format(tofiledate) if tofiledate else ''
|
| 1253 |
+
yield '*** {}{}{}'.format(fromfile, fromdate, lineterm)
|
| 1254 |
+
yield '--- {}{}{}'.format(tofile, todate, lineterm)
|
| 1255 |
+
|
| 1256 |
+
first, last = group[0], group[-1]
|
| 1257 |
+
yield '***************' + lineterm
|
| 1258 |
+
|
| 1259 |
+
file1_range = _format_range_context(first[1], last[2])
|
| 1260 |
+
yield '*** {} ****{}'.format(file1_range, lineterm)
|
| 1261 |
+
|
| 1262 |
+
if any(tag in {'replace', 'delete'} for tag, _, _, _, _ in group):
|
| 1263 |
+
for tag, i1, i2, _, _ in group:
|
| 1264 |
+
if tag != 'insert':
|
| 1265 |
+
for line in a[i1:i2]:
|
| 1266 |
+
yield prefix[tag] + line
|
| 1267 |
+
|
| 1268 |
+
file2_range = _format_range_context(first[3], last[4])
|
| 1269 |
+
yield '--- {} ----{}'.format(file2_range, lineterm)
|
| 1270 |
+
|
| 1271 |
+
if any(tag in {'replace', 'insert'} for tag, _, _, _, _ in group):
|
| 1272 |
+
for tag, _, _, j1, j2 in group:
|
| 1273 |
+
if tag != 'delete':
|
| 1274 |
+
for line in b[j1:j2]:
|
| 1275 |
+
yield prefix[tag] + line
|
| 1276 |
+
|
| 1277 |
+
def _check_types(a, b, *args):
|
| 1278 |
+
# Checking types is weird, but the alternative is garbled output when
|
| 1279 |
+
# someone passes mixed bytes and str to {unified,context}_diff(). E.g.
|
| 1280 |
+
# without this check, passing filenames as bytes results in output like
|
| 1281 |
+
# --- b'oldfile.txt'
|
| 1282 |
+
# +++ b'newfile.txt'
|
| 1283 |
+
# because of how str.format() incorporates bytes objects.
|
| 1284 |
+
if a and not isinstance(a[0], str):
|
| 1285 |
+
raise TypeError('lines to compare must be str, not %s (%r)' %
|
| 1286 |
+
(type(a[0]).__name__, a[0]))
|
| 1287 |
+
if b and not isinstance(b[0], str):
|
| 1288 |
+
raise TypeError('lines to compare must be str, not %s (%r)' %
|
| 1289 |
+
(type(b[0]).__name__, b[0]))
|
| 1290 |
+
if isinstance(a, str):
|
| 1291 |
+
raise TypeError('input must be a sequence of strings, not %s' %
|
| 1292 |
+
type(a).__name__)
|
| 1293 |
+
if isinstance(b, str):
|
| 1294 |
+
raise TypeError('input must be a sequence of strings, not %s' %
|
| 1295 |
+
type(b).__name__)
|
| 1296 |
+
for arg in args:
|
| 1297 |
+
if not isinstance(arg, str):
|
| 1298 |
+
raise TypeError('all arguments must be str, not: %r' % (arg,))
|
| 1299 |
+
|
| 1300 |
+
def diff_bytes(dfunc, a, b, fromfile=b'', tofile=b'',
|
| 1301 |
+
fromfiledate=b'', tofiledate=b'', n=3, lineterm=b'\n'):
|
| 1302 |
+
r"""
|
| 1303 |
+
Compare `a` and `b`, two sequences of lines represented as bytes rather
|
| 1304 |
+
than str. This is a wrapper for `dfunc`, which is typically either
|
| 1305 |
+
unified_diff() or context_diff(). Inputs are losslessly converted to
|
| 1306 |
+
strings so that `dfunc` only has to worry about strings, and encoded
|
| 1307 |
+
back to bytes on return. This is necessary to compare files with
|
| 1308 |
+
unknown or inconsistent encoding. All other inputs (except `n`) must be
|
| 1309 |
+
bytes rather than str.
|
| 1310 |
+
"""
|
| 1311 |
+
def decode(s):
|
| 1312 |
+
try:
|
| 1313 |
+
return s.decode('ascii', 'surrogateescape')
|
| 1314 |
+
except AttributeError as err:
|
| 1315 |
+
msg = ('all arguments must be bytes, not %s (%r)' %
|
| 1316 |
+
(type(s).__name__, s))
|
| 1317 |
+
raise TypeError(msg) from err
|
| 1318 |
+
a = list(map(decode, a))
|
| 1319 |
+
b = list(map(decode, b))
|
| 1320 |
+
fromfile = decode(fromfile)
|
| 1321 |
+
tofile = decode(tofile)
|
| 1322 |
+
fromfiledate = decode(fromfiledate)
|
| 1323 |
+
tofiledate = decode(tofiledate)
|
| 1324 |
+
lineterm = decode(lineterm)
|
| 1325 |
+
|
| 1326 |
+
lines = dfunc(a, b, fromfile, tofile, fromfiledate, tofiledate, n, lineterm)
|
| 1327 |
+
for line in lines:
|
| 1328 |
+
yield line.encode('ascii', 'surrogateescape')
|
| 1329 |
+
|
| 1330 |
+
def ndiff(a, b, linejunk=None, charjunk=IS_CHARACTER_JUNK):
|
| 1331 |
+
r"""
|
| 1332 |
+
Compare `a` and `b` (lists of strings); return a `Differ`-style delta.
|
| 1333 |
+
|
| 1334 |
+
Optional keyword parameters `linejunk` and `charjunk` are for filter
|
| 1335 |
+
functions, or can be None:
|
| 1336 |
+
|
| 1337 |
+
- linejunk: A function that should accept a single string argument and
|
| 1338 |
+
return true iff the string is junk. The default is None, and is
|
| 1339 |
+
recommended; the underlying SequenceMatcher class has an adaptive
|
| 1340 |
+
notion of "noise" lines.
|
| 1341 |
+
|
| 1342 |
+
- charjunk: A function that accepts a character (string of length
|
| 1343 |
+
1), and returns true iff the character is junk. The default is
|
| 1344 |
+
the module-level function IS_CHARACTER_JUNK, which filters out
|
| 1345 |
+
whitespace characters (a blank or tab; note: it's a bad idea to
|
| 1346 |
+
include newline in this!).
|
| 1347 |
+
|
| 1348 |
+
Tools/scripts/ndiff.py is a command-line front-end to this function.
|
| 1349 |
+
|
| 1350 |
+
Example:
|
| 1351 |
+
|
| 1352 |
+
>>> diff = ndiff('one\ntwo\nthree\n'.splitlines(keepends=True),
|
| 1353 |
+
... 'ore\ntree\nemu\n'.splitlines(keepends=True))
|
| 1354 |
+
>>> print(''.join(diff), end="")
|
| 1355 |
+
- one
|
| 1356 |
+
? ^
|
| 1357 |
+
+ ore
|
| 1358 |
+
? ^
|
| 1359 |
+
- two
|
| 1360 |
+
- three
|
| 1361 |
+
? -
|
| 1362 |
+
+ tree
|
| 1363 |
+
+ emu
|
| 1364 |
+
"""
|
| 1365 |
+
return Differ(linejunk, charjunk).compare(a, b)
|
| 1366 |
+
|
| 1367 |
+
def _mdiff(fromlines, tolines, context=None, linejunk=None,
|
| 1368 |
+
charjunk=IS_CHARACTER_JUNK):
|
| 1369 |
+
r"""Returns generator yielding marked up from/to side by side differences.
|
| 1370 |
+
|
| 1371 |
+
Arguments:
|
| 1372 |
+
fromlines -- list of text lines to compared to tolines
|
| 1373 |
+
tolines -- list of text lines to be compared to fromlines
|
| 1374 |
+
context -- number of context lines to display on each side of difference,
|
| 1375 |
+
if None, all from/to text lines will be generated.
|
| 1376 |
+
linejunk -- passed on to ndiff (see ndiff documentation)
|
| 1377 |
+
charjunk -- passed on to ndiff (see ndiff documentation)
|
| 1378 |
+
|
| 1379 |
+
This function returns an iterator which returns a tuple:
|
| 1380 |
+
(from line tuple, to line tuple, boolean flag)
|
| 1381 |
+
|
| 1382 |
+
from/to line tuple -- (line num, line text)
|
| 1383 |
+
line num -- integer or None (to indicate a context separation)
|
| 1384 |
+
line text -- original line text with following markers inserted:
|
| 1385 |
+
'\0+' -- marks start of added text
|
| 1386 |
+
'\0-' -- marks start of deleted text
|
| 1387 |
+
'\0^' -- marks start of changed text
|
| 1388 |
+
'\1' -- marks end of added/deleted/changed text
|
| 1389 |
+
|
| 1390 |
+
boolean flag -- None indicates context separation, True indicates
|
| 1391 |
+
either "from" or "to" line contains a change, otherwise False.
|
| 1392 |
+
|
| 1393 |
+
This function/iterator was originally developed to generate side by side
|
| 1394 |
+
file difference for making HTML pages (see HtmlDiff class for example
|
| 1395 |
+
usage).
|
| 1396 |
+
|
| 1397 |
+
Note, this function utilizes the ndiff function to generate the side by
|
| 1398 |
+
side difference markup. Optional ndiff arguments may be passed to this
|
| 1399 |
+
function and they in turn will be passed to ndiff.
|
| 1400 |
+
"""
|
| 1401 |
+
import re
|
| 1402 |
+
|
| 1403 |
+
# regular expression for finding intraline change indices
|
| 1404 |
+
change_re = re.compile(r'(\++|\-+|\^+)')
|
| 1405 |
+
|
| 1406 |
+
# create the difference iterator to generate the differences
|
| 1407 |
+
diff_lines_iterator = ndiff(fromlines,tolines,linejunk,charjunk)
|
| 1408 |
+
|
| 1409 |
+
def _make_line(lines, format_key, side, num_lines=[0,0]):
|
| 1410 |
+
"""Returns line of text with user's change markup and line formatting.
|
| 1411 |
+
|
| 1412 |
+
lines -- list of lines from the ndiff generator to produce a line of
|
| 1413 |
+
text from. When producing the line of text to return, the
|
| 1414 |
+
lines used are removed from this list.
|
| 1415 |
+
format_key -- '+' return first line in list with "add" markup around
|
| 1416 |
+
the entire line.
|
| 1417 |
+
'-' return first line in list with "delete" markup around
|
| 1418 |
+
the entire line.
|
| 1419 |
+
'?' return first line in list with add/delete/change
|
| 1420 |
+
intraline markup (indices obtained from second line)
|
| 1421 |
+
None return first line in list with no markup
|
| 1422 |
+
side -- indice into the num_lines list (0=from,1=to)
|
| 1423 |
+
num_lines -- from/to current line number. This is NOT intended to be a
|
| 1424 |
+
passed parameter. It is present as a keyword argument to
|
| 1425 |
+
maintain memory of the current line numbers between calls
|
| 1426 |
+
of this function.
|
| 1427 |
+
|
| 1428 |
+
Note, this function is purposefully not defined at the module scope so
|
| 1429 |
+
that data it needs from its parent function (within whose context it
|
| 1430 |
+
is defined) does not need to be of module scope.
|
| 1431 |
+
"""
|
| 1432 |
+
num_lines[side] += 1
|
| 1433 |
+
# Handle case where no user markup is to be added, just return line of
|
| 1434 |
+
# text with user's line format to allow for usage of the line number.
|
| 1435 |
+
if format_key is None:
|
| 1436 |
+
return (num_lines[side],lines.pop(0)[2:])
|
| 1437 |
+
# Handle case of intraline changes
|
| 1438 |
+
if format_key == '?':
|
| 1439 |
+
text, markers = lines.pop(0), lines.pop(0)
|
| 1440 |
+
# find intraline changes (store change type and indices in tuples)
|
| 1441 |
+
sub_info = []
|
| 1442 |
+
def record_sub_info(match_object,sub_info=sub_info):
|
| 1443 |
+
sub_info.append([match_object.group(1)[0],match_object.span()])
|
| 1444 |
+
return match_object.group(1)
|
| 1445 |
+
change_re.sub(record_sub_info,markers)
|
| 1446 |
+
# process each tuple inserting our special marks that won't be
|
| 1447 |
+
# noticed by an xml/html escaper.
|
| 1448 |
+
for key,(begin,end) in reversed(sub_info):
|
| 1449 |
+
text = text[0:begin]+'\0'+key+text[begin:end]+'\1'+text[end:]
|
| 1450 |
+
text = text[2:]
|
| 1451 |
+
# Handle case of add/delete entire line
|
| 1452 |
+
else:
|
| 1453 |
+
text = lines.pop(0)[2:]
|
| 1454 |
+
# if line of text is just a newline, insert a space so there is
|
| 1455 |
+
# something for the user to highlight and see.
|
| 1456 |
+
if not text:
|
| 1457 |
+
text = ' '
|
| 1458 |
+
# insert marks that won't be noticed by an xml/html escaper.
|
| 1459 |
+
text = '\0' + format_key + text + '\1'
|
| 1460 |
+
# Return line of text, first allow user's line formatter to do its
|
| 1461 |
+
# thing (such as adding the line number) then replace the special
|
| 1462 |
+
# marks with what the user's change markup.
|
| 1463 |
+
return (num_lines[side],text)
|
| 1464 |
+
|
| 1465 |
+
def _line_iterator():
|
| 1466 |
+
"""Yields from/to lines of text with a change indication.
|
| 1467 |
+
|
| 1468 |
+
This function is an iterator. It itself pulls lines from a
|
| 1469 |
+
differencing iterator, processes them and yields them. When it can
|
| 1470 |
+
it yields both a "from" and a "to" line, otherwise it will yield one
|
| 1471 |
+
or the other. In addition to yielding the lines of from/to text, a
|
| 1472 |
+
boolean flag is yielded to indicate if the text line(s) have
|
| 1473 |
+
differences in them.
|
| 1474 |
+
|
| 1475 |
+
Note, this function is purposefully not defined at the module scope so
|
| 1476 |
+
that data it needs from its parent function (within whose context it
|
| 1477 |
+
is defined) does not need to be of module scope.
|
| 1478 |
+
"""
|
| 1479 |
+
lines = []
|
| 1480 |
+
num_blanks_pending, num_blanks_to_yield = 0, 0
|
| 1481 |
+
while True:
|
| 1482 |
+
# Load up next 4 lines so we can look ahead, create strings which
|
| 1483 |
+
# are a concatenation of the first character of each of the 4 lines
|
| 1484 |
+
# so we can do some very readable comparisons.
|
| 1485 |
+
while len(lines) < 4:
|
| 1486 |
+
lines.append(next(diff_lines_iterator, 'X'))
|
| 1487 |
+
s = ''.join([line[0] for line in lines])
|
| 1488 |
+
if s.startswith('X'):
|
| 1489 |
+
# When no more lines, pump out any remaining blank lines so the
|
| 1490 |
+
# corresponding add/delete lines get a matching blank line so
|
| 1491 |
+
# all line pairs get yielded at the next level.
|
| 1492 |
+
num_blanks_to_yield = num_blanks_pending
|
| 1493 |
+
elif s.startswith('-?+?'):
|
| 1494 |
+
# simple intraline change
|
| 1495 |
+
yield _make_line(lines,'?',0), _make_line(lines,'?',1), True
|
| 1496 |
+
continue
|
| 1497 |
+
elif s.startswith('--++'):
|
| 1498 |
+
# in delete block, add block coming: we do NOT want to get
|
| 1499 |
+
# caught up on blank lines yet, just process the delete line
|
| 1500 |
+
num_blanks_pending -= 1
|
| 1501 |
+
yield _make_line(lines,'-',0), None, True
|
| 1502 |
+
continue
|
| 1503 |
+
elif s.startswith(('--?+', '--+', '- ')):
|
| 1504 |
+
# in delete block and see an intraline change or unchanged line
|
| 1505 |
+
# coming: yield the delete line and then blanks
|
| 1506 |
+
from_line,to_line = _make_line(lines,'-',0), None
|
| 1507 |
+
num_blanks_to_yield,num_blanks_pending = num_blanks_pending-1,0
|
| 1508 |
+
elif s.startswith('-+?'):
|
| 1509 |
+
# intraline change
|
| 1510 |
+
yield _make_line(lines,None,0), _make_line(lines,'?',1), True
|
| 1511 |
+
continue
|
| 1512 |
+
elif s.startswith('-?+'):
|
| 1513 |
+
# intraline change
|
| 1514 |
+
yield _make_line(lines,'?',0), _make_line(lines,None,1), True
|
| 1515 |
+
continue
|
| 1516 |
+
elif s.startswith('-'):
|
| 1517 |
+
# delete FROM line
|
| 1518 |
+
num_blanks_pending -= 1
|
| 1519 |
+
yield _make_line(lines,'-',0), None, True
|
| 1520 |
+
continue
|
| 1521 |
+
elif s.startswith('+--'):
|
| 1522 |
+
# in add block, delete block coming: we do NOT want to get
|
| 1523 |
+
# caught up on blank lines yet, just process the add line
|
| 1524 |
+
num_blanks_pending += 1
|
| 1525 |
+
yield None, _make_line(lines,'+',1), True
|
| 1526 |
+
continue
|
| 1527 |
+
elif s.startswith(('+ ', '+-')):
|
| 1528 |
+
# will be leaving an add block: yield blanks then add line
|
| 1529 |
+
from_line, to_line = None, _make_line(lines,'+',1)
|
| 1530 |
+
num_blanks_to_yield,num_blanks_pending = num_blanks_pending+1,0
|
| 1531 |
+
elif s.startswith('+'):
|
| 1532 |
+
# inside an add block, yield the add line
|
| 1533 |
+
num_blanks_pending += 1
|
| 1534 |
+
yield None, _make_line(lines,'+',1), True
|
| 1535 |
+
continue
|
| 1536 |
+
elif s.startswith(' '):
|
| 1537 |
+
# unchanged text, yield it to both sides
|
| 1538 |
+
yield _make_line(lines[:],None,0),_make_line(lines,None,1),False
|
| 1539 |
+
continue
|
| 1540 |
+
# Catch up on the blank lines so when we yield the next from/to
|
| 1541 |
+
# pair, they are lined up.
|
| 1542 |
+
while(num_blanks_to_yield < 0):
|
| 1543 |
+
num_blanks_to_yield += 1
|
| 1544 |
+
yield None,('','\n'),True
|
| 1545 |
+
while(num_blanks_to_yield > 0):
|
| 1546 |
+
num_blanks_to_yield -= 1
|
| 1547 |
+
yield ('','\n'),None,True
|
| 1548 |
+
if s.startswith('X'):
|
| 1549 |
+
return
|
| 1550 |
+
else:
|
| 1551 |
+
yield from_line,to_line,True
|
| 1552 |
+
|
| 1553 |
+
def _line_pair_iterator():
|
| 1554 |
+
"""Yields from/to lines of text with a change indication.
|
| 1555 |
+
|
| 1556 |
+
This function is an iterator. It itself pulls lines from the line
|
| 1557 |
+
iterator. Its difference from that iterator is that this function
|
| 1558 |
+
always yields a pair of from/to text lines (with the change
|
| 1559 |
+
indication). If necessary it will collect single from/to lines
|
| 1560 |
+
until it has a matching pair from/to pair to yield.
|
| 1561 |
+
|
| 1562 |
+
Note, this function is purposefully not defined at the module scope so
|
| 1563 |
+
that data it needs from its parent function (within whose context it
|
| 1564 |
+
is defined) does not need to be of module scope.
|
| 1565 |
+
"""
|
| 1566 |
+
line_iterator = _line_iterator()
|
| 1567 |
+
fromlines,tolines=[],[]
|
| 1568 |
+
while True:
|
| 1569 |
+
# Collecting lines of text until we have a from/to pair
|
| 1570 |
+
while (len(fromlines)==0 or len(tolines)==0):
|
| 1571 |
+
try:
|
| 1572 |
+
from_line, to_line, found_diff = next(line_iterator)
|
| 1573 |
+
except StopIteration:
|
| 1574 |
+
return
|
| 1575 |
+
if from_line is not None:
|
| 1576 |
+
fromlines.append((from_line,found_diff))
|
| 1577 |
+
if to_line is not None:
|
| 1578 |
+
tolines.append((to_line,found_diff))
|
| 1579 |
+
# Once we have a pair, remove them from the collection and yield it
|
| 1580 |
+
from_line, fromDiff = fromlines.pop(0)
|
| 1581 |
+
to_line, to_diff = tolines.pop(0)
|
| 1582 |
+
yield (from_line,to_line,fromDiff or to_diff)
|
| 1583 |
+
|
| 1584 |
+
# Handle case where user does not want context differencing, just yield
|
| 1585 |
+
# them up without doing anything else with them.
|
| 1586 |
+
line_pair_iterator = _line_pair_iterator()
|
| 1587 |
+
if context is None:
|
| 1588 |
+
yield from line_pair_iterator
|
| 1589 |
+
# Handle case where user wants context differencing. We must do some
|
| 1590 |
+
# storage of lines until we know for sure that they are to be yielded.
|
| 1591 |
+
else:
|
| 1592 |
+
context += 1
|
| 1593 |
+
lines_to_write = 0
|
| 1594 |
+
while True:
|
| 1595 |
+
# Store lines up until we find a difference, note use of a
|
| 1596 |
+
# circular queue because we only need to keep around what
|
| 1597 |
+
# we need for context.
|
| 1598 |
+
index, contextLines = 0, [None]*(context)
|
| 1599 |
+
found_diff = False
|
| 1600 |
+
while(found_diff is False):
|
| 1601 |
+
try:
|
| 1602 |
+
from_line, to_line, found_diff = next(line_pair_iterator)
|
| 1603 |
+
except StopIteration:
|
| 1604 |
+
return
|
| 1605 |
+
i = index % context
|
| 1606 |
+
contextLines[i] = (from_line, to_line, found_diff)
|
| 1607 |
+
index += 1
|
| 1608 |
+
# Yield lines that we have collected so far, but first yield
|
| 1609 |
+
# the user's separator.
|
| 1610 |
+
if index > context:
|
| 1611 |
+
yield None, None, None
|
| 1612 |
+
lines_to_write = context
|
| 1613 |
+
else:
|
| 1614 |
+
lines_to_write = index
|
| 1615 |
+
index = 0
|
| 1616 |
+
while(lines_to_write):
|
| 1617 |
+
i = index % context
|
| 1618 |
+
index += 1
|
| 1619 |
+
yield contextLines[i]
|
| 1620 |
+
lines_to_write -= 1
|
| 1621 |
+
# Now yield the context lines after the change
|
| 1622 |
+
lines_to_write = context-1
|
| 1623 |
+
try:
|
| 1624 |
+
while(lines_to_write):
|
| 1625 |
+
from_line, to_line, found_diff = next(line_pair_iterator)
|
| 1626 |
+
# If another change within the context, extend the context
|
| 1627 |
+
if found_diff:
|
| 1628 |
+
lines_to_write = context-1
|
| 1629 |
+
else:
|
| 1630 |
+
lines_to_write -= 1
|
| 1631 |
+
yield from_line, to_line, found_diff
|
| 1632 |
+
except StopIteration:
|
| 1633 |
+
# Catch exception from next() and return normally
|
| 1634 |
+
return
|
| 1635 |
+
|
| 1636 |
+
|
| 1637 |
+
_file_template = """
|
| 1638 |
+
<!DOCTYPE html>
|
| 1639 |
+
<html lang="en">
|
| 1640 |
+
<head>
|
| 1641 |
+
<meta charset="%(charset)s">
|
| 1642 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 1643 |
+
<title>Diff comparison</title>
|
| 1644 |
+
<style>%(styles)s
|
| 1645 |
+
</style>
|
| 1646 |
+
</head>
|
| 1647 |
+
|
| 1648 |
+
<body>
|
| 1649 |
+
%(table)s%(legend)s
|
| 1650 |
+
</body>
|
| 1651 |
+
|
| 1652 |
+
</html>"""
|
| 1653 |
+
|
| 1654 |
+
_styles = """
|
| 1655 |
+
:root {color-scheme: light dark}
|
| 1656 |
+
table.diff {
|
| 1657 |
+
font-family: Menlo, Consolas, Monaco, Liberation Mono, Lucida Console, monospace;
|
| 1658 |
+
border: medium;
|
| 1659 |
+
}
|
| 1660 |
+
.diff_header {
|
| 1661 |
+
background-color: #e0e0e0;
|
| 1662 |
+
font-weight: bold;
|
| 1663 |
+
}
|
| 1664 |
+
td.diff_header {
|
| 1665 |
+
text-align: right;
|
| 1666 |
+
padding: 0 8px;
|
| 1667 |
+
}
|
| 1668 |
+
.diff_next {
|
| 1669 |
+
background-color: #c0c0c0;
|
| 1670 |
+
padding: 4px 0;
|
| 1671 |
+
}
|
| 1672 |
+
.diff_add {background-color:palegreen}
|
| 1673 |
+
.diff_chg {background-color:#ffff77}
|
| 1674 |
+
.diff_sub {background-color:#ffaaaa}
|
| 1675 |
+
table.diff[summary="Legends"] {
|
| 1676 |
+
margin-top: 20px;
|
| 1677 |
+
border: 1px solid #ccc;
|
| 1678 |
+
}
|
| 1679 |
+
table.diff[summary="Legends"] th {
|
| 1680 |
+
background-color: #e0e0e0;
|
| 1681 |
+
padding: 4px 8px;
|
| 1682 |
+
}
|
| 1683 |
+
table.diff[summary="Legends"] td {
|
| 1684 |
+
padding: 4px 8px;
|
| 1685 |
+
}
|
| 1686 |
+
|
| 1687 |
+
@media (prefers-color-scheme: dark) {
|
| 1688 |
+
.diff_header {background-color:#666}
|
| 1689 |
+
.diff_next {background-color:#393939}
|
| 1690 |
+
.diff_add {background-color:darkgreen}
|
| 1691 |
+
.diff_chg {background-color:#847415}
|
| 1692 |
+
.diff_sub {background-color:darkred}
|
| 1693 |
+
table.diff[summary="Legends"] {border-color:#555}
|
| 1694 |
+
table.diff[summary="Legends"] th{background-color:#666}
|
| 1695 |
+
}"""
|
| 1696 |
+
|
| 1697 |
+
_table_template = """
|
| 1698 |
+
<table class="diff" id="difflib_chg_%(prefix)s_top"
|
| 1699 |
+
cellspacing="0" cellpadding="0" rules="groups" >
|
| 1700 |
+
<colgroup></colgroup> <colgroup></colgroup> <colgroup></colgroup>
|
| 1701 |
+
<colgroup></colgroup> <colgroup></colgroup> <colgroup></colgroup>
|
| 1702 |
+
%(header_row)s
|
| 1703 |
+
<tbody>
|
| 1704 |
+
%(data_rows)s </tbody>
|
| 1705 |
+
</table>"""
|
| 1706 |
+
|
| 1707 |
+
_legend = """
|
| 1708 |
+
<table class="diff" summary="Legends">
|
| 1709 |
+
<tr> <th colspan="2"> Legends </th> </tr>
|
| 1710 |
+
<tr> <td> <table border="" summary="Colors">
|
| 1711 |
+
<tr><th> Colors </th> </tr>
|
| 1712 |
+
<tr><td class="diff_add"> Added </td></tr>
|
| 1713 |
+
<tr><td class="diff_chg">Changed</td> </tr>
|
| 1714 |
+
<tr><td class="diff_sub">Deleted</td> </tr>
|
| 1715 |
+
</table></td>
|
| 1716 |
+
<td> <table border="" summary="Links">
|
| 1717 |
+
<tr><th colspan="2"> Links </th> </tr>
|
| 1718 |
+
<tr><td>(f)irst change</td> </tr>
|
| 1719 |
+
<tr><td>(n)ext change</td> </tr>
|
| 1720 |
+
<tr><td>(t)op</td> </tr>
|
| 1721 |
+
</table></td> </tr>
|
| 1722 |
+
</table>"""
|
| 1723 |
+
|
| 1724 |
+
class HtmlDiff(object):
|
| 1725 |
+
"""For producing HTML side by side comparison with change highlights.
|
| 1726 |
+
|
| 1727 |
+
This class can be used to create an HTML table (or a complete HTML file
|
| 1728 |
+
containing the table) showing a side by side, line by line comparison
|
| 1729 |
+
of text with inter-line and intra-line change highlights. The table can
|
| 1730 |
+
be generated in either full or contextual difference mode.
|
| 1731 |
+
|
| 1732 |
+
The following methods are provided for HTML generation:
|
| 1733 |
+
|
| 1734 |
+
make_table -- generates HTML for a single side by side table
|
| 1735 |
+
make_file -- generates complete HTML file with a single side by side table
|
| 1736 |
+
|
| 1737 |
+
See Doc/includes/diff.py for an example usage of this class.
|
| 1738 |
+
"""
|
| 1739 |
+
|
| 1740 |
+
_file_template = _file_template
|
| 1741 |
+
_styles = _styles
|
| 1742 |
+
_table_template = _table_template
|
| 1743 |
+
_legend = _legend
|
| 1744 |
+
_default_prefix = 0
|
| 1745 |
+
|
| 1746 |
+
def __init__(self,tabsize=8,wrapcolumn=None,linejunk=None,
|
| 1747 |
+
charjunk=IS_CHARACTER_JUNK):
|
| 1748 |
+
"""HtmlDiff instance initializer
|
| 1749 |
+
|
| 1750 |
+
Arguments:
|
| 1751 |
+
tabsize -- tab stop spacing, defaults to 8.
|
| 1752 |
+
wrapcolumn -- column number where lines are broken and wrapped,
|
| 1753 |
+
defaults to None where lines are not wrapped.
|
| 1754 |
+
linejunk,charjunk -- keyword arguments passed into ndiff() (used by
|
| 1755 |
+
HtmlDiff() to generate the side by side HTML differences). See
|
| 1756 |
+
ndiff() documentation for argument default values and descriptions.
|
| 1757 |
+
"""
|
| 1758 |
+
self._tabsize = tabsize
|
| 1759 |
+
self._wrapcolumn = wrapcolumn
|
| 1760 |
+
self._linejunk = linejunk
|
| 1761 |
+
self._charjunk = charjunk
|
| 1762 |
+
|
| 1763 |
+
def make_file(self, fromlines, tolines, fromdesc='', todesc='',
|
| 1764 |
+
context=False, numlines=5, *, charset='utf-8'):
|
| 1765 |
+
"""Returns HTML file of side by side comparison with change highlights
|
| 1766 |
+
|
| 1767 |
+
Arguments:
|
| 1768 |
+
fromlines -- list of "from" lines
|
| 1769 |
+
tolines -- list of "to" lines
|
| 1770 |
+
fromdesc -- "from" file column header string
|
| 1771 |
+
todesc -- "to" file column header string
|
| 1772 |
+
context -- set to True for contextual differences (defaults to False
|
| 1773 |
+
which shows full differences).
|
| 1774 |
+
numlines -- number of context lines. When context is set True,
|
| 1775 |
+
controls number of lines displayed before and after the change.
|
| 1776 |
+
When context is False, controls the number of lines to place
|
| 1777 |
+
the "next" link anchors before the next change (so click of
|
| 1778 |
+
"next" link jumps to just before the change).
|
| 1779 |
+
charset -- charset of the HTML document
|
| 1780 |
+
"""
|
| 1781 |
+
|
| 1782 |
+
return (self._file_template % dict(
|
| 1783 |
+
styles=self._styles,
|
| 1784 |
+
legend=self._legend,
|
| 1785 |
+
table=self.make_table(fromlines, tolines, fromdesc, todesc,
|
| 1786 |
+
context=context, numlines=numlines),
|
| 1787 |
+
charset=charset
|
| 1788 |
+
)).encode(charset, 'xmlcharrefreplace').decode(charset)
|
| 1789 |
+
|
| 1790 |
+
def _tab_newline_replace(self,fromlines,tolines):
|
| 1791 |
+
"""Returns from/to line lists with tabs expanded and newlines removed.
|
| 1792 |
+
|
| 1793 |
+
Instead of tab characters being replaced by the number of spaces
|
| 1794 |
+
needed to fill in to the next tab stop, this function will fill
|
| 1795 |
+
the space with tab characters. This is done so that the difference
|
| 1796 |
+
algorithms can identify changes in a file when tabs are replaced by
|
| 1797 |
+
spaces and vice versa. At the end of the HTML generation, the tab
|
| 1798 |
+
characters will be replaced with a nonbreakable space.
|
| 1799 |
+
"""
|
| 1800 |
+
def expand_tabs(line):
|
| 1801 |
+
# hide real spaces
|
| 1802 |
+
line = line.replace(' ','\0')
|
| 1803 |
+
# expand tabs into spaces
|
| 1804 |
+
line = line.expandtabs(self._tabsize)
|
| 1805 |
+
# replace spaces from expanded tabs back into tab characters
|
| 1806 |
+
# (we'll replace them with markup after we do differencing)
|
| 1807 |
+
line = line.replace(' ','\t')
|
| 1808 |
+
return line.replace('\0',' ').rstrip('\n')
|
| 1809 |
+
fromlines = [expand_tabs(line) for line in fromlines]
|
| 1810 |
+
tolines = [expand_tabs(line) for line in tolines]
|
| 1811 |
+
return fromlines,tolines
|
| 1812 |
+
|
| 1813 |
+
def _split_line(self,data_list,line_num,text):
|
| 1814 |
+
"""Builds list of text lines by splitting text lines at wrap point
|
| 1815 |
+
|
| 1816 |
+
This function will determine if the input text line needs to be
|
| 1817 |
+
wrapped (split) into separate lines. If so, the first wrap point
|
| 1818 |
+
will be determined and the first line appended to the output
|
| 1819 |
+
text line list. This function is used recursively to handle
|
| 1820 |
+
the second part of the split line to further split it.
|
| 1821 |
+
"""
|
| 1822 |
+
# if blank line or context separator, just add it to the output list
|
| 1823 |
+
if not line_num:
|
| 1824 |
+
data_list.append((line_num,text))
|
| 1825 |
+
return
|
| 1826 |
+
|
| 1827 |
+
# if line text doesn't need wrapping, just add it to the output list
|
| 1828 |
+
size = len(text)
|
| 1829 |
+
max = self._wrapcolumn
|
| 1830 |
+
if (size <= max) or ((size -(text.count('\0')*3)) <= max):
|
| 1831 |
+
data_list.append((line_num,text))
|
| 1832 |
+
return
|
| 1833 |
+
|
| 1834 |
+
# scan text looking for the wrap point, keeping track if the wrap
|
| 1835 |
+
# point is inside markers
|
| 1836 |
+
i = 0
|
| 1837 |
+
n = 0
|
| 1838 |
+
mark = ''
|
| 1839 |
+
while n < max and i < size:
|
| 1840 |
+
if text[i] == '\0':
|
| 1841 |
+
i += 1
|
| 1842 |
+
mark = text[i]
|
| 1843 |
+
i += 1
|
| 1844 |
+
elif text[i] == '\1':
|
| 1845 |
+
i += 1
|
| 1846 |
+
mark = ''
|
| 1847 |
+
else:
|
| 1848 |
+
i += 1
|
| 1849 |
+
n += 1
|
| 1850 |
+
|
| 1851 |
+
# wrap point is inside text, break it up into separate lines
|
| 1852 |
+
line1 = text[:i]
|
| 1853 |
+
line2 = text[i:]
|
| 1854 |
+
|
| 1855 |
+
# if wrap point is inside markers, place end marker at end of first
|
| 1856 |
+
# line and start marker at beginning of second line because each
|
| 1857 |
+
# line will have its own table tag markup around it.
|
| 1858 |
+
if mark:
|
| 1859 |
+
line1 = line1 + '\1'
|
| 1860 |
+
line2 = '\0' + mark + line2
|
| 1861 |
+
|
| 1862 |
+
# tack on first line onto the output list
|
| 1863 |
+
data_list.append((line_num,line1))
|
| 1864 |
+
|
| 1865 |
+
# use this routine again to wrap the remaining text
|
| 1866 |
+
self._split_line(data_list,'>',line2)
|
| 1867 |
+
|
| 1868 |
+
def _line_wrapper(self,diffs):
|
| 1869 |
+
"""Returns iterator that splits (wraps) mdiff text lines"""
|
| 1870 |
+
|
| 1871 |
+
# pull from/to data and flags from mdiff iterator
|
| 1872 |
+
for fromdata,todata,flag in diffs:
|
| 1873 |
+
# check for context separators and pass them through
|
| 1874 |
+
if flag is None:
|
| 1875 |
+
yield fromdata,todata,flag
|
| 1876 |
+
continue
|
| 1877 |
+
(fromline,fromtext),(toline,totext) = fromdata,todata
|
| 1878 |
+
# for each from/to line split it at the wrap column to form
|
| 1879 |
+
# list of text lines.
|
| 1880 |
+
fromlist,tolist = [],[]
|
| 1881 |
+
self._split_line(fromlist,fromline,fromtext)
|
| 1882 |
+
self._split_line(tolist,toline,totext)
|
| 1883 |
+
# yield from/to line in pairs inserting blank lines as
|
| 1884 |
+
# necessary when one side has more wrapped lines
|
| 1885 |
+
while fromlist or tolist:
|
| 1886 |
+
if fromlist:
|
| 1887 |
+
fromdata = fromlist.pop(0)
|
| 1888 |
+
else:
|
| 1889 |
+
fromdata = ('',' ')
|
| 1890 |
+
if tolist:
|
| 1891 |
+
todata = tolist.pop(0)
|
| 1892 |
+
else:
|
| 1893 |
+
todata = ('',' ')
|
| 1894 |
+
yield fromdata,todata,flag
|
| 1895 |
+
|
| 1896 |
+
def _collect_lines(self,diffs):
|
| 1897 |
+
"""Collects mdiff output into separate lists
|
| 1898 |
+
|
| 1899 |
+
Before storing the mdiff from/to data into a list, it is converted
|
| 1900 |
+
into a single line of text with HTML markup.
|
| 1901 |
+
"""
|
| 1902 |
+
|
| 1903 |
+
fromlist,tolist,flaglist = [],[],[]
|
| 1904 |
+
# pull from/to data and flags from mdiff style iterator
|
| 1905 |
+
for fromdata,todata,flag in diffs:
|
| 1906 |
+
try:
|
| 1907 |
+
# store HTML markup of the lines into the lists
|
| 1908 |
+
fromlist.append(self._format_line(0,flag,*fromdata))
|
| 1909 |
+
tolist.append(self._format_line(1,flag,*todata))
|
| 1910 |
+
except TypeError:
|
| 1911 |
+
# exceptions occur for lines where context separators go
|
| 1912 |
+
fromlist.append(None)
|
| 1913 |
+
tolist.append(None)
|
| 1914 |
+
flaglist.append(flag)
|
| 1915 |
+
return fromlist,tolist,flaglist
|
| 1916 |
+
|
| 1917 |
+
def _format_line(self,side,flag,linenum,text):
|
| 1918 |
+
"""Returns HTML markup of "from" / "to" text lines
|
| 1919 |
+
|
| 1920 |
+
side -- 0 or 1 indicating "from" or "to" text
|
| 1921 |
+
flag -- indicates if difference on line
|
| 1922 |
+
linenum -- line number (used for line number column)
|
| 1923 |
+
text -- line text to be marked up
|
| 1924 |
+
"""
|
| 1925 |
+
try:
|
| 1926 |
+
linenum = '%d' % linenum
|
| 1927 |
+
id = ' id="%s%s"' % (self._prefix[side],linenum)
|
| 1928 |
+
except TypeError:
|
| 1929 |
+
# handle blank lines where linenum is '>' or ''
|
| 1930 |
+
id = ''
|
| 1931 |
+
# replace those things that would get confused with HTML symbols
|
| 1932 |
+
text=text.replace("&","&").replace(">",">").replace("<","<")
|
| 1933 |
+
|
| 1934 |
+
# make space non-breakable so they don't get compressed or line wrapped
|
| 1935 |
+
text = text.replace(' ',' ').rstrip()
|
| 1936 |
+
|
| 1937 |
+
return '<td class="diff_header"%s>%s</td><td nowrap="nowrap">%s</td>' \
|
| 1938 |
+
% (id,linenum,text)
|
| 1939 |
+
|
| 1940 |
+
def _make_prefix(self):
|
| 1941 |
+
"""Create unique anchor prefixes"""
|
| 1942 |
+
|
| 1943 |
+
# Generate a unique anchor prefix so multiple tables
|
| 1944 |
+
# can exist on the same HTML page without conflicts.
|
| 1945 |
+
fromprefix = "from%d_" % HtmlDiff._default_prefix
|
| 1946 |
+
toprefix = "to%d_" % HtmlDiff._default_prefix
|
| 1947 |
+
HtmlDiff._default_prefix += 1
|
| 1948 |
+
# store prefixes so line format method has access
|
| 1949 |
+
self._prefix = [fromprefix,toprefix]
|
| 1950 |
+
|
| 1951 |
+
def _convert_flags(self,fromlist,tolist,flaglist,context,numlines):
|
| 1952 |
+
"""Makes list of "next" links"""
|
| 1953 |
+
|
| 1954 |
+
# all anchor names will be generated using the unique "to" prefix
|
| 1955 |
+
toprefix = self._prefix[1]
|
| 1956 |
+
|
| 1957 |
+
# process change flags, generating middle column of next anchors/links
|
| 1958 |
+
next_id = ['']*len(flaglist)
|
| 1959 |
+
next_href = ['']*len(flaglist)
|
| 1960 |
+
num_chg, in_change = 0, False
|
| 1961 |
+
last = 0
|
| 1962 |
+
for i,flag in enumerate(flaglist):
|
| 1963 |
+
if flag:
|
| 1964 |
+
if not in_change:
|
| 1965 |
+
in_change = True
|
| 1966 |
+
last = i
|
| 1967 |
+
# at the beginning of a change, drop an anchor a few lines
|
| 1968 |
+
# (the context lines) before the change for the previous
|
| 1969 |
+
# link
|
| 1970 |
+
i = max([0,i-numlines])
|
| 1971 |
+
next_id[i] = ' id="difflib_chg_%s_%d"' % (toprefix,num_chg)
|
| 1972 |
+
# at the beginning of a change, drop a link to the next
|
| 1973 |
+
# change
|
| 1974 |
+
num_chg += 1
|
| 1975 |
+
next_href[last] = '<a href="#difflib_chg_%s_%d">n</a>' % (
|
| 1976 |
+
toprefix,num_chg)
|
| 1977 |
+
else:
|
| 1978 |
+
in_change = False
|
| 1979 |
+
# check for cases where there is no content to avoid exceptions
|
| 1980 |
+
if not flaglist:
|
| 1981 |
+
flaglist = [False]
|
| 1982 |
+
next_id = ['']
|
| 1983 |
+
next_href = ['']
|
| 1984 |
+
last = 0
|
| 1985 |
+
if context:
|
| 1986 |
+
fromlist = ['<td></td><td> No Differences Found </td>']
|
| 1987 |
+
tolist = fromlist
|
| 1988 |
+
else:
|
| 1989 |
+
fromlist = tolist = ['<td></td><td> Empty File </td>']
|
| 1990 |
+
# if not a change on first line, drop a link
|
| 1991 |
+
if not flaglist[0]:
|
| 1992 |
+
next_href[0] = '<a href="#difflib_chg_%s_0">f</a>' % toprefix
|
| 1993 |
+
# redo the last link to link to the top
|
| 1994 |
+
next_href[last] = '<a href="#difflib_chg_%s_top">t</a>' % (toprefix)
|
| 1995 |
+
|
| 1996 |
+
return fromlist,tolist,flaglist,next_href,next_id
|
| 1997 |
+
|
| 1998 |
+
def make_table(self,fromlines,tolines,fromdesc='',todesc='',context=False,
|
| 1999 |
+
numlines=5):
|
| 2000 |
+
"""Returns HTML table of side by side comparison with change highlights
|
| 2001 |
+
|
| 2002 |
+
Arguments:
|
| 2003 |
+
fromlines -- list of "from" lines
|
| 2004 |
+
tolines -- list of "to" lines
|
| 2005 |
+
fromdesc -- "from" file column header string
|
| 2006 |
+
todesc -- "to" file column header string
|
| 2007 |
+
context -- set to True for contextual differences (defaults to False
|
| 2008 |
+
which shows full differences).
|
| 2009 |
+
numlines -- number of context lines. When context is set True,
|
| 2010 |
+
controls number of lines displayed before and after the change.
|
| 2011 |
+
When context is False, controls the number of lines to place
|
| 2012 |
+
the "next" link anchors before the next change (so click of
|
| 2013 |
+
"next" link jumps to just before the change).
|
| 2014 |
+
"""
|
| 2015 |
+
|
| 2016 |
+
# make unique anchor prefixes so that multiple tables may exist
|
| 2017 |
+
# on the same page without conflict.
|
| 2018 |
+
self._make_prefix()
|
| 2019 |
+
|
| 2020 |
+
# change tabs to spaces before it gets more difficult after we insert
|
| 2021 |
+
# markup
|
| 2022 |
+
fromlines,tolines = self._tab_newline_replace(fromlines,tolines)
|
| 2023 |
+
|
| 2024 |
+
# create diffs iterator which generates side by side from/to data
|
| 2025 |
+
if context:
|
| 2026 |
+
context_lines = numlines
|
| 2027 |
+
else:
|
| 2028 |
+
context_lines = None
|
| 2029 |
+
diffs = _mdiff(fromlines,tolines,context_lines,linejunk=self._linejunk,
|
| 2030 |
+
charjunk=self._charjunk)
|
| 2031 |
+
|
| 2032 |
+
# set up iterator to wrap lines that exceed desired width
|
| 2033 |
+
if self._wrapcolumn:
|
| 2034 |
+
diffs = self._line_wrapper(diffs)
|
| 2035 |
+
|
| 2036 |
+
# collect up from/to lines and flags into lists (also format the lines)
|
| 2037 |
+
fromlist,tolist,flaglist = self._collect_lines(diffs)
|
| 2038 |
+
|
| 2039 |
+
# process change flags, generating middle column of next anchors/links
|
| 2040 |
+
fromlist,tolist,flaglist,next_href,next_id = self._convert_flags(
|
| 2041 |
+
fromlist,tolist,flaglist,context,numlines)
|
| 2042 |
+
|
| 2043 |
+
s = []
|
| 2044 |
+
fmt = ' <tr><td class="diff_next"%s>%s</td>%s' + \
|
| 2045 |
+
'<td class="diff_next">%s</td>%s</tr>\n'
|
| 2046 |
+
for i in range(len(flaglist)):
|
| 2047 |
+
if flaglist[i] is None:
|
| 2048 |
+
# mdiff yields None on separator lines skip the bogus ones
|
| 2049 |
+
# generated for the first line
|
| 2050 |
+
if i > 0:
|
| 2051 |
+
s.append(' </tbody> \n <tbody>\n')
|
| 2052 |
+
else:
|
| 2053 |
+
s.append( fmt % (next_id[i],next_href[i],fromlist[i],
|
| 2054 |
+
next_href[i],tolist[i]))
|
| 2055 |
+
if fromdesc or todesc:
|
| 2056 |
+
header_row = '<thead><tr>%s%s%s%s</tr></thead>' % (
|
| 2057 |
+
'<th class="diff_next"><br /></th>',
|
| 2058 |
+
'<th colspan="2" class="diff_header">%s</th>' % fromdesc,
|
| 2059 |
+
'<th class="diff_next"><br /></th>',
|
| 2060 |
+
'<th colspan="2" class="diff_header">%s</th>' % todesc)
|
| 2061 |
+
else:
|
| 2062 |
+
header_row = ''
|
| 2063 |
+
|
| 2064 |
+
table = self._table_template % dict(
|
| 2065 |
+
data_rows=''.join(s),
|
| 2066 |
+
header_row=header_row,
|
| 2067 |
+
prefix=self._prefix[1])
|
| 2068 |
+
|
| 2069 |
+
return table.replace('\0+','<span class="diff_add">'). \
|
| 2070 |
+
replace('\0-','<span class="diff_sub">'). \
|
| 2071 |
+
replace('\0^','<span class="diff_chg">'). \
|
| 2072 |
+
replace('\1','</span>'). \
|
| 2073 |
+
replace('\t',' ')
|
| 2074 |
+
|
| 2075 |
+
|
| 2076 |
+
def restore(delta, which):
|
| 2077 |
+
r"""
|
| 2078 |
+
Generate one of the two sequences that generated a delta.
|
| 2079 |
+
|
| 2080 |
+
Given a `delta` produced by `Differ.compare()` or `ndiff()`, extract
|
| 2081 |
+
lines originating from file 1 or 2 (parameter `which`), stripping off line
|
| 2082 |
+
prefixes.
|
| 2083 |
+
|
| 2084 |
+
Examples:
|
| 2085 |
+
|
| 2086 |
+
>>> diff = ndiff('one\ntwo\nthree\n'.splitlines(keepends=True),
|
| 2087 |
+
... 'ore\ntree\nemu\n'.splitlines(keepends=True))
|
| 2088 |
+
>>> diff = list(diff)
|
| 2089 |
+
>>> print(''.join(restore(diff, 1)), end="")
|
| 2090 |
+
one
|
| 2091 |
+
two
|
| 2092 |
+
three
|
| 2093 |
+
>>> print(''.join(restore(diff, 2)), end="")
|
| 2094 |
+
ore
|
| 2095 |
+
tree
|
| 2096 |
+
emu
|
| 2097 |
+
"""
|
| 2098 |
+
try:
|
| 2099 |
+
tag = {1: "- ", 2: "+ "}[int(which)]
|
| 2100 |
+
except KeyError:
|
| 2101 |
+
raise ValueError('unknown delta choice (must be 1 or 2): %r'
|
| 2102 |
+
% which) from None
|
| 2103 |
+
prefixes = (" ", tag)
|
| 2104 |
+
for line in delta:
|
| 2105 |
+
if line[:2] in prefixes:
|
| 2106 |
+
yield line[2:]
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/formfill.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lxml.etree import XPath, ElementBase
|
| 2 |
+
from lxml.html import fromstring, XHTML_NAMESPACE
|
| 3 |
+
from lxml.html import _forms_xpath, _options_xpath, _nons, _transform_result
|
| 4 |
+
from lxml.html import defs
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
basestring
|
| 9 |
+
except NameError:
|
| 10 |
+
# Python 3
|
| 11 |
+
basestring = str
|
| 12 |
+
|
| 13 |
+
__all__ = ['FormNotFound', 'fill_form', 'fill_form_html',
|
| 14 |
+
'insert_errors', 'insert_errors_html',
|
| 15 |
+
'DefaultErrorCreator']
|
| 16 |
+
|
| 17 |
+
class FormNotFound(LookupError):
|
| 18 |
+
"""
|
| 19 |
+
Raised when no form can be found
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
_form_name_xpath = XPath('descendant-or-self::form[name=$name]|descendant-or-self::x:form[name=$name]', namespaces={'x':XHTML_NAMESPACE})
|
| 23 |
+
_input_xpath = XPath('|'.join(['descendant-or-self::'+_tag for _tag in ('input','select','textarea','x:input','x:select','x:textarea')]),
|
| 24 |
+
namespaces={'x':XHTML_NAMESPACE})
|
| 25 |
+
_label_for_xpath = XPath('//label[@for=$for_id]|//x:label[@for=$for_id]',
|
| 26 |
+
namespaces={'x':XHTML_NAMESPACE})
|
| 27 |
+
_name_xpath = XPath('descendant-or-self::*[@name=$name]')
|
| 28 |
+
|
| 29 |
+
def fill_form(
|
| 30 |
+
el,
|
| 31 |
+
values,
|
| 32 |
+
form_id=None,
|
| 33 |
+
form_index=None,
|
| 34 |
+
):
|
| 35 |
+
el = _find_form(el, form_id=form_id, form_index=form_index)
|
| 36 |
+
_fill_form(el, values)
|
| 37 |
+
|
| 38 |
+
def fill_form_html(html, values, form_id=None, form_index=None):
|
| 39 |
+
result_type = type(html)
|
| 40 |
+
if isinstance(html, basestring):
|
| 41 |
+
doc = fromstring(html)
|
| 42 |
+
else:
|
| 43 |
+
doc = copy.deepcopy(html)
|
| 44 |
+
fill_form(doc, values, form_id=form_id, form_index=form_index)
|
| 45 |
+
return _transform_result(result_type, doc)
|
| 46 |
+
|
| 47 |
+
def _fill_form(el, values):
|
| 48 |
+
counts = {}
|
| 49 |
+
if hasattr(values, 'mixed'):
|
| 50 |
+
# For Paste request parameters
|
| 51 |
+
values = values.mixed()
|
| 52 |
+
inputs = _input_xpath(el)
|
| 53 |
+
for input in inputs:
|
| 54 |
+
name = input.get('name')
|
| 55 |
+
if not name:
|
| 56 |
+
continue
|
| 57 |
+
if _takes_multiple(input):
|
| 58 |
+
value = values.get(name, [])
|
| 59 |
+
if not isinstance(value, (list, tuple)):
|
| 60 |
+
value = [value]
|
| 61 |
+
_fill_multiple(input, value)
|
| 62 |
+
elif name not in values:
|
| 63 |
+
continue
|
| 64 |
+
else:
|
| 65 |
+
index = counts.get(name, 0)
|
| 66 |
+
counts[name] = index + 1
|
| 67 |
+
value = values[name]
|
| 68 |
+
if isinstance(value, (list, tuple)):
|
| 69 |
+
try:
|
| 70 |
+
value = value[index]
|
| 71 |
+
except IndexError:
|
| 72 |
+
continue
|
| 73 |
+
elif index > 0:
|
| 74 |
+
continue
|
| 75 |
+
_fill_single(input, value)
|
| 76 |
+
|
| 77 |
+
def _takes_multiple(input):
|
| 78 |
+
if _nons(input.tag) == 'select' and input.get('multiple'):
|
| 79 |
+
# FIXME: multiple="0"?
|
| 80 |
+
return True
|
| 81 |
+
type = input.get('type', '').lower()
|
| 82 |
+
if type in ('radio', 'checkbox'):
|
| 83 |
+
return True
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
def _fill_multiple(input, value):
|
| 87 |
+
type = input.get('type', '').lower()
|
| 88 |
+
if type == 'checkbox':
|
| 89 |
+
v = input.get('value')
|
| 90 |
+
if v is None:
|
| 91 |
+
if not value:
|
| 92 |
+
result = False
|
| 93 |
+
else:
|
| 94 |
+
result = value[0]
|
| 95 |
+
if isinstance(value, basestring):
|
| 96 |
+
# The only valid "on" value for an unnamed checkbox is 'on'
|
| 97 |
+
result = result == 'on'
|
| 98 |
+
_check(input, result)
|
| 99 |
+
else:
|
| 100 |
+
_check(input, v in value)
|
| 101 |
+
elif type == 'radio':
|
| 102 |
+
v = input.get('value')
|
| 103 |
+
_check(input, v in value)
|
| 104 |
+
else:
|
| 105 |
+
assert _nons(input.tag) == 'select'
|
| 106 |
+
for option in _options_xpath(input):
|
| 107 |
+
v = option.get('value')
|
| 108 |
+
if v is None:
|
| 109 |
+
# This seems to be the default, at least on IE
|
| 110 |
+
# FIXME: but I'm not sure
|
| 111 |
+
v = option.text_content()
|
| 112 |
+
_select(option, v in value)
|
| 113 |
+
|
| 114 |
+
def _check(el, check):
|
| 115 |
+
if check:
|
| 116 |
+
el.set('checked', '')
|
| 117 |
+
else:
|
| 118 |
+
if 'checked' in el.attrib:
|
| 119 |
+
del el.attrib['checked']
|
| 120 |
+
|
| 121 |
+
def _select(el, select):
|
| 122 |
+
if select:
|
| 123 |
+
el.set('selected', '')
|
| 124 |
+
else:
|
| 125 |
+
if 'selected' in el.attrib:
|
| 126 |
+
del el.attrib['selected']
|
| 127 |
+
|
| 128 |
+
def _fill_single(input, value):
|
| 129 |
+
if _nons(input.tag) == 'textarea':
|
| 130 |
+
input.text = value
|
| 131 |
+
else:
|
| 132 |
+
input.set('value', value)
|
| 133 |
+
|
| 134 |
+
def _find_form(el, form_id=None, form_index=None):
|
| 135 |
+
if form_id is None and form_index is None:
|
| 136 |
+
forms = _forms_xpath(el)
|
| 137 |
+
for form in forms:
|
| 138 |
+
return form
|
| 139 |
+
raise FormNotFound(
|
| 140 |
+
"No forms in page")
|
| 141 |
+
if form_id is not None:
|
| 142 |
+
form = el.get_element_by_id(form_id)
|
| 143 |
+
if form is not None:
|
| 144 |
+
return form
|
| 145 |
+
forms = _form_name_xpath(el, name=form_id)
|
| 146 |
+
if forms:
|
| 147 |
+
return forms[0]
|
| 148 |
+
else:
|
| 149 |
+
raise FormNotFound(
|
| 150 |
+
"No form with the name or id of %r (forms: %s)"
|
| 151 |
+
% (id, ', '.join(_find_form_ids(el))))
|
| 152 |
+
if form_index is not None:
|
| 153 |
+
forms = _forms_xpath(el)
|
| 154 |
+
try:
|
| 155 |
+
return forms[form_index]
|
| 156 |
+
except IndexError:
|
| 157 |
+
raise FormNotFound(
|
| 158 |
+
"There is no form with the index %r (%i forms found)"
|
| 159 |
+
% (form_index, len(forms)))
|
| 160 |
+
|
| 161 |
+
def _find_form_ids(el):
|
| 162 |
+
forms = _forms_xpath(el)
|
| 163 |
+
if not forms:
|
| 164 |
+
yield '(no forms)'
|
| 165 |
+
return
|
| 166 |
+
for index, form in enumerate(forms):
|
| 167 |
+
if form.get('id'):
|
| 168 |
+
if form.get('name'):
|
| 169 |
+
yield '%s or %s' % (form.get('id'),
|
| 170 |
+
form.get('name'))
|
| 171 |
+
else:
|
| 172 |
+
yield form.get('id')
|
| 173 |
+
elif form.get('name'):
|
| 174 |
+
yield form.get('name')
|
| 175 |
+
else:
|
| 176 |
+
yield '(unnamed form %s)' % index
|
| 177 |
+
|
| 178 |
+
############################################################
|
| 179 |
+
## Error filling
|
| 180 |
+
############################################################
|
| 181 |
+
|
| 182 |
+
class DefaultErrorCreator:
|
| 183 |
+
insert_before = True
|
| 184 |
+
block_inside = True
|
| 185 |
+
error_container_tag = 'div'
|
| 186 |
+
error_message_class = 'error-message'
|
| 187 |
+
error_block_class = 'error-block'
|
| 188 |
+
default_message = "Invalid"
|
| 189 |
+
|
| 190 |
+
def __init__(self, **kw):
|
| 191 |
+
for name, value in kw.items():
|
| 192 |
+
if not hasattr(self, name):
|
| 193 |
+
raise TypeError(
|
| 194 |
+
"Unexpected keyword argument: %s" % name)
|
| 195 |
+
setattr(self, name, value)
|
| 196 |
+
|
| 197 |
+
def __call__(self, el, is_block, message):
|
| 198 |
+
error_el = el.makeelement(self.error_container_tag)
|
| 199 |
+
if self.error_message_class:
|
| 200 |
+
error_el.set('class', self.error_message_class)
|
| 201 |
+
if is_block and self.error_block_class:
|
| 202 |
+
error_el.set('class', error_el.get('class', '')+' '+self.error_block_class)
|
| 203 |
+
if message is None or message == '':
|
| 204 |
+
message = self.default_message
|
| 205 |
+
if isinstance(message, ElementBase):
|
| 206 |
+
error_el.append(message)
|
| 207 |
+
else:
|
| 208 |
+
assert isinstance(message, basestring), (
|
| 209 |
+
"Bad message; should be a string or element: %r" % message)
|
| 210 |
+
error_el.text = message or self.default_message
|
| 211 |
+
if is_block and self.block_inside:
|
| 212 |
+
if self.insert_before:
|
| 213 |
+
error_el.tail = el.text
|
| 214 |
+
el.text = None
|
| 215 |
+
el.insert(0, error_el)
|
| 216 |
+
else:
|
| 217 |
+
el.append(error_el)
|
| 218 |
+
else:
|
| 219 |
+
parent = el.getparent()
|
| 220 |
+
pos = parent.index(el)
|
| 221 |
+
if self.insert_before:
|
| 222 |
+
parent.insert(pos, error_el)
|
| 223 |
+
else:
|
| 224 |
+
error_el.tail = el.tail
|
| 225 |
+
el.tail = None
|
| 226 |
+
parent.insert(pos+1, error_el)
|
| 227 |
+
|
| 228 |
+
default_error_creator = DefaultErrorCreator()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def insert_errors(
|
| 232 |
+
el,
|
| 233 |
+
errors,
|
| 234 |
+
form_id=None,
|
| 235 |
+
form_index=None,
|
| 236 |
+
error_class="error",
|
| 237 |
+
error_creator=default_error_creator,
|
| 238 |
+
):
|
| 239 |
+
el = _find_form(el, form_id=form_id, form_index=form_index)
|
| 240 |
+
for name, error in errors.items():
|
| 241 |
+
if error is None:
|
| 242 |
+
continue
|
| 243 |
+
for error_el, message in _find_elements_for_name(el, name, error):
|
| 244 |
+
assert isinstance(message, (basestring, type(None), ElementBase)), (
|
| 245 |
+
"Bad message: %r" % message)
|
| 246 |
+
_insert_error(error_el, message, error_class, error_creator)
|
| 247 |
+
|
| 248 |
+
def insert_errors_html(html, values, **kw):
|
| 249 |
+
result_type = type(html)
|
| 250 |
+
if isinstance(html, basestring):
|
| 251 |
+
doc = fromstring(html)
|
| 252 |
+
else:
|
| 253 |
+
doc = copy.deepcopy(html)
|
| 254 |
+
insert_errors(doc, values, **kw)
|
| 255 |
+
return _transform_result(result_type, doc)
|
| 256 |
+
|
| 257 |
+
def _insert_error(el, error, error_class, error_creator):
|
| 258 |
+
if _nons(el.tag) in defs.empty_tags or _nons(el.tag) == 'textarea':
|
| 259 |
+
is_block = False
|
| 260 |
+
else:
|
| 261 |
+
is_block = True
|
| 262 |
+
if _nons(el.tag) != 'form' and error_class:
|
| 263 |
+
_add_class(el, error_class)
|
| 264 |
+
if el.get('id'):
|
| 265 |
+
labels = _label_for_xpath(el, for_id=el.get('id'))
|
| 266 |
+
if labels:
|
| 267 |
+
for label in labels:
|
| 268 |
+
_add_class(label, error_class)
|
| 269 |
+
error_creator(el, is_block, error)
|
| 270 |
+
|
| 271 |
+
def _add_class(el, class_name):
|
| 272 |
+
if el.get('class'):
|
| 273 |
+
el.set('class', el.get('class')+' '+class_name)
|
| 274 |
+
else:
|
| 275 |
+
el.set('class', class_name)
|
| 276 |
+
|
| 277 |
+
def _find_elements_for_name(form, name, error):
|
| 278 |
+
if name is None:
|
| 279 |
+
# An error for the entire form
|
| 280 |
+
yield form, error
|
| 281 |
+
return
|
| 282 |
+
if name.startswith('#'):
|
| 283 |
+
# By id
|
| 284 |
+
el = form.get_element_by_id(name[1:])
|
| 285 |
+
if el is not None:
|
| 286 |
+
yield el, error
|
| 287 |
+
return
|
| 288 |
+
els = _name_xpath(form, name=name)
|
| 289 |
+
if not els:
|
| 290 |
+
# FIXME: should this raise an exception?
|
| 291 |
+
return
|
| 292 |
+
if not isinstance(error, (list, tuple)):
|
| 293 |
+
yield els[0], error
|
| 294 |
+
return
|
| 295 |
+
# FIXME: if error is longer than els, should it raise an error?
|
| 296 |
+
for el, err in zip(els, error):
|
| 297 |
+
if err is None:
|
| 298 |
+
continue
|
| 299 |
+
yield el, err
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/html5parser.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
An interface to html5lib that mimics the lxml.html interface.
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
import string
|
| 6 |
+
|
| 7 |
+
from html5lib import HTMLParser as _HTMLParser
|
| 8 |
+
from html5lib.treebuilders.etree_lxml import TreeBuilder
|
| 9 |
+
from lxml import etree
|
| 10 |
+
from lxml.html import Element, XHTML_NAMESPACE, _contains_block_level_tag
|
| 11 |
+
|
| 12 |
+
# python3 compatibility
|
| 13 |
+
try:
|
| 14 |
+
_strings = basestring
|
| 15 |
+
except NameError:
|
| 16 |
+
_strings = (bytes, str)
|
| 17 |
+
try:
|
| 18 |
+
from urllib2 import urlopen
|
| 19 |
+
except ImportError:
|
| 20 |
+
from urllib.request import urlopen
|
| 21 |
+
try:
|
| 22 |
+
from urlparse import urlparse
|
| 23 |
+
except ImportError:
|
| 24 |
+
from urllib.parse import urlparse
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HTMLParser(_HTMLParser):
|
| 28 |
+
"""An html5lib HTML parser with lxml as tree."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, strict=False, **kwargs):
|
| 31 |
+
_HTMLParser.__init__(self, strict=strict, tree=TreeBuilder, **kwargs)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from html5lib import XHTMLParser as _XHTMLParser
|
| 36 |
+
except ImportError:
|
| 37 |
+
pass
|
| 38 |
+
else:
|
| 39 |
+
class XHTMLParser(_XHTMLParser):
|
| 40 |
+
"""An html5lib XHTML Parser with lxml as tree."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, strict=False, **kwargs):
|
| 43 |
+
_XHTMLParser.__init__(self, strict=strict, tree=TreeBuilder, **kwargs)
|
| 44 |
+
|
| 45 |
+
xhtml_parser = XHTMLParser()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _find_tag(tree, tag):
|
| 49 |
+
elem = tree.find(tag)
|
| 50 |
+
if elem is not None:
|
| 51 |
+
return elem
|
| 52 |
+
return tree.find('{%s}%s' % (XHTML_NAMESPACE, tag))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def document_fromstring(html, guess_charset=None, parser=None):
|
| 56 |
+
"""
|
| 57 |
+
Parse a whole document into a string.
|
| 58 |
+
|
| 59 |
+
If `guess_charset` is true, or if the input is not Unicode but a
|
| 60 |
+
byte string, the `chardet` library will perform charset guessing
|
| 61 |
+
on the string.
|
| 62 |
+
"""
|
| 63 |
+
if not isinstance(html, _strings):
|
| 64 |
+
raise TypeError('string required')
|
| 65 |
+
|
| 66 |
+
if parser is None:
|
| 67 |
+
parser = html_parser
|
| 68 |
+
|
| 69 |
+
options = {}
|
| 70 |
+
if guess_charset is None and isinstance(html, bytes):
|
| 71 |
+
# html5lib does not accept useChardet as an argument, if it
|
| 72 |
+
# detected the html argument would produce unicode objects.
|
| 73 |
+
guess_charset = True
|
| 74 |
+
if guess_charset is not None:
|
| 75 |
+
options['useChardet'] = guess_charset
|
| 76 |
+
return parser.parse(html, **options).getroot()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def fragments_fromstring(html, no_leading_text=False,
|
| 80 |
+
guess_charset=None, parser=None):
|
| 81 |
+
"""Parses several HTML elements, returning a list of elements.
|
| 82 |
+
|
| 83 |
+
The first item in the list may be a string. If no_leading_text is true,
|
| 84 |
+
then it will be an error if there is leading text, and it will always be
|
| 85 |
+
a list of only elements.
|
| 86 |
+
|
| 87 |
+
If `guess_charset` is true, the `chardet` library will perform charset
|
| 88 |
+
guessing on the string.
|
| 89 |
+
"""
|
| 90 |
+
if not isinstance(html, _strings):
|
| 91 |
+
raise TypeError('string required')
|
| 92 |
+
|
| 93 |
+
if parser is None:
|
| 94 |
+
parser = html_parser
|
| 95 |
+
|
| 96 |
+
options = {}
|
| 97 |
+
if guess_charset is None and isinstance(html, bytes):
|
| 98 |
+
# html5lib does not accept useChardet as an argument, if it
|
| 99 |
+
# detected the html argument would produce unicode objects.
|
| 100 |
+
guess_charset = False
|
| 101 |
+
if guess_charset is not None:
|
| 102 |
+
options['useChardet'] = guess_charset
|
| 103 |
+
children = parser.parseFragment(html, 'div', **options)
|
| 104 |
+
if children and isinstance(children[0], _strings):
|
| 105 |
+
if no_leading_text:
|
| 106 |
+
if children[0].strip():
|
| 107 |
+
raise etree.ParserError('There is leading text: %r' %
|
| 108 |
+
children[0])
|
| 109 |
+
del children[0]
|
| 110 |
+
return children
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def fragment_fromstring(html, create_parent=False,
|
| 114 |
+
guess_charset=None, parser=None):
|
| 115 |
+
"""Parses a single HTML element; it is an error if there is more than
|
| 116 |
+
one element, or if anything but whitespace precedes or follows the
|
| 117 |
+
element.
|
| 118 |
+
|
| 119 |
+
If 'create_parent' is true (or is a tag name) then a parent node
|
| 120 |
+
will be created to encapsulate the HTML in a single element. In
|
| 121 |
+
this case, leading or trailing text is allowed.
|
| 122 |
+
|
| 123 |
+
If `guess_charset` is true, the `chardet` library will perform charset
|
| 124 |
+
guessing on the string.
|
| 125 |
+
"""
|
| 126 |
+
if not isinstance(html, _strings):
|
| 127 |
+
raise TypeError('string required')
|
| 128 |
+
|
| 129 |
+
accept_leading_text = bool(create_parent)
|
| 130 |
+
|
| 131 |
+
elements = fragments_fromstring(
|
| 132 |
+
html, guess_charset=guess_charset, parser=parser,
|
| 133 |
+
no_leading_text=not accept_leading_text)
|
| 134 |
+
|
| 135 |
+
if create_parent:
|
| 136 |
+
if not isinstance(create_parent, _strings):
|
| 137 |
+
create_parent = 'div'
|
| 138 |
+
new_root = Element(create_parent)
|
| 139 |
+
if elements:
|
| 140 |
+
if isinstance(elements[0], _strings):
|
| 141 |
+
new_root.text = elements[0]
|
| 142 |
+
del elements[0]
|
| 143 |
+
new_root.extend(elements)
|
| 144 |
+
return new_root
|
| 145 |
+
|
| 146 |
+
if not elements:
|
| 147 |
+
raise etree.ParserError('No elements found')
|
| 148 |
+
if len(elements) > 1:
|
| 149 |
+
raise etree.ParserError('Multiple elements found')
|
| 150 |
+
result = elements[0]
|
| 151 |
+
if result.tail and result.tail.strip():
|
| 152 |
+
raise etree.ParserError('Element followed by text: %r' % result.tail)
|
| 153 |
+
result.tail = None
|
| 154 |
+
return result
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def fromstring(html, guess_charset=None, parser=None):
|
| 158 |
+
"""Parse the html, returning a single element/document.
|
| 159 |
+
|
| 160 |
+
This tries to minimally parse the chunk of text, without knowing if it
|
| 161 |
+
is a fragment or a document.
|
| 162 |
+
|
| 163 |
+
'base_url' will set the document's base_url attribute (and the tree's
|
| 164 |
+
docinfo.URL)
|
| 165 |
+
|
| 166 |
+
If `guess_charset` is true, or if the input is not Unicode but a
|
| 167 |
+
byte string, the `chardet` library will perform charset guessing
|
| 168 |
+
on the string.
|
| 169 |
+
"""
|
| 170 |
+
if not isinstance(html, _strings):
|
| 171 |
+
raise TypeError('string required')
|
| 172 |
+
doc = document_fromstring(html, parser=parser,
|
| 173 |
+
guess_charset=guess_charset)
|
| 174 |
+
|
| 175 |
+
# document starts with doctype or <html>, full document!
|
| 176 |
+
start = html[:50]
|
| 177 |
+
if isinstance(start, bytes):
|
| 178 |
+
# Allow text comparison in python3.
|
| 179 |
+
# Decode as ascii, that also covers latin-1 and utf-8 for the
|
| 180 |
+
# characters we need.
|
| 181 |
+
start = start.decode('ascii', 'replace')
|
| 182 |
+
|
| 183 |
+
start = start.lstrip().lower()
|
| 184 |
+
if start.startswith('<html') or start.startswith('<!doctype'):
|
| 185 |
+
return doc
|
| 186 |
+
|
| 187 |
+
head = _find_tag(doc, 'head')
|
| 188 |
+
|
| 189 |
+
# if the head is not empty we have a full document
|
| 190 |
+
if len(head):
|
| 191 |
+
return doc
|
| 192 |
+
|
| 193 |
+
body = _find_tag(doc, 'body')
|
| 194 |
+
|
| 195 |
+
# The body has just one element, so it was probably a single
|
| 196 |
+
# element passed in
|
| 197 |
+
if (len(body) == 1 and (not body.text or not body.text.strip())
|
| 198 |
+
and (not body[-1].tail or not body[-1].tail.strip())):
|
| 199 |
+
return body[0]
|
| 200 |
+
|
| 201 |
+
# Now we have a body which represents a bunch of tags which have the
|
| 202 |
+
# content that was passed in. We will create a fake container, which
|
| 203 |
+
# is the body tag, except <body> implies too much structure.
|
| 204 |
+
if _contains_block_level_tag(body):
|
| 205 |
+
body.tag = 'div'
|
| 206 |
+
else:
|
| 207 |
+
body.tag = 'span'
|
| 208 |
+
return body
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def parse(filename_url_or_file, guess_charset=None, parser=None):
|
| 212 |
+
"""Parse a filename, URL, or file-like object into an HTML document
|
| 213 |
+
tree. Note: this returns a tree, not an element. Use
|
| 214 |
+
``parse(...).getroot()`` to get the document root.
|
| 215 |
+
|
| 216 |
+
If ``guess_charset`` is true, the ``useChardet`` option is passed into
|
| 217 |
+
html5lib to enable character detection. This option is on by default
|
| 218 |
+
when parsing from URLs, off by default when parsing from file(-like)
|
| 219 |
+
objects (which tend to return Unicode more often than not), and on by
|
| 220 |
+
default when parsing from a file path (which is read in binary mode).
|
| 221 |
+
"""
|
| 222 |
+
if parser is None:
|
| 223 |
+
parser = html_parser
|
| 224 |
+
if not isinstance(filename_url_or_file, _strings):
|
| 225 |
+
fp = filename_url_or_file
|
| 226 |
+
if guess_charset is None:
|
| 227 |
+
# assume that file-like objects return Unicode more often than bytes
|
| 228 |
+
guess_charset = False
|
| 229 |
+
elif _looks_like_url(filename_url_or_file):
|
| 230 |
+
fp = urlopen(filename_url_or_file)
|
| 231 |
+
if guess_charset is None:
|
| 232 |
+
# assume that URLs return bytes
|
| 233 |
+
guess_charset = True
|
| 234 |
+
else:
|
| 235 |
+
fp = open(filename_url_or_file, 'rb')
|
| 236 |
+
if guess_charset is None:
|
| 237 |
+
guess_charset = True
|
| 238 |
+
|
| 239 |
+
options = {}
|
| 240 |
+
# html5lib does not accept useChardet as an argument, if it
|
| 241 |
+
# detected the html argument would produce unicode objects.
|
| 242 |
+
if guess_charset:
|
| 243 |
+
options['useChardet'] = guess_charset
|
| 244 |
+
return parser.parse(fp, **options)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _looks_like_url(str):
|
| 248 |
+
scheme = urlparse(str)[0]
|
| 249 |
+
if not scheme:
|
| 250 |
+
return False
|
| 251 |
+
elif (sys.platform == 'win32' and
|
| 252 |
+
scheme in string.ascii_letters
|
| 253 |
+
and len(scheme) == 1):
|
| 254 |
+
# looks like a 'normal' absolute path
|
| 255 |
+
return False
|
| 256 |
+
else:
|
| 257 |
+
return True
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
html_parser = HTMLParser()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/__init__.pxd
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/config.pxd
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cdef extern from "etree_defs.h":
|
| 2 |
+
cdef bint ENABLE_THREADING
|
| 3 |
+
cdef bint ENABLE_SCHEMATRON
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/relaxng.pxd
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lxml.includes.tree cimport xmlDoc
|
| 2 |
+
from lxml.includes.xmlerror cimport xmlStructuredErrorFunc
|
| 3 |
+
|
| 4 |
+
cdef extern from "libxml/relaxng.h" nogil:
|
| 5 |
+
ctypedef struct xmlRelaxNG
|
| 6 |
+
ctypedef struct xmlRelaxNGParserCtxt
|
| 7 |
+
|
| 8 |
+
ctypedef struct xmlRelaxNGValidCtxt
|
| 9 |
+
|
| 10 |
+
ctypedef enum xmlRelaxNGValidErr:
|
| 11 |
+
XML_RELAXNG_OK = 0
|
| 12 |
+
XML_RELAXNG_ERR_MEMORY = 1
|
| 13 |
+
XML_RELAXNG_ERR_TYPE = 2
|
| 14 |
+
XML_RELAXNG_ERR_TYPEVAL = 3
|
| 15 |
+
XML_RELAXNG_ERR_DUPID = 4
|
| 16 |
+
XML_RELAXNG_ERR_TYPECMP = 5
|
| 17 |
+
XML_RELAXNG_ERR_NOSTATE = 6
|
| 18 |
+
XML_RELAXNG_ERR_NODEFINE = 7
|
| 19 |
+
XML_RELAXNG_ERR_LISTEXTRA = 8
|
| 20 |
+
XML_RELAXNG_ERR_LISTEMPTY = 9
|
| 21 |
+
XML_RELAXNG_ERR_INTERNODATA = 10
|
| 22 |
+
XML_RELAXNG_ERR_INTERSEQ = 11
|
| 23 |
+
XML_RELAXNG_ERR_INTEREXTRA = 12
|
| 24 |
+
XML_RELAXNG_ERR_ELEMNAME = 13
|
| 25 |
+
XML_RELAXNG_ERR_ATTRNAME = 14
|
| 26 |
+
XML_RELAXNG_ERR_ELEMNONS = 15
|
| 27 |
+
XML_RELAXNG_ERR_ATTRNONS = 16
|
| 28 |
+
XML_RELAXNG_ERR_ELEMWRONGNS = 17
|
| 29 |
+
XML_RELAXNG_ERR_ATTRWRONGNS = 18
|
| 30 |
+
XML_RELAXNG_ERR_ELEMEXTRANS = 19
|
| 31 |
+
XML_RELAXNG_ERR_ATTREXTRANS = 20
|
| 32 |
+
XML_RELAXNG_ERR_ELEMNOTEMPTY = 21
|
| 33 |
+
XML_RELAXNG_ERR_NOELEM = 22
|
| 34 |
+
XML_RELAXNG_ERR_NOTELEM = 23
|
| 35 |
+
XML_RELAXNG_ERR_ATTRVALID = 24
|
| 36 |
+
XML_RELAXNG_ERR_CONTENTVALID = 25
|
| 37 |
+
XML_RELAXNG_ERR_EXTRACONTENT = 26
|
| 38 |
+
XML_RELAXNG_ERR_INVALIDATTR = 27
|
| 39 |
+
XML_RELAXNG_ERR_DATAELEM = 28
|
| 40 |
+
XML_RELAXNG_ERR_VALELEM = 29
|
| 41 |
+
XML_RELAXNG_ERR_LISTELEM = 30
|
| 42 |
+
XML_RELAXNG_ERR_DATATYPE = 31
|
| 43 |
+
XML_RELAXNG_ERR_VALUE = 32
|
| 44 |
+
XML_RELAXNG_ERR_LIST = 33
|
| 45 |
+
XML_RELAXNG_ERR_NOGRAMMAR = 34
|
| 46 |
+
XML_RELAXNG_ERR_EXTRADATA = 35
|
| 47 |
+
XML_RELAXNG_ERR_LACKDATA = 36
|
| 48 |
+
XML_RELAXNG_ERR_INTERNAL = 37
|
| 49 |
+
XML_RELAXNG_ERR_ELEMWRONG = 38
|
| 50 |
+
XML_RELAXNG_ERR_TEXTWRONG = 39
|
| 51 |
+
|
| 52 |
+
cdef xmlRelaxNGValidCtxt* xmlRelaxNGNewValidCtxt(xmlRelaxNG* schema)
|
| 53 |
+
cdef int xmlRelaxNGValidateDoc(xmlRelaxNGValidCtxt* ctxt, xmlDoc* doc)
|
| 54 |
+
cdef xmlRelaxNG* xmlRelaxNGParse(xmlRelaxNGParserCtxt* ctxt)
|
| 55 |
+
cdef xmlRelaxNGParserCtxt* xmlRelaxNGNewParserCtxt(char* URL)
|
| 56 |
+
cdef xmlRelaxNGParserCtxt* xmlRelaxNGNewDocParserCtxt(xmlDoc* doc)
|
| 57 |
+
cdef void xmlRelaxNGFree(xmlRelaxNG* schema)
|
| 58 |
+
cdef void xmlRelaxNGFreeParserCtxt(xmlRelaxNGParserCtxt* ctxt)
|
| 59 |
+
cdef void xmlRelaxNGFreeValidCtxt(xmlRelaxNGValidCtxt* ctxt)
|
| 60 |
+
|
| 61 |
+
cdef void xmlRelaxNGSetValidStructuredErrors(
|
| 62 |
+
xmlRelaxNGValidCtxt* ctxt, xmlStructuredErrorFunc serror, void *ctx)
|
| 63 |
+
cdef void xmlRelaxNGSetParserStructuredErrors(
|
| 64 |
+
xmlRelaxNGParserCtxt* ctxt, xmlStructuredErrorFunc serror, void *ctx)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/schematron.pxd
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lxml.includes cimport xmlerror
|
| 2 |
+
from lxml.includes.tree cimport xmlDoc
|
| 3 |
+
|
| 4 |
+
cdef extern from "libxml/schematron.h" nogil:
|
| 5 |
+
ctypedef struct xmlSchematron
|
| 6 |
+
ctypedef struct xmlSchematronParserCtxt
|
| 7 |
+
ctypedef struct xmlSchematronValidCtxt
|
| 8 |
+
|
| 9 |
+
ctypedef enum xmlSchematronValidOptions:
|
| 10 |
+
XML_SCHEMATRON_OUT_QUIET = 1 # quiet no report
|
| 11 |
+
XML_SCHEMATRON_OUT_TEXT = 2 # build a textual report
|
| 12 |
+
XML_SCHEMATRON_OUT_XML = 4 # output SVRL
|
| 13 |
+
XML_SCHEMATRON_OUT_ERROR = 8 # output via xmlStructuredErrorFunc
|
| 14 |
+
XML_SCHEMATRON_OUT_FILE = 256 # output to a file descriptor
|
| 15 |
+
XML_SCHEMATRON_OUT_BUFFER = 512 # output to a buffer
|
| 16 |
+
XML_SCHEMATRON_OUT_IO = 1024 # output to I/O mechanism
|
| 17 |
+
|
| 18 |
+
cdef xmlSchematronParserCtxt* xmlSchematronNewDocParserCtxt(
|
| 19 |
+
xmlDoc* doc)
|
| 20 |
+
cdef xmlSchematronParserCtxt* xmlSchematronNewParserCtxt(
|
| 21 |
+
char* filename) nogil
|
| 22 |
+
cdef xmlSchematronValidCtxt* xmlSchematronNewValidCtxt(
|
| 23 |
+
xmlSchematron* schema, int options)
|
| 24 |
+
|
| 25 |
+
cdef xmlSchematron* xmlSchematronParse(xmlSchematronParserCtxt* ctxt)
|
| 26 |
+
cdef int xmlSchematronValidateDoc(xmlSchematronValidCtxt* ctxt,
|
| 27 |
+
xmlDoc* instance)
|
| 28 |
+
|
| 29 |
+
cdef void xmlSchematronFreeParserCtxt(xmlSchematronParserCtxt* ctxt)
|
| 30 |
+
cdef void xmlSchematronFreeValidCtxt(xmlSchematronValidCtxt* ctxt)
|
| 31 |
+
cdef void xmlSchematronFree(xmlSchematron* schema)
|
| 32 |
+
cdef void xmlSchematronSetValidStructuredErrors(
|
| 33 |
+
xmlSchematronValidCtxt* ctxt,
|
| 34 |
+
xmlerror.xmlStructuredErrorFunc error_func, void *data)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/xpath.pxd
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lxml.includes cimport tree
|
| 2 |
+
from lxml.includes cimport xmlerror
|
| 3 |
+
|
| 4 |
+
from libc.string cimport const_char
|
| 5 |
+
from lxml.includes.tree cimport xmlChar, const_xmlChar
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
cdef extern from "libxml/xpath.h" nogil:
|
| 9 |
+
ctypedef enum xmlXPathObjectType:
|
| 10 |
+
XPATH_UNDEFINED = 0
|
| 11 |
+
XPATH_NODESET = 1
|
| 12 |
+
XPATH_BOOLEAN = 2
|
| 13 |
+
XPATH_NUMBER = 3
|
| 14 |
+
XPATH_STRING = 4
|
| 15 |
+
XPATH_POINT = 5
|
| 16 |
+
XPATH_RANGE = 6
|
| 17 |
+
XPATH_LOCATIONSET = 7
|
| 18 |
+
XPATH_USERS = 8
|
| 19 |
+
XPATH_XSLT_TREE = 9
|
| 20 |
+
|
| 21 |
+
ctypedef enum xmlXPathError:
|
| 22 |
+
XPATH_EXPRESSION_OK = 0
|
| 23 |
+
XPATH_NUMBER_ERROR = 1
|
| 24 |
+
XPATH_UNFINISHED_LITERAL_ERROR = 2
|
| 25 |
+
XPATH_START_LITERAL_ERROR = 3
|
| 26 |
+
XPATH_VARIABLE_REF_ERROR = 4
|
| 27 |
+
XPATH_UNDEF_VARIABLE_ERROR = 5
|
| 28 |
+
XPATH_INVALID_PREDICATE_ERROR = 6
|
| 29 |
+
XPATH_EXPR_ERROR = 7
|
| 30 |
+
XPATH_UNCLOSED_ERROR = 8
|
| 31 |
+
XPATH_UNKNOWN_FUNC_ERROR = 9
|
| 32 |
+
XPATH_INVALID_OPERAND = 10
|
| 33 |
+
XPATH_INVALID_TYPE = 11
|
| 34 |
+
XPATH_INVALID_ARITY = 12
|
| 35 |
+
XPATH_INVALID_CTXT_SIZE = 13
|
| 36 |
+
XPATH_INVALID_CTXT_POSITION = 14
|
| 37 |
+
XPATH_MEMORY_ERROR = 15
|
| 38 |
+
XPTR_SYNTAX_ERROR = 16
|
| 39 |
+
XPTR_RESOURCE_ERROR = 17
|
| 40 |
+
XPTR_SUB_RESOURCE_ERROR = 18
|
| 41 |
+
XPATH_UNDEF_PREFIX_ERROR = 19
|
| 42 |
+
XPATH_ENCODING_ERROR = 20
|
| 43 |
+
XPATH_INVALID_CHAR_ERROR = 21
|
| 44 |
+
XPATH_INVALID_CTXT = 22
|
| 45 |
+
|
| 46 |
+
ctypedef struct xmlNodeSet:
|
| 47 |
+
int nodeNr
|
| 48 |
+
int nodeMax
|
| 49 |
+
tree.xmlNode** nodeTab
|
| 50 |
+
|
| 51 |
+
ctypedef struct xmlXPathObject:
|
| 52 |
+
xmlXPathObjectType type
|
| 53 |
+
xmlNodeSet* nodesetval
|
| 54 |
+
bint boolval
|
| 55 |
+
double floatval
|
| 56 |
+
xmlChar* stringval
|
| 57 |
+
|
| 58 |
+
ctypedef struct xmlXPathContext:
|
| 59 |
+
tree.xmlDoc* doc
|
| 60 |
+
tree.xmlNode* node
|
| 61 |
+
tree.xmlDict* dict
|
| 62 |
+
tree.xmlHashTable* nsHash
|
| 63 |
+
const_xmlChar* function
|
| 64 |
+
const_xmlChar* functionURI
|
| 65 |
+
xmlerror.xmlStructuredErrorFunc error
|
| 66 |
+
xmlerror.xmlError lastError
|
| 67 |
+
void* userData
|
| 68 |
+
|
| 69 |
+
ctypedef struct xmlXPathParserContext:
|
| 70 |
+
xmlXPathContext* context
|
| 71 |
+
xmlXPathObject* value
|
| 72 |
+
tree.xmlNode* ancestor
|
| 73 |
+
int error
|
| 74 |
+
|
| 75 |
+
ctypedef struct xmlXPathCompExpr
|
| 76 |
+
|
| 77 |
+
ctypedef void (*xmlXPathFunction)(xmlXPathParserContext* ctxt, int nargs)
|
| 78 |
+
ctypedef xmlXPathFunction (*xmlXPathFuncLookupFunc)(void* ctxt,
|
| 79 |
+
const_xmlChar* name,
|
| 80 |
+
const_xmlChar* ns_uri)
|
| 81 |
+
|
| 82 |
+
cdef xmlXPathContext* xmlXPathNewContext(tree.xmlDoc* doc)
|
| 83 |
+
cdef xmlXPathObject* xmlXPathEvalExpression(const_xmlChar* str,
|
| 84 |
+
xmlXPathContext* ctxt)
|
| 85 |
+
cdef xmlXPathObject* xmlXPathCompiledEval(xmlXPathCompExpr* comp,
|
| 86 |
+
xmlXPathContext* ctxt)
|
| 87 |
+
cdef xmlXPathCompExpr* xmlXPathCompile(const_xmlChar* str)
|
| 88 |
+
cdef xmlXPathCompExpr* xmlXPathCtxtCompile(xmlXPathContext* ctxt,
|
| 89 |
+
const_xmlChar* str)
|
| 90 |
+
cdef void xmlXPathFreeContext(xmlXPathContext* ctxt)
|
| 91 |
+
cdef void xmlXPathFreeCompExpr(xmlXPathCompExpr* comp)
|
| 92 |
+
cdef void xmlXPathFreeObject(xmlXPathObject* obj)
|
| 93 |
+
cdef int xmlXPathRegisterNs(xmlXPathContext* ctxt,
|
| 94 |
+
const_xmlChar* prefix, const_xmlChar* ns_uri)
|
| 95 |
+
|
| 96 |
+
cdef xmlNodeSet* xmlXPathNodeSetCreate(tree.xmlNode* val)
|
| 97 |
+
cdef void xmlXPathFreeNodeSet(xmlNodeSet* val)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
cdef extern from "libxml/xpathInternals.h" nogil:
|
| 101 |
+
cdef int xmlXPathRegisterFunc(xmlXPathContext* ctxt,
|
| 102 |
+
const_xmlChar* name,
|
| 103 |
+
xmlXPathFunction f)
|
| 104 |
+
cdef int xmlXPathRegisterFuncNS(xmlXPathContext* ctxt,
|
| 105 |
+
const_xmlChar* name,
|
| 106 |
+
const_xmlChar* ns_uri,
|
| 107 |
+
xmlXPathFunction f)
|
| 108 |
+
cdef void xmlXPathRegisterFuncLookup(xmlXPathContext *ctxt,
|
| 109 |
+
xmlXPathFuncLookupFunc f,
|
| 110 |
+
void *funcCtxt)
|
| 111 |
+
cdef int xmlXPathRegisterVariable(xmlXPathContext *ctxt,
|
| 112 |
+
const_xmlChar* name,
|
| 113 |
+
xmlXPathObject* value)
|
| 114 |
+
cdef int xmlXPathRegisterVariableNS(xmlXPathContext *ctxt,
|
| 115 |
+
const_xmlChar* name,
|
| 116 |
+
const_xmlChar* ns_uri,
|
| 117 |
+
xmlXPathObject* value)
|
| 118 |
+
cdef void xmlXPathRegisteredVariablesCleanup(xmlXPathContext *ctxt)
|
| 119 |
+
cdef void xmlXPathRegisteredNsCleanup(xmlXPathContext *ctxt)
|
| 120 |
+
cdef xmlXPathObject* valuePop (xmlXPathParserContext *ctxt)
|
| 121 |
+
cdef int valuePush(xmlXPathParserContext* ctxt, xmlXPathObject *value)
|
| 122 |
+
|
| 123 |
+
cdef xmlXPathObject* xmlXPathNewCString(const_char *val)
|
| 124 |
+
cdef xmlXPathObject* xmlXPathWrapCString(const_char * val)
|
| 125 |
+
cdef xmlXPathObject* xmlXPathNewString(const_xmlChar *val)
|
| 126 |
+
cdef xmlXPathObject* xmlXPathWrapString(const_xmlChar * val)
|
| 127 |
+
cdef xmlXPathObject* xmlXPathNewFloat(double val)
|
| 128 |
+
cdef xmlXPathObject* xmlXPathNewBoolean(int val)
|
| 129 |
+
cdef xmlXPathObject* xmlXPathNewNodeSet(tree.xmlNode* val)
|
| 130 |
+
cdef xmlXPathObject* xmlXPathNewValueTree(tree.xmlNode* val)
|
| 131 |
+
cdef void xmlXPathNodeSetAdd(xmlNodeSet* cur,
|
| 132 |
+
tree.xmlNode* val)
|
| 133 |
+
cdef void xmlXPathNodeSetAddUnique(xmlNodeSet* cur,
|
| 134 |
+
tree.xmlNode* val)
|
| 135 |
+
cdef xmlXPathObject* xmlXPathWrapNodeSet(xmlNodeSet* val)
|
| 136 |
+
cdef void xmlXPathErr(xmlXPathParserContext* ctxt, int error)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/isoschematron/__init__.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""The ``lxml.isoschematron`` package implements ISO Schematron support on top
|
| 2 |
+
of the pure-xslt 'skeleton' implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import os.path
|
| 7 |
+
from lxml import etree as _etree # due to validator __init__ signature
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# some compat stuff, borrowed from lxml.html
|
| 11 |
+
try:
|
| 12 |
+
unicode
|
| 13 |
+
except NameError:
|
| 14 |
+
# Python 3
|
| 15 |
+
unicode = str
|
| 16 |
+
try:
|
| 17 |
+
basestring
|
| 18 |
+
except NameError:
|
| 19 |
+
# Python 3
|
| 20 |
+
basestring = str
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ['extract_xsd', 'extract_rng', 'iso_dsdl_include',
|
| 24 |
+
'iso_abstract_expand', 'iso_svrl_for_xslt1',
|
| 25 |
+
'svrl_validation_errors', 'schematron_schema_valid',
|
| 26 |
+
'stylesheet_params', 'Schematron']
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# some namespaces
|
| 30 |
+
#FIXME: Maybe lxml should provide a dedicated place for common namespace
|
| 31 |
+
#FIXME: definitions?
|
| 32 |
+
XML_SCHEMA_NS = "http://www.w3.org/2001/XMLSchema"
|
| 33 |
+
RELAXNG_NS = "http://relaxng.org/ns/structure/1.0"
|
| 34 |
+
SCHEMATRON_NS = "http://purl.oclc.org/dsdl/schematron"
|
| 35 |
+
SVRL_NS = "http://purl.oclc.org/dsdl/svrl"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# some helpers
|
| 39 |
+
_schematron_root = '{%s}schema' % SCHEMATRON_NS
|
| 40 |
+
_xml_schema_root = '{%s}schema' % XML_SCHEMA_NS
|
| 41 |
+
_resources_dir = os.path.join(os.path.dirname(__file__), 'resources')
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# the iso-schematron skeleton implementation steps aka xsl transformations
|
| 45 |
+
extract_xsd = _etree.XSLT(_etree.parse(
|
| 46 |
+
os.path.join(_resources_dir, 'xsl', 'XSD2Schtrn.xsl')))
|
| 47 |
+
extract_rng = _etree.XSLT(_etree.parse(
|
| 48 |
+
os.path.join(_resources_dir, 'xsl', 'RNG2Schtrn.xsl')))
|
| 49 |
+
iso_dsdl_include = _etree.XSLT(_etree.parse(
|
| 50 |
+
os.path.join(_resources_dir, 'xsl', 'iso-schematron-xslt1',
|
| 51 |
+
'iso_dsdl_include.xsl')))
|
| 52 |
+
iso_abstract_expand = _etree.XSLT(_etree.parse(
|
| 53 |
+
os.path.join(_resources_dir, 'xsl', 'iso-schematron-xslt1',
|
| 54 |
+
'iso_abstract_expand.xsl')))
|
| 55 |
+
iso_svrl_for_xslt1 = _etree.XSLT(_etree.parse(
|
| 56 |
+
os.path.join(_resources_dir,
|
| 57 |
+
'xsl', 'iso-schematron-xslt1', 'iso_svrl_for_xslt1.xsl')))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# svrl result accessors
|
| 61 |
+
svrl_validation_errors = _etree.XPath(
|
| 62 |
+
'//svrl:failed-assert', namespaces={'svrl': SVRL_NS})
|
| 63 |
+
|
| 64 |
+
# RelaxNG validator for schematron schemas
|
| 65 |
+
schematron_schema_valid_supported = False
|
| 66 |
+
try:
|
| 67 |
+
schematron_schema_valid = _etree.RelaxNG(
|
| 68 |
+
file=os.path.join(_resources_dir, 'rng', 'iso-schematron.rng'))
|
| 69 |
+
schematron_schema_valid_supported = True
|
| 70 |
+
except _etree.RelaxNGParseError:
|
| 71 |
+
# Some distributions delete the file due to licensing issues.
|
| 72 |
+
def schematron_schema_valid(arg):
|
| 73 |
+
raise NotImplementedError("Validating the ISO schematron requires iso-schematron.rng")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def stylesheet_params(**kwargs):
|
| 77 |
+
"""Convert keyword args to a dictionary of stylesheet parameters.
|
| 78 |
+
XSL stylesheet parameters must be XPath expressions, i.e.:
|
| 79 |
+
|
| 80 |
+
* string expressions, like "'5'"
|
| 81 |
+
* simple (number) expressions, like "5"
|
| 82 |
+
* valid XPath expressions, like "/a/b/text()"
|
| 83 |
+
|
| 84 |
+
This function converts native Python keyword arguments to stylesheet
|
| 85 |
+
parameters following these rules:
|
| 86 |
+
If an arg is a string wrap it with XSLT.strparam().
|
| 87 |
+
If an arg is an XPath object use its path string.
|
| 88 |
+
If arg is None raise TypeError.
|
| 89 |
+
Else convert arg to string.
|
| 90 |
+
"""
|
| 91 |
+
result = {}
|
| 92 |
+
for key, val in kwargs.items():
|
| 93 |
+
if isinstance(val, basestring):
|
| 94 |
+
val = _etree.XSLT.strparam(val)
|
| 95 |
+
elif val is None:
|
| 96 |
+
raise TypeError('None not allowed as a stylesheet parameter')
|
| 97 |
+
elif not isinstance(val, _etree.XPath):
|
| 98 |
+
val = unicode(val)
|
| 99 |
+
result[key] = val
|
| 100 |
+
return result
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# helper function for use in Schematron __init__
|
| 104 |
+
def _stylesheet_param_dict(paramsDict, kwargsDict):
|
| 105 |
+
"""Return a copy of paramsDict, updated with kwargsDict entries, wrapped as
|
| 106 |
+
stylesheet arguments.
|
| 107 |
+
kwargsDict entries with a value of None are ignored.
|
| 108 |
+
"""
|
| 109 |
+
# beware of changing mutable default arg
|
| 110 |
+
paramsDict = dict(paramsDict)
|
| 111 |
+
for k, v in kwargsDict.items():
|
| 112 |
+
if v is not None: # None values do not override
|
| 113 |
+
paramsDict[k] = v
|
| 114 |
+
paramsDict = stylesheet_params(**paramsDict)
|
| 115 |
+
return paramsDict
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Schematron(_etree._Validator):
|
| 119 |
+
"""An ISO Schematron validator.
|
| 120 |
+
|
| 121 |
+
Pass a root Element or an ElementTree to turn it into a validator.
|
| 122 |
+
Alternatively, pass a filename as keyword argument 'file' to parse from
|
| 123 |
+
the file system.
|
| 124 |
+
|
| 125 |
+
Schematron is a less well known, but very powerful schema language.
|
| 126 |
+
The main idea is to use the capabilities of XPath to put restrictions on
|
| 127 |
+
the structure and the content of XML documents.
|
| 128 |
+
|
| 129 |
+
The standard behaviour is to fail on ``failed-assert`` findings only
|
| 130 |
+
(``ASSERTS_ONLY``). To change this, you can either pass a report filter
|
| 131 |
+
function to the ``error_finder`` parameter (e.g. ``ASSERTS_AND_REPORTS``
|
| 132 |
+
or a custom ``XPath`` object), or subclass isoschematron.Schematron for
|
| 133 |
+
complete control of the validation process.
|
| 134 |
+
|
| 135 |
+
Built on the Schematron language 'reference' skeleton pure-xslt
|
| 136 |
+
implementation, the validator is created as an XSLT 1.0 stylesheet using
|
| 137 |
+
these steps:
|
| 138 |
+
|
| 139 |
+
0) (Extract from XML Schema or RelaxNG schema)
|
| 140 |
+
1) Process inclusions
|
| 141 |
+
2) Process abstract patterns
|
| 142 |
+
3) Compile the schematron schema to XSLT
|
| 143 |
+
|
| 144 |
+
The ``include`` and ``expand`` keyword arguments can be used to switch off
|
| 145 |
+
steps 1) and 2).
|
| 146 |
+
To set parameters for steps 1), 2) and 3) hand parameter dictionaries to the
|
| 147 |
+
keyword arguments ``include_params``, ``expand_params`` or
|
| 148 |
+
``compile_params``.
|
| 149 |
+
For convenience, the compile-step parameter ``phase`` is also exposed as a
|
| 150 |
+
keyword argument ``phase``. This takes precedence if the parameter is also
|
| 151 |
+
given in the parameter dictionary.
|
| 152 |
+
|
| 153 |
+
If ``store_schematron`` is set to True, the (included-and-expanded)
|
| 154 |
+
schematron document tree is stored and available through the ``schematron``
|
| 155 |
+
property.
|
| 156 |
+
If ``store_xslt`` is set to True, the validation XSLT document tree will be
|
| 157 |
+
stored and can be retrieved through the ``validator_xslt`` property.
|
| 158 |
+
With ``store_report`` set to True (default: False), the resulting validation
|
| 159 |
+
report document gets stored and can be accessed as the ``validation_report``
|
| 160 |
+
property.
|
| 161 |
+
|
| 162 |
+
If ``validate_schema`` is set to False, the validation of the schema file
|
| 163 |
+
itself is disabled. Validation happens by default after building the full
|
| 164 |
+
schema, unless the schema validation file cannot be found at import time,
|
| 165 |
+
in which case the validation gets disabled. Some lxml distributions exclude
|
| 166 |
+
this file due to licensing issues. ISO-Schematron validation can then still
|
| 167 |
+
be used normally, but the schemas themselves cannot be validated.
|
| 168 |
+
|
| 169 |
+
Here is a usage example::
|
| 170 |
+
|
| 171 |
+
>>> from lxml import etree
|
| 172 |
+
>>> from lxml.isoschematron import Schematron
|
| 173 |
+
|
| 174 |
+
>>> schematron = Schematron(etree.XML('''
|
| 175 |
+
... <schema xmlns="http://purl.oclc.org/dsdl/schematron" >
|
| 176 |
+
... <pattern id="id_only_attribute">
|
| 177 |
+
... <title>id is the only permitted attribute name</title>
|
| 178 |
+
... <rule context="*">
|
| 179 |
+
... <report test="@*[not(name()='id')]">Attribute
|
| 180 |
+
... <name path="@*[not(name()='id')]"/> is forbidden<name/>
|
| 181 |
+
... </report>
|
| 182 |
+
... </rule>
|
| 183 |
+
... </pattern>
|
| 184 |
+
... </schema>'''),
|
| 185 |
+
... error_finder=Schematron.ASSERTS_AND_REPORTS)
|
| 186 |
+
|
| 187 |
+
>>> xml = etree.XML('''
|
| 188 |
+
... <AAA name="aaa">
|
| 189 |
+
... <BBB id="bbb"/>
|
| 190 |
+
... <CCC color="ccc"/>
|
| 191 |
+
... </AAA>
|
| 192 |
+
... ''')
|
| 193 |
+
|
| 194 |
+
>>> schematron.validate(xml)
|
| 195 |
+
False
|
| 196 |
+
|
| 197 |
+
>>> xml = etree.XML('''
|
| 198 |
+
... <AAA id="aaa">
|
| 199 |
+
... <BBB id="bbb"/>
|
| 200 |
+
... <CCC/>
|
| 201 |
+
... </AAA>
|
| 202 |
+
... ''')
|
| 203 |
+
|
| 204 |
+
>>> schematron.validate(xml)
|
| 205 |
+
True
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
# libxml2 error categorization for validation errors
|
| 209 |
+
_domain = _etree.ErrorDomains.SCHEMATRONV
|
| 210 |
+
_level = _etree.ErrorLevels.ERROR
|
| 211 |
+
_error_type = _etree.ErrorTypes.SCHEMATRONV_ASSERT
|
| 212 |
+
|
| 213 |
+
# convenience definitions for common behaviours
|
| 214 |
+
ASSERTS_ONLY = svrl_validation_errors # Default
|
| 215 |
+
ASSERTS_AND_REPORTS = _etree.XPath(
|
| 216 |
+
'//svrl:failed-assert | //svrl:successful-report',
|
| 217 |
+
namespaces={'svrl': SVRL_NS})
|
| 218 |
+
|
| 219 |
+
def _extract(self, element):
|
| 220 |
+
"""Extract embedded schematron schema from non-schematron host schema.
|
| 221 |
+
This method will only be called by __init__ if the given schema document
|
| 222 |
+
is not a schematron schema by itself.
|
| 223 |
+
Must return a schematron schema document tree or None.
|
| 224 |
+
"""
|
| 225 |
+
schematron = None
|
| 226 |
+
if element.tag == _xml_schema_root:
|
| 227 |
+
schematron = self._extract_xsd(element)
|
| 228 |
+
elif element.nsmap.get(element.prefix) == RELAXNG_NS:
|
| 229 |
+
# RelaxNG does not have a single unique root element
|
| 230 |
+
schematron = self._extract_rng(element)
|
| 231 |
+
return schematron
|
| 232 |
+
|
| 233 |
+
# customization points
|
| 234 |
+
# etree.XSLT objects that provide the extract, include, expand, compile
|
| 235 |
+
# steps
|
| 236 |
+
_extract_xsd = extract_xsd
|
| 237 |
+
_extract_rng = extract_rng
|
| 238 |
+
_include = iso_dsdl_include
|
| 239 |
+
_expand = iso_abstract_expand
|
| 240 |
+
_compile = iso_svrl_for_xslt1
|
| 241 |
+
|
| 242 |
+
# etree.xpath object that determines input document validity when applied to
|
| 243 |
+
# the svrl result report; must return a list of result elements (empty if
|
| 244 |
+
# valid)
|
| 245 |
+
_validation_errors = ASSERTS_ONLY
|
| 246 |
+
|
| 247 |
+
def __init__(self, etree=None, file=None, include=True, expand=True,
|
| 248 |
+
include_params={}, expand_params={}, compile_params={},
|
| 249 |
+
store_schematron=False, store_xslt=False, store_report=False,
|
| 250 |
+
phase=None, error_finder=ASSERTS_ONLY,
|
| 251 |
+
validate_schema=schematron_schema_valid_supported):
|
| 252 |
+
super().__init__()
|
| 253 |
+
|
| 254 |
+
self._store_report = store_report
|
| 255 |
+
self._schematron = None
|
| 256 |
+
self._validator_xslt = None
|
| 257 |
+
self._validation_report = None
|
| 258 |
+
if error_finder is not self.ASSERTS_ONLY:
|
| 259 |
+
self._validation_errors = error_finder
|
| 260 |
+
|
| 261 |
+
# parse schema document, may be a schematron schema or an XML Schema or
|
| 262 |
+
# a RelaxNG schema with embedded schematron rules
|
| 263 |
+
root = None
|
| 264 |
+
try:
|
| 265 |
+
if etree is not None:
|
| 266 |
+
if _etree.iselement(etree):
|
| 267 |
+
root = etree
|
| 268 |
+
else:
|
| 269 |
+
root = etree.getroot()
|
| 270 |
+
elif file is not None:
|
| 271 |
+
root = _etree.parse(file).getroot()
|
| 272 |
+
except Exception:
|
| 273 |
+
raise _etree.SchematronParseError(
|
| 274 |
+
"No tree or file given: %s" % sys.exc_info()[1])
|
| 275 |
+
if root is None:
|
| 276 |
+
raise ValueError("Empty tree")
|
| 277 |
+
if root.tag == _schematron_root:
|
| 278 |
+
schematron = root
|
| 279 |
+
else:
|
| 280 |
+
schematron = self._extract(root)
|
| 281 |
+
if schematron is None:
|
| 282 |
+
raise _etree.SchematronParseError(
|
| 283 |
+
"Document is not a schematron schema or schematron-extractable")
|
| 284 |
+
# perform the iso-schematron skeleton implementation steps to get a
|
| 285 |
+
# validating xslt
|
| 286 |
+
if include:
|
| 287 |
+
schematron = self._include(schematron, **include_params)
|
| 288 |
+
if expand:
|
| 289 |
+
schematron = self._expand(schematron, **expand_params)
|
| 290 |
+
if validate_schema and not schematron_schema_valid(schematron):
|
| 291 |
+
raise _etree.SchematronParseError(
|
| 292 |
+
"invalid schematron schema: %s" %
|
| 293 |
+
schematron_schema_valid.error_log)
|
| 294 |
+
if store_schematron:
|
| 295 |
+
self._schematron = schematron
|
| 296 |
+
# add new compile keyword args here if exposing them
|
| 297 |
+
compile_kwargs = {'phase': phase}
|
| 298 |
+
compile_params = _stylesheet_param_dict(compile_params, compile_kwargs)
|
| 299 |
+
validator_xslt = self._compile(schematron, **compile_params)
|
| 300 |
+
if store_xslt:
|
| 301 |
+
self._validator_xslt = validator_xslt
|
| 302 |
+
self._validator = _etree.XSLT(validator_xslt)
|
| 303 |
+
|
| 304 |
+
def __call__(self, etree):
|
| 305 |
+
"""Validate doc using Schematron.
|
| 306 |
+
|
| 307 |
+
Returns true if document is valid, false if not.
|
| 308 |
+
"""
|
| 309 |
+
self._clear_error_log()
|
| 310 |
+
result = self._validator(etree)
|
| 311 |
+
if self._store_report:
|
| 312 |
+
self._validation_report = result
|
| 313 |
+
errors = self._validation_errors(result)
|
| 314 |
+
if errors:
|
| 315 |
+
if _etree.iselement(etree):
|
| 316 |
+
fname = etree.getroottree().docinfo.URL or '<file>'
|
| 317 |
+
else:
|
| 318 |
+
fname = etree.docinfo.URL or '<file>'
|
| 319 |
+
for error in errors:
|
| 320 |
+
# Does svrl report the line number, anywhere? Don't think so.
|
| 321 |
+
self._append_log_message(
|
| 322 |
+
domain=self._domain, type=self._error_type,
|
| 323 |
+
level=self._level, line=0,
|
| 324 |
+
message=_etree.tostring(error, encoding='unicode'),
|
| 325 |
+
filename=fname)
|
| 326 |
+
return False
|
| 327 |
+
return True
|
| 328 |
+
|
| 329 |
+
@property
|
| 330 |
+
def schematron(self):
|
| 331 |
+
"""ISO-schematron schema document (None if object has been initialized
|
| 332 |
+
with store_schematron=False).
|
| 333 |
+
"""
|
| 334 |
+
return self._schematron
|
| 335 |
+
|
| 336 |
+
@property
|
| 337 |
+
def validator_xslt(self):
|
| 338 |
+
"""ISO-schematron skeleton implementation XSLT validator document (None
|
| 339 |
+
if object has been initialized with store_xslt=False).
|
| 340 |
+
"""
|
| 341 |
+
return self._validator_xslt
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def validation_report(self):
|
| 345 |
+
"""ISO-schematron validation result report (None if result-storing has
|
| 346 |
+
been turned off).
|
| 347 |
+
"""
|
| 348 |
+
return self._validation_report
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (3.8 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/compiler.cpython-312.pyc
ADDED
|
Binary file (4.62 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/driver.cpython-312.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/__init__.py
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/compiler.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from triton.backends.compiler import BaseBackend, GPUTarget, Language
|
| 2 |
+
from triton._C.libtriton import ir, passes, llvm, amd
|
| 3 |
+
from triton import knobs
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, Tuple
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
import hashlib
|
| 8 |
+
import tempfile
|
| 9 |
+
import re
|
| 10 |
+
import functools
|
| 11 |
+
import warnings
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_min_dot_size(target: GPUTarget):
|
| 16 |
+
# We fallback to use FMA and cast arguments if certain configurations is
|
| 17 |
+
# not supported natively by matrix core units.
|
| 18 |
+
return lambda lhs_type, rhs_type: (1, 1, 1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def is_pingpong_schedule_enabled(arch, use_async_copy):
|
| 22 |
+
return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
|
| 23 |
+
) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_in_thread_transpose_enabled(arch):
|
| 27 |
+
return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(frozen=True)
|
| 31 |
+
class HIPOptions:
|
| 32 |
+
num_warps: int = 4
|
| 33 |
+
waves_per_eu: int = 0
|
| 34 |
+
num_stages: int = 2
|
| 35 |
+
num_ctas: int = 1
|
| 36 |
+
extern_libs: dict = None
|
| 37 |
+
debug: bool = False
|
| 38 |
+
sanitize_overflow: bool = True
|
| 39 |
+
arch: str = None
|
| 40 |
+
# We have native support for OCP fp8 variants since CDNA4/RDNA4. For earlier generations,
|
| 41 |
+
# we software emulate the support for them.
|
| 42 |
+
# UZ fp8 variants (fp8e4b8 and fp8e5b16) are natively supported for CDNA3. For other
|
| 43 |
+
# architectures they are software emulated.
|
| 44 |
+
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8")
|
| 45 |
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
| 46 |
+
default_dot_input_precision: str = "ieee"
|
| 47 |
+
allowed_dot_input_precisions: Tuple[str] = ("ieee", 'bf16x3', 'bf16x6')
|
| 48 |
+
enable_fp_fusion: bool = True
|
| 49 |
+
launch_cooperative_grid: bool = False
|
| 50 |
+
matrix_instr_nonkdim: int = 0
|
| 51 |
+
kpack: int = 1
|
| 52 |
+
allow_flush_denorm: bool = False
|
| 53 |
+
max_num_imprecise_acc_default: int = 0
|
| 54 |
+
backend_name: str = 'hip'
|
| 55 |
+
instrumentation_mode: str = ""
|
| 56 |
+
|
| 57 |
+
# The following option provides hints to the AMDGPU backend regarding instruction scheduling
|
| 58 |
+
# for all `tt.dot` operations in a kernel. The "none" variant preserves the default
|
| 59 |
+
# instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
|
| 60 |
+
# The option is experimental and may change at any time regarding its semantics and/or may
|
| 61 |
+
# be gone entirely anytime.
|
| 62 |
+
#
|
| 63 |
+
# Current experimental scheduling variants:
|
| 64 |
+
#
|
| 65 |
+
# attention: enables a bunch of optimizations for attention kernels, including:
|
| 66 |
+
# - iglp 2 and sched.barrier around it
|
| 67 |
+
# - sink-insts-to-avoid-spills flag to avoid register spills
|
| 68 |
+
# memory-bound-attention: enables custom scheduling strategy in llvm backend,
|
| 69 |
+
# This option targets special FA variant, which is memory bound and
|
| 70 |
+
# has a lot of elementwise operations from fused operand dequantizations.
|
| 71 |
+
# Note that this option is highly experimental,
|
| 72 |
+
# and will be removed as soon as default sceduler algorithm is fixed.
|
| 73 |
+
#
|
| 74 |
+
# Option allows to set multiple variants divided by commas:
|
| 75 |
+
# schedule_hint="attention,memory-bound-attention"
|
| 76 |
+
schedule_hint: str = 'none'
|
| 77 |
+
|
| 78 |
+
def __post_init__(self):
|
| 79 |
+
gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number
|
| 80 |
+
warp_size = 32 if gfx_major >= 10 else 64
|
| 81 |
+
object.__setattr__(self, 'warp_size', warp_size)
|
| 82 |
+
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
| 83 |
+
"num_warps must be a power of 2"
|
| 84 |
+
|
| 85 |
+
if (self.arch == 'gfx950') and (self.kpack != 1):
|
| 86 |
+
warnings.warn(
|
| 87 |
+
f"kpack is deprecated starting from gfx950 and will be removed in later releases. So for now kpack = {self.kpack} will be overwritten to 1 to make transitioning easier."
|
| 88 |
+
)
|
| 89 |
+
object.__setattr__(self, 'kpack', 1)
|
| 90 |
+
|
| 91 |
+
default_libdir = Path(__file__).parent / 'lib'
|
| 92 |
+
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
|
| 93 |
+
for lib in ["ocml", "ockl"]:
|
| 94 |
+
extern_libs[lib] = str(default_libdir / f'{lib}.bc')
|
| 95 |
+
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
|
| 96 |
+
|
| 97 |
+
def hash(self):
|
| 98 |
+
key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
|
| 99 |
+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class HIPBackend(BaseBackend):
|
| 103 |
+
instrumentation = None
|
| 104 |
+
supports_native_tensor_specialization = False
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def supports_target(target: GPUTarget):
|
| 108 |
+
return target.backend == 'hip'
|
| 109 |
+
|
| 110 |
+
def __init__(self, target: GPUTarget) -> None:
|
| 111 |
+
super().__init__(target)
|
| 112 |
+
assert isinstance(target.arch, str)
|
| 113 |
+
self.binary_ext = "hsaco"
|
| 114 |
+
|
| 115 |
+
def get_target_name(self, options) -> str:
|
| 116 |
+
return f"hip:{options.arch}"
|
| 117 |
+
|
| 118 |
+
def parse_options(self, opts) -> Any:
|
| 119 |
+
args = {'arch': knobs.runtime.override_arch or self.target.arch}
|
| 120 |
+
|
| 121 |
+
if opts.get("num_ctas", 1) > 1 and not amd.supports_multi_cta_launch(self.target.arch):
|
| 122 |
+
raise ValueError(f"num_ctas > 1 not supported on {self.target.arch}")
|
| 123 |
+
|
| 124 |
+
# Enable XF32 (TF32) for CDNA3 GPUs
|
| 125 |
+
if self.target.arch == 'gfx942':
|
| 126 |
+
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
|
| 127 |
+
allowed_dot_input_precisions.update({'tf32'})
|
| 128 |
+
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
|
| 129 |
+
|
| 130 |
+
if "supported_fp8_dtypes" not in opts:
|
| 131 |
+
args["supported_fp8_dtypes"] = tuple(sorted(HIPOptions.supported_fp8_dtypes))
|
| 132 |
+
|
| 133 |
+
if self.target.arch == 'gfx950':
|
| 134 |
+
deprecated_fp8_dot_operand_dtypes = set(HIPOptions.deprecated_fp8_dot_operand_dtypes)
|
| 135 |
+
deprecated_fp8_dot_operand_dtypes.update({"fp8e5b16", "fp8e4b8"})
|
| 136 |
+
args["deprecated_fp8_dot_operand_dtypes"] = tuple(sorted(deprecated_fp8_dot_operand_dtypes))
|
| 137 |
+
|
| 138 |
+
if "enable_fp_fusion" not in opts:
|
| 139 |
+
args["enable_fp_fusion"] = knobs.language.default_fp_fusion
|
| 140 |
+
args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None})
|
| 141 |
+
return HIPOptions(**args)
|
| 142 |
+
|
| 143 |
+
def pack_metadata(self, metadata):
|
| 144 |
+
return (
|
| 145 |
+
metadata.num_warps,
|
| 146 |
+
metadata.num_ctas,
|
| 147 |
+
metadata.shared,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def get_codegen_implementation(self, options):
|
| 151 |
+
return {"min_dot_size": get_min_dot_size(self.target)}
|
| 152 |
+
|
| 153 |
+
def get_module_map(self) -> Dict[str, ModuleType]:
|
| 154 |
+
from triton.language.extra.hip import libdevice
|
| 155 |
+
|
| 156 |
+
return {"triton.language.extra.libdevice": libdevice}
|
| 157 |
+
|
| 158 |
+
def load_dialects(self, ctx):
|
| 159 |
+
amd.load_dialects(ctx)
|
| 160 |
+
if HIPBackend.instrumentation:
|
| 161 |
+
HIPBackend.instrumentation.load_dialects(ctx)
|
| 162 |
+
|
| 163 |
+
@staticmethod
|
| 164 |
+
def is_within_2gb(arg):
|
| 165 |
+
import torch
|
| 166 |
+
|
| 167 |
+
MAX_INT_32 = 2**31 - 1
|
| 168 |
+
if hasattr(arg, "ptr_range"):
|
| 169 |
+
return arg.ptr_range() <= MAX_INT_32
|
| 170 |
+
if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
|
| 171 |
+
return arg.untyped_storage().size() <= MAX_INT_32
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def parse_attr(desc):
|
| 176 |
+
ret = BaseBackend.parse_attr(desc)
|
| 177 |
+
if "S" in desc:
|
| 178 |
+
ret += [["tt.pointer_range", 32]]
|
| 179 |
+
return ret
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def get_tensor_specialization(arg, **kwargs):
|
| 183 |
+
ret = BaseBackend.get_tensor_specialization(arg, **kwargs)
|
| 184 |
+
if knobs.amd.use_buffer_ops and HIPBackend.is_within_2gb(arg):
|
| 185 |
+
ret += "S"
|
| 186 |
+
return ret
|
| 187 |
+
|
| 188 |
+
@staticmethod
|
| 189 |
+
def make_ttir(mod, metadata, options):
|
| 190 |
+
pm = ir.pass_manager(mod.context)
|
| 191 |
+
pm.enable_debug()
|
| 192 |
+
passes.common.add_inliner(pm)
|
| 193 |
+
passes.ttir.add_rewrite_tensor_pointer(pm)
|
| 194 |
+
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
|
| 195 |
+
passes.common.add_canonicalizer(pm)
|
| 196 |
+
passes.ttir.add_combine(pm)
|
| 197 |
+
passes.ttir.add_reorder_broadcast(pm)
|
| 198 |
+
passes.common.add_cse(pm)
|
| 199 |
+
passes.ttir.add_triton_licm(pm)
|
| 200 |
+
passes.common.add_symbol_dce(pm)
|
| 201 |
+
passes.ttir.add_loop_unroll(pm)
|
| 202 |
+
pm.run(mod, 'make_ttir')
|
| 203 |
+
return mod
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def make_ttgir(mod, metadata, options):
|
| 207 |
+
pm = ir.pass_manager(mod.context)
|
| 208 |
+
pm.enable_debug()
|
| 209 |
+
passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size,
|
| 210 |
+
options.num_ctas)
|
| 211 |
+
pm.run(mod, 'make_ttgir_early')
|
| 212 |
+
pm = ir.pass_manager(mod.context)
|
| 213 |
+
pm.enable_debug()
|
| 214 |
+
emuTF32 = False
|
| 215 |
+
passes.ttgpuir.add_coalesce(pm)
|
| 216 |
+
passes.ttgpuir.add_f32_dot_tc(pm, emuTF32)
|
| 217 |
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
| 218 |
+
passes.ttgpuir.add_optimize_thread_locality(pm)
|
| 219 |
+
amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack)
|
| 220 |
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
| 221 |
+
amd.passes.ttgpuir.add_optimize_epilogue(pm)
|
| 222 |
+
amd.passes.ttgpuir.add_optimize_dot_operands(pm, options.arch)
|
| 223 |
+
amd.passes.ttgpuir.add_hoist_layout_conversions(pm)
|
| 224 |
+
|
| 225 |
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
| 226 |
+
passes.common.add_canonicalizer(pm)
|
| 227 |
+
passes.ttir.add_triton_licm(pm)
|
| 228 |
+
passes.common.add_canonicalizer(pm)
|
| 229 |
+
|
| 230 |
+
use_async_copy = knobs.amd.use_async_copy
|
| 231 |
+
use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)
|
| 232 |
+
|
| 233 |
+
amd.passes.ttgpuir.add_schedule_loops(pm, options.num_stages)
|
| 234 |
+
amd.passes.ttgpuir.add_pipeline(pm, use_async_copy, use_block_pingpong)
|
| 235 |
+
if use_async_copy:
|
| 236 |
+
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
|
| 237 |
+
passes.common.add_canonicalizer(pm)
|
| 238 |
+
if options.schedule_hint.lower() != "none":
|
| 239 |
+
for hint in options.schedule_hint.split(","):
|
| 240 |
+
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, hint)
|
| 241 |
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
| 242 |
+
passes.ttgpuir.add_reduce_data_duplication(pm)
|
| 243 |
+
if is_in_thread_transpose_enabled(options.arch):
|
| 244 |
+
amd.passes.ttgpuir.add_in_thread_transpose(pm)
|
| 245 |
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
| 246 |
+
amd.passes.ttgpuir.add_reorder_instructions(pm)
|
| 247 |
+
if use_block_pingpong and options.num_stages > 1:
|
| 248 |
+
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
|
| 249 |
+
|
| 250 |
+
if knobs.amd.use_buffer_ops:
|
| 251 |
+
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
|
| 252 |
+
passes.common.add_canonicalizer(pm)
|
| 253 |
+
amd.passes.ttgpuir.add_convert_to_buffer_ops(
|
| 254 |
+
pm,
|
| 255 |
+
options.arch,
|
| 256 |
+
knobs.amd.use_buffer_atomics,
|
| 257 |
+
knobs.amd.buffer_ops_analyze_small_tensor_range,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
amd.passes.ttgpuir.add_fold_true_cmpi(pm)
|
| 261 |
+
passes.common.add_canonicalizer(pm)
|
| 262 |
+
passes.common.add_cse(pm)
|
| 263 |
+
passes.common.add_symbol_dce(pm)
|
| 264 |
+
pm.run(mod, 'make_ttgir')
|
| 265 |
+
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
|
| 266 |
+
return mod
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def gluon_to_ttgir(src, metadata, options):
|
| 270 |
+
mod = src
|
| 271 |
+
pm = ir.pass_manager(mod.context)
|
| 272 |
+
pm.enable_debug()
|
| 273 |
+
|
| 274 |
+
passes.gluon.add_inliner(pm)
|
| 275 |
+
passes.gluon.add_resolve_auto_encodings(pm)
|
| 276 |
+
passes.common.add_sccp(pm)
|
| 277 |
+
passes.ttir.add_loop_aware_cse(pm)
|
| 278 |
+
passes.gluon.add_canonicalizer(pm)
|
| 279 |
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
| 280 |
+
|
| 281 |
+
pm.run(mod, 'gluon_to_ttgir')
|
| 282 |
+
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
|
| 283 |
+
return mod
|
| 284 |
+
|
| 285 |
+
@staticmethod
|
| 286 |
+
def make_llir(src, metadata, options):
|
| 287 |
+
mod = src
|
| 288 |
+
# TritonGPU -> LLVM-IR (MLIR)
|
| 289 |
+
pm = ir.pass_manager(mod.context)
|
| 290 |
+
pm.enable_debug()
|
| 291 |
+
amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
|
| 292 |
+
# custom_lds_size is an experimental parameter that defines amount of LDS available
|
| 293 |
+
# for one thread block. Measured in bytes.
|
| 294 |
+
#
|
| 295 |
+
# If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
|
| 296 |
+
# LDS size is determined by provided arch name.
|
| 297 |
+
custom_lds_size = 0
|
| 298 |
+
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
|
| 299 |
+
passes.convert.add_scf_to_cf(pm)
|
| 300 |
+
passes.gluon.add_inliner(pm)
|
| 301 |
+
passes.convert.add_index_to_llvmir(pm)
|
| 302 |
+
|
| 303 |
+
amd.passes.ttgpuir.add_allocate_shared_memory(pm)
|
| 304 |
+
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
|
| 305 |
+
if HIPBackend.instrumentation:
|
| 306 |
+
HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
|
| 307 |
+
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
|
| 308 |
+
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
|
| 309 |
+
## of the value of kernel arg `allow_flush_denorm`.
|
| 310 |
+
## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
|
| 311 |
+
## depends on the value of kernel arg `allow_flush_denorm`.
|
| 312 |
+
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
|
| 313 |
+
## For now it is used as a controller for developers only.
|
| 314 |
+
__HIP_FTZ = True
|
| 315 |
+
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
|
| 316 |
+
passes.common.add_canonicalizer(pm)
|
| 317 |
+
passes.common.add_cse(pm)
|
| 318 |
+
|
| 319 |
+
passes.convert.add_cf_to_llvmir(pm)
|
| 320 |
+
passes.convert.add_arith_to_llvmir(pm)
|
| 321 |
+
passes.common.add_canonicalizer(pm)
|
| 322 |
+
passes.common.add_cse(pm)
|
| 323 |
+
passes.common.add_symbol_dce(pm)
|
| 324 |
+
|
| 325 |
+
if options.schedule_hint.lower() != "none":
|
| 326 |
+
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
|
| 327 |
+
|
| 328 |
+
# This can not be moved below the di_scope pass
|
| 329 |
+
if HIPBackend.instrumentation:
|
| 330 |
+
HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
|
| 331 |
+
|
| 332 |
+
if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
|
| 333 |
+
passes.llvmir.add_di_scope(pm)
|
| 334 |
+
|
| 335 |
+
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
|
| 336 |
+
pm.run(mod, 'make_llir')
|
| 337 |
+
|
| 338 |
+
if knobs.compilation.dump_ir_extract_di_local_variables:
|
| 339 |
+
# comments below on why separate it
|
| 340 |
+
if not knobs.compilation.disable_line_info:
|
| 341 |
+
pm = ir.pass_manager(mod.context)
|
| 342 |
+
pm.enable_debug()
|
| 343 |
+
passes.llvmir.add_di_scope(pm)
|
| 344 |
+
pm.run(mod, 'make_llir.disable_line_info')
|
| 345 |
+
|
| 346 |
+
# insert dbg intrinsic with several DI Attribute including source
|
| 347 |
+
# var name and type info note: unknown reason for now, but this
|
| 348 |
+
# pass and add_di_scope has to be run separately, otherwise if we
|
| 349 |
+
# put them into previous pipline, it trigger a segmentfault without
|
| 350 |
+
# any error message; could be due to a bug in mlir or pybind11
|
| 351 |
+
pm = ir.pass_manager(mod.context)
|
| 352 |
+
pm.enable_debug()
|
| 353 |
+
passes.llvmir.add_di_local_variable(pm)
|
| 354 |
+
pm.run(mod, 'make_llir.dump_ir_extract_di_local_variables')
|
| 355 |
+
|
| 356 |
+
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
|
| 357 |
+
llvm.init_targets()
|
| 358 |
+
context = llvm.context()
|
| 359 |
+
llvm_mod = llvm.to_module(mod, context)
|
| 360 |
+
amd.attach_target_triple(llvm_mod)
|
| 361 |
+
target_features = ''
|
| 362 |
+
if knobs.compilation.enable_asan:
|
| 363 |
+
target_features = '+xnack'
|
| 364 |
+
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
|
| 365 |
+
|
| 366 |
+
# Set various control constants on the LLVM module so that device
|
| 367 |
+
# libraries can resolve references to them.
|
| 368 |
+
amd.set_isa_version(llvm_mod, options.arch)
|
| 369 |
+
amd.set_abi_version(llvm_mod, 500)
|
| 370 |
+
amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False)
|
| 371 |
+
amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True)
|
| 372 |
+
amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False)
|
| 373 |
+
amd.set_bool_control_constant(llvm_mod, "__oclc_wavefrontsize64", options.warp_size == 64)
|
| 374 |
+
|
| 375 |
+
# Set kernel attributes first given this may affect later optimizations.
|
| 376 |
+
fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()]
|
| 377 |
+
# The public kernel should be kernel 0.
|
| 378 |
+
fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL)
|
| 379 |
+
fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}")
|
| 380 |
+
if "memory-bound-attention" in options.schedule_hint.split(','):
|
| 381 |
+
fns[0].add_fn_attr("amdgpu-sched-strategy", "iterative-ilp")
|
| 382 |
+
fns[0].add_fn_attr("uniform-work-group-size", "true")
|
| 383 |
+
# LLVM AMDGPU backend supports the attribute "amdgpu-waves-per-eu"="<min>[, <max>]".
|
| 384 |
+
# This attribute may be attached to a kernel function definition and is an optimization hint.
|
| 385 |
+
# <min> parameter specifies the requested minimum number of waves per EU, and optional <max> parameter
|
| 386 |
+
# specifies the requested maximum number of waves per EU (must be >= <min> if specified).
|
| 387 |
+
# If <max> is omitted, then there is no restriction on the maximum number of waves per EU other than
|
| 388 |
+
# the one dictated by the hardware for which the kernel is compiled. Passing 0, 0 as <min>, <max>
|
| 389 |
+
# implies the default behavior (no limits).
|
| 390 |
+
# Specifying N, N forces LLVM to focus on a single register count, simplifies some heuristics
|
| 391 |
+
# and may improve scheduling.
|
| 392 |
+
fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}, {options.waves_per_eu}")
|
| 393 |
+
denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
|
| 394 |
+
fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
|
| 395 |
+
if knobs.compilation.enable_asan:
|
| 396 |
+
fns[0].add_fn_target_feature("+xnack")
|
| 397 |
+
fns[0].add_fn_asan_attr()
|
| 398 |
+
|
| 399 |
+
# Hint the compiler that we'd like the firmware to set the kernel arguments
|
| 400 |
+
# to user SGPRs so that the kernel does not need to s_load its arguments
|
| 401 |
+
# from memory.
|
| 402 |
+
amd.set_all_fn_arg_inreg(fns[0])
|
| 403 |
+
|
| 404 |
+
if knobs.compilation.enable_asan:
|
| 405 |
+
default_libdir = Path(__file__).parent / 'lib'
|
| 406 |
+
paths = [
|
| 407 |
+
str(default_libdir / 'asanrtl.bc'),
|
| 408 |
+
str(default_libdir / "ocml.bc"),
|
| 409 |
+
str(default_libdir / "ockl.bc")
|
| 410 |
+
]
|
| 411 |
+
llvm.link_extern_libs(llvm_mod, paths)
|
| 412 |
+
elif options.extern_libs:
|
| 413 |
+
paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
|
| 414 |
+
if len(paths) > 0:
|
| 415 |
+
llvm.link_extern_libs(llvm_mod, paths)
|
| 416 |
+
|
| 417 |
+
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
|
| 418 |
+
|
| 419 |
+
# Architectures with architected SGPRs store the workgroup id in ttmp9 (X) and ttmp7 (Y[15:0], Z[31:16]).
|
| 420 |
+
# These attributes are used to determine if Z should be masked out when loading Y. They are inferred during
|
| 421 |
+
# optimize_module from calls to @llvm.amdgcn.workgroup.id.x/y/z(). We cannot rely on this because a
|
| 422 |
+
# dispatch dimensions might be used even if there is no program_id() call for it.
|
| 423 |
+
if amd.has_architected_sgprs(options.arch):
|
| 424 |
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-x")
|
| 425 |
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-y")
|
| 426 |
+
fns[0].remove_fn_attr("amdgpu-no-workgroup-id-z")
|
| 427 |
+
|
| 428 |
+
if knobs.amd.scalarize_packed_fops:
|
| 429 |
+
amd.add_scalarize_packed_fops_llvm_pass(fns[0])
|
| 430 |
+
|
| 431 |
+
# Get some metadata
|
| 432 |
+
metadata["shared"] = src.get_int_attr("ttg.shared")
|
| 433 |
+
metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
|
| 434 |
+
metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
|
| 435 |
+
|
| 436 |
+
amd.cleanup_bitcode_metadata(llvm_mod)
|
| 437 |
+
# Disable inlining of print related functions,
|
| 438 |
+
# because inlining of these function could slow down compilation significantly
|
| 439 |
+
amd.disable_print_inline(llvm_mod)
|
| 440 |
+
return str(llvm_mod)
|
| 441 |
+
|
| 442 |
+
@staticmethod
|
| 443 |
+
def make_amdgcn(src, metadata, options):
|
| 444 |
+
# Find kernel names (there should only be one)
|
| 445 |
+
# We get the name at the last possible step to accommodate `triton.compile`
|
| 446 |
+
# on user-provided LLVM
|
| 447 |
+
names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src)
|
| 448 |
+
assert len(names) == 1
|
| 449 |
+
metadata["name"] = names[0]
|
| 450 |
+
# llvm -> hsaco
|
| 451 |
+
flags = []
|
| 452 |
+
features = '-real-true16' if 'gfx11' in options.arch else ''
|
| 453 |
+
ir_hash = hashlib.sha256(src.encode("utf-8")).hexdigest()
|
| 454 |
+
dump_file_id = names[0] + '_' + ir_hash
|
| 455 |
+
_ = llvm.translate_to_mir(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
|
| 456 |
+
dump_file_id)
|
| 457 |
+
llvm.dump_sched_dag(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
|
| 458 |
+
dump_file_id)
|
| 459 |
+
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
|
| 460 |
+
False)
|
| 461 |
+
if knobs.amd.dump_amdgcn:
|
| 462 |
+
print("// -----// AMDGCN Dump //----- //")
|
| 463 |
+
print(amdgcn)
|
| 464 |
+
return amdgcn
|
| 465 |
+
|
| 466 |
+
@staticmethod
|
| 467 |
+
def make_hsaco(src, metadata, options):
|
| 468 |
+
target_features = ''
|
| 469 |
+
if knobs.compilation.enable_asan:
|
| 470 |
+
target_features = '+xnack'
|
| 471 |
+
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
|
| 472 |
+
with tempfile.NamedTemporaryFile() as tmp_out:
|
| 473 |
+
with tempfile.NamedTemporaryFile() as tmp_in:
|
| 474 |
+
with open(tmp_in.name, "wb") as fd_in:
|
| 475 |
+
fd_in.write(hsaco)
|
| 476 |
+
amd.link_hsaco(tmp_in.name, tmp_out.name)
|
| 477 |
+
with open(tmp_out.name, "rb") as fd_out:
|
| 478 |
+
ret = fd_out.read()
|
| 479 |
+
return ret
|
| 480 |
+
|
| 481 |
+
def add_stages(self, stages, options, language):
|
| 482 |
+
if language == Language.TRITON:
|
| 483 |
+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
|
| 484 |
+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
|
| 485 |
+
elif language == Language.GLUON:
|
| 486 |
+
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
|
| 487 |
+
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
|
| 488 |
+
stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
|
| 489 |
+
stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
|
| 490 |
+
if knobs.runtime.add_stages_inspection_hook is not None:
|
| 491 |
+
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)
|
| 492 |
+
|
| 493 |
+
@functools.lru_cache()
|
| 494 |
+
def hash(self):
|
| 495 |
+
return f'{self.target}'
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/driver.c
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#define __HIP_PLATFORM_AMD__
|
| 2 |
+
#include <hip/hip_runtime.h>
|
| 3 |
+
#include <hip/hip_runtime_api.h>
|
| 4 |
+
#define PY_SSIZE_T_CLEAN
|
| 5 |
+
#include <Python.h>
|
| 6 |
+
#include <dlfcn.h>
|
| 7 |
+
#include <stdbool.h>
|
| 8 |
+
#include <stdio.h>
|
| 9 |
+
#include <stdlib.h>
|
| 10 |
+
|
| 11 |
+
typedef struct {
|
| 12 |
+
uint32_t group0_0;
|
| 13 |
+
uint32_t group0_1;
|
| 14 |
+
uint32_t group0_2;
|
| 15 |
+
uint32_t group0_3;
|
| 16 |
+
uint32_t group1_0;
|
| 17 |
+
uint32_t group1_1;
|
| 18 |
+
uint32_t group1_2;
|
| 19 |
+
uint32_t group1_3;
|
| 20 |
+
uint32_t group1_4;
|
| 21 |
+
uint32_t group1_5;
|
| 22 |
+
uint32_t group1_6;
|
| 23 |
+
uint32_t group1_7;
|
| 24 |
+
} TDMDescriptor;
|
| 25 |
+
|
| 26 |
+
typedef struct {
|
| 27 |
+
PyObject_HEAD;
|
| 28 |
+
TDMDescriptor desc;
|
| 29 |
+
} PyTDMDescriptorObject;
|
| 30 |
+
|
| 31 |
+
static PyObject *PyTDMDescriptor_new(PyTypeObject *type, PyObject *args,
|
| 32 |
+
PyObject *kw) {
|
| 33 |
+
PyTDMDescriptorObject *self =
|
| 34 |
+
(PyTDMDescriptorObject *)type->tp_alloc(type, 0);
|
| 35 |
+
if (!self)
|
| 36 |
+
return NULL;
|
| 37 |
+
|
| 38 |
+
memset(&self->desc, 0, sizeof(self->desc));
|
| 39 |
+
return (PyObject *)self;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
static void PyTDMDescriptor_dealloc(PyTDMDescriptorObject *self) {
|
| 43 |
+
Py_TYPE(self)->tp_free((PyObject *)self);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
static PyTypeObject PyTDMDescriptorType = {
|
| 47 |
+
PyVarObject_HEAD_INIT(NULL, 0).tp_name =
|
| 48 |
+
"triton.backends.amd.PyTDMDescriptor",
|
| 49 |
+
.tp_basicsize = sizeof(PyTDMDescriptorObject),
|
| 50 |
+
.tp_itemsize = 0,
|
| 51 |
+
.tp_flags = Py_TPFLAGS_DEFAULT,
|
| 52 |
+
.tp_doc = "PyObject for TDMDescriptor",
|
| 53 |
+
.tp_new = PyTDMDescriptor_new,
|
| 54 |
+
.tp_dealloc = (destructor)PyTDMDescriptor_dealloc,
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
// TODO: Both host-side and device-side TDM descriptor follow the same encoding
|
| 58 |
+
// format. Consider to add a common utility to remove duplicate code.
|
| 59 |
+
static bool encodeTDMDescriptor(TDMDescriptor *desc, int elementBitWidth,
|
| 60 |
+
uint32_t *blockSize, int numWarps,
|
| 61 |
+
int padInterval, int padAmount, uint32_t *shape,
|
| 62 |
+
uint32_t *strides, uint64_t globalAddress,
|
| 63 |
+
int rank) {
|
| 64 |
+
// NYI: TDM > 2D cases
|
| 65 |
+
if (rank != 2)
|
| 66 |
+
return false;
|
| 67 |
+
|
| 68 |
+
// Get warp distribution
|
| 69 |
+
uint32_t numWarpsDim0 = numWarps;
|
| 70 |
+
for (; numWarpsDim0 > blockSize[0]; numWarpsDim0 /= 2)
|
| 71 |
+
;
|
| 72 |
+
uint32_t numWarpsDim1 = numWarps / numWarpsDim0;
|
| 73 |
+
if (!(numWarpsDim0 > 0 && blockSize[1] % numWarpsDim1 == 0))
|
| 74 |
+
return false;
|
| 75 |
+
|
| 76 |
+
uint32_t blockSize0 = (blockSize[0] + numWarpsDim0 - 1) / numWarpsDim0;
|
| 77 |
+
uint32_t blockSize1 = (blockSize[1] + numWarpsDim1 - 1) / numWarpsDim1;
|
| 78 |
+
|
| 79 |
+
// group0 (128 bits / 4 dwords) effective bit encoding:
|
| 80 |
+
// [120:64]: global address
|
| 81 |
+
// [127:126]: type - currently always set to 0x2
|
| 82 |
+
desc->group0_2 = (uint32_t)(globalAddress & 0xFFFFFFFF);
|
| 83 |
+
desc->group0_3 = (uint32_t)((globalAddress >> 32) & 0x01FFFFFF);
|
| 84 |
+
desc->group0_3 |= (0x1 << 31);
|
| 85 |
+
|
| 86 |
+
// group1 (256 bits / 8 dwords) effective bit encoding:
|
| 87 |
+
// [17:16]: data size - log2(element size in bytes)
|
| 88 |
+
// [20]: enable padding
|
| 89 |
+
// [24:22]: pad interval - log2(pad interval in dwords) - 1
|
| 90 |
+
// [31:25]: pad amount - pad amount in dwords - 1
|
| 91 |
+
// [79:48]: tensor shape dim inner
|
| 92 |
+
// [111:80]: tensor shape dim outer
|
| 93 |
+
// [127:112]: block shape dim inner
|
| 94 |
+
// [143:128]: block shape dim outer
|
| 95 |
+
// [207:160]: tensor stride dim outer (we only use 32 bits)
|
| 96 |
+
int elementSizeInBytes = elementBitWidth / 8;
|
| 97 |
+
int dataSize = log2(elementSizeInBytes);
|
| 98 |
+
desc->group1_0 = (dataSize << 16);
|
| 99 |
+
int dwordSize = 32;
|
| 100 |
+
int padIntervalInDwords = padInterval * elementBitWidth / dwordSize;
|
| 101 |
+
int padAmountInDwords = padAmount * elementBitWidth / dwordSize;
|
| 102 |
+
if (padIntervalInDwords > 0 && padAmountInDwords > 0) {
|
| 103 |
+
int log2PadInterval = log2(padIntervalInDwords);
|
| 104 |
+
desc->group1_0 |= (1 << 20);
|
| 105 |
+
desc->group1_0 |= ((log2PadInterval - 1) << 22);
|
| 106 |
+
desc->group1_0 |= ((padAmountInDwords - 1) << 25);
|
| 107 |
+
}
|
| 108 |
+
desc->group1_1 = (shape[1] << 16);
|
| 109 |
+
desc->group1_2 = (shape[1] >> 16);
|
| 110 |
+
desc->group1_2 |= (shape[0] << 16);
|
| 111 |
+
desc->group1_3 = (shape[0] >> 16);
|
| 112 |
+
desc->group1_3 |= (blockSize1 << 16);
|
| 113 |
+
desc->group1_4 = (blockSize0 & 0xFFFF);
|
| 114 |
+
desc->group1_5 = strides[0];
|
| 115 |
+
|
| 116 |
+
return true;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// The list of paths to search for the HIP runtime library. The caller Python
|
| 120 |
+
// code should substitute the search path placeholder.
|
| 121 |
+
static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
|
| 122 |
+
|
| 123 |
+
// The list of HIP dynamic library symbols and their signature we are interested
|
| 124 |
+
// in this file.
|
| 125 |
+
// |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
|
| 126 |
+
// |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
|
| 127 |
+
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
|
| 128 |
+
FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
|
| 129 |
+
FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_t *prop, int deviceId) \
|
| 130 |
+
FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
|
| 131 |
+
unsigned int numOptions, hipJitOption *options, \
|
| 132 |
+
void **optionValues) \
|
| 133 |
+
FOR_EACH_ERR_FN(hipModuleGetFunction, hipFunction_t *function, \
|
| 134 |
+
hipModule_t module, const char *kname) \
|
| 135 |
+
FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
|
| 136 |
+
hipFunction_t function)
|
| 137 |
+
|
| 138 |
+
// HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
|
| 139 |
+
// 100000 + HIP_VERSION_PATCH.
|
| 140 |
+
#define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version) ((version) / 10000000)
|
| 141 |
+
#define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version) \
|
| 142 |
+
(((version) % 10000000) / 100000)
|
| 143 |
+
#define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version) ((version) % 100000)
|
| 144 |
+
#define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (6)
|
| 145 |
+
|
| 146 |
+
// #define TRITON_HIP_DRIVER_DBG_VERSION
|
| 147 |
+
#ifdef TRITON_HIP_DRIVER_DBG_VERSION
|
| 148 |
+
#define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
|
| 149 |
+
do { \
|
| 150 |
+
snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
|
| 151 |
+
TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
|
| 152 |
+
TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
|
| 153 |
+
TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
|
| 154 |
+
printf("%s\n", msgBuff); \
|
| 155 |
+
} while (0);
|
| 156 |
+
#else
|
| 157 |
+
#define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
|
| 158 |
+
do { \
|
| 159 |
+
(void)msgBuff; \
|
| 160 |
+
(void)(version); \
|
| 161 |
+
} while (0);
|
| 162 |
+
#endif
|
| 163 |
+
|
| 164 |
+
#define TRITON_HIP_MSG_BUFF_SIZE (1024U)
|
| 165 |
+
|
| 166 |
+
// The HIP symbol table for holding resolved dynamic library symbols.
|
| 167 |
+
struct HIPSymbolTable {
|
| 168 |
+
#define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
|
| 169 |
+
hipError_t (*hipSymbolName)(__VA_ARGS__);
|
| 170 |
+
#define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \
|
| 171 |
+
const char *(*hipSymbolName)(__VA_ARGS__);
|
| 172 |
+
|
| 173 |
+
HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
|
| 174 |
+
};
|
| 175 |
+
|
| 176 |
+
static struct HIPSymbolTable hipSymbolTable;
|
| 177 |
+
|
| 178 |
+
static int checkDriverVersion(void *lib) {
|
| 179 |
+
int hipVersion = -1;
|
| 180 |
+
const char *error = NULL;
|
| 181 |
+
typedef hipError_t (*hipDriverGetVersion_fn)(int *driverVersion);
|
| 182 |
+
hipDriverGetVersion_fn hipDriverGetVersion;
|
| 183 |
+
dlerror(); // Clear existing errors
|
| 184 |
+
hipDriverGetVersion =
|
| 185 |
+
(hipDriverGetVersion_fn)dlsym(lib, "hipDriverGetVersion");
|
| 186 |
+
error = dlerror();
|
| 187 |
+
if (error) {
|
| 188 |
+
PyErr_SetString(PyExc_RuntimeError,
|
| 189 |
+
"cannot query 'hipDriverGetVersion' from libamdhip64.so");
|
| 190 |
+
dlclose(lib);
|
| 191 |
+
return -1;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
(void)hipDriverGetVersion(&hipVersion);
|
| 195 |
+
char msgBuff[TRITON_HIP_MSG_BUFF_SIZE] = {0};
|
| 196 |
+
|
| 197 |
+
const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(hipVersion);
|
| 198 |
+
if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION) {
|
| 199 |
+
const int hipMinVersion =
|
| 200 |
+
TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(hipVersion);
|
| 201 |
+
const int hipPatchVersion =
|
| 202 |
+
TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(hipVersion);
|
| 203 |
+
snprintf(msgBuff, sizeof(msgBuff),
|
| 204 |
+
"libamdhip64 version %d.%d.%d is not supported! Required major "
|
| 205 |
+
"version is >=%d.",
|
| 206 |
+
hipMajVersion, hipMinVersion, hipPatchVersion,
|
| 207 |
+
TRITON_HIP_DRIVER_REQ_MAJOR_VERSION);
|
| 208 |
+
PyErr_SetString(PyExc_RuntimeError, msgBuff);
|
| 209 |
+
dlclose(lib);
|
| 210 |
+
return -1;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
TRITON_HIP_DRIVER_LOG_VERSION(hipVersion, msgBuff);
|
| 214 |
+
|
| 215 |
+
return hipVersion;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
bool initSymbolTable() {
|
| 219 |
+
void *lib;
|
| 220 |
+
|
| 221 |
+
// Go through the list of search paths to dlopen the first HIP driver library.
|
| 222 |
+
int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
|
| 223 |
+
for (int i = 0; i < n; ++i) {
|
| 224 |
+
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
|
| 225 |
+
if (handle) {
|
| 226 |
+
lib = handle;
|
| 227 |
+
// printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
if (!lib) {
|
| 232 |
+
PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
|
| 233 |
+
return false;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
int hipVersion = checkDriverVersion(lib);
|
| 237 |
+
if (hipVersion == -1)
|
| 238 |
+
return false;
|
| 239 |
+
|
| 240 |
+
const char *error = NULL;
|
| 241 |
+
typedef hipError_t (*hipGetProcAddress_fn)(
|
| 242 |
+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
|
| 243 |
+
hipDriverProcAddressQueryResult *symbolStatus);
|
| 244 |
+
hipGetProcAddress_fn hipGetProcAddress;
|
| 245 |
+
dlerror(); // Clear existing errors
|
| 246 |
+
|
| 247 |
+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
|
| 248 |
+
error = dlerror();
|
| 249 |
+
if (error) {
|
| 250 |
+
PyErr_SetString(PyExc_RuntimeError,
|
| 251 |
+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
|
| 252 |
+
dlclose(lib);
|
| 253 |
+
return false;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
// Resolve all symbols we are interested in.
|
| 257 |
+
uint64_t hipFlags = 0;
|
| 258 |
+
hipDriverProcAddressQueryResult symbolStatus;
|
| 259 |
+
hipError_t status = hipSuccess;
|
| 260 |
+
#define QUERY_EACH_FN(hipSymbolName, ...) \
|
| 261 |
+
status = hipGetProcAddress(#hipSymbolName, \
|
| 262 |
+
(void **)&hipSymbolTable.hipSymbolName, \
|
| 263 |
+
hipVersion, hipFlags, &symbolStatus); \
|
| 264 |
+
if (status != hipSuccess) { \
|
| 265 |
+
PyErr_SetString(PyExc_RuntimeError, \
|
| 266 |
+
"cannot get address for '" #hipSymbolName \
|
| 267 |
+
"' from libamdhip64.so"); \
|
| 268 |
+
dlclose(lib); \
|
| 269 |
+
return false; \
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
|
| 273 |
+
|
| 274 |
+
return true;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
static inline void gpuAssert(hipError_t code, const char *file, int line) {
|
| 278 |
+
{
|
| 279 |
+
if (code != HIP_SUCCESS) {
|
| 280 |
+
{
|
| 281 |
+
const char *prefix = "Triton Error [HIP]: ";
|
| 282 |
+
const char *str = hipSymbolTable.hipGetErrorString(code);
|
| 283 |
+
char err[TRITON_HIP_MSG_BUFF_SIZE] = {0};
|
| 284 |
+
snprintf(err, sizeof(err), "%s Code: %d, Messsage: %s", prefix, code,
|
| 285 |
+
str);
|
| 286 |
+
PyGILState_STATE gil_state;
|
| 287 |
+
gil_state = PyGILState_Ensure();
|
| 288 |
+
PyErr_SetString(PyExc_RuntimeError, err);
|
| 289 |
+
PyGILState_Release(gil_state);
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
#define HIP_CHECK(ans) \
|
| 296 |
+
{ \
|
| 297 |
+
gpuAssert((ans), __FILE__, __LINE__); \
|
| 298 |
+
if (PyErr_Occurred()) \
|
| 299 |
+
return NULL; \
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
| 303 |
+
int device_id;
|
| 304 |
+
if (!PyArg_ParseTuple(args, "i", &device_id))
|
| 305 |
+
return NULL;
|
| 306 |
+
|
| 307 |
+
hipDeviceProp_t props;
|
| 308 |
+
HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
|
| 309 |
+
|
| 310 |
+
// create a struct to hold device properties
|
| 311 |
+
return Py_BuildValue(
|
| 312 |
+
"{s:i, s:i, s:i, s:i, s:i, s:i, s:s, s:i, s:i}", "max_shared_mem",
|
| 313 |
+
props.sharedMemPerBlock, "max_num_regs", props.regsPerBlock,
|
| 314 |
+
"multiprocessor_count", props.multiProcessorCount, "sm_clock_rate",
|
| 315 |
+
props.clockRate, "mem_clock_rate", props.memoryClockRate, "mem_bus_width",
|
| 316 |
+
props.memoryBusWidth, "arch", props.gcnArchName, "warpSize",
|
| 317 |
+
props.warpSize, "max_threads_per_sm", props.maxThreadsPerMultiProcessor);
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
| 321 |
+
const char *name;
|
| 322 |
+
const char *data;
|
| 323 |
+
Py_ssize_t data_size;
|
| 324 |
+
int shared;
|
| 325 |
+
int device;
|
| 326 |
+
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
|
| 327 |
+
&device)) {
|
| 328 |
+
return NULL;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
// set HIP options
|
| 332 |
+
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
|
| 333 |
+
hipJitOptionErrorLogBuffer,
|
| 334 |
+
hipJitOptionInfoLogBufferSizeBytes,
|
| 335 |
+
hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
|
| 336 |
+
const unsigned int errbufsize = 8192;
|
| 337 |
+
const unsigned int logbufsize = 8192;
|
| 338 |
+
char _err[errbufsize];
|
| 339 |
+
char _log[logbufsize];
|
| 340 |
+
void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
|
| 341 |
+
(void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
|
| 342 |
+
|
| 343 |
+
// launch HIP Binary
|
| 344 |
+
hipModule_t mod;
|
| 345 |
+
hipFunction_t fun;
|
| 346 |
+
HIP_CHECK(hipSymbolTable.hipModuleLoadDataEx(&mod, data, 5, opt, optval))
|
| 347 |
+
HIP_CHECK(hipSymbolTable.hipModuleGetFunction(&fun, mod, name));
|
| 348 |
+
|
| 349 |
+
// get allocated registers and spilled registers from the function
|
| 350 |
+
int n_regs = 0;
|
| 351 |
+
int n_spills = 0;
|
| 352 |
+
int32_t n_max_threads = 0;
|
| 353 |
+
hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
| 354 |
+
hipSymbolTable.hipFuncGetAttribute(&n_spills,
|
| 355 |
+
HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
| 356 |
+
hipSymbolTable.hipFuncGetAttribute(
|
| 357 |
+
&n_max_threads, HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun);
|
| 358 |
+
n_spills /= 4;
|
| 359 |
+
if (PyErr_Occurred()) {
|
| 360 |
+
return NULL;
|
| 361 |
+
}
|
| 362 |
+
return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
| 363 |
+
n_spills, n_max_threads);
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
static PyObject *createTDMDescriptor(PyObject *self, PyObject *args) {
|
| 367 |
+
int elementBitWidth;
|
| 368 |
+
PyObject *blockSize;
|
| 369 |
+
int numWarps;
|
| 370 |
+
int padInterval;
|
| 371 |
+
int padAmount;
|
| 372 |
+
PyObject *shape;
|
| 373 |
+
PyObject *strides;
|
| 374 |
+
unsigned long long globalAddress;
|
| 375 |
+
|
| 376 |
+
if (!PyArg_ParseTuple(args, "iOiiiOOK", &elementBitWidth, &blockSize,
|
| 377 |
+
&numWarps, &padInterval, &padAmount, &shape, &strides,
|
| 378 |
+
&globalAddress)) {
|
| 379 |
+
return NULL;
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
PyTDMDescriptorObject *descObj = (PyTDMDescriptorObject *)PyObject_CallObject(
|
| 383 |
+
(PyObject *)&PyTDMDescriptorType, NULL);
|
| 384 |
+
if (!descObj)
|
| 385 |
+
return NULL;
|
| 386 |
+
|
| 387 |
+
PyObject *blockSizeFast = NULL;
|
| 388 |
+
PyObject *shapeFast = NULL;
|
| 389 |
+
PyObject *stridesFast = NULL;
|
| 390 |
+
|
| 391 |
+
uint32_t blockSizeInt[2];
|
| 392 |
+
uint32_t shapeInt[2];
|
| 393 |
+
uint32_t stridesInt[2];
|
| 394 |
+
|
| 395 |
+
blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
|
| 396 |
+
if (!blockSizeFast)
|
| 397 |
+
goto cleanup;
|
| 398 |
+
int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
|
| 399 |
+
if (rank != 2) {
|
| 400 |
+
PyErr_SetString(PyExc_RuntimeError, "rank must be 2");
|
| 401 |
+
goto cleanup;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
for (int i = 0; i < rank; ++i) {
|
| 405 |
+
PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
|
| 406 |
+
if (!PyLong_Check(item)) {
|
| 407 |
+
PyErr_SetString(PyExc_TypeError, "block size must be an int");
|
| 408 |
+
goto cleanup;
|
| 409 |
+
}
|
| 410 |
+
blockSizeInt[i] = PyLong_AsLong(item);
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
shapeFast = PySequence_Fast(shape, "shape must be a sequence");
|
| 414 |
+
if (!shapeFast)
|
| 415 |
+
goto cleanup;
|
| 416 |
+
|
| 417 |
+
if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
|
| 418 |
+
PyErr_SetString(PyExc_RuntimeError, "rank mismatch");
|
| 419 |
+
goto cleanup;
|
| 420 |
+
}
|
| 421 |
+
for (int i = 0; i < rank; ++i) {
|
| 422 |
+
PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
|
| 423 |
+
if (!PyLong_Check(item)) {
|
| 424 |
+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
|
| 425 |
+
goto cleanup;
|
| 426 |
+
}
|
| 427 |
+
shapeInt[i] = PyLong_AsLong(item);
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
stridesFast = PySequence_Fast(strides, "strides must be a sequence");
|
| 431 |
+
if (!stridesFast)
|
| 432 |
+
goto cleanup;
|
| 433 |
+
|
| 434 |
+
if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
|
| 435 |
+
PyErr_SetString(PyExc_RuntimeError, "rank mismatch");
|
| 436 |
+
goto cleanup;
|
| 437 |
+
}
|
| 438 |
+
for (int i = 0; i < rank; ++i) {
|
| 439 |
+
PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
|
| 440 |
+
if (!PyLong_Check(item)) {
|
| 441 |
+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
|
| 442 |
+
goto cleanup;
|
| 443 |
+
}
|
| 444 |
+
stridesInt[i] = PyLong_AsLong(item);
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
Py_DECREF(blockSizeFast);
|
| 448 |
+
blockSizeFast = NULL;
|
| 449 |
+
Py_DECREF(shapeFast);
|
| 450 |
+
shapeFast = NULL;
|
| 451 |
+
Py_DECREF(stridesFast);
|
| 452 |
+
stridesFast = NULL;
|
| 453 |
+
|
| 454 |
+
bool success = encodeTDMDescriptor(
|
| 455 |
+
&descObj->desc, elementBitWidth, blockSizeInt, numWarps, padInterval,
|
| 456 |
+
padAmount, shapeInt, stridesInt, globalAddress, rank);
|
| 457 |
+
if (!success) {
|
| 458 |
+
PyErr_SetString(PyExc_RuntimeError, "Failed to encode TDM descriptor");
|
| 459 |
+
goto cleanup;
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
return (PyObject *)descObj;
|
| 463 |
+
|
| 464 |
+
cleanup:
|
| 465 |
+
Py_XDECREF(blockSizeFast);
|
| 466 |
+
Py_XDECREF(shapeFast);
|
| 467 |
+
Py_XDECREF(stridesFast);
|
| 468 |
+
Py_XDECREF(descObj);
|
| 469 |
+
return NULL;
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
static PyMethodDef ModuleMethods[] = {
|
| 473 |
+
{"load_binary", loadBinary, METH_VARARGS,
|
| 474 |
+
"Load provided hsaco into HIP driver"},
|
| 475 |
+
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
| 476 |
+
"Get the properties for a given device"},
|
| 477 |
+
{"create_tdm_descriptor", createTDMDescriptor, METH_VARARGS,
|
| 478 |
+
"create a host-side TDM descriptor"},
|
| 479 |
+
{NULL, NULL, 0, NULL} // sentinel
|
| 480 |
+
};
|
| 481 |
+
|
| 482 |
+
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
|
| 483 |
+
NULL, // documentation
|
| 484 |
+
-1, // size
|
| 485 |
+
ModuleMethods};
|
| 486 |
+
|
| 487 |
+
PyMODINIT_FUNC PyInit_hip_utils(void) {
|
| 488 |
+
if (!initSymbolTable()) {
|
| 489 |
+
return NULL;
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
PyObject *m = PyModule_Create(&ModuleDef);
|
| 493 |
+
if (m == NULL) {
|
| 494 |
+
return NULL;
|
| 495 |
+
}
|
| 496 |
+
PyModule_AddFunctions(m, ModuleMethods);
|
| 497 |
+
|
| 498 |
+
if (PyType_Ready(&PyTDMDescriptorType) < 0)
|
| 499 |
+
return NULL;
|
| 500 |
+
Py_INCREF(&PyTDMDescriptorType);
|
| 501 |
+
PyModule_AddObject(m, "PyTDMDescriptor", (PyObject *)&PyTDMDescriptorType);
|
| 502 |
+
|
| 503 |
+
return m;
|
| 504 |
+
}
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/driver.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import re
|
| 5 |
+
import triton
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from triton import knobs
|
| 8 |
+
from triton.backends.compiler import GPUTarget
|
| 9 |
+
from triton.backends.driver import GPUDriver
|
| 10 |
+
from triton.runtime import _allocation
|
| 11 |
+
from triton.runtime.build import compile_module_from_src
|
| 12 |
+
|
| 13 |
+
dirname = os.path.dirname(os.path.realpath(__file__))
|
| 14 |
+
include_dirs = [os.path.join(dirname, "include")]
|
| 15 |
+
PyTDMDescriptor = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _find_already_mmapped_dylib_on_linux(lib_name):
|
| 19 |
+
import platform
|
| 20 |
+
if platform.system() != 'Linux':
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
# Use dl_iterate_phdr to walk through the list of shared libraries at runtime.
|
| 24 |
+
# See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details.
|
| 25 |
+
|
| 26 |
+
import ctypes
|
| 27 |
+
from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER
|
| 28 |
+
|
| 29 |
+
class DlPhdrInfo(ctypes.Structure):
|
| 30 |
+
_fields_ = [
|
| 31 |
+
('dlpi_addr', c_void_p),
|
| 32 |
+
('dlpi_name', c_char_p),
|
| 33 |
+
# We don't care about the remaining fields.
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
# callback_t must use POINTER(c_char) to avoid copying.
|
| 37 |
+
callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char))
|
| 38 |
+
|
| 39 |
+
# Load libc and get the dl_iterate_phdr symbol.
|
| 40 |
+
try:
|
| 41 |
+
dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr
|
| 42 |
+
except Exception:
|
| 43 |
+
return None
|
| 44 |
+
# argtypes must use c_char_p to accept create_string_buffer.
|
| 45 |
+
dl_iterate_phdr.argtypes = [callback_t, c_char_p]
|
| 46 |
+
dl_iterate_phdr.restype = c_int
|
| 47 |
+
|
| 48 |
+
max_path_length = 4096
|
| 49 |
+
path = ctypes.create_string_buffer(max_path_length + 1)
|
| 50 |
+
|
| 51 |
+
# Define callback to get the loaded dylib path.
|
| 52 |
+
def callback(info, size, data):
|
| 53 |
+
dlpi_name = info.contents.dlpi_name
|
| 54 |
+
p = Path(os.fsdecode(dlpi_name))
|
| 55 |
+
if lib_name in p.name:
|
| 56 |
+
# Found the dylib; get its path.
|
| 57 |
+
ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name)))
|
| 58 |
+
return 1
|
| 59 |
+
return 0
|
| 60 |
+
|
| 61 |
+
if dl_iterate_phdr(callback_t(callback), path):
|
| 62 |
+
return os.fsdecode(ctypes.string_at(path))
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@functools.lru_cache()
|
| 67 |
+
def _get_path_to_hip_runtime_dylib():
|
| 68 |
+
lib_name = "libamdhip64.so"
|
| 69 |
+
|
| 70 |
+
# If we are told explicitly what HIP runtime dynamic library to use, obey that.
|
| 71 |
+
if env_libhip_path := knobs.amd.libhip_path:
|
| 72 |
+
if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
|
| 73 |
+
return env_libhip_path
|
| 74 |
+
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
|
| 75 |
+
|
| 76 |
+
# If the shared object is already mmapped to address space, use it.
|
| 77 |
+
mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name)
|
| 78 |
+
if mmapped_path:
|
| 79 |
+
if os.path.exists(mmapped_path):
|
| 80 |
+
return mmapped_path
|
| 81 |
+
raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
|
| 82 |
+
|
| 83 |
+
paths = []
|
| 84 |
+
|
| 85 |
+
# Check backend
|
| 86 |
+
local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
|
| 87 |
+
if os.path.exists(local_lib):
|
| 88 |
+
return local_lib
|
| 89 |
+
paths.append(local_lib)
|
| 90 |
+
|
| 91 |
+
import site
|
| 92 |
+
# First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
|
| 93 |
+
# that we run Triton together with PyTorch. This makes sure we use the same dynamic
|
| 94 |
+
# library to avoid version mismatch.
|
| 95 |
+
site_packages = site.getsitepackages()
|
| 96 |
+
user_site = site.getusersitepackages()
|
| 97 |
+
if site.ENABLE_USER_SITE: # ENABLE_USER_SITE is initialized in getusersitepackages()
|
| 98 |
+
site_packages = [user_site] + site_packages
|
| 99 |
+
for path in site_packages:
|
| 100 |
+
path = os.path.join(path, "torch", "lib", lib_name)
|
| 101 |
+
if os.path.exists(path):
|
| 102 |
+
return path
|
| 103 |
+
paths.append(path)
|
| 104 |
+
|
| 105 |
+
# Then try to see if developer provides a HIP runtime dynamic library using LD_LIBARAY_PATH.
|
| 106 |
+
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
|
| 107 |
+
if env_ld_library_path:
|
| 108 |
+
for d in env_ld_library_path.split(":"):
|
| 109 |
+
f = os.path.join(d, lib_name)
|
| 110 |
+
if os.path.exists(f):
|
| 111 |
+
return f
|
| 112 |
+
paths.append(f)
|
| 113 |
+
|
| 114 |
+
# HIP_PATH should point to HIP SDK root if set
|
| 115 |
+
env_hip_path = os.getenv("HIP_PATH")
|
| 116 |
+
if env_hip_path:
|
| 117 |
+
hip_lib_path = os.path.join(env_hip_path, "lib", lib_name)
|
| 118 |
+
if os.path.exists(hip_lib_path):
|
| 119 |
+
return hip_lib_path
|
| 120 |
+
paths.append(hip_lib_path)
|
| 121 |
+
|
| 122 |
+
# if available, `hipconfig --path` prints the HIP SDK root
|
| 123 |
+
try:
|
| 124 |
+
hip_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
|
| 125 |
+
if hip_root:
|
| 126 |
+
hip_lib_path = os.path.join(hip_root, "lib", lib_name)
|
| 127 |
+
if os.path.exists(hip_lib_path):
|
| 128 |
+
return hip_lib_path
|
| 129 |
+
paths.append(hip_lib_path)
|
| 130 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 131 |
+
# hipconfig may not be available
|
| 132 |
+
pass
|
| 133 |
+
|
| 134 |
+
# ROCm lib dir based on env var
|
| 135 |
+
env_rocm_path = os.getenv("ROCM_PATH")
|
| 136 |
+
if env_rocm_path:
|
| 137 |
+
rocm_lib_path = os.path.join(env_rocm_path, "lib", lib_name)
|
| 138 |
+
if os.path.exists(rocm_lib_path):
|
| 139 |
+
return rocm_lib_path
|
| 140 |
+
paths.append(rocm_lib_path)
|
| 141 |
+
|
| 142 |
+
# Afterwards try to search the loader dynamic library resolution paths.
|
| 143 |
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
|
| 144 |
+
# each line looks like the following:
|
| 145 |
+
# libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
|
| 146 |
+
# libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
|
| 147 |
+
locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)]
|
| 148 |
+
for loc in locs:
|
| 149 |
+
if os.path.exists(loc):
|
| 150 |
+
return loc
|
| 151 |
+
paths.append(loc)
|
| 152 |
+
|
| 153 |
+
# As a last resort, guess if we have it in some common installation path.
|
| 154 |
+
common_install_path = os.path.join('/opt/rocm/lib/', lib_name)
|
| 155 |
+
if os.path.exists(common_install_path):
|
| 156 |
+
return common_install_path
|
| 157 |
+
paths.append(common_install_path)
|
| 158 |
+
|
| 159 |
+
raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class HIPUtils(object):
|
| 163 |
+
|
| 164 |
+
def __new__(cls):
|
| 165 |
+
if not hasattr(cls, "instance"):
|
| 166 |
+
cls.instance = super(HIPUtils, cls).__new__(cls)
|
| 167 |
+
return cls.instance
|
| 168 |
+
|
| 169 |
+
def __init__(self):
|
| 170 |
+
libhip_path = _get_path_to_hip_runtime_dylib()
|
| 171 |
+
src = Path(os.path.join(dirname, "driver.c")).read_text()
|
| 172 |
+
# Just do a simple search and replace here instead of templates or format strings.
|
| 173 |
+
# This way we don't need to escape-quote C code curly brackets and we can replace
|
| 174 |
+
# exactly once.
|
| 175 |
+
src = src.replace('/*py_libhip_search_path*/', libhip_path, 1)
|
| 176 |
+
mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
|
| 177 |
+
self.load_binary = mod.load_binary
|
| 178 |
+
self.get_device_properties = mod.get_device_properties
|
| 179 |
+
self.create_tdm_descriptor = mod.create_tdm_descriptor
|
| 180 |
+
global PyTDMDescriptor
|
| 181 |
+
PyTDMDescriptor = mod.PyTDMDescriptor
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# -------------------- Launcher ----------------------------
|
| 185 |
+
def ty_to_cpp(ty):
|
| 186 |
+
if ty.startswith('*'):
|
| 187 |
+
return "hipDeviceptr_t"
|
| 188 |
+
if ty == "tensordesc":
|
| 189 |
+
return "TDMDescriptor"
|
| 190 |
+
return {
|
| 191 |
+
"i1": "int8_t",
|
| 192 |
+
"i8": "int8_t",
|
| 193 |
+
"i16": "int16_t",
|
| 194 |
+
"i32": "int32_t",
|
| 195 |
+
"i64": "int64_t",
|
| 196 |
+
"u1": "uint8_t",
|
| 197 |
+
"u8": "uint8_t",
|
| 198 |
+
"u16": "uint16_t",
|
| 199 |
+
"u32": "uint32_t",
|
| 200 |
+
"u64": "uint64_t",
|
| 201 |
+
"fp16": "double",
|
| 202 |
+
"bf16": "double",
|
| 203 |
+
"fp32": "double",
|
| 204 |
+
"f32": "double",
|
| 205 |
+
"fp64": "double",
|
| 206 |
+
}[ty]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
FLOAT_STORAGE_TYPE = {
|
| 210 |
+
"fp16": "uint16_t",
|
| 211 |
+
"bf16": "uint16_t",
|
| 212 |
+
"fp32": "uint32_t",
|
| 213 |
+
"f32": "uint32_t",
|
| 214 |
+
"fp64": "uint64_t",
|
| 215 |
+
}
|
| 216 |
+
FLOAT_PACK_FUNCTION = {
|
| 217 |
+
"fp16": "pack_fp16",
|
| 218 |
+
"bf16": "pack_bf16",
|
| 219 |
+
"fp32": "pack_fp32",
|
| 220 |
+
"f32": "pack_fp32",
|
| 221 |
+
"fp64": "pack_fp64",
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
_BASE_ARGS_FORMAT = "piiiKKOOOOO"
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def make_launcher(constants, signature, warp_size, tensordesc_meta):
|
| 228 |
+
|
| 229 |
+
def _expand_signature(signature):
|
| 230 |
+
output = []
|
| 231 |
+
tensordesc_idx = 0
|
| 232 |
+
for sig in signature:
|
| 233 |
+
if isinstance(sig, str) and sig.startswith("tensordesc"):
|
| 234 |
+
meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
|
| 235 |
+
tensordesc_idx += 1
|
| 236 |
+
|
| 237 |
+
match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
|
| 238 |
+
dtype = match.group(1)
|
| 239 |
+
shape = match.group(2)
|
| 240 |
+
ndim = shape.count(",") + 1
|
| 241 |
+
|
| 242 |
+
# If there is no descriptor's metadata, the descriptor has been decomposed to base pointer, shape and strides
|
| 243 |
+
if meta is None:
|
| 244 |
+
output.append("*" + dtype)
|
| 245 |
+
for _ in range(2 * ndim):
|
| 246 |
+
output.append("i64")
|
| 247 |
+
output.append("i1")
|
| 248 |
+
else:
|
| 249 |
+
output.append("tensordesc")
|
| 250 |
+
|
| 251 |
+
for _ in range(ndim):
|
| 252 |
+
output.append("i32")
|
| 253 |
+
for _ in range(ndim):
|
| 254 |
+
output.append("i64")
|
| 255 |
+
else:
|
| 256 |
+
output.append(sig)
|
| 257 |
+
|
| 258 |
+
return output
|
| 259 |
+
|
| 260 |
+
def _serialize_signature(sig):
|
| 261 |
+
if isinstance(sig, tuple):
|
| 262 |
+
return ','.join(map(_serialize_signature, sig))
|
| 263 |
+
return sig
|
| 264 |
+
|
| 265 |
+
def _extracted_type(ty):
|
| 266 |
+
if isinstance(ty, tuple):
|
| 267 |
+
val = ','.join(map(_extracted_type, ty))
|
| 268 |
+
return f"[{val}]"
|
| 269 |
+
if ty.startswith("*") or ty.startswith("tensordesc"):
|
| 270 |
+
return "PyObject*"
|
| 271 |
+
if ty == "constexpr":
|
| 272 |
+
return "PyObject*"
|
| 273 |
+
return ty_to_cpp(ty)
|
| 274 |
+
|
| 275 |
+
def format_of(ty):
|
| 276 |
+
if isinstance(ty, tuple):
|
| 277 |
+
val = ''.join(map(format_of, ty))
|
| 278 |
+
return f"({val})"
|
| 279 |
+
if ty.startswith("*") or ty.startswith("tensordesc"):
|
| 280 |
+
return "O"
|
| 281 |
+
if ty == "constexpr":
|
| 282 |
+
return "O"
|
| 283 |
+
return {
|
| 284 |
+
"double": "d",
|
| 285 |
+
"long": "l",
|
| 286 |
+
"int8_t": "b",
|
| 287 |
+
"int16_t": "h",
|
| 288 |
+
"int32_t": "i",
|
| 289 |
+
"int64_t": "L",
|
| 290 |
+
"uint8_t": "B",
|
| 291 |
+
"uint16_t": "H",
|
| 292 |
+
"uint32_t": "I",
|
| 293 |
+
"uint64_t": "K",
|
| 294 |
+
}[ty_to_cpp(ty)]
|
| 295 |
+
|
| 296 |
+
signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
|
| 297 |
+
|
| 298 |
+
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
| 299 |
+
format = _BASE_ARGS_FORMAT + args_format
|
| 300 |
+
signature = ','.join(map(_serialize_signature, signature.values()))
|
| 301 |
+
signature = list(filter(bool, signature.split(',')))
|
| 302 |
+
signature = {i: s for i, s in enumerate(signature)}
|
| 303 |
+
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
| 304 |
+
# Record the end of regular arguments;
|
| 305 |
+
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
| 306 |
+
arg_decl_list = []
|
| 307 |
+
for i, ty in signature.items():
|
| 308 |
+
if ty == "constexpr":
|
| 309 |
+
continue
|
| 310 |
+
if ty in FLOAT_STORAGE_TYPE:
|
| 311 |
+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
|
| 312 |
+
else:
|
| 313 |
+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
|
| 314 |
+
arg_decls = ', '.join(arg_decl_list)
|
| 315 |
+
internal_args_list = []
|
| 316 |
+
for i, ty in signature.items():
|
| 317 |
+
if ty.startswith("*"):
|
| 318 |
+
internal_args_list.append(f"ptr_info{i}.dev_ptr")
|
| 319 |
+
elif ty.startswith("tensordesc"):
|
| 320 |
+
internal_args_list.append(f"*desc{i}")
|
| 321 |
+
elif ty in FLOAT_STORAGE_TYPE:
|
| 322 |
+
internal_args_list.append(f"_arg{i}_storage")
|
| 323 |
+
elif ty != "constexpr":
|
| 324 |
+
internal_args_list.append(f"_arg{i}")
|
| 325 |
+
|
| 326 |
+
newline = '\n '
|
| 327 |
+
ptr_decls = [
|
| 328 |
+
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
|
| 329 |
+
for i, ty in signature.items()
|
| 330 |
+
if ty.startswith("*")
|
| 331 |
+
]
|
| 332 |
+
tensor_desc_decls = [
|
| 333 |
+
f"TDMDescriptor* desc{i} = getTDMDescriptor(_arg{i}, {i});" for i, ty in signature.items()
|
| 334 |
+
if ty.startswith("tensordesc")
|
| 335 |
+
]
|
| 336 |
+
float_storage_decls = [
|
| 337 |
+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
|
| 338 |
+
for i, ty in signature.items()
|
| 339 |
+
if ty in FLOAT_STORAGE_TYPE
|
| 340 |
+
]
|
| 341 |
+
|
| 342 |
+
libhip_path = _get_path_to_hip_runtime_dylib()
|
| 343 |
+
|
| 344 |
+
# generate glue code
|
| 345 |
+
params = list(range(len(signature)))
|
| 346 |
+
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
| 347 |
+
params.append("&global_scratch")
|
| 348 |
+
params.append("&profile_scratch")
|
| 349 |
+
src = f"""
|
| 350 |
+
#define __HIP_PLATFORM_AMD__
|
| 351 |
+
#include <hip/hip_runtime.h>
|
| 352 |
+
#include <hip/hip_runtime_api.h>
|
| 353 |
+
#include <Python.h>
|
| 354 |
+
#include <dlfcn.h>
|
| 355 |
+
#include <stdbool.h>
|
| 356 |
+
#include <dlfcn.h>
|
| 357 |
+
|
| 358 |
+
typedef struct {{
|
| 359 |
+
uint32_t group0_0;
|
| 360 |
+
uint32_t group0_1;
|
| 361 |
+
uint32_t group0_2;
|
| 362 |
+
uint32_t group0_3;
|
| 363 |
+
uint32_t group1_0;
|
| 364 |
+
uint32_t group1_1;
|
| 365 |
+
uint32_t group1_2;
|
| 366 |
+
uint32_t group1_3;
|
| 367 |
+
uint32_t group1_4;
|
| 368 |
+
uint32_t group1_5;
|
| 369 |
+
uint32_t group1_6;
|
| 370 |
+
uint32_t group1_7;
|
| 371 |
+
}} TDMDescriptor;
|
| 372 |
+
|
| 373 |
+
typedef struct {{
|
| 374 |
+
PyObject_HEAD;
|
| 375 |
+
TDMDescriptor desc;
|
| 376 |
+
}} PyTDMDescriptorObject;
|
| 377 |
+
|
| 378 |
+
// The list of paths to search for the HIP runtime library. The caller Python
|
| 379 |
+
// code should substitute the search path placeholder.
|
| 380 |
+
static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
|
| 381 |
+
|
| 382 |
+
// The list of HIP dynamic library symbols and their signature we are interested
|
| 383 |
+
// in this file.
|
| 384 |
+
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
|
| 385 |
+
FOR_EACH_STR_FN(hipGetLastError, true) \\
|
| 386 |
+
FOR_EACH_STR_FN(hipGetErrorString, true, hipError_t hipError) \\
|
| 387 |
+
FOR_EACH_ERR_FN(hipDrvLaunchKernelEx, false, \\
|
| 388 |
+
const HIP_LAUNCH_CONFIG *config, \\
|
| 389 |
+
hipFunction_t f, \\
|
| 390 |
+
void **kernelParams, \\
|
| 391 |
+
void **extra) \\
|
| 392 |
+
FOR_EACH_ERR_FN(hipModuleLaunchKernel, true, hipFunction_t f, \\
|
| 393 |
+
unsigned int gridDimX, unsigned int gridDimY, \\
|
| 394 |
+
unsigned int gridDimZ, unsigned int blockDimX, \\
|
| 395 |
+
unsigned int blockDimY, unsigned int blockDimZ, \\
|
| 396 |
+
unsigned int sharedMemBytes, hipStream_t stream, \\
|
| 397 |
+
void **kernelParams, void **extra) \\
|
| 398 |
+
FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, true, hipFunction_t f, \\
|
| 399 |
+
unsigned int gridDimX, unsigned int gridDimY, \\
|
| 400 |
+
unsigned int gridDimZ, unsigned int blockDimX, \\
|
| 401 |
+
unsigned int blockDimY, unsigned int blockDimZ, \\
|
| 402 |
+
unsigned int sharedMemBytes, hipStream_t stream, \\
|
| 403 |
+
void **kernelParams, void **extra) \\
|
| 404 |
+
FOR_EACH_ERR_FN(hipPointerGetAttribute, true, void *data, \\
|
| 405 |
+
hipPointer_attribute attribute, hipDeviceptr_t ptr)
|
| 406 |
+
|
| 407 |
+
// The HIP symbol table for holding resolved dynamic library symbols.
|
| 408 |
+
struct HIPSymbolTable {{
|
| 409 |
+
#define DEFINE_EACH_ERR_FIELD(hipSymbolName, required, ...) \\
|
| 410 |
+
hipError_t (*hipSymbolName)(__VA_ARGS__);
|
| 411 |
+
#define DEFINE_EACH_STR_FIELD(hipSymbolName, required, ...) \\
|
| 412 |
+
const char *(*hipSymbolName)(__VA_ARGS__);
|
| 413 |
+
|
| 414 |
+
HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
|
| 415 |
+
}};
|
| 416 |
+
|
| 417 |
+
static struct HIPSymbolTable hipSymbolTable;
|
| 418 |
+
|
| 419 |
+
bool initSymbolTable() {{
|
| 420 |
+
// Use the HIP runtime library loaded into the existing process if it exits.
|
| 421 |
+
void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
|
| 422 |
+
|
| 423 |
+
// Otherwise, go through the list of search paths to dlopen the first HIP
|
| 424 |
+
// driver library.
|
| 425 |
+
if (!lib) {{
|
| 426 |
+
int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
|
| 427 |
+
for (int i = 0; i < n; ++i) {{
|
| 428 |
+
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
|
| 429 |
+
if (handle) {{
|
| 430 |
+
lib = handle;
|
| 431 |
+
}}
|
| 432 |
+
}}
|
| 433 |
+
}}
|
| 434 |
+
if (!lib) {{
|
| 435 |
+
PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
|
| 436 |
+
return false;
|
| 437 |
+
}}
|
| 438 |
+
|
| 439 |
+
typedef hipError_t (*hipGetProcAddress_fn)(
|
| 440 |
+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
|
| 441 |
+
hipDriverProcAddressQueryResult *symbolStatus);
|
| 442 |
+
hipGetProcAddress_fn hipGetProcAddress;
|
| 443 |
+
dlerror(); // Clear existing errors
|
| 444 |
+
const char *error = NULL;
|
| 445 |
+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
|
| 446 |
+
error = dlerror();
|
| 447 |
+
if (error) {{
|
| 448 |
+
PyErr_SetString(PyExc_RuntimeError,
|
| 449 |
+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
|
| 450 |
+
dlclose(lib);
|
| 451 |
+
return false;
|
| 452 |
+
}}
|
| 453 |
+
|
| 454 |
+
// Resolve all symbols we are interested in.
|
| 455 |
+
int hipVersion = HIP_VERSION;
|
| 456 |
+
uint64_t hipFlags = 0;
|
| 457 |
+
hipDriverProcAddressQueryResult symbolStatus;
|
| 458 |
+
hipError_t status = hipSuccess;
|
| 459 |
+
#define QUERY_EACH_FN(hipSymbolName, required, ...) \
|
| 460 |
+
status = hipGetProcAddress(#hipSymbolName, \
|
| 461 |
+
(void **)&hipSymbolTable.hipSymbolName, \
|
| 462 |
+
hipVersion, hipFlags, &symbolStatus); \
|
| 463 |
+
if (required && status != hipSuccess) {{ \
|
| 464 |
+
PyErr_SetString(PyExc_RuntimeError, \
|
| 465 |
+
"cannot get address for '" #hipSymbolName \
|
| 466 |
+
"' from libamdhip64.so"); \
|
| 467 |
+
dlclose(lib); \
|
| 468 |
+
return false; \
|
| 469 |
+
}}
|
| 470 |
+
|
| 471 |
+
HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
|
| 472 |
+
|
| 473 |
+
return true;
|
| 474 |
+
}}
|
| 475 |
+
|
| 476 |
+
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
| 477 |
+
{{
|
| 478 |
+
if (code != HIP_SUCCESS)
|
| 479 |
+
{{
|
| 480 |
+
const char* prefix = "Triton Error [HIP]: ";
|
| 481 |
+
const char* str = hipSymbolTable.hipGetErrorString(code);
|
| 482 |
+
char err[1024] = {{0}};
|
| 483 |
+
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
| 484 |
+
PyErr_SetString(PyExc_RuntimeError, err);
|
| 485 |
+
}}
|
| 486 |
+
}}
|
| 487 |
+
|
| 488 |
+
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
| 489 |
+
|
| 490 |
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
| 491 |
+
if (gridX * gridY * gridZ == 0)
|
| 492 |
+
return;
|
| 493 |
+
hipDeviceptr_t global_scratch = 0;
|
| 494 |
+
void *params[] = {{ {', '.join(params)} }};
|
| 495 |
+
if(num_ctas > 1) {{
|
| 496 |
+
if (!hipSymbolTable.hipDrvLaunchKernelEx) {{
|
| 497 |
+
PyErr_SetString(PyExc_RuntimeError, "missing hipDrvLaunchKernelEx symbol; please update HIP runtime");
|
| 498 |
+
return;
|
| 499 |
+
}}
|
| 500 |
+
|
| 501 |
+
hipLaunchAttribute attributes[2];
|
| 502 |
+
// Attribute0: Cluster dimensions
|
| 503 |
+
attributes[0].id = 4;
|
| 504 |
+
int *cluster_dims = (int*)attributes[0].val.pad;
|
| 505 |
+
cluster_dims[0] = num_ctas;
|
| 506 |
+
cluster_dims[1] = 1;
|
| 507 |
+
cluster_dims[2] = 1;
|
| 508 |
+
// Attribute1: Cooperative launch
|
| 509 |
+
attributes[1].id = hipLaunchAttributeCooperative;
|
| 510 |
+
attributes[1].val.cooperative = launch_cooperative_grid;
|
| 511 |
+
|
| 512 |
+
HIP_LAUNCH_CONFIG config = {{
|
| 513 |
+
gridX * num_ctas, gridY, gridZ, // Grid size
|
| 514 |
+
{warp_size} * num_warps, 1, 1, // Block size
|
| 515 |
+
shared_memory, stream,
|
| 516 |
+
attributes, 2 // Number of attributes
|
| 517 |
+
}};
|
| 518 |
+
HIP_CHECK(hipSymbolTable.hipDrvLaunchKernelEx(&config, function, params, 0));
|
| 519 |
+
return;
|
| 520 |
+
}}
|
| 521 |
+
else if (launch_cooperative_grid) {{
|
| 522 |
+
HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
|
| 523 |
+
return;
|
| 524 |
+
}}
|
| 525 |
+
else {{
|
| 526 |
+
HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
|
| 527 |
+
}}
|
| 528 |
+
}}
|
| 529 |
+
|
| 530 |
+
typedef struct _DevicePtrInfo {{
|
| 531 |
+
hipDeviceptr_t dev_ptr;
|
| 532 |
+
bool valid;
|
| 533 |
+
}} DevicePtrInfo;
|
| 534 |
+
|
| 535 |
+
static PyObject* data_ptr_str = NULL;
|
| 536 |
+
static PyObject* py_tdm_descriptor_type = NULL;
|
| 537 |
+
|
| 538 |
+
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
| 539 |
+
DevicePtrInfo ptr_info;
|
| 540 |
+
hipError_t status = hipSuccess;
|
| 541 |
+
ptr_info.dev_ptr = 0;
|
| 542 |
+
ptr_info.valid = true;
|
| 543 |
+
if (PyLong_Check(obj)) {{
|
| 544 |
+
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
| 545 |
+
return ptr_info;
|
| 546 |
+
}}
|
| 547 |
+
if (obj == Py_None) {{
|
| 548 |
+
// valid nullptr
|
| 549 |
+
return ptr_info;
|
| 550 |
+
}}
|
| 551 |
+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
|
| 552 |
+
if (!ret) {{
|
| 553 |
+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
| 554 |
+
ptr_info.valid = false;
|
| 555 |
+
goto cleanup;
|
| 556 |
+
}}
|
| 557 |
+
if (!PyLong_Check(ret)) {{
|
| 558 |
+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
| 559 |
+
ptr_info.valid = false;
|
| 560 |
+
goto cleanup;
|
| 561 |
+
}}
|
| 562 |
+
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
| 563 |
+
if (!ptr_info.dev_ptr)
|
| 564 |
+
goto cleanup;
|
| 565 |
+
uint64_t dev_ptr;
|
| 566 |
+
status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
| 567 |
+
if (status == hipErrorInvalidValue) {{
|
| 568 |
+
PyErr_Format(PyExc_ValueError,
|
| 569 |
+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
| 570 |
+
ptr_info.valid = false;
|
| 571 |
+
// Clear and ignore HIP error
|
| 572 |
+
(void)hipSymbolTable.hipGetLastError();
|
| 573 |
+
}}
|
| 574 |
+
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
| 575 |
+
cleanup:
|
| 576 |
+
Py_DECREF(ret);
|
| 577 |
+
return ptr_info;
|
| 578 |
+
}}
|
| 579 |
+
|
| 580 |
+
static inline TDMDescriptor* getTDMDescriptor(PyObject* obj, int idx) {{
|
| 581 |
+
if (Py_TYPE(obj) != (PyTypeObject*)py_tdm_descriptor_type) {{
|
| 582 |
+
PyErr_Format(PyExc_TypeError, "object must be of type PyTDMDescriptor, got %s", Py_TYPE(obj)->tp_name);
|
| 583 |
+
return NULL;
|
| 584 |
+
}}
|
| 585 |
+
|
| 586 |
+
TDMDescriptor* desc = &((PyTDMDescriptorObject*)obj)->desc;
|
| 587 |
+
return desc;
|
| 588 |
+
}}
|
| 589 |
+
|
| 590 |
+
static uint16_t pack_fp16(double f) {{
|
| 591 |
+
uint16_t result;
|
| 592 |
+
// from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
|
| 593 |
+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
| 594 |
+
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
| 595 |
+
#else
|
| 596 |
+
PyFloat_Pack2(f, (char*)&result, 1);
|
| 597 |
+
#endif
|
| 598 |
+
return result;
|
| 599 |
+
}}
|
| 600 |
+
|
| 601 |
+
static uint16_t pack_bf16(double f) {{
|
| 602 |
+
float f32 = (float)f;
|
| 603 |
+
uint32_t u32 = *(uint32_t*)&f32;
|
| 604 |
+
return (uint16_t)(u32 >> 16);
|
| 605 |
+
}}
|
| 606 |
+
|
| 607 |
+
static uint32_t pack_fp32(double f) {{
|
| 608 |
+
float f32 = (float)f;
|
| 609 |
+
return *(uint32_t*)&f32;
|
| 610 |
+
}}
|
| 611 |
+
|
| 612 |
+
static uint64_t pack_fp64(double f) {{
|
| 613 |
+
return *(uint64_t*)&f;
|
| 614 |
+
}}
|
| 615 |
+
|
| 616 |
+
static PyObject* launch(PyObject* self, PyObject* args) {{
|
| 617 |
+
int gridX, gridY, gridZ;
|
| 618 |
+
uint64_t _stream;
|
| 619 |
+
uint64_t _function;
|
| 620 |
+
int launch_cooperative_grid;
|
| 621 |
+
PyObject *profile_scratch_obj = NULL;
|
| 622 |
+
PyObject *launch_enter_hook = NULL;
|
| 623 |
+
PyObject *launch_exit_hook = NULL;
|
| 624 |
+
PyObject *kernel_metadata = NULL;
|
| 625 |
+
PyObject *launch_metadata = NULL;
|
| 626 |
+
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
| 627 |
+
if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
|
| 628 |
+
&gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
|
| 629 |
+
&kernel_metadata, &launch_metadata,
|
| 630 |
+
&launch_enter_hook, &launch_exit_hook {args_list})) {{
|
| 631 |
+
return NULL;
|
| 632 |
+
}}
|
| 633 |
+
|
| 634 |
+
// extract kernel metadata
|
| 635 |
+
int num_warps, num_ctas, shared_memory;
|
| 636 |
+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
|
| 637 |
+
return NULL;
|
| 638 |
+
}}
|
| 639 |
+
// extract launch metadata
|
| 640 |
+
if (launch_enter_hook != Py_None){{
|
| 641 |
+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
|
| 642 |
+
if (!ret)
|
| 643 |
+
return NULL;
|
| 644 |
+
Py_DECREF(ret);
|
| 645 |
+
}}
|
| 646 |
+
|
| 647 |
+
hipDeviceptr_t profile_scratch = 0;
|
| 648 |
+
if (profile_scratch_obj != Py_None) {{
|
| 649 |
+
DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
|
| 650 |
+
if (!profile_scratch_info.valid) {{
|
| 651 |
+
return NULL;
|
| 652 |
+
}}
|
| 653 |
+
profile_scratch = profile_scratch_info.dev_ptr;
|
| 654 |
+
}}
|
| 655 |
+
|
| 656 |
+
// raise exception asap
|
| 657 |
+
{newline.join(tensor_desc_decls)}
|
| 658 |
+
{newline.join(ptr_decls)}
|
| 659 |
+
{newline.join(float_storage_decls)}
|
| 660 |
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
| 661 |
+
|
| 662 |
+
if(launch_exit_hook != Py_None){{
|
| 663 |
+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
|
| 664 |
+
if (!ret)
|
| 665 |
+
return NULL;
|
| 666 |
+
Py_DECREF(ret);
|
| 667 |
+
}}
|
| 668 |
+
|
| 669 |
+
if(PyErr_Occurred()) {{
|
| 670 |
+
return NULL;
|
| 671 |
+
}}
|
| 672 |
+
Py_RETURN_NONE;
|
| 673 |
+
}}
|
| 674 |
+
|
| 675 |
+
static PyMethodDef ModuleMethods[] = {{
|
| 676 |
+
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
| 677 |
+
{{NULL, NULL, 0, NULL}} // sentinel
|
| 678 |
+
}};
|
| 679 |
+
|
| 680 |
+
static struct PyModuleDef ModuleDef = {{
|
| 681 |
+
PyModuleDef_HEAD_INIT,
|
| 682 |
+
\"__triton_launcher\",
|
| 683 |
+
NULL, //documentation
|
| 684 |
+
-1, //size
|
| 685 |
+
ModuleMethods
|
| 686 |
+
}};
|
| 687 |
+
|
| 688 |
+
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
| 689 |
+
if (!initSymbolTable()) {{
|
| 690 |
+
return NULL;
|
| 691 |
+
}}
|
| 692 |
+
PyObject *m = PyModule_Create(&ModuleDef);
|
| 693 |
+
if(m == NULL) {{
|
| 694 |
+
return NULL;
|
| 695 |
+
}}
|
| 696 |
+
data_ptr_str = PyUnicode_InternFromString("data_ptr");
|
| 697 |
+
if(data_ptr_str == NULL) {{
|
| 698 |
+
return NULL;
|
| 699 |
+
}}
|
| 700 |
+
PyObject* driver_mod = PyImport_ImportModule("triton.backends.amd.driver");
|
| 701 |
+
if (driver_mod == NULL) {{
|
| 702 |
+
return NULL;
|
| 703 |
+
}}
|
| 704 |
+
py_tdm_descriptor_type = PyObject_GetAttrString(driver_mod, "PyTDMDescriptor");
|
| 705 |
+
if (py_tdm_descriptor_type == NULL) {{
|
| 706 |
+
return NULL;
|
| 707 |
+
}}
|
| 708 |
+
|
| 709 |
+
PyModule_AddFunctions(m, ModuleMethods);
|
| 710 |
+
return m;
|
| 711 |
+
}}
|
| 712 |
+
"""
|
| 713 |
+
return src
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def make_tensordesc_arg(arg, kernel_metadata, tensordesc_metadata):
|
| 717 |
+
"""
|
| 718 |
+
Translate a tensor descriptor argument into the appropriate list of kernel
|
| 719 |
+
arguments. If `tensordesc_metadata` is provided, we will create a
|
| 720 |
+
TDMDescriptor object. Otherwise, we decompose the tensor descriptor into
|
| 721 |
+
base pointer, shape, strides, and padding flag. In both cases, we append the
|
| 722 |
+
shape and strides at the end to match the expected kernel signature.
|
| 723 |
+
"""
|
| 724 |
+
|
| 725 |
+
if tensordesc_metadata is None:
|
| 726 |
+
# Currently the host side tensor descriptors get decomposed in
|
| 727 |
+
# the frontend to tensor desc, shape, and strides. We have no
|
| 728 |
+
# way to use these shape and strides when processing tensor
|
| 729 |
+
# descriptors which is why we provide our own decomposition
|
| 730 |
+
# above. Sadly this means we have to pass the shape and strides
|
| 731 |
+
# twice.
|
| 732 |
+
return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
|
| 733 |
+
|
| 734 |
+
shape = arg.shape
|
| 735 |
+
strides = arg.strides
|
| 736 |
+
base = arg.base.data_ptr()
|
| 737 |
+
|
| 738 |
+
assert "elem_bits" in tensordesc_metadata and "block_size" in tensordesc_metadata
|
| 739 |
+
elem_bits = tensordesc_metadata["elem_bits"]
|
| 740 |
+
block_size = tensordesc_metadata["block_size"]
|
| 741 |
+
pad_interval, pad_amount = 0, 0
|
| 742 |
+
interval_padding_pairs = tensordesc_metadata.get("interval_padding_pairs", [])
|
| 743 |
+
if interval_padding_pairs:
|
| 744 |
+
assert len(interval_padding_pairs) == 1 and len(interval_padding_pairs[0]) == 2
|
| 745 |
+
pad_interval, pad_amount = interval_padding_pairs[0]
|
| 746 |
+
num_warps = kernel_metadata[0]
|
| 747 |
+
|
| 748 |
+
driver = triton.runtime.driver.active
|
| 749 |
+
assert isinstance(driver, HIPDriver)
|
| 750 |
+
|
| 751 |
+
desc = driver.utils.create_tdm_descriptor(elem_bits, block_size, num_warps, pad_interval, pad_amount, shape,
|
| 752 |
+
strides, base)
|
| 753 |
+
|
| 754 |
+
return [desc, *shape, *strides]
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def wrap_handle_tensordesc(launcher, signature, tensordesc_metadata):
|
| 758 |
+
"""
|
| 759 |
+
Wrap a kernel launcher function to handle tensor descriptor arguments.
|
| 760 |
+
Use the provided `tensordesc_metadata` to determine whether to create
|
| 761 |
+
TDMDescriptor objects or decompose the tensor descriptors.
|
| 762 |
+
|
| 763 |
+
Args:
|
| 764 |
+
launcher (callable): The original kernel launcher function.
|
| 765 |
+
signature (Dict[int, str]): The kernel signature mapping argument indices to types.
|
| 766 |
+
tensordesc_metadata (List[Dict] or None): The list of tensor descriptor metadata, following the order
|
| 767 |
+
of tensor descriptor arguments. If None, decompose tensor descriptors.
|
| 768 |
+
Returns:
|
| 769 |
+
launcher (callable): The wrapped kernel launcher function.
|
| 770 |
+
"""
|
| 771 |
+
|
| 772 |
+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
|
| 773 |
+
if not has_tensor_desc_arg:
|
| 774 |
+
return launcher
|
| 775 |
+
|
| 776 |
+
tensordesc_indices = set(
|
| 777 |
+
[i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
|
| 778 |
+
assert not tensordesc_metadata or len(tensordesc_metadata) == len(tensordesc_indices)
|
| 779 |
+
if not tensordesc_metadata:
|
| 780 |
+
tensordesc_metadata = [None] * len(tensordesc_indices)
|
| 781 |
+
|
| 782 |
+
def inner(*args):
|
| 783 |
+
meta_args = args[:len(_BASE_ARGS_FORMAT)]
|
| 784 |
+
raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
|
| 785 |
+
final_args = []
|
| 786 |
+
tensordesc_idx = 0
|
| 787 |
+
for i, arg in enumerate(raw_kernel_args):
|
| 788 |
+
if i in tensordesc_indices:
|
| 789 |
+
tensordesc_args = make_tensordesc_arg(arg, meta_args[7], # kernel_metadata
|
| 790 |
+
tensordesc_metadata[tensordesc_idx])
|
| 791 |
+
final_args.extend(tensordesc_args)
|
| 792 |
+
tensordesc_idx += 1
|
| 793 |
+
else:
|
| 794 |
+
final_args.append(arg)
|
| 795 |
+
return launcher(*meta_args, *final_args)
|
| 796 |
+
|
| 797 |
+
return inner
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
class HIPLauncher(object):
|
| 801 |
+
|
| 802 |
+
def __init__(self, src, metadata):
|
| 803 |
+
constants = src.constants if hasattr(src, "constants") else dict()
|
| 804 |
+
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
|
| 805 |
+
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
| 806 |
+
signature = {idx: value for idx, value in src.signature.items()}
|
| 807 |
+
tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
|
| 808 |
+
src = make_launcher(constants, signature, metadata.warp_size, tensordesc_meta)
|
| 809 |
+
mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
|
| 810 |
+
self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
|
| 811 |
+
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
| 812 |
+
self.profile_scratch_size = metadata.profile_scratch_size
|
| 813 |
+
self.profile_scratch_align = metadata.profile_scratch_align
|
| 814 |
+
|
| 815 |
+
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
| 816 |
+
|
| 817 |
+
def allocate_scratch(size, align, allocator):
|
| 818 |
+
if size > 0:
|
| 819 |
+
grid_size = gridX * gridY * gridZ
|
| 820 |
+
alloc_size = grid_size * size
|
| 821 |
+
alloc_fn = allocator.get()
|
| 822 |
+
return alloc_fn(alloc_size, align, stream)
|
| 823 |
+
return None
|
| 824 |
+
|
| 825 |
+
profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
|
| 826 |
+
_allocation._profile_allocator)
|
| 827 |
+
|
| 828 |
+
self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
class HIPDriver(GPUDriver):
|
| 832 |
+
|
| 833 |
+
def __init__(self):
|
| 834 |
+
super().__init__()
|
| 835 |
+
self.utils = HIPUtils()
|
| 836 |
+
self.launcher_cls = HIPLauncher
|
| 837 |
+
|
| 838 |
+
def get_device_interface(self):
|
| 839 |
+
import torch
|
| 840 |
+
return torch.cuda
|
| 841 |
+
|
| 842 |
+
@staticmethod
|
| 843 |
+
def is_active():
|
| 844 |
+
try:
|
| 845 |
+
import torch
|
| 846 |
+
return torch.cuda.is_available() and (torch.version.hip is not None)
|
| 847 |
+
except ImportError:
|
| 848 |
+
return False
|
| 849 |
+
|
| 850 |
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
| 851 |
+
return ty_to_cpp(ty)
|
| 852 |
+
|
| 853 |
+
def get_current_target(self):
|
| 854 |
+
device = self.get_current_device()
|
| 855 |
+
device_properties = self.utils.get_device_properties(device)
|
| 856 |
+
arch = knobs.runtime.override_arch or device_properties['arch']
|
| 857 |
+
warp_size = device_properties['warpSize']
|
| 858 |
+
return GPUTarget("hip", arch.split(':')[0], warp_size)
|
| 859 |
+
|
| 860 |
+
def get_active_torch_device(self):
|
| 861 |
+
import torch
|
| 862 |
+
# when using hip devices, the device string in pytorch is "cuda"
|
| 863 |
+
return torch.device("cuda", self.get_current_device())
|
| 864 |
+
|
| 865 |
+
def get_benchmarker(self):
|
| 866 |
+
from triton.testing import do_bench
|
| 867 |
+
return do_bench
|
| 868 |
+
|
| 869 |
+
def get_empty_cache_for_benchmark(self):
|
| 870 |
+
import torch
|
| 871 |
+
|
| 872 |
+
# It's the same as the Nvidia backend.
|
| 873 |
+
cache_size = 256 * 1024 * 1024
|
| 874 |
+
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
|
| 875 |
+
|
| 876 |
+
def clear_cache(self, cache):
|
| 877 |
+
cache.zero_()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/__init__.py
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from triton.backends.compiler import BaseBackend, GPUTarget, Language
|
| 2 |
+
from triton._C.libtriton import ir, passes, llvm, nvidia
|
| 3 |
+
from triton import knobs
|
| 4 |
+
from triton.runtime.errors import PTXASError
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
import functools
|
| 8 |
+
from typing import Any, Dict, Tuple, Optional
|
| 9 |
+
from types import ModuleType
|
| 10 |
+
import hashlib
|
| 11 |
+
import re
|
| 12 |
+
import tempfile
|
| 13 |
+
import signal
|
| 14 |
+
import os
|
| 15 |
+
import subprocess
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def min_dot_size(target: GPUTarget):
|
| 20 |
+
|
| 21 |
+
def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
|
| 22 |
+
lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
|
| 23 |
+
rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
|
| 24 |
+
assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
|
| 25 |
+
# For small M/N the input we can still use tensorcores with padding.
|
| 26 |
+
if lhs_bitwidth == 8:
|
| 27 |
+
return (1, 1, 32)
|
| 28 |
+
else:
|
| 29 |
+
return (1, 1, 16)
|
| 30 |
+
|
| 31 |
+
return check_dot_compatibility
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_ptxas(arch: int) -> knobs.NvidiaTool:
|
| 35 |
+
return knobs.nvidia.ptxas_blackwell if arch >= 100 else knobs.nvidia.ptxas
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@functools.lru_cache()
|
| 39 |
+
def get_ptxas_version(arch: int = 80):
|
| 40 |
+
mock_ver = knobs.nvidia.mock_ptx_version
|
| 41 |
+
if mock_ver is not None:
|
| 42 |
+
return mock_ver # This is not really a version of ptxas, but it is good enough for testing
|
| 43 |
+
version = subprocess.check_output([get_ptxas(arch).path, "--version"]).decode("utf-8")
|
| 44 |
+
return version
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@functools.lru_cache()
|
| 48 |
+
def ptx_get_version(cuda_version) -> int:
|
| 49 |
+
'''
|
| 50 |
+
Get the highest PTX version supported by the current CUDA driver.
|
| 51 |
+
'''
|
| 52 |
+
assert isinstance(cuda_version, str)
|
| 53 |
+
major, minor = map(int, cuda_version.split('.'))
|
| 54 |
+
if major == 12:
|
| 55 |
+
if minor < 6:
|
| 56 |
+
return 80 + minor
|
| 57 |
+
else:
|
| 58 |
+
return 80 + minor - 1
|
| 59 |
+
if major == 11:
|
| 60 |
+
return 70 + minor
|
| 61 |
+
if major == 10:
|
| 62 |
+
return 63 + minor
|
| 63 |
+
|
| 64 |
+
if major >= 13:
|
| 65 |
+
base_ptx = 90
|
| 66 |
+
return base_ptx + (major - 13) * 10 + minor
|
| 67 |
+
|
| 68 |
+
raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_ptx_version_from_options(options, arch: int):
|
| 72 |
+
ptx_version = options.ptx_version
|
| 73 |
+
if ptx_version is None:
|
| 74 |
+
cuda_version = get_ptxas(arch).version
|
| 75 |
+
ptx_version = ptx_get_version(cuda_version)
|
| 76 |
+
return ptx_version
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@functools.lru_cache()
|
| 80 |
+
def get_features(options, arch: int):
|
| 81 |
+
ptx_version = get_ptx_version_from_options(options, arch)
|
| 82 |
+
|
| 83 |
+
# PTX 8.6 is the max version supported by llvm c1188642.
|
| 84 |
+
#
|
| 85 |
+
# To check if a newer PTX version is supported, increase this value
|
| 86 |
+
# and run a test. If it's not supported, LLVM will print a warning
|
| 87 |
+
# like "+ptx8.4 is not a recognized feature for this target".
|
| 88 |
+
llvm_ptx_version = min(86, ptx_version)
|
| 89 |
+
features = f'+ptx{llvm_ptx_version}'
|
| 90 |
+
return features
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@functools.lru_cache(None)
|
| 94 |
+
def file_hash(path):
|
| 95 |
+
with open(path, "rb") as f:
|
| 96 |
+
return hashlib.sha256(f.read()).hexdigest()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def sm_arch_from_capability(capability: int):
|
| 100 |
+
# TODO: Handle non-"a" sms
|
| 101 |
+
suffix = "a" if capability >= 90 else ""
|
| 102 |
+
return f"sm_{capability}{suffix}"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass(frozen=True)
|
| 106 |
+
class CUDAOptions:
|
| 107 |
+
num_warps: int = 4
|
| 108 |
+
num_ctas: int = 1
|
| 109 |
+
num_stages: int = 3
|
| 110 |
+
warp_size: int = 32
|
| 111 |
+
# maxnreg corresponds to the ptx parameter .maxnreg, which controls the
|
| 112 |
+
# maximum number of 32-bit registers used by one thread.
|
| 113 |
+
maxnreg: Optional[int] = None
|
| 114 |
+
ptx_version: int = None
|
| 115 |
+
ptx_options: Optional[str] = knobs.nvidia.ptxas_options
|
| 116 |
+
ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
|
| 117 |
+
enable_fp_fusion: bool = True
|
| 118 |
+
enable_reflect_ftz: bool = True # ftz in libdevice
|
| 119 |
+
launch_cooperative_grid: bool = False
|
| 120 |
+
launch_pdl: bool = False
|
| 121 |
+
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
|
| 122 |
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
| 123 |
+
default_dot_input_precision: str = "tf32"
|
| 124 |
+
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6')
|
| 125 |
+
max_num_imprecise_acc_default: bool = None
|
| 126 |
+
extern_libs: dict = None
|
| 127 |
+
debug: bool = False
|
| 128 |
+
backend_name: str = 'cuda'
|
| 129 |
+
sanitize_overflow: bool = True
|
| 130 |
+
arch: str = None
|
| 131 |
+
instrumentation_mode: str = ""
|
| 132 |
+
|
| 133 |
+
def __post_init__(self):
|
| 134 |
+
default_libdir = Path(__file__).parent / 'lib'
|
| 135 |
+
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
|
| 136 |
+
if not extern_libs.get('libdevice', None):
|
| 137 |
+
extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
|
| 138 |
+
|
| 139 |
+
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
|
| 140 |
+
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
| 141 |
+
"num_warps must be a power of 2"
|
| 142 |
+
|
| 143 |
+
def hash(self):
|
| 144 |
+
hash_dict = dict(self.__dict__)
|
| 145 |
+
hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
|
| 146 |
+
key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
|
| 147 |
+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class CUDABackend(BaseBackend):
|
| 151 |
+
instrumentation = None
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def supports_target(target: GPUTarget):
|
| 155 |
+
return target.backend == 'cuda'
|
| 156 |
+
|
| 157 |
+
def _parse_arch(self, arch):
|
| 158 |
+
pattern = r"^sm(\d+)$"
|
| 159 |
+
match = re.fullmatch(pattern, arch)
|
| 160 |
+
if not match:
|
| 161 |
+
raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
|
| 162 |
+
return int(match.group(1))
|
| 163 |
+
|
| 164 |
+
def get_target_name(self, options) -> str:
|
| 165 |
+
capability = self._parse_arch(options.arch)
|
| 166 |
+
return f"cuda:{capability}"
|
| 167 |
+
|
| 168 |
+
def __init__(self, target: GPUTarget) -> None:
|
| 169 |
+
super().__init__(target)
|
| 170 |
+
self.binary_ext = "cubin"
|
| 171 |
+
|
| 172 |
+
def parse_options(self, opts) -> Any:
|
| 173 |
+
# Enable debug mode for ConSan, so device-side assertions are not optimized out
|
| 174 |
+
if "instrumentation_mode" in opts and opts["instrumentation_mode"] == "consan":
|
| 175 |
+
opts["debug"] = True
|
| 176 |
+
|
| 177 |
+
args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
|
| 178 |
+
args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
|
| 179 |
+
capability = int(self._parse_arch(args["arch"]))
|
| 180 |
+
|
| 181 |
+
if args.get("num_ctas", 1) > 1 and capability < 90:
|
| 182 |
+
raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
|
| 183 |
+
f"Current target is sm_{capability}. This configuration will fail. "
|
| 184 |
+
f"Please set num_ctas=1 or target an SM90+ GPU."))
|
| 185 |
+
|
| 186 |
+
if "supported_fp8_dtypes" not in args:
|
| 187 |
+
supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
|
| 188 |
+
if capability >= 89:
|
| 189 |
+
supported_fp8_dtypes.add("fp8e4nv")
|
| 190 |
+
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
|
| 191 |
+
|
| 192 |
+
if "deprecated_fp8_dot_operand_dtypes" not in args:
|
| 193 |
+
if capability >= 90:
|
| 194 |
+
args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
|
| 195 |
+
|
| 196 |
+
if "enable_fp_fusion" not in args:
|
| 197 |
+
args["enable_fp_fusion"] = knobs.language.default_fp_fusion
|
| 198 |
+
|
| 199 |
+
args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
|
| 200 |
+
|
| 201 |
+
return CUDAOptions(**args)
|
| 202 |
+
|
| 203 |
+
def pack_metadata(self, metadata):
|
| 204 |
+
return (
|
| 205 |
+
metadata.num_warps,
|
| 206 |
+
metadata.num_ctas,
|
| 207 |
+
metadata.shared,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def get_codegen_implementation(self, options):
|
| 211 |
+
import triton.language.extra.cuda as cuda
|
| 212 |
+
capability = int(self._parse_arch(options.arch))
|
| 213 |
+
codegen_fns = {
|
| 214 |
+
"convert_custom_types":
|
| 215 |
+
cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
|
| 216 |
+
min_dot_size(self.target)
|
| 217 |
+
}
|
| 218 |
+
return codegen_fns
|
| 219 |
+
|
| 220 |
+
def get_module_map(self) -> Dict[str, ModuleType]:
|
| 221 |
+
from triton.language.extra.cuda import libdevice
|
| 222 |
+
return {"triton.language.extra.libdevice": libdevice}
|
| 223 |
+
|
| 224 |
+
def load_dialects(self, ctx):
|
| 225 |
+
nvidia.load_dialects(ctx)
|
| 226 |
+
if CUDABackend.instrumentation:
|
| 227 |
+
CUDABackend.instrumentation.load_dialects(ctx)
|
| 228 |
+
|
| 229 |
+
@staticmethod
|
| 230 |
+
def make_ttir(mod, metadata, opt, capability):
|
| 231 |
+
pm = ir.pass_manager(mod.context)
|
| 232 |
+
pm.enable_debug()
|
| 233 |
+
passes.common.add_inliner(pm)
|
| 234 |
+
passes.ttir.add_rewrite_tensor_pointer(pm)
|
| 235 |
+
if capability // 10 < 9:
|
| 236 |
+
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
|
| 237 |
+
passes.common.add_canonicalizer(pm)
|
| 238 |
+
passes.ttir.add_combine(pm)
|
| 239 |
+
passes.ttir.add_reorder_broadcast(pm)
|
| 240 |
+
passes.common.add_cse(pm)
|
| 241 |
+
passes.common.add_symbol_dce(pm)
|
| 242 |
+
passes.ttir.add_loop_unroll(pm)
|
| 243 |
+
pm.run(mod, 'make_ttir')
|
| 244 |
+
return mod
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def make_ttgir(mod, metadata, opt, capability):
|
| 248 |
+
# Set maxnreg on all kernels, if it was provided.
|
| 249 |
+
if opt.maxnreg is not None:
|
| 250 |
+
mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
|
| 251 |
+
|
| 252 |
+
pm = ir.pass_manager(mod.context)
|
| 253 |
+
dump_enabled = pm.enable_debug()
|
| 254 |
+
emuTF32 = (capability // 10 >= 8)
|
| 255 |
+
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
|
| 256 |
+
# optimize TTGIR
|
| 257 |
+
passes.ttgpuir.add_coalesce(pm)
|
| 258 |
+
passes.ttgpuir.add_f32_dot_tc(pm, emuTF32)
|
| 259 |
+
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
|
| 260 |
+
nvidia.passes.ttnvgpuir.add_plan_cta(pm)
|
| 261 |
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
| 262 |
+
passes.ttgpuir.add_optimize_thread_locality(pm)
|
| 263 |
+
passes.ttgpuir.add_accelerate_matmul(pm)
|
| 264 |
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
| 265 |
+
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
| 266 |
+
nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
|
| 267 |
+
passes.ttir.add_loop_aware_cse(pm)
|
| 268 |
+
if capability // 10 in [8, 9]:
|
| 269 |
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
| 270 |
+
passes.common.add_canonicalizer(pm)
|
| 271 |
+
passes.ttir.add_triton_licm(pm)
|
| 272 |
+
passes.common.add_canonicalizer(pm)
|
| 273 |
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
| 274 |
+
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
|
| 275 |
+
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
|
| 276 |
+
passes.ttgpuir.add_schedule_loops(pm)
|
| 277 |
+
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
|
| 278 |
+
elif capability // 10 >= 10:
|
| 279 |
+
passes.ttgpuir.add_fuse_nested_loops(pm)
|
| 280 |
+
passes.common.add_canonicalizer(pm)
|
| 281 |
+
passes.ttir.add_triton_licm(pm)
|
| 282 |
+
passes.ttgpuir.add_optimize_accumulator_init(pm)
|
| 283 |
+
passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
|
| 284 |
+
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
|
| 285 |
+
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
|
| 286 |
+
passes.ttgpuir.add_schedule_loops(pm)
|
| 287 |
+
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
|
| 288 |
+
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
|
| 289 |
+
passes.ttgpuir.add_optimize_partition_warps(pm)
|
| 290 |
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
| 291 |
+
# hoist again and allow hoisting out of if statements
|
| 292 |
+
passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
|
| 293 |
+
nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
|
| 294 |
+
else:
|
| 295 |
+
passes.ttir.add_triton_licm(pm)
|
| 296 |
+
passes.common.add_canonicalizer(pm)
|
| 297 |
+
passes.ttir.add_loop_aware_cse(pm)
|
| 298 |
+
passes.ttgpuir.add_prefetch(pm)
|
| 299 |
+
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
| 300 |
+
passes.ttgpuir.add_coalesce_async_copy(pm)
|
| 301 |
+
nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
|
| 302 |
+
if capability // 10 >= 9:
|
| 303 |
+
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
|
| 304 |
+
passes.ttgpuir.add_remove_layout_conversions(pm)
|
| 305 |
+
nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
|
| 306 |
+
passes.ttgpuir.add_reduce_data_duplication(pm)
|
| 307 |
+
passes.ttgpuir.add_reorder_instructions(pm)
|
| 308 |
+
passes.ttir.add_loop_aware_cse(pm)
|
| 309 |
+
passes.common.add_symbol_dce(pm)
|
| 310 |
+
nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
|
| 311 |
+
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
|
| 312 |
+
passes.common.add_sccp(pm)
|
| 313 |
+
passes.common.add_cse(pm)
|
| 314 |
+
passes.common.add_canonicalizer(pm)
|
| 315 |
+
|
| 316 |
+
pm.run(mod, 'make_ttgir')
|
| 317 |
+
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
|
| 318 |
+
return mod
|
| 319 |
+
|
| 320 |
+
def gluon_to_ttgir(self, src, metadata, options, capability):
|
| 321 |
+
mod = src
|
| 322 |
+
pm = ir.pass_manager(mod.context)
|
| 323 |
+
pm.enable_debug()
|
| 324 |
+
|
| 325 |
+
passes.gluon.add_inliner(pm)
|
| 326 |
+
passes.gluon.add_infer_coalesced_encodings(pm)
|
| 327 |
+
passes.gluon.add_resolve_auto_encodings(pm)
|
| 328 |
+
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
|
| 329 |
+
passes.gluon.add_canonicalizer(pm)
|
| 330 |
+
passes.common.add_sccp(pm)
|
| 331 |
+
passes.ttir.add_loop_aware_cse(pm)
|
| 332 |
+
passes.gluon.add_canonicalizer(pm)
|
| 333 |
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
| 334 |
+
|
| 335 |
+
pm.run(mod, 'gluon_to_ttgir')
|
| 336 |
+
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
|
| 337 |
+
return mod
|
| 338 |
+
|
| 339 |
+
def make_llir(self, src, metadata, options, capability):
|
| 340 |
+
ptx_version = get_ptx_version_from_options(options, self.target.arch)
|
| 341 |
+
|
| 342 |
+
mod = src
|
| 343 |
+
# TritonGPU -> LLVM-IR (MLIR)
|
| 344 |
+
pm = ir.pass_manager(mod.context)
|
| 345 |
+
pm.enable_debug()
|
| 346 |
+
|
| 347 |
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
| 348 |
+
passes.ttgpuir.add_allocate_warp_groups(pm)
|
| 349 |
+
passes.convert.add_scf_to_cf(pm)
|
| 350 |
+
passes.gluon.add_inliner(pm)
|
| 351 |
+
nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
|
| 352 |
+
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
|
| 353 |
+
nvidia.passes.ttnvgpuir.add_check_matmul_two_cta(pm)
|
| 354 |
+
if knobs.compilation.instrumentation_mode == "consan":
|
| 355 |
+
# Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
|
| 356 |
+
passes.ttgpuir.add_concurrency_sanitizer(pm)
|
| 357 |
+
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
|
| 358 |
+
nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
|
| 359 |
+
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
|
| 360 |
+
if CUDABackend.instrumentation:
|
| 361 |
+
CUDABackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
|
| 362 |
+
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
|
| 363 |
+
passes.common.add_canonicalizer(pm)
|
| 364 |
+
passes.common.add_cse(pm)
|
| 365 |
+
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
|
| 366 |
+
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
|
| 367 |
+
passes.common.add_canonicalizer(pm)
|
| 368 |
+
passes.common.add_cse(pm)
|
| 369 |
+
passes.common.add_symbol_dce(pm)
|
| 370 |
+
passes.convert.add_nvvm_to_llvm(pm)
|
| 371 |
+
|
| 372 |
+
if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
|
| 373 |
+
passes.llvmir.add_di_scope(pm)
|
| 374 |
+
|
| 375 |
+
if CUDABackend.instrumentation:
|
| 376 |
+
CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
|
| 377 |
+
|
| 378 |
+
pm.run(mod, 'make_llir')
|
| 379 |
+
|
| 380 |
+
if knobs.compilation.dump_ir_extract_di_local_variables:
|
| 381 |
+
# comments below on why separate it
|
| 382 |
+
if not knobs.compilation.disable_line_info:
|
| 383 |
+
pm = ir.pass_manager(mod.context)
|
| 384 |
+
pm.enable_debug()
|
| 385 |
+
passes.llvmir.add_di_scope(pm)
|
| 386 |
+
pm.run(mod, 'make_llir.disable_line_info')
|
| 387 |
+
|
| 388 |
+
# insert dbg intrinsic with several DI Attribute including source
|
| 389 |
+
# var name and type info note: unknown reason for now, but this
|
| 390 |
+
# pass and add_di_scope has to be run separately, otherwise if we
|
| 391 |
+
# put them into previous pipline, it trigger a segmentfault without
|
| 392 |
+
# any error message; could be due to a bug in mlir or pybind11
|
| 393 |
+
pm = ir.pass_manager(mod.context)
|
| 394 |
+
pm.enable_debug()
|
| 395 |
+
passes.llvmir.add_di_local_variable(pm)
|
| 396 |
+
pm.run(mod, 'make_llir.dump_ir_extract_di_local_variables')
|
| 397 |
+
|
| 398 |
+
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
|
| 399 |
+
llvm.init_targets()
|
| 400 |
+
context = llvm.context()
|
| 401 |
+
if knobs.compilation.enable_asan:
|
| 402 |
+
raise RuntimeError(
|
| 403 |
+
"Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
|
| 404 |
+
llvm_mod = llvm.to_module(mod, context)
|
| 405 |
+
proc = sm_arch_from_capability(capability)
|
| 406 |
+
features = get_features(options, self.target.arch)
|
| 407 |
+
triple = 'nvptx64-nvidia-cuda'
|
| 408 |
+
nvidia.set_short_ptr()
|
| 409 |
+
llvm.attach_datalayout(llvm_mod, triple, proc, features)
|
| 410 |
+
if options.enable_reflect_ftz:
|
| 411 |
+
nvidia.set_nvvm_reflect_ftz(llvm_mod)
|
| 412 |
+
|
| 413 |
+
if options.extern_libs and nvidia.has_extern_deps(llvm_mod):
|
| 414 |
+
paths = [path for (name, path) in options.extern_libs]
|
| 415 |
+
llvm.link_extern_libs(llvm_mod, paths)
|
| 416 |
+
|
| 417 |
+
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
|
| 418 |
+
|
| 419 |
+
# Get some metadata
|
| 420 |
+
# warp-specialization mutates num_warps
|
| 421 |
+
total_num_warps = src.get_int_attr("ttg.total-num-warps")
|
| 422 |
+
if total_num_warps is not None:
|
| 423 |
+
metadata["num_warps"] = total_num_warps
|
| 424 |
+
metadata["shared"] = src.get_int_attr("ttg.shared")
|
| 425 |
+
metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
|
| 426 |
+
metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
|
| 427 |
+
metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
|
| 428 |
+
metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
|
| 429 |
+
metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
|
| 430 |
+
ret = str(llvm_mod)
|
| 431 |
+
del llvm_mod
|
| 432 |
+
del context
|
| 433 |
+
return ret
|
| 434 |
+
|
| 435 |
+
def make_ptx(self, src, metadata, opt, capability):
|
| 436 |
+
ptx_version = get_ptx_version_from_options(opt, self.target.arch)
|
| 437 |
+
|
| 438 |
+
triple = 'nvptx64-nvidia-cuda'
|
| 439 |
+
proc = sm_arch_from_capability(capability)
|
| 440 |
+
features = get_features(opt, self.target.arch)
|
| 441 |
+
flags = ["nvptx-mad-wide-opt"]
|
| 442 |
+
ret = llvm.translate_to_asm(src, triple, proc, features, flags, opt.enable_fp_fusion, False)
|
| 443 |
+
# Find kernel names (there should only be one)
|
| 444 |
+
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
|
| 445 |
+
assert len(names) == 1
|
| 446 |
+
metadata["name"] = names[0]
|
| 447 |
+
# post-process
|
| 448 |
+
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
|
| 449 |
+
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
|
| 450 |
+
ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
|
| 451 |
+
if not knobs.compilation.dump_ir_extract_di_local_variables:
|
| 452 |
+
# Remove the debug flag that prevents ptxas from optimizing the code
|
| 453 |
+
# Note: if this flag is removed, the source var name and type info will be lost when ptx was compiled into cubin
|
| 454 |
+
# and we may not be able to see them in cuda-gdb
|
| 455 |
+
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
|
| 456 |
+
if knobs.nvidia.dump_nvptx:
|
| 457 |
+
print("// -----// NVPTX Dump //----- //")
|
| 458 |
+
print(ret)
|
| 459 |
+
return ret
|
| 460 |
+
|
| 461 |
+
def make_cubin(self, src, metadata, opt, capability):
|
| 462 |
+
ptxas = get_ptxas(self.target.arch).path
|
| 463 |
+
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
|
| 464 |
+
tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
|
| 465 |
+
fsrc.write(src)
|
| 466 |
+
fsrc.flush()
|
| 467 |
+
fbin = fsrc.name + '.o'
|
| 468 |
+
|
| 469 |
+
debug_info = []
|
| 470 |
+
if knobs.compilation.disable_line_info:
|
| 471 |
+
# This option is ignored if used without -lineinfo
|
| 472 |
+
debug_info += ["-lineinfo", "-suppress-debug-info"]
|
| 473 |
+
elif knobs.nvidia.disable_ptxas_opt:
|
| 474 |
+
# Synthesize complete debug info
|
| 475 |
+
debug_info += ["-g"]
|
| 476 |
+
else:
|
| 477 |
+
# Only emit line info
|
| 478 |
+
debug_info += ["-lineinfo"]
|
| 479 |
+
|
| 480 |
+
fmad = [] if opt.enable_fp_fusion else ["--fmad=false"]
|
| 481 |
+
arch = sm_arch_from_capability(capability)
|
| 482 |
+
|
| 483 |
+
# Disable ptxas optimizations if requested
|
| 484 |
+
disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
|
| 485 |
+
|
| 486 |
+
# Accept more ptxas options if provided
|
| 487 |
+
ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
|
| 488 |
+
|
| 489 |
+
ptxas_cmd = [
|
| 490 |
+
ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name,
|
| 491 |
+
'-o', fbin
|
| 492 |
+
]
|
| 493 |
+
try:
|
| 494 |
+
subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog)
|
| 495 |
+
if knobs.nvidia.dump_ptxas_log:
|
| 496 |
+
with open(flog.name) as log_file:
|
| 497 |
+
print(log_file.read())
|
| 498 |
+
|
| 499 |
+
if os.path.exists(fsrc.name):
|
| 500 |
+
os.remove(fsrc.name)
|
| 501 |
+
if os.path.exists(flog.name):
|
| 502 |
+
os.remove(flog.name)
|
| 503 |
+
except subprocess.CalledProcessError as e:
|
| 504 |
+
with open(flog.name) as log_file:
|
| 505 |
+
log = log_file.read()
|
| 506 |
+
if os.path.exists(flog.name):
|
| 507 |
+
os.remove(flog.name)
|
| 508 |
+
|
| 509 |
+
if e.returncode == 255:
|
| 510 |
+
error = 'Internal Triton PTX codegen error'
|
| 511 |
+
elif e.returncode == 128 + signal.SIGSEGV:
|
| 512 |
+
error = '`ptxas` raised SIGSEGV'
|
| 513 |
+
else:
|
| 514 |
+
error = f'`ptxas` failed with error code {e.returncode}'
|
| 515 |
+
|
| 516 |
+
error = (f"{error}\n"
|
| 517 |
+
f"`ptxas` stderr:\n{log}\n"
|
| 518 |
+
f'Repro command: {" ".join(ptxas_cmd)}\n')
|
| 519 |
+
|
| 520 |
+
print(f"""
|
| 521 |
+
|
| 522 |
+
================================================================
|
| 523 |
+
{error}
|
| 524 |
+
|
| 525 |
+
{src}
|
| 526 |
+
================================================================
|
| 527 |
+
please share the reproducer above with Triton project.
|
| 528 |
+
""")
|
| 529 |
+
raise PTXASError(error)
|
| 530 |
+
|
| 531 |
+
with open(fbin, 'rb') as f:
|
| 532 |
+
cubin = f.read()
|
| 533 |
+
if os.path.exists(fbin):
|
| 534 |
+
os.remove(fbin)
|
| 535 |
+
return cubin
|
| 536 |
+
|
| 537 |
+
def add_stages(self, stages, options, language):
|
| 538 |
+
capability = self._parse_arch(options.arch)
|
| 539 |
+
if language == Language.TRITON:
|
| 540 |
+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
|
| 541 |
+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
|
| 542 |
+
elif language == Language.GLUON:
|
| 543 |
+
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability)
|
| 544 |
+
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
|
| 545 |
+
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
|
| 546 |
+
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
|
| 547 |
+
if knobs.runtime.add_stages_inspection_hook is not None:
|
| 548 |
+
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, capability)
|
| 549 |
+
|
| 550 |
+
@functools.lru_cache()
|
| 551 |
+
def hash(self):
|
| 552 |
+
version = get_ptxas_version(self.target.arch)
|
| 553 |
+
return f'{version}-{self.target.arch}'
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.c
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "cuda.h"
|
| 2 |
+
#include <dlfcn.h>
|
| 3 |
+
#include <stdbool.h>
|
| 4 |
+
#include <stdio.h>
|
| 5 |
+
#include <stdlib.h>
|
| 6 |
+
#define PY_SSIZE_T_CLEAN
|
| 7 |
+
#include <Python.h>
|
| 8 |
+
|
| 9 |
+
typedef struct {
|
| 10 |
+
PyObject_HEAD;
|
| 11 |
+
_Alignas(128) CUtensorMap tensorMap;
|
| 12 |
+
} PyCUtensorMapObject;
|
| 13 |
+
|
| 14 |
+
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
|
| 15 |
+
static bool gpuAssert(CUresult code, const char *file, int line) {
|
| 16 |
+
if (code == CUDA_SUCCESS)
|
| 17 |
+
return true;
|
| 18 |
+
|
| 19 |
+
const char *prefix = "Triton Error [CUDA]: ";
|
| 20 |
+
const char *str;
|
| 21 |
+
cuGetErrorString(code, &str);
|
| 22 |
+
char err[1024] = {0};
|
| 23 |
+
strcat(err, prefix);
|
| 24 |
+
strcat(err, str);
|
| 25 |
+
PyGILState_STATE gil_state;
|
| 26 |
+
gil_state = PyGILState_Ensure();
|
| 27 |
+
PyErr_SetString(PyExc_RuntimeError, err);
|
| 28 |
+
PyGILState_Release(gil_state);
|
| 29 |
+
return false;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
|
| 33 |
+
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
|
| 34 |
+
do { \
|
| 35 |
+
if (!gpuAssert((ans), __FILE__, __LINE__)) \
|
| 36 |
+
goto cleanup; \
|
| 37 |
+
} while (0)
|
| 38 |
+
|
| 39 |
+
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
|
| 40 |
+
#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
|
| 41 |
+
do { \
|
| 42 |
+
if (!gpuAssert((ans), __FILE__, __LINE__)) { \
|
| 43 |
+
PyEval_RestoreThread(_save); \
|
| 44 |
+
return NULL; \
|
| 45 |
+
} \
|
| 46 |
+
} while (0)
|
| 47 |
+
|
| 48 |
+
// Used to check if functions exist in old CUDA driver versions.
|
| 49 |
+
#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \
|
| 50 |
+
do { \
|
| 51 |
+
if ((funcPointer) == NULL) { \
|
| 52 |
+
(funcPointer) = (initializerFunction)(); \
|
| 53 |
+
if ((funcPointer) == NULL) { \
|
| 54 |
+
goto cleanup; \
|
| 55 |
+
} \
|
| 56 |
+
} \
|
| 57 |
+
} while (0)
|
| 58 |
+
|
| 59 |
+
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
| 60 |
+
int device_id;
|
| 61 |
+
if (!PyArg_ParseTuple(args, "i", &device_id))
|
| 62 |
+
return NULL;
|
| 63 |
+
// Get device handle
|
| 64 |
+
CUdevice device;
|
| 65 |
+
cuDeviceGet(&device, device_id);
|
| 66 |
+
|
| 67 |
+
// create a struct to hold device properties
|
| 68 |
+
int max_shared_mem;
|
| 69 |
+
int max_num_regs;
|
| 70 |
+
int multiprocessor_count;
|
| 71 |
+
int warp_size;
|
| 72 |
+
int sm_clock_rate;
|
| 73 |
+
int mem_clock_rate;
|
| 74 |
+
int mem_bus_width;
|
| 75 |
+
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
| 76 |
+
&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
| 77 |
+
device));
|
| 78 |
+
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
| 79 |
+
&max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
|
| 80 |
+
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
| 81 |
+
&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
| 82 |
+
CUDA_CHECK_AND_RETURN_NULL(
|
| 83 |
+
cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
|
| 84 |
+
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
| 85 |
+
&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
|
| 86 |
+
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
| 87 |
+
&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
|
| 88 |
+
CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
|
| 89 |
+
&mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
|
| 90 |
+
|
| 91 |
+
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
|
| 92 |
+
max_shared_mem, "max_num_regs", max_num_regs,
|
| 93 |
+
"multiprocessor_count", multiprocessor_count, "warpSize",
|
| 94 |
+
warp_size, "sm_clock_rate", sm_clock_rate,
|
| 95 |
+
"mem_clock_rate", mem_clock_rate, "mem_bus_width",
|
| 96 |
+
mem_bus_width);
|
| 97 |
+
|
| 98 |
+
cleanup:
|
| 99 |
+
return NULL;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
| 103 |
+
const char *name;
|
| 104 |
+
const char *data;
|
| 105 |
+
Py_ssize_t data_size;
|
| 106 |
+
int shared;
|
| 107 |
+
int device;
|
| 108 |
+
if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
|
| 109 |
+
&device)) {
|
| 110 |
+
return NULL;
|
| 111 |
+
}
|
| 112 |
+
CUfunction fun;
|
| 113 |
+
CUmodule mod;
|
| 114 |
+
int32_t n_regs = 0;
|
| 115 |
+
int32_t n_spills = 0;
|
| 116 |
+
int32_t n_max_threads = 0;
|
| 117 |
+
// create driver handles
|
| 118 |
+
CUcontext pctx = 0;
|
| 119 |
+
|
| 120 |
+
Py_BEGIN_ALLOW_THREADS;
|
| 121 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
|
| 122 |
+
if (!pctx) {
|
| 123 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 124 |
+
cuDevicePrimaryCtxRetain(&pctx, device));
|
| 125 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
|
| 129 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 130 |
+
cuModuleGetFunction(&fun, mod, name));
|
| 131 |
+
// get allocated registers and spilled registers from the function
|
| 132 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 133 |
+
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
| 134 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 135 |
+
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
| 136 |
+
n_spills /= 4;
|
| 137 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
|
| 138 |
+
&n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
|
| 139 |
+
// set dynamic shared memory if necessary
|
| 140 |
+
int shared_optin;
|
| 141 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
| 142 |
+
&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
| 143 |
+
device));
|
| 144 |
+
if (shared > 49152 && shared_optin > 49152) {
|
| 145 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 146 |
+
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
| 147 |
+
int shared_total, shared_static;
|
| 148 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
| 149 |
+
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
|
| 150 |
+
device));
|
| 151 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
|
| 152 |
+
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
| 153 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 154 |
+
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
| 155 |
+
shared_optin - shared_static));
|
| 156 |
+
}
|
| 157 |
+
Py_END_ALLOW_THREADS;
|
| 158 |
+
|
| 159 |
+
if (PyErr_Occurred()) {
|
| 160 |
+
return NULL;
|
| 161 |
+
}
|
| 162 |
+
return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
| 163 |
+
n_spills, n_max_threads);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
|
| 167 |
+
int *numClusters, CUfunction func, const CUlaunchConfig *config);
|
| 168 |
+
|
| 169 |
+
typedef CUresult (*cuTensorMapEncodeTiled_t)(
|
| 170 |
+
CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
|
| 171 |
+
cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
|
| 172 |
+
const cuuint64_t *globalStrides, const cuuint32_t *boxDim,
|
| 173 |
+
const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
|
| 174 |
+
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
|
| 175 |
+
CUtensorMapFloatOOBfill oobFill);
|
| 176 |
+
|
| 177 |
+
#define defineGetFunctionHandle(name, symbolName) \
|
| 178 |
+
static symbolName##_t name() { \
|
| 179 |
+
/* Open the shared library */ \
|
| 180 |
+
void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \
|
| 181 |
+
if (!libHandle) { \
|
| 182 |
+
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \
|
| 183 |
+
return NULL; \
|
| 184 |
+
} \
|
| 185 |
+
/* Clear any existing error */ \
|
| 186 |
+
dlerror(); \
|
| 187 |
+
symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \
|
| 188 |
+
/* Check for errors */ \
|
| 189 |
+
const char *err = dlerror(); \
|
| 190 |
+
if (err) { \
|
| 191 |
+
PyErr_SetString(PyExc_RuntimeError, \
|
| 192 |
+
"Failed to retrieve " #symbolName " from libcuda.so.1"); \
|
| 193 |
+
dlclose(libHandle); \
|
| 194 |
+
return NULL; \
|
| 195 |
+
} \
|
| 196 |
+
return funcHandle; \
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
|
| 200 |
+
cuOccupancyMaxActiveClusters);
|
| 201 |
+
|
| 202 |
+
defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
|
| 203 |
+
cuTensorMapEncodeTiled);
|
| 204 |
+
|
| 205 |
+
static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
|
| 206 |
+
int clusterDim = -1, maxActiveClusters = -1;
|
| 207 |
+
int shared = 0;
|
| 208 |
+
CUfunction func;
|
| 209 |
+
|
| 210 |
+
if (!PyArg_ParseTuple(args, "Kii", &func, &shared, &clusterDim)) {
|
| 211 |
+
return NULL;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// Let each SM have one block
|
| 215 |
+
int maxActiveBlocks = 1;
|
| 216 |
+
Py_BEGIN_ALLOW_THREADS;
|
| 217 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
|
| 218 |
+
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared));
|
| 219 |
+
Py_END_ALLOW_THREADS;
|
| 220 |
+
|
| 221 |
+
CUlaunchAttribute launchAttr[1];
|
| 222 |
+
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
| 223 |
+
launchAttr[0].value.clusterDim.x = clusterDim;
|
| 224 |
+
launchAttr[0].value.clusterDim.y = 1;
|
| 225 |
+
launchAttr[0].value.clusterDim.z = 1;
|
| 226 |
+
CUlaunchConfig config;
|
| 227 |
+
config.gridDimX = clusterDim * maxActiveBlocks;
|
| 228 |
+
config.gridDimY = 1;
|
| 229 |
+
config.gridDimZ = 1;
|
| 230 |
+
config.blockDimX = 128;
|
| 231 |
+
config.blockDimY = 1;
|
| 232 |
+
config.blockDimZ = 1;
|
| 233 |
+
config.sharedMemBytes = shared;
|
| 234 |
+
config.hStream = 0;
|
| 235 |
+
config.numAttrs = 1;
|
| 236 |
+
config.attrs = launchAttr;
|
| 237 |
+
|
| 238 |
+
static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL;
|
| 239 |
+
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters,
|
| 240 |
+
getCuOccupancyMaxActiveClustersHandle);
|
| 241 |
+
|
| 242 |
+
Py_BEGIN_ALLOW_THREADS;
|
| 243 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
|
| 244 |
+
func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
|
| 245 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 246 |
+
cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
|
| 247 |
+
Py_END_ALLOW_THREADS;
|
| 248 |
+
return PyLong_FromLong(maxActiveClusters);
|
| 249 |
+
|
| 250 |
+
cleanup:
|
| 251 |
+
return NULL;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
|
| 255 |
+
long size;
|
| 256 |
+
if (!PyArg_ParseTuple(args, "l", &size)) {
|
| 257 |
+
return NULL;
|
| 258 |
+
}
|
| 259 |
+
if (size < 0) {
|
| 260 |
+
PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative");
|
| 261 |
+
return NULL;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
Py_BEGIN_ALLOW_THREADS;
|
| 265 |
+
|
| 266 |
+
// Ensure we have an active context.
|
| 267 |
+
CUcontext ctx = NULL;
|
| 268 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx));
|
| 269 |
+
if (!ctx) {
|
| 270 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 271 |
+
cuDevicePrimaryCtxRetain(&ctx, /*device=*/0));
|
| 272 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx));
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// We can't set the fifo size after running a kernel that calls printf. This
|
| 276 |
+
// is true even if the set() call is a nop and the new size is the same as the
|
| 277 |
+
// old size.
|
| 278 |
+
//
|
| 279 |
+
// This is unfriendly, so check if the old size matches the new size, and skip
|
| 280 |
+
// the set() call if so.
|
| 281 |
+
size_t oldSize = 0;
|
| 282 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 283 |
+
cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE));
|
| 284 |
+
if (oldSize != size) {
|
| 285 |
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
| 286 |
+
cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size));
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
Py_END_ALLOW_THREADS;
|
| 290 |
+
Py_RETURN_NONE;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
static PyObject *PyCUtensorMap_alloc(PyTypeObject *type, Py_ssize_t n_items) {
|
| 294 |
+
PyCUtensorMapObject *self = NULL;
|
| 295 |
+
void *mem = NULL;
|
| 296 |
+
size_t size = type->tp_basicsize;
|
| 297 |
+
|
| 298 |
+
if (posix_memalign(&mem, 128, size) != 0) {
|
| 299 |
+
PyErr_NoMemory();
|
| 300 |
+
return NULL;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
self = (PyCUtensorMapObject *)mem;
|
| 304 |
+
PyObject_INIT(self, type);
|
| 305 |
+
return (PyObject *)self;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
static void PyCUtensorMap_dealloc(PyObject *self) {
|
| 309 |
+
Py_TYPE(self)->tp_free(self);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
static void PyCUtensorMap_free(void *ptr) { free(ptr); }
|
| 313 |
+
|
| 314 |
+
// clang-format off
|
| 315 |
+
static PyTypeObject PyCUtensorMapType = {
|
| 316 |
+
PyVarObject_HEAD_INIT(NULL, 0)
|
| 317 |
+
.tp_name = "triton.backends.nvidia.PyCUtensorMap",
|
| 318 |
+
.tp_basicsize = sizeof(PyCUtensorMapObject),
|
| 319 |
+
.tp_itemsize = 0,
|
| 320 |
+
.tp_flags = Py_TPFLAGS_DEFAULT,
|
| 321 |
+
.tp_doc = "<PyCUtensorMap object>",
|
| 322 |
+
.tp_new = PyType_GenericNew,
|
| 323 |
+
.tp_alloc = PyCUtensorMap_alloc,
|
| 324 |
+
.tp_dealloc = (destructor)PyCUtensorMap_dealloc,
|
| 325 |
+
.tp_free = PyCUtensorMap_free,
|
| 326 |
+
};
|
| 327 |
+
// clang-format on
|
| 328 |
+
|
| 329 |
+
static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
|
| 330 |
+
unsigned long long global_address;
|
| 331 |
+
int swizzle;
|
| 332 |
+
int elemSize;
|
| 333 |
+
int elemType;
|
| 334 |
+
PyObject *blockSize;
|
| 335 |
+
PyObject *shape;
|
| 336 |
+
PyObject *strides;
|
| 337 |
+
int padding;
|
| 338 |
+
|
| 339 |
+
if (!PyArg_ParseTuple(args, "KiiiOOOi", &global_address, &swizzle, &elemSize,
|
| 340 |
+
&elemType, &blockSize, &shape, &strides, &padding)) {
|
| 341 |
+
return NULL;
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
|
| 345 |
+
(PyObject *)&PyCUtensorMapType, NULL);
|
| 346 |
+
if (!desc) {
|
| 347 |
+
return NULL;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
PyObject *blockSizeFast = NULL;
|
| 351 |
+
PyObject *shapeFast = NULL;
|
| 352 |
+
PyObject *stridesFast = NULL;
|
| 353 |
+
|
| 354 |
+
uint32_t blockSizeInt[5];
|
| 355 |
+
uint64_t shapeInt[5];
|
| 356 |
+
uint64_t stridesLL[5];
|
| 357 |
+
|
| 358 |
+
blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
|
| 359 |
+
if (!blockSizeFast)
|
| 360 |
+
goto cleanup;
|
| 361 |
+
int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
|
| 362 |
+
|
| 363 |
+
for (int i = 0; i < rank; ++i) {
|
| 364 |
+
PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
|
| 365 |
+
if (!PyLong_Check(item)) {
|
| 366 |
+
PyErr_SetString(PyExc_TypeError, "block size must be an int");
|
| 367 |
+
goto cleanup;
|
| 368 |
+
}
|
| 369 |
+
blockSizeInt[rank - i - 1] = PyLong_AsLongLong(item);
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
shapeFast = PySequence_Fast(shape, "shape must be a sequence");
|
| 373 |
+
if (!shapeFast)
|
| 374 |
+
goto cleanup;
|
| 375 |
+
|
| 376 |
+
if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
|
| 377 |
+
PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
|
| 378 |
+
goto cleanup;
|
| 379 |
+
}
|
| 380 |
+
for (int i = 0; i < rank; ++i) {
|
| 381 |
+
PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
|
| 382 |
+
if (!PyLong_Check(item)) {
|
| 383 |
+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
|
| 384 |
+
goto cleanup;
|
| 385 |
+
}
|
| 386 |
+
shapeInt[rank - i - 1] = PyLong_AsLong(item);
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
stridesFast = PySequence_Fast(strides, "strides must be a sequence");
|
| 390 |
+
if (!stridesFast)
|
| 391 |
+
goto cleanup;
|
| 392 |
+
|
| 393 |
+
if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
|
| 394 |
+
PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
|
| 395 |
+
goto cleanup;
|
| 396 |
+
}
|
| 397 |
+
for (int i = 0; i + 1 < rank; ++i) {
|
| 398 |
+
PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
|
| 399 |
+
if (!PyLong_Check(item)) {
|
| 400 |
+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
|
| 401 |
+
goto cleanup;
|
| 402 |
+
}
|
| 403 |
+
stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
|
| 404 |
+
}
|
| 405 |
+
stridesLL[rank - 1] =
|
| 406 |
+
shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
|
| 407 |
+
Py_DECREF(blockSizeFast);
|
| 408 |
+
blockSizeFast = NULL;
|
| 409 |
+
Py_DECREF(shapeFast);
|
| 410 |
+
shapeFast = NULL;
|
| 411 |
+
Py_DECREF(stridesFast);
|
| 412 |
+
stridesFast = NULL;
|
| 413 |
+
|
| 414 |
+
CUtensorMapFloatOOBfill fill =
|
| 415 |
+
(padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
|
| 416 |
+
: CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
|
| 417 |
+
|
| 418 |
+
uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
|
| 419 |
+
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
|
| 420 |
+
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
|
| 421 |
+
getCuTensorMapEncodeTiledHandle);
|
| 422 |
+
CUresult res = cuTensorMapEncodeTiled(
|
| 423 |
+
&desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
|
| 424 |
+
stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
|
| 425 |
+
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill);
|
| 426 |
+
if (res != CUDA_SUCCESS) {
|
| 427 |
+
const char *str;
|
| 428 |
+
cuGetErrorString(res, &str);
|
| 429 |
+
char err[4096] = {0};
|
| 430 |
+
size_t off = 0;
|
| 431 |
+
off += snprintf(
|
| 432 |
+
err + off, sizeof(err) - off,
|
| 433 |
+
"Triton Error [CUDA]: Failed to create tensor map descriptor: %s\n",
|
| 434 |
+
str ? str : "Unknown error");
|
| 435 |
+
off += snprintf(err + off, sizeof(err) - off,
|
| 436 |
+
"elemType=%d rank=%d global_address=0x%llx elemSize=%d "
|
| 437 |
+
"swizzle=%d padding=%d\n",
|
| 438 |
+
elemType, rank, (unsigned long long)global_address,
|
| 439 |
+
elemSize, swizzle, padding);
|
| 440 |
+
off += snprintf(err + off, sizeof(err) - off, "shape=[");
|
| 441 |
+
for (int i = 0; i < rank; ++i) {
|
| 442 |
+
off +=
|
| 443 |
+
snprintf(err + off, sizeof(err) - off, "%llu%s",
|
| 444 |
+
(unsigned long long)shapeInt[i], (i + 1 < rank) ? ", " : "");
|
| 445 |
+
}
|
| 446 |
+
off += snprintf(err + off, sizeof(err) - off, "]\n");
|
| 447 |
+
off += snprintf(err + off, sizeof(err) - off, "strides=[");
|
| 448 |
+
for (int i = 0; i < rank; ++i) {
|
| 449 |
+
off += snprintf(err + off, sizeof(err) - off, "%llu%s",
|
| 450 |
+
(unsigned long long)stridesLL[i],
|
| 451 |
+
(i + 1 < rank) ? ", " : "");
|
| 452 |
+
}
|
| 453 |
+
off += snprintf(err + off, sizeof(err) - off, "]\n");
|
| 454 |
+
off += snprintf(err + off, sizeof(err) - off, "blockSize=[");
|
| 455 |
+
for (int i = 0; i < rank; ++i) {
|
| 456 |
+
off += snprintf(err + off, sizeof(err) - off, "%u%s",
|
| 457 |
+
(unsigned)blockSizeInt[i], (i + 1 < rank) ? ", " : "");
|
| 458 |
+
}
|
| 459 |
+
off += snprintf(err + off, sizeof(err) - off, "] elementStrides=[");
|
| 460 |
+
for (int i = 0; i < rank; ++i) {
|
| 461 |
+
off += snprintf(err + off, sizeof(err) - off, "%u%s",
|
| 462 |
+
(unsigned)elementStrides[i], (i + 1 < rank) ? ", " : "");
|
| 463 |
+
}
|
| 464 |
+
off += snprintf(err + off, sizeof(err) - off, "]\n");
|
| 465 |
+
PyErr_SetString(PyExc_RuntimeError, err);
|
| 466 |
+
|
| 467 |
+
goto cleanup;
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
return (PyObject *)desc;
|
| 471 |
+
|
| 472 |
+
cleanup:
|
| 473 |
+
Py_XDECREF(blockSizeFast);
|
| 474 |
+
Py_XDECREF(shapeFast);
|
| 475 |
+
Py_XDECREF(stridesFast);
|
| 476 |
+
Py_XDECREF(desc);
|
| 477 |
+
return NULL;
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
static PyMethodDef ModuleMethods[] = {
|
| 481 |
+
{"load_binary", loadBinary, METH_VARARGS,
|
| 482 |
+
"Load provided cubin into CUDA driver"},
|
| 483 |
+
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
| 484 |
+
"Get the properties for a given device"},
|
| 485 |
+
{"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS,
|
| 486 |
+
"Python interface for cuOccupancyMaxActiveClusters function"},
|
| 487 |
+
{"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS,
|
| 488 |
+
"Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which "
|
| 489 |
+
"controls how many bytes can be streamed from kernels before data starts "
|
| 490 |
+
"being dropped. This inherits all the limitations of this call; in "
|
| 491 |
+
"particular it's an error to change this value after launching any kernel "
|
| 492 |
+
"that calls printf()."},
|
| 493 |
+
{"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
|
| 494 |
+
|
| 495 |
+
{NULL, NULL, 0, NULL} // sentinel
|
| 496 |
+
};
|
| 497 |
+
|
| 498 |
+
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
|
| 499 |
+
NULL, // documentation
|
| 500 |
+
-1, // size
|
| 501 |
+
ModuleMethods};
|
| 502 |
+
|
| 503 |
+
PyMODINIT_FUNC PyInit_cuda_utils(void) {
|
| 504 |
+
if (PyType_Ready(&PyCUtensorMapType) < 0) {
|
| 505 |
+
return NULL;
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
PyObject *m = PyModule_Create(&ModuleDef);
|
| 509 |
+
if (m == NULL) {
|
| 510 |
+
return NULL;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
PyModule_AddFunctions(m, ModuleMethods);
|
| 514 |
+
Py_INCREF(&PyCUtensorMapType);
|
| 515 |
+
PyModule_AddObject(m, "PyCUtensorMap", (PyObject *)&PyCUtensorMapType);
|
| 516 |
+
|
| 517 |
+
return m;
|
| 518 |
+
}
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import triton
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from triton import knobs
|
| 8 |
+
from triton.runtime.build import compile_module_from_src
|
| 9 |
+
from triton.runtime import _allocation
|
| 10 |
+
from triton.backends.compiler import GPUTarget
|
| 11 |
+
from triton.backends.driver import GPUDriver
|
| 12 |
+
|
| 13 |
+
dirname = os.path.dirname(os.path.realpath(__file__))
|
| 14 |
+
include_dirs = [os.path.join(dirname, "include")]
|
| 15 |
+
libdevice_dir = os.path.join(dirname, "lib")
|
| 16 |
+
libraries = ['libcuda.so.1']
|
| 17 |
+
PyCUtensorMap = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@functools.lru_cache()
|
| 21 |
+
def libcuda_dirs():
|
| 22 |
+
if env_libcuda_path := knobs.nvidia.libcuda_path:
|
| 23 |
+
return [env_libcuda_path]
|
| 24 |
+
|
| 25 |
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
|
| 26 |
+
# each line looks like the following:
|
| 27 |
+
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
|
| 28 |
+
locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
|
| 29 |
+
dirs = [os.path.dirname(loc) for loc in locs]
|
| 30 |
+
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
|
| 31 |
+
if env_ld_library_path and not dirs:
|
| 32 |
+
dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))]
|
| 33 |
+
msg = 'libcuda.so cannot found!\n'
|
| 34 |
+
if locs:
|
| 35 |
+
msg += 'Possible files are located at %s.' % str(locs)
|
| 36 |
+
msg += 'Please create a symlink of libcuda.so to any of the files.'
|
| 37 |
+
else:
|
| 38 |
+
msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"'
|
| 39 |
+
msg += ' (requires sudo) to refresh the linker cache.'
|
| 40 |
+
assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg
|
| 41 |
+
return dirs
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@functools.lru_cache()
|
| 45 |
+
def library_dirs():
|
| 46 |
+
return [libdevice_dir, *libcuda_dirs()]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ------------------------
|
| 50 |
+
# Utils
|
| 51 |
+
# ------------------------
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class CudaUtils(object):
|
| 55 |
+
|
| 56 |
+
def __new__(cls):
|
| 57 |
+
if not hasattr(cls, "instance"):
|
| 58 |
+
cls.instance = super(CudaUtils, cls).__new__(cls)
|
| 59 |
+
return cls.instance
|
| 60 |
+
|
| 61 |
+
def __init__(self):
|
| 62 |
+
mod = compile_module_from_src(
|
| 63 |
+
src=Path(os.path.join(dirname, "driver.c")).read_text(),
|
| 64 |
+
name="cuda_utils",
|
| 65 |
+
library_dirs=library_dirs(),
|
| 66 |
+
include_dirs=include_dirs,
|
| 67 |
+
libraries=libraries,
|
| 68 |
+
)
|
| 69 |
+
global PyCUtensorMap
|
| 70 |
+
PyCUtensorMap = mod.PyCUtensorMap
|
| 71 |
+
self.load_binary = mod.load_binary
|
| 72 |
+
self.get_device_properties = mod.get_device_properties
|
| 73 |
+
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
|
| 74 |
+
self.set_printf_fifo_size = mod.set_printf_fifo_size
|
| 75 |
+
self.fill_tma_descriptor = mod.fill_tma_descriptor
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ------------------------
|
| 79 |
+
# Launcher
|
| 80 |
+
# ------------------------
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def ty_to_cpp(ty):
|
| 84 |
+
if ty[0] == '*':
|
| 85 |
+
return "CUdeviceptr"
|
| 86 |
+
if ty.startswith("tensordesc"):
|
| 87 |
+
return "CUtensorMap"
|
| 88 |
+
return {
|
| 89 |
+
"i1": "int8_t",
|
| 90 |
+
"i8": "int8_t",
|
| 91 |
+
"i16": "int16_t",
|
| 92 |
+
"i32": "int32_t",
|
| 93 |
+
"i64": "int64_t",
|
| 94 |
+
"u1": "uint8_t",
|
| 95 |
+
"u8": "uint8_t",
|
| 96 |
+
"u16": "uint16_t",
|
| 97 |
+
"u32": "uint32_t",
|
| 98 |
+
"u64": "uint64_t",
|
| 99 |
+
"fp16": "double",
|
| 100 |
+
"bf16": "double",
|
| 101 |
+
"fp32": "double",
|
| 102 |
+
"f32": "double",
|
| 103 |
+
"fp64": "double",
|
| 104 |
+
"nvTmaDesc": "CUtensorMap",
|
| 105 |
+
}[ty]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
FLOAT_STORAGE_TYPE = {
|
| 109 |
+
"fp16": "uint16_t",
|
| 110 |
+
"bf16": "uint16_t",
|
| 111 |
+
"fp32": "uint32_t",
|
| 112 |
+
"f32": "uint32_t",
|
| 113 |
+
"fp64": "uint64_t",
|
| 114 |
+
}
|
| 115 |
+
FLOAT_PACK_FUNCTION = {
|
| 116 |
+
"fp16": "pack_fp16",
|
| 117 |
+
"bf16": "pack_bf16",
|
| 118 |
+
"fp32": "pack_fp32",
|
| 119 |
+
"f32": "pack_fp32",
|
| 120 |
+
"fp64": "pack_fp64",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
_BASE_ARGS_FORMAT = "iiiKKppOOOOOO"
|
| 124 |
+
_BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def make_launcher(constants, signature, tensordesc_meta):
|
| 128 |
+
|
| 129 |
+
def _expand_signature(signature):
|
| 130 |
+
output = []
|
| 131 |
+
tensordesc_idx = 0
|
| 132 |
+
# Expand tensor descriptor arguments into either nvTmaDesc, shape and
|
| 133 |
+
# strides, or base pointer, shape and strides depending on whether the
|
| 134 |
+
# kernel was lowered to use the nvTmaDesc or not.
|
| 135 |
+
for sig in signature:
|
| 136 |
+
if isinstance(sig, str) and sig.startswith("tensordesc"):
|
| 137 |
+
meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
|
| 138 |
+
tensordesc_idx += 1
|
| 139 |
+
|
| 140 |
+
match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
|
| 141 |
+
dtype = match.group(1)
|
| 142 |
+
shape = match.group(2)
|
| 143 |
+
ndim = shape.count(",") + 1
|
| 144 |
+
|
| 145 |
+
if meta is None:
|
| 146 |
+
output.append("*" + dtype)
|
| 147 |
+
# Currently the host side tensor descriptors get passed in as a
|
| 148 |
+
# tensor desc, shape, and strides. We have no way to use these
|
| 149 |
+
# shape and strides when processing tensor descriptors which is
|
| 150 |
+
# why we provide our own decomposition above. Sadly this means
|
| 151 |
+
# we have to pass the shape and strides twice.
|
| 152 |
+
for _ in range(2 * ndim):
|
| 153 |
+
output.append("i64")
|
| 154 |
+
output.append("i1")
|
| 155 |
+
else:
|
| 156 |
+
output.append("nvTmaDesc")
|
| 157 |
+
|
| 158 |
+
for _ in range(ndim):
|
| 159 |
+
output.append("i32")
|
| 160 |
+
for _ in range(ndim):
|
| 161 |
+
output.append("i64")
|
| 162 |
+
else:
|
| 163 |
+
output.append(sig)
|
| 164 |
+
|
| 165 |
+
assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
|
| 166 |
+
return output
|
| 167 |
+
|
| 168 |
+
def _flatten_signature(sig, output):
|
| 169 |
+
# Flatten tuples
|
| 170 |
+
if isinstance(sig, tuple):
|
| 171 |
+
for x in sig:
|
| 172 |
+
_flatten_signature(x, output)
|
| 173 |
+
else:
|
| 174 |
+
output.append(sig)
|
| 175 |
+
|
| 176 |
+
def _extracted_type(ty):
|
| 177 |
+
if isinstance(ty, tuple):
|
| 178 |
+
val = ','.join(map(_extracted_type, ty))
|
| 179 |
+
return f"[{val}]"
|
| 180 |
+
if ty[0] == '*':
|
| 181 |
+
return "PyObject*"
|
| 182 |
+
if ty in ("constexpr", "nvTmaDesc"):
|
| 183 |
+
return "PyObject*"
|
| 184 |
+
return ty_to_cpp(ty)
|
| 185 |
+
|
| 186 |
+
def format_of(ty):
|
| 187 |
+
if isinstance(ty, tuple):
|
| 188 |
+
val = ''.join(map(format_of, ty))
|
| 189 |
+
return f"({val})"
|
| 190 |
+
if ty[0] == '*':
|
| 191 |
+
return "O"
|
| 192 |
+
if ty in ("constexpr", "nvTmaDesc"):
|
| 193 |
+
return "O"
|
| 194 |
+
if ty.startswith("tensordesc"):
|
| 195 |
+
return "O"
|
| 196 |
+
return {
|
| 197 |
+
"double": "d",
|
| 198 |
+
"long": "l",
|
| 199 |
+
"int8_t": "b",
|
| 200 |
+
"int16_t": "h",
|
| 201 |
+
"int32_t": "i",
|
| 202 |
+
"int64_t": "L",
|
| 203 |
+
"uint8_t": "B",
|
| 204 |
+
"uint16_t": "H",
|
| 205 |
+
"uint32_t": "I",
|
| 206 |
+
"uint64_t": "K",
|
| 207 |
+
}[ty_to_cpp(ty)]
|
| 208 |
+
|
| 209 |
+
expand_signature = _expand_signature(signature.values())
|
| 210 |
+
signature = {i: s for i, s in enumerate(expand_signature)}
|
| 211 |
+
|
| 212 |
+
args_format = ''.join([format_of(ty) for ty in signature.values()])
|
| 213 |
+
format = _BASE_ARGS_FORMAT + args_format
|
| 214 |
+
|
| 215 |
+
flat_signature = []
|
| 216 |
+
for sig in signature.values():
|
| 217 |
+
_flatten_signature(sig, flat_signature)
|
| 218 |
+
signature = {i: s for i, s in enumerate(flat_signature)}
|
| 219 |
+
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
|
| 220 |
+
# Record the end of regular arguments;
|
| 221 |
+
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
|
| 222 |
+
arg_decl_list = []
|
| 223 |
+
for i, ty in signature.items():
|
| 224 |
+
if ty == "constexpr":
|
| 225 |
+
continue
|
| 226 |
+
if ty in FLOAT_STORAGE_TYPE:
|
| 227 |
+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
|
| 228 |
+
else:
|
| 229 |
+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
|
| 230 |
+
arg_decls = ', '.join(arg_decl_list)
|
| 231 |
+
internal_args_list = []
|
| 232 |
+
for i, ty in signature.items():
|
| 233 |
+
if ty[0] == "*":
|
| 234 |
+
internal_args_list.append(f"ptr_info{i}.dev_ptr")
|
| 235 |
+
elif ty in FLOAT_STORAGE_TYPE:
|
| 236 |
+
internal_args_list.append(f"_arg{i}_storage")
|
| 237 |
+
elif ty == "nvTmaDesc":
|
| 238 |
+
# Note: we have to dereference the pointer
|
| 239 |
+
internal_args_list.append(f"*tma_ptr{i}")
|
| 240 |
+
elif ty != "constexpr":
|
| 241 |
+
internal_args_list.append(f"_arg{i}")
|
| 242 |
+
params = range(len(signature))
|
| 243 |
+
|
| 244 |
+
# generate glue code
|
| 245 |
+
newline = '\n '
|
| 246 |
+
ptr_decls = [
|
| 247 |
+
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
|
| 248 |
+
for i, ty in signature.items()
|
| 249 |
+
if ty[0] == "*"
|
| 250 |
+
]
|
| 251 |
+
tma_decls = [
|
| 252 |
+
f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
|
| 253 |
+
if ty == "nvTmaDesc"
|
| 254 |
+
]
|
| 255 |
+
float_storage_decls = [
|
| 256 |
+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
|
| 257 |
+
for i, ty in signature.items()
|
| 258 |
+
if ty in FLOAT_STORAGE_TYPE
|
| 259 |
+
]
|
| 260 |
+
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
|
| 261 |
+
params.append("&global_scratch")
|
| 262 |
+
params.append("&profile_scratch")
|
| 263 |
+
src = f"""
|
| 264 |
+
#include \"cuda.h\"
|
| 265 |
+
#include <dlfcn.h>
|
| 266 |
+
#include <stdbool.h>
|
| 267 |
+
#include <stdlib.h>
|
| 268 |
+
#define PY_SSIZE_T_CLEAN
|
| 269 |
+
#include <Python.h>
|
| 270 |
+
|
| 271 |
+
typedef struct {{
|
| 272 |
+
PyObject_HEAD;
|
| 273 |
+
_Alignas(128) CUtensorMap tensorMap;
|
| 274 |
+
}} PyCUtensorMapObject;
|
| 275 |
+
|
| 276 |
+
static inline void gpuAssert(CUresult code, const char *file, int line)
|
| 277 |
+
{{
|
| 278 |
+
if (code != CUDA_SUCCESS)
|
| 279 |
+
{{
|
| 280 |
+
const char* prefix = "Triton Error [CUDA]: ";
|
| 281 |
+
const char* str;
|
| 282 |
+
cuGetErrorString(code, &str);
|
| 283 |
+
char err[1024] = {{0}};
|
| 284 |
+
strcat(err, prefix);
|
| 285 |
+
strcat(err, str);
|
| 286 |
+
PyGILState_STATE gil_state;
|
| 287 |
+
gil_state = PyGILState_Ensure();
|
| 288 |
+
PyErr_SetString(PyExc_RuntimeError, err);
|
| 289 |
+
PyGILState_Release(gil_state);
|
| 290 |
+
}}
|
| 291 |
+
}}
|
| 292 |
+
|
| 293 |
+
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
| 294 |
+
|
| 295 |
+
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
|
| 296 |
+
|
| 297 |
+
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
|
| 298 |
+
// Open the shared library
|
| 299 |
+
void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
|
| 300 |
+
if (!handle) {{
|
| 301 |
+
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1");
|
| 302 |
+
return NULL;
|
| 303 |
+
}}
|
| 304 |
+
// Clear any existing error
|
| 305 |
+
dlerror();
|
| 306 |
+
cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
|
| 307 |
+
// Check for errors
|
| 308 |
+
const char *dlsym_error = dlerror();
|
| 309 |
+
if (dlsym_error) {{
|
| 310 |
+
PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1");
|
| 311 |
+
return NULL;
|
| 312 |
+
}}
|
| 313 |
+
return cuLaunchKernelExHandle;
|
| 314 |
+
}}
|
| 315 |
+
|
| 316 |
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
| 317 |
+
void *params[] = {{ {', '.join(params)} }};
|
| 318 |
+
if (gridX*gridY*gridZ > 0) {{
|
| 319 |
+
// 4 attributes that we can currently pass maximum
|
| 320 |
+
CUlaunchAttribute launchAttr[4];
|
| 321 |
+
static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
|
| 322 |
+
if (cuLaunchKernelExHandle == NULL) {{
|
| 323 |
+
cuLaunchKernelExHandle = getLaunchKernelExHandle();
|
| 324 |
+
}}
|
| 325 |
+
CUlaunchConfig config;
|
| 326 |
+
config.gridDimX = gridX * num_ctas;
|
| 327 |
+
config.gridDimY = gridY;
|
| 328 |
+
config.gridDimZ = gridZ;
|
| 329 |
+
|
| 330 |
+
config.blockDimX = 32 * num_warps;
|
| 331 |
+
config.blockDimY = 1;
|
| 332 |
+
config.blockDimZ = 1;
|
| 333 |
+
config.sharedMemBytes = shared_memory;
|
| 334 |
+
config.hStream = stream;
|
| 335 |
+
config.attrs = launchAttr;
|
| 336 |
+
int num_attrs = 0;
|
| 337 |
+
|
| 338 |
+
if (launch_pdl != 0) {{
|
| 339 |
+
CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
|
| 340 |
+
launchAttr[num_attrs] = pdlAttr;
|
| 341 |
+
++num_attrs;
|
| 342 |
+
}}
|
| 343 |
+
|
| 344 |
+
if (launch_cooperative_grid != 0) {{
|
| 345 |
+
CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
|
| 346 |
+
launchAttr[num_attrs] = coopAttr;
|
| 347 |
+
++num_attrs;
|
| 348 |
+
}}
|
| 349 |
+
|
| 350 |
+
if (num_ctas != 1) {{
|
| 351 |
+
CUlaunchAttribute clusterAttr = {{}};
|
| 352 |
+
clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
| 353 |
+
clusterAttr.value.clusterDim.x = num_ctas;
|
| 354 |
+
clusterAttr.value.clusterDim.y = 1;
|
| 355 |
+
clusterAttr.value.clusterDim.z = 1;
|
| 356 |
+
launchAttr[num_attrs] = clusterAttr;
|
| 357 |
+
++num_attrs;
|
| 358 |
+
|
| 359 |
+
CUlaunchAttribute clusterSchedulingAttr = {{}};
|
| 360 |
+
clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
|
| 361 |
+
clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
|
| 362 |
+
launchAttr[num_attrs] = clusterSchedulingAttr;
|
| 363 |
+
++num_attrs;
|
| 364 |
+
}}
|
| 365 |
+
|
| 366 |
+
// num_ctas == 16 is non-portable. Does work for H100 and B200 tho
|
| 367 |
+
config.numAttrs = num_attrs;
|
| 368 |
+
if (num_ctas == 16) {{
|
| 369 |
+
CUDA_CHECK(cuFuncSetAttribute(
|
| 370 |
+
function,
|
| 371 |
+
CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
|
| 372 |
+
1
|
| 373 |
+
));
|
| 374 |
+
}}
|
| 375 |
+
|
| 376 |
+
CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
|
| 377 |
+
}}
|
| 378 |
+
}}
|
| 379 |
+
|
| 380 |
+
typedef struct _DevicePtrInfo {{
|
| 381 |
+
CUdeviceptr dev_ptr;
|
| 382 |
+
bool valid;
|
| 383 |
+
}} DevicePtrInfo;
|
| 384 |
+
|
| 385 |
+
static PyObject* data_ptr_str = NULL;
|
| 386 |
+
static PyObject* py_tensor_map_type = NULL;
|
| 387 |
+
|
| 388 |
+
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
| 389 |
+
DevicePtrInfo ptr_info;
|
| 390 |
+
ptr_info.dev_ptr = 0;
|
| 391 |
+
ptr_info.valid = true;
|
| 392 |
+
if (PyLong_Check(obj)) {{
|
| 393 |
+
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
|
| 394 |
+
return ptr_info;
|
| 395 |
+
}}
|
| 396 |
+
if (obj == Py_None) {{
|
| 397 |
+
// valid nullptr
|
| 398 |
+
return ptr_info;
|
| 399 |
+
}}
|
| 400 |
+
PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
|
| 401 |
+
if (!ret) {{
|
| 402 |
+
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
| 403 |
+
ptr_info.valid = false;
|
| 404 |
+
goto cleanup;
|
| 405 |
+
}}
|
| 406 |
+
if (!PyLong_Check(ret)) {{
|
| 407 |
+
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
| 408 |
+
ptr_info.valid = false;
|
| 409 |
+
goto cleanup;
|
| 410 |
+
}}
|
| 411 |
+
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
|
| 412 |
+
if(!ptr_info.dev_ptr)
|
| 413 |
+
return ptr_info;
|
| 414 |
+
uint64_t dev_ptr;
|
| 415 |
+
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
| 416 |
+
if (status == CUDA_ERROR_INVALID_VALUE) {{
|
| 417 |
+
PyErr_Format(PyExc_ValueError,
|
| 418 |
+
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
| 419 |
+
ptr_info.valid = false;
|
| 420 |
+
}} else if (status != CUDA_SUCCESS) {{
|
| 421 |
+
CUDA_CHECK(status); // Catch any other cuda API errors
|
| 422 |
+
ptr_info.valid = false;
|
| 423 |
+
}}
|
| 424 |
+
ptr_info.dev_ptr = dev_ptr;
|
| 425 |
+
cleanup:
|
| 426 |
+
Py_XDECREF(ret);
|
| 427 |
+
return ptr_info;
|
| 428 |
+
|
| 429 |
+
}}
|
| 430 |
+
|
| 431 |
+
static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
|
| 432 |
+
if (sizeof(CUtensorMap*) != 8) {{
|
| 433 |
+
PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
|
| 434 |
+
return NULL;
|
| 435 |
+
}}
|
| 436 |
+
|
| 437 |
+
if (Py_TYPE(obj) != (PyTypeObject*)py_tensor_map_type) {{
|
| 438 |
+
PyErr_Format(PyExc_TypeError, "object must be of type PyCUtensorMap, got %s", Py_TYPE(obj)->tp_name);
|
| 439 |
+
return NULL;
|
| 440 |
+
}}
|
| 441 |
+
|
| 442 |
+
CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
|
| 443 |
+
uintptr_t align_128 = (uintptr_t)map & (128 - 1);
|
| 444 |
+
if (align_128 != 0) {{
|
| 445 |
+
PyErr_Format(PyExc_ValueError, "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", align_128);
|
| 446 |
+
return NULL;
|
| 447 |
+
}}
|
| 448 |
+
return map;
|
| 449 |
+
}}
|
| 450 |
+
|
| 451 |
+
static void ensureCudaContext() {{
|
| 452 |
+
CUcontext pctx;
|
| 453 |
+
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
| 454 |
+
if (!pctx) {{
|
| 455 |
+
// Ensure device context.
|
| 456 |
+
CUdevice device;
|
| 457 |
+
CUDA_CHECK(cuDeviceGet(&device, 0));
|
| 458 |
+
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
| 459 |
+
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
| 460 |
+
}}
|
| 461 |
+
}}
|
| 462 |
+
|
| 463 |
+
static uint16_t pack_fp16(double f) {{
|
| 464 |
+
uint16_t result;
|
| 465 |
+
// from https://github.com/python/pythoncapi-compat
|
| 466 |
+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
|
| 467 |
+
_PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
| 468 |
+
#else
|
| 469 |
+
PyFloat_Pack2(f, (unsigned char*)&result, 1);
|
| 470 |
+
#endif
|
| 471 |
+
return result;
|
| 472 |
+
}}
|
| 473 |
+
|
| 474 |
+
static uint16_t pack_bf16(double f) {{
|
| 475 |
+
float f32 = (float)f;
|
| 476 |
+
uint32_t u32 = *(uint32_t*)&f32;
|
| 477 |
+
return (uint16_t)(u32 >> 16);
|
| 478 |
+
}}
|
| 479 |
+
|
| 480 |
+
static uint32_t pack_fp32(double f) {{
|
| 481 |
+
float f32 = (float)f;
|
| 482 |
+
return *(uint32_t*)&f32;
|
| 483 |
+
}}
|
| 484 |
+
|
| 485 |
+
static uint64_t pack_fp64(double f) {{
|
| 486 |
+
return *(uint64_t*)&f;
|
| 487 |
+
}}
|
| 488 |
+
|
| 489 |
+
static PyObject* launch(PyObject* self, PyObject* args) {{
|
| 490 |
+
// ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
|
| 491 |
+
ensureCudaContext();
|
| 492 |
+
|
| 493 |
+
int gridX, gridY, gridZ;
|
| 494 |
+
uint64_t _stream;
|
| 495 |
+
uint64_t _function;
|
| 496 |
+
int launch_cooperative_grid;
|
| 497 |
+
int launch_pdl;
|
| 498 |
+
PyObject *launch_enter_hook = NULL;
|
| 499 |
+
PyObject *launch_exit_hook = NULL;
|
| 500 |
+
PyObject *kernel_metadata = NULL;
|
| 501 |
+
PyObject *launch_metadata = NULL;
|
| 502 |
+
PyObject *global_scratch_obj = NULL;
|
| 503 |
+
PyObject *profile_scratch_obj = NULL;
|
| 504 |
+
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
|
| 505 |
+
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
|
| 506 |
+
&_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj, &profile_scratch_obj,
|
| 507 |
+
&kernel_metadata, &launch_metadata,
|
| 508 |
+
&launch_enter_hook, &launch_exit_hook{args_list})) {{
|
| 509 |
+
return NULL;
|
| 510 |
+
}}
|
| 511 |
+
|
| 512 |
+
int num_warps, num_ctas, shared_memory;
|
| 513 |
+
if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
|
| 514 |
+
PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
|
| 515 |
+
return NULL;
|
| 516 |
+
}}
|
| 517 |
+
|
| 518 |
+
// extract launch metadata
|
| 519 |
+
if (launch_enter_hook != Py_None){{
|
| 520 |
+
PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
|
| 521 |
+
if (!ret)
|
| 522 |
+
return NULL;
|
| 523 |
+
Py_DECREF(ret);
|
| 524 |
+
}}
|
| 525 |
+
|
| 526 |
+
CUdeviceptr global_scratch = 0;
|
| 527 |
+
if (global_scratch_obj != Py_None) {{
|
| 528 |
+
DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
|
| 529 |
+
if (!global_scratch_info.valid) {{
|
| 530 |
+
return NULL;
|
| 531 |
+
}}
|
| 532 |
+
global_scratch = global_scratch_info.dev_ptr;
|
| 533 |
+
}}
|
| 534 |
+
|
| 535 |
+
CUdeviceptr profile_scratch = 0;
|
| 536 |
+
if (profile_scratch_obj != Py_None) {{
|
| 537 |
+
DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
|
| 538 |
+
if (!profile_scratch_info.valid) {{
|
| 539 |
+
return NULL;
|
| 540 |
+
}}
|
| 541 |
+
profile_scratch = profile_scratch_info.dev_ptr;
|
| 542 |
+
}}
|
| 543 |
+
|
| 544 |
+
// raise exception asap
|
| 545 |
+
{newline.join(ptr_decls)}
|
| 546 |
+
{newline.join(tma_decls)}
|
| 547 |
+
{newline.join(float_storage_decls)}
|
| 548 |
+
Py_BEGIN_ALLOW_THREADS;
|
| 549 |
+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
|
| 550 |
+
Py_END_ALLOW_THREADS;
|
| 551 |
+
if (PyErr_Occurred()) {{
|
| 552 |
+
return NULL;
|
| 553 |
+
}}
|
| 554 |
+
|
| 555 |
+
if(launch_exit_hook != Py_None){{
|
| 556 |
+
PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
|
| 557 |
+
if (!ret)
|
| 558 |
+
return NULL;
|
| 559 |
+
Py_DECREF(ret);
|
| 560 |
+
}}
|
| 561 |
+
|
| 562 |
+
Py_RETURN_NONE;
|
| 563 |
+
}}
|
| 564 |
+
|
| 565 |
+
static PyMethodDef ModuleMethods[] = {{
|
| 566 |
+
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
| 567 |
+
{{NULL, NULL, 0, NULL}} // sentinel
|
| 568 |
+
}};
|
| 569 |
+
|
| 570 |
+
static struct PyModuleDef ModuleDef = {{
|
| 571 |
+
PyModuleDef_HEAD_INIT,
|
| 572 |
+
\"__triton_launcher\",
|
| 573 |
+
NULL, //documentation
|
| 574 |
+
-1, //size
|
| 575 |
+
ModuleMethods
|
| 576 |
+
}};
|
| 577 |
+
|
| 578 |
+
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
| 579 |
+
data_ptr_str = PyUnicode_InternFromString("data_ptr");
|
| 580 |
+
if(data_ptr_str == NULL) {{
|
| 581 |
+
return NULL;
|
| 582 |
+
}}
|
| 583 |
+
PyObject* driver_mod = PyImport_ImportModule("triton.backends.nvidia.driver");
|
| 584 |
+
if (driver_mod == NULL) {{
|
| 585 |
+
return NULL;
|
| 586 |
+
}}
|
| 587 |
+
py_tensor_map_type = PyObject_GetAttrString(driver_mod, "PyCUtensorMap");
|
| 588 |
+
if (py_tensor_map_type == NULL) {{
|
| 589 |
+
return NULL;
|
| 590 |
+
}}
|
| 591 |
+
|
| 592 |
+
PyObject *m = PyModule_Create(&ModuleDef);
|
| 593 |
+
if(m == NULL) {{
|
| 594 |
+
return NULL;
|
| 595 |
+
}}
|
| 596 |
+
PyModule_AddFunctions(m, ModuleMethods);
|
| 597 |
+
return m;
|
| 598 |
+
}}
|
| 599 |
+
"""
|
| 600 |
+
return src
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
# The TMA dtype enum values are slightly different on host vs device...
|
| 604 |
+
TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
|
| 605 |
+
TMA_DTYPE_DEVICE_TO_HOST[8] = 10
|
| 606 |
+
TMA_DTYPE_DEVICE_TO_HOST[9] = 8
|
| 607 |
+
TMA_DTYPE_DEVICE_TO_HOST[10] = 9
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def make_tensordesc_arg(arg, metadata):
|
| 611 |
+
if metadata is None:
|
| 612 |
+
# Currently the host side tensor descriptors get decomposed in
|
| 613 |
+
# the frontend to tensor desc, shape, and strides. We have no
|
| 614 |
+
# way to use these shape and strides when processing tensor
|
| 615 |
+
# descriptors which is why we provide our own decomposition
|
| 616 |
+
# above. Sadly this means we have to pass the shape and strides
|
| 617 |
+
# twice.
|
| 618 |
+
return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
|
| 619 |
+
|
| 620 |
+
swizzle = metadata["swizzle"]
|
| 621 |
+
elem_size = metadata["elem_size"]
|
| 622 |
+
elem_type = metadata["elem_type"]
|
| 623 |
+
block_size = metadata["block_size"]
|
| 624 |
+
fp4_padded = metadata["fp4_padded"]
|
| 625 |
+
|
| 626 |
+
shape = arg.shape
|
| 627 |
+
strides = arg.strides
|
| 628 |
+
assert strides[-1] == 1
|
| 629 |
+
padding = 1 if arg.padding == "nan" else 0
|
| 630 |
+
|
| 631 |
+
if fp4_padded:
|
| 632 |
+
shape = list(shape)
|
| 633 |
+
shape[-1] *= 2
|
| 634 |
+
|
| 635 |
+
cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
|
| 636 |
+
arg.base.data_ptr(),
|
| 637 |
+
swizzle,
|
| 638 |
+
elem_size,
|
| 639 |
+
TMA_DTYPE_DEVICE_TO_HOST[elem_type],
|
| 640 |
+
block_size,
|
| 641 |
+
shape,
|
| 642 |
+
strides,
|
| 643 |
+
padding,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
return [cu_tensor_map, *shape, *strides]
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def wrap_handle_tensordesc(launcher, signature, tensordesc_meta):
|
| 650 |
+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
|
| 651 |
+
if not has_tensor_desc_arg:
|
| 652 |
+
return launcher
|
| 653 |
+
|
| 654 |
+
tensordesc_indices = set(
|
| 655 |
+
[i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
|
| 656 |
+
assert not tensordesc_meta or len(tensordesc_meta) == len(tensordesc_indices)
|
| 657 |
+
if not tensordesc_meta:
|
| 658 |
+
tensordesc_meta = [None] * len(tensordesc_indices)
|
| 659 |
+
|
| 660 |
+
def inner(*args):
|
| 661 |
+
final_args = list(args[:_BASE_ARGS_FORMAT_LEN])
|
| 662 |
+
tensordesc_idx = 0
|
| 663 |
+
for i, arg in enumerate(args[_BASE_ARGS_FORMAT_LEN:]):
|
| 664 |
+
if i in tensordesc_indices:
|
| 665 |
+
final_args.extend(make_tensordesc_arg(arg, tensordesc_meta[tensordesc_idx]))
|
| 666 |
+
tensordesc_idx += 1
|
| 667 |
+
else:
|
| 668 |
+
final_args.append(arg)
|
| 669 |
+
return launcher(*final_args)
|
| 670 |
+
|
| 671 |
+
return inner
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
class CudaLauncher(object):
|
| 675 |
+
|
| 676 |
+
def __init__(self, src, metadata):
|
| 677 |
+
constants = src.constants if hasattr(src, "constants") else dict()
|
| 678 |
+
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
|
| 679 |
+
constants = {arg_idx(idx): value for idx, value in constants.items()}
|
| 680 |
+
signature = {idx: value for idx, value in src.signature.items()}
|
| 681 |
+
tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
|
| 682 |
+
src = make_launcher(constants, signature, tensordesc_meta)
|
| 683 |
+
mod = compile_module_from_src(
|
| 684 |
+
src=src,
|
| 685 |
+
name="__triton_launcher",
|
| 686 |
+
library_dirs=library_dirs(),
|
| 687 |
+
include_dirs=include_dirs,
|
| 688 |
+
libraries=libraries,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
self.num_ctas = getattr(metadata, "num_ctas", 1)
|
| 692 |
+
self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
|
| 693 |
+
self.global_scratch_size = metadata.global_scratch_size
|
| 694 |
+
self.global_scratch_align = metadata.global_scratch_align
|
| 695 |
+
self.profile_scratch_size = metadata.profile_scratch_size
|
| 696 |
+
self.profile_scratch_align = metadata.profile_scratch_align
|
| 697 |
+
self.launch_cooperative_grid = metadata.launch_cooperative_grid
|
| 698 |
+
self.launch_pdl = metadata.launch_pdl
|
| 699 |
+
|
| 700 |
+
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
|
| 701 |
+
|
| 702 |
+
def allocate_scratch(size, align, allocator):
|
| 703 |
+
if size > 0:
|
| 704 |
+
grid_size = gridX * gridY * gridZ
|
| 705 |
+
alloc_size = grid_size * self.num_ctas * size
|
| 706 |
+
alloc_fn = allocator.get()
|
| 707 |
+
return alloc_fn(alloc_size, align, stream)
|
| 708 |
+
return None
|
| 709 |
+
|
| 710 |
+
global_scratch = allocate_scratch(self.global_scratch_size, self.global_scratch_align, _allocation._allocator)
|
| 711 |
+
profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
|
| 712 |
+
_allocation._profile_allocator)
|
| 713 |
+
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
|
| 714 |
+
global_scratch, profile_scratch, *args)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class CudaDriver(GPUDriver):
|
| 718 |
+
|
| 719 |
+
def __init__(self):
|
| 720 |
+
self.utils = CudaUtils() # TODO: make static
|
| 721 |
+
self.launcher_cls = CudaLauncher
|
| 722 |
+
super().__init__()
|
| 723 |
+
|
| 724 |
+
def get_current_target(self):
|
| 725 |
+
device = self.get_current_device()
|
| 726 |
+
capability = self.get_device_capability(device)
|
| 727 |
+
capability = capability[0] * 10 + capability[1]
|
| 728 |
+
warp_size = 32
|
| 729 |
+
return GPUTarget("cuda", capability, warp_size)
|
| 730 |
+
|
| 731 |
+
def get_active_torch_device(self):
|
| 732 |
+
import torch
|
| 733 |
+
return torch.device("cuda", self.get_current_device())
|
| 734 |
+
|
| 735 |
+
def get_device_interface(self):
|
| 736 |
+
import torch
|
| 737 |
+
return torch.cuda
|
| 738 |
+
|
| 739 |
+
@staticmethod
|
| 740 |
+
def is_active():
|
| 741 |
+
try:
|
| 742 |
+
import torch
|
| 743 |
+
return torch.cuda.is_available() and (torch.version.hip is None)
|
| 744 |
+
except ImportError:
|
| 745 |
+
return False
|
| 746 |
+
|
| 747 |
+
def map_python_to_cpp_type(self, ty: str) -> str:
|
| 748 |
+
return ty_to_cpp(ty)
|
| 749 |
+
|
| 750 |
+
def get_benchmarker(self):
|
| 751 |
+
from triton.testing import do_bench
|
| 752 |
+
return do_bench
|
| 753 |
+
|
| 754 |
+
def get_empty_cache_for_benchmark(self):
|
| 755 |
+
import torch
|
| 756 |
+
|
| 757 |
+
# We maintain a buffer of 256 MB that we clear
|
| 758 |
+
# before each kernel call to make sure that the L2 cache
|
| 759 |
+
# doesn't contain any input data before the run
|
| 760 |
+
cache_size = 256 * 1024 * 1024
|
| 761 |
+
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
|
| 762 |
+
|
| 763 |
+
def clear_cache(self, cache):
|
| 764 |
+
cache.zero_()
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/cudaGL.h
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2014 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#ifndef CUDAGL_H
|
| 51 |
+
#define CUDAGL_H
|
| 52 |
+
|
| 53 |
+
#include <cuda.h>
|
| 54 |
+
#include <GL/gl.h>
|
| 55 |
+
|
| 56 |
+
#if defined(__CUDA_API_VERSION_INTERNAL) || defined(__DOXYGEN_ONLY__) || defined(CUDA_ENABLE_DEPRECATED)
|
| 57 |
+
#define __CUDA_DEPRECATED
|
| 58 |
+
#elif defined(_MSC_VER)
|
| 59 |
+
#define __CUDA_DEPRECATED __declspec(deprecated)
|
| 60 |
+
#elif defined(__GNUC__)
|
| 61 |
+
#define __CUDA_DEPRECATED __attribute__((deprecated))
|
| 62 |
+
#else
|
| 63 |
+
#define __CUDA_DEPRECATED
|
| 64 |
+
#endif
|
| 65 |
+
|
| 66 |
+
#ifdef CUDA_FORCE_API_VERSION
|
| 67 |
+
#error "CUDA_FORCE_API_VERSION is no longer supported."
|
| 68 |
+
#endif
|
| 69 |
+
|
| 70 |
+
#if defined(__CUDA_API_VERSION_INTERNAL) || defined(CUDA_API_PER_THREAD_DEFAULT_STREAM)
|
| 71 |
+
#define __CUDA_API_PER_THREAD_DEFAULT_STREAM
|
| 72 |
+
#define __CUDA_API_PTDS(api) api ## _ptds
|
| 73 |
+
#define __CUDA_API_PTSZ(api) api ## _ptsz
|
| 74 |
+
#else
|
| 75 |
+
#define __CUDA_API_PTDS(api) api
|
| 76 |
+
#define __CUDA_API_PTSZ(api) api
|
| 77 |
+
#endif
|
| 78 |
+
|
| 79 |
+
#define cuGLCtxCreate cuGLCtxCreate_v2
|
| 80 |
+
#define cuGLMapBufferObject __CUDA_API_PTDS(cuGLMapBufferObject_v2)
|
| 81 |
+
#define cuGLMapBufferObjectAsync __CUDA_API_PTSZ(cuGLMapBufferObjectAsync_v2)
|
| 82 |
+
#define cuGLGetDevices cuGLGetDevices_v2
|
| 83 |
+
|
| 84 |
+
#ifdef __cplusplus
|
| 85 |
+
extern "C" {
|
| 86 |
+
#endif
|
| 87 |
+
|
| 88 |
+
/**
|
| 89 |
+
* \file cudaGL.h
|
| 90 |
+
* \brief Header file for the OpenGL interoperability functions of the
|
| 91 |
+
* low-level CUDA driver application programming interface.
|
| 92 |
+
*/
|
| 93 |
+
|
| 94 |
+
/**
|
| 95 |
+
* \defgroup CUDA_GL OpenGL Interoperability
|
| 96 |
+
* \ingroup CUDA_DRIVER
|
| 97 |
+
*
|
| 98 |
+
* ___MANBRIEF___ OpenGL interoperability functions of the low-level CUDA
|
| 99 |
+
* driver API (___CURRENT_FILE___) ___ENDMANBRIEF___
|
| 100 |
+
*
|
| 101 |
+
* This section describes the OpenGL interoperability functions of the
|
| 102 |
+
* low-level CUDA driver application programming interface. Note that mapping
|
| 103 |
+
* of OpenGL resources is performed with the graphics API agnostic, resource
|
| 104 |
+
* mapping interface described in \ref CUDA_GRAPHICS "Graphics Interoperability".
|
| 105 |
+
*
|
| 106 |
+
* @{
|
| 107 |
+
*/
|
| 108 |
+
|
| 109 |
+
#if defined(_WIN32)
|
| 110 |
+
#if !defined(WGL_NV_gpu_affinity)
|
| 111 |
+
typedef void* HGPUNV;
|
| 112 |
+
#endif
|
| 113 |
+
#endif /* _WIN32 */
|
| 114 |
+
|
| 115 |
+
/**
|
| 116 |
+
* \brief Registers an OpenGL buffer object
|
| 117 |
+
*
|
| 118 |
+
* Registers the buffer object specified by \p buffer for access by
|
| 119 |
+
* CUDA. A handle to the registered object is returned as \p
|
| 120 |
+
* pCudaResource. The register flags \p Flags specify the intended usage,
|
| 121 |
+
* as follows:
|
| 122 |
+
*
|
| 123 |
+
* - ::CU_GRAPHICS_REGISTER_FLAGS_NONE: Specifies no hints about how this
|
| 124 |
+
* resource will be used. It is therefore assumed that this resource will be
|
| 125 |
+
* read from and written to by CUDA. This is the default value.
|
| 126 |
+
* - ::CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY: Specifies that CUDA
|
| 127 |
+
* will not write to this resource.
|
| 128 |
+
* - ::CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD: Specifies that
|
| 129 |
+
* CUDA will not read from this resource and will write over the
|
| 130 |
+
* entire contents of the resource, so none of the data previously
|
| 131 |
+
* stored in the resource will be preserved.
|
| 132 |
+
*
|
| 133 |
+
* \param pCudaResource - Pointer to the returned object handle
|
| 134 |
+
* \param buffer - name of buffer object to be registered
|
| 135 |
+
* \param Flags - Register flags
|
| 136 |
+
*
|
| 137 |
+
* \return
|
| 138 |
+
* ::CUDA_SUCCESS,
|
| 139 |
+
* ::CUDA_ERROR_INVALID_HANDLE,
|
| 140 |
+
* ::CUDA_ERROR_ALREADY_MAPPED,
|
| 141 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 142 |
+
* ::CUDA_ERROR_OPERATING_SYSTEM
|
| 143 |
+
* \notefnerr
|
| 144 |
+
*
|
| 145 |
+
* \sa
|
| 146 |
+
* ::cuGraphicsUnregisterResource,
|
| 147 |
+
* ::cuGraphicsMapResources,
|
| 148 |
+
* ::cuGraphicsResourceGetMappedPointer,
|
| 149 |
+
* ::cudaGraphicsGLRegisterBuffer
|
| 150 |
+
*/
|
| 151 |
+
CUresult CUDAAPI cuGraphicsGLRegisterBuffer(CUgraphicsResource *pCudaResource, GLuint buffer, unsigned int Flags);
|
| 152 |
+
|
| 153 |
+
/**
|
| 154 |
+
* \brief Register an OpenGL texture or renderbuffer object
|
| 155 |
+
*
|
| 156 |
+
* Registers the texture or renderbuffer object specified by \p image for access by CUDA.
|
| 157 |
+
* A handle to the registered object is returned as \p pCudaResource.
|
| 158 |
+
*
|
| 159 |
+
* \p target must match the type of the object, and must be one of ::GL_TEXTURE_2D,
|
| 160 |
+
* ::GL_TEXTURE_RECTANGLE, ::GL_TEXTURE_CUBE_MAP, ::GL_TEXTURE_3D, ::GL_TEXTURE_2D_ARRAY,
|
| 161 |
+
* or ::GL_RENDERBUFFER.
|
| 162 |
+
*
|
| 163 |
+
* The register flags \p Flags specify the intended usage, as follows:
|
| 164 |
+
*
|
| 165 |
+
* - ::CU_GRAPHICS_REGISTER_FLAGS_NONE: Specifies no hints about how this
|
| 166 |
+
* resource will be used. It is therefore assumed that this resource will be
|
| 167 |
+
* read from and written to by CUDA. This is the default value.
|
| 168 |
+
* - ::CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY: Specifies that CUDA
|
| 169 |
+
* will not write to this resource.
|
| 170 |
+
* - ::CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD: Specifies that
|
| 171 |
+
* CUDA will not read from this resource and will write over the
|
| 172 |
+
* entire contents of the resource, so none of the data previously
|
| 173 |
+
* stored in the resource will be preserved.
|
| 174 |
+
* - ::CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST: Specifies that CUDA will
|
| 175 |
+
* bind this resource to a surface reference.
|
| 176 |
+
* - ::CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER: Specifies that CUDA will perform
|
| 177 |
+
* texture gather operations on this resource.
|
| 178 |
+
*
|
| 179 |
+
* The following image formats are supported. For brevity's sake, the list is abbreviated.
|
| 180 |
+
* For ex., {GL_R, GL_RG} X {8, 16} would expand to the following 4 formats
|
| 181 |
+
* {GL_R8, GL_R16, GL_RG8, GL_RG16} :
|
| 182 |
+
* - GL_RED, GL_RG, GL_RGBA, GL_LUMINANCE, GL_ALPHA, GL_LUMINANCE_ALPHA, GL_INTENSITY
|
| 183 |
+
* - {GL_R, GL_RG, GL_RGBA} X {8, 16, 16F, 32F, 8UI, 16UI, 32UI, 8I, 16I, 32I}
|
| 184 |
+
* - {GL_LUMINANCE, GL_ALPHA, GL_LUMINANCE_ALPHA, GL_INTENSITY} X
|
| 185 |
+
* {8, 16, 16F_ARB, 32F_ARB, 8UI_EXT, 16UI_EXT, 32UI_EXT, 8I_EXT, 16I_EXT, 32I_EXT}
|
| 186 |
+
*
|
| 187 |
+
* The following image classes are currently disallowed:
|
| 188 |
+
* - Textures with borders
|
| 189 |
+
* - Multisampled renderbuffers
|
| 190 |
+
*
|
| 191 |
+
* \param pCudaResource - Pointer to the returned object handle
|
| 192 |
+
* \param image - name of texture or renderbuffer object to be registered
|
| 193 |
+
* \param target - Identifies the type of object specified by \p image
|
| 194 |
+
* \param Flags - Register flags
|
| 195 |
+
*
|
| 196 |
+
* \return
|
| 197 |
+
* ::CUDA_SUCCESS,
|
| 198 |
+
* ::CUDA_ERROR_INVALID_HANDLE,
|
| 199 |
+
* ::CUDA_ERROR_ALREADY_MAPPED,
|
| 200 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 201 |
+
* ::CUDA_ERROR_OPERATING_SYSTEM
|
| 202 |
+
* \notefnerr
|
| 203 |
+
*
|
| 204 |
+
* \sa
|
| 205 |
+
* ::cuGraphicsUnregisterResource,
|
| 206 |
+
* ::cuGraphicsMapResources,
|
| 207 |
+
* ::cuGraphicsSubResourceGetMappedArray,
|
| 208 |
+
* ::cudaGraphicsGLRegisterImage
|
| 209 |
+
*/
|
| 210 |
+
CUresult CUDAAPI cuGraphicsGLRegisterImage(CUgraphicsResource *pCudaResource, GLuint image, GLenum target, unsigned int Flags);
|
| 211 |
+
|
| 212 |
+
#ifdef _WIN32
|
| 213 |
+
/**
|
| 214 |
+
* \brief Gets the CUDA device associated with hGpu
|
| 215 |
+
*
|
| 216 |
+
* Returns in \p *pDevice the CUDA device associated with a \p hGpu, if
|
| 217 |
+
* applicable.
|
| 218 |
+
*
|
| 219 |
+
* \param pDevice - Device associated with hGpu
|
| 220 |
+
* \param hGpu - Handle to a GPU, as queried via ::WGL_NV_gpu_affinity()
|
| 221 |
+
*
|
| 222 |
+
* \return
|
| 223 |
+
* ::CUDA_SUCCESS,
|
| 224 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 225 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 226 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 227 |
+
* ::CUDA_ERROR_INVALID_VALUE
|
| 228 |
+
* \notefnerr
|
| 229 |
+
*
|
| 230 |
+
* \sa ::cuGLMapBufferObject,
|
| 231 |
+
* ::cuGLRegisterBufferObject, ::cuGLUnmapBufferObject,
|
| 232 |
+
* ::cuGLUnregisterBufferObject, ::cuGLUnmapBufferObjectAsync,
|
| 233 |
+
* ::cuGLSetBufferObjectMapFlags,
|
| 234 |
+
* ::cudaWGLGetDevice
|
| 235 |
+
*/
|
| 236 |
+
CUresult CUDAAPI cuWGLGetDevice(CUdevice *pDevice, HGPUNV hGpu);
|
| 237 |
+
#endif /* _WIN32 */
|
| 238 |
+
|
| 239 |
+
/**
|
| 240 |
+
* CUDA devices corresponding to an OpenGL device
|
| 241 |
+
*/
|
| 242 |
+
typedef enum CUGLDeviceList_enum {
|
| 243 |
+
CU_GL_DEVICE_LIST_ALL = 0x01, /**< The CUDA devices for all GPUs used by the current OpenGL context */
|
| 244 |
+
CU_GL_DEVICE_LIST_CURRENT_FRAME = 0x02, /**< The CUDA devices for the GPUs used by the current OpenGL context in its currently rendering frame */
|
| 245 |
+
CU_GL_DEVICE_LIST_NEXT_FRAME = 0x03, /**< The CUDA devices for the GPUs to be used by the current OpenGL context in the next frame */
|
| 246 |
+
} CUGLDeviceList;
|
| 247 |
+
|
| 248 |
+
/**
|
| 249 |
+
* \brief Gets the CUDA devices associated with the current OpenGL context
|
| 250 |
+
*
|
| 251 |
+
* Returns in \p *pCudaDeviceCount the number of CUDA-compatible devices
|
| 252 |
+
* corresponding to the current OpenGL context. Also returns in \p *pCudaDevices
|
| 253 |
+
* at most cudaDeviceCount of the CUDA-compatible devices corresponding to
|
| 254 |
+
* the current OpenGL context. If any of the GPUs being used by the current OpenGL
|
| 255 |
+
* context are not CUDA capable then the call will return CUDA_ERROR_NO_DEVICE.
|
| 256 |
+
*
|
| 257 |
+
* The \p deviceList argument may be any of the following:
|
| 258 |
+
* - ::CU_GL_DEVICE_LIST_ALL: Query all devices used by the current OpenGL context.
|
| 259 |
+
* - ::CU_GL_DEVICE_LIST_CURRENT_FRAME: Query the devices used by the current OpenGL context to
|
| 260 |
+
* render the current frame (in SLI).
|
| 261 |
+
* - ::CU_GL_DEVICE_LIST_NEXT_FRAME: Query the devices used by the current OpenGL context to
|
| 262 |
+
* render the next frame (in SLI). Note that this is a prediction, it can't be guaranteed that
|
| 263 |
+
* this is correct in all cases.
|
| 264 |
+
*
|
| 265 |
+
* \param pCudaDeviceCount - Returned number of CUDA devices.
|
| 266 |
+
* \param pCudaDevices - Returned CUDA devices.
|
| 267 |
+
* \param cudaDeviceCount - The size of the output device array pCudaDevices.
|
| 268 |
+
* \param deviceList - The set of devices to return.
|
| 269 |
+
*
|
| 270 |
+
* \return
|
| 271 |
+
* ::CUDA_SUCCESS,
|
| 272 |
+
* ::CUDA_ERROR_NO_DEVICE,
|
| 273 |
+
* ::CUDA_ERROR_INVALID_VALUE,
|
| 274 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 275 |
+
* ::CUDA_ERROR_INVALID_GRAPHICS_CONTEXT,
|
| 276 |
+
* ::CUDA_ERROR_OPERATING_SYSTEM
|
| 277 |
+
*
|
| 278 |
+
* \notefnerr
|
| 279 |
+
*
|
| 280 |
+
* \sa
|
| 281 |
+
* ::cuWGLGetDevice,
|
| 282 |
+
* ::cudaGLGetDevices
|
| 283 |
+
*/
|
| 284 |
+
CUresult CUDAAPI cuGLGetDevices(unsigned int *pCudaDeviceCount, CUdevice *pCudaDevices, unsigned int cudaDeviceCount, CUGLDeviceList deviceList);
|
| 285 |
+
|
| 286 |
+
/**
|
| 287 |
+
* \defgroup CUDA_GL_DEPRECATED OpenGL Interoperability [DEPRECATED]
|
| 288 |
+
*
|
| 289 |
+
* ___MANBRIEF___ deprecated OpenGL interoperability functions of the low-level
|
| 290 |
+
* CUDA driver API (___CURRENT_FILE___) ___ENDMANBRIEF___
|
| 291 |
+
*
|
| 292 |
+
* This section describes deprecated OpenGL interoperability functionality.
|
| 293 |
+
*
|
| 294 |
+
* @{
|
| 295 |
+
*/
|
| 296 |
+
|
| 297 |
+
/** Flags to map or unmap a resource */
|
| 298 |
+
typedef enum CUGLmap_flags_enum {
|
| 299 |
+
CU_GL_MAP_RESOURCE_FLAGS_NONE = 0x00,
|
| 300 |
+
CU_GL_MAP_RESOURCE_FLAGS_READ_ONLY = 0x01,
|
| 301 |
+
CU_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD = 0x02,
|
| 302 |
+
} CUGLmap_flags;
|
| 303 |
+
|
| 304 |
+
/**
|
| 305 |
+
* \brief Create a CUDA context for interoperability with OpenGL
|
| 306 |
+
*
|
| 307 |
+
* \deprecated This function is deprecated as of Cuda 5.0.
|
| 308 |
+
*
|
| 309 |
+
* This function is deprecated and should no longer be used. It is
|
| 310 |
+
* no longer necessary to associate a CUDA context with an OpenGL
|
| 311 |
+
* context in order to achieve maximum interoperability performance.
|
| 312 |
+
*
|
| 313 |
+
* \param pCtx - Returned CUDA context
|
| 314 |
+
* \param Flags - Options for CUDA context creation
|
| 315 |
+
* \param device - Device on which to create the context
|
| 316 |
+
*
|
| 317 |
+
* \return
|
| 318 |
+
* ::CUDA_SUCCESS,
|
| 319 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 320 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 321 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 322 |
+
* ::CUDA_ERROR_INVALID_VALUE,
|
| 323 |
+
* ::CUDA_ERROR_OUT_OF_MEMORY
|
| 324 |
+
* \notefnerr
|
| 325 |
+
*
|
| 326 |
+
* \sa ::cuCtxCreate, ::cuGLInit, ::cuGLMapBufferObject,
|
| 327 |
+
* ::cuGLRegisterBufferObject, ::cuGLUnmapBufferObject,
|
| 328 |
+
* ::cuGLUnregisterBufferObject, ::cuGLMapBufferObjectAsync,
|
| 329 |
+
* ::cuGLUnmapBufferObjectAsync, ::cuGLSetBufferObjectMapFlags,
|
| 330 |
+
* ::cuWGLGetDevice
|
| 331 |
+
*/
|
| 332 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLCtxCreate(CUcontext *pCtx, unsigned int Flags, CUdevice device );
|
| 333 |
+
|
| 334 |
+
/**
|
| 335 |
+
* \brief Initializes OpenGL interoperability
|
| 336 |
+
*
|
| 337 |
+
* \deprecated This function is deprecated as of Cuda 3.0.
|
| 338 |
+
*
|
| 339 |
+
* Initializes OpenGL interoperability. This function is deprecated
|
| 340 |
+
* and calling it is no longer required. It may fail if the needed
|
| 341 |
+
* OpenGL driver facilities are not available.
|
| 342 |
+
*
|
| 343 |
+
* \return
|
| 344 |
+
* ::CUDA_SUCCESS,
|
| 345 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 346 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 347 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 348 |
+
* ::CUDA_ERROR_UNKNOWN
|
| 349 |
+
* \notefnerr
|
| 350 |
+
*
|
| 351 |
+
* \sa ::cuGLMapBufferObject,
|
| 352 |
+
* ::cuGLRegisterBufferObject, ::cuGLUnmapBufferObject,
|
| 353 |
+
* ::cuGLUnregisterBufferObject, ::cuGLMapBufferObjectAsync,
|
| 354 |
+
* ::cuGLUnmapBufferObjectAsync, ::cuGLSetBufferObjectMapFlags,
|
| 355 |
+
* ::cuWGLGetDevice
|
| 356 |
+
*/
|
| 357 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLInit(void);
|
| 358 |
+
|
| 359 |
+
/**
|
| 360 |
+
* \brief Registers an OpenGL buffer object
|
| 361 |
+
*
|
| 362 |
+
* \deprecated This function is deprecated as of Cuda 3.0.
|
| 363 |
+
*
|
| 364 |
+
* Registers the buffer object specified by \p buffer for access by
|
| 365 |
+
* CUDA. This function must be called before CUDA can map the buffer
|
| 366 |
+
* object. There must be a valid OpenGL context bound to the current
|
| 367 |
+
* thread when this function is called, and the buffer name is
|
| 368 |
+
* resolved by that context.
|
| 369 |
+
*
|
| 370 |
+
* \param buffer - The name of the buffer object to register.
|
| 371 |
+
*
|
| 372 |
+
* \return
|
| 373 |
+
* ::CUDA_SUCCESS,
|
| 374 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 375 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 376 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 377 |
+
* ::CUDA_ERROR_ALREADY_MAPPED
|
| 378 |
+
* \notefnerr
|
| 379 |
+
*
|
| 380 |
+
* \sa ::cuGraphicsGLRegisterBuffer
|
| 381 |
+
*/
|
| 382 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLRegisterBufferObject(GLuint buffer);
|
| 383 |
+
|
| 384 |
+
/**
|
| 385 |
+
* \brief Maps an OpenGL buffer object
|
| 386 |
+
*
|
| 387 |
+
* \deprecated This function is deprecated as of Cuda 3.0.
|
| 388 |
+
*
|
| 389 |
+
* Maps the buffer object specified by \p buffer into the address space of the
|
| 390 |
+
* current CUDA context and returns in \p *dptr and \p *size the base pointer
|
| 391 |
+
* and size of the resulting mapping.
|
| 392 |
+
*
|
| 393 |
+
* There must be a valid OpenGL context bound to the current thread
|
| 394 |
+
* when this function is called. This must be the same context, or a
|
| 395 |
+
* member of the same shareGroup, as the context that was bound when
|
| 396 |
+
* the buffer was registered.
|
| 397 |
+
*
|
| 398 |
+
* All streams in the current CUDA context are synchronized with the
|
| 399 |
+
* current GL context.
|
| 400 |
+
*
|
| 401 |
+
* \param dptr - Returned mapped base pointer
|
| 402 |
+
* \param size - Returned size of mapping
|
| 403 |
+
* \param buffer - The name of the buffer object to map
|
| 404 |
+
*
|
| 405 |
+
* \return
|
| 406 |
+
* ::CUDA_SUCCESS,
|
| 407 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 408 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 409 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 410 |
+
* ::CUDA_ERROR_INVALID_VALUE,
|
| 411 |
+
* ::CUDA_ERROR_MAP_FAILED
|
| 412 |
+
* \notefnerr
|
| 413 |
+
*
|
| 414 |
+
* \sa ::cuGraphicsMapResources
|
| 415 |
+
*/
|
| 416 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLMapBufferObject(CUdeviceptr *dptr, size_t *size, GLuint buffer);
|
| 417 |
+
|
| 418 |
+
/**
|
| 419 |
+
* \brief Unmaps an OpenGL buffer object
|
| 420 |
+
*
|
| 421 |
+
* \deprecated This function is deprecated as of Cuda 3.0.
|
| 422 |
+
*
|
| 423 |
+
* Unmaps the buffer object specified by \p buffer for access by CUDA.
|
| 424 |
+
*
|
| 425 |
+
* There must be a valid OpenGL context bound to the current thread
|
| 426 |
+
* when this function is called. This must be the same context, or a
|
| 427 |
+
* member of the same shareGroup, as the context that was bound when
|
| 428 |
+
* the buffer was registered.
|
| 429 |
+
*
|
| 430 |
+
* All streams in the current CUDA context are synchronized with the
|
| 431 |
+
* current GL context.
|
| 432 |
+
*
|
| 433 |
+
* \param buffer - Buffer object to unmap
|
| 434 |
+
*
|
| 435 |
+
* \return
|
| 436 |
+
* ::CUDA_SUCCESS,
|
| 437 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 438 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 439 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 440 |
+
* ::CUDA_ERROR_INVALID_VALUE
|
| 441 |
+
* \notefnerr
|
| 442 |
+
*
|
| 443 |
+
* \sa ::cuGraphicsUnmapResources
|
| 444 |
+
*/
|
| 445 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLUnmapBufferObject(GLuint buffer);
|
| 446 |
+
|
| 447 |
+
/**
|
| 448 |
+
* \brief Unregister an OpenGL buffer object
|
| 449 |
+
*
|
| 450 |
+
* \deprecated This function is deprecated as of Cuda 3.0.
|
| 451 |
+
*
|
| 452 |
+
* Unregisters the buffer object specified by \p buffer. This
|
| 453 |
+
* releases any resources associated with the registered buffer.
|
| 454 |
+
* After this call, the buffer may no longer be mapped for access by
|
| 455 |
+
* CUDA.
|
| 456 |
+
*
|
| 457 |
+
* There must be a valid OpenGL context bound to the current thread
|
| 458 |
+
* when this function is called. This must be the same context, or a
|
| 459 |
+
* member of the same shareGroup, as the context that was bound when
|
| 460 |
+
* the buffer was registered.
|
| 461 |
+
*
|
| 462 |
+
* \param buffer - Name of the buffer object to unregister
|
| 463 |
+
*
|
| 464 |
+
* \return
|
| 465 |
+
* ::CUDA_SUCCESS,
|
| 466 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 467 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 468 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 469 |
+
* ::CUDA_ERROR_INVALID_VALUE
|
| 470 |
+
* \notefnerr
|
| 471 |
+
*
|
| 472 |
+
* \sa ::cuGraphicsUnregisterResource
|
| 473 |
+
*/
|
| 474 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLUnregisterBufferObject(GLuint buffer);
|
| 475 |
+
|
| 476 |
+
/**
|
| 477 |
+
* \brief Set the map flags for an OpenGL buffer object
|
| 478 |
+
*
|
| 479 |
+
* \deprecated This function is deprecated as of Cuda 3.0.
|
| 480 |
+
*
|
| 481 |
+
* Sets the map flags for the buffer object specified by \p buffer.
|
| 482 |
+
*
|
| 483 |
+
* Changes to \p Flags will take effect the next time \p buffer is mapped.
|
| 484 |
+
* The \p Flags argument may be any of the following:
|
| 485 |
+
* - ::CU_GL_MAP_RESOURCE_FLAGS_NONE: Specifies no hints about how this
|
| 486 |
+
* resource will be used. It is therefore assumed that this resource will be
|
| 487 |
+
* read from and written to by CUDA kernels. This is the default value.
|
| 488 |
+
* - ::CU_GL_MAP_RESOURCE_FLAGS_READ_ONLY: Specifies that CUDA kernels which
|
| 489 |
+
* access this resource will not write to this resource.
|
| 490 |
+
* - ::CU_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD: Specifies that CUDA kernels
|
| 491 |
+
* which access this resource will not read from this resource and will
|
| 492 |
+
* write over the entire contents of the resource, so none of the data
|
| 493 |
+
* previously stored in the resource will be preserved.
|
| 494 |
+
*
|
| 495 |
+
* If \p buffer has not been registered for use with CUDA, then
|
| 496 |
+
* ::CUDA_ERROR_INVALID_HANDLE is returned. If \p buffer is presently
|
| 497 |
+
* mapped for access by CUDA, then ::CUDA_ERROR_ALREADY_MAPPED is returned.
|
| 498 |
+
*
|
| 499 |
+
* There must be a valid OpenGL context bound to the current thread
|
| 500 |
+
* when this function is called. This must be the same context, or a
|
| 501 |
+
* member of the same shareGroup, as the context that was bound when
|
| 502 |
+
* the buffer was registered.
|
| 503 |
+
*
|
| 504 |
+
* \param buffer - Buffer object to unmap
|
| 505 |
+
* \param Flags - Map flags
|
| 506 |
+
*
|
| 507 |
+
* \return
|
| 508 |
+
* ::CUDA_SUCCESS,
|
| 509 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 510 |
+
* ::CUDA_ERROR_INVALID_HANDLE,
|
| 511 |
+
* ::CUDA_ERROR_ALREADY_MAPPED,
|
| 512 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 513 |
+
* \notefnerr
|
| 514 |
+
*
|
| 515 |
+
* \sa ::cuGraphicsResourceSetMapFlags
|
| 516 |
+
*/
|
| 517 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLSetBufferObjectMapFlags(GLuint buffer, unsigned int Flags);
|
| 518 |
+
|
| 519 |
+
/**
|
| 520 |
+
* \brief Maps an OpenGL buffer object
|
| 521 |
+
*
|
| 522 |
+
* \deprecated This function is deprecated as of Cuda 3.0.
|
| 523 |
+
*
|
| 524 |
+
* Maps the buffer object specified by \p buffer into the address space of the
|
| 525 |
+
* current CUDA context and returns in \p *dptr and \p *size the base pointer
|
| 526 |
+
* and size of the resulting mapping.
|
| 527 |
+
*
|
| 528 |
+
* There must be a valid OpenGL context bound to the current thread
|
| 529 |
+
* when this function is called. This must be the same context, or a
|
| 530 |
+
* member of the same shareGroup, as the context that was bound when
|
| 531 |
+
* the buffer was registered.
|
| 532 |
+
*
|
| 533 |
+
* Stream \p hStream in the current CUDA context is synchronized with
|
| 534 |
+
* the current GL context.
|
| 535 |
+
*
|
| 536 |
+
* \param dptr - Returned mapped base pointer
|
| 537 |
+
* \param size - Returned size of mapping
|
| 538 |
+
* \param buffer - The name of the buffer object to map
|
| 539 |
+
* \param hStream - Stream to synchronize
|
| 540 |
+
*
|
| 541 |
+
* \return
|
| 542 |
+
* ::CUDA_SUCCESS,
|
| 543 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 544 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 545 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 546 |
+
* ::CUDA_ERROR_INVALID_VALUE,
|
| 547 |
+
* ::CUDA_ERROR_MAP_FAILED
|
| 548 |
+
* \notefnerr
|
| 549 |
+
*
|
| 550 |
+
* \sa ::cuGraphicsMapResources
|
| 551 |
+
*/
|
| 552 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLMapBufferObjectAsync(CUdeviceptr *dptr, size_t *size, GLuint buffer, CUstream hStream);
|
| 553 |
+
|
| 554 |
+
/**
|
| 555 |
+
* \brief Unmaps an OpenGL buffer object
|
| 556 |
+
*
|
| 557 |
+
* \deprecated This function is deprecated as of Cuda 3.0.
|
| 558 |
+
*
|
| 559 |
+
* Unmaps the buffer object specified by \p buffer for access by CUDA.
|
| 560 |
+
*
|
| 561 |
+
* There must be a valid OpenGL context bound to the current thread
|
| 562 |
+
* when this function is called. This must be the same context, or a
|
| 563 |
+
* member of the same shareGroup, as the context that was bound when
|
| 564 |
+
* the buffer was registered.
|
| 565 |
+
*
|
| 566 |
+
* Stream \p hStream in the current CUDA context is synchronized with
|
| 567 |
+
* the current GL context.
|
| 568 |
+
*
|
| 569 |
+
* \param buffer - Name of the buffer object to unmap
|
| 570 |
+
* \param hStream - Stream to synchronize
|
| 571 |
+
*
|
| 572 |
+
* \return
|
| 573 |
+
* ::CUDA_SUCCESS,
|
| 574 |
+
* ::CUDA_ERROR_DEINITIALIZED,
|
| 575 |
+
* ::CUDA_ERROR_NOT_INITIALIZED,
|
| 576 |
+
* ::CUDA_ERROR_INVALID_CONTEXT,
|
| 577 |
+
* ::CUDA_ERROR_INVALID_VALUE
|
| 578 |
+
* \notefnerr
|
| 579 |
+
*
|
| 580 |
+
* \sa ::cuGraphicsUnmapResources
|
| 581 |
+
*/
|
| 582 |
+
__CUDA_DEPRECATED CUresult CUDAAPI cuGLUnmapBufferObjectAsync(GLuint buffer, CUstream hStream);
|
| 583 |
+
|
| 584 |
+
/** @} */ /* END CUDA_GL_DEPRECATED */
|
| 585 |
+
/** @} */ /* END CUDA_GL */
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
#if defined(__CUDA_API_VERSION_INTERNAL)
|
| 589 |
+
#undef cuGLCtxCreate
|
| 590 |
+
#undef cuGLMapBufferObject
|
| 591 |
+
#undef cuGLMapBufferObjectAsync
|
| 592 |
+
#undef cuGLGetDevices
|
| 593 |
+
|
| 594 |
+
CUresult CUDAAPI cuGLGetDevices(unsigned int *pCudaDeviceCount, CUdevice *pCudaDevices, unsigned int cudaDeviceCount, CUGLDeviceList deviceList);
|
| 595 |
+
CUresult CUDAAPI cuGLMapBufferObject_v2(CUdeviceptr *dptr, size_t *size, GLuint buffer);
|
| 596 |
+
CUresult CUDAAPI cuGLMapBufferObjectAsync_v2(CUdeviceptr *dptr, size_t *size, GLuint buffer, CUstream hStream);
|
| 597 |
+
CUresult CUDAAPI cuGLCtxCreate(CUcontext *pCtx, unsigned int Flags, CUdevice device );
|
| 598 |
+
CUresult CUDAAPI cuGLMapBufferObject(CUdeviceptr_v1 *dptr, unsigned int *size, GLuint buffer);
|
| 599 |
+
CUresult CUDAAPI cuGLMapBufferObjectAsync(CUdeviceptr_v1 *dptr, unsigned int *size, GLuint buffer, CUstream hStream);
|
| 600 |
+
#endif /* __CUDA_API_VERSION_INTERNAL */
|
| 601 |
+
|
| 602 |
+
#ifdef __cplusplus
|
| 603 |
+
};
|
| 604 |
+
#endif
|
| 605 |
+
|
| 606 |
+
#undef __CUDA_DEPRECATED
|
| 607 |
+
|
| 608 |
+
#endif
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/cupti_pcsampling_util.h
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(_CUPTI_PCSAMPLING_UTIL_H_)
|
| 2 |
+
#define _CUPTI_PCSAMPLING_UTIL_H_
|
| 3 |
+
|
| 4 |
+
#include <cupti_pcsampling.h>
|
| 5 |
+
#include <fstream>
|
| 6 |
+
|
| 7 |
+
#include <cupti_common.h>
|
| 8 |
+
|
| 9 |
+
#ifndef CUPTI_UTIL_STRUCT_SIZE
|
| 10 |
+
#define CUPTI_UTIL_STRUCT_SIZE(type_, lastfield_) (offsetof(type_, lastfield_) + sizeof(((type_*)0)->lastfield_))
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#ifndef CHECK_PC_SAMPLING_STRUCT_FIELD_EXISTS
|
| 14 |
+
#define CHECK_PC_SAMPLING_STRUCT_FIELD_EXISTS(type, member, structSize) \
|
| 15 |
+
(offsetof(type, member) < structSize)
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#if defined(__cplusplus)
|
| 19 |
+
extern "C" {
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
#if defined(__GNUC__)
|
| 23 |
+
#pragma GCC visibility push(default)
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
namespace CUPTI { namespace PcSamplingUtil {
|
| 27 |
+
|
| 28 |
+
/**
|
| 29 |
+
* \defgroup CUPTI_PCSAMPLING_UTILITY CUPTI PC Sampling Utility API
|
| 30 |
+
* Functions, types, and enums that implement the CUPTI PC Sampling Utility API.
|
| 31 |
+
* @{
|
| 32 |
+
*/
|
| 33 |
+
|
| 34 |
+
/**
|
| 35 |
+
* \brief Header info will be stored in file.
|
| 36 |
+
*/
|
| 37 |
+
typedef struct PACKED_ALIGNMENT {
|
| 38 |
+
/**
|
| 39 |
+
* Version of file format.
|
| 40 |
+
*/
|
| 41 |
+
uint32_t version;
|
| 42 |
+
/**
|
| 43 |
+
* Total number of buffers present in the file.
|
| 44 |
+
*/
|
| 45 |
+
uint32_t totalBuffers;
|
| 46 |
+
} Header;
|
| 47 |
+
|
| 48 |
+
/**
|
| 49 |
+
* \brief BufferInfo will be stored in the file for every buffer
|
| 50 |
+
* i.e for every call of UtilDumpPcSamplingBufferInFile() API.
|
| 51 |
+
*/
|
| 52 |
+
typedef struct PACKED_ALIGNMENT {
|
| 53 |
+
/**
|
| 54 |
+
* Total number of PC records.
|
| 55 |
+
*/
|
| 56 |
+
uint64_t recordCount;
|
| 57 |
+
/**
|
| 58 |
+
* Count of all stall reasons supported on the GPU
|
| 59 |
+
*/
|
| 60 |
+
size_t numStallReasons;
|
| 61 |
+
/**
|
| 62 |
+
* Total number of stall reasons in single record.
|
| 63 |
+
*/
|
| 64 |
+
uint64_t numSelectedStallReasons;
|
| 65 |
+
/**
|
| 66 |
+
* Buffer size in Bytes.
|
| 67 |
+
*/
|
| 68 |
+
uint64_t bufferByteSize;
|
| 69 |
+
} BufferInfo;
|
| 70 |
+
|
| 71 |
+
/**
|
| 72 |
+
* \brief All available stall reasons name and respective indexes
|
| 73 |
+
* will be stored in it.
|
| 74 |
+
*/
|
| 75 |
+
typedef struct PACKED_ALIGNMENT {
|
| 76 |
+
/**
|
| 77 |
+
* Number of all available stall reasons
|
| 78 |
+
*/
|
| 79 |
+
size_t numStallReasons;
|
| 80 |
+
/**
|
| 81 |
+
* Stall reasons names of all available stall reasons
|
| 82 |
+
*/
|
| 83 |
+
char **stallReasons;
|
| 84 |
+
/**
|
| 85 |
+
* Stall reason index of all available stall reasons
|
| 86 |
+
*/
|
| 87 |
+
uint32_t *stallReasonIndex;
|
| 88 |
+
} PcSamplingStallReasons;
|
| 89 |
+
|
| 90 |
+
/**
|
| 91 |
+
* \brief CUPTI PC sampling buffer types.
|
| 92 |
+
*
|
| 93 |
+
*/
|
| 94 |
+
typedef enum {
|
| 95 |
+
/**
|
| 96 |
+
* Invalid buffer type.
|
| 97 |
+
*/
|
| 98 |
+
PC_SAMPLING_BUFFER_INVALID = 0,
|
| 99 |
+
/**
|
| 100 |
+
* Refers to CUpti_PCSamplingData buffer.
|
| 101 |
+
*/
|
| 102 |
+
PC_SAMPLING_BUFFER_PC_TO_COUNTER_DATA = 1
|
| 103 |
+
} PcSamplingBufferType;
|
| 104 |
+
|
| 105 |
+
/**
|
| 106 |
+
* \brief CUPTI PC sampling utility API result codes.
|
| 107 |
+
*
|
| 108 |
+
* Error and result codes returned by CUPTI PC sampling utility API.
|
| 109 |
+
*/
|
| 110 |
+
typedef enum {
|
| 111 |
+
/**
|
| 112 |
+
* No error
|
| 113 |
+
*/
|
| 114 |
+
CUPTI_UTIL_SUCCESS = 0,
|
| 115 |
+
/**
|
| 116 |
+
* One or more of the parameters are invalid.
|
| 117 |
+
*/
|
| 118 |
+
CUPTI_UTIL_ERROR_INVALID_PARAMETER = 1,
|
| 119 |
+
/**
|
| 120 |
+
* Unable to create a new file
|
| 121 |
+
*/
|
| 122 |
+
CUPTI_UTIL_ERROR_UNABLE_TO_CREATE_FILE = 2,
|
| 123 |
+
/**
|
| 124 |
+
* Unable to open a file
|
| 125 |
+
*/
|
| 126 |
+
CUPTI_UTIL_ERROR_UNABLE_TO_OPEN_FILE = 3,
|
| 127 |
+
/**
|
| 128 |
+
* Read or write operation failed
|
| 129 |
+
*/
|
| 130 |
+
CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED = 4,
|
| 131 |
+
/**
|
| 132 |
+
* Provided file handle is corrupted.
|
| 133 |
+
*/
|
| 134 |
+
CUPTI_UTIL_ERROR_FILE_HANDLE_CORRUPTED = 5,
|
| 135 |
+
/**
|
| 136 |
+
* seek operation failed.
|
| 137 |
+
*/
|
| 138 |
+
CUPTI_UTIL_ERROR_SEEK_OPERATION_FAILED = 6,
|
| 139 |
+
/**
|
| 140 |
+
* Unable to allocate enough memory to perform the requested
|
| 141 |
+
* operation.
|
| 142 |
+
*/
|
| 143 |
+
CUPTI_UTIL_ERROR_OUT_OF_MEMORY = 7,
|
| 144 |
+
/**
|
| 145 |
+
* An unknown internal error has occurred.
|
| 146 |
+
*/
|
| 147 |
+
CUPTI_UTIL_ERROR_UNKNOWN = 999,
|
| 148 |
+
CUPTI_UTIL_ERROR_FORCE_INT = 0x7fffffff
|
| 149 |
+
} CUptiUtilResult;
|
| 150 |
+
|
| 151 |
+
/**
|
| 152 |
+
* \brief Params for \ref CuptiUtilPutPcSampData
|
| 153 |
+
*/
|
| 154 |
+
typedef struct {
|
| 155 |
+
/**
|
| 156 |
+
* Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
|
| 157 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 158 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 159 |
+
*/
|
| 160 |
+
size_t size;
|
| 161 |
+
/**
|
| 162 |
+
* Type of buffer to store in file
|
| 163 |
+
*/
|
| 164 |
+
PcSamplingBufferType bufferType;
|
| 165 |
+
/**
|
| 166 |
+
* PC sampling buffer.
|
| 167 |
+
*/
|
| 168 |
+
void *pSamplingData;
|
| 169 |
+
/**
|
| 170 |
+
* Number of configured attributes
|
| 171 |
+
*/
|
| 172 |
+
size_t numAttributes;
|
| 173 |
+
/**
|
| 174 |
+
* Refer \ref CUpti_PCSamplingConfigurationInfo
|
| 175 |
+
* It is expected to provide configuration details of at least
|
| 176 |
+
* CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON attribute.
|
| 177 |
+
*/
|
| 178 |
+
CUpti_PCSamplingConfigurationInfo *pPCSamplingConfigurationInfo;
|
| 179 |
+
/**
|
| 180 |
+
* Refer \ref PcSamplingStallReasons.
|
| 181 |
+
*/
|
| 182 |
+
PcSamplingStallReasons *pPcSamplingStallReasons;
|
| 183 |
+
/**
|
| 184 |
+
* File name to store buffer into it.
|
| 185 |
+
*/
|
| 186 |
+
const char* fileName;
|
| 187 |
+
} CUptiUtil_PutPcSampDataParams;
|
| 188 |
+
#define CUptiUtil_PutPcSampDataParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_PutPcSampDataParams, fileName)
|
| 189 |
+
|
| 190 |
+
/**
|
| 191 |
+
* \brief Dump PC sampling data into the file.
|
| 192 |
+
*
|
| 193 |
+
* This API can be called multiple times.
|
| 194 |
+
* It will append buffer in the file.
|
| 195 |
+
* For every buffer it will store BufferInfo
|
| 196 |
+
* so that before retrieving data it will help to allocate buffer
|
| 197 |
+
* to store retrieved data.
|
| 198 |
+
* This API creates file if file does not present.
|
| 199 |
+
* If stallReasonIndex or stallReasons pointer of \ref CUptiUtil_PutPcSampDataParams is NULL
|
| 200 |
+
* then stall reasons data will not be stored in file.
|
| 201 |
+
* It is expected to store all available stall reason data at least once to refer it during
|
| 202 |
+
* offline correlation.
|
| 203 |
+
*
|
| 204 |
+
* \retval CUPTI_UTIL_SUCCESS
|
| 205 |
+
* \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if buffer type is invalid
|
| 206 |
+
* or if either of pSamplingData, pParams pointer is NULL or stall reason configuration details not provided
|
| 207 |
+
* or filename is empty.
|
| 208 |
+
* \retval CUPTI_UTIL_ERROR_UNABLE_TO_CREATE_FILE
|
| 209 |
+
* \retval CUPTI_UTIL_ERROR_UNABLE_TO_OPEN_FILE
|
| 210 |
+
* \retval CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED
|
| 211 |
+
*/
|
| 212 |
+
CUptiUtilResult CUPTIUTILAPI CuptiUtilPutPcSampData(CUptiUtil_PutPcSampDataParams *pParams);
|
| 213 |
+
|
| 214 |
+
/**
|
| 215 |
+
* \brief Params for \ref CuptiUtilGetHeaderData
|
| 216 |
+
*/
|
| 217 |
+
typedef struct {
|
| 218 |
+
/**
|
| 219 |
+
* Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
|
| 220 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 221 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 222 |
+
*/
|
| 223 |
+
size_t size;
|
| 224 |
+
/**
|
| 225 |
+
* File handle.
|
| 226 |
+
*/
|
| 227 |
+
std::ifstream *fileHandler;
|
| 228 |
+
/**
|
| 229 |
+
* Header Info.
|
| 230 |
+
*/
|
| 231 |
+
Header headerInfo;
|
| 232 |
+
|
| 233 |
+
} CUptiUtil_GetHeaderDataParams;
|
| 234 |
+
#define CUptiUtil_GetHeaderDataParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_GetHeaderDataParams, headerInfo)
|
| 235 |
+
|
| 236 |
+
/**
|
| 237 |
+
* \brief Get header data of file.
|
| 238 |
+
*
|
| 239 |
+
* This API must be called once initially while retrieving data from file.
|
| 240 |
+
* \ref Header structure, it gives info about total number
|
| 241 |
+
* of buffers present in the file.
|
| 242 |
+
*
|
| 243 |
+
* \retval CUPTI_UTIL_SUCCESS
|
| 244 |
+
* \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if either of pParam or fileHandle is NULL or param struct size is incorrect.
|
| 245 |
+
* \retval CUPTI_UTIL_ERROR_FILE_HANDLE_CORRUPTED file handle is not in good state to read data from file
|
| 246 |
+
* \retval CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED failed to read data from file.
|
| 247 |
+
*/
|
| 248 |
+
CUptiUtilResult CUPTIUTILAPI CuptiUtilGetHeaderData(CUptiUtil_GetHeaderDataParams *pParams);
|
| 249 |
+
|
| 250 |
+
/**
|
| 251 |
+
* \brief Params for \ref CuptiUtilGetBufferInfo
|
| 252 |
+
*/
|
| 253 |
+
typedef struct {
|
| 254 |
+
/**
|
| 255 |
+
* Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
|
| 256 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 257 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 258 |
+
*/
|
| 259 |
+
size_t size;
|
| 260 |
+
/**
|
| 261 |
+
* File handle.
|
| 262 |
+
*/
|
| 263 |
+
std::ifstream *fileHandler;
|
| 264 |
+
/**
|
| 265 |
+
* Buffer Info.
|
| 266 |
+
*/
|
| 267 |
+
BufferInfo bufferInfoData;
|
| 268 |
+
} CUptiUtil_GetBufferInfoParams;
|
| 269 |
+
#define CUptiUtil_GetBufferInfoParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_GetBufferInfoParams, bufferInfoData)
|
| 270 |
+
|
| 271 |
+
/**
|
| 272 |
+
* \brief Get buffer info data of file.
|
| 273 |
+
*
|
| 274 |
+
* This API must be called every time before calling CuptiUtilGetPcSampData API.
|
| 275 |
+
* \ref BufferInfo structure, it gives info about recordCount and stallReasonCount
|
| 276 |
+
* of every record in the buffer. This will help to allocate exact buffer to retrieve data into it.
|
| 277 |
+
*
|
| 278 |
+
* \retval CUPTI_UTIL_SUCCESS
|
| 279 |
+
* \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if either of pParam or fileHandle is NULL or param struct size is incorrect.
|
| 280 |
+
* \retval CUPTI_UTIL_ERROR_FILE_HANDLE_CORRUPTED file handle is not in good state to read data from file.
|
| 281 |
+
* \retval CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED failed to read data from file.
|
| 282 |
+
*/
|
| 283 |
+
CUptiUtilResult CUPTIUTILAPI CuptiUtilGetBufferInfo(CUptiUtil_GetBufferInfoParams *pParams);
|
| 284 |
+
|
| 285 |
+
/**
|
| 286 |
+
* \brief Params for \ref CuptiUtilGetPcSampData
|
| 287 |
+
*/
|
| 288 |
+
typedef struct {
|
| 289 |
+
/**
|
| 290 |
+
* Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
|
| 291 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 292 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 293 |
+
*/
|
| 294 |
+
size_t size;
|
| 295 |
+
/**
|
| 296 |
+
* File handle.
|
| 297 |
+
*/
|
| 298 |
+
std::ifstream *fileHandler;
|
| 299 |
+
/**
|
| 300 |
+
* Type of buffer to store in file
|
| 301 |
+
*/
|
| 302 |
+
PcSamplingBufferType bufferType;
|
| 303 |
+
/**
|
| 304 |
+
* Pointer to collected buffer info using \ref CuptiUtilGetBufferInfo
|
| 305 |
+
*/
|
| 306 |
+
BufferInfo *pBufferInfoData;
|
| 307 |
+
/**
|
| 308 |
+
* Pointer to allocated memory to store retrieved data from file.
|
| 309 |
+
*/
|
| 310 |
+
void *pSamplingData;
|
| 311 |
+
/**
|
| 312 |
+
* Number of configuration attributes
|
| 313 |
+
*/
|
| 314 |
+
size_t numAttributes;
|
| 315 |
+
/**
|
| 316 |
+
* Refer \ref CUpti_PCSamplingConfigurationInfo
|
| 317 |
+
*/
|
| 318 |
+
CUpti_PCSamplingConfigurationInfo *pPCSamplingConfigurationInfo;
|
| 319 |
+
/**
|
| 320 |
+
* Refer \ref PcSamplingStallReasons.
|
| 321 |
+
* For stallReasons field of \ref PcSamplingStallReasons it is expected to
|
| 322 |
+
* allocate memory for each string element of array.
|
| 323 |
+
*/
|
| 324 |
+
PcSamplingStallReasons *pPcSamplingStallReasons;
|
| 325 |
+
} CUptiUtil_GetPcSampDataParams;
|
| 326 |
+
#define CUptiUtil_GetPcSampDataParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_GetPcSampDataParams, pPcSamplingStallReasons)
|
| 327 |
+
|
| 328 |
+
/**
|
| 329 |
+
* \brief Retrieve PC sampling data from file into allocated buffer.
|
| 330 |
+
*
|
| 331 |
+
* This API must be called after CuptiUtilGetBufferInfo API.
|
| 332 |
+
* It will retrieve data from file into allocated buffer.
|
| 333 |
+
*
|
| 334 |
+
* \retval CUPTI_UTIL_SUCCESS
|
| 335 |
+
* \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if buffer type is invalid
|
| 336 |
+
* or if either of pSampData, pParams is NULL. If pPcSamplingStallReasons is not NULL then
|
| 337 |
+
* error out if either of stallReasonIndex, stallReasons or stallReasons array element pointer is NULL.
|
| 338 |
+
* or filename is empty.
|
| 339 |
+
* \retval CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED
|
| 340 |
+
* \retval CUPTI_UTIL_ERROR_FILE_HANDLE_CORRUPTED file handle is not in good state to read data from file.
|
| 341 |
+
*/
|
| 342 |
+
CUptiUtilResult CUPTIUTILAPI CuptiUtilGetPcSampData(CUptiUtil_GetPcSampDataParams *pParams);
|
| 343 |
+
|
| 344 |
+
/**
|
| 345 |
+
* \brief Params for \ref CuptiUtilMergePcSampData
|
| 346 |
+
*/
|
| 347 |
+
typedef struct
|
| 348 |
+
{
|
| 349 |
+
/**
|
| 350 |
+
* Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
|
| 351 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 352 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 353 |
+
*/
|
| 354 |
+
size_t size;
|
| 355 |
+
/**
|
| 356 |
+
* Number of buffers to merge.
|
| 357 |
+
*/
|
| 358 |
+
size_t numberOfBuffers;
|
| 359 |
+
/**
|
| 360 |
+
* Pointer to array of buffers to merge
|
| 361 |
+
*/
|
| 362 |
+
CUpti_PCSamplingData *PcSampDataBuffer;
|
| 363 |
+
/**
|
| 364 |
+
* Pointer to array of merged buffers as per the range id.
|
| 365 |
+
*/
|
| 366 |
+
CUpti_PCSamplingData **MergedPcSampDataBuffers;
|
| 367 |
+
/**
|
| 368 |
+
* Number of merged buffers.
|
| 369 |
+
*/
|
| 370 |
+
size_t *numMergedBuffer;
|
| 371 |
+
} CUptiUtil_MergePcSampDataParams;
|
| 372 |
+
#define CUptiUtil_MergePcSampDataParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_MergePcSampDataParams, numMergedBuffer)
|
| 373 |
+
|
| 374 |
+
/**
|
| 375 |
+
* \brief Merge PC sampling data range id wise.
|
| 376 |
+
*
|
| 377 |
+
* This API merge PC sampling data range id wise.
|
| 378 |
+
* It allocates memory for merged data and fill data in it
|
| 379 |
+
* and provide buffer pointer in MergedPcSampDataBuffers field.
|
| 380 |
+
* It is expected from user to free merge data buffers after use.
|
| 381 |
+
*
|
| 382 |
+
* \retval CUPTI_UTIL_SUCCESS
|
| 383 |
+
* \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if param struct size is invalid
|
| 384 |
+
* or count of buffers to merge is invalid i.e less than 1
|
| 385 |
+
* or either of PcSampDataBuffer, MergedPcSampDataBuffers, numMergedBuffer is NULL
|
| 386 |
+
* \retval CUPTI_UTIL_ERROR_OUT_OF_MEMORY Unable to allocate memory for merged buffer.
|
| 387 |
+
*/
|
| 388 |
+
CUptiUtilResult CUPTIUTILAPI CuptiUtilMergePcSampData(CUptiUtil_MergePcSampDataParams *pParams);
|
| 389 |
+
|
| 390 |
+
/** @} */ /* END CUPTI_PCSAMPLING_UTILITY */
|
| 391 |
+
|
| 392 |
+
} }
|
| 393 |
+
|
| 394 |
+
#if defined(__GNUC__)
|
| 395 |
+
#pragma GCC visibility pop
|
| 396 |
+
#endif
|
| 397 |
+
|
| 398 |
+
#if defined(__cplusplus)
|
| 399 |
+
}
|
| 400 |
+
#endif
|
| 401 |
+
|
| 402 |
+
#endif
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/driver_types.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/generated_cudaVDPAU_meta.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file is generated. Any changes you make will be lost during the next clean build.
|
| 2 |
+
|
| 3 |
+
// Dependent includes
|
| 4 |
+
#include <vdpau/vdpau.h>
|
| 5 |
+
|
| 6 |
+
// CUDA public interface, for type definitions and cu* function prototypes
|
| 7 |
+
#include "cudaVDPAU.h"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
// *************************************************************************
|
| 11 |
+
// Definitions of structs to hold parameters for each function
|
| 12 |
+
// *************************************************************************
|
| 13 |
+
|
| 14 |
+
typedef struct cuVDPAUGetDevice_params_st {
|
| 15 |
+
CUdevice *pDevice;
|
| 16 |
+
VdpDevice vdpDevice;
|
| 17 |
+
VdpGetProcAddress *vdpGetProcAddress;
|
| 18 |
+
} cuVDPAUGetDevice_params;
|
| 19 |
+
|
| 20 |
+
typedef struct cuVDPAUCtxCreate_v2_params_st {
|
| 21 |
+
CUcontext *pCtx;
|
| 22 |
+
unsigned int flags;
|
| 23 |
+
CUdevice device;
|
| 24 |
+
VdpDevice vdpDevice;
|
| 25 |
+
VdpGetProcAddress *vdpGetProcAddress;
|
| 26 |
+
} cuVDPAUCtxCreate_v2_params;
|
| 27 |
+
|
| 28 |
+
typedef struct cuGraphicsVDPAURegisterVideoSurface_params_st {
|
| 29 |
+
CUgraphicsResource *pCudaResource;
|
| 30 |
+
VdpVideoSurface vdpSurface;
|
| 31 |
+
unsigned int flags;
|
| 32 |
+
} cuGraphicsVDPAURegisterVideoSurface_params;
|
| 33 |
+
|
| 34 |
+
typedef struct cuGraphicsVDPAURegisterOutputSurface_params_st {
|
| 35 |
+
CUgraphicsResource *pCudaResource;
|
| 36 |
+
VdpOutputSurface vdpSurface;
|
| 37 |
+
unsigned int flags;
|
| 38 |
+
} cuGraphicsVDPAURegisterOutputSurface_params;
|
| 39 |
+
|
| 40 |
+
typedef struct cuVDPAUCtxCreate_params_st {
|
| 41 |
+
CUcontext *pCtx;
|
| 42 |
+
unsigned int flags;
|
| 43 |
+
CUdevice device;
|
| 44 |
+
VdpDevice vdpDevice;
|
| 45 |
+
VdpGetProcAddress *vdpGetProcAddress;
|
| 46 |
+
} cuVDPAUCtxCreate_params;
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/nvperf_target.h
ADDED
|
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef NVPERF_TARGET_H
|
| 2 |
+
#define NVPERF_TARGET_H
|
| 3 |
+
|
| 4 |
+
/*
|
| 5 |
+
* Copyright 2014-2024 NVIDIA Corporation. All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* NOTICE TO USER:
|
| 8 |
+
*
|
| 9 |
+
* This source code is subject to NVIDIA ownership rights under U.S. and
|
| 10 |
+
* international Copyright laws.
|
| 11 |
+
*
|
| 12 |
+
* This software and the information contained herein is PROPRIETARY and
|
| 13 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
|
| 14 |
+
* of a form of NVIDIA software license agreement.
|
| 15 |
+
*
|
| 16 |
+
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
|
| 17 |
+
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
|
| 18 |
+
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
|
| 19 |
+
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
|
| 20 |
+
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 21 |
+
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
|
| 22 |
+
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
|
| 23 |
+
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
| 24 |
+
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
|
| 25 |
+
* OR PERFORMANCE OF THIS SOURCE CODE.
|
| 26 |
+
*
|
| 27 |
+
* U.S. Government End Users. This source code is a "commercial item" as
|
| 28 |
+
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
|
| 29 |
+
* "commercial computer software" and "commercial computer software
|
| 30 |
+
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
|
| 31 |
+
* and is provided to the U.S. Government only as a commercial end item.
|
| 32 |
+
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
|
| 33 |
+
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
|
| 34 |
+
* source code with only those rights set forth herein.
|
| 35 |
+
*
|
| 36 |
+
* Any use of this source code in individual and commercial software must
|
| 37 |
+
* include, in the user documentation and internal comments to the code,
|
| 38 |
+
* the above Disclaimer and U.S. Government End Users Notice.
|
| 39 |
+
*/
|
| 40 |
+
|
| 41 |
+
#include <stddef.h>
|
| 42 |
+
#include <stdint.h>
|
| 43 |
+
#include "nvperf_common.h"
|
| 44 |
+
|
| 45 |
+
#if defined(__GNUC__) && defined(NVPA_SHARED_LIB)
|
| 46 |
+
#pragma GCC visibility push(default)
|
| 47 |
+
#if !defined(NVPW_LOCAL)
|
| 48 |
+
#define NVPW_LOCAL __attribute__ ((visibility ("hidden")))
|
| 49 |
+
#endif
|
| 50 |
+
#else
|
| 51 |
+
#if !defined(NVPW_LOCAL)
|
| 52 |
+
#define NVPW_LOCAL
|
| 53 |
+
#endif
|
| 54 |
+
#endif
|
| 55 |
+
|
| 56 |
+
#ifdef __cplusplus
|
| 57 |
+
extern "C" {
|
| 58 |
+
#endif
|
| 59 |
+
|
| 60 |
+
/**
|
| 61 |
+
* @file nvperf_target.h
|
| 62 |
+
*/
|
| 63 |
+
|
| 64 |
+
#ifndef NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_DEFINED
|
| 65 |
+
#define NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_DEFINED
|
| 66 |
+
/// GPU architecture support level
|
| 67 |
+
typedef enum NVPW_GpuArchitectureSupportLevel
|
| 68 |
+
{
|
| 69 |
+
NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_UNKNOWN = 0,
|
| 70 |
+
NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_UNSUPPORTED,
|
| 71 |
+
NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_SUPPORTED
|
| 72 |
+
} NVPW_GpuArchitectureSupportLevel;
|
| 73 |
+
#endif //NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_DEFINED
|
| 74 |
+
|
| 75 |
+
#ifndef NVPW_SLI_SUPPORT_LEVEL_DEFINED
|
| 76 |
+
#define NVPW_SLI_SUPPORT_LEVEL_DEFINED
|
| 77 |
+
/// SLI configuration support level
|
| 78 |
+
typedef enum NVPW_SliSupportLevel
|
| 79 |
+
{
|
| 80 |
+
NVPW_SLI_SUPPORT_LEVEL_UNKNOWN = 0,
|
| 81 |
+
NVPW_SLI_SUPPORT_LEVEL_UNSUPPORTED,
|
| 82 |
+
/// Only Non-SLI configurations are supported.
|
| 83 |
+
NVPW_SLI_SUPPORT_LEVEL_SUPPORTED_NON_SLI_CONFIGURATION
|
| 84 |
+
} NVPW_SliSupportLevel;
|
| 85 |
+
#endif //NVPW_SLI_SUPPORT_LEVEL_DEFINED
|
| 86 |
+
|
| 87 |
+
#ifndef NVPW_VGPU_SUPPORT_LEVEL_DEFINED
|
| 88 |
+
#define NVPW_VGPU_SUPPORT_LEVEL_DEFINED
|
| 89 |
+
/// Virtualized GPU configuration support level
|
| 90 |
+
typedef enum NVPW_VGpuSupportLevel
|
| 91 |
+
{
|
| 92 |
+
NVPW_VGPU_SUPPORT_LEVEL_UNKNOWN = 0,
|
| 93 |
+
NVPW_VGPU_SUPPORT_LEVEL_UNSUPPORTED,
|
| 94 |
+
/// Supported but not allowed by system admin.
|
| 95 |
+
NVPW_VGPU_SUPPORT_LEVEL_SUPPORTED_DISALLOWED,
|
| 96 |
+
NVPW_VGPU_SUPPORT_LEVEL_SUPPORTED_ALLOWED,
|
| 97 |
+
NVPW_VGPU_SUPPORT_LEVEL_SUPPORTED_NON_VGPU_CONFIGURATION
|
| 98 |
+
} NVPW_VGpuSupportLevel;
|
| 99 |
+
#endif //NVPW_VGPU_SUPPORT_LEVEL_DEFINED
|
| 100 |
+
|
| 101 |
+
#ifndef NVPW_CONF_COMPUTE_SUPPORT_LEVEL_DEFINED
|
| 102 |
+
#define NVPW_CONF_COMPUTE_SUPPORT_LEVEL_DEFINED
|
| 103 |
+
/// Confidential Compute mode support level
|
| 104 |
+
typedef enum NVPW_ConfidentialComputeSupportLevel
|
| 105 |
+
{
|
| 106 |
+
NVPW_CONF_COMPUTE_SUPPORT_LEVEL_UNKNOWN = 0,
|
| 107 |
+
NVPW_CONF_COMPUTE_SUPPORT_LEVEL_UNSUPPORTED,
|
| 108 |
+
NVPW_CONF_COMPUTE_SUPPORT_LEVEL_SUPPORTED_NON_CONF_COMPUTE_CONFIGURATION,
|
| 109 |
+
NVPW_CONF_COMPUTE_SUPPORT_LEVEL_SUPPORTED_CONF_COMPUTE_DEVTOOLS_MODE
|
| 110 |
+
} NVPW_ConfidentialComputeSupportLevel;
|
| 111 |
+
#endif //NVPW_CONF_COMPUTE_SUPPORT_LEVEL_DEFINED
|
| 112 |
+
|
| 113 |
+
#ifndef NVPW_CMP_SUPPORT_LEVEL_DEFINED
|
| 114 |
+
#define NVPW_CMP_SUPPORT_LEVEL_DEFINED
|
| 115 |
+
/// CMP support level
|
| 116 |
+
typedef enum NVPW_CmpSupportLevel
|
| 117 |
+
{
|
| 118 |
+
NVPW_CMP_SUPPORT_LEVEL_UNKNOWN = 0,
|
| 119 |
+
NVPW_CMP_SUPPORT_LEVEL_UNSUPPORTED,
|
| 120 |
+
NVPW_CMP_SUPPORT_LEVEL_SUPPORTED_NON_CMP_CONFIGURATON
|
| 121 |
+
} NVPW_CmpSupportLevel;
|
| 122 |
+
#endif //NVPW_CMP_SUPPORT_LEVEL_DEFINED
|
| 123 |
+
|
| 124 |
+
#ifndef NVPW_WSL_SUPPORT_LEVEL_DEFINED
|
| 125 |
+
#define NVPW_WSL_SUPPORT_LEVEL_DEFINED
|
| 126 |
+
/// WSL support level
|
| 127 |
+
typedef enum NVPW_WslSupportLevel
|
| 128 |
+
{
|
| 129 |
+
NVPW_WSL_SUPPORT_LEVEL_UNKNOWN = 0,
|
| 130 |
+
NVPW_WSL_SUPPORT_LEVEL_UNSUPPORTED_INSUFFICIENT_DRIVER_VERSION,
|
| 131 |
+
NVPW_WSL_SUPPORT_LEVEL_SUPPORTED,
|
| 132 |
+
NVPW_WSL_SUPPORT_LEVEL_SUPPORTED_NON_WSL_CONFIGURATION
|
| 133 |
+
} NVPW_WslSupportLevel;
|
| 134 |
+
#endif //NVPW_WSL_SUPPORT_LEVEL_DEFINED
|
| 135 |
+
|
| 136 |
+
#ifndef NVPW_MIG_SUPPORT_LEVEL_DEFINED
|
| 137 |
+
#define NVPW_MIG_SUPPORT_LEVEL_DEFINED
|
| 138 |
+
/// MIG support level
|
| 139 |
+
typedef enum NVPW_MigSupportLevel
|
| 140 |
+
{
|
| 141 |
+
NVPW_MIG_SUPPORT_LEVEL_UNKNOWN = 0,
|
| 142 |
+
NVPW_MIG_SUPPORT_LEVEL_UNSUPPORTED,
|
| 143 |
+
NVPW_MIG_SUPPORT_LEVEL_SUPPORTED,
|
| 144 |
+
NVPW_MIG_SUPPORT_LEVEL_SUPPORTED_NON_MIG_CONFIGURATION
|
| 145 |
+
} NVPW_MigSupportLevel;
|
| 146 |
+
#endif //NVPW_MIG_SUPPORT_LEVEL_DEFINED
|
| 147 |
+
|
| 148 |
+
typedef struct NVPW_InitializeTarget_Params
|
| 149 |
+
{
|
| 150 |
+
/// [in]
|
| 151 |
+
size_t structSize;
|
| 152 |
+
/// [in] assign to NULL
|
| 153 |
+
void* pPriv;
|
| 154 |
+
} NVPW_InitializeTarget_Params;
|
| 155 |
+
#define NVPW_InitializeTarget_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_InitializeTarget_Params, pPriv)
|
| 156 |
+
|
| 157 |
+
/// Load the target library.
|
| 158 |
+
NVPA_Status NVPW_InitializeTarget(NVPW_InitializeTarget_Params* pParams);
|
| 159 |
+
|
| 160 |
+
typedef struct NVPW_GetDeviceCount_Params
|
| 161 |
+
{
|
| 162 |
+
/// [in]
|
| 163 |
+
size_t structSize;
|
| 164 |
+
/// [in] assign to NULL
|
| 165 |
+
void* pPriv;
|
| 166 |
+
size_t numDevices;
|
| 167 |
+
} NVPW_GetDeviceCount_Params;
|
| 168 |
+
#define NVPW_GetDeviceCount_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_GetDeviceCount_Params, numDevices)
|
| 169 |
+
|
| 170 |
+
NVPA_Status NVPW_GetDeviceCount(NVPW_GetDeviceCount_Params* pParams);
|
| 171 |
+
|
| 172 |
+
typedef struct NVPW_Device_GetNames_Params
|
| 173 |
+
{
|
| 174 |
+
/// [in]
|
| 175 |
+
size_t structSize;
|
| 176 |
+
/// [in] assign to NULL
|
| 177 |
+
void* pPriv;
|
| 178 |
+
size_t deviceIndex;
|
| 179 |
+
const char* pDeviceName;
|
| 180 |
+
const char* pChipName;
|
| 181 |
+
} NVPW_Device_GetNames_Params;
|
| 182 |
+
#define NVPW_Device_GetNames_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_GetNames_Params, pChipName)
|
| 183 |
+
|
| 184 |
+
NVPA_Status NVPW_Device_GetNames(NVPW_Device_GetNames_Params* pParams);
|
| 185 |
+
|
| 186 |
+
typedef struct NVPW_PciBusId
|
| 187 |
+
{
|
| 188 |
+
/// The PCI domain on which the device bus resides.
|
| 189 |
+
uint32_t domain;
|
| 190 |
+
/// The bus on which the device resides.
|
| 191 |
+
uint16_t bus;
|
| 192 |
+
/// device ID.
|
| 193 |
+
uint16_t device;
|
| 194 |
+
} NVPW_PciBusId;
|
| 195 |
+
#define NVPW_PciBusId_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PciBusId, device)
|
| 196 |
+
|
| 197 |
+
typedef struct NVPW_Device_GetPciBusIds_Params
|
| 198 |
+
{
|
| 199 |
+
/// [in]
|
| 200 |
+
size_t structSize;
|
| 201 |
+
/// [in] assign to NULL
|
| 202 |
+
void* pPriv;
|
| 203 |
+
/// [in] caller-allocated array of NVPW_PciBusId, indexed by NVPW deviceIndex
|
| 204 |
+
NVPW_PciBusId* pBusIds;
|
| 205 |
+
/// [in] size of the pBusIDs array; use result from NVPW_GetDeviceCount
|
| 206 |
+
size_t numDevices;
|
| 207 |
+
} NVPW_Device_GetPciBusIds_Params;
|
| 208 |
+
#define NVPW_Device_GetPciBusIds_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_GetPciBusIds_Params, numDevices)
|
| 209 |
+
|
| 210 |
+
NVPA_Status NVPW_Device_GetPciBusIds(NVPW_Device_GetPciBusIds_Params* pParams);
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
#define NVPW_DEVICE_MIG_GPU_INSTANCE_ID_INVALID 0xFFFFFFFFu
|
| 214 |
+
#define NVPW_DEVICE_MIG_GPU_INSTANCE_ID_FULLCHIP 0xFFFFFFFEu
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
typedef struct NVPW_Device_GetMigAttributes_Params
|
| 218 |
+
{
|
| 219 |
+
/// [in]
|
| 220 |
+
size_t structSize;
|
| 221 |
+
/// [in] assign to NULL
|
| 222 |
+
void* pPriv;
|
| 223 |
+
/// [in]
|
| 224 |
+
size_t deviceIndex;
|
| 225 |
+
/// [out]
|
| 226 |
+
NVPA_Bool isMigPartition;
|
| 227 |
+
/// [out]
|
| 228 |
+
uint32_t gpuInstanceId;
|
| 229 |
+
/// [out]
|
| 230 |
+
uint32_t computeInstanceId;
|
| 231 |
+
} NVPW_Device_GetMigAttributes_Params;
|
| 232 |
+
#define NVPW_Device_GetMigAttributes_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_GetMigAttributes_Params, computeInstanceId)
|
| 233 |
+
|
| 234 |
+
NVPA_Status NVPW_Device_GetMigAttributes(NVPW_Device_GetMigAttributes_Params* pParams);
|
| 235 |
+
|
| 236 |
+
typedef struct NVPW_Adapter_GetDeviceIndex_Params
|
| 237 |
+
{
|
| 238 |
+
/// [in]
|
| 239 |
+
size_t structSize;
|
| 240 |
+
/// [in] assign to NULL
|
| 241 |
+
void* pPriv;
|
| 242 |
+
/// [in]
|
| 243 |
+
struct IDXGIAdapter* pAdapter;
|
| 244 |
+
/// [in]
|
| 245 |
+
size_t sliIndex;
|
| 246 |
+
/// [out]
|
| 247 |
+
size_t deviceIndex;
|
| 248 |
+
} NVPW_Adapter_GetDeviceIndex_Params;
|
| 249 |
+
#define NVPW_Adapter_GetDeviceIndex_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Adapter_GetDeviceIndex_Params, deviceIndex)
|
| 250 |
+
|
| 251 |
+
NVPA_Status NVPW_Adapter_GetDeviceIndex(NVPW_Adapter_GetDeviceIndex_Params* pParams);
|
| 252 |
+
|
| 253 |
+
typedef struct NVPW_CounterData_GetNumRanges_Params
|
| 254 |
+
{
|
| 255 |
+
/// [in]
|
| 256 |
+
size_t structSize;
|
| 257 |
+
/// [in] assign to NULL
|
| 258 |
+
void* pPriv;
|
| 259 |
+
const uint8_t* pCounterDataImage;
|
| 260 |
+
size_t numRanges;
|
| 261 |
+
} NVPW_CounterData_GetNumRanges_Params;
|
| 262 |
+
#define NVPW_CounterData_GetNumRanges_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_CounterData_GetNumRanges_Params, numRanges)
|
| 263 |
+
|
| 264 |
+
NVPA_Status NVPW_CounterData_GetNumRanges(NVPW_CounterData_GetNumRanges_Params* pParams);
|
| 265 |
+
|
| 266 |
+
typedef struct NVPW_CounterData_GetChipName_Params
|
| 267 |
+
{
|
| 268 |
+
/// [in]
|
| 269 |
+
size_t structSize;
|
| 270 |
+
/// [in] assign to NULL
|
| 271 |
+
void* pPriv;
|
| 272 |
+
/// [in]
|
| 273 |
+
const uint8_t* pCounterDataImage;
|
| 274 |
+
/// [in]
|
| 275 |
+
size_t counterDataImageSize;
|
| 276 |
+
/// [out]
|
| 277 |
+
const char* pChipName;
|
| 278 |
+
} NVPW_CounterData_GetChipName_Params;
|
| 279 |
+
#define NVPW_CounterData_GetChipName_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_CounterData_GetChipName_Params, pChipName)
|
| 280 |
+
|
| 281 |
+
NVPA_Status NVPW_CounterData_GetChipName(NVPW_CounterData_GetChipName_Params* pParams);
|
| 282 |
+
|
| 283 |
+
typedef struct NVPW_Config_GetNumPasses_Params
|
| 284 |
+
{
|
| 285 |
+
/// [in]
|
| 286 |
+
size_t structSize;
|
| 287 |
+
/// [in] assign to NULL
|
| 288 |
+
void* pPriv;
|
| 289 |
+
/// [in]
|
| 290 |
+
const uint8_t* pConfig;
|
| 291 |
+
/// [out]
|
| 292 |
+
size_t numPipelinedPasses;
|
| 293 |
+
/// [out]
|
| 294 |
+
size_t numIsolatedPasses;
|
| 295 |
+
} NVPW_Config_GetNumPasses_Params;
|
| 296 |
+
#define NVPW_Config_GetNumPasses_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Config_GetNumPasses_Params, numIsolatedPasses)
|
| 297 |
+
|
| 298 |
+
/// Total num passes = numPipelinedPasses + numIsolatedPasses * numNestingLevels
|
| 299 |
+
NVPA_Status NVPW_Config_GetNumPasses(NVPW_Config_GetNumPasses_Params* pParams);
|
| 300 |
+
|
| 301 |
+
typedef struct NVPW_Config_GetNumPasses_V2_Params
|
| 302 |
+
{
|
| 303 |
+
/// [in]
|
| 304 |
+
size_t structSize;
|
| 305 |
+
/// [in] assign to NULL
|
| 306 |
+
void* pPriv;
|
| 307 |
+
/// [in]
|
| 308 |
+
const uint8_t* pConfig;
|
| 309 |
+
/// [out]
|
| 310 |
+
size_t numPasses;
|
| 311 |
+
} NVPW_Config_GetNumPasses_V2_Params;
|
| 312 |
+
#define NVPW_Config_GetNumPasses_V2_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Config_GetNumPasses_V2_Params, numPasses)
|
| 313 |
+
|
| 314 |
+
/// Total num passes = numPasses * numNestingLevels
|
| 315 |
+
NVPA_Status NVPW_Config_GetNumPasses_V2(NVPW_Config_GetNumPasses_V2_Params* pParams);
|
| 316 |
+
|
| 317 |
+
#define NVPW_API_SET_CUDA_PROFILER 0x18209d0775b2f89dULL
|
| 318 |
+
|
| 319 |
+
#define NVPW_API_SET_D3D11_PROFILER 0xca55c6738445db2bULL
|
| 320 |
+
|
| 321 |
+
#define NVPW_API_SET_D3D12_PROFILER 0xc0c2d46dd7c7ad78ULL
|
| 322 |
+
|
| 323 |
+
#define NVPW_API_SET_EGL_PROFILER 0x3c3747dae1f9565cULL
|
| 324 |
+
|
| 325 |
+
#define NVPW_API_SET_GPU_PERIODICSAMPLER 0x9f4c2571fc0b2e8aULL
|
| 326 |
+
|
| 327 |
+
#define NVPW_API_SET_METRICSEVALUATOR 0x0368a8768d811af9ULL
|
| 328 |
+
|
| 329 |
+
#define NVPW_API_SET_METRICS_AD10X_COMP 0xbe57278e12cb5288ULL
|
| 330 |
+
|
| 331 |
+
#define NVPW_API_SET_METRICS_AD10X_GRFX 0x5cbf0774f81bf491ULL
|
| 332 |
+
|
| 333 |
+
#define NVPW_API_SET_METRICS_GA100_COMP 0x16b7d8c20d8b4915ULL
|
| 334 |
+
|
| 335 |
+
#define NVPW_API_SET_METRICS_GA100_GRFX 0xc94eaabec04a94faULL
|
| 336 |
+
|
| 337 |
+
#define NVPW_API_SET_METRICS_GA10X_COMP 0xb5d6391c2e299ab5ULL
|
| 338 |
+
|
| 339 |
+
#define NVPW_API_SET_METRICS_GA10X_GRFX 0x6ebc121178b5ce0bULL
|
| 340 |
+
|
| 341 |
+
#define NVPW_API_SET_METRICS_GV100_COMP 0x863705cc57919f72ULL
|
| 342 |
+
|
| 343 |
+
#define NVPW_API_SET_METRICS_GV100_GRFX 0x9900da75d164fecfULL
|
| 344 |
+
|
| 345 |
+
#define NVPW_API_SET_METRICS_GV11B_COMP 0xd3f79a859235848fULL
|
| 346 |
+
|
| 347 |
+
#define NVPW_API_SET_METRICS_GV11B_GRFX 0xeb8e26220106e227ULL
|
| 348 |
+
|
| 349 |
+
#define NVPW_API_SET_METRICS_TU10X_COMP 0x70f40be0afd35da8ULL
|
| 350 |
+
|
| 351 |
+
#define NVPW_API_SET_METRICS_TU10X_GRFX 0xdf219cb838db6968ULL
|
| 352 |
+
|
| 353 |
+
#define NVPW_API_SET_METRICS_TU11X_COMP 0xeb0069d7d0956678ULL
|
| 354 |
+
|
| 355 |
+
#define NVPW_API_SET_METRICS_TU11X_GRFX 0x0977d9342bd62743ULL
|
| 356 |
+
|
| 357 |
+
#define NVPW_API_SET_OPENGL_PROFILER 0xe4cd9ea40f2ee777ULL
|
| 358 |
+
|
| 359 |
+
#define NVPW_API_SET_VULKAN_PROFILER 0x8c56b6a03d779689ULL
|
| 360 |
+
|
| 361 |
+
#define NVPW_SDK_VERSION 0x1e128b6f001423fcULL
|
| 362 |
+
|
| 363 |
+
typedef struct NVPW_QueryVersionNumber_Params
|
| 364 |
+
{
|
| 365 |
+
/// [in]
|
| 366 |
+
size_t structSize;
|
| 367 |
+
/// [in] assign to NULL
|
| 368 |
+
void* pPriv;
|
| 369 |
+
/// [in]
|
| 370 |
+
uint64_t apiSet;
|
| 371 |
+
/// [out]
|
| 372 |
+
uint32_t major;
|
| 373 |
+
/// [out]
|
| 374 |
+
uint32_t minor;
|
| 375 |
+
/// [out]
|
| 376 |
+
uint32_t patch;
|
| 377 |
+
/// [out]
|
| 378 |
+
uint32_t relMajor;
|
| 379 |
+
/// [out]
|
| 380 |
+
uint32_t relMinor;
|
| 381 |
+
/// [out]
|
| 382 |
+
uint32_t relPatch;
|
| 383 |
+
} NVPW_QueryVersionNumber_Params;
|
| 384 |
+
#define NVPW_QueryVersionNumber_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_QueryVersionNumber_Params, relPatch)
|
| 385 |
+
|
| 386 |
+
/// Query version number of an API set
|
| 387 |
+
NVPA_Status NVPW_QueryVersionNumber(NVPW_QueryVersionNumber_Params* pParams);
|
| 388 |
+
|
| 389 |
+
typedef enum NVPW_Device_ClockStatus
|
| 390 |
+
{
|
| 391 |
+
/// clock status is unknown
|
| 392 |
+
NVPW_DEVICE_CLOCK_STATUS_UNKNOWN,
|
| 393 |
+
/// clocks are locked to rated tdp values - Deprecated, use NVPW_DEVICE_CLOCK_STATUS_LOCKED instead
|
| 394 |
+
NVPW_DEVICE_CLOCK_STATUS_LOCKED_TO_RATED_TDP,
|
| 395 |
+
/// clocks are not locked and can boost above rated tdp
|
| 396 |
+
NVPW_DEVICE_CLOCK_STATUS_BOOST_ENABLED,
|
| 397 |
+
/// clocks are not locked and will not go above rated tdp
|
| 398 |
+
NVPW_DEVICE_CLOCK_STATUS_BOOST_DISABLED,
|
| 399 |
+
/// clocks are locked
|
| 400 |
+
NVPW_DEVICE_CLOCK_STATUS_LOCKED,
|
| 401 |
+
/// clocks are not locked
|
| 402 |
+
NVPW_DEVICE_CLOCK_STATUS_UNLOCKED,
|
| 403 |
+
NVPW_DEVICE_CLOCK_STATUS__COUNT
|
| 404 |
+
} NVPW_Device_ClockStatus;
|
| 405 |
+
|
| 406 |
+
typedef enum NVPW_Device_ClockLevel
|
| 407 |
+
{
|
| 408 |
+
/// clock level is invalid
|
| 409 |
+
NVPW_DEVICE_CLOCK_LEVEL_INVALID,
|
| 410 |
+
/// clock level is at rated tdp
|
| 411 |
+
NVPW_DEVICE_CLOCK_LEVEL_RATED_TDP,
|
| 412 |
+
/// clock level is at turbo boost
|
| 413 |
+
NVPW_DEVICE_CLOCK_LEVEL_TURBO_BOOST,
|
| 414 |
+
NVPW_DEVICE_CLOCK_LEVEL__COUNT
|
| 415 |
+
} NVPW_Device_ClockLevel;
|
| 416 |
+
|
| 417 |
+
typedef struct NVPW_Device_GetClockStatus_Params
|
| 418 |
+
{
|
| 419 |
+
/// [in]
|
| 420 |
+
size_t structSize;
|
| 421 |
+
/// [in] assign to NULL
|
| 422 |
+
void* pPriv;
|
| 423 |
+
size_t deviceIndex;
|
| 424 |
+
/// [in]
|
| 425 |
+
NVPW_Device_ClockStatus clockStatus;
|
| 426 |
+
/// [in]
|
| 427 |
+
NVPW_Device_ClockLevel clockLevel;
|
| 428 |
+
} NVPW_Device_GetClockStatus_Params;
|
| 429 |
+
#define NVPW_Device_GetClockStatus_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_GetClockStatus_Params, clockLevel)
|
| 430 |
+
|
| 431 |
+
NVPA_Status NVPW_Device_GetClockStatus(NVPW_Device_GetClockStatus_Params* pParams);
|
| 432 |
+
|
| 433 |
+
typedef enum NVPW_Device_ClockSetting
|
| 434 |
+
{
|
| 435 |
+
/// invalid op, specify valid clocks operation during profiling
|
| 436 |
+
NVPW_DEVICE_CLOCK_SETTING_INVALID,
|
| 437 |
+
/// default to driver/application config (normally unlocked and not boosted, but could be unlocked boosted, or
|
| 438 |
+
/// locked to rated TDP)
|
| 439 |
+
NVPW_DEVICE_CLOCK_SETTING_DEFAULT,
|
| 440 |
+
/// lock clocks at rated tdp base values
|
| 441 |
+
NVPW_DEVICE_CLOCK_SETTING_LOCK_TO_RATED_TDP,
|
| 442 |
+
/// lock clocks at turbo boost values
|
| 443 |
+
NVPW_DEVICE_CLOCK_SETTING_LOCK_TO_TURBO_BOOST,
|
| 444 |
+
NVPW_DEVICE_CLOCK_SETTING__COUNT
|
| 445 |
+
} NVPW_Device_ClockSetting;
|
| 446 |
+
|
| 447 |
+
typedef struct NVPW_Device_SetClockSetting_Params
|
| 448 |
+
{
|
| 449 |
+
/// [in]
|
| 450 |
+
size_t structSize;
|
| 451 |
+
/// [in] assign to NULL
|
| 452 |
+
void* pPriv;
|
| 453 |
+
size_t deviceIndex;
|
| 454 |
+
/// [in]
|
| 455 |
+
NVPW_Device_ClockSetting clockSetting;
|
| 456 |
+
} NVPW_Device_SetClockSetting_Params;
|
| 457 |
+
#define NVPW_Device_SetClockSetting_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_SetClockSetting_Params, clockSetting)
|
| 458 |
+
|
| 459 |
+
NVPA_Status NVPW_Device_SetClockSetting(NVPW_Device_SetClockSetting_Params* pParams);
|
| 460 |
+
|
| 461 |
+
typedef struct NVPW_CounterData_GetRangeDescriptions_Params
|
| 462 |
+
{
|
| 463 |
+
/// [in]
|
| 464 |
+
size_t structSize;
|
| 465 |
+
/// [in] assign to NULL
|
| 466 |
+
void* pPriv;
|
| 467 |
+
const uint8_t* pCounterDataImage;
|
| 468 |
+
size_t rangeIndex;
|
| 469 |
+
/// [inout] Number of descriptions allocated in ppDescriptions
|
| 470 |
+
size_t numDescriptions;
|
| 471 |
+
const char** ppDescriptions;
|
| 472 |
+
} NVPW_CounterData_GetRangeDescriptions_Params;
|
| 473 |
+
#define NVPW_CounterData_GetRangeDescriptions_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_CounterData_GetRangeDescriptions_Params, ppDescriptions)
|
| 474 |
+
|
| 475 |
+
NVPA_Status NVPW_CounterData_GetRangeDescriptions(NVPW_CounterData_GetRangeDescriptions_Params* pParams);
|
| 476 |
+
|
| 477 |
+
typedef struct NVPW_Profiler_CounterData_GetRangeDescriptions_Params
|
| 478 |
+
{
|
| 479 |
+
/// [in]
|
| 480 |
+
size_t structSize;
|
| 481 |
+
/// [in] assign to NULL
|
| 482 |
+
void* pPriv;
|
| 483 |
+
const uint8_t* pCounterDataImage;
|
| 484 |
+
size_t rangeIndex;
|
| 485 |
+
/// [inout] Number of descriptions allocated in ppDescriptions
|
| 486 |
+
size_t numDescriptions;
|
| 487 |
+
const char** ppDescriptions;
|
| 488 |
+
} NVPW_Profiler_CounterData_GetRangeDescriptions_Params;
|
| 489 |
+
#define NVPW_Profiler_CounterData_GetRangeDescriptions_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Profiler_CounterData_GetRangeDescriptions_Params, ppDescriptions)
|
| 490 |
+
|
| 491 |
+
NVPA_Status NVPW_Profiler_CounterData_GetRangeDescriptions(NVPW_Profiler_CounterData_GetRangeDescriptions_Params* pParams);
|
| 492 |
+
|
| 493 |
+
#ifndef NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_DEFINED
|
| 494 |
+
#define NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_DEFINED
|
| 495 |
+
typedef enum NVPW_PeriodicSampler_CounterData_AppendMode
|
| 496 |
+
{
|
| 497 |
+
NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_LINEAR = 0,
|
| 498 |
+
NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_CIRCULAR = 1,
|
| 499 |
+
NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE__COUNT
|
| 500 |
+
} NVPW_PeriodicSampler_CounterData_AppendMode;
|
| 501 |
+
#endif //NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_DEFINED
|
| 502 |
+
|
| 503 |
+
typedef struct NVPW_PeriodicSampler_CounterData_GetSampleTime_Params
|
| 504 |
+
{
|
| 505 |
+
/// [in]
|
| 506 |
+
size_t structSize;
|
| 507 |
+
/// [in] assign to NULL
|
| 508 |
+
void* pPriv;
|
| 509 |
+
/// [in]
|
| 510 |
+
const uint8_t* pCounterDataImage;
|
| 511 |
+
/// [in]
|
| 512 |
+
size_t rangeIndex;
|
| 513 |
+
/// [out]
|
| 514 |
+
uint64_t timestampStart;
|
| 515 |
+
/// [out]
|
| 516 |
+
uint64_t timestampEnd;
|
| 517 |
+
} NVPW_PeriodicSampler_CounterData_GetSampleTime_Params;
|
| 518 |
+
#define NVPW_PeriodicSampler_CounterData_GetSampleTime_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_GetSampleTime_Params, timestampEnd)
|
| 519 |
+
|
| 520 |
+
NVPA_Status NVPW_PeriodicSampler_CounterData_GetSampleTime(NVPW_PeriodicSampler_CounterData_GetSampleTime_Params* pParams);
|
| 521 |
+
|
| 522 |
+
typedef struct NVPW_PeriodicSampler_CounterData_TrimInPlace_Params
|
| 523 |
+
{
|
| 524 |
+
/// [in]
|
| 525 |
+
size_t structSize;
|
| 526 |
+
/// [in] assign to NULL
|
| 527 |
+
void* pPriv;
|
| 528 |
+
/// [in]
|
| 529 |
+
uint8_t* pCounterDataImage;
|
| 530 |
+
/// [in]
|
| 531 |
+
size_t counterDataImageSize;
|
| 532 |
+
/// [out]
|
| 533 |
+
size_t counterDataImageTrimmedSize;
|
| 534 |
+
} NVPW_PeriodicSampler_CounterData_TrimInPlace_Params;
|
| 535 |
+
#define NVPW_PeriodicSampler_CounterData_TrimInPlace_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_TrimInPlace_Params, counterDataImageTrimmedSize)
|
| 536 |
+
|
| 537 |
+
NVPA_Status NVPW_PeriodicSampler_CounterData_TrimInPlace(NVPW_PeriodicSampler_CounterData_TrimInPlace_Params* pParams);
|
| 538 |
+
|
| 539 |
+
typedef struct NVPW_PeriodicSampler_CounterData_GetInfo_Params
|
| 540 |
+
{
|
| 541 |
+
/// [in]
|
| 542 |
+
size_t structSize;
|
| 543 |
+
/// [in] assign to NULL
|
| 544 |
+
void* pPriv;
|
| 545 |
+
/// [in]
|
| 546 |
+
const uint8_t* pCounterDataImage;
|
| 547 |
+
/// [in]
|
| 548 |
+
size_t counterDataImageSize;
|
| 549 |
+
/// [out] total number of ranges in the counter data
|
| 550 |
+
size_t numTotalRanges;
|
| 551 |
+
/// [out] if in "linear" mode, this API returns the number of "populated" ranges; if it's in "circular" mode,
|
| 552 |
+
/// then it returns the last "populated" range index + 1, when there is no such range, it returns 0.
|
| 553 |
+
size_t numPopulatedRanges;
|
| 554 |
+
/// [out] if in "linear" mode, this API returns the number of "completed" ranges; if it's in "circular" mode,
|
| 555 |
+
/// then it returns the last "completed" range index + 1, when there is no such range, it returns 0.
|
| 556 |
+
size_t numCompletedRanges;
|
| 557 |
+
} NVPW_PeriodicSampler_CounterData_GetInfo_Params;
|
| 558 |
+
#define NVPW_PeriodicSampler_CounterData_GetInfo_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_GetInfo_Params, numCompletedRanges)
|
| 559 |
+
|
| 560 |
+
/// In periodic sampler, a range in counter data stores exactly one sample's data. For better performance, periodic
|
| 561 |
+
/// sampler may operate in an out-of-order fashion when populating sample data, i.e. it may not fully populate all
|
| 562 |
+
/// counters of a sample/range before starting to populate the next sample/range. As a result, we have two concepts
|
| 563 |
+
/// here, "populated" & "completed": a range is considered "populated" even if only partial counters have been
|
| 564 |
+
/// written; on the other hand, a range is only considered "completed" if all the collecting counters have been
|
| 565 |
+
/// written.
|
| 566 |
+
NVPA_Status NVPW_PeriodicSampler_CounterData_GetInfo(NVPW_PeriodicSampler_CounterData_GetInfo_Params* pParams);
|
| 567 |
+
|
| 568 |
+
typedef struct NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params
|
| 569 |
+
{
|
| 570 |
+
/// [in]
|
| 571 |
+
size_t structSize;
|
| 572 |
+
/// [in] assign to NULL
|
| 573 |
+
void* pPriv;
|
| 574 |
+
/// [in]
|
| 575 |
+
const uint8_t* pCounterDataImage;
|
| 576 |
+
/// [in]
|
| 577 |
+
size_t counterDataImageSize;
|
| 578 |
+
/// [in]
|
| 579 |
+
size_t rangeIndex;
|
| 580 |
+
/// [out]
|
| 581 |
+
uint32_t triggerCount;
|
| 582 |
+
} NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params;
|
| 583 |
+
#define NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params, triggerCount)
|
| 584 |
+
|
| 585 |
+
NVPA_Status NVPW_PeriodicSampler_CounterData_GetTriggerCount(NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params* pParams);
|
| 586 |
+
|
| 587 |
+
typedef struct NVPW_PeriodicSampler_CounterData_IsDataComplete_Params
|
| 588 |
+
{
|
| 589 |
+
/// [in]
|
| 590 |
+
size_t structSize;
|
| 591 |
+
/// [in] assign to NULL
|
| 592 |
+
void* pPriv;
|
| 593 |
+
/// [in]
|
| 594 |
+
const uint8_t* pCounterDataImage;
|
| 595 |
+
/// [in]
|
| 596 |
+
size_t counterDataImageSize;
|
| 597 |
+
/// [in]
|
| 598 |
+
size_t rangeIndex;
|
| 599 |
+
/// [out]
|
| 600 |
+
NVPA_Bool isComplete;
|
| 601 |
+
} NVPW_PeriodicSampler_CounterData_IsDataComplete_Params;
|
| 602 |
+
#define NVPW_PeriodicSampler_CounterData_IsDataComplete_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_IsDataComplete_Params, isComplete)
|
| 603 |
+
|
| 604 |
+
/// Checks whether a given sample's data is complete. See also 'NVPW_PeriodicSampler_CounterData_GetInfo'
|
| 605 |
+
NVPA_Status NVPW_PeriodicSampler_CounterData_IsDataComplete(NVPW_PeriodicSampler_CounterData_IsDataComplete_Params* pParams);
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
typedef struct NVPW_TimestampReport
|
| 609 |
+
{
|
| 610 |
+
uint32_t payload;
|
| 611 |
+
uint8_t reserved0004[4];
|
| 612 |
+
uint64_t timestamp;
|
| 613 |
+
} NVPW_TimestampReport;
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
#ifdef __cplusplus
|
| 619 |
+
} // extern "C"
|
| 620 |
+
#endif
|
| 621 |
+
|
| 622 |
+
#if defined(__GNUC__) && defined(NVPA_SHARED_LIB)
|
| 623 |
+
#pragma GCC visibility pop
|
| 624 |
+
#endif
|
| 625 |
+
|
| 626 |
+
#endif // NVPERF_TARGET_H
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/sm_32_atomic_functions.hpp
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2023 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 35.235 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.35.235 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#if !defined(__SM_32_ATOMIC_FUNCTIONS_HPP__)
|
| 51 |
+
#define __SM_32_ATOMIC_FUNCTIONS_HPP__
|
| 52 |
+
|
| 53 |
+
#ifdef __CUDA_ARCH__
|
| 54 |
+
extern "C"
|
| 55 |
+
{
|
| 56 |
+
extern __device__ __device_builtin__ long long __illAtomicMin(long long *address, long long val);
|
| 57 |
+
extern __device__ __device_builtin__ long long __illAtomicMax(long long *address, long long val);
|
| 58 |
+
extern __device__ __device_builtin__ long long __llAtomicAnd(long long *address, long long val);
|
| 59 |
+
extern __device__ __device_builtin__ long long __llAtomicOr(long long *address, long long val);
|
| 60 |
+
extern __device__ __device_builtin__ long long __llAtomicXor(long long *address, long long val);
|
| 61 |
+
extern __device__ __device_builtin__ unsigned long long __ullAtomicMin(unsigned long long *address, unsigned long long val);
|
| 62 |
+
extern __device__ __device_builtin__ unsigned long long __ullAtomicMax(unsigned long long *address, unsigned long long val);
|
| 63 |
+
extern __device__ __device_builtin__ unsigned long long __ullAtomicAnd(unsigned long long *address, unsigned long long val);
|
| 64 |
+
extern __device__ __device_builtin__ unsigned long long __ullAtomicOr (unsigned long long *address, unsigned long long val);
|
| 65 |
+
extern __device__ __device_builtin__ unsigned long long __ullAtomicXor(unsigned long long *address, unsigned long long val);
|
| 66 |
+
}
|
| 67 |
+
#endif /* __CUDA_ARCH__ */
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
#if defined(__CUDACC_RTC__)
|
| 71 |
+
#define __SM_32_ATOMIC_FUNCTIONS_DECL__ __device__
|
| 72 |
+
#else /* !__CUDACC_RTC__ */
|
| 73 |
+
#define __SM_32_ATOMIC_FUNCTIONS_DECL__ static __inline__ __device__
|
| 74 |
+
#endif /* __CUDACC_RTC__ */
|
| 75 |
+
|
| 76 |
+
#if defined(__cplusplus) && defined(__CUDACC__)
|
| 77 |
+
|
| 78 |
+
#if defined(_NVHPC_CUDA) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 320
|
| 79 |
+
|
| 80 |
+
/*******************************************************************************
|
| 81 |
+
* *
|
| 82 |
+
* *
|
| 83 |
+
* *
|
| 84 |
+
*******************************************************************************/
|
| 85 |
+
|
| 86 |
+
#include "cuda_runtime_api.h"
|
| 87 |
+
|
| 88 |
+
/*******************************************************************************
|
| 89 |
+
* *
|
| 90 |
+
* *
|
| 91 |
+
* *
|
| 92 |
+
*******************************************************************************/
|
| 93 |
+
|
| 94 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicMin(long long *address, long long val)
|
| 95 |
+
{
|
| 96 |
+
return __illAtomicMin(address, val);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicMax(long long *address, long long val)
|
| 100 |
+
{
|
| 101 |
+
return __illAtomicMax(address, val);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicAnd(long long *address, long long val)
|
| 105 |
+
{
|
| 106 |
+
return __llAtomicAnd(address, val);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicOr(long long *address, long long val)
|
| 110 |
+
{
|
| 111 |
+
return __llAtomicOr(address, val);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicXor(long long *address, long long val)
|
| 115 |
+
{
|
| 116 |
+
return __llAtomicXor(address, val);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicMin(unsigned long long *address, unsigned long long val)
|
| 120 |
+
{
|
| 121 |
+
return __ullAtomicMin(address, val);
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicMax(unsigned long long *address, unsigned long long val)
|
| 125 |
+
{
|
| 126 |
+
return __ullAtomicMax(address, val);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicAnd(unsigned long long *address, unsigned long long val)
|
| 130 |
+
{
|
| 131 |
+
return __ullAtomicAnd(address, val);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicOr(unsigned long long *address, unsigned long long val)
|
| 135 |
+
{
|
| 136 |
+
return __ullAtomicOr(address, val);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
__SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicXor(unsigned long long *address, unsigned long long val)
|
| 140 |
+
{
|
| 141 |
+
return __ullAtomicXor(address, val);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
#endif /* _NVHPC_CUDA || !__CUDA_ARCH__ || __CUDA_ARCH__ >= 320 */
|
| 145 |
+
|
| 146 |
+
#endif /* __cplusplus && __CUDACC__ */
|
| 147 |
+
|
| 148 |
+
#undef __SM_32_ATOMIC_FUNCTIONS_DECL__
|
| 149 |
+
|
| 150 |
+
#endif /* !__SM_32_ATOMIC_FUNCTIONS_HPP__ */
|
| 151 |
+
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (7.17 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/math.cpython-312.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/random.cpython-312.pyc
ADDED
|
Binary file (9.64 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/standard.cpython-312.pyc
ADDED
|
Binary file (21.8 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/target_info.cpython-312.pyc
ADDED
|
Binary file (2.28 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pkgutil
|
| 2 |
+
from importlib.util import module_from_spec
|
| 3 |
+
from sys import modules
|
| 4 |
+
|
| 5 |
+
_backends = []
|
| 6 |
+
for module_finder, module_name, is_pkg in pkgutil.iter_modules(
|
| 7 |
+
__path__,
|
| 8 |
+
prefix=__name__ + ".",
|
| 9 |
+
):
|
| 10 |
+
# skip .py files (like libdevice.py)
|
| 11 |
+
if not is_pkg:
|
| 12 |
+
continue
|
| 13 |
+
|
| 14 |
+
# import backends (like cuda and hip) that are included during setup.py
|
| 15 |
+
spec = module_finder.find_spec(module_name)
|
| 16 |
+
if spec is None or spec.loader is None:
|
| 17 |
+
continue
|
| 18 |
+
module = module_from_spec(spec)
|
| 19 |
+
spec.loader.exec_module(module)
|
| 20 |
+
|
| 21 |
+
_backends.append(module_name)
|
| 22 |
+
modules[module_name] = module
|
| 23 |
+
|
| 24 |
+
__all__ = _backends
|
| 25 |
+
|
| 26 |
+
del _backends
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (911 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__pycache__/libdevice.cpython-312.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import libdevice
|
| 2 |
+
|
| 3 |
+
from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80)
|
| 4 |
+
from .gdc import (gdc_launch_dependents, gdc_wait)
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"libdevice",
|
| 8 |
+
"globaltimer",
|
| 9 |
+
"num_threads",
|
| 10 |
+
"num_warps",
|
| 11 |
+
"smid",
|
| 12 |
+
"convert_custom_float8_sm70",
|
| 13 |
+
"convert_custom_float8_sm80",
|
| 14 |
+
"gdc_launch_dependents",
|
| 15 |
+
"gdc_wait",
|
| 16 |
+
]
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (585 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/gdc.cpython-312.pyc
ADDED
|
Binary file (2.74 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/libdevice.cpython-312.pyc
ADDED
|
Binary file (89.2 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/gdc.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grid Dependency Control (GDC) is a mechanism used when enabling programmatic dependent launch to launch and
|
| 3 |
+
synchronize grids. These APIs expose GDC to the programmer.
|
| 4 |
+
|
| 5 |
+
Programmatic dependent launch is supported on SM90 (Hopper) and beyond.
|
| 6 |
+
For PTX reference on grid dependency control see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from triton.language import core
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@core.extern
|
| 13 |
+
def gdc_wait(_semantic=None):
|
| 14 |
+
"""
|
| 15 |
+
GDC wait is a blocking instruction that waits for all instructions in a prior kernel to complete before continuing.
|
| 16 |
+
This ensures all memory operations happening before the wait is visible to instructions after it,
|
| 17 |
+
e.g. if the prior kernel writes to address "x" the new values will be visible in this kernel after the wait.
|
| 18 |
+
|
| 19 |
+
This instruction is also safe to execute when programmatic dependent launch is disabled.
|
| 20 |
+
|
| 21 |
+
See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol for more details.
|
| 22 |
+
"""
|
| 23 |
+
core.inline_asm_elementwise("griddepcontrol.wait; // dummy $0", "=r", [], dtype=core.int32, is_pure=False, pack=1,
|
| 24 |
+
_semantic=_semantic)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@core.extern
|
| 28 |
+
def gdc_launch_dependents(_semantic=None):
|
| 29 |
+
"""
|
| 30 |
+
This operation when launched with programmatic dependent launch signals that
|
| 31 |
+
the next program may launch once all programs in the current kernel
|
| 32 |
+
call this function or complete.
|
| 33 |
+
|
| 34 |
+
Repeated calls to this function have no effect past the first call, and the first call should be
|
| 35 |
+
treated by the programmer as a hint to the runtime system to launch the next kernel.
|
| 36 |
+
|
| 37 |
+
This instruction is also safe to execute when programmatic dependent launch is disabled.
|
| 38 |
+
|
| 39 |
+
See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol for more details.
|
| 40 |
+
"""
|
| 41 |
+
core.inline_asm_elementwise("griddepcontrol.launch_dependents; // dummy $0", "=r", [], dtype=core.int32,
|
| 42 |
+
is_pure=False, pack=1, _semantic=_semantic)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/libdevice.py
ADDED
|
@@ -0,0 +1,1629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from triton.language import core
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@core.extern
|
| 5 |
+
def clz(arg0, _semantic=None):
|
| 6 |
+
return core.extern_elementwise(
|
| 7 |
+
"", "", [arg0], {
|
| 8 |
+
(core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")),
|
| 9 |
+
(core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")),
|
| 10 |
+
}, is_pure=True, _semantic=_semantic)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@core.extern
|
| 14 |
+
def popc(arg0, _semantic=None):
|
| 15 |
+
return core.extern_elementwise(
|
| 16 |
+
"", "", [arg0], {
|
| 17 |
+
(core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")),
|
| 18 |
+
(core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")),
|
| 19 |
+
}, is_pure=True, _semantic=_semantic)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@core.extern
|
| 23 |
+
def byte_perm(arg0, arg1, arg2, _semantic=None):
|
| 24 |
+
return core.extern_elementwise("", "", [arg0, arg1, arg2], {
|
| 25 |
+
(core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")),
|
| 26 |
+
}, is_pure=True, _semantic=_semantic)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@core.extern
|
| 30 |
+
def mulhi(arg0, arg1, _semantic=None):
|
| 31 |
+
return core.extern_elementwise(
|
| 32 |
+
"", "", [arg0, arg1], {
|
| 33 |
+
(core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")),
|
| 34 |
+
(core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")),
|
| 35 |
+
(core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")),
|
| 36 |
+
(core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")),
|
| 37 |
+
}, is_pure=True, _semantic=_semantic)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@core.extern
|
| 41 |
+
def mul24(arg0, arg1, _semantic=None):
|
| 42 |
+
return core.extern_elementwise(
|
| 43 |
+
"", "", [arg0, arg1], {
|
| 44 |
+
(core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")),
|
| 45 |
+
(core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")),
|
| 46 |
+
}, is_pure=True, _semantic=_semantic)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@core.extern
|
| 50 |
+
def brev(arg0, _semantic=None):
|
| 51 |
+
return core.extern_elementwise(
|
| 52 |
+
"", "", [arg0], {
|
| 53 |
+
(core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")),
|
| 54 |
+
(core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")),
|
| 55 |
+
}, is_pure=True, _semantic=_semantic)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@core.extern
|
| 59 |
+
def sad(arg0, arg1, arg2, _semantic=None):
|
| 60 |
+
return core.extern_elementwise(
|
| 61 |
+
"", "", [arg0, arg1, arg2], {
|
| 62 |
+
(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")),
|
| 63 |
+
(core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")),
|
| 64 |
+
}, is_pure=True, _semantic=_semantic)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@core.extern
|
| 68 |
+
def abs(arg0, _semantic=None):
|
| 69 |
+
return core.extern_elementwise(
|
| 70 |
+
"", "", [arg0], {
|
| 71 |
+
(core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")),
|
| 72 |
+
(core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")),
|
| 73 |
+
(core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")),
|
| 74 |
+
(core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")),
|
| 75 |
+
}, is_pure=True, _semantic=_semantic)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@core.extern
|
| 79 |
+
def floor(arg0, _semantic=None):
|
| 80 |
+
return core.extern_elementwise(
|
| 81 |
+
"", "", [arg0], {
|
| 82 |
+
(core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")),
|
| 83 |
+
(core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")),
|
| 84 |
+
}, is_pure=True, _semantic=_semantic)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@core.extern
|
| 88 |
+
def rcp64h(arg0, _semantic=None):
|
| 89 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 90 |
+
(core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")),
|
| 91 |
+
}, is_pure=True, _semantic=_semantic)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@core.extern
|
| 95 |
+
def rsqrt(arg0, _semantic=None):
|
| 96 |
+
return core.extern_elementwise(
|
| 97 |
+
"", "", [arg0], {
|
| 98 |
+
(core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")),
|
| 99 |
+
(core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")),
|
| 100 |
+
}, is_pure=True, _semantic=_semantic)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@core.extern
|
| 104 |
+
def ceil(arg0, _semantic=None):
|
| 105 |
+
return core.extern_elementwise(
|
| 106 |
+
"", "", [arg0], {
|
| 107 |
+
(core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")),
|
| 108 |
+
(core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")),
|
| 109 |
+
}, is_pure=True, _semantic=_semantic)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@core.extern
|
| 113 |
+
def trunc(arg0, _semantic=None):
|
| 114 |
+
return core.extern_elementwise(
|
| 115 |
+
"", "", [arg0], {
|
| 116 |
+
(core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")),
|
| 117 |
+
(core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")),
|
| 118 |
+
}, is_pure=True, _semantic=_semantic)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@core.extern
|
| 122 |
+
def exp2(arg0, _semantic=None):
|
| 123 |
+
return core.extern_elementwise(
|
| 124 |
+
"", "", [arg0], {
|
| 125 |
+
(core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")),
|
| 126 |
+
(core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")),
|
| 127 |
+
}, is_pure=True, _semantic=_semantic)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@core.extern
|
| 131 |
+
def saturatef(arg0, _semantic=None):
|
| 132 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 133 |
+
(core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")),
|
| 134 |
+
}, is_pure=True, _semantic=_semantic)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@core.extern
|
| 138 |
+
def fma_rn(arg0, arg1, arg2, _semantic=None):
|
| 139 |
+
return core.extern_elementwise(
|
| 140 |
+
"", "", [arg0, arg1, arg2], {
|
| 141 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")),
|
| 142 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")),
|
| 143 |
+
}, is_pure=True, _semantic=_semantic)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@core.extern
|
| 147 |
+
def fma_rz(arg0, arg1, arg2, _semantic=None):
|
| 148 |
+
return core.extern_elementwise(
|
| 149 |
+
"", "", [arg0, arg1, arg2], {
|
| 150 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")),
|
| 151 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")),
|
| 152 |
+
}, is_pure=True, _semantic=_semantic)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@core.extern
|
| 156 |
+
def fma_rd(arg0, arg1, arg2, _semantic=None):
|
| 157 |
+
return core.extern_elementwise(
|
| 158 |
+
"", "", [arg0, arg1, arg2], {
|
| 159 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")),
|
| 160 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")),
|
| 161 |
+
}, is_pure=True, _semantic=_semantic)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@core.extern
|
| 165 |
+
def fma_ru(arg0, arg1, arg2, _semantic=None):
|
| 166 |
+
return core.extern_elementwise(
|
| 167 |
+
"", "", [arg0, arg1, arg2], {
|
| 168 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")),
|
| 169 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")),
|
| 170 |
+
}, is_pure=True, _semantic=_semantic)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@core.extern
|
| 174 |
+
def fast_dividef(arg0, arg1, _semantic=None):
|
| 175 |
+
return core.extern_elementwise("", "", [arg0, arg1], {
|
| 176 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")),
|
| 177 |
+
}, is_pure=True, _semantic=_semantic)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@core.extern
|
| 181 |
+
def div_rn(arg0, arg1, _semantic=None):
|
| 182 |
+
return core.extern_elementwise(
|
| 183 |
+
"", "", [arg0, arg1], {
|
| 184 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")),
|
| 185 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")),
|
| 186 |
+
}, is_pure=True, _semantic=_semantic)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@core.extern
|
| 190 |
+
def div_rz(arg0, arg1, _semantic=None):
|
| 191 |
+
return core.extern_elementwise(
|
| 192 |
+
"", "", [arg0, arg1], {
|
| 193 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")),
|
| 194 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")),
|
| 195 |
+
}, is_pure=True, _semantic=_semantic)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@core.extern
|
| 199 |
+
def div_rd(arg0, arg1, _semantic=None):
|
| 200 |
+
return core.extern_elementwise(
|
| 201 |
+
"", "", [arg0, arg1], {
|
| 202 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")),
|
| 203 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")),
|
| 204 |
+
}, is_pure=True, _semantic=_semantic)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@core.extern
|
| 208 |
+
def div_ru(arg0, arg1, _semantic=None):
|
| 209 |
+
return core.extern_elementwise(
|
| 210 |
+
"", "", [arg0, arg1], {
|
| 211 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")),
|
| 212 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")),
|
| 213 |
+
}, is_pure=True, _semantic=_semantic)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@core.extern
|
| 217 |
+
def rcp_rn(arg0, _semantic=None):
|
| 218 |
+
return core.extern_elementwise(
|
| 219 |
+
"", "", [arg0], {
|
| 220 |
+
(core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")),
|
| 221 |
+
(core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")),
|
| 222 |
+
}, is_pure=True, _semantic=_semantic)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
@core.extern
|
| 226 |
+
def rcp_rz(arg0, _semantic=None):
|
| 227 |
+
return core.extern_elementwise(
|
| 228 |
+
"", "", [arg0], {
|
| 229 |
+
(core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")),
|
| 230 |
+
(core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")),
|
| 231 |
+
}, is_pure=True, _semantic=_semantic)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@core.extern
|
| 235 |
+
def rcp_rd(arg0, _semantic=None):
|
| 236 |
+
return core.extern_elementwise(
|
| 237 |
+
"", "", [arg0], {
|
| 238 |
+
(core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")),
|
| 239 |
+
(core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")),
|
| 240 |
+
}, is_pure=True, _semantic=_semantic)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@core.extern
|
| 244 |
+
def rcp_ru(arg0, _semantic=None):
|
| 245 |
+
return core.extern_elementwise(
|
| 246 |
+
"", "", [arg0], {
|
| 247 |
+
(core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")),
|
| 248 |
+
(core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")),
|
| 249 |
+
}, is_pure=True, _semantic=_semantic)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
@core.extern
|
| 253 |
+
def sqrt_rn(arg0, _semantic=None):
|
| 254 |
+
return core.extern_elementwise(
|
| 255 |
+
"", "", [arg0], {
|
| 256 |
+
(core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")),
|
| 257 |
+
(core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")),
|
| 258 |
+
}, is_pure=True, _semantic=_semantic)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@core.extern
|
| 262 |
+
def sqrt_rz(arg0, _semantic=None):
|
| 263 |
+
return core.extern_elementwise(
|
| 264 |
+
"", "", [arg0], {
|
| 265 |
+
(core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")),
|
| 266 |
+
(core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")),
|
| 267 |
+
}, is_pure=True, _semantic=_semantic)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@core.extern
|
| 271 |
+
def sqrt_rd(arg0, _semantic=None):
|
| 272 |
+
return core.extern_elementwise(
|
| 273 |
+
"", "", [arg0], {
|
| 274 |
+
(core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")),
|
| 275 |
+
(core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")),
|
| 276 |
+
}, is_pure=True, _semantic=_semantic)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@core.extern
|
| 280 |
+
def sqrt_ru(arg0, _semantic=None):
|
| 281 |
+
return core.extern_elementwise(
|
| 282 |
+
"", "", [arg0], {
|
| 283 |
+
(core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")),
|
| 284 |
+
(core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")),
|
| 285 |
+
}, is_pure=True, _semantic=_semantic)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@core.extern
|
| 289 |
+
def sqrt(arg0, _semantic=None):
|
| 290 |
+
return core.extern_elementwise(
|
| 291 |
+
"", "", [arg0], {
|
| 292 |
+
(core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")),
|
| 293 |
+
(core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")),
|
| 294 |
+
}, is_pure=True, _semantic=_semantic)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@core.extern
|
| 298 |
+
def add_rn(arg0, arg1, _semantic=None):
|
| 299 |
+
return core.extern_elementwise(
|
| 300 |
+
"", "", [arg0, arg1], {
|
| 301 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")),
|
| 302 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")),
|
| 303 |
+
}, is_pure=True, _semantic=_semantic)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@core.extern
|
| 307 |
+
def add_rz(arg0, arg1, _semantic=None):
|
| 308 |
+
return core.extern_elementwise(
|
| 309 |
+
"", "", [arg0, arg1], {
|
| 310 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")),
|
| 311 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")),
|
| 312 |
+
}, is_pure=True, _semantic=_semantic)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@core.extern
|
| 316 |
+
def add_rd(arg0, arg1, _semantic=None):
|
| 317 |
+
return core.extern_elementwise(
|
| 318 |
+
"", "", [arg0, arg1], {
|
| 319 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")),
|
| 320 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")),
|
| 321 |
+
}, is_pure=True, _semantic=_semantic)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@core.extern
|
| 325 |
+
def add_ru(arg0, arg1, _semantic=None):
|
| 326 |
+
return core.extern_elementwise(
|
| 327 |
+
"", "", [arg0, arg1], {
|
| 328 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")),
|
| 329 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")),
|
| 330 |
+
}, is_pure=True, _semantic=_semantic)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@core.extern
|
| 334 |
+
def mul_rn(arg0, arg1, _semantic=None):
|
| 335 |
+
return core.extern_elementwise(
|
| 336 |
+
"", "", [arg0, arg1], {
|
| 337 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")),
|
| 338 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")),
|
| 339 |
+
}, is_pure=True, _semantic=_semantic)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@core.extern
|
| 343 |
+
def mul_rz(arg0, arg1, _semantic=None):
|
| 344 |
+
return core.extern_elementwise(
|
| 345 |
+
"", "", [arg0, arg1], {
|
| 346 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")),
|
| 347 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")),
|
| 348 |
+
}, is_pure=True, _semantic=_semantic)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@core.extern
|
| 352 |
+
def mul_rd(arg0, arg1, _semantic=None):
|
| 353 |
+
return core.extern_elementwise(
|
| 354 |
+
"", "", [arg0, arg1], {
|
| 355 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")),
|
| 356 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")),
|
| 357 |
+
}, is_pure=True, _semantic=_semantic)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
@core.extern
|
| 361 |
+
def mul_ru(arg0, arg1, _semantic=None):
|
| 362 |
+
return core.extern_elementwise(
|
| 363 |
+
"", "", [
|
| 364 |
+
arg0,
|
| 365 |
+
arg1,
|
| 366 |
+
], {
|
| 367 |
+
(
|
| 368 |
+
core.dtype("fp64"),
|
| 369 |
+
core.dtype("fp64"),
|
| 370 |
+
): ("__nv_dmul_ru", core.dtype("fp64")),
|
| 371 |
+
(
|
| 372 |
+
core.dtype("fp32"),
|
| 373 |
+
core.dtype("fp32"),
|
| 374 |
+
): ("__nv_fmul_ru", core.dtype("fp32")),
|
| 375 |
+
}, is_pure=True, _semantic=_semantic)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
@core.extern
|
| 379 |
+
def double2float_rn(arg0, _semantic=None):
|
| 380 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 381 |
+
(core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")),
|
| 382 |
+
}, is_pure=True, _semantic=_semantic)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
@core.extern
|
| 386 |
+
def double2float_rz(arg0, _semantic=None):
|
| 387 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 388 |
+
(core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")),
|
| 389 |
+
}, is_pure=True, _semantic=_semantic)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@core.extern
|
| 393 |
+
def double2float_rd(arg0, _semantic=None):
|
| 394 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 395 |
+
(core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")),
|
| 396 |
+
}, is_pure=True, _semantic=_semantic)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
@core.extern
|
| 400 |
+
def double2float_ru(arg0, _semantic=None):
|
| 401 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 402 |
+
(core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")),
|
| 403 |
+
}, is_pure=True, _semantic=_semantic)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
@core.extern
|
| 407 |
+
def double2int_rn(arg0, _semantic=None):
|
| 408 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 409 |
+
(core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")),
|
| 410 |
+
}, is_pure=True, _semantic=_semantic)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
@core.extern
|
| 414 |
+
def double2int_rz(arg0, _semantic=None):
|
| 415 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 416 |
+
(core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")),
|
| 417 |
+
}, is_pure=True, _semantic=_semantic)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
@core.extern
|
| 421 |
+
def double2int_rd(arg0, _semantic=None):
|
| 422 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 423 |
+
(core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")),
|
| 424 |
+
}, is_pure=True, _semantic=_semantic)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@core.extern
|
| 428 |
+
def double2int_ru(arg0, _semantic=None):
|
| 429 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 430 |
+
(core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")),
|
| 431 |
+
}, is_pure=True, _semantic=_semantic)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
@core.extern
|
| 435 |
+
def double2uint_rn(arg0, _semantic=None):
|
| 436 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 437 |
+
(core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")),
|
| 438 |
+
}, is_pure=True, _semantic=_semantic)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
@core.extern
|
| 442 |
+
def double2uint_rz(arg0, _semantic=None):
|
| 443 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 444 |
+
(core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")),
|
| 445 |
+
}, is_pure=True, _semantic=_semantic)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@core.extern
|
| 449 |
+
def double2uint_rd(arg0, _semantic=None):
|
| 450 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 451 |
+
(core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")),
|
| 452 |
+
}, is_pure=True, _semantic=_semantic)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@core.extern
|
| 456 |
+
def double2uint_ru(arg0, _semantic=None):
|
| 457 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 458 |
+
(core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")),
|
| 459 |
+
}, is_pure=True, _semantic=_semantic)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@core.extern
|
| 463 |
+
def int2double_rn(arg0, _semantic=None):
|
| 464 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 465 |
+
(core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")),
|
| 466 |
+
}, is_pure=True, _semantic=_semantic)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@core.extern
|
| 470 |
+
def uint2double_rn(arg0, _semantic=None):
|
| 471 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 472 |
+
(core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")),
|
| 473 |
+
}, is_pure=True, _semantic=_semantic)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
@core.extern
|
| 477 |
+
def float2int_rn(arg0, _semantic=None):
|
| 478 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 479 |
+
(core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")),
|
| 480 |
+
}, is_pure=True, _semantic=_semantic)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
@core.extern
|
| 484 |
+
def float2int_rz(arg0, _semantic=None):
|
| 485 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 486 |
+
(core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")),
|
| 487 |
+
}, is_pure=True, _semantic=_semantic)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
@core.extern
|
| 491 |
+
def float2int_rd(arg0, _semantic=None):
|
| 492 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 493 |
+
(core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")),
|
| 494 |
+
}, is_pure=True, _semantic=_semantic)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
@core.extern
|
| 498 |
+
def float2int_ru(arg0, _semantic=None):
|
| 499 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 500 |
+
(core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")),
|
| 501 |
+
}, is_pure=True, _semantic=_semantic)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
@core.extern
|
| 505 |
+
def float2uint_rn(arg0, _semantic=None):
|
| 506 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 507 |
+
(core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")),
|
| 508 |
+
}, is_pure=True, _semantic=_semantic)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
@core.extern
|
| 512 |
+
def float2uint_rz(arg0, _semantic=None):
|
| 513 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 514 |
+
(core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")),
|
| 515 |
+
}, is_pure=True, _semantic=_semantic)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
@core.extern
|
| 519 |
+
def float2uint_rd(arg0, _semantic=None):
|
| 520 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 521 |
+
(core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")),
|
| 522 |
+
}, is_pure=True, _semantic=_semantic)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
@core.extern
|
| 526 |
+
def float2uint_ru(arg0, _semantic=None):
|
| 527 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 528 |
+
(core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")),
|
| 529 |
+
}, is_pure=True, _semantic=_semantic)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
@core.extern
|
| 533 |
+
def int2float_rn(arg0, _semantic=None):
|
| 534 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 535 |
+
(core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")),
|
| 536 |
+
}, is_pure=True, _semantic=_semantic)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@core.extern
|
| 540 |
+
def int2float_rz(arg0, _semantic=None):
|
| 541 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 542 |
+
(core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")),
|
| 543 |
+
}, is_pure=True, _semantic=_semantic)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@core.extern
|
| 547 |
+
def int2float_rd(arg0, _semantic=None):
|
| 548 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 549 |
+
(core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")),
|
| 550 |
+
}, is_pure=True, _semantic=_semantic)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
@core.extern
|
| 554 |
+
def int2float_ru(arg0, _semantic=None):
|
| 555 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 556 |
+
(core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")),
|
| 557 |
+
}, is_pure=True, _semantic=_semantic)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
@core.extern
|
| 561 |
+
def uint2float_rn(arg0, _semantic=None):
|
| 562 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 563 |
+
(core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")),
|
| 564 |
+
}, is_pure=True, _semantic=_semantic)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
@core.extern
|
| 568 |
+
def uint2float_rz(arg0, _semantic=None):
|
| 569 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 570 |
+
(core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")),
|
| 571 |
+
}, is_pure=True, _semantic=_semantic)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
@core.extern
|
| 575 |
+
def uint2float_rd(arg0, _semantic=None):
|
| 576 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 577 |
+
(core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")),
|
| 578 |
+
}, is_pure=True, _semantic=_semantic)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
@core.extern
|
| 582 |
+
def uint2float_ru(arg0, _semantic=None):
|
| 583 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 584 |
+
(core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")),
|
| 585 |
+
}, is_pure=True, _semantic=_semantic)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
@core.extern
|
| 589 |
+
def hiloint2double(arg0, arg1, _semantic=None):
|
| 590 |
+
return core.extern_elementwise("", "", [arg0, arg1], {
|
| 591 |
+
(core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")),
|
| 592 |
+
}, is_pure=True, _semantic=_semantic)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
@core.extern
|
| 596 |
+
def double2loint(arg0, _semantic=None):
|
| 597 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 598 |
+
(core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")),
|
| 599 |
+
}, is_pure=True, _semantic=_semantic)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
@core.extern
|
| 603 |
+
def double2hiint(arg0, _semantic=None):
|
| 604 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 605 |
+
(core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")),
|
| 606 |
+
}, is_pure=True, _semantic=_semantic)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
@core.extern
|
| 610 |
+
def float2ll_rn(arg0, _semantic=None):
|
| 611 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 612 |
+
(core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")),
|
| 613 |
+
}, is_pure=True, _semantic=_semantic)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
@core.extern
|
| 617 |
+
def float2ll_rz(arg0, _semantic=None):
|
| 618 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 619 |
+
(core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")),
|
| 620 |
+
}, is_pure=True, _semantic=_semantic)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
@core.extern
|
| 624 |
+
def float2ll_rd(arg0, _semantic=None):
|
| 625 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 626 |
+
(core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")),
|
| 627 |
+
}, is_pure=True, _semantic=_semantic)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
@core.extern
|
| 631 |
+
def float2ll_ru(arg0, _semantic=None):
|
| 632 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 633 |
+
(core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")),
|
| 634 |
+
}, is_pure=True, _semantic=_semantic)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
@core.extern
|
| 638 |
+
def float2ull_rn(arg0, _semantic=None):
|
| 639 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 640 |
+
(core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")),
|
| 641 |
+
}, is_pure=True, _semantic=_semantic)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
@core.extern
|
| 645 |
+
def float2ull_rz(arg0, _semantic=None):
|
| 646 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 647 |
+
(core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")),
|
| 648 |
+
}, is_pure=True, _semantic=_semantic)
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
@core.extern
|
| 652 |
+
def float2ull_rd(arg0, _semantic=None):
|
| 653 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 654 |
+
(core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")),
|
| 655 |
+
}, is_pure=True, _semantic=_semantic)
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
@core.extern
|
| 659 |
+
def float2ull_ru(arg0, _semantic=None):
|
| 660 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 661 |
+
(core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")),
|
| 662 |
+
}, is_pure=True, _semantic=_semantic)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
@core.extern
|
| 666 |
+
def double2ll_rn(arg0, _semantic=None):
|
| 667 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 668 |
+
(core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")),
|
| 669 |
+
}, is_pure=True, _semantic=_semantic)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
@core.extern
|
| 673 |
+
def double2ll_rz(arg0, _semantic=None):
|
| 674 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 675 |
+
(core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")),
|
| 676 |
+
}, is_pure=True, _semantic=_semantic)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
@core.extern
|
| 680 |
+
def double2ll_rd(arg0, _semantic=None):
|
| 681 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 682 |
+
(core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")),
|
| 683 |
+
}, is_pure=True, _semantic=_semantic)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
@core.extern
|
| 687 |
+
def double2ll_ru(arg0, _semantic=None):
|
| 688 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 689 |
+
(core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")),
|
| 690 |
+
}, is_pure=True, _semantic=_semantic)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
@core.extern
|
| 694 |
+
def double2ull_rn(arg0, _semantic=None):
|
| 695 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 696 |
+
(core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")),
|
| 697 |
+
}, is_pure=True, _semantic=_semantic)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
@core.extern
|
| 701 |
+
def double2ull_rz(arg0, _semantic=None):
|
| 702 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 703 |
+
(core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")),
|
| 704 |
+
}, is_pure=True, _semantic=_semantic)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
@core.extern
|
| 708 |
+
def double2ull_rd(arg0, _semantic=None):
|
| 709 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 710 |
+
(core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")),
|
| 711 |
+
}, is_pure=True, _semantic=_semantic)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
@core.extern
|
| 715 |
+
def double2ull_ru(arg0, _semantic=None):
|
| 716 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 717 |
+
(core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")),
|
| 718 |
+
}, is_pure=True, _semantic=_semantic)
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
@core.extern
|
| 722 |
+
def ll2float_rn(arg0, _semantic=None):
|
| 723 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 724 |
+
(core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")),
|
| 725 |
+
}, is_pure=True, _semantic=_semantic)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
@core.extern
|
| 729 |
+
def ll2float_rz(arg0, _semantic=None):
|
| 730 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 731 |
+
(core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")),
|
| 732 |
+
}, is_pure=True, _semantic=_semantic)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
@core.extern
|
| 736 |
+
def ll2float_rd(arg0, _semantic=None):
|
| 737 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 738 |
+
(core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")),
|
| 739 |
+
}, is_pure=True, _semantic=_semantic)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
@core.extern
|
| 743 |
+
def ll2float_ru(arg0, _semantic=None):
|
| 744 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 745 |
+
(core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")),
|
| 746 |
+
}, is_pure=True, _semantic=_semantic)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
@core.extern
|
| 750 |
+
def ull2float_rn(arg0, _semantic=None):
|
| 751 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 752 |
+
(core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")),
|
| 753 |
+
}, is_pure=True, _semantic=_semantic)
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
@core.extern
|
| 757 |
+
def ull2float_rz(arg0, _semantic=None):
|
| 758 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 759 |
+
(core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")),
|
| 760 |
+
}, is_pure=True, _semantic=_semantic)
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
@core.extern
|
| 764 |
+
def ull2float_rd(arg0, _semantic=None):
|
| 765 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 766 |
+
(core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")),
|
| 767 |
+
}, is_pure=True, _semantic=_semantic)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
@core.extern
|
| 771 |
+
def ull2float_ru(arg0, _semantic=None):
|
| 772 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 773 |
+
(core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")),
|
| 774 |
+
}, is_pure=True, _semantic=_semantic)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
@core.extern
|
| 778 |
+
def ll2double_rn(arg0, _semantic=None):
|
| 779 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 780 |
+
(core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")),
|
| 781 |
+
}, is_pure=True, _semantic=_semantic)
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
@core.extern
|
| 785 |
+
def ll2double_rz(arg0, _semantic=None):
|
| 786 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 787 |
+
(core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")),
|
| 788 |
+
}, is_pure=True, _semantic=_semantic)
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
@core.extern
|
| 792 |
+
def ll2double_rd(arg0, _semantic=None):
|
| 793 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 794 |
+
(core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")),
|
| 795 |
+
}, is_pure=True, _semantic=_semantic)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
@core.extern
|
| 799 |
+
def ll2double_ru(arg0, _semantic=None):
|
| 800 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 801 |
+
(core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")),
|
| 802 |
+
}, is_pure=True, _semantic=_semantic)
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
@core.extern
|
| 806 |
+
def ull2double_rn(arg0, _semantic=None):
|
| 807 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 808 |
+
(core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")),
|
| 809 |
+
}, is_pure=True, _semantic=_semantic)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
@core.extern
|
| 813 |
+
def ull2double_rz(arg0, _semantic=None):
|
| 814 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 815 |
+
(core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")),
|
| 816 |
+
}, is_pure=True, _semantic=_semantic)
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
@core.extern
|
| 820 |
+
def ull2double_rd(arg0, _semantic=None):
|
| 821 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 822 |
+
(core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")),
|
| 823 |
+
}, is_pure=True, _semantic=_semantic)
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
@core.extern
|
| 827 |
+
def ull2double_ru(arg0, _semantic=None):
|
| 828 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 829 |
+
(core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")),
|
| 830 |
+
}, is_pure=True, _semantic=_semantic)
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
@core.extern
|
| 834 |
+
def int_as_float(arg0, _semantic=None):
|
| 835 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 836 |
+
(core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")),
|
| 837 |
+
}, is_pure=True, _semantic=_semantic)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
@core.extern
|
| 841 |
+
def float_as_int(arg0, _semantic=None):
|
| 842 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 843 |
+
(core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")),
|
| 844 |
+
}, is_pure=True, _semantic=_semantic)
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
@core.extern
|
| 848 |
+
def uint_as_float(arg0, _semantic=None):
|
| 849 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 850 |
+
(core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")),
|
| 851 |
+
}, is_pure=True, _semantic=_semantic)
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
@core.extern
|
| 855 |
+
def float_as_uint(arg0, _semantic=None):
|
| 856 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 857 |
+
(core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")),
|
| 858 |
+
}, is_pure=True, _semantic=_semantic)
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
@core.extern
|
| 862 |
+
def longlong_as_double(arg0, _semantic=None):
|
| 863 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 864 |
+
(core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")),
|
| 865 |
+
}, is_pure=True, _semantic=_semantic)
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
@core.extern
|
| 869 |
+
def double_as_longlong(arg0, _semantic=None):
|
| 870 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 871 |
+
(core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")),
|
| 872 |
+
}, is_pure=True, _semantic=_semantic)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
@core.extern
|
| 876 |
+
def fast_sinf(arg0, _semantic=None):
|
| 877 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 878 |
+
(core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")),
|
| 879 |
+
}, is_pure=True, _semantic=_semantic)
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
@core.extern
|
| 883 |
+
def fast_cosf(arg0, _semantic=None):
|
| 884 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 885 |
+
(core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")),
|
| 886 |
+
}, is_pure=True, _semantic=_semantic)
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
@core.extern
|
| 890 |
+
def fast_log2f(arg0, _semantic=None):
|
| 891 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 892 |
+
(core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")),
|
| 893 |
+
}, is_pure=True, _semantic=_semantic)
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
@core.extern
|
| 897 |
+
def fast_logf(arg0, _semantic=None):
|
| 898 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 899 |
+
(core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")),
|
| 900 |
+
}, is_pure=True, _semantic=_semantic)
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
@core.extern
|
| 904 |
+
def fast_expf(arg0, _semantic=None):
|
| 905 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 906 |
+
(core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")),
|
| 907 |
+
}, is_pure=True, _semantic=_semantic)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
@core.extern
|
| 911 |
+
def fast_tanf(arg0, _semantic=None):
|
| 912 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 913 |
+
(core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")),
|
| 914 |
+
}, is_pure=True, _semantic=_semantic)
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
@core.extern
|
| 918 |
+
def fast_exp10f(arg0, _semantic=None):
|
| 919 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 920 |
+
(core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")),
|
| 921 |
+
}, is_pure=True, _semantic=_semantic)
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
@core.extern
|
| 925 |
+
def fast_log10f(arg0, _semantic=None):
|
| 926 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 927 |
+
(core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")),
|
| 928 |
+
}, is_pure=True, _semantic=_semantic)
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
@core.extern
|
| 932 |
+
def fast_powf(arg0, arg1, _semantic=None):
|
| 933 |
+
return core.extern_elementwise("", "", [arg0, arg1], {
|
| 934 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")),
|
| 935 |
+
}, is_pure=True, _semantic=_semantic)
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
@core.extern
|
| 939 |
+
def hadd(arg0, arg1, _semantic=None):
|
| 940 |
+
return core.extern_elementwise(
|
| 941 |
+
"", "", [arg0, arg1], {
|
| 942 |
+
(core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")),
|
| 943 |
+
(core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")),
|
| 944 |
+
}, is_pure=True, _semantic=_semantic)
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
@core.extern
|
| 948 |
+
def rhadd(arg0, arg1, _semantic=None):
|
| 949 |
+
return core.extern_elementwise(
|
| 950 |
+
"", "", [arg0, arg1], {
|
| 951 |
+
(core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")),
|
| 952 |
+
(core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")),
|
| 953 |
+
}, is_pure=True, _semantic=_semantic)
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
@core.extern
|
| 957 |
+
def sub_rn(arg0, arg1, _semantic=None):
|
| 958 |
+
return core.extern_elementwise(
|
| 959 |
+
"", "", [arg0, arg1], {
|
| 960 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")),
|
| 961 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")),
|
| 962 |
+
}, is_pure=True, _semantic=_semantic)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
@core.extern
|
| 966 |
+
def sub_rz(arg0, arg1, _semantic=None):
|
| 967 |
+
return core.extern_elementwise(
|
| 968 |
+
"", "", [arg0, arg1], {
|
| 969 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")),
|
| 970 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")),
|
| 971 |
+
}, is_pure=True, _semantic=_semantic)
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
@core.extern
|
| 975 |
+
def sub_rd(arg0, arg1, _semantic=None):
|
| 976 |
+
return core.extern_elementwise(
|
| 977 |
+
"", "", [arg0, arg1], {
|
| 978 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")),
|
| 979 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")),
|
| 980 |
+
}, is_pure=True, _semantic=_semantic)
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
@core.extern
|
| 984 |
+
def sub_ru(arg0, arg1, _semantic=None):
|
| 985 |
+
return core.extern_elementwise(
|
| 986 |
+
"", "", [arg0, arg1], {
|
| 987 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")),
|
| 988 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")),
|
| 989 |
+
}, is_pure=True, _semantic=_semantic)
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
@core.extern
|
| 993 |
+
def rsqrt_rn(arg0, _semantic=None):
|
| 994 |
+
return core.extern_elementwise("", "", [
|
| 995 |
+
arg0,
|
| 996 |
+
], {
|
| 997 |
+
(core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")),
|
| 998 |
+
}, is_pure=True, _semantic=_semantic)
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
@core.extern
|
| 1002 |
+
def ffs(arg0, _semantic=None):
|
| 1003 |
+
return core.extern_elementwise(
|
| 1004 |
+
"", "", [
|
| 1005 |
+
arg0,
|
| 1006 |
+
], {
|
| 1007 |
+
(core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")),
|
| 1008 |
+
(core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")),
|
| 1009 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
@core.extern
|
| 1013 |
+
def rint(arg0, _semantic=None):
|
| 1014 |
+
return core.extern_elementwise(
|
| 1015 |
+
"", "", [
|
| 1016 |
+
arg0,
|
| 1017 |
+
], {
|
| 1018 |
+
(core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")),
|
| 1019 |
+
(core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")),
|
| 1020 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
@core.extern
|
| 1024 |
+
def llrint(arg0, _semantic=None):
|
| 1025 |
+
return core.extern_elementwise(
|
| 1026 |
+
"", "", [
|
| 1027 |
+
arg0,
|
| 1028 |
+
], {
|
| 1029 |
+
(core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")),
|
| 1030 |
+
(core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")),
|
| 1031 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
@core.extern
|
| 1035 |
+
def nearbyint(arg0, _semantic=None):
|
| 1036 |
+
return core.extern_elementwise(
|
| 1037 |
+
"", "", [
|
| 1038 |
+
arg0,
|
| 1039 |
+
], {
|
| 1040 |
+
(core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")),
|
| 1041 |
+
(core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")),
|
| 1042 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
@core.extern
|
| 1046 |
+
def isnan(arg0, _semantic=None):
|
| 1047 |
+
return core.extern_elementwise(
|
| 1048 |
+
"", "", [
|
| 1049 |
+
arg0,
|
| 1050 |
+
], {
|
| 1051 |
+
(core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")),
|
| 1052 |
+
(core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")),
|
| 1053 |
+
}, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
@core.extern
|
| 1057 |
+
def signbit(arg0, _semantic=None):
|
| 1058 |
+
return core.extern_elementwise(
|
| 1059 |
+
"", "", [
|
| 1060 |
+
arg0,
|
| 1061 |
+
], {
|
| 1062 |
+
(core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")),
|
| 1063 |
+
(core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")),
|
| 1064 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1065 |
+
|
| 1066 |
+
|
| 1067 |
+
@core.extern
|
| 1068 |
+
def copysign(arg0, arg1, _semantic=None):
|
| 1069 |
+
return core.extern_elementwise(
|
| 1070 |
+
"", "", [arg0, arg1], {
|
| 1071 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")),
|
| 1072 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")),
|
| 1073 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
@core.extern
|
| 1077 |
+
def finitef(arg0, _semantic=None):
|
| 1078 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 1079 |
+
(core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")),
|
| 1080 |
+
}, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
@core.extern
|
| 1084 |
+
def isinf(arg0, _semantic=None):
|
| 1085 |
+
return core.extern_elementwise(
|
| 1086 |
+
"", "", [arg0], {
|
| 1087 |
+
(core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")),
|
| 1088 |
+
(core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")),
|
| 1089 |
+
}, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
@core.extern
|
| 1093 |
+
def nextafter(arg0, arg1, _semantic=None):
|
| 1094 |
+
return core.extern_elementwise(
|
| 1095 |
+
"", "", [arg0, arg1], {
|
| 1096 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")),
|
| 1097 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")),
|
| 1098 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
@core.extern
|
| 1102 |
+
def sin(arg0, _semantic=None):
|
| 1103 |
+
return core.extern_elementwise(
|
| 1104 |
+
"", "", [arg0], {
|
| 1105 |
+
(core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")),
|
| 1106 |
+
(core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")),
|
| 1107 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
@core.extern
|
| 1111 |
+
def cos(arg0, _semantic=None):
|
| 1112 |
+
return core.extern_elementwise(
|
| 1113 |
+
"", "", [arg0], {
|
| 1114 |
+
(core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")),
|
| 1115 |
+
(core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")),
|
| 1116 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
@core.extern
|
| 1120 |
+
def sinpi(arg0, _semantic=None):
|
| 1121 |
+
return core.extern_elementwise(
|
| 1122 |
+
"", "", [arg0], {
|
| 1123 |
+
(core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")),
|
| 1124 |
+
(core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")),
|
| 1125 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1126 |
+
|
| 1127 |
+
|
| 1128 |
+
@core.extern
|
| 1129 |
+
def cospi(arg0, _semantic=None):
|
| 1130 |
+
return core.extern_elementwise(
|
| 1131 |
+
"", "", [arg0], {
|
| 1132 |
+
(core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")),
|
| 1133 |
+
(core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")),
|
| 1134 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
@core.extern
|
| 1138 |
+
def tan(arg0, _semantic=None):
|
| 1139 |
+
return core.extern_elementwise(
|
| 1140 |
+
"", "", [arg0], {
|
| 1141 |
+
(core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")),
|
| 1142 |
+
(core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")),
|
| 1143 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
@core.extern
|
| 1147 |
+
def log2(arg0, _semantic=None):
|
| 1148 |
+
return core.extern_elementwise(
|
| 1149 |
+
"", "", [arg0], {
|
| 1150 |
+
(core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")),
|
| 1151 |
+
(core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")),
|
| 1152 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
@core.extern
|
| 1156 |
+
def exp(arg0, _semantic=None):
|
| 1157 |
+
return core.extern_elementwise(
|
| 1158 |
+
"", "", [arg0], {
|
| 1159 |
+
(core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")),
|
| 1160 |
+
(core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")),
|
| 1161 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
@core.extern
|
| 1165 |
+
def exp10(arg0, _semantic=None):
|
| 1166 |
+
return core.extern_elementwise(
|
| 1167 |
+
"", "", [arg0], {
|
| 1168 |
+
(core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")),
|
| 1169 |
+
(core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")),
|
| 1170 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
@core.extern
|
| 1174 |
+
def cosh(arg0, _semantic=None):
|
| 1175 |
+
return core.extern_elementwise(
|
| 1176 |
+
"", "", [arg0], {
|
| 1177 |
+
(core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")),
|
| 1178 |
+
(core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")),
|
| 1179 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1180 |
+
|
| 1181 |
+
|
| 1182 |
+
@core.extern
|
| 1183 |
+
def sinh(arg0, _semantic=None):
|
| 1184 |
+
return core.extern_elementwise(
|
| 1185 |
+
"", "", [arg0], {
|
| 1186 |
+
(core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")),
|
| 1187 |
+
(core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")),
|
| 1188 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
@core.extern
|
| 1192 |
+
def tanh(arg0, _semantic=None):
|
| 1193 |
+
return core.extern_elementwise(
|
| 1194 |
+
"", "", [arg0], {
|
| 1195 |
+
(core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")),
|
| 1196 |
+
(core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")),
|
| 1197 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
@core.extern
|
| 1201 |
+
def atan2(arg0, arg1, _semantic=None):
|
| 1202 |
+
return core.extern_elementwise(
|
| 1203 |
+
"", "", [arg0, arg1], {
|
| 1204 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")),
|
| 1205 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")),
|
| 1206 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1207 |
+
|
| 1208 |
+
|
| 1209 |
+
@core.extern
|
| 1210 |
+
def atan(arg0, _semantic=None):
|
| 1211 |
+
return core.extern_elementwise(
|
| 1212 |
+
"", "", [arg0], {
|
| 1213 |
+
(core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")),
|
| 1214 |
+
(core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")),
|
| 1215 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
@core.extern
|
| 1219 |
+
def asin(arg0, _semantic=None):
|
| 1220 |
+
return core.extern_elementwise(
|
| 1221 |
+
"", "", [arg0], {
|
| 1222 |
+
(core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")),
|
| 1223 |
+
(core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")),
|
| 1224 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1225 |
+
|
| 1226 |
+
|
| 1227 |
+
@core.extern
|
| 1228 |
+
def acos(arg0, _semantic=None):
|
| 1229 |
+
return core.extern_elementwise(
|
| 1230 |
+
"", "", [arg0], {
|
| 1231 |
+
(core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")),
|
| 1232 |
+
(core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")),
|
| 1233 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
@core.extern
|
| 1237 |
+
def log(arg0, _semantic=None):
|
| 1238 |
+
return core.extern_elementwise(
|
| 1239 |
+
"", "", [arg0], {
|
| 1240 |
+
(core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")),
|
| 1241 |
+
(core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")),
|
| 1242 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
@core.extern
|
| 1246 |
+
def log10(arg0, _semantic=None):
|
| 1247 |
+
return core.extern_elementwise(
|
| 1248 |
+
"", "", [arg0], {
|
| 1249 |
+
(core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")),
|
| 1250 |
+
(core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")),
|
| 1251 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1252 |
+
|
| 1253 |
+
|
| 1254 |
+
@core.extern
|
| 1255 |
+
def log1p(arg0, _semantic=None):
|
| 1256 |
+
return core.extern_elementwise(
|
| 1257 |
+
"", "", [arg0], {
|
| 1258 |
+
(core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")),
|
| 1259 |
+
(core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")),
|
| 1260 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1261 |
+
|
| 1262 |
+
|
| 1263 |
+
@core.extern
|
| 1264 |
+
def acosh(arg0, _semantic=None):
|
| 1265 |
+
return core.extern_elementwise(
|
| 1266 |
+
"", "", [arg0], {
|
| 1267 |
+
(core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")),
|
| 1268 |
+
(core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")),
|
| 1269 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1270 |
+
|
| 1271 |
+
|
| 1272 |
+
@core.extern
|
| 1273 |
+
def asinh(arg0, _semantic=None):
|
| 1274 |
+
return core.extern_elementwise(
|
| 1275 |
+
"", "", [arg0], {
|
| 1276 |
+
(core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")),
|
| 1277 |
+
(core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")),
|
| 1278 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
@core.extern
|
| 1282 |
+
def atanh(arg0, _semantic=None):
|
| 1283 |
+
return core.extern_elementwise(
|
| 1284 |
+
"", "", [arg0], {
|
| 1285 |
+
(core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")),
|
| 1286 |
+
(core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")),
|
| 1287 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
@core.extern
|
| 1291 |
+
def expm1(arg0, _semantic=None):
|
| 1292 |
+
return core.extern_elementwise(
|
| 1293 |
+
"", "", [arg0], {
|
| 1294 |
+
(core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")),
|
| 1295 |
+
(core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")),
|
| 1296 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
@core.extern
|
| 1300 |
+
def hypot(arg0, arg1, _semantic=None):
|
| 1301 |
+
return core.extern_elementwise(
|
| 1302 |
+
"", "", [arg0, arg1], {
|
| 1303 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")),
|
| 1304 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")),
|
| 1305 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1306 |
+
|
| 1307 |
+
|
| 1308 |
+
@core.extern
|
| 1309 |
+
def rhypot(arg0, arg1, _semantic=None):
|
| 1310 |
+
return core.extern_elementwise(
|
| 1311 |
+
"", "", [arg0, arg1], {
|
| 1312 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")),
|
| 1313 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")),
|
| 1314 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
@core.extern
|
| 1318 |
+
def norm3d(arg0, arg1, arg2, _semantic=None):
|
| 1319 |
+
return core.extern_elementwise(
|
| 1320 |
+
"", "", [arg0, arg1, arg2], {
|
| 1321 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")),
|
| 1322 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")),
|
| 1323 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1324 |
+
|
| 1325 |
+
|
| 1326 |
+
@core.extern
|
| 1327 |
+
def rnorm3d(arg0, arg1, arg2, _semantic=None):
|
| 1328 |
+
return core.extern_elementwise(
|
| 1329 |
+
"", "", [arg0, arg1, arg2], {
|
| 1330 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")),
|
| 1331 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")),
|
| 1332 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1333 |
+
|
| 1334 |
+
|
| 1335 |
+
@core.extern
|
| 1336 |
+
def norm4d(arg0, arg1, arg2, arg3, _semantic=None):
|
| 1337 |
+
return core.extern_elementwise(
|
| 1338 |
+
"", "", [arg0, arg1, arg2, arg3], {
|
| 1339 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")):
|
| 1340 |
+
("__nv_norm4df", core.dtype("fp32")),
|
| 1341 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")):
|
| 1342 |
+
("__nv_norm4d", core.dtype("fp64")),
|
| 1343 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1344 |
+
|
| 1345 |
+
|
| 1346 |
+
@core.extern
|
| 1347 |
+
def rnorm4d(arg0, arg1, arg2, arg3, _semantic=None):
|
| 1348 |
+
return core.extern_elementwise(
|
| 1349 |
+
"", "", [arg0, arg1, arg2, arg3], {
|
| 1350 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")):
|
| 1351 |
+
("__nv_rnorm4df", core.dtype("fp32")),
|
| 1352 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")):
|
| 1353 |
+
("__nv_rnorm4d", core.dtype("fp64")),
|
| 1354 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
@core.extern
|
| 1358 |
+
def cbrt(arg0, _semantic=None):
|
| 1359 |
+
return core.extern_elementwise(
|
| 1360 |
+
"", "", [arg0], {
|
| 1361 |
+
(core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")),
|
| 1362 |
+
(core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")),
|
| 1363 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1364 |
+
|
| 1365 |
+
|
| 1366 |
+
@core.extern
|
| 1367 |
+
def rcbrt(arg0, _semantic=None):
|
| 1368 |
+
return core.extern_elementwise(
|
| 1369 |
+
"", "", [arg0], {
|
| 1370 |
+
(core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")),
|
| 1371 |
+
(core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")),
|
| 1372 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1373 |
+
|
| 1374 |
+
|
| 1375 |
+
@core.extern
|
| 1376 |
+
def j0(arg0, _semantic=None):
|
| 1377 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 1378 |
+
(core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")),
|
| 1379 |
+
(core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")),
|
| 1380 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1381 |
+
|
| 1382 |
+
|
| 1383 |
+
@core.extern
|
| 1384 |
+
def j1(arg0, _semantic=None):
|
| 1385 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 1386 |
+
(core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")),
|
| 1387 |
+
(core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")),
|
| 1388 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1389 |
+
|
| 1390 |
+
|
| 1391 |
+
@core.extern
|
| 1392 |
+
def y0(arg0, _semantic=None):
|
| 1393 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 1394 |
+
(core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")),
|
| 1395 |
+
(core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")),
|
| 1396 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1397 |
+
|
| 1398 |
+
|
| 1399 |
+
@core.extern
|
| 1400 |
+
def y1(arg0, _semantic=None):
|
| 1401 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 1402 |
+
(core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")),
|
| 1403 |
+
(core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")),
|
| 1404 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1405 |
+
|
| 1406 |
+
|
| 1407 |
+
@core.extern
|
| 1408 |
+
def yn(arg0, arg1, _semantic=None):
|
| 1409 |
+
return core.extern_elementwise(
|
| 1410 |
+
"", "", [arg0, arg1], {
|
| 1411 |
+
(core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")),
|
| 1412 |
+
(core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")),
|
| 1413 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1414 |
+
|
| 1415 |
+
|
| 1416 |
+
@core.extern
|
| 1417 |
+
def jn(arg0, arg1, _semantic=None):
|
| 1418 |
+
return core.extern_elementwise(
|
| 1419 |
+
"", "", [arg0, arg1], {
|
| 1420 |
+
(core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")),
|
| 1421 |
+
(core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")),
|
| 1422 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
@core.extern
|
| 1426 |
+
def cyl_bessel_i0(arg0, _semantic=None):
|
| 1427 |
+
return core.extern_elementwise(
|
| 1428 |
+
"", "", [arg0], {
|
| 1429 |
+
(core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")),
|
| 1430 |
+
(core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")),
|
| 1431 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1432 |
+
|
| 1433 |
+
|
| 1434 |
+
@core.extern
|
| 1435 |
+
def cyl_bessel_i1(arg0, _semantic=None):
|
| 1436 |
+
return core.extern_elementwise(
|
| 1437 |
+
"", "", [arg0], {
|
| 1438 |
+
(core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")),
|
| 1439 |
+
(core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")),
|
| 1440 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1441 |
+
|
| 1442 |
+
|
| 1443 |
+
@core.extern
|
| 1444 |
+
def erf(arg0, _semantic=None):
|
| 1445 |
+
return core.extern_elementwise(
|
| 1446 |
+
"", "", [arg0], {
|
| 1447 |
+
(core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")),
|
| 1448 |
+
(core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")),
|
| 1449 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1450 |
+
|
| 1451 |
+
|
| 1452 |
+
@core.extern
|
| 1453 |
+
def erfinv(arg0, _semantic=None):
|
| 1454 |
+
return core.extern_elementwise(
|
| 1455 |
+
"", "", [arg0], {
|
| 1456 |
+
(core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")),
|
| 1457 |
+
(core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")),
|
| 1458 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1459 |
+
|
| 1460 |
+
|
| 1461 |
+
@core.extern
|
| 1462 |
+
def erfc(arg0, _semantic=None):
|
| 1463 |
+
return core.extern_elementwise(
|
| 1464 |
+
"", "", [arg0], {
|
| 1465 |
+
(core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")),
|
| 1466 |
+
(core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")),
|
| 1467 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1468 |
+
|
| 1469 |
+
|
| 1470 |
+
@core.extern
|
| 1471 |
+
def erfcx(arg0, _semantic=None):
|
| 1472 |
+
return core.extern_elementwise(
|
| 1473 |
+
"", "", [arg0], {
|
| 1474 |
+
(core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")),
|
| 1475 |
+
(core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")),
|
| 1476 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1477 |
+
|
| 1478 |
+
|
| 1479 |
+
@core.extern
|
| 1480 |
+
def erfcinv(arg0, _semantic=None):
|
| 1481 |
+
return core.extern_elementwise(
|
| 1482 |
+
"", "", [arg0], {
|
| 1483 |
+
(core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")),
|
| 1484 |
+
(core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")),
|
| 1485 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1486 |
+
|
| 1487 |
+
|
| 1488 |
+
@core.extern
|
| 1489 |
+
def normcdfinv(arg0, _semantic=None):
|
| 1490 |
+
return core.extern_elementwise(
|
| 1491 |
+
"", "", [arg0], {
|
| 1492 |
+
(core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")),
|
| 1493 |
+
(core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")),
|
| 1494 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1495 |
+
|
| 1496 |
+
|
| 1497 |
+
@core.extern
|
| 1498 |
+
def normcdf(arg0, _semantic=None):
|
| 1499 |
+
return core.extern_elementwise(
|
| 1500 |
+
"", "", [arg0], {
|
| 1501 |
+
(core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")),
|
| 1502 |
+
(core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")),
|
| 1503 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1504 |
+
|
| 1505 |
+
|
| 1506 |
+
@core.extern
|
| 1507 |
+
def lgamma(arg0, _semantic=None):
|
| 1508 |
+
return core.extern_elementwise(
|
| 1509 |
+
"", "", [arg0], {
|
| 1510 |
+
(core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")),
|
| 1511 |
+
(core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")),
|
| 1512 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1513 |
+
|
| 1514 |
+
|
| 1515 |
+
@core.extern
|
| 1516 |
+
def ldexp(arg0, arg1, _semantic=None):
|
| 1517 |
+
return core.extern_elementwise(
|
| 1518 |
+
"", "", [arg0, arg1], {
|
| 1519 |
+
(core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")),
|
| 1520 |
+
(core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")),
|
| 1521 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1522 |
+
|
| 1523 |
+
|
| 1524 |
+
@core.extern
|
| 1525 |
+
def scalbn(arg0, arg1, _semantic=None):
|
| 1526 |
+
return core.extern_elementwise(
|
| 1527 |
+
"", "", [arg0, arg1], {
|
| 1528 |
+
(core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")),
|
| 1529 |
+
(core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")),
|
| 1530 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1531 |
+
|
| 1532 |
+
|
| 1533 |
+
@core.extern
|
| 1534 |
+
def fmod(arg0, arg1, _semantic=None):
|
| 1535 |
+
return core.extern_elementwise(
|
| 1536 |
+
"", "", [arg0, arg1], {
|
| 1537 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")),
|
| 1538 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")),
|
| 1539 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1540 |
+
|
| 1541 |
+
|
| 1542 |
+
@core.extern
|
| 1543 |
+
def remainder(arg0, arg1, _semantic=None):
|
| 1544 |
+
return core.extern_elementwise(
|
| 1545 |
+
"", "", [arg0, arg1], {
|
| 1546 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")),
|
| 1547 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")),
|
| 1548 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1549 |
+
|
| 1550 |
+
|
| 1551 |
+
@core.extern
|
| 1552 |
+
def fma(arg0, arg1, arg2, _semantic=None):
|
| 1553 |
+
return core.extern_elementwise(
|
| 1554 |
+
"", "", [arg0, arg1, arg2], {
|
| 1555 |
+
(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")),
|
| 1556 |
+
(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")),
|
| 1557 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1558 |
+
|
| 1559 |
+
|
| 1560 |
+
@core.extern
|
| 1561 |
+
def pow(arg0, arg1, _semantic=None):
|
| 1562 |
+
return core.extern_elementwise(
|
| 1563 |
+
"", "", [arg0, arg1], {
|
| 1564 |
+
(core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")),
|
| 1565 |
+
(core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")),
|
| 1566 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")),
|
| 1567 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")),
|
| 1568 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1569 |
+
|
| 1570 |
+
|
| 1571 |
+
@core.extern
|
| 1572 |
+
def tgamma(arg0, _semantic=None):
|
| 1573 |
+
return core.extern_elementwise(
|
| 1574 |
+
"", "", [arg0], {
|
| 1575 |
+
(core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")),
|
| 1576 |
+
(core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")),
|
| 1577 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1578 |
+
|
| 1579 |
+
|
| 1580 |
+
@core.extern
|
| 1581 |
+
def round(arg0, _semantic=None):
|
| 1582 |
+
return core.extern_elementwise(
|
| 1583 |
+
"", "", [arg0], {
|
| 1584 |
+
(core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")),
|
| 1585 |
+
(core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")),
|
| 1586 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1587 |
+
|
| 1588 |
+
|
| 1589 |
+
@core.extern
|
| 1590 |
+
def llround(arg0, _semantic=None):
|
| 1591 |
+
return core.extern_elementwise(
|
| 1592 |
+
"", "", [arg0], {
|
| 1593 |
+
(core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")),
|
| 1594 |
+
(core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")),
|
| 1595 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1596 |
+
|
| 1597 |
+
|
| 1598 |
+
@core.extern
|
| 1599 |
+
def fdim(arg0, arg1, _semantic=None):
|
| 1600 |
+
return core.extern_elementwise(
|
| 1601 |
+
"", "", [arg0, arg1], {
|
| 1602 |
+
(core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")),
|
| 1603 |
+
(core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")),
|
| 1604 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1605 |
+
|
| 1606 |
+
|
| 1607 |
+
@core.extern
|
| 1608 |
+
def ilogb(arg0, _semantic=None):
|
| 1609 |
+
return core.extern_elementwise(
|
| 1610 |
+
"", "", [arg0], {
|
| 1611 |
+
(core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")),
|
| 1612 |
+
(core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")),
|
| 1613 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1614 |
+
|
| 1615 |
+
|
| 1616 |
+
@core.extern
|
| 1617 |
+
def logb(arg0, _semantic=None):
|
| 1618 |
+
return core.extern_elementwise(
|
| 1619 |
+
"", "", [arg0], {
|
| 1620 |
+
(core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")),
|
| 1621 |
+
(core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")),
|
| 1622 |
+
}, is_pure=True, _semantic=_semantic)
|
| 1623 |
+
|
| 1624 |
+
|
| 1625 |
+
@core.extern
|
| 1626 |
+
def isfinited(arg0, _semantic=None):
|
| 1627 |
+
return core.extern_elementwise("", "", [arg0], {
|
| 1628 |
+
(core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")),
|
| 1629 |
+
}, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from triton.language import core
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@core.extern
|
| 5 |
+
def globaltimer(_semantic=None):
|
| 6 |
+
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
|
| 7 |
+
_semantic=_semantic)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@core.extern
|
| 11 |
+
def smid(_semantic=None):
|
| 12 |
+
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
|
| 13 |
+
_semantic=_semantic)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@core.builtin
|
| 17 |
+
def num_threads(_semantic=None):
|
| 18 |
+
return core.constexpr(_semantic.builder.options.num_warps * 32)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@core.builtin
|
| 22 |
+
def num_warps(_semantic=None):
|
| 23 |
+
return core.constexpr(_semantic.builder.options.num_warps)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ----- FP8E4M3B15 ------
|
| 27 |
+
# This data-type is a variant of the standard FP8E4M3 format.
|
| 28 |
+
# It was designed for fast software conversion to FP16 on
|
| 29 |
+
# nvidia GPUs that do not support it natively.
|
| 30 |
+
# This is the same format as FP8E4M3Nv, but:
|
| 31 |
+
# - the exponent bias is 15 instead of 7
|
| 32 |
+
# - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
|
| 33 |
+
@core.builtin
|
| 34 |
+
def convert_fp8e4b15_to_float16(arg, _semantic=None):
|
| 35 |
+
return core.inline_asm_elementwise(
|
| 36 |
+
"{ \n"
|
| 37 |
+
".reg .b32 a<2>, b<2>; \n"
|
| 38 |
+
"prmt.b32 a0, 0, $2, 0x5746; \n"
|
| 39 |
+
"and.b32 b0, a0, 0x7f007f00; \n"
|
| 40 |
+
"and.b32 b1, a0, 0x00ff00ff; \n"
|
| 41 |
+
"and.b32 a1, a0, 0x00800080; \n"
|
| 42 |
+
"shr.b32 b0, b0, 1; \n"
|
| 43 |
+
"add.u32 b1, b1, a1; \n"
|
| 44 |
+
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
| 45 |
+
"shl.b32 $1, b1, 7; \n"
|
| 46 |
+
"} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
|
| 47 |
+
_semantic=_semantic)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@core.builtin
|
| 51 |
+
def convert_float16_to_fp8e4b15(arg, has_minx2, _semantic=None):
|
| 52 |
+
asm = """{
|
| 53 |
+
.reg .pred p<4>;
|
| 54 |
+
.reg .b32 a<2>, b<2>;
|
| 55 |
+
.reg .b16 c<4>;
|
| 56 |
+
.reg .b16 max_val_f16;
|
| 57 |
+
.reg .b32 max_val_f16x2;
|
| 58 |
+
mov.b16 max_val_f16, 0x3F00;
|
| 59 |
+
mov.b32 max_val_f16x2, 0x3F003F00;
|
| 60 |
+
and.b32 a0, $1, 0x7fff7fff;
|
| 61 |
+
and.b32 a1, $2, 0x7fff7fff;"""
|
| 62 |
+
if has_minx2:
|
| 63 |
+
asm += """min.f16x2 a0, a0, max_val_f16x2;
|
| 64 |
+
min.f16x2 a1, a1, max_val_f16x2;"""
|
| 65 |
+
else:
|
| 66 |
+
asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2;
|
| 67 |
+
setp.lt.f16x2 p2|p3, a1, max_val_f16x2;
|
| 68 |
+
mov.b32 {c0, c1}, a0;
|
| 69 |
+
mov.b32 {c2, c3}, a1;
|
| 70 |
+
selp.b16 c0, c0, max_val_f16, p0;
|
| 71 |
+
selp.b16 c1, c1, max_val_f16, p1;
|
| 72 |
+
selp.b16 c2, c2, max_val_f16, p2;
|
| 73 |
+
selp.b16 c3, c3, max_val_f16, p3;
|
| 74 |
+
mov.b32 a0, {c0, c1};
|
| 75 |
+
mov.b32 a1, {c2, c3};"""
|
| 76 |
+
asm += """mad.lo.u32 a0, a0, 2, 0x00800080;
|
| 77 |
+
mad.lo.u32 a1, a1, 2, 0x00800080;
|
| 78 |
+
lop3.b32 b0, $1, 0x80008000, a0, 0xea;
|
| 79 |
+
lop3.b32 b1, $2, 0x80008000, a1, 0xea;
|
| 80 |
+
prmt.b32 $0, b0, b1, 0x7531;
|
| 81 |
+
}"""
|
| 82 |
+
return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
|
| 83 |
+
_semantic=_semantic)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@core.builtin
|
| 87 |
+
def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _semantic=None):
|
| 88 |
+
if arg.type.scalar.is_fp8e4b15():
|
| 89 |
+
upcast_val = convert_fp8e4b15_to_float16(arg, _semantic=_semantic)
|
| 90 |
+
if dst_ty.scalar.is_fp32():
|
| 91 |
+
upcast_val = upcast_val.to(core.float32, _semantic=_semantic)
|
| 92 |
+
return upcast_val
|
| 93 |
+
|
| 94 |
+
assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
|
| 95 |
+
downcast_val = arg
|
| 96 |
+
if arg.type.scalar.is_fp32():
|
| 97 |
+
downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _semantic=_semantic)
|
| 98 |
+
downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _semantic=_semantic)
|
| 99 |
+
return downcast_val
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@core.builtin
|
| 103 |
+
def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
|
| 104 |
+
return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _semantic=_semantic)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@core.builtin
|
| 108 |
+
def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
|
| 109 |
+
return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _semantic=_semantic)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/hip/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import libdevice
|
| 2 |
+
|
| 3 |
+
from .utils import memrealtime
|
| 4 |
+
|
| 5 |
+
__all__ = ["libdevice", "memrealtime"]
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/hip/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (347 Bytes). View file
|
|
|