koichi12 commited on
Commit
e278978
·
verified ·
1 Parent(s): d716663

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/torch/cuda/__init__.py +1661 -0
  3. .venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torch/cuda/_gpu_trace.py +75 -0
  5. .venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py +632 -0
  6. .venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py +621 -0
  7. .venv/lib/python3.11/site-packages/torch/cuda/_utils.py +38 -0
  8. .venv/lib/python3.11/site-packages/torch/cuda/comm.py +19 -0
  9. .venv/lib/python3.11/site-packages/torch/cuda/error.py +0 -0
  10. .venv/lib/python3.11/site-packages/torch/cuda/gds.py +129 -0
  11. .venv/lib/python3.11/site-packages/torch/cuda/graphs.py +491 -0
  12. .venv/lib/python3.11/site-packages/torch/cuda/jiterator.py +187 -0
  13. .venv/lib/python3.11/site-packages/torch/cuda/memory.py +1041 -0
  14. .venv/lib/python3.11/site-packages/torch/cuda/nccl.py +151 -0
  15. .venv/lib/python3.11/site-packages/torch/cuda/nvtx.py +93 -0
  16. .venv/lib/python3.11/site-packages/torch/cuda/profiler.py +86 -0
  17. .venv/lib/python3.11/site-packages/torch/cuda/random.py +182 -0
  18. .venv/lib/python3.11/site-packages/torch/cuda/sparse.py +1 -0
  19. .venv/lib/python3.11/site-packages/torch/cuda/streams.py +242 -0
  20. .venv/lib/python3.11/site-packages/torch/cuda/tunable.py +242 -0
  21. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_utils.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -127,3 +127,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
127
  .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
128
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
129
  .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
127
  .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
128
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
129
  .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
130
+ .venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/torch/cuda/__init__.py ADDED
@@ -0,0 +1,1661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""
3
+ This package adds support for CUDA tensor types.
4
+
5
+ It implements the same function as CPU tensors, but they utilize
6
+ GPUs for computation.
7
+
8
+ It is lazily initialized, so you can always import it, and use
9
+ :func:`is_available()` to determine if your system supports CUDA.
10
+
11
+ :ref:`cuda-semantics` has more details about working with CUDA.
12
+ """
13
+
14
+ import importlib
15
+ import os
16
+ import threading
17
+ import traceback
18
+ import warnings
19
+ from functools import lru_cache
20
+ from typing import Any, Callable, cast, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch._C
24
+ from torch import device as _device
25
+ from torch._utils import _dummy_type, _LazySeedTracker, classproperty
26
+ from torch.types import Device
27
+
28
+ from . import gds
29
+ from ._utils import _get_device_index
30
+ from .graphs import (
31
+ CUDAGraph,
32
+ graph,
33
+ graph_pool_handle,
34
+ is_current_stream_capturing,
35
+ make_graphed_callables,
36
+ )
37
+ from .streams import Event, ExternalStream, Stream
38
+
39
+
40
+ try:
41
+ from torch._C import _cudart # type: ignore[attr-defined]
42
+ except ImportError:
43
+ _cudart = None
44
+
45
+ _initialized = False
46
+ _tls = threading.local()
47
+ _initialization_lock = threading.Lock()
48
+ _queued_calls: List[
49
+ Tuple[Callable[[], None], List[str]]
50
+ ] = [] # don't invoke these until initialization occurs
51
+ _is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
52
+ _device_t = Union[_device, str, int, None]
53
+
54
+ _HAS_PYNVML = False
55
+ _PYNVML_ERR = None
56
+ try:
57
+ from torch import version as _version
58
+
59
+ try:
60
+ if not _version.hip:
61
+ import pynvml # type: ignore[import]
62
+ else:
63
+ import amdsmi # type: ignore[import]
64
+
65
+ _HAS_PYNVML = True
66
+ except ModuleNotFoundError:
67
+ pass
68
+ finally:
69
+ del _version
70
+ except ImportError as err:
71
+ _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
72
+
73
+ _lazy_seed_tracker = _LazySeedTracker()
74
+
75
+ # Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
76
+ if hasattr(torch._C, "_CudaDeviceProperties"):
77
+ _CudaDeviceProperties = torch._C._CudaDeviceProperties
78
+ else:
79
+ _CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") # type: ignore[assignment, misc]
80
+
81
+ if hasattr(torch._C, "_cuda_exchangeDevice"):
82
+ _exchange_device = torch._C._cuda_exchangeDevice
83
+ else:
84
+
85
+ def _exchange_device(device: int) -> int:
86
+ if device < 0:
87
+ return -1
88
+ raise RuntimeError("PyTorch was compiled without CUDA support")
89
+
90
+
91
+ if hasattr(torch._C, "_cuda_maybeExchangeDevice"):
92
+ _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice
93
+ else:
94
+
95
+ def _maybe_exchange_device(device: int) -> int:
96
+ if device < 0:
97
+ return -1
98
+ raise RuntimeError("PyTorch was compiled without CUDA support")
99
+
100
+
101
+ has_half: bool = True
102
+ has_magma: bool = torch._C._has_magma
103
+
104
+ default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
105
+
106
+
107
+ def _is_compiled() -> bool:
108
+ r"""Return true if compile with CUDA support."""
109
+ return hasattr(torch._C, "_cuda_getDeviceCount")
110
+
111
+
112
+ def _nvml_based_avail() -> bool:
113
+ return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1"
114
+
115
+
116
+ def is_available() -> bool:
117
+ r"""Return a bool indicating if CUDA is currently available."""
118
+ if not _is_compiled():
119
+ return False
120
+ if _nvml_based_avail():
121
+ # The user has set an env variable to request this availability check that attempts to avoid fork poisoning by
122
+ # using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization
123
+ # fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`)
124
+ return device_count() > 0
125
+ else:
126
+ # The default availability inspection never throws and returns 0 if the driver is missing or can't
127
+ # be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver
128
+ # API via `cuInit`
129
+ return torch._C._cuda_getDeviceCount() > 0
130
+
131
+
132
+ def is_bf16_supported(including_emulation: bool = True):
133
+ r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
134
+ # Check for ROCm, if true return true, no ROCM_VERSION check required,
135
+ # since it is supported on AMD GPU archs.
136
+ if torch.version.hip:
137
+ return True
138
+
139
+ # If CUDA is not available, than it does not support bf16 either
140
+ if not is_available():
141
+ return False
142
+
143
+ device = torch.cuda.current_device()
144
+
145
+ # Check for CUDA version and device compute capability.
146
+ # This is a fast way to check for it.
147
+ cuda_version = torch.version.cuda
148
+ if (
149
+ cuda_version is not None
150
+ and int(cuda_version.split(".")[0]) >= 11
151
+ and torch.cuda.get_device_properties(device).major >= 8
152
+ ):
153
+ return True
154
+
155
+ if not including_emulation:
156
+ return False
157
+
158
+ # Finally try to create a bfloat16 device.
159
+ return _check_bf16_tensor_supported(device)
160
+
161
+
162
+ @lru_cache(maxsize=16)
163
+ def _check_bf16_tensor_supported(device: _device_t):
164
+ try:
165
+ torch.tensor([1.0], dtype=torch.bfloat16, device=device)
166
+ return True
167
+ except Exception:
168
+ return False
169
+
170
+
171
+ def _sleep(cycles):
172
+ torch._C._cuda_sleep(cycles)
173
+
174
+
175
+ def _extract_arch_version(arch_string: str):
176
+ """Extracts the architecture string from a CUDA version"""
177
+ base = arch_string.split("_")[1]
178
+ if base.endswith("a"):
179
+ base = base[:-1]
180
+ return int(base)
181
+
182
+
183
+ def _check_capability():
184
+ incorrect_binary_warn = """
185
+ Found GPU%d %s which requires CUDA_VERSION >= %d to
186
+ work properly, but your PyTorch was compiled
187
+ with CUDA_VERSION %d. Please install the correct PyTorch binary
188
+ using instructions from https://pytorch.org
189
+ """
190
+
191
+ old_gpu_warn = """
192
+ Found GPU%d %s which is of cuda capability %d.%d.
193
+ PyTorch no longer supports this GPU because it is too old.
194
+ The minimum cuda capability supported by this library is %d.%d.
195
+ """
196
+
197
+ if torch.version.cuda is not None: # on ROCm we don't want this check
198
+ CUDA_VERSION = torch._C._cuda_getCompiledVersion()
199
+ for d in range(device_count()):
200
+ capability = get_device_capability(d)
201
+ major = capability[0]
202
+ minor = capability[1]
203
+ name = get_device_name(d)
204
+ current_arch = major * 10 + minor
205
+ min_arch = min(
206
+ (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
207
+ default=35,
208
+ )
209
+ if current_arch < min_arch:
210
+ warnings.warn(
211
+ old_gpu_warn
212
+ % (d, name, major, minor, min_arch // 10, min_arch % 10)
213
+ )
214
+
215
+
216
+ def _check_cubins():
217
+ incompatible_device_warn = """
218
+ {} with CUDA capability sm_{} is not compatible with the current PyTorch installation.
219
+ The current PyTorch install supports CUDA capabilities {}.
220
+ If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/
221
+ """
222
+ if torch.version.cuda is None: # on ROCm we don't want this check
223
+ return
224
+ arch_list = get_arch_list()
225
+ if len(arch_list) == 0:
226
+ return
227
+ supported_sm = [_extract_arch_version(arch) for arch in arch_list if "sm_" in arch]
228
+ for idx in range(device_count()):
229
+ cap_major, cap_minor = get_device_capability(idx)
230
+ # NVIDIA GPU compute architectures are backward compatible within major version
231
+ supported = any(sm // 10 == cap_major for sm in supported_sm)
232
+ if not supported:
233
+ device_name = get_device_name(idx)
234
+ capability = cap_major * 10 + cap_minor
235
+ warnings.warn(
236
+ incompatible_device_warn.format(
237
+ device_name, capability, " ".join(arch_list), device_name
238
+ )
239
+ )
240
+
241
+
242
+ def is_initialized():
243
+ r"""Return whether PyTorch's CUDA state has been initialized."""
244
+ return _initialized and not _is_in_bad_fork()
245
+
246
+
247
+ def _lazy_call(callable, **kwargs):
248
+ if is_initialized():
249
+ callable()
250
+ else:
251
+ # TODO(torch_deploy): this accesses linecache, which attempts to read the
252
+ # file system to get traceback info. Patch linecache or do something
253
+ # else here if this ends up being important.
254
+ global _lazy_seed_tracker
255
+ if kwargs.get("seed_all", False):
256
+ _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
257
+ elif kwargs.get("seed", False):
258
+ _lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
259
+ else:
260
+ # Don't store the actual traceback to avoid memory cycle
261
+ _queued_calls.append((callable, traceback.format_stack()))
262
+
263
+
264
+ _lazy_call(_check_capability)
265
+ _lazy_call(_check_cubins)
266
+
267
+
268
+ class DeferredCudaCallError(Exception):
269
+ pass
270
+
271
+
272
+ OutOfMemoryError = torch._C.OutOfMemoryError
273
+
274
+
275
+ def init():
276
+ r"""Initialize PyTorch's CUDA state.
277
+
278
+ You may need to call this explicitly if you are interacting with
279
+ PyTorch via its C API, as Python bindings for CUDA functionality
280
+ will not be available until this initialization takes place.
281
+ Ordinary users should not need this, as all of PyTorch's CUDA methods
282
+ automatically initialize CUDA state on-demand.
283
+
284
+ Does nothing if the CUDA state is already initialized.
285
+ """
286
+ _lazy_init()
287
+
288
+
289
+ def _lazy_init():
290
+ global _initialized, _queued_calls
291
+ if is_initialized() or hasattr(_tls, "is_initializing"):
292
+ return
293
+ with _initialization_lock:
294
+ # We be double-checked locking, boys! This is OK because
295
+ # the above test was GIL protected anyway. The inner test
296
+ # is for when a thread blocked on some other thread which was
297
+ # doing the initialization; when they get the lock, they will
298
+ # find there is nothing left to do.
299
+ if is_initialized():
300
+ return
301
+ # It is important to prevent other threads from entering _lazy_init
302
+ # immediately, while we are still guaranteed to have the GIL, because some
303
+ # of the C calls we make below will release the GIL
304
+ if _is_in_bad_fork():
305
+ raise RuntimeError(
306
+ "Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
307
+ "multiprocessing, you must use the 'spawn' start method"
308
+ )
309
+ if not hasattr(torch._C, "_cuda_getDeviceCount"):
310
+ raise AssertionError("Torch not compiled with CUDA enabled")
311
+ if _cudart is None:
312
+ raise AssertionError(
313
+ "libcudart functions unavailable. It looks like you have a broken build?"
314
+ )
315
+ # This function throws if there's a driver initialization error, no GPUs
316
+ # are found or any other error occurs
317
+ if "CUDA_MODULE_LOADING" not in os.environ:
318
+ os.environ["CUDA_MODULE_LOADING"] = "LAZY"
319
+ torch._C._cuda_init()
320
+ # Some of the queued calls may reentrantly call _lazy_init();
321
+ # we need to just return without initializing in that case.
322
+ # However, we must not let any *other* threads in!
323
+ _tls.is_initializing = True
324
+
325
+ for calls in _lazy_seed_tracker.get_calls():
326
+ if calls:
327
+ _queued_calls.append(calls)
328
+
329
+ try:
330
+ for queued_call, orig_traceback in _queued_calls:
331
+ try:
332
+ queued_call()
333
+ except Exception as e:
334
+ msg = (
335
+ f"CUDA call failed lazily at initialization with error: {str(e)}\n\n"
336
+ f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}"
337
+ )
338
+ raise DeferredCudaCallError(msg) from e
339
+ finally:
340
+ delattr(_tls, "is_initializing")
341
+ _initialized = True
342
+
343
+
344
+ def cudart():
345
+ r"""Retrieves the CUDA runtime API module.
346
+
347
+
348
+ This function initializes the CUDA runtime environment if it is not already
349
+ initialized and returns the CUDA runtime API module (_cudart). The CUDA
350
+ runtime API module provides access to various CUDA runtime functions.
351
+
352
+ Args:
353
+ ``None``
354
+
355
+ Returns:
356
+ module: The CUDA runtime API module (_cudart).
357
+
358
+ Raises:
359
+ RuntimeError: If CUDA cannot be re-initialized in a forked subprocess.
360
+ AssertionError: If PyTorch is not compiled with CUDA support or if libcudart functions are unavailable.
361
+
362
+ Example of CUDA operations with profiling:
363
+ >>> import torch
364
+ >>> from torch.cuda import cudart, check_error
365
+ >>> import os
366
+ >>>
367
+ >>> os.environ['CUDA_PROFILE'] = '1'
368
+ >>>
369
+ >>> def perform_cuda_operations_with_streams():
370
+ >>> stream = torch.cuda.Stream()
371
+ >>> with torch.cuda.stream(stream):
372
+ >>> x = torch.randn(100, 100, device='cuda')
373
+ >>> y = torch.randn(100, 100, device='cuda')
374
+ >>> z = torch.mul(x, y)
375
+ >>> return z
376
+ >>>
377
+ >>> torch.cuda.synchronize()
378
+ >>> print("====== Start nsys profiling ======")
379
+ >>> check_error(cudart().cudaProfilerStart())
380
+ >>> with torch.autograd.profiler.emit_nvtx():
381
+ >>> result = perform_cuda_operations_with_streams()
382
+ >>> print("CUDA operations completed.")
383
+ >>> check_error(torch.cuda.cudart().cudaProfilerStop())
384
+ >>> print("====== End nsys profiling ======")
385
+
386
+ To run this example and save the profiling information, execute:
387
+ >>> $ nvprof --profile-from-start off --csv --print-summary -o trace_name.prof -f -- python cudart_test.py
388
+
389
+ This command profiles the CUDA operations in the provided script and saves
390
+ the profiling information to a file named `trace_name.prof`.
391
+ The `--profile-from-start off` option ensures that profiling starts only
392
+ after the `cudaProfilerStart` call in the script.
393
+ The `--csv` and `--print-summary` options format the profiling output as a
394
+ CSV file and print a summary, respectively.
395
+ The `-o` option specifies the output file name, and the `-f` option forces the
396
+ overwrite of the output file if it already exists.
397
+ """
398
+ _lazy_init()
399
+ return _cudart
400
+
401
+
402
+ class cudaStatus:
403
+ SUCCESS: int = 0
404
+ ERROR_NOT_READY: int = 34
405
+
406
+
407
+ class CudaError(RuntimeError):
408
+ def __init__(self, code: int) -> None:
409
+ msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
410
+ super().__init__(f"{msg} ({code})")
411
+
412
+
413
+ def check_error(res: int) -> None:
414
+ if res != _cudart.cudaError.success:
415
+ raise CudaError(res)
416
+
417
+
418
+ class _DeviceGuard:
419
+ def __init__(self, index: int):
420
+ self.idx = index
421
+ self.prev_idx = -1
422
+
423
+ def __enter__(self):
424
+ self.prev_idx = torch.cuda._exchange_device(self.idx)
425
+
426
+ def __exit__(self, type: Any, value: Any, traceback: Any):
427
+ self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
428
+ return False
429
+
430
+
431
+ class device:
432
+ r"""Context-manager that changes the selected device.
433
+
434
+ Args:
435
+ device (torch.device or int): device index to select. It's a no-op if
436
+ this argument is a negative integer or ``None``.
437
+ """
438
+
439
+ def __init__(self, device: Any):
440
+ self.idx = _get_device_index(device, optional=True)
441
+ self.prev_idx = -1
442
+
443
+ def __enter__(self):
444
+ self.prev_idx = torch.cuda._exchange_device(self.idx)
445
+
446
+ def __exit__(self, type: Any, value: Any, traceback: Any):
447
+ self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
448
+ return False
449
+
450
+
451
+ class device_of(device):
452
+ r"""Context-manager that changes the current device to that of given object.
453
+
454
+ You can use both tensors and storages as arguments. If a given object is
455
+ not allocated on a GPU, this is a no-op.
456
+
457
+ Args:
458
+ obj (Tensor or Storage): object allocated on the selected device.
459
+ """
460
+
461
+ def __init__(self, obj):
462
+ idx = obj.get_device() if obj.is_cuda else -1
463
+ super().__init__(idx)
464
+
465
+
466
+ def set_device(device: _device_t) -> None:
467
+ r"""Set the current device.
468
+
469
+ Usage of this function is discouraged in favor of :any:`device`. In most
470
+ cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
471
+
472
+ Args:
473
+ device (torch.device or int): selected device. This function is a no-op
474
+ if this argument is negative.
475
+ """
476
+ device = _get_device_index(device)
477
+ if device >= 0:
478
+ torch._C._cuda_setDevice(device)
479
+
480
+
481
+ def get_device_name(device: Optional[_device_t] = None) -> str:
482
+ r"""Get the name of a device.
483
+
484
+ Args:
485
+ device (torch.device or int or str, optional): device for which to return the
486
+ name. This function is a no-op if this argument is a negative
487
+ integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
488
+ if :attr:`device` is ``None`` (default).
489
+
490
+ Returns:
491
+ str: the name of the device
492
+ """
493
+ return get_device_properties(device).name
494
+
495
+
496
+ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
497
+ r"""Get the cuda capability of a device.
498
+
499
+ Args:
500
+ device (torch.device or int or str, optional): device for which to return the
501
+ device capability. This function is a no-op if this argument is
502
+ a negative integer. It uses the current device, given by
503
+ :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
504
+ (default).
505
+
506
+ Returns:
507
+ tuple(int, int): the major and minor cuda capability of the device
508
+ """
509
+ prop = get_device_properties(device)
510
+ return prop.major, prop.minor
511
+
512
+
513
+ def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
514
+ r"""Get the properties of a device.
515
+
516
+ Args:
517
+ device (torch.device or int or str): device for which to return the
518
+ properties of the device.
519
+
520
+ Returns:
521
+ _CudaDeviceProperties: the properties of the device
522
+ """
523
+ _lazy_init() # will define _get_device_properties
524
+ device = _get_device_index(device, optional=True)
525
+ if device < 0 or device >= device_count():
526
+ raise AssertionError("Invalid device id")
527
+ return _get_device_properties(device) # type: ignore[name-defined]
528
+
529
+
530
+ def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool:
531
+ r"""Check if peer access between two devices is possible."""
532
+ _lazy_init()
533
+ device = _get_device_index(device, optional=True)
534
+ peer_device = _get_device_index(peer_device)
535
+ if device < 0 or device >= device_count():
536
+ raise AssertionError("Invalid device id")
537
+ if peer_device < 0 or peer_device >= device_count():
538
+ raise AssertionError("Invalid peer device id")
539
+ return torch._C._cuda_canDeviceAccessPeer(device, peer_device)
540
+
541
+
542
+ class StreamContext:
543
+ r"""Context-manager that selects a given stream.
544
+
545
+ All CUDA kernels queued within its context will be enqueued on a selected
546
+ stream.
547
+
548
+ Args:
549
+ Stream (Stream): selected stream. This manager is a no-op if it's
550
+ ``None``.
551
+ .. note:: Streams are per-device.
552
+ """
553
+ cur_stream: Optional["torch.cuda.Stream"]
554
+
555
+ def __init__(self, stream: Optional["torch.cuda.Stream"]):
556
+ self.stream = stream
557
+ self.idx = _get_device_index(None, True)
558
+ if not torch.jit.is_scripting():
559
+ if self.idx is None:
560
+ self.idx = -1
561
+
562
+ self.src_prev_stream = (
563
+ None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
564
+ )
565
+ self.dst_prev_stream = (
566
+ None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
567
+ )
568
+
569
+ def __enter__(self):
570
+ # Local cur_stream variable for type refinement
571
+ cur_stream = self.stream
572
+ # Return if stream is None or CUDA device not available
573
+ if cur_stream is None or self.idx == -1:
574
+ return
575
+ self.src_prev_stream = torch.cuda.current_stream(None)
576
+
577
+ # If the stream is not on the current device, then
578
+ # set the current stream on the device
579
+ if self.src_prev_stream.device != cur_stream.device:
580
+ with device(cur_stream.device):
581
+ self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device)
582
+ torch.cuda.set_stream(cur_stream)
583
+
584
+ def __exit__(self, type: Any, value: Any, traceback: Any):
585
+ # Local cur_stream variable for type refinement
586
+ cur_stream = self.stream
587
+ # If stream is None or no CUDA device available, return
588
+ if cur_stream is None or self.idx == -1:
589
+ return
590
+
591
+ # Reset the stream on the original device
592
+ # and destination device
593
+ if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
594
+ torch.cuda.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
595
+ torch.cuda.set_stream(self.src_prev_stream) # type: ignore[arg-type]
596
+
597
+
598
+ def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext:
599
+ r"""Wrap around the Context-manager StreamContext that selects a given stream.
600
+
601
+ Arguments:
602
+ stream (Stream): selected stream. This manager is a no-op if it's
603
+ ``None``.
604
+ ..Note:: In eager mode stream is of type Stream class while in JIT it is
605
+ an object of the custom class ``torch.classes.cuda.Stream``.
606
+ """
607
+ return StreamContext(stream)
608
+
609
+
610
+ def _set_stream_by_id(stream_id, device_index, device_type):
611
+ r"""set stream specified by the stream id, device index and
612
+ device type
613
+
614
+ Args: stream_id (int): stream id in stream pool
615
+ device_index (int): device index in topo
616
+ device_type (int): enum device type
617
+ """
618
+ torch._C._cuda_setStream(
619
+ stream_id=stream_id,
620
+ device_index=device_index,
621
+ device_type=device_type,
622
+ )
623
+
624
+
625
+ def set_stream(stream: Stream):
626
+ r"""Set the current stream.This is a wrapper API to set the stream.
627
+ Usage of this function is discouraged in favor of the ``stream``
628
+ context manager.
629
+
630
+ Args:
631
+ stream (Stream): selected stream. This function is a no-op
632
+ if this argument is ``None``.
633
+ """
634
+ if stream is None:
635
+ return
636
+ _set_stream_by_id(
637
+ stream_id=stream.stream_id,
638
+ device_index=stream.device_index,
639
+ device_type=stream.device_type,
640
+ )
641
+
642
+
643
+ def _parse_visible_devices() -> Union[List[int], List[str]]:
644
+ r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
645
+ var = os.getenv("CUDA_VISIBLE_DEVICES")
646
+
647
+ if torch.version.hip:
648
+ hip_devices = os.getenv("HIP_VISIBLE_DEVICES")
649
+ if hip_devices is not None:
650
+ var = hip_devices
651
+
652
+ if var is None:
653
+ return list(range(64))
654
+
655
+ def _strtoul(s: str) -> int:
656
+ """Return -1 or positive integer sequence string starts with."""
657
+ if not s:
658
+ return -1
659
+ for idx, c in enumerate(s):
660
+ if not (c.isdigit() or (idx == 0 and c in "+-")):
661
+ break
662
+ if idx + 1 == len(s):
663
+ idx += 1
664
+ return int(s[:idx]) if idx > 0 else -1
665
+
666
+ def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
667
+ rcs: List[str] = []
668
+ for elem in lst.split(","):
669
+ # Repeated id results in empty set
670
+ if elem in rcs:
671
+ return cast(List[str], [])
672
+ # Anything other but prefix is ignored
673
+ if not elem.startswith(prefix):
674
+ break
675
+ rcs.append(elem)
676
+ return rcs
677
+
678
+ if var.startswith("GPU-"):
679
+ return parse_list_with_prefix(var, "GPU-")
680
+ if var.startswith("MIG-"):
681
+ return parse_list_with_prefix(var, "MIG-")
682
+ # CUDA_VISIBLE_DEVICES uses something like strtoul
683
+ # which makes `1gpu2,2ampere` is equivalent to `1,2`
684
+ rc: List[int] = []
685
+ for elem in var.split(","):
686
+ x = _strtoul(elem.strip())
687
+ # Repeated ordinal results in empty set
688
+ if x in rc:
689
+ return cast(List[int], [])
690
+ # Negative value aborts the sequence
691
+ if x < 0:
692
+ break
693
+ rc.append(x)
694
+ return rc
695
+
696
+
697
+ def _raw_device_count_amdsmi() -> int:
698
+ if not _HAS_PYNVML: # If amdsmi is not available
699
+ return -1
700
+ try:
701
+ amdsmi.amdsmi_init()
702
+ except amdsmi.AmdSmiException as e:
703
+ warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}")
704
+ return -1
705
+ socket_handles = amdsmi.amdsmi_get_processor_handles()
706
+ return len(socket_handles)
707
+
708
+
709
+ def _raw_device_count_nvml() -> int:
710
+ r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
711
+ from ctypes import byref, c_int, CDLL
712
+
713
+ nvml_h = CDLL("libnvidia-ml.so.1")
714
+ rc = nvml_h.nvmlInit()
715
+ if rc != 0:
716
+ warnings.warn("Can't initialize NVML")
717
+ return -1
718
+ dev_count = c_int(-1)
719
+ rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
720
+ if rc != 0:
721
+ warnings.warn("Can't get nvml device count")
722
+ return -1
723
+ del nvml_h
724
+ return dev_count.value
725
+
726
+
727
+ def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
728
+ from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
729
+
730
+ if not _HAS_PYNVML: # If amdsmi is not available
731
+ return None
732
+ try:
733
+ amdsmi.amdsmi_init()
734
+ except amdsmi.AmdSmiException:
735
+ warnings.warn("Can't initialize amdsmi")
736
+ return None
737
+ try:
738
+ socket_handles = amdsmi.amdsmi_get_processor_handles()
739
+ dev_count = len(socket_handles)
740
+ except amdsmi.AmdSmiException:
741
+ warnings.warn("Can't get amdsmi device count")
742
+ return None
743
+ uuids: List[str] = []
744
+ for idx in range(dev_count):
745
+ try:
746
+ handler = amdsmi.amdsmi_get_processor_handles()[idx]
747
+ except amdsmi.AmdSmiException:
748
+ warnings.warn("Cannot get amd device handler")
749
+ return None
750
+ try:
751
+ uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler)
752
+ except amdsmi.AmdSmiException:
753
+ warnings.warn("Cannot get uuid for amd device")
754
+ return None
755
+ uuids.append(str(uuid))
756
+ return uuids
757
+
758
+
759
+ def _raw_device_uuid_nvml() -> Optional[List[str]]:
760
+ r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
761
+ from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
762
+
763
+ nvml_h = CDLL("libnvidia-ml.so.1")
764
+ rc = nvml_h.nvmlInit()
765
+ if rc != 0:
766
+ warnings.warn("Can't initialize NVML")
767
+ return None
768
+ dev_count = c_int(-1)
769
+ rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
770
+ if rc != 0:
771
+ warnings.warn("Can't get nvml device count")
772
+ return None
773
+ uuids: List[str] = []
774
+ for idx in range(dev_count.value):
775
+ dev_id = c_void_p()
776
+ rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
777
+ if rc != 0:
778
+ warnings.warn("Can't get device handle")
779
+ return None
780
+ buf_len = 96
781
+ buf = create_string_buffer(buf_len)
782
+ rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
783
+ if rc != 0:
784
+ warnings.warn("Can't get device UUID")
785
+ return None
786
+ uuids.append(buf.raw.decode("ascii").strip("\0"))
787
+ del nvml_h
788
+ return uuids
789
+
790
+
791
+ def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
792
+ r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""
793
+
794
+ def uuid_to_orinal(candidate: str, uuids: List[str]) -> int:
795
+ best_match = -1
796
+ for idx, uuid in enumerate(uuids):
797
+ if not uuid.startswith(candidate):
798
+ continue
799
+ # Ambiguous candidate
800
+ if best_match != -1:
801
+ return -1
802
+ best_match = idx
803
+ return best_match
804
+
805
+ rc: List[int] = []
806
+ for candidate in candidates:
807
+ idx = uuid_to_orinal(candidate, uuids)
808
+ # First invalid ordinal stops parsing
809
+ if idx < 0:
810
+ break
811
+ # Duplicates result in empty set
812
+ if idx in rc:
813
+ return cast(List[int], [])
814
+ rc.append(idx)
815
+ return rc
816
+
817
+
818
+ def _device_count_amdsmi() -> int:
819
+ visible_devices = _parse_visible_devices()
820
+ if not visible_devices:
821
+ return 0
822
+ try:
823
+ if type(visible_devices[0]) is str:
824
+ return -1
825
+ else:
826
+ raw_cnt = _raw_device_count_amdsmi()
827
+ if raw_cnt <= 0:
828
+ return raw_cnt
829
+ # Trim the list up to a maximum available device
830
+ for idx, val in enumerate(visible_devices):
831
+ if cast(int, val) >= raw_cnt:
832
+ return idx
833
+ except OSError:
834
+ return -1
835
+ except AttributeError:
836
+ return -1
837
+ return len(visible_devices)
838
+
839
+
840
+ def _device_count_nvml() -> int:
841
+ r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
842
+
843
+ Negative value is returned if NVML discovery or initialization has failed.
844
+ """
845
+ visible_devices = _parse_visible_devices()
846
+ if not visible_devices:
847
+ return 0
848
+ try:
849
+ if type(visible_devices[0]) is str:
850
+ # Skip MIG parsing
851
+ if visible_devices[0].startswith("MIG-"):
852
+ return -1
853
+ uuids = _raw_device_uuid_nvml()
854
+ if uuids is None:
855
+ return -1
856
+ visible_devices = _transform_uuid_to_ordinals(
857
+ cast(List[str], visible_devices), uuids
858
+ )
859
+ else:
860
+ raw_cnt = _raw_device_count_nvml()
861
+ if raw_cnt <= 0:
862
+ return raw_cnt
863
+ # Trim the list up to a maximum available device
864
+ for idx, val in enumerate(visible_devices):
865
+ if cast(int, val) >= raw_cnt:
866
+ return idx
867
+ except OSError:
868
+ return -1
869
+ except AttributeError:
870
+ return -1
871
+ return len(visible_devices)
872
+
873
+
874
+ def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
875
+ r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
876
+ idx = _get_device_index(device, optional=True)
877
+ visible_devices = _parse_visible_devices()
878
+ if type(visible_devices[0]) is str:
879
+ uuids = _raw_device_uuid_nvml()
880
+ if uuids is None:
881
+ raise RuntimeError("Can't get device UUIDs")
882
+ visible_devices = _transform_uuid_to_ordinals(
883
+ cast(List[str], visible_devices), uuids
884
+ )
885
+ visible_devices = cast(List[int], visible_devices)
886
+ if idx < 0 or idx >= len(visible_devices):
887
+ raise RuntimeError(
888
+ f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})"
889
+ )
890
+ return visible_devices[idx]
891
+
892
+
893
+ _cached_device_count: Optional[int] = None
894
+
895
+
896
+ def device_count() -> int:
897
+ r"""Return the number of GPUs available."""
898
+ global _cached_device_count
899
+ if not _is_compiled():
900
+ return 0
901
+ if _cached_device_count is not None:
902
+ return _cached_device_count
903
+ # bypass _device_count_nvml() if rocm (not supported)
904
+ nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml()
905
+ r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
906
+ # NB: Do not cache the device count prior to CUDA initialization, because
907
+ # the number of devices can change due to changes to CUDA_VISIBLE_DEVICES
908
+ # setting prior to CUDA initialization.
909
+ if _initialized:
910
+ _cached_device_count = r
911
+ return r
912
+
913
+
914
+ def get_arch_list() -> List[str]:
915
+ r"""Return list CUDA architectures this library was compiled for."""
916
+ if not is_available():
917
+ return []
918
+ arch_flags = torch._C._cuda_getArchFlags()
919
+ if arch_flags is None:
920
+ return []
921
+ return arch_flags.split()
922
+
923
+
924
+ def get_gencode_flags() -> str:
925
+ r"""Return NVCC gencode flags this library was compiled with."""
926
+ arch_list = get_arch_list()
927
+ if len(arch_list) == 0:
928
+ return ""
929
+ arch_list_ = [arch.split("_") for arch in arch_list]
930
+ return " ".join(
931
+ [
932
+ f"-gencode compute=compute_{arch},code={kind}_{arch}"
933
+ for (kind, arch) in arch_list_
934
+ ]
935
+ )
936
+
937
+
938
+ def current_device() -> int:
939
+ r"""Return the index of a currently selected device."""
940
+ _lazy_init()
941
+ return torch._C._cuda_getDevice()
942
+
943
+
944
+ def synchronize(device: _device_t = None) -> None:
945
+ r"""Wait for all kernels in all streams on a CUDA device to complete.
946
+
947
+ Args:
948
+ device (torch.device or int, optional): device for which to synchronize.
949
+ It uses the current device, given by :func:`~torch.cuda.current_device`,
950
+ if :attr:`device` is ``None`` (default).
951
+ """
952
+ _lazy_init()
953
+ with torch.cuda.device(device):
954
+ return torch._C._cuda_synchronize()
955
+
956
+
957
+ def ipc_collect():
958
+ r"""Force collects GPU memory after it has been released by CUDA IPC.
959
+
960
+ .. note::
961
+ Checks if any sent CUDA tensors could be cleaned from the memory. Force
962
+ closes shared memory file used for reference counting if there is no
963
+ active counters. Useful when the producer process stopped actively sending
964
+ tensors and want to release unused memory.
965
+ """
966
+ _lazy_init()
967
+ return torch._C._cuda_ipc_collect()
968
+
969
+
970
+ def current_stream(device: Optional[_device_t] = None) -> Stream:
971
+ r"""Return the currently selected :class:`Stream` for a given device.
972
+
973
+ Args:
974
+ device (torch.device or int, optional): selected device. Returns
975
+ the currently selected :class:`Stream` for the current device, given
976
+ by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
977
+ (default).
978
+ """
979
+ _lazy_init()
980
+ streamdata = torch._C._cuda_getCurrentStream(
981
+ _get_device_index(device, optional=True)
982
+ )
983
+ return Stream(
984
+ stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
985
+ )
986
+
987
+
988
+ def default_stream(device: Optional[_device_t] = None) -> Stream:
989
+ r"""Return the default :class:`Stream` for a given device.
990
+
991
+ Args:
992
+ device (torch.device or int, optional): selected device. Returns
993
+ the default :class:`Stream` for the current device, given by
994
+ :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
995
+ (default).
996
+ """
997
+ _lazy_init()
998
+ streamdata = torch._C._cuda_getDefaultStream(
999
+ _get_device_index(device, optional=True)
1000
+ )
1001
+ return Stream(
1002
+ stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
1003
+ )
1004
+
1005
+
1006
+ def current_blas_handle():
1007
+ r"""Return cublasHandle_t pointer to current cuBLAS handle"""
1008
+ _lazy_init()
1009
+ return torch._C._cuda_getCurrentBlasHandle()
1010
+
1011
+
1012
+ def set_sync_debug_mode(debug_mode: Union[int, str]) -> None:
1013
+ r"""Set the debug mode for cuda synchronizing operations.
1014
+
1015
+ Args:
1016
+ debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations,
1017
+ if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations.
1018
+
1019
+ Warning:
1020
+ This is an experimental feature, and not all synchronizing operations will trigger warning or error. In
1021
+ particular, operations in torch.distributed and torch.sparse namespaces are not covered yet.
1022
+ """
1023
+ _lazy_init()
1024
+ if isinstance(debug_mode, str):
1025
+ if debug_mode == "default":
1026
+ debug_mode = 0
1027
+ elif debug_mode == "warn":
1028
+ debug_mode = 1
1029
+ elif debug_mode == "error":
1030
+ debug_mode = 2
1031
+ else:
1032
+ raise RuntimeError(
1033
+ "invalid value of debug_mode, expected one of `default`, `warn`, `error`"
1034
+ )
1035
+
1036
+ torch._C._cuda_set_sync_debug_mode(debug_mode)
1037
+
1038
+
1039
+ def get_sync_debug_mode() -> int:
1040
+ r"""Return current value of debug mode for cuda synchronizing operations."""
1041
+ _lazy_init()
1042
+ return torch._C._cuda_get_sync_debug_mode()
1043
+
1044
+
1045
+ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
1046
+ if not _HAS_PYNVML:
1047
+ raise ModuleNotFoundError(
1048
+ "pynvml does not seem to be installed or it can't be imported."
1049
+ ) from _PYNVML_ERR
1050
+ from pynvml import NVMLError_DriverNotLoaded
1051
+
1052
+ try:
1053
+ pynvml.nvmlInit()
1054
+ except NVMLError_DriverNotLoaded as e:
1055
+ raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
1056
+
1057
+ device = _get_nvml_device_index(device)
1058
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
1059
+ return handle
1060
+
1061
+
1062
+ def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
1063
+ if not _HAS_PYNVML:
1064
+ raise ModuleNotFoundError(
1065
+ "amdsmi does not seem to be installed or it can't be imported."
1066
+ ) from _PYNVML_ERR
1067
+ try:
1068
+ amdsmi.amdsmi_init()
1069
+ except amdsmi.AmdSmiException as e:
1070
+ raise RuntimeError(
1071
+ "amdsmi driver can't be loaded, requires >=ROCm5.6 installation"
1072
+ ) from e
1073
+ device = _get_amdsmi_device_index(device)
1074
+ handle = amdsmi.amdsmi_get_processor_handles()[device]
1075
+ return handle
1076
+
1077
+
1078
+ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
1079
+ r"""Return the amdsmi index of the device, taking visible_devices into account."""
1080
+ idx = _get_device_index(device, optional=True)
1081
+ visible_devices = _parse_visible_devices()
1082
+ if type(visible_devices[0]) is str:
1083
+ raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings")
1084
+ idx_map = dict(enumerate(cast(List[int], visible_devices)))
1085
+ if idx not in idx_map:
1086
+ raise RuntimeError(
1087
+ f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})"
1088
+ )
1089
+ return idx_map[idx]
1090
+
1091
+
1092
+ def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int:
1093
+ handle = _get_amdsmi_handler()
1094
+ device = _get_amdsmi_device_index(device)
1095
+ return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"]
1096
+
1097
+
1098
+ def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int:
1099
+ handle = _get_amdsmi_handler()
1100
+ device = _get_amdsmi_device_index(device)
1101
+ handle = amdsmi.amdsmi_get_processor_handles()[device]
1102
+ return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"]
1103
+
1104
+
1105
+ def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
1106
+ handle = _get_amdsmi_handler(device)
1107
+ return amdsmi.amdsmi_get_temp_metric(
1108
+ handle,
1109
+ amdsmi.AmdSmiTemperatureType.JUNCTION,
1110
+ amdsmi.AmdSmiTemperatureMetric.CURRENT,
1111
+ )
1112
+
1113
+
1114
+ def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
1115
+ handle = _get_amdsmi_handler(device)
1116
+ socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"]
1117
+ if socket_power != "N/A":
1118
+ return socket_power
1119
+ else:
1120
+ return amdsmi.amdsmi_get_power_info(handle)["current_socket_power"]
1121
+
1122
+
1123
+ def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
1124
+ handle = _get_amdsmi_handler(device)
1125
+ clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)
1126
+ if "cur_clk" in clock_info: # ROCm 6.2 deprecation
1127
+ return clock_info["cur_clk"]
1128
+ else:
1129
+ return clock_info["clk"]
1130
+
1131
+
1132
+ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
1133
+ r"""Return the percent of time over the past sample period during which global (device)
1134
+ memory was being read or written as given by `nvidia-smi`.
1135
+
1136
+ Args:
1137
+ device (torch.device or int, optional): selected device. Returns
1138
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
1139
+ if :attr:`device` is ``None`` (default).
1140
+
1141
+ Warning: Each sample period may be between 1 second and 1/6 second,
1142
+ depending on the product being queried.
1143
+ """
1144
+ if not torch.version.hip:
1145
+ handle = _get_pynvml_handler()
1146
+ device = _get_nvml_device_index(device)
1147
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
1148
+ return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
1149
+ else:
1150
+ return _get_amdsmi_memory_usage(device)
1151
+
1152
+
1153
+ def utilization(device: Optional[Union[Device, int]] = None) -> int:
1154
+ r"""Return the percent of time over the past sample period during which one or
1155
+ more kernels was executing on the GPU as given by `nvidia-smi`.
1156
+
1157
+ Args:
1158
+ device (torch.device or int, optional): selected device. Returns
1159
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
1160
+ if :attr:`device` is ``None`` (default).
1161
+
1162
+ Warning: Each sample period may be between 1 second and 1/6 second,
1163
+ depending on the product being queried.
1164
+ """
1165
+ if not torch.version.hip:
1166
+ handle = _get_pynvml_handler(device)
1167
+ device = _get_nvml_device_index(device)
1168
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
1169
+ return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
1170
+ else:
1171
+ return _get_amdsmi_utilization(device)
1172
+
1173
+
1174
+ def temperature(device: Optional[Union[Device, int]] = None) -> int:
1175
+ r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
1176
+
1177
+ The average temperature is computed based on past sample period as given by `nvidia-smi`.
1178
+
1179
+ Args:
1180
+ device (torch.device or int, optional): selected device. Returns
1181
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
1182
+ if :attr:`device` is ``None`` (default).
1183
+
1184
+ Warning: Each sample period may be between 1 second and 1/6 second,
1185
+ depending on the product being queried.
1186
+ """
1187
+ if not torch.version.hip:
1188
+ handle = _get_pynvml_handler(device)
1189
+ # 0 refers to the temperature sensor for the GPU die.
1190
+ return pynvml.nvmlDeviceGetTemperature(handle, 0)
1191
+ else:
1192
+ return _get_amdsmi_temperature(device)
1193
+
1194
+
1195
+ def power_draw(device: Optional[Union[Device, int]] = None) -> int:
1196
+ r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
1197
+ over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
1198
+
1199
+ Args:
1200
+ device (torch.device or int, optional): selected device. Returns
1201
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
1202
+ if :attr:`device` is ``None`` (default).
1203
+
1204
+ Warning: Each sample period may be between 1 second and 1/6 second,
1205
+ depending on the product being queried.
1206
+ """
1207
+ if not torch.version.hip:
1208
+ handle = _get_pynvml_handler(device)
1209
+ return pynvml.nvmlDeviceGetPowerUsage(handle)
1210
+ else:
1211
+ return _get_amdsmi_power_draw(device)
1212
+
1213
+
1214
+ def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
1215
+ r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`.
1216
+
1217
+ Args:
1218
+ device (torch.device or int, optional): selected device. Returns
1219
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
1220
+ if :attr:`device` is ``None`` (default).
1221
+
1222
+ Warning: Each sample period may be between 1 second and 1/6 second,
1223
+ depending on the product being queried.
1224
+ """
1225
+ if not torch.version.hip:
1226
+ handle = _get_pynvml_handler(device)
1227
+ return pynvml.nvmlDeviceGetClockInfo(handle, 1)
1228
+ else:
1229
+ return _get_amdsmi_clock_rate(device)
1230
+
1231
+
1232
+ def _get_device(device: Union[int, str, torch.device]) -> torch.device:
1233
+ r"""Return the torch.device type object from the passed in device.
1234
+
1235
+ Args:
1236
+ device (torch.device or int): selected device.
1237
+ """
1238
+ if isinstance(device, str):
1239
+ device = torch.device(device)
1240
+ elif isinstance(device, int):
1241
+ device = torch.device("cuda", device)
1242
+ return device
1243
+
1244
+
1245
+ def _get_generator(device: torch.device) -> torch._C.Generator:
1246
+ r"""Return the CUDA Generator object for the given device.
1247
+
1248
+ Args:
1249
+ device (torch.device): selected device.
1250
+ """
1251
+ idx = device.index
1252
+ if idx is None:
1253
+ idx = current_device()
1254
+ return torch.cuda.default_generators[idx]
1255
+
1256
+
1257
+ def _set_rng_state_offset(
1258
+ offset: int, device: Union[int, str, torch.device] = "cuda"
1259
+ ) -> None:
1260
+ r"""Set the random number generator state offset of the specified GPU.
1261
+
1262
+ Args:
1263
+ offset (int): The desired offset
1264
+ device (torch.device or int, optional): The device to set the RNG state.
1265
+ Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
1266
+ """
1267
+ final_device = _get_device(device)
1268
+
1269
+ def cb():
1270
+ default_generator = _get_generator(final_device)
1271
+ default_generator.set_offset(offset)
1272
+
1273
+ _lazy_call(cb)
1274
+
1275
+
1276
+ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int:
1277
+ r"""Return the random number generator state offset of the specified GPU.
1278
+
1279
+ Args:
1280
+ device (torch.device or int, optional): The device to return the RNG state offset of.
1281
+ Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
1282
+
1283
+ .. warning::
1284
+ This function eagerly initializes CUDA.
1285
+ """
1286
+ _lazy_init()
1287
+ final_device = _get_device(device)
1288
+ default_generator = _get_generator(final_device)
1289
+ return default_generator.get_offset()
1290
+
1291
+
1292
+ from .memory import * # noqa: F403
1293
+ from .random import * # noqa: F403
1294
+
1295
+
1296
+ ################################################################################
1297
+ # Define Storage and Tensor classes
1298
+ ################################################################################
1299
+
1300
+
1301
+ @staticmethod # type: ignore[misc]
1302
+ def _lazy_new(cls, *args, **kwargs):
1303
+ _lazy_init()
1304
+ # We may need to call lazy init again if we are a forked child
1305
+ # del _CudaBase.__new__
1306
+ return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
1307
+
1308
+
1309
+ class _CudaBase:
1310
+ is_cuda = True
1311
+ is_sparse = False
1312
+
1313
+ def type(self, *args, **kwargs):
1314
+ # We could use a Protocol here to tell mypy that self has `get_device` method
1315
+ # but it is only available in the typing module on Python >= 3.8
1316
+ # or on typing_extensions module on Python >= 3.6
1317
+ with device(self.get_device()): # type: ignore[attr-defined]
1318
+ return super().type(*args, **kwargs) # type: ignore[misc]
1319
+
1320
+ __new__ = _lazy_new
1321
+
1322
+
1323
+ from torch.storage import _LegacyStorage, _warn_typed_storage_removal
1324
+
1325
+
1326
+ class _CudaLegacyStorage(_LegacyStorage):
1327
+ @classmethod
1328
+ def from_buffer(cls, *args, **kwargs):
1329
+ _warn_typed_storage_removal()
1330
+ raise RuntimeError("from_buffer: Not available for CUDA storage")
1331
+
1332
+ @classmethod
1333
+ def _new_with_weak_ptr(cls, *args, **kwargs):
1334
+ raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage")
1335
+
1336
+ @classmethod
1337
+ def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
1338
+ raise RuntimeError("_new_shared_filename: Not available for CUDA storage")
1339
+
1340
+
1341
+ class ByteStorage(_CudaLegacyStorage):
1342
+ @classproperty
1343
+ def dtype(self):
1344
+ _warn_typed_storage_removal()
1345
+ return self._dtype
1346
+
1347
+ @classproperty
1348
+ def _dtype(self):
1349
+ return torch.uint8
1350
+
1351
+
1352
+ class DoubleStorage(_CudaLegacyStorage):
1353
+ @classproperty
1354
+ def dtype(self):
1355
+ _warn_typed_storage_removal()
1356
+ return self._dtype
1357
+
1358
+ @classproperty
1359
+ def _dtype(self):
1360
+ return torch.double
1361
+
1362
+
1363
+ class FloatStorage(_CudaLegacyStorage):
1364
+ @classproperty
1365
+ def dtype(self):
1366
+ _warn_typed_storage_removal()
1367
+ return self._dtype
1368
+
1369
+ @classproperty
1370
+ def _dtype(self):
1371
+ return torch.float
1372
+
1373
+
1374
+ class HalfStorage(_CudaLegacyStorage):
1375
+ @classproperty
1376
+ def dtype(self):
1377
+ _warn_typed_storage_removal()
1378
+ return self._dtype
1379
+
1380
+ @classproperty
1381
+ def _dtype(self):
1382
+ return torch.half
1383
+
1384
+
1385
+ class LongStorage(_CudaLegacyStorage):
1386
+ @classproperty
1387
+ def dtype(self):
1388
+ _warn_typed_storage_removal()
1389
+ return self._dtype
1390
+
1391
+ @classproperty
1392
+ def _dtype(self):
1393
+ return torch.long
1394
+
1395
+
1396
+ class IntStorage(_CudaLegacyStorage):
1397
+ @classproperty
1398
+ def dtype(self):
1399
+ _warn_typed_storage_removal()
1400
+ return self._dtype
1401
+
1402
+ @classproperty
1403
+ def _dtype(self):
1404
+ return torch.int
1405
+
1406
+
1407
+ class ShortStorage(_CudaLegacyStorage):
1408
+ @classproperty
1409
+ def dtype(self):
1410
+ _warn_typed_storage_removal()
1411
+ return self._dtype
1412
+
1413
+ @classproperty
1414
+ def _dtype(self):
1415
+ return torch.short
1416
+
1417
+
1418
+ class CharStorage(_CudaLegacyStorage):
1419
+ @classproperty
1420
+ def dtype(self):
1421
+ _warn_typed_storage_removal()
1422
+ return self._dtype
1423
+
1424
+ @classproperty
1425
+ def _dtype(self):
1426
+ return torch.int8
1427
+
1428
+
1429
+ class BoolStorage(_CudaLegacyStorage):
1430
+ @classproperty
1431
+ def dtype(self):
1432
+ _warn_typed_storage_removal()
1433
+ return self._dtype
1434
+
1435
+ @classproperty
1436
+ def _dtype(self):
1437
+ return torch.bool
1438
+
1439
+
1440
+ class BFloat16Storage(_CudaLegacyStorage):
1441
+ @classproperty
1442
+ def dtype(self):
1443
+ _warn_typed_storage_removal()
1444
+ return self._dtype
1445
+
1446
+ @classproperty
1447
+ def _dtype(self):
1448
+ return torch.bfloat16
1449
+
1450
+
1451
+ class ComplexDoubleStorage(_CudaLegacyStorage):
1452
+ @classproperty
1453
+ def dtype(self):
1454
+ _warn_typed_storage_removal()
1455
+ return self._dtype
1456
+
1457
+ @classproperty
1458
+ def _dtype(self):
1459
+ return torch.cdouble
1460
+
1461
+
1462
+ class ComplexFloatStorage(_CudaLegacyStorage):
1463
+ @classproperty
1464
+ def dtype(self):
1465
+ _warn_typed_storage_removal()
1466
+ return self._dtype
1467
+
1468
+ @classproperty
1469
+ def _dtype(self):
1470
+ return torch.cfloat
1471
+
1472
+
1473
+ del _LegacyStorage
1474
+ del _CudaLegacyStorage
1475
+
1476
+ torch._storage_classes.add(DoubleStorage)
1477
+ torch._storage_classes.add(FloatStorage)
1478
+ torch._storage_classes.add(LongStorage)
1479
+ torch._storage_classes.add(IntStorage)
1480
+ torch._storage_classes.add(ShortStorage)
1481
+ torch._storage_classes.add(CharStorage)
1482
+ torch._storage_classes.add(ByteStorage)
1483
+ torch._storage_classes.add(HalfStorage)
1484
+ torch._storage_classes.add(BoolStorage)
1485
+ torch._storage_classes.add(BFloat16Storage)
1486
+ torch._storage_classes.add(ComplexDoubleStorage)
1487
+ torch._storage_classes.add(ComplexFloatStorage)
1488
+
1489
+
1490
+ class _WrappedTritonKernel:
1491
+ """Just a simple wrapper to store some metadata for testing purposes."""
1492
+
1493
+ def __init__(self, kernel):
1494
+ self.kernel = kernel
1495
+ self.kernel_invoked = False
1496
+
1497
+ def __call__(self, *args, **kwargs):
1498
+ res = self.kernel(*args, **kwargs)
1499
+ self.kernel_invoked = True
1500
+ return res
1501
+
1502
+
1503
+ def _register_triton_kernels():
1504
+ if torch._running_with_deploy():
1505
+ return
1506
+
1507
+ @_WrappedTritonKernel
1508
+ def kernel_impl(*args, **kwargs):
1509
+ from torch.sparse._triton_ops import bsr_dense_mm
1510
+
1511
+ return bsr_dense_mm(*args, skip_checks=True, **kwargs)
1512
+
1513
+ @_WrappedTritonKernel
1514
+ def addmm_kernel_impl(*args, **kwargs):
1515
+ from torch.sparse._triton_ops import bsr_dense_addmm
1516
+
1517
+ return bsr_dense_addmm(*args, skip_checks=True, **kwargs)
1518
+
1519
+ has_triton = importlib.util.find_spec("triton") is not None
1520
+ if has_triton:
1521
+ torch._TritonLibrary.registerOp(
1522
+ "_triton_bsr_dense_mm_out",
1523
+ "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
1524
+ kernel_impl,
1525
+ "SparseCsrCUDA",
1526
+ )
1527
+
1528
+ torch._TritonLibrary.registerOp(
1529
+ "_triton_bsr_dense_addmm_out",
1530
+ (
1531
+ "_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense,"
1532
+ " *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)"
1533
+ ),
1534
+ addmm_kernel_impl,
1535
+ "SparseCsrCUDA",
1536
+ )
1537
+
1538
+
1539
+ _lazy_call(_register_triton_kernels)
1540
+
1541
+
1542
+ from . import amp, jiterator, nvtx, profiler, sparse, tunable
1543
+
1544
+
1545
+ __all__ = [
1546
+ # Typed storage and tensors
1547
+ "BFloat16Storage",
1548
+ "BFloat16Tensor",
1549
+ "BoolStorage",
1550
+ "BoolTensor",
1551
+ "ByteStorage",
1552
+ "ByteTensor",
1553
+ "CharStorage",
1554
+ "CharTensor",
1555
+ "ComplexDoubleStorage",
1556
+ "ComplexFloatStorage",
1557
+ "DoubleStorage",
1558
+ "DoubleTensor",
1559
+ "FloatStorage",
1560
+ "FloatTensor",
1561
+ "HalfStorage",
1562
+ "HalfTensor",
1563
+ "IntStorage",
1564
+ "IntTensor",
1565
+ "LongStorage",
1566
+ "LongTensor",
1567
+ "ShortStorage",
1568
+ "ShortTensor",
1569
+ "CUDAGraph",
1570
+ "CudaError",
1571
+ "DeferredCudaCallError",
1572
+ "Event",
1573
+ "ExternalStream",
1574
+ "Stream",
1575
+ "StreamContext",
1576
+ "amp",
1577
+ "caching_allocator_alloc",
1578
+ "caching_allocator_delete",
1579
+ "can_device_access_peer",
1580
+ "check_error",
1581
+ "cudaStatus",
1582
+ "cudart",
1583
+ "current_blas_handle",
1584
+ "current_device",
1585
+ "current_stream",
1586
+ "default_generators",
1587
+ "default_stream",
1588
+ "device",
1589
+ "device_count",
1590
+ "device_of",
1591
+ "empty_cache",
1592
+ "get_allocator_backend",
1593
+ "CUDAPluggableAllocator",
1594
+ "change_current_allocator",
1595
+ "get_arch_list",
1596
+ "get_device_capability",
1597
+ "get_device_name",
1598
+ "get_device_properties",
1599
+ "get_gencode_flags",
1600
+ "get_rng_state",
1601
+ "get_rng_state_all",
1602
+ "get_sync_debug_mode",
1603
+ "graph",
1604
+ "graph_pool_handle",
1605
+ "graphs",
1606
+ "has_half",
1607
+ "has_magma",
1608
+ "init",
1609
+ "initial_seed",
1610
+ "ipc_collect",
1611
+ "is_available",
1612
+ "is_bf16_supported",
1613
+ "is_current_stream_capturing",
1614
+ "is_initialized",
1615
+ "jiterator",
1616
+ "list_gpu_processes",
1617
+ "make_graphed_callables",
1618
+ "manual_seed",
1619
+ "manual_seed_all",
1620
+ "max_memory_allocated",
1621
+ "max_memory_cached",
1622
+ "max_memory_reserved",
1623
+ "mem_get_info",
1624
+ "memory",
1625
+ "memory_allocated",
1626
+ "memory_cached",
1627
+ "memory_reserved",
1628
+ "memory_snapshot",
1629
+ "memory_stats",
1630
+ "memory_stats_as_nested_dict",
1631
+ "memory_summary",
1632
+ "memory_usage",
1633
+ "MemPool",
1634
+ "MemPoolContext",
1635
+ "use_mem_pool",
1636
+ "temperature",
1637
+ "power_draw",
1638
+ "clock_rate",
1639
+ "nccl",
1640
+ "nvtx",
1641
+ "profiler",
1642
+ "random",
1643
+ "reset_accumulated_memory_stats",
1644
+ "reset_max_memory_allocated",
1645
+ "reset_max_memory_cached",
1646
+ "reset_peak_memory_stats",
1647
+ "seed",
1648
+ "seed_all",
1649
+ "set_device",
1650
+ "set_per_process_memory_fraction",
1651
+ "set_rng_state",
1652
+ "set_rng_state_all",
1653
+ "set_stream",
1654
+ "set_sync_debug_mode",
1655
+ "sparse",
1656
+ "stream",
1657
+ "streams",
1658
+ "synchronize",
1659
+ "tunable",
1660
+ "utilization",
1661
+ ]
.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc ADDED
Binary file (50.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/cuda/_gpu_trace.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ from torch._utils import CallbackRegistry
4
+
5
+
6
+ EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
7
+ "CUDA event creation"
8
+ )
9
+ EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
10
+ "CUDA event deletion"
11
+ )
12
+ EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
13
+ "CUDA event record"
14
+ )
15
+ EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
16
+ "CUDA event wait"
17
+ )
18
+ MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
19
+ "CUDA memory allocation"
20
+ )
21
+ MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
22
+ "CUDA memory deallocation"
23
+ )
24
+ StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
25
+ "CUDA stream creation"
26
+ )
27
+ DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
28
+ "CUDA device synchronization"
29
+ )
30
+ StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
31
+ "CUDA stream synchronization"
32
+ )
33
+ EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
34
+ "CUDA event synchronization"
35
+ )
36
+
37
+
38
+ def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
39
+ EventCreationCallbacks.add_callback(cb)
40
+
41
+
42
+ def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
43
+ EventDeletionCallbacks.add_callback(cb)
44
+
45
+
46
+ def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
47
+ EventRecordCallbacks.add_callback(cb)
48
+
49
+
50
+ def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
51
+ EventWaitCallbacks.add_callback(cb)
52
+
53
+
54
+ def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
55
+ MemoryAllocationCallbacks.add_callback(cb)
56
+
57
+
58
+ def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
59
+ MemoryDeallocationCallbacks.add_callback(cb)
60
+
61
+
62
+ def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
63
+ StreamCreationCallbacks.add_callback(cb)
64
+
65
+
66
+ def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
67
+ DeviceSynchronizationCallbacks.add_callback(cb)
68
+
69
+
70
+ def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
71
+ StreamSynchronizationCallbacks.add_callback(cb)
72
+
73
+
74
+ def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
75
+ EventSynchronizationCallbacks.add_callback(cb)
.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import pickle
3
+ import sys
4
+ import os
5
+ import io
6
+ import subprocess
7
+ import json
8
+ from functools import lru_cache
9
+ from typing import Any
10
+ from itertools import groupby
11
+ import base64
12
+ import warnings
13
+ import operator
14
+
15
+ cache = lru_cache(None)
16
+
17
+ __all__ = ["format_flamegraph", "segments", "memory", "compare"]
18
+
19
+ def _frame_fmt(f, full_filename=False):
20
+ i = f['line']
21
+ fname = f['filename']
22
+ if not full_filename:
23
+ fname = fname.split('/')[-1]
24
+ func = f['name']
25
+ return f'{fname}:{i}:{func}'
26
+
27
+ @cache
28
+ def _frame_filter(name, filename):
29
+ omit_functions = [
30
+ "unwind::unwind",
31
+ "CapturedTraceback::gather",
32
+ "gather_with_cpp",
33
+ "_start",
34
+ "__libc_start_main",
35
+ "PyEval_",
36
+ "PyObject_",
37
+ "PyFunction_",
38
+ ]
39
+ omit_filenames = [
40
+ "core/boxing",
41
+ "/Register",
42
+ "/Redispatch",
43
+ "pythonrun.c",
44
+ "Modules/main.c",
45
+ "Objects/call.c",
46
+ "Objects/methodobject.c",
47
+ "pycore_ceval.h",
48
+ "ceval.c",
49
+ "cpython/abstract.h",
50
+ ]
51
+ for of in omit_functions:
52
+ if of in name:
53
+ return False
54
+ for of in omit_filenames:
55
+ if of in filename:
56
+ return False
57
+ return True
58
+
59
+ def _frames_fmt(frames, full_filename=False, reverse=False):
60
+ if reverse:
61
+ frames = reversed(frames)
62
+ return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
63
+
64
+ def _block_extra_legacy(b):
65
+ if 'history' in b:
66
+ frames = b['history'][0].get('frames', [])
67
+ real_size = b['history'][0]['real_size']
68
+ else:
69
+ real_size = b.get('requested_size', b['size'])
70
+ frames = []
71
+ return frames, real_size
72
+
73
+ def _block_extra(b):
74
+ if 'frames' not in b:
75
+ # old snapshot format made it more complicated to get frames/allocated size
76
+ return _block_extra_legacy(b)
77
+ return b['frames'], b['requested_size']
78
+
79
+ def format_flamegraph(flamegraph_lines, flamegraph_script=None):
80
+ if flamegraph_script is None:
81
+ flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl'
82
+ if not os.path.exists(flamegraph_script):
83
+ import urllib.request
84
+ print(f"Downloading flamegraph.pl to: {flamegraph_script}")
85
+ urllib.request.urlretrieve(
86
+ 'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script)
87
+ subprocess.check_call(['chmod', '+x', flamegraph_script])
88
+ args = [flamegraph_script, '--countname', 'bytes']
89
+ p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8')
90
+ assert p.stdin is not None
91
+ assert p.stdout is not None
92
+ p.stdin.write(flamegraph_lines)
93
+ p.stdin.close()
94
+ result = p.stdout.read()
95
+ p.stdout.close()
96
+ p.wait()
97
+ assert p.wait() == 0
98
+ return result
99
+
100
+ def _write_blocks(f, prefix, blocks):
101
+ def frames_fragment(frames):
102
+ if not frames:
103
+ return "<non-python>"
104
+ return ';'.join(_frames_fmt(frames, reverse=True))
105
+ for b in blocks:
106
+ if 'history' not in b:
107
+ frames, accounted_for_size = _block_extra(b)
108
+ f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n')
109
+ else:
110
+ accounted_for_size = 0
111
+ for h in b['history']:
112
+ sz = h['real_size']
113
+ accounted_for_size += sz
114
+ if 'frames' in h:
115
+ frames = h['frames']
116
+ f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n')
117
+ else:
118
+ f.write(f'{prefix};{b["state"]};<no-context> {sz}\n')
119
+ gaps = b['size'] - accounted_for_size
120
+ if gaps:
121
+ f.write(f'{prefix};{b["state"]};<gaps> {gaps}\n')
122
+
123
+ def segments(snapshot, format_flamegraph=format_flamegraph):
124
+ f = io.StringIO()
125
+ for seg in snapshot['segments']:
126
+ prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
127
+ _write_blocks(f, prefix, seg['blocks'])
128
+ return format_flamegraph(f.getvalue())
129
+
130
+ def memory(snapshot, format_flamegraph=format_flamegraph):
131
+ f = io.StringIO()
132
+ for seg in snapshot['segments']:
133
+ prefix = f'stream_{seg["stream"]}'
134
+ _write_blocks(f, prefix, seg['blocks'])
135
+ return format_flamegraph(f.getvalue())
136
+
137
+ def compare(before, after, format_flamegraph=format_flamegraph):
138
+ def _seg_key(seg):
139
+ return (seg['address'], seg['total_size'])
140
+
141
+ def _seg_info(seg):
142
+ return f'stream_{seg["stream"]};seg_{seg["address"]}'
143
+
144
+ f = io.StringIO()
145
+
146
+ before_segs = {_seg_key(seg) for seg in before}
147
+ after_segs = {_seg_key(seg) for seg in after}
148
+
149
+ print(f'only_before = {[a for a, _ in (before_segs - after_segs)]}')
150
+ print(f'only_after = {[a for a, _ in (after_segs - before_segs)]}')
151
+
152
+ for seg in before:
153
+ if _seg_key(seg) not in after_segs:
154
+ _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks'])
155
+
156
+ for seg in after:
157
+ if _seg_key(seg) not in before_segs:
158
+ _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks'])
159
+
160
+ return format_flamegraph(f.getvalue())
161
+
162
+ def _format_size(num):
163
+ # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
164
+ for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
165
+ if abs(num) < 1024.0:
166
+ return f"{num:3.1f}{unit}B"
167
+ num /= 1024.0
168
+ return f"{num:.1f}YiB"
169
+
170
+ class Bytes:
171
+ def __init__(self, value):
172
+ self.value = value
173
+
174
+ def __add__(self, rhs):
175
+ return Bytes(self.value + rhs)
176
+
177
+ def __repr__(self):
178
+ return _format_size(self.value)
179
+
180
+ def calc_active(seg):
181
+ return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated')
182
+
183
+ def _report_free(free_external, free_internal):
184
+ total = free_external + free_internal
185
+ suffix = ''
186
+ if total != 0:
187
+ pct = (free_internal / total) * 100
188
+ suffix = f' ({pct:.1f}% internal)'
189
+ return f'{Bytes(total)}{suffix}'
190
+
191
+ PAGE_SIZE = 1024 * 1024 * 20
192
+ legend = f"""\
193
+
194
+ Legend:
195
+ [a ] - a segment in the allocator
196
+ ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
197
+ a-z: pages filled with a single block's content
198
+ ' ': page is completely free
199
+ *: page if completely full with multiple blocks
200
+ 0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
201
+ (X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
202
+ """
203
+
204
+ def segsum(data):
205
+ r"""Visually reports how the allocator has filled its segments.
206
+
207
+ This printout can help debug fragmentation issues since free fragments
208
+ will appear as gaps in this printout. The amount of free space is reported
209
+ for each segment.
210
+ We distinguish between internal free memory which occurs because the
211
+ allocator rounds the allocation size, and external free memory, which are
212
+ the gaps between allocations in a segment.
213
+ Args:
214
+ data: snapshot dictionary created from _snapshot()
215
+ """
216
+ segments = []
217
+ out = io.StringIO()
218
+ out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
219
+ total_reserved = 0
220
+ total_allocated = 0
221
+ free_external = 0
222
+ free_internal = 0
223
+ for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))):
224
+ total_reserved += seg['total_size']
225
+
226
+ seg_free_external = 0
227
+ seg_free_internal = 0
228
+ seg_allocated = 0
229
+ all_ranges = []
230
+ boffset = 0
231
+ for b in seg['blocks']:
232
+ active = b['state'] == 'active_allocated'
233
+ if active:
234
+ _, allocated_size = _block_extra(b)
235
+ all_ranges.append((boffset, allocated_size, True))
236
+ seg_allocated += allocated_size
237
+ seg_free_internal += b['size'] - allocated_size
238
+ else:
239
+ seg_free_external += b['size']
240
+
241
+ boffset += b['size']
242
+
243
+ total_allocated += seg_allocated
244
+ free_external += seg_free_external
245
+ free_internal += seg_free_internal
246
+
247
+ nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1
248
+ occupied = [' ' for _ in range(nseg)]
249
+ frac = [0.0 for _ in range(nseg)]
250
+ active_size = 0
251
+ for i, (start_, size, active) in enumerate(all_ranges):
252
+ active_size += size
253
+ finish_ = (start_ + size)
254
+ start = start_ // PAGE_SIZE
255
+ finish = (finish_ - 1) // PAGE_SIZE + 1
256
+ m = chr(ord('a' if active else 'A') + (i % 26))
257
+ for j in range(start, finish):
258
+ s = max(start_, j * PAGE_SIZE)
259
+ e = min(finish_, (j + 1) * PAGE_SIZE)
260
+ frac[j] += (e - s) / PAGE_SIZE
261
+ if occupied[j] != ' ':
262
+ occupied[j] = '0123456789*'[int(frac[j] * 10)]
263
+ else:
264
+ occupied[j] = m
265
+ stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}'
266
+ body = ''.join(occupied)
267
+ assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size']
268
+ stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else ''
269
+ if seg['total_size'] >= PAGE_SIZE:
270
+ out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, '
271
+ f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n')
272
+ out.write(f'segments: {len(data["segments"])}\n')
273
+ out.write(f'total_reserved: {Bytes(total_reserved)}\n')
274
+ out.write(f'total_allocated: {Bytes(total_allocated)}\n')
275
+ internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
276
+ out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
277
+ out.write(legend)
278
+ assert free_internal + free_external + total_allocated == total_reserved
279
+ return out.getvalue()
280
+
281
+ def trace(data):
282
+ out = io.StringIO()
283
+
284
+ def format(entries):
285
+ segment_intervals : list = []
286
+ segment_addr_to_name = {}
287
+ allocation_addr_to_name = {}
288
+
289
+ free_names : list = []
290
+ next_name = 0
291
+
292
+ def _name():
293
+ nonlocal next_name
294
+ if free_names:
295
+ return free_names.pop()
296
+ r, m = next_name // 26, next_name % 26
297
+ next_name += 1
298
+ return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
299
+
300
+ def find_segment(addr):
301
+ for name, saddr, size in segment_intervals:
302
+ if addr >= saddr and addr < saddr + size:
303
+ return name, saddr
304
+ for i, seg in enumerate(data['segments']):
305
+ saddr = seg['address']
306
+ size = seg['allocated_size']
307
+ if addr >= saddr and addr < saddr + size:
308
+ return f'seg_{i}', saddr
309
+ return None, None
310
+ count = 0
311
+ out.write(f'{len(entries)} entries\n')
312
+
313
+
314
+ total_reserved = 0
315
+ for seg in data['segments']:
316
+ total_reserved += seg['total_size']
317
+
318
+ for count, e in enumerate(entries):
319
+ if e['action'] == 'alloc':
320
+ addr, size = e['addr'], e['size']
321
+ n = _name()
322
+ seg_name, seg_addr = find_segment(addr)
323
+ if seg_name is None:
324
+ seg_name = "MEM"
325
+ offset = addr
326
+ else:
327
+ offset = addr - seg_addr
328
+ out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n')
329
+ allocation_addr_to_name[addr] = (n, size, count)
330
+ count += size
331
+ elif e['action'] == 'free_requested':
332
+ addr, size = e['addr'], e['size']
333
+ name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
334
+ out.write(f'del {name} # {Bytes(size)}\n')
335
+ elif e['action'] == 'free_completed':
336
+ addr, size = e['addr'], e['size']
337
+ count -= size
338
+ name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
339
+ out.write(f'# free completed for {name} {Bytes(size)}\n')
340
+ if name in allocation_addr_to_name:
341
+ free_names.append(name)
342
+ del allocation_addr_to_name[name]
343
+ elif e['action'] == 'segment_alloc':
344
+ addr, size = e['addr'], e['size']
345
+ name = _name()
346
+ out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n')
347
+ segment_intervals.append((name, addr, size))
348
+ segment_addr_to_name[addr] = name
349
+ elif e['action'] == 'segment_free':
350
+ addr, size = e['addr'], e['size']
351
+ name = segment_addr_to_name.get(addr, addr)
352
+ out.write(f'cudaFree({name}) # {Bytes(size)}\n')
353
+ if name in segment_addr_to_name:
354
+ free_names.append(name)
355
+ del segment_addr_to_name[name]
356
+ elif e['action'] == 'oom':
357
+ size = e['size']
358
+ free = e['device_free']
359
+ out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
360
+ else:
361
+ out.write(f'{e}\n')
362
+ out.write(f"TOTAL MEM: {Bytes(count)}")
363
+ for i, d in enumerate(data['device_traces']):
364
+ if d:
365
+ out.write(f'Device {i} ----------------\n')
366
+ format(d)
367
+ return out.getvalue()
368
+
369
+
370
+ _memory_viz_template = r"""
371
+ <!DOCTYPE html>
372
+ <html>
373
+ <head>
374
+ </head>
375
+ <body>
376
+ <script type="module">
377
+ import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
378
+ const local_files = $SNAPSHOT
379
+ add_local_files(local_files, $VIZ_KIND)
380
+ </script>
381
+ </body>
382
+ """
383
+
384
+ def _format_viz(data, viz_kind, device):
385
+ if device is not None:
386
+ warnings.warn(
387
+ 'device argument is deprecated, plots now contain all device',
388
+ FutureWarning,
389
+ stacklevel=3,
390
+ )
391
+ buffer = pickle.dumps(data)
392
+ buffer += b'\x00' * (3 - len(buffer) % 3)
393
+ # Encode the buffer with base64
394
+ encoded_buffer = base64.b64encode(buffer).decode('utf-8')
395
+
396
+ json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}])
397
+ return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \
398
+ .replace('$SNAPSHOT', json_format)
399
+
400
+ def trace_plot(data, device=None, plot_segments=False):
401
+ """Generate a visualization over time of the memory usage recorded by the trace as an html file.
402
+
403
+ Args:
404
+ data: Memory snapshot as generated from torch.cuda.memory._snapshot()
405
+ device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
406
+ plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
407
+ Defaults to False.
408
+
409
+ Returns:
410
+ str: HTML of visualization
411
+ """
412
+ return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device)
413
+
414
+
415
+ def _profile_to_snapshot(profile):
416
+ import torch
417
+ from torch.profiler._memory_profiler import Action, TensorKey
418
+ from torch._C._profiler import _EventType
419
+ memory_profile = profile._memory_profile()
420
+
421
+ allocation_stacks = {}
422
+ for event in memory_profile._op_tree.sorted_nodes:
423
+ if event.tag == _EventType.Allocation:
424
+ parent = event.parent
425
+ python_parents = []
426
+ while parent:
427
+ if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
428
+ python_parents.append(parent)
429
+ parent = parent.parent
430
+ key = TensorKey.from_allocation(event.extra_fields)
431
+
432
+ # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
433
+ # key will be None. I should add some way to identify these, I just haven't yet.
434
+ if key and event.extra_fields.alloc_size > 0:
435
+ allocation_stacks[key] = python_parents
436
+
437
+
438
+ device_count = torch.cuda.device_count()
439
+ snapshot = {
440
+ 'device_traces': [[] for _ in range(device_count + 1)],
441
+ 'segments': [{'device': device,
442
+ 'address': None,
443
+ 'total_size': 0,
444
+ 'stream': 0,
445
+ 'blocks': []} for device in range(device_count + 1)]
446
+ }
447
+
448
+ def to_device(device):
449
+ if device.type == 'cuda':
450
+ return device.index
451
+ else:
452
+ return device_count
453
+
454
+ def allocate(size, tensor_key, version, during_trace=True):
455
+ device = to_device(tensor_key.device)
456
+ addr = tensor_key.storage.ptr
457
+
458
+ seg = snapshot['segments'][device] # type: ignore[index]
459
+ if seg['address'] is None or seg['address'] > addr:
460
+ seg['address'] = addr
461
+ seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
462
+ category = memory_profile._categories.get(tensor_key, version)
463
+ category = category.name.lower() if category is not None else "unknown"
464
+ stack = allocation_stacks.get(tensor_key, ())
465
+ stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
466
+ r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
467
+ if during_trace:
468
+ snapshot['device_traces'][device].append(r) # type: ignore[index]
469
+ return r
470
+
471
+ def free(alloc, device):
472
+ for e in ('free_requested', 'free_completed'):
473
+ snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
474
+ 'addr': alloc['addr'],
475
+ 'size': alloc['size'],
476
+ 'stream': 0,
477
+ 'frames': alloc['frames']})
478
+
479
+ kv_to_elem = {}
480
+
481
+
482
+
483
+ # create the device trace
484
+ for time, action, (tensor_key, version), size in memory_profile.timeline:
485
+ if not isinstance(tensor_key, TensorKey):
486
+ continue
487
+ if action == Action.CREATE:
488
+ kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version)
489
+ elif action == Action.DESTROY:
490
+ free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
491
+ elif action == Action.INCREMENT_VERSION:
492
+ free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
493
+ kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1)
494
+ elif action == Action.PREEXISTING:
495
+ kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False)
496
+
497
+
498
+ # create the final snapshot state
499
+ blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
500
+ for (tensor_key, version), event in kv_to_elem.items()]
501
+ for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)):
502
+ seg = snapshot['segments'][device] # type: ignore[index]
503
+ last_addr = seg['address']
504
+ for _, addr, size, frames in blocks:
505
+ if last_addr < addr:
506
+ seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'})
507
+ seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames})
508
+ last_addr = addr + size
509
+ if last_addr < seg['total_size']:
510
+ seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
511
+
512
+ snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
513
+ for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
514
+ seg['total_size'] -= seg['address']
515
+ if not seg['blocks']:
516
+ seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
517
+
518
+ return snapshot
519
+
520
+ def profile_plot(profile, device=None):
521
+ """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file.
522
+
523
+ Args:
524
+ profile: profile as generated by `torch.profiler.profile(profile_memory=True)`
525
+ device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
526
+
527
+ Returns:
528
+ str: HTML of visualization
529
+ """
530
+ snapshot = _profile_to_snapshot(profile)
531
+ return _format_viz(snapshot, 'Active Memory Timeline', device)
532
+
533
+
534
+ def segment_plot(data: Any, device=None):
535
+ return _format_viz(data, 'Allocator State History', device)
536
+
537
+ if __name__ == "__main__":
538
+ import os.path
539
+ thedir = os.path.realpath(os.path.dirname(__file__))
540
+ if thedir in sys.path:
541
+ # otherwise we find cuda/random.py as random...
542
+ sys.path.remove(thedir)
543
+ import argparse
544
+
545
+ fn_name = 'torch.cuda.memory._snapshot()'
546
+ pickled = f'pickled memory statistics from {fn_name}'
547
+ parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}')
548
+
549
+ subparsers = parser.add_subparsers(dest='action')
550
+
551
+ def _output(p):
552
+ p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)')
553
+
554
+ description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.'
555
+ stats_a = subparsers.add_parser('stats', description=description)
556
+ stats_a.add_argument('input', help=pickled)
557
+
558
+ description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.'
559
+ trace_a = subparsers.add_parser('trace', description=description)
560
+ trace_a.add_argument('input', help=pickled)
561
+
562
+ description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)'
563
+ segments_a = subparsers.add_parser('segments', description=description)
564
+ segments_a.add_argument('input', help=pickled)
565
+ _output(segments_a)
566
+
567
+ description = "Generate a flamegraph the program locations contributing to CUDA memory usage."
568
+ memory_a = subparsers.add_parser('memory', description=description)
569
+ memory_a.add_argument('input', help=pickled)
570
+ _output(memory_a)
571
+
572
+ description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \
573
+ 'or removed between two different memorys snapshots.'
574
+ compare_a = subparsers.add_parser('compare', description=description)
575
+ compare_a.add_argument('before', help=pickled)
576
+ compare_a.add_argument('after', help=pickled)
577
+ _output(compare_a)
578
+
579
+ plots = (
580
+ ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."),
581
+ ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.")
582
+ )
583
+ for cmd, description in plots:
584
+ trace_plot_a = subparsers.add_parser(cmd, description=description)
585
+ trace_plot_a.add_argument('input', help=pickled)
586
+ help = 'visualize trace from this device (default: chooses the only device with trace info or errors)'
587
+ trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help)
588
+ help = 'path to save the visualization(default: output.html)'
589
+ trace_plot_a.add_argument('-o', '--output', default='output.html', help=help)
590
+ if cmd == "trace_plot":
591
+ help = 'visualize change to segments rather than individual allocations'
592
+ trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help)
593
+
594
+
595
+ args = parser.parse_args()
596
+
597
+ def _read(name):
598
+ if name == '-':
599
+ f = sys.stdin.buffer
600
+ else:
601
+ f = open(name, 'rb')
602
+ data = pickle.load(f)
603
+ if isinstance(data, list): # segments only...
604
+ data = {'segments': data, 'traces': []}
605
+ return data
606
+
607
+ def _write(name, data):
608
+ with open(name, 'w') as f:
609
+ f.write(data)
610
+
611
+ if args.action == 'segments':
612
+ data = _read(args.input)
613
+ _write(args.output, segments(data))
614
+ elif args.action == 'memory':
615
+ data = _read(args.input)
616
+ _write(args.output, memory(data))
617
+ elif args.action == 'stats':
618
+ data = _read(args.input)
619
+ print(segsum(data))
620
+ elif args.action == 'trace':
621
+ data = _read(args.input)
622
+ print(trace(data))
623
+ elif args.action == 'compare':
624
+ before = _read(args.before)
625
+ after = _read(args.after)
626
+ _write(args.output, compare(before, after))
627
+ elif args.action == 'trace_plot':
628
+ data = _read(args.input)
629
+ _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments))
630
+ elif args.action == 'segment_plot':
631
+ data = _read(args.input)
632
+ _write(args.output, segment_plot(data, device=args.device))
.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""
3
+ This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
4
+
5
+ It stores information on accesses to tensors to determine if they are synchronized
6
+ or not. When enabled in a python program and a possible data race is detected, a
7
+ detailed warning will be printed and the program will exit.
8
+
9
+ It can be enabled either by importing this module and calling
10
+ :func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
11
+ environment variable.
12
+ """
13
+
14
+ import enum
15
+ import functools
16
+ import inspect
17
+ import io
18
+ import logging
19
+ import sys
20
+ import textwrap
21
+ import traceback
22
+ from dataclasses import dataclass, field
23
+ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
24
+
25
+ import torch
26
+ import torch.cuda._gpu_trace as gpu_trace
27
+ from torch.utils import _pytree as pytree
28
+ from torch.utils._python_dispatch import TorchDispatchMode
29
+
30
+
31
+ DEFAULT_STREAM_ID = 0
32
+
33
+ TK = TypeVar("TK")
34
+ TVa = TypeVar("TVa")
35
+ TVb = TypeVar("TVb")
36
+
37
+ DataPtr = int
38
+ StreamId = int
39
+ EventId = int
40
+ SeqNum = int
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class AccessType(enum.Enum):
46
+ READ = enum.auto()
47
+ WRITE = enum.auto()
48
+
49
+ def __str__(self):
50
+ return "reading from" if self is AccessType.READ else "writing to"
51
+
52
+
53
+ @dataclass
54
+ class Access:
55
+ r"""Stores information about a single access to a tensor by a kernel.
56
+
57
+ Args:
58
+ type: either AccessType.READ or AccessType.Write.
59
+ seq_num: the sequential number of the kernel performing the access.
60
+ stream: the stream id of the stream executing the kernel.
61
+ operator: the schema of the launched kernel, which lists the
62
+ arguments and return type.
63
+ aliases: the arguments in the schema this access corresponds to.
64
+ is_output: Whether the tensor was an output of the kernel.
65
+ stack_trace: the stack summary object captured during access.
66
+ """
67
+
68
+ type: AccessType
69
+ seq_num: SeqNum
70
+ stream: StreamId
71
+ operator: str
72
+ aliases: List[str]
73
+ is_output: bool
74
+ stack_trace: traceback.StackSummary
75
+
76
+
77
+ class SynchronizationError(Exception):
78
+ """Base class for errors detected by CUDA Sanitizer."""
79
+
80
+
81
+ class UnsynchronizedAccessError(SynchronizationError):
82
+ """Stores information about two unsynchronized accesses to one data pointer."""
83
+
84
+ def __init__(
85
+ self,
86
+ data_ptr: DataPtr,
87
+ allocation_stack_trace: Optional[traceback.StackSummary],
88
+ current_access: Access,
89
+ previous_access: Access,
90
+ ):
91
+ self.data_ptr = data_ptr
92
+ self.allocation_stack_trace = allocation_stack_trace
93
+ self.current_access = current_access
94
+ self.previous_access = previous_access
95
+
96
+ def __str__(self):
97
+ def format_access(access: Access):
98
+ message.write(f"{access.operator}\n{access.type}")
99
+ if access.aliases:
100
+ message.write(" argument(s) " + ", ".join(access.aliases))
101
+ if access.is_output:
102
+ message.write(", and to")
103
+ if access.is_output:
104
+ message.write(" the output")
105
+ message.write(
106
+ f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
107
+ )
108
+
109
+ with io.StringIO() as message:
110
+ message.write(
111
+ textwrap.dedent(
112
+ f"""\
113
+ ============================
114
+ CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
115
+ Access by stream {self.current_access.stream} during kernel:
116
+ """
117
+ )
118
+ )
119
+ format_access(self.current_access)
120
+
121
+ message.write(
122
+ f"Previous access by stream {self.previous_access.stream} during kernel:\n"
123
+ )
124
+ format_access(self.previous_access)
125
+
126
+ if self.allocation_stack_trace:
127
+ message.write(
128
+ "Tensor was allocated with stack trace:\n"
129
+ f"{''.join(self.allocation_stack_trace.format())}"
130
+ )
131
+ else:
132
+ message.write("Trace for tensor allocation not found.")
133
+ return message.getvalue()
134
+
135
+
136
+ class CUDASanitizerErrors(Exception):
137
+ """Wrapper class for errors reported by CUDA Sanitizer."""
138
+
139
+ def __init__(self, errors: List[SynchronizationError]):
140
+ self.errors = errors
141
+
142
+ def __str__(self):
143
+ return f"detected {len(self.errors)} errors"
144
+
145
+
146
+ @dataclass
147
+ class TensorInfo:
148
+ r"""Stores information about a single tensor and recent accesses to it.
149
+
150
+ Args:
151
+ allocation_stack_trace: the stack summary object captured during tensor
152
+ allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
153
+ reads: list of read accesses to the tensor that were performed since
154
+ the last write.
155
+ write: the last write access to the tensor.
156
+ """
157
+
158
+ allocation_stack_trace: Optional[traceback.StackSummary]
159
+ reads: List[Access] = field(default_factory=list)
160
+ write: Optional[Access] = None
161
+
162
+
163
+ class _TensorsAccessed:
164
+ def __init__(self) -> None:
165
+ self.accesses: Dict[DataPtr, TensorInfo] = {}
166
+
167
+ def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
168
+ if data_ptr not in self.accesses:
169
+ logger.info(
170
+ "Found tensor with pointer: %s, but no matching tensor "
171
+ "allocation in the trace. Backfilling the trace now. "
172
+ "Perhaps the sanitizer was enabled after some torch operations?",
173
+ data_ptr,
174
+ )
175
+ self.create_tensor(data_ptr, None)
176
+
177
+ def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
178
+ if data_ptr in self.accesses:
179
+ logger.info(
180
+ "Found duplicate tensor allocation in the trace for tensor with "
181
+ "pointer: %s. Assuming the trace for tensor deallocation "
182
+ "wasn't caught and backfilling it now. "
183
+ "Perhaps the sanitizer was enabled after some torch operations?",
184
+ data_ptr,
185
+ )
186
+ self.delete_tensor(data_ptr)
187
+
188
+ def create_tensor(
189
+ self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
190
+ ) -> None:
191
+ self.accesses[data_ptr] = TensorInfo(stack_trace)
192
+
193
+ def delete_tensor(self, data_ptr: DataPtr) -> None:
194
+ del self.accesses[data_ptr]
195
+
196
+ def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
197
+ return True if self.accesses[data_ptr].reads else False
198
+
199
+ def get_allocation_stack_trace(
200
+ self, data_ptr: DataPtr
201
+ ) -> Optional[traceback.StackSummary]:
202
+ return self.accesses[data_ptr].allocation_stack_trace
203
+
204
+ def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
205
+ return self.accesses[data_ptr].write
206
+
207
+ def get_reads(self, data_ptr: DataPtr) -> List[Access]:
208
+ return self.accesses[data_ptr].reads
209
+
210
+ def add_read(self, data_ptr: DataPtr, access: Access) -> None:
211
+ self.accesses[data_ptr].reads.append(access)
212
+
213
+ def set_write(self, data_ptr: DataPtr, access: Access) -> None:
214
+ self.accesses[data_ptr].write = access
215
+ self.accesses[data_ptr].reads = []
216
+
217
+
218
+ class StreamSynchronizations:
219
+ def __init__(self) -> None:
220
+ self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
221
+ self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
222
+ self.host_sync_state: Dict[StreamId, SeqNum] = {}
223
+ self.create_stream(DEFAULT_STREAM_ID)
224
+
225
+ def _ensure_stream_exists(self, stream: StreamId) -> None:
226
+ if stream not in self.current_sync_states:
227
+ logger.info(
228
+ "Found Stream with id: %s, but no matching stream "
229
+ "creation in the trace. Backfilling the trace now. "
230
+ "Perhaps the sanitizer was enabled after some torch operations?",
231
+ stream,
232
+ )
233
+ self.create_stream(stream)
234
+
235
+ def _ensure_event_exists(self, event: EventId) -> None:
236
+ if event not in self.recorded_sync_states:
237
+ logger.info(
238
+ "Found Event with id: %s, but no matching event "
239
+ "creation in the trace. Backfilling the trace now. "
240
+ "Perhaps the sanitizer was enabled after some torch operations?",
241
+ event,
242
+ )
243
+ self.create_event(event)
244
+
245
+ def _ensure_event_does_not_exist(self, event: EventId) -> None:
246
+ if event in self.recorded_sync_states:
247
+ logger.info(
248
+ "Found duplicate event creation in the trace for event with "
249
+ "id: %s. Assuming the trace for event deletion wasn't caught "
250
+ "and backfilling it now. "
251
+ "Perhaps the sanitizer was enabled after some torch operations?",
252
+ event,
253
+ )
254
+ self.delete_event(event)
255
+
256
+ def create_stream(self, stream: StreamId) -> None:
257
+ if stream in self.current_sync_states:
258
+ logger.info(
259
+ "Found duplicate Stream creation in the trace for Stream with "
260
+ "id: %s. PyTorch Streams are only created once, so this "
261
+ "trace entry is ignored.",
262
+ stream,
263
+ )
264
+ else:
265
+ self.host_sync_state[stream] = 0
266
+ self.current_sync_states[stream] = self.host_sync_state.copy()
267
+
268
+ def create_event(self, event: EventId) -> None:
269
+ self._ensure_event_does_not_exist(event)
270
+ self.recorded_sync_states[event] = {}
271
+
272
+ def delete_event(self, event: EventId) -> None:
273
+ self._ensure_event_exists(event)
274
+ del self.recorded_sync_states[event]
275
+
276
+ def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
277
+ self._ensure_stream_exists(stream)
278
+ self.current_sync_states[stream][stream] = seq_num
279
+
280
+ def record_state(self, event: EventId, stream: StreamId) -> None:
281
+ self._ensure_event_exists(event)
282
+ self._ensure_stream_exists(stream)
283
+ self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
284
+
285
+ def _state_wait_for_other(
286
+ self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
287
+ ) -> None:
288
+ for stream, seq_num in other.items():
289
+ state[stream] = max(state.get(stream, -1), seq_num)
290
+
291
+ def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
292
+ self._ensure_stream_exists(stream)
293
+ self._ensure_event_exists(event)
294
+ self._state_wait_for_other(
295
+ self.current_sync_states[stream], self.recorded_sync_states[event]
296
+ )
297
+
298
+ def all_streams_wait_for_event(self, event: EventId) -> None:
299
+ self._ensure_event_exists(event)
300
+ for stream in self.current_sync_states.keys():
301
+ self.stream_wait_for_event(stream, event)
302
+
303
+ self._state_wait_for_other(
304
+ self.host_sync_state, self.recorded_sync_states[event]
305
+ )
306
+
307
+ def all_streams_wait_for_stream(self, stream: StreamId) -> None:
308
+ self._ensure_stream_exists(stream)
309
+ for state in self.current_sync_states.values():
310
+ self._state_wait_for_other(state, self.current_sync_states[stream])
311
+
312
+ self._state_wait_for_other(
313
+ self.host_sync_state, self.current_sync_states[stream]
314
+ )
315
+
316
+ def sync_all_streams(self) -> None:
317
+ for stream, state in self.current_sync_states.items():
318
+ self.host_sync_state[stream] = state[stream]
319
+
320
+ for state in self.current_sync_states.values():
321
+ self._state_wait_for_other(state, self.host_sync_state)
322
+
323
+ def is_ordered_after(
324
+ self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
325
+ ) -> bool:
326
+ self._ensure_stream_exists(current_stream)
327
+ self._ensure_stream_exists(other_stream)
328
+ return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
329
+
330
+
331
+ class EventHandler:
332
+ """Analyzes CSAN trace for synchronization errors.
333
+
334
+ Stores information on each stream's synchronizations with other streams as well
335
+ as tensor accesses to determine whether a given kernel launch might cause a
336
+ data race.
337
+ """
338
+
339
+ def __init__(self) -> None:
340
+ self.tensors_accessed = _TensorsAccessed()
341
+ self.syncs = StreamSynchronizations()
342
+ self.seq_num: SeqNum = 0
343
+
344
+ def _handle_kernel_launch(
345
+ self,
346
+ stream: StreamId,
347
+ read_only: Set[DataPtr],
348
+ read_write: Set[DataPtr],
349
+ outputs: Set[DataPtr],
350
+ operator: str,
351
+ tensor_aliases: Dict[int, List[str]],
352
+ ) -> List[SynchronizationError]:
353
+ def check_conflict(
354
+ data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
355
+ ) -> None:
356
+ if previous_access is None:
357
+ return
358
+ if not self.syncs.is_ordered_after(
359
+ current_access.stream, previous_access.seq_num, previous_access.stream
360
+ ):
361
+ error_list.append(
362
+ UnsynchronizedAccessError(
363
+ data_ptr,
364
+ self.tensors_accessed.get_allocation_stack_trace(data_ptr),
365
+ current_access,
366
+ previous_access,
367
+ )
368
+ )
369
+
370
+ error_list: List[SynchronizationError] = []
371
+ self.seq_num += 1
372
+ self.syncs.update_seq_num(stream, self.seq_num)
373
+ stack_trace = traceback.StackSummary.extract(
374
+ traceback.walk_stack(inspect.currentframe()), lookup_lines=False
375
+ )
376
+ # The stack trace generated in this way is in the inverse order, so it must be
377
+ # reversed.
378
+ stack_trace.reverse()
379
+
380
+ for data_ptr in read_only:
381
+ self.tensors_accessed.ensure_tensor_exists(data_ptr)
382
+ current_access = Access(
383
+ AccessType.READ,
384
+ self.seq_num,
385
+ stream,
386
+ operator,
387
+ tensor_aliases[data_ptr],
388
+ data_ptr in outputs,
389
+ stack_trace,
390
+ )
391
+ check_conflict(
392
+ data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
393
+ )
394
+ self.tensors_accessed.add_read(data_ptr, current_access)
395
+
396
+ for data_ptr in read_write:
397
+ self.tensors_accessed.ensure_tensor_exists(data_ptr)
398
+ current_access = Access(
399
+ AccessType.WRITE,
400
+ self.seq_num,
401
+ stream,
402
+ operator,
403
+ tensor_aliases[data_ptr],
404
+ data_ptr in outputs,
405
+ stack_trace,
406
+ )
407
+ if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
408
+ for previous_access in self.tensors_accessed.get_reads(data_ptr):
409
+ check_conflict(data_ptr, current_access, previous_access)
410
+ else:
411
+ check_conflict(
412
+ data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
413
+ )
414
+ self.tensors_accessed.set_write(data_ptr, current_access)
415
+
416
+ return error_list
417
+
418
+ def _handle_event_creation(self, event: EventId) -> None:
419
+ self.syncs.create_event(event)
420
+
421
+ def _handle_event_deletion(self, event: EventId) -> None:
422
+ self.syncs.delete_event(event)
423
+
424
+ def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
425
+ self.syncs.record_state(event, stream)
426
+
427
+ def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
428
+ self.syncs.stream_wait_for_event(stream, event)
429
+
430
+ def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
431
+ self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
432
+ stack_trace = traceback.StackSummary.extract(
433
+ traceback.walk_stack(inspect.currentframe()), lookup_lines=False
434
+ )
435
+ # The stack trace generated in this way is in the inverse order, so it must be
436
+ # reversed.
437
+ stack_trace.reverse()
438
+ self.tensors_accessed.create_tensor(
439
+ data_ptr,
440
+ stack_trace,
441
+ )
442
+
443
+ def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
444
+ self.tensors_accessed.ensure_tensor_exists(data_ptr)
445
+ self.tensors_accessed.delete_tensor(data_ptr)
446
+
447
+ def _handle_stream_creation(self, stream: StreamId) -> None:
448
+ self.syncs.create_stream(stream)
449
+
450
+ def _handle_device_synchronization(self) -> None:
451
+ self.syncs.sync_all_streams()
452
+
453
+ def _handle_stream_synchronization(self, stream: StreamId) -> None:
454
+ self.syncs.all_streams_wait_for_stream(stream)
455
+
456
+ def _handle_event_synchronization(self, event: EventId) -> None:
457
+ self.syncs.all_streams_wait_for_event(event)
458
+
459
+
460
+ def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
461
+ for arg, value in a.items():
462
+ if arg in b:
463
+ yield arg, value, b[arg]
464
+
465
+
466
+ def zip_arguments(
467
+ schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
468
+ ) -> Iterator[Tuple[torch.Argument, Any]]:
469
+ schema_args = schema.arguments[: len(args)]
470
+ schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
471
+
472
+ yield from zip(schema_args, args)
473
+
474
+ for _, argument, value in zip_by_key(schema_kwargs, kwargs):
475
+ yield (argument, value)
476
+
477
+
478
+ class ArgumentHandler:
479
+ def __init__(self) -> None:
480
+ self.dataptrs_read: Set[DataPtr] = set()
481
+ self.dataptrs_written: Set[DataPtr] = set()
482
+ self.tensor_aliases: Dict[DataPtr, List[str]] = {}
483
+ self.outputs: Set[DataPtr] = set()
484
+
485
+ def _handle_argument(
486
+ self,
487
+ value: Any,
488
+ is_write: bool,
489
+ name: Optional[str] = None,
490
+ is_output: bool = False,
491
+ ) -> None:
492
+ if isinstance(value, torch.Tensor) and value.is_cuda:
493
+ data_ptr = value.data_ptr()
494
+ if is_write:
495
+ self.dataptrs_written.add(data_ptr)
496
+ else:
497
+ self.dataptrs_read.add(data_ptr)
498
+
499
+ self.tensor_aliases.setdefault(data_ptr, [])
500
+ if name is not None:
501
+ self.tensor_aliases[data_ptr].append(name)
502
+ if is_output:
503
+ self.outputs.add(data_ptr)
504
+
505
+ def parse_inputs(
506
+ self,
507
+ schema: torch.FunctionSchema,
508
+ args: Tuple[Any, ...],
509
+ kwargs: Dict[str, Any],
510
+ ) -> None:
511
+ for argument, value in zip_arguments(schema, args, kwargs):
512
+ is_write = argument.alias_info is not None and argument.alias_info.is_write
513
+ pytree.tree_map_(
514
+ functools.partial(
515
+ self._handle_argument, is_write=is_write, name=argument.name
516
+ ),
517
+ value,
518
+ )
519
+
520
+ def parse_outputs(self, outputs: Any) -> None:
521
+ pytree.tree_map_(
522
+ functools.partial(self._handle_argument, is_write=True, is_output=True),
523
+ outputs,
524
+ )
525
+
526
+
527
+ class CUDASanitizerDispatchMode(TorchDispatchMode):
528
+ def __init__(self) -> None:
529
+ self.event_handler = EventHandler()
530
+ torch._C._activate_gpu_trace()
531
+ gpu_trace.register_callback_for_event_creation(
532
+ self.event_handler._handle_event_creation
533
+ )
534
+ gpu_trace.register_callback_for_event_deletion(
535
+ self.event_handler._handle_event_deletion
536
+ )
537
+ gpu_trace.register_callback_for_event_record(
538
+ self.event_handler._handle_event_record
539
+ )
540
+ gpu_trace.register_callback_for_event_wait(
541
+ self.event_handler._handle_event_wait
542
+ )
543
+ gpu_trace.register_callback_for_memory_allocation(
544
+ self.event_handler._handle_memory_allocation
545
+ )
546
+ gpu_trace.register_callback_for_memory_deallocation(
547
+ self.event_handler._handle_memory_deallocation
548
+ )
549
+ gpu_trace.register_callback_for_stream_creation(
550
+ self.event_handler._handle_stream_creation
551
+ )
552
+ gpu_trace.register_callback_for_device_synchronization(
553
+ self.event_handler._handle_device_synchronization
554
+ )
555
+ gpu_trace.register_callback_for_stream_synchronization(
556
+ self.event_handler._handle_stream_synchronization
557
+ )
558
+ gpu_trace.register_callback_for_event_synchronization(
559
+ self.event_handler._handle_event_synchronization
560
+ )
561
+
562
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
563
+ if kwargs is None:
564
+ kwargs = {}
565
+
566
+ argument_handler = ArgumentHandler()
567
+ argument_handler.parse_inputs(func._schema, args, kwargs)
568
+
569
+ outputs = func(*args, **kwargs)
570
+
571
+ argument_handler.parse_outputs(outputs)
572
+ errors = self.event_handler._handle_kernel_launch(
573
+ torch.cuda.current_stream().cuda_stream,
574
+ argument_handler.dataptrs_read - argument_handler.dataptrs_written,
575
+ argument_handler.dataptrs_written,
576
+ argument_handler.outputs,
577
+ func._schema,
578
+ argument_handler.tensor_aliases,
579
+ )
580
+ if errors:
581
+ for error in errors:
582
+ print(error, file=sys.stderr)
583
+ raise CUDASanitizerErrors(errors)
584
+
585
+ return outputs
586
+
587
+
588
+ class CUDASanitizer:
589
+ """Manages the lifetime of a CUDASanitizer dispatch mode object.
590
+
591
+ The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
592
+ context manager in the enable function/destructor, respectively. This is to
593
+ explicitly set the lifetime of the dispatch mode object to that of the application.
594
+ This approach was deemed more elegant than using the atexit module.
595
+ """
596
+
597
+ def __init__(self) -> None:
598
+ self.dispatch = CUDASanitizerDispatchMode()
599
+ self.enabled = False
600
+
601
+ def enable(self):
602
+ self.dispatch.__enter__()
603
+ self.enabled = True
604
+
605
+ def __del__(self):
606
+ if self.enabled:
607
+ self.dispatch.__exit__(None, None, None)
608
+
609
+
610
+ def enable_cuda_sanitizer():
611
+ """Enable CUDA Sanitizer.
612
+
613
+ The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
614
+ for synchronization errors. All data races found will be printed to the standard
615
+ error output along with stack traces of suspected causes. For best results, the
616
+ sanitizer should be enabled at the very beginning of the program.
617
+ """
618
+ cuda_sanitizer.enable()
619
+
620
+
621
+ cuda_sanitizer = CUDASanitizer()
.venv/lib/python3.11/site-packages/torch/cuda/_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ # The _get_device_index has been moved to torch.utils._get_device_index
6
+ from torch._utils import _get_device_index as _torch_get_device_index
7
+
8
+
9
+ def _get_device_index(
10
+ device: Any, optional: bool = False, allow_cpu: bool = False
11
+ ) -> int:
12
+ r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
13
+
14
+ If :attr:`device` is a torch.device object, returns the device index if it
15
+ is a CUDA device. Note that for a CUDA device without a specified index,
16
+ i.e., ``torch.device('cuda')``, this will return the current default CUDA
17
+ device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
18
+ CPU devices will be accepted and ``-1`` will be returned in this case.
19
+
20
+ If :attr:`device` is a Python integer, it is returned as is.
21
+
22
+ If :attr:`device` is ``None``, this will return the current default CUDA
23
+ device if :attr:`optional` is ``True``.
24
+ """
25
+ if isinstance(device, int):
26
+ return device
27
+ if isinstance(device, str):
28
+ device = torch.device(device)
29
+ if isinstance(device, torch.device):
30
+ if allow_cpu:
31
+ if device.type not in ["cuda", "cpu"]:
32
+ raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
33
+ elif device.type != "cuda":
34
+ raise ValueError(f"Expected a cuda device, but got: {device}")
35
+ if not torch.jit.is_scripting():
36
+ if isinstance(device, torch.cuda.device):
37
+ return device.idx
38
+ return _torch_get_device_index(device, optional, allow_cpu)
.venv/lib/python3.11/site-packages/torch/cuda/comm.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The functions here have been moved to torch.nn.parallel.comm
2
+ from torch.nn.parallel.comm import (
3
+ broadcast,
4
+ broadcast_coalesced,
5
+ gather,
6
+ reduce_add,
7
+ reduce_add_coalesced,
8
+ scatter,
9
+ )
10
+
11
+
12
+ __all__ = [
13
+ "broadcast",
14
+ "broadcast_coalesced",
15
+ "reduce_add",
16
+ "reduce_add_coalesced",
17
+ "scatter",
18
+ "gather",
19
+ ]
.venv/lib/python3.11/site-packages/torch/cuda/error.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/cuda/gds.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Callable, List, Optional
4
+
5
+ import torch
6
+ from torch.types import Storage
7
+
8
+
9
+ __all__: List[str] = []
10
+
11
+
12
+ def _dummy_fn(name: str) -> Callable:
13
+ def fn(*args, **kwargs): # type: ignore[no-untyped-def]
14
+ raise RuntimeError(f"torch._C.{name} is not supported on this platform")
15
+
16
+ return fn
17
+
18
+
19
+ if not hasattr(torch._C, "_gds_register_buffer"):
20
+ assert not hasattr(torch._C, "_gds_deregister_buffer")
21
+ assert not hasattr(torch._C, "_gds_register_handle")
22
+ assert not hasattr(torch._C, "_gds_deregister_handle")
23
+ assert not hasattr(torch._C, "_gds_load_storage")
24
+ assert not hasattr(torch._C, "_gds_save_storage")
25
+ # Define functions
26
+ torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
27
+ torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
28
+ torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
29
+ torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
30
+ torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
31
+ torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
32
+
33
+
34
+ def _gds_register_buffer(s: Storage) -> None:
35
+ """Registers a buffer.
36
+
37
+ Args:
38
+ s (Storage): Buffer to register.
39
+ """
40
+ torch._C._gds_register_buffer(s)
41
+
42
+
43
+ def _gds_deregister_buffer(s: Storage) -> None:
44
+ """Registers a buffer.
45
+
46
+ Args:
47
+ s (Storage): Buffer to register.
48
+ """
49
+ torch._C._gds_deregister_buffer(s)
50
+
51
+
52
+ class _GdsFile:
53
+ r"""Wrapper around cuFile.
54
+
55
+ cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
56
+
57
+ Args:
58
+ filename (str): Name of the file to open.
59
+ flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
60
+ be added automatically.
61
+
62
+ .. _CUDA GPUDirect Storage Documentation:
63
+ https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
64
+ """
65
+
66
+ def __init__(self, filename: str, flags: int):
67
+ if sys.platform == "win32":
68
+ raise RuntimeError("GdsFile is not supported on this platform.")
69
+ self.filename = filename
70
+ self.flags = flags
71
+ self.fd = os.open(filename, flags | os.O_DIRECT)
72
+ self.handle: Optional[int] = None
73
+ self.register_handle()
74
+
75
+ def __del__(self) -> None:
76
+ if self.handle is not None:
77
+ self.deregister_handle()
78
+ os.close(self.fd)
79
+
80
+ def register_handle(self) -> None:
81
+ """Registers file descriptor to cuFile Driver.
82
+
83
+ This is a wrapper around ``cuFileHandleRegister``.
84
+ """
85
+ assert (
86
+ self.handle is None
87
+ ), "Cannot register a handle that is already registered."
88
+ self.handle = torch._C._gds_register_handle(self.fd)
89
+
90
+ def deregister_handle(self) -> None:
91
+ """Deregisters file descriptor from cuFile Driver.
92
+
93
+ This is a wrapper around ``cuFileHandleDeregister``.
94
+ """
95
+ assert (
96
+ self.handle is not None
97
+ ), "Cannot deregister a handle that is not registered."
98
+ torch._C._gds_deregister_handle(self.handle)
99
+ self.handle = None
100
+
101
+ def load_storage(self, storage: Storage, offset: int = 0) -> None:
102
+ """Loads data from the file into the storage.
103
+
104
+ This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
105
+ will be loaded from the file at ``offset`` into the storage.
106
+
107
+ Args:
108
+ storage (Storage): Storage to load data into.
109
+ offset (int, optional): Offset into the file to start loading from. (Default: 0)
110
+ """
111
+ assert (
112
+ self.handle is not None
113
+ ), "Cannot load data from a file that is not registered."
114
+ torch._C._gds_load_storage(self.handle, storage, offset)
115
+
116
+ def save_storage(self, storage: Storage, offset: int = 0) -> None:
117
+ """Saves data from the storage into the file.
118
+
119
+ This is a wrapper around ``cuFileWrite``. All bytes of the storage
120
+ will be written to the file at ``offset``.
121
+
122
+ Args:
123
+ storage (Storage): Storage to save data from.
124
+ offset (int, optional): Offset into the file to start saving to. (Default: 0)
125
+ """
126
+ assert (
127
+ self.handle is not None
128
+ ), "Cannot save data to a file that is not registered."
129
+ torch._C._gds_save_storage(self.handle, storage, offset)
.venv/lib/python3.11/site-packages/torch/cuda/graphs.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import gc
3
+ import typing
4
+
5
+ import torch
6
+
7
+ from .._utils import _dummy_type
8
+
9
+
10
+ if not hasattr(torch._C, "_CudaStreamBase"):
11
+ # Define dummy base classes
12
+ torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
13
+ torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
14
+ torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
15
+ "_cuda_isCurrentStreamCapturing"
16
+ )
17
+
18
+ from torch._C import ( # noqa: F401
19
+ _cuda_isCurrentStreamCapturing,
20
+ _CUDAGraph,
21
+ _graph_pool_handle,
22
+ )
23
+
24
+
25
+ def is_current_stream_capturing():
26
+ r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
27
+
28
+ If a CUDA context does not exist on the current device, returns False without initializing the context.
29
+ """
30
+ return _cuda_isCurrentStreamCapturing()
31
+
32
+
33
+ # Python shim helps Sphinx process docstrings more reliably.
34
+ def graph_pool_handle():
35
+ r"""Return an opaque token representing the id of a graph memory pool.
36
+
37
+ See :ref:`Graph memory management<graph-memory-management>`.
38
+
39
+ .. warning::
40
+ This API is in beta and may change in future releases.
41
+ """
42
+ return _graph_pool_handle()
43
+
44
+
45
+ # Python shim helps Sphinx process docstrings more reliably.
46
+ class CUDAGraph(torch._C._CUDAGraph):
47
+ r"""Wrapper around a CUDA graph.
48
+
49
+ .. warning::
50
+ This API is in beta and may change in future releases.
51
+ """
52
+
53
+ def __new__(cls):
54
+ return super().__new__(cls)
55
+
56
+ def capture_begin(self, pool=None, capture_error_mode="global"):
57
+ r"""Begin capturing CUDA work on the current stream.
58
+
59
+ Typically, you shouldn't call ``capture_begin`` yourself.
60
+ Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
61
+ which call ``capture_begin`` internally.
62
+
63
+ Arguments:
64
+ pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
65
+ :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
66
+ with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
67
+ capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
68
+ Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
69
+ may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
70
+ actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
71
+ unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
72
+ """ # noqa: B950
73
+ super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
74
+
75
+ def capture_end(self):
76
+ r"""End CUDA graph capture on the current stream.
77
+
78
+ After ``capture_end``, ``replay`` may be called on this instance.
79
+
80
+ Typically, you shouldn't call ``capture_end`` yourself.
81
+ Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
82
+ which call ``capture_end`` internally.
83
+ """
84
+ super().capture_end()
85
+
86
+ def replay(self):
87
+ r"""Replay the CUDA work captured by this graph."""
88
+ super().replay()
89
+
90
+ def reset(self):
91
+ r"""Delete the graph currently held by this instance."""
92
+ super().reset()
93
+
94
+ def pool(self):
95
+ r"""Return an opaque token representing the id of this graph's memory pool.
96
+
97
+ This id can optionally be passed to another graph's ``capture_begin``,
98
+ which hints the other graph may share the same memory pool.
99
+ """
100
+ return super().pool()
101
+
102
+ def enable_debug_mode(self):
103
+ r"""Enable debugging mode for CUDAGraph.debug_dump."""
104
+ return super().enable_debug_mode()
105
+
106
+ def debug_dump(self, debug_path):
107
+ r"""
108
+ Arguments:
109
+ debug_path (required): Path to dump the graph to.
110
+
111
+ Calls a debugging function to dump the graph if the debugging is
112
+ enabled via CUDAGraph.enable_debug_mode()
113
+ """
114
+ return super().debug_dump(debug_path)
115
+
116
+
117
+ class graph:
118
+ r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
119
+
120
+ See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
121
+ detailed use, and constraints.
122
+
123
+ Arguments:
124
+ cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
125
+ pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
126
+ :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
127
+ may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
128
+ stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
129
+ If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
130
+ capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
131
+ Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
132
+ may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
133
+ actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
134
+ unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
135
+
136
+ .. note::
137
+ For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
138
+ used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
139
+
140
+ .. warning::
141
+ This API is in beta and may change in future releases.
142
+
143
+ .. _cudaStreamCaptureMode:
144
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
145
+ """ # noqa: B950
146
+
147
+ default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
148
+
149
+ def __init__(
150
+ self,
151
+ cuda_graph,
152
+ pool=None,
153
+ stream=None,
154
+ capture_error_mode: str = "global",
155
+ ):
156
+ # Lazy-init of default_capture_stream helps avoid circular-import errors.
157
+ # Not thread safe, but graphs already have the general (explicitly documented)
158
+ # restriction that only one capture may be underway at a time in the process.
159
+ if self.__class__.default_capture_stream is None:
160
+ self.__class__.default_capture_stream = torch.cuda.Stream()
161
+
162
+ self.pool = () if pool is None else (pool,)
163
+ self.capture_stream = (
164
+ stream if stream is not None else self.__class__.default_capture_stream
165
+ )
166
+ assert self.capture_stream is not None
167
+ self.stream_ctx = torch.cuda.stream(self.capture_stream)
168
+ self.cuda_graph = cuda_graph
169
+ self.capture_error_mode = capture_error_mode
170
+
171
+ def __enter__(self):
172
+ # Free as much memory as we can for the graph
173
+ torch.cuda.synchronize()
174
+ gc.collect()
175
+ torch.cuda.empty_cache()
176
+
177
+ # Stackoverflow seems comfortable with this pattern
178
+ # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
179
+ self.stream_ctx.__enter__()
180
+
181
+ self.cuda_graph.capture_begin(
182
+ *self.pool, capture_error_mode=self.capture_error_mode
183
+ )
184
+
185
+ def __exit__(self, exc_type, exc_value, traceback):
186
+ self.cuda_graph.capture_end()
187
+ self.stream_ctx.__exit__(exc_type, exc_value, traceback)
188
+ # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
189
+
190
+
191
+ def make_graphed_callables(
192
+ callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
193
+ ):
194
+ r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
195
+
196
+ Each graphed callable's forward pass runs its source callable's
197
+ forward CUDA work as a CUDA graph inside a single autograd node.
198
+
199
+ The graphed callable's forward pass also appends
200
+ a backward node to the autograd graph. During backward, this node runs the
201
+ callable's backward work as a CUDA graph.
202
+
203
+ Therefore, each graphed callable should be a drop-in replacement for its source callable
204
+ in an autograd-enabled training loop.
205
+
206
+ See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
207
+
208
+ If you pass a tuple of several callables, their captures will use the same memory pool.
209
+ See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
210
+
211
+ Arguments:
212
+ callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
213
+ See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
214
+ is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
215
+ they'll run in the live workload.
216
+ sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
217
+ If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
218
+ If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
219
+ num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
220
+ 11 iterations for warm up. Default: ``3``.
221
+ allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
222
+ (and therefore their grad is always zero) is an error. Defaults to False.
223
+ pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
224
+ :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
225
+ with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
226
+ .. note::
227
+ The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
228
+ that's expected for the corresponding real input in the training loop.
229
+
230
+ .. warning::
231
+ This API is in beta and may change in future releases.
232
+
233
+ .. warning::
234
+ ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
235
+
236
+ .. warning::
237
+ Returned callables do not support higher order differentiation (e.g., double backward).
238
+
239
+ .. warning::
240
+ In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
241
+ may be trainable. Buffers must have ``requires_grad=False``.
242
+
243
+ .. warning::
244
+ After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
245
+ you may not add or remove any of that Module's parameters or buffers.
246
+
247
+ .. warning::
248
+ :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
249
+ registered on them at the time they are passed. However, registering hooks on modules *after* passing them
250
+ through :func:`~torch.cuda.make_graphed_callables` is allowed.
251
+
252
+ .. warning::
253
+ When running a graphed callable, you must pass its arguments in the same order and format
254
+ they appeared in that callable's ``sample_args``.
255
+
256
+ .. warning::
257
+ The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
258
+ caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
259
+ """
260
+ if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
261
+ raise RuntimeError(
262
+ "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
263
+ )
264
+
265
+ just_one_callable = False
266
+
267
+ if not isinstance(callables, tuple):
268
+ just_one_callable = True
269
+ callables = (callables,)
270
+ sample_args = (sample_args,)
271
+
272
+ flatten_sample_args = []
273
+
274
+ for c, args in zip(callables, sample_args):
275
+ if isinstance(c, torch.nn.Module):
276
+ assert (
277
+ len(c._backward_hooks) == 0
278
+ and len(c._forward_hooks) == 0
279
+ and len(c._forward_pre_hooks) == 0
280
+ ), (
281
+ "Modules must not have hooks registered at the time they are passed. However, registering hooks "
282
+ + "on modules after passing them through make_graphed_callables is allowed."
283
+ )
284
+ assert all(b.requires_grad is False for b in c.buffers()), (
285
+ "In any :class:`~torch.nn.Module` passed to "
286
+ + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
287
+ + "``requires_grad=False``."
288
+ )
289
+ flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
290
+ flatten_sample_args.append(tuple(flatten_arg))
291
+ assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
292
+ "In the beta API, sample_args "
293
+ + "for each callable must contain only Tensors. Other types are not allowed."
294
+ )
295
+
296
+ # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
297
+ # passes to forward (ie, its sample_args) AND the module's parameter attributes.
298
+ per_callable_len_user_args = [len(args) for args in flatten_sample_args]
299
+ per_callable_module_params = [
300
+ tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
301
+ for c in callables
302
+ ]
303
+ per_callable_static_input_surfaces = [
304
+ flatten_sample_args[i] + per_callable_module_params[i]
305
+ for i in range(len(callables))
306
+ ]
307
+
308
+ fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
309
+ bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
310
+
311
+ mempool = graph_pool_handle() if pool is None else pool
312
+
313
+ # Warmup
314
+ # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
315
+ # from ending up in any captures.
316
+ torch.cuda.synchronize()
317
+ with torch.cuda.stream(torch.cuda.Stream()):
318
+ for func, args, static_input_surface in zip(
319
+ callables, sample_args, per_callable_static_input_surfaces
320
+ ):
321
+ grad_inputs, outputs, outputs_grad = None, None, None
322
+ for _ in range(num_warmup_iters):
323
+ outputs = torch.utils._pytree.tree_leaves(func(*args))
324
+ outputs_grad = tuple(o for o in outputs if o.requires_grad)
325
+ if len(outputs_grad) > 0:
326
+ grad_inputs = torch.autograd.grad(
327
+ outputs=outputs_grad,
328
+ inputs=tuple(
329
+ i for i in static_input_surface if i.requires_grad
330
+ ),
331
+ grad_outputs=tuple(
332
+ torch.empty_like(o) for o in outputs if o.requires_grad
333
+ ),
334
+ only_inputs=True,
335
+ allow_unused=allow_unused_input,
336
+ )
337
+ for v in [outputs, outputs_grad, grad_inputs]:
338
+ del v
339
+
340
+ torch.cuda.synchronize()
341
+
342
+ # All captures here share a mempool. To avoid replays corrupting each other's memory,
343
+ # the safest approach is to capture all passes in the same order they'll run:
344
+ # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
345
+
346
+ # Capture forward graphs
347
+ per_callable_static_outputs = []
348
+ per_callable_output_unflatten_spec = []
349
+ for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
350
+ with torch.cuda.graph(fwd_graph, pool=mempool):
351
+ outputs = func(*args)
352
+
353
+ flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
354
+ per_callable_static_outputs.append(tuple(flatten_outputs))
355
+ per_callable_output_unflatten_spec.append(spec)
356
+
357
+ # Capture backward graphs in reverse order
358
+ per_callable_static_grad_outputs = []
359
+ per_callable_static_grad_inputs = []
360
+ for static_input_surface, static_outputs, bwd_graph, module_params in zip(
361
+ reversed(per_callable_static_input_surfaces),
362
+ reversed(per_callable_static_outputs),
363
+ reversed(bwd_graphs),
364
+ reversed(per_callable_module_params),
365
+ ):
366
+ # For now, assumes all static_outputs require grad
367
+ # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
368
+ static_grad_outputs = tuple(
369
+ torch.empty_like(o) if o.requires_grad else None for o in static_outputs
370
+ )
371
+
372
+ outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
373
+ grad_inputs = None
374
+ if len(outputs_grad) > 0:
375
+ with torch.cuda.graph(bwd_graph, pool=mempool):
376
+ grad_inputs = torch.autograd.grad(
377
+ outputs=outputs_grad,
378
+ inputs=tuple(i for i in static_input_surface if i.requires_grad),
379
+ grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
380
+ only_inputs=True,
381
+ allow_unused=allow_unused_input,
382
+ )
383
+
384
+ # Constructs a tuple suitable for returning from Graphed.backward:
385
+ # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
386
+ # I couldn't think of a slick one-liner for this pattern.
387
+ static_grad_inputs = []
388
+ grad_idx = 0
389
+ for arg in static_input_surface:
390
+ if arg.requires_grad and grad_inputs is not None:
391
+ static_grad_inputs.append(grad_inputs[grad_idx])
392
+ grad_idx += 1
393
+ else:
394
+ static_grad_inputs.append(None) # type: ignore[arg-type]
395
+ static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
396
+
397
+ per_callable_static_grad_outputs.append(static_grad_outputs)
398
+ per_callable_static_grad_inputs.append(static_grad_inputs)
399
+
400
+ # Reverses the most recent two lists
401
+ per_callable_static_grad_outputs.reverse()
402
+ per_callable_static_grad_inputs.reverse()
403
+ # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
404
+
405
+ def make_graphed_autograd_function(
406
+ fwd_graph,
407
+ bwd_graph,
408
+ module_params,
409
+ len_user_args,
410
+ output_unflatten_spec,
411
+ static_input_surface,
412
+ static_outputs,
413
+ static_grad_outputs,
414
+ static_grad_inputs,
415
+ ):
416
+ class Graphed(torch.autograd.Function):
417
+ @staticmethod
418
+ def forward(ctx, *inputs):
419
+ # At this stage, only the user args may (potentially) be new tensors.
420
+ for i in range(len_user_args):
421
+ if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
422
+ static_input_surface[i].copy_(inputs[i])
423
+ fwd_graph.replay()
424
+ assert isinstance(static_outputs, tuple)
425
+ return tuple(o.detach() for o in static_outputs)
426
+
427
+ @staticmethod
428
+ @torch.autograd.function.once_differentiable
429
+ def backward(ctx, *grads):
430
+ assert len(grads) == len(static_grad_outputs)
431
+ for g, grad in zip(static_grad_outputs, grads):
432
+ if g is not None:
433
+ # don't copy if autograd gods have been kind and the
434
+ # incoming grad is already in the right place
435
+ if g.data_ptr() != grad.data_ptr():
436
+ g.copy_(grad)
437
+ bwd_graph.replay()
438
+
439
+ # Input args that didn't require grad expect a None gradient.
440
+ assert isinstance(static_grad_inputs, tuple)
441
+ return tuple(
442
+ b.detach() if b is not None else b for b in static_grad_inputs
443
+ )
444
+
445
+ def functionalized(*user_args):
446
+ # Runs the autograd function with inputs == all inputs to the graph that might require grad
447
+ # (explicit user args + module parameters)
448
+ # Assumes module params didn't change since capture.
449
+ flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
450
+ out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
451
+ return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)
452
+
453
+ return functionalized
454
+
455
+ # Put together the final graphed callables
456
+ ret = []
457
+ for i, func in enumerate(callables):
458
+ graphed = make_graphed_autograd_function(
459
+ fwd_graphs[i],
460
+ bwd_graphs[i],
461
+ per_callable_module_params[i],
462
+ per_callable_len_user_args[i],
463
+ per_callable_output_unflatten_spec[i],
464
+ per_callable_static_input_surfaces[i],
465
+ per_callable_static_outputs[i],
466
+ per_callable_static_grad_outputs[i],
467
+ per_callable_static_grad_inputs[i],
468
+ )
469
+
470
+ if isinstance(func, torch.nn.Module):
471
+
472
+ def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
473
+ def new_fwd(*user_args):
474
+ # If the module's training-or-eval state matches what we graphed,
475
+ # run the graph, otherwise run the original forward method
476
+ if func.training == graph_training_state:
477
+ return graphed(*user_args)
478
+ else:
479
+ return orig_fwd(*user_args)
480
+
481
+ return new_fwd
482
+
483
+ func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment]
484
+ ret.append(func)
485
+ else:
486
+ ret.append(graphed)
487
+
488
+ if just_one_callable:
489
+ return ret[0]
490
+
491
+ return tuple(ret)
.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import re
3
+ from typing import Callable, List
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+
9
+ __all__: List[str] = []
10
+
11
+
12
+ class _CodeParser:
13
+ def __init__(self, code_string: str):
14
+ optional_ws = r"\s*"
15
+ required_ws = r"\s+"
16
+ template_params = r"(?P<template_params>\<.+\>)"
17
+ return_type = r"(?P<return_type>\w+)"
18
+ function_name = r"(?P<function_name>\w+)"
19
+ function_params = r"(?P<function_params>\(.+\))"
20
+ function_body = r"(?P<function_body>\{.+\})"
21
+
22
+ pattern = (
23
+ optional_ws
24
+ + "template"
25
+ + optional_ws
26
+ + template_params
27
+ + optional_ws
28
+ + return_type
29
+ + required_ws
30
+ + function_name
31
+ + optional_ws
32
+ + function_params
33
+ + optional_ws
34
+ + function_body
35
+ + optional_ws
36
+ )
37
+
38
+ result = re.match(
39
+ pattern, code_string, re.DOTALL
40
+ ) # DOTALL for matching multiline
41
+
42
+ if result is None:
43
+ raise Exception( # noqa: TRY002
44
+ f"Couldn't parse code, please check correctness:\n {code_string}"
45
+ )
46
+
47
+ self.template_params = result["template_params"]
48
+ self.return_type = result["return_type"]
49
+ self.function_name = result["function_name"]
50
+ self.function_params = result["function_params"]
51
+ self.function_body = result["function_body"]
52
+
53
+
54
+ class _JittedFunction:
55
+ def __init__(
56
+ self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
57
+ ):
58
+ self.code_string = code_string
59
+
60
+ assert (
61
+ return_by_ref or num_outputs == 1
62
+ ), "Return by value only works for single output. "
63
+ self.return_by_ref = return_by_ref
64
+ self.num_outputs = num_outputs
65
+
66
+ parsed_code = _CodeParser(code_string)
67
+ self.kernel_name = parsed_code.function_name
68
+
69
+ self.kwargs_dict = kwargs
70
+ self.is_cuda_available = torch.cuda.is_available()
71
+
72
+ def __call__(self, *tensors: Tensor, **kwargs):
73
+ # Jiterator follow torch.cuda's lazy initialization behavior
74
+ # Defer checking cuda's availability at the function invocation time
75
+ assert (
76
+ self.is_cuda_available
77
+ ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
78
+
79
+ assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
80
+
81
+ expanded_kwargs = self.kwargs_dict.copy()
82
+ for key, value in kwargs.items():
83
+ if key in self.kwargs_dict:
84
+ expanded_kwargs[key] = value
85
+ else:
86
+ raise KeyError(f"{key} is not declared in function definition")
87
+
88
+ return torch._C._cuda_jiterator_compile_and_launch_kernel(
89
+ self.code_string,
90
+ self.kernel_name,
91
+ self.return_by_ref,
92
+ self.num_outputs,
93
+ tensors,
94
+ expanded_kwargs,
95
+ )
96
+
97
+
98
+ def _create_jit_fn(code_string: str, **kwargs) -> Callable:
99
+ """
100
+ Create a jiterator-generated cuda kernel for an elementwise op.
101
+
102
+ The code string has to be a valid CUDA function that describes the computation for a single element. The code
103
+ string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
104
+ into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
105
+ local temp dir.
106
+
107
+ Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
108
+
109
+ Args:
110
+ code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
111
+ kwargs (Dict, optional): Keyword arguments for generated function
112
+
113
+ Example::
114
+
115
+ code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
116
+ jitted_fn = create_jit_fn(code_string, alpha=1.0)
117
+ a = torch.rand(3, device='cuda')
118
+ b = torch.rand(3, device='cuda')
119
+ # invoke jitted function like a regular python function
120
+ result = jitted_fn(a, b, alpha=3.14)
121
+
122
+ code_string also allows multiple function definitions, and the last function will be treated as the entry function.
123
+
124
+ Example::
125
+
126
+ code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
127
+ code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
128
+ jitted_fn = create_jit_fn(code_string, val=0.0)
129
+ a = torch.rand(3, device='cuda')
130
+ b = torch.rand(3, device='cuda')
131
+ # invoke jitted function like a regular python function
132
+ result = jitted_fn(a, b) # using default val=0.0
133
+
134
+ Jiterator can be used together with python registration to override an operator's cuda kernel.
135
+ Following example is overriding gelu's cuda kernel with relu.
136
+
137
+ Example::
138
+
139
+ code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
140
+ my_gelu = create_jit_fn(code_string)
141
+ my_lib = torch.library.Library("aten", "IMPL")
142
+ my_lib.impl('aten::gelu', my_gelu, "CUDA")
143
+ # torch.nn.GELU and torch.nn.function.gelu are now overridden
144
+ a = torch.rand(3, device='cuda')
145
+ torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
146
+
147
+ .. warning::
148
+ This API is in beta and may change in future releases.
149
+
150
+ .. warning::
151
+ This API only supports up to 8 inputs and 1 output
152
+
153
+ .. warning::
154
+ All input tensors must live in CUDA device
155
+ """
156
+ return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
157
+
158
+
159
+ def _create_multi_output_jit_fn(
160
+ code_string: str, num_outputs: int, **kwargs
161
+ ) -> Callable:
162
+ """
163
+ Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
164
+
165
+ Args:
166
+ code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
167
+ num_outputs(int): number of outputs return by the kernel
168
+ kwargs (Dict, optional): Keyword arguments for generated function
169
+
170
+ Example::
171
+
172
+ code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
173
+ jitted_fn = create_jit_fn(code_string, alpha=1.0)
174
+ a = torch.rand(3, device='cuda')
175
+ b = torch.rand(3, device='cuda')
176
+ # invoke jitted function like a regular python function
177
+ result = jitted_fn(a, b, alpha=3.14)
178
+
179
+ .. warning::
180
+ This API is in beta and may change in future releases.
181
+
182
+ .. warning::
183
+ This API only supports up to 8 inputs and 8 outputs
184
+ """
185
+ return _JittedFunction(
186
+ code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
187
+ )
.venv/lib/python3.11/site-packages/torch/cuda/memory.py ADDED
@@ -0,0 +1,1041 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""This package adds support for device memory management implemented in CUDA."""
3
+
4
+ import collections
5
+ import contextlib
6
+ import ctypes
7
+ import pickle
8
+ import sys
9
+ import warnings
10
+ from inspect import signature
11
+ from typing import Any, Dict, Optional, Tuple, Union
12
+ from typing_extensions import deprecated
13
+
14
+ import torch
15
+ from torch import _C
16
+ from torch._utils import _dummy_type
17
+ from torch.types import Device
18
+
19
+ from . import (
20
+ _get_amdsmi_device_index,
21
+ _get_device_index,
22
+ _get_nvml_device_index,
23
+ _lazy_init,
24
+ is_initialized,
25
+ )
26
+ from ._memory_viz import memory as _memory, segments as _segments
27
+
28
+
29
+ __all__ = [
30
+ "caching_allocator_alloc",
31
+ "caching_allocator_delete",
32
+ "set_per_process_memory_fraction",
33
+ "empty_cache",
34
+ "memory_stats",
35
+ "memory_stats_as_nested_dict",
36
+ "reset_accumulated_memory_stats",
37
+ "reset_peak_memory_stats",
38
+ "reset_max_memory_allocated",
39
+ "reset_max_memory_cached",
40
+ "memory_allocated",
41
+ "max_memory_allocated",
42
+ "memory_reserved",
43
+ "max_memory_reserved",
44
+ "memory_cached",
45
+ "max_memory_cached",
46
+ "memory_snapshot",
47
+ "memory_summary",
48
+ "list_gpu_processes",
49
+ "mem_get_info",
50
+ "get_allocator_backend",
51
+ "CUDAPluggableAllocator",
52
+ "change_current_allocator",
53
+ "MemPool",
54
+ "MemPoolContext",
55
+ "use_mem_pool",
56
+ ]
57
+
58
+
59
+ if not hasattr(torch._C, "_cuda_CUDAAllocator"):
60
+ # Define dummy base classes
61
+ torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator")
62
+
63
+
64
+ if not hasattr(torch._C, "_MemPool"):
65
+ # Define dummy base classes
66
+ torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
67
+ torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
68
+ torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
69
+ "_cuda_beginAllocateToPool"
70
+ )
71
+ torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type(
72
+ "_cuda_endAllocateCurrentStreamToPool"
73
+ )
74
+
75
+ from torch._C import ( # noqa: F401
76
+ _cuda_beginAllocateToPool,
77
+ _cuda_CUDAAllocator,
78
+ _cuda_endAllocateCurrentStreamToPool,
79
+ _MemPool,
80
+ _MemPoolContext,
81
+ )
82
+
83
+
84
+ def _host_allocator():
85
+ _lazy_init()
86
+ return torch._C._cuda_cudaHostAllocator()
87
+
88
+
89
+ @contextlib.contextmanager
90
+ def _free_mutex():
91
+ torch._C._cuda_lock_mutex()
92
+ try:
93
+ yield
94
+ finally:
95
+ torch._C._cuda_unlock_mutex()
96
+
97
+
98
+ def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None):
99
+ r"""Perform a memory allocation using the CUDA memory allocator.
100
+
101
+ Memory is allocated for a given device and a stream, this
102
+ function is intended to be used for interoperability with other
103
+ frameworks. Allocated memory is released through
104
+ :func:`~torch.cuda.caching_allocator_delete`.
105
+
106
+ Args:
107
+ size (int): number of bytes to be allocated.
108
+ device (torch.device or int, optional): selected device. If it is
109
+ ``None`` the default CUDA device is used.
110
+ stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
111
+ the default stream for the selected device is used.
112
+
113
+ .. note::
114
+ See :ref:`cuda-memory-management` for more details about GPU memory
115
+ management.
116
+ """
117
+ if device is None:
118
+ device = torch.cuda.current_device()
119
+ device = _get_device_index(device)
120
+ if stream is None:
121
+ stream = torch.cuda.current_stream(device)
122
+ if isinstance(stream, torch.cuda.streams.Stream):
123
+ stream = stream.cuda_stream
124
+ if not isinstance(stream, int):
125
+ raise TypeError(
126
+ "Invalid type for stream argument, must be "
127
+ "`torch.cuda.Stream` or `int` representing a pointer "
128
+ "to a existing stream"
129
+ )
130
+ with torch.cuda.device(device):
131
+ return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream)
132
+
133
+
134
+ def caching_allocator_delete(mem_ptr):
135
+ r"""Delete memory allocated using the CUDA memory allocator.
136
+
137
+ Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`.
138
+ is freed here. The associated device and stream are tracked inside
139
+ the allocator.
140
+
141
+ Args:
142
+ mem_ptr (int): memory address to be freed by the allocator.
143
+
144
+ .. note::
145
+ See :ref:`cuda-memory-management` for more details about GPU memory
146
+ management.
147
+ """
148
+ torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
149
+
150
+
151
+ def set_per_process_memory_fraction(
152
+ fraction, device: Union[Device, int] = None
153
+ ) -> None:
154
+ r"""Set memory fraction for a process.
155
+
156
+ The fraction is used to limit an caching allocator to allocated memory on a CUDA device.
157
+ The allowed value equals the total visible memory multiplied fraction.
158
+ If trying to allocate more than the allowed value in a process, will raise an out of
159
+ memory error in allocator.
160
+
161
+ Args:
162
+ fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
163
+ device (torch.device or int, optional): selected device. If it is
164
+ ``None`` the default CUDA device is used.
165
+ .. note::
166
+ In general, the total available free memory is less than the total capacity.
167
+ """
168
+ _lazy_init()
169
+ if device is None:
170
+ device = torch.cuda.current_device()
171
+ device = _get_device_index(device)
172
+ if not isinstance(fraction, float):
173
+ raise TypeError("Invalid type for fraction argument, must be `float`")
174
+ if fraction < 0 or fraction > 1:
175
+ raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1")
176
+
177
+ torch._C._cuda_setMemoryFraction(fraction, device)
178
+
179
+
180
+ def empty_cache() -> None:
181
+ r"""Release all unoccupied cached memory currently held by the caching
182
+ allocator so that those can be used in other GPU application and visible in
183
+ `nvidia-smi`.
184
+
185
+ .. note::
186
+ :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
187
+ memory available for PyTorch. However, it may help reduce fragmentation
188
+ of GPU memory in certain cases. See :ref:`cuda-memory-management` for
189
+ more details about GPU memory management.
190
+ """
191
+ if is_initialized():
192
+ torch._C._cuda_emptyCache()
193
+
194
+
195
+ def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
196
+ r"""Return a dictionary of CUDA memory allocator statistics for a given device.
197
+
198
+ The return value of this function is a dictionary of statistics, each of
199
+ which is a non-negative integer.
200
+
201
+ Core statistics:
202
+
203
+ - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
204
+ number of allocation requests received by the memory allocator.
205
+ - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
206
+ amount of allocated memory.
207
+ - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
208
+ number of reserved segments from ``cudaMalloc()``.
209
+ - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
210
+ amount of reserved memory.
211
+ - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
212
+ number of active memory blocks.
213
+ - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
214
+ amount of active memory.
215
+ - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
216
+ number of inactive, non-releasable memory blocks.
217
+ - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
218
+ amount of inactive, non-releasable memory.
219
+
220
+ For these core statistics, values are broken down as follows.
221
+
222
+ Pool type:
223
+
224
+ - ``all``: combined statistics across all memory pools.
225
+ - ``large_pool``: statistics for the large allocation pool
226
+ (as of October 2019, for size >= 1MB allocations).
227
+ - ``small_pool``: statistics for the small allocation pool
228
+ (as of October 2019, for size < 1MB allocations).
229
+
230
+ Metric type:
231
+
232
+ - ``current``: current value of this metric.
233
+ - ``peak``: maximum value of this metric.
234
+ - ``allocated``: historical total increase in this metric.
235
+ - ``freed``: historical total decrease in this metric.
236
+
237
+ In addition to the core statistics, we also provide some simple event
238
+ counters:
239
+
240
+ - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that
241
+ result in a cache flush and retry.
242
+ - ``"num_ooms"``: number of out-of-memory errors thrown.
243
+ - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls.
244
+ - ``"num_device_alloc"``: number of CUDA allocation calls. This includes both
245
+ cuMemMap and cudaMalloc.
246
+ - ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap
247
+ and cudaFree.
248
+
249
+ The caching allocator can be configured via ENV to not split blocks larger than a
250
+ defined size (see Memory Management section of the Cuda Semantics documentation).
251
+ This helps avoid memory fragmentation but may have a performance
252
+ penalty. Additional outputs to assist with tuning and evaluating impact:
253
+
254
+ - ``"max_split_size"``: blocks above this size will not be split.
255
+ - ``"oversize_allocations.{current,peak,allocated,freed}"``:
256
+ number of over-size allocation requests received by the memory allocator.
257
+ - ``"oversize_segments.{current,peak,allocated,freed}"``:
258
+ number of over-size reserved segments from ``cudaMalloc()``.
259
+
260
+ The caching allocator can be configured via ENV to round memory allocations in order
261
+ to reduce fragmentation. Sometimes the overhead from rounding can be higher than
262
+ the fragmentation it helps reduce. The following stat can be used to check if
263
+ rounding adds too much overhead:
264
+
265
+ - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
266
+ memory requested by client code, compare this with allocated_bytes to check if
267
+ allocation rounding adds too much overhead.
268
+
269
+ Args:
270
+ device (torch.device or int, optional): selected device. Returns
271
+ statistics for the current device, given by :func:`~torch.cuda.current_device`,
272
+ if :attr:`device` is ``None`` (default).
273
+
274
+ .. note::
275
+ See :ref:`cuda-memory-management` for more details about GPU memory
276
+ management.
277
+
278
+ .. note::
279
+ With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
280
+ meaningful, and are always reported as zero.
281
+ """
282
+ result = []
283
+
284
+ def _recurse_add_to_result(prefix, obj):
285
+ if isinstance(obj, dict):
286
+ if len(prefix) > 0:
287
+ prefix += "."
288
+ for k, v in obj.items():
289
+ _recurse_add_to_result(prefix + k, v)
290
+ else:
291
+ result.append((prefix, obj))
292
+
293
+ stats = memory_stats_as_nested_dict(device=device)
294
+ _recurse_add_to_result("", stats)
295
+ result.sort()
296
+
297
+ return collections.OrderedDict(result)
298
+
299
+
300
+ def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
301
+ r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
302
+ if not is_initialized():
303
+ return {}
304
+ device = _get_device_index(device, optional=True)
305
+ return torch._C._cuda_memoryStats(device)
306
+
307
+
308
+ def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None:
309
+ r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator.
310
+
311
+ See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
312
+ the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
313
+ `"num_alloc_retries"` and `"num_ooms"`.
314
+
315
+ Args:
316
+ device (torch.device or int, optional): selected device. Returns
317
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
318
+ if :attr:`device` is ``None`` (default).
319
+
320
+ .. note::
321
+ See :ref:`cuda-memory-management` for more details about GPU memory
322
+ management.
323
+ """
324
+ device = _get_device_index(device, optional=True)
325
+ return torch._C._cuda_resetAccumulatedMemoryStats(device)
326
+
327
+
328
+ def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
329
+ r"""Reset the "peak" stats tracked by the CUDA memory allocator.
330
+
331
+ See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
332
+ `"peak"` key in each individual stat dict.
333
+
334
+ Args:
335
+ device (torch.device or int, optional): selected device. Returns
336
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
337
+ if :attr:`device` is ``None`` (default).
338
+
339
+ .. note::
340
+ See :ref:`cuda-memory-management` for more details about GPU memory
341
+ management.
342
+ """
343
+ device = _get_device_index(device, optional=True)
344
+ return torch._C._cuda_resetPeakMemoryStats(device)
345
+
346
+
347
+ def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
348
+ r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device.
349
+
350
+ See :func:`~torch.cuda.max_memory_allocated` for details.
351
+
352
+ Args:
353
+ device (torch.device or int, optional): selected device. Returns
354
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
355
+ if :attr:`device` is ``None`` (default).
356
+
357
+ .. warning::
358
+ This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
359
+ /all/ peak memory stats.
360
+
361
+ .. note::
362
+ See :ref:`cuda-memory-management` for more details about GPU memory
363
+ management.
364
+ """
365
+ warnings.warn(
366
+ "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
367
+ "which resets /all/ peak memory stats.",
368
+ FutureWarning,
369
+ )
370
+ return reset_peak_memory_stats(device=device)
371
+
372
+
373
+ def reset_max_memory_cached(device: Union[Device, int] = None) -> None:
374
+ r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
375
+
376
+ See :func:`~torch.cuda.max_memory_cached` for details.
377
+
378
+ Args:
379
+ device (torch.device or int, optional): selected device. Returns
380
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
381
+ if :attr:`device` is ``None`` (default).
382
+
383
+ .. warning::
384
+ This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
385
+ /all/ peak memory stats.
386
+
387
+ .. note::
388
+ See :ref:`cuda-memory-management` for more details about GPU memory
389
+ management.
390
+ """
391
+ warnings.warn(
392
+ "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
393
+ "which resets /all/ peak memory stats.",
394
+ FutureWarning,
395
+ )
396
+ return reset_peak_memory_stats(device=device)
397
+
398
+
399
+ def memory_allocated(device: Union[Device, int] = None) -> int:
400
+ r"""Return the current GPU memory occupied by tensors in bytes for a given device.
401
+
402
+ Args:
403
+ device (torch.device or int, optional): selected device. Returns
404
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
405
+ if :attr:`device` is ``None`` (default).
406
+
407
+ .. note::
408
+ This is likely less than the amount shown in `nvidia-smi` since some
409
+ unused memory can be held by the caching allocator and some context
410
+ needs to be created on GPU. See :ref:`cuda-memory-management` for more
411
+ details about GPU memory management.
412
+ """
413
+ return memory_stats(device=device).get("allocated_bytes.all.current", 0)
414
+
415
+
416
+ def max_memory_allocated(device: Union[Device, int] = None) -> int:
417
+ r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
418
+
419
+ By default, this returns the peak allocated memory since the beginning of
420
+ this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to
421
+ reset the starting point in tracking this metric. For example, these two
422
+ functions can measure the peak allocated memory usage of each iteration in a
423
+ training loop.
424
+
425
+ Args:
426
+ device (torch.device or int, optional): selected device. Returns
427
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
428
+ if :attr:`device` is ``None`` (default).
429
+
430
+ .. note::
431
+ See :ref:`cuda-memory-management` for more details about GPU memory
432
+ management.
433
+ """
434
+ return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
435
+
436
+
437
+ def memory_reserved(device: Union[Device, int] = None) -> int:
438
+ r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
439
+
440
+ Args:
441
+ device (torch.device or int, optional): selected device. Returns
442
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
443
+ if :attr:`device` is ``None`` (default).
444
+
445
+ .. note::
446
+ See :ref:`cuda-memory-management` for more details about GPU memory
447
+ management.
448
+ """
449
+ return memory_stats(device=device).get("reserved_bytes.all.current", 0)
450
+
451
+
452
+ def max_memory_reserved(device: Union[Device, int] = None) -> int:
453
+ r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
454
+
455
+ By default, this returns the peak cached memory since the beginning of this
456
+ program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset
457
+ the starting point in tracking this metric. For example, these two functions
458
+ can measure the peak cached memory amount of each iteration in a training
459
+ loop.
460
+
461
+ Args:
462
+ device (torch.device or int, optional): selected device. Returns
463
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
464
+ if :attr:`device` is ``None`` (default).
465
+
466
+ .. note::
467
+ See :ref:`cuda-memory-management` for more details about GPU memory
468
+ management.
469
+ """
470
+ return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
471
+
472
+
473
+ @deprecated(
474
+ "`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`",
475
+ category=FutureWarning,
476
+ )
477
+ def memory_cached(device: Union[Device, int] = None) -> int:
478
+ r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
479
+ return memory_reserved(device=device)
480
+
481
+
482
+ @deprecated(
483
+ "`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`",
484
+ category=FutureWarning,
485
+ )
486
+ def max_memory_cached(device: Union[Device, int] = None) -> int:
487
+ r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
488
+ return max_memory_reserved(device=device)
489
+
490
+
491
+ def memory_snapshot():
492
+ r"""Return a snapshot of the CUDA memory allocator state across all devices.
493
+
494
+ Interpreting the output of this function requires familiarity with the
495
+ memory allocator internals.
496
+
497
+ .. note::
498
+ See :ref:`cuda-memory-management` for more details about GPU memory
499
+ management.
500
+ """
501
+ return torch._C._cuda_memorySnapshot()["segments"]
502
+
503
+
504
+ def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str:
505
+ r"""Return a human-readable printout of the current memory allocator statistics for a given device.
506
+
507
+ This can be useful to display periodically during training, or when
508
+ handling out-of-memory exceptions.
509
+
510
+ Args:
511
+ device (torch.device or int, optional): selected device. Returns
512
+ printout for the current device, given by :func:`~torch.cuda.current_device`,
513
+ if :attr:`device` is ``None`` (default).
514
+ abbreviated (bool, optional): whether to return an abbreviated summary
515
+ (default: False).
516
+
517
+ .. note::
518
+ See :ref:`cuda-memory-management` for more details about GPU memory
519
+ management.
520
+ """
521
+ device = _get_device_index(device, optional=True)
522
+ stats = memory_stats(device=device)
523
+
524
+ def _format_size(sz, pref_sz):
525
+ prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"]
526
+ prefix = prefixes[0]
527
+ for new_prefix in prefixes[1:]:
528
+ if pref_sz < 768 * 1024:
529
+ break
530
+ prefix = new_prefix
531
+ sz //= 1024
532
+ pref_sz /= 1024
533
+ return f"{sz:6d} {prefix}"
534
+
535
+ def _format_count(cnt, pref_cnt):
536
+ prefixes = [" ", "K", "M"]
537
+ prefix = prefixes[0]
538
+ for new_prefix in prefixes[1:]:
539
+ if pref_cnt < 750 * 1000:
540
+ break
541
+ prefix = new_prefix
542
+ cnt //= 1000
543
+ pref_cnt /= 1000
544
+ return f"{cnt:7d} {prefix} "
545
+
546
+ metrics_to_display = [
547
+ ("allocated_bytes", "Allocated memory", _format_size),
548
+ ("active_bytes", "Active memory", _format_size),
549
+ ("requested_bytes", "Requested memory", _format_size),
550
+ ("reserved_bytes", "GPU reserved memory", _format_size),
551
+ ("inactive_split_bytes", "Non-releasable memory", _format_size),
552
+ ("allocation", "Allocations", _format_count),
553
+ ("active", "Active allocs", _format_count),
554
+ ("segment", "GPU reserved segments", _format_count),
555
+ ("inactive_split", "Non-releasable allocs", _format_count),
556
+ ]
557
+
558
+ lines = []
559
+ lines.append("=" * 75)
560
+ lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ")
561
+ lines.append("-" * 75)
562
+ lines.append(
563
+ " {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} "
564
+ )
565
+ lines.append("=" * 75)
566
+ lines.append(
567
+ " Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
568
+ )
569
+
570
+ for metric_key, metric_name, formatter in metrics_to_display:
571
+ lines.append("-" * 75)
572
+ submetrics = [("all", metric_name)]
573
+ if not abbreviated:
574
+ submetrics.append(("large_pool", " from large pool"))
575
+ submetrics.append(("small_pool", " from small pool"))
576
+
577
+ current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
578
+ None,
579
+ None,
580
+ None,
581
+ None,
582
+ )
583
+
584
+ for submetric_key, submetric_name in submetrics:
585
+ prefix = metric_key + "." + submetric_key + "."
586
+
587
+ current = stats[prefix + "current"]
588
+ peak = stats[prefix + "peak"]
589
+ allocated = stats[prefix + "allocated"]
590
+ freed = stats[prefix + "freed"]
591
+
592
+ if current_prefval is None:
593
+ current_prefval = current
594
+ peak_prefval = peak
595
+ allocated_prefval = allocated
596
+ freed_prefval = freed
597
+
598
+ lines.append(
599
+ f" {submetric_name:<21} | {formatter(current, current_prefval)} | {formatter(peak, peak_prefval)} | "
600
+ f"{formatter(allocated, allocated_prefval)} | {formatter(freed, freed_prefval)} ",
601
+ )
602
+
603
+ metrics_to_display = [
604
+ ("oversize_allocations", "Oversize allocations", _format_count),
605
+ ("oversize_segments", "Oversize GPU segments", _format_count),
606
+ ]
607
+
608
+ for metric_key, metric_name, formatter in metrics_to_display:
609
+ lines.append("-" * 75)
610
+
611
+ prefix = metric_key + "."
612
+
613
+ current = stats[prefix + "current"]
614
+ peak = stats[prefix + "peak"]
615
+ allocated = stats[prefix + "allocated"]
616
+ freed = stats[prefix + "freed"]
617
+
618
+ lines.append(
619
+ f" {metric_name:<21} | {formatter(current, current)} | {formatter(peak, peak)} | "
620
+ f"{formatter(allocated, allocated)} | {formatter(freed, freed)} ",
621
+ )
622
+
623
+ lines.append("=" * 75)
624
+
625
+ fmt_dict = {"_": "", "device": device}
626
+ for k, v in stats.items():
627
+ fmt_dict[k.replace(".", "-")] = v
628
+ return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
629
+
630
+
631
+ def list_gpu_processes(device: Union[Device, int] = None) -> str:
632
+ r"""Return a human-readable printout of the running processes and their GPU memory use for a given device.
633
+
634
+ This can be useful to display periodically during training, or when
635
+ handling out-of-memory exceptions.
636
+
637
+ Args:
638
+ device (torch.device or int, optional): selected device. Returns
639
+ printout for the current device, given by :func:`~torch.cuda.current_device`,
640
+ if :attr:`device` is ``None`` (default).
641
+ """
642
+ if not torch.version.hip:
643
+ try:
644
+ import pynvml # type: ignore[import]
645
+ except ModuleNotFoundError:
646
+ return "pynvml module not found, please install pynvml"
647
+ from pynvml import NVMLError_DriverNotLoaded
648
+
649
+ try:
650
+ pynvml.nvmlInit()
651
+ except NVMLError_DriverNotLoaded:
652
+ return "cuda driver can't be loaded, is cuda enabled?"
653
+
654
+ device = _get_nvml_device_index(device)
655
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
656
+ procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
657
+ else:
658
+ try:
659
+ import amdsmi # type: ignore[import]
660
+ except ModuleNotFoundError:
661
+ return "amdsmi module not found, please install amdsmi"
662
+ try:
663
+ amdsmi.amdsmi_init() # type: ignore[attr-defined]
664
+ except amdsmi.AmdSmiException: # type: ignore[attr-defined]
665
+ return "amdsmi driver can't be loaded, is ROCm installed?"
666
+
667
+ device = _get_amdsmi_device_index(device)
668
+
669
+ try:
670
+ handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined]
671
+ procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined]
672
+ except amdsmi.AmdSmiException: # type: ignore[attr-defined]
673
+ return "amdsmi cannot list processes from other users"
674
+
675
+ lines = []
676
+ lines.append(f"GPU:{device}")
677
+ if len(procs) == 0:
678
+ lines.append("no processes are running")
679
+ for p in procs:
680
+ if not torch.version.hip:
681
+ mem = p.usedGpuMemory / (1024 * 1024)
682
+ pid = p.pid
683
+ else:
684
+ try:
685
+ proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined]
686
+ except AttributeError:
687
+ # https://github.com/ROCm/amdsmi/commit/c551c3caedbd903ba828e7fdffa5b56d475a15e7
688
+ # is a BC-breaking change that removes amdsmi_get_gpu_process_info API from amdsmi
689
+ proc_info = p
690
+ mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024)
691
+ pid = proc_info["pid"]
692
+ lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory")
693
+ return "\n".join(lines)
694
+
695
+
696
+ def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
697
+ r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
698
+
699
+ Args:
700
+ device (torch.device or int or str, optional): selected device. Returns
701
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
702
+ if :attr:`device` is ``None`` (default) or if the device index is not specified.
703
+
704
+ .. note::
705
+ See :ref:`cuda-memory-management` for more
706
+ details about GPU memory management.
707
+ """
708
+ if device is None:
709
+ device = torch.cuda.current_device()
710
+ # optional=True allows `device = torch.device('cuda')` for which device.index is None
711
+ device = _get_device_index(device, optional=True)
712
+ return torch.cuda.cudart().cudaMemGetInfo(device)
713
+
714
+
715
+ def _record_memory_history_legacy(
716
+ enabled: bool,
717
+ record_context=True,
718
+ trace_alloc_max_entries=1,
719
+ trace_alloc_record_context=False,
720
+ device: Union[Device, int] = None,
721
+ record_context_cpp=False,
722
+ ):
723
+ _C._cuda_record_memory_history_legacy(
724
+ enabled,
725
+ record_context,
726
+ trace_alloc_max_entries,
727
+ trace_alloc_record_context,
728
+ record_context_cpp,
729
+ )
730
+
731
+
732
+ def _record_memory_history(enabled="all", *args, **kwargs):
733
+ """Enable recording of stack traces associated with memory
734
+ allocations, so you can tell what allocated any piece of memory in
735
+ :func:`torch.cuda.memory._snapshot()`.
736
+
737
+ In addition too keeping stack traces with each current allocation and free,
738
+ this will also enable recording of a history of all alloc/free events.
739
+
740
+ Use :func:`torch.cuda.memory._snapshot()` to retrieve this information,
741
+ and the tools in `_memory_viz.py` to visualize snapshots.
742
+
743
+ The Python trace collection is fast (2us per trace), so you may consider
744
+ enabling this on production jobs if you anticipate ever having to debug
745
+ memory issues.
746
+
747
+ C++ trace collection is also fast (~50ns/frame), which for many typical programs
748
+ works out to ~2us per trace, but can vary depending on stack depth.
749
+
750
+ Args:
751
+ enabled (Literal[None, "state", "all"], optional):
752
+ `None`, disable recording memory history.
753
+ `"state"`, keep information for currenly allocated memory.
754
+ `"all"`, additionally keep a history of all alloc/free calls.
755
+ Defaults to "all".
756
+ context (Literal[None, "state", "alloc", "all"], optional):
757
+ `None`, Do not record any tracebacks.
758
+ `"state"`, Record tracebacks for currently allocated memory.
759
+ `"alloc"`, additionally keep tracebacks for alloc calls.
760
+ `"all"`, additionally keep tracebacks for free calls.
761
+ Defaults to "all".
762
+ stacks (Literal["python", "all"], optional):
763
+ `"python"`, include Python, TorchScript, and inductor frames in tracebacks
764
+ `"all"`, additionally include C++ frames
765
+ Defaults to "all".
766
+ max_entries (int, optional): Keep a maximum of `max_entries`
767
+ alloc/free events in the recorded history recorded.
768
+ """
769
+ if isinstance(enabled, bool):
770
+ return _record_memory_history_legacy(enabled, *args, **kwargs)
771
+ else:
772
+ return _record_memory_history_impl(enabled, *args, **kwargs)
773
+
774
+
775
+ def _record_memory_history_impl(
776
+ enabled: Optional[str] = "all",
777
+ context: Optional[str] = "all",
778
+ stacks: str = "all",
779
+ max_entries: int = sys.maxsize,
780
+ device: Union[Device, int] = None,
781
+ ):
782
+ _C._cuda_record_memory_history(enabled, context, stacks, max_entries)
783
+
784
+
785
+ _record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
786
+
787
+
788
+ def _snapshot(device: Union[Device, int] = None):
789
+ """Save a snapshot of CUDA memory state at the time it was called.
790
+
791
+ The state is represented as a dictionary with the following structure.
792
+
793
+ .. code-block:: python
794
+
795
+ class Snapshot(TypedDict):
796
+ segments : List[Segment]
797
+ device_traces: List[List[TraceEntry]]
798
+
799
+ class Segment(TypedDict):
800
+ # Segments are memory returned from a cudaMalloc call.
801
+ # The size of reserved memory is the sum of all Segments.
802
+ # Segments are cached and reused for future allocations.
803
+ # If the reuse is smaller than the segment, the segment
804
+ # is split into more then one Block.
805
+ # empty_cache() frees Segments that are entirely inactive.
806
+ address: int
807
+ total_size: int # cudaMalloc'd size of segment
808
+ stream: int
809
+ segment_type: Literal['small', 'large'] # 'large' (>1MB)
810
+ allocated_size: int # size of memory in use
811
+ active_size: int # size of memory in use or in active_awaiting_free state
812
+ blocks : List[Block]
813
+
814
+ class Block(TypedDict):
815
+ # A piece of memory returned from the allocator, or
816
+ # current cached but inactive.
817
+ size: int
818
+ requested_size: int # size requested during malloc, may be smaller than
819
+ # size due to rounding
820
+ address: int
821
+ state: Literal['active_allocated', # used by a tensor
822
+ 'active_awaiting_free', # waiting for another stream to finish using
823
+ # this, then it will become free
824
+ 'inactive',] # free for reuse
825
+ frames: List[Frame] # stack trace from where the allocation occurred
826
+
827
+ class Frame(TypedDict):
828
+ filename: str
829
+ line: int
830
+ name: str
831
+
832
+ class TraceEntry(TypedDict):
833
+ # When `torch.cuda.memory._record_memory_history()` is enabled,
834
+ # the snapshot will contain TraceEntry objects that record each
835
+ # action the allocator took.
836
+ action: Literal[
837
+ 'alloc' # memory allocated
838
+ 'free_requested', # the allocated received a call to free memory
839
+ 'free_completed', # the memory that was requested to be freed is now
840
+ # able to be used in future allocation calls
841
+ 'segment_alloc', # the caching allocator ask cudaMalloc for more memory
842
+ # and added it as a segment in its cache
843
+ 'segment_free', # the caching allocator called cudaFree to return memory
844
+ # to cuda possibly trying free up memory to
845
+ # allocate more segments or because empty_caches was called
846
+ 'oom', # the allocator threw an OOM exception. 'size' is
847
+ # the requested number of bytes that did not succeed
848
+ 'snapshot' # the allocator generated a memory snapshot
849
+ # useful to coorelate a previously taken
850
+ # snapshot with this trace
851
+ ]
852
+ addr: int # not present for OOM
853
+ frames: List[Frame]
854
+ size: int
855
+ stream: int
856
+ device_free: int # only present for OOM, the amount of
857
+ # memory cuda still reports to be free
858
+
859
+ Returns:
860
+ The Snapshot dictionary object
861
+ """
862
+ return _C._cuda_memorySnapshot()
863
+
864
+
865
+ def _dump_snapshot(filename="dump_snapshot.pickle"):
866
+ """
867
+ Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
868
+
869
+ This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz
870
+
871
+ Args:
872
+ filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
873
+ """
874
+ s = _snapshot()
875
+ with open(filename, "wb") as f:
876
+ pickle.dump(s, f)
877
+
878
+
879
+ def _save_segment_usage(filename="output.svg", snapshot=None):
880
+ if snapshot is None:
881
+ snapshot = _snapshot()
882
+ with open(filename, "w") as f:
883
+ f.write(_segments(snapshot))
884
+
885
+
886
+ def _save_memory_usage(filename="output.svg", snapshot=None):
887
+ if snapshot is None:
888
+ snapshot = _snapshot()
889
+ with open(filename, "w") as f:
890
+ f.write(_memory(snapshot))
891
+
892
+
893
+ def _set_allocator_settings(env: str):
894
+ return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
895
+
896
+
897
+ def get_allocator_backend() -> str:
898
+ r"""Return a string describing the active allocator backend as set by
899
+ ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
900
+ ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
901
+ (CUDA's built-in asynchronous allocator).
902
+
903
+ .. note::
904
+ See :ref:`cuda-memory-management` for details on choosing the allocator backend.
905
+ """
906
+ return torch._C._cuda_getAllocatorBackend()
907
+
908
+
909
+ class _CUDAAllocator:
910
+ r"""Wrapper over internal CUDA memory allocators."""
911
+
912
+ def __init__(self, allocator: torch._C._cuda_CUDAAllocator):
913
+ self._allocator = allocator
914
+
915
+ def allocator(self):
916
+ return self._allocator
917
+
918
+
919
+ class CUDAPluggableAllocator(_CUDAAllocator):
920
+ r"""CUDA memory allocator loaded from a so file."""
921
+
922
+ def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
923
+ r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes.
924
+
925
+ To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function.
926
+
927
+ Args:
928
+ path_to_so_file(str): Path in the filesystem to the `.so` file containing
929
+ the allocator functions
930
+ alloc_fn_name(str): Name of the function to perform the memory allocation
931
+ in the so file. The signature must be:
932
+ void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream);
933
+ free_fn_name(str): Name of the function to perform the memory release
934
+ in the so file. The signature must be:
935
+ void free_fn_name(void* ptr, size_t size, cudaStream_t stream);
936
+
937
+ .. warning::
938
+ This is currently supported only in unix OSs
939
+
940
+ .. note::
941
+ See :ref:`cuda-memory-management` for details on creating and using a custom allocator
942
+ """
943
+ allocator = ctypes.CDLL(path_to_so_file)
944
+ alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
945
+ free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
946
+ assert alloc_fn is not None
947
+ assert free_fn is not None
948
+ self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn)
949
+
950
+
951
+ def change_current_allocator(allocator: _CUDAAllocator) -> None:
952
+ r"""Change the currently used memory allocator to be the one provided.
953
+
954
+ If the current allocator has already been used/initialized, this function will error.
955
+
956
+
957
+ Args:
958
+ allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one.
959
+ .. note::
960
+ See :ref:`cuda-memory-management` for details on creating and using a custom allocator
961
+ """
962
+ torch._C._cuda_changeCurrentAllocator(allocator.allocator())
963
+
964
+
965
+ def _get_current_allocator() -> _CUDAAllocator:
966
+ r"""Return the allocator being currently used.
967
+
968
+ .. note::
969
+ See :ref:`cuda-memory-management` for details on creating and using a custom allocator
970
+ """
971
+ return _CUDAAllocator(torch._C._cuda_getAllocator())
972
+
973
+
974
+ class MemPool(_MemPool):
975
+ r"""MemPool represents a pool of memory in a caching allocator. Currently,
976
+ it's just the ID of the pool object maintained in the CUDACachingAllocator.
977
+
978
+ Args:
979
+ allocator(torch._C._cuda_CUDAAllocator, optional): a
980
+ torch._C._cuda_CUDAAllocator object that can be used to
981
+ define how memory gets allocated in the pool. If :attr:`allocator`
982
+ is ``None`` (default), memory allocation follows the default/
983
+ current configuration of the CUDACachingAllocator.
984
+
985
+ """
986
+
987
+ def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None):
988
+ super().__init__(allocator, True)
989
+
990
+ @property
991
+ def id(self) -> Tuple[int, int]:
992
+ r"""Returns the ID of this pool as a tuple of two ints."""
993
+ return super().id
994
+
995
+ @property
996
+ def allocator(self) -> Optional[_cuda_CUDAAllocator]:
997
+ r"""Returns the allocator this MemPool routes allocations to"""
998
+ return super().allocator
999
+
1000
+
1001
+ class MemPoolContext(_MemPoolContext):
1002
+ r"""MemPoolContext holds the currently active pool and stashes the previous
1003
+ pool. On deletion it makes the previous pool active.
1004
+
1005
+ Args:
1006
+ pool(torch.cuda.MemPool): a MemPool object to be made active so that
1007
+ allocations route to this pool.
1008
+
1009
+ """
1010
+
1011
+ def __init__(self, pool: MemPool):
1012
+ super().__init__(pool)
1013
+
1014
+ @staticmethod
1015
+ def active_pool() -> Optional[_MemPool]:
1016
+ r"""Returns the active MemPool"""
1017
+ return _MemPoolContext.active_pool()
1018
+
1019
+
1020
+ @contextlib.contextmanager
1021
+ def use_mem_pool(pool: MemPool, device: Union[Device, int] = None):
1022
+ r"""A context manager that routes allocations to a given pool.
1023
+
1024
+ Args:
1025
+ pool(torch.cuda.MemPool): a MemPool object to be made active so that
1026
+ allocations route to this pool.
1027
+ device (torch.device or int, optional): selected device. Uses MemPool on
1028
+ the current device, given by :func:`~torch.cuda.current_device`,
1029
+ if :attr:`device` is ``None`` (default).
1030
+
1031
+ """
1032
+ ctx = MemPoolContext(pool)
1033
+ device_index = (
1034
+ torch.cuda.current_device() if device is None else _get_device_index(device)
1035
+ )
1036
+ _cuda_beginAllocateToPool(device_index, pool.id)
1037
+ try:
1038
+ yield
1039
+ finally:
1040
+ _cuda_endAllocateCurrentStreamToPool(device_index, pool.id)
1041
+ del ctx
.venv/lib/python3.11/site-packages/torch/cuda/nccl.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ import warnings
4
+ from typing import Optional, Sequence, Union
5
+
6
+ import torch.cuda
7
+
8
+
9
+ __all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
10
+
11
+ SUM = 0 # ncclRedOp_t
12
+
13
+
14
+ def is_available(tensors):
15
+ if not hasattr(torch._C, "_nccl_all_reduce"):
16
+ warnings.warn("PyTorch is not compiled with NCCL support")
17
+ return False
18
+
19
+ devices = set()
20
+ for tensor in tensors:
21
+ if tensor.is_sparse:
22
+ return False
23
+ if not tensor.is_contiguous():
24
+ return False
25
+ if not tensor.is_cuda:
26
+ return False
27
+ device = tensor.get_device()
28
+ if device in devices:
29
+ return False
30
+ devices.add(device)
31
+
32
+ return True
33
+
34
+
35
+ def version():
36
+ """
37
+ Returns the version of the NCCL.
38
+
39
+
40
+ This function returns a tuple containing the major, minor, and patch version numbers of the NCCL.
41
+ The suffix is also included in the tuple if a version suffix exists.
42
+ Returns:
43
+ tuple: The version information of the NCCL.
44
+ """
45
+ ver = torch._C._nccl_version()
46
+ major = ver >> 32
47
+ minor = (ver >> 16) & 65535
48
+ patch = ver & 65535
49
+ suffix = torch._C._nccl_version_suffix().decode("utf-8")
50
+ if suffix == "":
51
+ return (major, minor, patch)
52
+ else:
53
+ return (major, minor, patch, suffix)
54
+
55
+
56
+ def unique_id():
57
+ return torch._C._nccl_unique_id()
58
+
59
+
60
+ def init_rank(num_ranks, uid, rank):
61
+ return torch._C._nccl_init_rank(num_ranks, uid, rank)
62
+
63
+
64
+ def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
65
+ if not isinstance(inputs, collections.abc.Container) or isinstance(
66
+ inputs, torch.Tensor
67
+ ):
68
+ raise TypeError("Inputs should be a collection of tensors")
69
+
70
+
71
+ def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
72
+ _check_sequence_type(inputs)
73
+ if outputs is None:
74
+ outputs = inputs
75
+ _check_sequence_type(outputs)
76
+ torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
77
+
78
+
79
+ # `output` used to be `outputs`, taking in a list of tensors. So we have two
80
+ # arguments for BC reasons.
81
+ def reduce(
82
+ inputs: Sequence[torch.Tensor],
83
+ output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
84
+ root: int = 0,
85
+ op: int = SUM,
86
+ streams: Optional[Sequence[torch.cuda.Stream]] = None,
87
+ comms=None,
88
+ *,
89
+ outputs: Optional[Sequence[torch.Tensor]] = None,
90
+ ) -> None:
91
+ _check_sequence_type(inputs)
92
+ _output: torch.Tensor
93
+ if outputs is not None:
94
+ if output is not None:
95
+ raise ValueError(
96
+ "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
97
+ "favor of 'output', taking in a single output tensor. The signature of reduce is: "
98
+ "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
99
+ )
100
+ else:
101
+ warnings.warn(
102
+ "`nccl.reduce` with an output tensor list is deprecated. "
103
+ "Please specify a single output tensor with argument 'output' instead instead.",
104
+ FutureWarning,
105
+ stacklevel=2,
106
+ )
107
+ _output = outputs[root]
108
+ elif not isinstance(output, torch.Tensor) and isinstance(
109
+ output, collections.abc.Sequence
110
+ ):
111
+ # User called old API with positional arguments of list of output tensors.
112
+ warnings.warn(
113
+ "nccl.reduce with an output tensor list is deprecated. "
114
+ "Please specify a single output tensor.",
115
+ FutureWarning,
116
+ stacklevel=2,
117
+ )
118
+ _output = output[root]
119
+ else:
120
+ _output = inputs[root] if output is None else output
121
+ torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
122
+
123
+
124
+ def broadcast(
125
+ inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
126
+ ) -> None:
127
+ _check_sequence_type(inputs)
128
+ torch._C._nccl_broadcast(inputs, root, streams, comms)
129
+
130
+
131
+ def all_gather(
132
+ inputs: Sequence[torch.Tensor],
133
+ outputs: Sequence[torch.Tensor],
134
+ streams=None,
135
+ comms=None,
136
+ ) -> None:
137
+ _check_sequence_type(inputs)
138
+ _check_sequence_type(outputs)
139
+ torch._C._nccl_all_gather(inputs, outputs, streams, comms)
140
+
141
+
142
+ def reduce_scatter(
143
+ inputs: Sequence[torch.Tensor],
144
+ outputs: Sequence[torch.Tensor],
145
+ op: int = SUM,
146
+ streams=None,
147
+ comms=None,
148
+ ) -> None:
149
+ _check_sequence_type(inputs)
150
+ _check_sequence_type(outputs)
151
+ torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
3
+
4
+ from contextlib import contextmanager
5
+
6
+
7
+ try:
8
+ from torch._C import _nvtx
9
+ except ImportError:
10
+
11
+ class _NVTXStub:
12
+ @staticmethod
13
+ def _fail(*args, **kwargs):
14
+ raise RuntimeError(
15
+ "NVTX functions not installed. Are you sure you have a CUDA build?"
16
+ )
17
+
18
+ rangePushA = _fail
19
+ rangePop = _fail
20
+ markA = _fail
21
+
22
+ _nvtx = _NVTXStub() # type: ignore[assignment]
23
+
24
+ __all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
25
+
26
+
27
+ def range_push(msg):
28
+ """
29
+ Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started.
30
+
31
+ Args:
32
+ msg (str): ASCII message to associate with range
33
+ """
34
+ return _nvtx.rangePushA(msg)
35
+
36
+
37
+ def range_pop():
38
+ """Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended."""
39
+ return _nvtx.rangePop()
40
+
41
+
42
+ def range_start(msg) -> int:
43
+ """
44
+ Mark the start of a range with string message. It returns an unique handle
45
+ for this range to pass to the corresponding call to rangeEnd().
46
+
47
+ A key difference between this and range_push/range_pop is that the
48
+ range_start/range_end version supports range across threads (start on one
49
+ thread and end on another thread).
50
+
51
+ Returns: A range handle (uint64_t) that can be passed to range_end().
52
+
53
+ Args:
54
+ msg (str): ASCII message to associate with the range.
55
+ """
56
+ return _nvtx.rangeStartA(msg)
57
+
58
+
59
+ def range_end(range_id) -> None:
60
+ """
61
+ Mark the end of a range for a given range_id.
62
+
63
+ Args:
64
+ range_id (int): an unique handle for the start range.
65
+ """
66
+ _nvtx.rangeEnd(range_id)
67
+
68
+
69
+ def mark(msg):
70
+ """
71
+ Describe an instantaneous event that occurred at some point.
72
+
73
+ Args:
74
+ msg (str): ASCII message to associate with the event.
75
+ """
76
+ return _nvtx.markA(msg)
77
+
78
+
79
+ @contextmanager
80
+ def range(msg, *args, **kwargs):
81
+ """
82
+ Context manager / decorator that pushes an NVTX range at the beginning
83
+ of its scope, and pops it at the end. If extra arguments are given,
84
+ they are passed as arguments to msg.format().
85
+
86
+ Args:
87
+ msg (str): message to associate with the range
88
+ """
89
+ range_push(msg.format(*args, **kwargs))
90
+ try:
91
+ yield
92
+ finally:
93
+ range_pop()
.venv/lib/python3.11/site-packages/torch/cuda/profiler.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import tempfile
4
+
5
+ import torch
6
+
7
+ from . import check_error, cudart
8
+
9
+
10
+ __all__ = ["init", "start", "stop", "profile"]
11
+
12
+ DEFAULT_FLAGS = [
13
+ "gpustarttimestamp",
14
+ "gpuendtimestamp",
15
+ "gridsize3d",
16
+ "threadblocksize",
17
+ "streamid",
18
+ "enableonstart 0",
19
+ "conckerneltrace",
20
+ ]
21
+
22
+
23
+ def init(output_file, flags=None, output_mode="key_value"):
24
+ rt = cudart()
25
+ if not hasattr(rt, "cudaOutputMode"):
26
+ raise AssertionError("HIP does not support profiler initialization!")
27
+ if (
28
+ hasattr(torch.version, "cuda")
29
+ and torch.version.cuda is not None
30
+ and int(torch.version.cuda.split(".")[0]) >= 12
31
+ ):
32
+ # Check https://github.com/pytorch/pytorch/pull/91118
33
+ # cudaProfilerInitialize is no longer needed after CUDA 12
34
+ raise AssertionError("CUDA12+ does not need profiler initialization!")
35
+ flags = DEFAULT_FLAGS if flags is None else flags
36
+ if output_mode == "key_value":
37
+ output_mode_enum = rt.cudaOutputMode.KeyValuePair
38
+ elif output_mode == "csv":
39
+ output_mode_enum = rt.cudaOutputMode.CSV
40
+ else:
41
+ raise RuntimeError(
42
+ "supported CUDA profiler output modes are: key_value and csv"
43
+ )
44
+ with tempfile.NamedTemporaryFile(delete=True) as f:
45
+ f.write(b"\n".join(f.encode("ascii") for f in flags))
46
+ f.flush()
47
+ check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum))
48
+
49
+
50
+ def start():
51
+ r"""Starts cuda profiler data collection.
52
+
53
+ .. warning::
54
+ Raises CudaError in case of it is unable to start the profiler.
55
+ """
56
+ check_error(cudart().cudaProfilerStart())
57
+
58
+
59
+ def stop():
60
+ r"""Stops cuda profiler data collection.
61
+
62
+ .. warning::
63
+ Raises CudaError in case of it is unable to stop the profiler.
64
+ """
65
+ check_error(cudart().cudaProfilerStop())
66
+
67
+
68
+ @contextlib.contextmanager
69
+ def profile():
70
+ """
71
+ Enable profiling.
72
+
73
+ Context Manager to enabling profile collection by the active profiling tool from CUDA backend.
74
+ Example:
75
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
76
+ >>> import torch
77
+ >>> model = torch.nn.Linear(20, 30).cuda()
78
+ >>> inputs = torch.randn(128, 20).cuda()
79
+ >>> with torch.cuda.profiler.profile() as prof:
80
+ ... model(inputs)
81
+ """
82
+ try:
83
+ start()
84
+ yield
85
+ finally:
86
+ stop()
.venv/lib/python3.11/site-packages/torch/cuda/random.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Iterable, List, Union
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from . import _lazy_call, _lazy_init, current_device, device_count
8
+
9
+
10
+ __all__ = [
11
+ "get_rng_state",
12
+ "get_rng_state_all",
13
+ "set_rng_state",
14
+ "set_rng_state_all",
15
+ "manual_seed",
16
+ "manual_seed_all",
17
+ "seed",
18
+ "seed_all",
19
+ "initial_seed",
20
+ ]
21
+
22
+
23
+ def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
24
+ r"""Return the random number generator state of the specified GPU as a ByteTensor.
25
+
26
+ Args:
27
+ device (torch.device or int, optional): The device to return the RNG state of.
28
+ Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
29
+
30
+ .. warning::
31
+ This function eagerly initializes CUDA.
32
+ """
33
+ _lazy_init()
34
+ if isinstance(device, str):
35
+ device = torch.device(device)
36
+ elif isinstance(device, int):
37
+ device = torch.device("cuda", device)
38
+ idx = device.index
39
+ if idx is None:
40
+ idx = current_device()
41
+ default_generator = torch.cuda.default_generators[idx]
42
+ return default_generator.get_state()
43
+
44
+
45
+ def get_rng_state_all() -> List[Tensor]:
46
+ r"""Return a list of ByteTensor representing the random number states of all devices."""
47
+ results = []
48
+ for i in range(device_count()):
49
+ results.append(get_rng_state(i))
50
+ return results
51
+
52
+
53
+ def set_rng_state(
54
+ new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
55
+ ) -> None:
56
+ r"""Set the random number generator state of the specified GPU.
57
+
58
+ Args:
59
+ new_state (torch.ByteTensor): The desired state
60
+ device (torch.device or int, optional): The device to set the RNG state.
61
+ Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
62
+ """
63
+ with torch._C._DisableFuncTorch():
64
+ new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
65
+ if isinstance(device, str):
66
+ device = torch.device(device)
67
+ elif isinstance(device, int):
68
+ device = torch.device("cuda", device)
69
+
70
+ def cb():
71
+ idx = device.index
72
+ if idx is None:
73
+ idx = current_device()
74
+ default_generator = torch.cuda.default_generators[idx]
75
+ default_generator.set_state(new_state_copy)
76
+
77
+ _lazy_call(cb)
78
+
79
+
80
+ def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
81
+ r"""Set the random number generator state of all devices.
82
+
83
+ Args:
84
+ new_states (Iterable of torch.ByteTensor): The desired state for each device.
85
+ """
86
+ for i, state in enumerate(new_states):
87
+ set_rng_state(state, i)
88
+
89
+
90
+ def manual_seed(seed: int) -> None:
91
+ r"""Set the seed for generating random numbers for the current GPU.
92
+
93
+ It's safe to call this function if CUDA is not available; in that
94
+ case, it is silently ignored.
95
+
96
+ Args:
97
+ seed (int): The desired seed.
98
+
99
+ .. warning::
100
+ If you are working with a multi-GPU model, this function is insufficient
101
+ to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
102
+ """
103
+ seed = int(seed)
104
+
105
+ def cb():
106
+ idx = current_device()
107
+ default_generator = torch.cuda.default_generators[idx]
108
+ default_generator.manual_seed(seed)
109
+
110
+ _lazy_call(cb, seed=True)
111
+
112
+
113
+ def manual_seed_all(seed: int) -> None:
114
+ r"""Set the seed for generating random numbers on all GPUs.
115
+
116
+ It's safe to call this function if CUDA is not available; in that
117
+ case, it is silently ignored.
118
+
119
+ Args:
120
+ seed (int): The desired seed.
121
+ """
122
+ seed = int(seed)
123
+
124
+ def cb():
125
+ for i in range(device_count()):
126
+ default_generator = torch.cuda.default_generators[i]
127
+ default_generator.manual_seed(seed)
128
+
129
+ _lazy_call(cb, seed_all=True)
130
+
131
+
132
+ def seed() -> None:
133
+ r"""Set the seed for generating random numbers to a random number for the current GPU.
134
+
135
+ It's safe to call this function if CUDA is not available; in that
136
+ case, it is silently ignored.
137
+
138
+ .. warning::
139
+ If you are working with a multi-GPU model, this function will only initialize
140
+ the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
141
+ """
142
+
143
+ def cb():
144
+ idx = current_device()
145
+ default_generator = torch.cuda.default_generators[idx]
146
+ default_generator.seed()
147
+
148
+ _lazy_call(cb)
149
+
150
+
151
+ def seed_all() -> None:
152
+ r"""Set the seed for generating random numbers to a random number on all GPUs.
153
+
154
+ It's safe to call this function if CUDA is not available; in that
155
+ case, it is silently ignored.
156
+ """
157
+
158
+ def cb():
159
+ random_seed = 0
160
+ seeded = False
161
+ for i in range(device_count()):
162
+ default_generator = torch.cuda.default_generators[i]
163
+ if not seeded:
164
+ default_generator.seed()
165
+ random_seed = default_generator.initial_seed()
166
+ seeded = True
167
+ else:
168
+ default_generator.manual_seed(random_seed)
169
+
170
+ _lazy_call(cb)
171
+
172
+
173
+ def initial_seed() -> int:
174
+ r"""Return the current random seed of the current GPU.
175
+
176
+ .. warning::
177
+ This function eagerly initializes CUDA.
178
+ """
179
+ _lazy_init()
180
+ idx = current_device()
181
+ default_generator = torch.cuda.default_generators[idx]
182
+ return default_generator.initial_seed()
.venv/lib/python3.11/site-packages/torch/cuda/sparse.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # The Tensor classes are added to this module by python_tensor.cpp
.venv/lib/python3.11/site-packages/torch/cuda/streams.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import ctypes
3
+
4
+ import torch
5
+ from torch._streambase import _EventBase, _StreamBase
6
+ from torch._utils import _dummy_type
7
+
8
+
9
+ if not hasattr(torch._C, "_CudaStreamBase"):
10
+ # Define dummy base classes
11
+ torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
12
+ torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
13
+
14
+
15
+ class Stream(torch._C._CudaStreamBase, _StreamBase):
16
+ r"""Wrapper around a CUDA stream.
17
+
18
+ A CUDA stream is a linear sequence of execution that belongs to a specific
19
+ device, independent from other streams. See :ref:`cuda-semantics` for
20
+ details.
21
+
22
+ Args:
23
+ device(torch.device or int, optional): a device on which to allocate
24
+ the stream. If :attr:`device` is ``None`` (default) or a negative
25
+ integer, this will use the current device.
26
+ priority(int, optional): priority of the stream, should be 0 or
27
+ negative, where negative numbers indicate higher priority. By default,
28
+ streams have priority 0.
29
+
30
+ """
31
+
32
+ def __new__(cls, device=None, priority=0, **kwargs):
33
+ # setting device manager is expensive, so we avoid it unless necessary
34
+ if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
35
+ return super().__new__(cls, priority=priority, **kwargs)
36
+ else:
37
+ with torch.cuda.device(device):
38
+ return super().__new__(cls, priority=priority, **kwargs)
39
+
40
+ def wait_event(self, event) -> None:
41
+ r"""Make all future work submitted to the stream wait for an event.
42
+
43
+ Args:
44
+ event (torch.cuda.Event): an event to wait for.
45
+
46
+ .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
47
+ `CUDA Stream documentation`_ for more info.
48
+
49
+ This function returns without waiting for :attr:`event`: only future
50
+ operations are affected.
51
+
52
+ .. _CUDA Stream documentation:
53
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
54
+ """
55
+ event.wait(self)
56
+
57
+ def wait_stream(self, stream) -> None:
58
+ r"""Synchronize with another stream.
59
+
60
+ All future work submitted to this stream will wait until all kernels
61
+ submitted to a given stream at the time of call complete.
62
+
63
+ Args:
64
+ stream (Stream): a stream to synchronize.
65
+
66
+ .. note:: This function returns without waiting for currently enqueued
67
+ kernels in :attr:`stream`: only future operations are affected.
68
+ """
69
+ self.wait_event(stream.record_event())
70
+
71
+ def record_event(self, event=None):
72
+ r"""Record an event.
73
+
74
+ Args:
75
+ event (torch.cuda.Event, optional): event to record. If not given, a new one
76
+ will be allocated.
77
+
78
+ Returns:
79
+ Recorded event.
80
+ """
81
+ if event is None:
82
+ event = Event()
83
+ event.record(self)
84
+ return event
85
+
86
+ def query(self) -> bool:
87
+ r"""Check if all the work submitted has been completed.
88
+
89
+ Returns:
90
+ A boolean indicating if all kernels in this stream are completed.
91
+ """
92
+ return super().query()
93
+
94
+ def synchronize(self) -> None:
95
+ r"""Wait for all the kernels in this stream to complete.
96
+
97
+ .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
98
+ `CUDA Stream documentation`_ for more info.
99
+ """
100
+ super().synchronize()
101
+
102
+ @property
103
+ def _as_parameter_(self):
104
+ return ctypes.c_void_p(self.cuda_stream)
105
+
106
+ def __eq__(self, o) -> bool:
107
+ if isinstance(o, Stream):
108
+ return super().__eq__(o)
109
+ return False
110
+
111
+ def __hash__(self):
112
+ return hash((self.cuda_stream, self.device))
113
+
114
+ def __repr__(self):
115
+ return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
116
+
117
+
118
+ class ExternalStream(Stream):
119
+ r"""Wrapper around an externally allocated CUDA stream.
120
+
121
+ This class is used to wrap streams allocated in other libraries in order
122
+ to facilitate data exchange and multi-library interactions.
123
+
124
+ .. note:: This class doesn't manage the stream life-cycle, it is the user
125
+ responsibility to keep the referenced stream alive while this class is
126
+ being used.
127
+
128
+ Args:
129
+ stream_ptr(int): Integer representation of the `cudaStream_t` value.
130
+ allocated externally.
131
+ device(torch.device or int, optional): the device where the stream
132
+ was originally allocated. If device is specified incorrectly,
133
+ subsequent launches using this stream may fail.
134
+ """
135
+
136
+ def __new__(cls, stream_ptr, device=None, **kwargs):
137
+ with torch.cuda.device(device):
138
+ return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
139
+
140
+
141
+ class Event(torch._C._CudaEventBase, _EventBase):
142
+ r"""Wrapper around a CUDA event.
143
+
144
+ CUDA events are synchronization markers that can be used to monitor the
145
+ device's progress, to accurately measure timing, and to synchronize CUDA
146
+ streams.
147
+
148
+ The underlying CUDA events are lazily initialized when the event is first
149
+ recorded or exported to another process. After creation, only streams on the
150
+ same device may record the event. However, streams on any device can wait on
151
+ the event.
152
+
153
+ Args:
154
+ enable_timing (bool, optional): indicates if the event should measure time
155
+ (default: ``False``)
156
+ blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
157
+ interprocess (bool): if ``True``, the event can be shared between processes
158
+ (default: ``False``)
159
+
160
+ .. _CUDA Event Documentation:
161
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
162
+ """
163
+
164
+ def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
165
+ return super().__new__(
166
+ cls,
167
+ enable_timing=enable_timing,
168
+ blocking=blocking,
169
+ interprocess=interprocess,
170
+ )
171
+
172
+ @classmethod
173
+ def from_ipc_handle(cls, device, handle):
174
+ r"""Reconstruct an event from an IPC handle on the given device."""
175
+ return super().from_ipc_handle(device, handle)
176
+
177
+ def record(self, stream=None):
178
+ r"""Record the event in a given stream.
179
+
180
+ Uses ``torch.cuda.current_stream()`` if no stream is specified. The
181
+ stream's device must match the event's device.
182
+ """
183
+ if stream is None:
184
+ stream = torch.cuda.current_stream()
185
+ super().record(stream)
186
+
187
+ def wait(self, stream=None) -> None:
188
+ r"""Make all future work submitted to the given stream wait for this event.
189
+
190
+ Use ``torch.cuda.current_stream()`` if no stream is specified.
191
+
192
+ .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
193
+ `CUDA Event documentation`_ for more info.
194
+ """
195
+ if stream is None:
196
+ stream = torch.cuda.current_stream()
197
+ super().wait(stream)
198
+
199
+ def query(self):
200
+ r"""Check if all work currently captured by event has completed.
201
+
202
+ Returns:
203
+ A boolean indicating if all work currently captured by event has
204
+ completed.
205
+ """
206
+ return super().query()
207
+
208
+ def elapsed_time(self, end_event):
209
+ r"""Return the time elapsed.
210
+
211
+ Time reported in milliseconds after the event was recorded and
212
+ before the end_event was recorded.
213
+ """
214
+ return super().elapsed_time(end_event)
215
+
216
+ def synchronize(self) -> None:
217
+ r"""Wait for the event to complete.
218
+
219
+ Waits until the completion of all work currently captured in this event.
220
+ This prevents the CPU thread from proceeding until the event completes.
221
+
222
+ .. note:: This is a wrapper around ``cudaEventSynchronize()``: see
223
+ `CUDA Event documentation`_ for more info.
224
+ """
225
+ super().synchronize()
226
+
227
+ def ipc_handle(self):
228
+ r"""Return an IPC handle of this event.
229
+
230
+ If not recorded yet, the event will use the current device.
231
+ """
232
+ return super().ipc_handle()
233
+
234
+ @property
235
+ def _as_parameter_(self):
236
+ return ctypes.c_void_p(self.cuda_event)
237
+
238
+ def __repr__(self) -> str:
239
+ if self.cuda_event:
240
+ return f"<torch.cuda.Event {self._as_parameter_.value:#x}>"
241
+ else:
242
+ return "<torch.cuda.Event uninitialized>"
.venv/lib/python3.11/site-packages/torch/cuda/tunable.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ This module exposes a TunableOp interface.
3
+
4
+ Some operations, such as GEMMs, could be implemented using more than one library
5
+ or more than one technique. For example, a GEMM could be implemented for CUDA or
6
+ ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and
7
+ hipblaslt libraries allow the user to query for all possible algorithms and then
8
+ choose one. How does one know which implementation is the fastest and should be
9
+ chosen? That's what TunableOp provides.
10
+
11
+ Enabling TunableOp and Tuning Separately
12
+ ========================================
13
+
14
+ The TunableOp feature is enabled separately from enabling the tuning phase
15
+ itself. Enabling TunableOp means that PyTorch will replace any standard
16
+ operators with their Tunable implementations. Any call to a TunableOp first
17
+ checks whether it has already been tuned for the given operator inputs. If so,
18
+ it will immediately call the tuned operation; no further tuning will take place
19
+ even when the tuning setting is enabled. Instead if no tuning result is found,
20
+ and tuning is enabled, the TunableOp will benchmark every registered
21
+ implementation of that operator for the given set of inputs and select the
22
+ fastest.
23
+
24
+ File Input and Output
25
+ =====================
26
+
27
+ The first time any TunableOp is invoked, the internal database of tuned
28
+ operations will be prepared by attempting to read the results from the given
29
+ file. The default filename is 'tunableop_results.csv'. To support tuning when
30
+ multiple GPUs are used across multiple processes, the GPU device ordinal is
31
+ automatically inserted into the filename to avoid multiple processes overwriting
32
+ the same file.
33
+
34
+ If tuning is enabled and new tunings are discovered during the course of your
35
+ workload, it will also write out to this same filename with all tunings, both
36
+ the ones it read in at startup as well as the new ones found at runtime. This
37
+ can be used, for example, to build up a tunings file across many workloads by
38
+ reusing the same file. The output file is automatically created when the
39
+ application terminates. This behavior can be controlled by the C++ and Python
40
+ APIs but not the environment variables.
41
+
42
+ Assuming you specified a filename, you'll end up with a CSV file with contents
43
+ like so::
44
+
45
+ Validator,PT_VERSION,2.2.0
46
+ Validator,ROCM_VERSION,6.0.0.0-12969-1544e39
47
+ Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7
48
+ Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty
49
+ GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
50
+ GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
51
+
52
+ Note the "Validator" lines. If you change a library verison, or ROCm version, or
53
+ PyTorch version, TunableOp will detect this and reject the tunings file because
54
+ the prior tunings are likely affected by other software changes.
55
+
56
+ The remaining lines are the tuned solutions for each TunableOp encountered
57
+ during your execution. Each line consists of 4 comma-separated fields: operator
58
+ name, operator parameters, solution name, and average execution time. The
59
+ execution time is an optional field. The CSV file can be edited, but with
60
+ caution. For example, the solution name (field 3) can be changed to "Default"
61
+ and it will fall back to the original PyTorch untuned implementation. Or, in the
62
+ case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution
63
+ index you can override the solution that TunableOp selected by replacing the
64
+ value. The operator name and parameters (fields 1 and 2) are internally named
65
+ and should not be modified. In the case of GemmTunableOp, field 1 indicates the
66
+ datatype and whether the inputs are transposed (T) or not (N) and field 2
67
+ indicates the M, N, K input shapes.
68
+
69
+ There is an option to enable verbose output but it is only recommended for
70
+ debugging purposes. This will produce a lot of diagnostic messages but may be
71
+ useful to see if TunableOp is being used at all. Otherwise, TunableOp is
72
+ completely silent, besides file output, unless there is a warning or error
73
+ during its use. The verbose option is only available by setting the environment
74
+ variable PYTORCH_TUNABLEOP_VEROBSE=1.
75
+
76
+ A Note on Tuning Behavior
77
+ =========================
78
+
79
+ Tuning an operator consists of iterating through the list or registered
80
+ implementations and profiling each one. The profile is established by running a
81
+ single implementation in a loop multiple times and taking the average execution
82
+ time.
83
+
84
+ By default, each possible solution for a given operator will be run for either
85
+ 100 iterations or as many iterations that can be run within 30ms, whichever is
86
+ smaller, and its average execution will be calculated. The fastest solution
87
+ among all that were successfully profiled will be chosen. A profile might fail
88
+ if the given solution doesn't achieve the same accuracy as the default
89
+ implementation or if the solution returns an error code.
90
+
91
+ Current Tunable Operators
92
+ =========================
93
+
94
+ TunableGemm for ROCm
95
+ --------------------
96
+
97
+ Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of
98
+ PyTorch will function correctly when using TunableOp but the only solution
99
+ available to CUDA builds is the 'Default' implementation i.e. the original
100
+ cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm()
101
+ or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a
102
+ given set of input arguments (transa, transb, m, n, k) will attempt to use the
103
+ fastest available implementation across both rocblas and hipblaslt.
104
+
105
+ Tuning Context
106
+ ==============
107
+
108
+ The behavior of TunableOp is currently manipulated through environment
109
+ variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the
110
+ torch.cuda.tunable python interfaces that wrap the C++ TuningContext. The
111
+ environment variables take precedence over any setting you manipulate using the
112
+ C++ or Python APIs.
113
+
114
+ """
115
+ from typing import Optional, Tuple
116
+
117
+ import torch
118
+
119
+
120
+ __all__ = [
121
+ "enable",
122
+ "is_enabled",
123
+ "tuning_enable",
124
+ "tuning_is_enabled",
125
+ "set_max_tuning_duration",
126
+ "get_max_tuning_duration",
127
+ "set_max_tuning_iterations",
128
+ "get_max_tuning_iterations",
129
+ "set_filename",
130
+ "get_filename",
131
+ "get_results",
132
+ "get_validators",
133
+ "write_file_on_exit",
134
+ "write_file",
135
+ "read_file",
136
+ ]
137
+
138
+
139
+ def enable(val: bool = True) -> None:
140
+ r"""This is the big on/off switch for all TunableOp implementations."""
141
+ torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined]
142
+
143
+
144
+ def is_enabled() -> bool:
145
+ r"""Returns whether the TunableOp feature is enabled."""
146
+ return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined]
147
+
148
+
149
+ def tuning_enable(val: bool = True) -> None:
150
+ r"""Enable tuning of TunableOp implementations.
151
+
152
+ When enabled, if a tuned entry isn't found, run the tuning step and record
153
+ the entry.
154
+ """
155
+ torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined]
156
+
157
+
158
+ def tuning_is_enabled() -> bool:
159
+ r"""Returns whether TunableOp implementations can be tuned."""
160
+ return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined]
161
+
162
+
163
+ def set_max_tuning_duration(duration: int) -> None:
164
+ r"""Set max time in milliseconds to spend tuning a given solution.
165
+
166
+ If both max tuning duration and iterations are set, the smaller of the two
167
+ will be honored. At minimum 1 tuning iteration will always be run.
168
+ """
169
+ torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined]
170
+
171
+
172
+ def get_max_tuning_duration() -> int:
173
+ r"""Get max time to spend tuning a given solution."""
174
+ return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined]
175
+
176
+
177
+ def set_max_tuning_iterations(iterations: int) -> None:
178
+ r"""Set max number of iterations to spend tuning a given solution.
179
+
180
+ If both max tuning duration and iterations are set, the smaller of the two
181
+ will be honored. At minimum 1 tuning iteration will always be run.
182
+ """
183
+ torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined]
184
+
185
+
186
+ def get_max_tuning_iterations() -> int:
187
+ r"""Get max iterations to spend tuning a given solution."""
188
+ return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined]
189
+
190
+
191
+ def set_filename(filename: str, insert_device_ordinal: bool = False) -> None:
192
+ r"""Set the filename to use for input/output of tuning results.
193
+
194
+ If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal
195
+ will be added to the given filename automatically. This can be used in a
196
+ 1-process-per-gpu cenario to ensure all processes write to a separate file.
197
+ """
198
+ torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined]
199
+
200
+
201
+ def get_filename() -> str:
202
+ r"""Get the results filename."""
203
+ return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined]
204
+
205
+
206
+ def get_results() -> Tuple[str, str, str, float]:
207
+ r"""Return all TunableOp results."""
208
+ return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined]
209
+
210
+
211
+ def get_validators() -> Tuple[str, str]:
212
+ r"""Return the TunableOp validators."""
213
+ return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined]
214
+
215
+
216
+ def write_file_on_exit(val: bool) -> None:
217
+ r"""During Tuning Context destruction, write file to disk.
218
+
219
+ This is useful as a final flush of your results to disk if your application
220
+ terminates as result of normal operation or an error. Manual flushing of
221
+ your results can be achieved by manually calling ``write_file()``."""
222
+ torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined]
223
+
224
+
225
+ def write_file(filename: Optional[str] = None) -> bool:
226
+ r"""Write results to a CSV file.
227
+
228
+ If :attr:`filename` is not given, ``get_filename()`` is called.
229
+ """
230
+ if filename is None:
231
+ filename = get_filename()
232
+ return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined]
233
+
234
+
235
+ def read_file(filename: Optional[str] = None) -> bool:
236
+ r"""Read results from a TunableOp CSV file.
237
+
238
+ If :attr:`filename` is not given, ``get_filename()`` is called.
239
+ """
240
+ if filename is None:
241
+ filename = get_filename()
242
+ return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined]
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.29 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc ADDED
Binary file (2.04 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc ADDED
Binary file (9.05 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc ADDED
Binary file (5.86 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc ADDED
Binary file (58.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc ADDED
Binary file (1.56 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc ADDED
Binary file (234 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc ADDED
Binary file (95 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc ADDED
Binary file (44.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc ADDED
Binary file (4.57 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc ADDED
Binary file (29.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc ADDED
Binary file (43.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc ADDED
Binary file (24.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc ADDED
Binary file (33.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc ADDED
Binary file (5.79 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc ADDED
Binary file (4.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (806 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc ADDED
Binary file (4.93 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc ADDED
Binary file (21.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc ADDED
Binary file (5.92 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-311.pyc ADDED
Binary file (4.97 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc ADDED
Binary file (43 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc ADDED
Binary file (4.49 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc ADDED
Binary file (30.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-311.pyc ADDED
Binary file (25.7 kB). View file