koichi12 commited on
Commit
a378ef8
·
verified ·
1 Parent(s): ee1d2ef

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py +258 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py +218 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py +233 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__init__.py +1 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py +82 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py +1412 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py +622 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/comm.py +18 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/memory.py +914 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc +0 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc +0 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc +0 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-311.pyc +0 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-311.pyc +0 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-311.pyc +0 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/cudagraphs.py +56 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/test_pass_manager.py +58 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__init__.py +1 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (229 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc ADDED
Binary file (1.52 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import os
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from typing import Any, List, Optional
7
+
8
+ import sympy
9
+
10
+ import torch
11
+
12
+ from ...codecache import cache_dir
13
+ from ...config import cuda as inductor_cuda_config
14
+ from ...ir import Layout
15
+ from .cuda_env import get_cuda_arch, get_cuda_version
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
21
+ for cutlass_module in cutlass_modules:
22
+ content = content.replace(
23
+ f"from {cutlass_module} import ",
24
+ f"from cutlass_library.{cutlass_module} import ",
25
+ )
26
+ return content
27
+
28
+
29
+ def _gen_cutlass_file(
30
+ file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str
31
+ ) -> None:
32
+ orig_full_path = os.path.abspath(os.path.join(src_dir, file_name))
33
+ text = ""
34
+ with open(orig_full_path) as f:
35
+ text = f.read()
36
+ text = _rename_cutlass_import(text, cutlass_modules)
37
+ dst_full_path = os.path.abspath(
38
+ os.path.join(
39
+ dst_dir,
40
+ file_name,
41
+ )
42
+ )
43
+ with open(dst_full_path, "w") as f:
44
+ f.write(text)
45
+
46
+
47
+ @functools.lru_cache(None)
48
+ def try_import_cutlass() -> bool:
49
+ # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
50
+ # This is a temporary hack to avoid CUTLASS module naming conflicts.
51
+ # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
52
+
53
+ cutlass_py_full_path = os.path.abspath(
54
+ os.path.join(inductor_cuda_config.cutlass_dir, "python/cutlass_library")
55
+ )
56
+ tmp_cutlass_py_full_path = os.path.abspath(
57
+ os.path.join(cache_dir(), "torch_cutlass_library")
58
+ )
59
+ dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
60
+
61
+ if os.path.isdir(cutlass_py_full_path):
62
+ if tmp_cutlass_py_full_path not in sys.path:
63
+ if os.path.exists(dst_link):
64
+ assert os.path.islink(
65
+ dst_link
66
+ ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
67
+ assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
68
+ cutlass_py_full_path
69
+ ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
70
+ else:
71
+ os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
72
+ os.symlink(cutlass_py_full_path, dst_link)
73
+ sys.path.append(tmp_cutlass_py_full_path)
74
+ try:
75
+ import cutlass_library.generator # noqa: F401
76
+ import cutlass_library.library # noqa: F401
77
+ import cutlass_library.manifest # noqa: F401
78
+
79
+ return True
80
+
81
+ except ImportError as e:
82
+ log.debug(
83
+ "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.",
84
+ str(e),
85
+ )
86
+ else:
87
+ log.debug(
88
+ "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
89
+ cutlass_py_full_path,
90
+ )
91
+ return False
92
+
93
+
94
+ def _normalize_cuda_arch(arch: str) -> str:
95
+ if int(arch) >= 90:
96
+ return "90"
97
+ elif int(arch) >= 80:
98
+ return "80"
99
+ elif int(arch) >= 75:
100
+ return "75"
101
+ elif int(arch) >= 70:
102
+ return "70"
103
+ else:
104
+ raise NotImplementedError(f"Unsupported cuda arch: {arch}")
105
+
106
+
107
+ @dataclass
108
+ class CUTLASSArgs:
109
+ """
110
+ CUTLASS args used to initialize a CUTLASS Manifest.
111
+ """
112
+
113
+ architectures: Optional[str] = None
114
+ cuda_version: Optional[str] = None
115
+
116
+ operations = "all"
117
+ build_dir = ""
118
+ curr_build_dir = ""
119
+ generator_target = ""
120
+ kernels = "all"
121
+ ignore_kernels = ""
122
+ # TODO: these three look dead?
123
+ kernel_filter_file: None = None
124
+ selected_kernel_list: None = None
125
+ interface_dir: None = None
126
+ filter_by_cc = True
127
+ disable_full_archs_compilation = False
128
+
129
+ def __post_init__(self):
130
+ if self.architectures is None or self.cuda_version is None:
131
+ raise RuntimeError(
132
+ f"{self.architectures=} or {self.cuda_version=} is None!"
133
+ )
134
+ self.architectures = _normalize_cuda_arch(self.architectures)
135
+
136
+
137
+ @functools.lru_cache(None)
138
+ def _gen_ops_cached(arch, version) -> List[Any]:
139
+ # Note: Cache needs to be specific for cuda architecture and version
140
+
141
+ # Import cutlass python scripts.
142
+ assert try_import_cutlass()
143
+ import cutlass_library.generator as cutlass_generator
144
+ import cutlass_library.manifest as cutlass_manifest
145
+
146
+ if arch is None or version is None:
147
+ log.error(
148
+ "Cannot detect cuda arch %s or cuda version %s. "
149
+ "Will discard all cutlass ops. "
150
+ "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.",
151
+ arch,
152
+ version,
153
+ )
154
+ return list()
155
+ arch = _normalize_cuda_arch(arch)
156
+ args = CUTLASSArgs(architectures=arch, cuda_version=version)
157
+ manifest = cutlass_manifest.Manifest(args)
158
+
159
+ if arch == "90":
160
+ cutlass_generator.GenerateSM90(manifest, args.cuda_version)
161
+ cutlass_generator.GenerateSM80(manifest, args.cuda_version)
162
+ else:
163
+ try:
164
+ func = getattr(cutlass_generator, "GenerateSM" + arch)
165
+ func(manifest, args.cuda_version)
166
+ except AttributeError as e:
167
+ raise NotImplementedError(
168
+ "Arch " + arch + " is not supported by current cutlass lib."
169
+ ) from e
170
+ return manifest.operations
171
+
172
+
173
+ def gen_ops() -> List[Any]:
174
+ """
175
+ Generates all supported CUTLASS operations.
176
+ """
177
+ arch = get_cuda_arch()
178
+ version = get_cuda_version()
179
+ return _gen_ops_cached(arch, version)
180
+
181
+
182
+ def dtype_match(
183
+ torch_dtype: Optional[torch.dtype],
184
+ cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821
185
+ ) -> bool:
186
+ # Import cutlass python scripts.
187
+ assert try_import_cutlass()
188
+ import cutlass_library
189
+
190
+ if torch_dtype == torch.float:
191
+ return (
192
+ cutlass_dtype == cutlass_library.library.DataType.f32
193
+ or cutlass_dtype == cutlass_library.library.DataType.tf32
194
+ )
195
+ elif torch_dtype == torch.half:
196
+ return cutlass_dtype == cutlass_library.library.DataType.f16
197
+ elif torch_dtype == torch.bfloat16:
198
+ return cutlass_dtype == cutlass_library.library.DataType.bf16
199
+ else:
200
+ return False
201
+
202
+
203
+ def get_accumulator_dtype(
204
+ input_torch_dtypes: List[torch.dtype],
205
+ ) -> Optional[torch.dtype]:
206
+ """
207
+ Given a list of input torch dtypes, returns the inferred accumulator torch dtype.
208
+ """
209
+
210
+ if len(input_torch_dtypes) == 0:
211
+ return None
212
+ torch_dtype = input_torch_dtypes[0]
213
+ for dtype in input_torch_dtypes[1:]:
214
+ if torch_dtype != dtype:
215
+ raise RuntimeError(f"Unmatched input dtypes: {torch_dtype=}, {dtype=}")
216
+ if torch_dtype == torch.half:
217
+ if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction:
218
+ return torch_dtype
219
+ else:
220
+ return torch.float
221
+ if torch_dtype in {torch.bfloat16, torch.float}:
222
+ return torch.float
223
+ raise NotImplementedError(f"Unsupported data type: {input_torch_dtypes=}")
224
+
225
+
226
+ def get_alignments(torch_dtype: torch.dtype) -> List[int]:
227
+ """
228
+ Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype.
229
+ CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment.
230
+ """
231
+
232
+ if torch_dtype in (torch.half, torch.bfloat16):
233
+ return [8, 4, 2, 1]
234
+ elif torch_dtype == torch.float:
235
+ return [4, 2, 1]
236
+ else:
237
+ raise NotImplementedError(f"unsupported {torch_dtype=} for alignments")
238
+
239
+
240
+ def get_max_alignment(inductor_layout: Layout) -> int:
241
+ """
242
+ Returns the max alignment (in terms of number of elements) for a given Inductor Layout.
243
+ """
244
+
245
+ dtype = inductor_layout.dtype
246
+ size = inductor_layout.size
247
+ offset = inductor_layout.offset
248
+
249
+ def is_static_int(number):
250
+ return isinstance(number, (int, sympy.Integer))
251
+
252
+ if is_static_int(size[-1]) and is_static_int(offset):
253
+ alignments = get_alignments(dtype)
254
+ for alignment in alignments:
255
+ if int(size[-1]) % alignment == 0 and int(offset) % alignment == 0:
256
+ return alignment
257
+
258
+ return 1
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-311.pyc ADDED
Binary file (36.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-311.pyc ADDED
Binary file (31.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-311.pyc ADDED
Binary file (66.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc ADDED
Binary file (70.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc ADDED
Binary file (19.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc ADDED
Binary file (52.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-311.pyc ADDED
Binary file (19.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-311.pyc ADDED
Binary file (19.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
35
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
36
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
37
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
38
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
39
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
40
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
41
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
42
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
43
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
44
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
45
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
46
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
47
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
48
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
49
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
50
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
51
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
52
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
53
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
54
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
55
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
56
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
57
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
58
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
59
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
60
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
61
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
62
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
63
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
64
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
65
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
66
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
67
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
68
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
69
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
70
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
71
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
72
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
73
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
74
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
75
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
76
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
77
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
78
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
79
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
80
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
81
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
82
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
83
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
84
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
85
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
86
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
87
+ _sfdp_pattern_14_training = MultiOutputPattern([view_default_5,
88
+ permute_default_6,
89
+ permute_default_9,
90
+ permute_default_11,
91
+ None,
92
+ None
93
+ ])
94
+
95
+
96
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
97
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
98
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
99
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
100
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
101
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
102
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
103
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
104
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
105
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
106
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
107
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
108
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
109
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
110
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
111
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
112
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
113
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
114
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
115
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
116
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
117
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
118
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
119
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
120
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
121
+ _sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
122
+
123
+
124
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
125
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
126
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
127
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
128
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
129
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
130
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
131
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
132
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
133
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
134
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
135
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
136
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
137
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
138
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
139
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
140
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
141
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
142
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
143
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
144
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
145
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
146
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
147
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
148
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
149
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
150
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
151
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
152
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
153
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
154
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
155
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
156
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
157
+ alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
158
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
159
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
160
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
161
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
162
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
163
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
164
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
165
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
166
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
167
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
168
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
169
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
170
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
171
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
172
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
173
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
174
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
175
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
176
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
177
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
178
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
179
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
180
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
181
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
182
+ _sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5,
183
+ permute_default_6,
184
+ permute_default_9,
185
+ permute_default_11,
186
+ None,
187
+ None
188
+ ])
189
+
190
+
191
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
192
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
193
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
194
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
195
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
196
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
197
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
198
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
199
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
200
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
201
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
202
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
203
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
204
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
205
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
206
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
207
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
208
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
209
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
210
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
211
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
212
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
213
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
214
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
215
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
216
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
217
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
218
+ _sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python
7
+ # torchgen/fuse_attention_patterns/gen_attention_patterns.py
8
+
9
+ import torch
10
+ import torch._inductor
11
+
12
+ aten = torch.ops.aten
13
+ prims = torch.ops.prims
14
+
15
+ from torch._inductor.pattern_matcher import (
16
+ Arg,
17
+ CallFunction,
18
+ CallFunctionVarArgs,
19
+ CallMethod,
20
+ CallMethodVarArgs,
21
+ CallModule,
22
+ CallModuleVarArgs,
23
+ ExclusiveKeywordArg,
24
+ Ignored,
25
+ KeywordArg,
26
+ ListOf,
27
+ MultiOutputPattern,
28
+ PatternExpr,
29
+ RepeatedExpr,
30
+ _TargetArgsExpr,
31
+ _TargetExpr,
32
+ _TargetExprVarArgs,
33
+ )
34
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
35
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
37
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
38
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
39
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
40
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
41
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
42
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
43
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
44
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
45
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
46
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
47
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
48
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
49
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
50
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
51
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
52
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
53
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
54
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
55
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
56
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
57
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
58
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
59
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
60
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
61
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
62
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
63
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
64
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
65
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
66
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
67
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
68
+ view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
69
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
70
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
71
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
72
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
73
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
74
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
75
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
76
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
77
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
78
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
79
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
80
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
81
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
82
+ div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
83
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
84
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
85
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
86
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
87
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
88
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
89
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
90
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
91
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
92
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
93
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
94
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
95
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
96
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
97
+ _sfdp_pattern_7_training = MultiOutputPattern([view_default_5,
98
+ permute_default_6,
99
+ permute_default_9,
100
+ permute_default_11,
101
+ None
102
+ ])
103
+
104
+
105
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
106
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
107
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
108
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
109
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
110
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
111
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
112
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
113
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
114
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
115
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
116
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
117
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
118
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
119
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
120
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
121
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
122
+ clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
123
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
124
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
125
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
126
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
127
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
128
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
129
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
130
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
131
+ _sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
132
+
133
+
134
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
135
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
136
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
137
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
138
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
139
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
140
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
141
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
142
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
143
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
144
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
145
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
146
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
147
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
148
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
149
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
150
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
151
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
152
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
153
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
154
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
155
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
156
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
157
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
158
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
159
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
160
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
161
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
162
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
163
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
164
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
165
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
166
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
167
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
168
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
169
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
170
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
171
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
172
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
173
+ clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
174
+ alias_default = CallFunction(aten.alias.default, div_Tensor_1)
175
+ alias_default_1 = CallFunction(aten.alias.default, alias_default)
176
+ alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
177
+ alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
178
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
179
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
180
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
181
+ sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
182
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
183
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
184
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
185
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
186
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
187
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
188
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
189
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
190
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
191
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
192
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
193
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
194
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
195
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
196
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
197
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
198
+ _sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5,
199
+ permute_default_6,
200
+ permute_default_9,
201
+ permute_default_11,
202
+ None
203
+ ])
204
+
205
+
206
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
207
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
208
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
209
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
210
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
211
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
212
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
213
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
214
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
215
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
216
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
217
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
218
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
219
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
220
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
221
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
222
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
223
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
224
+ clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
225
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
226
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
227
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
228
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
229
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
230
+ clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
231
+ view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
232
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
233
+ _sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import mm, mm_common, mm_plus_mm, unpack_mixed_mm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (349 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-311.pyc ADDED
Binary file (7.42 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+
4
+ from ..select_algorithm import autotune_select_algorithm, ChoiceCaller, TritonTemplate
5
+ from .mm_common import mm_args, mm_configs, mm_grid, mm_options
6
+
7
+ log = logging.getLogger(__name__)
8
+
9
+ uint4x2_mixed_mm_template = TritonTemplate(
10
+ name="uint4x2_mixed_mm",
11
+ grid=mm_grid,
12
+ source=r"""
13
+ {{def_kernel("A", "B")}}
14
+ M = {{size("A", 0)}}
15
+ N = {{size("B", 1)}}
16
+ K = {{size("A", 1)}}
17
+ stride_am = {{stride("A", 0)}}
18
+ stride_ak = {{stride("A", 1)}}
19
+ stride_bk = {{stride("B", 0)}}
20
+ stride_bn = {{stride("B", 1)}}
21
+
22
+ # based on triton.ops.matmul
23
+ pid = tl.program_id(0)
24
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
25
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
26
+
27
+ # re-order program ID for better L2 performance
28
+ width = GROUP_M * grid_n
29
+ group_id = pid // width
30
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
31
+ pid_m = group_id * GROUP_M + (pid % group_size)
32
+ pid_n = (pid % width) // (group_size)
33
+
34
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
35
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
36
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
37
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
38
+ rk = tl.arange(0, BLOCK_K)
39
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
40
+ B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
41
+ b_shifts = 4*(rk%2)
42
+ b_subs = 8*(1-(rk%2))
43
+
44
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
45
+ for k in range(K, 0, -BLOCK_K):
46
+ if EVEN_K:
47
+ a = tl.load(A)
48
+ b = tl.load(B)
49
+ else:
50
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
51
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
52
+ b = ((b >> b_shifts[:, None]) & 0xF) - 8
53
+ b = b.to(B_PROLOGUE_CAST_TYPE)
54
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
55
+ A += BLOCK_K * stride_ak
56
+ B += BLOCK_K//2 * stride_bk
57
+
58
+ # rematerialize rm and rn to save registers
59
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
60
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
61
+ idx_m = rm[:, None]
62
+ idx_n = rn[None, :]
63
+ mask = (idx_m < M) & (idx_n < N)
64
+
65
+ # inductor generates a suffix
66
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
67
+ """,
68
+ )
69
+
70
+
71
+ def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
72
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
73
+ choices: List[ChoiceCaller] = []
74
+ b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
75
+ for config in mm_configs(m, n, k):
76
+ uint4x2_mixed_mm_template.maybe_append_choice(
77
+ choices,
78
+ input_nodes=(mat1, mat2),
79
+ layout=layout,
80
+ **mm_options(config, m, n, k, layout, b_prologue_cast_type),
81
+ )
82
+ return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py ADDED
@@ -0,0 +1,1412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ This package adds support for CUDA tensor types.
3
+
4
+ It implements the same function as CPU tensors, but they utilize
5
+ GPUs for computation.
6
+
7
+ It is lazily initialized, so you can always import it, and use
8
+ :func:`is_available()` to determine if your system supports CUDA.
9
+
10
+ :ref:`cuda-semantics` has more details about working with CUDA.
11
+ """
12
+
13
+
14
+ import contextlib
15
+ import importlib
16
+ import os
17
+ import sys
18
+ import threading
19
+ import traceback
20
+ import warnings
21
+ from functools import lru_cache
22
+ from typing import Any, Callable, cast, List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch._C
26
+ from torch.types import Device
27
+ from .. import device as _device
28
+ from .._utils import _dummy_type, _LazySeedTracker, classproperty
29
+ from ._utils import _get_device_index
30
+ from .graphs import (
31
+ CUDAGraph,
32
+ graph,
33
+ graph_pool_handle,
34
+ is_current_stream_capturing,
35
+ make_graphed_callables,
36
+ )
37
+ from .streams import Event, ExternalStream, Stream
38
+
39
+ try:
40
+ from torch._C import _cudart # type: ignore[attr-defined]
41
+ except ImportError:
42
+ _cudart = None
43
+
44
+ _initialized = False
45
+ _tls = threading.local()
46
+ _initialization_lock = threading.Lock()
47
+ _queued_calls: List[
48
+ Tuple[Callable[[], None], List[str]]
49
+ ] = [] # don't invoke these until initialization occurs
50
+ _is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
51
+ _device_t = Union[_device, str, int, None]
52
+
53
+ _HAS_PYNVML = False
54
+ _PYNVML_ERR = None
55
+ try:
56
+ import pynvml # type: ignore[import]
57
+
58
+ _HAS_PYNVML = True
59
+ except ImportError as err:
60
+ _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
61
+
62
+ _lazy_seed_tracker = _LazySeedTracker()
63
+
64
+ # Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
65
+ if hasattr(torch._C, "_CudaDeviceProperties"):
66
+ _CudaDeviceProperties = torch._C._CudaDeviceProperties
67
+ else:
68
+ _CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") # type: ignore[assignment, misc]
69
+
70
+ if hasattr(torch._C, "_cuda_exchangeDevice"):
71
+ _exchange_device = torch._C._cuda_exchangeDevice
72
+ else:
73
+
74
+ def _exchange_device(device: int) -> int:
75
+ if device < 0:
76
+ return -1
77
+ raise RuntimeError("PyTorch was compiled without CUDA support")
78
+
79
+
80
+ if hasattr(torch._C, "_cuda_maybeExchangeDevice"):
81
+ _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice
82
+ else:
83
+
84
+ def _maybe_exchange_device(device: int) -> int:
85
+ if device < 0:
86
+ return -1
87
+ raise RuntimeError("PyTorch was compiled without CUDA support")
88
+
89
+
90
+ has_half: bool = True
91
+ has_magma: bool = torch._C._has_magma
92
+
93
+ default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
94
+
95
+
96
+ def _is_compiled() -> bool:
97
+ r"""Return true if compile with CUDA support."""
98
+ return hasattr(torch._C, "_cuda_getDeviceCount")
99
+
100
+
101
+ def _nvml_based_avail() -> bool:
102
+ return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1"
103
+
104
+
105
+ def is_available() -> bool:
106
+ r"""Return a bool indicating if CUDA is currently available."""
107
+ if not _is_compiled():
108
+ return False
109
+ if _nvml_based_avail():
110
+ # The user has set an env variable to request this availability check that attempts to avoid fork poisoning by
111
+ # using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization
112
+ # fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`)
113
+ return device_count() > 0
114
+ else:
115
+ # The default availability inspection never throws and returns 0 if the driver is missing or can't
116
+ # be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver
117
+ # API via `cuInit`
118
+ return torch._C._cuda_getDeviceCount() > 0
119
+
120
+
121
+ def is_bf16_supported():
122
+ r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
123
+ # Check for ROCm, if true return true, no ROCM_VERSION check required,
124
+ # since it is supported on AMD GPU archs.
125
+ if torch.version.hip:
126
+ return True
127
+
128
+ device = torch.cuda.current_device()
129
+
130
+ # Check for CUDA version and device compute capability.
131
+ # This is a fast way to check for it.
132
+ cuda_version = torch.version.cuda
133
+ if (
134
+ cuda_version is not None
135
+ and int(cuda_version.split(".")[0]) >= 11
136
+ and torch.cuda.get_device_properties(device).major >= 8
137
+ ):
138
+ return True
139
+
140
+ # Finally try to create a bfloat16 device.
141
+ return _check_bf16_tensor_supported(device)
142
+
143
+
144
+ @lru_cache(maxsize=16)
145
+ def _check_bf16_tensor_supported(device: _device_t):
146
+ try:
147
+ torch.tensor([1.0], dtype=torch.bfloat16, device=device)
148
+ return True
149
+ except Exception:
150
+ return False
151
+
152
+
153
+ def _sleep(cycles):
154
+ torch._C._cuda_sleep(cycles)
155
+
156
+
157
+ def _check_capability():
158
+ incorrect_binary_warn = """
159
+ Found GPU%d %s which requires CUDA_VERSION >= %d to
160
+ work properly, but your PyTorch was compiled
161
+ with CUDA_VERSION %d. Please install the correct PyTorch binary
162
+ using instructions from https://pytorch.org
163
+ """
164
+
165
+ old_gpu_warn = """
166
+ Found GPU%d %s which is of cuda capability %d.%d.
167
+ PyTorch no longer supports this GPU because it is too old.
168
+ The minimum cuda capability supported by this library is %d.%d.
169
+ """
170
+
171
+ if torch.version.cuda is not None: # on ROCm we don't want this check
172
+ CUDA_VERSION = torch._C._cuda_getCompiledVersion()
173
+ for d in range(device_count()):
174
+ capability = get_device_capability(d)
175
+ major = capability[0]
176
+ minor = capability[1]
177
+ name = get_device_name(d)
178
+ current_arch = major * 10 + minor
179
+ min_arch = min(
180
+ (int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list()),
181
+ default=35,
182
+ )
183
+ if current_arch < min_arch:
184
+ warnings.warn(
185
+ old_gpu_warn
186
+ % (d, name, major, minor, min_arch // 10, min_arch % 10)
187
+ )
188
+
189
+
190
+ def _check_cubins():
191
+ incompatible_device_warn = """
192
+ {} with CUDA capability sm_{} is not compatible with the current PyTorch installation.
193
+ The current PyTorch install supports CUDA capabilities {}.
194
+ If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/
195
+ """
196
+ if torch.version.cuda is None: # on ROCm we don't want this check
197
+ return
198
+ arch_list = get_arch_list()
199
+ if len(arch_list) == 0:
200
+ return
201
+ supported_sm = [int(arch.split("_")[1]) for arch in arch_list if "sm_" in arch]
202
+ for idx in range(device_count()):
203
+ cap_major, cap_minor = get_device_capability(idx)
204
+ # NVIDIA GPU compute architectures are backward compatible within major version
205
+ supported = any(sm // 10 == cap_major for sm in supported_sm)
206
+ if not supported:
207
+ device_name = get_device_name(idx)
208
+ capability = cap_major * 10 + cap_minor
209
+ warnings.warn(
210
+ incompatible_device_warn.format(
211
+ device_name, capability, " ".join(arch_list), device_name
212
+ )
213
+ )
214
+
215
+
216
+ def is_initialized():
217
+ r"""Return whether PyTorch's CUDA state has been initialized."""
218
+ return _initialized and not _is_in_bad_fork()
219
+
220
+
221
+ def _lazy_call(callable, **kwargs):
222
+ if is_initialized():
223
+ callable()
224
+ else:
225
+ # TODO(torch_deploy): this accesses linecache, which attempts to read the
226
+ # file system to get traceback info. Patch linecache or do something
227
+ # else here if this ends up being important.
228
+ global _lazy_seed_tracker
229
+ if kwargs.get("seed_all", False):
230
+ _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
231
+ elif kwargs.get("seed", False):
232
+ _lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
233
+ else:
234
+ # Don't store the actual traceback to avoid memory cycle
235
+ _queued_calls.append((callable, traceback.format_stack()))
236
+
237
+
238
+ _lazy_call(_check_capability)
239
+ _lazy_call(_check_cubins)
240
+
241
+
242
+ class DeferredCudaCallError(Exception):
243
+ pass
244
+
245
+
246
+ OutOfMemoryError = torch._C._OutOfMemoryError
247
+
248
+
249
+ def init():
250
+ r"""Initialize PyTorch's CUDA state.
251
+
252
+ You may need to call this explicitly if you are interacting with
253
+ PyTorch via its C API, as Python bindings for CUDA functionality
254
+ will not be available until this initialization takes place.
255
+ Ordinary users should not need this, as all of PyTorch's CUDA methods
256
+ automatically initialize CUDA state on-demand.
257
+
258
+ Does nothing if the CUDA state is already initialized.
259
+ """
260
+ _lazy_init()
261
+
262
+
263
+ def _lazy_init():
264
+ global _initialized, _queued_calls
265
+ if is_initialized() or hasattr(_tls, "is_initializing"):
266
+ return
267
+ with _initialization_lock:
268
+ # We be double-checked locking, boys! This is OK because
269
+ # the above test was GIL protected anyway. The inner test
270
+ # is for when a thread blocked on some other thread which was
271
+ # doing the initialization; when they get the lock, they will
272
+ # find there is nothing left to do.
273
+ if is_initialized():
274
+ return
275
+ # It is important to prevent other threads from entering _lazy_init
276
+ # immediately, while we are still guaranteed to have the GIL, because some
277
+ # of the C calls we make below will release the GIL
278
+ if _is_in_bad_fork():
279
+ raise RuntimeError(
280
+ "Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
281
+ "multiprocessing, you must use the 'spawn' start method"
282
+ )
283
+ if not hasattr(torch._C, "_cuda_getDeviceCount"):
284
+ raise AssertionError("Torch not compiled with CUDA enabled")
285
+ if _cudart is None:
286
+ raise AssertionError(
287
+ "libcudart functions unavailable. It looks like you have a broken build?"
288
+ )
289
+ # This function throws if there's a driver initialization error, no GPUs
290
+ # are found or any other error occurs
291
+ if "CUDA_MODULE_LOADING" not in os.environ:
292
+ os.environ["CUDA_MODULE_LOADING"] = "LAZY"
293
+ torch._C._cuda_init()
294
+ # Some of the queued calls may reentrantly call _lazy_init();
295
+ # we need to just return without initializing in that case.
296
+ # However, we must not let any *other* threads in!
297
+ _tls.is_initializing = True
298
+
299
+ for calls in _lazy_seed_tracker.get_calls():
300
+ if calls:
301
+ _queued_calls.append(calls)
302
+
303
+ try:
304
+ for queued_call, orig_traceback in _queued_calls:
305
+ try:
306
+ queued_call()
307
+ except Exception as e:
308
+ msg = (
309
+ f"CUDA call failed lazily at initialization with error: {str(e)}\n\n"
310
+ f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}"
311
+ )
312
+ raise DeferredCudaCallError(msg) from e
313
+ finally:
314
+ delattr(_tls, "is_initializing")
315
+ _initialized = True
316
+
317
+
318
+ def cudart():
319
+ _lazy_init()
320
+ return _cudart
321
+
322
+
323
+ class cudaStatus:
324
+ SUCCESS: int = 0
325
+ ERROR_NOT_READY: int = 34
326
+
327
+
328
+ class CudaError(RuntimeError):
329
+ def __init__(self, code: int) -> None:
330
+ msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
331
+ super().__init__(f"{msg} ({code})")
332
+
333
+
334
+ def check_error(res: int) -> None:
335
+ if res != _cudart.cudaError.success:
336
+ raise CudaError(res)
337
+
338
+
339
+ class _DeviceGuard:
340
+ def __init__(self, index: int):
341
+ self.idx = index
342
+ self.prev_idx = -1
343
+
344
+ def __enter__(self):
345
+ self.prev_idx = torch.cuda._exchange_device(self.idx)
346
+
347
+ def __exit__(self, type: Any, value: Any, traceback: Any):
348
+ self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
349
+ return False
350
+
351
+
352
+ class device:
353
+ r"""Context-manager that changes the selected device.
354
+
355
+ Args:
356
+ device (torch.device or int): device index to select. It's a no-op if
357
+ this argument is a negative integer or ``None``.
358
+ """
359
+
360
+ def __init__(self, device: Any):
361
+ self.idx = _get_device_index(device, optional=True)
362
+ self.prev_idx = -1
363
+
364
+ def __enter__(self):
365
+ self.prev_idx = torch.cuda._exchange_device(self.idx)
366
+
367
+ def __exit__(self, type: Any, value: Any, traceback: Any):
368
+ self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
369
+ return False
370
+
371
+
372
+ class device_of(device):
373
+ r"""Context-manager that changes the current device to that of given object.
374
+
375
+ You can use both tensors and storages as arguments. If a given object is
376
+ not allocated on a GPU, this is a no-op.
377
+
378
+ Args:
379
+ obj (Tensor or Storage): object allocated on the selected device.
380
+ """
381
+
382
+ def __init__(self, obj):
383
+ idx = obj.get_device() if obj.is_cuda else -1
384
+ super().__init__(idx)
385
+
386
+
387
+ def set_device(device: _device_t) -> None:
388
+ r"""Set the current device.
389
+
390
+ Usage of this function is discouraged in favor of :any:`device`. In most
391
+ cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
392
+
393
+ Args:
394
+ device (torch.device or int): selected device. This function is a no-op
395
+ if this argument is negative.
396
+ """
397
+ device = _get_device_index(device)
398
+ if device >= 0:
399
+ torch._C._cuda_setDevice(device)
400
+
401
+
402
+ def get_device_name(device: Optional[_device_t] = None) -> str:
403
+ r"""Get the name of a device.
404
+
405
+ Args:
406
+ device (torch.device or int, optional): device for which to return the
407
+ name. This function is a no-op if this argument is a negative
408
+ integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
409
+ if :attr:`device` is ``None`` (default).
410
+
411
+ Returns:
412
+ str: the name of the device
413
+ """
414
+ return get_device_properties(device).name
415
+
416
+
417
+ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
418
+ r"""Get the cuda capability of a device.
419
+
420
+ Args:
421
+ device (torch.device or int, optional): device for which to return the
422
+ device capability. This function is a no-op if this argument is
423
+ a negative integer. It uses the current device, given by
424
+ :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
425
+ (default).
426
+
427
+ Returns:
428
+ tuple(int, int): the major and minor cuda capability of the device
429
+ """
430
+ prop = get_device_properties(device)
431
+ return prop.major, prop.minor
432
+
433
+
434
+ def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
435
+ r"""Get the properties of a device.
436
+
437
+ Args:
438
+ device (torch.device or int or str): device for which to return the
439
+ properties of the device.
440
+
441
+ Returns:
442
+ _CudaDeviceProperties: the properties of the device
443
+ """
444
+ _lazy_init() # will define _get_device_properties
445
+ device = _get_device_index(device, optional=True)
446
+ if device < 0 or device >= device_count():
447
+ raise AssertionError("Invalid device id")
448
+ return _get_device_properties(device) # type: ignore[name-defined]
449
+
450
+
451
+ def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool:
452
+ r"""Check if peer access between two devices is possible."""
453
+ _lazy_init()
454
+ device = _get_device_index(device, optional=True)
455
+ peer_device = _get_device_index(peer_device)
456
+ if device < 0 or device >= device_count():
457
+ raise AssertionError("Invalid device id")
458
+ if peer_device < 0 or peer_device >= device_count():
459
+ raise AssertionError("Invalid peer device id")
460
+ return torch._C._cuda_canDeviceAccessPeer(device, peer_device)
461
+
462
+
463
+ class StreamContext:
464
+ r"""Context-manager that selects a given stream.
465
+
466
+ All CUDA kernels queued within its context will be enqueued on a selected
467
+ stream.
468
+
469
+ Args:
470
+ Stream (Stream): selected stream. This manager is a no-op if it's
471
+ ``None``.
472
+ .. note:: Streams are per-device.
473
+ """
474
+ cur_stream: Optional["torch.cuda.Stream"]
475
+
476
+ def __init__(self, stream: Optional["torch.cuda.Stream"]):
477
+ self.stream = stream
478
+ self.idx = _get_device_index(None, True)
479
+ if not torch.jit.is_scripting():
480
+ if self.idx is None:
481
+ self.idx = -1
482
+
483
+ self.src_prev_stream = (
484
+ None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
485
+ )
486
+ self.dst_prev_stream = (
487
+ None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
488
+ )
489
+
490
+ def __enter__(self):
491
+ # Local cur_stream variable for type refinement
492
+ cur_stream = self.stream
493
+ # Return if stream is None or CUDA device not available
494
+ if cur_stream is None or self.idx == -1:
495
+ return
496
+ self.src_prev_stream = torch.cuda.current_stream(None)
497
+
498
+ # If the stream is not on the current device, then
499
+ # set the current stream on the device
500
+ if self.src_prev_stream.device != cur_stream.device:
501
+ with device(cur_stream.device):
502
+ self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device)
503
+ torch.cuda.set_stream(cur_stream)
504
+
505
+ def __exit__(self, type: Any, value: Any, traceback: Any):
506
+ # Local cur_stream variable for type refinement
507
+ cur_stream = self.stream
508
+ # If stream is None or no CUDA device available, return
509
+ if cur_stream is None or self.idx == -1:
510
+ return
511
+
512
+ # Reset the stream on the original device
513
+ # and destination device
514
+ if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
515
+ torch.cuda.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
516
+ torch.cuda.set_stream(self.src_prev_stream) # type: ignore[arg-type]
517
+
518
+
519
+ def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext:
520
+ r"""Wrap around the Context-manager StreamContext that selects a given stream.
521
+
522
+ Arguments:
523
+ stream (Stream): selected stream. This manager is a no-op if it's
524
+ ``None``.
525
+ ..Note:: In eager mode stream is of type Stream class while in JIT it is
526
+ an object of the custom class ``torch.classes.cuda.Stream``.
527
+ """
528
+ return StreamContext(stream)
529
+
530
+
531
+ def _set_stream_by_id(stream_id, device_index, device_type):
532
+ r"""set stream specified by the stream id, device index and
533
+ device type
534
+
535
+ Args: stream_id (int): stream id in stream pool
536
+ device_index (int): device index in topo
537
+ device_type (int): enum device type
538
+ """
539
+ torch._C._cuda_setStream(
540
+ stream_id=stream_id,
541
+ device_index=device_index,
542
+ device_type=device_type,
543
+ )
544
+
545
+
546
+ def set_stream(stream: Stream):
547
+ r"""Set the current stream.This is a wrapper API to set the stream.
548
+ Usage of this function is discouraged in favor of the ``stream``
549
+ context manager.
550
+
551
+ Args:
552
+ stream (Stream): selected stream. This function is a no-op
553
+ if this argument is ``None``.
554
+ """
555
+ if stream is None:
556
+ return
557
+ _set_stream_by_id(
558
+ stream_id=stream.stream_id,
559
+ device_index=stream.device_index,
560
+ device_type=stream.device_type,
561
+ )
562
+
563
+
564
+ def _parse_visible_devices() -> Union[List[int], List[str]]:
565
+ r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
566
+ var = os.getenv("CUDA_VISIBLE_DEVICES")
567
+ if var is None:
568
+ return list(range(64))
569
+
570
+ def _strtoul(s: str) -> int:
571
+ """Return -1 or positive integer sequence string starts with."""
572
+ if not s:
573
+ return -1
574
+ for idx, c in enumerate(s):
575
+ if not (c.isdigit() or (idx == 0 and c in "+-")):
576
+ break
577
+ if idx + 1 == len(s):
578
+ idx += 1
579
+ return int(s[:idx]) if idx > 0 else -1
580
+
581
+ def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
582
+ rcs: List[str] = []
583
+ for elem in lst.split(","):
584
+ # Repeated id results in empty set
585
+ if elem in rcs:
586
+ return cast(List[str], [])
587
+ # Anything other but prefix is ignored
588
+ if not elem.startswith(prefix):
589
+ break
590
+ rcs.append(elem)
591
+ return rcs
592
+
593
+ if var.startswith("GPU-"):
594
+ return parse_list_with_prefix(var, "GPU-")
595
+ if var.startswith("MIG-"):
596
+ return parse_list_with_prefix(var, "MIG-")
597
+ # CUDA_VISIBLE_DEVICES uses something like strtoul
598
+ # which makes `1gpu2,2ampere` is equivalent to `1,2`
599
+ rc: List[int] = []
600
+ for elem in var.split(","):
601
+ x = _strtoul(elem.strip())
602
+ # Repeated ordinal results in empty set
603
+ if x in rc:
604
+ return cast(List[int], [])
605
+ # Negative value aborts the sequence
606
+ if x < 0:
607
+ break
608
+ rc.append(x)
609
+ return rc
610
+
611
+
612
+ def _raw_device_count_nvml() -> int:
613
+ r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
614
+ from ctypes import byref, c_int, CDLL
615
+
616
+ nvml_h = CDLL("libnvidia-ml.so.1")
617
+ rc = nvml_h.nvmlInit()
618
+ if rc != 0:
619
+ warnings.warn("Can't initialize NVML")
620
+ return -1
621
+ dev_count = c_int(-1)
622
+ rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
623
+ if rc != 0:
624
+ warnings.warn("Can't get nvml device count")
625
+ return -1
626
+ del nvml_h
627
+ return dev_count.value
628
+
629
+
630
+ def _raw_device_uuid_nvml() -> Optional[List[str]]:
631
+ r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
632
+ from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
633
+
634
+ nvml_h = CDLL("libnvidia-ml.so.1")
635
+ rc = nvml_h.nvmlInit()
636
+ if rc != 0:
637
+ warnings.warn("Can't initialize NVML")
638
+ return None
639
+ dev_count = c_int(-1)
640
+ rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
641
+ if rc != 0:
642
+ warnings.warn("Can't get nvml device count")
643
+ return None
644
+ uuids: List[str] = []
645
+ for idx in range(dev_count.value):
646
+ dev_id = c_void_p()
647
+ rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
648
+ if rc != 0:
649
+ warnings.warn("Can't get device handle")
650
+ return None
651
+ buf_len = 96
652
+ buf = create_string_buffer(buf_len)
653
+ rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
654
+ if rc != 0:
655
+ warnings.warn("Can't get device UUID")
656
+ return None
657
+ uuids.append(buf.raw.decode("ascii").strip("\0"))
658
+ del nvml_h
659
+ return uuids
660
+
661
+
662
+ def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
663
+ r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""
664
+
665
+ def uuid_to_orinal(candidate: str, uuids: List[str]) -> int:
666
+ best_match = -1
667
+ for idx, uuid in enumerate(uuids):
668
+ if not uuid.startswith(candidate):
669
+ continue
670
+ # Ambiguous candidate
671
+ if best_match != -1:
672
+ return -1
673
+ best_match = idx
674
+ return best_match
675
+
676
+ rc: List[int] = []
677
+ for candidate in candidates:
678
+ idx = uuid_to_orinal(candidate, uuids)
679
+ # First invalid ordinal stops parsing
680
+ if idx < 0:
681
+ break
682
+ # Duplicates result in empty set
683
+ if idx in rc:
684
+ return cast(List[int], [])
685
+ rc.append(idx)
686
+ return rc
687
+
688
+
689
+ def _device_count_nvml() -> int:
690
+ r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
691
+
692
+ Negative value is returned if NVML discovery or initialization has failed.
693
+ """
694
+ visible_devices = _parse_visible_devices()
695
+ if not visible_devices:
696
+ return 0
697
+ try:
698
+ if type(visible_devices[0]) is str:
699
+ # Skip MIG parsing
700
+ if visible_devices[0].startswith("MIG-"):
701
+ return -1
702
+ uuids = _raw_device_uuid_nvml()
703
+ if uuids is None:
704
+ return -1
705
+ visible_devices = _transform_uuid_to_ordinals(
706
+ cast(List[str], visible_devices), uuids
707
+ )
708
+ else:
709
+ raw_cnt = _raw_device_count_nvml()
710
+ if raw_cnt <= 0:
711
+ return raw_cnt
712
+ # Trim the list up to a maximum available device
713
+ for idx, val in enumerate(visible_devices):
714
+ if cast(int, val) >= raw_cnt:
715
+ return idx
716
+ except OSError:
717
+ return -1
718
+ except AttributeError:
719
+ return -1
720
+ return len(visible_devices)
721
+
722
+
723
+ def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
724
+ r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
725
+ idx = _get_device_index(device, optional=True)
726
+ visible_devices = _parse_visible_devices()
727
+ if type(visible_devices[0]) is str:
728
+ uuids = _raw_device_uuid_nvml()
729
+ if uuids is None:
730
+ raise RuntimeError("Can't get device UUIDs")
731
+ visible_devices = _transform_uuid_to_ordinals(
732
+ cast(List[str], visible_devices), uuids
733
+ )
734
+ visible_devices = cast(List[int], visible_devices)
735
+ if idx < 0 or idx >= len(visible_devices):
736
+ raise RuntimeError(
737
+ f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})"
738
+ )
739
+ return visible_devices[idx]
740
+
741
+
742
+ @lru_cache(maxsize=1)
743
+ def device_count() -> int:
744
+ r"""Return the number of GPUs available."""
745
+ if not _is_compiled():
746
+ return 0
747
+ # bypass _device_count_nvml() if rocm (not supported)
748
+ nvml_count = -1 if torch.version.hip else _device_count_nvml()
749
+ return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
750
+
751
+
752
+ def get_arch_list() -> List[str]:
753
+ r"""Return list CUDA architectures this library was compiled for."""
754
+ if not is_available():
755
+ return []
756
+ arch_flags = torch._C._cuda_getArchFlags()
757
+ if arch_flags is None:
758
+ return []
759
+ return arch_flags.split()
760
+
761
+
762
+ def get_gencode_flags() -> str:
763
+ r"""Return NVCC gencode flags this library was compiled with."""
764
+ arch_list = get_arch_list()
765
+ if len(arch_list) == 0:
766
+ return ""
767
+ arch_list_ = [arch.split("_") for arch in arch_list]
768
+ return " ".join(
769
+ [
770
+ f"-gencode compute=compute_{arch},code={kind}_{arch}"
771
+ for (kind, arch) in arch_list_
772
+ ]
773
+ )
774
+
775
+
776
+ def current_device() -> int:
777
+ r"""Return the index of a currently selected device."""
778
+ _lazy_init()
779
+ return torch._C._cuda_getDevice()
780
+
781
+
782
+ def synchronize(device: _device_t = None) -> None:
783
+ r"""Wait for all kernels in all streams on a CUDA device to complete.
784
+
785
+ Args:
786
+ device (torch.device or int, optional): device for which to synchronize.
787
+ It uses the current device, given by :func:`~torch.cuda.current_device`,
788
+ if :attr:`device` is ``None`` (default).
789
+ """
790
+ _lazy_init()
791
+ with torch.cuda.device(device):
792
+ return torch._C._cuda_synchronize()
793
+
794
+
795
+ def ipc_collect():
796
+ r"""Force collects GPU memory after it has been released by CUDA IPC.
797
+
798
+ .. note::
799
+ Checks if any sent CUDA tensors could be cleaned from the memory. Force
800
+ closes shared memory file used for reference counting if there is no
801
+ active counters. Useful when the producer process stopped actively sending
802
+ tensors and want to release unused memory.
803
+ """
804
+ _lazy_init()
805
+ return torch._C._cuda_ipc_collect()
806
+
807
+
808
+ def current_stream(device: Optional[_device_t] = None) -> Stream:
809
+ r"""Return the currently selected :class:`Stream` for a given device.
810
+
811
+ Args:
812
+ device (torch.device or int, optional): selected device. Returns
813
+ the currently selected :class:`Stream` for the current device, given
814
+ by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
815
+ (default).
816
+ """
817
+ _lazy_init()
818
+ streamdata = torch._C._cuda_getCurrentStream(
819
+ _get_device_index(device, optional=True)
820
+ )
821
+ return Stream(
822
+ stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
823
+ )
824
+
825
+
826
+ def default_stream(device: Optional[_device_t] = None) -> Stream:
827
+ r"""Return the default :class:`Stream` for a given device.
828
+
829
+ Args:
830
+ device (torch.device or int, optional): selected device. Returns
831
+ the default :class:`Stream` for the current device, given by
832
+ :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
833
+ (default).
834
+ """
835
+ _lazy_init()
836
+ streamdata = torch._C._cuda_getDefaultStream(
837
+ _get_device_index(device, optional=True)
838
+ )
839
+ return Stream(
840
+ stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
841
+ )
842
+
843
+
844
+ def current_blas_handle():
845
+ r"""Return cublasHandle_t pointer to current cuBLAS handle"""
846
+ _lazy_init()
847
+ return torch._C._cuda_getCurrentBlasHandle()
848
+
849
+
850
+ def set_sync_debug_mode(debug_mode: Union[int, str]) -> None:
851
+ r"""Set the debug mode for cuda synchronizing operations.
852
+
853
+ Args:
854
+ debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations,
855
+ if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations.
856
+
857
+ Warning:
858
+ This is an experimental feature, and not all synchronizing operations will trigger warning or error. In
859
+ particular, operations in torch.distributed and torch.sparse namespaces are not covered yet.
860
+ """
861
+ _lazy_init()
862
+ if isinstance(debug_mode, str):
863
+ if debug_mode == "default":
864
+ debug_mode = 0
865
+ elif debug_mode == "warn":
866
+ debug_mode = 1
867
+ elif debug_mode == "error":
868
+ debug_mode = 2
869
+ else:
870
+ raise RuntimeError(
871
+ "invalid value of debug_mode, expected one of `default`, `warn`, `error`"
872
+ )
873
+
874
+ torch._C._cuda_set_sync_debug_mode(debug_mode)
875
+
876
+
877
+ def get_sync_debug_mode() -> int:
878
+ r"""Return current value of debug mode for cuda synchronizing operations."""
879
+ _lazy_init()
880
+ return torch._C._cuda_get_sync_debug_mode()
881
+
882
+
883
+ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
884
+ if not _HAS_PYNVML:
885
+ raise ModuleNotFoundError(
886
+ "pynvml does not seem to be installed or it can't be imported."
887
+ ) from _PYNVML_ERR
888
+ from pynvml import NVMLError_DriverNotLoaded
889
+
890
+ try:
891
+ pynvml.nvmlInit()
892
+ except NVMLError_DriverNotLoaded as e:
893
+ raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
894
+
895
+ device = _get_nvml_device_index(device)
896
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
897
+ return handle
898
+
899
+
900
+ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
901
+ r"""Return the percent of time over the past sample period during which global (device)
902
+ memory was being read or written as given by `nvidia-smi`.
903
+
904
+ Args:
905
+ device (torch.device or int, optional): selected device. Returns
906
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
907
+ if :attr:`device` is ``None`` (default).
908
+
909
+ Warning: Each sample period may be between 1 second and 1/6 second,
910
+ depending on the product being queried.
911
+ """
912
+ handle = _get_pynvml_handler()
913
+
914
+ device = _get_nvml_device_index(device)
915
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
916
+ return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
917
+
918
+
919
+ def utilization(device: Optional[Union[Device, int]] = None) -> int:
920
+ r"""Return the percent of time over the past sample period during which one or
921
+ more kernels was executing on the GPU as given by `nvidia-smi`.
922
+
923
+ Args:
924
+ device (torch.device or int, optional): selected device. Returns
925
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
926
+ if :attr:`device` is ``None`` (default).
927
+
928
+ Warning: Each sample period may be between 1 second and 1/6 second,
929
+ depending on the product being queried.
930
+ """
931
+ handle = _get_pynvml_handler(device)
932
+ device = _get_nvml_device_index(device)
933
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
934
+ return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
935
+
936
+
937
+ def temperature(device: Optional[Union[Device, int]] = None) -> int:
938
+ r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
939
+
940
+ The average temperature is computed based on past sample period as given by `nvidia-smi`.
941
+
942
+ Args:
943
+ device (torch.device or int, optional): selected device. Returns
944
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
945
+ if :attr:`device` is ``None`` (default).
946
+
947
+ Warning: Each sample period may be between 1 second and 1/6 second,
948
+ depending on the product being queried.
949
+ """
950
+ handle = _get_pynvml_handler(device)
951
+ # 0 refers to the temperature sensor for the GPU die.
952
+ return pynvml.nvmlDeviceGetTemperature(handle, 0)
953
+
954
+
955
+ def power_draw(device: Optional[Union[Device, int]] = None) -> int:
956
+ r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
957
+ over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
958
+
959
+ Args:
960
+ device (torch.device or int, optional): selected device. Returns
961
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
962
+ if :attr:`device` is ``None`` (default).
963
+
964
+ Warning: Each sample period may be between 1 second and 1/6 second,
965
+ depending on the product being queried.
966
+ """
967
+ handle = _get_pynvml_handler(device)
968
+ return pynvml.nvmlDeviceGetPowerUsage(handle)
969
+
970
+
971
+ def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
972
+ r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`.
973
+
974
+ Args:
975
+ device (torch.device or int, optional): selected device. Returns
976
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
977
+ if :attr:`device` is ``None`` (default).
978
+
979
+ Warning: Each sample period may be between 1 second and 1/6 second,
980
+ depending on the product being queried.
981
+ """
982
+ handle = _get_pynvml_handler(device)
983
+ return pynvml.nvmlDeviceGetClockInfo(handle, 1)
984
+
985
+
986
+ def _get_device(device: Union[int, str, torch.device]) -> torch.device:
987
+ r"""Return the torch.device type object from the passed in device.
988
+
989
+ Args:
990
+ device (torch.device or int): selected device.
991
+ """
992
+ if isinstance(device, str):
993
+ device = torch.device(device)
994
+ elif isinstance(device, int):
995
+ device = torch.device("cuda", device)
996
+ return device
997
+
998
+
999
+ def _get_generator(device: torch.device) -> torch._C.Generator:
1000
+ r"""Return the CUDA Generator object for the given device.
1001
+
1002
+ Args:
1003
+ device (torch.device): selected device.
1004
+ """
1005
+ idx = device.index
1006
+ if idx is None:
1007
+ idx = current_device()
1008
+ return torch.cuda.default_generators[idx]
1009
+
1010
+
1011
+ def _set_rng_state_offset(
1012
+ offset: int, device: Union[int, str, torch.device] = "cuda"
1013
+ ) -> None:
1014
+ r"""Set the random number generator state offset of the specified GPU.
1015
+
1016
+ Args:
1017
+ offset (int): The desired offset
1018
+ device (torch.device or int, optional): The device to set the RNG state.
1019
+ Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
1020
+ """
1021
+ final_device = _get_device(device)
1022
+
1023
+ def cb():
1024
+ default_generator = _get_generator(final_device)
1025
+ default_generator.set_offset(offset)
1026
+
1027
+ _lazy_call(cb)
1028
+
1029
+
1030
+ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int:
1031
+ r"""Return the random number generator state offset of the specified GPU.
1032
+
1033
+ Args:
1034
+ device (torch.device or int, optional): The device to return the RNG state offset of.
1035
+ Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
1036
+
1037
+ .. warning::
1038
+ This function eagerly initializes CUDA.
1039
+ """
1040
+ _lazy_init()
1041
+ final_device = _get_device(device)
1042
+ default_generator = _get_generator(final_device)
1043
+ return default_generator.get_offset()
1044
+
1045
+
1046
+ from .memory import * # noqa: F403
1047
+
1048
+
1049
+ from .random import * # noqa: F403
1050
+
1051
+ ################################################################################
1052
+ # Define Storage and Tensor classes
1053
+ ################################################################################
1054
+
1055
+
1056
+ @staticmethod # type: ignore[misc]
1057
+ def _lazy_new(cls, *args, **kwargs):
1058
+ _lazy_init()
1059
+ # We may need to call lazy init again if we are a forked child
1060
+ # del _CudaBase.__new__
1061
+ return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
1062
+
1063
+
1064
+ class _CudaBase:
1065
+ is_cuda = True
1066
+ is_sparse = False
1067
+
1068
+ def type(self, *args, **kwargs):
1069
+ # We could use a Protocol here to tell mypy that self has `get_device` method
1070
+ # but it is only available in the typing module on Python >= 3.8
1071
+ # or on typing_extensions module on Python >= 3.6
1072
+ with device(self.get_device()): # type: ignore[attr-defined]
1073
+ return super().type(*args, **kwargs) # type: ignore[misc]
1074
+
1075
+ __new__ = _lazy_new
1076
+
1077
+
1078
+ from torch.storage import _LegacyStorage, _warn_typed_storage_removal
1079
+
1080
+
1081
+ class _CudaLegacyStorage(_LegacyStorage):
1082
+ @classmethod
1083
+ def from_buffer(cls, *args, **kwargs):
1084
+ _warn_typed_storage_removal()
1085
+ raise RuntimeError("from_buffer: Not available for CUDA storage")
1086
+
1087
+ @classmethod
1088
+ def _new_with_weak_ptr(cls, *args, **kwargs):
1089
+ raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage")
1090
+
1091
+ @classmethod
1092
+ def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
1093
+ raise RuntimeError("_new_shared_filename: Not available for CUDA storage")
1094
+
1095
+
1096
+ class ByteStorage(_CudaLegacyStorage):
1097
+ @classproperty
1098
+ def dtype(self):
1099
+ _warn_typed_storage_removal()
1100
+ return self._dtype
1101
+
1102
+ @classproperty
1103
+ def _dtype(self):
1104
+ return torch.uint8
1105
+
1106
+
1107
+ class DoubleStorage(_CudaLegacyStorage):
1108
+ @classproperty
1109
+ def dtype(self):
1110
+ _warn_typed_storage_removal()
1111
+ return self._dtype
1112
+
1113
+ @classproperty
1114
+ def _dtype(self):
1115
+ return torch.double
1116
+
1117
+
1118
+ class FloatStorage(_CudaLegacyStorage):
1119
+ @classproperty
1120
+ def dtype(self):
1121
+ _warn_typed_storage_removal()
1122
+ return self._dtype
1123
+
1124
+ @classproperty
1125
+ def _dtype(self):
1126
+ return torch.float
1127
+
1128
+
1129
+ class HalfStorage(_CudaLegacyStorage):
1130
+ @classproperty
1131
+ def dtype(self):
1132
+ _warn_typed_storage_removal()
1133
+ return self._dtype
1134
+
1135
+ @classproperty
1136
+ def _dtype(self):
1137
+ return torch.half
1138
+
1139
+
1140
+ class LongStorage(_CudaLegacyStorage):
1141
+ @classproperty
1142
+ def dtype(self):
1143
+ _warn_typed_storage_removal()
1144
+ return self._dtype
1145
+
1146
+ @classproperty
1147
+ def _dtype(self):
1148
+ return torch.long
1149
+
1150
+
1151
+ class IntStorage(_CudaLegacyStorage):
1152
+ @classproperty
1153
+ def dtype(self):
1154
+ _warn_typed_storage_removal()
1155
+ return self._dtype
1156
+
1157
+ @classproperty
1158
+ def _dtype(self):
1159
+ return torch.int
1160
+
1161
+
1162
+ class ShortStorage(_CudaLegacyStorage):
1163
+ @classproperty
1164
+ def dtype(self):
1165
+ _warn_typed_storage_removal()
1166
+ return self._dtype
1167
+
1168
+ @classproperty
1169
+ def _dtype(self):
1170
+ return torch.short
1171
+
1172
+
1173
+ class CharStorage(_CudaLegacyStorage):
1174
+ @classproperty
1175
+ def dtype(self):
1176
+ _warn_typed_storage_removal()
1177
+ return self._dtype
1178
+
1179
+ @classproperty
1180
+ def _dtype(self):
1181
+ return torch.int8
1182
+
1183
+
1184
+ class BoolStorage(_CudaLegacyStorage):
1185
+ @classproperty
1186
+ def dtype(self):
1187
+ _warn_typed_storage_removal()
1188
+ return self._dtype
1189
+
1190
+ @classproperty
1191
+ def _dtype(self):
1192
+ return torch.bool
1193
+
1194
+
1195
+ class BFloat16Storage(_CudaLegacyStorage):
1196
+ @classproperty
1197
+ def dtype(self):
1198
+ _warn_typed_storage_removal()
1199
+ return self._dtype
1200
+
1201
+ @classproperty
1202
+ def _dtype(self):
1203
+ return torch.bfloat16
1204
+
1205
+
1206
+ class ComplexDoubleStorage(_CudaLegacyStorage):
1207
+ @classproperty
1208
+ def dtype(self):
1209
+ _warn_typed_storage_removal()
1210
+ return self._dtype
1211
+
1212
+ @classproperty
1213
+ def _dtype(self):
1214
+ return torch.cdouble
1215
+
1216
+
1217
+ class ComplexFloatStorage(_CudaLegacyStorage):
1218
+ @classproperty
1219
+ def dtype(self):
1220
+ _warn_typed_storage_removal()
1221
+ return self._dtype
1222
+
1223
+ @classproperty
1224
+ def _dtype(self):
1225
+ return torch.cfloat
1226
+
1227
+
1228
+ del _LegacyStorage
1229
+ del _CudaLegacyStorage
1230
+
1231
+ torch._storage_classes.add(DoubleStorage)
1232
+ torch._storage_classes.add(FloatStorage)
1233
+ torch._storage_classes.add(LongStorage)
1234
+ torch._storage_classes.add(IntStorage)
1235
+ torch._storage_classes.add(ShortStorage)
1236
+ torch._storage_classes.add(CharStorage)
1237
+ torch._storage_classes.add(ByteStorage)
1238
+ torch._storage_classes.add(HalfStorage)
1239
+ torch._storage_classes.add(BoolStorage)
1240
+ torch._storage_classes.add(BFloat16Storage)
1241
+ torch._storage_classes.add(ComplexDoubleStorage)
1242
+ torch._storage_classes.add(ComplexFloatStorage)
1243
+
1244
+
1245
+ class _WrappedTritonKernel:
1246
+ """Just a simple wrapper to store some metadata for testing purposes."""
1247
+
1248
+ def __init__(self, kernel):
1249
+ self.kernel = kernel
1250
+ self.kernel_invoked = False
1251
+
1252
+ def __call__(self, *args, **kwargs):
1253
+ res = self.kernel(*args, **kwargs)
1254
+ self.kernel_invoked = True
1255
+ return res
1256
+
1257
+
1258
+ def _register_triton_kernels():
1259
+ if torch._running_with_deploy():
1260
+ return
1261
+
1262
+ @_WrappedTritonKernel
1263
+ def kernel_impl(*args, **kwargs):
1264
+ from torch.sparse._triton_ops import bsr_dense_mm
1265
+
1266
+ return bsr_dense_mm(*args, skip_checks=True, **kwargs)
1267
+
1268
+ @_WrappedTritonKernel
1269
+ def addmm_kernel_impl(*args, **kwargs):
1270
+ from torch.sparse._triton_ops import bsr_dense_addmm
1271
+
1272
+ return bsr_dense_addmm(*args, skip_checks=True, **kwargs)
1273
+
1274
+ has_triton = importlib.util.find_spec("triton") is not None
1275
+ if has_triton:
1276
+ torch._TritonLibrary.registerOp(
1277
+ "_triton_bsr_dense_mm_out",
1278
+ "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
1279
+ kernel_impl,
1280
+ "SparseCsrCUDA",
1281
+ )
1282
+
1283
+ torch._TritonLibrary.registerOp(
1284
+ "_triton_bsr_dense_addmm_out",
1285
+ (
1286
+ "_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense,"
1287
+ " *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)"
1288
+ ),
1289
+ addmm_kernel_impl,
1290
+ "SparseCsrCUDA",
1291
+ )
1292
+
1293
+
1294
+ _lazy_call(_register_triton_kernels)
1295
+
1296
+
1297
+ from . import amp, jiterator, nvtx, profiler, sparse
1298
+
1299
+ __all__ = [
1300
+ # Typed storage and tensors
1301
+ "BFloat16Storage",
1302
+ "BFloat16Tensor",
1303
+ "BoolStorage",
1304
+ "BoolTensor",
1305
+ "ByteStorage",
1306
+ "ByteTensor",
1307
+ "CharStorage",
1308
+ "CharTensor",
1309
+ "ComplexDoubleStorage",
1310
+ "ComplexFloatStorage",
1311
+ "DoubleStorage",
1312
+ "DoubleTensor",
1313
+ "FloatStorage",
1314
+ "FloatTensor",
1315
+ "HalfStorage",
1316
+ "HalfTensor",
1317
+ "IntStorage",
1318
+ "IntTensor",
1319
+ "LongStorage",
1320
+ "LongTensor",
1321
+ "ShortStorage",
1322
+ "ShortTensor",
1323
+ "CUDAGraph",
1324
+ "CudaError",
1325
+ "DeferredCudaCallError",
1326
+ "Event",
1327
+ "ExternalStream",
1328
+ "OutOfMemoryError",
1329
+ "Stream",
1330
+ "StreamContext",
1331
+ "amp",
1332
+ "caching_allocator_alloc",
1333
+ "caching_allocator_delete",
1334
+ "can_device_access_peer",
1335
+ "check_error",
1336
+ "cudaStatus",
1337
+ "cudart",
1338
+ "current_blas_handle",
1339
+ "current_device",
1340
+ "current_stream",
1341
+ "default_generators",
1342
+ "default_stream",
1343
+ "device",
1344
+ "device_count",
1345
+ "device_of",
1346
+ "empty_cache",
1347
+ "get_allocator_backend",
1348
+ "CUDAPluggableAllocator",
1349
+ "change_current_allocator",
1350
+ "get_arch_list",
1351
+ "get_device_capability",
1352
+ "get_device_name",
1353
+ "get_device_properties",
1354
+ "get_gencode_flags",
1355
+ "get_rng_state",
1356
+ "get_rng_state_all",
1357
+ "get_sync_debug_mode",
1358
+ "graph",
1359
+ "graph_pool_handle",
1360
+ "graphs",
1361
+ "has_half",
1362
+ "has_magma",
1363
+ "init",
1364
+ "initial_seed",
1365
+ "ipc_collect",
1366
+ "is_available",
1367
+ "is_bf16_supported",
1368
+ "is_current_stream_capturing",
1369
+ "is_initialized",
1370
+ "jiterator",
1371
+ "list_gpu_processes",
1372
+ "make_graphed_callables",
1373
+ "manual_seed",
1374
+ "manual_seed_all",
1375
+ "max_memory_allocated",
1376
+ "max_memory_cached",
1377
+ "max_memory_reserved",
1378
+ "mem_get_info",
1379
+ "memory",
1380
+ "memory_allocated",
1381
+ "memory_cached",
1382
+ "memory_reserved",
1383
+ "memory_snapshot",
1384
+ "memory_stats",
1385
+ "memory_stats_as_nested_dict",
1386
+ "memory_summary",
1387
+ "memory_usage",
1388
+ "temperature",
1389
+ "power_draw",
1390
+ "clock_rate",
1391
+ "nccl",
1392
+ "nvtx",
1393
+ "profiler",
1394
+ "random",
1395
+ "reset_accumulated_memory_stats",
1396
+ "reset_max_memory_allocated",
1397
+ "reset_max_memory_cached",
1398
+ "reset_peak_memory_stats",
1399
+ "seed",
1400
+ "seed_all",
1401
+ "set_device",
1402
+ "set_per_process_memory_fraction",
1403
+ "set_rng_state",
1404
+ "set_rng_state_all",
1405
+ "set_stream",
1406
+ "set_sync_debug_mode",
1407
+ "sparse",
1408
+ "stream",
1409
+ "streams",
1410
+ "synchronize",
1411
+ "utilization",
1412
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc ADDED
Binary file (44.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
3
+
4
+ It stores information on accesses to tensors to determine if they are synchronized
5
+ or not. When enabled in a python program and a possible data race is detected, a
6
+ detailed warning will be printed and the program will exit.
7
+
8
+ It can be enabled either by importing this module and calling
9
+ :func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
10
+ environment variable.
11
+ """
12
+
13
+ import enum
14
+ import functools
15
+ import inspect
16
+ import io
17
+ import logging
18
+ import sys
19
+ import textwrap
20
+ import traceback
21
+ from dataclasses import dataclass, field
22
+ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
23
+
24
+ import torch
25
+ import torch.utils._cuda_trace as cuda_trace
26
+ from torch.utils import _pytree as pytree
27
+ from torch.utils._python_dispatch import TorchDispatchMode
28
+
29
+
30
+ DEFAULT_STREAM_ID = 0
31
+
32
+ TK = TypeVar("TK")
33
+ TVa = TypeVar("TVa")
34
+ TVb = TypeVar("TVb")
35
+
36
+ DataPtr = int
37
+ StreamId = int
38
+ EventId = int
39
+ SeqNum = int
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class AccessType(enum.Enum):
45
+ READ = enum.auto()
46
+ WRITE = enum.auto()
47
+
48
+ def __str__(self):
49
+ return "reading from" if self is AccessType.READ else "writing to"
50
+
51
+
52
+ @dataclass
53
+ class Access:
54
+ r"""Stores information about a single access to a tensor by a kernel.
55
+
56
+ Args:
57
+ type: either AccessType.READ or AccessType.Write.
58
+ seq_num: the sequential number of the kernel performing the access.
59
+ stream: the stream id of the stream executing the kernel.
60
+ operator: the schema of the launched kernel, which lists the
61
+ arguments and return type.
62
+ aliases: the arguments in the schema this access corresponds to.
63
+ is_output: Whether the tensor was an output of the kernel.
64
+ stack_trace: the stack summary object captured during access.
65
+ """
66
+
67
+ type: AccessType
68
+ seq_num: SeqNum
69
+ stream: StreamId
70
+ operator: str
71
+ aliases: List[str]
72
+ is_output: bool
73
+ stack_trace: traceback.StackSummary
74
+
75
+
76
+ class SynchronizationError(Exception):
77
+ """Base class for errors detected by CUDA Sanitizer."""
78
+
79
+ pass
80
+
81
+
82
+ class UnsynchronizedAccessError(SynchronizationError):
83
+ """Stores information about two unsynchronized accesses to one data pointer."""
84
+
85
+ def __init__(
86
+ self,
87
+ data_ptr: DataPtr,
88
+ allocation_stack_trace: Optional[traceback.StackSummary],
89
+ current_access: Access,
90
+ previous_access: Access,
91
+ ):
92
+ self.data_ptr = data_ptr
93
+ self.allocation_stack_trace = allocation_stack_trace
94
+ self.current_access = current_access
95
+ self.previous_access = previous_access
96
+
97
+ def __str__(self):
98
+ def format_access(access: Access):
99
+ message.write(f"{access.operator}\n{access.type}")
100
+ if access.aliases:
101
+ message.write(" argument(s) " + ", ".join(access.aliases))
102
+ if access.is_output:
103
+ message.write(", and to")
104
+ if access.is_output:
105
+ message.write(" the output")
106
+ message.write(
107
+ f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
108
+ )
109
+
110
+ with io.StringIO() as message:
111
+ message.write(
112
+ textwrap.dedent(
113
+ f"""\
114
+ ============================
115
+ CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
116
+ Access by stream {self.current_access.stream} during kernel:
117
+ """
118
+ )
119
+ )
120
+ format_access(self.current_access)
121
+
122
+ message.write(
123
+ f"Previous access by stream {self.previous_access.stream} during kernel:\n"
124
+ )
125
+ format_access(self.previous_access)
126
+
127
+ if self.allocation_stack_trace:
128
+ message.write(
129
+ "Tensor was allocated with stack trace:\n"
130
+ f"{''.join(self.allocation_stack_trace.format())}"
131
+ )
132
+ else:
133
+ message.write("Trace for tensor allocation not found.")
134
+ return message.getvalue()
135
+
136
+
137
+ class CUDASanitizerErrors(Exception):
138
+ """Wrapper class for errors reported by CUDA Sanitizer."""
139
+
140
+ def __init__(self, errors: List[SynchronizationError]):
141
+ self.errors = errors
142
+
143
+ def __str__(self):
144
+ return f"detected {len(self.errors)} errors"
145
+
146
+
147
+ @dataclass
148
+ class TensorInfo:
149
+ r"""Stores information about a single tensor and recent accesses to it.
150
+
151
+ Args:
152
+ allocation_stack_trace: the stack summary object captured during tensor
153
+ allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
154
+ reads: list of read accesses to the tensor that were performed since
155
+ the last write.
156
+ write: the last write access to the tensor.
157
+ """
158
+
159
+ allocation_stack_trace: Optional[traceback.StackSummary]
160
+ reads: List[Access] = field(default_factory=list)
161
+ write: Optional[Access] = None
162
+
163
+
164
+ class _TensorsAccessed:
165
+ def __init__(self):
166
+ self.accesses: Dict[DataPtr, TensorInfo] = {}
167
+
168
+ def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
169
+ if data_ptr not in self.accesses:
170
+ logger.info(
171
+ "Found tensor with pointer: %s, but no matching tensor "
172
+ "allocation in the trace. Backfilling the trace now. "
173
+ "Perhaps the sanitizer was enabled after some torch operations?",
174
+ data_ptr,
175
+ )
176
+ self.create_tensor(data_ptr, None)
177
+
178
+ def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
179
+ if data_ptr in self.accesses:
180
+ logger.info(
181
+ "Found duplicate tensor allocation in the trace for tensor with "
182
+ "pointer: %s. Assuming the trace for tensor deallocation "
183
+ "wasn't caught and backfilling it now. "
184
+ "Perhaps the sanitizer was enabled after some torch operations?",
185
+ data_ptr,
186
+ )
187
+ self.delete_tensor(data_ptr)
188
+
189
+ def create_tensor(
190
+ self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
191
+ ) -> None:
192
+ self.accesses[data_ptr] = TensorInfo(stack_trace)
193
+
194
+ def delete_tensor(self, data_ptr: DataPtr) -> None:
195
+ del self.accesses[data_ptr]
196
+
197
+ def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
198
+ return True if self.accesses[data_ptr].reads else False
199
+
200
+ def get_allocation_stack_trace(
201
+ self, data_ptr: DataPtr
202
+ ) -> Optional[traceback.StackSummary]:
203
+ return self.accesses[data_ptr].allocation_stack_trace
204
+
205
+ def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
206
+ return self.accesses[data_ptr].write
207
+
208
+ def get_reads(self, data_ptr: DataPtr) -> List[Access]:
209
+ return self.accesses[data_ptr].reads
210
+
211
+ def add_read(self, data_ptr: DataPtr, access: Access) -> None:
212
+ self.accesses[data_ptr].reads.append(access)
213
+
214
+ def set_write(self, data_ptr: DataPtr, access: Access) -> None:
215
+ self.accesses[data_ptr].write = access
216
+ self.accesses[data_ptr].reads = []
217
+
218
+
219
+ class StreamSynchronizations:
220
+ def __init__(self):
221
+ self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
222
+ self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
223
+ self.host_sync_state: Dict[StreamId, SeqNum] = {}
224
+ self.create_stream(DEFAULT_STREAM_ID)
225
+
226
+ def _ensure_stream_exists(self, stream: StreamId) -> None:
227
+ if stream not in self.current_sync_states:
228
+ logger.info(
229
+ "Found Stream with id: %s, but no matching stream "
230
+ "creation in the trace. Backfilling the trace now. "
231
+ "Perhaps the sanitizer was enabled after some torch operations?",
232
+ stream,
233
+ )
234
+ self.create_stream(stream)
235
+
236
+ def _ensure_event_exists(self, event: EventId) -> None:
237
+ if event not in self.recorded_sync_states:
238
+ logger.info(
239
+ "Found Event with id: %s, but no matching event "
240
+ "creation in the trace. Backfilling the trace now. "
241
+ "Perhaps the sanitizer was enabled after some torch operations?",
242
+ event,
243
+ )
244
+ self.create_event(event)
245
+
246
+ def _ensure_event_does_not_exist(self, event: EventId) -> None:
247
+ if event in self.recorded_sync_states:
248
+ logger.info(
249
+ "Found duplicate event creation in the trace for event with "
250
+ "id: %s. Assuming the trace for event deletion wasn't caught "
251
+ "and backfilling it now. "
252
+ "Perhaps the sanitizer was enabled after some torch operations?",
253
+ event,
254
+ )
255
+ self.delete_event(event)
256
+
257
+ def create_stream(self, stream: StreamId) -> None:
258
+ if stream in self.current_sync_states:
259
+ logger.info(
260
+ "Found duplicate Stream creation in the trace for Stream with "
261
+ "id: %s. PyTorch Streams are only created once, so this "
262
+ "trace entry is ignored.",
263
+ stream,
264
+ )
265
+ else:
266
+ self.host_sync_state[stream] = 0
267
+ self.current_sync_states[stream] = self.host_sync_state.copy()
268
+
269
+ def create_event(self, event: EventId) -> None:
270
+ self._ensure_event_does_not_exist(event)
271
+ self.recorded_sync_states[event] = {}
272
+
273
+ def delete_event(self, event: EventId) -> None:
274
+ self._ensure_event_exists(event)
275
+ del self.recorded_sync_states[event]
276
+
277
+ def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
278
+ self._ensure_stream_exists(stream)
279
+ self.current_sync_states[stream][stream] = seq_num
280
+
281
+ def record_state(self, event: EventId, stream: StreamId) -> None:
282
+ self._ensure_event_exists(event)
283
+ self._ensure_stream_exists(stream)
284
+ self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
285
+
286
+ def _state_wait_for_other(
287
+ self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
288
+ ) -> None:
289
+ for stream, seq_num in other.items():
290
+ state[stream] = max(state.get(stream, -1), seq_num)
291
+
292
+ def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
293
+ self._ensure_stream_exists(stream)
294
+ self._ensure_event_exists(event)
295
+ self._state_wait_for_other(
296
+ self.current_sync_states[stream], self.recorded_sync_states[event]
297
+ )
298
+
299
+ def all_streams_wait_for_event(self, event: EventId) -> None:
300
+ self._ensure_event_exists(event)
301
+ for stream in self.current_sync_states.keys():
302
+ self.stream_wait_for_event(stream, event)
303
+
304
+ self._state_wait_for_other(
305
+ self.host_sync_state, self.recorded_sync_states[event]
306
+ )
307
+
308
+ def all_streams_wait_for_stream(self, stream: StreamId) -> None:
309
+ self._ensure_stream_exists(stream)
310
+ for state in self.current_sync_states.values():
311
+ self._state_wait_for_other(state, self.current_sync_states[stream])
312
+
313
+ self._state_wait_for_other(
314
+ self.host_sync_state, self.current_sync_states[stream]
315
+ )
316
+
317
+ def sync_all_streams(self) -> None:
318
+ for stream, state in self.current_sync_states.items():
319
+ self.host_sync_state[stream] = state[stream]
320
+
321
+ for state in self.current_sync_states.values():
322
+ self._state_wait_for_other(state, self.host_sync_state)
323
+
324
+ def is_ordered_after(
325
+ self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
326
+ ) -> bool:
327
+ self._ensure_stream_exists(current_stream)
328
+ self._ensure_stream_exists(other_stream)
329
+ return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
330
+
331
+
332
+ class EventHandler:
333
+ """Analyzes CSAN trace for synchronization errors.
334
+
335
+ Stores information on each stream's synchronizations with other streams as well
336
+ as tensor accesses to determine whether a given kernel launch might cause a
337
+ data race.
338
+ """
339
+
340
+ def __init__(self):
341
+ self.tensors_accessed = _TensorsAccessed()
342
+ self.syncs = StreamSynchronizations()
343
+ self.seq_num: SeqNum = 0
344
+
345
+ def _handle_kernel_launch(
346
+ self,
347
+ stream: StreamId,
348
+ read_only: Set[DataPtr],
349
+ read_write: Set[DataPtr],
350
+ outputs: Set[DataPtr],
351
+ operator: str,
352
+ tensor_aliases: Dict[int, List[str]],
353
+ ) -> List[SynchronizationError]:
354
+ def check_conflict(
355
+ data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
356
+ ) -> None:
357
+ if previous_access is None:
358
+ return
359
+ if not self.syncs.is_ordered_after(
360
+ current_access.stream, previous_access.seq_num, previous_access.stream
361
+ ):
362
+ error_list.append(
363
+ UnsynchronizedAccessError(
364
+ data_ptr,
365
+ self.tensors_accessed.get_allocation_stack_trace(data_ptr),
366
+ current_access,
367
+ previous_access,
368
+ )
369
+ )
370
+
371
+ error_list: List[SynchronizationError] = []
372
+ self.seq_num += 1
373
+ self.syncs.update_seq_num(stream, self.seq_num)
374
+ stack_trace = traceback.StackSummary.extract(
375
+ traceback.walk_stack(inspect.currentframe()), lookup_lines=False
376
+ )
377
+ # The stack trace generated in this way is in the inverse order, so it must be
378
+ # reversed.
379
+ stack_trace.reverse()
380
+
381
+ for data_ptr in read_only:
382
+ self.tensors_accessed.ensure_tensor_exists(data_ptr)
383
+ current_access = Access(
384
+ AccessType.READ,
385
+ self.seq_num,
386
+ stream,
387
+ operator,
388
+ tensor_aliases[data_ptr],
389
+ data_ptr in outputs,
390
+ stack_trace,
391
+ )
392
+ check_conflict(
393
+ data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
394
+ )
395
+ self.tensors_accessed.add_read(data_ptr, current_access)
396
+
397
+ for data_ptr in read_write:
398
+ self.tensors_accessed.ensure_tensor_exists(data_ptr)
399
+ current_access = Access(
400
+ AccessType.WRITE,
401
+ self.seq_num,
402
+ stream,
403
+ operator,
404
+ tensor_aliases[data_ptr],
405
+ data_ptr in outputs,
406
+ stack_trace,
407
+ )
408
+ if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
409
+ for previous_access in self.tensors_accessed.get_reads(data_ptr):
410
+ check_conflict(data_ptr, current_access, previous_access)
411
+ else:
412
+ check_conflict(
413
+ data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
414
+ )
415
+ self.tensors_accessed.set_write(data_ptr, current_access)
416
+
417
+ return error_list
418
+
419
+ def _handle_event_creation(self, event: EventId) -> None:
420
+ self.syncs.create_event(event)
421
+
422
+ def _handle_event_deletion(self, event: EventId) -> None:
423
+ self.syncs.delete_event(event)
424
+
425
+ def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
426
+ self.syncs.record_state(event, stream)
427
+
428
+ def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
429
+ self.syncs.stream_wait_for_event(stream, event)
430
+
431
+ def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
432
+ self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
433
+ stack_trace = traceback.StackSummary.extract(
434
+ traceback.walk_stack(inspect.currentframe()), lookup_lines=False
435
+ )
436
+ # The stack trace generated in this way is in the inverse order, so it must be
437
+ # reversed.
438
+ stack_trace.reverse()
439
+ self.tensors_accessed.create_tensor(
440
+ data_ptr,
441
+ stack_trace,
442
+ )
443
+
444
+ def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
445
+ self.tensors_accessed.ensure_tensor_exists(data_ptr)
446
+ self.tensors_accessed.delete_tensor(data_ptr)
447
+
448
+ def _handle_stream_creation(self, stream: StreamId) -> None:
449
+ self.syncs.create_stream(stream)
450
+
451
+ def _handle_device_synchronization(self) -> None:
452
+ self.syncs.sync_all_streams()
453
+
454
+ def _handle_stream_synchronization(self, stream: StreamId) -> None:
455
+ self.syncs.all_streams_wait_for_stream(stream)
456
+
457
+ def _handle_event_synchronization(self, event: EventId) -> None:
458
+ self.syncs.all_streams_wait_for_event(event)
459
+
460
+
461
+ def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
462
+ for arg, value in a.items():
463
+ if arg in b:
464
+ yield arg, value, b[arg]
465
+
466
+
467
+ def zip_arguments(
468
+ schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
469
+ ) -> Iterator[Tuple[torch.Argument, Any]]:
470
+ schema_args = schema.arguments[: len(args)]
471
+ schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
472
+
473
+ yield from zip(schema_args, args)
474
+
475
+ for _, argument, value in zip_by_key(schema_kwargs, kwargs):
476
+ yield (argument, value)
477
+
478
+
479
+ class ArgumentHandler:
480
+ def __init__(self):
481
+ self.dataptrs_read: Set[DataPtr] = set()
482
+ self.dataptrs_written: Set[DataPtr] = set()
483
+ self.tensor_aliases: Dict[DataPtr, List[str]] = dict()
484
+ self.outputs: Set[DataPtr] = set()
485
+
486
+ def _handle_argument(
487
+ self,
488
+ value: Any,
489
+ is_write: bool,
490
+ name: Optional[str] = None,
491
+ is_output: bool = False,
492
+ ) -> None:
493
+ if isinstance(value, torch.Tensor) and value.is_cuda:
494
+ data_ptr = value.data_ptr()
495
+ if is_write:
496
+ self.dataptrs_written.add(data_ptr)
497
+ else:
498
+ self.dataptrs_read.add(data_ptr)
499
+
500
+ self.tensor_aliases.setdefault(data_ptr, [])
501
+ if name is not None:
502
+ self.tensor_aliases[data_ptr].append(name)
503
+ if is_output:
504
+ self.outputs.add(data_ptr)
505
+
506
+ def parse_inputs(
507
+ self,
508
+ schema: torch.FunctionSchema,
509
+ args: Tuple[Any, ...],
510
+ kwargs: Dict[str, Any],
511
+ ) -> None:
512
+ for argument, value in zip_arguments(schema, args, kwargs):
513
+ is_write = argument.alias_info is not None and argument.alias_info.is_write
514
+ pytree.tree_map_(
515
+ functools.partial(
516
+ self._handle_argument, is_write=is_write, name=argument.name
517
+ ),
518
+ value,
519
+ )
520
+
521
+ def parse_outputs(self, outputs: Any) -> None:
522
+ pytree.tree_map_(
523
+ functools.partial(self._handle_argument, is_write=True, is_output=True),
524
+ outputs,
525
+ )
526
+
527
+
528
+ class CUDASanitizerDispatchMode(TorchDispatchMode):
529
+ def __init__(self):
530
+ self.event_handler = EventHandler()
531
+ torch._C._activate_cuda_trace()
532
+ cuda_trace.register_callback_for_cuda_event_creation(
533
+ self.event_handler._handle_event_creation
534
+ )
535
+ cuda_trace.register_callback_for_cuda_event_deletion(
536
+ self.event_handler._handle_event_deletion
537
+ )
538
+ cuda_trace.register_callback_for_cuda_event_record(
539
+ self.event_handler._handle_event_record
540
+ )
541
+ cuda_trace.register_callback_for_cuda_event_wait(
542
+ self.event_handler._handle_event_wait
543
+ )
544
+ cuda_trace.register_callback_for_cuda_memory_allocation(
545
+ self.event_handler._handle_memory_allocation
546
+ )
547
+ cuda_trace.register_callback_for_cuda_memory_deallocation(
548
+ self.event_handler._handle_memory_deallocation
549
+ )
550
+ cuda_trace.register_callback_for_cuda_stream_creation(
551
+ self.event_handler._handle_stream_creation
552
+ )
553
+ cuda_trace.register_callback_for_cuda_device_synchronization(
554
+ self.event_handler._handle_device_synchronization
555
+ )
556
+ cuda_trace.register_callback_for_cuda_stream_synchronization(
557
+ self.event_handler._handle_stream_synchronization
558
+ )
559
+ cuda_trace.register_callback_for_cuda_event_synchronization(
560
+ self.event_handler._handle_event_synchronization
561
+ )
562
+
563
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
564
+ if kwargs is None:
565
+ kwargs = {}
566
+
567
+ argument_handler = ArgumentHandler()
568
+ argument_handler.parse_inputs(func._schema, args, kwargs)
569
+
570
+ outputs = func(*args, **kwargs)
571
+
572
+ argument_handler.parse_outputs(outputs)
573
+ errors = self.event_handler._handle_kernel_launch(
574
+ torch.cuda.current_stream().cuda_stream,
575
+ argument_handler.dataptrs_read - argument_handler.dataptrs_written,
576
+ argument_handler.dataptrs_written,
577
+ argument_handler.outputs,
578
+ func._schema,
579
+ argument_handler.tensor_aliases,
580
+ )
581
+ if errors:
582
+ for error in errors:
583
+ print(error, file=sys.stderr)
584
+ raise CUDASanitizerErrors(errors)
585
+
586
+ return outputs
587
+
588
+
589
+ class CUDASanitizer:
590
+ """Manages the lifetime of a CUDASanitizer dispatch mode object.
591
+
592
+ The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
593
+ context manager in the enable function/destructor, respectively. This is to
594
+ explicitly set the lifetime of the dispatch mode object to that of the application.
595
+ This approach was deemed more elegant than using the atexit module.
596
+ """
597
+
598
+ def __init__(self):
599
+ self.dispatch = CUDASanitizerDispatchMode()
600
+ self.enabled = False
601
+
602
+ def enable(self):
603
+ self.dispatch.__enter__()
604
+ self.enabled = True
605
+
606
+ def __del__(self):
607
+ if self.enabled:
608
+ self.dispatch.__exit__(None, None, None)
609
+
610
+
611
+ def enable_cuda_sanitizer():
612
+ """Enable CUDA Sanitizer.
613
+
614
+ The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
615
+ for synchronization errors. All data races found will be printed to the standard
616
+ error output along with stack traces of suspected causes. For best results, the
617
+ sanitizer should be enabled at the very beginning of the program.
618
+ """
619
+ cuda_sanitizer.enable()
620
+
621
+
622
+ cuda_sanitizer = CUDASanitizer()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-311.pyc ADDED
Binary file (1.46 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/comm.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The functions here have been moved to torch.nn.parallel.comm
2
+ from torch.nn.parallel.comm import (
3
+ broadcast,
4
+ broadcast_coalesced,
5
+ gather,
6
+ reduce_add,
7
+ reduce_add_coalesced,
8
+ scatter,
9
+ )
10
+
11
+ __all__ = [
12
+ "broadcast",
13
+ "broadcast_coalesced",
14
+ "reduce_add",
15
+ "reduce_add_coalesced",
16
+ "scatter",
17
+ "gather",
18
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/memory.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""This package adds support for device memory management implemented in CUDA."""
2
+
3
+ import collections
4
+ import contextlib
5
+ import ctypes
6
+ import pickle
7
+ import sys
8
+ import warnings
9
+ from inspect import signature
10
+
11
+ from typing import Any, Dict, Optional, Tuple, Union
12
+
13
+ import torch
14
+ from torch import _C
15
+
16
+ from torch.types import Device
17
+ from .._utils import _dummy_type
18
+ from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized
19
+
20
+ from ._memory_viz import memory as _memory, segments as _segments
21
+
22
+ __all__ = [
23
+ "caching_allocator_alloc",
24
+ "caching_allocator_delete",
25
+ "set_per_process_memory_fraction",
26
+ "empty_cache",
27
+ "memory_stats",
28
+ "memory_stats_as_nested_dict",
29
+ "reset_accumulated_memory_stats",
30
+ "reset_peak_memory_stats",
31
+ "reset_max_memory_allocated",
32
+ "reset_max_memory_cached",
33
+ "memory_allocated",
34
+ "max_memory_allocated",
35
+ "memory_reserved",
36
+ "max_memory_reserved",
37
+ "memory_cached",
38
+ "max_memory_cached",
39
+ "memory_snapshot",
40
+ "memory_summary",
41
+ "list_gpu_processes",
42
+ "mem_get_info",
43
+ "get_allocator_backend",
44
+ "CUDAPluggableAllocator",
45
+ "change_current_allocator",
46
+ ]
47
+
48
+
49
+ if not hasattr(torch._C, "_cuda_CUDAAllocator"):
50
+ # Define dummy base classes
51
+ torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator")
52
+
53
+
54
+ def _host_allocator():
55
+ _lazy_init()
56
+ return torch._C._cuda_cudaHostAllocator()
57
+
58
+
59
+ @contextlib.contextmanager
60
+ def _free_mutex():
61
+ torch._C._cuda_lock_mutex()
62
+ try:
63
+ yield
64
+ finally:
65
+ torch._C._cuda_unlock_mutex()
66
+
67
+
68
+ def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None):
69
+ r"""Perform a memory allocation using the CUDA memory allocator.
70
+
71
+ Memory is allocated for a given device and a stream, this
72
+ function is intended to be used for interoperability with other
73
+ frameworks. Allocated memory is released through
74
+ :func:`~torch.cuda.caching_allocator_delete`.
75
+
76
+ Args:
77
+ size (int): number of bytes to be allocated.
78
+ device (torch.device or int, optional): selected device. If it is
79
+ ``None`` the default CUDA device is used.
80
+ stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
81
+ the default stream for the selected device is used.
82
+
83
+ .. note::
84
+ See :ref:`cuda-memory-management` for more details about GPU memory
85
+ management.
86
+ """
87
+ if device is None:
88
+ device = torch.cuda.current_device()
89
+ device = _get_device_index(device)
90
+ if stream is None:
91
+ stream = torch.cuda.current_stream(device)
92
+ if isinstance(stream, torch.cuda.streams.Stream):
93
+ stream = stream.cuda_stream
94
+ if not isinstance(stream, int):
95
+ raise TypeError(
96
+ "Invalid type for stream argument, must be "
97
+ "`torch.cuda.Stream` or `int` representing a pointer "
98
+ "to a existing stream"
99
+ )
100
+ with torch.cuda.device(device):
101
+ return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream)
102
+
103
+
104
+ def caching_allocator_delete(mem_ptr):
105
+ r"""Delete memory allocated using the CUDA memory allocator.
106
+
107
+ Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`.
108
+ is freed here. The associated device and stream are tracked inside
109
+ the allocator.
110
+
111
+ Args:
112
+ mem_ptr (int): memory address to be freed by the allocator.
113
+
114
+ .. note::
115
+ See :ref:`cuda-memory-management` for more details about GPU memory
116
+ management.
117
+ """
118
+ torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
119
+
120
+
121
+ def set_per_process_memory_fraction(
122
+ fraction, device: Union[Device, int] = None
123
+ ) -> None:
124
+ r"""Set memory fraction for a process.
125
+
126
+ The fraction is used to limit an caching allocator to allocated memory on a CUDA device.
127
+ The allowed value equals the total visible memory multiplied fraction.
128
+ If trying to allocate more than the allowed value in a process, will raise an out of
129
+ memory error in allocator.
130
+
131
+ Args:
132
+ fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
133
+ device (torch.device or int, optional): selected device. If it is
134
+ ``None`` the default CUDA device is used.
135
+ .. note::
136
+ In general, the total available free memory is less than the total capacity.
137
+ """
138
+ _lazy_init()
139
+ if device is None:
140
+ device = torch.cuda.current_device()
141
+ device = _get_device_index(device)
142
+ if not isinstance(fraction, float):
143
+ raise TypeError("Invalid type for fraction argument, must be `float`")
144
+ if fraction < 0 or fraction > 1:
145
+ raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1")
146
+
147
+ torch._C._cuda_setMemoryFraction(fraction, device)
148
+
149
+
150
+ def empty_cache() -> None:
151
+ r"""Release all unoccupied cached memory currently held by the caching
152
+ allocator so that those can be used in other GPU application and visible in
153
+ `nvidia-smi`.
154
+
155
+ .. note::
156
+ :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
157
+ memory available for PyTorch. However, it may help reduce fragmentation
158
+ of GPU memory in certain cases. See :ref:`cuda-memory-management` for
159
+ more details about GPU memory management.
160
+ """
161
+ if is_initialized():
162
+ torch._C._cuda_emptyCache()
163
+
164
+
165
+ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
166
+ r"""Return a dictionary of CUDA memory allocator statistics for a given device.
167
+
168
+ The return value of this function is a dictionary of statistics, each of
169
+ which is a non-negative integer.
170
+
171
+ Core statistics:
172
+
173
+ - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
174
+ number of allocation requests received by the memory allocator.
175
+ - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
176
+ amount of allocated memory.
177
+ - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
178
+ number of reserved segments from ``cudaMalloc()``.
179
+ - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
180
+ amount of reserved memory.
181
+ - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
182
+ number of active memory blocks.
183
+ - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
184
+ amount of active memory.
185
+ - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
186
+ number of inactive, non-releasable memory blocks.
187
+ - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
188
+ amount of inactive, non-releasable memory.
189
+
190
+ For these core statistics, values are broken down as follows.
191
+
192
+ Pool type:
193
+
194
+ - ``all``: combined statistics across all memory pools.
195
+ - ``large_pool``: statistics for the large allocation pool
196
+ (as of October 2019, for size >= 1MB allocations).
197
+ - ``small_pool``: statistics for the small allocation pool
198
+ (as of October 2019, for size < 1MB allocations).
199
+
200
+ Metric type:
201
+
202
+ - ``current``: current value of this metric.
203
+ - ``peak``: maximum value of this metric.
204
+ - ``allocated``: historical total increase in this metric.
205
+ - ``freed``: historical total decrease in this metric.
206
+
207
+ In addition to the core statistics, we also provide some simple event
208
+ counters:
209
+
210
+ - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that
211
+ result in a cache flush and retry.
212
+ - ``"num_ooms"``: number of out-of-memory errors thrown.
213
+
214
+ The caching allocator can be configured via ENV to not split blocks larger than a
215
+ defined size (see Memory Management section of the Cuda Semantics documentation).
216
+ This helps avoid memory fragmentation but may have a performance
217
+ penalty. Additional outputs to assist with tuning and evaluating impact:
218
+
219
+ - ``"max_split_size"``: blocks above this size will not be split.
220
+ - ``"oversize_allocations.{current,peak,allocated,freed}"``:
221
+ number of over-size allocation requests received by the memory allocator.
222
+ - ``"oversize_segments.{current,peak,allocated,freed}"``:
223
+ number of over-size reserved segments from ``cudaMalloc()``.
224
+
225
+ The caching allocator can be configured via ENV to round memory allocations in order
226
+ to reduce fragmentation. Sometimes the overhead from rounding can be higher than
227
+ the fragmentation it helps reduce. The following stat can be used to check if
228
+ rounding adds too much overhead:
229
+
230
+ - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
231
+ memory requested by client code, compare this with allocated_bytes to check if
232
+ allocation rounding adds too much overhead.
233
+
234
+ Args:
235
+ device (torch.device or int, optional): selected device. Returns
236
+ statistics for the current device, given by :func:`~torch.cuda.current_device`,
237
+ if :attr:`device` is ``None`` (default).
238
+
239
+ .. note::
240
+ See :ref:`cuda-memory-management` for more details about GPU memory
241
+ management.
242
+
243
+ .. note::
244
+ With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
245
+ meaningful, and are always reported as zero.
246
+ """
247
+ result = []
248
+
249
+ def _recurse_add_to_result(prefix, obj):
250
+ if isinstance(obj, dict):
251
+ if len(prefix) > 0:
252
+ prefix += "."
253
+ for k, v in obj.items():
254
+ _recurse_add_to_result(prefix + k, v)
255
+ else:
256
+ result.append((prefix, obj))
257
+
258
+ stats = memory_stats_as_nested_dict(device=device)
259
+ _recurse_add_to_result("", stats)
260
+ result.sort()
261
+
262
+ return collections.OrderedDict(result)
263
+
264
+
265
+ def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
266
+ r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
267
+ if not is_initialized():
268
+ return {}
269
+ device = _get_device_index(device, optional=True)
270
+ return torch._C._cuda_memoryStats(device)
271
+
272
+
273
+ def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None:
274
+ r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator.
275
+
276
+ See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
277
+ the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
278
+ `"num_alloc_retries"` and `"num_ooms"`.
279
+
280
+ Args:
281
+ device (torch.device or int, optional): selected device. Returns
282
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
283
+ if :attr:`device` is ``None`` (default).
284
+
285
+ .. note::
286
+ See :ref:`cuda-memory-management` for more details about GPU memory
287
+ management.
288
+ """
289
+ device = _get_device_index(device, optional=True)
290
+ return torch._C._cuda_resetAccumulatedMemoryStats(device)
291
+
292
+
293
+ def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
294
+ r"""Reset the "peak" stats tracked by the CUDA memory allocator.
295
+
296
+ See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
297
+ `"peak"` key in each individual stat dict.
298
+
299
+ Args:
300
+ device (torch.device or int, optional): selected device. Returns
301
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
302
+ if :attr:`device` is ``None`` (default).
303
+
304
+ .. note::
305
+ See :ref:`cuda-memory-management` for more details about GPU memory
306
+ management.
307
+ """
308
+ device = _get_device_index(device, optional=True)
309
+ return torch._C._cuda_resetPeakMemoryStats(device)
310
+
311
+
312
+ def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
313
+ r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device.
314
+
315
+ See :func:`~torch.cuda.max_memory_allocated` for details.
316
+
317
+ Args:
318
+ device (torch.device or int, optional): selected device. Returns
319
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
320
+ if :attr:`device` is ``None`` (default).
321
+
322
+ .. warning::
323
+ This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
324
+ /all/ peak memory stats.
325
+
326
+ .. note::
327
+ See :ref:`cuda-memory-management` for more details about GPU memory
328
+ management.
329
+ """
330
+ warnings.warn(
331
+ "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
332
+ "which resets /all/ peak memory stats.",
333
+ FutureWarning,
334
+ )
335
+ return reset_peak_memory_stats(device=device)
336
+
337
+
338
+ def reset_max_memory_cached(device: Union[Device, int] = None) -> None:
339
+ r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
340
+
341
+ See :func:`~torch.cuda.max_memory_cached` for details.
342
+
343
+ Args:
344
+ device (torch.device or int, optional): selected device. Returns
345
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
346
+ if :attr:`device` is ``None`` (default).
347
+
348
+ .. warning::
349
+ This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
350
+ /all/ peak memory stats.
351
+
352
+ .. note::
353
+ See :ref:`cuda-memory-management` for more details about GPU memory
354
+ management.
355
+ """
356
+ warnings.warn(
357
+ "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
358
+ "which resets /all/ peak memory stats.",
359
+ FutureWarning,
360
+ )
361
+ return reset_peak_memory_stats(device=device)
362
+
363
+
364
+ def memory_allocated(device: Union[Device, int] = None) -> int:
365
+ r"""Return the current GPU memory occupied by tensors in bytes for a given device.
366
+
367
+ Args:
368
+ device (torch.device or int, optional): selected device. Returns
369
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
370
+ if :attr:`device` is ``None`` (default).
371
+
372
+ .. note::
373
+ This is likely less than the amount shown in `nvidia-smi` since some
374
+ unused memory can be held by the caching allocator and some context
375
+ needs to be created on GPU. See :ref:`cuda-memory-management` for more
376
+ details about GPU memory management.
377
+ """
378
+ return memory_stats(device=device).get("allocated_bytes.all.current", 0)
379
+
380
+
381
+ def max_memory_allocated(device: Union[Device, int] = None) -> int:
382
+ r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
383
+
384
+ By default, this returns the peak allocated memory since the beginning of
385
+ this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to
386
+ reset the starting point in tracking this metric. For example, these two
387
+ functions can measure the peak allocated memory usage of each iteration in a
388
+ training loop.
389
+
390
+ Args:
391
+ device (torch.device or int, optional): selected device. Returns
392
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
393
+ if :attr:`device` is ``None`` (default).
394
+
395
+ .. note::
396
+ See :ref:`cuda-memory-management` for more details about GPU memory
397
+ management.
398
+ """
399
+ return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
400
+
401
+
402
+ def memory_reserved(device: Union[Device, int] = None) -> int:
403
+ r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
404
+
405
+ Args:
406
+ device (torch.device or int, optional): selected device. Returns
407
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
408
+ if :attr:`device` is ``None`` (default).
409
+
410
+ .. note::
411
+ See :ref:`cuda-memory-management` for more details about GPU memory
412
+ management.
413
+ """
414
+ return memory_stats(device=device).get("reserved_bytes.all.current", 0)
415
+
416
+
417
+ def max_memory_reserved(device: Union[Device, int] = None) -> int:
418
+ r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
419
+
420
+ By default, this returns the peak cached memory since the beginning of this
421
+ program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset
422
+ the starting point in tracking this metric. For example, these two functions
423
+ can measure the peak cached memory amount of each iteration in a training
424
+ loop.
425
+
426
+ Args:
427
+ device (torch.device or int, optional): selected device. Returns
428
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
429
+ if :attr:`device` is ``None`` (default).
430
+
431
+ .. note::
432
+ See :ref:`cuda-memory-management` for more details about GPU memory
433
+ management.
434
+ """
435
+ return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
436
+
437
+
438
+ def memory_cached(device: Union[Device, int] = None) -> int:
439
+ r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
440
+ warnings.warn(
441
+ "torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved",
442
+ FutureWarning,
443
+ )
444
+ return memory_reserved(device=device)
445
+
446
+
447
+ def max_memory_cached(device: Union[Device, int] = None) -> int:
448
+ r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
449
+ warnings.warn(
450
+ "torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved",
451
+ FutureWarning,
452
+ )
453
+ return max_memory_reserved(device=device)
454
+
455
+
456
+ def memory_snapshot():
457
+ r"""Return a snapshot of the CUDA memory allocator state across all devices.
458
+
459
+ Interpreting the output of this function requires familiarity with the
460
+ memory allocator internals.
461
+
462
+ .. note::
463
+ See :ref:`cuda-memory-management` for more details about GPU memory
464
+ management.
465
+ """
466
+ return torch._C._cuda_memorySnapshot()["segments"]
467
+
468
+
469
+ def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str:
470
+ r"""Return a human-readable printout of the current memory allocator statistics for a given device.
471
+
472
+ This can be useful to display periodically during training, or when
473
+ handling out-of-memory exceptions.
474
+
475
+ Args:
476
+ device (torch.device or int, optional): selected device. Returns
477
+ printout for the current device, given by :func:`~torch.cuda.current_device`,
478
+ if :attr:`device` is ``None`` (default).
479
+ abbreviated (bool, optional): whether to return an abbreviated summary
480
+ (default: False).
481
+
482
+ .. note::
483
+ See :ref:`cuda-memory-management` for more details about GPU memory
484
+ management.
485
+ """
486
+ device = _get_device_index(device, optional=True)
487
+ stats = memory_stats(device=device)
488
+
489
+ def _format_size(sz, pref_sz):
490
+ prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"]
491
+ prefix = prefixes[0]
492
+ for new_prefix in prefixes[1:]:
493
+ if pref_sz < 768 * 1024:
494
+ break
495
+ prefix = new_prefix
496
+ sz //= 1024
497
+ pref_sz /= 1024
498
+ return f"{sz:6d} {prefix}"
499
+
500
+ def _format_count(cnt, pref_cnt):
501
+ prefixes = [" ", "K", "M"]
502
+ prefix = prefixes[0]
503
+ for new_prefix in prefixes[1:]:
504
+ if pref_cnt < 750 * 1000:
505
+ break
506
+ prefix = new_prefix
507
+ cnt //= 1000
508
+ pref_cnt /= 1000
509
+ return f"{cnt:7d} {prefix} "
510
+
511
+ metrics_to_display = [
512
+ ("allocated_bytes", "Allocated memory", _format_size),
513
+ ("active_bytes", "Active memory", _format_size),
514
+ ("requested_bytes", "Requested memory", _format_size),
515
+ ("reserved_bytes", "GPU reserved memory", _format_size),
516
+ ("inactive_split_bytes", "Non-releasable memory", _format_size),
517
+ ("allocation", "Allocations", _format_count),
518
+ ("active", "Active allocs", _format_count),
519
+ ("segment", "GPU reserved segments", _format_count),
520
+ ("inactive_split", "Non-releasable allocs", _format_count),
521
+ ]
522
+
523
+ lines = []
524
+ lines.append("=" * 75)
525
+ lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ")
526
+ lines.append("-" * 75)
527
+ lines.append(
528
+ " {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} "
529
+ )
530
+ lines.append("=" * 75)
531
+ lines.append(
532
+ " Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
533
+ )
534
+
535
+ for metric_key, metric_name, formatter in metrics_to_display:
536
+ lines.append("-" * 75)
537
+ submetrics = [("all", metric_name)]
538
+ if not abbreviated:
539
+ submetrics.append(("large_pool", " from large pool"))
540
+ submetrics.append(("small_pool", " from small pool"))
541
+
542
+ current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
543
+ None,
544
+ None,
545
+ None,
546
+ None,
547
+ )
548
+
549
+ for submetric_key, submetric_name in submetrics:
550
+ prefix = metric_key + "." + submetric_key + "."
551
+
552
+ current = stats[prefix + "current"]
553
+ peak = stats[prefix + "peak"]
554
+ allocated = stats[prefix + "allocated"]
555
+ freed = stats[prefix + "freed"]
556
+
557
+ if current_prefval is None:
558
+ current_prefval = current
559
+ peak_prefval = peak
560
+ allocated_prefval = allocated
561
+ freed_prefval = freed
562
+
563
+ lines.append(
564
+ " {:<21} | {} | {} | {} | {} ".format(
565
+ submetric_name,
566
+ formatter(current, current_prefval),
567
+ formatter(peak, peak_prefval),
568
+ formatter(allocated, allocated_prefval),
569
+ formatter(freed, freed_prefval),
570
+ ),
571
+ )
572
+
573
+ metrics_to_display = [
574
+ ("oversize_allocations", "Oversize allocations", _format_count),
575
+ ("oversize_segments", "Oversize GPU segments", _format_count),
576
+ ]
577
+
578
+ for metric_key, metric_name, formatter in metrics_to_display:
579
+ lines.append("-" * 75)
580
+
581
+ prefix = metric_key + "."
582
+
583
+ current = stats[prefix + "current"]
584
+ peak = stats[prefix + "peak"]
585
+ allocated = stats[prefix + "allocated"]
586
+ freed = stats[prefix + "freed"]
587
+
588
+ lines.append(
589
+ " {:<21} | {} | {} | {} | {} ".format(
590
+ metric_name,
591
+ formatter(current, current),
592
+ formatter(peak, peak),
593
+ formatter(allocated, allocated),
594
+ formatter(freed, freed),
595
+ ),
596
+ )
597
+
598
+ lines.append("=" * 75)
599
+
600
+ fmt_dict = {"_": "", "device": device}
601
+ for k, v in stats.items():
602
+ fmt_dict[k.replace(".", "-")] = v
603
+ return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
604
+
605
+
606
+ def list_gpu_processes(device: Union[Device, int] = None) -> str:
607
+ r"""Return a human-readable printout of the running processes and their GPU memory use for a given device.
608
+
609
+ This can be useful to display periodically during training, or when
610
+ handling out-of-memory exceptions.
611
+
612
+ Args:
613
+ device (torch.device or int, optional): selected device. Returns
614
+ printout for the current device, given by :func:`~torch.cuda.current_device`,
615
+ if :attr:`device` is ``None`` (default).
616
+ """
617
+ try:
618
+ import pynvml # type: ignore[import]
619
+ except ModuleNotFoundError:
620
+ return "pynvml module not found, please install pynvml"
621
+ from pynvml import NVMLError_DriverNotLoaded
622
+
623
+ try:
624
+ pynvml.nvmlInit()
625
+ except NVMLError_DriverNotLoaded:
626
+ return "cuda driver can't be loaded, is cuda enabled?"
627
+ device = _get_nvml_device_index(device)
628
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
629
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
630
+ lines = []
631
+ lines.append(f"GPU:{device}")
632
+ if len(procs) == 0:
633
+ lines.append("no processes are running")
634
+ for p in procs:
635
+ mem = p.usedGpuMemory / (1024 * 1024)
636
+ lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory")
637
+ return "\n".join(lines)
638
+
639
+
640
+ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
641
+ r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
642
+
643
+ Args:
644
+ device (torch.device or int, optional): selected device. Returns
645
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
646
+ if :attr:`device` is ``None`` (default).
647
+
648
+ .. note::
649
+ See :ref:`cuda-memory-management` for more
650
+ details about GPU memory management.
651
+ """
652
+ if device is None:
653
+ device = torch.cuda.current_device()
654
+ device = _get_device_index(device)
655
+ return torch.cuda.cudart().cudaMemGetInfo(device)
656
+
657
+
658
+ def _record_memory_history_legacy(
659
+ enabled: bool,
660
+ record_context=True,
661
+ trace_alloc_max_entries=1,
662
+ trace_alloc_record_context=False,
663
+ device: Union[Device, int] = None,
664
+ record_context_cpp=False,
665
+ ):
666
+ _C._cuda_record_memory_history_legacy(
667
+ enabled,
668
+ record_context,
669
+ trace_alloc_max_entries,
670
+ trace_alloc_record_context,
671
+ record_context_cpp,
672
+ )
673
+
674
+
675
+ def _record_memory_history(enabled="all", *args, **kwargs):
676
+ """Enable recording of stack traces associated with memory
677
+ allocations, so you can tell what allocated any piece of memory in
678
+ :func:`torch.cuda.memory._snapshot()`.
679
+
680
+ In addition too keeping stack traces with each current allocation and free,
681
+ this will also enable recording of a history of all alloc/free events.
682
+
683
+ Use :func:`torch.cuda.memory._snapshot()` to retrieve this information,
684
+ and the tools in `_memory_viz.py` to visualize snapshots.
685
+
686
+ The Python trace collection is fast (2us per trace), so you may consider
687
+ enabling this on production jobs if you anticipate ever having to debug
688
+ memory issues.
689
+
690
+ C++ trace collection is also fast (~50ns/frame), which for many typical programs
691
+ works out to ~2us per trace, but can vary depending on stack depth.
692
+
693
+ Args:
694
+ enabled (Literal[None, "state", "all"], optional):
695
+ `None`, disable recording memory history.
696
+ `"state"`, keep information for currenly allocated memory.
697
+ `"all"`, additionally keep a history of all alloc/free calls.
698
+ Defaults to "all".
699
+ context (Literal[None, "state", "alloc", "all"], optional):
700
+ `None`, Do not record any tracebacks.
701
+ `"state"`, Record tracebacks for currently allocated memory.
702
+ `"alloc"`, additionally keep tracebacks for alloc calls.
703
+ `"all"`, additionally keep tracebacks for free calls.
704
+ Defaults to "all".
705
+ stacks (Literal["python", "all"], optional):
706
+ `"python"`, include Python, TorchScript, and inductor frames in tracebacks
707
+ `"all"`, additionally include C++ frames
708
+ Defaults to "all".
709
+ max_entries (int, optional): Keep a maximum of `max_entries`
710
+ alloc/free events in the recorded history recorded.
711
+ """
712
+ if isinstance(enabled, bool):
713
+ return _record_memory_history_legacy(enabled, *args, **kwargs)
714
+ else:
715
+ return _record_memory_history_impl(enabled, *args, **kwargs)
716
+
717
+
718
+ def _record_memory_history_impl(
719
+ enabled: Optional[str] = "all",
720
+ context: Optional[str] = "all",
721
+ stacks: str = "all",
722
+ max_entries: int = sys.maxsize,
723
+ device: Union[Device, int] = None,
724
+ ):
725
+ _C._cuda_record_memory_history(enabled, context, stacks, max_entries)
726
+
727
+
728
+ _record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
729
+
730
+
731
+ def _snapshot(device: Union[Device, int] = None):
732
+ """Save a snapshot of CUDA memory state at the time it was called.
733
+
734
+ The state is represented as a dictionary with the following structure.
735
+
736
+ .. code-block:: python
737
+
738
+ class Snapshot(TypedDict):
739
+ segments : List[Segment]
740
+ device_traces: List[List[TraceEntry]]
741
+
742
+ class Segment(TypedDict):
743
+ # Segments are memory returned from a cudaMalloc call.
744
+ # The size of reserved memory is the sum of all Segments.
745
+ # Segments are cached and reused for future allocations.
746
+ # If the reuse is smaller than the segment, the segment
747
+ # is split into more then one Block.
748
+ # empty_cache() frees Segments that are entirely inactive.
749
+ address: int
750
+ total_size: int # cudaMalloc'd size of segment
751
+ stream: int
752
+ segment_type: Literal['small', 'large'] # 'large' (>1MB)
753
+ allocated_size: int # size of memory in use
754
+ active_size: int # size of memory in use or in active_awaiting_free state
755
+ blocks : List[Block]
756
+
757
+ class Block(TypedDict):
758
+ # A piece of memory returned from the allocator, or
759
+ # current cached but inactive.
760
+ size: int
761
+ requested_size: int # size requested during malloc, may be smaller than
762
+ # size due to rounding
763
+ address: int
764
+ state: Literal['active_allocated', # used by a tensor
765
+ 'active_awaiting_free', # waiting for another stream to finish using
766
+ # this, then it will become free
767
+ 'inactive',] # free for reuse
768
+ frames: List[Frame] # stack trace from where the allocation occurred
769
+
770
+ class Frame(TypedDict):
771
+ filename: str
772
+ line: int
773
+ name: str
774
+
775
+ class TraceEntry(TypedDict):
776
+ # When `torch.cuda.memory._record_memory_history()` is enabled,
777
+ # the snapshot will contain TraceEntry objects that record each
778
+ # action the allocator took.
779
+ action: Literal[
780
+ 'alloc' # memory allocated
781
+ 'free_requested', # the allocated received a call to free memory
782
+ 'free_completed', # the memory that was requested to be freed is now
783
+ # able to be used in future allocation calls
784
+ 'segment_alloc', # the caching allocator ask cudaMalloc for more memory
785
+ # and added it as a segment in its cache
786
+ 'segment_free', # the caching allocator called cudaFree to return memory
787
+ # to cuda possibly trying free up memory to
788
+ # allocate more segments or because empty_caches was called
789
+ 'oom', # the allocator threw an OOM exception. 'size' is
790
+ # the requested number of bytes that did not succeed
791
+ 'snapshot' # the allocator generated a memory snapshot
792
+ # useful to coorelate a previously taken
793
+ # snapshot with this trace
794
+ ]
795
+ addr: int # not present for OOM
796
+ frames: List[Frame]
797
+ size: int
798
+ stream: int
799
+ device_free: int # only present for OOM, the amount of
800
+ # memory cuda still reports to be free
801
+
802
+ Returns:
803
+ The Snapshot dictionary object
804
+ """
805
+ return _C._cuda_memorySnapshot()
806
+
807
+
808
+ def _dump_snapshot(filename="dump_snapshot.pickle"):
809
+ """
810
+ Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
811
+
812
+ This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz
813
+
814
+ Args:
815
+ filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
816
+ """
817
+ s = _snapshot()
818
+ with open(filename, "wb") as f:
819
+ pickle.dump(s, f)
820
+
821
+
822
+ def _save_segment_usage(filename="output.svg", snapshot=None):
823
+ if snapshot is None:
824
+ snapshot = _snapshot()
825
+ with open(filename, "w") as f:
826
+ f.write(_segments(snapshot))
827
+
828
+
829
+ def _save_memory_usage(filename="output.svg", snapshot=None):
830
+ if snapshot is None:
831
+ snapshot = _snapshot()
832
+ with open(filename, "w") as f:
833
+ f.write(_memory(snapshot))
834
+
835
+
836
+ def _set_allocator_settings(env: str):
837
+ return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
838
+
839
+
840
+ def get_allocator_backend() -> str:
841
+ r"""Return a string describing the active allocator backend as set by
842
+ ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
843
+ ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
844
+ (CUDA's built-in asynchronous allocator).
845
+
846
+ .. note::
847
+ See :ref:`cuda-memory-management` for details on choosing the allocator backend.
848
+ """
849
+ return torch._C._cuda_getAllocatorBackend()
850
+
851
+
852
+ class _CUDAAllocator:
853
+ r"""Wrapper over internal CUDA memory allocators."""
854
+
855
+ def __init__(self, allocator: torch._C._cuda_CUDAAllocator):
856
+ self._allocator = allocator
857
+
858
+ def allocator(self):
859
+ return self._allocator
860
+
861
+
862
+ class CUDAPluggableAllocator(_CUDAAllocator):
863
+ r"""CUDA memory allocator loaded from a so file."""
864
+
865
+ def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
866
+ r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes.
867
+
868
+ To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function.
869
+
870
+ Args:
871
+ path_to_so_file(str): Path in the filesystem to the `.so` file containing
872
+ the allocator functions
873
+ alloc_fn_name(str): Name of the function to perform the memory allocation
874
+ in the so file. The signature must be:
875
+ void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream);
876
+ free_fn_name(str): Name of the function to perform the memory release
877
+ in the so file. The signature must be:
878
+ void free_fn_name(void* ptr, size_t size, cudaStream_t stream);
879
+
880
+ .. warning::
881
+ This is currently supported only in unix OSs
882
+
883
+ .. note::
884
+ See :ref:`cuda-memory-management` for details on creating and using a custom allocator
885
+ """
886
+ allocator = ctypes.CDLL(path_to_so_file)
887
+ alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
888
+ free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
889
+ assert alloc_fn is not None
890
+ assert free_fn is not None
891
+ self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn)
892
+
893
+
894
+ def change_current_allocator(allocator: _CUDAAllocator) -> None:
895
+ r"""Change the currently used memory allocator to be the one provided.
896
+
897
+ If the current allocator has already been used/initialized, this function will error.
898
+
899
+
900
+ Args:
901
+ allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one.
902
+ .. note::
903
+ See :ref:`cuda-memory-management` for details on creating and using a custom allocator
904
+ """
905
+ torch._C._cuda_changeCurrentAllocator(allocator.allocator())
906
+
907
+
908
+ def _get_current_allocator() -> _CUDAAllocator:
909
+ r"""Return the allocator being currently used.
910
+
911
+ .. note::
912
+ See :ref:`cuda-memory-management` for details on creating and using a custom allocator
913
+ """
914
+ return _CUDAAllocator(torch._C._cuda_getAllocator())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc ADDED
Binary file (1.82 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc ADDED
Binary file (54.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc ADDED
Binary file (1.16 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc ADDED
Binary file (262 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc ADDED
Binary file (40.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc ADDED
Binary file (4.44 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc ADDED
Binary file (4.95 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc ADDED
Binary file (34.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc ADDED
Binary file (4.52 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-311.pyc ADDED
Binary file (12.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-311.pyc ADDED
Binary file (42 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (225 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/cudagraphs.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
3
+ from torch.fx.passes.operator_support import OperatorSupport
4
+ from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
5
+ from torch.fx.passes.fake_tensor_prop import FakeTensorProp
6
+ from torch.utils import _pytree as pytree
7
+
8
+ import operator
9
+
10
+ class CudaGraphsSupport(OperatorSupport):
11
+ # TODO: why is submodules passed here
12
+ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
13
+ if node.op not in CALLABLE_NODE_OPS:
14
+ return False
15
+
16
+ if node.target in [torch.ops.aten.embedding_dense_backward.default]:
17
+ return False
18
+
19
+ if node.target in [operator.getitem]:
20
+ return True
21
+
22
+ found_not_cuda = False
23
+
24
+ def meta_fk(meta):
25
+ return meta["val"] if "val" in meta else meta["fake_result"]
26
+
27
+ def find_not_cuda(t):
28
+ nonlocal found_not_cuda
29
+ if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
30
+ found_not_cuda = True
31
+
32
+ for n in node.all_input_nodes:
33
+ pytree.tree_map_(find_not_cuda, meta_fk(n.meta))
34
+
35
+ pytree.tree_map_(find_not_cuda, meta_fk(node.meta))
36
+
37
+ # NB: factory function is accounted for because the result would be
38
+ # cpu or cuda
39
+
40
+ return not found_not_cuda
41
+
42
+ def partition_cudagraphs(gm, inputs):
43
+ """
44
+ Partition an FX graph into sub-GraphModules that can be validly run under
45
+ CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations
46
+ must involve CUDA tensors only/
47
+ """
48
+
49
+ FakeTensorProp(gm).propagate(*inputs)
50
+ supported_ops = CudaGraphsSupport()
51
+ # TODO: single node partition may be wrong due to the pessimization
52
+ # from copying in and out the data. Check in benchmarks, perhaps
53
+ partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
54
+ partitions = partitioner.propose_partitions()
55
+ fused_graph = partitioner.fuse_partitions(partitions)
56
+ return fused_graph
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-311.pyc ADDED
Binary file (3.96 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (222 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-311.pyc ADDED
Binary file (5.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/test_pass_manager.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from ..pass_manager import (
4
+ inplace_wrapper,
5
+ PassManager,
6
+ these_before_those_pass_constraint,
7
+ this_before_that_pass_constraint,
8
+ )
9
+
10
+
11
+ class TestPassManager(unittest.TestCase):
12
+ def test_pass_manager_builder(self) -> None:
13
+ passes = [lambda x: 2 * x for _ in range(10)]
14
+ pm = PassManager(passes)
15
+ pm.validate()
16
+
17
+ def test_this_before_that_pass_constraint(self) -> None:
18
+ passes = [lambda x: 2 * x for _ in range(10)]
19
+ pm = PassManager(passes)
20
+
21
+ # add unfulfillable constraint
22
+ pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
23
+
24
+ self.assertRaises(RuntimeError, pm.validate)
25
+
26
+ def test_these_before_those_pass_constraint(self) -> None:
27
+ passes = [lambda x: 2 * x for _ in range(10)]
28
+ constraint = these_before_those_pass_constraint(passes[-1], passes[0])
29
+ pm = PassManager(
30
+ [inplace_wrapper(p) for p in passes]
31
+ )
32
+
33
+ # add unfulfillable constraint
34
+ pm.add_constraint(constraint)
35
+
36
+ self.assertRaises(RuntimeError, pm.validate)
37
+
38
+ def test_two_pass_managers(self) -> None:
39
+ """Make sure we can construct the PassManager twice and not share any
40
+ state between them"""
41
+
42
+ passes = [lambda x: 2 * x for _ in range(3)]
43
+ constraint = these_before_those_pass_constraint(passes[0], passes[1])
44
+ pm1 = PassManager()
45
+ for p in passes:
46
+ pm1.add_pass(p)
47
+ pm1.add_constraint(constraint)
48
+ output1 = pm1(1)
49
+ self.assertEqual(output1, 2 ** 3)
50
+
51
+ passes = [lambda x: 3 * x for _ in range(3)]
52
+ constraint = these_before_those_pass_constraint(passes[0], passes[1])
53
+ pm2 = PassManager()
54
+ for p in passes:
55
+ pm2.add_pass(p)
56
+ pm2.add_constraint(constraint)
57
+ output2 = pm2(1)
58
+ self.assertEqual(output2, 3 ** 3)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .common import lift_subgraph_as_module, HolderModule, compare_graphs