Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py +258 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py +218 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py +233 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py +82 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py +1412 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py +622 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/comm.py +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/memory.py +914 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/cudagraphs.py +56 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/test_pass_manager.py +58 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__init__.py +1 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-311.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (229 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-311.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, List, Optional
|
| 7 |
+
|
| 8 |
+
import sympy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from ...codecache import cache_dir
|
| 13 |
+
from ...config import cuda as inductor_cuda_config
|
| 14 |
+
from ...ir import Layout
|
| 15 |
+
from .cuda_env import get_cuda_arch, get_cuda_version
|
| 16 |
+
|
| 17 |
+
log = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
|
| 21 |
+
for cutlass_module in cutlass_modules:
|
| 22 |
+
content = content.replace(
|
| 23 |
+
f"from {cutlass_module} import ",
|
| 24 |
+
f"from cutlass_library.{cutlass_module} import ",
|
| 25 |
+
)
|
| 26 |
+
return content
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _gen_cutlass_file(
|
| 30 |
+
file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str
|
| 31 |
+
) -> None:
|
| 32 |
+
orig_full_path = os.path.abspath(os.path.join(src_dir, file_name))
|
| 33 |
+
text = ""
|
| 34 |
+
with open(orig_full_path) as f:
|
| 35 |
+
text = f.read()
|
| 36 |
+
text = _rename_cutlass_import(text, cutlass_modules)
|
| 37 |
+
dst_full_path = os.path.abspath(
|
| 38 |
+
os.path.join(
|
| 39 |
+
dst_dir,
|
| 40 |
+
file_name,
|
| 41 |
+
)
|
| 42 |
+
)
|
| 43 |
+
with open(dst_full_path, "w") as f:
|
| 44 |
+
f.write(text)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@functools.lru_cache(None)
|
| 48 |
+
def try_import_cutlass() -> bool:
|
| 49 |
+
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
|
| 50 |
+
# This is a temporary hack to avoid CUTLASS module naming conflicts.
|
| 51 |
+
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
|
| 52 |
+
|
| 53 |
+
cutlass_py_full_path = os.path.abspath(
|
| 54 |
+
os.path.join(inductor_cuda_config.cutlass_dir, "python/cutlass_library")
|
| 55 |
+
)
|
| 56 |
+
tmp_cutlass_py_full_path = os.path.abspath(
|
| 57 |
+
os.path.join(cache_dir(), "torch_cutlass_library")
|
| 58 |
+
)
|
| 59 |
+
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
|
| 60 |
+
|
| 61 |
+
if os.path.isdir(cutlass_py_full_path):
|
| 62 |
+
if tmp_cutlass_py_full_path not in sys.path:
|
| 63 |
+
if os.path.exists(dst_link):
|
| 64 |
+
assert os.path.islink(
|
| 65 |
+
dst_link
|
| 66 |
+
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
| 67 |
+
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
|
| 68 |
+
cutlass_py_full_path
|
| 69 |
+
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
|
| 70 |
+
else:
|
| 71 |
+
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
|
| 72 |
+
os.symlink(cutlass_py_full_path, dst_link)
|
| 73 |
+
sys.path.append(tmp_cutlass_py_full_path)
|
| 74 |
+
try:
|
| 75 |
+
import cutlass_library.generator # noqa: F401
|
| 76 |
+
import cutlass_library.library # noqa: F401
|
| 77 |
+
import cutlass_library.manifest # noqa: F401
|
| 78 |
+
|
| 79 |
+
return True
|
| 80 |
+
|
| 81 |
+
except ImportError as e:
|
| 82 |
+
log.debug(
|
| 83 |
+
"Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.",
|
| 84 |
+
str(e),
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
log.debug(
|
| 88 |
+
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
|
| 89 |
+
cutlass_py_full_path,
|
| 90 |
+
)
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _normalize_cuda_arch(arch: str) -> str:
|
| 95 |
+
if int(arch) >= 90:
|
| 96 |
+
return "90"
|
| 97 |
+
elif int(arch) >= 80:
|
| 98 |
+
return "80"
|
| 99 |
+
elif int(arch) >= 75:
|
| 100 |
+
return "75"
|
| 101 |
+
elif int(arch) >= 70:
|
| 102 |
+
return "70"
|
| 103 |
+
else:
|
| 104 |
+
raise NotImplementedError(f"Unsupported cuda arch: {arch}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class CUTLASSArgs:
|
| 109 |
+
"""
|
| 110 |
+
CUTLASS args used to initialize a CUTLASS Manifest.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
architectures: Optional[str] = None
|
| 114 |
+
cuda_version: Optional[str] = None
|
| 115 |
+
|
| 116 |
+
operations = "all"
|
| 117 |
+
build_dir = ""
|
| 118 |
+
curr_build_dir = ""
|
| 119 |
+
generator_target = ""
|
| 120 |
+
kernels = "all"
|
| 121 |
+
ignore_kernels = ""
|
| 122 |
+
# TODO: these three look dead?
|
| 123 |
+
kernel_filter_file: None = None
|
| 124 |
+
selected_kernel_list: None = None
|
| 125 |
+
interface_dir: None = None
|
| 126 |
+
filter_by_cc = True
|
| 127 |
+
disable_full_archs_compilation = False
|
| 128 |
+
|
| 129 |
+
def __post_init__(self):
|
| 130 |
+
if self.architectures is None or self.cuda_version is None:
|
| 131 |
+
raise RuntimeError(
|
| 132 |
+
f"{self.architectures=} or {self.cuda_version=} is None!"
|
| 133 |
+
)
|
| 134 |
+
self.architectures = _normalize_cuda_arch(self.architectures)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@functools.lru_cache(None)
|
| 138 |
+
def _gen_ops_cached(arch, version) -> List[Any]:
|
| 139 |
+
# Note: Cache needs to be specific for cuda architecture and version
|
| 140 |
+
|
| 141 |
+
# Import cutlass python scripts.
|
| 142 |
+
assert try_import_cutlass()
|
| 143 |
+
import cutlass_library.generator as cutlass_generator
|
| 144 |
+
import cutlass_library.manifest as cutlass_manifest
|
| 145 |
+
|
| 146 |
+
if arch is None or version is None:
|
| 147 |
+
log.error(
|
| 148 |
+
"Cannot detect cuda arch %s or cuda version %s. "
|
| 149 |
+
"Will discard all cutlass ops. "
|
| 150 |
+
"Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.",
|
| 151 |
+
arch,
|
| 152 |
+
version,
|
| 153 |
+
)
|
| 154 |
+
return list()
|
| 155 |
+
arch = _normalize_cuda_arch(arch)
|
| 156 |
+
args = CUTLASSArgs(architectures=arch, cuda_version=version)
|
| 157 |
+
manifest = cutlass_manifest.Manifest(args)
|
| 158 |
+
|
| 159 |
+
if arch == "90":
|
| 160 |
+
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
|
| 161 |
+
cutlass_generator.GenerateSM80(manifest, args.cuda_version)
|
| 162 |
+
else:
|
| 163 |
+
try:
|
| 164 |
+
func = getattr(cutlass_generator, "GenerateSM" + arch)
|
| 165 |
+
func(manifest, args.cuda_version)
|
| 166 |
+
except AttributeError as e:
|
| 167 |
+
raise NotImplementedError(
|
| 168 |
+
"Arch " + arch + " is not supported by current cutlass lib."
|
| 169 |
+
) from e
|
| 170 |
+
return manifest.operations
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def gen_ops() -> List[Any]:
|
| 174 |
+
"""
|
| 175 |
+
Generates all supported CUTLASS operations.
|
| 176 |
+
"""
|
| 177 |
+
arch = get_cuda_arch()
|
| 178 |
+
version = get_cuda_version()
|
| 179 |
+
return _gen_ops_cached(arch, version)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def dtype_match(
|
| 183 |
+
torch_dtype: Optional[torch.dtype],
|
| 184 |
+
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821
|
| 185 |
+
) -> bool:
|
| 186 |
+
# Import cutlass python scripts.
|
| 187 |
+
assert try_import_cutlass()
|
| 188 |
+
import cutlass_library
|
| 189 |
+
|
| 190 |
+
if torch_dtype == torch.float:
|
| 191 |
+
return (
|
| 192 |
+
cutlass_dtype == cutlass_library.library.DataType.f32
|
| 193 |
+
or cutlass_dtype == cutlass_library.library.DataType.tf32
|
| 194 |
+
)
|
| 195 |
+
elif torch_dtype == torch.half:
|
| 196 |
+
return cutlass_dtype == cutlass_library.library.DataType.f16
|
| 197 |
+
elif torch_dtype == torch.bfloat16:
|
| 198 |
+
return cutlass_dtype == cutlass_library.library.DataType.bf16
|
| 199 |
+
else:
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_accumulator_dtype(
|
| 204 |
+
input_torch_dtypes: List[torch.dtype],
|
| 205 |
+
) -> Optional[torch.dtype]:
|
| 206 |
+
"""
|
| 207 |
+
Given a list of input torch dtypes, returns the inferred accumulator torch dtype.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
if len(input_torch_dtypes) == 0:
|
| 211 |
+
return None
|
| 212 |
+
torch_dtype = input_torch_dtypes[0]
|
| 213 |
+
for dtype in input_torch_dtypes[1:]:
|
| 214 |
+
if torch_dtype != dtype:
|
| 215 |
+
raise RuntimeError(f"Unmatched input dtypes: {torch_dtype=}, {dtype=}")
|
| 216 |
+
if torch_dtype == torch.half:
|
| 217 |
+
if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction:
|
| 218 |
+
return torch_dtype
|
| 219 |
+
else:
|
| 220 |
+
return torch.float
|
| 221 |
+
if torch_dtype in {torch.bfloat16, torch.float}:
|
| 222 |
+
return torch.float
|
| 223 |
+
raise NotImplementedError(f"Unsupported data type: {input_torch_dtypes=}")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_alignments(torch_dtype: torch.dtype) -> List[int]:
|
| 227 |
+
"""
|
| 228 |
+
Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype.
|
| 229 |
+
CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
if torch_dtype in (torch.half, torch.bfloat16):
|
| 233 |
+
return [8, 4, 2, 1]
|
| 234 |
+
elif torch_dtype == torch.float:
|
| 235 |
+
return [4, 2, 1]
|
| 236 |
+
else:
|
| 237 |
+
raise NotImplementedError(f"unsupported {torch_dtype=} for alignments")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def get_max_alignment(inductor_layout: Layout) -> int:
|
| 241 |
+
"""
|
| 242 |
+
Returns the max alignment (in terms of number of elements) for a given Inductor Layout.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
dtype = inductor_layout.dtype
|
| 246 |
+
size = inductor_layout.size
|
| 247 |
+
offset = inductor_layout.offset
|
| 248 |
+
|
| 249 |
+
def is_static_int(number):
|
| 250 |
+
return isinstance(number, (int, sympy.Integer))
|
| 251 |
+
|
| 252 |
+
if is_static_int(size[-1]) and is_static_int(offset):
|
| 253 |
+
alignments = get_alignments(dtype)
|
| 254 |
+
for alignment in alignments:
|
| 255 |
+
if int(size[-1]) % alignment == 0 and int(offset) % alignment == 0:
|
| 256 |
+
return alignment
|
| 257 |
+
|
| 258 |
+
return 1
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-311.pyc
ADDED
|
Binary file (36.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-311.pyc
ADDED
|
Binary file (31.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-311.pyc
ADDED
|
Binary file (66.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc
ADDED
|
Binary file (70.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-311.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-311.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-311.pyc
ADDED
|
Binary file (52.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-311.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-311.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-311.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-311.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 36 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 37 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 38 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 39 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 40 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 41 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 42 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 43 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 44 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 45 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 46 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 47 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 48 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 49 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 50 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 51 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 52 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 53 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 54 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 55 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 56 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 57 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 58 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 59 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 60 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 61 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 62 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 63 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 64 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 65 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 66 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 67 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 68 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
| 69 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 70 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 71 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 72 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
|
| 73 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 74 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 75 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 76 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 77 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 78 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 79 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 80 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 81 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 82 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 83 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 84 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 85 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 86 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 87 |
+
_sfdp_pattern_14_training = MultiOutputPattern([view_default_5,
|
| 88 |
+
permute_default_6,
|
| 89 |
+
permute_default_9,
|
| 90 |
+
permute_default_11,
|
| 91 |
+
None,
|
| 92 |
+
None
|
| 93 |
+
])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 97 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 98 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 99 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 100 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 101 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 102 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 103 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 104 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 105 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 106 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 107 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 108 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 109 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 110 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 111 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 112 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 113 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 114 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 115 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 116 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 117 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 118 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 119 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 120 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 121 |
+
_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 125 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 126 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 127 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 128 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 129 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 130 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 131 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 132 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 133 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 134 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 135 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 136 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 137 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 138 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 139 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 140 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 141 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 142 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 143 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 144 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 145 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 146 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 147 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 148 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 149 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 150 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 151 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 152 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 153 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 154 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 155 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 156 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 157 |
+
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
| 158 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 159 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 160 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
| 161 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
| 162 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2)
|
| 163 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 164 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1)
|
| 165 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
| 166 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 167 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
|
| 168 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 169 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 170 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 171 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 172 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 173 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 174 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 175 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 176 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 177 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 178 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 179 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 180 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 181 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 182 |
+
_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5,
|
| 183 |
+
permute_default_6,
|
| 184 |
+
permute_default_9,
|
| 185 |
+
permute_default_11,
|
| 186 |
+
None,
|
| 187 |
+
None
|
| 188 |
+
])
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 192 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 193 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 194 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 195 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 196 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 197 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 198 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 199 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 200 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 201 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 202 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 203 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 204 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 205 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 206 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 207 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 208 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 209 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 210 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 211 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 212 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 213 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 214 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 215 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 216 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 217 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 218 |
+
_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python
|
| 7 |
+
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch._inductor
|
| 11 |
+
|
| 12 |
+
aten = torch.ops.aten
|
| 13 |
+
prims = torch.ops.prims
|
| 14 |
+
|
| 15 |
+
from torch._inductor.pattern_matcher import (
|
| 16 |
+
Arg,
|
| 17 |
+
CallFunction,
|
| 18 |
+
CallFunctionVarArgs,
|
| 19 |
+
CallMethod,
|
| 20 |
+
CallMethodVarArgs,
|
| 21 |
+
CallModule,
|
| 22 |
+
CallModuleVarArgs,
|
| 23 |
+
ExclusiveKeywordArg,
|
| 24 |
+
Ignored,
|
| 25 |
+
KeywordArg,
|
| 26 |
+
ListOf,
|
| 27 |
+
MultiOutputPattern,
|
| 28 |
+
PatternExpr,
|
| 29 |
+
RepeatedExpr,
|
| 30 |
+
_TargetArgsExpr,
|
| 31 |
+
_TargetExpr,
|
| 32 |
+
_TargetExprVarArgs,
|
| 33 |
+
)
|
| 34 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 35 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 37 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 38 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 39 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 40 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 41 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 42 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 43 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 44 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 45 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 46 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 47 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
|
| 48 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 49 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 50 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 51 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 52 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 53 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 54 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 55 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 56 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 57 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 58 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 59 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 60 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 61 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 62 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 63 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 64 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 65 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 66 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 67 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
| 68 |
+
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
| 69 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 70 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 71 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 72 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
| 73 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 74 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 75 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 76 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 77 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 78 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
| 79 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 80 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 81 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 82 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored())
|
| 83 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 84 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 85 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 86 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 87 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 88 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 89 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 90 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 91 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 92 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 93 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 94 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 95 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 96 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 97 |
+
_sfdp_pattern_7_training = MultiOutputPattern([view_default_5,
|
| 98 |
+
permute_default_6,
|
| 99 |
+
permute_default_9,
|
| 100 |
+
permute_default_11,
|
| 101 |
+
None
|
| 102 |
+
])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 106 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 107 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 108 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 109 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 110 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 111 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 112 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 113 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 114 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 115 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 116 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
|
| 117 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 118 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 119 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 120 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 121 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 122 |
+
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
| 123 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
| 124 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 125 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 126 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 127 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 128 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 129 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 130 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 131 |
+
_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 135 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 136 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 137 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 138 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 139 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 140 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 141 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 142 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 143 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 144 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 145 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 146 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 147 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 148 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 149 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 150 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 151 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 152 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 153 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
| 154 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 155 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 156 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 157 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 158 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 159 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 160 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 161 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 162 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 163 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 164 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 165 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 166 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 167 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 168 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 169 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 170 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 171 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 172 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
| 173 |
+
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
| 174 |
+
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
| 175 |
+
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
| 176 |
+
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
| 177 |
+
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
| 178 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
| 179 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 180 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1)
|
| 181 |
+
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5)
|
| 182 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored())
|
| 183 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
|
| 184 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 185 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 186 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 187 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 188 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 189 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 190 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 191 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 192 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 193 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 194 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 195 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 196 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 197 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 198 |
+
_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5,
|
| 199 |
+
permute_default_6,
|
| 200 |
+
permute_default_9,
|
| 201 |
+
permute_default_11,
|
| 202 |
+
None
|
| 203 |
+
])
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 207 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 208 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 209 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 210 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 211 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 212 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 213 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 214 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 215 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 216 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 217 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 218 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 219 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 220 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 221 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 222 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 223 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 224 |
+
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
| 225 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
| 226 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 227 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 228 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 229 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 230 |
+
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 231 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
| 232 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 233 |
+
_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import mm, mm_common, mm_plus_mm, unpack_mixed_mm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (349 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-311.pyc
ADDED
|
Binary file (7.42 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from ..select_algorithm import autotune_select_algorithm, ChoiceCaller, TritonTemplate
|
| 5 |
+
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
|
| 6 |
+
|
| 7 |
+
log = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
uint4x2_mixed_mm_template = TritonTemplate(
|
| 10 |
+
name="uint4x2_mixed_mm",
|
| 11 |
+
grid=mm_grid,
|
| 12 |
+
source=r"""
|
| 13 |
+
{{def_kernel("A", "B")}}
|
| 14 |
+
M = {{size("A", 0)}}
|
| 15 |
+
N = {{size("B", 1)}}
|
| 16 |
+
K = {{size("A", 1)}}
|
| 17 |
+
stride_am = {{stride("A", 0)}}
|
| 18 |
+
stride_ak = {{stride("A", 1)}}
|
| 19 |
+
stride_bk = {{stride("B", 0)}}
|
| 20 |
+
stride_bn = {{stride("B", 1)}}
|
| 21 |
+
|
| 22 |
+
# based on triton.ops.matmul
|
| 23 |
+
pid = tl.program_id(0)
|
| 24 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 25 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 26 |
+
|
| 27 |
+
# re-order program ID for better L2 performance
|
| 28 |
+
width = GROUP_M * grid_n
|
| 29 |
+
group_id = pid // width
|
| 30 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 31 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 32 |
+
pid_n = (pid % width) // (group_size)
|
| 33 |
+
|
| 34 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 35 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 36 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 37 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 38 |
+
rk = tl.arange(0, BLOCK_K)
|
| 39 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 40 |
+
B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
|
| 41 |
+
b_shifts = 4*(rk%2)
|
| 42 |
+
b_subs = 8*(1-(rk%2))
|
| 43 |
+
|
| 44 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 45 |
+
for k in range(K, 0, -BLOCK_K):
|
| 46 |
+
if EVEN_K:
|
| 47 |
+
a = tl.load(A)
|
| 48 |
+
b = tl.load(B)
|
| 49 |
+
else:
|
| 50 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 51 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 52 |
+
b = ((b >> b_shifts[:, None]) & 0xF) - 8
|
| 53 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 54 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 55 |
+
A += BLOCK_K * stride_ak
|
| 56 |
+
B += BLOCK_K//2 * stride_bk
|
| 57 |
+
|
| 58 |
+
# rematerialize rm and rn to save registers
|
| 59 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 60 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 61 |
+
idx_m = rm[:, None]
|
| 62 |
+
idx_n = rn[None, :]
|
| 63 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 64 |
+
|
| 65 |
+
# inductor generates a suffix
|
| 66 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 67 |
+
""",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
|
| 72 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
|
| 73 |
+
choices: List[ChoiceCaller] = []
|
| 74 |
+
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
| 75 |
+
for config in mm_configs(m, n, k):
|
| 76 |
+
uint4x2_mixed_mm_template.maybe_append_choice(
|
| 77 |
+
choices,
|
| 78 |
+
input_nodes=(mat1, mat2),
|
| 79 |
+
layout=layout,
|
| 80 |
+
**mm_options(config, m, n, k, layout, b_prologue_cast_type),
|
| 81 |
+
)
|
| 82 |
+
return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py
ADDED
|
@@ -0,0 +1,1412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""
|
| 2 |
+
This package adds support for CUDA tensor types.
|
| 3 |
+
|
| 4 |
+
It implements the same function as CPU tensors, but they utilize
|
| 5 |
+
GPUs for computation.
|
| 6 |
+
|
| 7 |
+
It is lazily initialized, so you can always import it, and use
|
| 8 |
+
:func:`is_available()` to determine if your system supports CUDA.
|
| 9 |
+
|
| 10 |
+
:ref:`cuda-semantics` has more details about working with CUDA.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
import contextlib
|
| 15 |
+
import importlib
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import threading
|
| 19 |
+
import traceback
|
| 20 |
+
import warnings
|
| 21 |
+
from functools import lru_cache
|
| 22 |
+
from typing import Any, Callable, cast, List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch._C
|
| 26 |
+
from torch.types import Device
|
| 27 |
+
from .. import device as _device
|
| 28 |
+
from .._utils import _dummy_type, _LazySeedTracker, classproperty
|
| 29 |
+
from ._utils import _get_device_index
|
| 30 |
+
from .graphs import (
|
| 31 |
+
CUDAGraph,
|
| 32 |
+
graph,
|
| 33 |
+
graph_pool_handle,
|
| 34 |
+
is_current_stream_capturing,
|
| 35 |
+
make_graphed_callables,
|
| 36 |
+
)
|
| 37 |
+
from .streams import Event, ExternalStream, Stream
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
from torch._C import _cudart # type: ignore[attr-defined]
|
| 41 |
+
except ImportError:
|
| 42 |
+
_cudart = None
|
| 43 |
+
|
| 44 |
+
_initialized = False
|
| 45 |
+
_tls = threading.local()
|
| 46 |
+
_initialization_lock = threading.Lock()
|
| 47 |
+
_queued_calls: List[
|
| 48 |
+
Tuple[Callable[[], None], List[str]]
|
| 49 |
+
] = [] # don't invoke these until initialization occurs
|
| 50 |
+
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
| 51 |
+
_device_t = Union[_device, str, int, None]
|
| 52 |
+
|
| 53 |
+
_HAS_PYNVML = False
|
| 54 |
+
_PYNVML_ERR = None
|
| 55 |
+
try:
|
| 56 |
+
import pynvml # type: ignore[import]
|
| 57 |
+
|
| 58 |
+
_HAS_PYNVML = True
|
| 59 |
+
except ImportError as err:
|
| 60 |
+
_PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
|
| 61 |
+
|
| 62 |
+
_lazy_seed_tracker = _LazySeedTracker()
|
| 63 |
+
|
| 64 |
+
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
|
| 65 |
+
if hasattr(torch._C, "_CudaDeviceProperties"):
|
| 66 |
+
_CudaDeviceProperties = torch._C._CudaDeviceProperties
|
| 67 |
+
else:
|
| 68 |
+
_CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") # type: ignore[assignment, misc]
|
| 69 |
+
|
| 70 |
+
if hasattr(torch._C, "_cuda_exchangeDevice"):
|
| 71 |
+
_exchange_device = torch._C._cuda_exchangeDevice
|
| 72 |
+
else:
|
| 73 |
+
|
| 74 |
+
def _exchange_device(device: int) -> int:
|
| 75 |
+
if device < 0:
|
| 76 |
+
return -1
|
| 77 |
+
raise RuntimeError("PyTorch was compiled without CUDA support")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if hasattr(torch._C, "_cuda_maybeExchangeDevice"):
|
| 81 |
+
_maybe_exchange_device = torch._C._cuda_maybeExchangeDevice
|
| 82 |
+
else:
|
| 83 |
+
|
| 84 |
+
def _maybe_exchange_device(device: int) -> int:
|
| 85 |
+
if device < 0:
|
| 86 |
+
return -1
|
| 87 |
+
raise RuntimeError("PyTorch was compiled without CUDA support")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
has_half: bool = True
|
| 91 |
+
has_magma: bool = torch._C._has_magma
|
| 92 |
+
|
| 93 |
+
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _is_compiled() -> bool:
|
| 97 |
+
r"""Return true if compile with CUDA support."""
|
| 98 |
+
return hasattr(torch._C, "_cuda_getDeviceCount")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _nvml_based_avail() -> bool:
|
| 102 |
+
return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def is_available() -> bool:
|
| 106 |
+
r"""Return a bool indicating if CUDA is currently available."""
|
| 107 |
+
if not _is_compiled():
|
| 108 |
+
return False
|
| 109 |
+
if _nvml_based_avail():
|
| 110 |
+
# The user has set an env variable to request this availability check that attempts to avoid fork poisoning by
|
| 111 |
+
# using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization
|
| 112 |
+
# fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`)
|
| 113 |
+
return device_count() > 0
|
| 114 |
+
else:
|
| 115 |
+
# The default availability inspection never throws and returns 0 if the driver is missing or can't
|
| 116 |
+
# be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver
|
| 117 |
+
# API via `cuInit`
|
| 118 |
+
return torch._C._cuda_getDeviceCount() > 0
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def is_bf16_supported():
|
| 122 |
+
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
|
| 123 |
+
# Check for ROCm, if true return true, no ROCM_VERSION check required,
|
| 124 |
+
# since it is supported on AMD GPU archs.
|
| 125 |
+
if torch.version.hip:
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
device = torch.cuda.current_device()
|
| 129 |
+
|
| 130 |
+
# Check for CUDA version and device compute capability.
|
| 131 |
+
# This is a fast way to check for it.
|
| 132 |
+
cuda_version = torch.version.cuda
|
| 133 |
+
if (
|
| 134 |
+
cuda_version is not None
|
| 135 |
+
and int(cuda_version.split(".")[0]) >= 11
|
| 136 |
+
and torch.cuda.get_device_properties(device).major >= 8
|
| 137 |
+
):
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
# Finally try to create a bfloat16 device.
|
| 141 |
+
return _check_bf16_tensor_supported(device)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@lru_cache(maxsize=16)
|
| 145 |
+
def _check_bf16_tensor_supported(device: _device_t):
|
| 146 |
+
try:
|
| 147 |
+
torch.tensor([1.0], dtype=torch.bfloat16, device=device)
|
| 148 |
+
return True
|
| 149 |
+
except Exception:
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _sleep(cycles):
|
| 154 |
+
torch._C._cuda_sleep(cycles)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _check_capability():
|
| 158 |
+
incorrect_binary_warn = """
|
| 159 |
+
Found GPU%d %s which requires CUDA_VERSION >= %d to
|
| 160 |
+
work properly, but your PyTorch was compiled
|
| 161 |
+
with CUDA_VERSION %d. Please install the correct PyTorch binary
|
| 162 |
+
using instructions from https://pytorch.org
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
old_gpu_warn = """
|
| 166 |
+
Found GPU%d %s which is of cuda capability %d.%d.
|
| 167 |
+
PyTorch no longer supports this GPU because it is too old.
|
| 168 |
+
The minimum cuda capability supported by this library is %d.%d.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
if torch.version.cuda is not None: # on ROCm we don't want this check
|
| 172 |
+
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
|
| 173 |
+
for d in range(device_count()):
|
| 174 |
+
capability = get_device_capability(d)
|
| 175 |
+
major = capability[0]
|
| 176 |
+
minor = capability[1]
|
| 177 |
+
name = get_device_name(d)
|
| 178 |
+
current_arch = major * 10 + minor
|
| 179 |
+
min_arch = min(
|
| 180 |
+
(int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list()),
|
| 181 |
+
default=35,
|
| 182 |
+
)
|
| 183 |
+
if current_arch < min_arch:
|
| 184 |
+
warnings.warn(
|
| 185 |
+
old_gpu_warn
|
| 186 |
+
% (d, name, major, minor, min_arch // 10, min_arch % 10)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _check_cubins():
|
| 191 |
+
incompatible_device_warn = """
|
| 192 |
+
{} with CUDA capability sm_{} is not compatible with the current PyTorch installation.
|
| 193 |
+
The current PyTorch install supports CUDA capabilities {}.
|
| 194 |
+
If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/
|
| 195 |
+
"""
|
| 196 |
+
if torch.version.cuda is None: # on ROCm we don't want this check
|
| 197 |
+
return
|
| 198 |
+
arch_list = get_arch_list()
|
| 199 |
+
if len(arch_list) == 0:
|
| 200 |
+
return
|
| 201 |
+
supported_sm = [int(arch.split("_")[1]) for arch in arch_list if "sm_" in arch]
|
| 202 |
+
for idx in range(device_count()):
|
| 203 |
+
cap_major, cap_minor = get_device_capability(idx)
|
| 204 |
+
# NVIDIA GPU compute architectures are backward compatible within major version
|
| 205 |
+
supported = any(sm // 10 == cap_major for sm in supported_sm)
|
| 206 |
+
if not supported:
|
| 207 |
+
device_name = get_device_name(idx)
|
| 208 |
+
capability = cap_major * 10 + cap_minor
|
| 209 |
+
warnings.warn(
|
| 210 |
+
incompatible_device_warn.format(
|
| 211 |
+
device_name, capability, " ".join(arch_list), device_name
|
| 212 |
+
)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def is_initialized():
|
| 217 |
+
r"""Return whether PyTorch's CUDA state has been initialized."""
|
| 218 |
+
return _initialized and not _is_in_bad_fork()
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _lazy_call(callable, **kwargs):
|
| 222 |
+
if is_initialized():
|
| 223 |
+
callable()
|
| 224 |
+
else:
|
| 225 |
+
# TODO(torch_deploy): this accesses linecache, which attempts to read the
|
| 226 |
+
# file system to get traceback info. Patch linecache or do something
|
| 227 |
+
# else here if this ends up being important.
|
| 228 |
+
global _lazy_seed_tracker
|
| 229 |
+
if kwargs.get("seed_all", False):
|
| 230 |
+
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
|
| 231 |
+
elif kwargs.get("seed", False):
|
| 232 |
+
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
|
| 233 |
+
else:
|
| 234 |
+
# Don't store the actual traceback to avoid memory cycle
|
| 235 |
+
_queued_calls.append((callable, traceback.format_stack()))
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
_lazy_call(_check_capability)
|
| 239 |
+
_lazy_call(_check_cubins)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class DeferredCudaCallError(Exception):
|
| 243 |
+
pass
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
OutOfMemoryError = torch._C._OutOfMemoryError
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def init():
|
| 250 |
+
r"""Initialize PyTorch's CUDA state.
|
| 251 |
+
|
| 252 |
+
You may need to call this explicitly if you are interacting with
|
| 253 |
+
PyTorch via its C API, as Python bindings for CUDA functionality
|
| 254 |
+
will not be available until this initialization takes place.
|
| 255 |
+
Ordinary users should not need this, as all of PyTorch's CUDA methods
|
| 256 |
+
automatically initialize CUDA state on-demand.
|
| 257 |
+
|
| 258 |
+
Does nothing if the CUDA state is already initialized.
|
| 259 |
+
"""
|
| 260 |
+
_lazy_init()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _lazy_init():
|
| 264 |
+
global _initialized, _queued_calls
|
| 265 |
+
if is_initialized() or hasattr(_tls, "is_initializing"):
|
| 266 |
+
return
|
| 267 |
+
with _initialization_lock:
|
| 268 |
+
# We be double-checked locking, boys! This is OK because
|
| 269 |
+
# the above test was GIL protected anyway. The inner test
|
| 270 |
+
# is for when a thread blocked on some other thread which was
|
| 271 |
+
# doing the initialization; when they get the lock, they will
|
| 272 |
+
# find there is nothing left to do.
|
| 273 |
+
if is_initialized():
|
| 274 |
+
return
|
| 275 |
+
# It is important to prevent other threads from entering _lazy_init
|
| 276 |
+
# immediately, while we are still guaranteed to have the GIL, because some
|
| 277 |
+
# of the C calls we make below will release the GIL
|
| 278 |
+
if _is_in_bad_fork():
|
| 279 |
+
raise RuntimeError(
|
| 280 |
+
"Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
|
| 281 |
+
"multiprocessing, you must use the 'spawn' start method"
|
| 282 |
+
)
|
| 283 |
+
if not hasattr(torch._C, "_cuda_getDeviceCount"):
|
| 284 |
+
raise AssertionError("Torch not compiled with CUDA enabled")
|
| 285 |
+
if _cudart is None:
|
| 286 |
+
raise AssertionError(
|
| 287 |
+
"libcudart functions unavailable. It looks like you have a broken build?"
|
| 288 |
+
)
|
| 289 |
+
# This function throws if there's a driver initialization error, no GPUs
|
| 290 |
+
# are found or any other error occurs
|
| 291 |
+
if "CUDA_MODULE_LOADING" not in os.environ:
|
| 292 |
+
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
|
| 293 |
+
torch._C._cuda_init()
|
| 294 |
+
# Some of the queued calls may reentrantly call _lazy_init();
|
| 295 |
+
# we need to just return without initializing in that case.
|
| 296 |
+
# However, we must not let any *other* threads in!
|
| 297 |
+
_tls.is_initializing = True
|
| 298 |
+
|
| 299 |
+
for calls in _lazy_seed_tracker.get_calls():
|
| 300 |
+
if calls:
|
| 301 |
+
_queued_calls.append(calls)
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
for queued_call, orig_traceback in _queued_calls:
|
| 305 |
+
try:
|
| 306 |
+
queued_call()
|
| 307 |
+
except Exception as e:
|
| 308 |
+
msg = (
|
| 309 |
+
f"CUDA call failed lazily at initialization with error: {str(e)}\n\n"
|
| 310 |
+
f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
| 311 |
+
)
|
| 312 |
+
raise DeferredCudaCallError(msg) from e
|
| 313 |
+
finally:
|
| 314 |
+
delattr(_tls, "is_initializing")
|
| 315 |
+
_initialized = True
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def cudart():
|
| 319 |
+
_lazy_init()
|
| 320 |
+
return _cudart
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class cudaStatus:
|
| 324 |
+
SUCCESS: int = 0
|
| 325 |
+
ERROR_NOT_READY: int = 34
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class CudaError(RuntimeError):
|
| 329 |
+
def __init__(self, code: int) -> None:
|
| 330 |
+
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
|
| 331 |
+
super().__init__(f"{msg} ({code})")
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def check_error(res: int) -> None:
|
| 335 |
+
if res != _cudart.cudaError.success:
|
| 336 |
+
raise CudaError(res)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class _DeviceGuard:
|
| 340 |
+
def __init__(self, index: int):
|
| 341 |
+
self.idx = index
|
| 342 |
+
self.prev_idx = -1
|
| 343 |
+
|
| 344 |
+
def __enter__(self):
|
| 345 |
+
self.prev_idx = torch.cuda._exchange_device(self.idx)
|
| 346 |
+
|
| 347 |
+
def __exit__(self, type: Any, value: Any, traceback: Any):
|
| 348 |
+
self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
|
| 349 |
+
return False
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class device:
|
| 353 |
+
r"""Context-manager that changes the selected device.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
device (torch.device or int): device index to select. It's a no-op if
|
| 357 |
+
this argument is a negative integer or ``None``.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
def __init__(self, device: Any):
|
| 361 |
+
self.idx = _get_device_index(device, optional=True)
|
| 362 |
+
self.prev_idx = -1
|
| 363 |
+
|
| 364 |
+
def __enter__(self):
|
| 365 |
+
self.prev_idx = torch.cuda._exchange_device(self.idx)
|
| 366 |
+
|
| 367 |
+
def __exit__(self, type: Any, value: Any, traceback: Any):
|
| 368 |
+
self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
|
| 369 |
+
return False
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class device_of(device):
|
| 373 |
+
r"""Context-manager that changes the current device to that of given object.
|
| 374 |
+
|
| 375 |
+
You can use both tensors and storages as arguments. If a given object is
|
| 376 |
+
not allocated on a GPU, this is a no-op.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
obj (Tensor or Storage): object allocated on the selected device.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
def __init__(self, obj):
|
| 383 |
+
idx = obj.get_device() if obj.is_cuda else -1
|
| 384 |
+
super().__init__(idx)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def set_device(device: _device_t) -> None:
|
| 388 |
+
r"""Set the current device.
|
| 389 |
+
|
| 390 |
+
Usage of this function is discouraged in favor of :any:`device`. In most
|
| 391 |
+
cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
device (torch.device or int): selected device. This function is a no-op
|
| 395 |
+
if this argument is negative.
|
| 396 |
+
"""
|
| 397 |
+
device = _get_device_index(device)
|
| 398 |
+
if device >= 0:
|
| 399 |
+
torch._C._cuda_setDevice(device)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def get_device_name(device: Optional[_device_t] = None) -> str:
|
| 403 |
+
r"""Get the name of a device.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
device (torch.device or int, optional): device for which to return the
|
| 407 |
+
name. This function is a no-op if this argument is a negative
|
| 408 |
+
integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
|
| 409 |
+
if :attr:`device` is ``None`` (default).
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
str: the name of the device
|
| 413 |
+
"""
|
| 414 |
+
return get_device_properties(device).name
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
|
| 418 |
+
r"""Get the cuda capability of a device.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
device (torch.device or int, optional): device for which to return the
|
| 422 |
+
device capability. This function is a no-op if this argument is
|
| 423 |
+
a negative integer. It uses the current device, given by
|
| 424 |
+
:func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
|
| 425 |
+
(default).
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
tuple(int, int): the major and minor cuda capability of the device
|
| 429 |
+
"""
|
| 430 |
+
prop = get_device_properties(device)
|
| 431 |
+
return prop.major, prop.minor
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
|
| 435 |
+
r"""Get the properties of a device.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
device (torch.device or int or str): device for which to return the
|
| 439 |
+
properties of the device.
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
_CudaDeviceProperties: the properties of the device
|
| 443 |
+
"""
|
| 444 |
+
_lazy_init() # will define _get_device_properties
|
| 445 |
+
device = _get_device_index(device, optional=True)
|
| 446 |
+
if device < 0 or device >= device_count():
|
| 447 |
+
raise AssertionError("Invalid device id")
|
| 448 |
+
return _get_device_properties(device) # type: ignore[name-defined]
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool:
|
| 452 |
+
r"""Check if peer access between two devices is possible."""
|
| 453 |
+
_lazy_init()
|
| 454 |
+
device = _get_device_index(device, optional=True)
|
| 455 |
+
peer_device = _get_device_index(peer_device)
|
| 456 |
+
if device < 0 or device >= device_count():
|
| 457 |
+
raise AssertionError("Invalid device id")
|
| 458 |
+
if peer_device < 0 or peer_device >= device_count():
|
| 459 |
+
raise AssertionError("Invalid peer device id")
|
| 460 |
+
return torch._C._cuda_canDeviceAccessPeer(device, peer_device)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class StreamContext:
|
| 464 |
+
r"""Context-manager that selects a given stream.
|
| 465 |
+
|
| 466 |
+
All CUDA kernels queued within its context will be enqueued on a selected
|
| 467 |
+
stream.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
Stream (Stream): selected stream. This manager is a no-op if it's
|
| 471 |
+
``None``.
|
| 472 |
+
.. note:: Streams are per-device.
|
| 473 |
+
"""
|
| 474 |
+
cur_stream: Optional["torch.cuda.Stream"]
|
| 475 |
+
|
| 476 |
+
def __init__(self, stream: Optional["torch.cuda.Stream"]):
|
| 477 |
+
self.stream = stream
|
| 478 |
+
self.idx = _get_device_index(None, True)
|
| 479 |
+
if not torch.jit.is_scripting():
|
| 480 |
+
if self.idx is None:
|
| 481 |
+
self.idx = -1
|
| 482 |
+
|
| 483 |
+
self.src_prev_stream = (
|
| 484 |
+
None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
|
| 485 |
+
)
|
| 486 |
+
self.dst_prev_stream = (
|
| 487 |
+
None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
def __enter__(self):
|
| 491 |
+
# Local cur_stream variable for type refinement
|
| 492 |
+
cur_stream = self.stream
|
| 493 |
+
# Return if stream is None or CUDA device not available
|
| 494 |
+
if cur_stream is None or self.idx == -1:
|
| 495 |
+
return
|
| 496 |
+
self.src_prev_stream = torch.cuda.current_stream(None)
|
| 497 |
+
|
| 498 |
+
# If the stream is not on the current device, then
|
| 499 |
+
# set the current stream on the device
|
| 500 |
+
if self.src_prev_stream.device != cur_stream.device:
|
| 501 |
+
with device(cur_stream.device):
|
| 502 |
+
self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device)
|
| 503 |
+
torch.cuda.set_stream(cur_stream)
|
| 504 |
+
|
| 505 |
+
def __exit__(self, type: Any, value: Any, traceback: Any):
|
| 506 |
+
# Local cur_stream variable for type refinement
|
| 507 |
+
cur_stream = self.stream
|
| 508 |
+
# If stream is None or no CUDA device available, return
|
| 509 |
+
if cur_stream is None or self.idx == -1:
|
| 510 |
+
return
|
| 511 |
+
|
| 512 |
+
# Reset the stream on the original device
|
| 513 |
+
# and destination device
|
| 514 |
+
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
|
| 515 |
+
torch.cuda.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
|
| 516 |
+
torch.cuda.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext:
|
| 520 |
+
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
| 521 |
+
|
| 522 |
+
Arguments:
|
| 523 |
+
stream (Stream): selected stream. This manager is a no-op if it's
|
| 524 |
+
``None``.
|
| 525 |
+
..Note:: In eager mode stream is of type Stream class while in JIT it is
|
| 526 |
+
an object of the custom class ``torch.classes.cuda.Stream``.
|
| 527 |
+
"""
|
| 528 |
+
return StreamContext(stream)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def _set_stream_by_id(stream_id, device_index, device_type):
|
| 532 |
+
r"""set stream specified by the stream id, device index and
|
| 533 |
+
device type
|
| 534 |
+
|
| 535 |
+
Args: stream_id (int): stream id in stream pool
|
| 536 |
+
device_index (int): device index in topo
|
| 537 |
+
device_type (int): enum device type
|
| 538 |
+
"""
|
| 539 |
+
torch._C._cuda_setStream(
|
| 540 |
+
stream_id=stream_id,
|
| 541 |
+
device_index=device_index,
|
| 542 |
+
device_type=device_type,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def set_stream(stream: Stream):
|
| 547 |
+
r"""Set the current stream.This is a wrapper API to set the stream.
|
| 548 |
+
Usage of this function is discouraged in favor of the ``stream``
|
| 549 |
+
context manager.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
stream (Stream): selected stream. This function is a no-op
|
| 553 |
+
if this argument is ``None``.
|
| 554 |
+
"""
|
| 555 |
+
if stream is None:
|
| 556 |
+
return
|
| 557 |
+
_set_stream_by_id(
|
| 558 |
+
stream_id=stream.stream_id,
|
| 559 |
+
device_index=stream.device_index,
|
| 560 |
+
device_type=stream.device_type,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def _parse_visible_devices() -> Union[List[int], List[str]]:
|
| 565 |
+
r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
|
| 566 |
+
var = os.getenv("CUDA_VISIBLE_DEVICES")
|
| 567 |
+
if var is None:
|
| 568 |
+
return list(range(64))
|
| 569 |
+
|
| 570 |
+
def _strtoul(s: str) -> int:
|
| 571 |
+
"""Return -1 or positive integer sequence string starts with."""
|
| 572 |
+
if not s:
|
| 573 |
+
return -1
|
| 574 |
+
for idx, c in enumerate(s):
|
| 575 |
+
if not (c.isdigit() or (idx == 0 and c in "+-")):
|
| 576 |
+
break
|
| 577 |
+
if idx + 1 == len(s):
|
| 578 |
+
idx += 1
|
| 579 |
+
return int(s[:idx]) if idx > 0 else -1
|
| 580 |
+
|
| 581 |
+
def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
|
| 582 |
+
rcs: List[str] = []
|
| 583 |
+
for elem in lst.split(","):
|
| 584 |
+
# Repeated id results in empty set
|
| 585 |
+
if elem in rcs:
|
| 586 |
+
return cast(List[str], [])
|
| 587 |
+
# Anything other but prefix is ignored
|
| 588 |
+
if not elem.startswith(prefix):
|
| 589 |
+
break
|
| 590 |
+
rcs.append(elem)
|
| 591 |
+
return rcs
|
| 592 |
+
|
| 593 |
+
if var.startswith("GPU-"):
|
| 594 |
+
return parse_list_with_prefix(var, "GPU-")
|
| 595 |
+
if var.startswith("MIG-"):
|
| 596 |
+
return parse_list_with_prefix(var, "MIG-")
|
| 597 |
+
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
| 598 |
+
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
| 599 |
+
rc: List[int] = []
|
| 600 |
+
for elem in var.split(","):
|
| 601 |
+
x = _strtoul(elem.strip())
|
| 602 |
+
# Repeated ordinal results in empty set
|
| 603 |
+
if x in rc:
|
| 604 |
+
return cast(List[int], [])
|
| 605 |
+
# Negative value aborts the sequence
|
| 606 |
+
if x < 0:
|
| 607 |
+
break
|
| 608 |
+
rc.append(x)
|
| 609 |
+
return rc
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def _raw_device_count_nvml() -> int:
|
| 613 |
+
r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
|
| 614 |
+
from ctypes import byref, c_int, CDLL
|
| 615 |
+
|
| 616 |
+
nvml_h = CDLL("libnvidia-ml.so.1")
|
| 617 |
+
rc = nvml_h.nvmlInit()
|
| 618 |
+
if rc != 0:
|
| 619 |
+
warnings.warn("Can't initialize NVML")
|
| 620 |
+
return -1
|
| 621 |
+
dev_count = c_int(-1)
|
| 622 |
+
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
| 623 |
+
if rc != 0:
|
| 624 |
+
warnings.warn("Can't get nvml device count")
|
| 625 |
+
return -1
|
| 626 |
+
del nvml_h
|
| 627 |
+
return dev_count.value
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def _raw_device_uuid_nvml() -> Optional[List[str]]:
|
| 631 |
+
r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
|
| 632 |
+
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
|
| 633 |
+
|
| 634 |
+
nvml_h = CDLL("libnvidia-ml.so.1")
|
| 635 |
+
rc = nvml_h.nvmlInit()
|
| 636 |
+
if rc != 0:
|
| 637 |
+
warnings.warn("Can't initialize NVML")
|
| 638 |
+
return None
|
| 639 |
+
dev_count = c_int(-1)
|
| 640 |
+
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
| 641 |
+
if rc != 0:
|
| 642 |
+
warnings.warn("Can't get nvml device count")
|
| 643 |
+
return None
|
| 644 |
+
uuids: List[str] = []
|
| 645 |
+
for idx in range(dev_count.value):
|
| 646 |
+
dev_id = c_void_p()
|
| 647 |
+
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
| 648 |
+
if rc != 0:
|
| 649 |
+
warnings.warn("Can't get device handle")
|
| 650 |
+
return None
|
| 651 |
+
buf_len = 96
|
| 652 |
+
buf = create_string_buffer(buf_len)
|
| 653 |
+
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
|
| 654 |
+
if rc != 0:
|
| 655 |
+
warnings.warn("Can't get device UUID")
|
| 656 |
+
return None
|
| 657 |
+
uuids.append(buf.raw.decode("ascii").strip("\0"))
|
| 658 |
+
del nvml_h
|
| 659 |
+
return uuids
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
|
| 663 |
+
r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""
|
| 664 |
+
|
| 665 |
+
def uuid_to_orinal(candidate: str, uuids: List[str]) -> int:
|
| 666 |
+
best_match = -1
|
| 667 |
+
for idx, uuid in enumerate(uuids):
|
| 668 |
+
if not uuid.startswith(candidate):
|
| 669 |
+
continue
|
| 670 |
+
# Ambiguous candidate
|
| 671 |
+
if best_match != -1:
|
| 672 |
+
return -1
|
| 673 |
+
best_match = idx
|
| 674 |
+
return best_match
|
| 675 |
+
|
| 676 |
+
rc: List[int] = []
|
| 677 |
+
for candidate in candidates:
|
| 678 |
+
idx = uuid_to_orinal(candidate, uuids)
|
| 679 |
+
# First invalid ordinal stops parsing
|
| 680 |
+
if idx < 0:
|
| 681 |
+
break
|
| 682 |
+
# Duplicates result in empty set
|
| 683 |
+
if idx in rc:
|
| 684 |
+
return cast(List[int], [])
|
| 685 |
+
rc.append(idx)
|
| 686 |
+
return rc
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def _device_count_nvml() -> int:
|
| 690 |
+
r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
|
| 691 |
+
|
| 692 |
+
Negative value is returned if NVML discovery or initialization has failed.
|
| 693 |
+
"""
|
| 694 |
+
visible_devices = _parse_visible_devices()
|
| 695 |
+
if not visible_devices:
|
| 696 |
+
return 0
|
| 697 |
+
try:
|
| 698 |
+
if type(visible_devices[0]) is str:
|
| 699 |
+
# Skip MIG parsing
|
| 700 |
+
if visible_devices[0].startswith("MIG-"):
|
| 701 |
+
return -1
|
| 702 |
+
uuids = _raw_device_uuid_nvml()
|
| 703 |
+
if uuids is None:
|
| 704 |
+
return -1
|
| 705 |
+
visible_devices = _transform_uuid_to_ordinals(
|
| 706 |
+
cast(List[str], visible_devices), uuids
|
| 707 |
+
)
|
| 708 |
+
else:
|
| 709 |
+
raw_cnt = _raw_device_count_nvml()
|
| 710 |
+
if raw_cnt <= 0:
|
| 711 |
+
return raw_cnt
|
| 712 |
+
# Trim the list up to a maximum available device
|
| 713 |
+
for idx, val in enumerate(visible_devices):
|
| 714 |
+
if cast(int, val) >= raw_cnt:
|
| 715 |
+
return idx
|
| 716 |
+
except OSError:
|
| 717 |
+
return -1
|
| 718 |
+
except AttributeError:
|
| 719 |
+
return -1
|
| 720 |
+
return len(visible_devices)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
|
| 724 |
+
r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
|
| 725 |
+
idx = _get_device_index(device, optional=True)
|
| 726 |
+
visible_devices = _parse_visible_devices()
|
| 727 |
+
if type(visible_devices[0]) is str:
|
| 728 |
+
uuids = _raw_device_uuid_nvml()
|
| 729 |
+
if uuids is None:
|
| 730 |
+
raise RuntimeError("Can't get device UUIDs")
|
| 731 |
+
visible_devices = _transform_uuid_to_ordinals(
|
| 732 |
+
cast(List[str], visible_devices), uuids
|
| 733 |
+
)
|
| 734 |
+
visible_devices = cast(List[int], visible_devices)
|
| 735 |
+
if idx < 0 or idx >= len(visible_devices):
|
| 736 |
+
raise RuntimeError(
|
| 737 |
+
f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})"
|
| 738 |
+
)
|
| 739 |
+
return visible_devices[idx]
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
@lru_cache(maxsize=1)
|
| 743 |
+
def device_count() -> int:
|
| 744 |
+
r"""Return the number of GPUs available."""
|
| 745 |
+
if not _is_compiled():
|
| 746 |
+
return 0
|
| 747 |
+
# bypass _device_count_nvml() if rocm (not supported)
|
| 748 |
+
nvml_count = -1 if torch.version.hip else _device_count_nvml()
|
| 749 |
+
return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def get_arch_list() -> List[str]:
|
| 753 |
+
r"""Return list CUDA architectures this library was compiled for."""
|
| 754 |
+
if not is_available():
|
| 755 |
+
return []
|
| 756 |
+
arch_flags = torch._C._cuda_getArchFlags()
|
| 757 |
+
if arch_flags is None:
|
| 758 |
+
return []
|
| 759 |
+
return arch_flags.split()
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def get_gencode_flags() -> str:
|
| 763 |
+
r"""Return NVCC gencode flags this library was compiled with."""
|
| 764 |
+
arch_list = get_arch_list()
|
| 765 |
+
if len(arch_list) == 0:
|
| 766 |
+
return ""
|
| 767 |
+
arch_list_ = [arch.split("_") for arch in arch_list]
|
| 768 |
+
return " ".join(
|
| 769 |
+
[
|
| 770 |
+
f"-gencode compute=compute_{arch},code={kind}_{arch}"
|
| 771 |
+
for (kind, arch) in arch_list_
|
| 772 |
+
]
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def current_device() -> int:
|
| 777 |
+
r"""Return the index of a currently selected device."""
|
| 778 |
+
_lazy_init()
|
| 779 |
+
return torch._C._cuda_getDevice()
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def synchronize(device: _device_t = None) -> None:
|
| 783 |
+
r"""Wait for all kernels in all streams on a CUDA device to complete.
|
| 784 |
+
|
| 785 |
+
Args:
|
| 786 |
+
device (torch.device or int, optional): device for which to synchronize.
|
| 787 |
+
It uses the current device, given by :func:`~torch.cuda.current_device`,
|
| 788 |
+
if :attr:`device` is ``None`` (default).
|
| 789 |
+
"""
|
| 790 |
+
_lazy_init()
|
| 791 |
+
with torch.cuda.device(device):
|
| 792 |
+
return torch._C._cuda_synchronize()
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
def ipc_collect():
|
| 796 |
+
r"""Force collects GPU memory after it has been released by CUDA IPC.
|
| 797 |
+
|
| 798 |
+
.. note::
|
| 799 |
+
Checks if any sent CUDA tensors could be cleaned from the memory. Force
|
| 800 |
+
closes shared memory file used for reference counting if there is no
|
| 801 |
+
active counters. Useful when the producer process stopped actively sending
|
| 802 |
+
tensors and want to release unused memory.
|
| 803 |
+
"""
|
| 804 |
+
_lazy_init()
|
| 805 |
+
return torch._C._cuda_ipc_collect()
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
| 809 |
+
r"""Return the currently selected :class:`Stream` for a given device.
|
| 810 |
+
|
| 811 |
+
Args:
|
| 812 |
+
device (torch.device or int, optional): selected device. Returns
|
| 813 |
+
the currently selected :class:`Stream` for the current device, given
|
| 814 |
+
by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
|
| 815 |
+
(default).
|
| 816 |
+
"""
|
| 817 |
+
_lazy_init()
|
| 818 |
+
streamdata = torch._C._cuda_getCurrentStream(
|
| 819 |
+
_get_device_index(device, optional=True)
|
| 820 |
+
)
|
| 821 |
+
return Stream(
|
| 822 |
+
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def default_stream(device: Optional[_device_t] = None) -> Stream:
|
| 827 |
+
r"""Return the default :class:`Stream` for a given device.
|
| 828 |
+
|
| 829 |
+
Args:
|
| 830 |
+
device (torch.device or int, optional): selected device. Returns
|
| 831 |
+
the default :class:`Stream` for the current device, given by
|
| 832 |
+
:func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
|
| 833 |
+
(default).
|
| 834 |
+
"""
|
| 835 |
+
_lazy_init()
|
| 836 |
+
streamdata = torch._C._cuda_getDefaultStream(
|
| 837 |
+
_get_device_index(device, optional=True)
|
| 838 |
+
)
|
| 839 |
+
return Stream(
|
| 840 |
+
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
def current_blas_handle():
|
| 845 |
+
r"""Return cublasHandle_t pointer to current cuBLAS handle"""
|
| 846 |
+
_lazy_init()
|
| 847 |
+
return torch._C._cuda_getCurrentBlasHandle()
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def set_sync_debug_mode(debug_mode: Union[int, str]) -> None:
|
| 851 |
+
r"""Set the debug mode for cuda synchronizing operations.
|
| 852 |
+
|
| 853 |
+
Args:
|
| 854 |
+
debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations,
|
| 855 |
+
if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations.
|
| 856 |
+
|
| 857 |
+
Warning:
|
| 858 |
+
This is an experimental feature, and not all synchronizing operations will trigger warning or error. In
|
| 859 |
+
particular, operations in torch.distributed and torch.sparse namespaces are not covered yet.
|
| 860 |
+
"""
|
| 861 |
+
_lazy_init()
|
| 862 |
+
if isinstance(debug_mode, str):
|
| 863 |
+
if debug_mode == "default":
|
| 864 |
+
debug_mode = 0
|
| 865 |
+
elif debug_mode == "warn":
|
| 866 |
+
debug_mode = 1
|
| 867 |
+
elif debug_mode == "error":
|
| 868 |
+
debug_mode = 2
|
| 869 |
+
else:
|
| 870 |
+
raise RuntimeError(
|
| 871 |
+
"invalid value of debug_mode, expected one of `default`, `warn`, `error`"
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
torch._C._cuda_set_sync_debug_mode(debug_mode)
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
def get_sync_debug_mode() -> int:
|
| 878 |
+
r"""Return current value of debug mode for cuda synchronizing operations."""
|
| 879 |
+
_lazy_init()
|
| 880 |
+
return torch._C._cuda_get_sync_debug_mode()
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
|
| 884 |
+
if not _HAS_PYNVML:
|
| 885 |
+
raise ModuleNotFoundError(
|
| 886 |
+
"pynvml does not seem to be installed or it can't be imported."
|
| 887 |
+
) from _PYNVML_ERR
|
| 888 |
+
from pynvml import NVMLError_DriverNotLoaded
|
| 889 |
+
|
| 890 |
+
try:
|
| 891 |
+
pynvml.nvmlInit()
|
| 892 |
+
except NVMLError_DriverNotLoaded as e:
|
| 893 |
+
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
|
| 894 |
+
|
| 895 |
+
device = _get_nvml_device_index(device)
|
| 896 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 897 |
+
return handle
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
|
| 901 |
+
r"""Return the percent of time over the past sample period during which global (device)
|
| 902 |
+
memory was being read or written as given by `nvidia-smi`.
|
| 903 |
+
|
| 904 |
+
Args:
|
| 905 |
+
device (torch.device or int, optional): selected device. Returns
|
| 906 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 907 |
+
if :attr:`device` is ``None`` (default).
|
| 908 |
+
|
| 909 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 910 |
+
depending on the product being queried.
|
| 911 |
+
"""
|
| 912 |
+
handle = _get_pynvml_handler()
|
| 913 |
+
|
| 914 |
+
device = _get_nvml_device_index(device)
|
| 915 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 916 |
+
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
def utilization(device: Optional[Union[Device, int]] = None) -> int:
|
| 920 |
+
r"""Return the percent of time over the past sample period during which one or
|
| 921 |
+
more kernels was executing on the GPU as given by `nvidia-smi`.
|
| 922 |
+
|
| 923 |
+
Args:
|
| 924 |
+
device (torch.device or int, optional): selected device. Returns
|
| 925 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 926 |
+
if :attr:`device` is ``None`` (default).
|
| 927 |
+
|
| 928 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 929 |
+
depending on the product being queried.
|
| 930 |
+
"""
|
| 931 |
+
handle = _get_pynvml_handler(device)
|
| 932 |
+
device = _get_nvml_device_index(device)
|
| 933 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 934 |
+
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def temperature(device: Optional[Union[Device, int]] = None) -> int:
|
| 938 |
+
r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
|
| 939 |
+
|
| 940 |
+
The average temperature is computed based on past sample period as given by `nvidia-smi`.
|
| 941 |
+
|
| 942 |
+
Args:
|
| 943 |
+
device (torch.device or int, optional): selected device. Returns
|
| 944 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 945 |
+
if :attr:`device` is ``None`` (default).
|
| 946 |
+
|
| 947 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 948 |
+
depending on the product being queried.
|
| 949 |
+
"""
|
| 950 |
+
handle = _get_pynvml_handler(device)
|
| 951 |
+
# 0 refers to the temperature sensor for the GPU die.
|
| 952 |
+
return pynvml.nvmlDeviceGetTemperature(handle, 0)
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
def power_draw(device: Optional[Union[Device, int]] = None) -> int:
|
| 956 |
+
r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
|
| 957 |
+
over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
|
| 958 |
+
|
| 959 |
+
Args:
|
| 960 |
+
device (torch.device or int, optional): selected device. Returns
|
| 961 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 962 |
+
if :attr:`device` is ``None`` (default).
|
| 963 |
+
|
| 964 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 965 |
+
depending on the product being queried.
|
| 966 |
+
"""
|
| 967 |
+
handle = _get_pynvml_handler(device)
|
| 968 |
+
return pynvml.nvmlDeviceGetPowerUsage(handle)
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
|
| 972 |
+
r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`.
|
| 973 |
+
|
| 974 |
+
Args:
|
| 975 |
+
device (torch.device or int, optional): selected device. Returns
|
| 976 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 977 |
+
if :attr:`device` is ``None`` (default).
|
| 978 |
+
|
| 979 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 980 |
+
depending on the product being queried.
|
| 981 |
+
"""
|
| 982 |
+
handle = _get_pynvml_handler(device)
|
| 983 |
+
return pynvml.nvmlDeviceGetClockInfo(handle, 1)
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
def _get_device(device: Union[int, str, torch.device]) -> torch.device:
|
| 987 |
+
r"""Return the torch.device type object from the passed in device.
|
| 988 |
+
|
| 989 |
+
Args:
|
| 990 |
+
device (torch.device or int): selected device.
|
| 991 |
+
"""
|
| 992 |
+
if isinstance(device, str):
|
| 993 |
+
device = torch.device(device)
|
| 994 |
+
elif isinstance(device, int):
|
| 995 |
+
device = torch.device("cuda", device)
|
| 996 |
+
return device
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def _get_generator(device: torch.device) -> torch._C.Generator:
|
| 1000 |
+
r"""Return the CUDA Generator object for the given device.
|
| 1001 |
+
|
| 1002 |
+
Args:
|
| 1003 |
+
device (torch.device): selected device.
|
| 1004 |
+
"""
|
| 1005 |
+
idx = device.index
|
| 1006 |
+
if idx is None:
|
| 1007 |
+
idx = current_device()
|
| 1008 |
+
return torch.cuda.default_generators[idx]
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
def _set_rng_state_offset(
|
| 1012 |
+
offset: int, device: Union[int, str, torch.device] = "cuda"
|
| 1013 |
+
) -> None:
|
| 1014 |
+
r"""Set the random number generator state offset of the specified GPU.
|
| 1015 |
+
|
| 1016 |
+
Args:
|
| 1017 |
+
offset (int): The desired offset
|
| 1018 |
+
device (torch.device or int, optional): The device to set the RNG state.
|
| 1019 |
+
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
| 1020 |
+
"""
|
| 1021 |
+
final_device = _get_device(device)
|
| 1022 |
+
|
| 1023 |
+
def cb():
|
| 1024 |
+
default_generator = _get_generator(final_device)
|
| 1025 |
+
default_generator.set_offset(offset)
|
| 1026 |
+
|
| 1027 |
+
_lazy_call(cb)
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int:
|
| 1031 |
+
r"""Return the random number generator state offset of the specified GPU.
|
| 1032 |
+
|
| 1033 |
+
Args:
|
| 1034 |
+
device (torch.device or int, optional): The device to return the RNG state offset of.
|
| 1035 |
+
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
| 1036 |
+
|
| 1037 |
+
.. warning::
|
| 1038 |
+
This function eagerly initializes CUDA.
|
| 1039 |
+
"""
|
| 1040 |
+
_lazy_init()
|
| 1041 |
+
final_device = _get_device(device)
|
| 1042 |
+
default_generator = _get_generator(final_device)
|
| 1043 |
+
return default_generator.get_offset()
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
from .memory import * # noqa: F403
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
+
from .random import * # noqa: F403
|
| 1050 |
+
|
| 1051 |
+
################################################################################
|
| 1052 |
+
# Define Storage and Tensor classes
|
| 1053 |
+
################################################################################
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
@staticmethod # type: ignore[misc]
|
| 1057 |
+
def _lazy_new(cls, *args, **kwargs):
|
| 1058 |
+
_lazy_init()
|
| 1059 |
+
# We may need to call lazy init again if we are a forked child
|
| 1060 |
+
# del _CudaBase.__new__
|
| 1061 |
+
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
class _CudaBase:
|
| 1065 |
+
is_cuda = True
|
| 1066 |
+
is_sparse = False
|
| 1067 |
+
|
| 1068 |
+
def type(self, *args, **kwargs):
|
| 1069 |
+
# We could use a Protocol here to tell mypy that self has `get_device` method
|
| 1070 |
+
# but it is only available in the typing module on Python >= 3.8
|
| 1071 |
+
# or on typing_extensions module on Python >= 3.6
|
| 1072 |
+
with device(self.get_device()): # type: ignore[attr-defined]
|
| 1073 |
+
return super().type(*args, **kwargs) # type: ignore[misc]
|
| 1074 |
+
|
| 1075 |
+
__new__ = _lazy_new
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
from torch.storage import _LegacyStorage, _warn_typed_storage_removal
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
class _CudaLegacyStorage(_LegacyStorage):
|
| 1082 |
+
@classmethod
|
| 1083 |
+
def from_buffer(cls, *args, **kwargs):
|
| 1084 |
+
_warn_typed_storage_removal()
|
| 1085 |
+
raise RuntimeError("from_buffer: Not available for CUDA storage")
|
| 1086 |
+
|
| 1087 |
+
@classmethod
|
| 1088 |
+
def _new_with_weak_ptr(cls, *args, **kwargs):
|
| 1089 |
+
raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage")
|
| 1090 |
+
|
| 1091 |
+
@classmethod
|
| 1092 |
+
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
|
| 1093 |
+
raise RuntimeError("_new_shared_filename: Not available for CUDA storage")
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
class ByteStorage(_CudaLegacyStorage):
|
| 1097 |
+
@classproperty
|
| 1098 |
+
def dtype(self):
|
| 1099 |
+
_warn_typed_storage_removal()
|
| 1100 |
+
return self._dtype
|
| 1101 |
+
|
| 1102 |
+
@classproperty
|
| 1103 |
+
def _dtype(self):
|
| 1104 |
+
return torch.uint8
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
class DoubleStorage(_CudaLegacyStorage):
|
| 1108 |
+
@classproperty
|
| 1109 |
+
def dtype(self):
|
| 1110 |
+
_warn_typed_storage_removal()
|
| 1111 |
+
return self._dtype
|
| 1112 |
+
|
| 1113 |
+
@classproperty
|
| 1114 |
+
def _dtype(self):
|
| 1115 |
+
return torch.double
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
class FloatStorage(_CudaLegacyStorage):
|
| 1119 |
+
@classproperty
|
| 1120 |
+
def dtype(self):
|
| 1121 |
+
_warn_typed_storage_removal()
|
| 1122 |
+
return self._dtype
|
| 1123 |
+
|
| 1124 |
+
@classproperty
|
| 1125 |
+
def _dtype(self):
|
| 1126 |
+
return torch.float
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
class HalfStorage(_CudaLegacyStorage):
|
| 1130 |
+
@classproperty
|
| 1131 |
+
def dtype(self):
|
| 1132 |
+
_warn_typed_storage_removal()
|
| 1133 |
+
return self._dtype
|
| 1134 |
+
|
| 1135 |
+
@classproperty
|
| 1136 |
+
def _dtype(self):
|
| 1137 |
+
return torch.half
|
| 1138 |
+
|
| 1139 |
+
|
| 1140 |
+
class LongStorage(_CudaLegacyStorage):
|
| 1141 |
+
@classproperty
|
| 1142 |
+
def dtype(self):
|
| 1143 |
+
_warn_typed_storage_removal()
|
| 1144 |
+
return self._dtype
|
| 1145 |
+
|
| 1146 |
+
@classproperty
|
| 1147 |
+
def _dtype(self):
|
| 1148 |
+
return torch.long
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
class IntStorage(_CudaLegacyStorage):
|
| 1152 |
+
@classproperty
|
| 1153 |
+
def dtype(self):
|
| 1154 |
+
_warn_typed_storage_removal()
|
| 1155 |
+
return self._dtype
|
| 1156 |
+
|
| 1157 |
+
@classproperty
|
| 1158 |
+
def _dtype(self):
|
| 1159 |
+
return torch.int
|
| 1160 |
+
|
| 1161 |
+
|
| 1162 |
+
class ShortStorage(_CudaLegacyStorage):
|
| 1163 |
+
@classproperty
|
| 1164 |
+
def dtype(self):
|
| 1165 |
+
_warn_typed_storage_removal()
|
| 1166 |
+
return self._dtype
|
| 1167 |
+
|
| 1168 |
+
@classproperty
|
| 1169 |
+
def _dtype(self):
|
| 1170 |
+
return torch.short
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
class CharStorage(_CudaLegacyStorage):
|
| 1174 |
+
@classproperty
|
| 1175 |
+
def dtype(self):
|
| 1176 |
+
_warn_typed_storage_removal()
|
| 1177 |
+
return self._dtype
|
| 1178 |
+
|
| 1179 |
+
@classproperty
|
| 1180 |
+
def _dtype(self):
|
| 1181 |
+
return torch.int8
|
| 1182 |
+
|
| 1183 |
+
|
| 1184 |
+
class BoolStorage(_CudaLegacyStorage):
|
| 1185 |
+
@classproperty
|
| 1186 |
+
def dtype(self):
|
| 1187 |
+
_warn_typed_storage_removal()
|
| 1188 |
+
return self._dtype
|
| 1189 |
+
|
| 1190 |
+
@classproperty
|
| 1191 |
+
def _dtype(self):
|
| 1192 |
+
return torch.bool
|
| 1193 |
+
|
| 1194 |
+
|
| 1195 |
+
class BFloat16Storage(_CudaLegacyStorage):
|
| 1196 |
+
@classproperty
|
| 1197 |
+
def dtype(self):
|
| 1198 |
+
_warn_typed_storage_removal()
|
| 1199 |
+
return self._dtype
|
| 1200 |
+
|
| 1201 |
+
@classproperty
|
| 1202 |
+
def _dtype(self):
|
| 1203 |
+
return torch.bfloat16
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
class ComplexDoubleStorage(_CudaLegacyStorage):
|
| 1207 |
+
@classproperty
|
| 1208 |
+
def dtype(self):
|
| 1209 |
+
_warn_typed_storage_removal()
|
| 1210 |
+
return self._dtype
|
| 1211 |
+
|
| 1212 |
+
@classproperty
|
| 1213 |
+
def _dtype(self):
|
| 1214 |
+
return torch.cdouble
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
class ComplexFloatStorage(_CudaLegacyStorage):
|
| 1218 |
+
@classproperty
|
| 1219 |
+
def dtype(self):
|
| 1220 |
+
_warn_typed_storage_removal()
|
| 1221 |
+
return self._dtype
|
| 1222 |
+
|
| 1223 |
+
@classproperty
|
| 1224 |
+
def _dtype(self):
|
| 1225 |
+
return torch.cfloat
|
| 1226 |
+
|
| 1227 |
+
|
| 1228 |
+
del _LegacyStorage
|
| 1229 |
+
del _CudaLegacyStorage
|
| 1230 |
+
|
| 1231 |
+
torch._storage_classes.add(DoubleStorage)
|
| 1232 |
+
torch._storage_classes.add(FloatStorage)
|
| 1233 |
+
torch._storage_classes.add(LongStorage)
|
| 1234 |
+
torch._storage_classes.add(IntStorage)
|
| 1235 |
+
torch._storage_classes.add(ShortStorage)
|
| 1236 |
+
torch._storage_classes.add(CharStorage)
|
| 1237 |
+
torch._storage_classes.add(ByteStorage)
|
| 1238 |
+
torch._storage_classes.add(HalfStorage)
|
| 1239 |
+
torch._storage_classes.add(BoolStorage)
|
| 1240 |
+
torch._storage_classes.add(BFloat16Storage)
|
| 1241 |
+
torch._storage_classes.add(ComplexDoubleStorage)
|
| 1242 |
+
torch._storage_classes.add(ComplexFloatStorage)
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
class _WrappedTritonKernel:
|
| 1246 |
+
"""Just a simple wrapper to store some metadata for testing purposes."""
|
| 1247 |
+
|
| 1248 |
+
def __init__(self, kernel):
|
| 1249 |
+
self.kernel = kernel
|
| 1250 |
+
self.kernel_invoked = False
|
| 1251 |
+
|
| 1252 |
+
def __call__(self, *args, **kwargs):
|
| 1253 |
+
res = self.kernel(*args, **kwargs)
|
| 1254 |
+
self.kernel_invoked = True
|
| 1255 |
+
return res
|
| 1256 |
+
|
| 1257 |
+
|
| 1258 |
+
def _register_triton_kernels():
|
| 1259 |
+
if torch._running_with_deploy():
|
| 1260 |
+
return
|
| 1261 |
+
|
| 1262 |
+
@_WrappedTritonKernel
|
| 1263 |
+
def kernel_impl(*args, **kwargs):
|
| 1264 |
+
from torch.sparse._triton_ops import bsr_dense_mm
|
| 1265 |
+
|
| 1266 |
+
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
|
| 1267 |
+
|
| 1268 |
+
@_WrappedTritonKernel
|
| 1269 |
+
def addmm_kernel_impl(*args, **kwargs):
|
| 1270 |
+
from torch.sparse._triton_ops import bsr_dense_addmm
|
| 1271 |
+
|
| 1272 |
+
return bsr_dense_addmm(*args, skip_checks=True, **kwargs)
|
| 1273 |
+
|
| 1274 |
+
has_triton = importlib.util.find_spec("triton") is not None
|
| 1275 |
+
if has_triton:
|
| 1276 |
+
torch._TritonLibrary.registerOp(
|
| 1277 |
+
"_triton_bsr_dense_mm_out",
|
| 1278 |
+
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
|
| 1279 |
+
kernel_impl,
|
| 1280 |
+
"SparseCsrCUDA",
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
torch._TritonLibrary.registerOp(
|
| 1284 |
+
"_triton_bsr_dense_addmm_out",
|
| 1285 |
+
(
|
| 1286 |
+
"_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense,"
|
| 1287 |
+
" *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)"
|
| 1288 |
+
),
|
| 1289 |
+
addmm_kernel_impl,
|
| 1290 |
+
"SparseCsrCUDA",
|
| 1291 |
+
)
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
_lazy_call(_register_triton_kernels)
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
from . import amp, jiterator, nvtx, profiler, sparse
|
| 1298 |
+
|
| 1299 |
+
__all__ = [
|
| 1300 |
+
# Typed storage and tensors
|
| 1301 |
+
"BFloat16Storage",
|
| 1302 |
+
"BFloat16Tensor",
|
| 1303 |
+
"BoolStorage",
|
| 1304 |
+
"BoolTensor",
|
| 1305 |
+
"ByteStorage",
|
| 1306 |
+
"ByteTensor",
|
| 1307 |
+
"CharStorage",
|
| 1308 |
+
"CharTensor",
|
| 1309 |
+
"ComplexDoubleStorage",
|
| 1310 |
+
"ComplexFloatStorage",
|
| 1311 |
+
"DoubleStorage",
|
| 1312 |
+
"DoubleTensor",
|
| 1313 |
+
"FloatStorage",
|
| 1314 |
+
"FloatTensor",
|
| 1315 |
+
"HalfStorage",
|
| 1316 |
+
"HalfTensor",
|
| 1317 |
+
"IntStorage",
|
| 1318 |
+
"IntTensor",
|
| 1319 |
+
"LongStorage",
|
| 1320 |
+
"LongTensor",
|
| 1321 |
+
"ShortStorage",
|
| 1322 |
+
"ShortTensor",
|
| 1323 |
+
"CUDAGraph",
|
| 1324 |
+
"CudaError",
|
| 1325 |
+
"DeferredCudaCallError",
|
| 1326 |
+
"Event",
|
| 1327 |
+
"ExternalStream",
|
| 1328 |
+
"OutOfMemoryError",
|
| 1329 |
+
"Stream",
|
| 1330 |
+
"StreamContext",
|
| 1331 |
+
"amp",
|
| 1332 |
+
"caching_allocator_alloc",
|
| 1333 |
+
"caching_allocator_delete",
|
| 1334 |
+
"can_device_access_peer",
|
| 1335 |
+
"check_error",
|
| 1336 |
+
"cudaStatus",
|
| 1337 |
+
"cudart",
|
| 1338 |
+
"current_blas_handle",
|
| 1339 |
+
"current_device",
|
| 1340 |
+
"current_stream",
|
| 1341 |
+
"default_generators",
|
| 1342 |
+
"default_stream",
|
| 1343 |
+
"device",
|
| 1344 |
+
"device_count",
|
| 1345 |
+
"device_of",
|
| 1346 |
+
"empty_cache",
|
| 1347 |
+
"get_allocator_backend",
|
| 1348 |
+
"CUDAPluggableAllocator",
|
| 1349 |
+
"change_current_allocator",
|
| 1350 |
+
"get_arch_list",
|
| 1351 |
+
"get_device_capability",
|
| 1352 |
+
"get_device_name",
|
| 1353 |
+
"get_device_properties",
|
| 1354 |
+
"get_gencode_flags",
|
| 1355 |
+
"get_rng_state",
|
| 1356 |
+
"get_rng_state_all",
|
| 1357 |
+
"get_sync_debug_mode",
|
| 1358 |
+
"graph",
|
| 1359 |
+
"graph_pool_handle",
|
| 1360 |
+
"graphs",
|
| 1361 |
+
"has_half",
|
| 1362 |
+
"has_magma",
|
| 1363 |
+
"init",
|
| 1364 |
+
"initial_seed",
|
| 1365 |
+
"ipc_collect",
|
| 1366 |
+
"is_available",
|
| 1367 |
+
"is_bf16_supported",
|
| 1368 |
+
"is_current_stream_capturing",
|
| 1369 |
+
"is_initialized",
|
| 1370 |
+
"jiterator",
|
| 1371 |
+
"list_gpu_processes",
|
| 1372 |
+
"make_graphed_callables",
|
| 1373 |
+
"manual_seed",
|
| 1374 |
+
"manual_seed_all",
|
| 1375 |
+
"max_memory_allocated",
|
| 1376 |
+
"max_memory_cached",
|
| 1377 |
+
"max_memory_reserved",
|
| 1378 |
+
"mem_get_info",
|
| 1379 |
+
"memory",
|
| 1380 |
+
"memory_allocated",
|
| 1381 |
+
"memory_cached",
|
| 1382 |
+
"memory_reserved",
|
| 1383 |
+
"memory_snapshot",
|
| 1384 |
+
"memory_stats",
|
| 1385 |
+
"memory_stats_as_nested_dict",
|
| 1386 |
+
"memory_summary",
|
| 1387 |
+
"memory_usage",
|
| 1388 |
+
"temperature",
|
| 1389 |
+
"power_draw",
|
| 1390 |
+
"clock_rate",
|
| 1391 |
+
"nccl",
|
| 1392 |
+
"nvtx",
|
| 1393 |
+
"profiler",
|
| 1394 |
+
"random",
|
| 1395 |
+
"reset_accumulated_memory_stats",
|
| 1396 |
+
"reset_max_memory_allocated",
|
| 1397 |
+
"reset_max_memory_cached",
|
| 1398 |
+
"reset_peak_memory_stats",
|
| 1399 |
+
"seed",
|
| 1400 |
+
"seed_all",
|
| 1401 |
+
"set_device",
|
| 1402 |
+
"set_per_process_memory_fraction",
|
| 1403 |
+
"set_rng_state",
|
| 1404 |
+
"set_rng_state_all",
|
| 1405 |
+
"set_stream",
|
| 1406 |
+
"set_sync_debug_mode",
|
| 1407 |
+
"sparse",
|
| 1408 |
+
"stream",
|
| 1409 |
+
"streams",
|
| 1410 |
+
"synchronize",
|
| 1411 |
+
"utilization",
|
| 1412 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc
ADDED
|
Binary file (44.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""
|
| 2 |
+
This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
|
| 3 |
+
|
| 4 |
+
It stores information on accesses to tensors to determine if they are synchronized
|
| 5 |
+
or not. When enabled in a python program and a possible data race is detected, a
|
| 6 |
+
detailed warning will be printed and the program will exit.
|
| 7 |
+
|
| 8 |
+
It can be enabled either by importing this module and calling
|
| 9 |
+
:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
|
| 10 |
+
environment variable.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import enum
|
| 14 |
+
import functools
|
| 15 |
+
import inspect
|
| 16 |
+
import io
|
| 17 |
+
import logging
|
| 18 |
+
import sys
|
| 19 |
+
import textwrap
|
| 20 |
+
import traceback
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.utils._cuda_trace as cuda_trace
|
| 26 |
+
from torch.utils import _pytree as pytree
|
| 27 |
+
from torch.utils._python_dispatch import TorchDispatchMode
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DEFAULT_STREAM_ID = 0
|
| 31 |
+
|
| 32 |
+
TK = TypeVar("TK")
|
| 33 |
+
TVa = TypeVar("TVa")
|
| 34 |
+
TVb = TypeVar("TVb")
|
| 35 |
+
|
| 36 |
+
DataPtr = int
|
| 37 |
+
StreamId = int
|
| 38 |
+
EventId = int
|
| 39 |
+
SeqNum = int
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class AccessType(enum.Enum):
|
| 45 |
+
READ = enum.auto()
|
| 46 |
+
WRITE = enum.auto()
|
| 47 |
+
|
| 48 |
+
def __str__(self):
|
| 49 |
+
return "reading from" if self is AccessType.READ else "writing to"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class Access:
|
| 54 |
+
r"""Stores information about a single access to a tensor by a kernel.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
type: either AccessType.READ or AccessType.Write.
|
| 58 |
+
seq_num: the sequential number of the kernel performing the access.
|
| 59 |
+
stream: the stream id of the stream executing the kernel.
|
| 60 |
+
operator: the schema of the launched kernel, which lists the
|
| 61 |
+
arguments and return type.
|
| 62 |
+
aliases: the arguments in the schema this access corresponds to.
|
| 63 |
+
is_output: Whether the tensor was an output of the kernel.
|
| 64 |
+
stack_trace: the stack summary object captured during access.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
type: AccessType
|
| 68 |
+
seq_num: SeqNum
|
| 69 |
+
stream: StreamId
|
| 70 |
+
operator: str
|
| 71 |
+
aliases: List[str]
|
| 72 |
+
is_output: bool
|
| 73 |
+
stack_trace: traceback.StackSummary
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class SynchronizationError(Exception):
|
| 77 |
+
"""Base class for errors detected by CUDA Sanitizer."""
|
| 78 |
+
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class UnsynchronizedAccessError(SynchronizationError):
|
| 83 |
+
"""Stores information about two unsynchronized accesses to one data pointer."""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
data_ptr: DataPtr,
|
| 88 |
+
allocation_stack_trace: Optional[traceback.StackSummary],
|
| 89 |
+
current_access: Access,
|
| 90 |
+
previous_access: Access,
|
| 91 |
+
):
|
| 92 |
+
self.data_ptr = data_ptr
|
| 93 |
+
self.allocation_stack_trace = allocation_stack_trace
|
| 94 |
+
self.current_access = current_access
|
| 95 |
+
self.previous_access = previous_access
|
| 96 |
+
|
| 97 |
+
def __str__(self):
|
| 98 |
+
def format_access(access: Access):
|
| 99 |
+
message.write(f"{access.operator}\n{access.type}")
|
| 100 |
+
if access.aliases:
|
| 101 |
+
message.write(" argument(s) " + ", ".join(access.aliases))
|
| 102 |
+
if access.is_output:
|
| 103 |
+
message.write(", and to")
|
| 104 |
+
if access.is_output:
|
| 105 |
+
message.write(" the output")
|
| 106 |
+
message.write(
|
| 107 |
+
f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
with io.StringIO() as message:
|
| 111 |
+
message.write(
|
| 112 |
+
textwrap.dedent(
|
| 113 |
+
f"""\
|
| 114 |
+
============================
|
| 115 |
+
CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
|
| 116 |
+
Access by stream {self.current_access.stream} during kernel:
|
| 117 |
+
"""
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
format_access(self.current_access)
|
| 121 |
+
|
| 122 |
+
message.write(
|
| 123 |
+
f"Previous access by stream {self.previous_access.stream} during kernel:\n"
|
| 124 |
+
)
|
| 125 |
+
format_access(self.previous_access)
|
| 126 |
+
|
| 127 |
+
if self.allocation_stack_trace:
|
| 128 |
+
message.write(
|
| 129 |
+
"Tensor was allocated with stack trace:\n"
|
| 130 |
+
f"{''.join(self.allocation_stack_trace.format())}"
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
message.write("Trace for tensor allocation not found.")
|
| 134 |
+
return message.getvalue()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class CUDASanitizerErrors(Exception):
|
| 138 |
+
"""Wrapper class for errors reported by CUDA Sanitizer."""
|
| 139 |
+
|
| 140 |
+
def __init__(self, errors: List[SynchronizationError]):
|
| 141 |
+
self.errors = errors
|
| 142 |
+
|
| 143 |
+
def __str__(self):
|
| 144 |
+
return f"detected {len(self.errors)} errors"
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@dataclass
|
| 148 |
+
class TensorInfo:
|
| 149 |
+
r"""Stores information about a single tensor and recent accesses to it.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
allocation_stack_trace: the stack summary object captured during tensor
|
| 153 |
+
allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
|
| 154 |
+
reads: list of read accesses to the tensor that were performed since
|
| 155 |
+
the last write.
|
| 156 |
+
write: the last write access to the tensor.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
allocation_stack_trace: Optional[traceback.StackSummary]
|
| 160 |
+
reads: List[Access] = field(default_factory=list)
|
| 161 |
+
write: Optional[Access] = None
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class _TensorsAccessed:
|
| 165 |
+
def __init__(self):
|
| 166 |
+
self.accesses: Dict[DataPtr, TensorInfo] = {}
|
| 167 |
+
|
| 168 |
+
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
|
| 169 |
+
if data_ptr not in self.accesses:
|
| 170 |
+
logger.info(
|
| 171 |
+
"Found tensor with pointer: %s, but no matching tensor "
|
| 172 |
+
"allocation in the trace. Backfilling the trace now. "
|
| 173 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 174 |
+
data_ptr,
|
| 175 |
+
)
|
| 176 |
+
self.create_tensor(data_ptr, None)
|
| 177 |
+
|
| 178 |
+
def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
|
| 179 |
+
if data_ptr in self.accesses:
|
| 180 |
+
logger.info(
|
| 181 |
+
"Found duplicate tensor allocation in the trace for tensor with "
|
| 182 |
+
"pointer: %s. Assuming the trace for tensor deallocation "
|
| 183 |
+
"wasn't caught and backfilling it now. "
|
| 184 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 185 |
+
data_ptr,
|
| 186 |
+
)
|
| 187 |
+
self.delete_tensor(data_ptr)
|
| 188 |
+
|
| 189 |
+
def create_tensor(
|
| 190 |
+
self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
|
| 191 |
+
) -> None:
|
| 192 |
+
self.accesses[data_ptr] = TensorInfo(stack_trace)
|
| 193 |
+
|
| 194 |
+
def delete_tensor(self, data_ptr: DataPtr) -> None:
|
| 195 |
+
del self.accesses[data_ptr]
|
| 196 |
+
|
| 197 |
+
def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
|
| 198 |
+
return True if self.accesses[data_ptr].reads else False
|
| 199 |
+
|
| 200 |
+
def get_allocation_stack_trace(
|
| 201 |
+
self, data_ptr: DataPtr
|
| 202 |
+
) -> Optional[traceback.StackSummary]:
|
| 203 |
+
return self.accesses[data_ptr].allocation_stack_trace
|
| 204 |
+
|
| 205 |
+
def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
|
| 206 |
+
return self.accesses[data_ptr].write
|
| 207 |
+
|
| 208 |
+
def get_reads(self, data_ptr: DataPtr) -> List[Access]:
|
| 209 |
+
return self.accesses[data_ptr].reads
|
| 210 |
+
|
| 211 |
+
def add_read(self, data_ptr: DataPtr, access: Access) -> None:
|
| 212 |
+
self.accesses[data_ptr].reads.append(access)
|
| 213 |
+
|
| 214 |
+
def set_write(self, data_ptr: DataPtr, access: Access) -> None:
|
| 215 |
+
self.accesses[data_ptr].write = access
|
| 216 |
+
self.accesses[data_ptr].reads = []
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class StreamSynchronizations:
|
| 220 |
+
def __init__(self):
|
| 221 |
+
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
|
| 222 |
+
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
|
| 223 |
+
self.host_sync_state: Dict[StreamId, SeqNum] = {}
|
| 224 |
+
self.create_stream(DEFAULT_STREAM_ID)
|
| 225 |
+
|
| 226 |
+
def _ensure_stream_exists(self, stream: StreamId) -> None:
|
| 227 |
+
if stream not in self.current_sync_states:
|
| 228 |
+
logger.info(
|
| 229 |
+
"Found Stream with id: %s, but no matching stream "
|
| 230 |
+
"creation in the trace. Backfilling the trace now. "
|
| 231 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 232 |
+
stream,
|
| 233 |
+
)
|
| 234 |
+
self.create_stream(stream)
|
| 235 |
+
|
| 236 |
+
def _ensure_event_exists(self, event: EventId) -> None:
|
| 237 |
+
if event not in self.recorded_sync_states:
|
| 238 |
+
logger.info(
|
| 239 |
+
"Found Event with id: %s, but no matching event "
|
| 240 |
+
"creation in the trace. Backfilling the trace now. "
|
| 241 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 242 |
+
event,
|
| 243 |
+
)
|
| 244 |
+
self.create_event(event)
|
| 245 |
+
|
| 246 |
+
def _ensure_event_does_not_exist(self, event: EventId) -> None:
|
| 247 |
+
if event in self.recorded_sync_states:
|
| 248 |
+
logger.info(
|
| 249 |
+
"Found duplicate event creation in the trace for event with "
|
| 250 |
+
"id: %s. Assuming the trace for event deletion wasn't caught "
|
| 251 |
+
"and backfilling it now. "
|
| 252 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 253 |
+
event,
|
| 254 |
+
)
|
| 255 |
+
self.delete_event(event)
|
| 256 |
+
|
| 257 |
+
def create_stream(self, stream: StreamId) -> None:
|
| 258 |
+
if stream in self.current_sync_states:
|
| 259 |
+
logger.info(
|
| 260 |
+
"Found duplicate Stream creation in the trace for Stream with "
|
| 261 |
+
"id: %s. PyTorch Streams are only created once, so this "
|
| 262 |
+
"trace entry is ignored.",
|
| 263 |
+
stream,
|
| 264 |
+
)
|
| 265 |
+
else:
|
| 266 |
+
self.host_sync_state[stream] = 0
|
| 267 |
+
self.current_sync_states[stream] = self.host_sync_state.copy()
|
| 268 |
+
|
| 269 |
+
def create_event(self, event: EventId) -> None:
|
| 270 |
+
self._ensure_event_does_not_exist(event)
|
| 271 |
+
self.recorded_sync_states[event] = {}
|
| 272 |
+
|
| 273 |
+
def delete_event(self, event: EventId) -> None:
|
| 274 |
+
self._ensure_event_exists(event)
|
| 275 |
+
del self.recorded_sync_states[event]
|
| 276 |
+
|
| 277 |
+
def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
|
| 278 |
+
self._ensure_stream_exists(stream)
|
| 279 |
+
self.current_sync_states[stream][stream] = seq_num
|
| 280 |
+
|
| 281 |
+
def record_state(self, event: EventId, stream: StreamId) -> None:
|
| 282 |
+
self._ensure_event_exists(event)
|
| 283 |
+
self._ensure_stream_exists(stream)
|
| 284 |
+
self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
|
| 285 |
+
|
| 286 |
+
def _state_wait_for_other(
|
| 287 |
+
self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
|
| 288 |
+
) -> None:
|
| 289 |
+
for stream, seq_num in other.items():
|
| 290 |
+
state[stream] = max(state.get(stream, -1), seq_num)
|
| 291 |
+
|
| 292 |
+
def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
|
| 293 |
+
self._ensure_stream_exists(stream)
|
| 294 |
+
self._ensure_event_exists(event)
|
| 295 |
+
self._state_wait_for_other(
|
| 296 |
+
self.current_sync_states[stream], self.recorded_sync_states[event]
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def all_streams_wait_for_event(self, event: EventId) -> None:
|
| 300 |
+
self._ensure_event_exists(event)
|
| 301 |
+
for stream in self.current_sync_states.keys():
|
| 302 |
+
self.stream_wait_for_event(stream, event)
|
| 303 |
+
|
| 304 |
+
self._state_wait_for_other(
|
| 305 |
+
self.host_sync_state, self.recorded_sync_states[event]
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def all_streams_wait_for_stream(self, stream: StreamId) -> None:
|
| 309 |
+
self._ensure_stream_exists(stream)
|
| 310 |
+
for state in self.current_sync_states.values():
|
| 311 |
+
self._state_wait_for_other(state, self.current_sync_states[stream])
|
| 312 |
+
|
| 313 |
+
self._state_wait_for_other(
|
| 314 |
+
self.host_sync_state, self.current_sync_states[stream]
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def sync_all_streams(self) -> None:
|
| 318 |
+
for stream, state in self.current_sync_states.items():
|
| 319 |
+
self.host_sync_state[stream] = state[stream]
|
| 320 |
+
|
| 321 |
+
for state in self.current_sync_states.values():
|
| 322 |
+
self._state_wait_for_other(state, self.host_sync_state)
|
| 323 |
+
|
| 324 |
+
def is_ordered_after(
|
| 325 |
+
self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
|
| 326 |
+
) -> bool:
|
| 327 |
+
self._ensure_stream_exists(current_stream)
|
| 328 |
+
self._ensure_stream_exists(other_stream)
|
| 329 |
+
return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class EventHandler:
|
| 333 |
+
"""Analyzes CSAN trace for synchronization errors.
|
| 334 |
+
|
| 335 |
+
Stores information on each stream's synchronizations with other streams as well
|
| 336 |
+
as tensor accesses to determine whether a given kernel launch might cause a
|
| 337 |
+
data race.
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def __init__(self):
|
| 341 |
+
self.tensors_accessed = _TensorsAccessed()
|
| 342 |
+
self.syncs = StreamSynchronizations()
|
| 343 |
+
self.seq_num: SeqNum = 0
|
| 344 |
+
|
| 345 |
+
def _handle_kernel_launch(
|
| 346 |
+
self,
|
| 347 |
+
stream: StreamId,
|
| 348 |
+
read_only: Set[DataPtr],
|
| 349 |
+
read_write: Set[DataPtr],
|
| 350 |
+
outputs: Set[DataPtr],
|
| 351 |
+
operator: str,
|
| 352 |
+
tensor_aliases: Dict[int, List[str]],
|
| 353 |
+
) -> List[SynchronizationError]:
|
| 354 |
+
def check_conflict(
|
| 355 |
+
data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
|
| 356 |
+
) -> None:
|
| 357 |
+
if previous_access is None:
|
| 358 |
+
return
|
| 359 |
+
if not self.syncs.is_ordered_after(
|
| 360 |
+
current_access.stream, previous_access.seq_num, previous_access.stream
|
| 361 |
+
):
|
| 362 |
+
error_list.append(
|
| 363 |
+
UnsynchronizedAccessError(
|
| 364 |
+
data_ptr,
|
| 365 |
+
self.tensors_accessed.get_allocation_stack_trace(data_ptr),
|
| 366 |
+
current_access,
|
| 367 |
+
previous_access,
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
error_list: List[SynchronizationError] = []
|
| 372 |
+
self.seq_num += 1
|
| 373 |
+
self.syncs.update_seq_num(stream, self.seq_num)
|
| 374 |
+
stack_trace = traceback.StackSummary.extract(
|
| 375 |
+
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
|
| 376 |
+
)
|
| 377 |
+
# The stack trace generated in this way is in the inverse order, so it must be
|
| 378 |
+
# reversed.
|
| 379 |
+
stack_trace.reverse()
|
| 380 |
+
|
| 381 |
+
for data_ptr in read_only:
|
| 382 |
+
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
| 383 |
+
current_access = Access(
|
| 384 |
+
AccessType.READ,
|
| 385 |
+
self.seq_num,
|
| 386 |
+
stream,
|
| 387 |
+
operator,
|
| 388 |
+
tensor_aliases[data_ptr],
|
| 389 |
+
data_ptr in outputs,
|
| 390 |
+
stack_trace,
|
| 391 |
+
)
|
| 392 |
+
check_conflict(
|
| 393 |
+
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
|
| 394 |
+
)
|
| 395 |
+
self.tensors_accessed.add_read(data_ptr, current_access)
|
| 396 |
+
|
| 397 |
+
for data_ptr in read_write:
|
| 398 |
+
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
| 399 |
+
current_access = Access(
|
| 400 |
+
AccessType.WRITE,
|
| 401 |
+
self.seq_num,
|
| 402 |
+
stream,
|
| 403 |
+
operator,
|
| 404 |
+
tensor_aliases[data_ptr],
|
| 405 |
+
data_ptr in outputs,
|
| 406 |
+
stack_trace,
|
| 407 |
+
)
|
| 408 |
+
if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
|
| 409 |
+
for previous_access in self.tensors_accessed.get_reads(data_ptr):
|
| 410 |
+
check_conflict(data_ptr, current_access, previous_access)
|
| 411 |
+
else:
|
| 412 |
+
check_conflict(
|
| 413 |
+
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
|
| 414 |
+
)
|
| 415 |
+
self.tensors_accessed.set_write(data_ptr, current_access)
|
| 416 |
+
|
| 417 |
+
return error_list
|
| 418 |
+
|
| 419 |
+
def _handle_event_creation(self, event: EventId) -> None:
|
| 420 |
+
self.syncs.create_event(event)
|
| 421 |
+
|
| 422 |
+
def _handle_event_deletion(self, event: EventId) -> None:
|
| 423 |
+
self.syncs.delete_event(event)
|
| 424 |
+
|
| 425 |
+
def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
|
| 426 |
+
self.syncs.record_state(event, stream)
|
| 427 |
+
|
| 428 |
+
def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
|
| 429 |
+
self.syncs.stream_wait_for_event(stream, event)
|
| 430 |
+
|
| 431 |
+
def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
|
| 432 |
+
self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
|
| 433 |
+
stack_trace = traceback.StackSummary.extract(
|
| 434 |
+
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
|
| 435 |
+
)
|
| 436 |
+
# The stack trace generated in this way is in the inverse order, so it must be
|
| 437 |
+
# reversed.
|
| 438 |
+
stack_trace.reverse()
|
| 439 |
+
self.tensors_accessed.create_tensor(
|
| 440 |
+
data_ptr,
|
| 441 |
+
stack_trace,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
|
| 445 |
+
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
| 446 |
+
self.tensors_accessed.delete_tensor(data_ptr)
|
| 447 |
+
|
| 448 |
+
def _handle_stream_creation(self, stream: StreamId) -> None:
|
| 449 |
+
self.syncs.create_stream(stream)
|
| 450 |
+
|
| 451 |
+
def _handle_device_synchronization(self) -> None:
|
| 452 |
+
self.syncs.sync_all_streams()
|
| 453 |
+
|
| 454 |
+
def _handle_stream_synchronization(self, stream: StreamId) -> None:
|
| 455 |
+
self.syncs.all_streams_wait_for_stream(stream)
|
| 456 |
+
|
| 457 |
+
def _handle_event_synchronization(self, event: EventId) -> None:
|
| 458 |
+
self.syncs.all_streams_wait_for_event(event)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
|
| 462 |
+
for arg, value in a.items():
|
| 463 |
+
if arg in b:
|
| 464 |
+
yield arg, value, b[arg]
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def zip_arguments(
|
| 468 |
+
schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 469 |
+
) -> Iterator[Tuple[torch.Argument, Any]]:
|
| 470 |
+
schema_args = schema.arguments[: len(args)]
|
| 471 |
+
schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
|
| 472 |
+
|
| 473 |
+
yield from zip(schema_args, args)
|
| 474 |
+
|
| 475 |
+
for _, argument, value in zip_by_key(schema_kwargs, kwargs):
|
| 476 |
+
yield (argument, value)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class ArgumentHandler:
|
| 480 |
+
def __init__(self):
|
| 481 |
+
self.dataptrs_read: Set[DataPtr] = set()
|
| 482 |
+
self.dataptrs_written: Set[DataPtr] = set()
|
| 483 |
+
self.tensor_aliases: Dict[DataPtr, List[str]] = dict()
|
| 484 |
+
self.outputs: Set[DataPtr] = set()
|
| 485 |
+
|
| 486 |
+
def _handle_argument(
|
| 487 |
+
self,
|
| 488 |
+
value: Any,
|
| 489 |
+
is_write: bool,
|
| 490 |
+
name: Optional[str] = None,
|
| 491 |
+
is_output: bool = False,
|
| 492 |
+
) -> None:
|
| 493 |
+
if isinstance(value, torch.Tensor) and value.is_cuda:
|
| 494 |
+
data_ptr = value.data_ptr()
|
| 495 |
+
if is_write:
|
| 496 |
+
self.dataptrs_written.add(data_ptr)
|
| 497 |
+
else:
|
| 498 |
+
self.dataptrs_read.add(data_ptr)
|
| 499 |
+
|
| 500 |
+
self.tensor_aliases.setdefault(data_ptr, [])
|
| 501 |
+
if name is not None:
|
| 502 |
+
self.tensor_aliases[data_ptr].append(name)
|
| 503 |
+
if is_output:
|
| 504 |
+
self.outputs.add(data_ptr)
|
| 505 |
+
|
| 506 |
+
def parse_inputs(
|
| 507 |
+
self,
|
| 508 |
+
schema: torch.FunctionSchema,
|
| 509 |
+
args: Tuple[Any, ...],
|
| 510 |
+
kwargs: Dict[str, Any],
|
| 511 |
+
) -> None:
|
| 512 |
+
for argument, value in zip_arguments(schema, args, kwargs):
|
| 513 |
+
is_write = argument.alias_info is not None and argument.alias_info.is_write
|
| 514 |
+
pytree.tree_map_(
|
| 515 |
+
functools.partial(
|
| 516 |
+
self._handle_argument, is_write=is_write, name=argument.name
|
| 517 |
+
),
|
| 518 |
+
value,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
def parse_outputs(self, outputs: Any) -> None:
|
| 522 |
+
pytree.tree_map_(
|
| 523 |
+
functools.partial(self._handle_argument, is_write=True, is_output=True),
|
| 524 |
+
outputs,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class CUDASanitizerDispatchMode(TorchDispatchMode):
|
| 529 |
+
def __init__(self):
|
| 530 |
+
self.event_handler = EventHandler()
|
| 531 |
+
torch._C._activate_cuda_trace()
|
| 532 |
+
cuda_trace.register_callback_for_cuda_event_creation(
|
| 533 |
+
self.event_handler._handle_event_creation
|
| 534 |
+
)
|
| 535 |
+
cuda_trace.register_callback_for_cuda_event_deletion(
|
| 536 |
+
self.event_handler._handle_event_deletion
|
| 537 |
+
)
|
| 538 |
+
cuda_trace.register_callback_for_cuda_event_record(
|
| 539 |
+
self.event_handler._handle_event_record
|
| 540 |
+
)
|
| 541 |
+
cuda_trace.register_callback_for_cuda_event_wait(
|
| 542 |
+
self.event_handler._handle_event_wait
|
| 543 |
+
)
|
| 544 |
+
cuda_trace.register_callback_for_cuda_memory_allocation(
|
| 545 |
+
self.event_handler._handle_memory_allocation
|
| 546 |
+
)
|
| 547 |
+
cuda_trace.register_callback_for_cuda_memory_deallocation(
|
| 548 |
+
self.event_handler._handle_memory_deallocation
|
| 549 |
+
)
|
| 550 |
+
cuda_trace.register_callback_for_cuda_stream_creation(
|
| 551 |
+
self.event_handler._handle_stream_creation
|
| 552 |
+
)
|
| 553 |
+
cuda_trace.register_callback_for_cuda_device_synchronization(
|
| 554 |
+
self.event_handler._handle_device_synchronization
|
| 555 |
+
)
|
| 556 |
+
cuda_trace.register_callback_for_cuda_stream_synchronization(
|
| 557 |
+
self.event_handler._handle_stream_synchronization
|
| 558 |
+
)
|
| 559 |
+
cuda_trace.register_callback_for_cuda_event_synchronization(
|
| 560 |
+
self.event_handler._handle_event_synchronization
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
| 564 |
+
if kwargs is None:
|
| 565 |
+
kwargs = {}
|
| 566 |
+
|
| 567 |
+
argument_handler = ArgumentHandler()
|
| 568 |
+
argument_handler.parse_inputs(func._schema, args, kwargs)
|
| 569 |
+
|
| 570 |
+
outputs = func(*args, **kwargs)
|
| 571 |
+
|
| 572 |
+
argument_handler.parse_outputs(outputs)
|
| 573 |
+
errors = self.event_handler._handle_kernel_launch(
|
| 574 |
+
torch.cuda.current_stream().cuda_stream,
|
| 575 |
+
argument_handler.dataptrs_read - argument_handler.dataptrs_written,
|
| 576 |
+
argument_handler.dataptrs_written,
|
| 577 |
+
argument_handler.outputs,
|
| 578 |
+
func._schema,
|
| 579 |
+
argument_handler.tensor_aliases,
|
| 580 |
+
)
|
| 581 |
+
if errors:
|
| 582 |
+
for error in errors:
|
| 583 |
+
print(error, file=sys.stderr)
|
| 584 |
+
raise CUDASanitizerErrors(errors)
|
| 585 |
+
|
| 586 |
+
return outputs
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
class CUDASanitizer:
|
| 590 |
+
"""Manages the lifetime of a CUDASanitizer dispatch mode object.
|
| 591 |
+
|
| 592 |
+
The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
|
| 593 |
+
context manager in the enable function/destructor, respectively. This is to
|
| 594 |
+
explicitly set the lifetime of the dispatch mode object to that of the application.
|
| 595 |
+
This approach was deemed more elegant than using the atexit module.
|
| 596 |
+
"""
|
| 597 |
+
|
| 598 |
+
def __init__(self):
|
| 599 |
+
self.dispatch = CUDASanitizerDispatchMode()
|
| 600 |
+
self.enabled = False
|
| 601 |
+
|
| 602 |
+
def enable(self):
|
| 603 |
+
self.dispatch.__enter__()
|
| 604 |
+
self.enabled = True
|
| 605 |
+
|
| 606 |
+
def __del__(self):
|
| 607 |
+
if self.enabled:
|
| 608 |
+
self.dispatch.__exit__(None, None, None)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def enable_cuda_sanitizer():
|
| 612 |
+
"""Enable CUDA Sanitizer.
|
| 613 |
+
|
| 614 |
+
The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
|
| 615 |
+
for synchronization errors. All data races found will be printed to the standard
|
| 616 |
+
error output along with stack traces of suspected causes. For best results, the
|
| 617 |
+
sanitizer should be enabled at the very beginning of the program.
|
| 618 |
+
"""
|
| 619 |
+
cuda_sanitizer.enable()
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
cuda_sanitizer = CUDASanitizer()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-311.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/comm.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The functions here have been moved to torch.nn.parallel.comm
|
| 2 |
+
from torch.nn.parallel.comm import (
|
| 3 |
+
broadcast,
|
| 4 |
+
broadcast_coalesced,
|
| 5 |
+
gather,
|
| 6 |
+
reduce_add,
|
| 7 |
+
reduce_add_coalesced,
|
| 8 |
+
scatter,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"broadcast",
|
| 13 |
+
"broadcast_coalesced",
|
| 14 |
+
"reduce_add",
|
| 15 |
+
"reduce_add_coalesced",
|
| 16 |
+
"scatter",
|
| 17 |
+
"gather",
|
| 18 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/memory.py
ADDED
|
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""This package adds support for device memory management implemented in CUDA."""
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import contextlib
|
| 5 |
+
import ctypes
|
| 6 |
+
import pickle
|
| 7 |
+
import sys
|
| 8 |
+
import warnings
|
| 9 |
+
from inspect import signature
|
| 10 |
+
|
| 11 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import _C
|
| 15 |
+
|
| 16 |
+
from torch.types import Device
|
| 17 |
+
from .._utils import _dummy_type
|
| 18 |
+
from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized
|
| 19 |
+
|
| 20 |
+
from ._memory_viz import memory as _memory, segments as _segments
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"caching_allocator_alloc",
|
| 24 |
+
"caching_allocator_delete",
|
| 25 |
+
"set_per_process_memory_fraction",
|
| 26 |
+
"empty_cache",
|
| 27 |
+
"memory_stats",
|
| 28 |
+
"memory_stats_as_nested_dict",
|
| 29 |
+
"reset_accumulated_memory_stats",
|
| 30 |
+
"reset_peak_memory_stats",
|
| 31 |
+
"reset_max_memory_allocated",
|
| 32 |
+
"reset_max_memory_cached",
|
| 33 |
+
"memory_allocated",
|
| 34 |
+
"max_memory_allocated",
|
| 35 |
+
"memory_reserved",
|
| 36 |
+
"max_memory_reserved",
|
| 37 |
+
"memory_cached",
|
| 38 |
+
"max_memory_cached",
|
| 39 |
+
"memory_snapshot",
|
| 40 |
+
"memory_summary",
|
| 41 |
+
"list_gpu_processes",
|
| 42 |
+
"mem_get_info",
|
| 43 |
+
"get_allocator_backend",
|
| 44 |
+
"CUDAPluggableAllocator",
|
| 45 |
+
"change_current_allocator",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if not hasattr(torch._C, "_cuda_CUDAAllocator"):
|
| 50 |
+
# Define dummy base classes
|
| 51 |
+
torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _host_allocator():
|
| 55 |
+
_lazy_init()
|
| 56 |
+
return torch._C._cuda_cudaHostAllocator()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@contextlib.contextmanager
|
| 60 |
+
def _free_mutex():
|
| 61 |
+
torch._C._cuda_lock_mutex()
|
| 62 |
+
try:
|
| 63 |
+
yield
|
| 64 |
+
finally:
|
| 65 |
+
torch._C._cuda_unlock_mutex()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None):
|
| 69 |
+
r"""Perform a memory allocation using the CUDA memory allocator.
|
| 70 |
+
|
| 71 |
+
Memory is allocated for a given device and a stream, this
|
| 72 |
+
function is intended to be used for interoperability with other
|
| 73 |
+
frameworks. Allocated memory is released through
|
| 74 |
+
:func:`~torch.cuda.caching_allocator_delete`.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
size (int): number of bytes to be allocated.
|
| 78 |
+
device (torch.device or int, optional): selected device. If it is
|
| 79 |
+
``None`` the default CUDA device is used.
|
| 80 |
+
stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
|
| 81 |
+
the default stream for the selected device is used.
|
| 82 |
+
|
| 83 |
+
.. note::
|
| 84 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 85 |
+
management.
|
| 86 |
+
"""
|
| 87 |
+
if device is None:
|
| 88 |
+
device = torch.cuda.current_device()
|
| 89 |
+
device = _get_device_index(device)
|
| 90 |
+
if stream is None:
|
| 91 |
+
stream = torch.cuda.current_stream(device)
|
| 92 |
+
if isinstance(stream, torch.cuda.streams.Stream):
|
| 93 |
+
stream = stream.cuda_stream
|
| 94 |
+
if not isinstance(stream, int):
|
| 95 |
+
raise TypeError(
|
| 96 |
+
"Invalid type for stream argument, must be "
|
| 97 |
+
"`torch.cuda.Stream` or `int` representing a pointer "
|
| 98 |
+
"to a existing stream"
|
| 99 |
+
)
|
| 100 |
+
with torch.cuda.device(device):
|
| 101 |
+
return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def caching_allocator_delete(mem_ptr):
|
| 105 |
+
r"""Delete memory allocated using the CUDA memory allocator.
|
| 106 |
+
|
| 107 |
+
Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`.
|
| 108 |
+
is freed here. The associated device and stream are tracked inside
|
| 109 |
+
the allocator.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
mem_ptr (int): memory address to be freed by the allocator.
|
| 113 |
+
|
| 114 |
+
.. note::
|
| 115 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 116 |
+
management.
|
| 117 |
+
"""
|
| 118 |
+
torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def set_per_process_memory_fraction(
|
| 122 |
+
fraction, device: Union[Device, int] = None
|
| 123 |
+
) -> None:
|
| 124 |
+
r"""Set memory fraction for a process.
|
| 125 |
+
|
| 126 |
+
The fraction is used to limit an caching allocator to allocated memory on a CUDA device.
|
| 127 |
+
The allowed value equals the total visible memory multiplied fraction.
|
| 128 |
+
If trying to allocate more than the allowed value in a process, will raise an out of
|
| 129 |
+
memory error in allocator.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
|
| 133 |
+
device (torch.device or int, optional): selected device. If it is
|
| 134 |
+
``None`` the default CUDA device is used.
|
| 135 |
+
.. note::
|
| 136 |
+
In general, the total available free memory is less than the total capacity.
|
| 137 |
+
"""
|
| 138 |
+
_lazy_init()
|
| 139 |
+
if device is None:
|
| 140 |
+
device = torch.cuda.current_device()
|
| 141 |
+
device = _get_device_index(device)
|
| 142 |
+
if not isinstance(fraction, float):
|
| 143 |
+
raise TypeError("Invalid type for fraction argument, must be `float`")
|
| 144 |
+
if fraction < 0 or fraction > 1:
|
| 145 |
+
raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1")
|
| 146 |
+
|
| 147 |
+
torch._C._cuda_setMemoryFraction(fraction, device)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def empty_cache() -> None:
|
| 151 |
+
r"""Release all unoccupied cached memory currently held by the caching
|
| 152 |
+
allocator so that those can be used in other GPU application and visible in
|
| 153 |
+
`nvidia-smi`.
|
| 154 |
+
|
| 155 |
+
.. note::
|
| 156 |
+
:func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
|
| 157 |
+
memory available for PyTorch. However, it may help reduce fragmentation
|
| 158 |
+
of GPU memory in certain cases. See :ref:`cuda-memory-management` for
|
| 159 |
+
more details about GPU memory management.
|
| 160 |
+
"""
|
| 161 |
+
if is_initialized():
|
| 162 |
+
torch._C._cuda_emptyCache()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
|
| 166 |
+
r"""Return a dictionary of CUDA memory allocator statistics for a given device.
|
| 167 |
+
|
| 168 |
+
The return value of this function is a dictionary of statistics, each of
|
| 169 |
+
which is a non-negative integer.
|
| 170 |
+
|
| 171 |
+
Core statistics:
|
| 172 |
+
|
| 173 |
+
- ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 174 |
+
number of allocation requests received by the memory allocator.
|
| 175 |
+
- ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 176 |
+
amount of allocated memory.
|
| 177 |
+
- ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 178 |
+
number of reserved segments from ``cudaMalloc()``.
|
| 179 |
+
- ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 180 |
+
amount of reserved memory.
|
| 181 |
+
- ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 182 |
+
number of active memory blocks.
|
| 183 |
+
- ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 184 |
+
amount of active memory.
|
| 185 |
+
- ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 186 |
+
number of inactive, non-releasable memory blocks.
|
| 187 |
+
- ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 188 |
+
amount of inactive, non-releasable memory.
|
| 189 |
+
|
| 190 |
+
For these core statistics, values are broken down as follows.
|
| 191 |
+
|
| 192 |
+
Pool type:
|
| 193 |
+
|
| 194 |
+
- ``all``: combined statistics across all memory pools.
|
| 195 |
+
- ``large_pool``: statistics for the large allocation pool
|
| 196 |
+
(as of October 2019, for size >= 1MB allocations).
|
| 197 |
+
- ``small_pool``: statistics for the small allocation pool
|
| 198 |
+
(as of October 2019, for size < 1MB allocations).
|
| 199 |
+
|
| 200 |
+
Metric type:
|
| 201 |
+
|
| 202 |
+
- ``current``: current value of this metric.
|
| 203 |
+
- ``peak``: maximum value of this metric.
|
| 204 |
+
- ``allocated``: historical total increase in this metric.
|
| 205 |
+
- ``freed``: historical total decrease in this metric.
|
| 206 |
+
|
| 207 |
+
In addition to the core statistics, we also provide some simple event
|
| 208 |
+
counters:
|
| 209 |
+
|
| 210 |
+
- ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that
|
| 211 |
+
result in a cache flush and retry.
|
| 212 |
+
- ``"num_ooms"``: number of out-of-memory errors thrown.
|
| 213 |
+
|
| 214 |
+
The caching allocator can be configured via ENV to not split blocks larger than a
|
| 215 |
+
defined size (see Memory Management section of the Cuda Semantics documentation).
|
| 216 |
+
This helps avoid memory fragmentation but may have a performance
|
| 217 |
+
penalty. Additional outputs to assist with tuning and evaluating impact:
|
| 218 |
+
|
| 219 |
+
- ``"max_split_size"``: blocks above this size will not be split.
|
| 220 |
+
- ``"oversize_allocations.{current,peak,allocated,freed}"``:
|
| 221 |
+
number of over-size allocation requests received by the memory allocator.
|
| 222 |
+
- ``"oversize_segments.{current,peak,allocated,freed}"``:
|
| 223 |
+
number of over-size reserved segments from ``cudaMalloc()``.
|
| 224 |
+
|
| 225 |
+
The caching allocator can be configured via ENV to round memory allocations in order
|
| 226 |
+
to reduce fragmentation. Sometimes the overhead from rounding can be higher than
|
| 227 |
+
the fragmentation it helps reduce. The following stat can be used to check if
|
| 228 |
+
rounding adds too much overhead:
|
| 229 |
+
|
| 230 |
+
- ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 231 |
+
memory requested by client code, compare this with allocated_bytes to check if
|
| 232 |
+
allocation rounding adds too much overhead.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
device (torch.device or int, optional): selected device. Returns
|
| 236 |
+
statistics for the current device, given by :func:`~torch.cuda.current_device`,
|
| 237 |
+
if :attr:`device` is ``None`` (default).
|
| 238 |
+
|
| 239 |
+
.. note::
|
| 240 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 241 |
+
management.
|
| 242 |
+
|
| 243 |
+
.. note::
|
| 244 |
+
With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
|
| 245 |
+
meaningful, and are always reported as zero.
|
| 246 |
+
"""
|
| 247 |
+
result = []
|
| 248 |
+
|
| 249 |
+
def _recurse_add_to_result(prefix, obj):
|
| 250 |
+
if isinstance(obj, dict):
|
| 251 |
+
if len(prefix) > 0:
|
| 252 |
+
prefix += "."
|
| 253 |
+
for k, v in obj.items():
|
| 254 |
+
_recurse_add_to_result(prefix + k, v)
|
| 255 |
+
else:
|
| 256 |
+
result.append((prefix, obj))
|
| 257 |
+
|
| 258 |
+
stats = memory_stats_as_nested_dict(device=device)
|
| 259 |
+
_recurse_add_to_result("", stats)
|
| 260 |
+
result.sort()
|
| 261 |
+
|
| 262 |
+
return collections.OrderedDict(result)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
|
| 266 |
+
r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
|
| 267 |
+
if not is_initialized():
|
| 268 |
+
return {}
|
| 269 |
+
device = _get_device_index(device, optional=True)
|
| 270 |
+
return torch._C._cuda_memoryStats(device)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None:
|
| 274 |
+
r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator.
|
| 275 |
+
|
| 276 |
+
See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
|
| 277 |
+
the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
|
| 278 |
+
`"num_alloc_retries"` and `"num_ooms"`.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
device (torch.device or int, optional): selected device. Returns
|
| 282 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 283 |
+
if :attr:`device` is ``None`` (default).
|
| 284 |
+
|
| 285 |
+
.. note::
|
| 286 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 287 |
+
management.
|
| 288 |
+
"""
|
| 289 |
+
device = _get_device_index(device, optional=True)
|
| 290 |
+
return torch._C._cuda_resetAccumulatedMemoryStats(device)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
|
| 294 |
+
r"""Reset the "peak" stats tracked by the CUDA memory allocator.
|
| 295 |
+
|
| 296 |
+
See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
|
| 297 |
+
`"peak"` key in each individual stat dict.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
device (torch.device or int, optional): selected device. Returns
|
| 301 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 302 |
+
if :attr:`device` is ``None`` (default).
|
| 303 |
+
|
| 304 |
+
.. note::
|
| 305 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 306 |
+
management.
|
| 307 |
+
"""
|
| 308 |
+
device = _get_device_index(device, optional=True)
|
| 309 |
+
return torch._C._cuda_resetPeakMemoryStats(device)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
|
| 313 |
+
r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device.
|
| 314 |
+
|
| 315 |
+
See :func:`~torch.cuda.max_memory_allocated` for details.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
device (torch.device or int, optional): selected device. Returns
|
| 319 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 320 |
+
if :attr:`device` is ``None`` (default).
|
| 321 |
+
|
| 322 |
+
.. warning::
|
| 323 |
+
This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
|
| 324 |
+
/all/ peak memory stats.
|
| 325 |
+
|
| 326 |
+
.. note::
|
| 327 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 328 |
+
management.
|
| 329 |
+
"""
|
| 330 |
+
warnings.warn(
|
| 331 |
+
"torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
|
| 332 |
+
"which resets /all/ peak memory stats.",
|
| 333 |
+
FutureWarning,
|
| 334 |
+
)
|
| 335 |
+
return reset_peak_memory_stats(device=device)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def reset_max_memory_cached(device: Union[Device, int] = None) -> None:
|
| 339 |
+
r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
|
| 340 |
+
|
| 341 |
+
See :func:`~torch.cuda.max_memory_cached` for details.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
device (torch.device or int, optional): selected device. Returns
|
| 345 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 346 |
+
if :attr:`device` is ``None`` (default).
|
| 347 |
+
|
| 348 |
+
.. warning::
|
| 349 |
+
This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
|
| 350 |
+
/all/ peak memory stats.
|
| 351 |
+
|
| 352 |
+
.. note::
|
| 353 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 354 |
+
management.
|
| 355 |
+
"""
|
| 356 |
+
warnings.warn(
|
| 357 |
+
"torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
|
| 358 |
+
"which resets /all/ peak memory stats.",
|
| 359 |
+
FutureWarning,
|
| 360 |
+
)
|
| 361 |
+
return reset_peak_memory_stats(device=device)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def memory_allocated(device: Union[Device, int] = None) -> int:
|
| 365 |
+
r"""Return the current GPU memory occupied by tensors in bytes for a given device.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
device (torch.device or int, optional): selected device. Returns
|
| 369 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 370 |
+
if :attr:`device` is ``None`` (default).
|
| 371 |
+
|
| 372 |
+
.. note::
|
| 373 |
+
This is likely less than the amount shown in `nvidia-smi` since some
|
| 374 |
+
unused memory can be held by the caching allocator and some context
|
| 375 |
+
needs to be created on GPU. See :ref:`cuda-memory-management` for more
|
| 376 |
+
details about GPU memory management.
|
| 377 |
+
"""
|
| 378 |
+
return memory_stats(device=device).get("allocated_bytes.all.current", 0)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def max_memory_allocated(device: Union[Device, int] = None) -> int:
|
| 382 |
+
r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
|
| 383 |
+
|
| 384 |
+
By default, this returns the peak allocated memory since the beginning of
|
| 385 |
+
this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to
|
| 386 |
+
reset the starting point in tracking this metric. For example, these two
|
| 387 |
+
functions can measure the peak allocated memory usage of each iteration in a
|
| 388 |
+
training loop.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
device (torch.device or int, optional): selected device. Returns
|
| 392 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 393 |
+
if :attr:`device` is ``None`` (default).
|
| 394 |
+
|
| 395 |
+
.. note::
|
| 396 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 397 |
+
management.
|
| 398 |
+
"""
|
| 399 |
+
return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def memory_reserved(device: Union[Device, int] = None) -> int:
|
| 403 |
+
r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
device (torch.device or int, optional): selected device. Returns
|
| 407 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 408 |
+
if :attr:`device` is ``None`` (default).
|
| 409 |
+
|
| 410 |
+
.. note::
|
| 411 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 412 |
+
management.
|
| 413 |
+
"""
|
| 414 |
+
return memory_stats(device=device).get("reserved_bytes.all.current", 0)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def max_memory_reserved(device: Union[Device, int] = None) -> int:
|
| 418 |
+
r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
|
| 419 |
+
|
| 420 |
+
By default, this returns the peak cached memory since the beginning of this
|
| 421 |
+
program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset
|
| 422 |
+
the starting point in tracking this metric. For example, these two functions
|
| 423 |
+
can measure the peak cached memory amount of each iteration in a training
|
| 424 |
+
loop.
|
| 425 |
+
|
| 426 |
+
Args:
|
| 427 |
+
device (torch.device or int, optional): selected device. Returns
|
| 428 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 429 |
+
if :attr:`device` is ``None`` (default).
|
| 430 |
+
|
| 431 |
+
.. note::
|
| 432 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 433 |
+
management.
|
| 434 |
+
"""
|
| 435 |
+
return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def memory_cached(device: Union[Device, int] = None) -> int:
|
| 439 |
+
r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
|
| 440 |
+
warnings.warn(
|
| 441 |
+
"torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved",
|
| 442 |
+
FutureWarning,
|
| 443 |
+
)
|
| 444 |
+
return memory_reserved(device=device)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def max_memory_cached(device: Union[Device, int] = None) -> int:
|
| 448 |
+
r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
|
| 449 |
+
warnings.warn(
|
| 450 |
+
"torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved",
|
| 451 |
+
FutureWarning,
|
| 452 |
+
)
|
| 453 |
+
return max_memory_reserved(device=device)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def memory_snapshot():
|
| 457 |
+
r"""Return a snapshot of the CUDA memory allocator state across all devices.
|
| 458 |
+
|
| 459 |
+
Interpreting the output of this function requires familiarity with the
|
| 460 |
+
memory allocator internals.
|
| 461 |
+
|
| 462 |
+
.. note::
|
| 463 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 464 |
+
management.
|
| 465 |
+
"""
|
| 466 |
+
return torch._C._cuda_memorySnapshot()["segments"]
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str:
|
| 470 |
+
r"""Return a human-readable printout of the current memory allocator statistics for a given device.
|
| 471 |
+
|
| 472 |
+
This can be useful to display periodically during training, or when
|
| 473 |
+
handling out-of-memory exceptions.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
device (torch.device or int, optional): selected device. Returns
|
| 477 |
+
printout for the current device, given by :func:`~torch.cuda.current_device`,
|
| 478 |
+
if :attr:`device` is ``None`` (default).
|
| 479 |
+
abbreviated (bool, optional): whether to return an abbreviated summary
|
| 480 |
+
(default: False).
|
| 481 |
+
|
| 482 |
+
.. note::
|
| 483 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 484 |
+
management.
|
| 485 |
+
"""
|
| 486 |
+
device = _get_device_index(device, optional=True)
|
| 487 |
+
stats = memory_stats(device=device)
|
| 488 |
+
|
| 489 |
+
def _format_size(sz, pref_sz):
|
| 490 |
+
prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"]
|
| 491 |
+
prefix = prefixes[0]
|
| 492 |
+
for new_prefix in prefixes[1:]:
|
| 493 |
+
if pref_sz < 768 * 1024:
|
| 494 |
+
break
|
| 495 |
+
prefix = new_prefix
|
| 496 |
+
sz //= 1024
|
| 497 |
+
pref_sz /= 1024
|
| 498 |
+
return f"{sz:6d} {prefix}"
|
| 499 |
+
|
| 500 |
+
def _format_count(cnt, pref_cnt):
|
| 501 |
+
prefixes = [" ", "K", "M"]
|
| 502 |
+
prefix = prefixes[0]
|
| 503 |
+
for new_prefix in prefixes[1:]:
|
| 504 |
+
if pref_cnt < 750 * 1000:
|
| 505 |
+
break
|
| 506 |
+
prefix = new_prefix
|
| 507 |
+
cnt //= 1000
|
| 508 |
+
pref_cnt /= 1000
|
| 509 |
+
return f"{cnt:7d} {prefix} "
|
| 510 |
+
|
| 511 |
+
metrics_to_display = [
|
| 512 |
+
("allocated_bytes", "Allocated memory", _format_size),
|
| 513 |
+
("active_bytes", "Active memory", _format_size),
|
| 514 |
+
("requested_bytes", "Requested memory", _format_size),
|
| 515 |
+
("reserved_bytes", "GPU reserved memory", _format_size),
|
| 516 |
+
("inactive_split_bytes", "Non-releasable memory", _format_size),
|
| 517 |
+
("allocation", "Allocations", _format_count),
|
| 518 |
+
("active", "Active allocs", _format_count),
|
| 519 |
+
("segment", "GPU reserved segments", _format_count),
|
| 520 |
+
("inactive_split", "Non-releasable allocs", _format_count),
|
| 521 |
+
]
|
| 522 |
+
|
| 523 |
+
lines = []
|
| 524 |
+
lines.append("=" * 75)
|
| 525 |
+
lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ")
|
| 526 |
+
lines.append("-" * 75)
|
| 527 |
+
lines.append(
|
| 528 |
+
" {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} "
|
| 529 |
+
)
|
| 530 |
+
lines.append("=" * 75)
|
| 531 |
+
lines.append(
|
| 532 |
+
" Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
for metric_key, metric_name, formatter in metrics_to_display:
|
| 536 |
+
lines.append("-" * 75)
|
| 537 |
+
submetrics = [("all", metric_name)]
|
| 538 |
+
if not abbreviated:
|
| 539 |
+
submetrics.append(("large_pool", " from large pool"))
|
| 540 |
+
submetrics.append(("small_pool", " from small pool"))
|
| 541 |
+
|
| 542 |
+
current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
|
| 543 |
+
None,
|
| 544 |
+
None,
|
| 545 |
+
None,
|
| 546 |
+
None,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
for submetric_key, submetric_name in submetrics:
|
| 550 |
+
prefix = metric_key + "." + submetric_key + "."
|
| 551 |
+
|
| 552 |
+
current = stats[prefix + "current"]
|
| 553 |
+
peak = stats[prefix + "peak"]
|
| 554 |
+
allocated = stats[prefix + "allocated"]
|
| 555 |
+
freed = stats[prefix + "freed"]
|
| 556 |
+
|
| 557 |
+
if current_prefval is None:
|
| 558 |
+
current_prefval = current
|
| 559 |
+
peak_prefval = peak
|
| 560 |
+
allocated_prefval = allocated
|
| 561 |
+
freed_prefval = freed
|
| 562 |
+
|
| 563 |
+
lines.append(
|
| 564 |
+
" {:<21} | {} | {} | {} | {} ".format(
|
| 565 |
+
submetric_name,
|
| 566 |
+
formatter(current, current_prefval),
|
| 567 |
+
formatter(peak, peak_prefval),
|
| 568 |
+
formatter(allocated, allocated_prefval),
|
| 569 |
+
formatter(freed, freed_prefval),
|
| 570 |
+
),
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
metrics_to_display = [
|
| 574 |
+
("oversize_allocations", "Oversize allocations", _format_count),
|
| 575 |
+
("oversize_segments", "Oversize GPU segments", _format_count),
|
| 576 |
+
]
|
| 577 |
+
|
| 578 |
+
for metric_key, metric_name, formatter in metrics_to_display:
|
| 579 |
+
lines.append("-" * 75)
|
| 580 |
+
|
| 581 |
+
prefix = metric_key + "."
|
| 582 |
+
|
| 583 |
+
current = stats[prefix + "current"]
|
| 584 |
+
peak = stats[prefix + "peak"]
|
| 585 |
+
allocated = stats[prefix + "allocated"]
|
| 586 |
+
freed = stats[prefix + "freed"]
|
| 587 |
+
|
| 588 |
+
lines.append(
|
| 589 |
+
" {:<21} | {} | {} | {} | {} ".format(
|
| 590 |
+
metric_name,
|
| 591 |
+
formatter(current, current),
|
| 592 |
+
formatter(peak, peak),
|
| 593 |
+
formatter(allocated, allocated),
|
| 594 |
+
formatter(freed, freed),
|
| 595 |
+
),
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
lines.append("=" * 75)
|
| 599 |
+
|
| 600 |
+
fmt_dict = {"_": "", "device": device}
|
| 601 |
+
for k, v in stats.items():
|
| 602 |
+
fmt_dict[k.replace(".", "-")] = v
|
| 603 |
+
return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def list_gpu_processes(device: Union[Device, int] = None) -> str:
|
| 607 |
+
r"""Return a human-readable printout of the running processes and their GPU memory use for a given device.
|
| 608 |
+
|
| 609 |
+
This can be useful to display periodically during training, or when
|
| 610 |
+
handling out-of-memory exceptions.
|
| 611 |
+
|
| 612 |
+
Args:
|
| 613 |
+
device (torch.device or int, optional): selected device. Returns
|
| 614 |
+
printout for the current device, given by :func:`~torch.cuda.current_device`,
|
| 615 |
+
if :attr:`device` is ``None`` (default).
|
| 616 |
+
"""
|
| 617 |
+
try:
|
| 618 |
+
import pynvml # type: ignore[import]
|
| 619 |
+
except ModuleNotFoundError:
|
| 620 |
+
return "pynvml module not found, please install pynvml"
|
| 621 |
+
from pynvml import NVMLError_DriverNotLoaded
|
| 622 |
+
|
| 623 |
+
try:
|
| 624 |
+
pynvml.nvmlInit()
|
| 625 |
+
except NVMLError_DriverNotLoaded:
|
| 626 |
+
return "cuda driver can't be loaded, is cuda enabled?"
|
| 627 |
+
device = _get_nvml_device_index(device)
|
| 628 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 629 |
+
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
| 630 |
+
lines = []
|
| 631 |
+
lines.append(f"GPU:{device}")
|
| 632 |
+
if len(procs) == 0:
|
| 633 |
+
lines.append("no processes are running")
|
| 634 |
+
for p in procs:
|
| 635 |
+
mem = p.usedGpuMemory / (1024 * 1024)
|
| 636 |
+
lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory")
|
| 637 |
+
return "\n".join(lines)
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
|
| 641 |
+
r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
device (torch.device or int, optional): selected device. Returns
|
| 645 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 646 |
+
if :attr:`device` is ``None`` (default).
|
| 647 |
+
|
| 648 |
+
.. note::
|
| 649 |
+
See :ref:`cuda-memory-management` for more
|
| 650 |
+
details about GPU memory management.
|
| 651 |
+
"""
|
| 652 |
+
if device is None:
|
| 653 |
+
device = torch.cuda.current_device()
|
| 654 |
+
device = _get_device_index(device)
|
| 655 |
+
return torch.cuda.cudart().cudaMemGetInfo(device)
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def _record_memory_history_legacy(
|
| 659 |
+
enabled: bool,
|
| 660 |
+
record_context=True,
|
| 661 |
+
trace_alloc_max_entries=1,
|
| 662 |
+
trace_alloc_record_context=False,
|
| 663 |
+
device: Union[Device, int] = None,
|
| 664 |
+
record_context_cpp=False,
|
| 665 |
+
):
|
| 666 |
+
_C._cuda_record_memory_history_legacy(
|
| 667 |
+
enabled,
|
| 668 |
+
record_context,
|
| 669 |
+
trace_alloc_max_entries,
|
| 670 |
+
trace_alloc_record_context,
|
| 671 |
+
record_context_cpp,
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def _record_memory_history(enabled="all", *args, **kwargs):
|
| 676 |
+
"""Enable recording of stack traces associated with memory
|
| 677 |
+
allocations, so you can tell what allocated any piece of memory in
|
| 678 |
+
:func:`torch.cuda.memory._snapshot()`.
|
| 679 |
+
|
| 680 |
+
In addition too keeping stack traces with each current allocation and free,
|
| 681 |
+
this will also enable recording of a history of all alloc/free events.
|
| 682 |
+
|
| 683 |
+
Use :func:`torch.cuda.memory._snapshot()` to retrieve this information,
|
| 684 |
+
and the tools in `_memory_viz.py` to visualize snapshots.
|
| 685 |
+
|
| 686 |
+
The Python trace collection is fast (2us per trace), so you may consider
|
| 687 |
+
enabling this on production jobs if you anticipate ever having to debug
|
| 688 |
+
memory issues.
|
| 689 |
+
|
| 690 |
+
C++ trace collection is also fast (~50ns/frame), which for many typical programs
|
| 691 |
+
works out to ~2us per trace, but can vary depending on stack depth.
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
enabled (Literal[None, "state", "all"], optional):
|
| 695 |
+
`None`, disable recording memory history.
|
| 696 |
+
`"state"`, keep information for currenly allocated memory.
|
| 697 |
+
`"all"`, additionally keep a history of all alloc/free calls.
|
| 698 |
+
Defaults to "all".
|
| 699 |
+
context (Literal[None, "state", "alloc", "all"], optional):
|
| 700 |
+
`None`, Do not record any tracebacks.
|
| 701 |
+
`"state"`, Record tracebacks for currently allocated memory.
|
| 702 |
+
`"alloc"`, additionally keep tracebacks for alloc calls.
|
| 703 |
+
`"all"`, additionally keep tracebacks for free calls.
|
| 704 |
+
Defaults to "all".
|
| 705 |
+
stacks (Literal["python", "all"], optional):
|
| 706 |
+
`"python"`, include Python, TorchScript, and inductor frames in tracebacks
|
| 707 |
+
`"all"`, additionally include C++ frames
|
| 708 |
+
Defaults to "all".
|
| 709 |
+
max_entries (int, optional): Keep a maximum of `max_entries`
|
| 710 |
+
alloc/free events in the recorded history recorded.
|
| 711 |
+
"""
|
| 712 |
+
if isinstance(enabled, bool):
|
| 713 |
+
return _record_memory_history_legacy(enabled, *args, **kwargs)
|
| 714 |
+
else:
|
| 715 |
+
return _record_memory_history_impl(enabled, *args, **kwargs)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def _record_memory_history_impl(
|
| 719 |
+
enabled: Optional[str] = "all",
|
| 720 |
+
context: Optional[str] = "all",
|
| 721 |
+
stacks: str = "all",
|
| 722 |
+
max_entries: int = sys.maxsize,
|
| 723 |
+
device: Union[Device, int] = None,
|
| 724 |
+
):
|
| 725 |
+
_C._cuda_record_memory_history(enabled, context, stacks, max_entries)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def _snapshot(device: Union[Device, int] = None):
|
| 732 |
+
"""Save a snapshot of CUDA memory state at the time it was called.
|
| 733 |
+
|
| 734 |
+
The state is represented as a dictionary with the following structure.
|
| 735 |
+
|
| 736 |
+
.. code-block:: python
|
| 737 |
+
|
| 738 |
+
class Snapshot(TypedDict):
|
| 739 |
+
segments : List[Segment]
|
| 740 |
+
device_traces: List[List[TraceEntry]]
|
| 741 |
+
|
| 742 |
+
class Segment(TypedDict):
|
| 743 |
+
# Segments are memory returned from a cudaMalloc call.
|
| 744 |
+
# The size of reserved memory is the sum of all Segments.
|
| 745 |
+
# Segments are cached and reused for future allocations.
|
| 746 |
+
# If the reuse is smaller than the segment, the segment
|
| 747 |
+
# is split into more then one Block.
|
| 748 |
+
# empty_cache() frees Segments that are entirely inactive.
|
| 749 |
+
address: int
|
| 750 |
+
total_size: int # cudaMalloc'd size of segment
|
| 751 |
+
stream: int
|
| 752 |
+
segment_type: Literal['small', 'large'] # 'large' (>1MB)
|
| 753 |
+
allocated_size: int # size of memory in use
|
| 754 |
+
active_size: int # size of memory in use or in active_awaiting_free state
|
| 755 |
+
blocks : List[Block]
|
| 756 |
+
|
| 757 |
+
class Block(TypedDict):
|
| 758 |
+
# A piece of memory returned from the allocator, or
|
| 759 |
+
# current cached but inactive.
|
| 760 |
+
size: int
|
| 761 |
+
requested_size: int # size requested during malloc, may be smaller than
|
| 762 |
+
# size due to rounding
|
| 763 |
+
address: int
|
| 764 |
+
state: Literal['active_allocated', # used by a tensor
|
| 765 |
+
'active_awaiting_free', # waiting for another stream to finish using
|
| 766 |
+
# this, then it will become free
|
| 767 |
+
'inactive',] # free for reuse
|
| 768 |
+
frames: List[Frame] # stack trace from where the allocation occurred
|
| 769 |
+
|
| 770 |
+
class Frame(TypedDict):
|
| 771 |
+
filename: str
|
| 772 |
+
line: int
|
| 773 |
+
name: str
|
| 774 |
+
|
| 775 |
+
class TraceEntry(TypedDict):
|
| 776 |
+
# When `torch.cuda.memory._record_memory_history()` is enabled,
|
| 777 |
+
# the snapshot will contain TraceEntry objects that record each
|
| 778 |
+
# action the allocator took.
|
| 779 |
+
action: Literal[
|
| 780 |
+
'alloc' # memory allocated
|
| 781 |
+
'free_requested', # the allocated received a call to free memory
|
| 782 |
+
'free_completed', # the memory that was requested to be freed is now
|
| 783 |
+
# able to be used in future allocation calls
|
| 784 |
+
'segment_alloc', # the caching allocator ask cudaMalloc for more memory
|
| 785 |
+
# and added it as a segment in its cache
|
| 786 |
+
'segment_free', # the caching allocator called cudaFree to return memory
|
| 787 |
+
# to cuda possibly trying free up memory to
|
| 788 |
+
# allocate more segments or because empty_caches was called
|
| 789 |
+
'oom', # the allocator threw an OOM exception. 'size' is
|
| 790 |
+
# the requested number of bytes that did not succeed
|
| 791 |
+
'snapshot' # the allocator generated a memory snapshot
|
| 792 |
+
# useful to coorelate a previously taken
|
| 793 |
+
# snapshot with this trace
|
| 794 |
+
]
|
| 795 |
+
addr: int # not present for OOM
|
| 796 |
+
frames: List[Frame]
|
| 797 |
+
size: int
|
| 798 |
+
stream: int
|
| 799 |
+
device_free: int # only present for OOM, the amount of
|
| 800 |
+
# memory cuda still reports to be free
|
| 801 |
+
|
| 802 |
+
Returns:
|
| 803 |
+
The Snapshot dictionary object
|
| 804 |
+
"""
|
| 805 |
+
return _C._cuda_memorySnapshot()
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def _dump_snapshot(filename="dump_snapshot.pickle"):
|
| 809 |
+
"""
|
| 810 |
+
Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
|
| 811 |
+
|
| 812 |
+
This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz
|
| 813 |
+
|
| 814 |
+
Args:
|
| 815 |
+
filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
|
| 816 |
+
"""
|
| 817 |
+
s = _snapshot()
|
| 818 |
+
with open(filename, "wb") as f:
|
| 819 |
+
pickle.dump(s, f)
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
def _save_segment_usage(filename="output.svg", snapshot=None):
|
| 823 |
+
if snapshot is None:
|
| 824 |
+
snapshot = _snapshot()
|
| 825 |
+
with open(filename, "w") as f:
|
| 826 |
+
f.write(_segments(snapshot))
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def _save_memory_usage(filename="output.svg", snapshot=None):
|
| 830 |
+
if snapshot is None:
|
| 831 |
+
snapshot = _snapshot()
|
| 832 |
+
with open(filename, "w") as f:
|
| 833 |
+
f.write(_memory(snapshot))
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
def _set_allocator_settings(env: str):
|
| 837 |
+
return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def get_allocator_backend() -> str:
|
| 841 |
+
r"""Return a string describing the active allocator backend as set by
|
| 842 |
+
``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
|
| 843 |
+
``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
|
| 844 |
+
(CUDA's built-in asynchronous allocator).
|
| 845 |
+
|
| 846 |
+
.. note::
|
| 847 |
+
See :ref:`cuda-memory-management` for details on choosing the allocator backend.
|
| 848 |
+
"""
|
| 849 |
+
return torch._C._cuda_getAllocatorBackend()
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
class _CUDAAllocator:
|
| 853 |
+
r"""Wrapper over internal CUDA memory allocators."""
|
| 854 |
+
|
| 855 |
+
def __init__(self, allocator: torch._C._cuda_CUDAAllocator):
|
| 856 |
+
self._allocator = allocator
|
| 857 |
+
|
| 858 |
+
def allocator(self):
|
| 859 |
+
return self._allocator
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
class CUDAPluggableAllocator(_CUDAAllocator):
|
| 863 |
+
r"""CUDA memory allocator loaded from a so file."""
|
| 864 |
+
|
| 865 |
+
def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
|
| 866 |
+
r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes.
|
| 867 |
+
|
| 868 |
+
To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function.
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
path_to_so_file(str): Path in the filesystem to the `.so` file containing
|
| 872 |
+
the allocator functions
|
| 873 |
+
alloc_fn_name(str): Name of the function to perform the memory allocation
|
| 874 |
+
in the so file. The signature must be:
|
| 875 |
+
void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream);
|
| 876 |
+
free_fn_name(str): Name of the function to perform the memory release
|
| 877 |
+
in the so file. The signature must be:
|
| 878 |
+
void free_fn_name(void* ptr, size_t size, cudaStream_t stream);
|
| 879 |
+
|
| 880 |
+
.. warning::
|
| 881 |
+
This is currently supported only in unix OSs
|
| 882 |
+
|
| 883 |
+
.. note::
|
| 884 |
+
See :ref:`cuda-memory-management` for details on creating and using a custom allocator
|
| 885 |
+
"""
|
| 886 |
+
allocator = ctypes.CDLL(path_to_so_file)
|
| 887 |
+
alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
|
| 888 |
+
free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
|
| 889 |
+
assert alloc_fn is not None
|
| 890 |
+
assert free_fn is not None
|
| 891 |
+
self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn)
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def change_current_allocator(allocator: _CUDAAllocator) -> None:
|
| 895 |
+
r"""Change the currently used memory allocator to be the one provided.
|
| 896 |
+
|
| 897 |
+
If the current allocator has already been used/initialized, this function will error.
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
Args:
|
| 901 |
+
allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one.
|
| 902 |
+
.. note::
|
| 903 |
+
See :ref:`cuda-memory-management` for details on creating and using a custom allocator
|
| 904 |
+
"""
|
| 905 |
+
torch._C._cuda_changeCurrentAllocator(allocator.allocator())
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def _get_current_allocator() -> _CUDAAllocator:
|
| 909 |
+
r"""Return the allocator being currently used.
|
| 910 |
+
|
| 911 |
+
.. note::
|
| 912 |
+
See :ref:`cuda-memory-management` for details on creating and using a custom allocator
|
| 913 |
+
"""
|
| 914 |
+
return _CUDAAllocator(torch._C._cuda_getAllocator())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc
ADDED
|
Binary file (54.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (262 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc
ADDED
|
Binary file (40.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc
ADDED
|
Binary file (4.95 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc
ADDED
|
Binary file (34.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-311.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-311.pyc
ADDED
|
Binary file (42 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (225 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/cudagraphs.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
| 3 |
+
from torch.fx.passes.operator_support import OperatorSupport
|
| 4 |
+
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
|
| 5 |
+
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
| 6 |
+
from torch.utils import _pytree as pytree
|
| 7 |
+
|
| 8 |
+
import operator
|
| 9 |
+
|
| 10 |
+
class CudaGraphsSupport(OperatorSupport):
|
| 11 |
+
# TODO: why is submodules passed here
|
| 12 |
+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
| 13 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 14 |
+
return False
|
| 15 |
+
|
| 16 |
+
if node.target in [torch.ops.aten.embedding_dense_backward.default]:
|
| 17 |
+
return False
|
| 18 |
+
|
| 19 |
+
if node.target in [operator.getitem]:
|
| 20 |
+
return True
|
| 21 |
+
|
| 22 |
+
found_not_cuda = False
|
| 23 |
+
|
| 24 |
+
def meta_fk(meta):
|
| 25 |
+
return meta["val"] if "val" in meta else meta["fake_result"]
|
| 26 |
+
|
| 27 |
+
def find_not_cuda(t):
|
| 28 |
+
nonlocal found_not_cuda
|
| 29 |
+
if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
|
| 30 |
+
found_not_cuda = True
|
| 31 |
+
|
| 32 |
+
for n in node.all_input_nodes:
|
| 33 |
+
pytree.tree_map_(find_not_cuda, meta_fk(n.meta))
|
| 34 |
+
|
| 35 |
+
pytree.tree_map_(find_not_cuda, meta_fk(node.meta))
|
| 36 |
+
|
| 37 |
+
# NB: factory function is accounted for because the result would be
|
| 38 |
+
# cpu or cuda
|
| 39 |
+
|
| 40 |
+
return not found_not_cuda
|
| 41 |
+
|
| 42 |
+
def partition_cudagraphs(gm, inputs):
|
| 43 |
+
"""
|
| 44 |
+
Partition an FX graph into sub-GraphModules that can be validly run under
|
| 45 |
+
CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations
|
| 46 |
+
must involve CUDA tensors only/
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
FakeTensorProp(gm).propagate(*inputs)
|
| 50 |
+
supported_ops = CudaGraphsSupport()
|
| 51 |
+
# TODO: single node partition may be wrong due to the pessimization
|
| 52 |
+
# from copying in and out the data. Check in benchmarks, perhaps
|
| 53 |
+
partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
|
| 54 |
+
partitions = partitioner.propose_partitions()
|
| 55 |
+
fused_graph = partitioner.fuse_partitions(partitions)
|
| 56 |
+
return fused_graph
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-311.pyc
ADDED
|
Binary file (3.96 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-311.pyc
ADDED
|
Binary file (5.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/test_pass_manager.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
|
| 3 |
+
from ..pass_manager import (
|
| 4 |
+
inplace_wrapper,
|
| 5 |
+
PassManager,
|
| 6 |
+
these_before_those_pass_constraint,
|
| 7 |
+
this_before_that_pass_constraint,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestPassManager(unittest.TestCase):
|
| 12 |
+
def test_pass_manager_builder(self) -> None:
|
| 13 |
+
passes = [lambda x: 2 * x for _ in range(10)]
|
| 14 |
+
pm = PassManager(passes)
|
| 15 |
+
pm.validate()
|
| 16 |
+
|
| 17 |
+
def test_this_before_that_pass_constraint(self) -> None:
|
| 18 |
+
passes = [lambda x: 2 * x for _ in range(10)]
|
| 19 |
+
pm = PassManager(passes)
|
| 20 |
+
|
| 21 |
+
# add unfulfillable constraint
|
| 22 |
+
pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
|
| 23 |
+
|
| 24 |
+
self.assertRaises(RuntimeError, pm.validate)
|
| 25 |
+
|
| 26 |
+
def test_these_before_those_pass_constraint(self) -> None:
|
| 27 |
+
passes = [lambda x: 2 * x for _ in range(10)]
|
| 28 |
+
constraint = these_before_those_pass_constraint(passes[-1], passes[0])
|
| 29 |
+
pm = PassManager(
|
| 30 |
+
[inplace_wrapper(p) for p in passes]
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# add unfulfillable constraint
|
| 34 |
+
pm.add_constraint(constraint)
|
| 35 |
+
|
| 36 |
+
self.assertRaises(RuntimeError, pm.validate)
|
| 37 |
+
|
| 38 |
+
def test_two_pass_managers(self) -> None:
|
| 39 |
+
"""Make sure we can construct the PassManager twice and not share any
|
| 40 |
+
state between them"""
|
| 41 |
+
|
| 42 |
+
passes = [lambda x: 2 * x for _ in range(3)]
|
| 43 |
+
constraint = these_before_those_pass_constraint(passes[0], passes[1])
|
| 44 |
+
pm1 = PassManager()
|
| 45 |
+
for p in passes:
|
| 46 |
+
pm1.add_pass(p)
|
| 47 |
+
pm1.add_constraint(constraint)
|
| 48 |
+
output1 = pm1(1)
|
| 49 |
+
self.assertEqual(output1, 2 ** 3)
|
| 50 |
+
|
| 51 |
+
passes = [lambda x: 3 * x for _ in range(3)]
|
| 52 |
+
constraint = these_before_those_pass_constraint(passes[0], passes[1])
|
| 53 |
+
pm2 = PassManager()
|
| 54 |
+
for p in passes:
|
| 55 |
+
pm2.add_pass(p)
|
| 56 |
+
pm2.add_constraint(constraint)
|
| 57 |
+
output2 = pm2(1)
|
| 58 |
+
self.assertEqual(output2, 3 ** 3)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .common import lift_subgraph_as_module, HolderModule, compare_graphs
|