Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py +55 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py +48 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__init__.py +308 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__init__.py +236 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py +91 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py +52 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py +75 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py +303 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py +675 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py +144 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py +78 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py +33 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__init__.py +117 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/bias.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/bias.py +353 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/common_types.py +42 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/grad.py +189 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__init__.py +31 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (217 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
|
| 3 |
+
import torch._C._lazy
|
| 4 |
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
| 5 |
+
|
| 6 |
+
from .closure import add_step_closure, run_step_closures
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def mark_step(device: str = "", wait=False):
|
| 10 |
+
"""Triggers a mark step, which amounts to
|
| 11 |
+
- collecting a group of 'live' lazy tensors to index into the compilation cache
|
| 12 |
+
(lowering/compiling their IR graphs if not cached)
|
| 13 |
+
- kicking off execution of the compiled function
|
| 14 |
+
- (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator)
|
| 15 |
+
"""
|
| 16 |
+
# TODO(whc) expand this to include backend hooks and align with XLA backend needs
|
| 17 |
+
torch._C._lazy._mark_step(device, [], wait=wait)
|
| 18 |
+
|
| 19 |
+
run_step_closures()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def wait_device_ops(devices=None):
|
| 23 |
+
"""Waits for all the async operations on the given devices to complete.
|
| 24 |
+
Args:
|
| 25 |
+
devices (string..., optional): The devices whose async ops need to be waited
|
| 26 |
+
for. If empty, all the local devices will be waited for.
|
| 27 |
+
"""
|
| 28 |
+
if devices is None:
|
| 29 |
+
devices = []
|
| 30 |
+
torch._C._lazy._wait_device_ops(devices=devices)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def sync_multi(tensors, devices):
|
| 34 |
+
"""
|
| 35 |
+
Sync the list of lazy tensors so there IR get lowered for the activate backend
|
| 36 |
+
and the compiled computation graph get cached.
|
| 37 |
+
"""
|
| 38 |
+
torch._C._lazy._sync_multi(tensors, devices)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_tensor_id(tensor):
|
| 42 |
+
"""Return a unique id of the lazy tensor maintained by LTC"""
|
| 43 |
+
return torch._C._lazy._get_tensor_id(tensor)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def to_cpu(tensors, devices=None):
|
| 47 |
+
devices = devices or ["lazy"]
|
| 48 |
+
|
| 49 |
+
flattened, spec = tree_flatten(tensors)
|
| 50 |
+
sync_multi(flattened, devices)
|
| 51 |
+
return tree_unflatten([t.to("cpu") for t in flattened], spec)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def save(tensors, *args, **kwargs):
|
| 55 |
+
torch.save(to_cpu(tensors), *args, **kwargs)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc
ADDED
|
Binary file (887 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc
ADDED
|
Binary file (1.09 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
tensor_factory_functions defines the list of torch functions that create tensors.
|
| 5 |
+
The list is grabbed by searching thru native_functions.yaml by the following
|
| 6 |
+
regular expression:
|
| 7 |
+
|
| 8 |
+
cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor"
|
| 9 |
+
|
| 10 |
+
It's possible that new tensor factory functions are added making this list stale.
|
| 11 |
+
Use at your own risk or regenerate the list.
|
| 12 |
+
"""
|
| 13 |
+
tensor_factory_functions = (
|
| 14 |
+
torch._cudnn_init_dropout_state,
|
| 15 |
+
torch.arange,
|
| 16 |
+
torch.bartlett_window,
|
| 17 |
+
torch.blackman_window,
|
| 18 |
+
torch._empty_affine_quantized,
|
| 19 |
+
torch.empty_strided,
|
| 20 |
+
torch.eye,
|
| 21 |
+
torch.full,
|
| 22 |
+
torch.from_file,
|
| 23 |
+
torch.hann_window,
|
| 24 |
+
torch.hamming_window,
|
| 25 |
+
torch.kaiser_window,
|
| 26 |
+
torch.linspace,
|
| 27 |
+
torch.logspace,
|
| 28 |
+
torch.ones,
|
| 29 |
+
torch.scalar_tensor,
|
| 30 |
+
torch.rand,
|
| 31 |
+
torch.randint,
|
| 32 |
+
torch.randn,
|
| 33 |
+
torch.randperm,
|
| 34 |
+
torch.range,
|
| 35 |
+
torch._efficientzerotensor,
|
| 36 |
+
torch.zeros,
|
| 37 |
+
torch.tril_indices,
|
| 38 |
+
torch.triu_indices,
|
| 39 |
+
# Note: the following functions match the regular expression search above but
|
| 40 |
+
# they are not available in the torch module. Comment out.
|
| 41 |
+
# torch._sparse_coo_tensor_with_dims,
|
| 42 |
+
# torch.fft_fftfreq,
|
| 43 |
+
# torch.fft_rfftfreq,
|
| 44 |
+
) + (
|
| 45 |
+
# torch.tensor is special since it's not in native_functions.yaml
|
| 46 |
+
# add it separately
|
| 47 |
+
torch.tensor,
|
| 48 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__init__.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
import torch._prims as prims
|
| 8 |
+
|
| 9 |
+
import torch._prims_common as utils
|
| 10 |
+
import torch._refs as refs
|
| 11 |
+
import torch._refs.linalg as linalg
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch._prims_common import (
|
| 14 |
+
check_fp_or_complex,
|
| 15 |
+
check_is_matrix,
|
| 16 |
+
Dim,
|
| 17 |
+
DimsType,
|
| 18 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
| 19 |
+
IntLike,
|
| 20 |
+
NumberType,
|
| 21 |
+
TensorLikeType,
|
| 22 |
+
)
|
| 23 |
+
from torch._prims_common.wrappers import (
|
| 24 |
+
_maybe_convert_to_dtype,
|
| 25 |
+
elementwise_type_promotion_wrapper,
|
| 26 |
+
out_wrapper,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"diagonal",
|
| 32 |
+
"matrix_norm",
|
| 33 |
+
"norm",
|
| 34 |
+
"svd",
|
| 35 |
+
"svdvals",
|
| 36 |
+
"vector_norm",
|
| 37 |
+
"vecdot",
|
| 38 |
+
"cross",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str):
|
| 43 |
+
"""
|
| 44 |
+
Checks related to the dtype kwarg in `linalg.*norm` functions
|
| 45 |
+
"""
|
| 46 |
+
if dtype is not None:
|
| 47 |
+
torch._check(
|
| 48 |
+
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
|
| 49 |
+
lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
|
| 50 |
+
)
|
| 51 |
+
torch._check(
|
| 52 |
+
utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
|
| 53 |
+
lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
|
| 54 |
+
fn_name=fn_name,
|
| 55 |
+
d="complex" if utils.is_complex_dtype(x_dtype) else "real",
|
| 56 |
+
dtype=dtype,
|
| 57 |
+
),
|
| 58 |
+
)
|
| 59 |
+
torch._check(
|
| 60 |
+
utils.get_higher_dtype(dtype, x_dtype) == dtype,
|
| 61 |
+
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
|
| 62 |
+
"without narrowing to the specified dtype ({dtype})",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Utilities should come BEFORE this import
|
| 67 |
+
from torch._decomp import register_decomposition
|
| 68 |
+
from torch._decomp.decompositions import pw_cast_for_opmath
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@register_decomposition(torch._ops.ops.aten.linalg_cross)
|
| 72 |
+
@out_wrapper()
|
| 73 |
+
@pw_cast_for_opmath
|
| 74 |
+
def cross(a: Tensor, b: Tensor, dim: int = -1):
|
| 75 |
+
torch._check(
|
| 76 |
+
a.ndim == b.ndim,
|
| 77 |
+
lambda: "linalg.cross: inputs must have the same number of dimensions.",
|
| 78 |
+
)
|
| 79 |
+
torch._check(
|
| 80 |
+
a.size(dim) == 3 and b.size(dim) == 3,
|
| 81 |
+
lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}",
|
| 82 |
+
)
|
| 83 |
+
a, b = torch.broadcast_tensors(a, b)
|
| 84 |
+
dim = utils.canonicalize_dim(a.ndim, dim)
|
| 85 |
+
idx = torch.arange(3, device=a.device)
|
| 86 |
+
return a.index_select(dim, (idx + 1) % 3) * b.index_select(
|
| 87 |
+
dim, (idx + 2) % 3
|
| 88 |
+
) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def diagonal(
|
| 92 |
+
input: TensorLikeType,
|
| 93 |
+
*,
|
| 94 |
+
offset: int = 0,
|
| 95 |
+
dim1: int = -2,
|
| 96 |
+
dim2: int = -1,
|
| 97 |
+
) -> TensorLikeType:
|
| 98 |
+
return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@register_decomposition(torch._ops.ops.aten.linalg_vector_norm)
|
| 102 |
+
@out_wrapper(exact_dtype=True)
|
| 103 |
+
def vector_norm(
|
| 104 |
+
x: TensorLikeType,
|
| 105 |
+
ord: Union[float, int] = 2,
|
| 106 |
+
dim: Optional[DimsType] = None,
|
| 107 |
+
keepdim: bool = False,
|
| 108 |
+
*,
|
| 109 |
+
dtype: Optional[torch.dtype] = None,
|
| 110 |
+
) -> Tensor:
|
| 111 |
+
# Checks
|
| 112 |
+
check_fp_or_complex(x.dtype, "linalg.vector_norm")
|
| 113 |
+
|
| 114 |
+
if isinstance(dim, Dim):
|
| 115 |
+
dim = [dim] # type: ignore[assignment]
|
| 116 |
+
|
| 117 |
+
if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
|
| 118 |
+
torch._check(
|
| 119 |
+
dim is not None and len(dim) != 0,
|
| 120 |
+
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
|
| 121 |
+
"because the operation does not have an identity",
|
| 122 |
+
)
|
| 123 |
+
shape = x.shape
|
| 124 |
+
assert dim is not None # mypy does not seem to be able to see through check?
|
| 125 |
+
for d in dim:
|
| 126 |
+
torch._check(
|
| 127 |
+
shape[d] != 0,
|
| 128 |
+
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
|
| 129 |
+
f"dimension {d} because this dimension is empty and the "
|
| 130 |
+
"operation does not have an identity",
|
| 131 |
+
)
|
| 132 |
+
_check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")
|
| 133 |
+
|
| 134 |
+
computation_dtype, result_dtype = utils.reduction_dtypes(
|
| 135 |
+
x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype)
|
| 139 |
+
|
| 140 |
+
# Implementation
|
| 141 |
+
if ord == 0.0:
|
| 142 |
+
return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
|
| 143 |
+
elif ord == float("inf"):
|
| 144 |
+
return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type]
|
| 145 |
+
elif ord == float("-inf"):
|
| 146 |
+
return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type]
|
| 147 |
+
else:
|
| 148 |
+
# From here on the computation dtype is important as the reduction is non-trivial
|
| 149 |
+
x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment]
|
| 150 |
+
reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim)
|
| 151 |
+
|
| 152 |
+
is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0
|
| 153 |
+
if not (is_ord_even and utils.is_float_dtype(x.dtype)):
|
| 154 |
+
x = torch.abs(x)
|
| 155 |
+
return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _backshift_permutation(dim0, dim1, ndim):
|
| 159 |
+
# Auxiliary function for matrix_norm
|
| 160 |
+
# Computes the permutation that moves the two given dimensions to the back
|
| 161 |
+
ret = [i for i in range(ndim) if i != dim0 and i != dim1]
|
| 162 |
+
ret.extend((dim0, dim1))
|
| 163 |
+
return ret
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _inverse_permutation(perm):
|
| 167 |
+
# Given a permutation, returns its inverse. It's equivalent to argsort on an array
|
| 168 |
+
return [i for i, j in sorted(enumerate(perm), key=lambda i_j: i_j[1])]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# CompositeImplicitAutograd
|
| 172 |
+
@out_wrapper(exact_dtype=True)
|
| 173 |
+
def matrix_norm(
|
| 174 |
+
A: TensorLikeType,
|
| 175 |
+
ord: Union[float, str] = "fro",
|
| 176 |
+
dim: DimsType = (-2, -1),
|
| 177 |
+
keepdim: bool = False,
|
| 178 |
+
*,
|
| 179 |
+
dtype: Optional[torch.dtype] = None,
|
| 180 |
+
) -> TensorLikeType:
|
| 181 |
+
# shape
|
| 182 |
+
check_is_matrix(A, "linalg.matrix_norm")
|
| 183 |
+
# dim
|
| 184 |
+
dim = utils.canonicalize_dims(A.ndim, dim)
|
| 185 |
+
if isinstance(dim, Dim):
|
| 186 |
+
dim = (dim,) # type: ignore[assignment]
|
| 187 |
+
torch._check(
|
| 188 |
+
len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
|
| 189 |
+
)
|
| 190 |
+
torch._check(
|
| 191 |
+
dim[0] != dim[1],
|
| 192 |
+
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
|
| 193 |
+
)
|
| 194 |
+
# dtype arg
|
| 195 |
+
_check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")
|
| 196 |
+
|
| 197 |
+
if isinstance(ord, str):
|
| 198 |
+
# ord
|
| 199 |
+
torch._check(
|
| 200 |
+
ord in ("fro", "nuc"),
|
| 201 |
+
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
| 202 |
+
)
|
| 203 |
+
# dtype
|
| 204 |
+
check_fp_or_complex(
|
| 205 |
+
A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
if ord == "fro":
|
| 209 |
+
return vector_norm(A, 2, dim, keepdim, dtype=dtype)
|
| 210 |
+
else: # ord == "nuc"
|
| 211 |
+
if dtype is not None:
|
| 212 |
+
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
|
| 213 |
+
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
|
| 214 |
+
result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
|
| 215 |
+
if keepdim:
|
| 216 |
+
inv_perm = _inverse_permutation(perm)
|
| 217 |
+
result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
|
| 218 |
+
return result
|
| 219 |
+
else:
|
| 220 |
+
# ord
|
| 221 |
+
abs_ord = abs(ord)
|
| 222 |
+
torch._check(
|
| 223 |
+
abs_ord in (2, 1, float("inf")),
|
| 224 |
+
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
| 225 |
+
)
|
| 226 |
+
# dtype
|
| 227 |
+
check_fp_or_complex(
|
| 228 |
+
A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim)
|
| 232 |
+
|
| 233 |
+
if abs_ord == 2.0:
|
| 234 |
+
if dtype is not None:
|
| 235 |
+
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
|
| 236 |
+
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
|
| 237 |
+
result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
|
| 238 |
+
if keepdim:
|
| 239 |
+
inv_perm = _inverse_permutation(perm)
|
| 240 |
+
result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
|
| 241 |
+
return result
|
| 242 |
+
else: # 1, -1, inf, -inf
|
| 243 |
+
dim0, dim1 = dim
|
| 244 |
+
if abs_ord == float("inf"):
|
| 245 |
+
dim0, dim1 = dim1, dim0
|
| 246 |
+
if not keepdim and (dim0 < dim1):
|
| 247 |
+
dim1 -= 1
|
| 248 |
+
return max_min(
|
| 249 |
+
vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# CompositeImplicitAutograd
|
| 254 |
+
@out_wrapper(exact_dtype=True)
|
| 255 |
+
def norm(
|
| 256 |
+
A: TensorLikeType,
|
| 257 |
+
ord: Optional[Union[float, str]] = None,
|
| 258 |
+
dim: Optional[DimsType] = None,
|
| 259 |
+
keepdim: bool = False,
|
| 260 |
+
*,
|
| 261 |
+
dtype: Optional[torch.dtype] = None,
|
| 262 |
+
) -> TensorLikeType:
|
| 263 |
+
if dim is not None:
|
| 264 |
+
if isinstance(dim, Dim):
|
| 265 |
+
dim = (dim,) # type: ignore[assignment]
|
| 266 |
+
torch._check(
|
| 267 |
+
len(dim) in (1, 2),
|
| 268 |
+
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
|
| 269 |
+
)
|
| 270 |
+
elif ord is not None:
|
| 271 |
+
torch._check(
|
| 272 |
+
A.ndim in (1, 2),
|
| 273 |
+
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
if ord is not None and (
|
| 277 |
+
(dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2)
|
| 278 |
+
):
|
| 279 |
+
if dim is None:
|
| 280 |
+
dim = (0, 1)
|
| 281 |
+
return matrix_norm(A, ord, dim, keepdim, dtype=dtype)
|
| 282 |
+
else:
|
| 283 |
+
if ord is None:
|
| 284 |
+
ord = 2.0
|
| 285 |
+
return vector_norm(A, ord, dim, keepdim, dtype=dtype)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# CompositeImplicitAutograd
|
| 289 |
+
@out_wrapper("U", "S", "Vh", exact_dtype=True)
|
| 290 |
+
def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
|
| 291 |
+
return prims.svd(A, full_matrices=full_matrices)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# CompositeImplicitAutograd
|
| 295 |
+
@out_wrapper(exact_dtype=True)
|
| 296 |
+
def svdvals(A: TensorLikeType) -> Tensor:
|
| 297 |
+
return svd(A, full_matrices=False)[1]
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# CompositeImplicitAutograd
|
| 301 |
+
@out_wrapper()
|
| 302 |
+
@elementwise_type_promotion_wrapper(
|
| 303 |
+
type_promoting_args=("x", "y"),
|
| 304 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 305 |
+
)
|
| 306 |
+
def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
|
| 307 |
+
check_fp_or_complex(x.dtype, "linalg.vecdot")
|
| 308 |
+
return (x.conj() * y).sum(dim=dim)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__init__.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch._prims as prims
|
| 6 |
+
import torch._prims_common as utils
|
| 7 |
+
import torch._refs as refs
|
| 8 |
+
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch._decomp import register_decomposition
|
| 11 |
+
from torch._prims_common import (
|
| 12 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
| 13 |
+
Number,
|
| 14 |
+
NumberType,
|
| 15 |
+
TensorLike,
|
| 16 |
+
TensorLikeType,
|
| 17 |
+
)
|
| 18 |
+
from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper
|
| 19 |
+
from torch._refs import (
|
| 20 |
+
_make_alias,
|
| 21 |
+
_make_elementwise_binary_reference,
|
| 22 |
+
_make_elementwise_unary_reference,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"bessel_j0",
|
| 28 |
+
"bessel_j1",
|
| 29 |
+
"entr",
|
| 30 |
+
"erfcx",
|
| 31 |
+
"expit",
|
| 32 |
+
"i0e",
|
| 33 |
+
"i1",
|
| 34 |
+
"i1e",
|
| 35 |
+
"log_ndtr",
|
| 36 |
+
"logit",
|
| 37 |
+
"log_softmax",
|
| 38 |
+
"multigammaln",
|
| 39 |
+
"ndtr",
|
| 40 |
+
"ndtri",
|
| 41 |
+
"softmax",
|
| 42 |
+
"spherical_bessel_j0",
|
| 43 |
+
"xlog1py",
|
| 44 |
+
"zeta",
|
| 45 |
+
]
|
| 46 |
+
aten = torch._ops.ops.aten
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@_make_elementwise_unary_reference(
|
| 50 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 51 |
+
)
|
| 52 |
+
def bessel_j0(a: TensorLikeType) -> TensorLikeType:
|
| 53 |
+
return prims.bessel_j0(a)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@_make_elementwise_unary_reference(
|
| 57 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 58 |
+
)
|
| 59 |
+
def bessel_j1(a: TensorLikeType) -> TensorLikeType:
|
| 60 |
+
return prims.bessel_j1(a)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@register_decomposition(aten.special_entr)
|
| 64 |
+
@out_wrapper()
|
| 65 |
+
@elementwise_type_promotion_wrapper(
|
| 66 |
+
type_promoting_args=("a",),
|
| 67 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 68 |
+
)
|
| 69 |
+
def entr(a: TensorLikeType) -> TensorLikeType:
|
| 70 |
+
return torch.where(
|
| 71 |
+
torch.isnan(a),
|
| 72 |
+
a,
|
| 73 |
+
torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@register_decomposition(aten.special_erfcx)
|
| 78 |
+
@out_wrapper()
|
| 79 |
+
@elementwise_type_promotion_wrapper(
|
| 80 |
+
type_promoting_args=("a",),
|
| 81 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 82 |
+
)
|
| 83 |
+
def erfcx(a: TensorLikeType) -> TensorLikeType:
|
| 84 |
+
return prims.erfcx(a)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# alias for sigmoid
|
| 88 |
+
expit = _make_alias(torch.sigmoid, "expit")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@_make_elementwise_unary_reference(
|
| 92 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 93 |
+
)
|
| 94 |
+
def i0e(a: TensorLikeType) -> TensorLikeType:
|
| 95 |
+
return prims.bessel_i0e(a)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@_make_elementwise_unary_reference(
|
| 99 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 100 |
+
)
|
| 101 |
+
def i1(a: TensorLikeType) -> TensorLikeType:
|
| 102 |
+
return prims.bessel_i1(a)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@_make_elementwise_unary_reference(
|
| 106 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 107 |
+
)
|
| 108 |
+
def i1e(a: TensorLikeType) -> TensorLikeType:
|
| 109 |
+
return prims.bessel_i1e(a)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@register_decomposition(aten.special_log_ndtr)
|
| 113 |
+
@out_wrapper()
|
| 114 |
+
@elementwise_type_promotion_wrapper(
|
| 115 |
+
type_promoting_args=("a",),
|
| 116 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 117 |
+
)
|
| 118 |
+
def log_ndtr(a: TensorLikeType) -> TensorLikeType:
|
| 119 |
+
# Note: M_SQRT1_2 is the value of 1 / √2
|
| 120 |
+
M_SQRT1_2 = 0.707106781186547524400844362104849039
|
| 121 |
+
t = a * M_SQRT1_2
|
| 122 |
+
return torch.where(
|
| 123 |
+
a < 1.0,
|
| 124 |
+
torch.log(torch.special.erfcx(-t) / 2) - t * t,
|
| 125 |
+
torch.log1p(-torch.erfc(t) / 2),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@register_decomposition(aten.logit)
|
| 130 |
+
@out_wrapper()
|
| 131 |
+
@elementwise_type_promotion_wrapper(
|
| 132 |
+
type_promoting_args=("self",),
|
| 133 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 134 |
+
)
|
| 135 |
+
def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
|
| 136 |
+
if eps is None:
|
| 137 |
+
eps = -1.0
|
| 138 |
+
lo = eps
|
| 139 |
+
hi = 1 - eps
|
| 140 |
+
self = torch.clamp(self, lo, hi)
|
| 141 |
+
return torch.log(torch.true_divide(self, torch.sub(1, self)))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@register_decomposition(aten.special_xlog1py)
|
| 145 |
+
@out_wrapper()
|
| 146 |
+
@elementwise_type_promotion_wrapper(
|
| 147 |
+
type_promoting_args=("a", "b"),
|
| 148 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 149 |
+
)
|
| 150 |
+
def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
|
| 151 |
+
torch._check(
|
| 152 |
+
isinstance(a, TensorLike) or isinstance(b, TensorLike),
|
| 153 |
+
lambda: 'Expected either argument a or b to be a Tensor"',
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
|
| 157 |
+
if isinstance(a, TensorLike) and isinstance(b, Number):
|
| 158 |
+
b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
|
| 159 |
+
elif isinstance(b, TensorLike) and isinstance(a, Number):
|
| 160 |
+
a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
|
| 161 |
+
|
| 162 |
+
# mypy: expected "Tensor"
|
| 163 |
+
assert isinstance(a, TensorLike)
|
| 164 |
+
assert isinstance(b, TensorLike)
|
| 165 |
+
rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b)))
|
| 166 |
+
return torch.where(torch.isnan(b), float("nan"), rhs)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@register_decomposition(aten.mvlgamma)
|
| 170 |
+
@out_wrapper()
|
| 171 |
+
@elementwise_type_promotion_wrapper(
|
| 172 |
+
type_promoting_args=("a",),
|
| 173 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 174 |
+
)
|
| 175 |
+
def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
|
| 176 |
+
c = 0.25 * p * (p - 1) * math.log(math.pi)
|
| 177 |
+
b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
|
| 178 |
+
return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@register_decomposition(aten.special_ndtr)
|
| 182 |
+
@out_wrapper()
|
| 183 |
+
@elementwise_type_promotion_wrapper(
|
| 184 |
+
type_promoting_args=("a",),
|
| 185 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 186 |
+
)
|
| 187 |
+
def ndtr(a: TensorLikeType) -> TensorLikeType:
|
| 188 |
+
# Note: M_SQRT1_2 is the value of 1 / √2
|
| 189 |
+
M_SQRT1_2 = 0.707106781186547524400844362104849039
|
| 190 |
+
a_sqrt_2 = a * M_SQRT1_2
|
| 191 |
+
return (1 + torch.erf(a_sqrt_2)) * 0.5
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@register_decomposition(aten.special_ndtri)
|
| 195 |
+
@out_wrapper()
|
| 196 |
+
@elementwise_type_promotion_wrapper(
|
| 197 |
+
type_promoting_args=("a",),
|
| 198 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 199 |
+
)
|
| 200 |
+
def ndtri(a: TensorLikeType) -> TensorLikeType:
|
| 201 |
+
return prims.ndtri(a)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# Forwarding alias: the special variant doesn't support the out kwarg
|
| 205 |
+
# CompositeImplicitAutograd - don't register decomp
|
| 206 |
+
def log_softmax(
|
| 207 |
+
a: TensorLikeType,
|
| 208 |
+
dim: int,
|
| 209 |
+
dtype: Optional[torch.dtype] = None,
|
| 210 |
+
) -> TensorLikeType:
|
| 211 |
+
return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# Forwarding alias: the special variant doesn't support the out kwarg
|
| 215 |
+
# CompositeImplicitAutograd - don't register decomp
|
| 216 |
+
def softmax(
|
| 217 |
+
a: TensorLikeType,
|
| 218 |
+
dim: int,
|
| 219 |
+
dtype: Optional[torch.dtype] = None,
|
| 220 |
+
) -> TensorLikeType:
|
| 221 |
+
return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@_make_elementwise_unary_reference(
|
| 225 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 226 |
+
)
|
| 227 |
+
def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType:
|
| 228 |
+
return prims.spherical_bessel_j0(a)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# TODO: add docstring
|
| 232 |
+
@_make_elementwise_binary_reference(
|
| 233 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 234 |
+
)
|
| 235 |
+
def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
|
| 236 |
+
return prims.zeta(a, b)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
|
| 2 |
+
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from torch._C import _nvtx
|
| 7 |
+
except ImportError:
|
| 8 |
+
|
| 9 |
+
class _NVTXStub:
|
| 10 |
+
@staticmethod
|
| 11 |
+
def _fail(*args, **kwargs):
|
| 12 |
+
raise RuntimeError(
|
| 13 |
+
"NVTX functions not installed. Are you sure you have a CUDA build?"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
rangePushA = _fail
|
| 17 |
+
rangePop = _fail
|
| 18 |
+
markA = _fail
|
| 19 |
+
|
| 20 |
+
_nvtx = _NVTXStub() # type: ignore[assignment]
|
| 21 |
+
|
| 22 |
+
__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def range_push(msg):
|
| 26 |
+
"""
|
| 27 |
+
Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
msg (str): ASCII message to associate with range
|
| 31 |
+
"""
|
| 32 |
+
return _nvtx.rangePushA(msg)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def range_pop():
|
| 36 |
+
"""Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended."""
|
| 37 |
+
return _nvtx.rangePop()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def range_start(msg) -> int:
|
| 41 |
+
"""
|
| 42 |
+
Mark the start of a range with string message. It returns an unique handle
|
| 43 |
+
for this range to pass to the corresponding call to rangeEnd().
|
| 44 |
+
|
| 45 |
+
A key difference between this and range_push/range_pop is that the
|
| 46 |
+
range_start/range_end version supports range across threads (start on one
|
| 47 |
+
thread and end on another thread).
|
| 48 |
+
|
| 49 |
+
Returns: A range handle (uint64_t) that can be passed to range_end().
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
msg (str): ASCII message to associate with the range.
|
| 53 |
+
"""
|
| 54 |
+
return _nvtx.rangeStartA(msg)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def range_end(range_id) -> None:
|
| 58 |
+
"""
|
| 59 |
+
Mark the end of a range for a given range_id.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
range_id (int): an unique handle for the start range.
|
| 63 |
+
"""
|
| 64 |
+
_nvtx.rangeEnd(range_id)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def mark(msg):
|
| 68 |
+
"""
|
| 69 |
+
Describe an instantaneous event that occurred at some point.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
msg (str): ASCII message to associate with the event.
|
| 73 |
+
"""
|
| 74 |
+
return _nvtx.markA(msg)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@contextmanager
|
| 78 |
+
def range(msg, *args, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
Context manager / decorator that pushes an NVTX range at the beginning
|
| 81 |
+
of its scope, and pops it at the end. If extra arguments are given,
|
| 82 |
+
they are passed as arguments to msg.format().
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
msg (str): message to associate with the range
|
| 86 |
+
"""
|
| 87 |
+
range_push(msg.format(*args, **kwargs))
|
| 88 |
+
try:
|
| 89 |
+
yield
|
| 90 |
+
finally:
|
| 91 |
+
range_pop()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc
ADDED
|
Binary file (5.88 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc
ADDED
|
Binary file (41.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc
ADDED
|
Binary file (4.59 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc
ADDED
|
Binary file (32.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
|
| 2 |
+
BVar
|
| 3 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_leq
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def gen_tvar(curr):
|
| 7 |
+
"""
|
| 8 |
+
Generate a tensor variable
|
| 9 |
+
:param curr: The current counter
|
| 10 |
+
:return: a tensor variable and the updated counter
|
| 11 |
+
"""
|
| 12 |
+
curr += 1
|
| 13 |
+
return TVar(curr), curr
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def gen_dvar(curr):
|
| 17 |
+
"""
|
| 18 |
+
Generate a dimension variable
|
| 19 |
+
:param curr: the current counter
|
| 20 |
+
:return: a dimension variable and an updated counter
|
| 21 |
+
"""
|
| 22 |
+
curr += 1
|
| 23 |
+
return DVar(curr), curr
|
| 24 |
+
|
| 25 |
+
def gen_bvar(curr):
|
| 26 |
+
"""
|
| 27 |
+
Generate a boolean variable
|
| 28 |
+
:param curr: the current counter
|
| 29 |
+
:return: a boolean variable and an updated counter
|
| 30 |
+
"""
|
| 31 |
+
curr += 1
|
| 32 |
+
return BVar(curr), curr
|
| 33 |
+
|
| 34 |
+
def gen_tensor_dims(n, curr):
|
| 35 |
+
"""
|
| 36 |
+
Generate a list of tensor dimensions
|
| 37 |
+
:param n: the number of dimensions
|
| 38 |
+
:param curr: the current counter
|
| 39 |
+
:return: a list of dimension variables and an updated counter
|
| 40 |
+
"""
|
| 41 |
+
dims = []
|
| 42 |
+
for _ in range(n):
|
| 43 |
+
dvar, curr = gen_dvar(curr)
|
| 44 |
+
dims.append(dvar)
|
| 45 |
+
return dims, curr
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def gen_nat_constraints(list_of_dims):
|
| 49 |
+
"""
|
| 50 |
+
Generate natural number constraints for dimensions
|
| 51 |
+
"""
|
| 52 |
+
return [BinConstraintD(0, d, op_leq) for d in list_of_dims]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc
ADDED
|
Binary file (5.28 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc
ADDED
|
Binary file (4.45 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (781 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc
ADDED
|
Binary file (2.32 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc
ADDED
|
Binary file (21.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc
ADDED
|
Binary file (5.94 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc
ADDED
|
Binary file (9.07 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc
ADDED
|
Binary file (25 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc
ADDED
|
Binary file (3.27 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc
ADDED
|
Binary file (6.11 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from . import pass_manager
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from torch.fx.graph_module import GraphModule
|
| 6 |
+
from torch.fx._compatibility import compatibility
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ['PassResult', 'PassBase']
|
| 10 |
+
|
| 11 |
+
@compatibility(is_backward_compatible=False)
|
| 12 |
+
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
|
| 13 |
+
"""
|
| 14 |
+
Result of a pass:
|
| 15 |
+
graph_module: The modified graph module
|
| 16 |
+
modified: A flag for if the pass has modified the graph module
|
| 17 |
+
"""
|
| 18 |
+
def __new__(cls, graph_module, modified):
|
| 19 |
+
return super().__new__(cls, graph_module, modified)
|
| 20 |
+
|
| 21 |
+
@compatibility(is_backward_compatible=False)
|
| 22 |
+
class PassBase(abc.ABC):
|
| 23 |
+
"""
|
| 24 |
+
Base interface for implementing passes.
|
| 25 |
+
|
| 26 |
+
It is required to implement the `call` function so that we can directly
|
| 27 |
+
pass instances of the Pass directly to the PassManager and call them as a
|
| 28 |
+
function.
|
| 29 |
+
|
| 30 |
+
We can directly pass an instance of a class implementing this interface into
|
| 31 |
+
the PassManager's `passes` attribute.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __call__(self, graph_module: GraphModule) -> Optional[PassResult]:
|
| 35 |
+
"""
|
| 36 |
+
Runs the precondition check, the pass itself, and the postcondition check.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
self.requires(graph_module)
|
| 40 |
+
res = self.call(graph_module)
|
| 41 |
+
self.ensures(graph_module)
|
| 42 |
+
return res
|
| 43 |
+
|
| 44 |
+
@abc.abstractmethod
|
| 45 |
+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
|
| 46 |
+
"""
|
| 47 |
+
The pass that is run through the given graph module. To implement a
|
| 48 |
+
pass, it is required to implement this function.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
graph_module: The graph module we will run a pass on
|
| 52 |
+
"""
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
def requires(self, graph_module: GraphModule) -> None: # noqa: B027
|
| 56 |
+
"""
|
| 57 |
+
This function will be called before the pass is run and will check that
|
| 58 |
+
the given graph module contains the preconditions needed to run the
|
| 59 |
+
pass. It is not required to implement this function.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
graph_module: The graph module we will run checks on
|
| 63 |
+
"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
def ensures(self, graph_module: GraphModule) -> None: # noqa: B027
|
| 67 |
+
"""
|
| 68 |
+
This function will be called after the pass is run and will check that
|
| 69 |
+
the given graph module contains the postconditions needed to run the
|
| 70 |
+
pass. It is not required to implement this function.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
graph_module: The graph module we will run checks on
|
| 74 |
+
"""
|
| 75 |
+
pass
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import logging
|
| 3 |
+
from queue import Queue
|
| 4 |
+
from functools import wraps
|
| 5 |
+
from typing import Callable, Dict, List
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.fx.graph_module import GraphModule
|
| 9 |
+
from torch.fx._compatibility import compatibility
|
| 10 |
+
from torch.fx.passes.infra.pass_base import PassResult
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
logger.setLevel(logging.WARNING)
|
| 14 |
+
|
| 15 |
+
__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
|
| 16 |
+
|
| 17 |
+
@compatibility(is_backward_compatible=False)
|
| 18 |
+
def pass_result_wrapper(fn: Callable) -> Callable:
|
| 19 |
+
"""
|
| 20 |
+
Wrapper for passes which currently do not return a PassResult.
|
| 21 |
+
This wrapper makes them return a PassResult containing the modified object
|
| 22 |
+
and True for the "modified" flag.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
fn (Callable[Module, Any])
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
wrapped_fn (Callable[Module, PassResult])
|
| 29 |
+
"""
|
| 30 |
+
if fn is None:
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
@wraps(fn)
|
| 34 |
+
def wrapped_fn(gm):
|
| 35 |
+
res = fn(gm)
|
| 36 |
+
if res is None:
|
| 37 |
+
return PassResult(gm, True)
|
| 38 |
+
if isinstance(res, PassResult):
|
| 39 |
+
return res
|
| 40 |
+
elif isinstance(res, nn.Module):
|
| 41 |
+
return PassResult(res, True)
|
| 42 |
+
|
| 43 |
+
if not inspect.isfunction(fn):
|
| 44 |
+
wrapped_fn.__name__ = type(fn).__name__
|
| 45 |
+
|
| 46 |
+
return wrapped_fn
|
| 47 |
+
|
| 48 |
+
def _validate_pass_schedule_constraint(
|
| 49 |
+
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
|
| 50 |
+
) -> None:
|
| 51 |
+
for i, a in enumerate(passes):
|
| 52 |
+
for j, b in enumerate(passes[i + 1 :]):
|
| 53 |
+
if constraint(a, b):
|
| 54 |
+
continue
|
| 55 |
+
raise RuntimeError(
|
| 56 |
+
f"pass schedule constraint violated. Expected {a} before {b}"
|
| 57 |
+
f" but found {a} at index {i} and {b} at index{j} in pass"
|
| 58 |
+
f" list."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def _topological_sort_passes(
|
| 62 |
+
passes: List[Callable], constraints: List[Callable]
|
| 63 |
+
) -> List[Callable]:
|
| 64 |
+
"""
|
| 65 |
+
Args
|
| 66 |
+
passes: Passes that we are ordering
|
| 67 |
+
constraints: Constraints applied on these passes
|
| 68 |
+
|
| 69 |
+
Returns
|
| 70 |
+
A sorted list of callables and a boolean of if a circular dependency
|
| 71 |
+
existed
|
| 72 |
+
"""
|
| 73 |
+
if len(constraints) == 0:
|
| 74 |
+
return passes
|
| 75 |
+
|
| 76 |
+
# Contruct a graph mapping nodes to a list of their users
|
| 77 |
+
graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
|
| 78 |
+
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
|
| 79 |
+
candidates: Queue = Queue()
|
| 80 |
+
for a in passes:
|
| 81 |
+
for b in passes:
|
| 82 |
+
if a == b:
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
for constraint in constraints:
|
| 86 |
+
if not constraint(a, b):
|
| 87 |
+
graph[b].append(a)
|
| 88 |
+
indegree_map[a] += 1
|
| 89 |
+
|
| 90 |
+
if indegree_map[a] == 0:
|
| 91 |
+
candidates.put(a)
|
| 92 |
+
|
| 93 |
+
visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
|
| 94 |
+
sorted_passes: List[Callable] = []
|
| 95 |
+
|
| 96 |
+
while not candidates.empty():
|
| 97 |
+
p = candidates.get()
|
| 98 |
+
sorted_passes.append(p)
|
| 99 |
+
visited[p] = True
|
| 100 |
+
|
| 101 |
+
for n in graph[p]:
|
| 102 |
+
if not visited[n]:
|
| 103 |
+
indegree_map[n] -= 1
|
| 104 |
+
if indegree_map[n] == 0:
|
| 105 |
+
candidates.put(n)
|
| 106 |
+
|
| 107 |
+
# Check if there are unvisited nodes (aka cycles in the graph)
|
| 108 |
+
cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
|
| 109 |
+
if len(cycle_passes) != 0:
|
| 110 |
+
error = f"Circular dependency detected within the following passes: {cycle_passes}"
|
| 111 |
+
raise RuntimeError(error)
|
| 112 |
+
|
| 113 |
+
return sorted_passes
|
| 114 |
+
|
| 115 |
+
@compatibility(is_backward_compatible=False)
|
| 116 |
+
def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
|
| 117 |
+
"""
|
| 118 |
+
Defines a partial order ('depends on' function) where `this` must occur
|
| 119 |
+
before `that`.
|
| 120 |
+
|
| 121 |
+
For example, the following pass list and constraint list would be invalid.
|
| 122 |
+
```
|
| 123 |
+
passes = [pass_b, pass_a]
|
| 124 |
+
|
| 125 |
+
constraints = [
|
| 126 |
+
this_before_that_pass_constraint(pass_a, pass_b)
|
| 127 |
+
]
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
this (Callable): pass which should occur first
|
| 132 |
+
that (Callable): pass which should occur later
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
depends_on (Callable[[Object, Object], bool]
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def depends_on(a: Callable, b: Callable):
|
| 139 |
+
if a == that and b == this:
|
| 140 |
+
return False
|
| 141 |
+
return True
|
| 142 |
+
|
| 143 |
+
return depends_on
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@compatibility(is_backward_compatible=False)
|
| 147 |
+
class PassManager:
|
| 148 |
+
"""
|
| 149 |
+
Construct a PassManager.
|
| 150 |
+
|
| 151 |
+
Collects passes and constraints. This defines the pass schedule, manages
|
| 152 |
+
pass constraints and pass execution.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
passes (Optional[List[Callable]]): List of passes. A pass is a
|
| 156 |
+
callable which modifies an object and returns a PassResult
|
| 157 |
+
constraint (Optional[List[Callable]]): List of constraints. A
|
| 158 |
+
constraint is a callable which takes two passes (A, B) and returns
|
| 159 |
+
True if A depends on B and False otherwise. See implementation of
|
| 160 |
+
`this_before_that_pass_constraint` for example.
|
| 161 |
+
steps (int): Max number of times we run the passes (default = 1).
|
| 162 |
+
run_checks_after_each_pass (bool): Whether to run checks and linting
|
| 163 |
+
after each pass
|
| 164 |
+
suppress_check_failures (bool): Whether to raise errors when running
|
| 165 |
+
checks
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
passes: List[Callable[[nn.Module], PassResult]]
|
| 169 |
+
constraints: List[Callable[[Callable, Callable], bool]]
|
| 170 |
+
_validated: bool = False
|
| 171 |
+
steps: int = 1
|
| 172 |
+
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
passes=None,
|
| 176 |
+
constraints=None,
|
| 177 |
+
steps=None,
|
| 178 |
+
run_checks_after_each_pass: bool = False,
|
| 179 |
+
suppress_check_failures: bool = False,
|
| 180 |
+
):
|
| 181 |
+
self.passes = passes or []
|
| 182 |
+
self.constraints = constraints or []
|
| 183 |
+
if steps:
|
| 184 |
+
self.steps = steps
|
| 185 |
+
|
| 186 |
+
self.run_checks_after_each_pass = run_checks_after_each_pass
|
| 187 |
+
self.suppress_check_failures = suppress_check_failures
|
| 188 |
+
|
| 189 |
+
def add_pass(self, _pass: Callable):
|
| 190 |
+
"""
|
| 191 |
+
Adds a pass into the current list of passes.
|
| 192 |
+
"""
|
| 193 |
+
self.passes.append(_pass)
|
| 194 |
+
self._validated = False
|
| 195 |
+
|
| 196 |
+
def add_constraint(self, constraint: Callable):
|
| 197 |
+
"""
|
| 198 |
+
Adds a constraint into the current list of constraints.
|
| 199 |
+
"""
|
| 200 |
+
self.constraints.append(constraint)
|
| 201 |
+
self._validated = False
|
| 202 |
+
|
| 203 |
+
def validate_constraints(self):
|
| 204 |
+
"""
|
| 205 |
+
Validates that current pass schedule defined by `self.passes` is valid
|
| 206 |
+
according to all constraints in `self.constraints`
|
| 207 |
+
"""
|
| 208 |
+
if self._validated:
|
| 209 |
+
return
|
| 210 |
+
for constraint in self.constraints:
|
| 211 |
+
_validate_pass_schedule_constraint(constraint, self.passes)
|
| 212 |
+
self._validated = True
|
| 213 |
+
|
| 214 |
+
def solve_constraints(self):
|
| 215 |
+
"""
|
| 216 |
+
Finds a valid traversal order based on the given constraints and orders
|
| 217 |
+
the passes based on this order.
|
| 218 |
+
|
| 219 |
+
If a circular dependency exists between the constraints and steps = 1,
|
| 220 |
+
then we will raise an error because if steps != 1 this means that we
|
| 221 |
+
will re-run the passes, allowing for circular dependencies.
|
| 222 |
+
"""
|
| 223 |
+
self.passes = _topological_sort_passes(self.passes, self.constraints)
|
| 224 |
+
self._validated = True
|
| 225 |
+
|
| 226 |
+
def add_checks(self, check: Callable) -> None:
|
| 227 |
+
"""
|
| 228 |
+
Adds a function which takes runs various checks on a given graph module.
|
| 229 |
+
This function is run before and after each pass if the
|
| 230 |
+
`run_checks_after_each_pass` flag is enabled.
|
| 231 |
+
"""
|
| 232 |
+
sig = inspect.signature(check)
|
| 233 |
+
|
| 234 |
+
if len(list(sig.parameters.values())) != 1:
|
| 235 |
+
raise TypeError("PassManager check function should only take in one variable, a module")
|
| 236 |
+
|
| 237 |
+
setattr(self, "check", check) # noqa: B010
|
| 238 |
+
|
| 239 |
+
def check(self, module: nn.Module) -> None:
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
def __call__(self, module: nn.Module) -> PassResult:
|
| 243 |
+
"""
|
| 244 |
+
Runs a list of passes in the order based on `self.passes` on the given
|
| 245 |
+
graph module. Each time a pass is run, checks and linting will be run on
|
| 246 |
+
the graph module if `run_checks_after_each_pass` is set.
|
| 247 |
+
|
| 248 |
+
If the module is a graph module, we will run the list of passes until
|
| 249 |
+
the graph stops changing, or until `steps` number of times.
|
| 250 |
+
"""
|
| 251 |
+
# Order the passes based on the constraints
|
| 252 |
+
if not self._validated:
|
| 253 |
+
self.solve_constraints()
|
| 254 |
+
|
| 255 |
+
# Check graph invariants
|
| 256 |
+
self.check(module)
|
| 257 |
+
|
| 258 |
+
# Run the set of passes `steps` number of times or until the graph stops
|
| 259 |
+
# changing
|
| 260 |
+
overall_modified = False
|
| 261 |
+
for _ in range(self.steps):
|
| 262 |
+
modified = False
|
| 263 |
+
|
| 264 |
+
# Run the set of passes on the graph module
|
| 265 |
+
for i, fn in enumerate(self.passes):
|
| 266 |
+
fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
|
| 267 |
+
logger.debug("Running pass '%s'", fn_name)
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
res = fn(module)
|
| 271 |
+
|
| 272 |
+
if not isinstance(res, PassResult) and not hasattr(
|
| 273 |
+
res, "graph_module"
|
| 274 |
+
):
|
| 275 |
+
raise TypeError(
|
| 276 |
+
f"The result of the pass {fn_name} should be type PassResult."
|
| 277 |
+
+ "Please wrap it with pass_result_wrapper()"
|
| 278 |
+
)
|
| 279 |
+
module = res.graph_module
|
| 280 |
+
modified = modified or res.modified
|
| 281 |
+
|
| 282 |
+
if isinstance(module, GraphModule):
|
| 283 |
+
logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
|
| 284 |
+
module.recompile()
|
| 285 |
+
|
| 286 |
+
# Check graph invariants
|
| 287 |
+
if self.run_checks_after_each_pass:
|
| 288 |
+
self.check(module)
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
prev_pass_names = [
|
| 292 |
+
p.__name__ if inspect.isfunction(p) else type(p).__name__
|
| 293 |
+
for p in self.passes[:i]
|
| 294 |
+
]
|
| 295 |
+
msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
|
| 296 |
+
raise Exception(msg) from e
|
| 297 |
+
|
| 298 |
+
# If the graph no longer changes, then we can stop running these passes
|
| 299 |
+
overall_modified = overall_modified or modified
|
| 300 |
+
if not modified:
|
| 301 |
+
break
|
| 302 |
+
|
| 303 |
+
return PassResult(module, overall_modified)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.fx import Node
|
| 3 |
+
from torch.fx._compatibility import compatibility
|
| 4 |
+
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
|
| 5 |
+
from torch.utils._pytree import tree_map_only
|
| 6 |
+
from torch.utils import _pytree as pytree
|
| 7 |
+
from torch.multiprocessing.reductions import StorageWeakRef
|
| 8 |
+
|
| 9 |
+
import _operator
|
| 10 |
+
from enum import Enum
|
| 11 |
+
import itertools
|
| 12 |
+
from typing import Set, Dict
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
|
| 15 |
+
__all__ = ['reinplace']
|
| 16 |
+
|
| 17 |
+
class _ViewType(Enum):
|
| 18 |
+
NonView = 0
|
| 19 |
+
SingleOutputView = 1
|
| 20 |
+
MultiOutputView = 2
|
| 21 |
+
|
| 22 |
+
def _is_view_op(tgt):
|
| 23 |
+
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
|
| 24 |
+
schema = tgt._schema
|
| 25 |
+
if len(schema.arguments) > 0:
|
| 26 |
+
first_arg = schema.arguments[0]
|
| 27 |
+
# check if op is a view
|
| 28 |
+
return first_arg.alias_info is not None and not first_arg.alias_info.is_write
|
| 29 |
+
|
| 30 |
+
def _get_view_type(tgt) -> _ViewType:
|
| 31 |
+
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
|
| 32 |
+
schema = tgt._schema
|
| 33 |
+
if len(schema.arguments) > 0:
|
| 34 |
+
first_arg = schema.arguments[0]
|
| 35 |
+
# check if op is a view
|
| 36 |
+
if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
|
| 37 |
+
# check if op is a multi-output view
|
| 38 |
+
if '*' in first_arg.alias_info.after_set:
|
| 39 |
+
return _ViewType.MultiOutputView
|
| 40 |
+
else:
|
| 41 |
+
return _ViewType.SingleOutputView
|
| 42 |
+
return _ViewType.NonView
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Stores a bunch of metadata related to functionalization each node.
|
| 46 |
+
# Relevant metadata:
|
| 47 |
+
# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
|
| 48 |
+
# The fake tensor output from running the current node
|
| 49 |
+
# n.meta['view_of']: Node
|
| 50 |
+
# If the current node n is a view of some base tensor, the 'view_of' field tells us which
|
| 51 |
+
# view node was used to generate the current node (a view tensor).
|
| 52 |
+
# This information actually makes `fake_result` redundant, but we can use `fake_result`
|
| 53 |
+
# to sanity check that our aliasing information is correct.
|
| 54 |
+
@compatibility(is_backward_compatible=False)
|
| 55 |
+
class _FunctionalizationMetadataProp(torch.fx.Interpreter):
|
| 56 |
+
|
| 57 |
+
def run_node(self, node: Node):
|
| 58 |
+
self.node_counter += 1
|
| 59 |
+
result = super().run_node(node)
|
| 60 |
+
node.meta['fake_result'] = result
|
| 61 |
+
node.meta['node_idx'] = self.node_counter
|
| 62 |
+
|
| 63 |
+
# (1) Update metadata with the list of nodes that are used by this node
|
| 64 |
+
# copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
|
| 65 |
+
# We don't want to treat it as "being used as an input".
|
| 66 |
+
node_args = node.args
|
| 67 |
+
if node.target is torch.ops.aten.copy_.default:
|
| 68 |
+
node_args = node_args[1:]
|
| 69 |
+
|
| 70 |
+
# (2) Update metadata to track aliasing information about view tensor nodes.
|
| 71 |
+
if node.op == 'call_function':
|
| 72 |
+
view_type = _get_view_type(node.target)
|
| 73 |
+
if view_type == _ViewType.SingleOutputView:
|
| 74 |
+
assert isinstance(node.args[0], Node)
|
| 75 |
+
node.meta['view_of'] = node.args[0]
|
| 76 |
+
elif view_type == _ViewType.MultiOutputView:
|
| 77 |
+
self.multi_output_view_nodes[node] = node.args[0]
|
| 78 |
+
|
| 79 |
+
# Check if we returned a multi-output view,
|
| 80 |
+
# and we're now grabbing the individual views from the output.
|
| 81 |
+
#
|
| 82 |
+
# For multi-output views, we want to map each output view to the base,
|
| 83 |
+
# but this mapping involves two separate nodes in FX IR.
|
| 84 |
+
# e.g. "a, b = x_1.split(...)" becomes:
|
| 85 |
+
# %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
|
| 86 |
+
# %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
|
| 87 |
+
# %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
|
| 88 |
+
# And we'd like to set:
|
| 89 |
+
# getitem1.meta['view_of'] = x_1
|
| 90 |
+
elif node.target is _operator.getitem:
|
| 91 |
+
list_arg = node.args[0]
|
| 92 |
+
maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
|
| 93 |
+
if maybe_base_of_view is not None:
|
| 94 |
+
# Note: we could also track indexing info here for multi-output views.
|
| 95 |
+
# I don't think this metadata is strictly needed for de-functionalization.
|
| 96 |
+
assert isinstance(maybe_base_of_view, Node)
|
| 97 |
+
node.meta['view_of'] = maybe_base_of_view
|
| 98 |
+
|
| 99 |
+
if 'view_of' in node.meta:
|
| 100 |
+
# We're linking the current node with its first argument as views.
|
| 101 |
+
# Assert here that this is actually the case, and their storages are the same.
|
| 102 |
+
assert isinstance(node.meta['fake_result'], FakeTensor)
|
| 103 |
+
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
|
| 104 |
+
view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
|
| 105 |
+
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
|
| 106 |
+
assert view_storage == base_storage
|
| 107 |
+
return result
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def propagate(self, *args):
|
| 112 |
+
self.multi_output_view_nodes = {}
|
| 113 |
+
self.node_counter = -1
|
| 114 |
+
|
| 115 |
+
with FakeTensorMode() as mode:
|
| 116 |
+
fake_args = [mode.from_tensor(a) for a in args]
|
| 117 |
+
return super().run(*fake_args)
|
| 118 |
+
|
| 119 |
+
def _schemas_match(functional_schema, inplace_schema):
|
| 120 |
+
names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
|
| 121 |
+
arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
|
| 122 |
+
a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
|
| 123 |
+
# for the inplace op, its first argument should be mutable
|
| 124 |
+
assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
|
| 125 |
+
# and its remaining arguments shouldn't be.
|
| 126 |
+
assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
|
| 127 |
+
return names_match and arg_types_match
|
| 128 |
+
|
| 129 |
+
# TODO: this should be beefed up to be able to properly re-inplace with:
|
| 130 |
+
# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
|
| 131 |
+
# - out= ops (e.g. angle -> angle.out)
|
| 132 |
+
# TODO: we should also figure this info out using torchgen.
|
| 133 |
+
def _maybe_get_inplace_op(op):
|
| 134 |
+
# __module__ seems broken; it returns torch._ops.aten which doesn't exist
|
| 135 |
+
if not isinstance(op, torch._ops.OpOverload):
|
| 136 |
+
return None
|
| 137 |
+
# Some view ops have inplace variants (as_strided_, etc),
|
| 138 |
+
# but we do NOT want the reinplacing pass to directly add these into the program.
|
| 139 |
+
# (they'll require extra special handling, aren't aren't really useful for perf anyway)
|
| 140 |
+
if _is_view_op(op):
|
| 141 |
+
return None
|
| 142 |
+
op_namespace = op.__module__.split(".")[-1]
|
| 143 |
+
op_base_name = op.overloadpacket.__name__
|
| 144 |
+
maybe_namespace_module = getattr(torch.ops, op_namespace)
|
| 145 |
+
maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
|
| 146 |
+
if maybe_inplace_op is None:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
inplace_overloads = [
|
| 150 |
+
getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
|
| 151 |
+
]
|
| 152 |
+
inplace_overloads_with_matching_schemas = [
|
| 153 |
+
f
|
| 154 |
+
for f in inplace_overloads
|
| 155 |
+
if _schemas_match(op._schema, f._schema)
|
| 156 |
+
]
|
| 157 |
+
# Just because foo() and foo_() are both existing operators,
|
| 158 |
+
# They aren't guaranteed to have compatible schemas.
|
| 159 |
+
# For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant,
|
| 160 |
+
# Even though several overloads of pow_ exist.
|
| 161 |
+
if len(inplace_overloads_with_matching_schemas) == 0:
|
| 162 |
+
return None
|
| 163 |
+
assert len(inplace_overloads_with_matching_schemas) == 1
|
| 164 |
+
inplace_op = inplace_overloads_with_matching_schemas[0]
|
| 165 |
+
return inplace_op
|
| 166 |
+
|
| 167 |
+
_VIEW_INVERSE_MAP = {
|
| 168 |
+
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
|
| 169 |
+
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
|
| 170 |
+
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
|
| 171 |
+
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
# This function, given a set of set of (aliased) tensor nodes,
|
| 175 |
+
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
|
| 176 |
+
# in the node ordering.
|
| 177 |
+
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
| 178 |
+
def _add_if_tensor(x, set_):
|
| 179 |
+
if isinstance(x, FakeTensor):
|
| 180 |
+
set_.add(StorageWeakRef(x._typed_storage()))
|
| 181 |
+
|
| 182 |
+
nodes_used_after = set()
|
| 183 |
+
for t in tensor_aliases:
|
| 184 |
+
# get all nodes that use the current alias
|
| 185 |
+
usage_nodes = t.users
|
| 186 |
+
for n in usage_nodes:
|
| 187 |
+
# We only care about usages after the current node
|
| 188 |
+
if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
|
| 189 |
+
continue
|
| 190 |
+
# We also don't care about intermediate view ops.
|
| 191 |
+
# They only matter if their output is then used elsewhere
|
| 192 |
+
# (either in an out-of-place op, or as an output to the function).
|
| 193 |
+
if n in tensor_aliases:
|
| 194 |
+
if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
|
| 195 |
+
continue
|
| 196 |
+
nodes_used_after.add(n)
|
| 197 |
+
return nodes_used_after
|
| 198 |
+
|
| 199 |
+
# Given an op that we're trying to re-inplace, "b = foo(a)",
|
| 200 |
+
# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
|
| 201 |
+
# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
|
| 202 |
+
# If there are any aliases in the alias_set(a) that satisfy:
|
| 203 |
+
# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
|
| 204 |
+
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
|
| 205 |
+
# as "alias"
|
| 206 |
+
def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
|
| 207 |
+
def matching_view_metadata(a, b):
|
| 208 |
+
return a.size() == b.size() and \
|
| 209 |
+
a.stride() == b.stride() and \
|
| 210 |
+
a.storage_offset() == b.storage_offset()
|
| 211 |
+
|
| 212 |
+
view_inverse_nodes = set()
|
| 213 |
+
# Go through them in node order, so we can see chains of view_scatter ops.
|
| 214 |
+
for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
|
| 215 |
+
if n.target not in _VIEW_INVERSE_MAP:
|
| 216 |
+
continue
|
| 217 |
+
base = n.args[0]
|
| 218 |
+
mutated_view = n.args[1]
|
| 219 |
+
assert isinstance(base, Node)
|
| 220 |
+
assert isinstance(base.meta['fake_result'], FakeTensor)
|
| 221 |
+
assert isinstance(mutated_view, Node)
|
| 222 |
+
assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
|
| 223 |
+
# Check that this view_inverse op actually corresponds to taking doing the inverse
|
| 224 |
+
# of one of our existing self_alias nodes.
|
| 225 |
+
original_view = _VIEW_INVERSE_MAP[n.target]
|
| 226 |
+
for self_alias in self_aliases:
|
| 227 |
+
# We're looking for some alias of the self arg, "alias",
|
| 228 |
+
# that was created from some op `alias = foo(base, args...)`
|
| 229 |
+
# such that the current _scatter op "inverts" that foo call.
|
| 230 |
+
# We can check that by running the original op again, and checking that the strides match.
|
| 231 |
+
if 'view_of' not in self_alias.meta:
|
| 232 |
+
continue
|
| 233 |
+
self_alias_base = self_alias.meta['view_of']
|
| 234 |
+
try:
|
| 235 |
+
# The we're trying to re-use the args from the view_scatter call inside of the corresponding
|
| 236 |
+
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
|
| 237 |
+
# of the current alias we're looking at.
|
| 238 |
+
view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
|
| 239 |
+
expected_metadata = self_alias.meta['fake_result']
|
| 240 |
+
# If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
|
| 241 |
+
if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
|
| 242 |
+
matching_view_metadata(view_replay_metadata, expected_metadata):
|
| 243 |
+
view_inverse_nodes.add(n)
|
| 244 |
+
except Exception:
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
return view_inverse_nodes
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@compatibility(is_backward_compatible=True)
|
| 251 |
+
def reinplace(gm, *sample_args):
|
| 252 |
+
"""
|
| 253 |
+
Given an fx.GraphModule, modifies it to perform "reinplacing",
|
| 254 |
+
mutating the nodes of the graph.
|
| 255 |
+
We look for out-of-place op call sites like `b = a.add(...)`,
|
| 256 |
+
and convert them to be inplace (`b = a.add_(...)`),
|
| 257 |
+
as long as the input to the current operator ("a") isn't re-used
|
| 258 |
+
anywhere later in the graph.
|
| 259 |
+
|
| 260 |
+
This pass currently expects to operate on a **functional, ATen** graph.
|
| 261 |
+
This can be obtained by running `make_fx(functionalize(f))`.
|
| 262 |
+
|
| 263 |
+
Sample inputs are needed to determine aliasing relationships of the inputs.
|
| 264 |
+
In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
|
| 265 |
+
inputs to the program.
|
| 266 |
+
|
| 267 |
+
Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows:
|
| 268 |
+
|
| 269 |
+
(1) Perform some initial checks on the metadata of "a" and "args..."
|
| 270 |
+
that can disqualify them from being reinplaced.
|
| 271 |
+
|
| 272 |
+
(1a) Check that the self argument we're attempting to reinplace
|
| 273 |
+
has acceptable dtype/size metadata to reinplace with.
|
| 274 |
+
|
| 275 |
+
For example, if we have:
|
| 276 |
+
a = torch.ones(1)
|
| 277 |
+
b = torch.ones(10)
|
| 278 |
+
out = torch.add(a, b)
|
| 279 |
+
We can't turn that into
|
| 280 |
+
a.add_(b)
|
| 281 |
+
Because that would require resizing "a".
|
| 282 |
+
|
| 283 |
+
Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
|
| 284 |
+
because that would require changing a's dtype (from e.g. float32 to bool).
|
| 285 |
+
Note that in this specific example, we could technically do better..
|
| 286 |
+
|
| 287 |
+
If we see the pattern:
|
| 288 |
+
a_1 = a.ge(b)
|
| 289 |
+
a_2 = aten._to_copy(a_1, a.dtype)
|
| 290 |
+
Then we this should be valid to completely re-inplace
|
| 291 |
+
(this is exactly what functionalization will emit when it sees a.ge_(b)).
|
| 292 |
+
|
| 293 |
+
This optimization is only really important for user programs
|
| 294 |
+
that directly use inplace comparison ops though.
|
| 295 |
+
|
| 296 |
+
We also cannot re-inplace on tensors that have overlapping memory,
|
| 297 |
+
e.g. torch.ones(1).expand(4, 4).add_(1)
|
| 298 |
+
|
| 299 |
+
(1b) Check if "a" is an alias of any of the program inputs.
|
| 300 |
+
|
| 301 |
+
If it is, skip and move to the next node.
|
| 302 |
+
Inplace'ing an op that would cause it to mutate a program is not sound,
|
| 303 |
+
because that would be a side effect visible to the user.
|
| 304 |
+
|
| 305 |
+
NOTE: there's a future optimization that we should make:
|
| 306 |
+
if "a" is a (alias of a) program input, but later in the program
|
| 307 |
+
there is a node that looks like "a.copy_(...)",
|
| 308 |
+
Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
|
| 309 |
+
which will later be overwritten by the copy_() call.
|
| 310 |
+
|
| 311 |
+
This will be an important optimization to have for programs that mutate
|
| 312 |
+
their inputs. It currently isn't implemented though.
|
| 313 |
+
|
| 314 |
+
(1c) Check if "a" and "args..." alias
|
| 315 |
+
|
| 316 |
+
For example, re-inplacing to create code like the below
|
| 317 |
+
isn't guaranteed to be sound:
|
| 318 |
+
|
| 319 |
+
aten.mul_(a, a)
|
| 320 |
+
|
| 321 |
+
(2) Check that "a" and all of its outstanding aliases are not used anywhere
|
| 322 |
+
later in the graph. If this is the case, then it's safe to re-inplace
|
| 323 |
+
to "b = foo_(a)".
|
| 324 |
+
|
| 325 |
+
There are a few caveats to this, explained in more detail below:
|
| 326 |
+
(a) If "a" is used later as an argument to a view op, that is okay.
|
| 327 |
+
It's only a problem if "a" (or that view) is later passed
|
| 328 |
+
into a normal operator, or if it is returned as the program output.
|
| 329 |
+
(b) If "a" is a repeat argument in `foo()`, then don't reinplace.
|
| 330 |
+
Most ATen kernels don't make any guarantees that this is sound,
|
| 331 |
+
e.g. if you do aten.mul_(a, a).
|
| 332 |
+
So we'll just ban re-inplacing in this case.
|
| 333 |
+
It's only a problem if "a" (or that view) is later passed
|
| 334 |
+
(c) If "a" is used as an input into a view "inverse" / "scatter"
|
| 335 |
+
operator, it is potentially fine to re-inplace
|
| 336 |
+
(and remove that scatter operator from the graph).
|
| 337 |
+
See below for a more detailed example.
|
| 338 |
+
|
| 339 |
+
NOTE: there is an optimization in this step that is crucial
|
| 340 |
+
to fully recovering performance from functionalization.
|
| 341 |
+
|
| 342 |
+
Given this program:
|
| 343 |
+
def f(x):
|
| 344 |
+
a = torch.ops.aten.add(x, x)
|
| 345 |
+
b = torch.ops.aten.diagonal(a)
|
| 346 |
+
torch.ops.aten.fill_(b, 0)
|
| 347 |
+
return d
|
| 348 |
+
|
| 349 |
+
Functionalization will emit the following:
|
| 350 |
+
def f(x):
|
| 351 |
+
a = torch.ops.aten.add(x, x)
|
| 352 |
+
b = torch.ops.aten.diagonal(a, 0, 1)
|
| 353 |
+
b_updated = torch.ops.aten.fill(b, 0)
|
| 354 |
+
a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
|
| 355 |
+
return a_updated
|
| 356 |
+
|
| 357 |
+
Ordinarily, we would not be able to reinplace the fill,
|
| 358 |
+
because "b" aliases with "a" which is used by the diagonal_scatter call.
|
| 359 |
+
|
| 360 |
+
"re-inplacing" is on the hook for figuring out that it is ok to
|
| 361 |
+
completely, the expensive diagonal_scatter call, if we re-inplace the add().
|
| 362 |
+
|
| 363 |
+
So, for every `alias in alias_set(a)`, instead of checking
|
| 364 |
+
that "alias" is not used anywhere later in the graph,
|
| 365 |
+
we check that
|
| 366 |
+
EITHER:
|
| 367 |
+
(a) alias is not used anywhere later in the graph
|
| 368 |
+
OR:
|
| 369 |
+
(b) alias is used exactly once later on in the graph,
|
| 370 |
+
in the following op:
|
| 371 |
+
|
| 372 |
+
out = foo_scatter(alias, x, args...)
|
| 373 |
+
|
| 374 |
+
where the following must hold:
|
| 375 |
+
(i) "foo_scatter" is the "inverse" operator for foo.
|
| 376 |
+
This only applies to "foo" ops that are view operators,
|
| 377 |
+
which view into a subset of the original tensor's memory.
|
| 378 |
+
In practice, there are ~4 operators where this applies:
|
| 379 |
+
diagonal -> diagonal_scatter
|
| 380 |
+
slice -> slice_scatter
|
| 381 |
+
select -> select_scatter
|
| 382 |
+
as_strided -> as_strided_scatter
|
| 383 |
+
(ii) "args..." are the same between the foo() and foo_scatter() calls.
|
| 384 |
+
|
| 385 |
+
(3) Perform the actual re-inplacing on foo!
|
| 386 |
+
|
| 387 |
+
(3b) is the common case, but special care is needed for {view}_scatter (3a)
|
| 388 |
+
|
| 389 |
+
(3a) {view}_scatter ops.
|
| 390 |
+
|
| 391 |
+
Consider this program:
|
| 392 |
+
a = torch.zeros(2, 2)
|
| 393 |
+
b = torch.ones(2)
|
| 394 |
+
a[0] = b
|
| 395 |
+
|
| 396 |
+
Post functionalization, that will look like:
|
| 397 |
+
a = torch.zeros(2)
|
| 398 |
+
b = torch.ones(1)
|
| 399 |
+
a_updated = torch.select_scatter(a, b, 0, 0)
|
| 400 |
+
|
| 401 |
+
In this case though, there is no "functional" op to re-inplace!
|
| 402 |
+
Instead, we'd like to directly remove toe select_scatter call.
|
| 403 |
+
We already know from (3) that this is valid,
|
| 404 |
+
because "a" has no later usages in the graph.
|
| 405 |
+
|
| 406 |
+
We perform the re-inplacing on the {view}_scatter op like so
|
| 407 |
+
Before:
|
| 408 |
+
a_updated = torch.select_scatter(a, b, args...)
|
| 409 |
+
After:
|
| 410 |
+
a_slice = a.select(a, args...)
|
| 411 |
+
a_slice.copy_(b)
|
| 412 |
+
|
| 413 |
+
(3b) Otherwise, replace the functional op with its inplace variant.
|
| 414 |
+
Before:
|
| 415 |
+
b = foo(a, args...)
|
| 416 |
+
After:
|
| 417 |
+
a.foo_(args...)
|
| 418 |
+
|
| 419 |
+
(4) Finally, after converting either:
|
| 420 |
+
Before:
|
| 421 |
+
b = foo(a)
|
| 422 |
+
After:
|
| 423 |
+
foo_(a)
|
| 424 |
+
or
|
| 425 |
+
Before:
|
| 426 |
+
b = {slice}_scatter(a, mutated_slice, args...)
|
| 427 |
+
After:
|
| 428 |
+
slice = {slice}(a, args...)
|
| 429 |
+
slice.copy_(mutated_slice)
|
| 430 |
+
|
| 431 |
+
We now need to find all later nodes that use "b" as an argument
|
| 432 |
+
and update them to take in "a" instead.
|
| 433 |
+
|
| 434 |
+
Note that for the majority of inplace ops, this isn't actually necessary
|
| 435 |
+
(because most inplace ops return "self" as their output).
|
| 436 |
+
This isn't generally true for all mutable ops though, which is why
|
| 437 |
+
we need to actually replace all of the arguments.
|
| 438 |
+
|
| 439 |
+
We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
|
| 440 |
+
That maps a given tensor storage to the set of all nodes that take in that storage
|
| 441 |
+
as an input.
|
| 442 |
+
Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
|
| 443 |
+
together.
|
| 444 |
+
|
| 445 |
+
(5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
|
| 446 |
+
during step (3) get manually deleted from the graph.
|
| 447 |
+
Their outputs are no longer used, so technically standard DCE would be able
|
| 448 |
+
to do this, but we can no longer run FX's DCE pass now that we have mutable
|
| 449 |
+
ops in the graph.
|
| 450 |
+
"""
|
| 451 |
+
_FunctionalizationMetadataProp(gm).propagate(*sample_args)
|
| 452 |
+
|
| 453 |
+
# Useful debug printing
|
| 454 |
+
# def _print(x):
|
| 455 |
+
# if isinstance(x, FakeTensor):
|
| 456 |
+
# print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
|
| 457 |
+
|
| 458 |
+
# for n in gm.graph.nodes:
|
| 459 |
+
# print(n.format_node())
|
| 460 |
+
# if hasattr(n, 'meta'):
|
| 461 |
+
# print(f'node_idx: {n.meta["node_idx"]}')
|
| 462 |
+
# if 'fake_result' in n.meta:
|
| 463 |
+
# tree_map(_print, n.meta['fake_result'])
|
| 464 |
+
# if 'view_of' in n.meta:
|
| 465 |
+
# print(f'view_of: {str(n.meta["view_of"])}')
|
| 466 |
+
# print()
|
| 467 |
+
|
| 468 |
+
# We need to know which nodes correspond to inputs (or their aliases)
|
| 469 |
+
# so we know not to re-inplace them.
|
| 470 |
+
# NOTE: later, we'll need to add an optimization for fully recovering performance
|
| 471 |
+
# on programs that mutate inputs.
|
| 472 |
+
input_storages = {
|
| 473 |
+
StorageWeakRef(
|
| 474 |
+
node.meta['fake_result']._typed_storage()
|
| 475 |
+
) for node in gm.graph.nodes if node.op == 'placeholder'}
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# We also need to know for a given node, what are all of its aliasing nodes.
|
| 479 |
+
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
|
| 480 |
+
for n in gm.graph.nodes:
|
| 481 |
+
if 'fake_result' in n.meta:
|
| 482 |
+
# Tree-mapping because some ops can return lists of tensors.
|
| 483 |
+
def _add_to_map(x):
|
| 484 |
+
if isinstance(x, FakeTensor):
|
| 485 |
+
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
|
| 486 |
+
pytree.tree_map_(_add_to_map, n.meta['fake_result'])
|
| 487 |
+
|
| 488 |
+
# inplace-ify functional ops, subject to the constraints written below.
|
| 489 |
+
all_later_view_inverse_nodes_to_delete = set()
|
| 490 |
+
for idx, node in enumerate(gm.graph.nodes):
|
| 491 |
+
if node.op == 'call_function':
|
| 492 |
+
|
| 493 |
+
# Today, the re-inplace pass on directly acts on:
|
| 494 |
+
# - functional ops with an inplace variant
|
| 495 |
+
# - {view}_scatter ops that can be potentially removed from the graph.
|
| 496 |
+
# Both of these ops take in tensor first args, so filtering on this condition
|
| 497 |
+
# makes the later code simpler.
|
| 498 |
+
# We should revisit this at some point though, particularly when we also want
|
| 499 |
+
# the reinplacer to be able to handle out= and mutable operators
|
| 500 |
+
# and tensorlist first args (like `_foreach_` ops).
|
| 501 |
+
if not isinstance(node.target, torch._ops.OpOverload):
|
| 502 |
+
continue
|
| 503 |
+
if len(node.target._schema.arguments) < 1:
|
| 504 |
+
continue
|
| 505 |
+
if type(node.target._schema.arguments[0].type) != torch.TensorType:
|
| 506 |
+
continue
|
| 507 |
+
|
| 508 |
+
# Step 1a: Check that the self argument we're attempting to reinplace
|
| 509 |
+
# has the same size/stride as the output.
|
| 510 |
+
# For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor)
|
| 511 |
+
# As it would require resizing scalar_tensor.
|
| 512 |
+
# (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
|
| 513 |
+
# this is probably an optimization to revisit later).
|
| 514 |
+
self_arg = node.args[0]
|
| 515 |
+
self_flattened = pytree.tree_leaves(self_arg.meta['fake_result'])
|
| 516 |
+
node_flattened = pytree.tree_leaves(node.meta['fake_result'])
|
| 517 |
+
self_has_wrong_metadata = False
|
| 518 |
+
if len(self_flattened) == len(node_flattened):
|
| 519 |
+
for self_meta, node_meta in zip(self_flattened, node_flattened):
|
| 520 |
+
if self_meta.numel() != node_meta.numel():
|
| 521 |
+
self_has_wrong_metadata = True
|
| 522 |
+
if self_meta.dtype != node_meta.dtype:
|
| 523 |
+
self_has_wrong_metadata = True
|
| 524 |
+
# We also cannot re-inplace on tensors that have internal memory overlap.
|
| 525 |
+
# e.g. torch.ones(1).expand(4, 4).add_(1)
|
| 526 |
+
if torch._debug_has_internal_overlap(self_meta) == 1:
|
| 527 |
+
self_has_wrong_metadata = True
|
| 528 |
+
# Here, we (optimistically) assume that a.resize(b) is valid to re-inplace,
|
| 529 |
+
# Since users should never really be calling the functional "torch.ops.aten.resize"
|
| 530 |
+
# op directly in their programs.
|
| 531 |
+
if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default:
|
| 532 |
+
continue
|
| 533 |
+
|
| 534 |
+
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
|
| 535 |
+
self_arg_name = self_arg.name
|
| 536 |
+
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
| 537 |
+
if self_arg_storage in input_storages:
|
| 538 |
+
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
|
| 539 |
+
continue
|
| 540 |
+
if len([x for x in node.args if x is self_arg]) > 1:
|
| 541 |
+
# Step 1c:
|
| 542 |
+
# Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
|
| 543 |
+
# so we prevent re-inplacing in this case.
|
| 544 |
+
continue
|
| 545 |
+
|
| 546 |
+
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
| 547 |
+
self_aliases = storage_to_nodes[self_arg_storage]
|
| 548 |
+
|
| 549 |
+
# First, we find all later usages of any of the aliases of self_arg.
|
| 550 |
+
later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
|
| 551 |
+
# Then, we check if any of those later usages are actually view_scatter ops
|
| 552 |
+
# that are safe to fully remove.
|
| 553 |
+
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
|
| 554 |
+
|
| 555 |
+
# Step 2: Check to see if the input to the op is re-used later in the graph.
|
| 556 |
+
# If not (same goes for its aliases), then this op is safe to re-in place.
|
| 557 |
+
# This is a slightly roundabout way to check that there are no later usages of the current self argument.
|
| 558 |
+
# (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
|
| 559 |
+
can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
|
| 560 |
+
if not can_reinplace:
|
| 561 |
+
continue
|
| 562 |
+
|
| 563 |
+
# Step 3a: Special handling for when we see *_scatter operators.
|
| 564 |
+
# When we see an operator like `b = torch.slice_scatter(a, ...)`,
|
| 565 |
+
# instead of trying to "inplace" it into a.slice_scatter_(..._),
|
| 566 |
+
# we would prefer to remove it from the graph entirely,
|
| 567 |
+
# and instead copy_() the slice directly into the larger tensor.
|
| 568 |
+
# See the description of the algorithm for a full example.
|
| 569 |
+
if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
|
| 570 |
+
view_op = _VIEW_INVERSE_MAP[node.target]
|
| 571 |
+
# Before:
|
| 572 |
+
# base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
|
| 573 |
+
# After:
|
| 574 |
+
# slice = torch.ops.aten.slice.default(base, args...)
|
| 575 |
+
# slice.copy_(mutated_slice)
|
| 576 |
+
with gm.graph.inserting_before(node):
|
| 577 |
+
mutated_slice_node = node.args[1]
|
| 578 |
+
remaining_slice_args = node.args[2:]
|
| 579 |
+
slice_node = gm.graph.create_node(
|
| 580 |
+
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
|
| 581 |
+
copy_node = gm.graph.create_node(
|
| 582 |
+
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
|
| 583 |
+
# Add the slice_scatter node to our "nodes to delete" list.
|
| 584 |
+
all_later_view_inverse_nodes_to_delete.add(node)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
else:
|
| 588 |
+
# Step 3b: Check to see if this operator has an inplace variant.
|
| 589 |
+
maybe_inplace_op = _maybe_get_inplace_op(node.target)
|
| 590 |
+
if maybe_inplace_op is None:
|
| 591 |
+
continue
|
| 592 |
+
# And if so, replace it with its inplace variant.
|
| 593 |
+
node.target = maybe_inplace_op
|
| 594 |
+
|
| 595 |
+
# At this point, 'storage_to_nodes' will be stale.
|
| 596 |
+
# Now that we're inplacing `b = foo(a)`, we need to effectively
|
| 597 |
+
# union together the dict values for b and a's storage.
|
| 598 |
+
# Hmm... morally I think we also want to keep the `fake_result` metadata
|
| 599 |
+
# up to date here, but I'm not sure how easy it is to do.
|
| 600 |
+
# Maybe it's fine to wait until the end of the pass to update it.
|
| 601 |
+
curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
|
| 602 |
+
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
|
| 603 |
+
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
|
| 604 |
+
|
| 605 |
+
# Need to remember the view_scatter view nodes we found so we can remove them alter.
|
| 606 |
+
all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
|
| 607 |
+
|
| 608 |
+
# Step 4:
|
| 609 |
+
# Now that we've replaced b = a.foo() with a.foo_(),
|
| 610 |
+
# We need to replace any later usages of "b" with "a"
|
| 611 |
+
for old in itertools.chain([node], later_view_inverse_node_usages):
|
| 612 |
+
new = old.args[0]
|
| 613 |
+
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
|
| 614 |
+
for node_to_update in nodes_to_update:
|
| 615 |
+
new_args = []
|
| 616 |
+
args = node_to_update.args
|
| 617 |
+
|
| 618 |
+
def replace_arg(a):
|
| 619 |
+
if a == old:
|
| 620 |
+
return new
|
| 621 |
+
return a
|
| 622 |
+
|
| 623 |
+
# First, replace usages of "b" with "a"
|
| 624 |
+
node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
|
| 625 |
+
node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
|
| 626 |
+
|
| 627 |
+
# Second, update our storage_to_nodes data structure.
|
| 628 |
+
old_flattened_res = pytree.tree_leaves(old.meta['fake_result'])
|
| 629 |
+
node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result'])
|
| 630 |
+
|
| 631 |
+
old_res_storage = {
|
| 632 |
+
StorageWeakRef(
|
| 633 |
+
x._typed_storage()
|
| 634 |
+
) for x in old_flattened_res if isinstance(x, FakeTensor)}
|
| 635 |
+
node_res_storage = {
|
| 636 |
+
StorageWeakRef(
|
| 637 |
+
x._typed_storage()
|
| 638 |
+
) for x in node_flattened_res if isinstance(x, FakeTensor)}
|
| 639 |
+
|
| 640 |
+
# This will happen if we're updating a view op, e.g.
|
| 641 |
+
# e.g. replacing
|
| 642 |
+
# x = view(old)
|
| 643 |
+
# x = view(new)
|
| 644 |
+
# When that happens, we need to make sure to keep our
|
| 645 |
+
# storage mapping up to date.
|
| 646 |
+
#
|
| 647 |
+
# We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
|
| 648 |
+
# or multiple tensors that all share the same storage.
|
| 649 |
+
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
|
| 650 |
+
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
|
| 651 |
+
new_flattened_res = pytree.tree_leaves(new.meta['fake_result'])
|
| 652 |
+
new_res_storage = {
|
| 653 |
+
StorageWeakRef(
|
| 654 |
+
x._typed_storage()
|
| 655 |
+
) for x in new_flattened_res if isinstance(x, FakeTensor)}
|
| 656 |
+
assert len(new_res_storage) == 1
|
| 657 |
+
(old_ref,) = old_res_storage
|
| 658 |
+
(new_ref,) = new_res_storage
|
| 659 |
+
(node_ref,) = node_res_storage
|
| 660 |
+
# Technically, "old_ref" and all its aliases will remain
|
| 661 |
+
# in our mapping.
|
| 662 |
+
# That should be fine though, since we deleted "old"
|
| 663 |
+
# from the graph at this point.
|
| 664 |
+
storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
|
| 665 |
+
storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
|
| 666 |
+
|
| 667 |
+
# Step 4: delete any _scatter nodes that we de-functionalized
|
| 668 |
+
# Need to take care not to delete any of these nodes until after *all* modifications
|
| 669 |
+
# to the graph are finished.
|
| 670 |
+
for to_delete in all_later_view_inverse_nodes_to_delete:
|
| 671 |
+
gm.graph.erase_node(to_delete)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
gm.recompile()
|
| 675 |
+
return gm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc
ADDED
|
Binary file (22.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from torch.fx.graph import Graph
|
| 3 |
+
from torch.fx.node import Node
|
| 4 |
+
from torch.fx._compatibility import compatibility
|
| 5 |
+
from typing import Dict, List, Any, Type, Optional, Callable
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition']
|
| 11 |
+
|
| 12 |
+
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
|
| 13 |
+
def _init_logger():
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
|
| 17 |
+
logger.setLevel(level)
|
| 18 |
+
console = logging.StreamHandler()
|
| 19 |
+
formatter = logging.Formatter("%(filename)s > %(message)s")
|
| 20 |
+
console.setFormatter(formatter)
|
| 21 |
+
console.setLevel(level)
|
| 22 |
+
# add the handlers to the logger
|
| 23 |
+
logger.addHandler(console)
|
| 24 |
+
logger.propagate = False
|
| 25 |
+
return logger
|
| 26 |
+
|
| 27 |
+
logger = _init_logger()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@compatibility(is_backward_compatible=False)
|
| 31 |
+
@dataclass
|
| 32 |
+
class SourcePartition:
|
| 33 |
+
# Nodes in a particular partition
|
| 34 |
+
nodes: List[Node]
|
| 35 |
+
|
| 36 |
+
# The source these nodes decomposed from
|
| 37 |
+
source: Any
|
| 38 |
+
|
| 39 |
+
# Nodes in the graph that are needed as inputs to the partition
|
| 40 |
+
input_nodes: List[Node] = field(default_factory=list)
|
| 41 |
+
|
| 42 |
+
# Nodes in the partition that are being used by nodes outside of the
|
| 43 |
+
# partition
|
| 44 |
+
output_nodes: List[Node] = field(default_factory=list)
|
| 45 |
+
|
| 46 |
+
# Parameters that are being used
|
| 47 |
+
params: List[Node] = field(default_factory=list)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@compatibility(is_backward_compatible=False)
|
| 51 |
+
def get_source_partitions(
|
| 52 |
+
graph: Graph,
|
| 53 |
+
wanted_sources: List[Any],
|
| 54 |
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
| 55 |
+
) -> Dict[Any, List[SourcePartition]]:
|
| 56 |
+
"""
|
| 57 |
+
Args:
|
| 58 |
+
graph: The graph we want to partition
|
| 59 |
+
wanted_sources: List of sources of nodes that were decomposed from this
|
| 60 |
+
source. This can be a function (ex. torch.nn.functional.linear) or a
|
| 61 |
+
leaf module type (ex. torch.nn.Linear).
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Dictionary mapping sources that were given to a list of SourcePartitions
|
| 65 |
+
that correspond to the list of nodes that were decomposed from the given
|
| 66 |
+
source.
|
| 67 |
+
"""
|
| 68 |
+
modules: Dict[Type, Dict[str, List[Node]]] = {}
|
| 69 |
+
|
| 70 |
+
for node in graph.nodes:
|
| 71 |
+
# The metadata source_fn should contain a tuple of a unique name for the
|
| 72 |
+
# source, and the source function if the node is decomposed from a
|
| 73 |
+
# function, or the type of module if the node is decomposed from a leaf
|
| 74 |
+
# module
|
| 75 |
+
|
| 76 |
+
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None:
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
source_fn = source_fn_st[-1]
|
| 80 |
+
if source_fn[1] not in wanted_sources:
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
diff_modules = modules.setdefault(source_fn[1], {})
|
| 84 |
+
partition = diff_modules.setdefault(source_fn[0], [])
|
| 85 |
+
partition.append(node)
|
| 86 |
+
|
| 87 |
+
def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
|
| 88 |
+
input_nodes = set()
|
| 89 |
+
output_nodes = set()
|
| 90 |
+
params = set()
|
| 91 |
+
for node in nodes:
|
| 92 |
+
for arg in node.args:
|
| 93 |
+
if isinstance(arg, Node) and arg not in nodes:
|
| 94 |
+
input_nodes.add(arg)
|
| 95 |
+
|
| 96 |
+
if node.op == "get_attr":
|
| 97 |
+
params.add(node)
|
| 98 |
+
|
| 99 |
+
for user in node.users.keys():
|
| 100 |
+
if user not in nodes:
|
| 101 |
+
output_nodes.add(node)
|
| 102 |
+
|
| 103 |
+
return SourcePartition(
|
| 104 |
+
nodes,
|
| 105 |
+
module_type,
|
| 106 |
+
list(input_nodes),
|
| 107 |
+
list(output_nodes),
|
| 108 |
+
list(params), # type: ignore[arg-type]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
ret: Dict[Type[Any], List[SourcePartition]] = {}
|
| 112 |
+
|
| 113 |
+
if filter_fn:
|
| 114 |
+
# for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
|
| 115 |
+
# filter condition
|
| 116 |
+
filtered_modules = {}
|
| 117 |
+
for tp, name_to_partition in modules.items():
|
| 118 |
+
filtered_name_to_partition = {
|
| 119 |
+
name: partition
|
| 120 |
+
for name, partition in name_to_partition.items()
|
| 121 |
+
if all(map(filter_fn, partition))
|
| 122 |
+
}
|
| 123 |
+
filtered_modules[tp] = filtered_name_to_partition
|
| 124 |
+
modules = filtered_modules
|
| 125 |
+
|
| 126 |
+
for k, v in modules.items():
|
| 127 |
+
ret[k] = [make_partition(partition, k) for partition in v.values()]
|
| 128 |
+
|
| 129 |
+
return ret
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@compatibility(is_backward_compatible=False)
|
| 133 |
+
def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
|
| 134 |
+
"""
|
| 135 |
+
Given two subgraphs A and B (in the form of a list of nodes), checks if
|
| 136 |
+
A has nodes connecting to at least one node in B -- aka there exists a node
|
| 137 |
+
in B that uses a node in A (not the other way around).
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
for node in reversed(subgraph1.nodes):
|
| 141 |
+
for user in node.users.keys():
|
| 142 |
+
if user in subgraph2.nodes:
|
| 143 |
+
return True
|
| 144 |
+
return False
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module.
|
| 2 |
+
|
| 3 |
+
It registers custom reducers, that use shared memory to provide shared
|
| 4 |
+
views on the same data in different processes. Once the tensor/storage is moved
|
| 5 |
+
to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible
|
| 6 |
+
to send it to other processes without making any copies.
|
| 7 |
+
|
| 8 |
+
The API is 100% compatible with the original module - it's enough to change
|
| 9 |
+
``import multiprocessing`` to ``import torch.multiprocessing`` to have all the
|
| 10 |
+
tensors sent through the queues or shared via other mechanisms, moved to shared
|
| 11 |
+
memory.
|
| 12 |
+
|
| 13 |
+
Because of the similarity of APIs we do not document most of this package
|
| 14 |
+
contents, and we recommend referring to very good docs of the original module.
|
| 15 |
+
"""
|
| 16 |
+
import multiprocessing
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from .reductions import init_reductions
|
| 21 |
+
|
| 22 |
+
__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
from multiprocessing import * # noqa: F403
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# This call adds a Linux specific prctl(2) wrapper function to this module.
|
| 32 |
+
# See https://github.com/pytorch/pytorch/pull/14391 for more information.
|
| 33 |
+
torch._C._multiprocessing_init()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
"""Add helper function to spawn N processes and wait for completion of any of
|
| 37 |
+
them. This depends `mp.get_context` which was added in Python 3.4."""
|
| 38 |
+
from .spawn import (
|
| 39 |
+
ProcessContext,
|
| 40 |
+
ProcessExitedException,
|
| 41 |
+
ProcessRaisedException,
|
| 42 |
+
spawn,
|
| 43 |
+
SpawnContext,
|
| 44 |
+
start_processes,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if sys.platform == "darwin" or sys.platform == "win32":
|
| 49 |
+
_sharing_strategy = "file_system"
|
| 50 |
+
_all_sharing_strategies = {"file_system"}
|
| 51 |
+
else:
|
| 52 |
+
_sharing_strategy = "file_descriptor"
|
| 53 |
+
_all_sharing_strategies = {"file_descriptor", "file_system"}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def set_sharing_strategy(new_strategy):
|
| 57 |
+
"""Set the strategy for sharing CPU tensors.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
new_strategy (str): Name of the selected strategy. Should be one of
|
| 61 |
+
the values returned by :func:`get_all_sharing_strategies()`.
|
| 62 |
+
"""
|
| 63 |
+
global _sharing_strategy
|
| 64 |
+
assert new_strategy in _all_sharing_strategies
|
| 65 |
+
_sharing_strategy = new_strategy
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_sharing_strategy():
|
| 69 |
+
"""Return the current strategy for sharing CPU tensors."""
|
| 70 |
+
return _sharing_strategy
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_all_sharing_strategies():
|
| 74 |
+
"""Return a set of sharing strategies supported on a current system."""
|
| 75 |
+
return _all_sharing_strategies
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
init_reductions()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
__all__ = ["register_after_fork"]
|
| 4 |
+
|
| 5 |
+
if sys.platform == "win32":
|
| 6 |
+
import multiprocessing.util as _util
|
| 7 |
+
|
| 8 |
+
def _register(func):
|
| 9 |
+
def wrapper(arg):
|
| 10 |
+
func()
|
| 11 |
+
|
| 12 |
+
_util.register_after_fork(_register, wrapper)
|
| 13 |
+
|
| 14 |
+
else:
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def _register(func):
|
| 18 |
+
os.register_at_fork(after_in_child=func)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def register_after_fork(func):
|
| 22 |
+
"""Register a callable to be executed in the child process after a fork.
|
| 23 |
+
|
| 24 |
+
Note:
|
| 25 |
+
In python < 3.7 this will only work with processes created using the
|
| 26 |
+
``multiprocessing`` module. In python >= 3.7 it also works with
|
| 27 |
+
``os.fork()``.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
func (function): Function taking no arguments to be called in the child after fork
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
_register(func)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__init__.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
|
| 2 |
+
import contextlib
|
| 3 |
+
from typing import List, Union
|
| 4 |
+
from warnings import warn
|
| 5 |
+
|
| 6 |
+
from torch.backends.cuda import (
|
| 7 |
+
can_use_efficient_attention,
|
| 8 |
+
can_use_flash_attention,
|
| 9 |
+
enable_flash_sdp,
|
| 10 |
+
enable_math_sdp,
|
| 11 |
+
enable_mem_efficient_sdp,
|
| 12 |
+
flash_sdp_enabled,
|
| 13 |
+
math_sdp_enabled,
|
| 14 |
+
mem_efficient_sdp_enabled,
|
| 15 |
+
SDPAParams,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
|
| 19 |
+
|
| 20 |
+
# Note: [SDPA warnings]
|
| 21 |
+
# TODO: Consider using this for sdpa regardless of subclasses
|
| 22 |
+
# This only effects users of bias subclasses
|
| 23 |
+
# If this is set to True, we will warn the user if they are not using the fused kernels
|
| 24 |
+
# As well, it will raise warnings for all the reasons why the fused kernels can't be run.
|
| 25 |
+
# To set this to True, run
|
| 26 |
+
# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True
|
| 27 |
+
WARN_FOR_UNFUSED_KERNELS = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from torch._C import _SDPBackend as SDPBackend
|
| 31 |
+
|
| 32 |
+
# Hacks for Sphinx documentation:
|
| 33 |
+
# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
|
| 34 |
+
SDPBackend = SDPBackend
|
| 35 |
+
r"""An enum-like class that contains the different backends for scaled dot product attention.
|
| 36 |
+
This backend class is designed to be used with the sdpa_kernel context manager.
|
| 37 |
+
|
| 38 |
+
The following Enums are available:
|
| 39 |
+
- ERROR: An error occurred when trying to determine the backend.
|
| 40 |
+
- MATH: The math backend for scaled dot product attention.
|
| 41 |
+
- FLASH_ATTENTION: The flash attention backend for scaled dot product attention.
|
| 42 |
+
- EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention.
|
| 43 |
+
- CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention.
|
| 44 |
+
|
| 45 |
+
See :func:`torch.nn.attention.sdpa_kernel` for more details.
|
| 46 |
+
|
| 47 |
+
.. warning:: This class is in beta and subject to change.
|
| 48 |
+
"""
|
| 49 |
+
SDPBackend.__module__ = __name__
|
| 50 |
+
SDPBackend.__name__ = "SDPBackend"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _raise_kernel_warnings(params: SDPAParams) -> None:
|
| 54 |
+
"""
|
| 55 |
+
If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings
|
| 56 |
+
for all the reasons why the fused kernels can't be run. If using subclasses
|
| 57 |
+
"""
|
| 58 |
+
if WARN_FOR_UNFUSED_KERNELS:
|
| 59 |
+
if not can_use_efficient_attention(params):
|
| 60 |
+
warn("Efficient attention can't be used because:")
|
| 61 |
+
can_use_efficient_attention(params, True)
|
| 62 |
+
if not can_use_flash_attention(params):
|
| 63 |
+
warn("Flash attention can't be used because:")
|
| 64 |
+
can_use_flash_attention(params, True)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@contextlib.contextmanager
|
| 68 |
+
def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):
|
| 69 |
+
r"""
|
| 70 |
+
Context manager to select which backend to use for scaled dot product attention.
|
| 71 |
+
|
| 72 |
+
.. warning:: This function is beta and subject to change.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
backend (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention.
|
| 76 |
+
|
| 77 |
+
Example:
|
| 78 |
+
|
| 79 |
+
.. code-block:: python
|
| 80 |
+
|
| 81 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 82 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 83 |
+
# Only enable flash attention backend
|
| 84 |
+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
| 85 |
+
scaled_dot_product_attention(...)
|
| 86 |
+
|
| 87 |
+
# Enable the Math or Efficient attention backends
|
| 88 |
+
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
| 89 |
+
scaled_dot_product_attention(...)
|
| 90 |
+
|
| 91 |
+
This context manager can be used to select which backend to use for scaled dot product attention.
|
| 92 |
+
Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
|
| 93 |
+
"""
|
| 94 |
+
assert isinstance(
|
| 95 |
+
backends, (list, SDPBackend)
|
| 96 |
+
), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
|
| 97 |
+
|
| 98 |
+
if isinstance(backends, SDPBackend):
|
| 99 |
+
backends = [backends]
|
| 100 |
+
|
| 101 |
+
backends = set(backends)
|
| 102 |
+
previous_flash: bool = flash_sdp_enabled()
|
| 103 |
+
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
|
| 104 |
+
previous_math: bool = math_sdp_enabled()
|
| 105 |
+
try:
|
| 106 |
+
enable_flash = SDPBackend.FLASH_ATTENTION in backends
|
| 107 |
+
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
|
| 108 |
+
enable_math = SDPBackend.MATH in backends
|
| 109 |
+
|
| 110 |
+
enable_flash_sdp(enable_flash)
|
| 111 |
+
enable_mem_efficient_sdp(enable_mem_efficient)
|
| 112 |
+
enable_math_sdp(enable_math)
|
| 113 |
+
yield {}
|
| 114 |
+
finally:
|
| 115 |
+
enable_flash_sdp(previous_flash)
|
| 116 |
+
enable_mem_efficient_sdp(previous_mem_efficient)
|
| 117 |
+
enable_math_sdp(previous_math)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/bias.cpython-311.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/bias.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Defines bias subclasses that work with scaled_dot_product_attention"""
|
| 2 |
+
from enum import auto, IntEnum
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from warnings import warn
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.backends.cuda import (
|
| 8 |
+
can_use_efficient_attention,
|
| 9 |
+
can_use_flash_attention,
|
| 10 |
+
SDPAParams,
|
| 11 |
+
)
|
| 12 |
+
from torch.nn.attention import _raise_kernel_warnings
|
| 13 |
+
from torch.nn.attention._utils import (
|
| 14 |
+
_calculate_scale,
|
| 15 |
+
_input_requires_grad,
|
| 16 |
+
_postprocess_flash_output,
|
| 17 |
+
_validate_sdpa_input,
|
| 18 |
+
)
|
| 19 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 20 |
+
|
| 21 |
+
__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
torch._dynamo.allow_in_graph(can_use_flash_attention)
|
| 25 |
+
torch._dynamo.allow_in_graph(can_use_efficient_attention)
|
| 26 |
+
torch._dynamo.allow_in_graph(SDPAParams)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CausalVariant(IntEnum):
|
| 30 |
+
r"""
|
| 31 |
+
Enum for causal variants used in attention mechanisms.
|
| 32 |
+
|
| 33 |
+
Defines two types of causal biases:
|
| 34 |
+
|
| 35 |
+
`UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention.
|
| 36 |
+
The equivalent pytorch code for constructing this bias is:
|
| 37 |
+
|
| 38 |
+
.. code-block:: python
|
| 39 |
+
|
| 40 |
+
torch.tril(torch.ones(size, dtype=torch.bool))
|
| 41 |
+
|
| 42 |
+
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
|
| 43 |
+
|
| 44 |
+
.. code-block:: text
|
| 45 |
+
|
| 46 |
+
[[1, 0, 0, 0],
|
| 47 |
+
[1, 1, 0, 0],
|
| 48 |
+
[1, 1, 1, 0]]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
`LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower
|
| 52 |
+
right corner of the matrix.
|
| 53 |
+
|
| 54 |
+
The equivalent pytorch code for constructing this bias is:
|
| 55 |
+
|
| 56 |
+
.. code-block:: python
|
| 57 |
+
|
| 58 |
+
diagonal_offset = size[1] - size[0]
|
| 59 |
+
torch.tril(
|
| 60 |
+
torch.ones(size, dtype=torch.bool),
|
| 61 |
+
diagonal=diagonal_offset,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
|
| 65 |
+
|
| 66 |
+
.. code-block:: text
|
| 67 |
+
|
| 68 |
+
[[1, 1, 0, 0],
|
| 69 |
+
[1, 1, 1, 0],
|
| 70 |
+
[1, 1, 1, 1]]
|
| 71 |
+
|
| 72 |
+
Note that these variants are equivalent to each other when the sequence lengths of the query and key/value
|
| 73 |
+
tensors are equal since the triangular matrix is square.
|
| 74 |
+
|
| 75 |
+
.. warning:: This enum is a prototype and subject to change.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
UPPER_LEFT = auto()
|
| 79 |
+
LOWER_RIGHT = auto()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class CausalBias(torch.Tensor):
|
| 83 |
+
"""
|
| 84 |
+
A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
|
| 85 |
+
|
| 86 |
+
This class is used for defining causal (triangular) attention biases. For construing the bias, there exist
|
| 87 |
+
two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`.
|
| 88 |
+
|
| 89 |
+
Example:
|
| 90 |
+
|
| 91 |
+
.. code-block:: python
|
| 92 |
+
|
| 93 |
+
from torch.nn.attention.bias import causal_lower_right
|
| 94 |
+
|
| 95 |
+
bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
|
| 96 |
+
|
| 97 |
+
# Create a lower-right causal bias
|
| 98 |
+
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
|
| 99 |
+
|
| 100 |
+
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
|
| 101 |
+
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
|
| 102 |
+
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
|
| 103 |
+
|
| 104 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 105 |
+
|
| 106 |
+
.. warning:: This class is a prototype and subject to change.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int):
|
| 110 |
+
"""
|
| 111 |
+
Initializes the CausalBias instance with a specified variant and sequence lengths.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT).
|
| 115 |
+
seq_len_q (int): The sequence length of the query tensor.
|
| 116 |
+
seq_len_kv (int): The sequence length of the key/value tensor.
|
| 117 |
+
|
| 118 |
+
Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs.
|
| 119 |
+
"""
|
| 120 |
+
assert isinstance(variant, CausalVariant)
|
| 121 |
+
self.variant = variant
|
| 122 |
+
self.seq_len_q = seq_len_q
|
| 123 |
+
self.seq_len_kv = seq_len_kv
|
| 124 |
+
if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
|
| 125 |
+
warn(
|
| 126 |
+
"Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def _upper_left(self, device: torch.device) -> torch.Tensor:
|
| 130 |
+
"""Upper left causal bias"""
|
| 131 |
+
return torch.tril(
|
| 132 |
+
torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def _lower_right(self, device: torch.device) -> torch.Tensor:
|
| 136 |
+
"""Lower right causal bias"""
|
| 137 |
+
diagonal_offset = self.seq_len_kv - self.seq_len_q
|
| 138 |
+
return torch.tril(
|
| 139 |
+
torch.ones(
|
| 140 |
+
self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
|
| 141 |
+
),
|
| 142 |
+
diagonal=diagonal_offset,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
| 146 |
+
"""
|
| 147 |
+
Materializes the causal bias into a tensor form.
|
| 148 |
+
|
| 149 |
+
Depending on the variant, this method generates either an upper-left or lower-right
|
| 150 |
+
triangular matrix to represent the causal bias.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
torch.Tensor: The materialized bias tensor.
|
| 157 |
+
"""
|
| 158 |
+
if device is None:
|
| 159 |
+
device = torch.device("cpu")
|
| 160 |
+
if self.variant == CausalVariant.UPPER_LEFT:
|
| 161 |
+
return self._upper_left(device)
|
| 162 |
+
elif self.variant == CausalVariant.LOWER_RIGHT:
|
| 163 |
+
return self._lower_right(device)
|
| 164 |
+
|
| 165 |
+
@staticmethod
|
| 166 |
+
def _dispatch(
|
| 167 |
+
query: torch.Tensor,
|
| 168 |
+
key: torch.Tensor,
|
| 169 |
+
value: torch.Tensor,
|
| 170 |
+
attn_mask: "CausalBias",
|
| 171 |
+
dropout_p: float = 0.0,
|
| 172 |
+
is_causal: bool = False,
|
| 173 |
+
scale: Optional[float] = None,
|
| 174 |
+
) -> torch.Tensor:
|
| 175 |
+
r"""
|
| 176 |
+
Handles the logic for computing attention with the specified causal bias.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
|
| 180 |
+
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
|
| 181 |
+
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
|
| 182 |
+
attn_mask (CausalBias): The type of causal attention to apply.
|
| 183 |
+
A boolean mask where a value of True indicates that the element *should* take part in attention.
|
| 184 |
+
A float mask of the same type as query, key, value that is added to the attention score.
|
| 185 |
+
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
|
| 186 |
+
is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal
|
| 187 |
+
are set.
|
| 188 |
+
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
|
| 189 |
+
to :math:`\frac{1}{\sqrt{E}}`.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
|
| 193 |
+
|
| 194 |
+
Raises:
|
| 195 |
+
ValueError: If the causal bias variant is not a CausalVariant type.
|
| 196 |
+
|
| 197 |
+
"""
|
| 198 |
+
if is_causal:
|
| 199 |
+
raise ValueError("CausalBias should not be used with causal=True")
|
| 200 |
+
|
| 201 |
+
if (
|
| 202 |
+
attn_mask.seq_len_q == attn_mask.seq_len_kv
|
| 203 |
+
or attn_mask.variant == CausalVariant.UPPER_LEFT
|
| 204 |
+
):
|
| 205 |
+
return scaled_dot_product_attention(
|
| 206 |
+
query,
|
| 207 |
+
key,
|
| 208 |
+
value,
|
| 209 |
+
attn_mask=None,
|
| 210 |
+
dropout_p=dropout_p,
|
| 211 |
+
is_causal=True,
|
| 212 |
+
scale=scale,
|
| 213 |
+
)
|
| 214 |
+
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
|
| 215 |
+
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
|
| 216 |
+
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
|
| 217 |
+
if can_use_flash_attention(sdpa_params):
|
| 218 |
+
needs_padding = query.size(-1) % 8 != 0
|
| 219 |
+
og_head_size = query.size(-1)
|
| 220 |
+
og_scale = _calculate_scale(og_head_size, scale)
|
| 221 |
+
if needs_padding:
|
| 222 |
+
query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8))
|
| 223 |
+
key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8))
|
| 224 |
+
value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8))
|
| 225 |
+
out = torch.ops.aten._scaled_dot_product_flash_attention(
|
| 226 |
+
query,
|
| 227 |
+
key,
|
| 228 |
+
value,
|
| 229 |
+
dropout_p,
|
| 230 |
+
is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right
|
| 231 |
+
return_debug_mask=False,
|
| 232 |
+
scale=og_scale,
|
| 233 |
+
)[0]
|
| 234 |
+
return _postprocess_flash_output(out, og_head_size)
|
| 235 |
+
if can_use_efficient_attention(sdpa_params):
|
| 236 |
+
compute_log_sumexp = False
|
| 237 |
+
if _input_requires_grad(query, key, value):
|
| 238 |
+
compute_log_sumexp = True
|
| 239 |
+
return torch.ops.aten._efficient_attention_forward(
|
| 240 |
+
query.transpose(1, 2),
|
| 241 |
+
key.transpose(1, 2),
|
| 242 |
+
value.transpose(1, 2),
|
| 243 |
+
bias=None,
|
| 244 |
+
cu_seqlens_q=None,
|
| 245 |
+
cu_seqlens_k=None,
|
| 246 |
+
max_seqlen_q=None,
|
| 247 |
+
max_seqlen_k=None,
|
| 248 |
+
dropout_p=dropout_p,
|
| 249 |
+
custom_mask_type=int(attn_mask.variant),
|
| 250 |
+
compute_log_sumexp=compute_log_sumexp,
|
| 251 |
+
scale=scale,
|
| 252 |
+
causal_diagonal=None,
|
| 253 |
+
seqlen_k=None,
|
| 254 |
+
)[0].transpose(1, 2)
|
| 255 |
+
else:
|
| 256 |
+
_raise_kernel_warnings(sdpa_params)
|
| 257 |
+
# We cant use efficient attention the only support for lower right is via materialization
|
| 258 |
+
return scaled_dot_product_attention(
|
| 259 |
+
query,
|
| 260 |
+
key,
|
| 261 |
+
value,
|
| 262 |
+
attn_mask=attn_mask._materialize(query.device),
|
| 263 |
+
dropout_p=dropout_p,
|
| 264 |
+
is_causal=False,
|
| 265 |
+
scale=scale,
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
raise ValueError(
|
| 269 |
+
f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
@classmethod
|
| 273 |
+
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
| 274 |
+
"""Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
|
| 275 |
+
if kwargs is None:
|
| 276 |
+
kwargs = {}
|
| 277 |
+
if func != torch.nn.functional.scaled_dot_product_attention:
|
| 278 |
+
raise NotImplementedError(
|
| 279 |
+
"CausalBias only supports scaled_dot_product_attention"
|
| 280 |
+
)
|
| 281 |
+
return cls._dispatch(*args, **kwargs)
|
| 282 |
+
|
| 283 |
+
def __repr__(self):
|
| 284 |
+
return self._materialize().__repr__()
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def causal_upper_left(*size) -> CausalBias:
|
| 288 |
+
"""
|
| 289 |
+
Creates an upper-left triangular causal bias.
|
| 290 |
+
|
| 291 |
+
This function generates a upper-left triangular matrix to represent causal attention bias with a
|
| 292 |
+
diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix.
|
| 293 |
+
This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`.
|
| 294 |
+
|
| 295 |
+
The equivalent pytorch code for constructing this bias is:
|
| 296 |
+
|
| 297 |
+
.. code-block:: python
|
| 298 |
+
|
| 299 |
+
torch.tril(torch.ones(size, dtype=torch.bool))
|
| 300 |
+
|
| 301 |
+
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
|
| 302 |
+
|
| 303 |
+
.. code-block:: text
|
| 304 |
+
|
| 305 |
+
[[1, 0, 0, 0],
|
| 306 |
+
[1, 1, 0, 0],
|
| 307 |
+
[1, 1, 1, 0]]
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
size: The size of the bias matrix.
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
CausalBias: The UPPER_LEFT triangular causal bias variant.
|
| 314 |
+
"""
|
| 315 |
+
assert len(size) == 2, "causal_upper_left only supports 2D tensors"
|
| 316 |
+
seq_len_q, seq_len_kv = size
|
| 317 |
+
return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def causal_lower_right(*size) -> CausalBias:
|
| 321 |
+
"""
|
| 322 |
+
Creates a lower-right triangular causal bias.
|
| 323 |
+
|
| 324 |
+
This function generates a lower-right triangular matrix to represent causal attention bias with a
|
| 325 |
+
diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix.
|
| 326 |
+
|
| 327 |
+
The equivalent pytorch code for constructing this bias is:
|
| 328 |
+
|
| 329 |
+
.. code-block:: python
|
| 330 |
+
|
| 331 |
+
diagonal_offset = size[1] - size[0]
|
| 332 |
+
torch.tril(
|
| 333 |
+
torch.ones(size, dtype=torch.bool),
|
| 334 |
+
diagonal=diagonal_offset,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
|
| 338 |
+
|
| 339 |
+
.. code-block:: text
|
| 340 |
+
|
| 341 |
+
[[1, 1, 0, 0],
|
| 342 |
+
[1, 1, 1, 0],
|
| 343 |
+
[1, 1, 1, 1]]
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
size: The size of the bias matrix.
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
CausalBias: The LOWER_RIGHT triangular causal bias variant.
|
| 350 |
+
"""
|
| 351 |
+
assert len(size) == 2, "causal_lower_right only supports 2D tensors"
|
| 352 |
+
seq_len_q, seq_len_kv = size
|
| 353 |
+
return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/common_types.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypeVar, Union, Tuple, Optional
|
| 2 |
+
from .. import Tensor
|
| 3 |
+
|
| 4 |
+
# Create some useful type aliases
|
| 5 |
+
|
| 6 |
+
# Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally
|
| 7 |
+
# broadcast to a tuple.
|
| 8 |
+
# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations.
|
| 9 |
+
T = TypeVar('T')
|
| 10 |
+
_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]]
|
| 11 |
+
_scalar_or_tuple_1_t = Union[T, Tuple[T]]
|
| 12 |
+
_scalar_or_tuple_2_t = Union[T, Tuple[T, T]]
|
| 13 |
+
_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]]
|
| 14 |
+
_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]]
|
| 15 |
+
_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]]
|
| 16 |
+
_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]]
|
| 17 |
+
|
| 18 |
+
# For arguments which represent size parameters (eg, kernel size, padding)
|
| 19 |
+
_size_any_t = _scalar_or_tuple_any_t[int]
|
| 20 |
+
_size_1_t = _scalar_or_tuple_1_t[int]
|
| 21 |
+
_size_2_t = _scalar_or_tuple_2_t[int]
|
| 22 |
+
_size_3_t = _scalar_or_tuple_3_t[int]
|
| 23 |
+
_size_4_t = _scalar_or_tuple_4_t[int]
|
| 24 |
+
_size_5_t = _scalar_or_tuple_5_t[int]
|
| 25 |
+
_size_6_t = _scalar_or_tuple_6_t[int]
|
| 26 |
+
|
| 27 |
+
# For arguments which represent optional size parameters (eg, adaptive pool parameters)
|
| 28 |
+
_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]]
|
| 29 |
+
_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]]
|
| 30 |
+
_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]]
|
| 31 |
+
|
| 32 |
+
# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters)
|
| 33 |
+
_ratio_2_t = _scalar_or_tuple_2_t[float]
|
| 34 |
+
_ratio_3_t = _scalar_or_tuple_3_t[float]
|
| 35 |
+
_ratio_any_t = _scalar_or_tuple_any_t[float]
|
| 36 |
+
|
| 37 |
+
_tensor_list_t = _scalar_or_tuple_any_t[Tensor]
|
| 38 |
+
|
| 39 |
+
# For the return value of max pooling operations that may or may not return indices.
|
| 40 |
+
# With the proposed 'Literal' feature to Python typing, it might be possible to
|
| 41 |
+
# eventually eliminate this.
|
| 42 |
+
_maybe_indices_t = _scalar_or_tuple_2_t[Tensor]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/grad.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradient interface."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from .modules.utils import _single, _pair, _triple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
| 8 |
+
r"""Compute the gradient of conv1d with respect to the input of the convolution.
|
| 9 |
+
|
| 10 |
+
This is same as the 1D transposed convolution operator under the hood but requires
|
| 11 |
+
the shape of the gradient w.r.t. input to be specified explicitly.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
input_size : Shape of the input gradient tensor
|
| 15 |
+
weight: weight tensor (out_channels x in_channels/groups x kW)
|
| 16 |
+
grad_output : output gradient tensor (minibatch x out_channels x oW)
|
| 17 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 18 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 19 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 20 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 21 |
+
|
| 22 |
+
Examples::
|
| 23 |
+
|
| 24 |
+
>>> input = torch.randn(1, 1, 3, requires_grad=True)
|
| 25 |
+
>>> weight = torch.randn(1, 1, 1, requires_grad=True)
|
| 26 |
+
>>> output = F.conv1d(input, weight)
|
| 27 |
+
>>> grad_output = torch.randn(output.shape)
|
| 28 |
+
>>> grad_input = torch.autograd.grad(output, input, grad_output)
|
| 29 |
+
>>> F.grad.conv1d_input(input.shape, weight, grad_output)
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
input = grad_output.new_empty(1).expand(input_size)
|
| 33 |
+
|
| 34 |
+
return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
|
| 35 |
+
_single(stride), _single(padding), _single(dilation),
|
| 36 |
+
False, [0], groups, (True, False, False))[0]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
| 40 |
+
r"""Compute the gradient of conv1d with respect to the weight of the convolution.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
input: input tensor of shape (minibatch x in_channels x iW)
|
| 44 |
+
weight_size : Shape of the weight gradient tensor
|
| 45 |
+
grad_output : output gradient tensor (minibatch x out_channels x oW)
|
| 46 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 47 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 48 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 49 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 50 |
+
|
| 51 |
+
Examples::
|
| 52 |
+
|
| 53 |
+
>>> input = torch.randn(1, 1, 3, requires_grad=True)
|
| 54 |
+
>>> weight = torch.randn(1, 1, 1, requires_grad=True)
|
| 55 |
+
>>> output = F.conv1d(input, weight)
|
| 56 |
+
>>> grad_output = torch.randn(output.shape)
|
| 57 |
+
>>> # xdoctest: +SKIP
|
| 58 |
+
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
|
| 59 |
+
>>> F.grad.conv1d_weight(input, weight.shape, grad_output)
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
weight = grad_output.new_empty(1).expand(weight_size)
|
| 63 |
+
|
| 64 |
+
return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
|
| 65 |
+
_single(stride), _single(padding), _single(dilation),
|
| 66 |
+
False, [0], groups, (False, True, False))[1]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
| 70 |
+
r"""Compute the gradient of conv2d with respect to the input of the convolution.
|
| 71 |
+
|
| 72 |
+
This is same as the 2D transposed convolution operator under the hood but requires
|
| 73 |
+
the shape of the gradient w.r.t. input to be specified explicitly.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
input_size : Shape of the input gradient tensor
|
| 77 |
+
weight: weight tensor (out_channels x in_channels/groups x kH x kW)
|
| 78 |
+
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
|
| 79 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 80 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 81 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 82 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 83 |
+
|
| 84 |
+
Examples::
|
| 85 |
+
|
| 86 |
+
>>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
|
| 87 |
+
>>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
|
| 88 |
+
>>> output = F.conv2d(input, weight)
|
| 89 |
+
>>> grad_output = torch.randn(output.shape)
|
| 90 |
+
>>> grad_input = torch.autograd.grad(output, input, grad_output)
|
| 91 |
+
>>> F.grad.conv2d_input(input.shape, weight, grad_output)
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
input = grad_output.new_empty(1).expand(input_size)
|
| 95 |
+
|
| 96 |
+
return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
|
| 97 |
+
_pair(stride), _pair(padding), _pair(dilation),
|
| 98 |
+
False, [0], groups, (True, False, False))[0]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
| 102 |
+
r"""Compute the gradient of conv2d with respect to the weight of the convolution.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
input: input tensor of shape (minibatch x in_channels x iH x iW)
|
| 106 |
+
weight_size : Shape of the weight gradient tensor
|
| 107 |
+
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
|
| 108 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 109 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 110 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 111 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 112 |
+
|
| 113 |
+
Examples::
|
| 114 |
+
|
| 115 |
+
>>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
|
| 116 |
+
>>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
|
| 117 |
+
>>> output = F.conv2d(input, weight)
|
| 118 |
+
>>> grad_output = torch.randn(output.shape)
|
| 119 |
+
>>> # xdoctest: +SKIP
|
| 120 |
+
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
|
| 121 |
+
>>> F.grad.conv2d_weight(input, weight.shape, grad_output)
|
| 122 |
+
|
| 123 |
+
"""
|
| 124 |
+
weight = grad_output.new_empty(1).expand(weight_size)
|
| 125 |
+
|
| 126 |
+
return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
|
| 127 |
+
_pair(stride), _pair(padding), _pair(dilation),
|
| 128 |
+
False, [0], groups, (False, True, False))[1]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
| 132 |
+
r"""Compute the gradient of conv3d with respect to the input of the convolution.
|
| 133 |
+
|
| 134 |
+
This is same as the 3D transposed convolution operator under the hood but requires
|
| 135 |
+
the shape of the gradient w.r.t. input to be specified explicitly.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
input_size : Shape of the input gradient tensor
|
| 139 |
+
weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
|
| 140 |
+
grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
|
| 141 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 142 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 143 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 144 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 145 |
+
|
| 146 |
+
Examples::
|
| 147 |
+
|
| 148 |
+
>>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
|
| 149 |
+
>>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
|
| 150 |
+
>>> output = F.conv3d(input, weight)
|
| 151 |
+
>>> grad_output = torch.randn(output.shape)
|
| 152 |
+
>>> grad_input = torch.autograd.grad(output, input, grad_output)
|
| 153 |
+
>>> F.grad.conv3d_input(input.shape, weight, grad_output)
|
| 154 |
+
|
| 155 |
+
"""
|
| 156 |
+
input = grad_output.new_empty(1).expand(input_size)
|
| 157 |
+
|
| 158 |
+
return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
|
| 159 |
+
_triple(stride), _triple(padding), _triple(dilation),
|
| 160 |
+
False, [0], groups, (True, False, False))[0]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
| 164 |
+
r"""Compute the gradient of conv3d with respect to the weight of the convolution.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
|
| 168 |
+
weight_size : Shape of the weight gradient tensor
|
| 169 |
+
grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
|
| 170 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 171 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
| 172 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
| 173 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
| 174 |
+
|
| 175 |
+
Examples::
|
| 176 |
+
|
| 177 |
+
>>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
|
| 178 |
+
>>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
|
| 179 |
+
>>> output = F.conv3d(input, weight)
|
| 180 |
+
>>> grad_output = torch.randn(output.shape)
|
| 181 |
+
>>> grad_weight = torch.autograd.grad(output, weight, grad_output)
|
| 182 |
+
>>> F.grad.conv3d_weight(input, weight.shape, grad_output)
|
| 183 |
+
|
| 184 |
+
"""
|
| 185 |
+
weight = grad_output.new_empty(1).expand(weight_size)
|
| 186 |
+
|
| 187 |
+
return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
|
| 188 |
+
_triple(stride), _triple(padding), _triple(dilation),
|
| 189 |
+
False, [0], groups, (False, True, False))[1]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .linear_relu import LinearReLU
|
| 2 |
+
from .linear_fused import LinearBn1d
|
| 3 |
+
from .conv_fused import (
|
| 4 |
+
ConvBn1d,
|
| 5 |
+
ConvBn2d,
|
| 6 |
+
ConvBn3d,
|
| 7 |
+
ConvBnReLU1d,
|
| 8 |
+
ConvBnReLU2d,
|
| 9 |
+
ConvBnReLU3d,
|
| 10 |
+
ConvReLU1d,
|
| 11 |
+
ConvReLU2d,
|
| 12 |
+
ConvReLU3d,
|
| 13 |
+
update_bn_stats,
|
| 14 |
+
freeze_bn_stats,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"LinearReLU",
|
| 19 |
+
"LinearBn1d",
|
| 20 |
+
"ConvReLU1d",
|
| 21 |
+
"ConvReLU2d",
|
| 22 |
+
"ConvReLU3d",
|
| 23 |
+
"ConvBn1d",
|
| 24 |
+
"ConvBn2d",
|
| 25 |
+
"ConvBn3d",
|
| 26 |
+
"ConvBnReLU1d",
|
| 27 |
+
"ConvBnReLU2d",
|
| 28 |
+
"ConvBnReLU3d",
|
| 29 |
+
"update_bn_stats",
|
| 30 |
+
"freeze_bn_stats",
|
| 31 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (846 Bytes). View file
|
|
|