BryanW commited on
Commit
a3277d7
·
verified ·
1 Parent(s): 95414f3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/__init__.cpython-312.pyc +0 -0
  2. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/big_modeling.cpython-312.pyc +0 -0
  3. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/checkpointing.cpython-312.pyc +0 -0
  4. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/launchers.cpython-312.pyc +0 -0
  5. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/optimizer.cpython-312.pyc +0 -0
  6. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/ElementSoup.py +10 -0
  7. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/_difflib.py +2106 -0
  8. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/formfill.py +299 -0
  9. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/html5parser.py +260 -0
  10. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/__init__.pxd +0 -0
  11. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/config.pxd +3 -0
  12. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/relaxng.pxd +64 -0
  13. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/schematron.pxd +34 -0
  14. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/xpath.pxd +136 -0
  15. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/isoschematron/__init__.py +348 -0
  16. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/__init__.cpython-312.pyc +0 -0
  17. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/compiler.cpython-312.pyc +0 -0
  18. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/driver.cpython-312.pyc +0 -0
  19. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/__init__.py +0 -0
  20. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/compiler.py +495 -0
  21. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/driver.c +504 -0
  22. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/driver.py +877 -0
  23. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/__init__.py +0 -0
  24. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py +553 -0
  25. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.c +518 -0
  26. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py +764 -0
  27. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/cudaGL.h +608 -0
  28. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/cupti_pcsampling_util.h +402 -0
  29. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/driver_types.h +0 -0
  30. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/generated_cudaVDPAU_meta.h +46 -0
  31. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/nvperf_target.h +626 -0
  32. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/sm_32_atomic_functions.hpp +151 -0
  33. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/__init__.cpython-312.pyc +0 -0
  34. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/math.cpython-312.pyc +0 -0
  35. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/random.cpython-312.pyc +0 -0
  36. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/standard.cpython-312.pyc +0 -0
  37. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/target_info.cpython-312.pyc +0 -0
  38. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__init__.py +26 -0
  39. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__pycache__/__init__.cpython-312.pyc +0 -0
  40. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__pycache__/libdevice.cpython-312.pyc +0 -0
  41. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__init__.py +16 -0
  42. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/__init__.cpython-312.pyc +0 -0
  43. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/gdc.cpython-312.pyc +0 -0
  44. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/libdevice.cpython-312.pyc +0 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/utils.cpython-312.pyc +0 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/gdc.py +42 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/libdevice.py +1629 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/utils.py +109 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/hip/__init__.py +5 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/hip/__pycache__/__init__.cpython-312.pyc +0 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.39 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/big_modeling.cpython-312.pyc ADDED
Binary file (36.9 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/checkpointing.cpython-312.pyc ADDED
Binary file (15.4 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/launchers.cpython-312.pyc ADDED
Binary file (12.9 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/accelerate/__pycache__/optimizer.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/ElementSoup.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __doc__ = """Legacy interface to the BeautifulSoup HTML parser.
2
+ """
3
+
4
+ __all__ = ["parse", "convert_tree"]
5
+
6
+ from .soupparser import convert_tree, parse as _parse
7
+
8
+ def parse(file, beautifulsoup=None, makeelement=None):
9
+ root = _parse(file, beautifulsoup=beautifulsoup, makeelement=makeelement)
10
+ return root.getroot()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/_difflib.py ADDED
@@ -0,0 +1,2106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from CPython 3.14b2+.
2
+ # cython: infer_types=True
3
+
4
+ """
5
+ Module difflib -- helpers for computing deltas between objects.
6
+
7
+ Function get_close_matches(word, possibilities, n=3, cutoff=0.6):
8
+ Use SequenceMatcher to return list of the best "good enough" matches.
9
+
10
+ Function context_diff(a, b):
11
+ For two lists of strings, return a delta in context diff format.
12
+
13
+ Function ndiff(a, b):
14
+ Return a delta: the difference between `a` and `b` (lists of strings).
15
+
16
+ Function restore(delta, which):
17
+ Return one of the two sequences that generated an ndiff delta.
18
+
19
+ Function unified_diff(a, b):
20
+ For two lists of strings, return a delta in unified diff format.
21
+
22
+ Class SequenceMatcher:
23
+ A flexible class for comparing pairs of sequences of any type.
24
+
25
+ Class Differ:
26
+ For producing human-readable deltas from sequences of lines of text.
27
+
28
+ Class HtmlDiff:
29
+ For producing HTML side by side comparison with change highlights.
30
+ """
31
+
32
+ try:
33
+ import cython
34
+ except ImportError:
35
+ class fake_cython:
36
+ compiled = False
37
+ def cfunc(self, func): return func
38
+ def declare(self, _, value): return value
39
+ def __getattr__(self, type_name): return "object"
40
+
41
+ cython = fake_cython()
42
+
43
+
44
+ __all__ = ['get_close_matches', 'ndiff', 'restore', 'SequenceMatcher',
45
+ 'Differ','IS_CHARACTER_JUNK', 'IS_LINE_JUNK', 'context_diff',
46
+ 'unified_diff', 'diff_bytes', 'HtmlDiff', 'Match']
47
+
48
+ from heapq import nlargest as _nlargest
49
+ from collections import namedtuple as _namedtuple
50
+
51
+ try:
52
+ from types import GenericAlias
53
+ except ImportError:
54
+ GenericAlias = None
55
+
56
+ Match = _namedtuple('Match', 'a b size')
57
+
58
+ def _calculate_ratio(matches, length):
59
+ if length:
60
+ return 2.0 * matches / length
61
+ return 1.0
62
+
63
+ class SequenceMatcher:
64
+
65
+ """
66
+ SequenceMatcher is a flexible class for comparing pairs of sequences of
67
+ any type, so long as the sequence elements are hashable. The basic
68
+ algorithm predates, and is a little fancier than, an algorithm
69
+ published in the late 1980's by Ratcliff and Obershelp under the
70
+ hyperbolic name "gestalt pattern matching". The basic idea is to find
71
+ the longest contiguous matching subsequence that contains no "junk"
72
+ elements (R-O doesn't address junk). The same idea is then applied
73
+ recursively to the pieces of the sequences to the left and to the right
74
+ of the matching subsequence. This does not yield minimal edit
75
+ sequences, but does tend to yield matches that "look right" to people.
76
+
77
+ SequenceMatcher tries to compute a "human-friendly diff" between two
78
+ sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the
79
+ longest *contiguous* & junk-free matching subsequence. That's what
80
+ catches peoples' eyes. The Windows(tm) windiff has another interesting
81
+ notion, pairing up elements that appear uniquely in each sequence.
82
+ That, and the method here, appear to yield more intuitive difference
83
+ reports than does diff. This method appears to be the least vulnerable
84
+ to syncing up on blocks of "junk lines", though (like blank lines in
85
+ ordinary text files, or maybe "<P>" lines in HTML files). That may be
86
+ because this is the only method of the 3 that has a *concept* of
87
+ "junk" <wink>.
88
+
89
+ Example, comparing two strings, and considering blanks to be "junk":
90
+
91
+ >>> s = SequenceMatcher(lambda x: x == " ",
92
+ ... "private Thread currentThread;",
93
+ ... "private volatile Thread currentThread;")
94
+ >>>
95
+
96
+ .ratio() returns a float in [0, 1], measuring the "similarity" of the
97
+ sequences. As a rule of thumb, a .ratio() value over 0.6 means the
98
+ sequences are close matches:
99
+
100
+ >>> print(round(s.ratio(), 3))
101
+ 0.866
102
+ >>>
103
+
104
+ If you're only interested in where the sequences match,
105
+ .get_matching_blocks() is handy:
106
+
107
+ >>> for block in s.get_matching_blocks():
108
+ ... print("a[%d] and b[%d] match for %d elements" % block)
109
+ a[0] and b[0] match for 8 elements
110
+ a[8] and b[17] match for 21 elements
111
+ a[29] and b[38] match for 0 elements
112
+
113
+ Note that the last tuple returned by .get_matching_blocks() is always a
114
+ dummy, (len(a), len(b), 0), and this is the only case in which the last
115
+ tuple element (number of elements matched) is 0.
116
+
117
+ If you want to know how to change the first sequence into the second,
118
+ use .get_opcodes():
119
+
120
+ >>> for opcode in s.get_opcodes():
121
+ ... print("%6s a[%d:%d] b[%d:%d]" % opcode)
122
+ equal a[0:8] b[0:8]
123
+ insert a[8:8] b[8:17]
124
+ equal a[8:29] b[17:38]
125
+
126
+ See the Differ class for a fancy human-friendly file differencer, which
127
+ uses SequenceMatcher both to compare sequences of lines, and to compare
128
+ sequences of characters within similar (near-matching) lines.
129
+
130
+ See also function get_close_matches() in this module, which shows how
131
+ simple code building on SequenceMatcher can be used to do useful work.
132
+
133
+ Timing: Basic R-O is cubic time worst case and quadratic time expected
134
+ case. SequenceMatcher is quadratic time for the worst case and has
135
+ expected-case behavior dependent in a complicated way on how many
136
+ elements the sequences have in common; best case time is linear.
137
+ """
138
+
139
+ def __init__(self, isjunk=None, a='', b='', autojunk=True):
140
+ """Construct a SequenceMatcher.
141
+
142
+ Optional arg isjunk is None (the default), or a one-argument
143
+ function that takes a sequence element and returns true iff the
144
+ element is junk. None is equivalent to passing "lambda x: 0", i.e.
145
+ no elements are considered to be junk. For example, pass
146
+ lambda x: x in " \\t"
147
+ if you're comparing lines as sequences of characters, and don't
148
+ want to synch up on blanks or hard tabs.
149
+
150
+ Optional arg a is the first of two sequences to be compared. By
151
+ default, an empty string. The elements of a must be hashable. See
152
+ also .set_seqs() and .set_seq1().
153
+
154
+ Optional arg b is the second of two sequences to be compared. By
155
+ default, an empty string. The elements of b must be hashable. See
156
+ also .set_seqs() and .set_seq2().
157
+
158
+ Optional arg autojunk should be set to False to disable the
159
+ "automatic junk heuristic" that treats popular elements as junk
160
+ (see module documentation for more information).
161
+ """
162
+
163
+ # Members:
164
+ # a
165
+ # first sequence
166
+ # b
167
+ # second sequence; differences are computed as "what do
168
+ # we need to do to 'a' to change it into 'b'?"
169
+ # b2j
170
+ # for x in b, b2j[x] is a list of the indices (into b)
171
+ # at which x appears; junk and popular elements do not appear
172
+ # fullbcount
173
+ # for x in b, fullbcount[x] == the number of times x
174
+ # appears in b; only materialized if really needed (used
175
+ # only for computing quick_ratio())
176
+ # matching_blocks
177
+ # a list of (i, j, k) triples, where a[i:i+k] == b[j:j+k];
178
+ # ascending & non-overlapping in i and in j; terminated by
179
+ # a dummy (len(a), len(b), 0) sentinel
180
+ # opcodes
181
+ # a list of (tag, i1, i2, j1, j2) tuples, where tag is
182
+ # one of
183
+ # 'replace' a[i1:i2] should be replaced by b[j1:j2]
184
+ # 'delete' a[i1:i2] should be deleted
185
+ # 'insert' b[j1:j2] should be inserted
186
+ # 'equal' a[i1:i2] == b[j1:j2]
187
+ # isjunk
188
+ # a user-supplied function taking a sequence element and
189
+ # returning true iff the element is "junk" -- this has
190
+ # subtle but helpful effects on the algorithm, which I'll
191
+ # get around to writing up someday <0.9 wink>.
192
+ # DON'T USE! Only __chain_b uses this. Use "in self.bjunk".
193
+ # bjunk
194
+ # the items in b for which isjunk is True.
195
+ # bpopular
196
+ # nonjunk items in b treated as junk by the heuristic (if used).
197
+
198
+ self.isjunk = isjunk
199
+ self.a = self.b = None
200
+ self.autojunk = autojunk
201
+ self.set_seqs(a, b)
202
+
203
+ def set_seqs(self, a, b):
204
+ """Set the two sequences to be compared.
205
+
206
+ >>> s = SequenceMatcher()
207
+ >>> s.set_seqs("abcd", "bcde")
208
+ >>> s.ratio()
209
+ 0.75
210
+ """
211
+
212
+ self.set_seq1(a)
213
+ self.set_seq2(b)
214
+
215
+ def set_seq1(self, a):
216
+ """Set the first sequence to be compared.
217
+
218
+ The second sequence to be compared is not changed.
219
+
220
+ >>> s = SequenceMatcher(None, "abcd", "bcde")
221
+ >>> s.ratio()
222
+ 0.75
223
+ >>> s.set_seq1("bcde")
224
+ >>> s.ratio()
225
+ 1.0
226
+ >>>
227
+
228
+ SequenceMatcher computes and caches detailed information about the
229
+ second sequence, so if you want to compare one sequence S against
230
+ many sequences, use .set_seq2(S) once and call .set_seq1(x)
231
+ repeatedly for each of the other sequences.
232
+
233
+ See also set_seqs() and set_seq2().
234
+ """
235
+
236
+ if a is self.a:
237
+ return
238
+ self.a = a
239
+ self.matching_blocks = self.opcodes = None
240
+
241
+ def set_seq2(self, b):
242
+ """Set the second sequence to be compared.
243
+
244
+ The first sequence to be compared is not changed.
245
+
246
+ >>> s = SequenceMatcher(None, "abcd", "bcde")
247
+ >>> s.ratio()
248
+ 0.75
249
+ >>> s.set_seq2("abcd")
250
+ >>> s.ratio()
251
+ 1.0
252
+ >>>
253
+
254
+ SequenceMatcher computes and caches detailed information about the
255
+ second sequence, so if you want to compare one sequence S against
256
+ many sequences, use .set_seq2(S) once and call .set_seq1(x)
257
+ repeatedly for each of the other sequences.
258
+
259
+ See also set_seqs() and set_seq1().
260
+ """
261
+
262
+ if b is self.b:
263
+ return
264
+ self.b = b
265
+ self.matching_blocks = self.opcodes = None
266
+ self.fullbcount = None
267
+ self.__chain_b()
268
+
269
+ # For each element x in b, set b2j[x] to a list of the indices in
270
+ # b where x appears; the indices are in increasing order; note that
271
+ # the number of times x appears in b is len(b2j[x]) ...
272
+ # when self.isjunk is defined, junk elements don't show up in this
273
+ # map at all, which stops the central find_longest_match method
274
+ # from starting any matching block at a junk element ...
275
+ # b2j also does not contain entries for "popular" elements, meaning
276
+ # elements that account for more than 1 + 1% of the total elements, and
277
+ # when the sequence is reasonably large (>= 200 elements); this can
278
+ # be viewed as an adaptive notion of semi-junk, and yields an enormous
279
+ # speedup when, e.g., comparing program files with hundreds of
280
+ # instances of "return NULL;" ...
281
+ # note that this is only called when b changes; so for cross-product
282
+ # kinds of matches, it's best to call set_seq2 once, then set_seq1
283
+ # repeatedly
284
+
285
+ def __chain_b(self):
286
+ # Because isjunk is a user-defined (not C) function, and we test
287
+ # for junk a LOT, it's important to minimize the number of calls.
288
+ # Before the tricks described here, __chain_b was by far the most
289
+ # time-consuming routine in the whole module! If anyone sees
290
+ # Jim Roskind, thank him again for profile.py -- I never would
291
+ # have guessed that.
292
+ # The first trick is to build b2j ignoring the possibility
293
+ # of junk. I.e., we don't call isjunk at all yet. Throwing
294
+ # out the junk later is much cheaper than building b2j "right"
295
+ # from the start.
296
+ b = self.b
297
+ self.b2j = b2j = {}
298
+
299
+ for i, elt in enumerate(b):
300
+ indices = b2j.setdefault(elt, [])
301
+ indices.append(i)
302
+
303
+ # Purge junk elements
304
+ self.bjunk = junk = set()
305
+ isjunk = self.isjunk
306
+ if isjunk:
307
+ for elt in b2j.keys():
308
+ if isjunk(elt):
309
+ junk.add(elt)
310
+ for elt in junk: # separate loop avoids separate list of keys
311
+ del b2j[elt]
312
+
313
+ # Purge popular elements that are not junk
314
+ self.bpopular = popular = set()
315
+ n = len(b)
316
+ if self.autojunk and n >= 200:
317
+ ntest = n // 100 + 1
318
+ for elt, idxs in b2j.items():
319
+ if len(idxs) > ntest:
320
+ popular.add(elt)
321
+ for elt in popular: # ditto; as fast for 1% deletion
322
+ del b2j[elt]
323
+
324
+ def find_longest_match(self, alo=0, ahi_=None, blo=0, bhi_=None):
325
+ """Find longest matching block in a[alo:ahi] and b[blo:bhi].
326
+
327
+ By default it will find the longest match in the entirety of a and b.
328
+
329
+ If isjunk is not defined:
330
+
331
+ Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where
332
+ alo <= i <= i+k <= ahi
333
+ blo <= j <= j+k <= bhi
334
+ and for all (i',j',k') meeting those conditions,
335
+ k >= k'
336
+ i <= i'
337
+ and if i == i', j <= j'
338
+
339
+ In other words, of all maximal matching blocks, return one that
340
+ starts earliest in a, and of all those maximal matching blocks that
341
+ start earliest in a, return the one that starts earliest in b.
342
+
343
+ >>> s = SequenceMatcher(None, " abcd", "abcd abcd")
344
+ >>> s.find_longest_match(0, 5, 0, 9)
345
+ Match(a=0, b=4, size=5)
346
+
347
+ If isjunk is defined, first the longest matching block is
348
+ determined as above, but with the additional restriction that no
349
+ junk element appears in the block. Then that block is extended as
350
+ far as possible by matching (only) junk elements on both sides. So
351
+ the resulting block never matches on junk except as identical junk
352
+ happens to be adjacent to an "interesting" match.
353
+
354
+ Here's the same example as before, but considering blanks to be
355
+ junk. That prevents " abcd" from matching the " abcd" at the tail
356
+ end of the second sequence directly. Instead only the "abcd" can
357
+ match, and matches the leftmost "abcd" in the second sequence:
358
+
359
+ >>> s = SequenceMatcher(lambda x: x==" ", " abcd", "abcd abcd")
360
+ >>> s.find_longest_match(0, 5, 0, 9)
361
+ Match(a=1, b=0, size=4)
362
+
363
+ If no blocks match, return (alo, blo, 0).
364
+
365
+ >>> s = SequenceMatcher(None, "ab", "c")
366
+ >>> s.find_longest_match(0, 2, 0, 1)
367
+ Match(a=0, b=0, size=0)
368
+ """
369
+
370
+ # CAUTION: stripping common prefix or suffix would be incorrect.
371
+ # E.g.,
372
+ # ab
373
+ # acab
374
+ # Longest matching block is "ab", but if common prefix is
375
+ # stripped, it's "a" (tied with "b"). UNIX(tm) diff does so
376
+ # strip, so ends up claiming that ab is changed to acab by
377
+ # inserting "ca" in the middle. That's minimal but unintuitive:
378
+ # "it's obvious" that someone inserted "ac" at the front.
379
+ # Windiff ends up at the same place as diff, but by pairing up
380
+ # the unique 'b's and then matching the first two 'a's.
381
+
382
+ bjunk: set = self.bjunk
383
+ a, b, b2j = self.a, self.b, self.b2j
384
+ ahi = len(a) if ahi_ is None else ahi_
385
+ bhi = len(b) if bhi_ is None else bhi_
386
+ besti, bestj, bestsize = alo, blo, 0
387
+ # find longest junk-free match
388
+ # during an iteration of the loop, j2len[j] = length of longest
389
+ # junk-free match ending with a[i-1] and b[j]
390
+ j2len = {}
391
+ nothing = []
392
+ for i in range(alo, ahi):
393
+ # look at all instances of a[i] in b; note that because
394
+ # b2j has no junk keys, the loop is skipped if a[i] is junk
395
+ newj2len = {}
396
+ for j in b2j.get(a[i], nothing):
397
+ # a[i] matches b[j]
398
+ if j < blo:
399
+ continue
400
+ if j >= bhi:
401
+ break
402
+ k = newj2len[j] = j2len.get(j-1, 0) + 1
403
+ if k > bestsize:
404
+ besti, bestj, bestsize = i-k+1, j-k+1, k
405
+ j2len = newj2len
406
+
407
+ # Extend the best by non-junk elements on each end. In particular,
408
+ # "popular" non-junk elements aren't in b2j, which greatly speeds
409
+ # the inner loop above, but also means "the best" match so far
410
+ # doesn't contain any junk *or* popular non-junk elements.
411
+ while besti > alo and bestj > blo and \
412
+ b[bestj-1] not in bjunk and \
413
+ a[besti-1] == b[bestj-1]:
414
+ besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
415
+ while besti+bestsize < ahi and bestj+bestsize < bhi and \
416
+ b[bestj+bestsize] not in bjunk and \
417
+ a[besti+bestsize] == b[bestj+bestsize]:
418
+ bestsize += 1
419
+
420
+ # Now that we have a wholly interesting match (albeit possibly
421
+ # empty!), we may as well suck up the matching junk on each
422
+ # side of it too. Can't think of a good reason not to, and it
423
+ # saves post-processing the (possibly considerable) expense of
424
+ # figuring out what to do with it. In the case of an empty
425
+ # interesting match, this is clearly the right thing to do,
426
+ # because no other kind of match is possible in the regions.
427
+ while besti > alo and bestj > blo and \
428
+ b[bestj-1] in bjunk and \
429
+ a[besti-1] == b[bestj-1]:
430
+ besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
431
+ while besti+bestsize < ahi and bestj+bestsize < bhi and \
432
+ b[bestj+bestsize] in bjunk and \
433
+ a[besti+bestsize] == b[bestj+bestsize]:
434
+ bestsize = bestsize + 1
435
+
436
+ return Match(besti, bestj, bestsize)
437
+
438
+ def get_matching_blocks(self):
439
+ """Return list of triples describing matching subsequences.
440
+
441
+ Each triple is of the form (i, j, n), and means that
442
+ a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in
443
+ i and in j. New in Python 2.5, it's also guaranteed that if
444
+ (i, j, n) and (i', j', n') are adjacent triples in the list, and
445
+ the second is not the last triple in the list, then i+n != i' or
446
+ j+n != j'. IOW, adjacent triples never describe adjacent equal
447
+ blocks.
448
+
449
+ The last triple is a dummy, (len(a), len(b), 0), and is the only
450
+ triple with n==0.
451
+
452
+ >>> s = SequenceMatcher(None, "abxcd", "abcd")
453
+ >>> list(s.get_matching_blocks())
454
+ [Match(a=0, b=0, size=2), Match(a=3, b=2, size=2), Match(a=5, b=4, size=0)]
455
+ """
456
+
457
+ if self.matching_blocks is not None:
458
+ return self.matching_blocks
459
+ la, lb = len(self.a), len(self.b)
460
+
461
+ # This is most naturally expressed as a recursive algorithm, but
462
+ # at least one user bumped into extreme use cases that exceeded
463
+ # the recursion limit on their box. So, now we maintain a list
464
+ # ('queue`) of blocks we still need to look at, and append partial
465
+ # results to `matching_blocks` in a loop; the matches are sorted
466
+ # at the end.
467
+ queue = [(0, la, 0, lb)]
468
+ matching_blocks = []
469
+ while queue:
470
+ alo, ahi, blo, bhi = queue.pop()
471
+ i, j, k = x = self.find_longest_match(alo, ahi, blo, bhi)
472
+ # a[alo:i] vs b[blo:j] unknown
473
+ # a[i:i+k] same as b[j:j+k]
474
+ # a[i+k:ahi] vs b[j+k:bhi] unknown
475
+ if k: # if k is 0, there was no matching block
476
+ matching_blocks.append(x)
477
+ if alo < i and blo < j:
478
+ queue.append((alo, i, blo, j))
479
+ if i+k < ahi and j+k < bhi:
480
+ queue.append((i+k, ahi, j+k, bhi))
481
+ matching_blocks.sort()
482
+
483
+ # It's possible that we have adjacent equal blocks in the
484
+ # matching_blocks list now. Starting with 2.5, this code was added
485
+ # to collapse them.
486
+ i1 = j1 = k1 = 0
487
+ non_adjacent = []
488
+ for i2, j2, k2 in matching_blocks:
489
+ # Is this block adjacent to i1, j1, k1?
490
+ if i1 + k1 == i2 and j1 + k1 == j2:
491
+ # Yes, so collapse them -- this just increases the length of
492
+ # the first block by the length of the second, and the first
493
+ # block so lengthened remains the block to compare against.
494
+ k1 += k2
495
+ else:
496
+ # Not adjacent. Remember the first block (k1==0 means it's
497
+ # the dummy we started with), and make the second block the
498
+ # new block to compare against.
499
+ if k1:
500
+ non_adjacent.append((i1, j1, k1))
501
+ i1, j1, k1 = i2, j2, k2
502
+ if k1:
503
+ non_adjacent.append((i1, j1, k1))
504
+
505
+ non_adjacent.append( (la, lb, 0) )
506
+ self.matching_blocks = list(map(Match._make, non_adjacent))
507
+ return self.matching_blocks
508
+
509
+ def get_opcodes(self):
510
+ """Return list of 5-tuples describing how to turn a into b.
511
+
512
+ Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple
513
+ has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the
514
+ tuple preceding it, and likewise for j1 == the previous j2.
515
+
516
+ The tags are strings, with these meanings:
517
+
518
+ 'replace': a[i1:i2] should be replaced by b[j1:j2]
519
+ 'delete': a[i1:i2] should be deleted.
520
+ Note that j1==j2 in this case.
521
+ 'insert': b[j1:j2] should be inserted at a[i1:i1].
522
+ Note that i1==i2 in this case.
523
+ 'equal': a[i1:i2] == b[j1:j2]
524
+
525
+ >>> a = "qabxcd"
526
+ >>> b = "abycdf"
527
+ >>> s = SequenceMatcher(None, a, b)
528
+ >>> for tag, i1, i2, j1, j2 in s.get_opcodes():
529
+ ... print(("%7s a[%d:%d] (%s) b[%d:%d] (%s)" %
530
+ ... (tag, i1, i2, a[i1:i2], j1, j2, b[j1:j2])))
531
+ delete a[0:1] (q) b[0:0] ()
532
+ equal a[1:3] (ab) b[0:2] (ab)
533
+ replace a[3:4] (x) b[2:3] (y)
534
+ equal a[4:6] (cd) b[3:5] (cd)
535
+ insert a[6:6] () b[5:6] (f)
536
+ """
537
+
538
+ if self.opcodes is not None:
539
+ return self.opcodes
540
+ i = j = 0
541
+ self.opcodes = answer = []
542
+ for ai, bj, size in self.get_matching_blocks():
543
+ # invariant: we've pumped out correct diffs to change
544
+ # a[:i] into b[:j], and the next matching block is
545
+ # a[ai:ai+size] == b[bj:bj+size]. So we need to pump
546
+ # out a diff to change a[i:ai] into b[j:bj], pump out
547
+ # the matching block, and move (i,j) beyond the match
548
+ tag = ''
549
+ if i < ai and j < bj:
550
+ tag = 'replace'
551
+ elif i < ai:
552
+ tag = 'delete'
553
+ elif j < bj:
554
+ tag = 'insert'
555
+ if tag:
556
+ answer.append( (tag, i, ai, j, bj) )
557
+ i, j = ai+size, bj+size
558
+ # the list of matching blocks is terminated by a
559
+ # sentinel with size 0
560
+ if size:
561
+ answer.append( ('equal', ai, i, bj, j) )
562
+ return answer
563
+
564
+ def get_grouped_opcodes(self, n=3):
565
+ """ Isolate change clusters by eliminating ranges with no changes.
566
+
567
+ Return a generator of groups with up to n lines of context.
568
+ Each group is in the same format as returned by get_opcodes().
569
+
570
+ >>> from pprint import pprint
571
+ >>> a = list(map(str, range(1,40)))
572
+ >>> b = a[:]
573
+ >>> b[8:8] = ['i'] # Make an insertion
574
+ >>> b[20] += 'x' # Make a replacement
575
+ >>> b[23:28] = [] # Make a deletion
576
+ >>> b[30] += 'y' # Make another replacement
577
+ >>> pprint(list(SequenceMatcher(None,a,b).get_grouped_opcodes()))
578
+ [[('equal', 5, 8, 5, 8), ('insert', 8, 8, 8, 9), ('equal', 8, 11, 9, 12)],
579
+ [('equal', 16, 19, 17, 20),
580
+ ('replace', 19, 20, 20, 21),
581
+ ('equal', 20, 22, 21, 23),
582
+ ('delete', 22, 27, 23, 23),
583
+ ('equal', 27, 30, 23, 26)],
584
+ [('equal', 31, 34, 27, 30),
585
+ ('replace', 34, 35, 30, 31),
586
+ ('equal', 35, 38, 31, 34)]]
587
+ """
588
+
589
+ codes = self.get_opcodes()
590
+ if not codes:
591
+ codes = [("equal", 0, 1, 0, 1)]
592
+ # Fixup leading and trailing groups if they show no changes.
593
+ if codes[0][0] == 'equal':
594
+ tag, i1, i2, j1, j2 = codes[0]
595
+ codes[0] = tag, max(i1, i2-n), i2, max(j1, j2-n), j2
596
+ if codes[-1][0] == 'equal':
597
+ tag, i1, i2, j1, j2 = codes[-1]
598
+ codes[-1] = tag, i1, min(i2, i1+n), j1, min(j2, j1+n)
599
+
600
+ nn = n + n
601
+ group = []
602
+ for tag, i1, i2, j1, j2 in codes:
603
+ # End the current group and start a new one whenever
604
+ # there is a large range with no changes.
605
+ if tag == 'equal' and i2-i1 > nn:
606
+ group.append((tag, i1, min(i2, i1+n), j1, min(j2, j1+n)))
607
+ yield group
608
+ group = []
609
+ i1, j1 = max(i1, i2-n), max(j1, j2-n)
610
+ group.append((tag, i1, i2, j1 ,j2))
611
+ if group and not (len(group)==1 and group[0][0] == 'equal'):
612
+ yield group
613
+
614
+ def ratio(self):
615
+ """Return a measure of the sequences' similarity (float in [0,1]).
616
+
617
+ Where T is the total number of elements in both sequences, and
618
+ M is the number of matches, this is 2.0*M / T.
619
+ Note that this is 1 if the sequences are identical, and 0 if
620
+ they have nothing in common.
621
+
622
+ .ratio() is expensive to compute if you haven't already computed
623
+ .get_matching_blocks() or .get_opcodes(), in which case you may
624
+ want to try .quick_ratio() or .real_quick_ratio() first to get an
625
+ upper bound.
626
+
627
+ >>> s = SequenceMatcher(None, "abcd", "bcde")
628
+ >>> s.ratio()
629
+ 0.75
630
+ >>> s.quick_ratio()
631
+ 0.75
632
+ >>> s.real_quick_ratio()
633
+ 1.0
634
+ """
635
+
636
+ matches: cython.Py_ssize_t
637
+ matches = sum(triple[-1] for triple in self.get_matching_blocks())
638
+ return _calculate_ratio(matches, len(self.a) + len(self.b))
639
+
640
+ def quick_ratio(self):
641
+ """Return an upper bound on ratio() relatively quickly.
642
+
643
+ This isn't defined beyond that it is an upper bound on .ratio(), and
644
+ is faster to compute.
645
+ """
646
+
647
+ # viewing a and b as multisets, set matches to the cardinality
648
+ # of their intersection; this counts the number of matches
649
+ # without regard to order, so is clearly an upper bound
650
+ if self.fullbcount is None:
651
+ self.fullbcount = fullbcount = {}
652
+ for elt in self.b:
653
+ fullbcount[elt] = fullbcount.get(elt, 0) + 1
654
+ fullbcount = self.fullbcount
655
+ # avail[x] is the number of times x appears in 'b' less the
656
+ # number of times we've seen it in 'a' so far ... kinda
657
+ avail = {}
658
+ matches: cython.Py_ssize_t
659
+ matches = 0
660
+ for elt in self.a:
661
+ if elt in avail:
662
+ numb = avail[elt]
663
+ else:
664
+ numb = fullbcount.get(elt, 0)
665
+ avail[elt] = numb - 1
666
+ if numb > 0:
667
+ matches = matches + 1
668
+ return _calculate_ratio(matches, len(self.a) + len(self.b))
669
+
670
+ def real_quick_ratio(self):
671
+ """Return an upper bound on ratio() very quickly.
672
+
673
+ This isn't defined beyond that it is an upper bound on .ratio(), and
674
+ is faster to compute than either .ratio() or .quick_ratio().
675
+ """
676
+
677
+ la, lb = len(self.a), len(self.b)
678
+ # can't have more matches than the number of elements in the
679
+ # shorter sequence
680
+ return _calculate_ratio(min(la, lb), la + lb)
681
+
682
+ if GenericAlias is not None:
683
+ __class_getitem__ = classmethod(GenericAlias)
684
+
685
+
686
+ def get_close_matches(word, possibilities, n=3, cutoff=0.6):
687
+ """Use SequenceMatcher to return list of the best "good enough" matches.
688
+
689
+ word is a sequence for which close matches are desired (typically a
690
+ string).
691
+
692
+ possibilities is a list of sequences against which to match word
693
+ (typically a list of strings).
694
+
695
+ Optional arg n (default 3) is the maximum number of close matches to
696
+ return. n must be > 0.
697
+
698
+ Optional arg cutoff (default 0.6) is a float in [0, 1]. Possibilities
699
+ that don't score at least that similar to word are ignored.
700
+
701
+ The best (no more than n) matches among the possibilities are returned
702
+ in a list, sorted by similarity score, most similar first.
703
+
704
+ >>> get_close_matches("appel", ["ape", "apple", "peach", "puppy"])
705
+ ['apple', 'ape']
706
+ >>> import keyword as _keyword
707
+ >>> get_close_matches("wheel", _keyword.kwlist)
708
+ ['while']
709
+ >>> get_close_matches("Apple", _keyword.kwlist)
710
+ []
711
+ >>> get_close_matches("accept", _keyword.kwlist)
712
+ ['except']
713
+ """
714
+
715
+ if not n > 0:
716
+ raise ValueError("n must be > 0: %r" % (n,))
717
+ if not 0.0 <= cutoff <= 1.0:
718
+ raise ValueError("cutoff must be in [0.0, 1.0]: %r" % (cutoff,))
719
+ result = []
720
+ s = SequenceMatcher()
721
+ s.set_seq2(word)
722
+ for x in possibilities:
723
+ s.set_seq1(x)
724
+ if s.real_quick_ratio() >= cutoff and \
725
+ s.quick_ratio() >= cutoff and \
726
+ s.ratio() >= cutoff:
727
+ result.append((s.ratio(), x))
728
+
729
+ # Move the best scorers to head of list
730
+ result = _nlargest(n, result)
731
+ # Strip scores for the best n matches
732
+ return [x for score, x in result]
733
+
734
+
735
+ def _keep_original_ws(s, tag_s):
736
+ """Replace whitespace with the original whitespace characters in `s`"""
737
+ return ''.join(
738
+ c if tag_c == " " and c.isspace() else tag_c
739
+ for c, tag_c in zip(s, tag_s)
740
+ )
741
+
742
+
743
+
744
+ class Differ:
745
+ r"""
746
+ Differ is a class for comparing sequences of lines of text, and
747
+ producing human-readable differences or deltas. Differ uses
748
+ SequenceMatcher both to compare sequences of lines, and to compare
749
+ sequences of characters within similar (near-matching) lines.
750
+
751
+ Each line of a Differ delta begins with a two-letter code:
752
+
753
+ '- ' line unique to sequence 1
754
+ '+ ' line unique to sequence 2
755
+ ' ' line common to both sequences
756
+ '? ' line not present in either input sequence
757
+
758
+ Lines beginning with '? ' attempt to guide the eye to intraline
759
+ differences, and were not present in either input sequence. These lines
760
+ can be confusing if the sequences contain tab characters.
761
+
762
+ Note that Differ makes no claim to produce a *minimal* diff. To the
763
+ contrary, minimal diffs are often counter-intuitive, because they synch
764
+ up anywhere possible, sometimes accidental matches 100 pages apart.
765
+ Restricting synch points to contiguous matches preserves some notion of
766
+ locality, at the occasional cost of producing a longer diff.
767
+
768
+ Example: Comparing two texts.
769
+
770
+ First we set up the texts, sequences of individual single-line strings
771
+ ending with newlines (such sequences can also be obtained from the
772
+ `readlines()` method of file-like objects):
773
+
774
+ >>> text1 = ''' 1. Beautiful is better than ugly.
775
+ ... 2. Explicit is better than implicit.
776
+ ... 3. Simple is better than complex.
777
+ ... 4. Complex is better than complicated.
778
+ ... '''.splitlines(keepends=True)
779
+ >>> len(text1)
780
+ 4
781
+ >>> text1[0][-1]
782
+ '\n'
783
+ >>> text2 = ''' 1. Beautiful is better than ugly.
784
+ ... 3. Simple is better than complex.
785
+ ... 4. Complicated is better than complex.
786
+ ... 5. Flat is better than nested.
787
+ ... '''.splitlines(keepends=True)
788
+
789
+ Next we instantiate a Differ object:
790
+
791
+ >>> d = Differ()
792
+
793
+ Note that when instantiating a Differ object we may pass functions to
794
+ filter out line and character 'junk'. See Differ.__init__ for details.
795
+
796
+ Finally, we compare the two:
797
+
798
+ >>> result = list(d.compare(text1, text2))
799
+
800
+ 'result' is a list of strings, so let's pretty-print it:
801
+
802
+ >>> from pprint import pprint as _pprint
803
+ >>> _pprint(result)
804
+ [' 1. Beautiful is better than ugly.\n',
805
+ '- 2. Explicit is better than implicit.\n',
806
+ '- 3. Simple is better than complex.\n',
807
+ '+ 3. Simple is better than complex.\n',
808
+ '? ++\n',
809
+ '- 4. Complex is better than complicated.\n',
810
+ '? ^ ---- ^\n',
811
+ '+ 4. Complicated is better than complex.\n',
812
+ '? ++++ ^ ^\n',
813
+ '+ 5. Flat is better than nested.\n']
814
+
815
+ As a single multi-line string it looks like this:
816
+
817
+ >>> print(''.join(result), end="")
818
+ 1. Beautiful is better than ugly.
819
+ - 2. Explicit is better than implicit.
820
+ - 3. Simple is better than complex.
821
+ + 3. Simple is better than complex.
822
+ ? ++
823
+ - 4. Complex is better than complicated.
824
+ ? ^ ---- ^
825
+ + 4. Complicated is better than complex.
826
+ ? ++++ ^ ^
827
+ + 5. Flat is better than nested.
828
+ """
829
+
830
+ def __init__(self, linejunk=None, charjunk=None):
831
+ """
832
+ Construct a text differencer, with optional filters.
833
+
834
+ The two optional keyword parameters are for filter functions:
835
+
836
+ - `linejunk`: A function that should accept a single string argument,
837
+ and return true iff the string is junk. The module-level function
838
+ `IS_LINE_JUNK` may be used to filter out lines without visible
839
+ characters, except for at most one splat ('#'). It is recommended
840
+ to leave linejunk None; the underlying SequenceMatcher class has
841
+ an adaptive notion of "noise" lines that's better than any static
842
+ definition the author has ever been able to craft.
843
+
844
+ - `charjunk`: A function that should accept a string of length 1. The
845
+ module-level function `IS_CHARACTER_JUNK` may be used to filter out
846
+ whitespace characters (a blank or tab; **note**: bad idea to include
847
+ newline in this!). Use of IS_CHARACTER_JUNK is recommended.
848
+ """
849
+
850
+ self.linejunk = linejunk
851
+ self.charjunk = charjunk
852
+
853
+ def compare(self, a, b):
854
+ r"""
855
+ Compare two sequences of lines; generate the resulting delta.
856
+
857
+ Each sequence must contain individual single-line strings ending with
858
+ newlines. Such sequences can be obtained from the `readlines()` method
859
+ of file-like objects. The delta generated also consists of newline-
860
+ terminated strings, ready to be printed as-is via the writelines()
861
+ method of a file-like object.
862
+
863
+ Example:
864
+
865
+ >>> print(''.join(Differ().compare('one\ntwo\nthree\n'.splitlines(True),
866
+ ... 'ore\ntree\nemu\n'.splitlines(True))),
867
+ ... end="")
868
+ - one
869
+ ? ^
870
+ + ore
871
+ ? ^
872
+ - two
873
+ - three
874
+ ? -
875
+ + tree
876
+ + emu
877
+ """
878
+
879
+ cruncher = SequenceMatcher(self.linejunk, a, b)
880
+ for tag, alo, ahi, blo, bhi in cruncher.get_opcodes():
881
+ if tag == 'replace':
882
+ g = self._fancy_replace(a, alo, ahi, b, blo, bhi)
883
+ elif tag == 'delete':
884
+ g = self._dump('-', a, alo, ahi)
885
+ elif tag == 'insert':
886
+ g = self._dump('+', b, blo, bhi)
887
+ elif tag == 'equal':
888
+ g = self._dump(' ', a, alo, ahi)
889
+ else:
890
+ raise ValueError('unknown tag %r' % (tag,))
891
+
892
+ yield from g
893
+
894
+ def _dump(self, tag, x, lo, hi):
895
+ """Generate comparison results for a same-tagged range."""
896
+ for i in range(lo, hi):
897
+ yield '%s %s' % (tag, x[i])
898
+
899
+ def _plain_replace(self, a, alo, ahi, b, blo, bhi):
900
+ assert alo < ahi and blo < bhi
901
+ # dump the shorter block first -- reduces the burden on short-term
902
+ # memory if the blocks are of very different sizes
903
+ if bhi - blo < ahi - alo:
904
+ first = self._dump('+', b, blo, bhi)
905
+ second = self._dump('-', a, alo, ahi)
906
+ else:
907
+ first = self._dump('-', a, alo, ahi)
908
+ second = self._dump('+', b, blo, bhi)
909
+
910
+ for g in first, second:
911
+ yield from g
912
+
913
+ def _fancy_replace(self, a, alo, ahi, b, blo, bhi):
914
+ r"""
915
+ When replacing one block of lines with another, search the blocks
916
+ for *similar* lines; the best-matching pair (if any) is used as a
917
+ synch point, and intraline difference marking is done on the
918
+ similar pair. Lots of work, but often worth it.
919
+
920
+ Example:
921
+
922
+ >>> d = Differ()
923
+ >>> results = d._fancy_replace(['abcDefghiJkl\n'], 0, 1,
924
+ ... ['abcdefGhijkl\n'], 0, 1)
925
+ >>> print(''.join(results), end="")
926
+ - abcDefghiJkl
927
+ ? ^ ^ ^
928
+ + abcdefGhijkl
929
+ ? ^ ^ ^
930
+ """
931
+ # Don't synch up unless the lines have a similarity score above
932
+ # cutoff. Previously only the smallest pair was handled here,
933
+ # and if there are many pairs with the best ratio, recursion
934
+ # could grow very deep, and runtime cubic. See:
935
+ # https://github.com/python/cpython/issues/119105
936
+ #
937
+ # Later, more pathological cases prompted removing recursion
938
+ # entirely.
939
+ cutoff = 0.74999
940
+ cruncher = SequenceMatcher(self.charjunk)
941
+ crqr = cruncher.real_quick_ratio
942
+ cqr = cruncher.quick_ratio
943
+ cr = cruncher.ratio
944
+
945
+ WINDOW = 10
946
+ best_i = best_j = None
947
+ dump_i, dump_j = alo, blo # smallest indices not yet resolved
948
+ for j in range(blo, bhi):
949
+ cruncher.set_seq2(b[j])
950
+ # Search the corresponding i's within WINDOW for rhe highest
951
+ # ratio greater than `cutoff`.
952
+ aequiv = alo + (j - blo)
953
+ arange = range(max(aequiv - WINDOW, dump_i),
954
+ min(aequiv + WINDOW + 1, ahi))
955
+ if not arange: # likely exit if `a` is shorter than `b`
956
+ break
957
+ best_ratio = cutoff
958
+ for i in arange:
959
+ cruncher.set_seq1(a[i])
960
+ # Ordering by cheapest to most expensive ratio is very
961
+ # valuable, most often getting out early.
962
+ if (crqr() > best_ratio
963
+ and cqr() > best_ratio
964
+ and cr() > best_ratio):
965
+ best_i, best_j, best_ratio = i, j, cr()
966
+
967
+ if best_i is None:
968
+ # found nothing to synch on yet - move to next j
969
+ continue
970
+
971
+ # pump out straight replace from before this synch pair
972
+ yield from self._fancy_helper(a, dump_i, best_i,
973
+ b, dump_j, best_j)
974
+ # do intraline marking on the synch pair
975
+ aelt, belt = a[best_i], b[best_j]
976
+ if aelt != belt:
977
+ # pump out a '-', '?', '+', '?' quad for the synched lines
978
+ atags = btags = ""
979
+ cruncher.set_seqs(aelt, belt)
980
+ for tag, ai1, ai2, bj1, bj2 in cruncher.get_opcodes():
981
+ la, lb = ai2 - ai1, bj2 - bj1
982
+ if tag == 'replace':
983
+ atags += '^' * la
984
+ btags += '^' * lb
985
+ elif tag == 'delete':
986
+ atags += '-' * la
987
+ elif tag == 'insert':
988
+ btags += '+' * lb
989
+ elif tag == 'equal':
990
+ atags += ' ' * la
991
+ btags += ' ' * lb
992
+ else:
993
+ raise ValueError('unknown tag %r' % (tag,))
994
+ yield from self._qformat(aelt, belt, atags, btags)
995
+ else:
996
+ # the synch pair is identical
997
+ yield ' ' + aelt
998
+ dump_i, dump_j = best_i + 1, best_j + 1
999
+ best_i = best_j = None
1000
+
1001
+ # pump out straight replace from after the last synch pair
1002
+ yield from self._fancy_helper(a, dump_i, ahi,
1003
+ b, dump_j, bhi)
1004
+
1005
+ def _fancy_helper(self, a, alo, ahi, b, blo, bhi):
1006
+ g = []
1007
+ if alo < ahi:
1008
+ if blo < bhi:
1009
+ g = self._plain_replace(a, alo, ahi, b, blo, bhi)
1010
+ else:
1011
+ g = self._dump('-', a, alo, ahi)
1012
+ elif blo < bhi:
1013
+ g = self._dump('+', b, blo, bhi)
1014
+
1015
+ yield from g
1016
+
1017
+ def _qformat(self, aline, bline, atags, btags):
1018
+ r"""
1019
+ Format "?" output and deal with tabs.
1020
+
1021
+ Example:
1022
+
1023
+ >>> d = Differ()
1024
+ >>> results = d._qformat('\tabcDefghiJkl\n', '\tabcdefGhijkl\n',
1025
+ ... ' ^ ^ ^ ', ' ^ ^ ^ ')
1026
+ >>> for line in results: print(repr(line))
1027
+ ...
1028
+ '- \tabcDefghiJkl\n'
1029
+ '? \t ^ ^ ^\n'
1030
+ '+ \tabcdefGhijkl\n'
1031
+ '? \t ^ ^ ^\n'
1032
+ """
1033
+ atags = _keep_original_ws(aline, atags).rstrip()
1034
+ btags = _keep_original_ws(bline, btags).rstrip()
1035
+
1036
+ yield "- " + aline
1037
+ if atags:
1038
+ yield f"? {atags}\n"
1039
+
1040
+ yield "+ " + bline
1041
+ if btags:
1042
+ yield f"? {btags}\n"
1043
+
1044
+ # With respect to junk, an earlier version of ndiff simply refused to
1045
+ # *start* a match with a junk element. The result was cases like this:
1046
+ # before: private Thread currentThread;
1047
+ # after: private volatile Thread currentThread;
1048
+ # If you consider whitespace to be junk, the longest contiguous match
1049
+ # not starting with junk is "e Thread currentThread". So ndiff reported
1050
+ # that "e volatil" was inserted between the 't' and the 'e' in "private".
1051
+ # While an accurate view, to people that's absurd. The current version
1052
+ # looks for matching blocks that are entirely junk-free, then extends the
1053
+ # longest one of those as far as possible but only with matching junk.
1054
+ # So now "currentThread" is matched, then extended to suck up the
1055
+ # preceding blank; then "private" is matched, and extended to suck up the
1056
+ # following blank; then "Thread" is matched; and finally ndiff reports
1057
+ # that "volatile " was inserted before "Thread". The only quibble
1058
+ # remaining is that perhaps it was really the case that " volatile"
1059
+ # was inserted after "private". I can live with that <wink>.
1060
+
1061
+ def IS_LINE_JUNK(line, pat=None):
1062
+ r"""
1063
+ Return True for ignorable line: if `line` is blank or contains a single '#'.
1064
+
1065
+ Examples:
1066
+
1067
+ >>> IS_LINE_JUNK('\n')
1068
+ True
1069
+ >>> IS_LINE_JUNK(' # \n')
1070
+ True
1071
+ >>> IS_LINE_JUNK('hello\n')
1072
+ False
1073
+ """
1074
+
1075
+ if pat is None:
1076
+ # Default: match '#' or the empty string
1077
+ return line.strip() in '#'
1078
+ # Previous versions used the undocumented parameter 'pat' as a
1079
+ # match function. Retain this behaviour for compatibility.
1080
+ return pat(line) is not None
1081
+
1082
+ def IS_CHARACTER_JUNK(ch, ws=" \t"):
1083
+ r"""
1084
+ Return True for ignorable character: iff `ch` is a space or tab.
1085
+
1086
+ Examples:
1087
+
1088
+ >>> IS_CHARACTER_JUNK(' ')
1089
+ True
1090
+ >>> IS_CHARACTER_JUNK('\t')
1091
+ True
1092
+ >>> IS_CHARACTER_JUNK('\n')
1093
+ False
1094
+ >>> IS_CHARACTER_JUNK('x')
1095
+ False
1096
+ """
1097
+
1098
+ return ch in ws
1099
+
1100
+
1101
+ ########################################################################
1102
+ ### Unified Diff
1103
+ ########################################################################
1104
+
1105
+ def _format_range_unified(start, stop):
1106
+ 'Convert range to the "ed" format'
1107
+ # Per the diff spec at http://www.unix.org/single_unix_specification/
1108
+ beginning = start + 1 # lines start numbering with one
1109
+ length = stop - start
1110
+ if length == 1:
1111
+ return '{}'.format(beginning)
1112
+ if not length:
1113
+ beginning -= 1 # empty ranges begin at line just before the range
1114
+ return '{},{}'.format(beginning, length)
1115
+
1116
+ def unified_diff(a, b, fromfile='', tofile='', fromfiledate='',
1117
+ tofiledate='', n=3, lineterm='\n'):
1118
+ r"""
1119
+ Compare two sequences of lines; generate the delta as a unified diff.
1120
+
1121
+ Unified diffs are a compact way of showing line changes and a few
1122
+ lines of context. The number of context lines is set by 'n' which
1123
+ defaults to three.
1124
+
1125
+ By default, the diff control lines (those with ---, +++, or @@) are
1126
+ created with a trailing newline. This is helpful so that inputs
1127
+ created from file.readlines() result in diffs that are suitable for
1128
+ file.writelines() since both the inputs and outputs have trailing
1129
+ newlines.
1130
+
1131
+ For inputs that do not have trailing newlines, set the lineterm
1132
+ argument to "" so that the output will be uniformly newline free.
1133
+
1134
+ The unidiff format normally has a header for filenames and modification
1135
+ times. Any or all of these may be specified using strings for
1136
+ 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'.
1137
+ The modification times are normally expressed in the ISO 8601 format.
1138
+
1139
+ Example:
1140
+
1141
+ >>> for line in unified_diff('one two three four'.split(),
1142
+ ... 'zero one tree four'.split(), 'Original', 'Current',
1143
+ ... '2005-01-26 23:30:50', '2010-04-02 10:20:52',
1144
+ ... lineterm=''):
1145
+ ... print(line) # doctest: +NORMALIZE_WHITESPACE
1146
+ --- Original 2005-01-26 23:30:50
1147
+ +++ Current 2010-04-02 10:20:52
1148
+ @@ -1,4 +1,4 @@
1149
+ +zero
1150
+ one
1151
+ -two
1152
+ -three
1153
+ +tree
1154
+ four
1155
+ """
1156
+
1157
+ _check_types(a, b, fromfile, tofile, fromfiledate, tofiledate, lineterm)
1158
+ started = False
1159
+ for group in SequenceMatcher(None,a,b).get_grouped_opcodes(n):
1160
+ if not started:
1161
+ started = True
1162
+ fromdate = '\t{}'.format(fromfiledate) if fromfiledate else ''
1163
+ todate = '\t{}'.format(tofiledate) if tofiledate else ''
1164
+ yield '--- {}{}{}'.format(fromfile, fromdate, lineterm)
1165
+ yield '+++ {}{}{}'.format(tofile, todate, lineterm)
1166
+
1167
+ first, last = group[0], group[-1]
1168
+ file1_range = _format_range_unified(first[1], last[2])
1169
+ file2_range = _format_range_unified(first[3], last[4])
1170
+ yield '@@ -{} +{} @@{}'.format(file1_range, file2_range, lineterm)
1171
+
1172
+ for tag, i1, i2, j1, j2 in group:
1173
+ if tag == 'equal':
1174
+ for line in a[i1:i2]:
1175
+ yield ' ' + line
1176
+ continue
1177
+ if tag in {'replace', 'delete'}:
1178
+ for line in a[i1:i2]:
1179
+ yield '-' + line
1180
+ if tag in {'replace', 'insert'}:
1181
+ for line in b[j1:j2]:
1182
+ yield '+' + line
1183
+
1184
+
1185
+ ########################################################################
1186
+ ### Context Diff
1187
+ ########################################################################
1188
+
1189
+ def _format_range_context(start, stop):
1190
+ 'Convert range to the "ed" format'
1191
+ # Per the diff spec at http://www.unix.org/single_unix_specification/
1192
+ beginning = start + 1 # lines start numbering with one
1193
+ length = stop - start
1194
+ if not length:
1195
+ beginning -= 1 # empty ranges begin at line just before the range
1196
+ if length <= 1:
1197
+ return '{}'.format(beginning)
1198
+ return '{},{}'.format(beginning, beginning + length - 1)
1199
+
1200
+ # See http://www.unix.org/single_unix_specification/
1201
+ def context_diff(a, b, fromfile='', tofile='',
1202
+ fromfiledate='', tofiledate='', n=3, lineterm='\n'):
1203
+ r"""
1204
+ Compare two sequences of lines; generate the delta as a context diff.
1205
+
1206
+ Context diffs are a compact way of showing line changes and a few
1207
+ lines of context. The number of context lines is set by 'n' which
1208
+ defaults to three.
1209
+
1210
+ By default, the diff control lines (those with *** or ---) are
1211
+ created with a trailing newline. This is helpful so that inputs
1212
+ created from file.readlines() result in diffs that are suitable for
1213
+ file.writelines() since both the inputs and outputs have trailing
1214
+ newlines.
1215
+
1216
+ For inputs that do not have trailing newlines, set the lineterm
1217
+ argument to "" so that the output will be uniformly newline free.
1218
+
1219
+ The context diff format normally has a header for filenames and
1220
+ modification times. Any or all of these may be specified using
1221
+ strings for 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'.
1222
+ The modification times are normally expressed in the ISO 8601 format.
1223
+ If not specified, the strings default to blanks.
1224
+
1225
+ Example:
1226
+
1227
+ >>> print(''.join(context_diff('one\ntwo\nthree\nfour\n'.splitlines(True),
1228
+ ... 'zero\none\ntree\nfour\n'.splitlines(True), 'Original', 'Current')),
1229
+ ... end="")
1230
+ *** Original
1231
+ --- Current
1232
+ ***************
1233
+ *** 1,4 ****
1234
+ one
1235
+ ! two
1236
+ ! three
1237
+ four
1238
+ --- 1,4 ----
1239
+ + zero
1240
+ one
1241
+ ! tree
1242
+ four
1243
+ """
1244
+
1245
+ _check_types(a, b, fromfile, tofile, fromfiledate, tofiledate, lineterm)
1246
+ prefix = dict(insert='+ ', delete='- ', replace='! ', equal=' ')
1247
+ started = False
1248
+ for group in SequenceMatcher(None,a,b).get_grouped_opcodes(n):
1249
+ if not started:
1250
+ started = True
1251
+ fromdate = '\t{}'.format(fromfiledate) if fromfiledate else ''
1252
+ todate = '\t{}'.format(tofiledate) if tofiledate else ''
1253
+ yield '*** {}{}{}'.format(fromfile, fromdate, lineterm)
1254
+ yield '--- {}{}{}'.format(tofile, todate, lineterm)
1255
+
1256
+ first, last = group[0], group[-1]
1257
+ yield '***************' + lineterm
1258
+
1259
+ file1_range = _format_range_context(first[1], last[2])
1260
+ yield '*** {} ****{}'.format(file1_range, lineterm)
1261
+
1262
+ if any(tag in {'replace', 'delete'} for tag, _, _, _, _ in group):
1263
+ for tag, i1, i2, _, _ in group:
1264
+ if tag != 'insert':
1265
+ for line in a[i1:i2]:
1266
+ yield prefix[tag] + line
1267
+
1268
+ file2_range = _format_range_context(first[3], last[4])
1269
+ yield '--- {} ----{}'.format(file2_range, lineterm)
1270
+
1271
+ if any(tag in {'replace', 'insert'} for tag, _, _, _, _ in group):
1272
+ for tag, _, _, j1, j2 in group:
1273
+ if tag != 'delete':
1274
+ for line in b[j1:j2]:
1275
+ yield prefix[tag] + line
1276
+
1277
+ def _check_types(a, b, *args):
1278
+ # Checking types is weird, but the alternative is garbled output when
1279
+ # someone passes mixed bytes and str to {unified,context}_diff(). E.g.
1280
+ # without this check, passing filenames as bytes results in output like
1281
+ # --- b'oldfile.txt'
1282
+ # +++ b'newfile.txt'
1283
+ # because of how str.format() incorporates bytes objects.
1284
+ if a and not isinstance(a[0], str):
1285
+ raise TypeError('lines to compare must be str, not %s (%r)' %
1286
+ (type(a[0]).__name__, a[0]))
1287
+ if b and not isinstance(b[0], str):
1288
+ raise TypeError('lines to compare must be str, not %s (%r)' %
1289
+ (type(b[0]).__name__, b[0]))
1290
+ if isinstance(a, str):
1291
+ raise TypeError('input must be a sequence of strings, not %s' %
1292
+ type(a).__name__)
1293
+ if isinstance(b, str):
1294
+ raise TypeError('input must be a sequence of strings, not %s' %
1295
+ type(b).__name__)
1296
+ for arg in args:
1297
+ if not isinstance(arg, str):
1298
+ raise TypeError('all arguments must be str, not: %r' % (arg,))
1299
+
1300
+ def diff_bytes(dfunc, a, b, fromfile=b'', tofile=b'',
1301
+ fromfiledate=b'', tofiledate=b'', n=3, lineterm=b'\n'):
1302
+ r"""
1303
+ Compare `a` and `b`, two sequences of lines represented as bytes rather
1304
+ than str. This is a wrapper for `dfunc`, which is typically either
1305
+ unified_diff() or context_diff(). Inputs are losslessly converted to
1306
+ strings so that `dfunc` only has to worry about strings, and encoded
1307
+ back to bytes on return. This is necessary to compare files with
1308
+ unknown or inconsistent encoding. All other inputs (except `n`) must be
1309
+ bytes rather than str.
1310
+ """
1311
+ def decode(s):
1312
+ try:
1313
+ return s.decode('ascii', 'surrogateescape')
1314
+ except AttributeError as err:
1315
+ msg = ('all arguments must be bytes, not %s (%r)' %
1316
+ (type(s).__name__, s))
1317
+ raise TypeError(msg) from err
1318
+ a = list(map(decode, a))
1319
+ b = list(map(decode, b))
1320
+ fromfile = decode(fromfile)
1321
+ tofile = decode(tofile)
1322
+ fromfiledate = decode(fromfiledate)
1323
+ tofiledate = decode(tofiledate)
1324
+ lineterm = decode(lineterm)
1325
+
1326
+ lines = dfunc(a, b, fromfile, tofile, fromfiledate, tofiledate, n, lineterm)
1327
+ for line in lines:
1328
+ yield line.encode('ascii', 'surrogateescape')
1329
+
1330
+ def ndiff(a, b, linejunk=None, charjunk=IS_CHARACTER_JUNK):
1331
+ r"""
1332
+ Compare `a` and `b` (lists of strings); return a `Differ`-style delta.
1333
+
1334
+ Optional keyword parameters `linejunk` and `charjunk` are for filter
1335
+ functions, or can be None:
1336
+
1337
+ - linejunk: A function that should accept a single string argument and
1338
+ return true iff the string is junk. The default is None, and is
1339
+ recommended; the underlying SequenceMatcher class has an adaptive
1340
+ notion of "noise" lines.
1341
+
1342
+ - charjunk: A function that accepts a character (string of length
1343
+ 1), and returns true iff the character is junk. The default is
1344
+ the module-level function IS_CHARACTER_JUNK, which filters out
1345
+ whitespace characters (a blank or tab; note: it's a bad idea to
1346
+ include newline in this!).
1347
+
1348
+ Tools/scripts/ndiff.py is a command-line front-end to this function.
1349
+
1350
+ Example:
1351
+
1352
+ >>> diff = ndiff('one\ntwo\nthree\n'.splitlines(keepends=True),
1353
+ ... 'ore\ntree\nemu\n'.splitlines(keepends=True))
1354
+ >>> print(''.join(diff), end="")
1355
+ - one
1356
+ ? ^
1357
+ + ore
1358
+ ? ^
1359
+ - two
1360
+ - three
1361
+ ? -
1362
+ + tree
1363
+ + emu
1364
+ """
1365
+ return Differ(linejunk, charjunk).compare(a, b)
1366
+
1367
+ def _mdiff(fromlines, tolines, context=None, linejunk=None,
1368
+ charjunk=IS_CHARACTER_JUNK):
1369
+ r"""Returns generator yielding marked up from/to side by side differences.
1370
+
1371
+ Arguments:
1372
+ fromlines -- list of text lines to compared to tolines
1373
+ tolines -- list of text lines to be compared to fromlines
1374
+ context -- number of context lines to display on each side of difference,
1375
+ if None, all from/to text lines will be generated.
1376
+ linejunk -- passed on to ndiff (see ndiff documentation)
1377
+ charjunk -- passed on to ndiff (see ndiff documentation)
1378
+
1379
+ This function returns an iterator which returns a tuple:
1380
+ (from line tuple, to line tuple, boolean flag)
1381
+
1382
+ from/to line tuple -- (line num, line text)
1383
+ line num -- integer or None (to indicate a context separation)
1384
+ line text -- original line text with following markers inserted:
1385
+ '\0+' -- marks start of added text
1386
+ '\0-' -- marks start of deleted text
1387
+ '\0^' -- marks start of changed text
1388
+ '\1' -- marks end of added/deleted/changed text
1389
+
1390
+ boolean flag -- None indicates context separation, True indicates
1391
+ either "from" or "to" line contains a change, otherwise False.
1392
+
1393
+ This function/iterator was originally developed to generate side by side
1394
+ file difference for making HTML pages (see HtmlDiff class for example
1395
+ usage).
1396
+
1397
+ Note, this function utilizes the ndiff function to generate the side by
1398
+ side difference markup. Optional ndiff arguments may be passed to this
1399
+ function and they in turn will be passed to ndiff.
1400
+ """
1401
+ import re
1402
+
1403
+ # regular expression for finding intraline change indices
1404
+ change_re = re.compile(r'(\++|\-+|\^+)')
1405
+
1406
+ # create the difference iterator to generate the differences
1407
+ diff_lines_iterator = ndiff(fromlines,tolines,linejunk,charjunk)
1408
+
1409
+ def _make_line(lines, format_key, side, num_lines=[0,0]):
1410
+ """Returns line of text with user's change markup and line formatting.
1411
+
1412
+ lines -- list of lines from the ndiff generator to produce a line of
1413
+ text from. When producing the line of text to return, the
1414
+ lines used are removed from this list.
1415
+ format_key -- '+' return first line in list with "add" markup around
1416
+ the entire line.
1417
+ '-' return first line in list with "delete" markup around
1418
+ the entire line.
1419
+ '?' return first line in list with add/delete/change
1420
+ intraline markup (indices obtained from second line)
1421
+ None return first line in list with no markup
1422
+ side -- indice into the num_lines list (0=from,1=to)
1423
+ num_lines -- from/to current line number. This is NOT intended to be a
1424
+ passed parameter. It is present as a keyword argument to
1425
+ maintain memory of the current line numbers between calls
1426
+ of this function.
1427
+
1428
+ Note, this function is purposefully not defined at the module scope so
1429
+ that data it needs from its parent function (within whose context it
1430
+ is defined) does not need to be of module scope.
1431
+ """
1432
+ num_lines[side] += 1
1433
+ # Handle case where no user markup is to be added, just return line of
1434
+ # text with user's line format to allow for usage of the line number.
1435
+ if format_key is None:
1436
+ return (num_lines[side],lines.pop(0)[2:])
1437
+ # Handle case of intraline changes
1438
+ if format_key == '?':
1439
+ text, markers = lines.pop(0), lines.pop(0)
1440
+ # find intraline changes (store change type and indices in tuples)
1441
+ sub_info = []
1442
+ def record_sub_info(match_object,sub_info=sub_info):
1443
+ sub_info.append([match_object.group(1)[0],match_object.span()])
1444
+ return match_object.group(1)
1445
+ change_re.sub(record_sub_info,markers)
1446
+ # process each tuple inserting our special marks that won't be
1447
+ # noticed by an xml/html escaper.
1448
+ for key,(begin,end) in reversed(sub_info):
1449
+ text = text[0:begin]+'\0'+key+text[begin:end]+'\1'+text[end:]
1450
+ text = text[2:]
1451
+ # Handle case of add/delete entire line
1452
+ else:
1453
+ text = lines.pop(0)[2:]
1454
+ # if line of text is just a newline, insert a space so there is
1455
+ # something for the user to highlight and see.
1456
+ if not text:
1457
+ text = ' '
1458
+ # insert marks that won't be noticed by an xml/html escaper.
1459
+ text = '\0' + format_key + text + '\1'
1460
+ # Return line of text, first allow user's line formatter to do its
1461
+ # thing (such as adding the line number) then replace the special
1462
+ # marks with what the user's change markup.
1463
+ return (num_lines[side],text)
1464
+
1465
+ def _line_iterator():
1466
+ """Yields from/to lines of text with a change indication.
1467
+
1468
+ This function is an iterator. It itself pulls lines from a
1469
+ differencing iterator, processes them and yields them. When it can
1470
+ it yields both a "from" and a "to" line, otherwise it will yield one
1471
+ or the other. In addition to yielding the lines of from/to text, a
1472
+ boolean flag is yielded to indicate if the text line(s) have
1473
+ differences in them.
1474
+
1475
+ Note, this function is purposefully not defined at the module scope so
1476
+ that data it needs from its parent function (within whose context it
1477
+ is defined) does not need to be of module scope.
1478
+ """
1479
+ lines = []
1480
+ num_blanks_pending, num_blanks_to_yield = 0, 0
1481
+ while True:
1482
+ # Load up next 4 lines so we can look ahead, create strings which
1483
+ # are a concatenation of the first character of each of the 4 lines
1484
+ # so we can do some very readable comparisons.
1485
+ while len(lines) < 4:
1486
+ lines.append(next(diff_lines_iterator, 'X'))
1487
+ s = ''.join([line[0] for line in lines])
1488
+ if s.startswith('X'):
1489
+ # When no more lines, pump out any remaining blank lines so the
1490
+ # corresponding add/delete lines get a matching blank line so
1491
+ # all line pairs get yielded at the next level.
1492
+ num_blanks_to_yield = num_blanks_pending
1493
+ elif s.startswith('-?+?'):
1494
+ # simple intraline change
1495
+ yield _make_line(lines,'?',0), _make_line(lines,'?',1), True
1496
+ continue
1497
+ elif s.startswith('--++'):
1498
+ # in delete block, add block coming: we do NOT want to get
1499
+ # caught up on blank lines yet, just process the delete line
1500
+ num_blanks_pending -= 1
1501
+ yield _make_line(lines,'-',0), None, True
1502
+ continue
1503
+ elif s.startswith(('--?+', '--+', '- ')):
1504
+ # in delete block and see an intraline change or unchanged line
1505
+ # coming: yield the delete line and then blanks
1506
+ from_line,to_line = _make_line(lines,'-',0), None
1507
+ num_blanks_to_yield,num_blanks_pending = num_blanks_pending-1,0
1508
+ elif s.startswith('-+?'):
1509
+ # intraline change
1510
+ yield _make_line(lines,None,0), _make_line(lines,'?',1), True
1511
+ continue
1512
+ elif s.startswith('-?+'):
1513
+ # intraline change
1514
+ yield _make_line(lines,'?',0), _make_line(lines,None,1), True
1515
+ continue
1516
+ elif s.startswith('-'):
1517
+ # delete FROM line
1518
+ num_blanks_pending -= 1
1519
+ yield _make_line(lines,'-',0), None, True
1520
+ continue
1521
+ elif s.startswith('+--'):
1522
+ # in add block, delete block coming: we do NOT want to get
1523
+ # caught up on blank lines yet, just process the add line
1524
+ num_blanks_pending += 1
1525
+ yield None, _make_line(lines,'+',1), True
1526
+ continue
1527
+ elif s.startswith(('+ ', '+-')):
1528
+ # will be leaving an add block: yield blanks then add line
1529
+ from_line, to_line = None, _make_line(lines,'+',1)
1530
+ num_blanks_to_yield,num_blanks_pending = num_blanks_pending+1,0
1531
+ elif s.startswith('+'):
1532
+ # inside an add block, yield the add line
1533
+ num_blanks_pending += 1
1534
+ yield None, _make_line(lines,'+',1), True
1535
+ continue
1536
+ elif s.startswith(' '):
1537
+ # unchanged text, yield it to both sides
1538
+ yield _make_line(lines[:],None,0),_make_line(lines,None,1),False
1539
+ continue
1540
+ # Catch up on the blank lines so when we yield the next from/to
1541
+ # pair, they are lined up.
1542
+ while(num_blanks_to_yield < 0):
1543
+ num_blanks_to_yield += 1
1544
+ yield None,('','\n'),True
1545
+ while(num_blanks_to_yield > 0):
1546
+ num_blanks_to_yield -= 1
1547
+ yield ('','\n'),None,True
1548
+ if s.startswith('X'):
1549
+ return
1550
+ else:
1551
+ yield from_line,to_line,True
1552
+
1553
+ def _line_pair_iterator():
1554
+ """Yields from/to lines of text with a change indication.
1555
+
1556
+ This function is an iterator. It itself pulls lines from the line
1557
+ iterator. Its difference from that iterator is that this function
1558
+ always yields a pair of from/to text lines (with the change
1559
+ indication). If necessary it will collect single from/to lines
1560
+ until it has a matching pair from/to pair to yield.
1561
+
1562
+ Note, this function is purposefully not defined at the module scope so
1563
+ that data it needs from its parent function (within whose context it
1564
+ is defined) does not need to be of module scope.
1565
+ """
1566
+ line_iterator = _line_iterator()
1567
+ fromlines,tolines=[],[]
1568
+ while True:
1569
+ # Collecting lines of text until we have a from/to pair
1570
+ while (len(fromlines)==0 or len(tolines)==0):
1571
+ try:
1572
+ from_line, to_line, found_diff = next(line_iterator)
1573
+ except StopIteration:
1574
+ return
1575
+ if from_line is not None:
1576
+ fromlines.append((from_line,found_diff))
1577
+ if to_line is not None:
1578
+ tolines.append((to_line,found_diff))
1579
+ # Once we have a pair, remove them from the collection and yield it
1580
+ from_line, fromDiff = fromlines.pop(0)
1581
+ to_line, to_diff = tolines.pop(0)
1582
+ yield (from_line,to_line,fromDiff or to_diff)
1583
+
1584
+ # Handle case where user does not want context differencing, just yield
1585
+ # them up without doing anything else with them.
1586
+ line_pair_iterator = _line_pair_iterator()
1587
+ if context is None:
1588
+ yield from line_pair_iterator
1589
+ # Handle case where user wants context differencing. We must do some
1590
+ # storage of lines until we know for sure that they are to be yielded.
1591
+ else:
1592
+ context += 1
1593
+ lines_to_write = 0
1594
+ while True:
1595
+ # Store lines up until we find a difference, note use of a
1596
+ # circular queue because we only need to keep around what
1597
+ # we need for context.
1598
+ index, contextLines = 0, [None]*(context)
1599
+ found_diff = False
1600
+ while(found_diff is False):
1601
+ try:
1602
+ from_line, to_line, found_diff = next(line_pair_iterator)
1603
+ except StopIteration:
1604
+ return
1605
+ i = index % context
1606
+ contextLines[i] = (from_line, to_line, found_diff)
1607
+ index += 1
1608
+ # Yield lines that we have collected so far, but first yield
1609
+ # the user's separator.
1610
+ if index > context:
1611
+ yield None, None, None
1612
+ lines_to_write = context
1613
+ else:
1614
+ lines_to_write = index
1615
+ index = 0
1616
+ while(lines_to_write):
1617
+ i = index % context
1618
+ index += 1
1619
+ yield contextLines[i]
1620
+ lines_to_write -= 1
1621
+ # Now yield the context lines after the change
1622
+ lines_to_write = context-1
1623
+ try:
1624
+ while(lines_to_write):
1625
+ from_line, to_line, found_diff = next(line_pair_iterator)
1626
+ # If another change within the context, extend the context
1627
+ if found_diff:
1628
+ lines_to_write = context-1
1629
+ else:
1630
+ lines_to_write -= 1
1631
+ yield from_line, to_line, found_diff
1632
+ except StopIteration:
1633
+ # Catch exception from next() and return normally
1634
+ return
1635
+
1636
+
1637
+ _file_template = """
1638
+ <!DOCTYPE html>
1639
+ <html lang="en">
1640
+ <head>
1641
+ <meta charset="%(charset)s">
1642
+ <meta name="viewport" content="width=device-width, initial-scale=1">
1643
+ <title>Diff comparison</title>
1644
+ <style>%(styles)s
1645
+ </style>
1646
+ </head>
1647
+
1648
+ <body>
1649
+ %(table)s%(legend)s
1650
+ </body>
1651
+
1652
+ </html>"""
1653
+
1654
+ _styles = """
1655
+ :root {color-scheme: light dark}
1656
+ table.diff {
1657
+ font-family: Menlo, Consolas, Monaco, Liberation Mono, Lucida Console, monospace;
1658
+ border: medium;
1659
+ }
1660
+ .diff_header {
1661
+ background-color: #e0e0e0;
1662
+ font-weight: bold;
1663
+ }
1664
+ td.diff_header {
1665
+ text-align: right;
1666
+ padding: 0 8px;
1667
+ }
1668
+ .diff_next {
1669
+ background-color: #c0c0c0;
1670
+ padding: 4px 0;
1671
+ }
1672
+ .diff_add {background-color:palegreen}
1673
+ .diff_chg {background-color:#ffff77}
1674
+ .diff_sub {background-color:#ffaaaa}
1675
+ table.diff[summary="Legends"] {
1676
+ margin-top: 20px;
1677
+ border: 1px solid #ccc;
1678
+ }
1679
+ table.diff[summary="Legends"] th {
1680
+ background-color: #e0e0e0;
1681
+ padding: 4px 8px;
1682
+ }
1683
+ table.diff[summary="Legends"] td {
1684
+ padding: 4px 8px;
1685
+ }
1686
+
1687
+ @media (prefers-color-scheme: dark) {
1688
+ .diff_header {background-color:#666}
1689
+ .diff_next {background-color:#393939}
1690
+ .diff_add {background-color:darkgreen}
1691
+ .diff_chg {background-color:#847415}
1692
+ .diff_sub {background-color:darkred}
1693
+ table.diff[summary="Legends"] {border-color:#555}
1694
+ table.diff[summary="Legends"] th{background-color:#666}
1695
+ }"""
1696
+
1697
+ _table_template = """
1698
+ <table class="diff" id="difflib_chg_%(prefix)s_top"
1699
+ cellspacing="0" cellpadding="0" rules="groups" >
1700
+ <colgroup></colgroup> <colgroup></colgroup> <colgroup></colgroup>
1701
+ <colgroup></colgroup> <colgroup></colgroup> <colgroup></colgroup>
1702
+ %(header_row)s
1703
+ <tbody>
1704
+ %(data_rows)s </tbody>
1705
+ </table>"""
1706
+
1707
+ _legend = """
1708
+ <table class="diff" summary="Legends">
1709
+ <tr> <th colspan="2"> Legends </th> </tr>
1710
+ <tr> <td> <table border="" summary="Colors">
1711
+ <tr><th> Colors </th> </tr>
1712
+ <tr><td class="diff_add">&nbsp;Added&nbsp;</td></tr>
1713
+ <tr><td class="diff_chg">Changed</td> </tr>
1714
+ <tr><td class="diff_sub">Deleted</td> </tr>
1715
+ </table></td>
1716
+ <td> <table border="" summary="Links">
1717
+ <tr><th colspan="2"> Links </th> </tr>
1718
+ <tr><td>(f)irst change</td> </tr>
1719
+ <tr><td>(n)ext change</td> </tr>
1720
+ <tr><td>(t)op</td> </tr>
1721
+ </table></td> </tr>
1722
+ </table>"""
1723
+
1724
+ class HtmlDiff(object):
1725
+ """For producing HTML side by side comparison with change highlights.
1726
+
1727
+ This class can be used to create an HTML table (or a complete HTML file
1728
+ containing the table) showing a side by side, line by line comparison
1729
+ of text with inter-line and intra-line change highlights. The table can
1730
+ be generated in either full or contextual difference mode.
1731
+
1732
+ The following methods are provided for HTML generation:
1733
+
1734
+ make_table -- generates HTML for a single side by side table
1735
+ make_file -- generates complete HTML file with a single side by side table
1736
+
1737
+ See Doc/includes/diff.py for an example usage of this class.
1738
+ """
1739
+
1740
+ _file_template = _file_template
1741
+ _styles = _styles
1742
+ _table_template = _table_template
1743
+ _legend = _legend
1744
+ _default_prefix = 0
1745
+
1746
+ def __init__(self,tabsize=8,wrapcolumn=None,linejunk=None,
1747
+ charjunk=IS_CHARACTER_JUNK):
1748
+ """HtmlDiff instance initializer
1749
+
1750
+ Arguments:
1751
+ tabsize -- tab stop spacing, defaults to 8.
1752
+ wrapcolumn -- column number where lines are broken and wrapped,
1753
+ defaults to None where lines are not wrapped.
1754
+ linejunk,charjunk -- keyword arguments passed into ndiff() (used by
1755
+ HtmlDiff() to generate the side by side HTML differences). See
1756
+ ndiff() documentation for argument default values and descriptions.
1757
+ """
1758
+ self._tabsize = tabsize
1759
+ self._wrapcolumn = wrapcolumn
1760
+ self._linejunk = linejunk
1761
+ self._charjunk = charjunk
1762
+
1763
+ def make_file(self, fromlines, tolines, fromdesc='', todesc='',
1764
+ context=False, numlines=5, *, charset='utf-8'):
1765
+ """Returns HTML file of side by side comparison with change highlights
1766
+
1767
+ Arguments:
1768
+ fromlines -- list of "from" lines
1769
+ tolines -- list of "to" lines
1770
+ fromdesc -- "from" file column header string
1771
+ todesc -- "to" file column header string
1772
+ context -- set to True for contextual differences (defaults to False
1773
+ which shows full differences).
1774
+ numlines -- number of context lines. When context is set True,
1775
+ controls number of lines displayed before and after the change.
1776
+ When context is False, controls the number of lines to place
1777
+ the "next" link anchors before the next change (so click of
1778
+ "next" link jumps to just before the change).
1779
+ charset -- charset of the HTML document
1780
+ """
1781
+
1782
+ return (self._file_template % dict(
1783
+ styles=self._styles,
1784
+ legend=self._legend,
1785
+ table=self.make_table(fromlines, tolines, fromdesc, todesc,
1786
+ context=context, numlines=numlines),
1787
+ charset=charset
1788
+ )).encode(charset, 'xmlcharrefreplace').decode(charset)
1789
+
1790
+ def _tab_newline_replace(self,fromlines,tolines):
1791
+ """Returns from/to line lists with tabs expanded and newlines removed.
1792
+
1793
+ Instead of tab characters being replaced by the number of spaces
1794
+ needed to fill in to the next tab stop, this function will fill
1795
+ the space with tab characters. This is done so that the difference
1796
+ algorithms can identify changes in a file when tabs are replaced by
1797
+ spaces and vice versa. At the end of the HTML generation, the tab
1798
+ characters will be replaced with a nonbreakable space.
1799
+ """
1800
+ def expand_tabs(line):
1801
+ # hide real spaces
1802
+ line = line.replace(' ','\0')
1803
+ # expand tabs into spaces
1804
+ line = line.expandtabs(self._tabsize)
1805
+ # replace spaces from expanded tabs back into tab characters
1806
+ # (we'll replace them with markup after we do differencing)
1807
+ line = line.replace(' ','\t')
1808
+ return line.replace('\0',' ').rstrip('\n')
1809
+ fromlines = [expand_tabs(line) for line in fromlines]
1810
+ tolines = [expand_tabs(line) for line in tolines]
1811
+ return fromlines,tolines
1812
+
1813
+ def _split_line(self,data_list,line_num,text):
1814
+ """Builds list of text lines by splitting text lines at wrap point
1815
+
1816
+ This function will determine if the input text line needs to be
1817
+ wrapped (split) into separate lines. If so, the first wrap point
1818
+ will be determined and the first line appended to the output
1819
+ text line list. This function is used recursively to handle
1820
+ the second part of the split line to further split it.
1821
+ """
1822
+ # if blank line or context separator, just add it to the output list
1823
+ if not line_num:
1824
+ data_list.append((line_num,text))
1825
+ return
1826
+
1827
+ # if line text doesn't need wrapping, just add it to the output list
1828
+ size = len(text)
1829
+ max = self._wrapcolumn
1830
+ if (size <= max) or ((size -(text.count('\0')*3)) <= max):
1831
+ data_list.append((line_num,text))
1832
+ return
1833
+
1834
+ # scan text looking for the wrap point, keeping track if the wrap
1835
+ # point is inside markers
1836
+ i = 0
1837
+ n = 0
1838
+ mark = ''
1839
+ while n < max and i < size:
1840
+ if text[i] == '\0':
1841
+ i += 1
1842
+ mark = text[i]
1843
+ i += 1
1844
+ elif text[i] == '\1':
1845
+ i += 1
1846
+ mark = ''
1847
+ else:
1848
+ i += 1
1849
+ n += 1
1850
+
1851
+ # wrap point is inside text, break it up into separate lines
1852
+ line1 = text[:i]
1853
+ line2 = text[i:]
1854
+
1855
+ # if wrap point is inside markers, place end marker at end of first
1856
+ # line and start marker at beginning of second line because each
1857
+ # line will have its own table tag markup around it.
1858
+ if mark:
1859
+ line1 = line1 + '\1'
1860
+ line2 = '\0' + mark + line2
1861
+
1862
+ # tack on first line onto the output list
1863
+ data_list.append((line_num,line1))
1864
+
1865
+ # use this routine again to wrap the remaining text
1866
+ self._split_line(data_list,'>',line2)
1867
+
1868
+ def _line_wrapper(self,diffs):
1869
+ """Returns iterator that splits (wraps) mdiff text lines"""
1870
+
1871
+ # pull from/to data and flags from mdiff iterator
1872
+ for fromdata,todata,flag in diffs:
1873
+ # check for context separators and pass them through
1874
+ if flag is None:
1875
+ yield fromdata,todata,flag
1876
+ continue
1877
+ (fromline,fromtext),(toline,totext) = fromdata,todata
1878
+ # for each from/to line split it at the wrap column to form
1879
+ # list of text lines.
1880
+ fromlist,tolist = [],[]
1881
+ self._split_line(fromlist,fromline,fromtext)
1882
+ self._split_line(tolist,toline,totext)
1883
+ # yield from/to line in pairs inserting blank lines as
1884
+ # necessary when one side has more wrapped lines
1885
+ while fromlist or tolist:
1886
+ if fromlist:
1887
+ fromdata = fromlist.pop(0)
1888
+ else:
1889
+ fromdata = ('',' ')
1890
+ if tolist:
1891
+ todata = tolist.pop(0)
1892
+ else:
1893
+ todata = ('',' ')
1894
+ yield fromdata,todata,flag
1895
+
1896
+ def _collect_lines(self,diffs):
1897
+ """Collects mdiff output into separate lists
1898
+
1899
+ Before storing the mdiff from/to data into a list, it is converted
1900
+ into a single line of text with HTML markup.
1901
+ """
1902
+
1903
+ fromlist,tolist,flaglist = [],[],[]
1904
+ # pull from/to data and flags from mdiff style iterator
1905
+ for fromdata,todata,flag in diffs:
1906
+ try:
1907
+ # store HTML markup of the lines into the lists
1908
+ fromlist.append(self._format_line(0,flag,*fromdata))
1909
+ tolist.append(self._format_line(1,flag,*todata))
1910
+ except TypeError:
1911
+ # exceptions occur for lines where context separators go
1912
+ fromlist.append(None)
1913
+ tolist.append(None)
1914
+ flaglist.append(flag)
1915
+ return fromlist,tolist,flaglist
1916
+
1917
+ def _format_line(self,side,flag,linenum,text):
1918
+ """Returns HTML markup of "from" / "to" text lines
1919
+
1920
+ side -- 0 or 1 indicating "from" or "to" text
1921
+ flag -- indicates if difference on line
1922
+ linenum -- line number (used for line number column)
1923
+ text -- line text to be marked up
1924
+ """
1925
+ try:
1926
+ linenum = '%d' % linenum
1927
+ id = ' id="%s%s"' % (self._prefix[side],linenum)
1928
+ except TypeError:
1929
+ # handle blank lines where linenum is '>' or ''
1930
+ id = ''
1931
+ # replace those things that would get confused with HTML symbols
1932
+ text=text.replace("&","&amp;").replace(">","&gt;").replace("<","&lt;")
1933
+
1934
+ # make space non-breakable so they don't get compressed or line wrapped
1935
+ text = text.replace(' ','&nbsp;').rstrip()
1936
+
1937
+ return '<td class="diff_header"%s>%s</td><td nowrap="nowrap">%s</td>' \
1938
+ % (id,linenum,text)
1939
+
1940
+ def _make_prefix(self):
1941
+ """Create unique anchor prefixes"""
1942
+
1943
+ # Generate a unique anchor prefix so multiple tables
1944
+ # can exist on the same HTML page without conflicts.
1945
+ fromprefix = "from%d_" % HtmlDiff._default_prefix
1946
+ toprefix = "to%d_" % HtmlDiff._default_prefix
1947
+ HtmlDiff._default_prefix += 1
1948
+ # store prefixes so line format method has access
1949
+ self._prefix = [fromprefix,toprefix]
1950
+
1951
+ def _convert_flags(self,fromlist,tolist,flaglist,context,numlines):
1952
+ """Makes list of "next" links"""
1953
+
1954
+ # all anchor names will be generated using the unique "to" prefix
1955
+ toprefix = self._prefix[1]
1956
+
1957
+ # process change flags, generating middle column of next anchors/links
1958
+ next_id = ['']*len(flaglist)
1959
+ next_href = ['']*len(flaglist)
1960
+ num_chg, in_change = 0, False
1961
+ last = 0
1962
+ for i,flag in enumerate(flaglist):
1963
+ if flag:
1964
+ if not in_change:
1965
+ in_change = True
1966
+ last = i
1967
+ # at the beginning of a change, drop an anchor a few lines
1968
+ # (the context lines) before the change for the previous
1969
+ # link
1970
+ i = max([0,i-numlines])
1971
+ next_id[i] = ' id="difflib_chg_%s_%d"' % (toprefix,num_chg)
1972
+ # at the beginning of a change, drop a link to the next
1973
+ # change
1974
+ num_chg += 1
1975
+ next_href[last] = '<a href="#difflib_chg_%s_%d">n</a>' % (
1976
+ toprefix,num_chg)
1977
+ else:
1978
+ in_change = False
1979
+ # check for cases where there is no content to avoid exceptions
1980
+ if not flaglist:
1981
+ flaglist = [False]
1982
+ next_id = ['']
1983
+ next_href = ['']
1984
+ last = 0
1985
+ if context:
1986
+ fromlist = ['<td></td><td>&nbsp;No Differences Found&nbsp;</td>']
1987
+ tolist = fromlist
1988
+ else:
1989
+ fromlist = tolist = ['<td></td><td>&nbsp;Empty File&nbsp;</td>']
1990
+ # if not a change on first line, drop a link
1991
+ if not flaglist[0]:
1992
+ next_href[0] = '<a href="#difflib_chg_%s_0">f</a>' % toprefix
1993
+ # redo the last link to link to the top
1994
+ next_href[last] = '<a href="#difflib_chg_%s_top">t</a>' % (toprefix)
1995
+
1996
+ return fromlist,tolist,flaglist,next_href,next_id
1997
+
1998
+ def make_table(self,fromlines,tolines,fromdesc='',todesc='',context=False,
1999
+ numlines=5):
2000
+ """Returns HTML table of side by side comparison with change highlights
2001
+
2002
+ Arguments:
2003
+ fromlines -- list of "from" lines
2004
+ tolines -- list of "to" lines
2005
+ fromdesc -- "from" file column header string
2006
+ todesc -- "to" file column header string
2007
+ context -- set to True for contextual differences (defaults to False
2008
+ which shows full differences).
2009
+ numlines -- number of context lines. When context is set True,
2010
+ controls number of lines displayed before and after the change.
2011
+ When context is False, controls the number of lines to place
2012
+ the "next" link anchors before the next change (so click of
2013
+ "next" link jumps to just before the change).
2014
+ """
2015
+
2016
+ # make unique anchor prefixes so that multiple tables may exist
2017
+ # on the same page without conflict.
2018
+ self._make_prefix()
2019
+
2020
+ # change tabs to spaces before it gets more difficult after we insert
2021
+ # markup
2022
+ fromlines,tolines = self._tab_newline_replace(fromlines,tolines)
2023
+
2024
+ # create diffs iterator which generates side by side from/to data
2025
+ if context:
2026
+ context_lines = numlines
2027
+ else:
2028
+ context_lines = None
2029
+ diffs = _mdiff(fromlines,tolines,context_lines,linejunk=self._linejunk,
2030
+ charjunk=self._charjunk)
2031
+
2032
+ # set up iterator to wrap lines that exceed desired width
2033
+ if self._wrapcolumn:
2034
+ diffs = self._line_wrapper(diffs)
2035
+
2036
+ # collect up from/to lines and flags into lists (also format the lines)
2037
+ fromlist,tolist,flaglist = self._collect_lines(diffs)
2038
+
2039
+ # process change flags, generating middle column of next anchors/links
2040
+ fromlist,tolist,flaglist,next_href,next_id = self._convert_flags(
2041
+ fromlist,tolist,flaglist,context,numlines)
2042
+
2043
+ s = []
2044
+ fmt = ' <tr><td class="diff_next"%s>%s</td>%s' + \
2045
+ '<td class="diff_next">%s</td>%s</tr>\n'
2046
+ for i in range(len(flaglist)):
2047
+ if flaglist[i] is None:
2048
+ # mdiff yields None on separator lines skip the bogus ones
2049
+ # generated for the first line
2050
+ if i > 0:
2051
+ s.append(' </tbody> \n <tbody>\n')
2052
+ else:
2053
+ s.append( fmt % (next_id[i],next_href[i],fromlist[i],
2054
+ next_href[i],tolist[i]))
2055
+ if fromdesc or todesc:
2056
+ header_row = '<thead><tr>%s%s%s%s</tr></thead>' % (
2057
+ '<th class="diff_next"><br /></th>',
2058
+ '<th colspan="2" class="diff_header">%s</th>' % fromdesc,
2059
+ '<th class="diff_next"><br /></th>',
2060
+ '<th colspan="2" class="diff_header">%s</th>' % todesc)
2061
+ else:
2062
+ header_row = ''
2063
+
2064
+ table = self._table_template % dict(
2065
+ data_rows=''.join(s),
2066
+ header_row=header_row,
2067
+ prefix=self._prefix[1])
2068
+
2069
+ return table.replace('\0+','<span class="diff_add">'). \
2070
+ replace('\0-','<span class="diff_sub">'). \
2071
+ replace('\0^','<span class="diff_chg">'). \
2072
+ replace('\1','</span>'). \
2073
+ replace('\t','&nbsp;')
2074
+
2075
+
2076
+ def restore(delta, which):
2077
+ r"""
2078
+ Generate one of the two sequences that generated a delta.
2079
+
2080
+ Given a `delta` produced by `Differ.compare()` or `ndiff()`, extract
2081
+ lines originating from file 1 or 2 (parameter `which`), stripping off line
2082
+ prefixes.
2083
+
2084
+ Examples:
2085
+
2086
+ >>> diff = ndiff('one\ntwo\nthree\n'.splitlines(keepends=True),
2087
+ ... 'ore\ntree\nemu\n'.splitlines(keepends=True))
2088
+ >>> diff = list(diff)
2089
+ >>> print(''.join(restore(diff, 1)), end="")
2090
+ one
2091
+ two
2092
+ three
2093
+ >>> print(''.join(restore(diff, 2)), end="")
2094
+ ore
2095
+ tree
2096
+ emu
2097
+ """
2098
+ try:
2099
+ tag = {1: "- ", 2: "+ "}[int(which)]
2100
+ except KeyError:
2101
+ raise ValueError('unknown delta choice (must be 1 or 2): %r'
2102
+ % which) from None
2103
+ prefixes = (" ", tag)
2104
+ for line in delta:
2105
+ if line[:2] in prefixes:
2106
+ yield line[2:]
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/formfill.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lxml.etree import XPath, ElementBase
2
+ from lxml.html import fromstring, XHTML_NAMESPACE
3
+ from lxml.html import _forms_xpath, _options_xpath, _nons, _transform_result
4
+ from lxml.html import defs
5
+ import copy
6
+
7
+ try:
8
+ basestring
9
+ except NameError:
10
+ # Python 3
11
+ basestring = str
12
+
13
+ __all__ = ['FormNotFound', 'fill_form', 'fill_form_html',
14
+ 'insert_errors', 'insert_errors_html',
15
+ 'DefaultErrorCreator']
16
+
17
+ class FormNotFound(LookupError):
18
+ """
19
+ Raised when no form can be found
20
+ """
21
+
22
+ _form_name_xpath = XPath('descendant-or-self::form[name=$name]|descendant-or-self::x:form[name=$name]', namespaces={'x':XHTML_NAMESPACE})
23
+ _input_xpath = XPath('|'.join(['descendant-or-self::'+_tag for _tag in ('input','select','textarea','x:input','x:select','x:textarea')]),
24
+ namespaces={'x':XHTML_NAMESPACE})
25
+ _label_for_xpath = XPath('//label[@for=$for_id]|//x:label[@for=$for_id]',
26
+ namespaces={'x':XHTML_NAMESPACE})
27
+ _name_xpath = XPath('descendant-or-self::*[@name=$name]')
28
+
29
+ def fill_form(
30
+ el,
31
+ values,
32
+ form_id=None,
33
+ form_index=None,
34
+ ):
35
+ el = _find_form(el, form_id=form_id, form_index=form_index)
36
+ _fill_form(el, values)
37
+
38
+ def fill_form_html(html, values, form_id=None, form_index=None):
39
+ result_type = type(html)
40
+ if isinstance(html, basestring):
41
+ doc = fromstring(html)
42
+ else:
43
+ doc = copy.deepcopy(html)
44
+ fill_form(doc, values, form_id=form_id, form_index=form_index)
45
+ return _transform_result(result_type, doc)
46
+
47
+ def _fill_form(el, values):
48
+ counts = {}
49
+ if hasattr(values, 'mixed'):
50
+ # For Paste request parameters
51
+ values = values.mixed()
52
+ inputs = _input_xpath(el)
53
+ for input in inputs:
54
+ name = input.get('name')
55
+ if not name:
56
+ continue
57
+ if _takes_multiple(input):
58
+ value = values.get(name, [])
59
+ if not isinstance(value, (list, tuple)):
60
+ value = [value]
61
+ _fill_multiple(input, value)
62
+ elif name not in values:
63
+ continue
64
+ else:
65
+ index = counts.get(name, 0)
66
+ counts[name] = index + 1
67
+ value = values[name]
68
+ if isinstance(value, (list, tuple)):
69
+ try:
70
+ value = value[index]
71
+ except IndexError:
72
+ continue
73
+ elif index > 0:
74
+ continue
75
+ _fill_single(input, value)
76
+
77
+ def _takes_multiple(input):
78
+ if _nons(input.tag) == 'select' and input.get('multiple'):
79
+ # FIXME: multiple="0"?
80
+ return True
81
+ type = input.get('type', '').lower()
82
+ if type in ('radio', 'checkbox'):
83
+ return True
84
+ return False
85
+
86
+ def _fill_multiple(input, value):
87
+ type = input.get('type', '').lower()
88
+ if type == 'checkbox':
89
+ v = input.get('value')
90
+ if v is None:
91
+ if not value:
92
+ result = False
93
+ else:
94
+ result = value[0]
95
+ if isinstance(value, basestring):
96
+ # The only valid "on" value for an unnamed checkbox is 'on'
97
+ result = result == 'on'
98
+ _check(input, result)
99
+ else:
100
+ _check(input, v in value)
101
+ elif type == 'radio':
102
+ v = input.get('value')
103
+ _check(input, v in value)
104
+ else:
105
+ assert _nons(input.tag) == 'select'
106
+ for option in _options_xpath(input):
107
+ v = option.get('value')
108
+ if v is None:
109
+ # This seems to be the default, at least on IE
110
+ # FIXME: but I'm not sure
111
+ v = option.text_content()
112
+ _select(option, v in value)
113
+
114
+ def _check(el, check):
115
+ if check:
116
+ el.set('checked', '')
117
+ else:
118
+ if 'checked' in el.attrib:
119
+ del el.attrib['checked']
120
+
121
+ def _select(el, select):
122
+ if select:
123
+ el.set('selected', '')
124
+ else:
125
+ if 'selected' in el.attrib:
126
+ del el.attrib['selected']
127
+
128
+ def _fill_single(input, value):
129
+ if _nons(input.tag) == 'textarea':
130
+ input.text = value
131
+ else:
132
+ input.set('value', value)
133
+
134
+ def _find_form(el, form_id=None, form_index=None):
135
+ if form_id is None and form_index is None:
136
+ forms = _forms_xpath(el)
137
+ for form in forms:
138
+ return form
139
+ raise FormNotFound(
140
+ "No forms in page")
141
+ if form_id is not None:
142
+ form = el.get_element_by_id(form_id)
143
+ if form is not None:
144
+ return form
145
+ forms = _form_name_xpath(el, name=form_id)
146
+ if forms:
147
+ return forms[0]
148
+ else:
149
+ raise FormNotFound(
150
+ "No form with the name or id of %r (forms: %s)"
151
+ % (id, ', '.join(_find_form_ids(el))))
152
+ if form_index is not None:
153
+ forms = _forms_xpath(el)
154
+ try:
155
+ return forms[form_index]
156
+ except IndexError:
157
+ raise FormNotFound(
158
+ "There is no form with the index %r (%i forms found)"
159
+ % (form_index, len(forms)))
160
+
161
+ def _find_form_ids(el):
162
+ forms = _forms_xpath(el)
163
+ if not forms:
164
+ yield '(no forms)'
165
+ return
166
+ for index, form in enumerate(forms):
167
+ if form.get('id'):
168
+ if form.get('name'):
169
+ yield '%s or %s' % (form.get('id'),
170
+ form.get('name'))
171
+ else:
172
+ yield form.get('id')
173
+ elif form.get('name'):
174
+ yield form.get('name')
175
+ else:
176
+ yield '(unnamed form %s)' % index
177
+
178
+ ############################################################
179
+ ## Error filling
180
+ ############################################################
181
+
182
+ class DefaultErrorCreator:
183
+ insert_before = True
184
+ block_inside = True
185
+ error_container_tag = 'div'
186
+ error_message_class = 'error-message'
187
+ error_block_class = 'error-block'
188
+ default_message = "Invalid"
189
+
190
+ def __init__(self, **kw):
191
+ for name, value in kw.items():
192
+ if not hasattr(self, name):
193
+ raise TypeError(
194
+ "Unexpected keyword argument: %s" % name)
195
+ setattr(self, name, value)
196
+
197
+ def __call__(self, el, is_block, message):
198
+ error_el = el.makeelement(self.error_container_tag)
199
+ if self.error_message_class:
200
+ error_el.set('class', self.error_message_class)
201
+ if is_block and self.error_block_class:
202
+ error_el.set('class', error_el.get('class', '')+' '+self.error_block_class)
203
+ if message is None or message == '':
204
+ message = self.default_message
205
+ if isinstance(message, ElementBase):
206
+ error_el.append(message)
207
+ else:
208
+ assert isinstance(message, basestring), (
209
+ "Bad message; should be a string or element: %r" % message)
210
+ error_el.text = message or self.default_message
211
+ if is_block and self.block_inside:
212
+ if self.insert_before:
213
+ error_el.tail = el.text
214
+ el.text = None
215
+ el.insert(0, error_el)
216
+ else:
217
+ el.append(error_el)
218
+ else:
219
+ parent = el.getparent()
220
+ pos = parent.index(el)
221
+ if self.insert_before:
222
+ parent.insert(pos, error_el)
223
+ else:
224
+ error_el.tail = el.tail
225
+ el.tail = None
226
+ parent.insert(pos+1, error_el)
227
+
228
+ default_error_creator = DefaultErrorCreator()
229
+
230
+
231
+ def insert_errors(
232
+ el,
233
+ errors,
234
+ form_id=None,
235
+ form_index=None,
236
+ error_class="error",
237
+ error_creator=default_error_creator,
238
+ ):
239
+ el = _find_form(el, form_id=form_id, form_index=form_index)
240
+ for name, error in errors.items():
241
+ if error is None:
242
+ continue
243
+ for error_el, message in _find_elements_for_name(el, name, error):
244
+ assert isinstance(message, (basestring, type(None), ElementBase)), (
245
+ "Bad message: %r" % message)
246
+ _insert_error(error_el, message, error_class, error_creator)
247
+
248
+ def insert_errors_html(html, values, **kw):
249
+ result_type = type(html)
250
+ if isinstance(html, basestring):
251
+ doc = fromstring(html)
252
+ else:
253
+ doc = copy.deepcopy(html)
254
+ insert_errors(doc, values, **kw)
255
+ return _transform_result(result_type, doc)
256
+
257
+ def _insert_error(el, error, error_class, error_creator):
258
+ if _nons(el.tag) in defs.empty_tags or _nons(el.tag) == 'textarea':
259
+ is_block = False
260
+ else:
261
+ is_block = True
262
+ if _nons(el.tag) != 'form' and error_class:
263
+ _add_class(el, error_class)
264
+ if el.get('id'):
265
+ labels = _label_for_xpath(el, for_id=el.get('id'))
266
+ if labels:
267
+ for label in labels:
268
+ _add_class(label, error_class)
269
+ error_creator(el, is_block, error)
270
+
271
+ def _add_class(el, class_name):
272
+ if el.get('class'):
273
+ el.set('class', el.get('class')+' '+class_name)
274
+ else:
275
+ el.set('class', class_name)
276
+
277
+ def _find_elements_for_name(form, name, error):
278
+ if name is None:
279
+ # An error for the entire form
280
+ yield form, error
281
+ return
282
+ if name.startswith('#'):
283
+ # By id
284
+ el = form.get_element_by_id(name[1:])
285
+ if el is not None:
286
+ yield el, error
287
+ return
288
+ els = _name_xpath(form, name=name)
289
+ if not els:
290
+ # FIXME: should this raise an exception?
291
+ return
292
+ if not isinstance(error, (list, tuple)):
293
+ yield els[0], error
294
+ return
295
+ # FIXME: if error is longer than els, should it raise an error?
296
+ for el, err in zip(els, error):
297
+ if err is None:
298
+ continue
299
+ yield el, err
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/html/html5parser.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ An interface to html5lib that mimics the lxml.html interface.
3
+ """
4
+ import sys
5
+ import string
6
+
7
+ from html5lib import HTMLParser as _HTMLParser
8
+ from html5lib.treebuilders.etree_lxml import TreeBuilder
9
+ from lxml import etree
10
+ from lxml.html import Element, XHTML_NAMESPACE, _contains_block_level_tag
11
+
12
+ # python3 compatibility
13
+ try:
14
+ _strings = basestring
15
+ except NameError:
16
+ _strings = (bytes, str)
17
+ try:
18
+ from urllib2 import urlopen
19
+ except ImportError:
20
+ from urllib.request import urlopen
21
+ try:
22
+ from urlparse import urlparse
23
+ except ImportError:
24
+ from urllib.parse import urlparse
25
+
26
+
27
+ class HTMLParser(_HTMLParser):
28
+ """An html5lib HTML parser with lxml as tree."""
29
+
30
+ def __init__(self, strict=False, **kwargs):
31
+ _HTMLParser.__init__(self, strict=strict, tree=TreeBuilder, **kwargs)
32
+
33
+
34
+ try:
35
+ from html5lib import XHTMLParser as _XHTMLParser
36
+ except ImportError:
37
+ pass
38
+ else:
39
+ class XHTMLParser(_XHTMLParser):
40
+ """An html5lib XHTML Parser with lxml as tree."""
41
+
42
+ def __init__(self, strict=False, **kwargs):
43
+ _XHTMLParser.__init__(self, strict=strict, tree=TreeBuilder, **kwargs)
44
+
45
+ xhtml_parser = XHTMLParser()
46
+
47
+
48
+ def _find_tag(tree, tag):
49
+ elem = tree.find(tag)
50
+ if elem is not None:
51
+ return elem
52
+ return tree.find('{%s}%s' % (XHTML_NAMESPACE, tag))
53
+
54
+
55
+ def document_fromstring(html, guess_charset=None, parser=None):
56
+ """
57
+ Parse a whole document into a string.
58
+
59
+ If `guess_charset` is true, or if the input is not Unicode but a
60
+ byte string, the `chardet` library will perform charset guessing
61
+ on the string.
62
+ """
63
+ if not isinstance(html, _strings):
64
+ raise TypeError('string required')
65
+
66
+ if parser is None:
67
+ parser = html_parser
68
+
69
+ options = {}
70
+ if guess_charset is None and isinstance(html, bytes):
71
+ # html5lib does not accept useChardet as an argument, if it
72
+ # detected the html argument would produce unicode objects.
73
+ guess_charset = True
74
+ if guess_charset is not None:
75
+ options['useChardet'] = guess_charset
76
+ return parser.parse(html, **options).getroot()
77
+
78
+
79
+ def fragments_fromstring(html, no_leading_text=False,
80
+ guess_charset=None, parser=None):
81
+ """Parses several HTML elements, returning a list of elements.
82
+
83
+ The first item in the list may be a string. If no_leading_text is true,
84
+ then it will be an error if there is leading text, and it will always be
85
+ a list of only elements.
86
+
87
+ If `guess_charset` is true, the `chardet` library will perform charset
88
+ guessing on the string.
89
+ """
90
+ if not isinstance(html, _strings):
91
+ raise TypeError('string required')
92
+
93
+ if parser is None:
94
+ parser = html_parser
95
+
96
+ options = {}
97
+ if guess_charset is None and isinstance(html, bytes):
98
+ # html5lib does not accept useChardet as an argument, if it
99
+ # detected the html argument would produce unicode objects.
100
+ guess_charset = False
101
+ if guess_charset is not None:
102
+ options['useChardet'] = guess_charset
103
+ children = parser.parseFragment(html, 'div', **options)
104
+ if children and isinstance(children[0], _strings):
105
+ if no_leading_text:
106
+ if children[0].strip():
107
+ raise etree.ParserError('There is leading text: %r' %
108
+ children[0])
109
+ del children[0]
110
+ return children
111
+
112
+
113
+ def fragment_fromstring(html, create_parent=False,
114
+ guess_charset=None, parser=None):
115
+ """Parses a single HTML element; it is an error if there is more than
116
+ one element, or if anything but whitespace precedes or follows the
117
+ element.
118
+
119
+ If 'create_parent' is true (or is a tag name) then a parent node
120
+ will be created to encapsulate the HTML in a single element. In
121
+ this case, leading or trailing text is allowed.
122
+
123
+ If `guess_charset` is true, the `chardet` library will perform charset
124
+ guessing on the string.
125
+ """
126
+ if not isinstance(html, _strings):
127
+ raise TypeError('string required')
128
+
129
+ accept_leading_text = bool(create_parent)
130
+
131
+ elements = fragments_fromstring(
132
+ html, guess_charset=guess_charset, parser=parser,
133
+ no_leading_text=not accept_leading_text)
134
+
135
+ if create_parent:
136
+ if not isinstance(create_parent, _strings):
137
+ create_parent = 'div'
138
+ new_root = Element(create_parent)
139
+ if elements:
140
+ if isinstance(elements[0], _strings):
141
+ new_root.text = elements[0]
142
+ del elements[0]
143
+ new_root.extend(elements)
144
+ return new_root
145
+
146
+ if not elements:
147
+ raise etree.ParserError('No elements found')
148
+ if len(elements) > 1:
149
+ raise etree.ParserError('Multiple elements found')
150
+ result = elements[0]
151
+ if result.tail and result.tail.strip():
152
+ raise etree.ParserError('Element followed by text: %r' % result.tail)
153
+ result.tail = None
154
+ return result
155
+
156
+
157
+ def fromstring(html, guess_charset=None, parser=None):
158
+ """Parse the html, returning a single element/document.
159
+
160
+ This tries to minimally parse the chunk of text, without knowing if it
161
+ is a fragment or a document.
162
+
163
+ 'base_url' will set the document's base_url attribute (and the tree's
164
+ docinfo.URL)
165
+
166
+ If `guess_charset` is true, or if the input is not Unicode but a
167
+ byte string, the `chardet` library will perform charset guessing
168
+ on the string.
169
+ """
170
+ if not isinstance(html, _strings):
171
+ raise TypeError('string required')
172
+ doc = document_fromstring(html, parser=parser,
173
+ guess_charset=guess_charset)
174
+
175
+ # document starts with doctype or <html>, full document!
176
+ start = html[:50]
177
+ if isinstance(start, bytes):
178
+ # Allow text comparison in python3.
179
+ # Decode as ascii, that also covers latin-1 and utf-8 for the
180
+ # characters we need.
181
+ start = start.decode('ascii', 'replace')
182
+
183
+ start = start.lstrip().lower()
184
+ if start.startswith('<html') or start.startswith('<!doctype'):
185
+ return doc
186
+
187
+ head = _find_tag(doc, 'head')
188
+
189
+ # if the head is not empty we have a full document
190
+ if len(head):
191
+ return doc
192
+
193
+ body = _find_tag(doc, 'body')
194
+
195
+ # The body has just one element, so it was probably a single
196
+ # element passed in
197
+ if (len(body) == 1 and (not body.text or not body.text.strip())
198
+ and (not body[-1].tail or not body[-1].tail.strip())):
199
+ return body[0]
200
+
201
+ # Now we have a body which represents a bunch of tags which have the
202
+ # content that was passed in. We will create a fake container, which
203
+ # is the body tag, except <body> implies too much structure.
204
+ if _contains_block_level_tag(body):
205
+ body.tag = 'div'
206
+ else:
207
+ body.tag = 'span'
208
+ return body
209
+
210
+
211
+ def parse(filename_url_or_file, guess_charset=None, parser=None):
212
+ """Parse a filename, URL, or file-like object into an HTML document
213
+ tree. Note: this returns a tree, not an element. Use
214
+ ``parse(...).getroot()`` to get the document root.
215
+
216
+ If ``guess_charset`` is true, the ``useChardet`` option is passed into
217
+ html5lib to enable character detection. This option is on by default
218
+ when parsing from URLs, off by default when parsing from file(-like)
219
+ objects (which tend to return Unicode more often than not), and on by
220
+ default when parsing from a file path (which is read in binary mode).
221
+ """
222
+ if parser is None:
223
+ parser = html_parser
224
+ if not isinstance(filename_url_or_file, _strings):
225
+ fp = filename_url_or_file
226
+ if guess_charset is None:
227
+ # assume that file-like objects return Unicode more often than bytes
228
+ guess_charset = False
229
+ elif _looks_like_url(filename_url_or_file):
230
+ fp = urlopen(filename_url_or_file)
231
+ if guess_charset is None:
232
+ # assume that URLs return bytes
233
+ guess_charset = True
234
+ else:
235
+ fp = open(filename_url_or_file, 'rb')
236
+ if guess_charset is None:
237
+ guess_charset = True
238
+
239
+ options = {}
240
+ # html5lib does not accept useChardet as an argument, if it
241
+ # detected the html argument would produce unicode objects.
242
+ if guess_charset:
243
+ options['useChardet'] = guess_charset
244
+ return parser.parse(fp, **options)
245
+
246
+
247
+ def _looks_like_url(str):
248
+ scheme = urlparse(str)[0]
249
+ if not scheme:
250
+ return False
251
+ elif (sys.platform == 'win32' and
252
+ scheme in string.ascii_letters
253
+ and len(scheme) == 1):
254
+ # looks like a 'normal' absolute path
255
+ return False
256
+ else:
257
+ return True
258
+
259
+
260
+ html_parser = HTMLParser()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/__init__.pxd ADDED
File without changes
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/config.pxd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ cdef extern from "etree_defs.h":
2
+ cdef bint ENABLE_THREADING
3
+ cdef bint ENABLE_SCHEMATRON
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/relaxng.pxd ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lxml.includes.tree cimport xmlDoc
2
+ from lxml.includes.xmlerror cimport xmlStructuredErrorFunc
3
+
4
+ cdef extern from "libxml/relaxng.h" nogil:
5
+ ctypedef struct xmlRelaxNG
6
+ ctypedef struct xmlRelaxNGParserCtxt
7
+
8
+ ctypedef struct xmlRelaxNGValidCtxt
9
+
10
+ ctypedef enum xmlRelaxNGValidErr:
11
+ XML_RELAXNG_OK = 0
12
+ XML_RELAXNG_ERR_MEMORY = 1
13
+ XML_RELAXNG_ERR_TYPE = 2
14
+ XML_RELAXNG_ERR_TYPEVAL = 3
15
+ XML_RELAXNG_ERR_DUPID = 4
16
+ XML_RELAXNG_ERR_TYPECMP = 5
17
+ XML_RELAXNG_ERR_NOSTATE = 6
18
+ XML_RELAXNG_ERR_NODEFINE = 7
19
+ XML_RELAXNG_ERR_LISTEXTRA = 8
20
+ XML_RELAXNG_ERR_LISTEMPTY = 9
21
+ XML_RELAXNG_ERR_INTERNODATA = 10
22
+ XML_RELAXNG_ERR_INTERSEQ = 11
23
+ XML_RELAXNG_ERR_INTEREXTRA = 12
24
+ XML_RELAXNG_ERR_ELEMNAME = 13
25
+ XML_RELAXNG_ERR_ATTRNAME = 14
26
+ XML_RELAXNG_ERR_ELEMNONS = 15
27
+ XML_RELAXNG_ERR_ATTRNONS = 16
28
+ XML_RELAXNG_ERR_ELEMWRONGNS = 17
29
+ XML_RELAXNG_ERR_ATTRWRONGNS = 18
30
+ XML_RELAXNG_ERR_ELEMEXTRANS = 19
31
+ XML_RELAXNG_ERR_ATTREXTRANS = 20
32
+ XML_RELAXNG_ERR_ELEMNOTEMPTY = 21
33
+ XML_RELAXNG_ERR_NOELEM = 22
34
+ XML_RELAXNG_ERR_NOTELEM = 23
35
+ XML_RELAXNG_ERR_ATTRVALID = 24
36
+ XML_RELAXNG_ERR_CONTENTVALID = 25
37
+ XML_RELAXNG_ERR_EXTRACONTENT = 26
38
+ XML_RELAXNG_ERR_INVALIDATTR = 27
39
+ XML_RELAXNG_ERR_DATAELEM = 28
40
+ XML_RELAXNG_ERR_VALELEM = 29
41
+ XML_RELAXNG_ERR_LISTELEM = 30
42
+ XML_RELAXNG_ERR_DATATYPE = 31
43
+ XML_RELAXNG_ERR_VALUE = 32
44
+ XML_RELAXNG_ERR_LIST = 33
45
+ XML_RELAXNG_ERR_NOGRAMMAR = 34
46
+ XML_RELAXNG_ERR_EXTRADATA = 35
47
+ XML_RELAXNG_ERR_LACKDATA = 36
48
+ XML_RELAXNG_ERR_INTERNAL = 37
49
+ XML_RELAXNG_ERR_ELEMWRONG = 38
50
+ XML_RELAXNG_ERR_TEXTWRONG = 39
51
+
52
+ cdef xmlRelaxNGValidCtxt* xmlRelaxNGNewValidCtxt(xmlRelaxNG* schema)
53
+ cdef int xmlRelaxNGValidateDoc(xmlRelaxNGValidCtxt* ctxt, xmlDoc* doc)
54
+ cdef xmlRelaxNG* xmlRelaxNGParse(xmlRelaxNGParserCtxt* ctxt)
55
+ cdef xmlRelaxNGParserCtxt* xmlRelaxNGNewParserCtxt(char* URL)
56
+ cdef xmlRelaxNGParserCtxt* xmlRelaxNGNewDocParserCtxt(xmlDoc* doc)
57
+ cdef void xmlRelaxNGFree(xmlRelaxNG* schema)
58
+ cdef void xmlRelaxNGFreeParserCtxt(xmlRelaxNGParserCtxt* ctxt)
59
+ cdef void xmlRelaxNGFreeValidCtxt(xmlRelaxNGValidCtxt* ctxt)
60
+
61
+ cdef void xmlRelaxNGSetValidStructuredErrors(
62
+ xmlRelaxNGValidCtxt* ctxt, xmlStructuredErrorFunc serror, void *ctx)
63
+ cdef void xmlRelaxNGSetParserStructuredErrors(
64
+ xmlRelaxNGParserCtxt* ctxt, xmlStructuredErrorFunc serror, void *ctx)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/schematron.pxd ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lxml.includes cimport xmlerror
2
+ from lxml.includes.tree cimport xmlDoc
3
+
4
+ cdef extern from "libxml/schematron.h" nogil:
5
+ ctypedef struct xmlSchematron
6
+ ctypedef struct xmlSchematronParserCtxt
7
+ ctypedef struct xmlSchematronValidCtxt
8
+
9
+ ctypedef enum xmlSchematronValidOptions:
10
+ XML_SCHEMATRON_OUT_QUIET = 1 # quiet no report
11
+ XML_SCHEMATRON_OUT_TEXT = 2 # build a textual report
12
+ XML_SCHEMATRON_OUT_XML = 4 # output SVRL
13
+ XML_SCHEMATRON_OUT_ERROR = 8 # output via xmlStructuredErrorFunc
14
+ XML_SCHEMATRON_OUT_FILE = 256 # output to a file descriptor
15
+ XML_SCHEMATRON_OUT_BUFFER = 512 # output to a buffer
16
+ XML_SCHEMATRON_OUT_IO = 1024 # output to I/O mechanism
17
+
18
+ cdef xmlSchematronParserCtxt* xmlSchematronNewDocParserCtxt(
19
+ xmlDoc* doc)
20
+ cdef xmlSchematronParserCtxt* xmlSchematronNewParserCtxt(
21
+ char* filename) nogil
22
+ cdef xmlSchematronValidCtxt* xmlSchematronNewValidCtxt(
23
+ xmlSchematron* schema, int options)
24
+
25
+ cdef xmlSchematron* xmlSchematronParse(xmlSchematronParserCtxt* ctxt)
26
+ cdef int xmlSchematronValidateDoc(xmlSchematronValidCtxt* ctxt,
27
+ xmlDoc* instance)
28
+
29
+ cdef void xmlSchematronFreeParserCtxt(xmlSchematronParserCtxt* ctxt)
30
+ cdef void xmlSchematronFreeValidCtxt(xmlSchematronValidCtxt* ctxt)
31
+ cdef void xmlSchematronFree(xmlSchematron* schema)
32
+ cdef void xmlSchematronSetValidStructuredErrors(
33
+ xmlSchematronValidCtxt* ctxt,
34
+ xmlerror.xmlStructuredErrorFunc error_func, void *data)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/includes/xpath.pxd ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lxml.includes cimport tree
2
+ from lxml.includes cimport xmlerror
3
+
4
+ from libc.string cimport const_char
5
+ from lxml.includes.tree cimport xmlChar, const_xmlChar
6
+
7
+
8
+ cdef extern from "libxml/xpath.h" nogil:
9
+ ctypedef enum xmlXPathObjectType:
10
+ XPATH_UNDEFINED = 0
11
+ XPATH_NODESET = 1
12
+ XPATH_BOOLEAN = 2
13
+ XPATH_NUMBER = 3
14
+ XPATH_STRING = 4
15
+ XPATH_POINT = 5
16
+ XPATH_RANGE = 6
17
+ XPATH_LOCATIONSET = 7
18
+ XPATH_USERS = 8
19
+ XPATH_XSLT_TREE = 9
20
+
21
+ ctypedef enum xmlXPathError:
22
+ XPATH_EXPRESSION_OK = 0
23
+ XPATH_NUMBER_ERROR = 1
24
+ XPATH_UNFINISHED_LITERAL_ERROR = 2
25
+ XPATH_START_LITERAL_ERROR = 3
26
+ XPATH_VARIABLE_REF_ERROR = 4
27
+ XPATH_UNDEF_VARIABLE_ERROR = 5
28
+ XPATH_INVALID_PREDICATE_ERROR = 6
29
+ XPATH_EXPR_ERROR = 7
30
+ XPATH_UNCLOSED_ERROR = 8
31
+ XPATH_UNKNOWN_FUNC_ERROR = 9
32
+ XPATH_INVALID_OPERAND = 10
33
+ XPATH_INVALID_TYPE = 11
34
+ XPATH_INVALID_ARITY = 12
35
+ XPATH_INVALID_CTXT_SIZE = 13
36
+ XPATH_INVALID_CTXT_POSITION = 14
37
+ XPATH_MEMORY_ERROR = 15
38
+ XPTR_SYNTAX_ERROR = 16
39
+ XPTR_RESOURCE_ERROR = 17
40
+ XPTR_SUB_RESOURCE_ERROR = 18
41
+ XPATH_UNDEF_PREFIX_ERROR = 19
42
+ XPATH_ENCODING_ERROR = 20
43
+ XPATH_INVALID_CHAR_ERROR = 21
44
+ XPATH_INVALID_CTXT = 22
45
+
46
+ ctypedef struct xmlNodeSet:
47
+ int nodeNr
48
+ int nodeMax
49
+ tree.xmlNode** nodeTab
50
+
51
+ ctypedef struct xmlXPathObject:
52
+ xmlXPathObjectType type
53
+ xmlNodeSet* nodesetval
54
+ bint boolval
55
+ double floatval
56
+ xmlChar* stringval
57
+
58
+ ctypedef struct xmlXPathContext:
59
+ tree.xmlDoc* doc
60
+ tree.xmlNode* node
61
+ tree.xmlDict* dict
62
+ tree.xmlHashTable* nsHash
63
+ const_xmlChar* function
64
+ const_xmlChar* functionURI
65
+ xmlerror.xmlStructuredErrorFunc error
66
+ xmlerror.xmlError lastError
67
+ void* userData
68
+
69
+ ctypedef struct xmlXPathParserContext:
70
+ xmlXPathContext* context
71
+ xmlXPathObject* value
72
+ tree.xmlNode* ancestor
73
+ int error
74
+
75
+ ctypedef struct xmlXPathCompExpr
76
+
77
+ ctypedef void (*xmlXPathFunction)(xmlXPathParserContext* ctxt, int nargs)
78
+ ctypedef xmlXPathFunction (*xmlXPathFuncLookupFunc)(void* ctxt,
79
+ const_xmlChar* name,
80
+ const_xmlChar* ns_uri)
81
+
82
+ cdef xmlXPathContext* xmlXPathNewContext(tree.xmlDoc* doc)
83
+ cdef xmlXPathObject* xmlXPathEvalExpression(const_xmlChar* str,
84
+ xmlXPathContext* ctxt)
85
+ cdef xmlXPathObject* xmlXPathCompiledEval(xmlXPathCompExpr* comp,
86
+ xmlXPathContext* ctxt)
87
+ cdef xmlXPathCompExpr* xmlXPathCompile(const_xmlChar* str)
88
+ cdef xmlXPathCompExpr* xmlXPathCtxtCompile(xmlXPathContext* ctxt,
89
+ const_xmlChar* str)
90
+ cdef void xmlXPathFreeContext(xmlXPathContext* ctxt)
91
+ cdef void xmlXPathFreeCompExpr(xmlXPathCompExpr* comp)
92
+ cdef void xmlXPathFreeObject(xmlXPathObject* obj)
93
+ cdef int xmlXPathRegisterNs(xmlXPathContext* ctxt,
94
+ const_xmlChar* prefix, const_xmlChar* ns_uri)
95
+
96
+ cdef xmlNodeSet* xmlXPathNodeSetCreate(tree.xmlNode* val)
97
+ cdef void xmlXPathFreeNodeSet(xmlNodeSet* val)
98
+
99
+
100
+ cdef extern from "libxml/xpathInternals.h" nogil:
101
+ cdef int xmlXPathRegisterFunc(xmlXPathContext* ctxt,
102
+ const_xmlChar* name,
103
+ xmlXPathFunction f)
104
+ cdef int xmlXPathRegisterFuncNS(xmlXPathContext* ctxt,
105
+ const_xmlChar* name,
106
+ const_xmlChar* ns_uri,
107
+ xmlXPathFunction f)
108
+ cdef void xmlXPathRegisterFuncLookup(xmlXPathContext *ctxt,
109
+ xmlXPathFuncLookupFunc f,
110
+ void *funcCtxt)
111
+ cdef int xmlXPathRegisterVariable(xmlXPathContext *ctxt,
112
+ const_xmlChar* name,
113
+ xmlXPathObject* value)
114
+ cdef int xmlXPathRegisterVariableNS(xmlXPathContext *ctxt,
115
+ const_xmlChar* name,
116
+ const_xmlChar* ns_uri,
117
+ xmlXPathObject* value)
118
+ cdef void xmlXPathRegisteredVariablesCleanup(xmlXPathContext *ctxt)
119
+ cdef void xmlXPathRegisteredNsCleanup(xmlXPathContext *ctxt)
120
+ cdef xmlXPathObject* valuePop (xmlXPathParserContext *ctxt)
121
+ cdef int valuePush(xmlXPathParserContext* ctxt, xmlXPathObject *value)
122
+
123
+ cdef xmlXPathObject* xmlXPathNewCString(const_char *val)
124
+ cdef xmlXPathObject* xmlXPathWrapCString(const_char * val)
125
+ cdef xmlXPathObject* xmlXPathNewString(const_xmlChar *val)
126
+ cdef xmlXPathObject* xmlXPathWrapString(const_xmlChar * val)
127
+ cdef xmlXPathObject* xmlXPathNewFloat(double val)
128
+ cdef xmlXPathObject* xmlXPathNewBoolean(int val)
129
+ cdef xmlXPathObject* xmlXPathNewNodeSet(tree.xmlNode* val)
130
+ cdef xmlXPathObject* xmlXPathNewValueTree(tree.xmlNode* val)
131
+ cdef void xmlXPathNodeSetAdd(xmlNodeSet* cur,
132
+ tree.xmlNode* val)
133
+ cdef void xmlXPathNodeSetAddUnique(xmlNodeSet* cur,
134
+ tree.xmlNode* val)
135
+ cdef xmlXPathObject* xmlXPathWrapNodeSet(xmlNodeSet* val)
136
+ cdef void xmlXPathErr(xmlXPathParserContext* ctxt, int error)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/isoschematron/__init__.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The ``lxml.isoschematron`` package implements ISO Schematron support on top
2
+ of the pure-xslt 'skeleton' implementation.
3
+ """
4
+
5
+ import sys
6
+ import os.path
7
+ from lxml import etree as _etree # due to validator __init__ signature
8
+
9
+
10
+ # some compat stuff, borrowed from lxml.html
11
+ try:
12
+ unicode
13
+ except NameError:
14
+ # Python 3
15
+ unicode = str
16
+ try:
17
+ basestring
18
+ except NameError:
19
+ # Python 3
20
+ basestring = str
21
+
22
+
23
+ __all__ = ['extract_xsd', 'extract_rng', 'iso_dsdl_include',
24
+ 'iso_abstract_expand', 'iso_svrl_for_xslt1',
25
+ 'svrl_validation_errors', 'schematron_schema_valid',
26
+ 'stylesheet_params', 'Schematron']
27
+
28
+
29
+ # some namespaces
30
+ #FIXME: Maybe lxml should provide a dedicated place for common namespace
31
+ #FIXME: definitions?
32
+ XML_SCHEMA_NS = "http://www.w3.org/2001/XMLSchema"
33
+ RELAXNG_NS = "http://relaxng.org/ns/structure/1.0"
34
+ SCHEMATRON_NS = "http://purl.oclc.org/dsdl/schematron"
35
+ SVRL_NS = "http://purl.oclc.org/dsdl/svrl"
36
+
37
+
38
+ # some helpers
39
+ _schematron_root = '{%s}schema' % SCHEMATRON_NS
40
+ _xml_schema_root = '{%s}schema' % XML_SCHEMA_NS
41
+ _resources_dir = os.path.join(os.path.dirname(__file__), 'resources')
42
+
43
+
44
+ # the iso-schematron skeleton implementation steps aka xsl transformations
45
+ extract_xsd = _etree.XSLT(_etree.parse(
46
+ os.path.join(_resources_dir, 'xsl', 'XSD2Schtrn.xsl')))
47
+ extract_rng = _etree.XSLT(_etree.parse(
48
+ os.path.join(_resources_dir, 'xsl', 'RNG2Schtrn.xsl')))
49
+ iso_dsdl_include = _etree.XSLT(_etree.parse(
50
+ os.path.join(_resources_dir, 'xsl', 'iso-schematron-xslt1',
51
+ 'iso_dsdl_include.xsl')))
52
+ iso_abstract_expand = _etree.XSLT(_etree.parse(
53
+ os.path.join(_resources_dir, 'xsl', 'iso-schematron-xslt1',
54
+ 'iso_abstract_expand.xsl')))
55
+ iso_svrl_for_xslt1 = _etree.XSLT(_etree.parse(
56
+ os.path.join(_resources_dir,
57
+ 'xsl', 'iso-schematron-xslt1', 'iso_svrl_for_xslt1.xsl')))
58
+
59
+
60
+ # svrl result accessors
61
+ svrl_validation_errors = _etree.XPath(
62
+ '//svrl:failed-assert', namespaces={'svrl': SVRL_NS})
63
+
64
+ # RelaxNG validator for schematron schemas
65
+ schematron_schema_valid_supported = False
66
+ try:
67
+ schematron_schema_valid = _etree.RelaxNG(
68
+ file=os.path.join(_resources_dir, 'rng', 'iso-schematron.rng'))
69
+ schematron_schema_valid_supported = True
70
+ except _etree.RelaxNGParseError:
71
+ # Some distributions delete the file due to licensing issues.
72
+ def schematron_schema_valid(arg):
73
+ raise NotImplementedError("Validating the ISO schematron requires iso-schematron.rng")
74
+
75
+
76
+ def stylesheet_params(**kwargs):
77
+ """Convert keyword args to a dictionary of stylesheet parameters.
78
+ XSL stylesheet parameters must be XPath expressions, i.e.:
79
+
80
+ * string expressions, like "'5'"
81
+ * simple (number) expressions, like "5"
82
+ * valid XPath expressions, like "/a/b/text()"
83
+
84
+ This function converts native Python keyword arguments to stylesheet
85
+ parameters following these rules:
86
+ If an arg is a string wrap it with XSLT.strparam().
87
+ If an arg is an XPath object use its path string.
88
+ If arg is None raise TypeError.
89
+ Else convert arg to string.
90
+ """
91
+ result = {}
92
+ for key, val in kwargs.items():
93
+ if isinstance(val, basestring):
94
+ val = _etree.XSLT.strparam(val)
95
+ elif val is None:
96
+ raise TypeError('None not allowed as a stylesheet parameter')
97
+ elif not isinstance(val, _etree.XPath):
98
+ val = unicode(val)
99
+ result[key] = val
100
+ return result
101
+
102
+
103
+ # helper function for use in Schematron __init__
104
+ def _stylesheet_param_dict(paramsDict, kwargsDict):
105
+ """Return a copy of paramsDict, updated with kwargsDict entries, wrapped as
106
+ stylesheet arguments.
107
+ kwargsDict entries with a value of None are ignored.
108
+ """
109
+ # beware of changing mutable default arg
110
+ paramsDict = dict(paramsDict)
111
+ for k, v in kwargsDict.items():
112
+ if v is not None: # None values do not override
113
+ paramsDict[k] = v
114
+ paramsDict = stylesheet_params(**paramsDict)
115
+ return paramsDict
116
+
117
+
118
+ class Schematron(_etree._Validator):
119
+ """An ISO Schematron validator.
120
+
121
+ Pass a root Element or an ElementTree to turn it into a validator.
122
+ Alternatively, pass a filename as keyword argument 'file' to parse from
123
+ the file system.
124
+
125
+ Schematron is a less well known, but very powerful schema language.
126
+ The main idea is to use the capabilities of XPath to put restrictions on
127
+ the structure and the content of XML documents.
128
+
129
+ The standard behaviour is to fail on ``failed-assert`` findings only
130
+ (``ASSERTS_ONLY``). To change this, you can either pass a report filter
131
+ function to the ``error_finder`` parameter (e.g. ``ASSERTS_AND_REPORTS``
132
+ or a custom ``XPath`` object), or subclass isoschematron.Schematron for
133
+ complete control of the validation process.
134
+
135
+ Built on the Schematron language 'reference' skeleton pure-xslt
136
+ implementation, the validator is created as an XSLT 1.0 stylesheet using
137
+ these steps:
138
+
139
+ 0) (Extract from XML Schema or RelaxNG schema)
140
+ 1) Process inclusions
141
+ 2) Process abstract patterns
142
+ 3) Compile the schematron schema to XSLT
143
+
144
+ The ``include`` and ``expand`` keyword arguments can be used to switch off
145
+ steps 1) and 2).
146
+ To set parameters for steps 1), 2) and 3) hand parameter dictionaries to the
147
+ keyword arguments ``include_params``, ``expand_params`` or
148
+ ``compile_params``.
149
+ For convenience, the compile-step parameter ``phase`` is also exposed as a
150
+ keyword argument ``phase``. This takes precedence if the parameter is also
151
+ given in the parameter dictionary.
152
+
153
+ If ``store_schematron`` is set to True, the (included-and-expanded)
154
+ schematron document tree is stored and available through the ``schematron``
155
+ property.
156
+ If ``store_xslt`` is set to True, the validation XSLT document tree will be
157
+ stored and can be retrieved through the ``validator_xslt`` property.
158
+ With ``store_report`` set to True (default: False), the resulting validation
159
+ report document gets stored and can be accessed as the ``validation_report``
160
+ property.
161
+
162
+ If ``validate_schema`` is set to False, the validation of the schema file
163
+ itself is disabled. Validation happens by default after building the full
164
+ schema, unless the schema validation file cannot be found at import time,
165
+ in which case the validation gets disabled. Some lxml distributions exclude
166
+ this file due to licensing issues. ISO-Schematron validation can then still
167
+ be used normally, but the schemas themselves cannot be validated.
168
+
169
+ Here is a usage example::
170
+
171
+ >>> from lxml import etree
172
+ >>> from lxml.isoschematron import Schematron
173
+
174
+ >>> schematron = Schematron(etree.XML('''
175
+ ... <schema xmlns="http://purl.oclc.org/dsdl/schematron" >
176
+ ... <pattern id="id_only_attribute">
177
+ ... <title>id is the only permitted attribute name</title>
178
+ ... <rule context="*">
179
+ ... <report test="@*[not(name()='id')]">Attribute
180
+ ... <name path="@*[not(name()='id')]"/> is forbidden<name/>
181
+ ... </report>
182
+ ... </rule>
183
+ ... </pattern>
184
+ ... </schema>'''),
185
+ ... error_finder=Schematron.ASSERTS_AND_REPORTS)
186
+
187
+ >>> xml = etree.XML('''
188
+ ... <AAA name="aaa">
189
+ ... <BBB id="bbb"/>
190
+ ... <CCC color="ccc"/>
191
+ ... </AAA>
192
+ ... ''')
193
+
194
+ >>> schematron.validate(xml)
195
+ False
196
+
197
+ >>> xml = etree.XML('''
198
+ ... <AAA id="aaa">
199
+ ... <BBB id="bbb"/>
200
+ ... <CCC/>
201
+ ... </AAA>
202
+ ... ''')
203
+
204
+ >>> schematron.validate(xml)
205
+ True
206
+ """
207
+
208
+ # libxml2 error categorization for validation errors
209
+ _domain = _etree.ErrorDomains.SCHEMATRONV
210
+ _level = _etree.ErrorLevels.ERROR
211
+ _error_type = _etree.ErrorTypes.SCHEMATRONV_ASSERT
212
+
213
+ # convenience definitions for common behaviours
214
+ ASSERTS_ONLY = svrl_validation_errors # Default
215
+ ASSERTS_AND_REPORTS = _etree.XPath(
216
+ '//svrl:failed-assert | //svrl:successful-report',
217
+ namespaces={'svrl': SVRL_NS})
218
+
219
+ def _extract(self, element):
220
+ """Extract embedded schematron schema from non-schematron host schema.
221
+ This method will only be called by __init__ if the given schema document
222
+ is not a schematron schema by itself.
223
+ Must return a schematron schema document tree or None.
224
+ """
225
+ schematron = None
226
+ if element.tag == _xml_schema_root:
227
+ schematron = self._extract_xsd(element)
228
+ elif element.nsmap.get(element.prefix) == RELAXNG_NS:
229
+ # RelaxNG does not have a single unique root element
230
+ schematron = self._extract_rng(element)
231
+ return schematron
232
+
233
+ # customization points
234
+ # etree.XSLT objects that provide the extract, include, expand, compile
235
+ # steps
236
+ _extract_xsd = extract_xsd
237
+ _extract_rng = extract_rng
238
+ _include = iso_dsdl_include
239
+ _expand = iso_abstract_expand
240
+ _compile = iso_svrl_for_xslt1
241
+
242
+ # etree.xpath object that determines input document validity when applied to
243
+ # the svrl result report; must return a list of result elements (empty if
244
+ # valid)
245
+ _validation_errors = ASSERTS_ONLY
246
+
247
+ def __init__(self, etree=None, file=None, include=True, expand=True,
248
+ include_params={}, expand_params={}, compile_params={},
249
+ store_schematron=False, store_xslt=False, store_report=False,
250
+ phase=None, error_finder=ASSERTS_ONLY,
251
+ validate_schema=schematron_schema_valid_supported):
252
+ super().__init__()
253
+
254
+ self._store_report = store_report
255
+ self._schematron = None
256
+ self._validator_xslt = None
257
+ self._validation_report = None
258
+ if error_finder is not self.ASSERTS_ONLY:
259
+ self._validation_errors = error_finder
260
+
261
+ # parse schema document, may be a schematron schema or an XML Schema or
262
+ # a RelaxNG schema with embedded schematron rules
263
+ root = None
264
+ try:
265
+ if etree is not None:
266
+ if _etree.iselement(etree):
267
+ root = etree
268
+ else:
269
+ root = etree.getroot()
270
+ elif file is not None:
271
+ root = _etree.parse(file).getroot()
272
+ except Exception:
273
+ raise _etree.SchematronParseError(
274
+ "No tree or file given: %s" % sys.exc_info()[1])
275
+ if root is None:
276
+ raise ValueError("Empty tree")
277
+ if root.tag == _schematron_root:
278
+ schematron = root
279
+ else:
280
+ schematron = self._extract(root)
281
+ if schematron is None:
282
+ raise _etree.SchematronParseError(
283
+ "Document is not a schematron schema or schematron-extractable")
284
+ # perform the iso-schematron skeleton implementation steps to get a
285
+ # validating xslt
286
+ if include:
287
+ schematron = self._include(schematron, **include_params)
288
+ if expand:
289
+ schematron = self._expand(schematron, **expand_params)
290
+ if validate_schema and not schematron_schema_valid(schematron):
291
+ raise _etree.SchematronParseError(
292
+ "invalid schematron schema: %s" %
293
+ schematron_schema_valid.error_log)
294
+ if store_schematron:
295
+ self._schematron = schematron
296
+ # add new compile keyword args here if exposing them
297
+ compile_kwargs = {'phase': phase}
298
+ compile_params = _stylesheet_param_dict(compile_params, compile_kwargs)
299
+ validator_xslt = self._compile(schematron, **compile_params)
300
+ if store_xslt:
301
+ self._validator_xslt = validator_xslt
302
+ self._validator = _etree.XSLT(validator_xslt)
303
+
304
+ def __call__(self, etree):
305
+ """Validate doc using Schematron.
306
+
307
+ Returns true if document is valid, false if not.
308
+ """
309
+ self._clear_error_log()
310
+ result = self._validator(etree)
311
+ if self._store_report:
312
+ self._validation_report = result
313
+ errors = self._validation_errors(result)
314
+ if errors:
315
+ if _etree.iselement(etree):
316
+ fname = etree.getroottree().docinfo.URL or '<file>'
317
+ else:
318
+ fname = etree.docinfo.URL or '<file>'
319
+ for error in errors:
320
+ # Does svrl report the line number, anywhere? Don't think so.
321
+ self._append_log_message(
322
+ domain=self._domain, type=self._error_type,
323
+ level=self._level, line=0,
324
+ message=_etree.tostring(error, encoding='unicode'),
325
+ filename=fname)
326
+ return False
327
+ return True
328
+
329
+ @property
330
+ def schematron(self):
331
+ """ISO-schematron schema document (None if object has been initialized
332
+ with store_schematron=False).
333
+ """
334
+ return self._schematron
335
+
336
+ @property
337
+ def validator_xslt(self):
338
+ """ISO-schematron skeleton implementation XSLT validator document (None
339
+ if object has been initialized with store_xslt=False).
340
+ """
341
+ return self._validator_xslt
342
+
343
+ @property
344
+ def validation_report(self):
345
+ """ISO-schematron validation result report (None if result-storing has
346
+ been turned off).
347
+ """
348
+ return self._validation_report
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (3.8 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/compiler.cpython-312.pyc ADDED
Binary file (4.62 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/__pycache__/driver.cpython-312.pyc ADDED
Binary file (3.69 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/compiler.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from triton.backends.compiler import BaseBackend, GPUTarget, Language
2
+ from triton._C.libtriton import ir, passes, llvm, amd
3
+ from triton import knobs
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, Tuple
6
+ from types import ModuleType
7
+ import hashlib
8
+ import tempfile
9
+ import re
10
+ import functools
11
+ import warnings
12
+ from pathlib import Path
13
+
14
+
15
+ def get_min_dot_size(target: GPUTarget):
16
+ # We fallback to use FMA and cast arguments if certain configurations is
17
+ # not supported natively by matrix core units.
18
+ return lambda lhs_type, rhs_type: (1, 1, 1)
19
+
20
+
21
+ def is_pingpong_schedule_enabled(arch, use_async_copy):
22
+ return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
23
+ ) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
24
+
25
+
26
+ def is_in_thread_transpose_enabled(arch):
27
+ return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class HIPOptions:
32
+ num_warps: int = 4
33
+ waves_per_eu: int = 0
34
+ num_stages: int = 2
35
+ num_ctas: int = 1
36
+ extern_libs: dict = None
37
+ debug: bool = False
38
+ sanitize_overflow: bool = True
39
+ arch: str = None
40
+ # We have native support for OCP fp8 variants since CDNA4/RDNA4. For earlier generations,
41
+ # we software emulate the support for them.
42
+ # UZ fp8 variants (fp8e4b8 and fp8e5b16) are natively supported for CDNA3. For other
43
+ # architectures they are software emulated.
44
+ supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8")
45
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
46
+ default_dot_input_precision: str = "ieee"
47
+ allowed_dot_input_precisions: Tuple[str] = ("ieee", 'bf16x3', 'bf16x6')
48
+ enable_fp_fusion: bool = True
49
+ launch_cooperative_grid: bool = False
50
+ matrix_instr_nonkdim: int = 0
51
+ kpack: int = 1
52
+ allow_flush_denorm: bool = False
53
+ max_num_imprecise_acc_default: int = 0
54
+ backend_name: str = 'hip'
55
+ instrumentation_mode: str = ""
56
+
57
+ # The following option provides hints to the AMDGPU backend regarding instruction scheduling
58
+ # for all `tt.dot` operations in a kernel. The "none" variant preserves the default
59
+ # instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
60
+ # The option is experimental and may change at any time regarding its semantics and/or may
61
+ # be gone entirely anytime.
62
+ #
63
+ # Current experimental scheduling variants:
64
+ #
65
+ # attention: enables a bunch of optimizations for attention kernels, including:
66
+ # - iglp 2 and sched.barrier around it
67
+ # - sink-insts-to-avoid-spills flag to avoid register spills
68
+ # memory-bound-attention: enables custom scheduling strategy in llvm backend,
69
+ # This option targets special FA variant, which is memory bound and
70
+ # has a lot of elementwise operations from fused operand dequantizations.
71
+ # Note that this option is highly experimental,
72
+ # and will be removed as soon as default sceduler algorithm is fixed.
73
+ #
74
+ # Option allows to set multiple variants divided by commas:
75
+ # schedule_hint="attention,memory-bound-attention"
76
+ schedule_hint: str = 'none'
77
+
78
+ def __post_init__(self):
79
+ gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number
80
+ warp_size = 32 if gfx_major >= 10 else 64
81
+ object.__setattr__(self, 'warp_size', warp_size)
82
+ assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
83
+ "num_warps must be a power of 2"
84
+
85
+ if (self.arch == 'gfx950') and (self.kpack != 1):
86
+ warnings.warn(
87
+ f"kpack is deprecated starting from gfx950 and will be removed in later releases. So for now kpack = {self.kpack} will be overwritten to 1 to make transitioning easier."
88
+ )
89
+ object.__setattr__(self, 'kpack', 1)
90
+
91
+ default_libdir = Path(__file__).parent / 'lib'
92
+ extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
93
+ for lib in ["ocml", "ockl"]:
94
+ extern_libs[lib] = str(default_libdir / f'{lib}.bc')
95
+ object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
96
+
97
+ def hash(self):
98
+ key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
99
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
100
+
101
+
102
+ class HIPBackend(BaseBackend):
103
+ instrumentation = None
104
+ supports_native_tensor_specialization = False
105
+
106
+ @staticmethod
107
+ def supports_target(target: GPUTarget):
108
+ return target.backend == 'hip'
109
+
110
+ def __init__(self, target: GPUTarget) -> None:
111
+ super().__init__(target)
112
+ assert isinstance(target.arch, str)
113
+ self.binary_ext = "hsaco"
114
+
115
+ def get_target_name(self, options) -> str:
116
+ return f"hip:{options.arch}"
117
+
118
+ def parse_options(self, opts) -> Any:
119
+ args = {'arch': knobs.runtime.override_arch or self.target.arch}
120
+
121
+ if opts.get("num_ctas", 1) > 1 and not amd.supports_multi_cta_launch(self.target.arch):
122
+ raise ValueError(f"num_ctas > 1 not supported on {self.target.arch}")
123
+
124
+ # Enable XF32 (TF32) for CDNA3 GPUs
125
+ if self.target.arch == 'gfx942':
126
+ allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
127
+ allowed_dot_input_precisions.update({'tf32'})
128
+ args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
129
+
130
+ if "supported_fp8_dtypes" not in opts:
131
+ args["supported_fp8_dtypes"] = tuple(sorted(HIPOptions.supported_fp8_dtypes))
132
+
133
+ if self.target.arch == 'gfx950':
134
+ deprecated_fp8_dot_operand_dtypes = set(HIPOptions.deprecated_fp8_dot_operand_dtypes)
135
+ deprecated_fp8_dot_operand_dtypes.update({"fp8e5b16", "fp8e4b8"})
136
+ args["deprecated_fp8_dot_operand_dtypes"] = tuple(sorted(deprecated_fp8_dot_operand_dtypes))
137
+
138
+ if "enable_fp_fusion" not in opts:
139
+ args["enable_fp_fusion"] = knobs.language.default_fp_fusion
140
+ args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None})
141
+ return HIPOptions(**args)
142
+
143
+ def pack_metadata(self, metadata):
144
+ return (
145
+ metadata.num_warps,
146
+ metadata.num_ctas,
147
+ metadata.shared,
148
+ )
149
+
150
+ def get_codegen_implementation(self, options):
151
+ return {"min_dot_size": get_min_dot_size(self.target)}
152
+
153
+ def get_module_map(self) -> Dict[str, ModuleType]:
154
+ from triton.language.extra.hip import libdevice
155
+
156
+ return {"triton.language.extra.libdevice": libdevice}
157
+
158
+ def load_dialects(self, ctx):
159
+ amd.load_dialects(ctx)
160
+ if HIPBackend.instrumentation:
161
+ HIPBackend.instrumentation.load_dialects(ctx)
162
+
163
+ @staticmethod
164
+ def is_within_2gb(arg):
165
+ import torch
166
+
167
+ MAX_INT_32 = 2**31 - 1
168
+ if hasattr(arg, "ptr_range"):
169
+ return arg.ptr_range() <= MAX_INT_32
170
+ if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
171
+ return arg.untyped_storage().size() <= MAX_INT_32
172
+ return False
173
+
174
+ @staticmethod
175
+ def parse_attr(desc):
176
+ ret = BaseBackend.parse_attr(desc)
177
+ if "S" in desc:
178
+ ret += [["tt.pointer_range", 32]]
179
+ return ret
180
+
181
+ @staticmethod
182
+ def get_tensor_specialization(arg, **kwargs):
183
+ ret = BaseBackend.get_tensor_specialization(arg, **kwargs)
184
+ if knobs.amd.use_buffer_ops and HIPBackend.is_within_2gb(arg):
185
+ ret += "S"
186
+ return ret
187
+
188
+ @staticmethod
189
+ def make_ttir(mod, metadata, options):
190
+ pm = ir.pass_manager(mod.context)
191
+ pm.enable_debug()
192
+ passes.common.add_inliner(pm)
193
+ passes.ttir.add_rewrite_tensor_pointer(pm)
194
+ passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
195
+ passes.common.add_canonicalizer(pm)
196
+ passes.ttir.add_combine(pm)
197
+ passes.ttir.add_reorder_broadcast(pm)
198
+ passes.common.add_cse(pm)
199
+ passes.ttir.add_triton_licm(pm)
200
+ passes.common.add_symbol_dce(pm)
201
+ passes.ttir.add_loop_unroll(pm)
202
+ pm.run(mod, 'make_ttir')
203
+ return mod
204
+
205
+ @staticmethod
206
+ def make_ttgir(mod, metadata, options):
207
+ pm = ir.pass_manager(mod.context)
208
+ pm.enable_debug()
209
+ passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size,
210
+ options.num_ctas)
211
+ pm.run(mod, 'make_ttgir_early')
212
+ pm = ir.pass_manager(mod.context)
213
+ pm.enable_debug()
214
+ emuTF32 = False
215
+ passes.ttgpuir.add_coalesce(pm)
216
+ passes.ttgpuir.add_f32_dot_tc(pm, emuTF32)
217
+ passes.ttgpuir.add_remove_layout_conversions(pm)
218
+ passes.ttgpuir.add_optimize_thread_locality(pm)
219
+ amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack)
220
+ passes.ttgpuir.add_remove_layout_conversions(pm)
221
+ amd.passes.ttgpuir.add_optimize_epilogue(pm)
222
+ amd.passes.ttgpuir.add_optimize_dot_operands(pm, options.arch)
223
+ amd.passes.ttgpuir.add_hoist_layout_conversions(pm)
224
+
225
+ passes.ttgpuir.add_fuse_nested_loops(pm)
226
+ passes.common.add_canonicalizer(pm)
227
+ passes.ttir.add_triton_licm(pm)
228
+ passes.common.add_canonicalizer(pm)
229
+
230
+ use_async_copy = knobs.amd.use_async_copy
231
+ use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)
232
+
233
+ amd.passes.ttgpuir.add_schedule_loops(pm, options.num_stages)
234
+ amd.passes.ttgpuir.add_pipeline(pm, use_async_copy, use_block_pingpong)
235
+ if use_async_copy:
236
+ amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
237
+ passes.common.add_canonicalizer(pm)
238
+ if options.schedule_hint.lower() != "none":
239
+ for hint in options.schedule_hint.split(","):
240
+ amd.passes.ttgpuir.insert_instruction_sched_hints(pm, hint)
241
+ passes.ttgpuir.add_remove_layout_conversions(pm)
242
+ passes.ttgpuir.add_reduce_data_duplication(pm)
243
+ if is_in_thread_transpose_enabled(options.arch):
244
+ amd.passes.ttgpuir.add_in_thread_transpose(pm)
245
+ passes.ttgpuir.add_remove_layout_conversions(pm)
246
+ amd.passes.ttgpuir.add_reorder_instructions(pm)
247
+ if use_block_pingpong and options.num_stages > 1:
248
+ amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)
249
+
250
+ if knobs.amd.use_buffer_ops:
251
+ amd.passes.ttgpuir.add_canonicalize_pointers(pm)
252
+ passes.common.add_canonicalizer(pm)
253
+ amd.passes.ttgpuir.add_convert_to_buffer_ops(
254
+ pm,
255
+ options.arch,
256
+ knobs.amd.use_buffer_atomics,
257
+ knobs.amd.buffer_ops_analyze_small_tensor_range,
258
+ )
259
+
260
+ amd.passes.ttgpuir.add_fold_true_cmpi(pm)
261
+ passes.common.add_canonicalizer(pm)
262
+ passes.common.add_cse(pm)
263
+ passes.common.add_symbol_dce(pm)
264
+ pm.run(mod, 'make_ttgir')
265
+ metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
266
+ return mod
267
+
268
+ @staticmethod
269
+ def gluon_to_ttgir(src, metadata, options):
270
+ mod = src
271
+ pm = ir.pass_manager(mod.context)
272
+ pm.enable_debug()
273
+
274
+ passes.gluon.add_inliner(pm)
275
+ passes.gluon.add_resolve_auto_encodings(pm)
276
+ passes.common.add_sccp(pm)
277
+ passes.ttir.add_loop_aware_cse(pm)
278
+ passes.gluon.add_canonicalizer(pm)
279
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
280
+
281
+ pm.run(mod, 'gluon_to_ttgir')
282
+ metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
283
+ return mod
284
+
285
+ @staticmethod
286
+ def make_llir(src, metadata, options):
287
+ mod = src
288
+ # TritonGPU -> LLVM-IR (MLIR)
289
+ pm = ir.pass_manager(mod.context)
290
+ pm.enable_debug()
291
+ amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
292
+ # custom_lds_size is an experimental parameter that defines amount of LDS available
293
+ # for one thread block. Measured in bytes.
294
+ #
295
+ # If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
296
+ # LDS size is determined by provided arch name.
297
+ custom_lds_size = 0
298
+ amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
299
+ passes.convert.add_scf_to_cf(pm)
300
+ passes.gluon.add_inliner(pm)
301
+ passes.convert.add_index_to_llvmir(pm)
302
+
303
+ amd.passes.ttgpuir.add_allocate_shared_memory(pm)
304
+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
305
+ if HIPBackend.instrumentation:
306
+ HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
307
+ ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
308
+ ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
309
+ ## of the value of kernel arg `allow_flush_denorm`.
310
+ ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
311
+ ## depends on the value of kernel arg `allow_flush_denorm`.
312
+ ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
313
+ ## For now it is used as a controller for developers only.
314
+ __HIP_FTZ = True
315
+ amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
316
+ passes.common.add_canonicalizer(pm)
317
+ passes.common.add_cse(pm)
318
+
319
+ passes.convert.add_cf_to_llvmir(pm)
320
+ passes.convert.add_arith_to_llvmir(pm)
321
+ passes.common.add_canonicalizer(pm)
322
+ passes.common.add_cse(pm)
323
+ passes.common.add_symbol_dce(pm)
324
+
325
+ if options.schedule_hint.lower() != "none":
326
+ amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
327
+
328
+ # This can not be moved below the di_scope pass
329
+ if HIPBackend.instrumentation:
330
+ HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
331
+
332
+ if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
333
+ passes.llvmir.add_di_scope(pm)
334
+
335
+ amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
336
+ pm.run(mod, 'make_llir')
337
+
338
+ if knobs.compilation.dump_ir_extract_di_local_variables:
339
+ # comments below on why separate it
340
+ if not knobs.compilation.disable_line_info:
341
+ pm = ir.pass_manager(mod.context)
342
+ pm.enable_debug()
343
+ passes.llvmir.add_di_scope(pm)
344
+ pm.run(mod, 'make_llir.disable_line_info')
345
+
346
+ # insert dbg intrinsic with several DI Attribute including source
347
+ # var name and type info note: unknown reason for now, but this
348
+ # pass and add_di_scope has to be run separately, otherwise if we
349
+ # put them into previous pipline, it trigger a segmentfault without
350
+ # any error message; could be due to a bug in mlir or pybind11
351
+ pm = ir.pass_manager(mod.context)
352
+ pm.enable_debug()
353
+ passes.llvmir.add_di_local_variable(pm)
354
+ pm.run(mod, 'make_llir.dump_ir_extract_di_local_variables')
355
+
356
+ # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
357
+ llvm.init_targets()
358
+ context = llvm.context()
359
+ llvm_mod = llvm.to_module(mod, context)
360
+ amd.attach_target_triple(llvm_mod)
361
+ target_features = ''
362
+ if knobs.compilation.enable_asan:
363
+ target_features = '+xnack'
364
+ llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)
365
+
366
+ # Set various control constants on the LLVM module so that device
367
+ # libraries can resolve references to them.
368
+ amd.set_isa_version(llvm_mod, options.arch)
369
+ amd.set_abi_version(llvm_mod, 500)
370
+ amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False)
371
+ amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True)
372
+ amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False)
373
+ amd.set_bool_control_constant(llvm_mod, "__oclc_wavefrontsize64", options.warp_size == 64)
374
+
375
+ # Set kernel attributes first given this may affect later optimizations.
376
+ fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()]
377
+ # The public kernel should be kernel 0.
378
+ fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL)
379
+ fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}")
380
+ if "memory-bound-attention" in options.schedule_hint.split(','):
381
+ fns[0].add_fn_attr("amdgpu-sched-strategy", "iterative-ilp")
382
+ fns[0].add_fn_attr("uniform-work-group-size", "true")
383
+ # LLVM AMDGPU backend supports the attribute "amdgpu-waves-per-eu"="<min>[, <max>]".
384
+ # This attribute may be attached to a kernel function definition and is an optimization hint.
385
+ # <min> parameter specifies the requested minimum number of waves per EU, and optional <max> parameter
386
+ # specifies the requested maximum number of waves per EU (must be >= <min> if specified).
387
+ # If <max> is omitted, then there is no restriction on the maximum number of waves per EU other than
388
+ # the one dictated by the hardware for which the kernel is compiled. Passing 0, 0 as <min>, <max>
389
+ # implies the default behavior (no limits).
390
+ # Specifying N, N forces LLVM to focus on a single register count, simplifies some heuristics
391
+ # and may improve scheduling.
392
+ fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}, {options.waves_per_eu}")
393
+ denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
394
+ fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
395
+ if knobs.compilation.enable_asan:
396
+ fns[0].add_fn_target_feature("+xnack")
397
+ fns[0].add_fn_asan_attr()
398
+
399
+ # Hint the compiler that we'd like the firmware to set the kernel arguments
400
+ # to user SGPRs so that the kernel does not need to s_load its arguments
401
+ # from memory.
402
+ amd.set_all_fn_arg_inreg(fns[0])
403
+
404
+ if knobs.compilation.enable_asan:
405
+ default_libdir = Path(__file__).parent / 'lib'
406
+ paths = [
407
+ str(default_libdir / 'asanrtl.bc'),
408
+ str(default_libdir / "ocml.bc"),
409
+ str(default_libdir / "ockl.bc")
410
+ ]
411
+ llvm.link_extern_libs(llvm_mod, paths)
412
+ elif options.extern_libs:
413
+ paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
414
+ if len(paths) > 0:
415
+ llvm.link_extern_libs(llvm_mod, paths)
416
+
417
+ llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
418
+
419
+ # Architectures with architected SGPRs store the workgroup id in ttmp9 (X) and ttmp7 (Y[15:0], Z[31:16]).
420
+ # These attributes are used to determine if Z should be masked out when loading Y. They are inferred during
421
+ # optimize_module from calls to @llvm.amdgcn.workgroup.id.x/y/z(). We cannot rely on this because a
422
+ # dispatch dimensions might be used even if there is no program_id() call for it.
423
+ if amd.has_architected_sgprs(options.arch):
424
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-x")
425
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-y")
426
+ fns[0].remove_fn_attr("amdgpu-no-workgroup-id-z")
427
+
428
+ if knobs.amd.scalarize_packed_fops:
429
+ amd.add_scalarize_packed_fops_llvm_pass(fns[0])
430
+
431
+ # Get some metadata
432
+ metadata["shared"] = src.get_int_attr("ttg.shared")
433
+ metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
434
+ metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
435
+
436
+ amd.cleanup_bitcode_metadata(llvm_mod)
437
+ # Disable inlining of print related functions,
438
+ # because inlining of these function could slow down compilation significantly
439
+ amd.disable_print_inline(llvm_mod)
440
+ return str(llvm_mod)
441
+
442
+ @staticmethod
443
+ def make_amdgcn(src, metadata, options):
444
+ # Find kernel names (there should only be one)
445
+ # We get the name at the last possible step to accommodate `triton.compile`
446
+ # on user-provided LLVM
447
+ names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src)
448
+ assert len(names) == 1
449
+ metadata["name"] = names[0]
450
+ # llvm -> hsaco
451
+ flags = []
452
+ features = '-real-true16' if 'gfx11' in options.arch else ''
453
+ ir_hash = hashlib.sha256(src.encode("utf-8")).hexdigest()
454
+ dump_file_id = names[0] + '_' + ir_hash
455
+ _ = llvm.translate_to_mir(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
456
+ dump_file_id)
457
+ llvm.dump_sched_dag(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
458
+ dump_file_id)
459
+ amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, features, flags, options.enable_fp_fusion,
460
+ False)
461
+ if knobs.amd.dump_amdgcn:
462
+ print("// -----// AMDGCN Dump //----- //")
463
+ print(amdgcn)
464
+ return amdgcn
465
+
466
+ @staticmethod
467
+ def make_hsaco(src, metadata, options):
468
+ target_features = ''
469
+ if knobs.compilation.enable_asan:
470
+ target_features = '+xnack'
471
+ hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
472
+ with tempfile.NamedTemporaryFile() as tmp_out:
473
+ with tempfile.NamedTemporaryFile() as tmp_in:
474
+ with open(tmp_in.name, "wb") as fd_in:
475
+ fd_in.write(hsaco)
476
+ amd.link_hsaco(tmp_in.name, tmp_out.name)
477
+ with open(tmp_out.name, "rb") as fd_out:
478
+ ret = fd_out.read()
479
+ return ret
480
+
481
+ def add_stages(self, stages, options, language):
482
+ if language == Language.TRITON:
483
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
484
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options)
485
+ elif language == Language.GLUON:
486
+ stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
487
+ stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
488
+ stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
489
+ stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
490
+ if knobs.runtime.add_stages_inspection_hook is not None:
491
+ knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)
492
+
493
+ @functools.lru_cache()
494
+ def hash(self):
495
+ return f'{self.target}'
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/driver.c ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define __HIP_PLATFORM_AMD__
2
+ #include <hip/hip_runtime.h>
3
+ #include <hip/hip_runtime_api.h>
4
+ #define PY_SSIZE_T_CLEAN
5
+ #include <Python.h>
6
+ #include <dlfcn.h>
7
+ #include <stdbool.h>
8
+ #include <stdio.h>
9
+ #include <stdlib.h>
10
+
11
+ typedef struct {
12
+ uint32_t group0_0;
13
+ uint32_t group0_1;
14
+ uint32_t group0_2;
15
+ uint32_t group0_3;
16
+ uint32_t group1_0;
17
+ uint32_t group1_1;
18
+ uint32_t group1_2;
19
+ uint32_t group1_3;
20
+ uint32_t group1_4;
21
+ uint32_t group1_5;
22
+ uint32_t group1_6;
23
+ uint32_t group1_7;
24
+ } TDMDescriptor;
25
+
26
+ typedef struct {
27
+ PyObject_HEAD;
28
+ TDMDescriptor desc;
29
+ } PyTDMDescriptorObject;
30
+
31
+ static PyObject *PyTDMDescriptor_new(PyTypeObject *type, PyObject *args,
32
+ PyObject *kw) {
33
+ PyTDMDescriptorObject *self =
34
+ (PyTDMDescriptorObject *)type->tp_alloc(type, 0);
35
+ if (!self)
36
+ return NULL;
37
+
38
+ memset(&self->desc, 0, sizeof(self->desc));
39
+ return (PyObject *)self;
40
+ }
41
+
42
+ static void PyTDMDescriptor_dealloc(PyTDMDescriptorObject *self) {
43
+ Py_TYPE(self)->tp_free((PyObject *)self);
44
+ }
45
+
46
+ static PyTypeObject PyTDMDescriptorType = {
47
+ PyVarObject_HEAD_INIT(NULL, 0).tp_name =
48
+ "triton.backends.amd.PyTDMDescriptor",
49
+ .tp_basicsize = sizeof(PyTDMDescriptorObject),
50
+ .tp_itemsize = 0,
51
+ .tp_flags = Py_TPFLAGS_DEFAULT,
52
+ .tp_doc = "PyObject for TDMDescriptor",
53
+ .tp_new = PyTDMDescriptor_new,
54
+ .tp_dealloc = (destructor)PyTDMDescriptor_dealloc,
55
+ };
56
+
57
+ // TODO: Both host-side and device-side TDM descriptor follow the same encoding
58
+ // format. Consider to add a common utility to remove duplicate code.
59
+ static bool encodeTDMDescriptor(TDMDescriptor *desc, int elementBitWidth,
60
+ uint32_t *blockSize, int numWarps,
61
+ int padInterval, int padAmount, uint32_t *shape,
62
+ uint32_t *strides, uint64_t globalAddress,
63
+ int rank) {
64
+ // NYI: TDM > 2D cases
65
+ if (rank != 2)
66
+ return false;
67
+
68
+ // Get warp distribution
69
+ uint32_t numWarpsDim0 = numWarps;
70
+ for (; numWarpsDim0 > blockSize[0]; numWarpsDim0 /= 2)
71
+ ;
72
+ uint32_t numWarpsDim1 = numWarps / numWarpsDim0;
73
+ if (!(numWarpsDim0 > 0 && blockSize[1] % numWarpsDim1 == 0))
74
+ return false;
75
+
76
+ uint32_t blockSize0 = (blockSize[0] + numWarpsDim0 - 1) / numWarpsDim0;
77
+ uint32_t blockSize1 = (blockSize[1] + numWarpsDim1 - 1) / numWarpsDim1;
78
+
79
+ // group0 (128 bits / 4 dwords) effective bit encoding:
80
+ // [120:64]: global address
81
+ // [127:126]: type - currently always set to 0x2
82
+ desc->group0_2 = (uint32_t)(globalAddress & 0xFFFFFFFF);
83
+ desc->group0_3 = (uint32_t)((globalAddress >> 32) & 0x01FFFFFF);
84
+ desc->group0_3 |= (0x1 << 31);
85
+
86
+ // group1 (256 bits / 8 dwords) effective bit encoding:
87
+ // [17:16]: data size - log2(element size in bytes)
88
+ // [20]: enable padding
89
+ // [24:22]: pad interval - log2(pad interval in dwords) - 1
90
+ // [31:25]: pad amount - pad amount in dwords - 1
91
+ // [79:48]: tensor shape dim inner
92
+ // [111:80]: tensor shape dim outer
93
+ // [127:112]: block shape dim inner
94
+ // [143:128]: block shape dim outer
95
+ // [207:160]: tensor stride dim outer (we only use 32 bits)
96
+ int elementSizeInBytes = elementBitWidth / 8;
97
+ int dataSize = log2(elementSizeInBytes);
98
+ desc->group1_0 = (dataSize << 16);
99
+ int dwordSize = 32;
100
+ int padIntervalInDwords = padInterval * elementBitWidth / dwordSize;
101
+ int padAmountInDwords = padAmount * elementBitWidth / dwordSize;
102
+ if (padIntervalInDwords > 0 && padAmountInDwords > 0) {
103
+ int log2PadInterval = log2(padIntervalInDwords);
104
+ desc->group1_0 |= (1 << 20);
105
+ desc->group1_0 |= ((log2PadInterval - 1) << 22);
106
+ desc->group1_0 |= ((padAmountInDwords - 1) << 25);
107
+ }
108
+ desc->group1_1 = (shape[1] << 16);
109
+ desc->group1_2 = (shape[1] >> 16);
110
+ desc->group1_2 |= (shape[0] << 16);
111
+ desc->group1_3 = (shape[0] >> 16);
112
+ desc->group1_3 |= (blockSize1 << 16);
113
+ desc->group1_4 = (blockSize0 & 0xFFFF);
114
+ desc->group1_5 = strides[0];
115
+
116
+ return true;
117
+ }
118
+
119
+ // The list of paths to search for the HIP runtime library. The caller Python
120
+ // code should substitute the search path placeholder.
121
+ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
122
+
123
+ // The list of HIP dynamic library symbols and their signature we are interested
124
+ // in this file.
125
+ // |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
126
+ // |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
127
+ #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
128
+ FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
129
+ FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_t *prop, int deviceId) \
130
+ FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
131
+ unsigned int numOptions, hipJitOption *options, \
132
+ void **optionValues) \
133
+ FOR_EACH_ERR_FN(hipModuleGetFunction, hipFunction_t *function, \
134
+ hipModule_t module, const char *kname) \
135
+ FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
136
+ hipFunction_t function)
137
+
138
+ // HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
139
+ // 100000 + HIP_VERSION_PATCH.
140
+ #define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version) ((version) / 10000000)
141
+ #define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version) \
142
+ (((version) % 10000000) / 100000)
143
+ #define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version) ((version) % 100000)
144
+ #define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (6)
145
+
146
+ // #define TRITON_HIP_DRIVER_DBG_VERSION
147
+ #ifdef TRITON_HIP_DRIVER_DBG_VERSION
148
+ #define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
149
+ do { \
150
+ snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
151
+ TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
152
+ TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
153
+ TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
154
+ printf("%s\n", msgBuff); \
155
+ } while (0);
156
+ #else
157
+ #define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
158
+ do { \
159
+ (void)msgBuff; \
160
+ (void)(version); \
161
+ } while (0);
162
+ #endif
163
+
164
+ #define TRITON_HIP_MSG_BUFF_SIZE (1024U)
165
+
166
+ // The HIP symbol table for holding resolved dynamic library symbols.
167
+ struct HIPSymbolTable {
168
+ #define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
169
+ hipError_t (*hipSymbolName)(__VA_ARGS__);
170
+ #define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \
171
+ const char *(*hipSymbolName)(__VA_ARGS__);
172
+
173
+ HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
174
+ };
175
+
176
+ static struct HIPSymbolTable hipSymbolTable;
177
+
178
+ static int checkDriverVersion(void *lib) {
179
+ int hipVersion = -1;
180
+ const char *error = NULL;
181
+ typedef hipError_t (*hipDriverGetVersion_fn)(int *driverVersion);
182
+ hipDriverGetVersion_fn hipDriverGetVersion;
183
+ dlerror(); // Clear existing errors
184
+ hipDriverGetVersion =
185
+ (hipDriverGetVersion_fn)dlsym(lib, "hipDriverGetVersion");
186
+ error = dlerror();
187
+ if (error) {
188
+ PyErr_SetString(PyExc_RuntimeError,
189
+ "cannot query 'hipDriverGetVersion' from libamdhip64.so");
190
+ dlclose(lib);
191
+ return -1;
192
+ }
193
+
194
+ (void)hipDriverGetVersion(&hipVersion);
195
+ char msgBuff[TRITON_HIP_MSG_BUFF_SIZE] = {0};
196
+
197
+ const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(hipVersion);
198
+ if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION) {
199
+ const int hipMinVersion =
200
+ TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(hipVersion);
201
+ const int hipPatchVersion =
202
+ TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(hipVersion);
203
+ snprintf(msgBuff, sizeof(msgBuff),
204
+ "libamdhip64 version %d.%d.%d is not supported! Required major "
205
+ "version is >=%d.",
206
+ hipMajVersion, hipMinVersion, hipPatchVersion,
207
+ TRITON_HIP_DRIVER_REQ_MAJOR_VERSION);
208
+ PyErr_SetString(PyExc_RuntimeError, msgBuff);
209
+ dlclose(lib);
210
+ return -1;
211
+ }
212
+
213
+ TRITON_HIP_DRIVER_LOG_VERSION(hipVersion, msgBuff);
214
+
215
+ return hipVersion;
216
+ }
217
+
218
+ bool initSymbolTable() {
219
+ void *lib;
220
+
221
+ // Go through the list of search paths to dlopen the first HIP driver library.
222
+ int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
223
+ for (int i = 0; i < n; ++i) {
224
+ void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
225
+ if (handle) {
226
+ lib = handle;
227
+ // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
228
+ }
229
+ }
230
+
231
+ if (!lib) {
232
+ PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
233
+ return false;
234
+ }
235
+
236
+ int hipVersion = checkDriverVersion(lib);
237
+ if (hipVersion == -1)
238
+ return false;
239
+
240
+ const char *error = NULL;
241
+ typedef hipError_t (*hipGetProcAddress_fn)(
242
+ const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
243
+ hipDriverProcAddressQueryResult *symbolStatus);
244
+ hipGetProcAddress_fn hipGetProcAddress;
245
+ dlerror(); // Clear existing errors
246
+
247
+ *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
248
+ error = dlerror();
249
+ if (error) {
250
+ PyErr_SetString(PyExc_RuntimeError,
251
+ "cannot query 'hipGetProcAddress' from libamdhip64.so");
252
+ dlclose(lib);
253
+ return false;
254
+ }
255
+
256
+ // Resolve all symbols we are interested in.
257
+ uint64_t hipFlags = 0;
258
+ hipDriverProcAddressQueryResult symbolStatus;
259
+ hipError_t status = hipSuccess;
260
+ #define QUERY_EACH_FN(hipSymbolName, ...) \
261
+ status = hipGetProcAddress(#hipSymbolName, \
262
+ (void **)&hipSymbolTable.hipSymbolName, \
263
+ hipVersion, hipFlags, &symbolStatus); \
264
+ if (status != hipSuccess) { \
265
+ PyErr_SetString(PyExc_RuntimeError, \
266
+ "cannot get address for '" #hipSymbolName \
267
+ "' from libamdhip64.so"); \
268
+ dlclose(lib); \
269
+ return false; \
270
+ }
271
+
272
+ HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
273
+
274
+ return true;
275
+ }
276
+
277
+ static inline void gpuAssert(hipError_t code, const char *file, int line) {
278
+ {
279
+ if (code != HIP_SUCCESS) {
280
+ {
281
+ const char *prefix = "Triton Error [HIP]: ";
282
+ const char *str = hipSymbolTable.hipGetErrorString(code);
283
+ char err[TRITON_HIP_MSG_BUFF_SIZE] = {0};
284
+ snprintf(err, sizeof(err), "%s Code: %d, Messsage: %s", prefix, code,
285
+ str);
286
+ PyGILState_STATE gil_state;
287
+ gil_state = PyGILState_Ensure();
288
+ PyErr_SetString(PyExc_RuntimeError, err);
289
+ PyGILState_Release(gil_state);
290
+ }
291
+ }
292
+ }
293
+ }
294
+
295
+ #define HIP_CHECK(ans) \
296
+ { \
297
+ gpuAssert((ans), __FILE__, __LINE__); \
298
+ if (PyErr_Occurred()) \
299
+ return NULL; \
300
+ }
301
+
302
+ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
303
+ int device_id;
304
+ if (!PyArg_ParseTuple(args, "i", &device_id))
305
+ return NULL;
306
+
307
+ hipDeviceProp_t props;
308
+ HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
309
+
310
+ // create a struct to hold device properties
311
+ return Py_BuildValue(
312
+ "{s:i, s:i, s:i, s:i, s:i, s:i, s:s, s:i, s:i}", "max_shared_mem",
313
+ props.sharedMemPerBlock, "max_num_regs", props.regsPerBlock,
314
+ "multiprocessor_count", props.multiProcessorCount, "sm_clock_rate",
315
+ props.clockRate, "mem_clock_rate", props.memoryClockRate, "mem_bus_width",
316
+ props.memoryBusWidth, "arch", props.gcnArchName, "warpSize",
317
+ props.warpSize, "max_threads_per_sm", props.maxThreadsPerMultiProcessor);
318
+ }
319
+
320
+ static PyObject *loadBinary(PyObject *self, PyObject *args) {
321
+ const char *name;
322
+ const char *data;
323
+ Py_ssize_t data_size;
324
+ int shared;
325
+ int device;
326
+ if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
327
+ &device)) {
328
+ return NULL;
329
+ }
330
+
331
+ // set HIP options
332
+ hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
333
+ hipJitOptionErrorLogBuffer,
334
+ hipJitOptionInfoLogBufferSizeBytes,
335
+ hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
336
+ const unsigned int errbufsize = 8192;
337
+ const unsigned int logbufsize = 8192;
338
+ char _err[errbufsize];
339
+ char _log[logbufsize];
340
+ void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
341
+ (void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
342
+
343
+ // launch HIP Binary
344
+ hipModule_t mod;
345
+ hipFunction_t fun;
346
+ HIP_CHECK(hipSymbolTable.hipModuleLoadDataEx(&mod, data, 5, opt, optval))
347
+ HIP_CHECK(hipSymbolTable.hipModuleGetFunction(&fun, mod, name));
348
+
349
+ // get allocated registers and spilled registers from the function
350
+ int n_regs = 0;
351
+ int n_spills = 0;
352
+ int32_t n_max_threads = 0;
353
+ hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun);
354
+ hipSymbolTable.hipFuncGetAttribute(&n_spills,
355
+ HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
356
+ hipSymbolTable.hipFuncGetAttribute(
357
+ &n_max_threads, HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun);
358
+ n_spills /= 4;
359
+ if (PyErr_Occurred()) {
360
+ return NULL;
361
+ }
362
+ return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
363
+ n_spills, n_max_threads);
364
+ }
365
+
366
+ static PyObject *createTDMDescriptor(PyObject *self, PyObject *args) {
367
+ int elementBitWidth;
368
+ PyObject *blockSize;
369
+ int numWarps;
370
+ int padInterval;
371
+ int padAmount;
372
+ PyObject *shape;
373
+ PyObject *strides;
374
+ unsigned long long globalAddress;
375
+
376
+ if (!PyArg_ParseTuple(args, "iOiiiOOK", &elementBitWidth, &blockSize,
377
+ &numWarps, &padInterval, &padAmount, &shape, &strides,
378
+ &globalAddress)) {
379
+ return NULL;
380
+ }
381
+
382
+ PyTDMDescriptorObject *descObj = (PyTDMDescriptorObject *)PyObject_CallObject(
383
+ (PyObject *)&PyTDMDescriptorType, NULL);
384
+ if (!descObj)
385
+ return NULL;
386
+
387
+ PyObject *blockSizeFast = NULL;
388
+ PyObject *shapeFast = NULL;
389
+ PyObject *stridesFast = NULL;
390
+
391
+ uint32_t blockSizeInt[2];
392
+ uint32_t shapeInt[2];
393
+ uint32_t stridesInt[2];
394
+
395
+ blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
396
+ if (!blockSizeFast)
397
+ goto cleanup;
398
+ int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
399
+ if (rank != 2) {
400
+ PyErr_SetString(PyExc_RuntimeError, "rank must be 2");
401
+ goto cleanup;
402
+ }
403
+
404
+ for (int i = 0; i < rank; ++i) {
405
+ PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
406
+ if (!PyLong_Check(item)) {
407
+ PyErr_SetString(PyExc_TypeError, "block size must be an int");
408
+ goto cleanup;
409
+ }
410
+ blockSizeInt[i] = PyLong_AsLong(item);
411
+ }
412
+
413
+ shapeFast = PySequence_Fast(shape, "shape must be a sequence");
414
+ if (!shapeFast)
415
+ goto cleanup;
416
+
417
+ if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
418
+ PyErr_SetString(PyExc_RuntimeError, "rank mismatch");
419
+ goto cleanup;
420
+ }
421
+ for (int i = 0; i < rank; ++i) {
422
+ PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
423
+ if (!PyLong_Check(item)) {
424
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
425
+ goto cleanup;
426
+ }
427
+ shapeInt[i] = PyLong_AsLong(item);
428
+ }
429
+
430
+ stridesFast = PySequence_Fast(strides, "strides must be a sequence");
431
+ if (!stridesFast)
432
+ goto cleanup;
433
+
434
+ if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
435
+ PyErr_SetString(PyExc_RuntimeError, "rank mismatch");
436
+ goto cleanup;
437
+ }
438
+ for (int i = 0; i < rank; ++i) {
439
+ PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
440
+ if (!PyLong_Check(item)) {
441
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
442
+ goto cleanup;
443
+ }
444
+ stridesInt[i] = PyLong_AsLong(item);
445
+ }
446
+
447
+ Py_DECREF(blockSizeFast);
448
+ blockSizeFast = NULL;
449
+ Py_DECREF(shapeFast);
450
+ shapeFast = NULL;
451
+ Py_DECREF(stridesFast);
452
+ stridesFast = NULL;
453
+
454
+ bool success = encodeTDMDescriptor(
455
+ &descObj->desc, elementBitWidth, blockSizeInt, numWarps, padInterval,
456
+ padAmount, shapeInt, stridesInt, globalAddress, rank);
457
+ if (!success) {
458
+ PyErr_SetString(PyExc_RuntimeError, "Failed to encode TDM descriptor");
459
+ goto cleanup;
460
+ }
461
+
462
+ return (PyObject *)descObj;
463
+
464
+ cleanup:
465
+ Py_XDECREF(blockSizeFast);
466
+ Py_XDECREF(shapeFast);
467
+ Py_XDECREF(stridesFast);
468
+ Py_XDECREF(descObj);
469
+ return NULL;
470
+ }
471
+
472
+ static PyMethodDef ModuleMethods[] = {
473
+ {"load_binary", loadBinary, METH_VARARGS,
474
+ "Load provided hsaco into HIP driver"},
475
+ {"get_device_properties", getDeviceProperties, METH_VARARGS,
476
+ "Get the properties for a given device"},
477
+ {"create_tdm_descriptor", createTDMDescriptor, METH_VARARGS,
478
+ "create a host-side TDM descriptor"},
479
+ {NULL, NULL, 0, NULL} // sentinel
480
+ };
481
+
482
+ static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
483
+ NULL, // documentation
484
+ -1, // size
485
+ ModuleMethods};
486
+
487
+ PyMODINIT_FUNC PyInit_hip_utils(void) {
488
+ if (!initSymbolTable()) {
489
+ return NULL;
490
+ }
491
+
492
+ PyObject *m = PyModule_Create(&ModuleDef);
493
+ if (m == NULL) {
494
+ return NULL;
495
+ }
496
+ PyModule_AddFunctions(m, ModuleMethods);
497
+
498
+ if (PyType_Ready(&PyTDMDescriptorType) < 0)
499
+ return NULL;
500
+ Py_INCREF(&PyTDMDescriptorType);
501
+ PyModule_AddObject(m, "PyTDMDescriptor", (PyObject *)&PyTDMDescriptorType);
502
+
503
+ return m;
504
+ }
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/amd/driver.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import subprocess
4
+ import re
5
+ import triton
6
+ from pathlib import Path
7
+ from triton import knobs
8
+ from triton.backends.compiler import GPUTarget
9
+ from triton.backends.driver import GPUDriver
10
+ from triton.runtime import _allocation
11
+ from triton.runtime.build import compile_module_from_src
12
+
13
+ dirname = os.path.dirname(os.path.realpath(__file__))
14
+ include_dirs = [os.path.join(dirname, "include")]
15
+ PyTDMDescriptor = None
16
+
17
+
18
+ def _find_already_mmapped_dylib_on_linux(lib_name):
19
+ import platform
20
+ if platform.system() != 'Linux':
21
+ return None
22
+
23
+ # Use dl_iterate_phdr to walk through the list of shared libraries at runtime.
24
+ # See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details.
25
+
26
+ import ctypes
27
+ from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER
28
+
29
+ class DlPhdrInfo(ctypes.Structure):
30
+ _fields_ = [
31
+ ('dlpi_addr', c_void_p),
32
+ ('dlpi_name', c_char_p),
33
+ # We don't care about the remaining fields.
34
+ ]
35
+
36
+ # callback_t must use POINTER(c_char) to avoid copying.
37
+ callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char))
38
+
39
+ # Load libc and get the dl_iterate_phdr symbol.
40
+ try:
41
+ dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr
42
+ except Exception:
43
+ return None
44
+ # argtypes must use c_char_p to accept create_string_buffer.
45
+ dl_iterate_phdr.argtypes = [callback_t, c_char_p]
46
+ dl_iterate_phdr.restype = c_int
47
+
48
+ max_path_length = 4096
49
+ path = ctypes.create_string_buffer(max_path_length + 1)
50
+
51
+ # Define callback to get the loaded dylib path.
52
+ def callback(info, size, data):
53
+ dlpi_name = info.contents.dlpi_name
54
+ p = Path(os.fsdecode(dlpi_name))
55
+ if lib_name in p.name:
56
+ # Found the dylib; get its path.
57
+ ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name)))
58
+ return 1
59
+ return 0
60
+
61
+ if dl_iterate_phdr(callback_t(callback), path):
62
+ return os.fsdecode(ctypes.string_at(path))
63
+ return None
64
+
65
+
66
+ @functools.lru_cache()
67
+ def _get_path_to_hip_runtime_dylib():
68
+ lib_name = "libamdhip64.so"
69
+
70
+ # If we are told explicitly what HIP runtime dynamic library to use, obey that.
71
+ if env_libhip_path := knobs.amd.libhip_path:
72
+ if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path):
73
+ return env_libhip_path
74
+ raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
75
+
76
+ # If the shared object is already mmapped to address space, use it.
77
+ mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name)
78
+ if mmapped_path:
79
+ if os.path.exists(mmapped_path):
80
+ return mmapped_path
81
+ raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
82
+
83
+ paths = []
84
+
85
+ # Check backend
86
+ local_lib = os.path.join(os.path.dirname(__file__), "lib", lib_name)
87
+ if os.path.exists(local_lib):
88
+ return local_lib
89
+ paths.append(local_lib)
90
+
91
+ import site
92
+ # First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
93
+ # that we run Triton together with PyTorch. This makes sure we use the same dynamic
94
+ # library to avoid version mismatch.
95
+ site_packages = site.getsitepackages()
96
+ user_site = site.getusersitepackages()
97
+ if site.ENABLE_USER_SITE: # ENABLE_USER_SITE is initialized in getusersitepackages()
98
+ site_packages = [user_site] + site_packages
99
+ for path in site_packages:
100
+ path = os.path.join(path, "torch", "lib", lib_name)
101
+ if os.path.exists(path):
102
+ return path
103
+ paths.append(path)
104
+
105
+ # Then try to see if developer provides a HIP runtime dynamic library using LD_LIBARAY_PATH.
106
+ env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
107
+ if env_ld_library_path:
108
+ for d in env_ld_library_path.split(":"):
109
+ f = os.path.join(d, lib_name)
110
+ if os.path.exists(f):
111
+ return f
112
+ paths.append(f)
113
+
114
+ # HIP_PATH should point to HIP SDK root if set
115
+ env_hip_path = os.getenv("HIP_PATH")
116
+ if env_hip_path:
117
+ hip_lib_path = os.path.join(env_hip_path, "lib", lib_name)
118
+ if os.path.exists(hip_lib_path):
119
+ return hip_lib_path
120
+ paths.append(hip_lib_path)
121
+
122
+ # if available, `hipconfig --path` prints the HIP SDK root
123
+ try:
124
+ hip_root = subprocess.check_output(["hipconfig", "--path"]).decode().strip()
125
+ if hip_root:
126
+ hip_lib_path = os.path.join(hip_root, "lib", lib_name)
127
+ if os.path.exists(hip_lib_path):
128
+ return hip_lib_path
129
+ paths.append(hip_lib_path)
130
+ except (subprocess.CalledProcessError, FileNotFoundError):
131
+ # hipconfig may not be available
132
+ pass
133
+
134
+ # ROCm lib dir based on env var
135
+ env_rocm_path = os.getenv("ROCM_PATH")
136
+ if env_rocm_path:
137
+ rocm_lib_path = os.path.join(env_rocm_path, "lib", lib_name)
138
+ if os.path.exists(rocm_lib_path):
139
+ return rocm_lib_path
140
+ paths.append(rocm_lib_path)
141
+
142
+ # Afterwards try to search the loader dynamic library resolution paths.
143
+ libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
144
+ # each line looks like the following:
145
+ # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6
146
+ # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so
147
+ locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)]
148
+ for loc in locs:
149
+ if os.path.exists(loc):
150
+ return loc
151
+ paths.append(loc)
152
+
153
+ # As a last resort, guess if we have it in some common installation path.
154
+ common_install_path = os.path.join('/opt/rocm/lib/', lib_name)
155
+ if os.path.exists(common_install_path):
156
+ return common_install_path
157
+ paths.append(common_install_path)
158
+
159
+ raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
160
+
161
+
162
+ class HIPUtils(object):
163
+
164
+ def __new__(cls):
165
+ if not hasattr(cls, "instance"):
166
+ cls.instance = super(HIPUtils, cls).__new__(cls)
167
+ return cls.instance
168
+
169
+ def __init__(self):
170
+ libhip_path = _get_path_to_hip_runtime_dylib()
171
+ src = Path(os.path.join(dirname, "driver.c")).read_text()
172
+ # Just do a simple search and replace here instead of templates or format strings.
173
+ # This way we don't need to escape-quote C code curly brackets and we can replace
174
+ # exactly once.
175
+ src = src.replace('/*py_libhip_search_path*/', libhip_path, 1)
176
+ mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
177
+ self.load_binary = mod.load_binary
178
+ self.get_device_properties = mod.get_device_properties
179
+ self.create_tdm_descriptor = mod.create_tdm_descriptor
180
+ global PyTDMDescriptor
181
+ PyTDMDescriptor = mod.PyTDMDescriptor
182
+
183
+
184
+ # -------------------- Launcher ----------------------------
185
+ def ty_to_cpp(ty):
186
+ if ty.startswith('*'):
187
+ return "hipDeviceptr_t"
188
+ if ty == "tensordesc":
189
+ return "TDMDescriptor"
190
+ return {
191
+ "i1": "int8_t",
192
+ "i8": "int8_t",
193
+ "i16": "int16_t",
194
+ "i32": "int32_t",
195
+ "i64": "int64_t",
196
+ "u1": "uint8_t",
197
+ "u8": "uint8_t",
198
+ "u16": "uint16_t",
199
+ "u32": "uint32_t",
200
+ "u64": "uint64_t",
201
+ "fp16": "double",
202
+ "bf16": "double",
203
+ "fp32": "double",
204
+ "f32": "double",
205
+ "fp64": "double",
206
+ }[ty]
207
+
208
+
209
+ FLOAT_STORAGE_TYPE = {
210
+ "fp16": "uint16_t",
211
+ "bf16": "uint16_t",
212
+ "fp32": "uint32_t",
213
+ "f32": "uint32_t",
214
+ "fp64": "uint64_t",
215
+ }
216
+ FLOAT_PACK_FUNCTION = {
217
+ "fp16": "pack_fp16",
218
+ "bf16": "pack_bf16",
219
+ "fp32": "pack_fp32",
220
+ "f32": "pack_fp32",
221
+ "fp64": "pack_fp64",
222
+ }
223
+
224
+ _BASE_ARGS_FORMAT = "piiiKKOOOOO"
225
+
226
+
227
+ def make_launcher(constants, signature, warp_size, tensordesc_meta):
228
+
229
+ def _expand_signature(signature):
230
+ output = []
231
+ tensordesc_idx = 0
232
+ for sig in signature:
233
+ if isinstance(sig, str) and sig.startswith("tensordesc"):
234
+ meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
235
+ tensordesc_idx += 1
236
+
237
+ match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
238
+ dtype = match.group(1)
239
+ shape = match.group(2)
240
+ ndim = shape.count(",") + 1
241
+
242
+ # If there is no descriptor's metadata, the descriptor has been decomposed to base pointer, shape and strides
243
+ if meta is None:
244
+ output.append("*" + dtype)
245
+ for _ in range(2 * ndim):
246
+ output.append("i64")
247
+ output.append("i1")
248
+ else:
249
+ output.append("tensordesc")
250
+
251
+ for _ in range(ndim):
252
+ output.append("i32")
253
+ for _ in range(ndim):
254
+ output.append("i64")
255
+ else:
256
+ output.append(sig)
257
+
258
+ return output
259
+
260
+ def _serialize_signature(sig):
261
+ if isinstance(sig, tuple):
262
+ return ','.join(map(_serialize_signature, sig))
263
+ return sig
264
+
265
+ def _extracted_type(ty):
266
+ if isinstance(ty, tuple):
267
+ val = ','.join(map(_extracted_type, ty))
268
+ return f"[{val}]"
269
+ if ty.startswith("*") or ty.startswith("tensordesc"):
270
+ return "PyObject*"
271
+ if ty == "constexpr":
272
+ return "PyObject*"
273
+ return ty_to_cpp(ty)
274
+
275
+ def format_of(ty):
276
+ if isinstance(ty, tuple):
277
+ val = ''.join(map(format_of, ty))
278
+ return f"({val})"
279
+ if ty.startswith("*") or ty.startswith("tensordesc"):
280
+ return "O"
281
+ if ty == "constexpr":
282
+ return "O"
283
+ return {
284
+ "double": "d",
285
+ "long": "l",
286
+ "int8_t": "b",
287
+ "int16_t": "h",
288
+ "int32_t": "i",
289
+ "int64_t": "L",
290
+ "uint8_t": "B",
291
+ "uint16_t": "H",
292
+ "uint32_t": "I",
293
+ "uint64_t": "K",
294
+ }[ty_to_cpp(ty)]
295
+
296
+ signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
297
+
298
+ args_format = ''.join([format_of(ty) for ty in signature.values()])
299
+ format = _BASE_ARGS_FORMAT + args_format
300
+ signature = ','.join(map(_serialize_signature, signature.values()))
301
+ signature = list(filter(bool, signature.split(',')))
302
+ signature = {i: s for i, s in enumerate(signature)}
303
+ args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
304
+ # Record the end of regular arguments;
305
+ # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
306
+ arg_decl_list = []
307
+ for i, ty in signature.items():
308
+ if ty == "constexpr":
309
+ continue
310
+ if ty in FLOAT_STORAGE_TYPE:
311
+ arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
312
+ else:
313
+ arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
314
+ arg_decls = ', '.join(arg_decl_list)
315
+ internal_args_list = []
316
+ for i, ty in signature.items():
317
+ if ty.startswith("*"):
318
+ internal_args_list.append(f"ptr_info{i}.dev_ptr")
319
+ elif ty.startswith("tensordesc"):
320
+ internal_args_list.append(f"*desc{i}")
321
+ elif ty in FLOAT_STORAGE_TYPE:
322
+ internal_args_list.append(f"_arg{i}_storage")
323
+ elif ty != "constexpr":
324
+ internal_args_list.append(f"_arg{i}")
325
+
326
+ newline = '\n '
327
+ ptr_decls = [
328
+ f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
329
+ for i, ty in signature.items()
330
+ if ty.startswith("*")
331
+ ]
332
+ tensor_desc_decls = [
333
+ f"TDMDescriptor* desc{i} = getTDMDescriptor(_arg{i}, {i});" for i, ty in signature.items()
334
+ if ty.startswith("tensordesc")
335
+ ]
336
+ float_storage_decls = [
337
+ f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
338
+ for i, ty in signature.items()
339
+ if ty in FLOAT_STORAGE_TYPE
340
+ ]
341
+
342
+ libhip_path = _get_path_to_hip_runtime_dylib()
343
+
344
+ # generate glue code
345
+ params = list(range(len(signature)))
346
+ params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
347
+ params.append("&global_scratch")
348
+ params.append("&profile_scratch")
349
+ src = f"""
350
+ #define __HIP_PLATFORM_AMD__
351
+ #include <hip/hip_runtime.h>
352
+ #include <hip/hip_runtime_api.h>
353
+ #include <Python.h>
354
+ #include <dlfcn.h>
355
+ #include <stdbool.h>
356
+ #include <dlfcn.h>
357
+
358
+ typedef struct {{
359
+ uint32_t group0_0;
360
+ uint32_t group0_1;
361
+ uint32_t group0_2;
362
+ uint32_t group0_3;
363
+ uint32_t group1_0;
364
+ uint32_t group1_1;
365
+ uint32_t group1_2;
366
+ uint32_t group1_3;
367
+ uint32_t group1_4;
368
+ uint32_t group1_5;
369
+ uint32_t group1_6;
370
+ uint32_t group1_7;
371
+ }} TDMDescriptor;
372
+
373
+ typedef struct {{
374
+ PyObject_HEAD;
375
+ TDMDescriptor desc;
376
+ }} PyTDMDescriptorObject;
377
+
378
+ // The list of paths to search for the HIP runtime library. The caller Python
379
+ // code should substitute the search path placeholder.
380
+ static const char *hipLibSearchPaths[] = {{"{libhip_path}"}};
381
+
382
+ // The list of HIP dynamic library symbols and their signature we are interested
383
+ // in this file.
384
+ #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\
385
+ FOR_EACH_STR_FN(hipGetLastError, true) \\
386
+ FOR_EACH_STR_FN(hipGetErrorString, true, hipError_t hipError) \\
387
+ FOR_EACH_ERR_FN(hipDrvLaunchKernelEx, false, \\
388
+ const HIP_LAUNCH_CONFIG *config, \\
389
+ hipFunction_t f, \\
390
+ void **kernelParams, \\
391
+ void **extra) \\
392
+ FOR_EACH_ERR_FN(hipModuleLaunchKernel, true, hipFunction_t f, \\
393
+ unsigned int gridDimX, unsigned int gridDimY, \\
394
+ unsigned int gridDimZ, unsigned int blockDimX, \\
395
+ unsigned int blockDimY, unsigned int blockDimZ, \\
396
+ unsigned int sharedMemBytes, hipStream_t stream, \\
397
+ void **kernelParams, void **extra) \\
398
+ FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, true, hipFunction_t f, \\
399
+ unsigned int gridDimX, unsigned int gridDimY, \\
400
+ unsigned int gridDimZ, unsigned int blockDimX, \\
401
+ unsigned int blockDimY, unsigned int blockDimZ, \\
402
+ unsigned int sharedMemBytes, hipStream_t stream, \\
403
+ void **kernelParams, void **extra) \\
404
+ FOR_EACH_ERR_FN(hipPointerGetAttribute, true, void *data, \\
405
+ hipPointer_attribute attribute, hipDeviceptr_t ptr)
406
+
407
+ // The HIP symbol table for holding resolved dynamic library symbols.
408
+ struct HIPSymbolTable {{
409
+ #define DEFINE_EACH_ERR_FIELD(hipSymbolName, required, ...) \\
410
+ hipError_t (*hipSymbolName)(__VA_ARGS__);
411
+ #define DEFINE_EACH_STR_FIELD(hipSymbolName, required, ...) \\
412
+ const char *(*hipSymbolName)(__VA_ARGS__);
413
+
414
+ HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
415
+ }};
416
+
417
+ static struct HIPSymbolTable hipSymbolTable;
418
+
419
+ bool initSymbolTable() {{
420
+ // Use the HIP runtime library loaded into the existing process if it exits.
421
+ void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
422
+
423
+ // Otherwise, go through the list of search paths to dlopen the first HIP
424
+ // driver library.
425
+ if (!lib) {{
426
+ int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
427
+ for (int i = 0; i < n; ++i) {{
428
+ void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
429
+ if (handle) {{
430
+ lib = handle;
431
+ }}
432
+ }}
433
+ }}
434
+ if (!lib) {{
435
+ PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
436
+ return false;
437
+ }}
438
+
439
+ typedef hipError_t (*hipGetProcAddress_fn)(
440
+ const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
441
+ hipDriverProcAddressQueryResult *symbolStatus);
442
+ hipGetProcAddress_fn hipGetProcAddress;
443
+ dlerror(); // Clear existing errors
444
+ const char *error = NULL;
445
+ *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
446
+ error = dlerror();
447
+ if (error) {{
448
+ PyErr_SetString(PyExc_RuntimeError,
449
+ "cannot query 'hipGetProcAddress' from libamdhip64.so");
450
+ dlclose(lib);
451
+ return false;
452
+ }}
453
+
454
+ // Resolve all symbols we are interested in.
455
+ int hipVersion = HIP_VERSION;
456
+ uint64_t hipFlags = 0;
457
+ hipDriverProcAddressQueryResult symbolStatus;
458
+ hipError_t status = hipSuccess;
459
+ #define QUERY_EACH_FN(hipSymbolName, required, ...) \
460
+ status = hipGetProcAddress(#hipSymbolName, \
461
+ (void **)&hipSymbolTable.hipSymbolName, \
462
+ hipVersion, hipFlags, &symbolStatus); \
463
+ if (required && status != hipSuccess) {{ \
464
+ PyErr_SetString(PyExc_RuntimeError, \
465
+ "cannot get address for '" #hipSymbolName \
466
+ "' from libamdhip64.so"); \
467
+ dlclose(lib); \
468
+ return false; \
469
+ }}
470
+
471
+ HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
472
+
473
+ return true;
474
+ }}
475
+
476
+ static inline void gpuAssert(hipError_t code, const char *file, int line)
477
+ {{
478
+ if (code != HIP_SUCCESS)
479
+ {{
480
+ const char* prefix = "Triton Error [HIP]: ";
481
+ const char* str = hipSymbolTable.hipGetErrorString(code);
482
+ char err[1024] = {{0}};
483
+ snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
484
+ PyErr_SetString(PyExc_RuntimeError, err);
485
+ }}
486
+ }}
487
+
488
+ #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
489
+
490
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
491
+ if (gridX * gridY * gridZ == 0)
492
+ return;
493
+ hipDeviceptr_t global_scratch = 0;
494
+ void *params[] = {{ {', '.join(params)} }};
495
+ if(num_ctas > 1) {{
496
+ if (!hipSymbolTable.hipDrvLaunchKernelEx) {{
497
+ PyErr_SetString(PyExc_RuntimeError, "missing hipDrvLaunchKernelEx symbol; please update HIP runtime");
498
+ return;
499
+ }}
500
+
501
+ hipLaunchAttribute attributes[2];
502
+ // Attribute0: Cluster dimensions
503
+ attributes[0].id = 4;
504
+ int *cluster_dims = (int*)attributes[0].val.pad;
505
+ cluster_dims[0] = num_ctas;
506
+ cluster_dims[1] = 1;
507
+ cluster_dims[2] = 1;
508
+ // Attribute1: Cooperative launch
509
+ attributes[1].id = hipLaunchAttributeCooperative;
510
+ attributes[1].val.cooperative = launch_cooperative_grid;
511
+
512
+ HIP_LAUNCH_CONFIG config = {{
513
+ gridX * num_ctas, gridY, gridZ, // Grid size
514
+ {warp_size} * num_warps, 1, 1, // Block size
515
+ shared_memory, stream,
516
+ attributes, 2 // Number of attributes
517
+ }};
518
+ HIP_CHECK(hipSymbolTable.hipDrvLaunchKernelEx(&config, function, params, 0));
519
+ return;
520
+ }}
521
+ else if (launch_cooperative_grid) {{
522
+ HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
523
+ return;
524
+ }}
525
+ else {{
526
+ HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
527
+ }}
528
+ }}
529
+
530
+ typedef struct _DevicePtrInfo {{
531
+ hipDeviceptr_t dev_ptr;
532
+ bool valid;
533
+ }} DevicePtrInfo;
534
+
535
+ static PyObject* data_ptr_str = NULL;
536
+ static PyObject* py_tdm_descriptor_type = NULL;
537
+
538
+ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
539
+ DevicePtrInfo ptr_info;
540
+ hipError_t status = hipSuccess;
541
+ ptr_info.dev_ptr = 0;
542
+ ptr_info.valid = true;
543
+ if (PyLong_Check(obj)) {{
544
+ ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
545
+ return ptr_info;
546
+ }}
547
+ if (obj == Py_None) {{
548
+ // valid nullptr
549
+ return ptr_info;
550
+ }}
551
+ PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
552
+ if (!ret) {{
553
+ PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
554
+ ptr_info.valid = false;
555
+ goto cleanup;
556
+ }}
557
+ if (!PyLong_Check(ret)) {{
558
+ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
559
+ ptr_info.valid = false;
560
+ goto cleanup;
561
+ }}
562
+ ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
563
+ if (!ptr_info.dev_ptr)
564
+ goto cleanup;
565
+ uint64_t dev_ptr;
566
+ status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
567
+ if (status == hipErrorInvalidValue) {{
568
+ PyErr_Format(PyExc_ValueError,
569
+ "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
570
+ ptr_info.valid = false;
571
+ // Clear and ignore HIP error
572
+ (void)hipSymbolTable.hipGetLastError();
573
+ }}
574
+ ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
575
+ cleanup:
576
+ Py_DECREF(ret);
577
+ return ptr_info;
578
+ }}
579
+
580
+ static inline TDMDescriptor* getTDMDescriptor(PyObject* obj, int idx) {{
581
+ if (Py_TYPE(obj) != (PyTypeObject*)py_tdm_descriptor_type) {{
582
+ PyErr_Format(PyExc_TypeError, "object must be of type PyTDMDescriptor, got %s", Py_TYPE(obj)->tp_name);
583
+ return NULL;
584
+ }}
585
+
586
+ TDMDescriptor* desc = &((PyTDMDescriptorObject*)obj)->desc;
587
+ return desc;
588
+ }}
589
+
590
+ static uint16_t pack_fp16(double f) {{
591
+ uint16_t result;
592
+ // from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
593
+ #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
594
+ _PyFloat_Pack2(f, (unsigned char*)&result, 1);
595
+ #else
596
+ PyFloat_Pack2(f, (char*)&result, 1);
597
+ #endif
598
+ return result;
599
+ }}
600
+
601
+ static uint16_t pack_bf16(double f) {{
602
+ float f32 = (float)f;
603
+ uint32_t u32 = *(uint32_t*)&f32;
604
+ return (uint16_t)(u32 >> 16);
605
+ }}
606
+
607
+ static uint32_t pack_fp32(double f) {{
608
+ float f32 = (float)f;
609
+ return *(uint32_t*)&f32;
610
+ }}
611
+
612
+ static uint64_t pack_fp64(double f) {{
613
+ return *(uint64_t*)&f;
614
+ }}
615
+
616
+ static PyObject* launch(PyObject* self, PyObject* args) {{
617
+ int gridX, gridY, gridZ;
618
+ uint64_t _stream;
619
+ uint64_t _function;
620
+ int launch_cooperative_grid;
621
+ PyObject *profile_scratch_obj = NULL;
622
+ PyObject *launch_enter_hook = NULL;
623
+ PyObject *launch_exit_hook = NULL;
624
+ PyObject *kernel_metadata = NULL;
625
+ PyObject *launch_metadata = NULL;
626
+ {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
627
+ if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
628
+ &gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
629
+ &kernel_metadata, &launch_metadata,
630
+ &launch_enter_hook, &launch_exit_hook {args_list})) {{
631
+ return NULL;
632
+ }}
633
+
634
+ // extract kernel metadata
635
+ int num_warps, num_ctas, shared_memory;
636
+ if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
637
+ return NULL;
638
+ }}
639
+ // extract launch metadata
640
+ if (launch_enter_hook != Py_None){{
641
+ PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
642
+ if (!ret)
643
+ return NULL;
644
+ Py_DECREF(ret);
645
+ }}
646
+
647
+ hipDeviceptr_t profile_scratch = 0;
648
+ if (profile_scratch_obj != Py_None) {{
649
+ DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
650
+ if (!profile_scratch_info.valid) {{
651
+ return NULL;
652
+ }}
653
+ profile_scratch = profile_scratch_info.dev_ptr;
654
+ }}
655
+
656
+ // raise exception asap
657
+ {newline.join(tensor_desc_decls)}
658
+ {newline.join(ptr_decls)}
659
+ {newline.join(float_storage_decls)}
660
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
661
+
662
+ if(launch_exit_hook != Py_None){{
663
+ PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
664
+ if (!ret)
665
+ return NULL;
666
+ Py_DECREF(ret);
667
+ }}
668
+
669
+ if(PyErr_Occurred()) {{
670
+ return NULL;
671
+ }}
672
+ Py_RETURN_NONE;
673
+ }}
674
+
675
+ static PyMethodDef ModuleMethods[] = {{
676
+ {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
677
+ {{NULL, NULL, 0, NULL}} // sentinel
678
+ }};
679
+
680
+ static struct PyModuleDef ModuleDef = {{
681
+ PyModuleDef_HEAD_INIT,
682
+ \"__triton_launcher\",
683
+ NULL, //documentation
684
+ -1, //size
685
+ ModuleMethods
686
+ }};
687
+
688
+ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
689
+ if (!initSymbolTable()) {{
690
+ return NULL;
691
+ }}
692
+ PyObject *m = PyModule_Create(&ModuleDef);
693
+ if(m == NULL) {{
694
+ return NULL;
695
+ }}
696
+ data_ptr_str = PyUnicode_InternFromString("data_ptr");
697
+ if(data_ptr_str == NULL) {{
698
+ return NULL;
699
+ }}
700
+ PyObject* driver_mod = PyImport_ImportModule("triton.backends.amd.driver");
701
+ if (driver_mod == NULL) {{
702
+ return NULL;
703
+ }}
704
+ py_tdm_descriptor_type = PyObject_GetAttrString(driver_mod, "PyTDMDescriptor");
705
+ if (py_tdm_descriptor_type == NULL) {{
706
+ return NULL;
707
+ }}
708
+
709
+ PyModule_AddFunctions(m, ModuleMethods);
710
+ return m;
711
+ }}
712
+ """
713
+ return src
714
+
715
+
716
+ def make_tensordesc_arg(arg, kernel_metadata, tensordesc_metadata):
717
+ """
718
+ Translate a tensor descriptor argument into the appropriate list of kernel
719
+ arguments. If `tensordesc_metadata` is provided, we will create a
720
+ TDMDescriptor object. Otherwise, we decompose the tensor descriptor into
721
+ base pointer, shape, strides, and padding flag. In both cases, we append the
722
+ shape and strides at the end to match the expected kernel signature.
723
+ """
724
+
725
+ if tensordesc_metadata is None:
726
+ # Currently the host side tensor descriptors get decomposed in
727
+ # the frontend to tensor desc, shape, and strides. We have no
728
+ # way to use these shape and strides when processing tensor
729
+ # descriptors which is why we provide our own decomposition
730
+ # above. Sadly this means we have to pass the shape and strides
731
+ # twice.
732
+ return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
733
+
734
+ shape = arg.shape
735
+ strides = arg.strides
736
+ base = arg.base.data_ptr()
737
+
738
+ assert "elem_bits" in tensordesc_metadata and "block_size" in tensordesc_metadata
739
+ elem_bits = tensordesc_metadata["elem_bits"]
740
+ block_size = tensordesc_metadata["block_size"]
741
+ pad_interval, pad_amount = 0, 0
742
+ interval_padding_pairs = tensordesc_metadata.get("interval_padding_pairs", [])
743
+ if interval_padding_pairs:
744
+ assert len(interval_padding_pairs) == 1 and len(interval_padding_pairs[0]) == 2
745
+ pad_interval, pad_amount = interval_padding_pairs[0]
746
+ num_warps = kernel_metadata[0]
747
+
748
+ driver = triton.runtime.driver.active
749
+ assert isinstance(driver, HIPDriver)
750
+
751
+ desc = driver.utils.create_tdm_descriptor(elem_bits, block_size, num_warps, pad_interval, pad_amount, shape,
752
+ strides, base)
753
+
754
+ return [desc, *shape, *strides]
755
+
756
+
757
+ def wrap_handle_tensordesc(launcher, signature, tensordesc_metadata):
758
+ """
759
+ Wrap a kernel launcher function to handle tensor descriptor arguments.
760
+ Use the provided `tensordesc_metadata` to determine whether to create
761
+ TDMDescriptor objects or decompose the tensor descriptors.
762
+
763
+ Args:
764
+ launcher (callable): The original kernel launcher function.
765
+ signature (Dict[int, str]): The kernel signature mapping argument indices to types.
766
+ tensordesc_metadata (List[Dict] or None): The list of tensor descriptor metadata, following the order
767
+ of tensor descriptor arguments. If None, decompose tensor descriptors.
768
+ Returns:
769
+ launcher (callable): The wrapped kernel launcher function.
770
+ """
771
+
772
+ has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
773
+ if not has_tensor_desc_arg:
774
+ return launcher
775
+
776
+ tensordesc_indices = set(
777
+ [i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
778
+ assert not tensordesc_metadata or len(tensordesc_metadata) == len(tensordesc_indices)
779
+ if not tensordesc_metadata:
780
+ tensordesc_metadata = [None] * len(tensordesc_indices)
781
+
782
+ def inner(*args):
783
+ meta_args = args[:len(_BASE_ARGS_FORMAT)]
784
+ raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
785
+ final_args = []
786
+ tensordesc_idx = 0
787
+ for i, arg in enumerate(raw_kernel_args):
788
+ if i in tensordesc_indices:
789
+ tensordesc_args = make_tensordesc_arg(arg, meta_args[7], # kernel_metadata
790
+ tensordesc_metadata[tensordesc_idx])
791
+ final_args.extend(tensordesc_args)
792
+ tensordesc_idx += 1
793
+ else:
794
+ final_args.append(arg)
795
+ return launcher(*meta_args, *final_args)
796
+
797
+ return inner
798
+
799
+
800
+ class HIPLauncher(object):
801
+
802
+ def __init__(self, src, metadata):
803
+ constants = src.constants if hasattr(src, "constants") else dict()
804
+ arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
805
+ constants = {arg_idx(idx): value for idx, value in constants.items()}
806
+ signature = {idx: value for idx, value in src.signature.items()}
807
+ tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
808
+ src = make_launcher(constants, signature, metadata.warp_size, tensordesc_meta)
809
+ mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
810
+ self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
811
+ self.launch_cooperative_grid = metadata.launch_cooperative_grid
812
+ self.profile_scratch_size = metadata.profile_scratch_size
813
+ self.profile_scratch_align = metadata.profile_scratch_align
814
+
815
+ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
816
+
817
+ def allocate_scratch(size, align, allocator):
818
+ if size > 0:
819
+ grid_size = gridX * gridY * gridZ
820
+ alloc_size = grid_size * size
821
+ alloc_fn = allocator.get()
822
+ return alloc_fn(alloc_size, align, stream)
823
+ return None
824
+
825
+ profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
826
+ _allocation._profile_allocator)
827
+
828
+ self.launch(self.launch_cooperative_grid, gridX, gridY, gridZ, stream, function, profile_scratch, *args)
829
+
830
+
831
+ class HIPDriver(GPUDriver):
832
+
833
+ def __init__(self):
834
+ super().__init__()
835
+ self.utils = HIPUtils()
836
+ self.launcher_cls = HIPLauncher
837
+
838
+ def get_device_interface(self):
839
+ import torch
840
+ return torch.cuda
841
+
842
+ @staticmethod
843
+ def is_active():
844
+ try:
845
+ import torch
846
+ return torch.cuda.is_available() and (torch.version.hip is not None)
847
+ except ImportError:
848
+ return False
849
+
850
+ def map_python_to_cpp_type(self, ty: str) -> str:
851
+ return ty_to_cpp(ty)
852
+
853
+ def get_current_target(self):
854
+ device = self.get_current_device()
855
+ device_properties = self.utils.get_device_properties(device)
856
+ arch = knobs.runtime.override_arch or device_properties['arch']
857
+ warp_size = device_properties['warpSize']
858
+ return GPUTarget("hip", arch.split(':')[0], warp_size)
859
+
860
+ def get_active_torch_device(self):
861
+ import torch
862
+ # when using hip devices, the device string in pytorch is "cuda"
863
+ return torch.device("cuda", self.get_current_device())
864
+
865
+ def get_benchmarker(self):
866
+ from triton.testing import do_bench
867
+ return do_bench
868
+
869
+ def get_empty_cache_for_benchmark(self):
870
+ import torch
871
+
872
+ # It's the same as the Nvidia backend.
873
+ cache_size = 256 * 1024 * 1024
874
+ return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
875
+
876
+ def clear_cache(self, cache):
877
+ cache.zero_()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from triton.backends.compiler import BaseBackend, GPUTarget, Language
2
+ from triton._C.libtriton import ir, passes, llvm, nvidia
3
+ from triton import knobs
4
+ from triton.runtime.errors import PTXASError
5
+
6
+ from dataclasses import dataclass
7
+ import functools
8
+ from typing import Any, Dict, Tuple, Optional
9
+ from types import ModuleType
10
+ import hashlib
11
+ import re
12
+ import tempfile
13
+ import signal
14
+ import os
15
+ import subprocess
16
+ from pathlib import Path
17
+
18
+
19
+ def min_dot_size(target: GPUTarget):
20
+
21
+ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
22
+ lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
23
+ rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
24
+ assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
25
+ # For small M/N the input we can still use tensorcores with padding.
26
+ if lhs_bitwidth == 8:
27
+ return (1, 1, 32)
28
+ else:
29
+ return (1, 1, 16)
30
+
31
+ return check_dot_compatibility
32
+
33
+
34
+ def get_ptxas(arch: int) -> knobs.NvidiaTool:
35
+ return knobs.nvidia.ptxas_blackwell if arch >= 100 else knobs.nvidia.ptxas
36
+
37
+
38
+ @functools.lru_cache()
39
+ def get_ptxas_version(arch: int = 80):
40
+ mock_ver = knobs.nvidia.mock_ptx_version
41
+ if mock_ver is not None:
42
+ return mock_ver # This is not really a version of ptxas, but it is good enough for testing
43
+ version = subprocess.check_output([get_ptxas(arch).path, "--version"]).decode("utf-8")
44
+ return version
45
+
46
+
47
+ @functools.lru_cache()
48
+ def ptx_get_version(cuda_version) -> int:
49
+ '''
50
+ Get the highest PTX version supported by the current CUDA driver.
51
+ '''
52
+ assert isinstance(cuda_version, str)
53
+ major, minor = map(int, cuda_version.split('.'))
54
+ if major == 12:
55
+ if minor < 6:
56
+ return 80 + minor
57
+ else:
58
+ return 80 + minor - 1
59
+ if major == 11:
60
+ return 70 + minor
61
+ if major == 10:
62
+ return 63 + minor
63
+
64
+ if major >= 13:
65
+ base_ptx = 90
66
+ return base_ptx + (major - 13) * 10 + minor
67
+
68
+ raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
69
+
70
+
71
+ def get_ptx_version_from_options(options, arch: int):
72
+ ptx_version = options.ptx_version
73
+ if ptx_version is None:
74
+ cuda_version = get_ptxas(arch).version
75
+ ptx_version = ptx_get_version(cuda_version)
76
+ return ptx_version
77
+
78
+
79
+ @functools.lru_cache()
80
+ def get_features(options, arch: int):
81
+ ptx_version = get_ptx_version_from_options(options, arch)
82
+
83
+ # PTX 8.6 is the max version supported by llvm c1188642.
84
+ #
85
+ # To check if a newer PTX version is supported, increase this value
86
+ # and run a test. If it's not supported, LLVM will print a warning
87
+ # like "+ptx8.4 is not a recognized feature for this target".
88
+ llvm_ptx_version = min(86, ptx_version)
89
+ features = f'+ptx{llvm_ptx_version}'
90
+ return features
91
+
92
+
93
+ @functools.lru_cache(None)
94
+ def file_hash(path):
95
+ with open(path, "rb") as f:
96
+ return hashlib.sha256(f.read()).hexdigest()
97
+
98
+
99
+ def sm_arch_from_capability(capability: int):
100
+ # TODO: Handle non-"a" sms
101
+ suffix = "a" if capability >= 90 else ""
102
+ return f"sm_{capability}{suffix}"
103
+
104
+
105
+ @dataclass(frozen=True)
106
+ class CUDAOptions:
107
+ num_warps: int = 4
108
+ num_ctas: int = 1
109
+ num_stages: int = 3
110
+ warp_size: int = 32
111
+ # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
112
+ # maximum number of 32-bit registers used by one thread.
113
+ maxnreg: Optional[int] = None
114
+ ptx_version: int = None
115
+ ptx_options: Optional[str] = knobs.nvidia.ptxas_options
116
+ ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
117
+ enable_fp_fusion: bool = True
118
+ enable_reflect_ftz: bool = True # ftz in libdevice
119
+ launch_cooperative_grid: bool = False
120
+ launch_pdl: bool = False
121
+ supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
122
+ deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
123
+ default_dot_input_precision: str = "tf32"
124
+ allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6')
125
+ max_num_imprecise_acc_default: bool = None
126
+ extern_libs: dict = None
127
+ debug: bool = False
128
+ backend_name: str = 'cuda'
129
+ sanitize_overflow: bool = True
130
+ arch: str = None
131
+ instrumentation_mode: str = ""
132
+
133
+ def __post_init__(self):
134
+ default_libdir = Path(__file__).parent / 'lib'
135
+ extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
136
+ if not extern_libs.get('libdevice', None):
137
+ extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
138
+
139
+ object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
140
+ assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
141
+ "num_warps must be a power of 2"
142
+
143
+ def hash(self):
144
+ hash_dict = dict(self.__dict__)
145
+ hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
146
+ key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
147
+ return hashlib.sha256(key.encode("utf-8")).hexdigest()
148
+
149
+
150
+ class CUDABackend(BaseBackend):
151
+ instrumentation = None
152
+
153
+ @staticmethod
154
+ def supports_target(target: GPUTarget):
155
+ return target.backend == 'cuda'
156
+
157
+ def _parse_arch(self, arch):
158
+ pattern = r"^sm(\d+)$"
159
+ match = re.fullmatch(pattern, arch)
160
+ if not match:
161
+ raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
162
+ return int(match.group(1))
163
+
164
+ def get_target_name(self, options) -> str:
165
+ capability = self._parse_arch(options.arch)
166
+ return f"cuda:{capability}"
167
+
168
+ def __init__(self, target: GPUTarget) -> None:
169
+ super().__init__(target)
170
+ self.binary_ext = "cubin"
171
+
172
+ def parse_options(self, opts) -> Any:
173
+ # Enable debug mode for ConSan, so device-side assertions are not optimized out
174
+ if "instrumentation_mode" in opts and opts["instrumentation_mode"] == "consan":
175
+ opts["debug"] = True
176
+
177
+ args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
178
+ args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
179
+ capability = int(self._parse_arch(args["arch"]))
180
+
181
+ if args.get("num_ctas", 1) > 1 and capability < 90:
182
+ raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
183
+ f"Current target is sm_{capability}. This configuration will fail. "
184
+ f"Please set num_ctas=1 or target an SM90+ GPU."))
185
+
186
+ if "supported_fp8_dtypes" not in args:
187
+ supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
188
+ if capability >= 89:
189
+ supported_fp8_dtypes.add("fp8e4nv")
190
+ args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
191
+
192
+ if "deprecated_fp8_dot_operand_dtypes" not in args:
193
+ if capability >= 90:
194
+ args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
195
+
196
+ if "enable_fp_fusion" not in args:
197
+ args["enable_fp_fusion"] = knobs.language.default_fp_fusion
198
+
199
+ args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
200
+
201
+ return CUDAOptions(**args)
202
+
203
+ def pack_metadata(self, metadata):
204
+ return (
205
+ metadata.num_warps,
206
+ metadata.num_ctas,
207
+ metadata.shared,
208
+ )
209
+
210
+ def get_codegen_implementation(self, options):
211
+ import triton.language.extra.cuda as cuda
212
+ capability = int(self._parse_arch(options.arch))
213
+ codegen_fns = {
214
+ "convert_custom_types":
215
+ cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
216
+ min_dot_size(self.target)
217
+ }
218
+ return codegen_fns
219
+
220
+ def get_module_map(self) -> Dict[str, ModuleType]:
221
+ from triton.language.extra.cuda import libdevice
222
+ return {"triton.language.extra.libdevice": libdevice}
223
+
224
+ def load_dialects(self, ctx):
225
+ nvidia.load_dialects(ctx)
226
+ if CUDABackend.instrumentation:
227
+ CUDABackend.instrumentation.load_dialects(ctx)
228
+
229
+ @staticmethod
230
+ def make_ttir(mod, metadata, opt, capability):
231
+ pm = ir.pass_manager(mod.context)
232
+ pm.enable_debug()
233
+ passes.common.add_inliner(pm)
234
+ passes.ttir.add_rewrite_tensor_pointer(pm)
235
+ if capability // 10 < 9:
236
+ passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
237
+ passes.common.add_canonicalizer(pm)
238
+ passes.ttir.add_combine(pm)
239
+ passes.ttir.add_reorder_broadcast(pm)
240
+ passes.common.add_cse(pm)
241
+ passes.common.add_symbol_dce(pm)
242
+ passes.ttir.add_loop_unroll(pm)
243
+ pm.run(mod, 'make_ttir')
244
+ return mod
245
+
246
+ @staticmethod
247
+ def make_ttgir(mod, metadata, opt, capability):
248
+ # Set maxnreg on all kernels, if it was provided.
249
+ if opt.maxnreg is not None:
250
+ mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
251
+
252
+ pm = ir.pass_manager(mod.context)
253
+ dump_enabled = pm.enable_debug()
254
+ emuTF32 = (capability // 10 >= 8)
255
+ passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
256
+ # optimize TTGIR
257
+ passes.ttgpuir.add_coalesce(pm)
258
+ passes.ttgpuir.add_f32_dot_tc(pm, emuTF32)
259
+ # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
260
+ nvidia.passes.ttnvgpuir.add_plan_cta(pm)
261
+ passes.ttgpuir.add_remove_layout_conversions(pm)
262
+ passes.ttgpuir.add_optimize_thread_locality(pm)
263
+ passes.ttgpuir.add_accelerate_matmul(pm)
264
+ passes.ttgpuir.add_remove_layout_conversions(pm)
265
+ passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
266
+ nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
267
+ passes.ttir.add_loop_aware_cse(pm)
268
+ if capability // 10 in [8, 9]:
269
+ passes.ttgpuir.add_fuse_nested_loops(pm)
270
+ passes.common.add_canonicalizer(pm)
271
+ passes.ttir.add_triton_licm(pm)
272
+ passes.common.add_canonicalizer(pm)
273
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
274
+ nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
275
+ passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
276
+ passes.ttgpuir.add_schedule_loops(pm)
277
+ passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
278
+ elif capability // 10 >= 10:
279
+ passes.ttgpuir.add_fuse_nested_loops(pm)
280
+ passes.common.add_canonicalizer(pm)
281
+ passes.ttir.add_triton_licm(pm)
282
+ passes.ttgpuir.add_optimize_accumulator_init(pm)
283
+ passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
284
+ nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
285
+ passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
286
+ passes.ttgpuir.add_schedule_loops(pm)
287
+ passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
288
+ passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
289
+ passes.ttgpuir.add_optimize_partition_warps(pm)
290
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
291
+ # hoist again and allow hoisting out of if statements
292
+ passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
293
+ nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
294
+ else:
295
+ passes.ttir.add_triton_licm(pm)
296
+ passes.common.add_canonicalizer(pm)
297
+ passes.ttir.add_loop_aware_cse(pm)
298
+ passes.ttgpuir.add_prefetch(pm)
299
+ passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
300
+ passes.ttgpuir.add_coalesce_async_copy(pm)
301
+ nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
302
+ if capability // 10 >= 9:
303
+ nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
304
+ passes.ttgpuir.add_remove_layout_conversions(pm)
305
+ nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
306
+ passes.ttgpuir.add_reduce_data_duplication(pm)
307
+ passes.ttgpuir.add_reorder_instructions(pm)
308
+ passes.ttir.add_loop_aware_cse(pm)
309
+ passes.common.add_symbol_dce(pm)
310
+ nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
311
+ nvidia.passes.ttnvgpuir.add_lower_mma(pm)
312
+ passes.common.add_sccp(pm)
313
+ passes.common.add_cse(pm)
314
+ passes.common.add_canonicalizer(pm)
315
+
316
+ pm.run(mod, 'make_ttgir')
317
+ metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
318
+ return mod
319
+
320
+ def gluon_to_ttgir(self, src, metadata, options, capability):
321
+ mod = src
322
+ pm = ir.pass_manager(mod.context)
323
+ pm.enable_debug()
324
+
325
+ passes.gluon.add_inliner(pm)
326
+ passes.gluon.add_infer_coalesced_encodings(pm)
327
+ passes.gluon.add_resolve_auto_encodings(pm)
328
+ nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
329
+ passes.gluon.add_canonicalizer(pm)
330
+ passes.common.add_sccp(pm)
331
+ passes.ttir.add_loop_aware_cse(pm)
332
+ passes.gluon.add_canonicalizer(pm)
333
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
334
+
335
+ pm.run(mod, 'gluon_to_ttgir')
336
+ metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
337
+ return mod
338
+
339
+ def make_llir(self, src, metadata, options, capability):
340
+ ptx_version = get_ptx_version_from_options(options, self.target.arch)
341
+
342
+ mod = src
343
+ # TritonGPU -> LLVM-IR (MLIR)
344
+ pm = ir.pass_manager(mod.context)
345
+ pm.enable_debug()
346
+
347
+ passes.ttgpuir.add_combine_tensor_select_and_if(pm)
348
+ passes.ttgpuir.add_allocate_warp_groups(pm)
349
+ passes.convert.add_scf_to_cf(pm)
350
+ passes.gluon.add_inliner(pm)
351
+ nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
352
+ nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
353
+ nvidia.passes.ttnvgpuir.add_check_matmul_two_cta(pm)
354
+ if knobs.compilation.instrumentation_mode == "consan":
355
+ # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
356
+ passes.ttgpuir.add_concurrency_sanitizer(pm)
357
+ passes.ttgpuir.add_allocate_global_scratch_memory(pm)
358
+ nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
359
+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
360
+ if CUDABackend.instrumentation:
361
+ CUDABackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
362
+ nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
363
+ passes.common.add_canonicalizer(pm)
364
+ passes.common.add_cse(pm)
365
+ nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
366
+ nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
367
+ passes.common.add_canonicalizer(pm)
368
+ passes.common.add_cse(pm)
369
+ passes.common.add_symbol_dce(pm)
370
+ passes.convert.add_nvvm_to_llvm(pm)
371
+
372
+ if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
373
+ passes.llvmir.add_di_scope(pm)
374
+
375
+ if CUDABackend.instrumentation:
376
+ CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
377
+
378
+ pm.run(mod, 'make_llir')
379
+
380
+ if knobs.compilation.dump_ir_extract_di_local_variables:
381
+ # comments below on why separate it
382
+ if not knobs.compilation.disable_line_info:
383
+ pm = ir.pass_manager(mod.context)
384
+ pm.enable_debug()
385
+ passes.llvmir.add_di_scope(pm)
386
+ pm.run(mod, 'make_llir.disable_line_info')
387
+
388
+ # insert dbg intrinsic with several DI Attribute including source
389
+ # var name and type info note: unknown reason for now, but this
390
+ # pass and add_di_scope has to be run separately, otherwise if we
391
+ # put them into previous pipline, it trigger a segmentfault without
392
+ # any error message; could be due to a bug in mlir or pybind11
393
+ pm = ir.pass_manager(mod.context)
394
+ pm.enable_debug()
395
+ passes.llvmir.add_di_local_variable(pm)
396
+ pm.run(mod, 'make_llir.dump_ir_extract_di_local_variables')
397
+
398
+ # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
399
+ llvm.init_targets()
400
+ context = llvm.context()
401
+ if knobs.compilation.enable_asan:
402
+ raise RuntimeError(
403
+ "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
404
+ llvm_mod = llvm.to_module(mod, context)
405
+ proc = sm_arch_from_capability(capability)
406
+ features = get_features(options, self.target.arch)
407
+ triple = 'nvptx64-nvidia-cuda'
408
+ nvidia.set_short_ptr()
409
+ llvm.attach_datalayout(llvm_mod, triple, proc, features)
410
+ if options.enable_reflect_ftz:
411
+ nvidia.set_nvvm_reflect_ftz(llvm_mod)
412
+
413
+ if options.extern_libs and nvidia.has_extern_deps(llvm_mod):
414
+ paths = [path for (name, path) in options.extern_libs]
415
+ llvm.link_extern_libs(llvm_mod, paths)
416
+
417
+ llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
418
+
419
+ # Get some metadata
420
+ # warp-specialization mutates num_warps
421
+ total_num_warps = src.get_int_attr("ttg.total-num-warps")
422
+ if total_num_warps is not None:
423
+ metadata["num_warps"] = total_num_warps
424
+ metadata["shared"] = src.get_int_attr("ttg.shared")
425
+ metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
426
+ metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
427
+ metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
428
+ metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
429
+ metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
430
+ ret = str(llvm_mod)
431
+ del llvm_mod
432
+ del context
433
+ return ret
434
+
435
+ def make_ptx(self, src, metadata, opt, capability):
436
+ ptx_version = get_ptx_version_from_options(opt, self.target.arch)
437
+
438
+ triple = 'nvptx64-nvidia-cuda'
439
+ proc = sm_arch_from_capability(capability)
440
+ features = get_features(opt, self.target.arch)
441
+ flags = ["nvptx-mad-wide-opt"]
442
+ ret = llvm.translate_to_asm(src, triple, proc, features, flags, opt.enable_fp_fusion, False)
443
+ # Find kernel names (there should only be one)
444
+ names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
445
+ assert len(names) == 1
446
+ metadata["name"] = names[0]
447
+ # post-process
448
+ ptx_version = f'{ptx_version//10}.{ptx_version%10}'
449
+ ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
450
+ ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
451
+ if not knobs.compilation.dump_ir_extract_di_local_variables:
452
+ # Remove the debug flag that prevents ptxas from optimizing the code
453
+ # Note: if this flag is removed, the source var name and type info will be lost when ptx was compiled into cubin
454
+ # and we may not be able to see them in cuda-gdb
455
+ ret = re.sub(r",\s*debug|debug,\s*", "", ret)
456
+ if knobs.nvidia.dump_nvptx:
457
+ print("// -----// NVPTX Dump //----- //")
458
+ print(ret)
459
+ return ret
460
+
461
+ def make_cubin(self, src, metadata, opt, capability):
462
+ ptxas = get_ptxas(self.target.arch).path
463
+ with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
464
+ tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
465
+ fsrc.write(src)
466
+ fsrc.flush()
467
+ fbin = fsrc.name + '.o'
468
+
469
+ debug_info = []
470
+ if knobs.compilation.disable_line_info:
471
+ # This option is ignored if used without -lineinfo
472
+ debug_info += ["-lineinfo", "-suppress-debug-info"]
473
+ elif knobs.nvidia.disable_ptxas_opt:
474
+ # Synthesize complete debug info
475
+ debug_info += ["-g"]
476
+ else:
477
+ # Only emit line info
478
+ debug_info += ["-lineinfo"]
479
+
480
+ fmad = [] if opt.enable_fp_fusion else ["--fmad=false"]
481
+ arch = sm_arch_from_capability(capability)
482
+
483
+ # Disable ptxas optimizations if requested
484
+ disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
485
+
486
+ # Accept more ptxas options if provided
487
+ ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
488
+
489
+ ptxas_cmd = [
490
+ ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name,
491
+ '-o', fbin
492
+ ]
493
+ try:
494
+ subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog)
495
+ if knobs.nvidia.dump_ptxas_log:
496
+ with open(flog.name) as log_file:
497
+ print(log_file.read())
498
+
499
+ if os.path.exists(fsrc.name):
500
+ os.remove(fsrc.name)
501
+ if os.path.exists(flog.name):
502
+ os.remove(flog.name)
503
+ except subprocess.CalledProcessError as e:
504
+ with open(flog.name) as log_file:
505
+ log = log_file.read()
506
+ if os.path.exists(flog.name):
507
+ os.remove(flog.name)
508
+
509
+ if e.returncode == 255:
510
+ error = 'Internal Triton PTX codegen error'
511
+ elif e.returncode == 128 + signal.SIGSEGV:
512
+ error = '`ptxas` raised SIGSEGV'
513
+ else:
514
+ error = f'`ptxas` failed with error code {e.returncode}'
515
+
516
+ error = (f"{error}\n"
517
+ f"`ptxas` stderr:\n{log}\n"
518
+ f'Repro command: {" ".join(ptxas_cmd)}\n')
519
+
520
+ print(f"""
521
+
522
+ ================================================================
523
+ {error}
524
+
525
+ {src}
526
+ ================================================================
527
+ please share the reproducer above with Triton project.
528
+ """)
529
+ raise PTXASError(error)
530
+
531
+ with open(fbin, 'rb') as f:
532
+ cubin = f.read()
533
+ if os.path.exists(fbin):
534
+ os.remove(fbin)
535
+ return cubin
536
+
537
+ def add_stages(self, stages, options, language):
538
+ capability = self._parse_arch(options.arch)
539
+ if language == Language.TRITON:
540
+ stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
541
+ stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
542
+ elif language == Language.GLUON:
543
+ stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability)
544
+ stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
545
+ stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
546
+ stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
547
+ if knobs.runtime.add_stages_inspection_hook is not None:
548
+ knobs.runtime.add_stages_inspection_hook(self, stages, options, language, capability)
549
+
550
+ @functools.lru_cache()
551
+ def hash(self):
552
+ version = get_ptxas_version(self.target.arch)
553
+ return f'{version}-{self.target.arch}'
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.c ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "cuda.h"
2
+ #include <dlfcn.h>
3
+ #include <stdbool.h>
4
+ #include <stdio.h>
5
+ #include <stdlib.h>
6
+ #define PY_SSIZE_T_CLEAN
7
+ #include <Python.h>
8
+
9
+ typedef struct {
10
+ PyObject_HEAD;
11
+ _Alignas(128) CUtensorMap tensorMap;
12
+ } PyCUtensorMapObject;
13
+
14
+ // Raises a Python exception and returns false if code is not CUDA_SUCCESS.
15
+ static bool gpuAssert(CUresult code, const char *file, int line) {
16
+ if (code == CUDA_SUCCESS)
17
+ return true;
18
+
19
+ const char *prefix = "Triton Error [CUDA]: ";
20
+ const char *str;
21
+ cuGetErrorString(code, &str);
22
+ char err[1024] = {0};
23
+ strcat(err, prefix);
24
+ strcat(err, str);
25
+ PyGILState_STATE gil_state;
26
+ gil_state = PyGILState_Ensure();
27
+ PyErr_SetString(PyExc_RuntimeError, err);
28
+ PyGILState_Release(gil_state);
29
+ return false;
30
+ }
31
+
32
+ // To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
33
+ #define CUDA_CHECK_AND_RETURN_NULL(ans) \
34
+ do { \
35
+ if (!gpuAssert((ans), __FILE__, __LINE__)) \
36
+ goto cleanup; \
37
+ } while (0)
38
+
39
+ // To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
40
+ #define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
41
+ do { \
42
+ if (!gpuAssert((ans), __FILE__, __LINE__)) { \
43
+ PyEval_RestoreThread(_save); \
44
+ return NULL; \
45
+ } \
46
+ } while (0)
47
+
48
+ // Used to check if functions exist in old CUDA driver versions.
49
+ #define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \
50
+ do { \
51
+ if ((funcPointer) == NULL) { \
52
+ (funcPointer) = (initializerFunction)(); \
53
+ if ((funcPointer) == NULL) { \
54
+ goto cleanup; \
55
+ } \
56
+ } \
57
+ } while (0)
58
+
59
+ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
60
+ int device_id;
61
+ if (!PyArg_ParseTuple(args, "i", &device_id))
62
+ return NULL;
63
+ // Get device handle
64
+ CUdevice device;
65
+ cuDeviceGet(&device, device_id);
66
+
67
+ // create a struct to hold device properties
68
+ int max_shared_mem;
69
+ int max_num_regs;
70
+ int multiprocessor_count;
71
+ int warp_size;
72
+ int sm_clock_rate;
73
+ int mem_clock_rate;
74
+ int mem_bus_width;
75
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
76
+ &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
77
+ device));
78
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
79
+ &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
80
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
81
+ &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
82
+ CUDA_CHECK_AND_RETURN_NULL(
83
+ cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
84
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
85
+ &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
86
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
87
+ &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
88
+ CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
89
+ &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
90
+
91
+ return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
92
+ max_shared_mem, "max_num_regs", max_num_regs,
93
+ "multiprocessor_count", multiprocessor_count, "warpSize",
94
+ warp_size, "sm_clock_rate", sm_clock_rate,
95
+ "mem_clock_rate", mem_clock_rate, "mem_bus_width",
96
+ mem_bus_width);
97
+
98
+ cleanup:
99
+ return NULL;
100
+ }
101
+
102
+ static PyObject *loadBinary(PyObject *self, PyObject *args) {
103
+ const char *name;
104
+ const char *data;
105
+ Py_ssize_t data_size;
106
+ int shared;
107
+ int device;
108
+ if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
109
+ &device)) {
110
+ return NULL;
111
+ }
112
+ CUfunction fun;
113
+ CUmodule mod;
114
+ int32_t n_regs = 0;
115
+ int32_t n_spills = 0;
116
+ int32_t n_max_threads = 0;
117
+ // create driver handles
118
+ CUcontext pctx = 0;
119
+
120
+ Py_BEGIN_ALLOW_THREADS;
121
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
122
+ if (!pctx) {
123
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
124
+ cuDevicePrimaryCtxRetain(&pctx, device));
125
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
126
+ }
127
+
128
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
129
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
130
+ cuModuleGetFunction(&fun, mod, name));
131
+ // get allocated registers and spilled registers from the function
132
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
133
+ cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
134
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
135
+ cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
136
+ n_spills /= 4;
137
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
138
+ &n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
139
+ // set dynamic shared memory if necessary
140
+ int shared_optin;
141
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
142
+ &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
143
+ device));
144
+ if (shared > 49152 && shared_optin > 49152) {
145
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
146
+ cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
147
+ int shared_total, shared_static;
148
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
149
+ &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
150
+ device));
151
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
152
+ &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
153
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
154
+ cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
155
+ shared_optin - shared_static));
156
+ }
157
+ Py_END_ALLOW_THREADS;
158
+
159
+ if (PyErr_Occurred()) {
160
+ return NULL;
161
+ }
162
+ return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
163
+ n_spills, n_max_threads);
164
+ }
165
+
166
+ typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
167
+ int *numClusters, CUfunction func, const CUlaunchConfig *config);
168
+
169
+ typedef CUresult (*cuTensorMapEncodeTiled_t)(
170
+ CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
171
+ cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
172
+ const cuuint64_t *globalStrides, const cuuint32_t *boxDim,
173
+ const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
174
+ CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
175
+ CUtensorMapFloatOOBfill oobFill);
176
+
177
+ #define defineGetFunctionHandle(name, symbolName) \
178
+ static symbolName##_t name() { \
179
+ /* Open the shared library */ \
180
+ void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \
181
+ if (!libHandle) { \
182
+ PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \
183
+ return NULL; \
184
+ } \
185
+ /* Clear any existing error */ \
186
+ dlerror(); \
187
+ symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \
188
+ /* Check for errors */ \
189
+ const char *err = dlerror(); \
190
+ if (err) { \
191
+ PyErr_SetString(PyExc_RuntimeError, \
192
+ "Failed to retrieve " #symbolName " from libcuda.so.1"); \
193
+ dlclose(libHandle); \
194
+ return NULL; \
195
+ } \
196
+ return funcHandle; \
197
+ }
198
+
199
+ defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
200
+ cuOccupancyMaxActiveClusters);
201
+
202
+ defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
203
+ cuTensorMapEncodeTiled);
204
+
205
+ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
206
+ int clusterDim = -1, maxActiveClusters = -1;
207
+ int shared = 0;
208
+ CUfunction func;
209
+
210
+ if (!PyArg_ParseTuple(args, "Kii", &func, &shared, &clusterDim)) {
211
+ return NULL;
212
+ }
213
+
214
+ // Let each SM have one block
215
+ int maxActiveBlocks = 1;
216
+ Py_BEGIN_ALLOW_THREADS;
217
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
218
+ func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared));
219
+ Py_END_ALLOW_THREADS;
220
+
221
+ CUlaunchAttribute launchAttr[1];
222
+ launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
223
+ launchAttr[0].value.clusterDim.x = clusterDim;
224
+ launchAttr[0].value.clusterDim.y = 1;
225
+ launchAttr[0].value.clusterDim.z = 1;
226
+ CUlaunchConfig config;
227
+ config.gridDimX = clusterDim * maxActiveBlocks;
228
+ config.gridDimY = 1;
229
+ config.gridDimZ = 1;
230
+ config.blockDimX = 128;
231
+ config.blockDimY = 1;
232
+ config.blockDimZ = 1;
233
+ config.sharedMemBytes = shared;
234
+ config.hStream = 0;
235
+ config.numAttrs = 1;
236
+ config.attrs = launchAttr;
237
+
238
+ static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL;
239
+ INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters,
240
+ getCuOccupancyMaxActiveClustersHandle);
241
+
242
+ Py_BEGIN_ALLOW_THREADS;
243
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
244
+ func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
245
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
246
+ cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
247
+ Py_END_ALLOW_THREADS;
248
+ return PyLong_FromLong(maxActiveClusters);
249
+
250
+ cleanup:
251
+ return NULL;
252
+ }
253
+
254
+ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
255
+ long size;
256
+ if (!PyArg_ParseTuple(args, "l", &size)) {
257
+ return NULL;
258
+ }
259
+ if (size < 0) {
260
+ PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative");
261
+ return NULL;
262
+ }
263
+
264
+ Py_BEGIN_ALLOW_THREADS;
265
+
266
+ // Ensure we have an active context.
267
+ CUcontext ctx = NULL;
268
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx));
269
+ if (!ctx) {
270
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
271
+ cuDevicePrimaryCtxRetain(&ctx, /*device=*/0));
272
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx));
273
+ }
274
+
275
+ // We can't set the fifo size after running a kernel that calls printf. This
276
+ // is true even if the set() call is a nop and the new size is the same as the
277
+ // old size.
278
+ //
279
+ // This is unfriendly, so check if the old size matches the new size, and skip
280
+ // the set() call if so.
281
+ size_t oldSize = 0;
282
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
283
+ cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE));
284
+ if (oldSize != size) {
285
+ CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
286
+ cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size));
287
+ }
288
+
289
+ Py_END_ALLOW_THREADS;
290
+ Py_RETURN_NONE;
291
+ }
292
+
293
+ static PyObject *PyCUtensorMap_alloc(PyTypeObject *type, Py_ssize_t n_items) {
294
+ PyCUtensorMapObject *self = NULL;
295
+ void *mem = NULL;
296
+ size_t size = type->tp_basicsize;
297
+
298
+ if (posix_memalign(&mem, 128, size) != 0) {
299
+ PyErr_NoMemory();
300
+ return NULL;
301
+ }
302
+
303
+ self = (PyCUtensorMapObject *)mem;
304
+ PyObject_INIT(self, type);
305
+ return (PyObject *)self;
306
+ }
307
+
308
+ static void PyCUtensorMap_dealloc(PyObject *self) {
309
+ Py_TYPE(self)->tp_free(self);
310
+ }
311
+
312
+ static void PyCUtensorMap_free(void *ptr) { free(ptr); }
313
+
314
+ // clang-format off
315
+ static PyTypeObject PyCUtensorMapType = {
316
+ PyVarObject_HEAD_INIT(NULL, 0)
317
+ .tp_name = "triton.backends.nvidia.PyCUtensorMap",
318
+ .tp_basicsize = sizeof(PyCUtensorMapObject),
319
+ .tp_itemsize = 0,
320
+ .tp_flags = Py_TPFLAGS_DEFAULT,
321
+ .tp_doc = "<PyCUtensorMap object>",
322
+ .tp_new = PyType_GenericNew,
323
+ .tp_alloc = PyCUtensorMap_alloc,
324
+ .tp_dealloc = (destructor)PyCUtensorMap_dealloc,
325
+ .tp_free = PyCUtensorMap_free,
326
+ };
327
+ // clang-format on
328
+
329
+ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
330
+ unsigned long long global_address;
331
+ int swizzle;
332
+ int elemSize;
333
+ int elemType;
334
+ PyObject *blockSize;
335
+ PyObject *shape;
336
+ PyObject *strides;
337
+ int padding;
338
+
339
+ if (!PyArg_ParseTuple(args, "KiiiOOOi", &global_address, &swizzle, &elemSize,
340
+ &elemType, &blockSize, &shape, &strides, &padding)) {
341
+ return NULL;
342
+ }
343
+
344
+ PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
345
+ (PyObject *)&PyCUtensorMapType, NULL);
346
+ if (!desc) {
347
+ return NULL;
348
+ }
349
+
350
+ PyObject *blockSizeFast = NULL;
351
+ PyObject *shapeFast = NULL;
352
+ PyObject *stridesFast = NULL;
353
+
354
+ uint32_t blockSizeInt[5];
355
+ uint64_t shapeInt[5];
356
+ uint64_t stridesLL[5];
357
+
358
+ blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
359
+ if (!blockSizeFast)
360
+ goto cleanup;
361
+ int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
362
+
363
+ for (int i = 0; i < rank; ++i) {
364
+ PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
365
+ if (!PyLong_Check(item)) {
366
+ PyErr_SetString(PyExc_TypeError, "block size must be an int");
367
+ goto cleanup;
368
+ }
369
+ blockSizeInt[rank - i - 1] = PyLong_AsLongLong(item);
370
+ }
371
+
372
+ shapeFast = PySequence_Fast(shape, "shape must be a sequence");
373
+ if (!shapeFast)
374
+ goto cleanup;
375
+
376
+ if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
377
+ PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
378
+ goto cleanup;
379
+ }
380
+ for (int i = 0; i < rank; ++i) {
381
+ PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
382
+ if (!PyLong_Check(item)) {
383
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
384
+ goto cleanup;
385
+ }
386
+ shapeInt[rank - i - 1] = PyLong_AsLong(item);
387
+ }
388
+
389
+ stridesFast = PySequence_Fast(strides, "strides must be a sequence");
390
+ if (!stridesFast)
391
+ goto cleanup;
392
+
393
+ if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
394
+ PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
395
+ goto cleanup;
396
+ }
397
+ for (int i = 0; i + 1 < rank; ++i) {
398
+ PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
399
+ if (!PyLong_Check(item)) {
400
+ PyErr_SetString(PyExc_TypeError, "shape must be an int");
401
+ goto cleanup;
402
+ }
403
+ stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
404
+ }
405
+ stridesLL[rank - 1] =
406
+ shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
407
+ Py_DECREF(blockSizeFast);
408
+ blockSizeFast = NULL;
409
+ Py_DECREF(shapeFast);
410
+ shapeFast = NULL;
411
+ Py_DECREF(stridesFast);
412
+ stridesFast = NULL;
413
+
414
+ CUtensorMapFloatOOBfill fill =
415
+ (padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
416
+ : CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
417
+
418
+ uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
419
+ static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
420
+ INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
421
+ getCuTensorMapEncodeTiledHandle);
422
+ CUresult res = cuTensorMapEncodeTiled(
423
+ &desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
424
+ stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
425
+ swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill);
426
+ if (res != CUDA_SUCCESS) {
427
+ const char *str;
428
+ cuGetErrorString(res, &str);
429
+ char err[4096] = {0};
430
+ size_t off = 0;
431
+ off += snprintf(
432
+ err + off, sizeof(err) - off,
433
+ "Triton Error [CUDA]: Failed to create tensor map descriptor: %s\n",
434
+ str ? str : "Unknown error");
435
+ off += snprintf(err + off, sizeof(err) - off,
436
+ "elemType=%d rank=%d global_address=0x%llx elemSize=%d "
437
+ "swizzle=%d padding=%d\n",
438
+ elemType, rank, (unsigned long long)global_address,
439
+ elemSize, swizzle, padding);
440
+ off += snprintf(err + off, sizeof(err) - off, "shape=[");
441
+ for (int i = 0; i < rank; ++i) {
442
+ off +=
443
+ snprintf(err + off, sizeof(err) - off, "%llu%s",
444
+ (unsigned long long)shapeInt[i], (i + 1 < rank) ? ", " : "");
445
+ }
446
+ off += snprintf(err + off, sizeof(err) - off, "]\n");
447
+ off += snprintf(err + off, sizeof(err) - off, "strides=[");
448
+ for (int i = 0; i < rank; ++i) {
449
+ off += snprintf(err + off, sizeof(err) - off, "%llu%s",
450
+ (unsigned long long)stridesLL[i],
451
+ (i + 1 < rank) ? ", " : "");
452
+ }
453
+ off += snprintf(err + off, sizeof(err) - off, "]\n");
454
+ off += snprintf(err + off, sizeof(err) - off, "blockSize=[");
455
+ for (int i = 0; i < rank; ++i) {
456
+ off += snprintf(err + off, sizeof(err) - off, "%u%s",
457
+ (unsigned)blockSizeInt[i], (i + 1 < rank) ? ", " : "");
458
+ }
459
+ off += snprintf(err + off, sizeof(err) - off, "] elementStrides=[");
460
+ for (int i = 0; i < rank; ++i) {
461
+ off += snprintf(err + off, sizeof(err) - off, "%u%s",
462
+ (unsigned)elementStrides[i], (i + 1 < rank) ? ", " : "");
463
+ }
464
+ off += snprintf(err + off, sizeof(err) - off, "]\n");
465
+ PyErr_SetString(PyExc_RuntimeError, err);
466
+
467
+ goto cleanup;
468
+ }
469
+
470
+ return (PyObject *)desc;
471
+
472
+ cleanup:
473
+ Py_XDECREF(blockSizeFast);
474
+ Py_XDECREF(shapeFast);
475
+ Py_XDECREF(stridesFast);
476
+ Py_XDECREF(desc);
477
+ return NULL;
478
+ }
479
+
480
+ static PyMethodDef ModuleMethods[] = {
481
+ {"load_binary", loadBinary, METH_VARARGS,
482
+ "Load provided cubin into CUDA driver"},
483
+ {"get_device_properties", getDeviceProperties, METH_VARARGS,
484
+ "Get the properties for a given device"},
485
+ {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS,
486
+ "Python interface for cuOccupancyMaxActiveClusters function"},
487
+ {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS,
488
+ "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which "
489
+ "controls how many bytes can be streamed from kernels before data starts "
490
+ "being dropped. This inherits all the limitations of this call; in "
491
+ "particular it's an error to change this value after launching any kernel "
492
+ "that calls printf()."},
493
+ {"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
494
+
495
+ {NULL, NULL, 0, NULL} // sentinel
496
+ };
497
+
498
+ static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
499
+ NULL, // documentation
500
+ -1, // size
501
+ ModuleMethods};
502
+
503
+ PyMODINIT_FUNC PyInit_cuda_utils(void) {
504
+ if (PyType_Ready(&PyCUtensorMapType) < 0) {
505
+ return NULL;
506
+ }
507
+
508
+ PyObject *m = PyModule_Create(&ModuleDef);
509
+ if (m == NULL) {
510
+ return NULL;
511
+ }
512
+
513
+ PyModule_AddFunctions(m, ModuleMethods);
514
+ Py_INCREF(&PyCUtensorMapType);
515
+ PyModule_AddObject(m, "PyCUtensorMap", (PyObject *)&PyCUtensorMapType);
516
+
517
+ return m;
518
+ }
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import subprocess
4
+ import triton
5
+ import re
6
+ from pathlib import Path
7
+ from triton import knobs
8
+ from triton.runtime.build import compile_module_from_src
9
+ from triton.runtime import _allocation
10
+ from triton.backends.compiler import GPUTarget
11
+ from triton.backends.driver import GPUDriver
12
+
13
+ dirname = os.path.dirname(os.path.realpath(__file__))
14
+ include_dirs = [os.path.join(dirname, "include")]
15
+ libdevice_dir = os.path.join(dirname, "lib")
16
+ libraries = ['libcuda.so.1']
17
+ PyCUtensorMap = None
18
+
19
+
20
+ @functools.lru_cache()
21
+ def libcuda_dirs():
22
+ if env_libcuda_path := knobs.nvidia.libcuda_path:
23
+ return [env_libcuda_path]
24
+
25
+ libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode(errors="ignore")
26
+ # each line looks like the following:
27
+ # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
28
+ locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line]
29
+ dirs = [os.path.dirname(loc) for loc in locs]
30
+ env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
31
+ if env_ld_library_path and not dirs:
32
+ dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))]
33
+ msg = 'libcuda.so cannot found!\n'
34
+ if locs:
35
+ msg += 'Possible files are located at %s.' % str(locs)
36
+ msg += 'Please create a symlink of libcuda.so to any of the files.'
37
+ else:
38
+ msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"'
39
+ msg += ' (requires sudo) to refresh the linker cache.'
40
+ assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg
41
+ return dirs
42
+
43
+
44
+ @functools.lru_cache()
45
+ def library_dirs():
46
+ return [libdevice_dir, *libcuda_dirs()]
47
+
48
+
49
+ # ------------------------
50
+ # Utils
51
+ # ------------------------
52
+
53
+
54
+ class CudaUtils(object):
55
+
56
+ def __new__(cls):
57
+ if not hasattr(cls, "instance"):
58
+ cls.instance = super(CudaUtils, cls).__new__(cls)
59
+ return cls.instance
60
+
61
+ def __init__(self):
62
+ mod = compile_module_from_src(
63
+ src=Path(os.path.join(dirname, "driver.c")).read_text(),
64
+ name="cuda_utils",
65
+ library_dirs=library_dirs(),
66
+ include_dirs=include_dirs,
67
+ libraries=libraries,
68
+ )
69
+ global PyCUtensorMap
70
+ PyCUtensorMap = mod.PyCUtensorMap
71
+ self.load_binary = mod.load_binary
72
+ self.get_device_properties = mod.get_device_properties
73
+ self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
74
+ self.set_printf_fifo_size = mod.set_printf_fifo_size
75
+ self.fill_tma_descriptor = mod.fill_tma_descriptor
76
+
77
+
78
+ # ------------------------
79
+ # Launcher
80
+ # ------------------------
81
+
82
+
83
+ def ty_to_cpp(ty):
84
+ if ty[0] == '*':
85
+ return "CUdeviceptr"
86
+ if ty.startswith("tensordesc"):
87
+ return "CUtensorMap"
88
+ return {
89
+ "i1": "int8_t",
90
+ "i8": "int8_t",
91
+ "i16": "int16_t",
92
+ "i32": "int32_t",
93
+ "i64": "int64_t",
94
+ "u1": "uint8_t",
95
+ "u8": "uint8_t",
96
+ "u16": "uint16_t",
97
+ "u32": "uint32_t",
98
+ "u64": "uint64_t",
99
+ "fp16": "double",
100
+ "bf16": "double",
101
+ "fp32": "double",
102
+ "f32": "double",
103
+ "fp64": "double",
104
+ "nvTmaDesc": "CUtensorMap",
105
+ }[ty]
106
+
107
+
108
+ FLOAT_STORAGE_TYPE = {
109
+ "fp16": "uint16_t",
110
+ "bf16": "uint16_t",
111
+ "fp32": "uint32_t",
112
+ "f32": "uint32_t",
113
+ "fp64": "uint64_t",
114
+ }
115
+ FLOAT_PACK_FUNCTION = {
116
+ "fp16": "pack_fp16",
117
+ "bf16": "pack_bf16",
118
+ "fp32": "pack_fp32",
119
+ "f32": "pack_fp32",
120
+ "fp64": "pack_fp64",
121
+ }
122
+
123
+ _BASE_ARGS_FORMAT = "iiiKKppOOOOOO"
124
+ _BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
125
+
126
+
127
+ def make_launcher(constants, signature, tensordesc_meta):
128
+
129
+ def _expand_signature(signature):
130
+ output = []
131
+ tensordesc_idx = 0
132
+ # Expand tensor descriptor arguments into either nvTmaDesc, shape and
133
+ # strides, or base pointer, shape and strides depending on whether the
134
+ # kernel was lowered to use the nvTmaDesc or not.
135
+ for sig in signature:
136
+ if isinstance(sig, str) and sig.startswith("tensordesc"):
137
+ meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
138
+ tensordesc_idx += 1
139
+
140
+ match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", sig)
141
+ dtype = match.group(1)
142
+ shape = match.group(2)
143
+ ndim = shape.count(",") + 1
144
+
145
+ if meta is None:
146
+ output.append("*" + dtype)
147
+ # Currently the host side tensor descriptors get passed in as a
148
+ # tensor desc, shape, and strides. We have no way to use these
149
+ # shape and strides when processing tensor descriptors which is
150
+ # why we provide our own decomposition above. Sadly this means
151
+ # we have to pass the shape and strides twice.
152
+ for _ in range(2 * ndim):
153
+ output.append("i64")
154
+ output.append("i1")
155
+ else:
156
+ output.append("nvTmaDesc")
157
+
158
+ for _ in range(ndim):
159
+ output.append("i32")
160
+ for _ in range(ndim):
161
+ output.append("i64")
162
+ else:
163
+ output.append(sig)
164
+
165
+ assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
166
+ return output
167
+
168
+ def _flatten_signature(sig, output):
169
+ # Flatten tuples
170
+ if isinstance(sig, tuple):
171
+ for x in sig:
172
+ _flatten_signature(x, output)
173
+ else:
174
+ output.append(sig)
175
+
176
+ def _extracted_type(ty):
177
+ if isinstance(ty, tuple):
178
+ val = ','.join(map(_extracted_type, ty))
179
+ return f"[{val}]"
180
+ if ty[0] == '*':
181
+ return "PyObject*"
182
+ if ty in ("constexpr", "nvTmaDesc"):
183
+ return "PyObject*"
184
+ return ty_to_cpp(ty)
185
+
186
+ def format_of(ty):
187
+ if isinstance(ty, tuple):
188
+ val = ''.join(map(format_of, ty))
189
+ return f"({val})"
190
+ if ty[0] == '*':
191
+ return "O"
192
+ if ty in ("constexpr", "nvTmaDesc"):
193
+ return "O"
194
+ if ty.startswith("tensordesc"):
195
+ return "O"
196
+ return {
197
+ "double": "d",
198
+ "long": "l",
199
+ "int8_t": "b",
200
+ "int16_t": "h",
201
+ "int32_t": "i",
202
+ "int64_t": "L",
203
+ "uint8_t": "B",
204
+ "uint16_t": "H",
205
+ "uint32_t": "I",
206
+ "uint64_t": "K",
207
+ }[ty_to_cpp(ty)]
208
+
209
+ expand_signature = _expand_signature(signature.values())
210
+ signature = {i: s for i, s in enumerate(expand_signature)}
211
+
212
+ args_format = ''.join([format_of(ty) for ty in signature.values()])
213
+ format = _BASE_ARGS_FORMAT + args_format
214
+
215
+ flat_signature = []
216
+ for sig in signature.values():
217
+ _flatten_signature(sig, flat_signature)
218
+ signature = {i: s for i, s in enumerate(flat_signature)}
219
+ args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
220
+ # Record the end of regular arguments;
221
+ # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
222
+ arg_decl_list = []
223
+ for i, ty in signature.items():
224
+ if ty == "constexpr":
225
+ continue
226
+ if ty in FLOAT_STORAGE_TYPE:
227
+ arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
228
+ else:
229
+ arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
230
+ arg_decls = ', '.join(arg_decl_list)
231
+ internal_args_list = []
232
+ for i, ty in signature.items():
233
+ if ty[0] == "*":
234
+ internal_args_list.append(f"ptr_info{i}.dev_ptr")
235
+ elif ty in FLOAT_STORAGE_TYPE:
236
+ internal_args_list.append(f"_arg{i}_storage")
237
+ elif ty == "nvTmaDesc":
238
+ # Note: we have to dereference the pointer
239
+ internal_args_list.append(f"*tma_ptr{i}")
240
+ elif ty != "constexpr":
241
+ internal_args_list.append(f"_arg{i}")
242
+ params = range(len(signature))
243
+
244
+ # generate glue code
245
+ newline = '\n '
246
+ ptr_decls = [
247
+ f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
248
+ for i, ty in signature.items()
249
+ if ty[0] == "*"
250
+ ]
251
+ tma_decls = [
252
+ f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
253
+ if ty == "nvTmaDesc"
254
+ ]
255
+ float_storage_decls = [
256
+ f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
257
+ for i, ty in signature.items()
258
+ if ty in FLOAT_STORAGE_TYPE
259
+ ]
260
+ params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
261
+ params.append("&global_scratch")
262
+ params.append("&profile_scratch")
263
+ src = f"""
264
+ #include \"cuda.h\"
265
+ #include <dlfcn.h>
266
+ #include <stdbool.h>
267
+ #include <stdlib.h>
268
+ #define PY_SSIZE_T_CLEAN
269
+ #include <Python.h>
270
+
271
+ typedef struct {{
272
+ PyObject_HEAD;
273
+ _Alignas(128) CUtensorMap tensorMap;
274
+ }} PyCUtensorMapObject;
275
+
276
+ static inline void gpuAssert(CUresult code, const char *file, int line)
277
+ {{
278
+ if (code != CUDA_SUCCESS)
279
+ {{
280
+ const char* prefix = "Triton Error [CUDA]: ";
281
+ const char* str;
282
+ cuGetErrorString(code, &str);
283
+ char err[1024] = {{0}};
284
+ strcat(err, prefix);
285
+ strcat(err, str);
286
+ PyGILState_STATE gil_state;
287
+ gil_state = PyGILState_Ensure();
288
+ PyErr_SetString(PyExc_RuntimeError, err);
289
+ PyGILState_Release(gil_state);
290
+ }}
291
+ }}
292
+
293
+ #define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
294
+
295
+ typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
296
+
297
+ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
298
+ // Open the shared library
299
+ void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
300
+ if (!handle) {{
301
+ PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1");
302
+ return NULL;
303
+ }}
304
+ // Clear any existing error
305
+ dlerror();
306
+ cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
307
+ // Check for errors
308
+ const char *dlsym_error = dlerror();
309
+ if (dlsym_error) {{
310
+ PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1");
311
+ return NULL;
312
+ }}
313
+ return cuLaunchKernelExHandle;
314
+ }}
315
+
316
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch, CUdeviceptr profile_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
317
+ void *params[] = {{ {', '.join(params)} }};
318
+ if (gridX*gridY*gridZ > 0) {{
319
+ // 4 attributes that we can currently pass maximum
320
+ CUlaunchAttribute launchAttr[4];
321
+ static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
322
+ if (cuLaunchKernelExHandle == NULL) {{
323
+ cuLaunchKernelExHandle = getLaunchKernelExHandle();
324
+ }}
325
+ CUlaunchConfig config;
326
+ config.gridDimX = gridX * num_ctas;
327
+ config.gridDimY = gridY;
328
+ config.gridDimZ = gridZ;
329
+
330
+ config.blockDimX = 32 * num_warps;
331
+ config.blockDimY = 1;
332
+ config.blockDimZ = 1;
333
+ config.sharedMemBytes = shared_memory;
334
+ config.hStream = stream;
335
+ config.attrs = launchAttr;
336
+ int num_attrs = 0;
337
+
338
+ if (launch_pdl != 0) {{
339
+ CUlaunchAttribute pdlAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1}};
340
+ launchAttr[num_attrs] = pdlAttr;
341
+ ++num_attrs;
342
+ }}
343
+
344
+ if (launch_cooperative_grid != 0) {{
345
+ CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
346
+ launchAttr[num_attrs] = coopAttr;
347
+ ++num_attrs;
348
+ }}
349
+
350
+ if (num_ctas != 1) {{
351
+ CUlaunchAttribute clusterAttr = {{}};
352
+ clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
353
+ clusterAttr.value.clusterDim.x = num_ctas;
354
+ clusterAttr.value.clusterDim.y = 1;
355
+ clusterAttr.value.clusterDim.z = 1;
356
+ launchAttr[num_attrs] = clusterAttr;
357
+ ++num_attrs;
358
+
359
+ CUlaunchAttribute clusterSchedulingAttr = {{}};
360
+ clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
361
+ clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
362
+ launchAttr[num_attrs] = clusterSchedulingAttr;
363
+ ++num_attrs;
364
+ }}
365
+
366
+ // num_ctas == 16 is non-portable. Does work for H100 and B200 tho
367
+ config.numAttrs = num_attrs;
368
+ if (num_ctas == 16) {{
369
+ CUDA_CHECK(cuFuncSetAttribute(
370
+ function,
371
+ CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
372
+ 1
373
+ ));
374
+ }}
375
+
376
+ CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
377
+ }}
378
+ }}
379
+
380
+ typedef struct _DevicePtrInfo {{
381
+ CUdeviceptr dev_ptr;
382
+ bool valid;
383
+ }} DevicePtrInfo;
384
+
385
+ static PyObject* data_ptr_str = NULL;
386
+ static PyObject* py_tensor_map_type = NULL;
387
+
388
+ static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
389
+ DevicePtrInfo ptr_info;
390
+ ptr_info.dev_ptr = 0;
391
+ ptr_info.valid = true;
392
+ if (PyLong_Check(obj)) {{
393
+ ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
394
+ return ptr_info;
395
+ }}
396
+ if (obj == Py_None) {{
397
+ // valid nullptr
398
+ return ptr_info;
399
+ }}
400
+ PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
401
+ if (!ret) {{
402
+ PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
403
+ ptr_info.valid = false;
404
+ goto cleanup;
405
+ }}
406
+ if (!PyLong_Check(ret)) {{
407
+ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
408
+ ptr_info.valid = false;
409
+ goto cleanup;
410
+ }}
411
+ ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
412
+ if(!ptr_info.dev_ptr)
413
+ return ptr_info;
414
+ uint64_t dev_ptr;
415
+ int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
416
+ if (status == CUDA_ERROR_INVALID_VALUE) {{
417
+ PyErr_Format(PyExc_ValueError,
418
+ "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
419
+ ptr_info.valid = false;
420
+ }} else if (status != CUDA_SUCCESS) {{
421
+ CUDA_CHECK(status); // Catch any other cuda API errors
422
+ ptr_info.valid = false;
423
+ }}
424
+ ptr_info.dev_ptr = dev_ptr;
425
+ cleanup:
426
+ Py_XDECREF(ret);
427
+ return ptr_info;
428
+
429
+ }}
430
+
431
+ static inline CUtensorMap* getTmaDesc(PyObject *obj) {{
432
+ if (sizeof(CUtensorMap*) != 8) {{
433
+ PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
434
+ return NULL;
435
+ }}
436
+
437
+ if (Py_TYPE(obj) != (PyTypeObject*)py_tensor_map_type) {{
438
+ PyErr_Format(PyExc_TypeError, "object must be of type PyCUtensorMap, got %s", Py_TYPE(obj)->tp_name);
439
+ return NULL;
440
+ }}
441
+
442
+ CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
443
+ uintptr_t align_128 = (uintptr_t)map & (128 - 1);
444
+ if (align_128 != 0) {{
445
+ PyErr_Format(PyExc_ValueError, "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", align_128);
446
+ return NULL;
447
+ }}
448
+ return map;
449
+ }}
450
+
451
+ static void ensureCudaContext() {{
452
+ CUcontext pctx;
453
+ CUDA_CHECK(cuCtxGetCurrent(&pctx));
454
+ if (!pctx) {{
455
+ // Ensure device context.
456
+ CUdevice device;
457
+ CUDA_CHECK(cuDeviceGet(&device, 0));
458
+ CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
459
+ CUDA_CHECK(cuCtxSetCurrent(pctx));
460
+ }}
461
+ }}
462
+
463
+ static uint16_t pack_fp16(double f) {{
464
+ uint16_t result;
465
+ // from https://github.com/python/pythoncapi-compat
466
+ #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
467
+ _PyFloat_Pack2(f, (unsigned char*)&result, 1);
468
+ #else
469
+ PyFloat_Pack2(f, (unsigned char*)&result, 1);
470
+ #endif
471
+ return result;
472
+ }}
473
+
474
+ static uint16_t pack_bf16(double f) {{
475
+ float f32 = (float)f;
476
+ uint32_t u32 = *(uint32_t*)&f32;
477
+ return (uint16_t)(u32 >> 16);
478
+ }}
479
+
480
+ static uint32_t pack_fp32(double f) {{
481
+ float f32 = (float)f;
482
+ return *(uint32_t*)&f32;
483
+ }}
484
+
485
+ static uint64_t pack_fp64(double f) {{
486
+ return *(uint64_t*)&f;
487
+ }}
488
+
489
+ static PyObject* launch(PyObject* self, PyObject* args) {{
490
+ // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
491
+ ensureCudaContext();
492
+
493
+ int gridX, gridY, gridZ;
494
+ uint64_t _stream;
495
+ uint64_t _function;
496
+ int launch_cooperative_grid;
497
+ int launch_pdl;
498
+ PyObject *launch_enter_hook = NULL;
499
+ PyObject *launch_exit_hook = NULL;
500
+ PyObject *kernel_metadata = NULL;
501
+ PyObject *launch_metadata = NULL;
502
+ PyObject *global_scratch_obj = NULL;
503
+ PyObject *profile_scratch_obj = NULL;
504
+ {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
505
+ if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
506
+ &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj, &profile_scratch_obj,
507
+ &kernel_metadata, &launch_metadata,
508
+ &launch_enter_hook, &launch_exit_hook{args_list})) {{
509
+ return NULL;
510
+ }}
511
+
512
+ int num_warps, num_ctas, shared_memory;
513
+ if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &num_warps, &num_ctas, &shared_memory)) {{
514
+ PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
515
+ return NULL;
516
+ }}
517
+
518
+ // extract launch metadata
519
+ if (launch_enter_hook != Py_None){{
520
+ PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
521
+ if (!ret)
522
+ return NULL;
523
+ Py_DECREF(ret);
524
+ }}
525
+
526
+ CUdeviceptr global_scratch = 0;
527
+ if (global_scratch_obj != Py_None) {{
528
+ DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
529
+ if (!global_scratch_info.valid) {{
530
+ return NULL;
531
+ }}
532
+ global_scratch = global_scratch_info.dev_ptr;
533
+ }}
534
+
535
+ CUdeviceptr profile_scratch = 0;
536
+ if (profile_scratch_obj != Py_None) {{
537
+ DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
538
+ if (!profile_scratch_info.valid) {{
539
+ return NULL;
540
+ }}
541
+ profile_scratch = profile_scratch_info.dev_ptr;
542
+ }}
543
+
544
+ // raise exception asap
545
+ {newline.join(ptr_decls)}
546
+ {newline.join(tma_decls)}
547
+ {newline.join(float_storage_decls)}
548
+ Py_BEGIN_ALLOW_THREADS;
549
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch, profile_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
550
+ Py_END_ALLOW_THREADS;
551
+ if (PyErr_Occurred()) {{
552
+ return NULL;
553
+ }}
554
+
555
+ if(launch_exit_hook != Py_None){{
556
+ PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
557
+ if (!ret)
558
+ return NULL;
559
+ Py_DECREF(ret);
560
+ }}
561
+
562
+ Py_RETURN_NONE;
563
+ }}
564
+
565
+ static PyMethodDef ModuleMethods[] = {{
566
+ {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
567
+ {{NULL, NULL, 0, NULL}} // sentinel
568
+ }};
569
+
570
+ static struct PyModuleDef ModuleDef = {{
571
+ PyModuleDef_HEAD_INIT,
572
+ \"__triton_launcher\",
573
+ NULL, //documentation
574
+ -1, //size
575
+ ModuleMethods
576
+ }};
577
+
578
+ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
579
+ data_ptr_str = PyUnicode_InternFromString("data_ptr");
580
+ if(data_ptr_str == NULL) {{
581
+ return NULL;
582
+ }}
583
+ PyObject* driver_mod = PyImport_ImportModule("triton.backends.nvidia.driver");
584
+ if (driver_mod == NULL) {{
585
+ return NULL;
586
+ }}
587
+ py_tensor_map_type = PyObject_GetAttrString(driver_mod, "PyCUtensorMap");
588
+ if (py_tensor_map_type == NULL) {{
589
+ return NULL;
590
+ }}
591
+
592
+ PyObject *m = PyModule_Create(&ModuleDef);
593
+ if(m == NULL) {{
594
+ return NULL;
595
+ }}
596
+ PyModule_AddFunctions(m, ModuleMethods);
597
+ return m;
598
+ }}
599
+ """
600
+ return src
601
+
602
+
603
+ # The TMA dtype enum values are slightly different on host vs device...
604
+ TMA_DTYPE_DEVICE_TO_HOST = dict((i, i) for i in range(16))
605
+ TMA_DTYPE_DEVICE_TO_HOST[8] = 10
606
+ TMA_DTYPE_DEVICE_TO_HOST[9] = 8
607
+ TMA_DTYPE_DEVICE_TO_HOST[10] = 9
608
+
609
+
610
+ def make_tensordesc_arg(arg, metadata):
611
+ if metadata is None:
612
+ # Currently the host side tensor descriptors get decomposed in
613
+ # the frontend to tensor desc, shape, and strides. We have no
614
+ # way to use these shape and strides when processing tensor
615
+ # descriptors which is why we provide our own decomposition
616
+ # above. Sadly this means we have to pass the shape and strides
617
+ # twice.
618
+ return [arg.base, *arg.shape, *arg.strides, arg.padding == "nan", *arg.shape, *arg.strides]
619
+
620
+ swizzle = metadata["swizzle"]
621
+ elem_size = metadata["elem_size"]
622
+ elem_type = metadata["elem_type"]
623
+ block_size = metadata["block_size"]
624
+ fp4_padded = metadata["fp4_padded"]
625
+
626
+ shape = arg.shape
627
+ strides = arg.strides
628
+ assert strides[-1] == 1
629
+ padding = 1 if arg.padding == "nan" else 0
630
+
631
+ if fp4_padded:
632
+ shape = list(shape)
633
+ shape[-1] *= 2
634
+
635
+ cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
636
+ arg.base.data_ptr(),
637
+ swizzle,
638
+ elem_size,
639
+ TMA_DTYPE_DEVICE_TO_HOST[elem_type],
640
+ block_size,
641
+ shape,
642
+ strides,
643
+ padding,
644
+ )
645
+
646
+ return [cu_tensor_map, *shape, *strides]
647
+
648
+
649
+ def wrap_handle_tensordesc(launcher, signature, tensordesc_meta):
650
+ has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
651
+ if not has_tensor_desc_arg:
652
+ return launcher
653
+
654
+ tensordesc_indices = set(
655
+ [i for i, sig in enumerate(signature.values()) if isinstance(sig, str) and sig.startswith("tensordesc")])
656
+ assert not tensordesc_meta or len(tensordesc_meta) == len(tensordesc_indices)
657
+ if not tensordesc_meta:
658
+ tensordesc_meta = [None] * len(tensordesc_indices)
659
+
660
+ def inner(*args):
661
+ final_args = list(args[:_BASE_ARGS_FORMAT_LEN])
662
+ tensordesc_idx = 0
663
+ for i, arg in enumerate(args[_BASE_ARGS_FORMAT_LEN:]):
664
+ if i in tensordesc_indices:
665
+ final_args.extend(make_tensordesc_arg(arg, tensordesc_meta[tensordesc_idx]))
666
+ tensordesc_idx += 1
667
+ else:
668
+ final_args.append(arg)
669
+ return launcher(*final_args)
670
+
671
+ return inner
672
+
673
+
674
+ class CudaLauncher(object):
675
+
676
+ def __init__(self, src, metadata):
677
+ constants = src.constants if hasattr(src, "constants") else dict()
678
+ arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
679
+ constants = {arg_idx(idx): value for idx, value in constants.items()}
680
+ signature = {idx: value for idx, value in src.signature.items()}
681
+ tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
682
+ src = make_launcher(constants, signature, tensordesc_meta)
683
+ mod = compile_module_from_src(
684
+ src=src,
685
+ name="__triton_launcher",
686
+ library_dirs=library_dirs(),
687
+ include_dirs=include_dirs,
688
+ libraries=libraries,
689
+ )
690
+
691
+ self.num_ctas = getattr(metadata, "num_ctas", 1)
692
+ self.launch = wrap_handle_tensordesc(mod.launch, signature, tensordesc_meta)
693
+ self.global_scratch_size = metadata.global_scratch_size
694
+ self.global_scratch_align = metadata.global_scratch_align
695
+ self.profile_scratch_size = metadata.profile_scratch_size
696
+ self.profile_scratch_align = metadata.profile_scratch_align
697
+ self.launch_cooperative_grid = metadata.launch_cooperative_grid
698
+ self.launch_pdl = metadata.launch_pdl
699
+
700
+ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
701
+
702
+ def allocate_scratch(size, align, allocator):
703
+ if size > 0:
704
+ grid_size = gridX * gridY * gridZ
705
+ alloc_size = grid_size * self.num_ctas * size
706
+ alloc_fn = allocator.get()
707
+ return alloc_fn(alloc_size, align, stream)
708
+ return None
709
+
710
+ global_scratch = allocate_scratch(self.global_scratch_size, self.global_scratch_align, _allocation._allocator)
711
+ profile_scratch = allocate_scratch(self.profile_scratch_size, self.profile_scratch_align,
712
+ _allocation._profile_allocator)
713
+ self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
714
+ global_scratch, profile_scratch, *args)
715
+
716
+
717
+ class CudaDriver(GPUDriver):
718
+
719
+ def __init__(self):
720
+ self.utils = CudaUtils() # TODO: make static
721
+ self.launcher_cls = CudaLauncher
722
+ super().__init__()
723
+
724
+ def get_current_target(self):
725
+ device = self.get_current_device()
726
+ capability = self.get_device_capability(device)
727
+ capability = capability[0] * 10 + capability[1]
728
+ warp_size = 32
729
+ return GPUTarget("cuda", capability, warp_size)
730
+
731
+ def get_active_torch_device(self):
732
+ import torch
733
+ return torch.device("cuda", self.get_current_device())
734
+
735
+ def get_device_interface(self):
736
+ import torch
737
+ return torch.cuda
738
+
739
+ @staticmethod
740
+ def is_active():
741
+ try:
742
+ import torch
743
+ return torch.cuda.is_available() and (torch.version.hip is None)
744
+ except ImportError:
745
+ return False
746
+
747
+ def map_python_to_cpp_type(self, ty: str) -> str:
748
+ return ty_to_cpp(ty)
749
+
750
+ def get_benchmarker(self):
751
+ from triton.testing import do_bench
752
+ return do_bench
753
+
754
+ def get_empty_cache_for_benchmark(self):
755
+ import torch
756
+
757
+ # We maintain a buffer of 256 MB that we clear
758
+ # before each kernel call to make sure that the L2 cache
759
+ # doesn't contain any input data before the run
760
+ cache_size = 256 * 1024 * 1024
761
+ return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
762
+
763
+ def clear_cache(self, cache):
764
+ cache.zero_()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/cudaGL.h ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2014 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #ifndef CUDAGL_H
51
+ #define CUDAGL_H
52
+
53
+ #include <cuda.h>
54
+ #include <GL/gl.h>
55
+
56
+ #if defined(__CUDA_API_VERSION_INTERNAL) || defined(__DOXYGEN_ONLY__) || defined(CUDA_ENABLE_DEPRECATED)
57
+ #define __CUDA_DEPRECATED
58
+ #elif defined(_MSC_VER)
59
+ #define __CUDA_DEPRECATED __declspec(deprecated)
60
+ #elif defined(__GNUC__)
61
+ #define __CUDA_DEPRECATED __attribute__((deprecated))
62
+ #else
63
+ #define __CUDA_DEPRECATED
64
+ #endif
65
+
66
+ #ifdef CUDA_FORCE_API_VERSION
67
+ #error "CUDA_FORCE_API_VERSION is no longer supported."
68
+ #endif
69
+
70
+ #if defined(__CUDA_API_VERSION_INTERNAL) || defined(CUDA_API_PER_THREAD_DEFAULT_STREAM)
71
+ #define __CUDA_API_PER_THREAD_DEFAULT_STREAM
72
+ #define __CUDA_API_PTDS(api) api ## _ptds
73
+ #define __CUDA_API_PTSZ(api) api ## _ptsz
74
+ #else
75
+ #define __CUDA_API_PTDS(api) api
76
+ #define __CUDA_API_PTSZ(api) api
77
+ #endif
78
+
79
+ #define cuGLCtxCreate cuGLCtxCreate_v2
80
+ #define cuGLMapBufferObject __CUDA_API_PTDS(cuGLMapBufferObject_v2)
81
+ #define cuGLMapBufferObjectAsync __CUDA_API_PTSZ(cuGLMapBufferObjectAsync_v2)
82
+ #define cuGLGetDevices cuGLGetDevices_v2
83
+
84
+ #ifdef __cplusplus
85
+ extern "C" {
86
+ #endif
87
+
88
+ /**
89
+ * \file cudaGL.h
90
+ * \brief Header file for the OpenGL interoperability functions of the
91
+ * low-level CUDA driver application programming interface.
92
+ */
93
+
94
+ /**
95
+ * \defgroup CUDA_GL OpenGL Interoperability
96
+ * \ingroup CUDA_DRIVER
97
+ *
98
+ * ___MANBRIEF___ OpenGL interoperability functions of the low-level CUDA
99
+ * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___
100
+ *
101
+ * This section describes the OpenGL interoperability functions of the
102
+ * low-level CUDA driver application programming interface. Note that mapping
103
+ * of OpenGL resources is performed with the graphics API agnostic, resource
104
+ * mapping interface described in \ref CUDA_GRAPHICS "Graphics Interoperability".
105
+ *
106
+ * @{
107
+ */
108
+
109
+ #if defined(_WIN32)
110
+ #if !defined(WGL_NV_gpu_affinity)
111
+ typedef void* HGPUNV;
112
+ #endif
113
+ #endif /* _WIN32 */
114
+
115
+ /**
116
+ * \brief Registers an OpenGL buffer object
117
+ *
118
+ * Registers the buffer object specified by \p buffer for access by
119
+ * CUDA. A handle to the registered object is returned as \p
120
+ * pCudaResource. The register flags \p Flags specify the intended usage,
121
+ * as follows:
122
+ *
123
+ * - ::CU_GRAPHICS_REGISTER_FLAGS_NONE: Specifies no hints about how this
124
+ * resource will be used. It is therefore assumed that this resource will be
125
+ * read from and written to by CUDA. This is the default value.
126
+ * - ::CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY: Specifies that CUDA
127
+ * will not write to this resource.
128
+ * - ::CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD: Specifies that
129
+ * CUDA will not read from this resource and will write over the
130
+ * entire contents of the resource, so none of the data previously
131
+ * stored in the resource will be preserved.
132
+ *
133
+ * \param pCudaResource - Pointer to the returned object handle
134
+ * \param buffer - name of buffer object to be registered
135
+ * \param Flags - Register flags
136
+ *
137
+ * \return
138
+ * ::CUDA_SUCCESS,
139
+ * ::CUDA_ERROR_INVALID_HANDLE,
140
+ * ::CUDA_ERROR_ALREADY_MAPPED,
141
+ * ::CUDA_ERROR_INVALID_CONTEXT,
142
+ * ::CUDA_ERROR_OPERATING_SYSTEM
143
+ * \notefnerr
144
+ *
145
+ * \sa
146
+ * ::cuGraphicsUnregisterResource,
147
+ * ::cuGraphicsMapResources,
148
+ * ::cuGraphicsResourceGetMappedPointer,
149
+ * ::cudaGraphicsGLRegisterBuffer
150
+ */
151
+ CUresult CUDAAPI cuGraphicsGLRegisterBuffer(CUgraphicsResource *pCudaResource, GLuint buffer, unsigned int Flags);
152
+
153
+ /**
154
+ * \brief Register an OpenGL texture or renderbuffer object
155
+ *
156
+ * Registers the texture or renderbuffer object specified by \p image for access by CUDA.
157
+ * A handle to the registered object is returned as \p pCudaResource.
158
+ *
159
+ * \p target must match the type of the object, and must be one of ::GL_TEXTURE_2D,
160
+ * ::GL_TEXTURE_RECTANGLE, ::GL_TEXTURE_CUBE_MAP, ::GL_TEXTURE_3D, ::GL_TEXTURE_2D_ARRAY,
161
+ * or ::GL_RENDERBUFFER.
162
+ *
163
+ * The register flags \p Flags specify the intended usage, as follows:
164
+ *
165
+ * - ::CU_GRAPHICS_REGISTER_FLAGS_NONE: Specifies no hints about how this
166
+ * resource will be used. It is therefore assumed that this resource will be
167
+ * read from and written to by CUDA. This is the default value.
168
+ * - ::CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY: Specifies that CUDA
169
+ * will not write to this resource.
170
+ * - ::CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD: Specifies that
171
+ * CUDA will not read from this resource and will write over the
172
+ * entire contents of the resource, so none of the data previously
173
+ * stored in the resource will be preserved.
174
+ * - ::CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST: Specifies that CUDA will
175
+ * bind this resource to a surface reference.
176
+ * - ::CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER: Specifies that CUDA will perform
177
+ * texture gather operations on this resource.
178
+ *
179
+ * The following image formats are supported. For brevity's sake, the list is abbreviated.
180
+ * For ex., {GL_R, GL_RG} X {8, 16} would expand to the following 4 formats
181
+ * {GL_R8, GL_R16, GL_RG8, GL_RG16} :
182
+ * - GL_RED, GL_RG, GL_RGBA, GL_LUMINANCE, GL_ALPHA, GL_LUMINANCE_ALPHA, GL_INTENSITY
183
+ * - {GL_R, GL_RG, GL_RGBA} X {8, 16, 16F, 32F, 8UI, 16UI, 32UI, 8I, 16I, 32I}
184
+ * - {GL_LUMINANCE, GL_ALPHA, GL_LUMINANCE_ALPHA, GL_INTENSITY} X
185
+ * {8, 16, 16F_ARB, 32F_ARB, 8UI_EXT, 16UI_EXT, 32UI_EXT, 8I_EXT, 16I_EXT, 32I_EXT}
186
+ *
187
+ * The following image classes are currently disallowed:
188
+ * - Textures with borders
189
+ * - Multisampled renderbuffers
190
+ *
191
+ * \param pCudaResource - Pointer to the returned object handle
192
+ * \param image - name of texture or renderbuffer object to be registered
193
+ * \param target - Identifies the type of object specified by \p image
194
+ * \param Flags - Register flags
195
+ *
196
+ * \return
197
+ * ::CUDA_SUCCESS,
198
+ * ::CUDA_ERROR_INVALID_HANDLE,
199
+ * ::CUDA_ERROR_ALREADY_MAPPED,
200
+ * ::CUDA_ERROR_INVALID_CONTEXT,
201
+ * ::CUDA_ERROR_OPERATING_SYSTEM
202
+ * \notefnerr
203
+ *
204
+ * \sa
205
+ * ::cuGraphicsUnregisterResource,
206
+ * ::cuGraphicsMapResources,
207
+ * ::cuGraphicsSubResourceGetMappedArray,
208
+ * ::cudaGraphicsGLRegisterImage
209
+ */
210
+ CUresult CUDAAPI cuGraphicsGLRegisterImage(CUgraphicsResource *pCudaResource, GLuint image, GLenum target, unsigned int Flags);
211
+
212
+ #ifdef _WIN32
213
+ /**
214
+ * \brief Gets the CUDA device associated with hGpu
215
+ *
216
+ * Returns in \p *pDevice the CUDA device associated with a \p hGpu, if
217
+ * applicable.
218
+ *
219
+ * \param pDevice - Device associated with hGpu
220
+ * \param hGpu - Handle to a GPU, as queried via ::WGL_NV_gpu_affinity()
221
+ *
222
+ * \return
223
+ * ::CUDA_SUCCESS,
224
+ * ::CUDA_ERROR_DEINITIALIZED,
225
+ * ::CUDA_ERROR_NOT_INITIALIZED,
226
+ * ::CUDA_ERROR_INVALID_CONTEXT,
227
+ * ::CUDA_ERROR_INVALID_VALUE
228
+ * \notefnerr
229
+ *
230
+ * \sa ::cuGLMapBufferObject,
231
+ * ::cuGLRegisterBufferObject, ::cuGLUnmapBufferObject,
232
+ * ::cuGLUnregisterBufferObject, ::cuGLUnmapBufferObjectAsync,
233
+ * ::cuGLSetBufferObjectMapFlags,
234
+ * ::cudaWGLGetDevice
235
+ */
236
+ CUresult CUDAAPI cuWGLGetDevice(CUdevice *pDevice, HGPUNV hGpu);
237
+ #endif /* _WIN32 */
238
+
239
+ /**
240
+ * CUDA devices corresponding to an OpenGL device
241
+ */
242
+ typedef enum CUGLDeviceList_enum {
243
+ CU_GL_DEVICE_LIST_ALL = 0x01, /**< The CUDA devices for all GPUs used by the current OpenGL context */
244
+ CU_GL_DEVICE_LIST_CURRENT_FRAME = 0x02, /**< The CUDA devices for the GPUs used by the current OpenGL context in its currently rendering frame */
245
+ CU_GL_DEVICE_LIST_NEXT_FRAME = 0x03, /**< The CUDA devices for the GPUs to be used by the current OpenGL context in the next frame */
246
+ } CUGLDeviceList;
247
+
248
+ /**
249
+ * \brief Gets the CUDA devices associated with the current OpenGL context
250
+ *
251
+ * Returns in \p *pCudaDeviceCount the number of CUDA-compatible devices
252
+ * corresponding to the current OpenGL context. Also returns in \p *pCudaDevices
253
+ * at most cudaDeviceCount of the CUDA-compatible devices corresponding to
254
+ * the current OpenGL context. If any of the GPUs being used by the current OpenGL
255
+ * context are not CUDA capable then the call will return CUDA_ERROR_NO_DEVICE.
256
+ *
257
+ * The \p deviceList argument may be any of the following:
258
+ * - ::CU_GL_DEVICE_LIST_ALL: Query all devices used by the current OpenGL context.
259
+ * - ::CU_GL_DEVICE_LIST_CURRENT_FRAME: Query the devices used by the current OpenGL context to
260
+ * render the current frame (in SLI).
261
+ * - ::CU_GL_DEVICE_LIST_NEXT_FRAME: Query the devices used by the current OpenGL context to
262
+ * render the next frame (in SLI). Note that this is a prediction, it can't be guaranteed that
263
+ * this is correct in all cases.
264
+ *
265
+ * \param pCudaDeviceCount - Returned number of CUDA devices.
266
+ * \param pCudaDevices - Returned CUDA devices.
267
+ * \param cudaDeviceCount - The size of the output device array pCudaDevices.
268
+ * \param deviceList - The set of devices to return.
269
+ *
270
+ * \return
271
+ * ::CUDA_SUCCESS,
272
+ * ::CUDA_ERROR_NO_DEVICE,
273
+ * ::CUDA_ERROR_INVALID_VALUE,
274
+ * ::CUDA_ERROR_INVALID_CONTEXT,
275
+ * ::CUDA_ERROR_INVALID_GRAPHICS_CONTEXT,
276
+ * ::CUDA_ERROR_OPERATING_SYSTEM
277
+ *
278
+ * \notefnerr
279
+ *
280
+ * \sa
281
+ * ::cuWGLGetDevice,
282
+ * ::cudaGLGetDevices
283
+ */
284
+ CUresult CUDAAPI cuGLGetDevices(unsigned int *pCudaDeviceCount, CUdevice *pCudaDevices, unsigned int cudaDeviceCount, CUGLDeviceList deviceList);
285
+
286
+ /**
287
+ * \defgroup CUDA_GL_DEPRECATED OpenGL Interoperability [DEPRECATED]
288
+ *
289
+ * ___MANBRIEF___ deprecated OpenGL interoperability functions of the low-level
290
+ * CUDA driver API (___CURRENT_FILE___) ___ENDMANBRIEF___
291
+ *
292
+ * This section describes deprecated OpenGL interoperability functionality.
293
+ *
294
+ * @{
295
+ */
296
+
297
+ /** Flags to map or unmap a resource */
298
+ typedef enum CUGLmap_flags_enum {
299
+ CU_GL_MAP_RESOURCE_FLAGS_NONE = 0x00,
300
+ CU_GL_MAP_RESOURCE_FLAGS_READ_ONLY = 0x01,
301
+ CU_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD = 0x02,
302
+ } CUGLmap_flags;
303
+
304
+ /**
305
+ * \brief Create a CUDA context for interoperability with OpenGL
306
+ *
307
+ * \deprecated This function is deprecated as of Cuda 5.0.
308
+ *
309
+ * This function is deprecated and should no longer be used. It is
310
+ * no longer necessary to associate a CUDA context with an OpenGL
311
+ * context in order to achieve maximum interoperability performance.
312
+ *
313
+ * \param pCtx - Returned CUDA context
314
+ * \param Flags - Options for CUDA context creation
315
+ * \param device - Device on which to create the context
316
+ *
317
+ * \return
318
+ * ::CUDA_SUCCESS,
319
+ * ::CUDA_ERROR_DEINITIALIZED,
320
+ * ::CUDA_ERROR_NOT_INITIALIZED,
321
+ * ::CUDA_ERROR_INVALID_CONTEXT,
322
+ * ::CUDA_ERROR_INVALID_VALUE,
323
+ * ::CUDA_ERROR_OUT_OF_MEMORY
324
+ * \notefnerr
325
+ *
326
+ * \sa ::cuCtxCreate, ::cuGLInit, ::cuGLMapBufferObject,
327
+ * ::cuGLRegisterBufferObject, ::cuGLUnmapBufferObject,
328
+ * ::cuGLUnregisterBufferObject, ::cuGLMapBufferObjectAsync,
329
+ * ::cuGLUnmapBufferObjectAsync, ::cuGLSetBufferObjectMapFlags,
330
+ * ::cuWGLGetDevice
331
+ */
332
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLCtxCreate(CUcontext *pCtx, unsigned int Flags, CUdevice device );
333
+
334
+ /**
335
+ * \brief Initializes OpenGL interoperability
336
+ *
337
+ * \deprecated This function is deprecated as of Cuda 3.0.
338
+ *
339
+ * Initializes OpenGL interoperability. This function is deprecated
340
+ * and calling it is no longer required. It may fail if the needed
341
+ * OpenGL driver facilities are not available.
342
+ *
343
+ * \return
344
+ * ::CUDA_SUCCESS,
345
+ * ::CUDA_ERROR_DEINITIALIZED,
346
+ * ::CUDA_ERROR_NOT_INITIALIZED,
347
+ * ::CUDA_ERROR_INVALID_CONTEXT,
348
+ * ::CUDA_ERROR_UNKNOWN
349
+ * \notefnerr
350
+ *
351
+ * \sa ::cuGLMapBufferObject,
352
+ * ::cuGLRegisterBufferObject, ::cuGLUnmapBufferObject,
353
+ * ::cuGLUnregisterBufferObject, ::cuGLMapBufferObjectAsync,
354
+ * ::cuGLUnmapBufferObjectAsync, ::cuGLSetBufferObjectMapFlags,
355
+ * ::cuWGLGetDevice
356
+ */
357
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLInit(void);
358
+
359
+ /**
360
+ * \brief Registers an OpenGL buffer object
361
+ *
362
+ * \deprecated This function is deprecated as of Cuda 3.0.
363
+ *
364
+ * Registers the buffer object specified by \p buffer for access by
365
+ * CUDA. This function must be called before CUDA can map the buffer
366
+ * object. There must be a valid OpenGL context bound to the current
367
+ * thread when this function is called, and the buffer name is
368
+ * resolved by that context.
369
+ *
370
+ * \param buffer - The name of the buffer object to register.
371
+ *
372
+ * \return
373
+ * ::CUDA_SUCCESS,
374
+ * ::CUDA_ERROR_DEINITIALIZED,
375
+ * ::CUDA_ERROR_NOT_INITIALIZED,
376
+ * ::CUDA_ERROR_INVALID_CONTEXT,
377
+ * ::CUDA_ERROR_ALREADY_MAPPED
378
+ * \notefnerr
379
+ *
380
+ * \sa ::cuGraphicsGLRegisterBuffer
381
+ */
382
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLRegisterBufferObject(GLuint buffer);
383
+
384
+ /**
385
+ * \brief Maps an OpenGL buffer object
386
+ *
387
+ * \deprecated This function is deprecated as of Cuda 3.0.
388
+ *
389
+ * Maps the buffer object specified by \p buffer into the address space of the
390
+ * current CUDA context and returns in \p *dptr and \p *size the base pointer
391
+ * and size of the resulting mapping.
392
+ *
393
+ * There must be a valid OpenGL context bound to the current thread
394
+ * when this function is called. This must be the same context, or a
395
+ * member of the same shareGroup, as the context that was bound when
396
+ * the buffer was registered.
397
+ *
398
+ * All streams in the current CUDA context are synchronized with the
399
+ * current GL context.
400
+ *
401
+ * \param dptr - Returned mapped base pointer
402
+ * \param size - Returned size of mapping
403
+ * \param buffer - The name of the buffer object to map
404
+ *
405
+ * \return
406
+ * ::CUDA_SUCCESS,
407
+ * ::CUDA_ERROR_DEINITIALIZED,
408
+ * ::CUDA_ERROR_NOT_INITIALIZED,
409
+ * ::CUDA_ERROR_INVALID_CONTEXT,
410
+ * ::CUDA_ERROR_INVALID_VALUE,
411
+ * ::CUDA_ERROR_MAP_FAILED
412
+ * \notefnerr
413
+ *
414
+ * \sa ::cuGraphicsMapResources
415
+ */
416
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLMapBufferObject(CUdeviceptr *dptr, size_t *size, GLuint buffer);
417
+
418
+ /**
419
+ * \brief Unmaps an OpenGL buffer object
420
+ *
421
+ * \deprecated This function is deprecated as of Cuda 3.0.
422
+ *
423
+ * Unmaps the buffer object specified by \p buffer for access by CUDA.
424
+ *
425
+ * There must be a valid OpenGL context bound to the current thread
426
+ * when this function is called. This must be the same context, or a
427
+ * member of the same shareGroup, as the context that was bound when
428
+ * the buffer was registered.
429
+ *
430
+ * All streams in the current CUDA context are synchronized with the
431
+ * current GL context.
432
+ *
433
+ * \param buffer - Buffer object to unmap
434
+ *
435
+ * \return
436
+ * ::CUDA_SUCCESS,
437
+ * ::CUDA_ERROR_DEINITIALIZED,
438
+ * ::CUDA_ERROR_NOT_INITIALIZED,
439
+ * ::CUDA_ERROR_INVALID_CONTEXT,
440
+ * ::CUDA_ERROR_INVALID_VALUE
441
+ * \notefnerr
442
+ *
443
+ * \sa ::cuGraphicsUnmapResources
444
+ */
445
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLUnmapBufferObject(GLuint buffer);
446
+
447
+ /**
448
+ * \brief Unregister an OpenGL buffer object
449
+ *
450
+ * \deprecated This function is deprecated as of Cuda 3.0.
451
+ *
452
+ * Unregisters the buffer object specified by \p buffer. This
453
+ * releases any resources associated with the registered buffer.
454
+ * After this call, the buffer may no longer be mapped for access by
455
+ * CUDA.
456
+ *
457
+ * There must be a valid OpenGL context bound to the current thread
458
+ * when this function is called. This must be the same context, or a
459
+ * member of the same shareGroup, as the context that was bound when
460
+ * the buffer was registered.
461
+ *
462
+ * \param buffer - Name of the buffer object to unregister
463
+ *
464
+ * \return
465
+ * ::CUDA_SUCCESS,
466
+ * ::CUDA_ERROR_DEINITIALIZED,
467
+ * ::CUDA_ERROR_NOT_INITIALIZED,
468
+ * ::CUDA_ERROR_INVALID_CONTEXT,
469
+ * ::CUDA_ERROR_INVALID_VALUE
470
+ * \notefnerr
471
+ *
472
+ * \sa ::cuGraphicsUnregisterResource
473
+ */
474
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLUnregisterBufferObject(GLuint buffer);
475
+
476
+ /**
477
+ * \brief Set the map flags for an OpenGL buffer object
478
+ *
479
+ * \deprecated This function is deprecated as of Cuda 3.0.
480
+ *
481
+ * Sets the map flags for the buffer object specified by \p buffer.
482
+ *
483
+ * Changes to \p Flags will take effect the next time \p buffer is mapped.
484
+ * The \p Flags argument may be any of the following:
485
+ * - ::CU_GL_MAP_RESOURCE_FLAGS_NONE: Specifies no hints about how this
486
+ * resource will be used. It is therefore assumed that this resource will be
487
+ * read from and written to by CUDA kernels. This is the default value.
488
+ * - ::CU_GL_MAP_RESOURCE_FLAGS_READ_ONLY: Specifies that CUDA kernels which
489
+ * access this resource will not write to this resource.
490
+ * - ::CU_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD: Specifies that CUDA kernels
491
+ * which access this resource will not read from this resource and will
492
+ * write over the entire contents of the resource, so none of the data
493
+ * previously stored in the resource will be preserved.
494
+ *
495
+ * If \p buffer has not been registered for use with CUDA, then
496
+ * ::CUDA_ERROR_INVALID_HANDLE is returned. If \p buffer is presently
497
+ * mapped for access by CUDA, then ::CUDA_ERROR_ALREADY_MAPPED is returned.
498
+ *
499
+ * There must be a valid OpenGL context bound to the current thread
500
+ * when this function is called. This must be the same context, or a
501
+ * member of the same shareGroup, as the context that was bound when
502
+ * the buffer was registered.
503
+ *
504
+ * \param buffer - Buffer object to unmap
505
+ * \param Flags - Map flags
506
+ *
507
+ * \return
508
+ * ::CUDA_SUCCESS,
509
+ * ::CUDA_ERROR_NOT_INITIALIZED,
510
+ * ::CUDA_ERROR_INVALID_HANDLE,
511
+ * ::CUDA_ERROR_ALREADY_MAPPED,
512
+ * ::CUDA_ERROR_INVALID_CONTEXT,
513
+ * \notefnerr
514
+ *
515
+ * \sa ::cuGraphicsResourceSetMapFlags
516
+ */
517
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLSetBufferObjectMapFlags(GLuint buffer, unsigned int Flags);
518
+
519
+ /**
520
+ * \brief Maps an OpenGL buffer object
521
+ *
522
+ * \deprecated This function is deprecated as of Cuda 3.0.
523
+ *
524
+ * Maps the buffer object specified by \p buffer into the address space of the
525
+ * current CUDA context and returns in \p *dptr and \p *size the base pointer
526
+ * and size of the resulting mapping.
527
+ *
528
+ * There must be a valid OpenGL context bound to the current thread
529
+ * when this function is called. This must be the same context, or a
530
+ * member of the same shareGroup, as the context that was bound when
531
+ * the buffer was registered.
532
+ *
533
+ * Stream \p hStream in the current CUDA context is synchronized with
534
+ * the current GL context.
535
+ *
536
+ * \param dptr - Returned mapped base pointer
537
+ * \param size - Returned size of mapping
538
+ * \param buffer - The name of the buffer object to map
539
+ * \param hStream - Stream to synchronize
540
+ *
541
+ * \return
542
+ * ::CUDA_SUCCESS,
543
+ * ::CUDA_ERROR_DEINITIALIZED,
544
+ * ::CUDA_ERROR_NOT_INITIALIZED,
545
+ * ::CUDA_ERROR_INVALID_CONTEXT,
546
+ * ::CUDA_ERROR_INVALID_VALUE,
547
+ * ::CUDA_ERROR_MAP_FAILED
548
+ * \notefnerr
549
+ *
550
+ * \sa ::cuGraphicsMapResources
551
+ */
552
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLMapBufferObjectAsync(CUdeviceptr *dptr, size_t *size, GLuint buffer, CUstream hStream);
553
+
554
+ /**
555
+ * \brief Unmaps an OpenGL buffer object
556
+ *
557
+ * \deprecated This function is deprecated as of Cuda 3.0.
558
+ *
559
+ * Unmaps the buffer object specified by \p buffer for access by CUDA.
560
+ *
561
+ * There must be a valid OpenGL context bound to the current thread
562
+ * when this function is called. This must be the same context, or a
563
+ * member of the same shareGroup, as the context that was bound when
564
+ * the buffer was registered.
565
+ *
566
+ * Stream \p hStream in the current CUDA context is synchronized with
567
+ * the current GL context.
568
+ *
569
+ * \param buffer - Name of the buffer object to unmap
570
+ * \param hStream - Stream to synchronize
571
+ *
572
+ * \return
573
+ * ::CUDA_SUCCESS,
574
+ * ::CUDA_ERROR_DEINITIALIZED,
575
+ * ::CUDA_ERROR_NOT_INITIALIZED,
576
+ * ::CUDA_ERROR_INVALID_CONTEXT,
577
+ * ::CUDA_ERROR_INVALID_VALUE
578
+ * \notefnerr
579
+ *
580
+ * \sa ::cuGraphicsUnmapResources
581
+ */
582
+ __CUDA_DEPRECATED CUresult CUDAAPI cuGLUnmapBufferObjectAsync(GLuint buffer, CUstream hStream);
583
+
584
+ /** @} */ /* END CUDA_GL_DEPRECATED */
585
+ /** @} */ /* END CUDA_GL */
586
+
587
+
588
+ #if defined(__CUDA_API_VERSION_INTERNAL)
589
+ #undef cuGLCtxCreate
590
+ #undef cuGLMapBufferObject
591
+ #undef cuGLMapBufferObjectAsync
592
+ #undef cuGLGetDevices
593
+
594
+ CUresult CUDAAPI cuGLGetDevices(unsigned int *pCudaDeviceCount, CUdevice *pCudaDevices, unsigned int cudaDeviceCount, CUGLDeviceList deviceList);
595
+ CUresult CUDAAPI cuGLMapBufferObject_v2(CUdeviceptr *dptr, size_t *size, GLuint buffer);
596
+ CUresult CUDAAPI cuGLMapBufferObjectAsync_v2(CUdeviceptr *dptr, size_t *size, GLuint buffer, CUstream hStream);
597
+ CUresult CUDAAPI cuGLCtxCreate(CUcontext *pCtx, unsigned int Flags, CUdevice device );
598
+ CUresult CUDAAPI cuGLMapBufferObject(CUdeviceptr_v1 *dptr, unsigned int *size, GLuint buffer);
599
+ CUresult CUDAAPI cuGLMapBufferObjectAsync(CUdeviceptr_v1 *dptr, unsigned int *size, GLuint buffer, CUstream hStream);
600
+ #endif /* __CUDA_API_VERSION_INTERNAL */
601
+
602
+ #ifdef __cplusplus
603
+ };
604
+ #endif
605
+
606
+ #undef __CUDA_DEPRECATED
607
+
608
+ #endif
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/cupti_pcsampling_util.h ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(_CUPTI_PCSAMPLING_UTIL_H_)
2
+ #define _CUPTI_PCSAMPLING_UTIL_H_
3
+
4
+ #include <cupti_pcsampling.h>
5
+ #include <fstream>
6
+
7
+ #include <cupti_common.h>
8
+
9
+ #ifndef CUPTI_UTIL_STRUCT_SIZE
10
+ #define CUPTI_UTIL_STRUCT_SIZE(type_, lastfield_) (offsetof(type_, lastfield_) + sizeof(((type_*)0)->lastfield_))
11
+ #endif
12
+
13
+ #ifndef CHECK_PC_SAMPLING_STRUCT_FIELD_EXISTS
14
+ #define CHECK_PC_SAMPLING_STRUCT_FIELD_EXISTS(type, member, structSize) \
15
+ (offsetof(type, member) < structSize)
16
+ #endif
17
+
18
+ #if defined(__cplusplus)
19
+ extern "C" {
20
+ #endif
21
+
22
+ #if defined(__GNUC__)
23
+ #pragma GCC visibility push(default)
24
+ #endif
25
+
26
+ namespace CUPTI { namespace PcSamplingUtil {
27
+
28
+ /**
29
+ * \defgroup CUPTI_PCSAMPLING_UTILITY CUPTI PC Sampling Utility API
30
+ * Functions, types, and enums that implement the CUPTI PC Sampling Utility API.
31
+ * @{
32
+ */
33
+
34
+ /**
35
+ * \brief Header info will be stored in file.
36
+ */
37
+ typedef struct PACKED_ALIGNMENT {
38
+ /**
39
+ * Version of file format.
40
+ */
41
+ uint32_t version;
42
+ /**
43
+ * Total number of buffers present in the file.
44
+ */
45
+ uint32_t totalBuffers;
46
+ } Header;
47
+
48
+ /**
49
+ * \brief BufferInfo will be stored in the file for every buffer
50
+ * i.e for every call of UtilDumpPcSamplingBufferInFile() API.
51
+ */
52
+ typedef struct PACKED_ALIGNMENT {
53
+ /**
54
+ * Total number of PC records.
55
+ */
56
+ uint64_t recordCount;
57
+ /**
58
+ * Count of all stall reasons supported on the GPU
59
+ */
60
+ size_t numStallReasons;
61
+ /**
62
+ * Total number of stall reasons in single record.
63
+ */
64
+ uint64_t numSelectedStallReasons;
65
+ /**
66
+ * Buffer size in Bytes.
67
+ */
68
+ uint64_t bufferByteSize;
69
+ } BufferInfo;
70
+
71
+ /**
72
+ * \brief All available stall reasons name and respective indexes
73
+ * will be stored in it.
74
+ */
75
+ typedef struct PACKED_ALIGNMENT {
76
+ /**
77
+ * Number of all available stall reasons
78
+ */
79
+ size_t numStallReasons;
80
+ /**
81
+ * Stall reasons names of all available stall reasons
82
+ */
83
+ char **stallReasons;
84
+ /**
85
+ * Stall reason index of all available stall reasons
86
+ */
87
+ uint32_t *stallReasonIndex;
88
+ } PcSamplingStallReasons;
89
+
90
+ /**
91
+ * \brief CUPTI PC sampling buffer types.
92
+ *
93
+ */
94
+ typedef enum {
95
+ /**
96
+ * Invalid buffer type.
97
+ */
98
+ PC_SAMPLING_BUFFER_INVALID = 0,
99
+ /**
100
+ * Refers to CUpti_PCSamplingData buffer.
101
+ */
102
+ PC_SAMPLING_BUFFER_PC_TO_COUNTER_DATA = 1
103
+ } PcSamplingBufferType;
104
+
105
+ /**
106
+ * \brief CUPTI PC sampling utility API result codes.
107
+ *
108
+ * Error and result codes returned by CUPTI PC sampling utility API.
109
+ */
110
+ typedef enum {
111
+ /**
112
+ * No error
113
+ */
114
+ CUPTI_UTIL_SUCCESS = 0,
115
+ /**
116
+ * One or more of the parameters are invalid.
117
+ */
118
+ CUPTI_UTIL_ERROR_INVALID_PARAMETER = 1,
119
+ /**
120
+ * Unable to create a new file
121
+ */
122
+ CUPTI_UTIL_ERROR_UNABLE_TO_CREATE_FILE = 2,
123
+ /**
124
+ * Unable to open a file
125
+ */
126
+ CUPTI_UTIL_ERROR_UNABLE_TO_OPEN_FILE = 3,
127
+ /**
128
+ * Read or write operation failed
129
+ */
130
+ CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED = 4,
131
+ /**
132
+ * Provided file handle is corrupted.
133
+ */
134
+ CUPTI_UTIL_ERROR_FILE_HANDLE_CORRUPTED = 5,
135
+ /**
136
+ * seek operation failed.
137
+ */
138
+ CUPTI_UTIL_ERROR_SEEK_OPERATION_FAILED = 6,
139
+ /**
140
+ * Unable to allocate enough memory to perform the requested
141
+ * operation.
142
+ */
143
+ CUPTI_UTIL_ERROR_OUT_OF_MEMORY = 7,
144
+ /**
145
+ * An unknown internal error has occurred.
146
+ */
147
+ CUPTI_UTIL_ERROR_UNKNOWN = 999,
148
+ CUPTI_UTIL_ERROR_FORCE_INT = 0x7fffffff
149
+ } CUptiUtilResult;
150
+
151
+ /**
152
+ * \brief Params for \ref CuptiUtilPutPcSampData
153
+ */
154
+ typedef struct {
155
+ /**
156
+ * Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
157
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
158
+ * available in the structure. Used to preserve backward compatibility.
159
+ */
160
+ size_t size;
161
+ /**
162
+ * Type of buffer to store in file
163
+ */
164
+ PcSamplingBufferType bufferType;
165
+ /**
166
+ * PC sampling buffer.
167
+ */
168
+ void *pSamplingData;
169
+ /**
170
+ * Number of configured attributes
171
+ */
172
+ size_t numAttributes;
173
+ /**
174
+ * Refer \ref CUpti_PCSamplingConfigurationInfo
175
+ * It is expected to provide configuration details of at least
176
+ * CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON attribute.
177
+ */
178
+ CUpti_PCSamplingConfigurationInfo *pPCSamplingConfigurationInfo;
179
+ /**
180
+ * Refer \ref PcSamplingStallReasons.
181
+ */
182
+ PcSamplingStallReasons *pPcSamplingStallReasons;
183
+ /**
184
+ * File name to store buffer into it.
185
+ */
186
+ const char* fileName;
187
+ } CUptiUtil_PutPcSampDataParams;
188
+ #define CUptiUtil_PutPcSampDataParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_PutPcSampDataParams, fileName)
189
+
190
+ /**
191
+ * \brief Dump PC sampling data into the file.
192
+ *
193
+ * This API can be called multiple times.
194
+ * It will append buffer in the file.
195
+ * For every buffer it will store BufferInfo
196
+ * so that before retrieving data it will help to allocate buffer
197
+ * to store retrieved data.
198
+ * This API creates file if file does not present.
199
+ * If stallReasonIndex or stallReasons pointer of \ref CUptiUtil_PutPcSampDataParams is NULL
200
+ * then stall reasons data will not be stored in file.
201
+ * It is expected to store all available stall reason data at least once to refer it during
202
+ * offline correlation.
203
+ *
204
+ * \retval CUPTI_UTIL_SUCCESS
205
+ * \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if buffer type is invalid
206
+ * or if either of pSamplingData, pParams pointer is NULL or stall reason configuration details not provided
207
+ * or filename is empty.
208
+ * \retval CUPTI_UTIL_ERROR_UNABLE_TO_CREATE_FILE
209
+ * \retval CUPTI_UTIL_ERROR_UNABLE_TO_OPEN_FILE
210
+ * \retval CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED
211
+ */
212
+ CUptiUtilResult CUPTIUTILAPI CuptiUtilPutPcSampData(CUptiUtil_PutPcSampDataParams *pParams);
213
+
214
+ /**
215
+ * \brief Params for \ref CuptiUtilGetHeaderData
216
+ */
217
+ typedef struct {
218
+ /**
219
+ * Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
220
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
221
+ * available in the structure. Used to preserve backward compatibility.
222
+ */
223
+ size_t size;
224
+ /**
225
+ * File handle.
226
+ */
227
+ std::ifstream *fileHandler;
228
+ /**
229
+ * Header Info.
230
+ */
231
+ Header headerInfo;
232
+
233
+ } CUptiUtil_GetHeaderDataParams;
234
+ #define CUptiUtil_GetHeaderDataParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_GetHeaderDataParams, headerInfo)
235
+
236
+ /**
237
+ * \brief Get header data of file.
238
+ *
239
+ * This API must be called once initially while retrieving data from file.
240
+ * \ref Header structure, it gives info about total number
241
+ * of buffers present in the file.
242
+ *
243
+ * \retval CUPTI_UTIL_SUCCESS
244
+ * \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if either of pParam or fileHandle is NULL or param struct size is incorrect.
245
+ * \retval CUPTI_UTIL_ERROR_FILE_HANDLE_CORRUPTED file handle is not in good state to read data from file
246
+ * \retval CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED failed to read data from file.
247
+ */
248
+ CUptiUtilResult CUPTIUTILAPI CuptiUtilGetHeaderData(CUptiUtil_GetHeaderDataParams *pParams);
249
+
250
+ /**
251
+ * \brief Params for \ref CuptiUtilGetBufferInfo
252
+ */
253
+ typedef struct {
254
+ /**
255
+ * Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
256
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
257
+ * available in the structure. Used to preserve backward compatibility.
258
+ */
259
+ size_t size;
260
+ /**
261
+ * File handle.
262
+ */
263
+ std::ifstream *fileHandler;
264
+ /**
265
+ * Buffer Info.
266
+ */
267
+ BufferInfo bufferInfoData;
268
+ } CUptiUtil_GetBufferInfoParams;
269
+ #define CUptiUtil_GetBufferInfoParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_GetBufferInfoParams, bufferInfoData)
270
+
271
+ /**
272
+ * \brief Get buffer info data of file.
273
+ *
274
+ * This API must be called every time before calling CuptiUtilGetPcSampData API.
275
+ * \ref BufferInfo structure, it gives info about recordCount and stallReasonCount
276
+ * of every record in the buffer. This will help to allocate exact buffer to retrieve data into it.
277
+ *
278
+ * \retval CUPTI_UTIL_SUCCESS
279
+ * \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if either of pParam or fileHandle is NULL or param struct size is incorrect.
280
+ * \retval CUPTI_UTIL_ERROR_FILE_HANDLE_CORRUPTED file handle is not in good state to read data from file.
281
+ * \retval CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED failed to read data from file.
282
+ */
283
+ CUptiUtilResult CUPTIUTILAPI CuptiUtilGetBufferInfo(CUptiUtil_GetBufferInfoParams *pParams);
284
+
285
+ /**
286
+ * \brief Params for \ref CuptiUtilGetPcSampData
287
+ */
288
+ typedef struct {
289
+ /**
290
+ * Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
291
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
292
+ * available in the structure. Used to preserve backward compatibility.
293
+ */
294
+ size_t size;
295
+ /**
296
+ * File handle.
297
+ */
298
+ std::ifstream *fileHandler;
299
+ /**
300
+ * Type of buffer to store in file
301
+ */
302
+ PcSamplingBufferType bufferType;
303
+ /**
304
+ * Pointer to collected buffer info using \ref CuptiUtilGetBufferInfo
305
+ */
306
+ BufferInfo *pBufferInfoData;
307
+ /**
308
+ * Pointer to allocated memory to store retrieved data from file.
309
+ */
310
+ void *pSamplingData;
311
+ /**
312
+ * Number of configuration attributes
313
+ */
314
+ size_t numAttributes;
315
+ /**
316
+ * Refer \ref CUpti_PCSamplingConfigurationInfo
317
+ */
318
+ CUpti_PCSamplingConfigurationInfo *pPCSamplingConfigurationInfo;
319
+ /**
320
+ * Refer \ref PcSamplingStallReasons.
321
+ * For stallReasons field of \ref PcSamplingStallReasons it is expected to
322
+ * allocate memory for each string element of array.
323
+ */
324
+ PcSamplingStallReasons *pPcSamplingStallReasons;
325
+ } CUptiUtil_GetPcSampDataParams;
326
+ #define CUptiUtil_GetPcSampDataParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_GetPcSampDataParams, pPcSamplingStallReasons)
327
+
328
+ /**
329
+ * \brief Retrieve PC sampling data from file into allocated buffer.
330
+ *
331
+ * This API must be called after CuptiUtilGetBufferInfo API.
332
+ * It will retrieve data from file into allocated buffer.
333
+ *
334
+ * \retval CUPTI_UTIL_SUCCESS
335
+ * \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if buffer type is invalid
336
+ * or if either of pSampData, pParams is NULL. If pPcSamplingStallReasons is not NULL then
337
+ * error out if either of stallReasonIndex, stallReasons or stallReasons array element pointer is NULL.
338
+ * or filename is empty.
339
+ * \retval CUPTI_UTIL_ERROR_READ_WRITE_OPERATION_FAILED
340
+ * \retval CUPTI_UTIL_ERROR_FILE_HANDLE_CORRUPTED file handle is not in good state to read data from file.
341
+ */
342
+ CUptiUtilResult CUPTIUTILAPI CuptiUtilGetPcSampData(CUptiUtil_GetPcSampDataParams *pParams);
343
+
344
+ /**
345
+ * \brief Params for \ref CuptiUtilMergePcSampData
346
+ */
347
+ typedef struct
348
+ {
349
+ /**
350
+ * Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
351
+ * CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
352
+ * available in the structure. Used to preserve backward compatibility.
353
+ */
354
+ size_t size;
355
+ /**
356
+ * Number of buffers to merge.
357
+ */
358
+ size_t numberOfBuffers;
359
+ /**
360
+ * Pointer to array of buffers to merge
361
+ */
362
+ CUpti_PCSamplingData *PcSampDataBuffer;
363
+ /**
364
+ * Pointer to array of merged buffers as per the range id.
365
+ */
366
+ CUpti_PCSamplingData **MergedPcSampDataBuffers;
367
+ /**
368
+ * Number of merged buffers.
369
+ */
370
+ size_t *numMergedBuffer;
371
+ } CUptiUtil_MergePcSampDataParams;
372
+ #define CUptiUtil_MergePcSampDataParamsSize CUPTI_UTIL_STRUCT_SIZE(CUptiUtil_MergePcSampDataParams, numMergedBuffer)
373
+
374
+ /**
375
+ * \brief Merge PC sampling data range id wise.
376
+ *
377
+ * This API merge PC sampling data range id wise.
378
+ * It allocates memory for merged data and fill data in it
379
+ * and provide buffer pointer in MergedPcSampDataBuffers field.
380
+ * It is expected from user to free merge data buffers after use.
381
+ *
382
+ * \retval CUPTI_UTIL_SUCCESS
383
+ * \retval CUPTI_UTIL_ERROR_INVALID_PARAMETER error out if param struct size is invalid
384
+ * or count of buffers to merge is invalid i.e less than 1
385
+ * or either of PcSampDataBuffer, MergedPcSampDataBuffers, numMergedBuffer is NULL
386
+ * \retval CUPTI_UTIL_ERROR_OUT_OF_MEMORY Unable to allocate memory for merged buffer.
387
+ */
388
+ CUptiUtilResult CUPTIUTILAPI CuptiUtilMergePcSampData(CUptiUtil_MergePcSampDataParams *pParams);
389
+
390
+ /** @} */ /* END CUPTI_PCSAMPLING_UTILITY */
391
+
392
+ } }
393
+
394
+ #if defined(__GNUC__)
395
+ #pragma GCC visibility pop
396
+ #endif
397
+
398
+ #if defined(__cplusplus)
399
+ }
400
+ #endif
401
+
402
+ #endif
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/driver_types.h ADDED
The diff for this file is too large to render. See raw diff
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/generated_cudaVDPAU_meta.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file is generated. Any changes you make will be lost during the next clean build.
2
+
3
+ // Dependent includes
4
+ #include <vdpau/vdpau.h>
5
+
6
+ // CUDA public interface, for type definitions and cu* function prototypes
7
+ #include "cudaVDPAU.h"
8
+
9
+
10
+ // *************************************************************************
11
+ // Definitions of structs to hold parameters for each function
12
+ // *************************************************************************
13
+
14
+ typedef struct cuVDPAUGetDevice_params_st {
15
+ CUdevice *pDevice;
16
+ VdpDevice vdpDevice;
17
+ VdpGetProcAddress *vdpGetProcAddress;
18
+ } cuVDPAUGetDevice_params;
19
+
20
+ typedef struct cuVDPAUCtxCreate_v2_params_st {
21
+ CUcontext *pCtx;
22
+ unsigned int flags;
23
+ CUdevice device;
24
+ VdpDevice vdpDevice;
25
+ VdpGetProcAddress *vdpGetProcAddress;
26
+ } cuVDPAUCtxCreate_v2_params;
27
+
28
+ typedef struct cuGraphicsVDPAURegisterVideoSurface_params_st {
29
+ CUgraphicsResource *pCudaResource;
30
+ VdpVideoSurface vdpSurface;
31
+ unsigned int flags;
32
+ } cuGraphicsVDPAURegisterVideoSurface_params;
33
+
34
+ typedef struct cuGraphicsVDPAURegisterOutputSurface_params_st {
35
+ CUgraphicsResource *pCudaResource;
36
+ VdpOutputSurface vdpSurface;
37
+ unsigned int flags;
38
+ } cuGraphicsVDPAURegisterOutputSurface_params;
39
+
40
+ typedef struct cuVDPAUCtxCreate_params_st {
41
+ CUcontext *pCtx;
42
+ unsigned int flags;
43
+ CUdevice device;
44
+ VdpDevice vdpDevice;
45
+ VdpGetProcAddress *vdpGetProcAddress;
46
+ } cuVDPAUCtxCreate_params;
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/nvperf_target.h ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef NVPERF_TARGET_H
2
+ #define NVPERF_TARGET_H
3
+
4
+ /*
5
+ * Copyright 2014-2024 NVIDIA Corporation. All rights reserved.
6
+ *
7
+ * NOTICE TO USER:
8
+ *
9
+ * This source code is subject to NVIDIA ownership rights under U.S. and
10
+ * international Copyright laws.
11
+ *
12
+ * This software and the information contained herein is PROPRIETARY and
13
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
14
+ * of a form of NVIDIA software license agreement.
15
+ *
16
+ * NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
17
+ * CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
18
+ * IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
19
+ * REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
20
+ * MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
21
+ * IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
22
+ * OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
23
+ * OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
24
+ * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
25
+ * OR PERFORMANCE OF THIS SOURCE CODE.
26
+ *
27
+ * U.S. Government End Users. This source code is a "commercial item" as
28
+ * that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
29
+ * "commercial computer software" and "commercial computer software
30
+ * documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
31
+ * and is provided to the U.S. Government only as a commercial end item.
32
+ * Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
33
+ * 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
34
+ * source code with only those rights set forth herein.
35
+ *
36
+ * Any use of this source code in individual and commercial software must
37
+ * include, in the user documentation and internal comments to the code,
38
+ * the above Disclaimer and U.S. Government End Users Notice.
39
+ */
40
+
41
+ #include <stddef.h>
42
+ #include <stdint.h>
43
+ #include "nvperf_common.h"
44
+
45
+ #if defined(__GNUC__) && defined(NVPA_SHARED_LIB)
46
+ #pragma GCC visibility push(default)
47
+ #if !defined(NVPW_LOCAL)
48
+ #define NVPW_LOCAL __attribute__ ((visibility ("hidden")))
49
+ #endif
50
+ #else
51
+ #if !defined(NVPW_LOCAL)
52
+ #define NVPW_LOCAL
53
+ #endif
54
+ #endif
55
+
56
+ #ifdef __cplusplus
57
+ extern "C" {
58
+ #endif
59
+
60
+ /**
61
+ * @file nvperf_target.h
62
+ */
63
+
64
+ #ifndef NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_DEFINED
65
+ #define NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_DEFINED
66
+ /// GPU architecture support level
67
+ typedef enum NVPW_GpuArchitectureSupportLevel
68
+ {
69
+ NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_UNKNOWN = 0,
70
+ NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_UNSUPPORTED,
71
+ NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_SUPPORTED
72
+ } NVPW_GpuArchitectureSupportLevel;
73
+ #endif //NVPW_GPU_ARCHITECTURE_SUPPORT_LEVEL_DEFINED
74
+
75
+ #ifndef NVPW_SLI_SUPPORT_LEVEL_DEFINED
76
+ #define NVPW_SLI_SUPPORT_LEVEL_DEFINED
77
+ /// SLI configuration support level
78
+ typedef enum NVPW_SliSupportLevel
79
+ {
80
+ NVPW_SLI_SUPPORT_LEVEL_UNKNOWN = 0,
81
+ NVPW_SLI_SUPPORT_LEVEL_UNSUPPORTED,
82
+ /// Only Non-SLI configurations are supported.
83
+ NVPW_SLI_SUPPORT_LEVEL_SUPPORTED_NON_SLI_CONFIGURATION
84
+ } NVPW_SliSupportLevel;
85
+ #endif //NVPW_SLI_SUPPORT_LEVEL_DEFINED
86
+
87
+ #ifndef NVPW_VGPU_SUPPORT_LEVEL_DEFINED
88
+ #define NVPW_VGPU_SUPPORT_LEVEL_DEFINED
89
+ /// Virtualized GPU configuration support level
90
+ typedef enum NVPW_VGpuSupportLevel
91
+ {
92
+ NVPW_VGPU_SUPPORT_LEVEL_UNKNOWN = 0,
93
+ NVPW_VGPU_SUPPORT_LEVEL_UNSUPPORTED,
94
+ /// Supported but not allowed by system admin.
95
+ NVPW_VGPU_SUPPORT_LEVEL_SUPPORTED_DISALLOWED,
96
+ NVPW_VGPU_SUPPORT_LEVEL_SUPPORTED_ALLOWED,
97
+ NVPW_VGPU_SUPPORT_LEVEL_SUPPORTED_NON_VGPU_CONFIGURATION
98
+ } NVPW_VGpuSupportLevel;
99
+ #endif //NVPW_VGPU_SUPPORT_LEVEL_DEFINED
100
+
101
+ #ifndef NVPW_CONF_COMPUTE_SUPPORT_LEVEL_DEFINED
102
+ #define NVPW_CONF_COMPUTE_SUPPORT_LEVEL_DEFINED
103
+ /// Confidential Compute mode support level
104
+ typedef enum NVPW_ConfidentialComputeSupportLevel
105
+ {
106
+ NVPW_CONF_COMPUTE_SUPPORT_LEVEL_UNKNOWN = 0,
107
+ NVPW_CONF_COMPUTE_SUPPORT_LEVEL_UNSUPPORTED,
108
+ NVPW_CONF_COMPUTE_SUPPORT_LEVEL_SUPPORTED_NON_CONF_COMPUTE_CONFIGURATION,
109
+ NVPW_CONF_COMPUTE_SUPPORT_LEVEL_SUPPORTED_CONF_COMPUTE_DEVTOOLS_MODE
110
+ } NVPW_ConfidentialComputeSupportLevel;
111
+ #endif //NVPW_CONF_COMPUTE_SUPPORT_LEVEL_DEFINED
112
+
113
+ #ifndef NVPW_CMP_SUPPORT_LEVEL_DEFINED
114
+ #define NVPW_CMP_SUPPORT_LEVEL_DEFINED
115
+ /// CMP support level
116
+ typedef enum NVPW_CmpSupportLevel
117
+ {
118
+ NVPW_CMP_SUPPORT_LEVEL_UNKNOWN = 0,
119
+ NVPW_CMP_SUPPORT_LEVEL_UNSUPPORTED,
120
+ NVPW_CMP_SUPPORT_LEVEL_SUPPORTED_NON_CMP_CONFIGURATON
121
+ } NVPW_CmpSupportLevel;
122
+ #endif //NVPW_CMP_SUPPORT_LEVEL_DEFINED
123
+
124
+ #ifndef NVPW_WSL_SUPPORT_LEVEL_DEFINED
125
+ #define NVPW_WSL_SUPPORT_LEVEL_DEFINED
126
+ /// WSL support level
127
+ typedef enum NVPW_WslSupportLevel
128
+ {
129
+ NVPW_WSL_SUPPORT_LEVEL_UNKNOWN = 0,
130
+ NVPW_WSL_SUPPORT_LEVEL_UNSUPPORTED_INSUFFICIENT_DRIVER_VERSION,
131
+ NVPW_WSL_SUPPORT_LEVEL_SUPPORTED,
132
+ NVPW_WSL_SUPPORT_LEVEL_SUPPORTED_NON_WSL_CONFIGURATION
133
+ } NVPW_WslSupportLevel;
134
+ #endif //NVPW_WSL_SUPPORT_LEVEL_DEFINED
135
+
136
+ #ifndef NVPW_MIG_SUPPORT_LEVEL_DEFINED
137
+ #define NVPW_MIG_SUPPORT_LEVEL_DEFINED
138
+ /// MIG support level
139
+ typedef enum NVPW_MigSupportLevel
140
+ {
141
+ NVPW_MIG_SUPPORT_LEVEL_UNKNOWN = 0,
142
+ NVPW_MIG_SUPPORT_LEVEL_UNSUPPORTED,
143
+ NVPW_MIG_SUPPORT_LEVEL_SUPPORTED,
144
+ NVPW_MIG_SUPPORT_LEVEL_SUPPORTED_NON_MIG_CONFIGURATION
145
+ } NVPW_MigSupportLevel;
146
+ #endif //NVPW_MIG_SUPPORT_LEVEL_DEFINED
147
+
148
+ typedef struct NVPW_InitializeTarget_Params
149
+ {
150
+ /// [in]
151
+ size_t structSize;
152
+ /// [in] assign to NULL
153
+ void* pPriv;
154
+ } NVPW_InitializeTarget_Params;
155
+ #define NVPW_InitializeTarget_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_InitializeTarget_Params, pPriv)
156
+
157
+ /// Load the target library.
158
+ NVPA_Status NVPW_InitializeTarget(NVPW_InitializeTarget_Params* pParams);
159
+
160
+ typedef struct NVPW_GetDeviceCount_Params
161
+ {
162
+ /// [in]
163
+ size_t structSize;
164
+ /// [in] assign to NULL
165
+ void* pPriv;
166
+ size_t numDevices;
167
+ } NVPW_GetDeviceCount_Params;
168
+ #define NVPW_GetDeviceCount_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_GetDeviceCount_Params, numDevices)
169
+
170
+ NVPA_Status NVPW_GetDeviceCount(NVPW_GetDeviceCount_Params* pParams);
171
+
172
+ typedef struct NVPW_Device_GetNames_Params
173
+ {
174
+ /// [in]
175
+ size_t structSize;
176
+ /// [in] assign to NULL
177
+ void* pPriv;
178
+ size_t deviceIndex;
179
+ const char* pDeviceName;
180
+ const char* pChipName;
181
+ } NVPW_Device_GetNames_Params;
182
+ #define NVPW_Device_GetNames_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_GetNames_Params, pChipName)
183
+
184
+ NVPA_Status NVPW_Device_GetNames(NVPW_Device_GetNames_Params* pParams);
185
+
186
+ typedef struct NVPW_PciBusId
187
+ {
188
+ /// The PCI domain on which the device bus resides.
189
+ uint32_t domain;
190
+ /// The bus on which the device resides.
191
+ uint16_t bus;
192
+ /// device ID.
193
+ uint16_t device;
194
+ } NVPW_PciBusId;
195
+ #define NVPW_PciBusId_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PciBusId, device)
196
+
197
+ typedef struct NVPW_Device_GetPciBusIds_Params
198
+ {
199
+ /// [in]
200
+ size_t structSize;
201
+ /// [in] assign to NULL
202
+ void* pPriv;
203
+ /// [in] caller-allocated array of NVPW_PciBusId, indexed by NVPW deviceIndex
204
+ NVPW_PciBusId* pBusIds;
205
+ /// [in] size of the pBusIDs array; use result from NVPW_GetDeviceCount
206
+ size_t numDevices;
207
+ } NVPW_Device_GetPciBusIds_Params;
208
+ #define NVPW_Device_GetPciBusIds_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_GetPciBusIds_Params, numDevices)
209
+
210
+ NVPA_Status NVPW_Device_GetPciBusIds(NVPW_Device_GetPciBusIds_Params* pParams);
211
+
212
+
213
+ #define NVPW_DEVICE_MIG_GPU_INSTANCE_ID_INVALID 0xFFFFFFFFu
214
+ #define NVPW_DEVICE_MIG_GPU_INSTANCE_ID_FULLCHIP 0xFFFFFFFEu
215
+
216
+
217
+ typedef struct NVPW_Device_GetMigAttributes_Params
218
+ {
219
+ /// [in]
220
+ size_t structSize;
221
+ /// [in] assign to NULL
222
+ void* pPriv;
223
+ /// [in]
224
+ size_t deviceIndex;
225
+ /// [out]
226
+ NVPA_Bool isMigPartition;
227
+ /// [out]
228
+ uint32_t gpuInstanceId;
229
+ /// [out]
230
+ uint32_t computeInstanceId;
231
+ } NVPW_Device_GetMigAttributes_Params;
232
+ #define NVPW_Device_GetMigAttributes_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_GetMigAttributes_Params, computeInstanceId)
233
+
234
+ NVPA_Status NVPW_Device_GetMigAttributes(NVPW_Device_GetMigAttributes_Params* pParams);
235
+
236
+ typedef struct NVPW_Adapter_GetDeviceIndex_Params
237
+ {
238
+ /// [in]
239
+ size_t structSize;
240
+ /// [in] assign to NULL
241
+ void* pPriv;
242
+ /// [in]
243
+ struct IDXGIAdapter* pAdapter;
244
+ /// [in]
245
+ size_t sliIndex;
246
+ /// [out]
247
+ size_t deviceIndex;
248
+ } NVPW_Adapter_GetDeviceIndex_Params;
249
+ #define NVPW_Adapter_GetDeviceIndex_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Adapter_GetDeviceIndex_Params, deviceIndex)
250
+
251
+ NVPA_Status NVPW_Adapter_GetDeviceIndex(NVPW_Adapter_GetDeviceIndex_Params* pParams);
252
+
253
+ typedef struct NVPW_CounterData_GetNumRanges_Params
254
+ {
255
+ /// [in]
256
+ size_t structSize;
257
+ /// [in] assign to NULL
258
+ void* pPriv;
259
+ const uint8_t* pCounterDataImage;
260
+ size_t numRanges;
261
+ } NVPW_CounterData_GetNumRanges_Params;
262
+ #define NVPW_CounterData_GetNumRanges_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_CounterData_GetNumRanges_Params, numRanges)
263
+
264
+ NVPA_Status NVPW_CounterData_GetNumRanges(NVPW_CounterData_GetNumRanges_Params* pParams);
265
+
266
+ typedef struct NVPW_CounterData_GetChipName_Params
267
+ {
268
+ /// [in]
269
+ size_t structSize;
270
+ /// [in] assign to NULL
271
+ void* pPriv;
272
+ /// [in]
273
+ const uint8_t* pCounterDataImage;
274
+ /// [in]
275
+ size_t counterDataImageSize;
276
+ /// [out]
277
+ const char* pChipName;
278
+ } NVPW_CounterData_GetChipName_Params;
279
+ #define NVPW_CounterData_GetChipName_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_CounterData_GetChipName_Params, pChipName)
280
+
281
+ NVPA_Status NVPW_CounterData_GetChipName(NVPW_CounterData_GetChipName_Params* pParams);
282
+
283
+ typedef struct NVPW_Config_GetNumPasses_Params
284
+ {
285
+ /// [in]
286
+ size_t structSize;
287
+ /// [in] assign to NULL
288
+ void* pPriv;
289
+ /// [in]
290
+ const uint8_t* pConfig;
291
+ /// [out]
292
+ size_t numPipelinedPasses;
293
+ /// [out]
294
+ size_t numIsolatedPasses;
295
+ } NVPW_Config_GetNumPasses_Params;
296
+ #define NVPW_Config_GetNumPasses_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Config_GetNumPasses_Params, numIsolatedPasses)
297
+
298
+ /// Total num passes = numPipelinedPasses + numIsolatedPasses * numNestingLevels
299
+ NVPA_Status NVPW_Config_GetNumPasses(NVPW_Config_GetNumPasses_Params* pParams);
300
+
301
+ typedef struct NVPW_Config_GetNumPasses_V2_Params
302
+ {
303
+ /// [in]
304
+ size_t structSize;
305
+ /// [in] assign to NULL
306
+ void* pPriv;
307
+ /// [in]
308
+ const uint8_t* pConfig;
309
+ /// [out]
310
+ size_t numPasses;
311
+ } NVPW_Config_GetNumPasses_V2_Params;
312
+ #define NVPW_Config_GetNumPasses_V2_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Config_GetNumPasses_V2_Params, numPasses)
313
+
314
+ /// Total num passes = numPasses * numNestingLevels
315
+ NVPA_Status NVPW_Config_GetNumPasses_V2(NVPW_Config_GetNumPasses_V2_Params* pParams);
316
+
317
+ #define NVPW_API_SET_CUDA_PROFILER 0x18209d0775b2f89dULL
318
+
319
+ #define NVPW_API_SET_D3D11_PROFILER 0xca55c6738445db2bULL
320
+
321
+ #define NVPW_API_SET_D3D12_PROFILER 0xc0c2d46dd7c7ad78ULL
322
+
323
+ #define NVPW_API_SET_EGL_PROFILER 0x3c3747dae1f9565cULL
324
+
325
+ #define NVPW_API_SET_GPU_PERIODICSAMPLER 0x9f4c2571fc0b2e8aULL
326
+
327
+ #define NVPW_API_SET_METRICSEVALUATOR 0x0368a8768d811af9ULL
328
+
329
+ #define NVPW_API_SET_METRICS_AD10X_COMP 0xbe57278e12cb5288ULL
330
+
331
+ #define NVPW_API_SET_METRICS_AD10X_GRFX 0x5cbf0774f81bf491ULL
332
+
333
+ #define NVPW_API_SET_METRICS_GA100_COMP 0x16b7d8c20d8b4915ULL
334
+
335
+ #define NVPW_API_SET_METRICS_GA100_GRFX 0xc94eaabec04a94faULL
336
+
337
+ #define NVPW_API_SET_METRICS_GA10X_COMP 0xb5d6391c2e299ab5ULL
338
+
339
+ #define NVPW_API_SET_METRICS_GA10X_GRFX 0x6ebc121178b5ce0bULL
340
+
341
+ #define NVPW_API_SET_METRICS_GV100_COMP 0x863705cc57919f72ULL
342
+
343
+ #define NVPW_API_SET_METRICS_GV100_GRFX 0x9900da75d164fecfULL
344
+
345
+ #define NVPW_API_SET_METRICS_GV11B_COMP 0xd3f79a859235848fULL
346
+
347
+ #define NVPW_API_SET_METRICS_GV11B_GRFX 0xeb8e26220106e227ULL
348
+
349
+ #define NVPW_API_SET_METRICS_TU10X_COMP 0x70f40be0afd35da8ULL
350
+
351
+ #define NVPW_API_SET_METRICS_TU10X_GRFX 0xdf219cb838db6968ULL
352
+
353
+ #define NVPW_API_SET_METRICS_TU11X_COMP 0xeb0069d7d0956678ULL
354
+
355
+ #define NVPW_API_SET_METRICS_TU11X_GRFX 0x0977d9342bd62743ULL
356
+
357
+ #define NVPW_API_SET_OPENGL_PROFILER 0xe4cd9ea40f2ee777ULL
358
+
359
+ #define NVPW_API_SET_VULKAN_PROFILER 0x8c56b6a03d779689ULL
360
+
361
+ #define NVPW_SDK_VERSION 0x1e128b6f001423fcULL
362
+
363
+ typedef struct NVPW_QueryVersionNumber_Params
364
+ {
365
+ /// [in]
366
+ size_t structSize;
367
+ /// [in] assign to NULL
368
+ void* pPriv;
369
+ /// [in]
370
+ uint64_t apiSet;
371
+ /// [out]
372
+ uint32_t major;
373
+ /// [out]
374
+ uint32_t minor;
375
+ /// [out]
376
+ uint32_t patch;
377
+ /// [out]
378
+ uint32_t relMajor;
379
+ /// [out]
380
+ uint32_t relMinor;
381
+ /// [out]
382
+ uint32_t relPatch;
383
+ } NVPW_QueryVersionNumber_Params;
384
+ #define NVPW_QueryVersionNumber_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_QueryVersionNumber_Params, relPatch)
385
+
386
+ /// Query version number of an API set
387
+ NVPA_Status NVPW_QueryVersionNumber(NVPW_QueryVersionNumber_Params* pParams);
388
+
389
+ typedef enum NVPW_Device_ClockStatus
390
+ {
391
+ /// clock status is unknown
392
+ NVPW_DEVICE_CLOCK_STATUS_UNKNOWN,
393
+ /// clocks are locked to rated tdp values - Deprecated, use NVPW_DEVICE_CLOCK_STATUS_LOCKED instead
394
+ NVPW_DEVICE_CLOCK_STATUS_LOCKED_TO_RATED_TDP,
395
+ /// clocks are not locked and can boost above rated tdp
396
+ NVPW_DEVICE_CLOCK_STATUS_BOOST_ENABLED,
397
+ /// clocks are not locked and will not go above rated tdp
398
+ NVPW_DEVICE_CLOCK_STATUS_BOOST_DISABLED,
399
+ /// clocks are locked
400
+ NVPW_DEVICE_CLOCK_STATUS_LOCKED,
401
+ /// clocks are not locked
402
+ NVPW_DEVICE_CLOCK_STATUS_UNLOCKED,
403
+ NVPW_DEVICE_CLOCK_STATUS__COUNT
404
+ } NVPW_Device_ClockStatus;
405
+
406
+ typedef enum NVPW_Device_ClockLevel
407
+ {
408
+ /// clock level is invalid
409
+ NVPW_DEVICE_CLOCK_LEVEL_INVALID,
410
+ /// clock level is at rated tdp
411
+ NVPW_DEVICE_CLOCK_LEVEL_RATED_TDP,
412
+ /// clock level is at turbo boost
413
+ NVPW_DEVICE_CLOCK_LEVEL_TURBO_BOOST,
414
+ NVPW_DEVICE_CLOCK_LEVEL__COUNT
415
+ } NVPW_Device_ClockLevel;
416
+
417
+ typedef struct NVPW_Device_GetClockStatus_Params
418
+ {
419
+ /// [in]
420
+ size_t structSize;
421
+ /// [in] assign to NULL
422
+ void* pPriv;
423
+ size_t deviceIndex;
424
+ /// [in]
425
+ NVPW_Device_ClockStatus clockStatus;
426
+ /// [in]
427
+ NVPW_Device_ClockLevel clockLevel;
428
+ } NVPW_Device_GetClockStatus_Params;
429
+ #define NVPW_Device_GetClockStatus_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_GetClockStatus_Params, clockLevel)
430
+
431
+ NVPA_Status NVPW_Device_GetClockStatus(NVPW_Device_GetClockStatus_Params* pParams);
432
+
433
+ typedef enum NVPW_Device_ClockSetting
434
+ {
435
+ /// invalid op, specify valid clocks operation during profiling
436
+ NVPW_DEVICE_CLOCK_SETTING_INVALID,
437
+ /// default to driver/application config (normally unlocked and not boosted, but could be unlocked boosted, or
438
+ /// locked to rated TDP)
439
+ NVPW_DEVICE_CLOCK_SETTING_DEFAULT,
440
+ /// lock clocks at rated tdp base values
441
+ NVPW_DEVICE_CLOCK_SETTING_LOCK_TO_RATED_TDP,
442
+ /// lock clocks at turbo boost values
443
+ NVPW_DEVICE_CLOCK_SETTING_LOCK_TO_TURBO_BOOST,
444
+ NVPW_DEVICE_CLOCK_SETTING__COUNT
445
+ } NVPW_Device_ClockSetting;
446
+
447
+ typedef struct NVPW_Device_SetClockSetting_Params
448
+ {
449
+ /// [in]
450
+ size_t structSize;
451
+ /// [in] assign to NULL
452
+ void* pPriv;
453
+ size_t deviceIndex;
454
+ /// [in]
455
+ NVPW_Device_ClockSetting clockSetting;
456
+ } NVPW_Device_SetClockSetting_Params;
457
+ #define NVPW_Device_SetClockSetting_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Device_SetClockSetting_Params, clockSetting)
458
+
459
+ NVPA_Status NVPW_Device_SetClockSetting(NVPW_Device_SetClockSetting_Params* pParams);
460
+
461
+ typedef struct NVPW_CounterData_GetRangeDescriptions_Params
462
+ {
463
+ /// [in]
464
+ size_t structSize;
465
+ /// [in] assign to NULL
466
+ void* pPriv;
467
+ const uint8_t* pCounterDataImage;
468
+ size_t rangeIndex;
469
+ /// [inout] Number of descriptions allocated in ppDescriptions
470
+ size_t numDescriptions;
471
+ const char** ppDescriptions;
472
+ } NVPW_CounterData_GetRangeDescriptions_Params;
473
+ #define NVPW_CounterData_GetRangeDescriptions_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_CounterData_GetRangeDescriptions_Params, ppDescriptions)
474
+
475
+ NVPA_Status NVPW_CounterData_GetRangeDescriptions(NVPW_CounterData_GetRangeDescriptions_Params* pParams);
476
+
477
+ typedef struct NVPW_Profiler_CounterData_GetRangeDescriptions_Params
478
+ {
479
+ /// [in]
480
+ size_t structSize;
481
+ /// [in] assign to NULL
482
+ void* pPriv;
483
+ const uint8_t* pCounterDataImage;
484
+ size_t rangeIndex;
485
+ /// [inout] Number of descriptions allocated in ppDescriptions
486
+ size_t numDescriptions;
487
+ const char** ppDescriptions;
488
+ } NVPW_Profiler_CounterData_GetRangeDescriptions_Params;
489
+ #define NVPW_Profiler_CounterData_GetRangeDescriptions_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_Profiler_CounterData_GetRangeDescriptions_Params, ppDescriptions)
490
+
491
+ NVPA_Status NVPW_Profiler_CounterData_GetRangeDescriptions(NVPW_Profiler_CounterData_GetRangeDescriptions_Params* pParams);
492
+
493
+ #ifndef NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_DEFINED
494
+ #define NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_DEFINED
495
+ typedef enum NVPW_PeriodicSampler_CounterData_AppendMode
496
+ {
497
+ NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_LINEAR = 0,
498
+ NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_CIRCULAR = 1,
499
+ NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE__COUNT
500
+ } NVPW_PeriodicSampler_CounterData_AppendMode;
501
+ #endif //NVPW_PERIODIC_SAMPLER_COUNTER_DATA_APPEND_MODE_DEFINED
502
+
503
+ typedef struct NVPW_PeriodicSampler_CounterData_GetSampleTime_Params
504
+ {
505
+ /// [in]
506
+ size_t structSize;
507
+ /// [in] assign to NULL
508
+ void* pPriv;
509
+ /// [in]
510
+ const uint8_t* pCounterDataImage;
511
+ /// [in]
512
+ size_t rangeIndex;
513
+ /// [out]
514
+ uint64_t timestampStart;
515
+ /// [out]
516
+ uint64_t timestampEnd;
517
+ } NVPW_PeriodicSampler_CounterData_GetSampleTime_Params;
518
+ #define NVPW_PeriodicSampler_CounterData_GetSampleTime_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_GetSampleTime_Params, timestampEnd)
519
+
520
+ NVPA_Status NVPW_PeriodicSampler_CounterData_GetSampleTime(NVPW_PeriodicSampler_CounterData_GetSampleTime_Params* pParams);
521
+
522
+ typedef struct NVPW_PeriodicSampler_CounterData_TrimInPlace_Params
523
+ {
524
+ /// [in]
525
+ size_t structSize;
526
+ /// [in] assign to NULL
527
+ void* pPriv;
528
+ /// [in]
529
+ uint8_t* pCounterDataImage;
530
+ /// [in]
531
+ size_t counterDataImageSize;
532
+ /// [out]
533
+ size_t counterDataImageTrimmedSize;
534
+ } NVPW_PeriodicSampler_CounterData_TrimInPlace_Params;
535
+ #define NVPW_PeriodicSampler_CounterData_TrimInPlace_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_TrimInPlace_Params, counterDataImageTrimmedSize)
536
+
537
+ NVPA_Status NVPW_PeriodicSampler_CounterData_TrimInPlace(NVPW_PeriodicSampler_CounterData_TrimInPlace_Params* pParams);
538
+
539
+ typedef struct NVPW_PeriodicSampler_CounterData_GetInfo_Params
540
+ {
541
+ /// [in]
542
+ size_t structSize;
543
+ /// [in] assign to NULL
544
+ void* pPriv;
545
+ /// [in]
546
+ const uint8_t* pCounterDataImage;
547
+ /// [in]
548
+ size_t counterDataImageSize;
549
+ /// [out] total number of ranges in the counter data
550
+ size_t numTotalRanges;
551
+ /// [out] if in "linear" mode, this API returns the number of "populated" ranges; if it's in "circular" mode,
552
+ /// then it returns the last "populated" range index + 1, when there is no such range, it returns 0.
553
+ size_t numPopulatedRanges;
554
+ /// [out] if in "linear" mode, this API returns the number of "completed" ranges; if it's in "circular" mode,
555
+ /// then it returns the last "completed" range index + 1, when there is no such range, it returns 0.
556
+ size_t numCompletedRanges;
557
+ } NVPW_PeriodicSampler_CounterData_GetInfo_Params;
558
+ #define NVPW_PeriodicSampler_CounterData_GetInfo_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_GetInfo_Params, numCompletedRanges)
559
+
560
+ /// In periodic sampler, a range in counter data stores exactly one sample's data. For better performance, periodic
561
+ /// sampler may operate in an out-of-order fashion when populating sample data, i.e. it may not fully populate all
562
+ /// counters of a sample/range before starting to populate the next sample/range. As a result, we have two concepts
563
+ /// here, "populated" & "completed": a range is considered "populated" even if only partial counters have been
564
+ /// written; on the other hand, a range is only considered "completed" if all the collecting counters have been
565
+ /// written.
566
+ NVPA_Status NVPW_PeriodicSampler_CounterData_GetInfo(NVPW_PeriodicSampler_CounterData_GetInfo_Params* pParams);
567
+
568
+ typedef struct NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params
569
+ {
570
+ /// [in]
571
+ size_t structSize;
572
+ /// [in] assign to NULL
573
+ void* pPriv;
574
+ /// [in]
575
+ const uint8_t* pCounterDataImage;
576
+ /// [in]
577
+ size_t counterDataImageSize;
578
+ /// [in]
579
+ size_t rangeIndex;
580
+ /// [out]
581
+ uint32_t triggerCount;
582
+ } NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params;
583
+ #define NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params, triggerCount)
584
+
585
+ NVPA_Status NVPW_PeriodicSampler_CounterData_GetTriggerCount(NVPW_PeriodicSampler_CounterData_GetTriggerCount_Params* pParams);
586
+
587
+ typedef struct NVPW_PeriodicSampler_CounterData_IsDataComplete_Params
588
+ {
589
+ /// [in]
590
+ size_t structSize;
591
+ /// [in] assign to NULL
592
+ void* pPriv;
593
+ /// [in]
594
+ const uint8_t* pCounterDataImage;
595
+ /// [in]
596
+ size_t counterDataImageSize;
597
+ /// [in]
598
+ size_t rangeIndex;
599
+ /// [out]
600
+ NVPA_Bool isComplete;
601
+ } NVPW_PeriodicSampler_CounterData_IsDataComplete_Params;
602
+ #define NVPW_PeriodicSampler_CounterData_IsDataComplete_Params_STRUCT_SIZE NVPA_STRUCT_SIZE(NVPW_PeriodicSampler_CounterData_IsDataComplete_Params, isComplete)
603
+
604
+ /// Checks whether a given sample's data is complete. See also 'NVPW_PeriodicSampler_CounterData_GetInfo'
605
+ NVPA_Status NVPW_PeriodicSampler_CounterData_IsDataComplete(NVPW_PeriodicSampler_CounterData_IsDataComplete_Params* pParams);
606
+
607
+
608
+ typedef struct NVPW_TimestampReport
609
+ {
610
+ uint32_t payload;
611
+ uint8_t reserved0004[4];
612
+ uint64_t timestamp;
613
+ } NVPW_TimestampReport;
614
+
615
+
616
+
617
+
618
+ #ifdef __cplusplus
619
+ } // extern "C"
620
+ #endif
621
+
622
+ #if defined(__GNUC__) && defined(NVPA_SHARED_LIB)
623
+ #pragma GCC visibility pop
624
+ #endif
625
+
626
+ #endif // NVPERF_TARGET_H
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/backends/nvidia/include/sm_32_atomic_functions.hpp ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 1993-2023 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NOTICE TO LICENSEE:
5
+ *
6
+ * This source code and/or documentation ("Licensed Deliverables") are
7
+ * subject to NVIDIA intellectual property rights under U.S. and
8
+ * international Copyright laws.
9
+ *
10
+ * These Licensed Deliverables contained herein is PROPRIETARY and
11
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
12
+ * conditions of a form of NVIDIA software license agreement by and
13
+ * between NVIDIA and Licensee ("License Agreement") or electronically
14
+ * accepted by Licensee. Notwithstanding any terms or conditions to
15
+ * the contrary in the License Agreement, reproduction or disclosure
16
+ * of the Licensed Deliverables to any third party without the express
17
+ * written consent of NVIDIA is prohibited.
18
+ *
19
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32
+ * OF THESE LICENSED DELIVERABLES.
33
+ *
34
+ * U.S. Government End Users. These Licensed Deliverables are a
35
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36
+ * 1995), consisting of "commercial computer software" and "commercial
37
+ * computer software documentation" as such terms are used in 48
38
+ * C.F.R. 35.235 (SEPT 1995) and is provided to the U.S. Government
39
+ * only as a commercial end item. Consistent with 48 C.F.R.35.235 and
40
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41
+ * U.S. Government End Users acquire the Licensed Deliverables with
42
+ * only those rights set forth herein.
43
+ *
44
+ * Any use of the Licensed Deliverables in individual and commercial
45
+ * software must include, in the user documentation and internal
46
+ * comments to the code, the above Disclaimer and U.S. Government End
47
+ * Users Notice.
48
+ */
49
+
50
+ #if !defined(__SM_32_ATOMIC_FUNCTIONS_HPP__)
51
+ #define __SM_32_ATOMIC_FUNCTIONS_HPP__
52
+
53
+ #ifdef __CUDA_ARCH__
54
+ extern "C"
55
+ {
56
+ extern __device__ __device_builtin__ long long __illAtomicMin(long long *address, long long val);
57
+ extern __device__ __device_builtin__ long long __illAtomicMax(long long *address, long long val);
58
+ extern __device__ __device_builtin__ long long __llAtomicAnd(long long *address, long long val);
59
+ extern __device__ __device_builtin__ long long __llAtomicOr(long long *address, long long val);
60
+ extern __device__ __device_builtin__ long long __llAtomicXor(long long *address, long long val);
61
+ extern __device__ __device_builtin__ unsigned long long __ullAtomicMin(unsigned long long *address, unsigned long long val);
62
+ extern __device__ __device_builtin__ unsigned long long __ullAtomicMax(unsigned long long *address, unsigned long long val);
63
+ extern __device__ __device_builtin__ unsigned long long __ullAtomicAnd(unsigned long long *address, unsigned long long val);
64
+ extern __device__ __device_builtin__ unsigned long long __ullAtomicOr (unsigned long long *address, unsigned long long val);
65
+ extern __device__ __device_builtin__ unsigned long long __ullAtomicXor(unsigned long long *address, unsigned long long val);
66
+ }
67
+ #endif /* __CUDA_ARCH__ */
68
+
69
+
70
+ #if defined(__CUDACC_RTC__)
71
+ #define __SM_32_ATOMIC_FUNCTIONS_DECL__ __device__
72
+ #else /* !__CUDACC_RTC__ */
73
+ #define __SM_32_ATOMIC_FUNCTIONS_DECL__ static __inline__ __device__
74
+ #endif /* __CUDACC_RTC__ */
75
+
76
+ #if defined(__cplusplus) && defined(__CUDACC__)
77
+
78
+ #if defined(_NVHPC_CUDA) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 320
79
+
80
+ /*******************************************************************************
81
+ * *
82
+ * *
83
+ * *
84
+ *******************************************************************************/
85
+
86
+ #include "cuda_runtime_api.h"
87
+
88
+ /*******************************************************************************
89
+ * *
90
+ * *
91
+ * *
92
+ *******************************************************************************/
93
+
94
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicMin(long long *address, long long val)
95
+ {
96
+ return __illAtomicMin(address, val);
97
+ }
98
+
99
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicMax(long long *address, long long val)
100
+ {
101
+ return __illAtomicMax(address, val);
102
+ }
103
+
104
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicAnd(long long *address, long long val)
105
+ {
106
+ return __llAtomicAnd(address, val);
107
+ }
108
+
109
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicOr(long long *address, long long val)
110
+ {
111
+ return __llAtomicOr(address, val);
112
+ }
113
+
114
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ long long atomicXor(long long *address, long long val)
115
+ {
116
+ return __llAtomicXor(address, val);
117
+ }
118
+
119
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicMin(unsigned long long *address, unsigned long long val)
120
+ {
121
+ return __ullAtomicMin(address, val);
122
+ }
123
+
124
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicMax(unsigned long long *address, unsigned long long val)
125
+ {
126
+ return __ullAtomicMax(address, val);
127
+ }
128
+
129
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicAnd(unsigned long long *address, unsigned long long val)
130
+ {
131
+ return __ullAtomicAnd(address, val);
132
+ }
133
+
134
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicOr(unsigned long long *address, unsigned long long val)
135
+ {
136
+ return __ullAtomicOr(address, val);
137
+ }
138
+
139
+ __SM_32_ATOMIC_FUNCTIONS_DECL__ unsigned long long atomicXor(unsigned long long *address, unsigned long long val)
140
+ {
141
+ return __ullAtomicXor(address, val);
142
+ }
143
+
144
+ #endif /* _NVHPC_CUDA || !__CUDA_ARCH__ || __CUDA_ARCH__ >= 320 */
145
+
146
+ #endif /* __cplusplus && __CUDACC__ */
147
+
148
+ #undef __SM_32_ATOMIC_FUNCTIONS_DECL__
149
+
150
+ #endif /* !__SM_32_ATOMIC_FUNCTIONS_HPP__ */
151
+
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (7.17 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/math.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/random.cpython-312.pyc ADDED
Binary file (9.64 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/standard.cpython-312.pyc ADDED
Binary file (21.8 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/__pycache__/target_info.cpython-312.pyc ADDED
Binary file (2.28 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pkgutil
2
+ from importlib.util import module_from_spec
3
+ from sys import modules
4
+
5
+ _backends = []
6
+ for module_finder, module_name, is_pkg in pkgutil.iter_modules(
7
+ __path__,
8
+ prefix=__name__ + ".",
9
+ ):
10
+ # skip .py files (like libdevice.py)
11
+ if not is_pkg:
12
+ continue
13
+
14
+ # import backends (like cuda and hip) that are included during setup.py
15
+ spec = module_finder.find_spec(module_name)
16
+ if spec is None or spec.loader is None:
17
+ continue
18
+ module = module_from_spec(spec)
19
+ spec.loader.exec_module(module)
20
+
21
+ _backends.append(module_name)
22
+ modules[module_name] = module
23
+
24
+ __all__ = _backends
25
+
26
+ del _backends
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (911 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/__pycache__/libdevice.cpython-312.pyc ADDED
Binary file (20.4 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import libdevice
2
+
3
+ from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80)
4
+ from .gdc import (gdc_launch_dependents, gdc_wait)
5
+
6
+ __all__ = [
7
+ "libdevice",
8
+ "globaltimer",
9
+ "num_threads",
10
+ "num_warps",
11
+ "smid",
12
+ "convert_custom_float8_sm70",
13
+ "convert_custom_float8_sm80",
14
+ "gdc_launch_dependents",
15
+ "gdc_wait",
16
+ ]
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (585 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/gdc.cpython-312.pyc ADDED
Binary file (2.74 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/libdevice.cpython-312.pyc ADDED
Binary file (89.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.61 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/gdc.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grid Dependency Control (GDC) is a mechanism used when enabling programmatic dependent launch to launch and
3
+ synchronize grids. These APIs expose GDC to the programmer.
4
+
5
+ Programmatic dependent launch is supported on SM90 (Hopper) and beyond.
6
+ For PTX reference on grid dependency control see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol.
7
+ """
8
+
9
+ from triton.language import core
10
+
11
+
12
+ @core.extern
13
+ def gdc_wait(_semantic=None):
14
+ """
15
+ GDC wait is a blocking instruction that waits for all instructions in a prior kernel to complete before continuing.
16
+ This ensures all memory operations happening before the wait is visible to instructions after it,
17
+ e.g. if the prior kernel writes to address "x" the new values will be visible in this kernel after the wait.
18
+
19
+ This instruction is also safe to execute when programmatic dependent launch is disabled.
20
+
21
+ See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol for more details.
22
+ """
23
+ core.inline_asm_elementwise("griddepcontrol.wait; // dummy $0", "=r", [], dtype=core.int32, is_pure=False, pack=1,
24
+ _semantic=_semantic)
25
+
26
+
27
+ @core.extern
28
+ def gdc_launch_dependents(_semantic=None):
29
+ """
30
+ This operation when launched with programmatic dependent launch signals that
31
+ the next program may launch once all programs in the current kernel
32
+ call this function or complete.
33
+
34
+ Repeated calls to this function have no effect past the first call, and the first call should be
35
+ treated by the programmer as a hint to the runtime system to launch the next kernel.
36
+
37
+ This instruction is also safe to execute when programmatic dependent launch is disabled.
38
+
39
+ See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol for more details.
40
+ """
41
+ core.inline_asm_elementwise("griddepcontrol.launch_dependents; // dummy $0", "=r", [], dtype=core.int32,
42
+ is_pure=False, pack=1, _semantic=_semantic)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/libdevice.py ADDED
@@ -0,0 +1,1629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def clz(arg0, _semantic=None):
6
+ return core.extern_elementwise(
7
+ "", "", [arg0], {
8
+ (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")),
9
+ (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")),
10
+ }, is_pure=True, _semantic=_semantic)
11
+
12
+
13
+ @core.extern
14
+ def popc(arg0, _semantic=None):
15
+ return core.extern_elementwise(
16
+ "", "", [arg0], {
17
+ (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")),
18
+ (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")),
19
+ }, is_pure=True, _semantic=_semantic)
20
+
21
+
22
+ @core.extern
23
+ def byte_perm(arg0, arg1, arg2, _semantic=None):
24
+ return core.extern_elementwise("", "", [arg0, arg1, arg2], {
25
+ (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")),
26
+ }, is_pure=True, _semantic=_semantic)
27
+
28
+
29
+ @core.extern
30
+ def mulhi(arg0, arg1, _semantic=None):
31
+ return core.extern_elementwise(
32
+ "", "", [arg0, arg1], {
33
+ (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")),
34
+ (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")),
35
+ (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")),
36
+ (core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")),
37
+ }, is_pure=True, _semantic=_semantic)
38
+
39
+
40
+ @core.extern
41
+ def mul24(arg0, arg1, _semantic=None):
42
+ return core.extern_elementwise(
43
+ "", "", [arg0, arg1], {
44
+ (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")),
45
+ (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")),
46
+ }, is_pure=True, _semantic=_semantic)
47
+
48
+
49
+ @core.extern
50
+ def brev(arg0, _semantic=None):
51
+ return core.extern_elementwise(
52
+ "", "", [arg0], {
53
+ (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")),
54
+ (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")),
55
+ }, is_pure=True, _semantic=_semantic)
56
+
57
+
58
+ @core.extern
59
+ def sad(arg0, arg1, arg2, _semantic=None):
60
+ return core.extern_elementwise(
61
+ "", "", [arg0, arg1, arg2], {
62
+ (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")),
63
+ (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")),
64
+ }, is_pure=True, _semantic=_semantic)
65
+
66
+
67
+ @core.extern
68
+ def abs(arg0, _semantic=None):
69
+ return core.extern_elementwise(
70
+ "", "", [arg0], {
71
+ (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")),
72
+ (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")),
73
+ (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")),
74
+ (core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")),
75
+ }, is_pure=True, _semantic=_semantic)
76
+
77
+
78
+ @core.extern
79
+ def floor(arg0, _semantic=None):
80
+ return core.extern_elementwise(
81
+ "", "", [arg0], {
82
+ (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")),
83
+ (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")),
84
+ }, is_pure=True, _semantic=_semantic)
85
+
86
+
87
+ @core.extern
88
+ def rcp64h(arg0, _semantic=None):
89
+ return core.extern_elementwise("", "", [arg0], {
90
+ (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")),
91
+ }, is_pure=True, _semantic=_semantic)
92
+
93
+
94
+ @core.extern
95
+ def rsqrt(arg0, _semantic=None):
96
+ return core.extern_elementwise(
97
+ "", "", [arg0], {
98
+ (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")),
99
+ (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")),
100
+ }, is_pure=True, _semantic=_semantic)
101
+
102
+
103
+ @core.extern
104
+ def ceil(arg0, _semantic=None):
105
+ return core.extern_elementwise(
106
+ "", "", [arg0], {
107
+ (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")),
108
+ (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")),
109
+ }, is_pure=True, _semantic=_semantic)
110
+
111
+
112
+ @core.extern
113
+ def trunc(arg0, _semantic=None):
114
+ return core.extern_elementwise(
115
+ "", "", [arg0], {
116
+ (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")),
117
+ (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")),
118
+ }, is_pure=True, _semantic=_semantic)
119
+
120
+
121
+ @core.extern
122
+ def exp2(arg0, _semantic=None):
123
+ return core.extern_elementwise(
124
+ "", "", [arg0], {
125
+ (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")),
126
+ (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")),
127
+ }, is_pure=True, _semantic=_semantic)
128
+
129
+
130
+ @core.extern
131
+ def saturatef(arg0, _semantic=None):
132
+ return core.extern_elementwise("", "", [arg0], {
133
+ (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")),
134
+ }, is_pure=True, _semantic=_semantic)
135
+
136
+
137
+ @core.extern
138
+ def fma_rn(arg0, arg1, arg2, _semantic=None):
139
+ return core.extern_elementwise(
140
+ "", "", [arg0, arg1, arg2], {
141
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")),
142
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")),
143
+ }, is_pure=True, _semantic=_semantic)
144
+
145
+
146
+ @core.extern
147
+ def fma_rz(arg0, arg1, arg2, _semantic=None):
148
+ return core.extern_elementwise(
149
+ "", "", [arg0, arg1, arg2], {
150
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")),
151
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")),
152
+ }, is_pure=True, _semantic=_semantic)
153
+
154
+
155
+ @core.extern
156
+ def fma_rd(arg0, arg1, arg2, _semantic=None):
157
+ return core.extern_elementwise(
158
+ "", "", [arg0, arg1, arg2], {
159
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")),
160
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")),
161
+ }, is_pure=True, _semantic=_semantic)
162
+
163
+
164
+ @core.extern
165
+ def fma_ru(arg0, arg1, arg2, _semantic=None):
166
+ return core.extern_elementwise(
167
+ "", "", [arg0, arg1, arg2], {
168
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")),
169
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")),
170
+ }, is_pure=True, _semantic=_semantic)
171
+
172
+
173
+ @core.extern
174
+ def fast_dividef(arg0, arg1, _semantic=None):
175
+ return core.extern_elementwise("", "", [arg0, arg1], {
176
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")),
177
+ }, is_pure=True, _semantic=_semantic)
178
+
179
+
180
+ @core.extern
181
+ def div_rn(arg0, arg1, _semantic=None):
182
+ return core.extern_elementwise(
183
+ "", "", [arg0, arg1], {
184
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")),
185
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")),
186
+ }, is_pure=True, _semantic=_semantic)
187
+
188
+
189
+ @core.extern
190
+ def div_rz(arg0, arg1, _semantic=None):
191
+ return core.extern_elementwise(
192
+ "", "", [arg0, arg1], {
193
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")),
194
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")),
195
+ }, is_pure=True, _semantic=_semantic)
196
+
197
+
198
+ @core.extern
199
+ def div_rd(arg0, arg1, _semantic=None):
200
+ return core.extern_elementwise(
201
+ "", "", [arg0, arg1], {
202
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")),
203
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")),
204
+ }, is_pure=True, _semantic=_semantic)
205
+
206
+
207
+ @core.extern
208
+ def div_ru(arg0, arg1, _semantic=None):
209
+ return core.extern_elementwise(
210
+ "", "", [arg0, arg1], {
211
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")),
212
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")),
213
+ }, is_pure=True, _semantic=_semantic)
214
+
215
+
216
+ @core.extern
217
+ def rcp_rn(arg0, _semantic=None):
218
+ return core.extern_elementwise(
219
+ "", "", [arg0], {
220
+ (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")),
221
+ (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")),
222
+ }, is_pure=True, _semantic=_semantic)
223
+
224
+
225
+ @core.extern
226
+ def rcp_rz(arg0, _semantic=None):
227
+ return core.extern_elementwise(
228
+ "", "", [arg0], {
229
+ (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")),
230
+ (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")),
231
+ }, is_pure=True, _semantic=_semantic)
232
+
233
+
234
+ @core.extern
235
+ def rcp_rd(arg0, _semantic=None):
236
+ return core.extern_elementwise(
237
+ "", "", [arg0], {
238
+ (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")),
239
+ (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")),
240
+ }, is_pure=True, _semantic=_semantic)
241
+
242
+
243
+ @core.extern
244
+ def rcp_ru(arg0, _semantic=None):
245
+ return core.extern_elementwise(
246
+ "", "", [arg0], {
247
+ (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")),
248
+ (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")),
249
+ }, is_pure=True, _semantic=_semantic)
250
+
251
+
252
+ @core.extern
253
+ def sqrt_rn(arg0, _semantic=None):
254
+ return core.extern_elementwise(
255
+ "", "", [arg0], {
256
+ (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")),
257
+ (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")),
258
+ }, is_pure=True, _semantic=_semantic)
259
+
260
+
261
+ @core.extern
262
+ def sqrt_rz(arg0, _semantic=None):
263
+ return core.extern_elementwise(
264
+ "", "", [arg0], {
265
+ (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")),
266
+ (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")),
267
+ }, is_pure=True, _semantic=_semantic)
268
+
269
+
270
+ @core.extern
271
+ def sqrt_rd(arg0, _semantic=None):
272
+ return core.extern_elementwise(
273
+ "", "", [arg0], {
274
+ (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")),
275
+ (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")),
276
+ }, is_pure=True, _semantic=_semantic)
277
+
278
+
279
+ @core.extern
280
+ def sqrt_ru(arg0, _semantic=None):
281
+ return core.extern_elementwise(
282
+ "", "", [arg0], {
283
+ (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")),
284
+ (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")),
285
+ }, is_pure=True, _semantic=_semantic)
286
+
287
+
288
+ @core.extern
289
+ def sqrt(arg0, _semantic=None):
290
+ return core.extern_elementwise(
291
+ "", "", [arg0], {
292
+ (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")),
293
+ (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")),
294
+ }, is_pure=True, _semantic=_semantic)
295
+
296
+
297
+ @core.extern
298
+ def add_rn(arg0, arg1, _semantic=None):
299
+ return core.extern_elementwise(
300
+ "", "", [arg0, arg1], {
301
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")),
302
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")),
303
+ }, is_pure=True, _semantic=_semantic)
304
+
305
+
306
+ @core.extern
307
+ def add_rz(arg0, arg1, _semantic=None):
308
+ return core.extern_elementwise(
309
+ "", "", [arg0, arg1], {
310
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")),
311
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")),
312
+ }, is_pure=True, _semantic=_semantic)
313
+
314
+
315
+ @core.extern
316
+ def add_rd(arg0, arg1, _semantic=None):
317
+ return core.extern_elementwise(
318
+ "", "", [arg0, arg1], {
319
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")),
320
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")),
321
+ }, is_pure=True, _semantic=_semantic)
322
+
323
+
324
+ @core.extern
325
+ def add_ru(arg0, arg1, _semantic=None):
326
+ return core.extern_elementwise(
327
+ "", "", [arg0, arg1], {
328
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")),
329
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")),
330
+ }, is_pure=True, _semantic=_semantic)
331
+
332
+
333
+ @core.extern
334
+ def mul_rn(arg0, arg1, _semantic=None):
335
+ return core.extern_elementwise(
336
+ "", "", [arg0, arg1], {
337
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")),
338
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")),
339
+ }, is_pure=True, _semantic=_semantic)
340
+
341
+
342
+ @core.extern
343
+ def mul_rz(arg0, arg1, _semantic=None):
344
+ return core.extern_elementwise(
345
+ "", "", [arg0, arg1], {
346
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")),
347
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")),
348
+ }, is_pure=True, _semantic=_semantic)
349
+
350
+
351
+ @core.extern
352
+ def mul_rd(arg0, arg1, _semantic=None):
353
+ return core.extern_elementwise(
354
+ "", "", [arg0, arg1], {
355
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")),
356
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")),
357
+ }, is_pure=True, _semantic=_semantic)
358
+
359
+
360
+ @core.extern
361
+ def mul_ru(arg0, arg1, _semantic=None):
362
+ return core.extern_elementwise(
363
+ "", "", [
364
+ arg0,
365
+ arg1,
366
+ ], {
367
+ (
368
+ core.dtype("fp64"),
369
+ core.dtype("fp64"),
370
+ ): ("__nv_dmul_ru", core.dtype("fp64")),
371
+ (
372
+ core.dtype("fp32"),
373
+ core.dtype("fp32"),
374
+ ): ("__nv_fmul_ru", core.dtype("fp32")),
375
+ }, is_pure=True, _semantic=_semantic)
376
+
377
+
378
+ @core.extern
379
+ def double2float_rn(arg0, _semantic=None):
380
+ return core.extern_elementwise("", "", [arg0], {
381
+ (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")),
382
+ }, is_pure=True, _semantic=_semantic)
383
+
384
+
385
+ @core.extern
386
+ def double2float_rz(arg0, _semantic=None):
387
+ return core.extern_elementwise("", "", [arg0], {
388
+ (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")),
389
+ }, is_pure=True, _semantic=_semantic)
390
+
391
+
392
+ @core.extern
393
+ def double2float_rd(arg0, _semantic=None):
394
+ return core.extern_elementwise("", "", [arg0], {
395
+ (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")),
396
+ }, is_pure=True, _semantic=_semantic)
397
+
398
+
399
+ @core.extern
400
+ def double2float_ru(arg0, _semantic=None):
401
+ return core.extern_elementwise("", "", [arg0], {
402
+ (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")),
403
+ }, is_pure=True, _semantic=_semantic)
404
+
405
+
406
+ @core.extern
407
+ def double2int_rn(arg0, _semantic=None):
408
+ return core.extern_elementwise("", "", [arg0], {
409
+ (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")),
410
+ }, is_pure=True, _semantic=_semantic)
411
+
412
+
413
+ @core.extern
414
+ def double2int_rz(arg0, _semantic=None):
415
+ return core.extern_elementwise("", "", [arg0], {
416
+ (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")),
417
+ }, is_pure=True, _semantic=_semantic)
418
+
419
+
420
+ @core.extern
421
+ def double2int_rd(arg0, _semantic=None):
422
+ return core.extern_elementwise("", "", [arg0], {
423
+ (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")),
424
+ }, is_pure=True, _semantic=_semantic)
425
+
426
+
427
+ @core.extern
428
+ def double2int_ru(arg0, _semantic=None):
429
+ return core.extern_elementwise("", "", [arg0], {
430
+ (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")),
431
+ }, is_pure=True, _semantic=_semantic)
432
+
433
+
434
+ @core.extern
435
+ def double2uint_rn(arg0, _semantic=None):
436
+ return core.extern_elementwise("", "", [arg0], {
437
+ (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")),
438
+ }, is_pure=True, _semantic=_semantic)
439
+
440
+
441
+ @core.extern
442
+ def double2uint_rz(arg0, _semantic=None):
443
+ return core.extern_elementwise("", "", [arg0], {
444
+ (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")),
445
+ }, is_pure=True, _semantic=_semantic)
446
+
447
+
448
+ @core.extern
449
+ def double2uint_rd(arg0, _semantic=None):
450
+ return core.extern_elementwise("", "", [arg0], {
451
+ (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")),
452
+ }, is_pure=True, _semantic=_semantic)
453
+
454
+
455
+ @core.extern
456
+ def double2uint_ru(arg0, _semantic=None):
457
+ return core.extern_elementwise("", "", [arg0], {
458
+ (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")),
459
+ }, is_pure=True, _semantic=_semantic)
460
+
461
+
462
+ @core.extern
463
+ def int2double_rn(arg0, _semantic=None):
464
+ return core.extern_elementwise("", "", [arg0], {
465
+ (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")),
466
+ }, is_pure=True, _semantic=_semantic)
467
+
468
+
469
+ @core.extern
470
+ def uint2double_rn(arg0, _semantic=None):
471
+ return core.extern_elementwise("", "", [arg0], {
472
+ (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")),
473
+ }, is_pure=True, _semantic=_semantic)
474
+
475
+
476
+ @core.extern
477
+ def float2int_rn(arg0, _semantic=None):
478
+ return core.extern_elementwise("", "", [arg0], {
479
+ (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")),
480
+ }, is_pure=True, _semantic=_semantic)
481
+
482
+
483
+ @core.extern
484
+ def float2int_rz(arg0, _semantic=None):
485
+ return core.extern_elementwise("", "", [arg0], {
486
+ (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")),
487
+ }, is_pure=True, _semantic=_semantic)
488
+
489
+
490
+ @core.extern
491
+ def float2int_rd(arg0, _semantic=None):
492
+ return core.extern_elementwise("", "", [arg0], {
493
+ (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")),
494
+ }, is_pure=True, _semantic=_semantic)
495
+
496
+
497
+ @core.extern
498
+ def float2int_ru(arg0, _semantic=None):
499
+ return core.extern_elementwise("", "", [arg0], {
500
+ (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")),
501
+ }, is_pure=True, _semantic=_semantic)
502
+
503
+
504
+ @core.extern
505
+ def float2uint_rn(arg0, _semantic=None):
506
+ return core.extern_elementwise("", "", [arg0], {
507
+ (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")),
508
+ }, is_pure=True, _semantic=_semantic)
509
+
510
+
511
+ @core.extern
512
+ def float2uint_rz(arg0, _semantic=None):
513
+ return core.extern_elementwise("", "", [arg0], {
514
+ (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")),
515
+ }, is_pure=True, _semantic=_semantic)
516
+
517
+
518
+ @core.extern
519
+ def float2uint_rd(arg0, _semantic=None):
520
+ return core.extern_elementwise("", "", [arg0], {
521
+ (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")),
522
+ }, is_pure=True, _semantic=_semantic)
523
+
524
+
525
+ @core.extern
526
+ def float2uint_ru(arg0, _semantic=None):
527
+ return core.extern_elementwise("", "", [arg0], {
528
+ (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")),
529
+ }, is_pure=True, _semantic=_semantic)
530
+
531
+
532
+ @core.extern
533
+ def int2float_rn(arg0, _semantic=None):
534
+ return core.extern_elementwise("", "", [arg0], {
535
+ (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")),
536
+ }, is_pure=True, _semantic=_semantic)
537
+
538
+
539
+ @core.extern
540
+ def int2float_rz(arg0, _semantic=None):
541
+ return core.extern_elementwise("", "", [arg0], {
542
+ (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")),
543
+ }, is_pure=True, _semantic=_semantic)
544
+
545
+
546
+ @core.extern
547
+ def int2float_rd(arg0, _semantic=None):
548
+ return core.extern_elementwise("", "", [arg0], {
549
+ (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")),
550
+ }, is_pure=True, _semantic=_semantic)
551
+
552
+
553
+ @core.extern
554
+ def int2float_ru(arg0, _semantic=None):
555
+ return core.extern_elementwise("", "", [arg0], {
556
+ (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")),
557
+ }, is_pure=True, _semantic=_semantic)
558
+
559
+
560
+ @core.extern
561
+ def uint2float_rn(arg0, _semantic=None):
562
+ return core.extern_elementwise("", "", [arg0], {
563
+ (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")),
564
+ }, is_pure=True, _semantic=_semantic)
565
+
566
+
567
+ @core.extern
568
+ def uint2float_rz(arg0, _semantic=None):
569
+ return core.extern_elementwise("", "", [arg0], {
570
+ (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")),
571
+ }, is_pure=True, _semantic=_semantic)
572
+
573
+
574
+ @core.extern
575
+ def uint2float_rd(arg0, _semantic=None):
576
+ return core.extern_elementwise("", "", [arg0], {
577
+ (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")),
578
+ }, is_pure=True, _semantic=_semantic)
579
+
580
+
581
+ @core.extern
582
+ def uint2float_ru(arg0, _semantic=None):
583
+ return core.extern_elementwise("", "", [arg0], {
584
+ (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")),
585
+ }, is_pure=True, _semantic=_semantic)
586
+
587
+
588
+ @core.extern
589
+ def hiloint2double(arg0, arg1, _semantic=None):
590
+ return core.extern_elementwise("", "", [arg0, arg1], {
591
+ (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")),
592
+ }, is_pure=True, _semantic=_semantic)
593
+
594
+
595
+ @core.extern
596
+ def double2loint(arg0, _semantic=None):
597
+ return core.extern_elementwise("", "", [arg0], {
598
+ (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")),
599
+ }, is_pure=True, _semantic=_semantic)
600
+
601
+
602
+ @core.extern
603
+ def double2hiint(arg0, _semantic=None):
604
+ return core.extern_elementwise("", "", [arg0], {
605
+ (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")),
606
+ }, is_pure=True, _semantic=_semantic)
607
+
608
+
609
+ @core.extern
610
+ def float2ll_rn(arg0, _semantic=None):
611
+ return core.extern_elementwise("", "", [arg0], {
612
+ (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")),
613
+ }, is_pure=True, _semantic=_semantic)
614
+
615
+
616
+ @core.extern
617
+ def float2ll_rz(arg0, _semantic=None):
618
+ return core.extern_elementwise("", "", [arg0], {
619
+ (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")),
620
+ }, is_pure=True, _semantic=_semantic)
621
+
622
+
623
+ @core.extern
624
+ def float2ll_rd(arg0, _semantic=None):
625
+ return core.extern_elementwise("", "", [arg0], {
626
+ (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")),
627
+ }, is_pure=True, _semantic=_semantic)
628
+
629
+
630
+ @core.extern
631
+ def float2ll_ru(arg0, _semantic=None):
632
+ return core.extern_elementwise("", "", [arg0], {
633
+ (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")),
634
+ }, is_pure=True, _semantic=_semantic)
635
+
636
+
637
+ @core.extern
638
+ def float2ull_rn(arg0, _semantic=None):
639
+ return core.extern_elementwise("", "", [arg0], {
640
+ (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")),
641
+ }, is_pure=True, _semantic=_semantic)
642
+
643
+
644
+ @core.extern
645
+ def float2ull_rz(arg0, _semantic=None):
646
+ return core.extern_elementwise("", "", [arg0], {
647
+ (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")),
648
+ }, is_pure=True, _semantic=_semantic)
649
+
650
+
651
+ @core.extern
652
+ def float2ull_rd(arg0, _semantic=None):
653
+ return core.extern_elementwise("", "", [arg0], {
654
+ (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")),
655
+ }, is_pure=True, _semantic=_semantic)
656
+
657
+
658
+ @core.extern
659
+ def float2ull_ru(arg0, _semantic=None):
660
+ return core.extern_elementwise("", "", [arg0], {
661
+ (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")),
662
+ }, is_pure=True, _semantic=_semantic)
663
+
664
+
665
+ @core.extern
666
+ def double2ll_rn(arg0, _semantic=None):
667
+ return core.extern_elementwise("", "", [arg0], {
668
+ (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")),
669
+ }, is_pure=True, _semantic=_semantic)
670
+
671
+
672
+ @core.extern
673
+ def double2ll_rz(arg0, _semantic=None):
674
+ return core.extern_elementwise("", "", [arg0], {
675
+ (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")),
676
+ }, is_pure=True, _semantic=_semantic)
677
+
678
+
679
+ @core.extern
680
+ def double2ll_rd(arg0, _semantic=None):
681
+ return core.extern_elementwise("", "", [arg0], {
682
+ (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")),
683
+ }, is_pure=True, _semantic=_semantic)
684
+
685
+
686
+ @core.extern
687
+ def double2ll_ru(arg0, _semantic=None):
688
+ return core.extern_elementwise("", "", [arg0], {
689
+ (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")),
690
+ }, is_pure=True, _semantic=_semantic)
691
+
692
+
693
+ @core.extern
694
+ def double2ull_rn(arg0, _semantic=None):
695
+ return core.extern_elementwise("", "", [arg0], {
696
+ (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")),
697
+ }, is_pure=True, _semantic=_semantic)
698
+
699
+
700
+ @core.extern
701
+ def double2ull_rz(arg0, _semantic=None):
702
+ return core.extern_elementwise("", "", [arg0], {
703
+ (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")),
704
+ }, is_pure=True, _semantic=_semantic)
705
+
706
+
707
+ @core.extern
708
+ def double2ull_rd(arg0, _semantic=None):
709
+ return core.extern_elementwise("", "", [arg0], {
710
+ (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")),
711
+ }, is_pure=True, _semantic=_semantic)
712
+
713
+
714
+ @core.extern
715
+ def double2ull_ru(arg0, _semantic=None):
716
+ return core.extern_elementwise("", "", [arg0], {
717
+ (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")),
718
+ }, is_pure=True, _semantic=_semantic)
719
+
720
+
721
+ @core.extern
722
+ def ll2float_rn(arg0, _semantic=None):
723
+ return core.extern_elementwise("", "", [arg0], {
724
+ (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")),
725
+ }, is_pure=True, _semantic=_semantic)
726
+
727
+
728
+ @core.extern
729
+ def ll2float_rz(arg0, _semantic=None):
730
+ return core.extern_elementwise("", "", [arg0], {
731
+ (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")),
732
+ }, is_pure=True, _semantic=_semantic)
733
+
734
+
735
+ @core.extern
736
+ def ll2float_rd(arg0, _semantic=None):
737
+ return core.extern_elementwise("", "", [arg0], {
738
+ (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")),
739
+ }, is_pure=True, _semantic=_semantic)
740
+
741
+
742
+ @core.extern
743
+ def ll2float_ru(arg0, _semantic=None):
744
+ return core.extern_elementwise("", "", [arg0], {
745
+ (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")),
746
+ }, is_pure=True, _semantic=_semantic)
747
+
748
+
749
+ @core.extern
750
+ def ull2float_rn(arg0, _semantic=None):
751
+ return core.extern_elementwise("", "", [arg0], {
752
+ (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")),
753
+ }, is_pure=True, _semantic=_semantic)
754
+
755
+
756
+ @core.extern
757
+ def ull2float_rz(arg0, _semantic=None):
758
+ return core.extern_elementwise("", "", [arg0], {
759
+ (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")),
760
+ }, is_pure=True, _semantic=_semantic)
761
+
762
+
763
+ @core.extern
764
+ def ull2float_rd(arg0, _semantic=None):
765
+ return core.extern_elementwise("", "", [arg0], {
766
+ (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")),
767
+ }, is_pure=True, _semantic=_semantic)
768
+
769
+
770
+ @core.extern
771
+ def ull2float_ru(arg0, _semantic=None):
772
+ return core.extern_elementwise("", "", [arg0], {
773
+ (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")),
774
+ }, is_pure=True, _semantic=_semantic)
775
+
776
+
777
+ @core.extern
778
+ def ll2double_rn(arg0, _semantic=None):
779
+ return core.extern_elementwise("", "", [arg0], {
780
+ (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")),
781
+ }, is_pure=True, _semantic=_semantic)
782
+
783
+
784
+ @core.extern
785
+ def ll2double_rz(arg0, _semantic=None):
786
+ return core.extern_elementwise("", "", [arg0], {
787
+ (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")),
788
+ }, is_pure=True, _semantic=_semantic)
789
+
790
+
791
+ @core.extern
792
+ def ll2double_rd(arg0, _semantic=None):
793
+ return core.extern_elementwise("", "", [arg0], {
794
+ (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")),
795
+ }, is_pure=True, _semantic=_semantic)
796
+
797
+
798
+ @core.extern
799
+ def ll2double_ru(arg0, _semantic=None):
800
+ return core.extern_elementwise("", "", [arg0], {
801
+ (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")),
802
+ }, is_pure=True, _semantic=_semantic)
803
+
804
+
805
+ @core.extern
806
+ def ull2double_rn(arg0, _semantic=None):
807
+ return core.extern_elementwise("", "", [arg0], {
808
+ (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")),
809
+ }, is_pure=True, _semantic=_semantic)
810
+
811
+
812
+ @core.extern
813
+ def ull2double_rz(arg0, _semantic=None):
814
+ return core.extern_elementwise("", "", [arg0], {
815
+ (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")),
816
+ }, is_pure=True, _semantic=_semantic)
817
+
818
+
819
+ @core.extern
820
+ def ull2double_rd(arg0, _semantic=None):
821
+ return core.extern_elementwise("", "", [arg0], {
822
+ (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")),
823
+ }, is_pure=True, _semantic=_semantic)
824
+
825
+
826
+ @core.extern
827
+ def ull2double_ru(arg0, _semantic=None):
828
+ return core.extern_elementwise("", "", [arg0], {
829
+ (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")),
830
+ }, is_pure=True, _semantic=_semantic)
831
+
832
+
833
+ @core.extern
834
+ def int_as_float(arg0, _semantic=None):
835
+ return core.extern_elementwise("", "", [arg0], {
836
+ (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")),
837
+ }, is_pure=True, _semantic=_semantic)
838
+
839
+
840
+ @core.extern
841
+ def float_as_int(arg0, _semantic=None):
842
+ return core.extern_elementwise("", "", [arg0], {
843
+ (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")),
844
+ }, is_pure=True, _semantic=_semantic)
845
+
846
+
847
+ @core.extern
848
+ def uint_as_float(arg0, _semantic=None):
849
+ return core.extern_elementwise("", "", [arg0], {
850
+ (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")),
851
+ }, is_pure=True, _semantic=_semantic)
852
+
853
+
854
+ @core.extern
855
+ def float_as_uint(arg0, _semantic=None):
856
+ return core.extern_elementwise("", "", [arg0], {
857
+ (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")),
858
+ }, is_pure=True, _semantic=_semantic)
859
+
860
+
861
+ @core.extern
862
+ def longlong_as_double(arg0, _semantic=None):
863
+ return core.extern_elementwise("", "", [arg0], {
864
+ (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")),
865
+ }, is_pure=True, _semantic=_semantic)
866
+
867
+
868
+ @core.extern
869
+ def double_as_longlong(arg0, _semantic=None):
870
+ return core.extern_elementwise("", "", [arg0], {
871
+ (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")),
872
+ }, is_pure=True, _semantic=_semantic)
873
+
874
+
875
+ @core.extern
876
+ def fast_sinf(arg0, _semantic=None):
877
+ return core.extern_elementwise("", "", [arg0], {
878
+ (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")),
879
+ }, is_pure=True, _semantic=_semantic)
880
+
881
+
882
+ @core.extern
883
+ def fast_cosf(arg0, _semantic=None):
884
+ return core.extern_elementwise("", "", [arg0], {
885
+ (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")),
886
+ }, is_pure=True, _semantic=_semantic)
887
+
888
+
889
+ @core.extern
890
+ def fast_log2f(arg0, _semantic=None):
891
+ return core.extern_elementwise("", "", [arg0], {
892
+ (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")),
893
+ }, is_pure=True, _semantic=_semantic)
894
+
895
+
896
+ @core.extern
897
+ def fast_logf(arg0, _semantic=None):
898
+ return core.extern_elementwise("", "", [arg0], {
899
+ (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")),
900
+ }, is_pure=True, _semantic=_semantic)
901
+
902
+
903
+ @core.extern
904
+ def fast_expf(arg0, _semantic=None):
905
+ return core.extern_elementwise("", "", [arg0], {
906
+ (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")),
907
+ }, is_pure=True, _semantic=_semantic)
908
+
909
+
910
+ @core.extern
911
+ def fast_tanf(arg0, _semantic=None):
912
+ return core.extern_elementwise("", "", [arg0], {
913
+ (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")),
914
+ }, is_pure=True, _semantic=_semantic)
915
+
916
+
917
+ @core.extern
918
+ def fast_exp10f(arg0, _semantic=None):
919
+ return core.extern_elementwise("", "", [arg0], {
920
+ (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")),
921
+ }, is_pure=True, _semantic=_semantic)
922
+
923
+
924
+ @core.extern
925
+ def fast_log10f(arg0, _semantic=None):
926
+ return core.extern_elementwise("", "", [arg0], {
927
+ (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")),
928
+ }, is_pure=True, _semantic=_semantic)
929
+
930
+
931
+ @core.extern
932
+ def fast_powf(arg0, arg1, _semantic=None):
933
+ return core.extern_elementwise("", "", [arg0, arg1], {
934
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")),
935
+ }, is_pure=True, _semantic=_semantic)
936
+
937
+
938
+ @core.extern
939
+ def hadd(arg0, arg1, _semantic=None):
940
+ return core.extern_elementwise(
941
+ "", "", [arg0, arg1], {
942
+ (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")),
943
+ (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")),
944
+ }, is_pure=True, _semantic=_semantic)
945
+
946
+
947
+ @core.extern
948
+ def rhadd(arg0, arg1, _semantic=None):
949
+ return core.extern_elementwise(
950
+ "", "", [arg0, arg1], {
951
+ (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")),
952
+ (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")),
953
+ }, is_pure=True, _semantic=_semantic)
954
+
955
+
956
+ @core.extern
957
+ def sub_rn(arg0, arg1, _semantic=None):
958
+ return core.extern_elementwise(
959
+ "", "", [arg0, arg1], {
960
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")),
961
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")),
962
+ }, is_pure=True, _semantic=_semantic)
963
+
964
+
965
+ @core.extern
966
+ def sub_rz(arg0, arg1, _semantic=None):
967
+ return core.extern_elementwise(
968
+ "", "", [arg0, arg1], {
969
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")),
970
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")),
971
+ }, is_pure=True, _semantic=_semantic)
972
+
973
+
974
+ @core.extern
975
+ def sub_rd(arg0, arg1, _semantic=None):
976
+ return core.extern_elementwise(
977
+ "", "", [arg0, arg1], {
978
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")),
979
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")),
980
+ }, is_pure=True, _semantic=_semantic)
981
+
982
+
983
+ @core.extern
984
+ def sub_ru(arg0, arg1, _semantic=None):
985
+ return core.extern_elementwise(
986
+ "", "", [arg0, arg1], {
987
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")),
988
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")),
989
+ }, is_pure=True, _semantic=_semantic)
990
+
991
+
992
+ @core.extern
993
+ def rsqrt_rn(arg0, _semantic=None):
994
+ return core.extern_elementwise("", "", [
995
+ arg0,
996
+ ], {
997
+ (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")),
998
+ }, is_pure=True, _semantic=_semantic)
999
+
1000
+
1001
+ @core.extern
1002
+ def ffs(arg0, _semantic=None):
1003
+ return core.extern_elementwise(
1004
+ "", "", [
1005
+ arg0,
1006
+ ], {
1007
+ (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")),
1008
+ (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")),
1009
+ }, is_pure=True, _semantic=_semantic)
1010
+
1011
+
1012
+ @core.extern
1013
+ def rint(arg0, _semantic=None):
1014
+ return core.extern_elementwise(
1015
+ "", "", [
1016
+ arg0,
1017
+ ], {
1018
+ (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")),
1019
+ (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")),
1020
+ }, is_pure=True, _semantic=_semantic)
1021
+
1022
+
1023
+ @core.extern
1024
+ def llrint(arg0, _semantic=None):
1025
+ return core.extern_elementwise(
1026
+ "", "", [
1027
+ arg0,
1028
+ ], {
1029
+ (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")),
1030
+ (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")),
1031
+ }, is_pure=True, _semantic=_semantic)
1032
+
1033
+
1034
+ @core.extern
1035
+ def nearbyint(arg0, _semantic=None):
1036
+ return core.extern_elementwise(
1037
+ "", "", [
1038
+ arg0,
1039
+ ], {
1040
+ (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")),
1041
+ (core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")),
1042
+ }, is_pure=True, _semantic=_semantic)
1043
+
1044
+
1045
+ @core.extern
1046
+ def isnan(arg0, _semantic=None):
1047
+ return core.extern_elementwise(
1048
+ "", "", [
1049
+ arg0,
1050
+ ], {
1051
+ (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")),
1052
+ (core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")),
1053
+ }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
1054
+
1055
+
1056
+ @core.extern
1057
+ def signbit(arg0, _semantic=None):
1058
+ return core.extern_elementwise(
1059
+ "", "", [
1060
+ arg0,
1061
+ ], {
1062
+ (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")),
1063
+ (core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")),
1064
+ }, is_pure=True, _semantic=_semantic)
1065
+
1066
+
1067
+ @core.extern
1068
+ def copysign(arg0, arg1, _semantic=None):
1069
+ return core.extern_elementwise(
1070
+ "", "", [arg0, arg1], {
1071
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")),
1072
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")),
1073
+ }, is_pure=True, _semantic=_semantic)
1074
+
1075
+
1076
+ @core.extern
1077
+ def finitef(arg0, _semantic=None):
1078
+ return core.extern_elementwise("", "", [arg0], {
1079
+ (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")),
1080
+ }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
1081
+
1082
+
1083
+ @core.extern
1084
+ def isinf(arg0, _semantic=None):
1085
+ return core.extern_elementwise(
1086
+ "", "", [arg0], {
1087
+ (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")),
1088
+ (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")),
1089
+ }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
1090
+
1091
+
1092
+ @core.extern
1093
+ def nextafter(arg0, arg1, _semantic=None):
1094
+ return core.extern_elementwise(
1095
+ "", "", [arg0, arg1], {
1096
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")),
1097
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")),
1098
+ }, is_pure=True, _semantic=_semantic)
1099
+
1100
+
1101
+ @core.extern
1102
+ def sin(arg0, _semantic=None):
1103
+ return core.extern_elementwise(
1104
+ "", "", [arg0], {
1105
+ (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")),
1106
+ (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")),
1107
+ }, is_pure=True, _semantic=_semantic)
1108
+
1109
+
1110
+ @core.extern
1111
+ def cos(arg0, _semantic=None):
1112
+ return core.extern_elementwise(
1113
+ "", "", [arg0], {
1114
+ (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")),
1115
+ (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")),
1116
+ }, is_pure=True, _semantic=_semantic)
1117
+
1118
+
1119
+ @core.extern
1120
+ def sinpi(arg0, _semantic=None):
1121
+ return core.extern_elementwise(
1122
+ "", "", [arg0], {
1123
+ (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")),
1124
+ (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")),
1125
+ }, is_pure=True, _semantic=_semantic)
1126
+
1127
+
1128
+ @core.extern
1129
+ def cospi(arg0, _semantic=None):
1130
+ return core.extern_elementwise(
1131
+ "", "", [arg0], {
1132
+ (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")),
1133
+ (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")),
1134
+ }, is_pure=True, _semantic=_semantic)
1135
+
1136
+
1137
+ @core.extern
1138
+ def tan(arg0, _semantic=None):
1139
+ return core.extern_elementwise(
1140
+ "", "", [arg0], {
1141
+ (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")),
1142
+ (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")),
1143
+ }, is_pure=True, _semantic=_semantic)
1144
+
1145
+
1146
+ @core.extern
1147
+ def log2(arg0, _semantic=None):
1148
+ return core.extern_elementwise(
1149
+ "", "", [arg0], {
1150
+ (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")),
1151
+ (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")),
1152
+ }, is_pure=True, _semantic=_semantic)
1153
+
1154
+
1155
+ @core.extern
1156
+ def exp(arg0, _semantic=None):
1157
+ return core.extern_elementwise(
1158
+ "", "", [arg0], {
1159
+ (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")),
1160
+ (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")),
1161
+ }, is_pure=True, _semantic=_semantic)
1162
+
1163
+
1164
+ @core.extern
1165
+ def exp10(arg0, _semantic=None):
1166
+ return core.extern_elementwise(
1167
+ "", "", [arg0], {
1168
+ (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")),
1169
+ (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")),
1170
+ }, is_pure=True, _semantic=_semantic)
1171
+
1172
+
1173
+ @core.extern
1174
+ def cosh(arg0, _semantic=None):
1175
+ return core.extern_elementwise(
1176
+ "", "", [arg0], {
1177
+ (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")),
1178
+ (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")),
1179
+ }, is_pure=True, _semantic=_semantic)
1180
+
1181
+
1182
+ @core.extern
1183
+ def sinh(arg0, _semantic=None):
1184
+ return core.extern_elementwise(
1185
+ "", "", [arg0], {
1186
+ (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")),
1187
+ (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")),
1188
+ }, is_pure=True, _semantic=_semantic)
1189
+
1190
+
1191
+ @core.extern
1192
+ def tanh(arg0, _semantic=None):
1193
+ return core.extern_elementwise(
1194
+ "", "", [arg0], {
1195
+ (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")),
1196
+ (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")),
1197
+ }, is_pure=True, _semantic=_semantic)
1198
+
1199
+
1200
+ @core.extern
1201
+ def atan2(arg0, arg1, _semantic=None):
1202
+ return core.extern_elementwise(
1203
+ "", "", [arg0, arg1], {
1204
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")),
1205
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")),
1206
+ }, is_pure=True, _semantic=_semantic)
1207
+
1208
+
1209
+ @core.extern
1210
+ def atan(arg0, _semantic=None):
1211
+ return core.extern_elementwise(
1212
+ "", "", [arg0], {
1213
+ (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")),
1214
+ (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")),
1215
+ }, is_pure=True, _semantic=_semantic)
1216
+
1217
+
1218
+ @core.extern
1219
+ def asin(arg0, _semantic=None):
1220
+ return core.extern_elementwise(
1221
+ "", "", [arg0], {
1222
+ (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")),
1223
+ (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")),
1224
+ }, is_pure=True, _semantic=_semantic)
1225
+
1226
+
1227
+ @core.extern
1228
+ def acos(arg0, _semantic=None):
1229
+ return core.extern_elementwise(
1230
+ "", "", [arg0], {
1231
+ (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")),
1232
+ (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")),
1233
+ }, is_pure=True, _semantic=_semantic)
1234
+
1235
+
1236
+ @core.extern
1237
+ def log(arg0, _semantic=None):
1238
+ return core.extern_elementwise(
1239
+ "", "", [arg0], {
1240
+ (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")),
1241
+ (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")),
1242
+ }, is_pure=True, _semantic=_semantic)
1243
+
1244
+
1245
+ @core.extern
1246
+ def log10(arg0, _semantic=None):
1247
+ return core.extern_elementwise(
1248
+ "", "", [arg0], {
1249
+ (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")),
1250
+ (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")),
1251
+ }, is_pure=True, _semantic=_semantic)
1252
+
1253
+
1254
+ @core.extern
1255
+ def log1p(arg0, _semantic=None):
1256
+ return core.extern_elementwise(
1257
+ "", "", [arg0], {
1258
+ (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")),
1259
+ (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")),
1260
+ }, is_pure=True, _semantic=_semantic)
1261
+
1262
+
1263
+ @core.extern
1264
+ def acosh(arg0, _semantic=None):
1265
+ return core.extern_elementwise(
1266
+ "", "", [arg0], {
1267
+ (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")),
1268
+ (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")),
1269
+ }, is_pure=True, _semantic=_semantic)
1270
+
1271
+
1272
+ @core.extern
1273
+ def asinh(arg0, _semantic=None):
1274
+ return core.extern_elementwise(
1275
+ "", "", [arg0], {
1276
+ (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")),
1277
+ (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")),
1278
+ }, is_pure=True, _semantic=_semantic)
1279
+
1280
+
1281
+ @core.extern
1282
+ def atanh(arg0, _semantic=None):
1283
+ return core.extern_elementwise(
1284
+ "", "", [arg0], {
1285
+ (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")),
1286
+ (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")),
1287
+ }, is_pure=True, _semantic=_semantic)
1288
+
1289
+
1290
+ @core.extern
1291
+ def expm1(arg0, _semantic=None):
1292
+ return core.extern_elementwise(
1293
+ "", "", [arg0], {
1294
+ (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")),
1295
+ (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")),
1296
+ }, is_pure=True, _semantic=_semantic)
1297
+
1298
+
1299
+ @core.extern
1300
+ def hypot(arg0, arg1, _semantic=None):
1301
+ return core.extern_elementwise(
1302
+ "", "", [arg0, arg1], {
1303
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")),
1304
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")),
1305
+ }, is_pure=True, _semantic=_semantic)
1306
+
1307
+
1308
+ @core.extern
1309
+ def rhypot(arg0, arg1, _semantic=None):
1310
+ return core.extern_elementwise(
1311
+ "", "", [arg0, arg1], {
1312
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")),
1313
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")),
1314
+ }, is_pure=True, _semantic=_semantic)
1315
+
1316
+
1317
+ @core.extern
1318
+ def norm3d(arg0, arg1, arg2, _semantic=None):
1319
+ return core.extern_elementwise(
1320
+ "", "", [arg0, arg1, arg2], {
1321
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")),
1322
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")),
1323
+ }, is_pure=True, _semantic=_semantic)
1324
+
1325
+
1326
+ @core.extern
1327
+ def rnorm3d(arg0, arg1, arg2, _semantic=None):
1328
+ return core.extern_elementwise(
1329
+ "", "", [arg0, arg1, arg2], {
1330
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")),
1331
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")),
1332
+ }, is_pure=True, _semantic=_semantic)
1333
+
1334
+
1335
+ @core.extern
1336
+ def norm4d(arg0, arg1, arg2, arg3, _semantic=None):
1337
+ return core.extern_elementwise(
1338
+ "", "", [arg0, arg1, arg2, arg3], {
1339
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")):
1340
+ ("__nv_norm4df", core.dtype("fp32")),
1341
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")):
1342
+ ("__nv_norm4d", core.dtype("fp64")),
1343
+ }, is_pure=True, _semantic=_semantic)
1344
+
1345
+
1346
+ @core.extern
1347
+ def rnorm4d(arg0, arg1, arg2, arg3, _semantic=None):
1348
+ return core.extern_elementwise(
1349
+ "", "", [arg0, arg1, arg2, arg3], {
1350
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")):
1351
+ ("__nv_rnorm4df", core.dtype("fp32")),
1352
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")):
1353
+ ("__nv_rnorm4d", core.dtype("fp64")),
1354
+ }, is_pure=True, _semantic=_semantic)
1355
+
1356
+
1357
+ @core.extern
1358
+ def cbrt(arg0, _semantic=None):
1359
+ return core.extern_elementwise(
1360
+ "", "", [arg0], {
1361
+ (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")),
1362
+ (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")),
1363
+ }, is_pure=True, _semantic=_semantic)
1364
+
1365
+
1366
+ @core.extern
1367
+ def rcbrt(arg0, _semantic=None):
1368
+ return core.extern_elementwise(
1369
+ "", "", [arg0], {
1370
+ (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")),
1371
+ (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")),
1372
+ }, is_pure=True, _semantic=_semantic)
1373
+
1374
+
1375
+ @core.extern
1376
+ def j0(arg0, _semantic=None):
1377
+ return core.extern_elementwise("", "", [arg0], {
1378
+ (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")),
1379
+ (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")),
1380
+ }, is_pure=True, _semantic=_semantic)
1381
+
1382
+
1383
+ @core.extern
1384
+ def j1(arg0, _semantic=None):
1385
+ return core.extern_elementwise("", "", [arg0], {
1386
+ (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")),
1387
+ (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")),
1388
+ }, is_pure=True, _semantic=_semantic)
1389
+
1390
+
1391
+ @core.extern
1392
+ def y0(arg0, _semantic=None):
1393
+ return core.extern_elementwise("", "", [arg0], {
1394
+ (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")),
1395
+ (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")),
1396
+ }, is_pure=True, _semantic=_semantic)
1397
+
1398
+
1399
+ @core.extern
1400
+ def y1(arg0, _semantic=None):
1401
+ return core.extern_elementwise("", "", [arg0], {
1402
+ (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")),
1403
+ (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")),
1404
+ }, is_pure=True, _semantic=_semantic)
1405
+
1406
+
1407
+ @core.extern
1408
+ def yn(arg0, arg1, _semantic=None):
1409
+ return core.extern_elementwise(
1410
+ "", "", [arg0, arg1], {
1411
+ (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")),
1412
+ (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")),
1413
+ }, is_pure=True, _semantic=_semantic)
1414
+
1415
+
1416
+ @core.extern
1417
+ def jn(arg0, arg1, _semantic=None):
1418
+ return core.extern_elementwise(
1419
+ "", "", [arg0, arg1], {
1420
+ (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")),
1421
+ (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")),
1422
+ }, is_pure=True, _semantic=_semantic)
1423
+
1424
+
1425
+ @core.extern
1426
+ def cyl_bessel_i0(arg0, _semantic=None):
1427
+ return core.extern_elementwise(
1428
+ "", "", [arg0], {
1429
+ (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")),
1430
+ (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")),
1431
+ }, is_pure=True, _semantic=_semantic)
1432
+
1433
+
1434
+ @core.extern
1435
+ def cyl_bessel_i1(arg0, _semantic=None):
1436
+ return core.extern_elementwise(
1437
+ "", "", [arg0], {
1438
+ (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")),
1439
+ (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")),
1440
+ }, is_pure=True, _semantic=_semantic)
1441
+
1442
+
1443
+ @core.extern
1444
+ def erf(arg0, _semantic=None):
1445
+ return core.extern_elementwise(
1446
+ "", "", [arg0], {
1447
+ (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")),
1448
+ (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")),
1449
+ }, is_pure=True, _semantic=_semantic)
1450
+
1451
+
1452
+ @core.extern
1453
+ def erfinv(arg0, _semantic=None):
1454
+ return core.extern_elementwise(
1455
+ "", "", [arg0], {
1456
+ (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")),
1457
+ (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")),
1458
+ }, is_pure=True, _semantic=_semantic)
1459
+
1460
+
1461
+ @core.extern
1462
+ def erfc(arg0, _semantic=None):
1463
+ return core.extern_elementwise(
1464
+ "", "", [arg0], {
1465
+ (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")),
1466
+ (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")),
1467
+ }, is_pure=True, _semantic=_semantic)
1468
+
1469
+
1470
+ @core.extern
1471
+ def erfcx(arg0, _semantic=None):
1472
+ return core.extern_elementwise(
1473
+ "", "", [arg0], {
1474
+ (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")),
1475
+ (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")),
1476
+ }, is_pure=True, _semantic=_semantic)
1477
+
1478
+
1479
+ @core.extern
1480
+ def erfcinv(arg0, _semantic=None):
1481
+ return core.extern_elementwise(
1482
+ "", "", [arg0], {
1483
+ (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")),
1484
+ (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")),
1485
+ }, is_pure=True, _semantic=_semantic)
1486
+
1487
+
1488
+ @core.extern
1489
+ def normcdfinv(arg0, _semantic=None):
1490
+ return core.extern_elementwise(
1491
+ "", "", [arg0], {
1492
+ (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")),
1493
+ (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")),
1494
+ }, is_pure=True, _semantic=_semantic)
1495
+
1496
+
1497
+ @core.extern
1498
+ def normcdf(arg0, _semantic=None):
1499
+ return core.extern_elementwise(
1500
+ "", "", [arg0], {
1501
+ (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")),
1502
+ (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")),
1503
+ }, is_pure=True, _semantic=_semantic)
1504
+
1505
+
1506
+ @core.extern
1507
+ def lgamma(arg0, _semantic=None):
1508
+ return core.extern_elementwise(
1509
+ "", "", [arg0], {
1510
+ (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")),
1511
+ (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")),
1512
+ }, is_pure=True, _semantic=_semantic)
1513
+
1514
+
1515
+ @core.extern
1516
+ def ldexp(arg0, arg1, _semantic=None):
1517
+ return core.extern_elementwise(
1518
+ "", "", [arg0, arg1], {
1519
+ (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")),
1520
+ (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")),
1521
+ }, is_pure=True, _semantic=_semantic)
1522
+
1523
+
1524
+ @core.extern
1525
+ def scalbn(arg0, arg1, _semantic=None):
1526
+ return core.extern_elementwise(
1527
+ "", "", [arg0, arg1], {
1528
+ (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")),
1529
+ (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")),
1530
+ }, is_pure=True, _semantic=_semantic)
1531
+
1532
+
1533
+ @core.extern
1534
+ def fmod(arg0, arg1, _semantic=None):
1535
+ return core.extern_elementwise(
1536
+ "", "", [arg0, arg1], {
1537
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")),
1538
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")),
1539
+ }, is_pure=True, _semantic=_semantic)
1540
+
1541
+
1542
+ @core.extern
1543
+ def remainder(arg0, arg1, _semantic=None):
1544
+ return core.extern_elementwise(
1545
+ "", "", [arg0, arg1], {
1546
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")),
1547
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")),
1548
+ }, is_pure=True, _semantic=_semantic)
1549
+
1550
+
1551
+ @core.extern
1552
+ def fma(arg0, arg1, arg2, _semantic=None):
1553
+ return core.extern_elementwise(
1554
+ "", "", [arg0, arg1, arg2], {
1555
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")),
1556
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")),
1557
+ }, is_pure=True, _semantic=_semantic)
1558
+
1559
+
1560
+ @core.extern
1561
+ def pow(arg0, arg1, _semantic=None):
1562
+ return core.extern_elementwise(
1563
+ "", "", [arg0, arg1], {
1564
+ (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")),
1565
+ (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")),
1566
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")),
1567
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")),
1568
+ }, is_pure=True, _semantic=_semantic)
1569
+
1570
+
1571
+ @core.extern
1572
+ def tgamma(arg0, _semantic=None):
1573
+ return core.extern_elementwise(
1574
+ "", "", [arg0], {
1575
+ (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")),
1576
+ (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")),
1577
+ }, is_pure=True, _semantic=_semantic)
1578
+
1579
+
1580
+ @core.extern
1581
+ def round(arg0, _semantic=None):
1582
+ return core.extern_elementwise(
1583
+ "", "", [arg0], {
1584
+ (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")),
1585
+ (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")),
1586
+ }, is_pure=True, _semantic=_semantic)
1587
+
1588
+
1589
+ @core.extern
1590
+ def llround(arg0, _semantic=None):
1591
+ return core.extern_elementwise(
1592
+ "", "", [arg0], {
1593
+ (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")),
1594
+ (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")),
1595
+ }, is_pure=True, _semantic=_semantic)
1596
+
1597
+
1598
+ @core.extern
1599
+ def fdim(arg0, arg1, _semantic=None):
1600
+ return core.extern_elementwise(
1601
+ "", "", [arg0, arg1], {
1602
+ (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")),
1603
+ (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")),
1604
+ }, is_pure=True, _semantic=_semantic)
1605
+
1606
+
1607
+ @core.extern
1608
+ def ilogb(arg0, _semantic=None):
1609
+ return core.extern_elementwise(
1610
+ "", "", [arg0], {
1611
+ (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")),
1612
+ (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")),
1613
+ }, is_pure=True, _semantic=_semantic)
1614
+
1615
+
1616
+ @core.extern
1617
+ def logb(arg0, _semantic=None):
1618
+ return core.extern_elementwise(
1619
+ "", "", [arg0], {
1620
+ (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")),
1621
+ (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")),
1622
+ }, is_pure=True, _semantic=_semantic)
1623
+
1624
+
1625
+ @core.extern
1626
+ def isfinited(arg0, _semantic=None):
1627
+ return core.extern_elementwise("", "", [arg0], {
1628
+ (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")),
1629
+ }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/cuda/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def globaltimer(_semantic=None):
6
+ return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
7
+ _semantic=_semantic)
8
+
9
+
10
+ @core.extern
11
+ def smid(_semantic=None):
12
+ return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
13
+ _semantic=_semantic)
14
+
15
+
16
+ @core.builtin
17
+ def num_threads(_semantic=None):
18
+ return core.constexpr(_semantic.builder.options.num_warps * 32)
19
+
20
+
21
+ @core.builtin
22
+ def num_warps(_semantic=None):
23
+ return core.constexpr(_semantic.builder.options.num_warps)
24
+
25
+
26
+ # ----- FP8E4M3B15 ------
27
+ # This data-type is a variant of the standard FP8E4M3 format.
28
+ # It was designed for fast software conversion to FP16 on
29
+ # nvidia GPUs that do not support it natively.
30
+ # This is the same format as FP8E4M3Nv, but:
31
+ # - the exponent bias is 15 instead of 7
32
+ # - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
33
+ @core.builtin
34
+ def convert_fp8e4b15_to_float16(arg, _semantic=None):
35
+ return core.inline_asm_elementwise(
36
+ "{ \n"
37
+ ".reg .b32 a<2>, b<2>; \n"
38
+ "prmt.b32 a0, 0, $2, 0x5746; \n"
39
+ "and.b32 b0, a0, 0x7f007f00; \n"
40
+ "and.b32 b1, a0, 0x00ff00ff; \n"
41
+ "and.b32 a1, a0, 0x00800080; \n"
42
+ "shr.b32 b0, b0, 1; \n"
43
+ "add.u32 b1, b1, a1; \n"
44
+ "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
45
+ "shl.b32 $1, b1, 7; \n"
46
+ "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
47
+ _semantic=_semantic)
48
+
49
+
50
+ @core.builtin
51
+ def convert_float16_to_fp8e4b15(arg, has_minx2, _semantic=None):
52
+ asm = """{
53
+ .reg .pred p<4>;
54
+ .reg .b32 a<2>, b<2>;
55
+ .reg .b16 c<4>;
56
+ .reg .b16 max_val_f16;
57
+ .reg .b32 max_val_f16x2;
58
+ mov.b16 max_val_f16, 0x3F00;
59
+ mov.b32 max_val_f16x2, 0x3F003F00;
60
+ and.b32 a0, $1, 0x7fff7fff;
61
+ and.b32 a1, $2, 0x7fff7fff;"""
62
+ if has_minx2:
63
+ asm += """min.f16x2 a0, a0, max_val_f16x2;
64
+ min.f16x2 a1, a1, max_val_f16x2;"""
65
+ else:
66
+ asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2;
67
+ setp.lt.f16x2 p2|p3, a1, max_val_f16x2;
68
+ mov.b32 {c0, c1}, a0;
69
+ mov.b32 {c2, c3}, a1;
70
+ selp.b16 c0, c0, max_val_f16, p0;
71
+ selp.b16 c1, c1, max_val_f16, p1;
72
+ selp.b16 c2, c2, max_val_f16, p2;
73
+ selp.b16 c3, c3, max_val_f16, p3;
74
+ mov.b32 a0, {c0, c1};
75
+ mov.b32 a1, {c2, c3};"""
76
+ asm += """mad.lo.u32 a0, a0, 2, 0x00800080;
77
+ mad.lo.u32 a1, a1, 2, 0x00800080;
78
+ lop3.b32 b0, $1, 0x80008000, a0, 0xea;
79
+ lop3.b32 b1, $2, 0x80008000, a1, 0xea;
80
+ prmt.b32 $0, b0, b1, 0x7531;
81
+ }"""
82
+ return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
83
+ _semantic=_semantic)
84
+
85
+
86
+ @core.builtin
87
+ def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _semantic=None):
88
+ if arg.type.scalar.is_fp8e4b15():
89
+ upcast_val = convert_fp8e4b15_to_float16(arg, _semantic=_semantic)
90
+ if dst_ty.scalar.is_fp32():
91
+ upcast_val = upcast_val.to(core.float32, _semantic=_semantic)
92
+ return upcast_val
93
+
94
+ assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
95
+ downcast_val = arg
96
+ if arg.type.scalar.is_fp32():
97
+ downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _semantic=_semantic)
98
+ downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _semantic=_semantic)
99
+ return downcast_val
100
+
101
+
102
+ @core.builtin
103
+ def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
104
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _semantic=_semantic)
105
+
106
+
107
+ @core.builtin
108
+ def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
109
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _semantic=_semantic)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/hip/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import libdevice
2
+
3
+ from .utils import memrealtime
4
+
5
+ __all__ = ["libdevice", "memrealtime"]
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/triton/language/extra/hip/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (347 Bytes). View file