Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/torch/cuda/__init__.py +1661 -0
- .venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/cuda/_gpu_trace.py +75 -0
- .venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py +632 -0
- .venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py +621 -0
- .venv/lib/python3.11/site-packages/torch/cuda/_utils.py +38 -0
- .venv/lib/python3.11/site-packages/torch/cuda/comm.py +19 -0
- .venv/lib/python3.11/site-packages/torch/cuda/error.py +0 -0
- .venv/lib/python3.11/site-packages/torch/cuda/gds.py +129 -0
- .venv/lib/python3.11/site-packages/torch/cuda/graphs.py +491 -0
- .venv/lib/python3.11/site-packages/torch/cuda/jiterator.py +187 -0
- .venv/lib/python3.11/site-packages/torch/cuda/memory.py +1041 -0
- .venv/lib/python3.11/site-packages/torch/cuda/nccl.py +151 -0
- .venv/lib/python3.11/site-packages/torch/cuda/nvtx.py +93 -0
- .venv/lib/python3.11/site-packages/torch/cuda/profiler.py +86 -0
- .venv/lib/python3.11/site-packages/torch/cuda/random.py +182 -0
- .venv/lib/python3.11/site-packages/torch/cuda/sparse.py +1 -0
- .venv/lib/python3.11/site-packages/torch/cuda/streams.py +242 -0
- .venv/lib/python3.11/site-packages/torch/cuda/tunable.py +242 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -127,3 +127,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 127 |
.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 128 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 129 |
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 127 |
.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 128 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 129 |
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 130 |
+
.venv/lib/python3.11/site-packages/torch/nn/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/torch/cuda/__init__.py
ADDED
|
@@ -0,0 +1,1661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""
|
| 3 |
+
This package adds support for CUDA tensor types.
|
| 4 |
+
|
| 5 |
+
It implements the same function as CPU tensors, but they utilize
|
| 6 |
+
GPUs for computation.
|
| 7 |
+
|
| 8 |
+
It is lazily initialized, so you can always import it, and use
|
| 9 |
+
:func:`is_available()` to determine if your system supports CUDA.
|
| 10 |
+
|
| 11 |
+
:ref:`cuda-semantics` has more details about working with CUDA.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import importlib
|
| 15 |
+
import os
|
| 16 |
+
import threading
|
| 17 |
+
import traceback
|
| 18 |
+
import warnings
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
from typing import Any, Callable, cast, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch._C
|
| 24 |
+
from torch import device as _device
|
| 25 |
+
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
|
| 26 |
+
from torch.types import Device
|
| 27 |
+
|
| 28 |
+
from . import gds
|
| 29 |
+
from ._utils import _get_device_index
|
| 30 |
+
from .graphs import (
|
| 31 |
+
CUDAGraph,
|
| 32 |
+
graph,
|
| 33 |
+
graph_pool_handle,
|
| 34 |
+
is_current_stream_capturing,
|
| 35 |
+
make_graphed_callables,
|
| 36 |
+
)
|
| 37 |
+
from .streams import Event, ExternalStream, Stream
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
from torch._C import _cudart # type: ignore[attr-defined]
|
| 42 |
+
except ImportError:
|
| 43 |
+
_cudart = None
|
| 44 |
+
|
| 45 |
+
_initialized = False
|
| 46 |
+
_tls = threading.local()
|
| 47 |
+
_initialization_lock = threading.Lock()
|
| 48 |
+
_queued_calls: List[
|
| 49 |
+
Tuple[Callable[[], None], List[str]]
|
| 50 |
+
] = [] # don't invoke these until initialization occurs
|
| 51 |
+
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
| 52 |
+
_device_t = Union[_device, str, int, None]
|
| 53 |
+
|
| 54 |
+
_HAS_PYNVML = False
|
| 55 |
+
_PYNVML_ERR = None
|
| 56 |
+
try:
|
| 57 |
+
from torch import version as _version
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
if not _version.hip:
|
| 61 |
+
import pynvml # type: ignore[import]
|
| 62 |
+
else:
|
| 63 |
+
import amdsmi # type: ignore[import]
|
| 64 |
+
|
| 65 |
+
_HAS_PYNVML = True
|
| 66 |
+
except ModuleNotFoundError:
|
| 67 |
+
pass
|
| 68 |
+
finally:
|
| 69 |
+
del _version
|
| 70 |
+
except ImportError as err:
|
| 71 |
+
_PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
|
| 72 |
+
|
| 73 |
+
_lazy_seed_tracker = _LazySeedTracker()
|
| 74 |
+
|
| 75 |
+
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
|
| 76 |
+
if hasattr(torch._C, "_CudaDeviceProperties"):
|
| 77 |
+
_CudaDeviceProperties = torch._C._CudaDeviceProperties
|
| 78 |
+
else:
|
| 79 |
+
_CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") # type: ignore[assignment, misc]
|
| 80 |
+
|
| 81 |
+
if hasattr(torch._C, "_cuda_exchangeDevice"):
|
| 82 |
+
_exchange_device = torch._C._cuda_exchangeDevice
|
| 83 |
+
else:
|
| 84 |
+
|
| 85 |
+
def _exchange_device(device: int) -> int:
|
| 86 |
+
if device < 0:
|
| 87 |
+
return -1
|
| 88 |
+
raise RuntimeError("PyTorch was compiled without CUDA support")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if hasattr(torch._C, "_cuda_maybeExchangeDevice"):
|
| 92 |
+
_maybe_exchange_device = torch._C._cuda_maybeExchangeDevice
|
| 93 |
+
else:
|
| 94 |
+
|
| 95 |
+
def _maybe_exchange_device(device: int) -> int:
|
| 96 |
+
if device < 0:
|
| 97 |
+
return -1
|
| 98 |
+
raise RuntimeError("PyTorch was compiled without CUDA support")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
has_half: bool = True
|
| 102 |
+
has_magma: bool = torch._C._has_magma
|
| 103 |
+
|
| 104 |
+
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _is_compiled() -> bool:
|
| 108 |
+
r"""Return true if compile with CUDA support."""
|
| 109 |
+
return hasattr(torch._C, "_cuda_getDeviceCount")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _nvml_based_avail() -> bool:
|
| 113 |
+
return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def is_available() -> bool:
|
| 117 |
+
r"""Return a bool indicating if CUDA is currently available."""
|
| 118 |
+
if not _is_compiled():
|
| 119 |
+
return False
|
| 120 |
+
if _nvml_based_avail():
|
| 121 |
+
# The user has set an env variable to request this availability check that attempts to avoid fork poisoning by
|
| 122 |
+
# using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization
|
| 123 |
+
# fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`)
|
| 124 |
+
return device_count() > 0
|
| 125 |
+
else:
|
| 126 |
+
# The default availability inspection never throws and returns 0 if the driver is missing or can't
|
| 127 |
+
# be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver
|
| 128 |
+
# API via `cuInit`
|
| 129 |
+
return torch._C._cuda_getDeviceCount() > 0
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def is_bf16_supported(including_emulation: bool = True):
|
| 133 |
+
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
|
| 134 |
+
# Check for ROCm, if true return true, no ROCM_VERSION check required,
|
| 135 |
+
# since it is supported on AMD GPU archs.
|
| 136 |
+
if torch.version.hip:
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
# If CUDA is not available, than it does not support bf16 either
|
| 140 |
+
if not is_available():
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
device = torch.cuda.current_device()
|
| 144 |
+
|
| 145 |
+
# Check for CUDA version and device compute capability.
|
| 146 |
+
# This is a fast way to check for it.
|
| 147 |
+
cuda_version = torch.version.cuda
|
| 148 |
+
if (
|
| 149 |
+
cuda_version is not None
|
| 150 |
+
and int(cuda_version.split(".")[0]) >= 11
|
| 151 |
+
and torch.cuda.get_device_properties(device).major >= 8
|
| 152 |
+
):
|
| 153 |
+
return True
|
| 154 |
+
|
| 155 |
+
if not including_emulation:
|
| 156 |
+
return False
|
| 157 |
+
|
| 158 |
+
# Finally try to create a bfloat16 device.
|
| 159 |
+
return _check_bf16_tensor_supported(device)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@lru_cache(maxsize=16)
|
| 163 |
+
def _check_bf16_tensor_supported(device: _device_t):
|
| 164 |
+
try:
|
| 165 |
+
torch.tensor([1.0], dtype=torch.bfloat16, device=device)
|
| 166 |
+
return True
|
| 167 |
+
except Exception:
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _sleep(cycles):
|
| 172 |
+
torch._C._cuda_sleep(cycles)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _extract_arch_version(arch_string: str):
|
| 176 |
+
"""Extracts the architecture string from a CUDA version"""
|
| 177 |
+
base = arch_string.split("_")[1]
|
| 178 |
+
if base.endswith("a"):
|
| 179 |
+
base = base[:-1]
|
| 180 |
+
return int(base)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _check_capability():
|
| 184 |
+
incorrect_binary_warn = """
|
| 185 |
+
Found GPU%d %s which requires CUDA_VERSION >= %d to
|
| 186 |
+
work properly, but your PyTorch was compiled
|
| 187 |
+
with CUDA_VERSION %d. Please install the correct PyTorch binary
|
| 188 |
+
using instructions from https://pytorch.org
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
old_gpu_warn = """
|
| 192 |
+
Found GPU%d %s which is of cuda capability %d.%d.
|
| 193 |
+
PyTorch no longer supports this GPU because it is too old.
|
| 194 |
+
The minimum cuda capability supported by this library is %d.%d.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
if torch.version.cuda is not None: # on ROCm we don't want this check
|
| 198 |
+
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
|
| 199 |
+
for d in range(device_count()):
|
| 200 |
+
capability = get_device_capability(d)
|
| 201 |
+
major = capability[0]
|
| 202 |
+
minor = capability[1]
|
| 203 |
+
name = get_device_name(d)
|
| 204 |
+
current_arch = major * 10 + minor
|
| 205 |
+
min_arch = min(
|
| 206 |
+
(_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
|
| 207 |
+
default=35,
|
| 208 |
+
)
|
| 209 |
+
if current_arch < min_arch:
|
| 210 |
+
warnings.warn(
|
| 211 |
+
old_gpu_warn
|
| 212 |
+
% (d, name, major, minor, min_arch // 10, min_arch % 10)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _check_cubins():
|
| 217 |
+
incompatible_device_warn = """
|
| 218 |
+
{} with CUDA capability sm_{} is not compatible with the current PyTorch installation.
|
| 219 |
+
The current PyTorch install supports CUDA capabilities {}.
|
| 220 |
+
If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/
|
| 221 |
+
"""
|
| 222 |
+
if torch.version.cuda is None: # on ROCm we don't want this check
|
| 223 |
+
return
|
| 224 |
+
arch_list = get_arch_list()
|
| 225 |
+
if len(arch_list) == 0:
|
| 226 |
+
return
|
| 227 |
+
supported_sm = [_extract_arch_version(arch) for arch in arch_list if "sm_" in arch]
|
| 228 |
+
for idx in range(device_count()):
|
| 229 |
+
cap_major, cap_minor = get_device_capability(idx)
|
| 230 |
+
# NVIDIA GPU compute architectures are backward compatible within major version
|
| 231 |
+
supported = any(sm // 10 == cap_major for sm in supported_sm)
|
| 232 |
+
if not supported:
|
| 233 |
+
device_name = get_device_name(idx)
|
| 234 |
+
capability = cap_major * 10 + cap_minor
|
| 235 |
+
warnings.warn(
|
| 236 |
+
incompatible_device_warn.format(
|
| 237 |
+
device_name, capability, " ".join(arch_list), device_name
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def is_initialized():
|
| 243 |
+
r"""Return whether PyTorch's CUDA state has been initialized."""
|
| 244 |
+
return _initialized and not _is_in_bad_fork()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _lazy_call(callable, **kwargs):
|
| 248 |
+
if is_initialized():
|
| 249 |
+
callable()
|
| 250 |
+
else:
|
| 251 |
+
# TODO(torch_deploy): this accesses linecache, which attempts to read the
|
| 252 |
+
# file system to get traceback info. Patch linecache or do something
|
| 253 |
+
# else here if this ends up being important.
|
| 254 |
+
global _lazy_seed_tracker
|
| 255 |
+
if kwargs.get("seed_all", False):
|
| 256 |
+
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
|
| 257 |
+
elif kwargs.get("seed", False):
|
| 258 |
+
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
|
| 259 |
+
else:
|
| 260 |
+
# Don't store the actual traceback to avoid memory cycle
|
| 261 |
+
_queued_calls.append((callable, traceback.format_stack()))
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
_lazy_call(_check_capability)
|
| 265 |
+
_lazy_call(_check_cubins)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class DeferredCudaCallError(Exception):
|
| 269 |
+
pass
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
OutOfMemoryError = torch._C.OutOfMemoryError
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def init():
|
| 276 |
+
r"""Initialize PyTorch's CUDA state.
|
| 277 |
+
|
| 278 |
+
You may need to call this explicitly if you are interacting with
|
| 279 |
+
PyTorch via its C API, as Python bindings for CUDA functionality
|
| 280 |
+
will not be available until this initialization takes place.
|
| 281 |
+
Ordinary users should not need this, as all of PyTorch's CUDA methods
|
| 282 |
+
automatically initialize CUDA state on-demand.
|
| 283 |
+
|
| 284 |
+
Does nothing if the CUDA state is already initialized.
|
| 285 |
+
"""
|
| 286 |
+
_lazy_init()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _lazy_init():
|
| 290 |
+
global _initialized, _queued_calls
|
| 291 |
+
if is_initialized() or hasattr(_tls, "is_initializing"):
|
| 292 |
+
return
|
| 293 |
+
with _initialization_lock:
|
| 294 |
+
# We be double-checked locking, boys! This is OK because
|
| 295 |
+
# the above test was GIL protected anyway. The inner test
|
| 296 |
+
# is for when a thread blocked on some other thread which was
|
| 297 |
+
# doing the initialization; when they get the lock, they will
|
| 298 |
+
# find there is nothing left to do.
|
| 299 |
+
if is_initialized():
|
| 300 |
+
return
|
| 301 |
+
# It is important to prevent other threads from entering _lazy_init
|
| 302 |
+
# immediately, while we are still guaranteed to have the GIL, because some
|
| 303 |
+
# of the C calls we make below will release the GIL
|
| 304 |
+
if _is_in_bad_fork():
|
| 305 |
+
raise RuntimeError(
|
| 306 |
+
"Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
|
| 307 |
+
"multiprocessing, you must use the 'spawn' start method"
|
| 308 |
+
)
|
| 309 |
+
if not hasattr(torch._C, "_cuda_getDeviceCount"):
|
| 310 |
+
raise AssertionError("Torch not compiled with CUDA enabled")
|
| 311 |
+
if _cudart is None:
|
| 312 |
+
raise AssertionError(
|
| 313 |
+
"libcudart functions unavailable. It looks like you have a broken build?"
|
| 314 |
+
)
|
| 315 |
+
# This function throws if there's a driver initialization error, no GPUs
|
| 316 |
+
# are found or any other error occurs
|
| 317 |
+
if "CUDA_MODULE_LOADING" not in os.environ:
|
| 318 |
+
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
|
| 319 |
+
torch._C._cuda_init()
|
| 320 |
+
# Some of the queued calls may reentrantly call _lazy_init();
|
| 321 |
+
# we need to just return without initializing in that case.
|
| 322 |
+
# However, we must not let any *other* threads in!
|
| 323 |
+
_tls.is_initializing = True
|
| 324 |
+
|
| 325 |
+
for calls in _lazy_seed_tracker.get_calls():
|
| 326 |
+
if calls:
|
| 327 |
+
_queued_calls.append(calls)
|
| 328 |
+
|
| 329 |
+
try:
|
| 330 |
+
for queued_call, orig_traceback in _queued_calls:
|
| 331 |
+
try:
|
| 332 |
+
queued_call()
|
| 333 |
+
except Exception as e:
|
| 334 |
+
msg = (
|
| 335 |
+
f"CUDA call failed lazily at initialization with error: {str(e)}\n\n"
|
| 336 |
+
f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
| 337 |
+
)
|
| 338 |
+
raise DeferredCudaCallError(msg) from e
|
| 339 |
+
finally:
|
| 340 |
+
delattr(_tls, "is_initializing")
|
| 341 |
+
_initialized = True
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def cudart():
|
| 345 |
+
r"""Retrieves the CUDA runtime API module.
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
This function initializes the CUDA runtime environment if it is not already
|
| 349 |
+
initialized and returns the CUDA runtime API module (_cudart). The CUDA
|
| 350 |
+
runtime API module provides access to various CUDA runtime functions.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
``None``
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
module: The CUDA runtime API module (_cudart).
|
| 357 |
+
|
| 358 |
+
Raises:
|
| 359 |
+
RuntimeError: If CUDA cannot be re-initialized in a forked subprocess.
|
| 360 |
+
AssertionError: If PyTorch is not compiled with CUDA support or if libcudart functions are unavailable.
|
| 361 |
+
|
| 362 |
+
Example of CUDA operations with profiling:
|
| 363 |
+
>>> import torch
|
| 364 |
+
>>> from torch.cuda import cudart, check_error
|
| 365 |
+
>>> import os
|
| 366 |
+
>>>
|
| 367 |
+
>>> os.environ['CUDA_PROFILE'] = '1'
|
| 368 |
+
>>>
|
| 369 |
+
>>> def perform_cuda_operations_with_streams():
|
| 370 |
+
>>> stream = torch.cuda.Stream()
|
| 371 |
+
>>> with torch.cuda.stream(stream):
|
| 372 |
+
>>> x = torch.randn(100, 100, device='cuda')
|
| 373 |
+
>>> y = torch.randn(100, 100, device='cuda')
|
| 374 |
+
>>> z = torch.mul(x, y)
|
| 375 |
+
>>> return z
|
| 376 |
+
>>>
|
| 377 |
+
>>> torch.cuda.synchronize()
|
| 378 |
+
>>> print("====== Start nsys profiling ======")
|
| 379 |
+
>>> check_error(cudart().cudaProfilerStart())
|
| 380 |
+
>>> with torch.autograd.profiler.emit_nvtx():
|
| 381 |
+
>>> result = perform_cuda_operations_with_streams()
|
| 382 |
+
>>> print("CUDA operations completed.")
|
| 383 |
+
>>> check_error(torch.cuda.cudart().cudaProfilerStop())
|
| 384 |
+
>>> print("====== End nsys profiling ======")
|
| 385 |
+
|
| 386 |
+
To run this example and save the profiling information, execute:
|
| 387 |
+
>>> $ nvprof --profile-from-start off --csv --print-summary -o trace_name.prof -f -- python cudart_test.py
|
| 388 |
+
|
| 389 |
+
This command profiles the CUDA operations in the provided script and saves
|
| 390 |
+
the profiling information to a file named `trace_name.prof`.
|
| 391 |
+
The `--profile-from-start off` option ensures that profiling starts only
|
| 392 |
+
after the `cudaProfilerStart` call in the script.
|
| 393 |
+
The `--csv` and `--print-summary` options format the profiling output as a
|
| 394 |
+
CSV file and print a summary, respectively.
|
| 395 |
+
The `-o` option specifies the output file name, and the `-f` option forces the
|
| 396 |
+
overwrite of the output file if it already exists.
|
| 397 |
+
"""
|
| 398 |
+
_lazy_init()
|
| 399 |
+
return _cudart
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class cudaStatus:
|
| 403 |
+
SUCCESS: int = 0
|
| 404 |
+
ERROR_NOT_READY: int = 34
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class CudaError(RuntimeError):
|
| 408 |
+
def __init__(self, code: int) -> None:
|
| 409 |
+
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
|
| 410 |
+
super().__init__(f"{msg} ({code})")
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def check_error(res: int) -> None:
|
| 414 |
+
if res != _cudart.cudaError.success:
|
| 415 |
+
raise CudaError(res)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class _DeviceGuard:
|
| 419 |
+
def __init__(self, index: int):
|
| 420 |
+
self.idx = index
|
| 421 |
+
self.prev_idx = -1
|
| 422 |
+
|
| 423 |
+
def __enter__(self):
|
| 424 |
+
self.prev_idx = torch.cuda._exchange_device(self.idx)
|
| 425 |
+
|
| 426 |
+
def __exit__(self, type: Any, value: Any, traceback: Any):
|
| 427 |
+
self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
|
| 428 |
+
return False
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class device:
|
| 432 |
+
r"""Context-manager that changes the selected device.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
device (torch.device or int): device index to select. It's a no-op if
|
| 436 |
+
this argument is a negative integer or ``None``.
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
def __init__(self, device: Any):
|
| 440 |
+
self.idx = _get_device_index(device, optional=True)
|
| 441 |
+
self.prev_idx = -1
|
| 442 |
+
|
| 443 |
+
def __enter__(self):
|
| 444 |
+
self.prev_idx = torch.cuda._exchange_device(self.idx)
|
| 445 |
+
|
| 446 |
+
def __exit__(self, type: Any, value: Any, traceback: Any):
|
| 447 |
+
self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
|
| 448 |
+
return False
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class device_of(device):
|
| 452 |
+
r"""Context-manager that changes the current device to that of given object.
|
| 453 |
+
|
| 454 |
+
You can use both tensors and storages as arguments. If a given object is
|
| 455 |
+
not allocated on a GPU, this is a no-op.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
obj (Tensor or Storage): object allocated on the selected device.
|
| 459 |
+
"""
|
| 460 |
+
|
| 461 |
+
def __init__(self, obj):
|
| 462 |
+
idx = obj.get_device() if obj.is_cuda else -1
|
| 463 |
+
super().__init__(idx)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def set_device(device: _device_t) -> None:
|
| 467 |
+
r"""Set the current device.
|
| 468 |
+
|
| 469 |
+
Usage of this function is discouraged in favor of :any:`device`. In most
|
| 470 |
+
cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
device (torch.device or int): selected device. This function is a no-op
|
| 474 |
+
if this argument is negative.
|
| 475 |
+
"""
|
| 476 |
+
device = _get_device_index(device)
|
| 477 |
+
if device >= 0:
|
| 478 |
+
torch._C._cuda_setDevice(device)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def get_device_name(device: Optional[_device_t] = None) -> str:
|
| 482 |
+
r"""Get the name of a device.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
device (torch.device or int or str, optional): device for which to return the
|
| 486 |
+
name. This function is a no-op if this argument is a negative
|
| 487 |
+
integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
|
| 488 |
+
if :attr:`device` is ``None`` (default).
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
str: the name of the device
|
| 492 |
+
"""
|
| 493 |
+
return get_device_properties(device).name
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
|
| 497 |
+
r"""Get the cuda capability of a device.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
device (torch.device or int or str, optional): device for which to return the
|
| 501 |
+
device capability. This function is a no-op if this argument is
|
| 502 |
+
a negative integer. It uses the current device, given by
|
| 503 |
+
:func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
|
| 504 |
+
(default).
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
tuple(int, int): the major and minor cuda capability of the device
|
| 508 |
+
"""
|
| 509 |
+
prop = get_device_properties(device)
|
| 510 |
+
return prop.major, prop.minor
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
|
| 514 |
+
r"""Get the properties of a device.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
device (torch.device or int or str): device for which to return the
|
| 518 |
+
properties of the device.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
_CudaDeviceProperties: the properties of the device
|
| 522 |
+
"""
|
| 523 |
+
_lazy_init() # will define _get_device_properties
|
| 524 |
+
device = _get_device_index(device, optional=True)
|
| 525 |
+
if device < 0 or device >= device_count():
|
| 526 |
+
raise AssertionError("Invalid device id")
|
| 527 |
+
return _get_device_properties(device) # type: ignore[name-defined]
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool:
|
| 531 |
+
r"""Check if peer access between two devices is possible."""
|
| 532 |
+
_lazy_init()
|
| 533 |
+
device = _get_device_index(device, optional=True)
|
| 534 |
+
peer_device = _get_device_index(peer_device)
|
| 535 |
+
if device < 0 or device >= device_count():
|
| 536 |
+
raise AssertionError("Invalid device id")
|
| 537 |
+
if peer_device < 0 or peer_device >= device_count():
|
| 538 |
+
raise AssertionError("Invalid peer device id")
|
| 539 |
+
return torch._C._cuda_canDeviceAccessPeer(device, peer_device)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class StreamContext:
|
| 543 |
+
r"""Context-manager that selects a given stream.
|
| 544 |
+
|
| 545 |
+
All CUDA kernels queued within its context will be enqueued on a selected
|
| 546 |
+
stream.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
Stream (Stream): selected stream. This manager is a no-op if it's
|
| 550 |
+
``None``.
|
| 551 |
+
.. note:: Streams are per-device.
|
| 552 |
+
"""
|
| 553 |
+
cur_stream: Optional["torch.cuda.Stream"]
|
| 554 |
+
|
| 555 |
+
def __init__(self, stream: Optional["torch.cuda.Stream"]):
|
| 556 |
+
self.stream = stream
|
| 557 |
+
self.idx = _get_device_index(None, True)
|
| 558 |
+
if not torch.jit.is_scripting():
|
| 559 |
+
if self.idx is None:
|
| 560 |
+
self.idx = -1
|
| 561 |
+
|
| 562 |
+
self.src_prev_stream = (
|
| 563 |
+
None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
|
| 564 |
+
)
|
| 565 |
+
self.dst_prev_stream = (
|
| 566 |
+
None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
def __enter__(self):
|
| 570 |
+
# Local cur_stream variable for type refinement
|
| 571 |
+
cur_stream = self.stream
|
| 572 |
+
# Return if stream is None or CUDA device not available
|
| 573 |
+
if cur_stream is None or self.idx == -1:
|
| 574 |
+
return
|
| 575 |
+
self.src_prev_stream = torch.cuda.current_stream(None)
|
| 576 |
+
|
| 577 |
+
# If the stream is not on the current device, then
|
| 578 |
+
# set the current stream on the device
|
| 579 |
+
if self.src_prev_stream.device != cur_stream.device:
|
| 580 |
+
with device(cur_stream.device):
|
| 581 |
+
self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device)
|
| 582 |
+
torch.cuda.set_stream(cur_stream)
|
| 583 |
+
|
| 584 |
+
def __exit__(self, type: Any, value: Any, traceback: Any):
|
| 585 |
+
# Local cur_stream variable for type refinement
|
| 586 |
+
cur_stream = self.stream
|
| 587 |
+
# If stream is None or no CUDA device available, return
|
| 588 |
+
if cur_stream is None or self.idx == -1:
|
| 589 |
+
return
|
| 590 |
+
|
| 591 |
+
# Reset the stream on the original device
|
| 592 |
+
# and destination device
|
| 593 |
+
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
|
| 594 |
+
torch.cuda.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
|
| 595 |
+
torch.cuda.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext:
|
| 599 |
+
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
| 600 |
+
|
| 601 |
+
Arguments:
|
| 602 |
+
stream (Stream): selected stream. This manager is a no-op if it's
|
| 603 |
+
``None``.
|
| 604 |
+
..Note:: In eager mode stream is of type Stream class while in JIT it is
|
| 605 |
+
an object of the custom class ``torch.classes.cuda.Stream``.
|
| 606 |
+
"""
|
| 607 |
+
return StreamContext(stream)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def _set_stream_by_id(stream_id, device_index, device_type):
|
| 611 |
+
r"""set stream specified by the stream id, device index and
|
| 612 |
+
device type
|
| 613 |
+
|
| 614 |
+
Args: stream_id (int): stream id in stream pool
|
| 615 |
+
device_index (int): device index in topo
|
| 616 |
+
device_type (int): enum device type
|
| 617 |
+
"""
|
| 618 |
+
torch._C._cuda_setStream(
|
| 619 |
+
stream_id=stream_id,
|
| 620 |
+
device_index=device_index,
|
| 621 |
+
device_type=device_type,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def set_stream(stream: Stream):
|
| 626 |
+
r"""Set the current stream.This is a wrapper API to set the stream.
|
| 627 |
+
Usage of this function is discouraged in favor of the ``stream``
|
| 628 |
+
context manager.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
stream (Stream): selected stream. This function is a no-op
|
| 632 |
+
if this argument is ``None``.
|
| 633 |
+
"""
|
| 634 |
+
if stream is None:
|
| 635 |
+
return
|
| 636 |
+
_set_stream_by_id(
|
| 637 |
+
stream_id=stream.stream_id,
|
| 638 |
+
device_index=stream.device_index,
|
| 639 |
+
device_type=stream.device_type,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def _parse_visible_devices() -> Union[List[int], List[str]]:
|
| 644 |
+
r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
|
| 645 |
+
var = os.getenv("CUDA_VISIBLE_DEVICES")
|
| 646 |
+
|
| 647 |
+
if torch.version.hip:
|
| 648 |
+
hip_devices = os.getenv("HIP_VISIBLE_DEVICES")
|
| 649 |
+
if hip_devices is not None:
|
| 650 |
+
var = hip_devices
|
| 651 |
+
|
| 652 |
+
if var is None:
|
| 653 |
+
return list(range(64))
|
| 654 |
+
|
| 655 |
+
def _strtoul(s: str) -> int:
|
| 656 |
+
"""Return -1 or positive integer sequence string starts with."""
|
| 657 |
+
if not s:
|
| 658 |
+
return -1
|
| 659 |
+
for idx, c in enumerate(s):
|
| 660 |
+
if not (c.isdigit() or (idx == 0 and c in "+-")):
|
| 661 |
+
break
|
| 662 |
+
if idx + 1 == len(s):
|
| 663 |
+
idx += 1
|
| 664 |
+
return int(s[:idx]) if idx > 0 else -1
|
| 665 |
+
|
| 666 |
+
def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
|
| 667 |
+
rcs: List[str] = []
|
| 668 |
+
for elem in lst.split(","):
|
| 669 |
+
# Repeated id results in empty set
|
| 670 |
+
if elem in rcs:
|
| 671 |
+
return cast(List[str], [])
|
| 672 |
+
# Anything other but prefix is ignored
|
| 673 |
+
if not elem.startswith(prefix):
|
| 674 |
+
break
|
| 675 |
+
rcs.append(elem)
|
| 676 |
+
return rcs
|
| 677 |
+
|
| 678 |
+
if var.startswith("GPU-"):
|
| 679 |
+
return parse_list_with_prefix(var, "GPU-")
|
| 680 |
+
if var.startswith("MIG-"):
|
| 681 |
+
return parse_list_with_prefix(var, "MIG-")
|
| 682 |
+
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
| 683 |
+
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
| 684 |
+
rc: List[int] = []
|
| 685 |
+
for elem in var.split(","):
|
| 686 |
+
x = _strtoul(elem.strip())
|
| 687 |
+
# Repeated ordinal results in empty set
|
| 688 |
+
if x in rc:
|
| 689 |
+
return cast(List[int], [])
|
| 690 |
+
# Negative value aborts the sequence
|
| 691 |
+
if x < 0:
|
| 692 |
+
break
|
| 693 |
+
rc.append(x)
|
| 694 |
+
return rc
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def _raw_device_count_amdsmi() -> int:
|
| 698 |
+
if not _HAS_PYNVML: # If amdsmi is not available
|
| 699 |
+
return -1
|
| 700 |
+
try:
|
| 701 |
+
amdsmi.amdsmi_init()
|
| 702 |
+
except amdsmi.AmdSmiException as e:
|
| 703 |
+
warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}")
|
| 704 |
+
return -1
|
| 705 |
+
socket_handles = amdsmi.amdsmi_get_processor_handles()
|
| 706 |
+
return len(socket_handles)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def _raw_device_count_nvml() -> int:
|
| 710 |
+
r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
|
| 711 |
+
from ctypes import byref, c_int, CDLL
|
| 712 |
+
|
| 713 |
+
nvml_h = CDLL("libnvidia-ml.so.1")
|
| 714 |
+
rc = nvml_h.nvmlInit()
|
| 715 |
+
if rc != 0:
|
| 716 |
+
warnings.warn("Can't initialize NVML")
|
| 717 |
+
return -1
|
| 718 |
+
dev_count = c_int(-1)
|
| 719 |
+
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
| 720 |
+
if rc != 0:
|
| 721 |
+
warnings.warn("Can't get nvml device count")
|
| 722 |
+
return -1
|
| 723 |
+
del nvml_h
|
| 724 |
+
return dev_count.value
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def _raw_device_uuid_amdsmi() -> Optional[List[str]]:
|
| 728 |
+
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
|
| 729 |
+
|
| 730 |
+
if not _HAS_PYNVML: # If amdsmi is not available
|
| 731 |
+
return None
|
| 732 |
+
try:
|
| 733 |
+
amdsmi.amdsmi_init()
|
| 734 |
+
except amdsmi.AmdSmiException:
|
| 735 |
+
warnings.warn("Can't initialize amdsmi")
|
| 736 |
+
return None
|
| 737 |
+
try:
|
| 738 |
+
socket_handles = amdsmi.amdsmi_get_processor_handles()
|
| 739 |
+
dev_count = len(socket_handles)
|
| 740 |
+
except amdsmi.AmdSmiException:
|
| 741 |
+
warnings.warn("Can't get amdsmi device count")
|
| 742 |
+
return None
|
| 743 |
+
uuids: List[str] = []
|
| 744 |
+
for idx in range(dev_count):
|
| 745 |
+
try:
|
| 746 |
+
handler = amdsmi.amdsmi_get_processor_handles()[idx]
|
| 747 |
+
except amdsmi.AmdSmiException:
|
| 748 |
+
warnings.warn("Cannot get amd device handler")
|
| 749 |
+
return None
|
| 750 |
+
try:
|
| 751 |
+
uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler)
|
| 752 |
+
except amdsmi.AmdSmiException:
|
| 753 |
+
warnings.warn("Cannot get uuid for amd device")
|
| 754 |
+
return None
|
| 755 |
+
uuids.append(str(uuid))
|
| 756 |
+
return uuids
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def _raw_device_uuid_nvml() -> Optional[List[str]]:
|
| 760 |
+
r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
|
| 761 |
+
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
|
| 762 |
+
|
| 763 |
+
nvml_h = CDLL("libnvidia-ml.so.1")
|
| 764 |
+
rc = nvml_h.nvmlInit()
|
| 765 |
+
if rc != 0:
|
| 766 |
+
warnings.warn("Can't initialize NVML")
|
| 767 |
+
return None
|
| 768 |
+
dev_count = c_int(-1)
|
| 769 |
+
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
| 770 |
+
if rc != 0:
|
| 771 |
+
warnings.warn("Can't get nvml device count")
|
| 772 |
+
return None
|
| 773 |
+
uuids: List[str] = []
|
| 774 |
+
for idx in range(dev_count.value):
|
| 775 |
+
dev_id = c_void_p()
|
| 776 |
+
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
| 777 |
+
if rc != 0:
|
| 778 |
+
warnings.warn("Can't get device handle")
|
| 779 |
+
return None
|
| 780 |
+
buf_len = 96
|
| 781 |
+
buf = create_string_buffer(buf_len)
|
| 782 |
+
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
|
| 783 |
+
if rc != 0:
|
| 784 |
+
warnings.warn("Can't get device UUID")
|
| 785 |
+
return None
|
| 786 |
+
uuids.append(buf.raw.decode("ascii").strip("\0"))
|
| 787 |
+
del nvml_h
|
| 788 |
+
return uuids
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
|
| 792 |
+
r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""
|
| 793 |
+
|
| 794 |
+
def uuid_to_orinal(candidate: str, uuids: List[str]) -> int:
|
| 795 |
+
best_match = -1
|
| 796 |
+
for idx, uuid in enumerate(uuids):
|
| 797 |
+
if not uuid.startswith(candidate):
|
| 798 |
+
continue
|
| 799 |
+
# Ambiguous candidate
|
| 800 |
+
if best_match != -1:
|
| 801 |
+
return -1
|
| 802 |
+
best_match = idx
|
| 803 |
+
return best_match
|
| 804 |
+
|
| 805 |
+
rc: List[int] = []
|
| 806 |
+
for candidate in candidates:
|
| 807 |
+
idx = uuid_to_orinal(candidate, uuids)
|
| 808 |
+
# First invalid ordinal stops parsing
|
| 809 |
+
if idx < 0:
|
| 810 |
+
break
|
| 811 |
+
# Duplicates result in empty set
|
| 812 |
+
if idx in rc:
|
| 813 |
+
return cast(List[int], [])
|
| 814 |
+
rc.append(idx)
|
| 815 |
+
return rc
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def _device_count_amdsmi() -> int:
|
| 819 |
+
visible_devices = _parse_visible_devices()
|
| 820 |
+
if not visible_devices:
|
| 821 |
+
return 0
|
| 822 |
+
try:
|
| 823 |
+
if type(visible_devices[0]) is str:
|
| 824 |
+
return -1
|
| 825 |
+
else:
|
| 826 |
+
raw_cnt = _raw_device_count_amdsmi()
|
| 827 |
+
if raw_cnt <= 0:
|
| 828 |
+
return raw_cnt
|
| 829 |
+
# Trim the list up to a maximum available device
|
| 830 |
+
for idx, val in enumerate(visible_devices):
|
| 831 |
+
if cast(int, val) >= raw_cnt:
|
| 832 |
+
return idx
|
| 833 |
+
except OSError:
|
| 834 |
+
return -1
|
| 835 |
+
except AttributeError:
|
| 836 |
+
return -1
|
| 837 |
+
return len(visible_devices)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def _device_count_nvml() -> int:
|
| 841 |
+
r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
|
| 842 |
+
|
| 843 |
+
Negative value is returned if NVML discovery or initialization has failed.
|
| 844 |
+
"""
|
| 845 |
+
visible_devices = _parse_visible_devices()
|
| 846 |
+
if not visible_devices:
|
| 847 |
+
return 0
|
| 848 |
+
try:
|
| 849 |
+
if type(visible_devices[0]) is str:
|
| 850 |
+
# Skip MIG parsing
|
| 851 |
+
if visible_devices[0].startswith("MIG-"):
|
| 852 |
+
return -1
|
| 853 |
+
uuids = _raw_device_uuid_nvml()
|
| 854 |
+
if uuids is None:
|
| 855 |
+
return -1
|
| 856 |
+
visible_devices = _transform_uuid_to_ordinals(
|
| 857 |
+
cast(List[str], visible_devices), uuids
|
| 858 |
+
)
|
| 859 |
+
else:
|
| 860 |
+
raw_cnt = _raw_device_count_nvml()
|
| 861 |
+
if raw_cnt <= 0:
|
| 862 |
+
return raw_cnt
|
| 863 |
+
# Trim the list up to a maximum available device
|
| 864 |
+
for idx, val in enumerate(visible_devices):
|
| 865 |
+
if cast(int, val) >= raw_cnt:
|
| 866 |
+
return idx
|
| 867 |
+
except OSError:
|
| 868 |
+
return -1
|
| 869 |
+
except AttributeError:
|
| 870 |
+
return -1
|
| 871 |
+
return len(visible_devices)
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
|
| 875 |
+
r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
|
| 876 |
+
idx = _get_device_index(device, optional=True)
|
| 877 |
+
visible_devices = _parse_visible_devices()
|
| 878 |
+
if type(visible_devices[0]) is str:
|
| 879 |
+
uuids = _raw_device_uuid_nvml()
|
| 880 |
+
if uuids is None:
|
| 881 |
+
raise RuntimeError("Can't get device UUIDs")
|
| 882 |
+
visible_devices = _transform_uuid_to_ordinals(
|
| 883 |
+
cast(List[str], visible_devices), uuids
|
| 884 |
+
)
|
| 885 |
+
visible_devices = cast(List[int], visible_devices)
|
| 886 |
+
if idx < 0 or idx >= len(visible_devices):
|
| 887 |
+
raise RuntimeError(
|
| 888 |
+
f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})"
|
| 889 |
+
)
|
| 890 |
+
return visible_devices[idx]
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
_cached_device_count: Optional[int] = None
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
def device_count() -> int:
|
| 897 |
+
r"""Return the number of GPUs available."""
|
| 898 |
+
global _cached_device_count
|
| 899 |
+
if not _is_compiled():
|
| 900 |
+
return 0
|
| 901 |
+
if _cached_device_count is not None:
|
| 902 |
+
return _cached_device_count
|
| 903 |
+
# bypass _device_count_nvml() if rocm (not supported)
|
| 904 |
+
nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml()
|
| 905 |
+
r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
|
| 906 |
+
# NB: Do not cache the device count prior to CUDA initialization, because
|
| 907 |
+
# the number of devices can change due to changes to CUDA_VISIBLE_DEVICES
|
| 908 |
+
# setting prior to CUDA initialization.
|
| 909 |
+
if _initialized:
|
| 910 |
+
_cached_device_count = r
|
| 911 |
+
return r
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
def get_arch_list() -> List[str]:
|
| 915 |
+
r"""Return list CUDA architectures this library was compiled for."""
|
| 916 |
+
if not is_available():
|
| 917 |
+
return []
|
| 918 |
+
arch_flags = torch._C._cuda_getArchFlags()
|
| 919 |
+
if arch_flags is None:
|
| 920 |
+
return []
|
| 921 |
+
return arch_flags.split()
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def get_gencode_flags() -> str:
|
| 925 |
+
r"""Return NVCC gencode flags this library was compiled with."""
|
| 926 |
+
arch_list = get_arch_list()
|
| 927 |
+
if len(arch_list) == 0:
|
| 928 |
+
return ""
|
| 929 |
+
arch_list_ = [arch.split("_") for arch in arch_list]
|
| 930 |
+
return " ".join(
|
| 931 |
+
[
|
| 932 |
+
f"-gencode compute=compute_{arch},code={kind}_{arch}"
|
| 933 |
+
for (kind, arch) in arch_list_
|
| 934 |
+
]
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
def current_device() -> int:
|
| 939 |
+
r"""Return the index of a currently selected device."""
|
| 940 |
+
_lazy_init()
|
| 941 |
+
return torch._C._cuda_getDevice()
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def synchronize(device: _device_t = None) -> None:
|
| 945 |
+
r"""Wait for all kernels in all streams on a CUDA device to complete.
|
| 946 |
+
|
| 947 |
+
Args:
|
| 948 |
+
device (torch.device or int, optional): device for which to synchronize.
|
| 949 |
+
It uses the current device, given by :func:`~torch.cuda.current_device`,
|
| 950 |
+
if :attr:`device` is ``None`` (default).
|
| 951 |
+
"""
|
| 952 |
+
_lazy_init()
|
| 953 |
+
with torch.cuda.device(device):
|
| 954 |
+
return torch._C._cuda_synchronize()
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
def ipc_collect():
|
| 958 |
+
r"""Force collects GPU memory after it has been released by CUDA IPC.
|
| 959 |
+
|
| 960 |
+
.. note::
|
| 961 |
+
Checks if any sent CUDA tensors could be cleaned from the memory. Force
|
| 962 |
+
closes shared memory file used for reference counting if there is no
|
| 963 |
+
active counters. Useful when the producer process stopped actively sending
|
| 964 |
+
tensors and want to release unused memory.
|
| 965 |
+
"""
|
| 966 |
+
_lazy_init()
|
| 967 |
+
return torch._C._cuda_ipc_collect()
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
| 971 |
+
r"""Return the currently selected :class:`Stream` for a given device.
|
| 972 |
+
|
| 973 |
+
Args:
|
| 974 |
+
device (torch.device or int, optional): selected device. Returns
|
| 975 |
+
the currently selected :class:`Stream` for the current device, given
|
| 976 |
+
by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
|
| 977 |
+
(default).
|
| 978 |
+
"""
|
| 979 |
+
_lazy_init()
|
| 980 |
+
streamdata = torch._C._cuda_getCurrentStream(
|
| 981 |
+
_get_device_index(device, optional=True)
|
| 982 |
+
)
|
| 983 |
+
return Stream(
|
| 984 |
+
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
def default_stream(device: Optional[_device_t] = None) -> Stream:
|
| 989 |
+
r"""Return the default :class:`Stream` for a given device.
|
| 990 |
+
|
| 991 |
+
Args:
|
| 992 |
+
device (torch.device or int, optional): selected device. Returns
|
| 993 |
+
the default :class:`Stream` for the current device, given by
|
| 994 |
+
:func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
|
| 995 |
+
(default).
|
| 996 |
+
"""
|
| 997 |
+
_lazy_init()
|
| 998 |
+
streamdata = torch._C._cuda_getDefaultStream(
|
| 999 |
+
_get_device_index(device, optional=True)
|
| 1000 |
+
)
|
| 1001 |
+
return Stream(
|
| 1002 |
+
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def current_blas_handle():
|
| 1007 |
+
r"""Return cublasHandle_t pointer to current cuBLAS handle"""
|
| 1008 |
+
_lazy_init()
|
| 1009 |
+
return torch._C._cuda_getCurrentBlasHandle()
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
def set_sync_debug_mode(debug_mode: Union[int, str]) -> None:
|
| 1013 |
+
r"""Set the debug mode for cuda synchronizing operations.
|
| 1014 |
+
|
| 1015 |
+
Args:
|
| 1016 |
+
debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations,
|
| 1017 |
+
if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations.
|
| 1018 |
+
|
| 1019 |
+
Warning:
|
| 1020 |
+
This is an experimental feature, and not all synchronizing operations will trigger warning or error. In
|
| 1021 |
+
particular, operations in torch.distributed and torch.sparse namespaces are not covered yet.
|
| 1022 |
+
"""
|
| 1023 |
+
_lazy_init()
|
| 1024 |
+
if isinstance(debug_mode, str):
|
| 1025 |
+
if debug_mode == "default":
|
| 1026 |
+
debug_mode = 0
|
| 1027 |
+
elif debug_mode == "warn":
|
| 1028 |
+
debug_mode = 1
|
| 1029 |
+
elif debug_mode == "error":
|
| 1030 |
+
debug_mode = 2
|
| 1031 |
+
else:
|
| 1032 |
+
raise RuntimeError(
|
| 1033 |
+
"invalid value of debug_mode, expected one of `default`, `warn`, `error`"
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
torch._C._cuda_set_sync_debug_mode(debug_mode)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
def get_sync_debug_mode() -> int:
|
| 1040 |
+
r"""Return current value of debug mode for cuda synchronizing operations."""
|
| 1041 |
+
_lazy_init()
|
| 1042 |
+
return torch._C._cuda_get_sync_debug_mode()
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
|
| 1046 |
+
if not _HAS_PYNVML:
|
| 1047 |
+
raise ModuleNotFoundError(
|
| 1048 |
+
"pynvml does not seem to be installed or it can't be imported."
|
| 1049 |
+
) from _PYNVML_ERR
|
| 1050 |
+
from pynvml import NVMLError_DriverNotLoaded
|
| 1051 |
+
|
| 1052 |
+
try:
|
| 1053 |
+
pynvml.nvmlInit()
|
| 1054 |
+
except NVMLError_DriverNotLoaded as e:
|
| 1055 |
+
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
|
| 1056 |
+
|
| 1057 |
+
device = _get_nvml_device_index(device)
|
| 1058 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 1059 |
+
return handle
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None):
|
| 1063 |
+
if not _HAS_PYNVML:
|
| 1064 |
+
raise ModuleNotFoundError(
|
| 1065 |
+
"amdsmi does not seem to be installed or it can't be imported."
|
| 1066 |
+
) from _PYNVML_ERR
|
| 1067 |
+
try:
|
| 1068 |
+
amdsmi.amdsmi_init()
|
| 1069 |
+
except amdsmi.AmdSmiException as e:
|
| 1070 |
+
raise RuntimeError(
|
| 1071 |
+
"amdsmi driver can't be loaded, requires >=ROCm5.6 installation"
|
| 1072 |
+
) from e
|
| 1073 |
+
device = _get_amdsmi_device_index(device)
|
| 1074 |
+
handle = amdsmi.amdsmi_get_processor_handles()[device]
|
| 1075 |
+
return handle
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int:
|
| 1079 |
+
r"""Return the amdsmi index of the device, taking visible_devices into account."""
|
| 1080 |
+
idx = _get_device_index(device, optional=True)
|
| 1081 |
+
visible_devices = _parse_visible_devices()
|
| 1082 |
+
if type(visible_devices[0]) is str:
|
| 1083 |
+
raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings")
|
| 1084 |
+
idx_map = dict(enumerate(cast(List[int], visible_devices)))
|
| 1085 |
+
if idx not in idx_map:
|
| 1086 |
+
raise RuntimeError(
|
| 1087 |
+
f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})"
|
| 1088 |
+
)
|
| 1089 |
+
return idx_map[idx]
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int:
|
| 1093 |
+
handle = _get_amdsmi_handler()
|
| 1094 |
+
device = _get_amdsmi_device_index(device)
|
| 1095 |
+
return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"]
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int:
|
| 1099 |
+
handle = _get_amdsmi_handler()
|
| 1100 |
+
device = _get_amdsmi_device_index(device)
|
| 1101 |
+
handle = amdsmi.amdsmi_get_processor_handles()[device]
|
| 1102 |
+
return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"]
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int:
|
| 1106 |
+
handle = _get_amdsmi_handler(device)
|
| 1107 |
+
return amdsmi.amdsmi_get_temp_metric(
|
| 1108 |
+
handle,
|
| 1109 |
+
amdsmi.AmdSmiTemperatureType.JUNCTION,
|
| 1110 |
+
amdsmi.AmdSmiTemperatureMetric.CURRENT,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int:
|
| 1115 |
+
handle = _get_amdsmi_handler(device)
|
| 1116 |
+
socket_power = amdsmi.amdsmi_get_power_info(handle)["average_socket_power"]
|
| 1117 |
+
if socket_power != "N/A":
|
| 1118 |
+
return socket_power
|
| 1119 |
+
else:
|
| 1120 |
+
return amdsmi.amdsmi_get_power_info(handle)["current_socket_power"]
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
+
def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int:
|
| 1124 |
+
handle = _get_amdsmi_handler(device)
|
| 1125 |
+
clock_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)
|
| 1126 |
+
if "cur_clk" in clock_info: # ROCm 6.2 deprecation
|
| 1127 |
+
return clock_info["cur_clk"]
|
| 1128 |
+
else:
|
| 1129 |
+
return clock_info["clk"]
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
|
| 1133 |
+
r"""Return the percent of time over the past sample period during which global (device)
|
| 1134 |
+
memory was being read or written as given by `nvidia-smi`.
|
| 1135 |
+
|
| 1136 |
+
Args:
|
| 1137 |
+
device (torch.device or int, optional): selected device. Returns
|
| 1138 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 1139 |
+
if :attr:`device` is ``None`` (default).
|
| 1140 |
+
|
| 1141 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 1142 |
+
depending on the product being queried.
|
| 1143 |
+
"""
|
| 1144 |
+
if not torch.version.hip:
|
| 1145 |
+
handle = _get_pynvml_handler()
|
| 1146 |
+
device = _get_nvml_device_index(device)
|
| 1147 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 1148 |
+
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
|
| 1149 |
+
else:
|
| 1150 |
+
return _get_amdsmi_memory_usage(device)
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
def utilization(device: Optional[Union[Device, int]] = None) -> int:
|
| 1154 |
+
r"""Return the percent of time over the past sample period during which one or
|
| 1155 |
+
more kernels was executing on the GPU as given by `nvidia-smi`.
|
| 1156 |
+
|
| 1157 |
+
Args:
|
| 1158 |
+
device (torch.device or int, optional): selected device. Returns
|
| 1159 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 1160 |
+
if :attr:`device` is ``None`` (default).
|
| 1161 |
+
|
| 1162 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 1163 |
+
depending on the product being queried.
|
| 1164 |
+
"""
|
| 1165 |
+
if not torch.version.hip:
|
| 1166 |
+
handle = _get_pynvml_handler(device)
|
| 1167 |
+
device = _get_nvml_device_index(device)
|
| 1168 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 1169 |
+
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
|
| 1170 |
+
else:
|
| 1171 |
+
return _get_amdsmi_utilization(device)
|
| 1172 |
+
|
| 1173 |
+
|
| 1174 |
+
def temperature(device: Optional[Union[Device, int]] = None) -> int:
|
| 1175 |
+
r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
|
| 1176 |
+
|
| 1177 |
+
The average temperature is computed based on past sample period as given by `nvidia-smi`.
|
| 1178 |
+
|
| 1179 |
+
Args:
|
| 1180 |
+
device (torch.device or int, optional): selected device. Returns
|
| 1181 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 1182 |
+
if :attr:`device` is ``None`` (default).
|
| 1183 |
+
|
| 1184 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 1185 |
+
depending on the product being queried.
|
| 1186 |
+
"""
|
| 1187 |
+
if not torch.version.hip:
|
| 1188 |
+
handle = _get_pynvml_handler(device)
|
| 1189 |
+
# 0 refers to the temperature sensor for the GPU die.
|
| 1190 |
+
return pynvml.nvmlDeviceGetTemperature(handle, 0)
|
| 1191 |
+
else:
|
| 1192 |
+
return _get_amdsmi_temperature(device)
|
| 1193 |
+
|
| 1194 |
+
|
| 1195 |
+
def power_draw(device: Optional[Union[Device, int]] = None) -> int:
|
| 1196 |
+
r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
|
| 1197 |
+
over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
|
| 1198 |
+
|
| 1199 |
+
Args:
|
| 1200 |
+
device (torch.device or int, optional): selected device. Returns
|
| 1201 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 1202 |
+
if :attr:`device` is ``None`` (default).
|
| 1203 |
+
|
| 1204 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 1205 |
+
depending on the product being queried.
|
| 1206 |
+
"""
|
| 1207 |
+
if not torch.version.hip:
|
| 1208 |
+
handle = _get_pynvml_handler(device)
|
| 1209 |
+
return pynvml.nvmlDeviceGetPowerUsage(handle)
|
| 1210 |
+
else:
|
| 1211 |
+
return _get_amdsmi_power_draw(device)
|
| 1212 |
+
|
| 1213 |
+
|
| 1214 |
+
def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
|
| 1215 |
+
r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`.
|
| 1216 |
+
|
| 1217 |
+
Args:
|
| 1218 |
+
device (torch.device or int, optional): selected device. Returns
|
| 1219 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 1220 |
+
if :attr:`device` is ``None`` (default).
|
| 1221 |
+
|
| 1222 |
+
Warning: Each sample period may be between 1 second and 1/6 second,
|
| 1223 |
+
depending on the product being queried.
|
| 1224 |
+
"""
|
| 1225 |
+
if not torch.version.hip:
|
| 1226 |
+
handle = _get_pynvml_handler(device)
|
| 1227 |
+
return pynvml.nvmlDeviceGetClockInfo(handle, 1)
|
| 1228 |
+
else:
|
| 1229 |
+
return _get_amdsmi_clock_rate(device)
|
| 1230 |
+
|
| 1231 |
+
|
| 1232 |
+
def _get_device(device: Union[int, str, torch.device]) -> torch.device:
|
| 1233 |
+
r"""Return the torch.device type object from the passed in device.
|
| 1234 |
+
|
| 1235 |
+
Args:
|
| 1236 |
+
device (torch.device or int): selected device.
|
| 1237 |
+
"""
|
| 1238 |
+
if isinstance(device, str):
|
| 1239 |
+
device = torch.device(device)
|
| 1240 |
+
elif isinstance(device, int):
|
| 1241 |
+
device = torch.device("cuda", device)
|
| 1242 |
+
return device
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
def _get_generator(device: torch.device) -> torch._C.Generator:
|
| 1246 |
+
r"""Return the CUDA Generator object for the given device.
|
| 1247 |
+
|
| 1248 |
+
Args:
|
| 1249 |
+
device (torch.device): selected device.
|
| 1250 |
+
"""
|
| 1251 |
+
idx = device.index
|
| 1252 |
+
if idx is None:
|
| 1253 |
+
idx = current_device()
|
| 1254 |
+
return torch.cuda.default_generators[idx]
|
| 1255 |
+
|
| 1256 |
+
|
| 1257 |
+
def _set_rng_state_offset(
|
| 1258 |
+
offset: int, device: Union[int, str, torch.device] = "cuda"
|
| 1259 |
+
) -> None:
|
| 1260 |
+
r"""Set the random number generator state offset of the specified GPU.
|
| 1261 |
+
|
| 1262 |
+
Args:
|
| 1263 |
+
offset (int): The desired offset
|
| 1264 |
+
device (torch.device or int, optional): The device to set the RNG state.
|
| 1265 |
+
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
| 1266 |
+
"""
|
| 1267 |
+
final_device = _get_device(device)
|
| 1268 |
+
|
| 1269 |
+
def cb():
|
| 1270 |
+
default_generator = _get_generator(final_device)
|
| 1271 |
+
default_generator.set_offset(offset)
|
| 1272 |
+
|
| 1273 |
+
_lazy_call(cb)
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int:
|
| 1277 |
+
r"""Return the random number generator state offset of the specified GPU.
|
| 1278 |
+
|
| 1279 |
+
Args:
|
| 1280 |
+
device (torch.device or int, optional): The device to return the RNG state offset of.
|
| 1281 |
+
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
| 1282 |
+
|
| 1283 |
+
.. warning::
|
| 1284 |
+
This function eagerly initializes CUDA.
|
| 1285 |
+
"""
|
| 1286 |
+
_lazy_init()
|
| 1287 |
+
final_device = _get_device(device)
|
| 1288 |
+
default_generator = _get_generator(final_device)
|
| 1289 |
+
return default_generator.get_offset()
|
| 1290 |
+
|
| 1291 |
+
|
| 1292 |
+
from .memory import * # noqa: F403
|
| 1293 |
+
from .random import * # noqa: F403
|
| 1294 |
+
|
| 1295 |
+
|
| 1296 |
+
################################################################################
|
| 1297 |
+
# Define Storage and Tensor classes
|
| 1298 |
+
################################################################################
|
| 1299 |
+
|
| 1300 |
+
|
| 1301 |
+
@staticmethod # type: ignore[misc]
|
| 1302 |
+
def _lazy_new(cls, *args, **kwargs):
|
| 1303 |
+
_lazy_init()
|
| 1304 |
+
# We may need to call lazy init again if we are a forked child
|
| 1305 |
+
# del _CudaBase.__new__
|
| 1306 |
+
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
|
| 1307 |
+
|
| 1308 |
+
|
| 1309 |
+
class _CudaBase:
|
| 1310 |
+
is_cuda = True
|
| 1311 |
+
is_sparse = False
|
| 1312 |
+
|
| 1313 |
+
def type(self, *args, **kwargs):
|
| 1314 |
+
# We could use a Protocol here to tell mypy that self has `get_device` method
|
| 1315 |
+
# but it is only available in the typing module on Python >= 3.8
|
| 1316 |
+
# or on typing_extensions module on Python >= 3.6
|
| 1317 |
+
with device(self.get_device()): # type: ignore[attr-defined]
|
| 1318 |
+
return super().type(*args, **kwargs) # type: ignore[misc]
|
| 1319 |
+
|
| 1320 |
+
__new__ = _lazy_new
|
| 1321 |
+
|
| 1322 |
+
|
| 1323 |
+
from torch.storage import _LegacyStorage, _warn_typed_storage_removal
|
| 1324 |
+
|
| 1325 |
+
|
| 1326 |
+
class _CudaLegacyStorage(_LegacyStorage):
|
| 1327 |
+
@classmethod
|
| 1328 |
+
def from_buffer(cls, *args, **kwargs):
|
| 1329 |
+
_warn_typed_storage_removal()
|
| 1330 |
+
raise RuntimeError("from_buffer: Not available for CUDA storage")
|
| 1331 |
+
|
| 1332 |
+
@classmethod
|
| 1333 |
+
def _new_with_weak_ptr(cls, *args, **kwargs):
|
| 1334 |
+
raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage")
|
| 1335 |
+
|
| 1336 |
+
@classmethod
|
| 1337 |
+
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
|
| 1338 |
+
raise RuntimeError("_new_shared_filename: Not available for CUDA storage")
|
| 1339 |
+
|
| 1340 |
+
|
| 1341 |
+
class ByteStorage(_CudaLegacyStorage):
|
| 1342 |
+
@classproperty
|
| 1343 |
+
def dtype(self):
|
| 1344 |
+
_warn_typed_storage_removal()
|
| 1345 |
+
return self._dtype
|
| 1346 |
+
|
| 1347 |
+
@classproperty
|
| 1348 |
+
def _dtype(self):
|
| 1349 |
+
return torch.uint8
|
| 1350 |
+
|
| 1351 |
+
|
| 1352 |
+
class DoubleStorage(_CudaLegacyStorage):
|
| 1353 |
+
@classproperty
|
| 1354 |
+
def dtype(self):
|
| 1355 |
+
_warn_typed_storage_removal()
|
| 1356 |
+
return self._dtype
|
| 1357 |
+
|
| 1358 |
+
@classproperty
|
| 1359 |
+
def _dtype(self):
|
| 1360 |
+
return torch.double
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
+
class FloatStorage(_CudaLegacyStorage):
|
| 1364 |
+
@classproperty
|
| 1365 |
+
def dtype(self):
|
| 1366 |
+
_warn_typed_storage_removal()
|
| 1367 |
+
return self._dtype
|
| 1368 |
+
|
| 1369 |
+
@classproperty
|
| 1370 |
+
def _dtype(self):
|
| 1371 |
+
return torch.float
|
| 1372 |
+
|
| 1373 |
+
|
| 1374 |
+
class HalfStorage(_CudaLegacyStorage):
|
| 1375 |
+
@classproperty
|
| 1376 |
+
def dtype(self):
|
| 1377 |
+
_warn_typed_storage_removal()
|
| 1378 |
+
return self._dtype
|
| 1379 |
+
|
| 1380 |
+
@classproperty
|
| 1381 |
+
def _dtype(self):
|
| 1382 |
+
return torch.half
|
| 1383 |
+
|
| 1384 |
+
|
| 1385 |
+
class LongStorage(_CudaLegacyStorage):
|
| 1386 |
+
@classproperty
|
| 1387 |
+
def dtype(self):
|
| 1388 |
+
_warn_typed_storage_removal()
|
| 1389 |
+
return self._dtype
|
| 1390 |
+
|
| 1391 |
+
@classproperty
|
| 1392 |
+
def _dtype(self):
|
| 1393 |
+
return torch.long
|
| 1394 |
+
|
| 1395 |
+
|
| 1396 |
+
class IntStorage(_CudaLegacyStorage):
|
| 1397 |
+
@classproperty
|
| 1398 |
+
def dtype(self):
|
| 1399 |
+
_warn_typed_storage_removal()
|
| 1400 |
+
return self._dtype
|
| 1401 |
+
|
| 1402 |
+
@classproperty
|
| 1403 |
+
def _dtype(self):
|
| 1404 |
+
return torch.int
|
| 1405 |
+
|
| 1406 |
+
|
| 1407 |
+
class ShortStorage(_CudaLegacyStorage):
|
| 1408 |
+
@classproperty
|
| 1409 |
+
def dtype(self):
|
| 1410 |
+
_warn_typed_storage_removal()
|
| 1411 |
+
return self._dtype
|
| 1412 |
+
|
| 1413 |
+
@classproperty
|
| 1414 |
+
def _dtype(self):
|
| 1415 |
+
return torch.short
|
| 1416 |
+
|
| 1417 |
+
|
| 1418 |
+
class CharStorage(_CudaLegacyStorage):
|
| 1419 |
+
@classproperty
|
| 1420 |
+
def dtype(self):
|
| 1421 |
+
_warn_typed_storage_removal()
|
| 1422 |
+
return self._dtype
|
| 1423 |
+
|
| 1424 |
+
@classproperty
|
| 1425 |
+
def _dtype(self):
|
| 1426 |
+
return torch.int8
|
| 1427 |
+
|
| 1428 |
+
|
| 1429 |
+
class BoolStorage(_CudaLegacyStorage):
|
| 1430 |
+
@classproperty
|
| 1431 |
+
def dtype(self):
|
| 1432 |
+
_warn_typed_storage_removal()
|
| 1433 |
+
return self._dtype
|
| 1434 |
+
|
| 1435 |
+
@classproperty
|
| 1436 |
+
def _dtype(self):
|
| 1437 |
+
return torch.bool
|
| 1438 |
+
|
| 1439 |
+
|
| 1440 |
+
class BFloat16Storage(_CudaLegacyStorage):
|
| 1441 |
+
@classproperty
|
| 1442 |
+
def dtype(self):
|
| 1443 |
+
_warn_typed_storage_removal()
|
| 1444 |
+
return self._dtype
|
| 1445 |
+
|
| 1446 |
+
@classproperty
|
| 1447 |
+
def _dtype(self):
|
| 1448 |
+
return torch.bfloat16
|
| 1449 |
+
|
| 1450 |
+
|
| 1451 |
+
class ComplexDoubleStorage(_CudaLegacyStorage):
|
| 1452 |
+
@classproperty
|
| 1453 |
+
def dtype(self):
|
| 1454 |
+
_warn_typed_storage_removal()
|
| 1455 |
+
return self._dtype
|
| 1456 |
+
|
| 1457 |
+
@classproperty
|
| 1458 |
+
def _dtype(self):
|
| 1459 |
+
return torch.cdouble
|
| 1460 |
+
|
| 1461 |
+
|
| 1462 |
+
class ComplexFloatStorage(_CudaLegacyStorage):
|
| 1463 |
+
@classproperty
|
| 1464 |
+
def dtype(self):
|
| 1465 |
+
_warn_typed_storage_removal()
|
| 1466 |
+
return self._dtype
|
| 1467 |
+
|
| 1468 |
+
@classproperty
|
| 1469 |
+
def _dtype(self):
|
| 1470 |
+
return torch.cfloat
|
| 1471 |
+
|
| 1472 |
+
|
| 1473 |
+
del _LegacyStorage
|
| 1474 |
+
del _CudaLegacyStorage
|
| 1475 |
+
|
| 1476 |
+
torch._storage_classes.add(DoubleStorage)
|
| 1477 |
+
torch._storage_classes.add(FloatStorage)
|
| 1478 |
+
torch._storage_classes.add(LongStorage)
|
| 1479 |
+
torch._storage_classes.add(IntStorage)
|
| 1480 |
+
torch._storage_classes.add(ShortStorage)
|
| 1481 |
+
torch._storage_classes.add(CharStorage)
|
| 1482 |
+
torch._storage_classes.add(ByteStorage)
|
| 1483 |
+
torch._storage_classes.add(HalfStorage)
|
| 1484 |
+
torch._storage_classes.add(BoolStorage)
|
| 1485 |
+
torch._storage_classes.add(BFloat16Storage)
|
| 1486 |
+
torch._storage_classes.add(ComplexDoubleStorage)
|
| 1487 |
+
torch._storage_classes.add(ComplexFloatStorage)
|
| 1488 |
+
|
| 1489 |
+
|
| 1490 |
+
class _WrappedTritonKernel:
|
| 1491 |
+
"""Just a simple wrapper to store some metadata for testing purposes."""
|
| 1492 |
+
|
| 1493 |
+
def __init__(self, kernel):
|
| 1494 |
+
self.kernel = kernel
|
| 1495 |
+
self.kernel_invoked = False
|
| 1496 |
+
|
| 1497 |
+
def __call__(self, *args, **kwargs):
|
| 1498 |
+
res = self.kernel(*args, **kwargs)
|
| 1499 |
+
self.kernel_invoked = True
|
| 1500 |
+
return res
|
| 1501 |
+
|
| 1502 |
+
|
| 1503 |
+
def _register_triton_kernels():
|
| 1504 |
+
if torch._running_with_deploy():
|
| 1505 |
+
return
|
| 1506 |
+
|
| 1507 |
+
@_WrappedTritonKernel
|
| 1508 |
+
def kernel_impl(*args, **kwargs):
|
| 1509 |
+
from torch.sparse._triton_ops import bsr_dense_mm
|
| 1510 |
+
|
| 1511 |
+
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
|
| 1512 |
+
|
| 1513 |
+
@_WrappedTritonKernel
|
| 1514 |
+
def addmm_kernel_impl(*args, **kwargs):
|
| 1515 |
+
from torch.sparse._triton_ops import bsr_dense_addmm
|
| 1516 |
+
|
| 1517 |
+
return bsr_dense_addmm(*args, skip_checks=True, **kwargs)
|
| 1518 |
+
|
| 1519 |
+
has_triton = importlib.util.find_spec("triton") is not None
|
| 1520 |
+
if has_triton:
|
| 1521 |
+
torch._TritonLibrary.registerOp(
|
| 1522 |
+
"_triton_bsr_dense_mm_out",
|
| 1523 |
+
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
|
| 1524 |
+
kernel_impl,
|
| 1525 |
+
"SparseCsrCUDA",
|
| 1526 |
+
)
|
| 1527 |
+
|
| 1528 |
+
torch._TritonLibrary.registerOp(
|
| 1529 |
+
"_triton_bsr_dense_addmm_out",
|
| 1530 |
+
(
|
| 1531 |
+
"_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense,"
|
| 1532 |
+
" *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)"
|
| 1533 |
+
),
|
| 1534 |
+
addmm_kernel_impl,
|
| 1535 |
+
"SparseCsrCUDA",
|
| 1536 |
+
)
|
| 1537 |
+
|
| 1538 |
+
|
| 1539 |
+
_lazy_call(_register_triton_kernels)
|
| 1540 |
+
|
| 1541 |
+
|
| 1542 |
+
from . import amp, jiterator, nvtx, profiler, sparse, tunable
|
| 1543 |
+
|
| 1544 |
+
|
| 1545 |
+
__all__ = [
|
| 1546 |
+
# Typed storage and tensors
|
| 1547 |
+
"BFloat16Storage",
|
| 1548 |
+
"BFloat16Tensor",
|
| 1549 |
+
"BoolStorage",
|
| 1550 |
+
"BoolTensor",
|
| 1551 |
+
"ByteStorage",
|
| 1552 |
+
"ByteTensor",
|
| 1553 |
+
"CharStorage",
|
| 1554 |
+
"CharTensor",
|
| 1555 |
+
"ComplexDoubleStorage",
|
| 1556 |
+
"ComplexFloatStorage",
|
| 1557 |
+
"DoubleStorage",
|
| 1558 |
+
"DoubleTensor",
|
| 1559 |
+
"FloatStorage",
|
| 1560 |
+
"FloatTensor",
|
| 1561 |
+
"HalfStorage",
|
| 1562 |
+
"HalfTensor",
|
| 1563 |
+
"IntStorage",
|
| 1564 |
+
"IntTensor",
|
| 1565 |
+
"LongStorage",
|
| 1566 |
+
"LongTensor",
|
| 1567 |
+
"ShortStorage",
|
| 1568 |
+
"ShortTensor",
|
| 1569 |
+
"CUDAGraph",
|
| 1570 |
+
"CudaError",
|
| 1571 |
+
"DeferredCudaCallError",
|
| 1572 |
+
"Event",
|
| 1573 |
+
"ExternalStream",
|
| 1574 |
+
"Stream",
|
| 1575 |
+
"StreamContext",
|
| 1576 |
+
"amp",
|
| 1577 |
+
"caching_allocator_alloc",
|
| 1578 |
+
"caching_allocator_delete",
|
| 1579 |
+
"can_device_access_peer",
|
| 1580 |
+
"check_error",
|
| 1581 |
+
"cudaStatus",
|
| 1582 |
+
"cudart",
|
| 1583 |
+
"current_blas_handle",
|
| 1584 |
+
"current_device",
|
| 1585 |
+
"current_stream",
|
| 1586 |
+
"default_generators",
|
| 1587 |
+
"default_stream",
|
| 1588 |
+
"device",
|
| 1589 |
+
"device_count",
|
| 1590 |
+
"device_of",
|
| 1591 |
+
"empty_cache",
|
| 1592 |
+
"get_allocator_backend",
|
| 1593 |
+
"CUDAPluggableAllocator",
|
| 1594 |
+
"change_current_allocator",
|
| 1595 |
+
"get_arch_list",
|
| 1596 |
+
"get_device_capability",
|
| 1597 |
+
"get_device_name",
|
| 1598 |
+
"get_device_properties",
|
| 1599 |
+
"get_gencode_flags",
|
| 1600 |
+
"get_rng_state",
|
| 1601 |
+
"get_rng_state_all",
|
| 1602 |
+
"get_sync_debug_mode",
|
| 1603 |
+
"graph",
|
| 1604 |
+
"graph_pool_handle",
|
| 1605 |
+
"graphs",
|
| 1606 |
+
"has_half",
|
| 1607 |
+
"has_magma",
|
| 1608 |
+
"init",
|
| 1609 |
+
"initial_seed",
|
| 1610 |
+
"ipc_collect",
|
| 1611 |
+
"is_available",
|
| 1612 |
+
"is_bf16_supported",
|
| 1613 |
+
"is_current_stream_capturing",
|
| 1614 |
+
"is_initialized",
|
| 1615 |
+
"jiterator",
|
| 1616 |
+
"list_gpu_processes",
|
| 1617 |
+
"make_graphed_callables",
|
| 1618 |
+
"manual_seed",
|
| 1619 |
+
"manual_seed_all",
|
| 1620 |
+
"max_memory_allocated",
|
| 1621 |
+
"max_memory_cached",
|
| 1622 |
+
"max_memory_reserved",
|
| 1623 |
+
"mem_get_info",
|
| 1624 |
+
"memory",
|
| 1625 |
+
"memory_allocated",
|
| 1626 |
+
"memory_cached",
|
| 1627 |
+
"memory_reserved",
|
| 1628 |
+
"memory_snapshot",
|
| 1629 |
+
"memory_stats",
|
| 1630 |
+
"memory_stats_as_nested_dict",
|
| 1631 |
+
"memory_summary",
|
| 1632 |
+
"memory_usage",
|
| 1633 |
+
"MemPool",
|
| 1634 |
+
"MemPoolContext",
|
| 1635 |
+
"use_mem_pool",
|
| 1636 |
+
"temperature",
|
| 1637 |
+
"power_draw",
|
| 1638 |
+
"clock_rate",
|
| 1639 |
+
"nccl",
|
| 1640 |
+
"nvtx",
|
| 1641 |
+
"profiler",
|
| 1642 |
+
"random",
|
| 1643 |
+
"reset_accumulated_memory_stats",
|
| 1644 |
+
"reset_max_memory_allocated",
|
| 1645 |
+
"reset_max_memory_cached",
|
| 1646 |
+
"reset_peak_memory_stats",
|
| 1647 |
+
"seed",
|
| 1648 |
+
"seed_all",
|
| 1649 |
+
"set_device",
|
| 1650 |
+
"set_per_process_memory_fraction",
|
| 1651 |
+
"set_rng_state",
|
| 1652 |
+
"set_rng_state_all",
|
| 1653 |
+
"set_stream",
|
| 1654 |
+
"set_sync_debug_mode",
|
| 1655 |
+
"sparse",
|
| 1656 |
+
"stream",
|
| 1657 |
+
"streams",
|
| 1658 |
+
"synchronize",
|
| 1659 |
+
"tunable",
|
| 1660 |
+
"utilization",
|
| 1661 |
+
]
|
.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/memory.cpython-311.pyc
ADDED
|
Binary file (50.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/cuda/_gpu_trace.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
|
| 3 |
+
from torch._utils import CallbackRegistry
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
| 7 |
+
"CUDA event creation"
|
| 8 |
+
)
|
| 9 |
+
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
| 10 |
+
"CUDA event deletion"
|
| 11 |
+
)
|
| 12 |
+
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
|
| 13 |
+
"CUDA event record"
|
| 14 |
+
)
|
| 15 |
+
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
|
| 16 |
+
"CUDA event wait"
|
| 17 |
+
)
|
| 18 |
+
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
| 19 |
+
"CUDA memory allocation"
|
| 20 |
+
)
|
| 21 |
+
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
| 22 |
+
"CUDA memory deallocation"
|
| 23 |
+
)
|
| 24 |
+
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
| 25 |
+
"CUDA stream creation"
|
| 26 |
+
)
|
| 27 |
+
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
|
| 28 |
+
"CUDA device synchronization"
|
| 29 |
+
)
|
| 30 |
+
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
| 31 |
+
"CUDA stream synchronization"
|
| 32 |
+
)
|
| 33 |
+
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
| 34 |
+
"CUDA event synchronization"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
|
| 39 |
+
EventCreationCallbacks.add_callback(cb)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
|
| 43 |
+
EventDeletionCallbacks.add_callback(cb)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
|
| 47 |
+
EventRecordCallbacks.add_callback(cb)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
|
| 51 |
+
EventWaitCallbacks.add_callback(cb)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
|
| 55 |
+
MemoryAllocationCallbacks.add_callback(cb)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
|
| 59 |
+
MemoryDeallocationCallbacks.add_callback(cb)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
|
| 63 |
+
StreamCreationCallbacks.add_callback(cb)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
|
| 67 |
+
DeviceSynchronizationCallbacks.add_callback(cb)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
|
| 71 |
+
StreamSynchronizationCallbacks.add_callback(cb)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
|
| 75 |
+
EventSynchronizationCallbacks.add_callback(cb)
|
.venv/lib/python3.11/site-packages/torch/cuda/_memory_viz.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import pickle
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import io
|
| 6 |
+
import subprocess
|
| 7 |
+
import json
|
| 8 |
+
from functools import lru_cache
|
| 9 |
+
from typing import Any
|
| 10 |
+
from itertools import groupby
|
| 11 |
+
import base64
|
| 12 |
+
import warnings
|
| 13 |
+
import operator
|
| 14 |
+
|
| 15 |
+
cache = lru_cache(None)
|
| 16 |
+
|
| 17 |
+
__all__ = ["format_flamegraph", "segments", "memory", "compare"]
|
| 18 |
+
|
| 19 |
+
def _frame_fmt(f, full_filename=False):
|
| 20 |
+
i = f['line']
|
| 21 |
+
fname = f['filename']
|
| 22 |
+
if not full_filename:
|
| 23 |
+
fname = fname.split('/')[-1]
|
| 24 |
+
func = f['name']
|
| 25 |
+
return f'{fname}:{i}:{func}'
|
| 26 |
+
|
| 27 |
+
@cache
|
| 28 |
+
def _frame_filter(name, filename):
|
| 29 |
+
omit_functions = [
|
| 30 |
+
"unwind::unwind",
|
| 31 |
+
"CapturedTraceback::gather",
|
| 32 |
+
"gather_with_cpp",
|
| 33 |
+
"_start",
|
| 34 |
+
"__libc_start_main",
|
| 35 |
+
"PyEval_",
|
| 36 |
+
"PyObject_",
|
| 37 |
+
"PyFunction_",
|
| 38 |
+
]
|
| 39 |
+
omit_filenames = [
|
| 40 |
+
"core/boxing",
|
| 41 |
+
"/Register",
|
| 42 |
+
"/Redispatch",
|
| 43 |
+
"pythonrun.c",
|
| 44 |
+
"Modules/main.c",
|
| 45 |
+
"Objects/call.c",
|
| 46 |
+
"Objects/methodobject.c",
|
| 47 |
+
"pycore_ceval.h",
|
| 48 |
+
"ceval.c",
|
| 49 |
+
"cpython/abstract.h",
|
| 50 |
+
]
|
| 51 |
+
for of in omit_functions:
|
| 52 |
+
if of in name:
|
| 53 |
+
return False
|
| 54 |
+
for of in omit_filenames:
|
| 55 |
+
if of in filename:
|
| 56 |
+
return False
|
| 57 |
+
return True
|
| 58 |
+
|
| 59 |
+
def _frames_fmt(frames, full_filename=False, reverse=False):
|
| 60 |
+
if reverse:
|
| 61 |
+
frames = reversed(frames)
|
| 62 |
+
return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
|
| 63 |
+
|
| 64 |
+
def _block_extra_legacy(b):
|
| 65 |
+
if 'history' in b:
|
| 66 |
+
frames = b['history'][0].get('frames', [])
|
| 67 |
+
real_size = b['history'][0]['real_size']
|
| 68 |
+
else:
|
| 69 |
+
real_size = b.get('requested_size', b['size'])
|
| 70 |
+
frames = []
|
| 71 |
+
return frames, real_size
|
| 72 |
+
|
| 73 |
+
def _block_extra(b):
|
| 74 |
+
if 'frames' not in b:
|
| 75 |
+
# old snapshot format made it more complicated to get frames/allocated size
|
| 76 |
+
return _block_extra_legacy(b)
|
| 77 |
+
return b['frames'], b['requested_size']
|
| 78 |
+
|
| 79 |
+
def format_flamegraph(flamegraph_lines, flamegraph_script=None):
|
| 80 |
+
if flamegraph_script is None:
|
| 81 |
+
flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl'
|
| 82 |
+
if not os.path.exists(flamegraph_script):
|
| 83 |
+
import urllib.request
|
| 84 |
+
print(f"Downloading flamegraph.pl to: {flamegraph_script}")
|
| 85 |
+
urllib.request.urlretrieve(
|
| 86 |
+
'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script)
|
| 87 |
+
subprocess.check_call(['chmod', '+x', flamegraph_script])
|
| 88 |
+
args = [flamegraph_script, '--countname', 'bytes']
|
| 89 |
+
p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8')
|
| 90 |
+
assert p.stdin is not None
|
| 91 |
+
assert p.stdout is not None
|
| 92 |
+
p.stdin.write(flamegraph_lines)
|
| 93 |
+
p.stdin.close()
|
| 94 |
+
result = p.stdout.read()
|
| 95 |
+
p.stdout.close()
|
| 96 |
+
p.wait()
|
| 97 |
+
assert p.wait() == 0
|
| 98 |
+
return result
|
| 99 |
+
|
| 100 |
+
def _write_blocks(f, prefix, blocks):
|
| 101 |
+
def frames_fragment(frames):
|
| 102 |
+
if not frames:
|
| 103 |
+
return "<non-python>"
|
| 104 |
+
return ';'.join(_frames_fmt(frames, reverse=True))
|
| 105 |
+
for b in blocks:
|
| 106 |
+
if 'history' not in b:
|
| 107 |
+
frames, accounted_for_size = _block_extra(b)
|
| 108 |
+
f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n')
|
| 109 |
+
else:
|
| 110 |
+
accounted_for_size = 0
|
| 111 |
+
for h in b['history']:
|
| 112 |
+
sz = h['real_size']
|
| 113 |
+
accounted_for_size += sz
|
| 114 |
+
if 'frames' in h:
|
| 115 |
+
frames = h['frames']
|
| 116 |
+
f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n')
|
| 117 |
+
else:
|
| 118 |
+
f.write(f'{prefix};{b["state"]};<no-context> {sz}\n')
|
| 119 |
+
gaps = b['size'] - accounted_for_size
|
| 120 |
+
if gaps:
|
| 121 |
+
f.write(f'{prefix};{b["state"]};<gaps> {gaps}\n')
|
| 122 |
+
|
| 123 |
+
def segments(snapshot, format_flamegraph=format_flamegraph):
|
| 124 |
+
f = io.StringIO()
|
| 125 |
+
for seg in snapshot['segments']:
|
| 126 |
+
prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
|
| 127 |
+
_write_blocks(f, prefix, seg['blocks'])
|
| 128 |
+
return format_flamegraph(f.getvalue())
|
| 129 |
+
|
| 130 |
+
def memory(snapshot, format_flamegraph=format_flamegraph):
|
| 131 |
+
f = io.StringIO()
|
| 132 |
+
for seg in snapshot['segments']:
|
| 133 |
+
prefix = f'stream_{seg["stream"]}'
|
| 134 |
+
_write_blocks(f, prefix, seg['blocks'])
|
| 135 |
+
return format_flamegraph(f.getvalue())
|
| 136 |
+
|
| 137 |
+
def compare(before, after, format_flamegraph=format_flamegraph):
|
| 138 |
+
def _seg_key(seg):
|
| 139 |
+
return (seg['address'], seg['total_size'])
|
| 140 |
+
|
| 141 |
+
def _seg_info(seg):
|
| 142 |
+
return f'stream_{seg["stream"]};seg_{seg["address"]}'
|
| 143 |
+
|
| 144 |
+
f = io.StringIO()
|
| 145 |
+
|
| 146 |
+
before_segs = {_seg_key(seg) for seg in before}
|
| 147 |
+
after_segs = {_seg_key(seg) for seg in after}
|
| 148 |
+
|
| 149 |
+
print(f'only_before = {[a for a, _ in (before_segs - after_segs)]}')
|
| 150 |
+
print(f'only_after = {[a for a, _ in (after_segs - before_segs)]}')
|
| 151 |
+
|
| 152 |
+
for seg in before:
|
| 153 |
+
if _seg_key(seg) not in after_segs:
|
| 154 |
+
_write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks'])
|
| 155 |
+
|
| 156 |
+
for seg in after:
|
| 157 |
+
if _seg_key(seg) not in before_segs:
|
| 158 |
+
_write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks'])
|
| 159 |
+
|
| 160 |
+
return format_flamegraph(f.getvalue())
|
| 161 |
+
|
| 162 |
+
def _format_size(num):
|
| 163 |
+
# https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
|
| 164 |
+
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
|
| 165 |
+
if abs(num) < 1024.0:
|
| 166 |
+
return f"{num:3.1f}{unit}B"
|
| 167 |
+
num /= 1024.0
|
| 168 |
+
return f"{num:.1f}YiB"
|
| 169 |
+
|
| 170 |
+
class Bytes:
|
| 171 |
+
def __init__(self, value):
|
| 172 |
+
self.value = value
|
| 173 |
+
|
| 174 |
+
def __add__(self, rhs):
|
| 175 |
+
return Bytes(self.value + rhs)
|
| 176 |
+
|
| 177 |
+
def __repr__(self):
|
| 178 |
+
return _format_size(self.value)
|
| 179 |
+
|
| 180 |
+
def calc_active(seg):
|
| 181 |
+
return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated')
|
| 182 |
+
|
| 183 |
+
def _report_free(free_external, free_internal):
|
| 184 |
+
total = free_external + free_internal
|
| 185 |
+
suffix = ''
|
| 186 |
+
if total != 0:
|
| 187 |
+
pct = (free_internal / total) * 100
|
| 188 |
+
suffix = f' ({pct:.1f}% internal)'
|
| 189 |
+
return f'{Bytes(total)}{suffix}'
|
| 190 |
+
|
| 191 |
+
PAGE_SIZE = 1024 * 1024 * 20
|
| 192 |
+
legend = f"""\
|
| 193 |
+
|
| 194 |
+
Legend:
|
| 195 |
+
[a ] - a segment in the allocator
|
| 196 |
+
^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
|
| 197 |
+
a-z: pages filled with a single block's content
|
| 198 |
+
' ': page is completely free
|
| 199 |
+
*: page if completely full with multiple blocks
|
| 200 |
+
0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
|
| 201 |
+
(X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
def segsum(data):
|
| 205 |
+
r"""Visually reports how the allocator has filled its segments.
|
| 206 |
+
|
| 207 |
+
This printout can help debug fragmentation issues since free fragments
|
| 208 |
+
will appear as gaps in this printout. The amount of free space is reported
|
| 209 |
+
for each segment.
|
| 210 |
+
We distinguish between internal free memory which occurs because the
|
| 211 |
+
allocator rounds the allocation size, and external free memory, which are
|
| 212 |
+
the gaps between allocations in a segment.
|
| 213 |
+
Args:
|
| 214 |
+
data: snapshot dictionary created from _snapshot()
|
| 215 |
+
"""
|
| 216 |
+
segments = []
|
| 217 |
+
out = io.StringIO()
|
| 218 |
+
out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
|
| 219 |
+
total_reserved = 0
|
| 220 |
+
total_allocated = 0
|
| 221 |
+
free_external = 0
|
| 222 |
+
free_internal = 0
|
| 223 |
+
for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))):
|
| 224 |
+
total_reserved += seg['total_size']
|
| 225 |
+
|
| 226 |
+
seg_free_external = 0
|
| 227 |
+
seg_free_internal = 0
|
| 228 |
+
seg_allocated = 0
|
| 229 |
+
all_ranges = []
|
| 230 |
+
boffset = 0
|
| 231 |
+
for b in seg['blocks']:
|
| 232 |
+
active = b['state'] == 'active_allocated'
|
| 233 |
+
if active:
|
| 234 |
+
_, allocated_size = _block_extra(b)
|
| 235 |
+
all_ranges.append((boffset, allocated_size, True))
|
| 236 |
+
seg_allocated += allocated_size
|
| 237 |
+
seg_free_internal += b['size'] - allocated_size
|
| 238 |
+
else:
|
| 239 |
+
seg_free_external += b['size']
|
| 240 |
+
|
| 241 |
+
boffset += b['size']
|
| 242 |
+
|
| 243 |
+
total_allocated += seg_allocated
|
| 244 |
+
free_external += seg_free_external
|
| 245 |
+
free_internal += seg_free_internal
|
| 246 |
+
|
| 247 |
+
nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1
|
| 248 |
+
occupied = [' ' for _ in range(nseg)]
|
| 249 |
+
frac = [0.0 for _ in range(nseg)]
|
| 250 |
+
active_size = 0
|
| 251 |
+
for i, (start_, size, active) in enumerate(all_ranges):
|
| 252 |
+
active_size += size
|
| 253 |
+
finish_ = (start_ + size)
|
| 254 |
+
start = start_ // PAGE_SIZE
|
| 255 |
+
finish = (finish_ - 1) // PAGE_SIZE + 1
|
| 256 |
+
m = chr(ord('a' if active else 'A') + (i % 26))
|
| 257 |
+
for j in range(start, finish):
|
| 258 |
+
s = max(start_, j * PAGE_SIZE)
|
| 259 |
+
e = min(finish_, (j + 1) * PAGE_SIZE)
|
| 260 |
+
frac[j] += (e - s) / PAGE_SIZE
|
| 261 |
+
if occupied[j] != ' ':
|
| 262 |
+
occupied[j] = '0123456789*'[int(frac[j] * 10)]
|
| 263 |
+
else:
|
| 264 |
+
occupied[j] = m
|
| 265 |
+
stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}'
|
| 266 |
+
body = ''.join(occupied)
|
| 267 |
+
assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size']
|
| 268 |
+
stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else ''
|
| 269 |
+
if seg['total_size'] >= PAGE_SIZE:
|
| 270 |
+
out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, '
|
| 271 |
+
f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n')
|
| 272 |
+
out.write(f'segments: {len(data["segments"])}\n')
|
| 273 |
+
out.write(f'total_reserved: {Bytes(total_reserved)}\n')
|
| 274 |
+
out.write(f'total_allocated: {Bytes(total_allocated)}\n')
|
| 275 |
+
internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
|
| 276 |
+
out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
|
| 277 |
+
out.write(legend)
|
| 278 |
+
assert free_internal + free_external + total_allocated == total_reserved
|
| 279 |
+
return out.getvalue()
|
| 280 |
+
|
| 281 |
+
def trace(data):
|
| 282 |
+
out = io.StringIO()
|
| 283 |
+
|
| 284 |
+
def format(entries):
|
| 285 |
+
segment_intervals : list = []
|
| 286 |
+
segment_addr_to_name = {}
|
| 287 |
+
allocation_addr_to_name = {}
|
| 288 |
+
|
| 289 |
+
free_names : list = []
|
| 290 |
+
next_name = 0
|
| 291 |
+
|
| 292 |
+
def _name():
|
| 293 |
+
nonlocal next_name
|
| 294 |
+
if free_names:
|
| 295 |
+
return free_names.pop()
|
| 296 |
+
r, m = next_name // 26, next_name % 26
|
| 297 |
+
next_name += 1
|
| 298 |
+
return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
|
| 299 |
+
|
| 300 |
+
def find_segment(addr):
|
| 301 |
+
for name, saddr, size in segment_intervals:
|
| 302 |
+
if addr >= saddr and addr < saddr + size:
|
| 303 |
+
return name, saddr
|
| 304 |
+
for i, seg in enumerate(data['segments']):
|
| 305 |
+
saddr = seg['address']
|
| 306 |
+
size = seg['allocated_size']
|
| 307 |
+
if addr >= saddr and addr < saddr + size:
|
| 308 |
+
return f'seg_{i}', saddr
|
| 309 |
+
return None, None
|
| 310 |
+
count = 0
|
| 311 |
+
out.write(f'{len(entries)} entries\n')
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
total_reserved = 0
|
| 315 |
+
for seg in data['segments']:
|
| 316 |
+
total_reserved += seg['total_size']
|
| 317 |
+
|
| 318 |
+
for count, e in enumerate(entries):
|
| 319 |
+
if e['action'] == 'alloc':
|
| 320 |
+
addr, size = e['addr'], e['size']
|
| 321 |
+
n = _name()
|
| 322 |
+
seg_name, seg_addr = find_segment(addr)
|
| 323 |
+
if seg_name is None:
|
| 324 |
+
seg_name = "MEM"
|
| 325 |
+
offset = addr
|
| 326 |
+
else:
|
| 327 |
+
offset = addr - seg_addr
|
| 328 |
+
out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n')
|
| 329 |
+
allocation_addr_to_name[addr] = (n, size, count)
|
| 330 |
+
count += size
|
| 331 |
+
elif e['action'] == 'free_requested':
|
| 332 |
+
addr, size = e['addr'], e['size']
|
| 333 |
+
name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
|
| 334 |
+
out.write(f'del {name} # {Bytes(size)}\n')
|
| 335 |
+
elif e['action'] == 'free_completed':
|
| 336 |
+
addr, size = e['addr'], e['size']
|
| 337 |
+
count -= size
|
| 338 |
+
name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
|
| 339 |
+
out.write(f'# free completed for {name} {Bytes(size)}\n')
|
| 340 |
+
if name in allocation_addr_to_name:
|
| 341 |
+
free_names.append(name)
|
| 342 |
+
del allocation_addr_to_name[name]
|
| 343 |
+
elif e['action'] == 'segment_alloc':
|
| 344 |
+
addr, size = e['addr'], e['size']
|
| 345 |
+
name = _name()
|
| 346 |
+
out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n')
|
| 347 |
+
segment_intervals.append((name, addr, size))
|
| 348 |
+
segment_addr_to_name[addr] = name
|
| 349 |
+
elif e['action'] == 'segment_free':
|
| 350 |
+
addr, size = e['addr'], e['size']
|
| 351 |
+
name = segment_addr_to_name.get(addr, addr)
|
| 352 |
+
out.write(f'cudaFree({name}) # {Bytes(size)}\n')
|
| 353 |
+
if name in segment_addr_to_name:
|
| 354 |
+
free_names.append(name)
|
| 355 |
+
del segment_addr_to_name[name]
|
| 356 |
+
elif e['action'] == 'oom':
|
| 357 |
+
size = e['size']
|
| 358 |
+
free = e['device_free']
|
| 359 |
+
out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
|
| 360 |
+
else:
|
| 361 |
+
out.write(f'{e}\n')
|
| 362 |
+
out.write(f"TOTAL MEM: {Bytes(count)}")
|
| 363 |
+
for i, d in enumerate(data['device_traces']):
|
| 364 |
+
if d:
|
| 365 |
+
out.write(f'Device {i} ----------------\n')
|
| 366 |
+
format(d)
|
| 367 |
+
return out.getvalue()
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
_memory_viz_template = r"""
|
| 371 |
+
<!DOCTYPE html>
|
| 372 |
+
<html>
|
| 373 |
+
<head>
|
| 374 |
+
</head>
|
| 375 |
+
<body>
|
| 376 |
+
<script type="module">
|
| 377 |
+
import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
|
| 378 |
+
const local_files = $SNAPSHOT
|
| 379 |
+
add_local_files(local_files, $VIZ_KIND)
|
| 380 |
+
</script>
|
| 381 |
+
</body>
|
| 382 |
+
"""
|
| 383 |
+
|
| 384 |
+
def _format_viz(data, viz_kind, device):
|
| 385 |
+
if device is not None:
|
| 386 |
+
warnings.warn(
|
| 387 |
+
'device argument is deprecated, plots now contain all device',
|
| 388 |
+
FutureWarning,
|
| 389 |
+
stacklevel=3,
|
| 390 |
+
)
|
| 391 |
+
buffer = pickle.dumps(data)
|
| 392 |
+
buffer += b'\x00' * (3 - len(buffer) % 3)
|
| 393 |
+
# Encode the buffer with base64
|
| 394 |
+
encoded_buffer = base64.b64encode(buffer).decode('utf-8')
|
| 395 |
+
|
| 396 |
+
json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}])
|
| 397 |
+
return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \
|
| 398 |
+
.replace('$SNAPSHOT', json_format)
|
| 399 |
+
|
| 400 |
+
def trace_plot(data, device=None, plot_segments=False):
|
| 401 |
+
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
data: Memory snapshot as generated from torch.cuda.memory._snapshot()
|
| 405 |
+
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
|
| 406 |
+
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
|
| 407 |
+
Defaults to False.
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
str: HTML of visualization
|
| 411 |
+
"""
|
| 412 |
+
return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def _profile_to_snapshot(profile):
|
| 416 |
+
import torch
|
| 417 |
+
from torch.profiler._memory_profiler import Action, TensorKey
|
| 418 |
+
from torch._C._profiler import _EventType
|
| 419 |
+
memory_profile = profile._memory_profile()
|
| 420 |
+
|
| 421 |
+
allocation_stacks = {}
|
| 422 |
+
for event in memory_profile._op_tree.sorted_nodes:
|
| 423 |
+
if event.tag == _EventType.Allocation:
|
| 424 |
+
parent = event.parent
|
| 425 |
+
python_parents = []
|
| 426 |
+
while parent:
|
| 427 |
+
if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
|
| 428 |
+
python_parents.append(parent)
|
| 429 |
+
parent = parent.parent
|
| 430 |
+
key = TensorKey.from_allocation(event.extra_fields)
|
| 431 |
+
|
| 432 |
+
# Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
|
| 433 |
+
# key will be None. I should add some way to identify these, I just haven't yet.
|
| 434 |
+
if key and event.extra_fields.alloc_size > 0:
|
| 435 |
+
allocation_stacks[key] = python_parents
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
device_count = torch.cuda.device_count()
|
| 439 |
+
snapshot = {
|
| 440 |
+
'device_traces': [[] for _ in range(device_count + 1)],
|
| 441 |
+
'segments': [{'device': device,
|
| 442 |
+
'address': None,
|
| 443 |
+
'total_size': 0,
|
| 444 |
+
'stream': 0,
|
| 445 |
+
'blocks': []} for device in range(device_count + 1)]
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
def to_device(device):
|
| 449 |
+
if device.type == 'cuda':
|
| 450 |
+
return device.index
|
| 451 |
+
else:
|
| 452 |
+
return device_count
|
| 453 |
+
|
| 454 |
+
def allocate(size, tensor_key, version, during_trace=True):
|
| 455 |
+
device = to_device(tensor_key.device)
|
| 456 |
+
addr = tensor_key.storage.ptr
|
| 457 |
+
|
| 458 |
+
seg = snapshot['segments'][device] # type: ignore[index]
|
| 459 |
+
if seg['address'] is None or seg['address'] > addr:
|
| 460 |
+
seg['address'] = addr
|
| 461 |
+
seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
|
| 462 |
+
category = memory_profile._categories.get(tensor_key, version)
|
| 463 |
+
category = category.name.lower() if category is not None else "unknown"
|
| 464 |
+
stack = allocation_stacks.get(tensor_key, ())
|
| 465 |
+
stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
|
| 466 |
+
r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
|
| 467 |
+
if during_trace:
|
| 468 |
+
snapshot['device_traces'][device].append(r) # type: ignore[index]
|
| 469 |
+
return r
|
| 470 |
+
|
| 471 |
+
def free(alloc, device):
|
| 472 |
+
for e in ('free_requested', 'free_completed'):
|
| 473 |
+
snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
|
| 474 |
+
'addr': alloc['addr'],
|
| 475 |
+
'size': alloc['size'],
|
| 476 |
+
'stream': 0,
|
| 477 |
+
'frames': alloc['frames']})
|
| 478 |
+
|
| 479 |
+
kv_to_elem = {}
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# create the device trace
|
| 484 |
+
for time, action, (tensor_key, version), size in memory_profile.timeline:
|
| 485 |
+
if not isinstance(tensor_key, TensorKey):
|
| 486 |
+
continue
|
| 487 |
+
if action == Action.CREATE:
|
| 488 |
+
kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version)
|
| 489 |
+
elif action == Action.DESTROY:
|
| 490 |
+
free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
|
| 491 |
+
elif action == Action.INCREMENT_VERSION:
|
| 492 |
+
free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
|
| 493 |
+
kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1)
|
| 494 |
+
elif action == Action.PREEXISTING:
|
| 495 |
+
kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# create the final snapshot state
|
| 499 |
+
blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
|
| 500 |
+
for (tensor_key, version), event in kv_to_elem.items()]
|
| 501 |
+
for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)):
|
| 502 |
+
seg = snapshot['segments'][device] # type: ignore[index]
|
| 503 |
+
last_addr = seg['address']
|
| 504 |
+
for _, addr, size, frames in blocks:
|
| 505 |
+
if last_addr < addr:
|
| 506 |
+
seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'})
|
| 507 |
+
seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames})
|
| 508 |
+
last_addr = addr + size
|
| 509 |
+
if last_addr < seg['total_size']:
|
| 510 |
+
seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
|
| 511 |
+
|
| 512 |
+
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
|
| 513 |
+
for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
|
| 514 |
+
seg['total_size'] -= seg['address']
|
| 515 |
+
if not seg['blocks']:
|
| 516 |
+
seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
|
| 517 |
+
|
| 518 |
+
return snapshot
|
| 519 |
+
|
| 520 |
+
def profile_plot(profile, device=None):
|
| 521 |
+
"""Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
profile: profile as generated by `torch.profiler.profile(profile_memory=True)`
|
| 525 |
+
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
str: HTML of visualization
|
| 529 |
+
"""
|
| 530 |
+
snapshot = _profile_to_snapshot(profile)
|
| 531 |
+
return _format_viz(snapshot, 'Active Memory Timeline', device)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def segment_plot(data: Any, device=None):
|
| 535 |
+
return _format_viz(data, 'Allocator State History', device)
|
| 536 |
+
|
| 537 |
+
if __name__ == "__main__":
|
| 538 |
+
import os.path
|
| 539 |
+
thedir = os.path.realpath(os.path.dirname(__file__))
|
| 540 |
+
if thedir in sys.path:
|
| 541 |
+
# otherwise we find cuda/random.py as random...
|
| 542 |
+
sys.path.remove(thedir)
|
| 543 |
+
import argparse
|
| 544 |
+
|
| 545 |
+
fn_name = 'torch.cuda.memory._snapshot()'
|
| 546 |
+
pickled = f'pickled memory statistics from {fn_name}'
|
| 547 |
+
parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}')
|
| 548 |
+
|
| 549 |
+
subparsers = parser.add_subparsers(dest='action')
|
| 550 |
+
|
| 551 |
+
def _output(p):
|
| 552 |
+
p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)')
|
| 553 |
+
|
| 554 |
+
description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.'
|
| 555 |
+
stats_a = subparsers.add_parser('stats', description=description)
|
| 556 |
+
stats_a.add_argument('input', help=pickled)
|
| 557 |
+
|
| 558 |
+
description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.'
|
| 559 |
+
trace_a = subparsers.add_parser('trace', description=description)
|
| 560 |
+
trace_a.add_argument('input', help=pickled)
|
| 561 |
+
|
| 562 |
+
description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)'
|
| 563 |
+
segments_a = subparsers.add_parser('segments', description=description)
|
| 564 |
+
segments_a.add_argument('input', help=pickled)
|
| 565 |
+
_output(segments_a)
|
| 566 |
+
|
| 567 |
+
description = "Generate a flamegraph the program locations contributing to CUDA memory usage."
|
| 568 |
+
memory_a = subparsers.add_parser('memory', description=description)
|
| 569 |
+
memory_a.add_argument('input', help=pickled)
|
| 570 |
+
_output(memory_a)
|
| 571 |
+
|
| 572 |
+
description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \
|
| 573 |
+
'or removed between two different memorys snapshots.'
|
| 574 |
+
compare_a = subparsers.add_parser('compare', description=description)
|
| 575 |
+
compare_a.add_argument('before', help=pickled)
|
| 576 |
+
compare_a.add_argument('after', help=pickled)
|
| 577 |
+
_output(compare_a)
|
| 578 |
+
|
| 579 |
+
plots = (
|
| 580 |
+
("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."),
|
| 581 |
+
("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.")
|
| 582 |
+
)
|
| 583 |
+
for cmd, description in plots:
|
| 584 |
+
trace_plot_a = subparsers.add_parser(cmd, description=description)
|
| 585 |
+
trace_plot_a.add_argument('input', help=pickled)
|
| 586 |
+
help = 'visualize trace from this device (default: chooses the only device with trace info or errors)'
|
| 587 |
+
trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help)
|
| 588 |
+
help = 'path to save the visualization(default: output.html)'
|
| 589 |
+
trace_plot_a.add_argument('-o', '--output', default='output.html', help=help)
|
| 590 |
+
if cmd == "trace_plot":
|
| 591 |
+
help = 'visualize change to segments rather than individual allocations'
|
| 592 |
+
trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
args = parser.parse_args()
|
| 596 |
+
|
| 597 |
+
def _read(name):
|
| 598 |
+
if name == '-':
|
| 599 |
+
f = sys.stdin.buffer
|
| 600 |
+
else:
|
| 601 |
+
f = open(name, 'rb')
|
| 602 |
+
data = pickle.load(f)
|
| 603 |
+
if isinstance(data, list): # segments only...
|
| 604 |
+
data = {'segments': data, 'traces': []}
|
| 605 |
+
return data
|
| 606 |
+
|
| 607 |
+
def _write(name, data):
|
| 608 |
+
with open(name, 'w') as f:
|
| 609 |
+
f.write(data)
|
| 610 |
+
|
| 611 |
+
if args.action == 'segments':
|
| 612 |
+
data = _read(args.input)
|
| 613 |
+
_write(args.output, segments(data))
|
| 614 |
+
elif args.action == 'memory':
|
| 615 |
+
data = _read(args.input)
|
| 616 |
+
_write(args.output, memory(data))
|
| 617 |
+
elif args.action == 'stats':
|
| 618 |
+
data = _read(args.input)
|
| 619 |
+
print(segsum(data))
|
| 620 |
+
elif args.action == 'trace':
|
| 621 |
+
data = _read(args.input)
|
| 622 |
+
print(trace(data))
|
| 623 |
+
elif args.action == 'compare':
|
| 624 |
+
before = _read(args.before)
|
| 625 |
+
after = _read(args.after)
|
| 626 |
+
_write(args.output, compare(before, after))
|
| 627 |
+
elif args.action == 'trace_plot':
|
| 628 |
+
data = _read(args.input)
|
| 629 |
+
_write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments))
|
| 630 |
+
elif args.action == 'segment_plot':
|
| 631 |
+
data = _read(args.input)
|
| 632 |
+
_write(args.output, segment_plot(data, device=args.device))
|
.venv/lib/python3.11/site-packages/torch/cuda/_sanitizer.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""
|
| 3 |
+
This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
|
| 4 |
+
|
| 5 |
+
It stores information on accesses to tensors to determine if they are synchronized
|
| 6 |
+
or not. When enabled in a python program and a possible data race is detected, a
|
| 7 |
+
detailed warning will be printed and the program will exit.
|
| 8 |
+
|
| 9 |
+
It can be enabled either by importing this module and calling
|
| 10 |
+
:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
|
| 11 |
+
environment variable.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import enum
|
| 15 |
+
import functools
|
| 16 |
+
import inspect
|
| 17 |
+
import io
|
| 18 |
+
import logging
|
| 19 |
+
import sys
|
| 20 |
+
import textwrap
|
| 21 |
+
import traceback
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.cuda._gpu_trace as gpu_trace
|
| 27 |
+
from torch.utils import _pytree as pytree
|
| 28 |
+
from torch.utils._python_dispatch import TorchDispatchMode
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DEFAULT_STREAM_ID = 0
|
| 32 |
+
|
| 33 |
+
TK = TypeVar("TK")
|
| 34 |
+
TVa = TypeVar("TVa")
|
| 35 |
+
TVb = TypeVar("TVb")
|
| 36 |
+
|
| 37 |
+
DataPtr = int
|
| 38 |
+
StreamId = int
|
| 39 |
+
EventId = int
|
| 40 |
+
SeqNum = int
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AccessType(enum.Enum):
|
| 46 |
+
READ = enum.auto()
|
| 47 |
+
WRITE = enum.auto()
|
| 48 |
+
|
| 49 |
+
def __str__(self):
|
| 50 |
+
return "reading from" if self is AccessType.READ else "writing to"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class Access:
|
| 55 |
+
r"""Stores information about a single access to a tensor by a kernel.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
type: either AccessType.READ or AccessType.Write.
|
| 59 |
+
seq_num: the sequential number of the kernel performing the access.
|
| 60 |
+
stream: the stream id of the stream executing the kernel.
|
| 61 |
+
operator: the schema of the launched kernel, which lists the
|
| 62 |
+
arguments and return type.
|
| 63 |
+
aliases: the arguments in the schema this access corresponds to.
|
| 64 |
+
is_output: Whether the tensor was an output of the kernel.
|
| 65 |
+
stack_trace: the stack summary object captured during access.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
type: AccessType
|
| 69 |
+
seq_num: SeqNum
|
| 70 |
+
stream: StreamId
|
| 71 |
+
operator: str
|
| 72 |
+
aliases: List[str]
|
| 73 |
+
is_output: bool
|
| 74 |
+
stack_trace: traceback.StackSummary
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SynchronizationError(Exception):
|
| 78 |
+
"""Base class for errors detected by CUDA Sanitizer."""
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class UnsynchronizedAccessError(SynchronizationError):
|
| 82 |
+
"""Stores information about two unsynchronized accesses to one data pointer."""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
data_ptr: DataPtr,
|
| 87 |
+
allocation_stack_trace: Optional[traceback.StackSummary],
|
| 88 |
+
current_access: Access,
|
| 89 |
+
previous_access: Access,
|
| 90 |
+
):
|
| 91 |
+
self.data_ptr = data_ptr
|
| 92 |
+
self.allocation_stack_trace = allocation_stack_trace
|
| 93 |
+
self.current_access = current_access
|
| 94 |
+
self.previous_access = previous_access
|
| 95 |
+
|
| 96 |
+
def __str__(self):
|
| 97 |
+
def format_access(access: Access):
|
| 98 |
+
message.write(f"{access.operator}\n{access.type}")
|
| 99 |
+
if access.aliases:
|
| 100 |
+
message.write(" argument(s) " + ", ".join(access.aliases))
|
| 101 |
+
if access.is_output:
|
| 102 |
+
message.write(", and to")
|
| 103 |
+
if access.is_output:
|
| 104 |
+
message.write(" the output")
|
| 105 |
+
message.write(
|
| 106 |
+
f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
with io.StringIO() as message:
|
| 110 |
+
message.write(
|
| 111 |
+
textwrap.dedent(
|
| 112 |
+
f"""\
|
| 113 |
+
============================
|
| 114 |
+
CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
|
| 115 |
+
Access by stream {self.current_access.stream} during kernel:
|
| 116 |
+
"""
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
format_access(self.current_access)
|
| 120 |
+
|
| 121 |
+
message.write(
|
| 122 |
+
f"Previous access by stream {self.previous_access.stream} during kernel:\n"
|
| 123 |
+
)
|
| 124 |
+
format_access(self.previous_access)
|
| 125 |
+
|
| 126 |
+
if self.allocation_stack_trace:
|
| 127 |
+
message.write(
|
| 128 |
+
"Tensor was allocated with stack trace:\n"
|
| 129 |
+
f"{''.join(self.allocation_stack_trace.format())}"
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
message.write("Trace for tensor allocation not found.")
|
| 133 |
+
return message.getvalue()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class CUDASanitizerErrors(Exception):
|
| 137 |
+
"""Wrapper class for errors reported by CUDA Sanitizer."""
|
| 138 |
+
|
| 139 |
+
def __init__(self, errors: List[SynchronizationError]):
|
| 140 |
+
self.errors = errors
|
| 141 |
+
|
| 142 |
+
def __str__(self):
|
| 143 |
+
return f"detected {len(self.errors)} errors"
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@dataclass
|
| 147 |
+
class TensorInfo:
|
| 148 |
+
r"""Stores information about a single tensor and recent accesses to it.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
allocation_stack_trace: the stack summary object captured during tensor
|
| 152 |
+
allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
|
| 153 |
+
reads: list of read accesses to the tensor that were performed since
|
| 154 |
+
the last write.
|
| 155 |
+
write: the last write access to the tensor.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
allocation_stack_trace: Optional[traceback.StackSummary]
|
| 159 |
+
reads: List[Access] = field(default_factory=list)
|
| 160 |
+
write: Optional[Access] = None
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class _TensorsAccessed:
|
| 164 |
+
def __init__(self) -> None:
|
| 165 |
+
self.accesses: Dict[DataPtr, TensorInfo] = {}
|
| 166 |
+
|
| 167 |
+
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
|
| 168 |
+
if data_ptr not in self.accesses:
|
| 169 |
+
logger.info(
|
| 170 |
+
"Found tensor with pointer: %s, but no matching tensor "
|
| 171 |
+
"allocation in the trace. Backfilling the trace now. "
|
| 172 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 173 |
+
data_ptr,
|
| 174 |
+
)
|
| 175 |
+
self.create_tensor(data_ptr, None)
|
| 176 |
+
|
| 177 |
+
def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
|
| 178 |
+
if data_ptr in self.accesses:
|
| 179 |
+
logger.info(
|
| 180 |
+
"Found duplicate tensor allocation in the trace for tensor with "
|
| 181 |
+
"pointer: %s. Assuming the trace for tensor deallocation "
|
| 182 |
+
"wasn't caught and backfilling it now. "
|
| 183 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 184 |
+
data_ptr,
|
| 185 |
+
)
|
| 186 |
+
self.delete_tensor(data_ptr)
|
| 187 |
+
|
| 188 |
+
def create_tensor(
|
| 189 |
+
self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
|
| 190 |
+
) -> None:
|
| 191 |
+
self.accesses[data_ptr] = TensorInfo(stack_trace)
|
| 192 |
+
|
| 193 |
+
def delete_tensor(self, data_ptr: DataPtr) -> None:
|
| 194 |
+
del self.accesses[data_ptr]
|
| 195 |
+
|
| 196 |
+
def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
|
| 197 |
+
return True if self.accesses[data_ptr].reads else False
|
| 198 |
+
|
| 199 |
+
def get_allocation_stack_trace(
|
| 200 |
+
self, data_ptr: DataPtr
|
| 201 |
+
) -> Optional[traceback.StackSummary]:
|
| 202 |
+
return self.accesses[data_ptr].allocation_stack_trace
|
| 203 |
+
|
| 204 |
+
def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
|
| 205 |
+
return self.accesses[data_ptr].write
|
| 206 |
+
|
| 207 |
+
def get_reads(self, data_ptr: DataPtr) -> List[Access]:
|
| 208 |
+
return self.accesses[data_ptr].reads
|
| 209 |
+
|
| 210 |
+
def add_read(self, data_ptr: DataPtr, access: Access) -> None:
|
| 211 |
+
self.accesses[data_ptr].reads.append(access)
|
| 212 |
+
|
| 213 |
+
def set_write(self, data_ptr: DataPtr, access: Access) -> None:
|
| 214 |
+
self.accesses[data_ptr].write = access
|
| 215 |
+
self.accesses[data_ptr].reads = []
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class StreamSynchronizations:
|
| 219 |
+
def __init__(self) -> None:
|
| 220 |
+
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
|
| 221 |
+
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
|
| 222 |
+
self.host_sync_state: Dict[StreamId, SeqNum] = {}
|
| 223 |
+
self.create_stream(DEFAULT_STREAM_ID)
|
| 224 |
+
|
| 225 |
+
def _ensure_stream_exists(self, stream: StreamId) -> None:
|
| 226 |
+
if stream not in self.current_sync_states:
|
| 227 |
+
logger.info(
|
| 228 |
+
"Found Stream with id: %s, but no matching stream "
|
| 229 |
+
"creation in the trace. Backfilling the trace now. "
|
| 230 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 231 |
+
stream,
|
| 232 |
+
)
|
| 233 |
+
self.create_stream(stream)
|
| 234 |
+
|
| 235 |
+
def _ensure_event_exists(self, event: EventId) -> None:
|
| 236 |
+
if event not in self.recorded_sync_states:
|
| 237 |
+
logger.info(
|
| 238 |
+
"Found Event with id: %s, but no matching event "
|
| 239 |
+
"creation in the trace. Backfilling the trace now. "
|
| 240 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 241 |
+
event,
|
| 242 |
+
)
|
| 243 |
+
self.create_event(event)
|
| 244 |
+
|
| 245 |
+
def _ensure_event_does_not_exist(self, event: EventId) -> None:
|
| 246 |
+
if event in self.recorded_sync_states:
|
| 247 |
+
logger.info(
|
| 248 |
+
"Found duplicate event creation in the trace for event with "
|
| 249 |
+
"id: %s. Assuming the trace for event deletion wasn't caught "
|
| 250 |
+
"and backfilling it now. "
|
| 251 |
+
"Perhaps the sanitizer was enabled after some torch operations?",
|
| 252 |
+
event,
|
| 253 |
+
)
|
| 254 |
+
self.delete_event(event)
|
| 255 |
+
|
| 256 |
+
def create_stream(self, stream: StreamId) -> None:
|
| 257 |
+
if stream in self.current_sync_states:
|
| 258 |
+
logger.info(
|
| 259 |
+
"Found duplicate Stream creation in the trace for Stream with "
|
| 260 |
+
"id: %s. PyTorch Streams are only created once, so this "
|
| 261 |
+
"trace entry is ignored.",
|
| 262 |
+
stream,
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
self.host_sync_state[stream] = 0
|
| 266 |
+
self.current_sync_states[stream] = self.host_sync_state.copy()
|
| 267 |
+
|
| 268 |
+
def create_event(self, event: EventId) -> None:
|
| 269 |
+
self._ensure_event_does_not_exist(event)
|
| 270 |
+
self.recorded_sync_states[event] = {}
|
| 271 |
+
|
| 272 |
+
def delete_event(self, event: EventId) -> None:
|
| 273 |
+
self._ensure_event_exists(event)
|
| 274 |
+
del self.recorded_sync_states[event]
|
| 275 |
+
|
| 276 |
+
def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
|
| 277 |
+
self._ensure_stream_exists(stream)
|
| 278 |
+
self.current_sync_states[stream][stream] = seq_num
|
| 279 |
+
|
| 280 |
+
def record_state(self, event: EventId, stream: StreamId) -> None:
|
| 281 |
+
self._ensure_event_exists(event)
|
| 282 |
+
self._ensure_stream_exists(stream)
|
| 283 |
+
self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
|
| 284 |
+
|
| 285 |
+
def _state_wait_for_other(
|
| 286 |
+
self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
|
| 287 |
+
) -> None:
|
| 288 |
+
for stream, seq_num in other.items():
|
| 289 |
+
state[stream] = max(state.get(stream, -1), seq_num)
|
| 290 |
+
|
| 291 |
+
def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
|
| 292 |
+
self._ensure_stream_exists(stream)
|
| 293 |
+
self._ensure_event_exists(event)
|
| 294 |
+
self._state_wait_for_other(
|
| 295 |
+
self.current_sync_states[stream], self.recorded_sync_states[event]
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def all_streams_wait_for_event(self, event: EventId) -> None:
|
| 299 |
+
self._ensure_event_exists(event)
|
| 300 |
+
for stream in self.current_sync_states.keys():
|
| 301 |
+
self.stream_wait_for_event(stream, event)
|
| 302 |
+
|
| 303 |
+
self._state_wait_for_other(
|
| 304 |
+
self.host_sync_state, self.recorded_sync_states[event]
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
def all_streams_wait_for_stream(self, stream: StreamId) -> None:
|
| 308 |
+
self._ensure_stream_exists(stream)
|
| 309 |
+
for state in self.current_sync_states.values():
|
| 310 |
+
self._state_wait_for_other(state, self.current_sync_states[stream])
|
| 311 |
+
|
| 312 |
+
self._state_wait_for_other(
|
| 313 |
+
self.host_sync_state, self.current_sync_states[stream]
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def sync_all_streams(self) -> None:
|
| 317 |
+
for stream, state in self.current_sync_states.items():
|
| 318 |
+
self.host_sync_state[stream] = state[stream]
|
| 319 |
+
|
| 320 |
+
for state in self.current_sync_states.values():
|
| 321 |
+
self._state_wait_for_other(state, self.host_sync_state)
|
| 322 |
+
|
| 323 |
+
def is_ordered_after(
|
| 324 |
+
self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
|
| 325 |
+
) -> bool:
|
| 326 |
+
self._ensure_stream_exists(current_stream)
|
| 327 |
+
self._ensure_stream_exists(other_stream)
|
| 328 |
+
return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class EventHandler:
|
| 332 |
+
"""Analyzes CSAN trace for synchronization errors.
|
| 333 |
+
|
| 334 |
+
Stores information on each stream's synchronizations with other streams as well
|
| 335 |
+
as tensor accesses to determine whether a given kernel launch might cause a
|
| 336 |
+
data race.
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
def __init__(self) -> None:
|
| 340 |
+
self.tensors_accessed = _TensorsAccessed()
|
| 341 |
+
self.syncs = StreamSynchronizations()
|
| 342 |
+
self.seq_num: SeqNum = 0
|
| 343 |
+
|
| 344 |
+
def _handle_kernel_launch(
|
| 345 |
+
self,
|
| 346 |
+
stream: StreamId,
|
| 347 |
+
read_only: Set[DataPtr],
|
| 348 |
+
read_write: Set[DataPtr],
|
| 349 |
+
outputs: Set[DataPtr],
|
| 350 |
+
operator: str,
|
| 351 |
+
tensor_aliases: Dict[int, List[str]],
|
| 352 |
+
) -> List[SynchronizationError]:
|
| 353 |
+
def check_conflict(
|
| 354 |
+
data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
|
| 355 |
+
) -> None:
|
| 356 |
+
if previous_access is None:
|
| 357 |
+
return
|
| 358 |
+
if not self.syncs.is_ordered_after(
|
| 359 |
+
current_access.stream, previous_access.seq_num, previous_access.stream
|
| 360 |
+
):
|
| 361 |
+
error_list.append(
|
| 362 |
+
UnsynchronizedAccessError(
|
| 363 |
+
data_ptr,
|
| 364 |
+
self.tensors_accessed.get_allocation_stack_trace(data_ptr),
|
| 365 |
+
current_access,
|
| 366 |
+
previous_access,
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
error_list: List[SynchronizationError] = []
|
| 371 |
+
self.seq_num += 1
|
| 372 |
+
self.syncs.update_seq_num(stream, self.seq_num)
|
| 373 |
+
stack_trace = traceback.StackSummary.extract(
|
| 374 |
+
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
|
| 375 |
+
)
|
| 376 |
+
# The stack trace generated in this way is in the inverse order, so it must be
|
| 377 |
+
# reversed.
|
| 378 |
+
stack_trace.reverse()
|
| 379 |
+
|
| 380 |
+
for data_ptr in read_only:
|
| 381 |
+
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
| 382 |
+
current_access = Access(
|
| 383 |
+
AccessType.READ,
|
| 384 |
+
self.seq_num,
|
| 385 |
+
stream,
|
| 386 |
+
operator,
|
| 387 |
+
tensor_aliases[data_ptr],
|
| 388 |
+
data_ptr in outputs,
|
| 389 |
+
stack_trace,
|
| 390 |
+
)
|
| 391 |
+
check_conflict(
|
| 392 |
+
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
|
| 393 |
+
)
|
| 394 |
+
self.tensors_accessed.add_read(data_ptr, current_access)
|
| 395 |
+
|
| 396 |
+
for data_ptr in read_write:
|
| 397 |
+
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
| 398 |
+
current_access = Access(
|
| 399 |
+
AccessType.WRITE,
|
| 400 |
+
self.seq_num,
|
| 401 |
+
stream,
|
| 402 |
+
operator,
|
| 403 |
+
tensor_aliases[data_ptr],
|
| 404 |
+
data_ptr in outputs,
|
| 405 |
+
stack_trace,
|
| 406 |
+
)
|
| 407 |
+
if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
|
| 408 |
+
for previous_access in self.tensors_accessed.get_reads(data_ptr):
|
| 409 |
+
check_conflict(data_ptr, current_access, previous_access)
|
| 410 |
+
else:
|
| 411 |
+
check_conflict(
|
| 412 |
+
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
|
| 413 |
+
)
|
| 414 |
+
self.tensors_accessed.set_write(data_ptr, current_access)
|
| 415 |
+
|
| 416 |
+
return error_list
|
| 417 |
+
|
| 418 |
+
def _handle_event_creation(self, event: EventId) -> None:
|
| 419 |
+
self.syncs.create_event(event)
|
| 420 |
+
|
| 421 |
+
def _handle_event_deletion(self, event: EventId) -> None:
|
| 422 |
+
self.syncs.delete_event(event)
|
| 423 |
+
|
| 424 |
+
def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
|
| 425 |
+
self.syncs.record_state(event, stream)
|
| 426 |
+
|
| 427 |
+
def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
|
| 428 |
+
self.syncs.stream_wait_for_event(stream, event)
|
| 429 |
+
|
| 430 |
+
def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
|
| 431 |
+
self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
|
| 432 |
+
stack_trace = traceback.StackSummary.extract(
|
| 433 |
+
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
|
| 434 |
+
)
|
| 435 |
+
# The stack trace generated in this way is in the inverse order, so it must be
|
| 436 |
+
# reversed.
|
| 437 |
+
stack_trace.reverse()
|
| 438 |
+
self.tensors_accessed.create_tensor(
|
| 439 |
+
data_ptr,
|
| 440 |
+
stack_trace,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
|
| 444 |
+
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
| 445 |
+
self.tensors_accessed.delete_tensor(data_ptr)
|
| 446 |
+
|
| 447 |
+
def _handle_stream_creation(self, stream: StreamId) -> None:
|
| 448 |
+
self.syncs.create_stream(stream)
|
| 449 |
+
|
| 450 |
+
def _handle_device_synchronization(self) -> None:
|
| 451 |
+
self.syncs.sync_all_streams()
|
| 452 |
+
|
| 453 |
+
def _handle_stream_synchronization(self, stream: StreamId) -> None:
|
| 454 |
+
self.syncs.all_streams_wait_for_stream(stream)
|
| 455 |
+
|
| 456 |
+
def _handle_event_synchronization(self, event: EventId) -> None:
|
| 457 |
+
self.syncs.all_streams_wait_for_event(event)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
|
| 461 |
+
for arg, value in a.items():
|
| 462 |
+
if arg in b:
|
| 463 |
+
yield arg, value, b[arg]
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def zip_arguments(
|
| 467 |
+
schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 468 |
+
) -> Iterator[Tuple[torch.Argument, Any]]:
|
| 469 |
+
schema_args = schema.arguments[: len(args)]
|
| 470 |
+
schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
|
| 471 |
+
|
| 472 |
+
yield from zip(schema_args, args)
|
| 473 |
+
|
| 474 |
+
for _, argument, value in zip_by_key(schema_kwargs, kwargs):
|
| 475 |
+
yield (argument, value)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class ArgumentHandler:
|
| 479 |
+
def __init__(self) -> None:
|
| 480 |
+
self.dataptrs_read: Set[DataPtr] = set()
|
| 481 |
+
self.dataptrs_written: Set[DataPtr] = set()
|
| 482 |
+
self.tensor_aliases: Dict[DataPtr, List[str]] = {}
|
| 483 |
+
self.outputs: Set[DataPtr] = set()
|
| 484 |
+
|
| 485 |
+
def _handle_argument(
|
| 486 |
+
self,
|
| 487 |
+
value: Any,
|
| 488 |
+
is_write: bool,
|
| 489 |
+
name: Optional[str] = None,
|
| 490 |
+
is_output: bool = False,
|
| 491 |
+
) -> None:
|
| 492 |
+
if isinstance(value, torch.Tensor) and value.is_cuda:
|
| 493 |
+
data_ptr = value.data_ptr()
|
| 494 |
+
if is_write:
|
| 495 |
+
self.dataptrs_written.add(data_ptr)
|
| 496 |
+
else:
|
| 497 |
+
self.dataptrs_read.add(data_ptr)
|
| 498 |
+
|
| 499 |
+
self.tensor_aliases.setdefault(data_ptr, [])
|
| 500 |
+
if name is not None:
|
| 501 |
+
self.tensor_aliases[data_ptr].append(name)
|
| 502 |
+
if is_output:
|
| 503 |
+
self.outputs.add(data_ptr)
|
| 504 |
+
|
| 505 |
+
def parse_inputs(
|
| 506 |
+
self,
|
| 507 |
+
schema: torch.FunctionSchema,
|
| 508 |
+
args: Tuple[Any, ...],
|
| 509 |
+
kwargs: Dict[str, Any],
|
| 510 |
+
) -> None:
|
| 511 |
+
for argument, value in zip_arguments(schema, args, kwargs):
|
| 512 |
+
is_write = argument.alias_info is not None and argument.alias_info.is_write
|
| 513 |
+
pytree.tree_map_(
|
| 514 |
+
functools.partial(
|
| 515 |
+
self._handle_argument, is_write=is_write, name=argument.name
|
| 516 |
+
),
|
| 517 |
+
value,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
def parse_outputs(self, outputs: Any) -> None:
|
| 521 |
+
pytree.tree_map_(
|
| 522 |
+
functools.partial(self._handle_argument, is_write=True, is_output=True),
|
| 523 |
+
outputs,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class CUDASanitizerDispatchMode(TorchDispatchMode):
|
| 528 |
+
def __init__(self) -> None:
|
| 529 |
+
self.event_handler = EventHandler()
|
| 530 |
+
torch._C._activate_gpu_trace()
|
| 531 |
+
gpu_trace.register_callback_for_event_creation(
|
| 532 |
+
self.event_handler._handle_event_creation
|
| 533 |
+
)
|
| 534 |
+
gpu_trace.register_callback_for_event_deletion(
|
| 535 |
+
self.event_handler._handle_event_deletion
|
| 536 |
+
)
|
| 537 |
+
gpu_trace.register_callback_for_event_record(
|
| 538 |
+
self.event_handler._handle_event_record
|
| 539 |
+
)
|
| 540 |
+
gpu_trace.register_callback_for_event_wait(
|
| 541 |
+
self.event_handler._handle_event_wait
|
| 542 |
+
)
|
| 543 |
+
gpu_trace.register_callback_for_memory_allocation(
|
| 544 |
+
self.event_handler._handle_memory_allocation
|
| 545 |
+
)
|
| 546 |
+
gpu_trace.register_callback_for_memory_deallocation(
|
| 547 |
+
self.event_handler._handle_memory_deallocation
|
| 548 |
+
)
|
| 549 |
+
gpu_trace.register_callback_for_stream_creation(
|
| 550 |
+
self.event_handler._handle_stream_creation
|
| 551 |
+
)
|
| 552 |
+
gpu_trace.register_callback_for_device_synchronization(
|
| 553 |
+
self.event_handler._handle_device_synchronization
|
| 554 |
+
)
|
| 555 |
+
gpu_trace.register_callback_for_stream_synchronization(
|
| 556 |
+
self.event_handler._handle_stream_synchronization
|
| 557 |
+
)
|
| 558 |
+
gpu_trace.register_callback_for_event_synchronization(
|
| 559 |
+
self.event_handler._handle_event_synchronization
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
| 563 |
+
if kwargs is None:
|
| 564 |
+
kwargs = {}
|
| 565 |
+
|
| 566 |
+
argument_handler = ArgumentHandler()
|
| 567 |
+
argument_handler.parse_inputs(func._schema, args, kwargs)
|
| 568 |
+
|
| 569 |
+
outputs = func(*args, **kwargs)
|
| 570 |
+
|
| 571 |
+
argument_handler.parse_outputs(outputs)
|
| 572 |
+
errors = self.event_handler._handle_kernel_launch(
|
| 573 |
+
torch.cuda.current_stream().cuda_stream,
|
| 574 |
+
argument_handler.dataptrs_read - argument_handler.dataptrs_written,
|
| 575 |
+
argument_handler.dataptrs_written,
|
| 576 |
+
argument_handler.outputs,
|
| 577 |
+
func._schema,
|
| 578 |
+
argument_handler.tensor_aliases,
|
| 579 |
+
)
|
| 580 |
+
if errors:
|
| 581 |
+
for error in errors:
|
| 582 |
+
print(error, file=sys.stderr)
|
| 583 |
+
raise CUDASanitizerErrors(errors)
|
| 584 |
+
|
| 585 |
+
return outputs
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
class CUDASanitizer:
|
| 589 |
+
"""Manages the lifetime of a CUDASanitizer dispatch mode object.
|
| 590 |
+
|
| 591 |
+
The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
|
| 592 |
+
context manager in the enable function/destructor, respectively. This is to
|
| 593 |
+
explicitly set the lifetime of the dispatch mode object to that of the application.
|
| 594 |
+
This approach was deemed more elegant than using the atexit module.
|
| 595 |
+
"""
|
| 596 |
+
|
| 597 |
+
def __init__(self) -> None:
|
| 598 |
+
self.dispatch = CUDASanitizerDispatchMode()
|
| 599 |
+
self.enabled = False
|
| 600 |
+
|
| 601 |
+
def enable(self):
|
| 602 |
+
self.dispatch.__enter__()
|
| 603 |
+
self.enabled = True
|
| 604 |
+
|
| 605 |
+
def __del__(self):
|
| 606 |
+
if self.enabled:
|
| 607 |
+
self.dispatch.__exit__(None, None, None)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def enable_cuda_sanitizer():
|
| 611 |
+
"""Enable CUDA Sanitizer.
|
| 612 |
+
|
| 613 |
+
The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
|
| 614 |
+
for synchronization errors. All data races found will be printed to the standard
|
| 615 |
+
error output along with stack traces of suspected causes. For best results, the
|
| 616 |
+
sanitizer should be enabled at the very beginning of the program.
|
| 617 |
+
"""
|
| 618 |
+
cuda_sanitizer.enable()
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
cuda_sanitizer = CUDASanitizer()
|
.venv/lib/python3.11/site-packages/torch/cuda/_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# The _get_device_index has been moved to torch.utils._get_device_index
|
| 6 |
+
from torch._utils import _get_device_index as _torch_get_device_index
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _get_device_index(
|
| 10 |
+
device: Any, optional: bool = False, allow_cpu: bool = False
|
| 11 |
+
) -> int:
|
| 12 |
+
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
|
| 13 |
+
|
| 14 |
+
If :attr:`device` is a torch.device object, returns the device index if it
|
| 15 |
+
is a CUDA device. Note that for a CUDA device without a specified index,
|
| 16 |
+
i.e., ``torch.device('cuda')``, this will return the current default CUDA
|
| 17 |
+
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
|
| 18 |
+
CPU devices will be accepted and ``-1`` will be returned in this case.
|
| 19 |
+
|
| 20 |
+
If :attr:`device` is a Python integer, it is returned as is.
|
| 21 |
+
|
| 22 |
+
If :attr:`device` is ``None``, this will return the current default CUDA
|
| 23 |
+
device if :attr:`optional` is ``True``.
|
| 24 |
+
"""
|
| 25 |
+
if isinstance(device, int):
|
| 26 |
+
return device
|
| 27 |
+
if isinstance(device, str):
|
| 28 |
+
device = torch.device(device)
|
| 29 |
+
if isinstance(device, torch.device):
|
| 30 |
+
if allow_cpu:
|
| 31 |
+
if device.type not in ["cuda", "cpu"]:
|
| 32 |
+
raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
|
| 33 |
+
elif device.type != "cuda":
|
| 34 |
+
raise ValueError(f"Expected a cuda device, but got: {device}")
|
| 35 |
+
if not torch.jit.is_scripting():
|
| 36 |
+
if isinstance(device, torch.cuda.device):
|
| 37 |
+
return device.idx
|
| 38 |
+
return _torch_get_device_index(device, optional, allow_cpu)
|
.venv/lib/python3.11/site-packages/torch/cuda/comm.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The functions here have been moved to torch.nn.parallel.comm
|
| 2 |
+
from torch.nn.parallel.comm import (
|
| 3 |
+
broadcast,
|
| 4 |
+
broadcast_coalesced,
|
| 5 |
+
gather,
|
| 6 |
+
reduce_add,
|
| 7 |
+
reduce_add_coalesced,
|
| 8 |
+
scatter,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"broadcast",
|
| 14 |
+
"broadcast_coalesced",
|
| 15 |
+
"reduce_add",
|
| 16 |
+
"reduce_add_coalesced",
|
| 17 |
+
"scatter",
|
| 18 |
+
"gather",
|
| 19 |
+
]
|
.venv/lib/python3.11/site-packages/torch/cuda/error.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/cuda/gds.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Callable, List, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.types import Storage
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__: List[str] = []
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _dummy_fn(name: str) -> Callable:
|
| 13 |
+
def fn(*args, **kwargs): # type: ignore[no-untyped-def]
|
| 14 |
+
raise RuntimeError(f"torch._C.{name} is not supported on this platform")
|
| 15 |
+
|
| 16 |
+
return fn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if not hasattr(torch._C, "_gds_register_buffer"):
|
| 20 |
+
assert not hasattr(torch._C, "_gds_deregister_buffer")
|
| 21 |
+
assert not hasattr(torch._C, "_gds_register_handle")
|
| 22 |
+
assert not hasattr(torch._C, "_gds_deregister_handle")
|
| 23 |
+
assert not hasattr(torch._C, "_gds_load_storage")
|
| 24 |
+
assert not hasattr(torch._C, "_gds_save_storage")
|
| 25 |
+
# Define functions
|
| 26 |
+
torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
|
| 27 |
+
torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
|
| 28 |
+
torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
|
| 29 |
+
torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
|
| 30 |
+
torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
|
| 31 |
+
torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _gds_register_buffer(s: Storage) -> None:
|
| 35 |
+
"""Registers a buffer.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
s (Storage): Buffer to register.
|
| 39 |
+
"""
|
| 40 |
+
torch._C._gds_register_buffer(s)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _gds_deregister_buffer(s: Storage) -> None:
|
| 44 |
+
"""Registers a buffer.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
s (Storage): Buffer to register.
|
| 48 |
+
"""
|
| 49 |
+
torch._C._gds_deregister_buffer(s)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class _GdsFile:
|
| 53 |
+
r"""Wrapper around cuFile.
|
| 54 |
+
|
| 55 |
+
cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
filename (str): Name of the file to open.
|
| 59 |
+
flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
|
| 60 |
+
be added automatically.
|
| 61 |
+
|
| 62 |
+
.. _CUDA GPUDirect Storage Documentation:
|
| 63 |
+
https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, filename: str, flags: int):
|
| 67 |
+
if sys.platform == "win32":
|
| 68 |
+
raise RuntimeError("GdsFile is not supported on this platform.")
|
| 69 |
+
self.filename = filename
|
| 70 |
+
self.flags = flags
|
| 71 |
+
self.fd = os.open(filename, flags | os.O_DIRECT)
|
| 72 |
+
self.handle: Optional[int] = None
|
| 73 |
+
self.register_handle()
|
| 74 |
+
|
| 75 |
+
def __del__(self) -> None:
|
| 76 |
+
if self.handle is not None:
|
| 77 |
+
self.deregister_handle()
|
| 78 |
+
os.close(self.fd)
|
| 79 |
+
|
| 80 |
+
def register_handle(self) -> None:
|
| 81 |
+
"""Registers file descriptor to cuFile Driver.
|
| 82 |
+
|
| 83 |
+
This is a wrapper around ``cuFileHandleRegister``.
|
| 84 |
+
"""
|
| 85 |
+
assert (
|
| 86 |
+
self.handle is None
|
| 87 |
+
), "Cannot register a handle that is already registered."
|
| 88 |
+
self.handle = torch._C._gds_register_handle(self.fd)
|
| 89 |
+
|
| 90 |
+
def deregister_handle(self) -> None:
|
| 91 |
+
"""Deregisters file descriptor from cuFile Driver.
|
| 92 |
+
|
| 93 |
+
This is a wrapper around ``cuFileHandleDeregister``.
|
| 94 |
+
"""
|
| 95 |
+
assert (
|
| 96 |
+
self.handle is not None
|
| 97 |
+
), "Cannot deregister a handle that is not registered."
|
| 98 |
+
torch._C._gds_deregister_handle(self.handle)
|
| 99 |
+
self.handle = None
|
| 100 |
+
|
| 101 |
+
def load_storage(self, storage: Storage, offset: int = 0) -> None:
|
| 102 |
+
"""Loads data from the file into the storage.
|
| 103 |
+
|
| 104 |
+
This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
|
| 105 |
+
will be loaded from the file at ``offset`` into the storage.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
storage (Storage): Storage to load data into.
|
| 109 |
+
offset (int, optional): Offset into the file to start loading from. (Default: 0)
|
| 110 |
+
"""
|
| 111 |
+
assert (
|
| 112 |
+
self.handle is not None
|
| 113 |
+
), "Cannot load data from a file that is not registered."
|
| 114 |
+
torch._C._gds_load_storage(self.handle, storage, offset)
|
| 115 |
+
|
| 116 |
+
def save_storage(self, storage: Storage, offset: int = 0) -> None:
|
| 117 |
+
"""Saves data from the storage into the file.
|
| 118 |
+
|
| 119 |
+
This is a wrapper around ``cuFileWrite``. All bytes of the storage
|
| 120 |
+
will be written to the file at ``offset``.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
storage (Storage): Storage to save data from.
|
| 124 |
+
offset (int, optional): Offset into the file to start saving to. (Default: 0)
|
| 125 |
+
"""
|
| 126 |
+
assert (
|
| 127 |
+
self.handle is not None
|
| 128 |
+
), "Cannot save data to a file that is not registered."
|
| 129 |
+
torch._C._gds_save_storage(self.handle, storage, offset)
|
.venv/lib/python3.11/site-packages/torch/cuda/graphs.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import gc
|
| 3 |
+
import typing
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .._utils import _dummy_type
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if not hasattr(torch._C, "_CudaStreamBase"):
|
| 11 |
+
# Define dummy base classes
|
| 12 |
+
torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
|
| 13 |
+
torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
|
| 14 |
+
torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
|
| 15 |
+
"_cuda_isCurrentStreamCapturing"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from torch._C import ( # noqa: F401
|
| 19 |
+
_cuda_isCurrentStreamCapturing,
|
| 20 |
+
_CUDAGraph,
|
| 21 |
+
_graph_pool_handle,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def is_current_stream_capturing():
|
| 26 |
+
r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
|
| 27 |
+
|
| 28 |
+
If a CUDA context does not exist on the current device, returns False without initializing the context.
|
| 29 |
+
"""
|
| 30 |
+
return _cuda_isCurrentStreamCapturing()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Python shim helps Sphinx process docstrings more reliably.
|
| 34 |
+
def graph_pool_handle():
|
| 35 |
+
r"""Return an opaque token representing the id of a graph memory pool.
|
| 36 |
+
|
| 37 |
+
See :ref:`Graph memory management<graph-memory-management>`.
|
| 38 |
+
|
| 39 |
+
.. warning::
|
| 40 |
+
This API is in beta and may change in future releases.
|
| 41 |
+
"""
|
| 42 |
+
return _graph_pool_handle()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Python shim helps Sphinx process docstrings more reliably.
|
| 46 |
+
class CUDAGraph(torch._C._CUDAGraph):
|
| 47 |
+
r"""Wrapper around a CUDA graph.
|
| 48 |
+
|
| 49 |
+
.. warning::
|
| 50 |
+
This API is in beta and may change in future releases.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __new__(cls):
|
| 54 |
+
return super().__new__(cls)
|
| 55 |
+
|
| 56 |
+
def capture_begin(self, pool=None, capture_error_mode="global"):
|
| 57 |
+
r"""Begin capturing CUDA work on the current stream.
|
| 58 |
+
|
| 59 |
+
Typically, you shouldn't call ``capture_begin`` yourself.
|
| 60 |
+
Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
|
| 61 |
+
which call ``capture_begin`` internally.
|
| 62 |
+
|
| 63 |
+
Arguments:
|
| 64 |
+
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
|
| 65 |
+
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
|
| 66 |
+
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
|
| 67 |
+
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
|
| 68 |
+
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
|
| 69 |
+
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
|
| 70 |
+
actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
|
| 71 |
+
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
|
| 72 |
+
""" # noqa: B950
|
| 73 |
+
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
|
| 74 |
+
|
| 75 |
+
def capture_end(self):
|
| 76 |
+
r"""End CUDA graph capture on the current stream.
|
| 77 |
+
|
| 78 |
+
After ``capture_end``, ``replay`` may be called on this instance.
|
| 79 |
+
|
| 80 |
+
Typically, you shouldn't call ``capture_end`` yourself.
|
| 81 |
+
Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
|
| 82 |
+
which call ``capture_end`` internally.
|
| 83 |
+
"""
|
| 84 |
+
super().capture_end()
|
| 85 |
+
|
| 86 |
+
def replay(self):
|
| 87 |
+
r"""Replay the CUDA work captured by this graph."""
|
| 88 |
+
super().replay()
|
| 89 |
+
|
| 90 |
+
def reset(self):
|
| 91 |
+
r"""Delete the graph currently held by this instance."""
|
| 92 |
+
super().reset()
|
| 93 |
+
|
| 94 |
+
def pool(self):
|
| 95 |
+
r"""Return an opaque token representing the id of this graph's memory pool.
|
| 96 |
+
|
| 97 |
+
This id can optionally be passed to another graph's ``capture_begin``,
|
| 98 |
+
which hints the other graph may share the same memory pool.
|
| 99 |
+
"""
|
| 100 |
+
return super().pool()
|
| 101 |
+
|
| 102 |
+
def enable_debug_mode(self):
|
| 103 |
+
r"""Enable debugging mode for CUDAGraph.debug_dump."""
|
| 104 |
+
return super().enable_debug_mode()
|
| 105 |
+
|
| 106 |
+
def debug_dump(self, debug_path):
|
| 107 |
+
r"""
|
| 108 |
+
Arguments:
|
| 109 |
+
debug_path (required): Path to dump the graph to.
|
| 110 |
+
|
| 111 |
+
Calls a debugging function to dump the graph if the debugging is
|
| 112 |
+
enabled via CUDAGraph.enable_debug_mode()
|
| 113 |
+
"""
|
| 114 |
+
return super().debug_dump(debug_path)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class graph:
|
| 118 |
+
r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
|
| 119 |
+
|
| 120 |
+
See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
|
| 121 |
+
detailed use, and constraints.
|
| 122 |
+
|
| 123 |
+
Arguments:
|
| 124 |
+
cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
|
| 125 |
+
pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
|
| 126 |
+
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
|
| 127 |
+
may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
|
| 128 |
+
stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
|
| 129 |
+
If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
|
| 130 |
+
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
|
| 131 |
+
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
|
| 132 |
+
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
|
| 133 |
+
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
|
| 134 |
+
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
|
| 135 |
+
|
| 136 |
+
.. note::
|
| 137 |
+
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
|
| 138 |
+
used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
|
| 139 |
+
|
| 140 |
+
.. warning::
|
| 141 |
+
This API is in beta and may change in future releases.
|
| 142 |
+
|
| 143 |
+
.. _cudaStreamCaptureMode:
|
| 144 |
+
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
| 145 |
+
""" # noqa: B950
|
| 146 |
+
|
| 147 |
+
default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
cuda_graph,
|
| 152 |
+
pool=None,
|
| 153 |
+
stream=None,
|
| 154 |
+
capture_error_mode: str = "global",
|
| 155 |
+
):
|
| 156 |
+
# Lazy-init of default_capture_stream helps avoid circular-import errors.
|
| 157 |
+
# Not thread safe, but graphs already have the general (explicitly documented)
|
| 158 |
+
# restriction that only one capture may be underway at a time in the process.
|
| 159 |
+
if self.__class__.default_capture_stream is None:
|
| 160 |
+
self.__class__.default_capture_stream = torch.cuda.Stream()
|
| 161 |
+
|
| 162 |
+
self.pool = () if pool is None else (pool,)
|
| 163 |
+
self.capture_stream = (
|
| 164 |
+
stream if stream is not None else self.__class__.default_capture_stream
|
| 165 |
+
)
|
| 166 |
+
assert self.capture_stream is not None
|
| 167 |
+
self.stream_ctx = torch.cuda.stream(self.capture_stream)
|
| 168 |
+
self.cuda_graph = cuda_graph
|
| 169 |
+
self.capture_error_mode = capture_error_mode
|
| 170 |
+
|
| 171 |
+
def __enter__(self):
|
| 172 |
+
# Free as much memory as we can for the graph
|
| 173 |
+
torch.cuda.synchronize()
|
| 174 |
+
gc.collect()
|
| 175 |
+
torch.cuda.empty_cache()
|
| 176 |
+
|
| 177 |
+
# Stackoverflow seems comfortable with this pattern
|
| 178 |
+
# https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
|
| 179 |
+
self.stream_ctx.__enter__()
|
| 180 |
+
|
| 181 |
+
self.cuda_graph.capture_begin(
|
| 182 |
+
*self.pool, capture_error_mode=self.capture_error_mode
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 186 |
+
self.cuda_graph.capture_end()
|
| 187 |
+
self.stream_ctx.__exit__(exc_type, exc_value, traceback)
|
| 188 |
+
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def make_graphed_callables(
|
| 192 |
+
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
|
| 193 |
+
):
|
| 194 |
+
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
|
| 195 |
+
|
| 196 |
+
Each graphed callable's forward pass runs its source callable's
|
| 197 |
+
forward CUDA work as a CUDA graph inside a single autograd node.
|
| 198 |
+
|
| 199 |
+
The graphed callable's forward pass also appends
|
| 200 |
+
a backward node to the autograd graph. During backward, this node runs the
|
| 201 |
+
callable's backward work as a CUDA graph.
|
| 202 |
+
|
| 203 |
+
Therefore, each graphed callable should be a drop-in replacement for its source callable
|
| 204 |
+
in an autograd-enabled training loop.
|
| 205 |
+
|
| 206 |
+
See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
|
| 207 |
+
|
| 208 |
+
If you pass a tuple of several callables, their captures will use the same memory pool.
|
| 209 |
+
See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
|
| 210 |
+
|
| 211 |
+
Arguments:
|
| 212 |
+
callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
|
| 213 |
+
See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
|
| 214 |
+
is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
|
| 215 |
+
they'll run in the live workload.
|
| 216 |
+
sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
|
| 217 |
+
If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
|
| 218 |
+
If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
|
| 219 |
+
num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
|
| 220 |
+
11 iterations for warm up. Default: ``3``.
|
| 221 |
+
allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
|
| 222 |
+
(and therefore their grad is always zero) is an error. Defaults to False.
|
| 223 |
+
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
|
| 224 |
+
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
|
| 225 |
+
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
|
| 226 |
+
.. note::
|
| 227 |
+
The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
|
| 228 |
+
that's expected for the corresponding real input in the training loop.
|
| 229 |
+
|
| 230 |
+
.. warning::
|
| 231 |
+
This API is in beta and may change in future releases.
|
| 232 |
+
|
| 233 |
+
.. warning::
|
| 234 |
+
``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
|
| 235 |
+
|
| 236 |
+
.. warning::
|
| 237 |
+
Returned callables do not support higher order differentiation (e.g., double backward).
|
| 238 |
+
|
| 239 |
+
.. warning::
|
| 240 |
+
In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
|
| 241 |
+
may be trainable. Buffers must have ``requires_grad=False``.
|
| 242 |
+
|
| 243 |
+
.. warning::
|
| 244 |
+
After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
|
| 245 |
+
you may not add or remove any of that Module's parameters or buffers.
|
| 246 |
+
|
| 247 |
+
.. warning::
|
| 248 |
+
:class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
|
| 249 |
+
registered on them at the time they are passed. However, registering hooks on modules *after* passing them
|
| 250 |
+
through :func:`~torch.cuda.make_graphed_callables` is allowed.
|
| 251 |
+
|
| 252 |
+
.. warning::
|
| 253 |
+
When running a graphed callable, you must pass its arguments in the same order and format
|
| 254 |
+
they appeared in that callable's ``sample_args``.
|
| 255 |
+
|
| 256 |
+
.. warning::
|
| 257 |
+
The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
|
| 258 |
+
caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
|
| 259 |
+
"""
|
| 260 |
+
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
|
| 261 |
+
raise RuntimeError(
|
| 262 |
+
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
just_one_callable = False
|
| 266 |
+
|
| 267 |
+
if not isinstance(callables, tuple):
|
| 268 |
+
just_one_callable = True
|
| 269 |
+
callables = (callables,)
|
| 270 |
+
sample_args = (sample_args,)
|
| 271 |
+
|
| 272 |
+
flatten_sample_args = []
|
| 273 |
+
|
| 274 |
+
for c, args in zip(callables, sample_args):
|
| 275 |
+
if isinstance(c, torch.nn.Module):
|
| 276 |
+
assert (
|
| 277 |
+
len(c._backward_hooks) == 0
|
| 278 |
+
and len(c._forward_hooks) == 0
|
| 279 |
+
and len(c._forward_pre_hooks) == 0
|
| 280 |
+
), (
|
| 281 |
+
"Modules must not have hooks registered at the time they are passed. However, registering hooks "
|
| 282 |
+
+ "on modules after passing them through make_graphed_callables is allowed."
|
| 283 |
+
)
|
| 284 |
+
assert all(b.requires_grad is False for b in c.buffers()), (
|
| 285 |
+
"In any :class:`~torch.nn.Module` passed to "
|
| 286 |
+
+ ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
|
| 287 |
+
+ "``requires_grad=False``."
|
| 288 |
+
)
|
| 289 |
+
flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
|
| 290 |
+
flatten_sample_args.append(tuple(flatten_arg))
|
| 291 |
+
assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
|
| 292 |
+
"In the beta API, sample_args "
|
| 293 |
+
+ "for each callable must contain only Tensors. Other types are not allowed."
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
|
| 297 |
+
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
|
| 298 |
+
per_callable_len_user_args = [len(args) for args in flatten_sample_args]
|
| 299 |
+
per_callable_module_params = [
|
| 300 |
+
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
|
| 301 |
+
for c in callables
|
| 302 |
+
]
|
| 303 |
+
per_callable_static_input_surfaces = [
|
| 304 |
+
flatten_sample_args[i] + per_callable_module_params[i]
|
| 305 |
+
for i in range(len(callables))
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
|
| 309 |
+
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
|
| 310 |
+
|
| 311 |
+
mempool = graph_pool_handle() if pool is None else pool
|
| 312 |
+
|
| 313 |
+
# Warmup
|
| 314 |
+
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
|
| 315 |
+
# from ending up in any captures.
|
| 316 |
+
torch.cuda.synchronize()
|
| 317 |
+
with torch.cuda.stream(torch.cuda.Stream()):
|
| 318 |
+
for func, args, static_input_surface in zip(
|
| 319 |
+
callables, sample_args, per_callable_static_input_surfaces
|
| 320 |
+
):
|
| 321 |
+
grad_inputs, outputs, outputs_grad = None, None, None
|
| 322 |
+
for _ in range(num_warmup_iters):
|
| 323 |
+
outputs = torch.utils._pytree.tree_leaves(func(*args))
|
| 324 |
+
outputs_grad = tuple(o for o in outputs if o.requires_grad)
|
| 325 |
+
if len(outputs_grad) > 0:
|
| 326 |
+
grad_inputs = torch.autograd.grad(
|
| 327 |
+
outputs=outputs_grad,
|
| 328 |
+
inputs=tuple(
|
| 329 |
+
i for i in static_input_surface if i.requires_grad
|
| 330 |
+
),
|
| 331 |
+
grad_outputs=tuple(
|
| 332 |
+
torch.empty_like(o) for o in outputs if o.requires_grad
|
| 333 |
+
),
|
| 334 |
+
only_inputs=True,
|
| 335 |
+
allow_unused=allow_unused_input,
|
| 336 |
+
)
|
| 337 |
+
for v in [outputs, outputs_grad, grad_inputs]:
|
| 338 |
+
del v
|
| 339 |
+
|
| 340 |
+
torch.cuda.synchronize()
|
| 341 |
+
|
| 342 |
+
# All captures here share a mempool. To avoid replays corrupting each other's memory,
|
| 343 |
+
# the safest approach is to capture all passes in the same order they'll run:
|
| 344 |
+
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
|
| 345 |
+
|
| 346 |
+
# Capture forward graphs
|
| 347 |
+
per_callable_static_outputs = []
|
| 348 |
+
per_callable_output_unflatten_spec = []
|
| 349 |
+
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
|
| 350 |
+
with torch.cuda.graph(fwd_graph, pool=mempool):
|
| 351 |
+
outputs = func(*args)
|
| 352 |
+
|
| 353 |
+
flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
|
| 354 |
+
per_callable_static_outputs.append(tuple(flatten_outputs))
|
| 355 |
+
per_callable_output_unflatten_spec.append(spec)
|
| 356 |
+
|
| 357 |
+
# Capture backward graphs in reverse order
|
| 358 |
+
per_callable_static_grad_outputs = []
|
| 359 |
+
per_callable_static_grad_inputs = []
|
| 360 |
+
for static_input_surface, static_outputs, bwd_graph, module_params in zip(
|
| 361 |
+
reversed(per_callable_static_input_surfaces),
|
| 362 |
+
reversed(per_callable_static_outputs),
|
| 363 |
+
reversed(bwd_graphs),
|
| 364 |
+
reversed(per_callable_module_params),
|
| 365 |
+
):
|
| 366 |
+
# For now, assumes all static_outputs require grad
|
| 367 |
+
# assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
|
| 368 |
+
static_grad_outputs = tuple(
|
| 369 |
+
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
|
| 373 |
+
grad_inputs = None
|
| 374 |
+
if len(outputs_grad) > 0:
|
| 375 |
+
with torch.cuda.graph(bwd_graph, pool=mempool):
|
| 376 |
+
grad_inputs = torch.autograd.grad(
|
| 377 |
+
outputs=outputs_grad,
|
| 378 |
+
inputs=tuple(i for i in static_input_surface if i.requires_grad),
|
| 379 |
+
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
|
| 380 |
+
only_inputs=True,
|
| 381 |
+
allow_unused=allow_unused_input,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Constructs a tuple suitable for returning from Graphed.backward:
|
| 385 |
+
# Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
|
| 386 |
+
# I couldn't think of a slick one-liner for this pattern.
|
| 387 |
+
static_grad_inputs = []
|
| 388 |
+
grad_idx = 0
|
| 389 |
+
for arg in static_input_surface:
|
| 390 |
+
if arg.requires_grad and grad_inputs is not None:
|
| 391 |
+
static_grad_inputs.append(grad_inputs[grad_idx])
|
| 392 |
+
grad_idx += 1
|
| 393 |
+
else:
|
| 394 |
+
static_grad_inputs.append(None) # type: ignore[arg-type]
|
| 395 |
+
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
|
| 396 |
+
|
| 397 |
+
per_callable_static_grad_outputs.append(static_grad_outputs)
|
| 398 |
+
per_callable_static_grad_inputs.append(static_grad_inputs)
|
| 399 |
+
|
| 400 |
+
# Reverses the most recent two lists
|
| 401 |
+
per_callable_static_grad_outputs.reverse()
|
| 402 |
+
per_callable_static_grad_inputs.reverse()
|
| 403 |
+
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
|
| 404 |
+
|
| 405 |
+
def make_graphed_autograd_function(
|
| 406 |
+
fwd_graph,
|
| 407 |
+
bwd_graph,
|
| 408 |
+
module_params,
|
| 409 |
+
len_user_args,
|
| 410 |
+
output_unflatten_spec,
|
| 411 |
+
static_input_surface,
|
| 412 |
+
static_outputs,
|
| 413 |
+
static_grad_outputs,
|
| 414 |
+
static_grad_inputs,
|
| 415 |
+
):
|
| 416 |
+
class Graphed(torch.autograd.Function):
|
| 417 |
+
@staticmethod
|
| 418 |
+
def forward(ctx, *inputs):
|
| 419 |
+
# At this stage, only the user args may (potentially) be new tensors.
|
| 420 |
+
for i in range(len_user_args):
|
| 421 |
+
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
|
| 422 |
+
static_input_surface[i].copy_(inputs[i])
|
| 423 |
+
fwd_graph.replay()
|
| 424 |
+
assert isinstance(static_outputs, tuple)
|
| 425 |
+
return tuple(o.detach() for o in static_outputs)
|
| 426 |
+
|
| 427 |
+
@staticmethod
|
| 428 |
+
@torch.autograd.function.once_differentiable
|
| 429 |
+
def backward(ctx, *grads):
|
| 430 |
+
assert len(grads) == len(static_grad_outputs)
|
| 431 |
+
for g, grad in zip(static_grad_outputs, grads):
|
| 432 |
+
if g is not None:
|
| 433 |
+
# don't copy if autograd gods have been kind and the
|
| 434 |
+
# incoming grad is already in the right place
|
| 435 |
+
if g.data_ptr() != grad.data_ptr():
|
| 436 |
+
g.copy_(grad)
|
| 437 |
+
bwd_graph.replay()
|
| 438 |
+
|
| 439 |
+
# Input args that didn't require grad expect a None gradient.
|
| 440 |
+
assert isinstance(static_grad_inputs, tuple)
|
| 441 |
+
return tuple(
|
| 442 |
+
b.detach() if b is not None else b for b in static_grad_inputs
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
def functionalized(*user_args):
|
| 446 |
+
# Runs the autograd function with inputs == all inputs to the graph that might require grad
|
| 447 |
+
# (explicit user args + module parameters)
|
| 448 |
+
# Assumes module params didn't change since capture.
|
| 449 |
+
flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
|
| 450 |
+
out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
|
| 451 |
+
return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)
|
| 452 |
+
|
| 453 |
+
return functionalized
|
| 454 |
+
|
| 455 |
+
# Put together the final graphed callables
|
| 456 |
+
ret = []
|
| 457 |
+
for i, func in enumerate(callables):
|
| 458 |
+
graphed = make_graphed_autograd_function(
|
| 459 |
+
fwd_graphs[i],
|
| 460 |
+
bwd_graphs[i],
|
| 461 |
+
per_callable_module_params[i],
|
| 462 |
+
per_callable_len_user_args[i],
|
| 463 |
+
per_callable_output_unflatten_spec[i],
|
| 464 |
+
per_callable_static_input_surfaces[i],
|
| 465 |
+
per_callable_static_outputs[i],
|
| 466 |
+
per_callable_static_grad_outputs[i],
|
| 467 |
+
per_callable_static_grad_inputs[i],
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if isinstance(func, torch.nn.Module):
|
| 471 |
+
|
| 472 |
+
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
|
| 473 |
+
def new_fwd(*user_args):
|
| 474 |
+
# If the module's training-or-eval state matches what we graphed,
|
| 475 |
+
# run the graph, otherwise run the original forward method
|
| 476 |
+
if func.training == graph_training_state:
|
| 477 |
+
return graphed(*user_args)
|
| 478 |
+
else:
|
| 479 |
+
return orig_fwd(*user_args)
|
| 480 |
+
|
| 481 |
+
return new_fwd
|
| 482 |
+
|
| 483 |
+
func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment]
|
| 484 |
+
ret.append(func)
|
| 485 |
+
else:
|
| 486 |
+
ret.append(graphed)
|
| 487 |
+
|
| 488 |
+
if just_one_callable:
|
| 489 |
+
return ret[0]
|
| 490 |
+
|
| 491 |
+
return tuple(ret)
|
.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import re
|
| 3 |
+
from typing import Callable, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__: List[str] = []
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class _CodeParser:
|
| 13 |
+
def __init__(self, code_string: str):
|
| 14 |
+
optional_ws = r"\s*"
|
| 15 |
+
required_ws = r"\s+"
|
| 16 |
+
template_params = r"(?P<template_params>\<.+\>)"
|
| 17 |
+
return_type = r"(?P<return_type>\w+)"
|
| 18 |
+
function_name = r"(?P<function_name>\w+)"
|
| 19 |
+
function_params = r"(?P<function_params>\(.+\))"
|
| 20 |
+
function_body = r"(?P<function_body>\{.+\})"
|
| 21 |
+
|
| 22 |
+
pattern = (
|
| 23 |
+
optional_ws
|
| 24 |
+
+ "template"
|
| 25 |
+
+ optional_ws
|
| 26 |
+
+ template_params
|
| 27 |
+
+ optional_ws
|
| 28 |
+
+ return_type
|
| 29 |
+
+ required_ws
|
| 30 |
+
+ function_name
|
| 31 |
+
+ optional_ws
|
| 32 |
+
+ function_params
|
| 33 |
+
+ optional_ws
|
| 34 |
+
+ function_body
|
| 35 |
+
+ optional_ws
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
result = re.match(
|
| 39 |
+
pattern, code_string, re.DOTALL
|
| 40 |
+
) # DOTALL for matching multiline
|
| 41 |
+
|
| 42 |
+
if result is None:
|
| 43 |
+
raise Exception( # noqa: TRY002
|
| 44 |
+
f"Couldn't parse code, please check correctness:\n {code_string}"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
self.template_params = result["template_params"]
|
| 48 |
+
self.return_type = result["return_type"]
|
| 49 |
+
self.function_name = result["function_name"]
|
| 50 |
+
self.function_params = result["function_params"]
|
| 51 |
+
self.function_body = result["function_body"]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class _JittedFunction:
|
| 55 |
+
def __init__(
|
| 56 |
+
self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
|
| 57 |
+
):
|
| 58 |
+
self.code_string = code_string
|
| 59 |
+
|
| 60 |
+
assert (
|
| 61 |
+
return_by_ref or num_outputs == 1
|
| 62 |
+
), "Return by value only works for single output. "
|
| 63 |
+
self.return_by_ref = return_by_ref
|
| 64 |
+
self.num_outputs = num_outputs
|
| 65 |
+
|
| 66 |
+
parsed_code = _CodeParser(code_string)
|
| 67 |
+
self.kernel_name = parsed_code.function_name
|
| 68 |
+
|
| 69 |
+
self.kwargs_dict = kwargs
|
| 70 |
+
self.is_cuda_available = torch.cuda.is_available()
|
| 71 |
+
|
| 72 |
+
def __call__(self, *tensors: Tensor, **kwargs):
|
| 73 |
+
# Jiterator follow torch.cuda's lazy initialization behavior
|
| 74 |
+
# Defer checking cuda's availability at the function invocation time
|
| 75 |
+
assert (
|
| 76 |
+
self.is_cuda_available
|
| 77 |
+
), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
|
| 78 |
+
|
| 79 |
+
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
|
| 80 |
+
|
| 81 |
+
expanded_kwargs = self.kwargs_dict.copy()
|
| 82 |
+
for key, value in kwargs.items():
|
| 83 |
+
if key in self.kwargs_dict:
|
| 84 |
+
expanded_kwargs[key] = value
|
| 85 |
+
else:
|
| 86 |
+
raise KeyError(f"{key} is not declared in function definition")
|
| 87 |
+
|
| 88 |
+
return torch._C._cuda_jiterator_compile_and_launch_kernel(
|
| 89 |
+
self.code_string,
|
| 90 |
+
self.kernel_name,
|
| 91 |
+
self.return_by_ref,
|
| 92 |
+
self.num_outputs,
|
| 93 |
+
tensors,
|
| 94 |
+
expanded_kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
|
| 99 |
+
"""
|
| 100 |
+
Create a jiterator-generated cuda kernel for an elementwise op.
|
| 101 |
+
|
| 102 |
+
The code string has to be a valid CUDA function that describes the computation for a single element. The code
|
| 103 |
+
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
|
| 104 |
+
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
|
| 105 |
+
local temp dir.
|
| 106 |
+
|
| 107 |
+
Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
|
| 111 |
+
kwargs (Dict, optional): Keyword arguments for generated function
|
| 112 |
+
|
| 113 |
+
Example::
|
| 114 |
+
|
| 115 |
+
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
|
| 116 |
+
jitted_fn = create_jit_fn(code_string, alpha=1.0)
|
| 117 |
+
a = torch.rand(3, device='cuda')
|
| 118 |
+
b = torch.rand(3, device='cuda')
|
| 119 |
+
# invoke jitted function like a regular python function
|
| 120 |
+
result = jitted_fn(a, b, alpha=3.14)
|
| 121 |
+
|
| 122 |
+
code_string also allows multiple function definitions, and the last function will be treated as the entry function.
|
| 123 |
+
|
| 124 |
+
Example::
|
| 125 |
+
|
| 126 |
+
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
|
| 127 |
+
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
|
| 128 |
+
jitted_fn = create_jit_fn(code_string, val=0.0)
|
| 129 |
+
a = torch.rand(3, device='cuda')
|
| 130 |
+
b = torch.rand(3, device='cuda')
|
| 131 |
+
# invoke jitted function like a regular python function
|
| 132 |
+
result = jitted_fn(a, b) # using default val=0.0
|
| 133 |
+
|
| 134 |
+
Jiterator can be used together with python registration to override an operator's cuda kernel.
|
| 135 |
+
Following example is overriding gelu's cuda kernel with relu.
|
| 136 |
+
|
| 137 |
+
Example::
|
| 138 |
+
|
| 139 |
+
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
|
| 140 |
+
my_gelu = create_jit_fn(code_string)
|
| 141 |
+
my_lib = torch.library.Library("aten", "IMPL")
|
| 142 |
+
my_lib.impl('aten::gelu', my_gelu, "CUDA")
|
| 143 |
+
# torch.nn.GELU and torch.nn.function.gelu are now overridden
|
| 144 |
+
a = torch.rand(3, device='cuda')
|
| 145 |
+
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
|
| 146 |
+
|
| 147 |
+
.. warning::
|
| 148 |
+
This API is in beta and may change in future releases.
|
| 149 |
+
|
| 150 |
+
.. warning::
|
| 151 |
+
This API only supports up to 8 inputs and 1 output
|
| 152 |
+
|
| 153 |
+
.. warning::
|
| 154 |
+
All input tensors must live in CUDA device
|
| 155 |
+
"""
|
| 156 |
+
return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _create_multi_output_jit_fn(
|
| 160 |
+
code_string: str, num_outputs: int, **kwargs
|
| 161 |
+
) -> Callable:
|
| 162 |
+
"""
|
| 163 |
+
Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
|
| 167 |
+
num_outputs(int): number of outputs return by the kernel
|
| 168 |
+
kwargs (Dict, optional): Keyword arguments for generated function
|
| 169 |
+
|
| 170 |
+
Example::
|
| 171 |
+
|
| 172 |
+
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
|
| 173 |
+
jitted_fn = create_jit_fn(code_string, alpha=1.0)
|
| 174 |
+
a = torch.rand(3, device='cuda')
|
| 175 |
+
b = torch.rand(3, device='cuda')
|
| 176 |
+
# invoke jitted function like a regular python function
|
| 177 |
+
result = jitted_fn(a, b, alpha=3.14)
|
| 178 |
+
|
| 179 |
+
.. warning::
|
| 180 |
+
This API is in beta and may change in future releases.
|
| 181 |
+
|
| 182 |
+
.. warning::
|
| 183 |
+
This API only supports up to 8 inputs and 8 outputs
|
| 184 |
+
"""
|
| 185 |
+
return _JittedFunction(
|
| 186 |
+
code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
|
| 187 |
+
)
|
.venv/lib/python3.11/site-packages/torch/cuda/memory.py
ADDED
|
@@ -0,0 +1,1041 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""This package adds support for device memory management implemented in CUDA."""
|
| 3 |
+
|
| 4 |
+
import collections
|
| 5 |
+
import contextlib
|
| 6 |
+
import ctypes
|
| 7 |
+
import pickle
|
| 8 |
+
import sys
|
| 9 |
+
import warnings
|
| 10 |
+
from inspect import signature
|
| 11 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 12 |
+
from typing_extensions import deprecated
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import _C
|
| 16 |
+
from torch._utils import _dummy_type
|
| 17 |
+
from torch.types import Device
|
| 18 |
+
|
| 19 |
+
from . import (
|
| 20 |
+
_get_amdsmi_device_index,
|
| 21 |
+
_get_device_index,
|
| 22 |
+
_get_nvml_device_index,
|
| 23 |
+
_lazy_init,
|
| 24 |
+
is_initialized,
|
| 25 |
+
)
|
| 26 |
+
from ._memory_viz import memory as _memory, segments as _segments
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"caching_allocator_alloc",
|
| 31 |
+
"caching_allocator_delete",
|
| 32 |
+
"set_per_process_memory_fraction",
|
| 33 |
+
"empty_cache",
|
| 34 |
+
"memory_stats",
|
| 35 |
+
"memory_stats_as_nested_dict",
|
| 36 |
+
"reset_accumulated_memory_stats",
|
| 37 |
+
"reset_peak_memory_stats",
|
| 38 |
+
"reset_max_memory_allocated",
|
| 39 |
+
"reset_max_memory_cached",
|
| 40 |
+
"memory_allocated",
|
| 41 |
+
"max_memory_allocated",
|
| 42 |
+
"memory_reserved",
|
| 43 |
+
"max_memory_reserved",
|
| 44 |
+
"memory_cached",
|
| 45 |
+
"max_memory_cached",
|
| 46 |
+
"memory_snapshot",
|
| 47 |
+
"memory_summary",
|
| 48 |
+
"list_gpu_processes",
|
| 49 |
+
"mem_get_info",
|
| 50 |
+
"get_allocator_backend",
|
| 51 |
+
"CUDAPluggableAllocator",
|
| 52 |
+
"change_current_allocator",
|
| 53 |
+
"MemPool",
|
| 54 |
+
"MemPoolContext",
|
| 55 |
+
"use_mem_pool",
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if not hasattr(torch._C, "_cuda_CUDAAllocator"):
|
| 60 |
+
# Define dummy base classes
|
| 61 |
+
torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if not hasattr(torch._C, "_MemPool"):
|
| 65 |
+
# Define dummy base classes
|
| 66 |
+
torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
|
| 67 |
+
torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
|
| 68 |
+
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
|
| 69 |
+
"_cuda_beginAllocateToPool"
|
| 70 |
+
)
|
| 71 |
+
torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type(
|
| 72 |
+
"_cuda_endAllocateCurrentStreamToPool"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
from torch._C import ( # noqa: F401
|
| 76 |
+
_cuda_beginAllocateToPool,
|
| 77 |
+
_cuda_CUDAAllocator,
|
| 78 |
+
_cuda_endAllocateCurrentStreamToPool,
|
| 79 |
+
_MemPool,
|
| 80 |
+
_MemPoolContext,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _host_allocator():
|
| 85 |
+
_lazy_init()
|
| 86 |
+
return torch._C._cuda_cudaHostAllocator()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@contextlib.contextmanager
|
| 90 |
+
def _free_mutex():
|
| 91 |
+
torch._C._cuda_lock_mutex()
|
| 92 |
+
try:
|
| 93 |
+
yield
|
| 94 |
+
finally:
|
| 95 |
+
torch._C._cuda_unlock_mutex()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None):
|
| 99 |
+
r"""Perform a memory allocation using the CUDA memory allocator.
|
| 100 |
+
|
| 101 |
+
Memory is allocated for a given device and a stream, this
|
| 102 |
+
function is intended to be used for interoperability with other
|
| 103 |
+
frameworks. Allocated memory is released through
|
| 104 |
+
:func:`~torch.cuda.caching_allocator_delete`.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
size (int): number of bytes to be allocated.
|
| 108 |
+
device (torch.device or int, optional): selected device. If it is
|
| 109 |
+
``None`` the default CUDA device is used.
|
| 110 |
+
stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
|
| 111 |
+
the default stream for the selected device is used.
|
| 112 |
+
|
| 113 |
+
.. note::
|
| 114 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 115 |
+
management.
|
| 116 |
+
"""
|
| 117 |
+
if device is None:
|
| 118 |
+
device = torch.cuda.current_device()
|
| 119 |
+
device = _get_device_index(device)
|
| 120 |
+
if stream is None:
|
| 121 |
+
stream = torch.cuda.current_stream(device)
|
| 122 |
+
if isinstance(stream, torch.cuda.streams.Stream):
|
| 123 |
+
stream = stream.cuda_stream
|
| 124 |
+
if not isinstance(stream, int):
|
| 125 |
+
raise TypeError(
|
| 126 |
+
"Invalid type for stream argument, must be "
|
| 127 |
+
"`torch.cuda.Stream` or `int` representing a pointer "
|
| 128 |
+
"to a existing stream"
|
| 129 |
+
)
|
| 130 |
+
with torch.cuda.device(device):
|
| 131 |
+
return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def caching_allocator_delete(mem_ptr):
|
| 135 |
+
r"""Delete memory allocated using the CUDA memory allocator.
|
| 136 |
+
|
| 137 |
+
Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`.
|
| 138 |
+
is freed here. The associated device and stream are tracked inside
|
| 139 |
+
the allocator.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
mem_ptr (int): memory address to be freed by the allocator.
|
| 143 |
+
|
| 144 |
+
.. note::
|
| 145 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 146 |
+
management.
|
| 147 |
+
"""
|
| 148 |
+
torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def set_per_process_memory_fraction(
|
| 152 |
+
fraction, device: Union[Device, int] = None
|
| 153 |
+
) -> None:
|
| 154 |
+
r"""Set memory fraction for a process.
|
| 155 |
+
|
| 156 |
+
The fraction is used to limit an caching allocator to allocated memory on a CUDA device.
|
| 157 |
+
The allowed value equals the total visible memory multiplied fraction.
|
| 158 |
+
If trying to allocate more than the allowed value in a process, will raise an out of
|
| 159 |
+
memory error in allocator.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
|
| 163 |
+
device (torch.device or int, optional): selected device. If it is
|
| 164 |
+
``None`` the default CUDA device is used.
|
| 165 |
+
.. note::
|
| 166 |
+
In general, the total available free memory is less than the total capacity.
|
| 167 |
+
"""
|
| 168 |
+
_lazy_init()
|
| 169 |
+
if device is None:
|
| 170 |
+
device = torch.cuda.current_device()
|
| 171 |
+
device = _get_device_index(device)
|
| 172 |
+
if not isinstance(fraction, float):
|
| 173 |
+
raise TypeError("Invalid type for fraction argument, must be `float`")
|
| 174 |
+
if fraction < 0 or fraction > 1:
|
| 175 |
+
raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1")
|
| 176 |
+
|
| 177 |
+
torch._C._cuda_setMemoryFraction(fraction, device)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def empty_cache() -> None:
|
| 181 |
+
r"""Release all unoccupied cached memory currently held by the caching
|
| 182 |
+
allocator so that those can be used in other GPU application and visible in
|
| 183 |
+
`nvidia-smi`.
|
| 184 |
+
|
| 185 |
+
.. note::
|
| 186 |
+
:func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
|
| 187 |
+
memory available for PyTorch. However, it may help reduce fragmentation
|
| 188 |
+
of GPU memory in certain cases. See :ref:`cuda-memory-management` for
|
| 189 |
+
more details about GPU memory management.
|
| 190 |
+
"""
|
| 191 |
+
if is_initialized():
|
| 192 |
+
torch._C._cuda_emptyCache()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
|
| 196 |
+
r"""Return a dictionary of CUDA memory allocator statistics for a given device.
|
| 197 |
+
|
| 198 |
+
The return value of this function is a dictionary of statistics, each of
|
| 199 |
+
which is a non-negative integer.
|
| 200 |
+
|
| 201 |
+
Core statistics:
|
| 202 |
+
|
| 203 |
+
- ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 204 |
+
number of allocation requests received by the memory allocator.
|
| 205 |
+
- ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 206 |
+
amount of allocated memory.
|
| 207 |
+
- ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 208 |
+
number of reserved segments from ``cudaMalloc()``.
|
| 209 |
+
- ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 210 |
+
amount of reserved memory.
|
| 211 |
+
- ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 212 |
+
number of active memory blocks.
|
| 213 |
+
- ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 214 |
+
amount of active memory.
|
| 215 |
+
- ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 216 |
+
number of inactive, non-releasable memory blocks.
|
| 217 |
+
- ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 218 |
+
amount of inactive, non-releasable memory.
|
| 219 |
+
|
| 220 |
+
For these core statistics, values are broken down as follows.
|
| 221 |
+
|
| 222 |
+
Pool type:
|
| 223 |
+
|
| 224 |
+
- ``all``: combined statistics across all memory pools.
|
| 225 |
+
- ``large_pool``: statistics for the large allocation pool
|
| 226 |
+
(as of October 2019, for size >= 1MB allocations).
|
| 227 |
+
- ``small_pool``: statistics for the small allocation pool
|
| 228 |
+
(as of October 2019, for size < 1MB allocations).
|
| 229 |
+
|
| 230 |
+
Metric type:
|
| 231 |
+
|
| 232 |
+
- ``current``: current value of this metric.
|
| 233 |
+
- ``peak``: maximum value of this metric.
|
| 234 |
+
- ``allocated``: historical total increase in this metric.
|
| 235 |
+
- ``freed``: historical total decrease in this metric.
|
| 236 |
+
|
| 237 |
+
In addition to the core statistics, we also provide some simple event
|
| 238 |
+
counters:
|
| 239 |
+
|
| 240 |
+
- ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that
|
| 241 |
+
result in a cache flush and retry.
|
| 242 |
+
- ``"num_ooms"``: number of out-of-memory errors thrown.
|
| 243 |
+
- ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls.
|
| 244 |
+
- ``"num_device_alloc"``: number of CUDA allocation calls. This includes both
|
| 245 |
+
cuMemMap and cudaMalloc.
|
| 246 |
+
- ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap
|
| 247 |
+
and cudaFree.
|
| 248 |
+
|
| 249 |
+
The caching allocator can be configured via ENV to not split blocks larger than a
|
| 250 |
+
defined size (see Memory Management section of the Cuda Semantics documentation).
|
| 251 |
+
This helps avoid memory fragmentation but may have a performance
|
| 252 |
+
penalty. Additional outputs to assist with tuning and evaluating impact:
|
| 253 |
+
|
| 254 |
+
- ``"max_split_size"``: blocks above this size will not be split.
|
| 255 |
+
- ``"oversize_allocations.{current,peak,allocated,freed}"``:
|
| 256 |
+
number of over-size allocation requests received by the memory allocator.
|
| 257 |
+
- ``"oversize_segments.{current,peak,allocated,freed}"``:
|
| 258 |
+
number of over-size reserved segments from ``cudaMalloc()``.
|
| 259 |
+
|
| 260 |
+
The caching allocator can be configured via ENV to round memory allocations in order
|
| 261 |
+
to reduce fragmentation. Sometimes the overhead from rounding can be higher than
|
| 262 |
+
the fragmentation it helps reduce. The following stat can be used to check if
|
| 263 |
+
rounding adds too much overhead:
|
| 264 |
+
|
| 265 |
+
- ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
| 266 |
+
memory requested by client code, compare this with allocated_bytes to check if
|
| 267 |
+
allocation rounding adds too much overhead.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
device (torch.device or int, optional): selected device. Returns
|
| 271 |
+
statistics for the current device, given by :func:`~torch.cuda.current_device`,
|
| 272 |
+
if :attr:`device` is ``None`` (default).
|
| 273 |
+
|
| 274 |
+
.. note::
|
| 275 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 276 |
+
management.
|
| 277 |
+
|
| 278 |
+
.. note::
|
| 279 |
+
With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
|
| 280 |
+
meaningful, and are always reported as zero.
|
| 281 |
+
"""
|
| 282 |
+
result = []
|
| 283 |
+
|
| 284 |
+
def _recurse_add_to_result(prefix, obj):
|
| 285 |
+
if isinstance(obj, dict):
|
| 286 |
+
if len(prefix) > 0:
|
| 287 |
+
prefix += "."
|
| 288 |
+
for k, v in obj.items():
|
| 289 |
+
_recurse_add_to_result(prefix + k, v)
|
| 290 |
+
else:
|
| 291 |
+
result.append((prefix, obj))
|
| 292 |
+
|
| 293 |
+
stats = memory_stats_as_nested_dict(device=device)
|
| 294 |
+
_recurse_add_to_result("", stats)
|
| 295 |
+
result.sort()
|
| 296 |
+
|
| 297 |
+
return collections.OrderedDict(result)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
|
| 301 |
+
r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
|
| 302 |
+
if not is_initialized():
|
| 303 |
+
return {}
|
| 304 |
+
device = _get_device_index(device, optional=True)
|
| 305 |
+
return torch._C._cuda_memoryStats(device)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None:
|
| 309 |
+
r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator.
|
| 310 |
+
|
| 311 |
+
See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
|
| 312 |
+
the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
|
| 313 |
+
`"num_alloc_retries"` and `"num_ooms"`.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
device (torch.device or int, optional): selected device. Returns
|
| 317 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 318 |
+
if :attr:`device` is ``None`` (default).
|
| 319 |
+
|
| 320 |
+
.. note::
|
| 321 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 322 |
+
management.
|
| 323 |
+
"""
|
| 324 |
+
device = _get_device_index(device, optional=True)
|
| 325 |
+
return torch._C._cuda_resetAccumulatedMemoryStats(device)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
|
| 329 |
+
r"""Reset the "peak" stats tracked by the CUDA memory allocator.
|
| 330 |
+
|
| 331 |
+
See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
|
| 332 |
+
`"peak"` key in each individual stat dict.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
device (torch.device or int, optional): selected device. Returns
|
| 336 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 337 |
+
if :attr:`device` is ``None`` (default).
|
| 338 |
+
|
| 339 |
+
.. note::
|
| 340 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 341 |
+
management.
|
| 342 |
+
"""
|
| 343 |
+
device = _get_device_index(device, optional=True)
|
| 344 |
+
return torch._C._cuda_resetPeakMemoryStats(device)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
|
| 348 |
+
r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device.
|
| 349 |
+
|
| 350 |
+
See :func:`~torch.cuda.max_memory_allocated` for details.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
device (torch.device or int, optional): selected device. Returns
|
| 354 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 355 |
+
if :attr:`device` is ``None`` (default).
|
| 356 |
+
|
| 357 |
+
.. warning::
|
| 358 |
+
This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
|
| 359 |
+
/all/ peak memory stats.
|
| 360 |
+
|
| 361 |
+
.. note::
|
| 362 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 363 |
+
management.
|
| 364 |
+
"""
|
| 365 |
+
warnings.warn(
|
| 366 |
+
"torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
|
| 367 |
+
"which resets /all/ peak memory stats.",
|
| 368 |
+
FutureWarning,
|
| 369 |
+
)
|
| 370 |
+
return reset_peak_memory_stats(device=device)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def reset_max_memory_cached(device: Union[Device, int] = None) -> None:
|
| 374 |
+
r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
|
| 375 |
+
|
| 376 |
+
See :func:`~torch.cuda.max_memory_cached` for details.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
device (torch.device or int, optional): selected device. Returns
|
| 380 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 381 |
+
if :attr:`device` is ``None`` (default).
|
| 382 |
+
|
| 383 |
+
.. warning::
|
| 384 |
+
This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
|
| 385 |
+
/all/ peak memory stats.
|
| 386 |
+
|
| 387 |
+
.. note::
|
| 388 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 389 |
+
management.
|
| 390 |
+
"""
|
| 391 |
+
warnings.warn(
|
| 392 |
+
"torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
|
| 393 |
+
"which resets /all/ peak memory stats.",
|
| 394 |
+
FutureWarning,
|
| 395 |
+
)
|
| 396 |
+
return reset_peak_memory_stats(device=device)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def memory_allocated(device: Union[Device, int] = None) -> int:
|
| 400 |
+
r"""Return the current GPU memory occupied by tensors in bytes for a given device.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
device (torch.device or int, optional): selected device. Returns
|
| 404 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 405 |
+
if :attr:`device` is ``None`` (default).
|
| 406 |
+
|
| 407 |
+
.. note::
|
| 408 |
+
This is likely less than the amount shown in `nvidia-smi` since some
|
| 409 |
+
unused memory can be held by the caching allocator and some context
|
| 410 |
+
needs to be created on GPU. See :ref:`cuda-memory-management` for more
|
| 411 |
+
details about GPU memory management.
|
| 412 |
+
"""
|
| 413 |
+
return memory_stats(device=device).get("allocated_bytes.all.current", 0)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def max_memory_allocated(device: Union[Device, int] = None) -> int:
|
| 417 |
+
r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
|
| 418 |
+
|
| 419 |
+
By default, this returns the peak allocated memory since the beginning of
|
| 420 |
+
this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to
|
| 421 |
+
reset the starting point in tracking this metric. For example, these two
|
| 422 |
+
functions can measure the peak allocated memory usage of each iteration in a
|
| 423 |
+
training loop.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
device (torch.device or int, optional): selected device. Returns
|
| 427 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 428 |
+
if :attr:`device` is ``None`` (default).
|
| 429 |
+
|
| 430 |
+
.. note::
|
| 431 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 432 |
+
management.
|
| 433 |
+
"""
|
| 434 |
+
return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def memory_reserved(device: Union[Device, int] = None) -> int:
|
| 438 |
+
r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
device (torch.device or int, optional): selected device. Returns
|
| 442 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 443 |
+
if :attr:`device` is ``None`` (default).
|
| 444 |
+
|
| 445 |
+
.. note::
|
| 446 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 447 |
+
management.
|
| 448 |
+
"""
|
| 449 |
+
return memory_stats(device=device).get("reserved_bytes.all.current", 0)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def max_memory_reserved(device: Union[Device, int] = None) -> int:
|
| 453 |
+
r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
|
| 454 |
+
|
| 455 |
+
By default, this returns the peak cached memory since the beginning of this
|
| 456 |
+
program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset
|
| 457 |
+
the starting point in tracking this metric. For example, these two functions
|
| 458 |
+
can measure the peak cached memory amount of each iteration in a training
|
| 459 |
+
loop.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
device (torch.device or int, optional): selected device. Returns
|
| 463 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 464 |
+
if :attr:`device` is ``None`` (default).
|
| 465 |
+
|
| 466 |
+
.. note::
|
| 467 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 468 |
+
management.
|
| 469 |
+
"""
|
| 470 |
+
return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
@deprecated(
|
| 474 |
+
"`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`",
|
| 475 |
+
category=FutureWarning,
|
| 476 |
+
)
|
| 477 |
+
def memory_cached(device: Union[Device, int] = None) -> int:
|
| 478 |
+
r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
|
| 479 |
+
return memory_reserved(device=device)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
@deprecated(
|
| 483 |
+
"`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`",
|
| 484 |
+
category=FutureWarning,
|
| 485 |
+
)
|
| 486 |
+
def max_memory_cached(device: Union[Device, int] = None) -> int:
|
| 487 |
+
r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
|
| 488 |
+
return max_memory_reserved(device=device)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def memory_snapshot():
|
| 492 |
+
r"""Return a snapshot of the CUDA memory allocator state across all devices.
|
| 493 |
+
|
| 494 |
+
Interpreting the output of this function requires familiarity with the
|
| 495 |
+
memory allocator internals.
|
| 496 |
+
|
| 497 |
+
.. note::
|
| 498 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 499 |
+
management.
|
| 500 |
+
"""
|
| 501 |
+
return torch._C._cuda_memorySnapshot()["segments"]
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str:
|
| 505 |
+
r"""Return a human-readable printout of the current memory allocator statistics for a given device.
|
| 506 |
+
|
| 507 |
+
This can be useful to display periodically during training, or when
|
| 508 |
+
handling out-of-memory exceptions.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
device (torch.device or int, optional): selected device. Returns
|
| 512 |
+
printout for the current device, given by :func:`~torch.cuda.current_device`,
|
| 513 |
+
if :attr:`device` is ``None`` (default).
|
| 514 |
+
abbreviated (bool, optional): whether to return an abbreviated summary
|
| 515 |
+
(default: False).
|
| 516 |
+
|
| 517 |
+
.. note::
|
| 518 |
+
See :ref:`cuda-memory-management` for more details about GPU memory
|
| 519 |
+
management.
|
| 520 |
+
"""
|
| 521 |
+
device = _get_device_index(device, optional=True)
|
| 522 |
+
stats = memory_stats(device=device)
|
| 523 |
+
|
| 524 |
+
def _format_size(sz, pref_sz):
|
| 525 |
+
prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"]
|
| 526 |
+
prefix = prefixes[0]
|
| 527 |
+
for new_prefix in prefixes[1:]:
|
| 528 |
+
if pref_sz < 768 * 1024:
|
| 529 |
+
break
|
| 530 |
+
prefix = new_prefix
|
| 531 |
+
sz //= 1024
|
| 532 |
+
pref_sz /= 1024
|
| 533 |
+
return f"{sz:6d} {prefix}"
|
| 534 |
+
|
| 535 |
+
def _format_count(cnt, pref_cnt):
|
| 536 |
+
prefixes = [" ", "K", "M"]
|
| 537 |
+
prefix = prefixes[0]
|
| 538 |
+
for new_prefix in prefixes[1:]:
|
| 539 |
+
if pref_cnt < 750 * 1000:
|
| 540 |
+
break
|
| 541 |
+
prefix = new_prefix
|
| 542 |
+
cnt //= 1000
|
| 543 |
+
pref_cnt /= 1000
|
| 544 |
+
return f"{cnt:7d} {prefix} "
|
| 545 |
+
|
| 546 |
+
metrics_to_display = [
|
| 547 |
+
("allocated_bytes", "Allocated memory", _format_size),
|
| 548 |
+
("active_bytes", "Active memory", _format_size),
|
| 549 |
+
("requested_bytes", "Requested memory", _format_size),
|
| 550 |
+
("reserved_bytes", "GPU reserved memory", _format_size),
|
| 551 |
+
("inactive_split_bytes", "Non-releasable memory", _format_size),
|
| 552 |
+
("allocation", "Allocations", _format_count),
|
| 553 |
+
("active", "Active allocs", _format_count),
|
| 554 |
+
("segment", "GPU reserved segments", _format_count),
|
| 555 |
+
("inactive_split", "Non-releasable allocs", _format_count),
|
| 556 |
+
]
|
| 557 |
+
|
| 558 |
+
lines = []
|
| 559 |
+
lines.append("=" * 75)
|
| 560 |
+
lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ")
|
| 561 |
+
lines.append("-" * 75)
|
| 562 |
+
lines.append(
|
| 563 |
+
" {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} "
|
| 564 |
+
)
|
| 565 |
+
lines.append("=" * 75)
|
| 566 |
+
lines.append(
|
| 567 |
+
" Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
for metric_key, metric_name, formatter in metrics_to_display:
|
| 571 |
+
lines.append("-" * 75)
|
| 572 |
+
submetrics = [("all", metric_name)]
|
| 573 |
+
if not abbreviated:
|
| 574 |
+
submetrics.append(("large_pool", " from large pool"))
|
| 575 |
+
submetrics.append(("small_pool", " from small pool"))
|
| 576 |
+
|
| 577 |
+
current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
|
| 578 |
+
None,
|
| 579 |
+
None,
|
| 580 |
+
None,
|
| 581 |
+
None,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
for submetric_key, submetric_name in submetrics:
|
| 585 |
+
prefix = metric_key + "." + submetric_key + "."
|
| 586 |
+
|
| 587 |
+
current = stats[prefix + "current"]
|
| 588 |
+
peak = stats[prefix + "peak"]
|
| 589 |
+
allocated = stats[prefix + "allocated"]
|
| 590 |
+
freed = stats[prefix + "freed"]
|
| 591 |
+
|
| 592 |
+
if current_prefval is None:
|
| 593 |
+
current_prefval = current
|
| 594 |
+
peak_prefval = peak
|
| 595 |
+
allocated_prefval = allocated
|
| 596 |
+
freed_prefval = freed
|
| 597 |
+
|
| 598 |
+
lines.append(
|
| 599 |
+
f" {submetric_name:<21} | {formatter(current, current_prefval)} | {formatter(peak, peak_prefval)} | "
|
| 600 |
+
f"{formatter(allocated, allocated_prefval)} | {formatter(freed, freed_prefval)} ",
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
metrics_to_display = [
|
| 604 |
+
("oversize_allocations", "Oversize allocations", _format_count),
|
| 605 |
+
("oversize_segments", "Oversize GPU segments", _format_count),
|
| 606 |
+
]
|
| 607 |
+
|
| 608 |
+
for metric_key, metric_name, formatter in metrics_to_display:
|
| 609 |
+
lines.append("-" * 75)
|
| 610 |
+
|
| 611 |
+
prefix = metric_key + "."
|
| 612 |
+
|
| 613 |
+
current = stats[prefix + "current"]
|
| 614 |
+
peak = stats[prefix + "peak"]
|
| 615 |
+
allocated = stats[prefix + "allocated"]
|
| 616 |
+
freed = stats[prefix + "freed"]
|
| 617 |
+
|
| 618 |
+
lines.append(
|
| 619 |
+
f" {metric_name:<21} | {formatter(current, current)} | {formatter(peak, peak)} | "
|
| 620 |
+
f"{formatter(allocated, allocated)} | {formatter(freed, freed)} ",
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
lines.append("=" * 75)
|
| 624 |
+
|
| 625 |
+
fmt_dict = {"_": "", "device": device}
|
| 626 |
+
for k, v in stats.items():
|
| 627 |
+
fmt_dict[k.replace(".", "-")] = v
|
| 628 |
+
return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def list_gpu_processes(device: Union[Device, int] = None) -> str:
|
| 632 |
+
r"""Return a human-readable printout of the running processes and their GPU memory use for a given device.
|
| 633 |
+
|
| 634 |
+
This can be useful to display periodically during training, or when
|
| 635 |
+
handling out-of-memory exceptions.
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
device (torch.device or int, optional): selected device. Returns
|
| 639 |
+
printout for the current device, given by :func:`~torch.cuda.current_device`,
|
| 640 |
+
if :attr:`device` is ``None`` (default).
|
| 641 |
+
"""
|
| 642 |
+
if not torch.version.hip:
|
| 643 |
+
try:
|
| 644 |
+
import pynvml # type: ignore[import]
|
| 645 |
+
except ModuleNotFoundError:
|
| 646 |
+
return "pynvml module not found, please install pynvml"
|
| 647 |
+
from pynvml import NVMLError_DriverNotLoaded
|
| 648 |
+
|
| 649 |
+
try:
|
| 650 |
+
pynvml.nvmlInit()
|
| 651 |
+
except NVMLError_DriverNotLoaded:
|
| 652 |
+
return "cuda driver can't be loaded, is cuda enabled?"
|
| 653 |
+
|
| 654 |
+
device = _get_nvml_device_index(device)
|
| 655 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 656 |
+
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
| 657 |
+
else:
|
| 658 |
+
try:
|
| 659 |
+
import amdsmi # type: ignore[import]
|
| 660 |
+
except ModuleNotFoundError:
|
| 661 |
+
return "amdsmi module not found, please install amdsmi"
|
| 662 |
+
try:
|
| 663 |
+
amdsmi.amdsmi_init() # type: ignore[attr-defined]
|
| 664 |
+
except amdsmi.AmdSmiException: # type: ignore[attr-defined]
|
| 665 |
+
return "amdsmi driver can't be loaded, is ROCm installed?"
|
| 666 |
+
|
| 667 |
+
device = _get_amdsmi_device_index(device)
|
| 668 |
+
|
| 669 |
+
try:
|
| 670 |
+
handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined]
|
| 671 |
+
procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined]
|
| 672 |
+
except amdsmi.AmdSmiException: # type: ignore[attr-defined]
|
| 673 |
+
return "amdsmi cannot list processes from other users"
|
| 674 |
+
|
| 675 |
+
lines = []
|
| 676 |
+
lines.append(f"GPU:{device}")
|
| 677 |
+
if len(procs) == 0:
|
| 678 |
+
lines.append("no processes are running")
|
| 679 |
+
for p in procs:
|
| 680 |
+
if not torch.version.hip:
|
| 681 |
+
mem = p.usedGpuMemory / (1024 * 1024)
|
| 682 |
+
pid = p.pid
|
| 683 |
+
else:
|
| 684 |
+
try:
|
| 685 |
+
proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined]
|
| 686 |
+
except AttributeError:
|
| 687 |
+
# https://github.com/ROCm/amdsmi/commit/c551c3caedbd903ba828e7fdffa5b56d475a15e7
|
| 688 |
+
# is a BC-breaking change that removes amdsmi_get_gpu_process_info API from amdsmi
|
| 689 |
+
proc_info = p
|
| 690 |
+
mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024)
|
| 691 |
+
pid = proc_info["pid"]
|
| 692 |
+
lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory")
|
| 693 |
+
return "\n".join(lines)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
|
| 697 |
+
r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
device (torch.device or int or str, optional): selected device. Returns
|
| 701 |
+
statistic for the current device, given by :func:`~torch.cuda.current_device`,
|
| 702 |
+
if :attr:`device` is ``None`` (default) or if the device index is not specified.
|
| 703 |
+
|
| 704 |
+
.. note::
|
| 705 |
+
See :ref:`cuda-memory-management` for more
|
| 706 |
+
details about GPU memory management.
|
| 707 |
+
"""
|
| 708 |
+
if device is None:
|
| 709 |
+
device = torch.cuda.current_device()
|
| 710 |
+
# optional=True allows `device = torch.device('cuda')` for which device.index is None
|
| 711 |
+
device = _get_device_index(device, optional=True)
|
| 712 |
+
return torch.cuda.cudart().cudaMemGetInfo(device)
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def _record_memory_history_legacy(
|
| 716 |
+
enabled: bool,
|
| 717 |
+
record_context=True,
|
| 718 |
+
trace_alloc_max_entries=1,
|
| 719 |
+
trace_alloc_record_context=False,
|
| 720 |
+
device: Union[Device, int] = None,
|
| 721 |
+
record_context_cpp=False,
|
| 722 |
+
):
|
| 723 |
+
_C._cuda_record_memory_history_legacy(
|
| 724 |
+
enabled,
|
| 725 |
+
record_context,
|
| 726 |
+
trace_alloc_max_entries,
|
| 727 |
+
trace_alloc_record_context,
|
| 728 |
+
record_context_cpp,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def _record_memory_history(enabled="all", *args, **kwargs):
|
| 733 |
+
"""Enable recording of stack traces associated with memory
|
| 734 |
+
allocations, so you can tell what allocated any piece of memory in
|
| 735 |
+
:func:`torch.cuda.memory._snapshot()`.
|
| 736 |
+
|
| 737 |
+
In addition too keeping stack traces with each current allocation and free,
|
| 738 |
+
this will also enable recording of a history of all alloc/free events.
|
| 739 |
+
|
| 740 |
+
Use :func:`torch.cuda.memory._snapshot()` to retrieve this information,
|
| 741 |
+
and the tools in `_memory_viz.py` to visualize snapshots.
|
| 742 |
+
|
| 743 |
+
The Python trace collection is fast (2us per trace), so you may consider
|
| 744 |
+
enabling this on production jobs if you anticipate ever having to debug
|
| 745 |
+
memory issues.
|
| 746 |
+
|
| 747 |
+
C++ trace collection is also fast (~50ns/frame), which for many typical programs
|
| 748 |
+
works out to ~2us per trace, but can vary depending on stack depth.
|
| 749 |
+
|
| 750 |
+
Args:
|
| 751 |
+
enabled (Literal[None, "state", "all"], optional):
|
| 752 |
+
`None`, disable recording memory history.
|
| 753 |
+
`"state"`, keep information for currenly allocated memory.
|
| 754 |
+
`"all"`, additionally keep a history of all alloc/free calls.
|
| 755 |
+
Defaults to "all".
|
| 756 |
+
context (Literal[None, "state", "alloc", "all"], optional):
|
| 757 |
+
`None`, Do not record any tracebacks.
|
| 758 |
+
`"state"`, Record tracebacks for currently allocated memory.
|
| 759 |
+
`"alloc"`, additionally keep tracebacks for alloc calls.
|
| 760 |
+
`"all"`, additionally keep tracebacks for free calls.
|
| 761 |
+
Defaults to "all".
|
| 762 |
+
stacks (Literal["python", "all"], optional):
|
| 763 |
+
`"python"`, include Python, TorchScript, and inductor frames in tracebacks
|
| 764 |
+
`"all"`, additionally include C++ frames
|
| 765 |
+
Defaults to "all".
|
| 766 |
+
max_entries (int, optional): Keep a maximum of `max_entries`
|
| 767 |
+
alloc/free events in the recorded history recorded.
|
| 768 |
+
"""
|
| 769 |
+
if isinstance(enabled, bool):
|
| 770 |
+
return _record_memory_history_legacy(enabled, *args, **kwargs)
|
| 771 |
+
else:
|
| 772 |
+
return _record_memory_history_impl(enabled, *args, **kwargs)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def _record_memory_history_impl(
|
| 776 |
+
enabled: Optional[str] = "all",
|
| 777 |
+
context: Optional[str] = "all",
|
| 778 |
+
stacks: str = "all",
|
| 779 |
+
max_entries: int = sys.maxsize,
|
| 780 |
+
device: Union[Device, int] = None,
|
| 781 |
+
):
|
| 782 |
+
_C._cuda_record_memory_history(enabled, context, stacks, max_entries)
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def _snapshot(device: Union[Device, int] = None):
|
| 789 |
+
"""Save a snapshot of CUDA memory state at the time it was called.
|
| 790 |
+
|
| 791 |
+
The state is represented as a dictionary with the following structure.
|
| 792 |
+
|
| 793 |
+
.. code-block:: python
|
| 794 |
+
|
| 795 |
+
class Snapshot(TypedDict):
|
| 796 |
+
segments : List[Segment]
|
| 797 |
+
device_traces: List[List[TraceEntry]]
|
| 798 |
+
|
| 799 |
+
class Segment(TypedDict):
|
| 800 |
+
# Segments are memory returned from a cudaMalloc call.
|
| 801 |
+
# The size of reserved memory is the sum of all Segments.
|
| 802 |
+
# Segments are cached and reused for future allocations.
|
| 803 |
+
# If the reuse is smaller than the segment, the segment
|
| 804 |
+
# is split into more then one Block.
|
| 805 |
+
# empty_cache() frees Segments that are entirely inactive.
|
| 806 |
+
address: int
|
| 807 |
+
total_size: int # cudaMalloc'd size of segment
|
| 808 |
+
stream: int
|
| 809 |
+
segment_type: Literal['small', 'large'] # 'large' (>1MB)
|
| 810 |
+
allocated_size: int # size of memory in use
|
| 811 |
+
active_size: int # size of memory in use or in active_awaiting_free state
|
| 812 |
+
blocks : List[Block]
|
| 813 |
+
|
| 814 |
+
class Block(TypedDict):
|
| 815 |
+
# A piece of memory returned from the allocator, or
|
| 816 |
+
# current cached but inactive.
|
| 817 |
+
size: int
|
| 818 |
+
requested_size: int # size requested during malloc, may be smaller than
|
| 819 |
+
# size due to rounding
|
| 820 |
+
address: int
|
| 821 |
+
state: Literal['active_allocated', # used by a tensor
|
| 822 |
+
'active_awaiting_free', # waiting for another stream to finish using
|
| 823 |
+
# this, then it will become free
|
| 824 |
+
'inactive',] # free for reuse
|
| 825 |
+
frames: List[Frame] # stack trace from where the allocation occurred
|
| 826 |
+
|
| 827 |
+
class Frame(TypedDict):
|
| 828 |
+
filename: str
|
| 829 |
+
line: int
|
| 830 |
+
name: str
|
| 831 |
+
|
| 832 |
+
class TraceEntry(TypedDict):
|
| 833 |
+
# When `torch.cuda.memory._record_memory_history()` is enabled,
|
| 834 |
+
# the snapshot will contain TraceEntry objects that record each
|
| 835 |
+
# action the allocator took.
|
| 836 |
+
action: Literal[
|
| 837 |
+
'alloc' # memory allocated
|
| 838 |
+
'free_requested', # the allocated received a call to free memory
|
| 839 |
+
'free_completed', # the memory that was requested to be freed is now
|
| 840 |
+
# able to be used in future allocation calls
|
| 841 |
+
'segment_alloc', # the caching allocator ask cudaMalloc for more memory
|
| 842 |
+
# and added it as a segment in its cache
|
| 843 |
+
'segment_free', # the caching allocator called cudaFree to return memory
|
| 844 |
+
# to cuda possibly trying free up memory to
|
| 845 |
+
# allocate more segments or because empty_caches was called
|
| 846 |
+
'oom', # the allocator threw an OOM exception. 'size' is
|
| 847 |
+
# the requested number of bytes that did not succeed
|
| 848 |
+
'snapshot' # the allocator generated a memory snapshot
|
| 849 |
+
# useful to coorelate a previously taken
|
| 850 |
+
# snapshot with this trace
|
| 851 |
+
]
|
| 852 |
+
addr: int # not present for OOM
|
| 853 |
+
frames: List[Frame]
|
| 854 |
+
size: int
|
| 855 |
+
stream: int
|
| 856 |
+
device_free: int # only present for OOM, the amount of
|
| 857 |
+
# memory cuda still reports to be free
|
| 858 |
+
|
| 859 |
+
Returns:
|
| 860 |
+
The Snapshot dictionary object
|
| 861 |
+
"""
|
| 862 |
+
return _C._cuda_memorySnapshot()
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def _dump_snapshot(filename="dump_snapshot.pickle"):
|
| 866 |
+
"""
|
| 867 |
+
Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
|
| 868 |
+
|
| 869 |
+
This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz
|
| 870 |
+
|
| 871 |
+
Args:
|
| 872 |
+
filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
|
| 873 |
+
"""
|
| 874 |
+
s = _snapshot()
|
| 875 |
+
with open(filename, "wb") as f:
|
| 876 |
+
pickle.dump(s, f)
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
def _save_segment_usage(filename="output.svg", snapshot=None):
|
| 880 |
+
if snapshot is None:
|
| 881 |
+
snapshot = _snapshot()
|
| 882 |
+
with open(filename, "w") as f:
|
| 883 |
+
f.write(_segments(snapshot))
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
def _save_memory_usage(filename="output.svg", snapshot=None):
|
| 887 |
+
if snapshot is None:
|
| 888 |
+
snapshot = _snapshot()
|
| 889 |
+
with open(filename, "w") as f:
|
| 890 |
+
f.write(_memory(snapshot))
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def _set_allocator_settings(env: str):
|
| 894 |
+
return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
def get_allocator_backend() -> str:
|
| 898 |
+
r"""Return a string describing the active allocator backend as set by
|
| 899 |
+
``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
|
| 900 |
+
``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
|
| 901 |
+
(CUDA's built-in asynchronous allocator).
|
| 902 |
+
|
| 903 |
+
.. note::
|
| 904 |
+
See :ref:`cuda-memory-management` for details on choosing the allocator backend.
|
| 905 |
+
"""
|
| 906 |
+
return torch._C._cuda_getAllocatorBackend()
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
class _CUDAAllocator:
|
| 910 |
+
r"""Wrapper over internal CUDA memory allocators."""
|
| 911 |
+
|
| 912 |
+
def __init__(self, allocator: torch._C._cuda_CUDAAllocator):
|
| 913 |
+
self._allocator = allocator
|
| 914 |
+
|
| 915 |
+
def allocator(self):
|
| 916 |
+
return self._allocator
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
class CUDAPluggableAllocator(_CUDAAllocator):
|
| 920 |
+
r"""CUDA memory allocator loaded from a so file."""
|
| 921 |
+
|
| 922 |
+
def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
|
| 923 |
+
r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes.
|
| 924 |
+
|
| 925 |
+
To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function.
|
| 926 |
+
|
| 927 |
+
Args:
|
| 928 |
+
path_to_so_file(str): Path in the filesystem to the `.so` file containing
|
| 929 |
+
the allocator functions
|
| 930 |
+
alloc_fn_name(str): Name of the function to perform the memory allocation
|
| 931 |
+
in the so file. The signature must be:
|
| 932 |
+
void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream);
|
| 933 |
+
free_fn_name(str): Name of the function to perform the memory release
|
| 934 |
+
in the so file. The signature must be:
|
| 935 |
+
void free_fn_name(void* ptr, size_t size, cudaStream_t stream);
|
| 936 |
+
|
| 937 |
+
.. warning::
|
| 938 |
+
This is currently supported only in unix OSs
|
| 939 |
+
|
| 940 |
+
.. note::
|
| 941 |
+
See :ref:`cuda-memory-management` for details on creating and using a custom allocator
|
| 942 |
+
"""
|
| 943 |
+
allocator = ctypes.CDLL(path_to_so_file)
|
| 944 |
+
alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
|
| 945 |
+
free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
|
| 946 |
+
assert alloc_fn is not None
|
| 947 |
+
assert free_fn is not None
|
| 948 |
+
self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn)
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
def change_current_allocator(allocator: _CUDAAllocator) -> None:
|
| 952 |
+
r"""Change the currently used memory allocator to be the one provided.
|
| 953 |
+
|
| 954 |
+
If the current allocator has already been used/initialized, this function will error.
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
Args:
|
| 958 |
+
allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one.
|
| 959 |
+
.. note::
|
| 960 |
+
See :ref:`cuda-memory-management` for details on creating and using a custom allocator
|
| 961 |
+
"""
|
| 962 |
+
torch._C._cuda_changeCurrentAllocator(allocator.allocator())
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
def _get_current_allocator() -> _CUDAAllocator:
|
| 966 |
+
r"""Return the allocator being currently used.
|
| 967 |
+
|
| 968 |
+
.. note::
|
| 969 |
+
See :ref:`cuda-memory-management` for details on creating and using a custom allocator
|
| 970 |
+
"""
|
| 971 |
+
return _CUDAAllocator(torch._C._cuda_getAllocator())
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
class MemPool(_MemPool):
|
| 975 |
+
r"""MemPool represents a pool of memory in a caching allocator. Currently,
|
| 976 |
+
it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
| 977 |
+
|
| 978 |
+
Args:
|
| 979 |
+
allocator(torch._C._cuda_CUDAAllocator, optional): a
|
| 980 |
+
torch._C._cuda_CUDAAllocator object that can be used to
|
| 981 |
+
define how memory gets allocated in the pool. If :attr:`allocator`
|
| 982 |
+
is ``None`` (default), memory allocation follows the default/
|
| 983 |
+
current configuration of the CUDACachingAllocator.
|
| 984 |
+
|
| 985 |
+
"""
|
| 986 |
+
|
| 987 |
+
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None):
|
| 988 |
+
super().__init__(allocator, True)
|
| 989 |
+
|
| 990 |
+
@property
|
| 991 |
+
def id(self) -> Tuple[int, int]:
|
| 992 |
+
r"""Returns the ID of this pool as a tuple of two ints."""
|
| 993 |
+
return super().id
|
| 994 |
+
|
| 995 |
+
@property
|
| 996 |
+
def allocator(self) -> Optional[_cuda_CUDAAllocator]:
|
| 997 |
+
r"""Returns the allocator this MemPool routes allocations to"""
|
| 998 |
+
return super().allocator
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
class MemPoolContext(_MemPoolContext):
|
| 1002 |
+
r"""MemPoolContext holds the currently active pool and stashes the previous
|
| 1003 |
+
pool. On deletion it makes the previous pool active.
|
| 1004 |
+
|
| 1005 |
+
Args:
|
| 1006 |
+
pool(torch.cuda.MemPool): a MemPool object to be made active so that
|
| 1007 |
+
allocations route to this pool.
|
| 1008 |
+
|
| 1009 |
+
"""
|
| 1010 |
+
|
| 1011 |
+
def __init__(self, pool: MemPool):
|
| 1012 |
+
super().__init__(pool)
|
| 1013 |
+
|
| 1014 |
+
@staticmethod
|
| 1015 |
+
def active_pool() -> Optional[_MemPool]:
|
| 1016 |
+
r"""Returns the active MemPool"""
|
| 1017 |
+
return _MemPoolContext.active_pool()
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
@contextlib.contextmanager
|
| 1021 |
+
def use_mem_pool(pool: MemPool, device: Union[Device, int] = None):
|
| 1022 |
+
r"""A context manager that routes allocations to a given pool.
|
| 1023 |
+
|
| 1024 |
+
Args:
|
| 1025 |
+
pool(torch.cuda.MemPool): a MemPool object to be made active so that
|
| 1026 |
+
allocations route to this pool.
|
| 1027 |
+
device (torch.device or int, optional): selected device. Uses MemPool on
|
| 1028 |
+
the current device, given by :func:`~torch.cuda.current_device`,
|
| 1029 |
+
if :attr:`device` is ``None`` (default).
|
| 1030 |
+
|
| 1031 |
+
"""
|
| 1032 |
+
ctx = MemPoolContext(pool)
|
| 1033 |
+
device_index = (
|
| 1034 |
+
torch.cuda.current_device() if device is None else _get_device_index(device)
|
| 1035 |
+
)
|
| 1036 |
+
_cuda_beginAllocateToPool(device_index, pool.id)
|
| 1037 |
+
try:
|
| 1038 |
+
yield
|
| 1039 |
+
finally:
|
| 1040 |
+
_cuda_endAllocateCurrentStreamToPool(device_index, pool.id)
|
| 1041 |
+
del ctx
|
.venv/lib/python3.11/site-packages/torch/cuda/nccl.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Optional, Sequence, Union
|
| 5 |
+
|
| 6 |
+
import torch.cuda
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
|
| 10 |
+
|
| 11 |
+
SUM = 0 # ncclRedOp_t
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def is_available(tensors):
|
| 15 |
+
if not hasattr(torch._C, "_nccl_all_reduce"):
|
| 16 |
+
warnings.warn("PyTorch is not compiled with NCCL support")
|
| 17 |
+
return False
|
| 18 |
+
|
| 19 |
+
devices = set()
|
| 20 |
+
for tensor in tensors:
|
| 21 |
+
if tensor.is_sparse:
|
| 22 |
+
return False
|
| 23 |
+
if not tensor.is_contiguous():
|
| 24 |
+
return False
|
| 25 |
+
if not tensor.is_cuda:
|
| 26 |
+
return False
|
| 27 |
+
device = tensor.get_device()
|
| 28 |
+
if device in devices:
|
| 29 |
+
return False
|
| 30 |
+
devices.add(device)
|
| 31 |
+
|
| 32 |
+
return True
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def version():
|
| 36 |
+
"""
|
| 37 |
+
Returns the version of the NCCL.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
This function returns a tuple containing the major, minor, and patch version numbers of the NCCL.
|
| 41 |
+
The suffix is also included in the tuple if a version suffix exists.
|
| 42 |
+
Returns:
|
| 43 |
+
tuple: The version information of the NCCL.
|
| 44 |
+
"""
|
| 45 |
+
ver = torch._C._nccl_version()
|
| 46 |
+
major = ver >> 32
|
| 47 |
+
minor = (ver >> 16) & 65535
|
| 48 |
+
patch = ver & 65535
|
| 49 |
+
suffix = torch._C._nccl_version_suffix().decode("utf-8")
|
| 50 |
+
if suffix == "":
|
| 51 |
+
return (major, minor, patch)
|
| 52 |
+
else:
|
| 53 |
+
return (major, minor, patch, suffix)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def unique_id():
|
| 57 |
+
return torch._C._nccl_unique_id()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def init_rank(num_ranks, uid, rank):
|
| 61 |
+
return torch._C._nccl_init_rank(num_ranks, uid, rank)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
|
| 65 |
+
if not isinstance(inputs, collections.abc.Container) or isinstance(
|
| 66 |
+
inputs, torch.Tensor
|
| 67 |
+
):
|
| 68 |
+
raise TypeError("Inputs should be a collection of tensors")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
|
| 72 |
+
_check_sequence_type(inputs)
|
| 73 |
+
if outputs is None:
|
| 74 |
+
outputs = inputs
|
| 75 |
+
_check_sequence_type(outputs)
|
| 76 |
+
torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# `output` used to be `outputs`, taking in a list of tensors. So we have two
|
| 80 |
+
# arguments for BC reasons.
|
| 81 |
+
def reduce(
|
| 82 |
+
inputs: Sequence[torch.Tensor],
|
| 83 |
+
output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
|
| 84 |
+
root: int = 0,
|
| 85 |
+
op: int = SUM,
|
| 86 |
+
streams: Optional[Sequence[torch.cuda.Stream]] = None,
|
| 87 |
+
comms=None,
|
| 88 |
+
*,
|
| 89 |
+
outputs: Optional[Sequence[torch.Tensor]] = None,
|
| 90 |
+
) -> None:
|
| 91 |
+
_check_sequence_type(inputs)
|
| 92 |
+
_output: torch.Tensor
|
| 93 |
+
if outputs is not None:
|
| 94 |
+
if output is not None:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
"'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
|
| 97 |
+
"favor of 'output', taking in a single output tensor. The signature of reduce is: "
|
| 98 |
+
"reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
warnings.warn(
|
| 102 |
+
"`nccl.reduce` with an output tensor list is deprecated. "
|
| 103 |
+
"Please specify a single output tensor with argument 'output' instead instead.",
|
| 104 |
+
FutureWarning,
|
| 105 |
+
stacklevel=2,
|
| 106 |
+
)
|
| 107 |
+
_output = outputs[root]
|
| 108 |
+
elif not isinstance(output, torch.Tensor) and isinstance(
|
| 109 |
+
output, collections.abc.Sequence
|
| 110 |
+
):
|
| 111 |
+
# User called old API with positional arguments of list of output tensors.
|
| 112 |
+
warnings.warn(
|
| 113 |
+
"nccl.reduce with an output tensor list is deprecated. "
|
| 114 |
+
"Please specify a single output tensor.",
|
| 115 |
+
FutureWarning,
|
| 116 |
+
stacklevel=2,
|
| 117 |
+
)
|
| 118 |
+
_output = output[root]
|
| 119 |
+
else:
|
| 120 |
+
_output = inputs[root] if output is None else output
|
| 121 |
+
torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def broadcast(
|
| 125 |
+
inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
|
| 126 |
+
) -> None:
|
| 127 |
+
_check_sequence_type(inputs)
|
| 128 |
+
torch._C._nccl_broadcast(inputs, root, streams, comms)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def all_gather(
|
| 132 |
+
inputs: Sequence[torch.Tensor],
|
| 133 |
+
outputs: Sequence[torch.Tensor],
|
| 134 |
+
streams=None,
|
| 135 |
+
comms=None,
|
| 136 |
+
) -> None:
|
| 137 |
+
_check_sequence_type(inputs)
|
| 138 |
+
_check_sequence_type(outputs)
|
| 139 |
+
torch._C._nccl_all_gather(inputs, outputs, streams, comms)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def reduce_scatter(
|
| 143 |
+
inputs: Sequence[torch.Tensor],
|
| 144 |
+
outputs: Sequence[torch.Tensor],
|
| 145 |
+
op: int = SUM,
|
| 146 |
+
streams=None,
|
| 147 |
+
comms=None,
|
| 148 |
+
) -> None:
|
| 149 |
+
_check_sequence_type(inputs)
|
| 150 |
+
_check_sequence_type(outputs)
|
| 151 |
+
torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
|
.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
|
| 3 |
+
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from torch._C import _nvtx
|
| 9 |
+
except ImportError:
|
| 10 |
+
|
| 11 |
+
class _NVTXStub:
|
| 12 |
+
@staticmethod
|
| 13 |
+
def _fail(*args, **kwargs):
|
| 14 |
+
raise RuntimeError(
|
| 15 |
+
"NVTX functions not installed. Are you sure you have a CUDA build?"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
rangePushA = _fail
|
| 19 |
+
rangePop = _fail
|
| 20 |
+
markA = _fail
|
| 21 |
+
|
| 22 |
+
_nvtx = _NVTXStub() # type: ignore[assignment]
|
| 23 |
+
|
| 24 |
+
__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def range_push(msg):
|
| 28 |
+
"""
|
| 29 |
+
Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
msg (str): ASCII message to associate with range
|
| 33 |
+
"""
|
| 34 |
+
return _nvtx.rangePushA(msg)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def range_pop():
|
| 38 |
+
"""Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended."""
|
| 39 |
+
return _nvtx.rangePop()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def range_start(msg) -> int:
|
| 43 |
+
"""
|
| 44 |
+
Mark the start of a range with string message. It returns an unique handle
|
| 45 |
+
for this range to pass to the corresponding call to rangeEnd().
|
| 46 |
+
|
| 47 |
+
A key difference between this and range_push/range_pop is that the
|
| 48 |
+
range_start/range_end version supports range across threads (start on one
|
| 49 |
+
thread and end on another thread).
|
| 50 |
+
|
| 51 |
+
Returns: A range handle (uint64_t) that can be passed to range_end().
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
msg (str): ASCII message to associate with the range.
|
| 55 |
+
"""
|
| 56 |
+
return _nvtx.rangeStartA(msg)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def range_end(range_id) -> None:
|
| 60 |
+
"""
|
| 61 |
+
Mark the end of a range for a given range_id.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
range_id (int): an unique handle for the start range.
|
| 65 |
+
"""
|
| 66 |
+
_nvtx.rangeEnd(range_id)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def mark(msg):
|
| 70 |
+
"""
|
| 71 |
+
Describe an instantaneous event that occurred at some point.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
msg (str): ASCII message to associate with the event.
|
| 75 |
+
"""
|
| 76 |
+
return _nvtx.markA(msg)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@contextmanager
|
| 80 |
+
def range(msg, *args, **kwargs):
|
| 81 |
+
"""
|
| 82 |
+
Context manager / decorator that pushes an NVTX range at the beginning
|
| 83 |
+
of its scope, and pops it at the end. If extra arguments are given,
|
| 84 |
+
they are passed as arguments to msg.format().
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
msg (str): message to associate with the range
|
| 88 |
+
"""
|
| 89 |
+
range_push(msg.format(*args, **kwargs))
|
| 90 |
+
try:
|
| 91 |
+
yield
|
| 92 |
+
finally:
|
| 93 |
+
range_pop()
|
.venv/lib/python3.11/site-packages/torch/cuda/profiler.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import tempfile
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from . import check_error, cudart
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = ["init", "start", "stop", "profile"]
|
| 11 |
+
|
| 12 |
+
DEFAULT_FLAGS = [
|
| 13 |
+
"gpustarttimestamp",
|
| 14 |
+
"gpuendtimestamp",
|
| 15 |
+
"gridsize3d",
|
| 16 |
+
"threadblocksize",
|
| 17 |
+
"streamid",
|
| 18 |
+
"enableonstart 0",
|
| 19 |
+
"conckerneltrace",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def init(output_file, flags=None, output_mode="key_value"):
|
| 24 |
+
rt = cudart()
|
| 25 |
+
if not hasattr(rt, "cudaOutputMode"):
|
| 26 |
+
raise AssertionError("HIP does not support profiler initialization!")
|
| 27 |
+
if (
|
| 28 |
+
hasattr(torch.version, "cuda")
|
| 29 |
+
and torch.version.cuda is not None
|
| 30 |
+
and int(torch.version.cuda.split(".")[0]) >= 12
|
| 31 |
+
):
|
| 32 |
+
# Check https://github.com/pytorch/pytorch/pull/91118
|
| 33 |
+
# cudaProfilerInitialize is no longer needed after CUDA 12
|
| 34 |
+
raise AssertionError("CUDA12+ does not need profiler initialization!")
|
| 35 |
+
flags = DEFAULT_FLAGS if flags is None else flags
|
| 36 |
+
if output_mode == "key_value":
|
| 37 |
+
output_mode_enum = rt.cudaOutputMode.KeyValuePair
|
| 38 |
+
elif output_mode == "csv":
|
| 39 |
+
output_mode_enum = rt.cudaOutputMode.CSV
|
| 40 |
+
else:
|
| 41 |
+
raise RuntimeError(
|
| 42 |
+
"supported CUDA profiler output modes are: key_value and csv"
|
| 43 |
+
)
|
| 44 |
+
with tempfile.NamedTemporaryFile(delete=True) as f:
|
| 45 |
+
f.write(b"\n".join(f.encode("ascii") for f in flags))
|
| 46 |
+
f.flush()
|
| 47 |
+
check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def start():
|
| 51 |
+
r"""Starts cuda profiler data collection.
|
| 52 |
+
|
| 53 |
+
.. warning::
|
| 54 |
+
Raises CudaError in case of it is unable to start the profiler.
|
| 55 |
+
"""
|
| 56 |
+
check_error(cudart().cudaProfilerStart())
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def stop():
|
| 60 |
+
r"""Stops cuda profiler data collection.
|
| 61 |
+
|
| 62 |
+
.. warning::
|
| 63 |
+
Raises CudaError in case of it is unable to stop the profiler.
|
| 64 |
+
"""
|
| 65 |
+
check_error(cudart().cudaProfilerStop())
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@contextlib.contextmanager
|
| 69 |
+
def profile():
|
| 70 |
+
"""
|
| 71 |
+
Enable profiling.
|
| 72 |
+
|
| 73 |
+
Context Manager to enabling profile collection by the active profiling tool from CUDA backend.
|
| 74 |
+
Example:
|
| 75 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
| 76 |
+
>>> import torch
|
| 77 |
+
>>> model = torch.nn.Linear(20, 30).cuda()
|
| 78 |
+
>>> inputs = torch.randn(128, 20).cuda()
|
| 79 |
+
>>> with torch.cuda.profiler.profile() as prof:
|
| 80 |
+
... model(inputs)
|
| 81 |
+
"""
|
| 82 |
+
try:
|
| 83 |
+
start()
|
| 84 |
+
yield
|
| 85 |
+
finally:
|
| 86 |
+
stop()
|
.venv/lib/python3.11/site-packages/torch/cuda/random.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Iterable, List, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
from . import _lazy_call, _lazy_init, current_device, device_count
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"get_rng_state",
|
| 12 |
+
"get_rng_state_all",
|
| 13 |
+
"set_rng_state",
|
| 14 |
+
"set_rng_state_all",
|
| 15 |
+
"manual_seed",
|
| 16 |
+
"manual_seed_all",
|
| 17 |
+
"seed",
|
| 18 |
+
"seed_all",
|
| 19 |
+
"initial_seed",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
|
| 24 |
+
r"""Return the random number generator state of the specified GPU as a ByteTensor.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
device (torch.device or int, optional): The device to return the RNG state of.
|
| 28 |
+
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
| 29 |
+
|
| 30 |
+
.. warning::
|
| 31 |
+
This function eagerly initializes CUDA.
|
| 32 |
+
"""
|
| 33 |
+
_lazy_init()
|
| 34 |
+
if isinstance(device, str):
|
| 35 |
+
device = torch.device(device)
|
| 36 |
+
elif isinstance(device, int):
|
| 37 |
+
device = torch.device("cuda", device)
|
| 38 |
+
idx = device.index
|
| 39 |
+
if idx is None:
|
| 40 |
+
idx = current_device()
|
| 41 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 42 |
+
return default_generator.get_state()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_rng_state_all() -> List[Tensor]:
|
| 46 |
+
r"""Return a list of ByteTensor representing the random number states of all devices."""
|
| 47 |
+
results = []
|
| 48 |
+
for i in range(device_count()):
|
| 49 |
+
results.append(get_rng_state(i))
|
| 50 |
+
return results
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def set_rng_state(
|
| 54 |
+
new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
|
| 55 |
+
) -> None:
|
| 56 |
+
r"""Set the random number generator state of the specified GPU.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
new_state (torch.ByteTensor): The desired state
|
| 60 |
+
device (torch.device or int, optional): The device to set the RNG state.
|
| 61 |
+
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
| 62 |
+
"""
|
| 63 |
+
with torch._C._DisableFuncTorch():
|
| 64 |
+
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
| 65 |
+
if isinstance(device, str):
|
| 66 |
+
device = torch.device(device)
|
| 67 |
+
elif isinstance(device, int):
|
| 68 |
+
device = torch.device("cuda", device)
|
| 69 |
+
|
| 70 |
+
def cb():
|
| 71 |
+
idx = device.index
|
| 72 |
+
if idx is None:
|
| 73 |
+
idx = current_device()
|
| 74 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 75 |
+
default_generator.set_state(new_state_copy)
|
| 76 |
+
|
| 77 |
+
_lazy_call(cb)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
|
| 81 |
+
r"""Set the random number generator state of all devices.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
new_states (Iterable of torch.ByteTensor): The desired state for each device.
|
| 85 |
+
"""
|
| 86 |
+
for i, state in enumerate(new_states):
|
| 87 |
+
set_rng_state(state, i)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def manual_seed(seed: int) -> None:
|
| 91 |
+
r"""Set the seed for generating random numbers for the current GPU.
|
| 92 |
+
|
| 93 |
+
It's safe to call this function if CUDA is not available; in that
|
| 94 |
+
case, it is silently ignored.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
seed (int): The desired seed.
|
| 98 |
+
|
| 99 |
+
.. warning::
|
| 100 |
+
If you are working with a multi-GPU model, this function is insufficient
|
| 101 |
+
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
|
| 102 |
+
"""
|
| 103 |
+
seed = int(seed)
|
| 104 |
+
|
| 105 |
+
def cb():
|
| 106 |
+
idx = current_device()
|
| 107 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 108 |
+
default_generator.manual_seed(seed)
|
| 109 |
+
|
| 110 |
+
_lazy_call(cb, seed=True)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def manual_seed_all(seed: int) -> None:
|
| 114 |
+
r"""Set the seed for generating random numbers on all GPUs.
|
| 115 |
+
|
| 116 |
+
It's safe to call this function if CUDA is not available; in that
|
| 117 |
+
case, it is silently ignored.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
seed (int): The desired seed.
|
| 121 |
+
"""
|
| 122 |
+
seed = int(seed)
|
| 123 |
+
|
| 124 |
+
def cb():
|
| 125 |
+
for i in range(device_count()):
|
| 126 |
+
default_generator = torch.cuda.default_generators[i]
|
| 127 |
+
default_generator.manual_seed(seed)
|
| 128 |
+
|
| 129 |
+
_lazy_call(cb, seed_all=True)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def seed() -> None:
|
| 133 |
+
r"""Set the seed for generating random numbers to a random number for the current GPU.
|
| 134 |
+
|
| 135 |
+
It's safe to call this function if CUDA is not available; in that
|
| 136 |
+
case, it is silently ignored.
|
| 137 |
+
|
| 138 |
+
.. warning::
|
| 139 |
+
If you are working with a multi-GPU model, this function will only initialize
|
| 140 |
+
the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def cb():
|
| 144 |
+
idx = current_device()
|
| 145 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 146 |
+
default_generator.seed()
|
| 147 |
+
|
| 148 |
+
_lazy_call(cb)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def seed_all() -> None:
|
| 152 |
+
r"""Set the seed for generating random numbers to a random number on all GPUs.
|
| 153 |
+
|
| 154 |
+
It's safe to call this function if CUDA is not available; in that
|
| 155 |
+
case, it is silently ignored.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def cb():
|
| 159 |
+
random_seed = 0
|
| 160 |
+
seeded = False
|
| 161 |
+
for i in range(device_count()):
|
| 162 |
+
default_generator = torch.cuda.default_generators[i]
|
| 163 |
+
if not seeded:
|
| 164 |
+
default_generator.seed()
|
| 165 |
+
random_seed = default_generator.initial_seed()
|
| 166 |
+
seeded = True
|
| 167 |
+
else:
|
| 168 |
+
default_generator.manual_seed(random_seed)
|
| 169 |
+
|
| 170 |
+
_lazy_call(cb)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def initial_seed() -> int:
|
| 174 |
+
r"""Return the current random seed of the current GPU.
|
| 175 |
+
|
| 176 |
+
.. warning::
|
| 177 |
+
This function eagerly initializes CUDA.
|
| 178 |
+
"""
|
| 179 |
+
_lazy_init()
|
| 180 |
+
idx = current_device()
|
| 181 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 182 |
+
return default_generator.initial_seed()
|
.venv/lib/python3.11/site-packages/torch/cuda/sparse.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# The Tensor classes are added to this module by python_tensor.cpp
|
.venv/lib/python3.11/site-packages/torch/cuda/streams.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import ctypes
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._streambase import _EventBase, _StreamBase
|
| 6 |
+
from torch._utils import _dummy_type
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if not hasattr(torch._C, "_CudaStreamBase"):
|
| 10 |
+
# Define dummy base classes
|
| 11 |
+
torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
|
| 12 |
+
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Stream(torch._C._CudaStreamBase, _StreamBase):
|
| 16 |
+
r"""Wrapper around a CUDA stream.
|
| 17 |
+
|
| 18 |
+
A CUDA stream is a linear sequence of execution that belongs to a specific
|
| 19 |
+
device, independent from other streams. See :ref:`cuda-semantics` for
|
| 20 |
+
details.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
device(torch.device or int, optional): a device on which to allocate
|
| 24 |
+
the stream. If :attr:`device` is ``None`` (default) or a negative
|
| 25 |
+
integer, this will use the current device.
|
| 26 |
+
priority(int, optional): priority of the stream, should be 0 or
|
| 27 |
+
negative, where negative numbers indicate higher priority. By default,
|
| 28 |
+
streams have priority 0.
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __new__(cls, device=None, priority=0, **kwargs):
|
| 33 |
+
# setting device manager is expensive, so we avoid it unless necessary
|
| 34 |
+
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
| 35 |
+
return super().__new__(cls, priority=priority, **kwargs)
|
| 36 |
+
else:
|
| 37 |
+
with torch.cuda.device(device):
|
| 38 |
+
return super().__new__(cls, priority=priority, **kwargs)
|
| 39 |
+
|
| 40 |
+
def wait_event(self, event) -> None:
|
| 41 |
+
r"""Make all future work submitted to the stream wait for an event.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
event (torch.cuda.Event): an event to wait for.
|
| 45 |
+
|
| 46 |
+
.. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
|
| 47 |
+
`CUDA Stream documentation`_ for more info.
|
| 48 |
+
|
| 49 |
+
This function returns without waiting for :attr:`event`: only future
|
| 50 |
+
operations are affected.
|
| 51 |
+
|
| 52 |
+
.. _CUDA Stream documentation:
|
| 53 |
+
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
|
| 54 |
+
"""
|
| 55 |
+
event.wait(self)
|
| 56 |
+
|
| 57 |
+
def wait_stream(self, stream) -> None:
|
| 58 |
+
r"""Synchronize with another stream.
|
| 59 |
+
|
| 60 |
+
All future work submitted to this stream will wait until all kernels
|
| 61 |
+
submitted to a given stream at the time of call complete.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
stream (Stream): a stream to synchronize.
|
| 65 |
+
|
| 66 |
+
.. note:: This function returns without waiting for currently enqueued
|
| 67 |
+
kernels in :attr:`stream`: only future operations are affected.
|
| 68 |
+
"""
|
| 69 |
+
self.wait_event(stream.record_event())
|
| 70 |
+
|
| 71 |
+
def record_event(self, event=None):
|
| 72 |
+
r"""Record an event.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
event (torch.cuda.Event, optional): event to record. If not given, a new one
|
| 76 |
+
will be allocated.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Recorded event.
|
| 80 |
+
"""
|
| 81 |
+
if event is None:
|
| 82 |
+
event = Event()
|
| 83 |
+
event.record(self)
|
| 84 |
+
return event
|
| 85 |
+
|
| 86 |
+
def query(self) -> bool:
|
| 87 |
+
r"""Check if all the work submitted has been completed.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
A boolean indicating if all kernels in this stream are completed.
|
| 91 |
+
"""
|
| 92 |
+
return super().query()
|
| 93 |
+
|
| 94 |
+
def synchronize(self) -> None:
|
| 95 |
+
r"""Wait for all the kernels in this stream to complete.
|
| 96 |
+
|
| 97 |
+
.. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
|
| 98 |
+
`CUDA Stream documentation`_ for more info.
|
| 99 |
+
"""
|
| 100 |
+
super().synchronize()
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def _as_parameter_(self):
|
| 104 |
+
return ctypes.c_void_p(self.cuda_stream)
|
| 105 |
+
|
| 106 |
+
def __eq__(self, o) -> bool:
|
| 107 |
+
if isinstance(o, Stream):
|
| 108 |
+
return super().__eq__(o)
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def __hash__(self):
|
| 112 |
+
return hash((self.cuda_stream, self.device))
|
| 113 |
+
|
| 114 |
+
def __repr__(self):
|
| 115 |
+
return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ExternalStream(Stream):
|
| 119 |
+
r"""Wrapper around an externally allocated CUDA stream.
|
| 120 |
+
|
| 121 |
+
This class is used to wrap streams allocated in other libraries in order
|
| 122 |
+
to facilitate data exchange and multi-library interactions.
|
| 123 |
+
|
| 124 |
+
.. note:: This class doesn't manage the stream life-cycle, it is the user
|
| 125 |
+
responsibility to keep the referenced stream alive while this class is
|
| 126 |
+
being used.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
stream_ptr(int): Integer representation of the `cudaStream_t` value.
|
| 130 |
+
allocated externally.
|
| 131 |
+
device(torch.device or int, optional): the device where the stream
|
| 132 |
+
was originally allocated. If device is specified incorrectly,
|
| 133 |
+
subsequent launches using this stream may fail.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __new__(cls, stream_ptr, device=None, **kwargs):
|
| 137 |
+
with torch.cuda.device(device):
|
| 138 |
+
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Event(torch._C._CudaEventBase, _EventBase):
|
| 142 |
+
r"""Wrapper around a CUDA event.
|
| 143 |
+
|
| 144 |
+
CUDA events are synchronization markers that can be used to monitor the
|
| 145 |
+
device's progress, to accurately measure timing, and to synchronize CUDA
|
| 146 |
+
streams.
|
| 147 |
+
|
| 148 |
+
The underlying CUDA events are lazily initialized when the event is first
|
| 149 |
+
recorded or exported to another process. After creation, only streams on the
|
| 150 |
+
same device may record the event. However, streams on any device can wait on
|
| 151 |
+
the event.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
enable_timing (bool, optional): indicates if the event should measure time
|
| 155 |
+
(default: ``False``)
|
| 156 |
+
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
|
| 157 |
+
interprocess (bool): if ``True``, the event can be shared between processes
|
| 158 |
+
(default: ``False``)
|
| 159 |
+
|
| 160 |
+
.. _CUDA Event Documentation:
|
| 161 |
+
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
|
| 165 |
+
return super().__new__(
|
| 166 |
+
cls,
|
| 167 |
+
enable_timing=enable_timing,
|
| 168 |
+
blocking=blocking,
|
| 169 |
+
interprocess=interprocess,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
@classmethod
|
| 173 |
+
def from_ipc_handle(cls, device, handle):
|
| 174 |
+
r"""Reconstruct an event from an IPC handle on the given device."""
|
| 175 |
+
return super().from_ipc_handle(device, handle)
|
| 176 |
+
|
| 177 |
+
def record(self, stream=None):
|
| 178 |
+
r"""Record the event in a given stream.
|
| 179 |
+
|
| 180 |
+
Uses ``torch.cuda.current_stream()`` if no stream is specified. The
|
| 181 |
+
stream's device must match the event's device.
|
| 182 |
+
"""
|
| 183 |
+
if stream is None:
|
| 184 |
+
stream = torch.cuda.current_stream()
|
| 185 |
+
super().record(stream)
|
| 186 |
+
|
| 187 |
+
def wait(self, stream=None) -> None:
|
| 188 |
+
r"""Make all future work submitted to the given stream wait for this event.
|
| 189 |
+
|
| 190 |
+
Use ``torch.cuda.current_stream()`` if no stream is specified.
|
| 191 |
+
|
| 192 |
+
.. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
|
| 193 |
+
`CUDA Event documentation`_ for more info.
|
| 194 |
+
"""
|
| 195 |
+
if stream is None:
|
| 196 |
+
stream = torch.cuda.current_stream()
|
| 197 |
+
super().wait(stream)
|
| 198 |
+
|
| 199 |
+
def query(self):
|
| 200 |
+
r"""Check if all work currently captured by event has completed.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
A boolean indicating if all work currently captured by event has
|
| 204 |
+
completed.
|
| 205 |
+
"""
|
| 206 |
+
return super().query()
|
| 207 |
+
|
| 208 |
+
def elapsed_time(self, end_event):
|
| 209 |
+
r"""Return the time elapsed.
|
| 210 |
+
|
| 211 |
+
Time reported in milliseconds after the event was recorded and
|
| 212 |
+
before the end_event was recorded.
|
| 213 |
+
"""
|
| 214 |
+
return super().elapsed_time(end_event)
|
| 215 |
+
|
| 216 |
+
def synchronize(self) -> None:
|
| 217 |
+
r"""Wait for the event to complete.
|
| 218 |
+
|
| 219 |
+
Waits until the completion of all work currently captured in this event.
|
| 220 |
+
This prevents the CPU thread from proceeding until the event completes.
|
| 221 |
+
|
| 222 |
+
.. note:: This is a wrapper around ``cudaEventSynchronize()``: see
|
| 223 |
+
`CUDA Event documentation`_ for more info.
|
| 224 |
+
"""
|
| 225 |
+
super().synchronize()
|
| 226 |
+
|
| 227 |
+
def ipc_handle(self):
|
| 228 |
+
r"""Return an IPC handle of this event.
|
| 229 |
+
|
| 230 |
+
If not recorded yet, the event will use the current device.
|
| 231 |
+
"""
|
| 232 |
+
return super().ipc_handle()
|
| 233 |
+
|
| 234 |
+
@property
|
| 235 |
+
def _as_parameter_(self):
|
| 236 |
+
return ctypes.c_void_p(self.cuda_event)
|
| 237 |
+
|
| 238 |
+
def __repr__(self) -> str:
|
| 239 |
+
if self.cuda_event:
|
| 240 |
+
return f"<torch.cuda.Event {self._as_parameter_.value:#x}>"
|
| 241 |
+
else:
|
| 242 |
+
return "<torch.cuda.Event uninitialized>"
|
.venv/lib/python3.11/site-packages/torch/cuda/tunable.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""
|
| 2 |
+
This module exposes a TunableOp interface.
|
| 3 |
+
|
| 4 |
+
Some operations, such as GEMMs, could be implemented using more than one library
|
| 5 |
+
or more than one technique. For example, a GEMM could be implemented for CUDA or
|
| 6 |
+
ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and
|
| 7 |
+
hipblaslt libraries allow the user to query for all possible algorithms and then
|
| 8 |
+
choose one. How does one know which implementation is the fastest and should be
|
| 9 |
+
chosen? That's what TunableOp provides.
|
| 10 |
+
|
| 11 |
+
Enabling TunableOp and Tuning Separately
|
| 12 |
+
========================================
|
| 13 |
+
|
| 14 |
+
The TunableOp feature is enabled separately from enabling the tuning phase
|
| 15 |
+
itself. Enabling TunableOp means that PyTorch will replace any standard
|
| 16 |
+
operators with their Tunable implementations. Any call to a TunableOp first
|
| 17 |
+
checks whether it has already been tuned for the given operator inputs. If so,
|
| 18 |
+
it will immediately call the tuned operation; no further tuning will take place
|
| 19 |
+
even when the tuning setting is enabled. Instead if no tuning result is found,
|
| 20 |
+
and tuning is enabled, the TunableOp will benchmark every registered
|
| 21 |
+
implementation of that operator for the given set of inputs and select the
|
| 22 |
+
fastest.
|
| 23 |
+
|
| 24 |
+
File Input and Output
|
| 25 |
+
=====================
|
| 26 |
+
|
| 27 |
+
The first time any TunableOp is invoked, the internal database of tuned
|
| 28 |
+
operations will be prepared by attempting to read the results from the given
|
| 29 |
+
file. The default filename is 'tunableop_results.csv'. To support tuning when
|
| 30 |
+
multiple GPUs are used across multiple processes, the GPU device ordinal is
|
| 31 |
+
automatically inserted into the filename to avoid multiple processes overwriting
|
| 32 |
+
the same file.
|
| 33 |
+
|
| 34 |
+
If tuning is enabled and new tunings are discovered during the course of your
|
| 35 |
+
workload, it will also write out to this same filename with all tunings, both
|
| 36 |
+
the ones it read in at startup as well as the new ones found at runtime. This
|
| 37 |
+
can be used, for example, to build up a tunings file across many workloads by
|
| 38 |
+
reusing the same file. The output file is automatically created when the
|
| 39 |
+
application terminates. This behavior can be controlled by the C++ and Python
|
| 40 |
+
APIs but not the environment variables.
|
| 41 |
+
|
| 42 |
+
Assuming you specified a filename, you'll end up with a CSV file with contents
|
| 43 |
+
like so::
|
| 44 |
+
|
| 45 |
+
Validator,PT_VERSION,2.2.0
|
| 46 |
+
Validator,ROCM_VERSION,6.0.0.0-12969-1544e39
|
| 47 |
+
Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7
|
| 48 |
+
Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty
|
| 49 |
+
GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
|
| 50 |
+
GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
|
| 51 |
+
|
| 52 |
+
Note the "Validator" lines. If you change a library verison, or ROCm version, or
|
| 53 |
+
PyTorch version, TunableOp will detect this and reject the tunings file because
|
| 54 |
+
the prior tunings are likely affected by other software changes.
|
| 55 |
+
|
| 56 |
+
The remaining lines are the tuned solutions for each TunableOp encountered
|
| 57 |
+
during your execution. Each line consists of 4 comma-separated fields: operator
|
| 58 |
+
name, operator parameters, solution name, and average execution time. The
|
| 59 |
+
execution time is an optional field. The CSV file can be edited, but with
|
| 60 |
+
caution. For example, the solution name (field 3) can be changed to "Default"
|
| 61 |
+
and it will fall back to the original PyTorch untuned implementation. Or, in the
|
| 62 |
+
case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution
|
| 63 |
+
index you can override the solution that TunableOp selected by replacing the
|
| 64 |
+
value. The operator name and parameters (fields 1 and 2) are internally named
|
| 65 |
+
and should not be modified. In the case of GemmTunableOp, field 1 indicates the
|
| 66 |
+
datatype and whether the inputs are transposed (T) or not (N) and field 2
|
| 67 |
+
indicates the M, N, K input shapes.
|
| 68 |
+
|
| 69 |
+
There is an option to enable verbose output but it is only recommended for
|
| 70 |
+
debugging purposes. This will produce a lot of diagnostic messages but may be
|
| 71 |
+
useful to see if TunableOp is being used at all. Otherwise, TunableOp is
|
| 72 |
+
completely silent, besides file output, unless there is a warning or error
|
| 73 |
+
during its use. The verbose option is only available by setting the environment
|
| 74 |
+
variable PYTORCH_TUNABLEOP_VEROBSE=1.
|
| 75 |
+
|
| 76 |
+
A Note on Tuning Behavior
|
| 77 |
+
=========================
|
| 78 |
+
|
| 79 |
+
Tuning an operator consists of iterating through the list or registered
|
| 80 |
+
implementations and profiling each one. The profile is established by running a
|
| 81 |
+
single implementation in a loop multiple times and taking the average execution
|
| 82 |
+
time.
|
| 83 |
+
|
| 84 |
+
By default, each possible solution for a given operator will be run for either
|
| 85 |
+
100 iterations or as many iterations that can be run within 30ms, whichever is
|
| 86 |
+
smaller, and its average execution will be calculated. The fastest solution
|
| 87 |
+
among all that were successfully profiled will be chosen. A profile might fail
|
| 88 |
+
if the given solution doesn't achieve the same accuracy as the default
|
| 89 |
+
implementation or if the solution returns an error code.
|
| 90 |
+
|
| 91 |
+
Current Tunable Operators
|
| 92 |
+
=========================
|
| 93 |
+
|
| 94 |
+
TunableGemm for ROCm
|
| 95 |
+
--------------------
|
| 96 |
+
|
| 97 |
+
Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of
|
| 98 |
+
PyTorch will function correctly when using TunableOp but the only solution
|
| 99 |
+
available to CUDA builds is the 'Default' implementation i.e. the original
|
| 100 |
+
cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm()
|
| 101 |
+
or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a
|
| 102 |
+
given set of input arguments (transa, transb, m, n, k) will attempt to use the
|
| 103 |
+
fastest available implementation across both rocblas and hipblaslt.
|
| 104 |
+
|
| 105 |
+
Tuning Context
|
| 106 |
+
==============
|
| 107 |
+
|
| 108 |
+
The behavior of TunableOp is currently manipulated through environment
|
| 109 |
+
variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the
|
| 110 |
+
torch.cuda.tunable python interfaces that wrap the C++ TuningContext. The
|
| 111 |
+
environment variables take precedence over any setting you manipulate using the
|
| 112 |
+
C++ or Python APIs.
|
| 113 |
+
|
| 114 |
+
"""
|
| 115 |
+
from typing import Optional, Tuple
|
| 116 |
+
|
| 117 |
+
import torch
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
__all__ = [
|
| 121 |
+
"enable",
|
| 122 |
+
"is_enabled",
|
| 123 |
+
"tuning_enable",
|
| 124 |
+
"tuning_is_enabled",
|
| 125 |
+
"set_max_tuning_duration",
|
| 126 |
+
"get_max_tuning_duration",
|
| 127 |
+
"set_max_tuning_iterations",
|
| 128 |
+
"get_max_tuning_iterations",
|
| 129 |
+
"set_filename",
|
| 130 |
+
"get_filename",
|
| 131 |
+
"get_results",
|
| 132 |
+
"get_validators",
|
| 133 |
+
"write_file_on_exit",
|
| 134 |
+
"write_file",
|
| 135 |
+
"read_file",
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def enable(val: bool = True) -> None:
|
| 140 |
+
r"""This is the big on/off switch for all TunableOp implementations."""
|
| 141 |
+
torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def is_enabled() -> bool:
|
| 145 |
+
r"""Returns whether the TunableOp feature is enabled."""
|
| 146 |
+
return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def tuning_enable(val: bool = True) -> None:
|
| 150 |
+
r"""Enable tuning of TunableOp implementations.
|
| 151 |
+
|
| 152 |
+
When enabled, if a tuned entry isn't found, run the tuning step and record
|
| 153 |
+
the entry.
|
| 154 |
+
"""
|
| 155 |
+
torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def tuning_is_enabled() -> bool:
|
| 159 |
+
r"""Returns whether TunableOp implementations can be tuned."""
|
| 160 |
+
return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def set_max_tuning_duration(duration: int) -> None:
|
| 164 |
+
r"""Set max time in milliseconds to spend tuning a given solution.
|
| 165 |
+
|
| 166 |
+
If both max tuning duration and iterations are set, the smaller of the two
|
| 167 |
+
will be honored. At minimum 1 tuning iteration will always be run.
|
| 168 |
+
"""
|
| 169 |
+
torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_max_tuning_duration() -> int:
|
| 173 |
+
r"""Get max time to spend tuning a given solution."""
|
| 174 |
+
return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined]
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def set_max_tuning_iterations(iterations: int) -> None:
|
| 178 |
+
r"""Set max number of iterations to spend tuning a given solution.
|
| 179 |
+
|
| 180 |
+
If both max tuning duration and iterations are set, the smaller of the two
|
| 181 |
+
will be honored. At minimum 1 tuning iteration will always be run.
|
| 182 |
+
"""
|
| 183 |
+
torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_max_tuning_iterations() -> int:
|
| 187 |
+
r"""Get max iterations to spend tuning a given solution."""
|
| 188 |
+
return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def set_filename(filename: str, insert_device_ordinal: bool = False) -> None:
|
| 192 |
+
r"""Set the filename to use for input/output of tuning results.
|
| 193 |
+
|
| 194 |
+
If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal
|
| 195 |
+
will be added to the given filename automatically. This can be used in a
|
| 196 |
+
1-process-per-gpu cenario to ensure all processes write to a separate file.
|
| 197 |
+
"""
|
| 198 |
+
torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_filename() -> str:
|
| 202 |
+
r"""Get the results filename."""
|
| 203 |
+
return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined]
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_results() -> Tuple[str, str, str, float]:
|
| 207 |
+
r"""Return all TunableOp results."""
|
| 208 |
+
return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def get_validators() -> Tuple[str, str]:
|
| 212 |
+
r"""Return the TunableOp validators."""
|
| 213 |
+
return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def write_file_on_exit(val: bool) -> None:
|
| 217 |
+
r"""During Tuning Context destruction, write file to disk.
|
| 218 |
+
|
| 219 |
+
This is useful as a final flush of your results to disk if your application
|
| 220 |
+
terminates as result of normal operation or an error. Manual flushing of
|
| 221 |
+
your results can be achieved by manually calling ``write_file()``."""
|
| 222 |
+
torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def write_file(filename: Optional[str] = None) -> bool:
|
| 226 |
+
r"""Write results to a CSV file.
|
| 227 |
+
|
| 228 |
+
If :attr:`filename` is not given, ``get_filename()`` is called.
|
| 229 |
+
"""
|
| 230 |
+
if filename is None:
|
| 231 |
+
filename = get_filename()
|
| 232 |
+
return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def read_file(filename: Optional[str] = None) -> bool:
|
| 236 |
+
r"""Read results from a TunableOp CSV file.
|
| 237 |
+
|
| 238 |
+
If :attr:`filename` is not given, ``get_filename()`` is called.
|
| 239 |
+
"""
|
| 240 |
+
if filename is None:
|
| 241 |
+
filename = get_filename()
|
| 242 |
+
return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined]
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_compatibility.cpython-311.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc
ADDED
|
Binary file (9.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc
ADDED
|
Binary file (5.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-311.pyc
ADDED
|
Binary file (58.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/annotate.cpython-311.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (234 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc
ADDED
|
Binary file (95 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc
ADDED
|
Binary file (44.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc
ADDED
|
Binary file (4.57 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc
ADDED
|
Binary file (29.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/node.cpython-311.pyc
ADDED
|
Binary file (43.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc
ADDED
|
Binary file (24.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc
ADDED
|
Binary file (33.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc
ADDED
|
Binary file (5.79 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__pycache__/traceback.cpython-311.pyc
ADDED
|
Binary file (4.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (806 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-311.pyc
ADDED
|
Binary file (4.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc
ADDED
|
Binary file (5.92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-311.pyc
ADDED
|
Binary file (4.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-311.pyc
ADDED
|
Binary file (43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-311.pyc
ADDED
|
Binary file (4.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc
ADDED
|
Binary file (30.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-311.pyc
ADDED
|
Binary file (25.7 kB). View file
|
|
|