diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d124aebd4e4fd978c13bf8616b48b67924fdedee Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..249ce9b1157829d47d8fd833068de7e4da54cf1c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py @@ -0,0 +1,55 @@ +import threading + +import torch._C._lazy +from torch.utils._pytree import tree_flatten, tree_unflatten + +from .closure import add_step_closure, run_step_closures + + +def mark_step(device: str = "", wait=False): + """Triggers a mark step, which amounts to + - collecting a group of 'live' lazy tensors to index into the compilation cache + (lowering/compiling their IR graphs if not cached) + - kicking off execution of the compiled function + - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) + """ + # TODO(whc) expand this to include backend hooks and align with XLA backend needs + torch._C._lazy._mark_step(device, [], wait=wait) + + run_step_closures() + + +def wait_device_ops(devices=None): + """Waits for all the async operations on the given devices to complete. + Args: + devices (string..., optional): The devices whose async ops need to be waited + for. If empty, all the local devices will be waited for. + """ + if devices is None: + devices = [] + torch._C._lazy._wait_device_ops(devices=devices) + + +def sync_multi(tensors, devices): + """ + Sync the list of lazy tensors so there IR get lowered for the activate backend + and the compiled computation graph get cached. + """ + torch._C._lazy._sync_multi(tensors, devices) + + +def get_tensor_id(tensor): + """Return a unique id of the lazy tensor maintained by LTC""" + return torch._C._lazy._get_tensor_id(tensor) + + +def to_cpu(tensors, devices=None): + devices = devices or ["lazy"] + + flattened, spec = tree_flatten(tensors) + sync_multi(flattened, devices) + return tree_unflatten([t.to("cpu") for t in flattened], spec) + + +def save(tensors, *args, **kwargs): + torch.save(to_cpu(tensors), *args, **kwargs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6074014e53a9340d6ce31b3c6c9cfb119d3a4bcb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c72d6fd430030297954aec0100ab076d8003ebfc Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0854c6d97e3165188132d621c8e45584983a1476 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2ee5c69d2062c15f4398f97cbea63ffaec352c5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..47aa9c500466daadf282633d43f0335e0a8c0b70 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py @@ -0,0 +1,48 @@ +import torch + +""" +tensor_factory_functions defines the list of torch functions that create tensors. +The list is grabbed by searching thru native_functions.yaml by the following +regular expression: + + cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor" + +It's possible that new tensor factory functions are added making this list stale. +Use at your own risk or regenerate the list. +""" +tensor_factory_functions = ( + torch._cudnn_init_dropout_state, + torch.arange, + torch.bartlett_window, + torch.blackman_window, + torch._empty_affine_quantized, + torch.empty_strided, + torch.eye, + torch.full, + torch.from_file, + torch.hann_window, + torch.hamming_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.ones, + torch.scalar_tensor, + torch.rand, + torch.randint, + torch.randn, + torch.randperm, + torch.range, + torch._efficientzerotensor, + torch.zeros, + torch.tril_indices, + torch.triu_indices, + # Note: the following functions match the regular expression search above but + # they are not available in the torch module. Comment out. + # torch._sparse_coo_tensor_with_dims, + # torch.fft_fftfreq, + # torch.fft_rfftfreq, +) + ( + # torch.tensor is special since it's not in native_functions.yaml + # add it separately + torch.tensor, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b948e1eccc0b785eeccd000b39b9d68725b39e17 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__init__.py @@ -0,0 +1,308 @@ +from functools import partial + +from typing import List, Optional, Tuple, Union + +import torch + +import torch._prims as prims + +import torch._prims_common as utils +import torch._refs as refs +import torch._refs.linalg as linalg +from torch import Tensor +from torch._prims_common import ( + check_fp_or_complex, + check_is_matrix, + Dim, + DimsType, + ELEMENTWISE_TYPE_PROMOTION_KIND, + IntLike, + NumberType, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + elementwise_type_promotion_wrapper, + out_wrapper, +) + + +__all__ = [ + "diagonal", + "matrix_norm", + "norm", + "svd", + "svdvals", + "vector_norm", + "vecdot", + "cross", +] + + +def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): + """ + Checks related to the dtype kwarg in `linalg.*norm` functions + """ + if dtype is not None: + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", + ) + torch._check( + utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), + lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( + fn_name=fn_name, + d="complex" if utils.is_complex_dtype(x_dtype) else "real", + dtype=dtype, + ), + ) + torch._check( + utils.get_higher_dtype(dtype, x_dtype) == dtype, + lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " + "without narrowing to the specified dtype ({dtype})", + ) + + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition +from torch._decomp.decompositions import pw_cast_for_opmath + + +@register_decomposition(torch._ops.ops.aten.linalg_cross) +@out_wrapper() +@pw_cast_for_opmath +def cross(a: Tensor, b: Tensor, dim: int = -1): + torch._check( + a.ndim == b.ndim, + lambda: "linalg.cross: inputs must have the same number of dimensions.", + ) + torch._check( + a.size(dim) == 3 and b.size(dim) == 3, + lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}", + ) + a, b = torch.broadcast_tensors(a, b) + dim = utils.canonicalize_dim(a.ndim, dim) + idx = torch.arange(3, device=a.device) + return a.index_select(dim, (idx + 1) % 3) * b.index_select( + dim, (idx + 2) % 3 + ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + + +def diagonal( + input: TensorLikeType, + *, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) + + +@register_decomposition(torch._ops.ops.aten.linalg_vector_norm) +@out_wrapper(exact_dtype=True) +def vector_norm( + x: TensorLikeType, + ord: Union[float, int] = 2, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + # Checks + check_fp_or_complex(x.dtype, "linalg.vector_norm") + + if isinstance(dim, Dim): + dim = [dim] # type: ignore[assignment] + + if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): + torch._check( + dim is not None and len(dim) != 0, + lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " + "because the operation does not have an identity", + ) + shape = x.shape + assert dim is not None # mypy does not seem to be able to see through check? + for d in dim: + torch._check( + shape[d] != 0, + lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " + f"dimension {d} because this dimension is empty and the " + "operation does not have an identity", + ) + _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") + + computation_dtype, result_dtype = utils.reduction_dtypes( + x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype + ) + + to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype) + + # Implementation + if ord == 0.0: + return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) + elif ord == float("inf"): + return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + elif ord == float("-inf"): + return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + else: + # From here on the computation dtype is important as the reduction is non-trivial + x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment] + reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) + + is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 + if not (is_ord_even and utils.is_float_dtype(x.dtype)): + x = torch.abs(x) + return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] + + +def _backshift_permutation(dim0, dim1, ndim): + # Auxiliary function for matrix_norm + # Computes the permutation that moves the two given dimensions to the back + ret = [i for i in range(ndim) if i != dim0 and i != dim1] + ret.extend((dim0, dim1)) + return ret + + +def _inverse_permutation(perm): + # Given a permutation, returns its inverse. It's equivalent to argsort on an array + return [i for i, j in sorted(enumerate(perm), key=lambda i_j: i_j[1])] + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def matrix_norm( + A: TensorLikeType, + ord: Union[float, str] = "fro", + dim: DimsType = (-2, -1), + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # shape + check_is_matrix(A, "linalg.matrix_norm") + # dim + dim = utils.canonicalize_dims(A.ndim, dim) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" + ) + torch._check( + dim[0] != dim[1], + lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", + ) + # dtype arg + _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") + + if isinstance(ord, str): + # ord + torch._check( + ord in ("fro", "nuc"), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc" + ) + + if ord == "fro": + return vector_norm(A, 2, dim, keepdim, dtype=dtype) + else: # ord == "nuc" + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: + # ord + abs_ord = abs(ord) + torch._check( + abs_ord in (2, 1, float("inf")), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2 + ) + + max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + + if abs_ord == 2.0: + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: # 1, -1, inf, -inf + dim0, dim1 = dim + if abs_ord == float("inf"): + dim0, dim1 = dim1, dim0 + if not keepdim and (dim0 < dim1): + dim1 -= 1 + return max_min( + vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 + ) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def norm( + A: TensorLikeType, + ord: Optional[Union[float, str]] = None, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + if dim is not None: + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) in (1, 2), + lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", + ) + elif ord is not None: + torch._check( + A.ndim in (1, 2), + lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", + ) + + if ord is not None and ( + (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2) + ): + if dim is None: + dim = (0, 1) + return matrix_norm(A, ord, dim, keepdim, dtype=dtype) + else: + if ord is None: + ord = 2.0 + return vector_norm(A, ord, dim, keepdim, dtype=dtype) + + +# CompositeImplicitAutograd +@out_wrapper("U", "S", "Vh", exact_dtype=True) +def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + return prims.svd(A, full_matrices=full_matrices) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def svdvals(A: TensorLikeType) -> Tensor: + return svd(A, full_matrices=False)[1] + + +# CompositeImplicitAutograd +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("x", "y"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: + check_fp_or_complex(x.dtype, "linalg.vecdot") + return (x.conj() * y).sum(dim=dim) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..048de83506d2919fd858e871290871bb0f558289 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__init__.py @@ -0,0 +1,236 @@ +import math +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs + +from torch import Tensor +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper +from torch._refs import ( + _make_alias, + _make_elementwise_binary_reference, + _make_elementwise_unary_reference, +) + + +__all__ = [ + "bessel_j0", + "bessel_j1", + "entr", + "erfcx", + "expit", + "i0e", + "i1", + "i1e", + "log_ndtr", + "logit", + "log_softmax", + "multigammaln", + "ndtr", + "ndtri", + "softmax", + "spherical_bessel_j0", + "xlog1py", + "zeta", +] +aten = torch._ops.ops.aten + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j0(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j1(a) + + +@register_decomposition(aten.special_entr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def entr(a: TensorLikeType) -> TensorLikeType: + return torch.where( + torch.isnan(a), + a, + torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)), + ) + + +@register_decomposition(aten.special_erfcx) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def erfcx(a: TensorLikeType) -> TensorLikeType: + return prims.erfcx(a) + + +# alias for sigmoid +expit = _make_alias(torch.sigmoid, "expit") + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i0e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i0e(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1e(a) + + +@register_decomposition(aten.special_log_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def log_ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / √2 + M_SQRT1_2 = 0.707106781186547524400844362104849039 + t = a * M_SQRT1_2 + return torch.where( + a < 1.0, + torch.log(torch.special.erfcx(-t) / 2) - t * t, + torch.log1p(-torch.erfc(t) / 2), + ) + + +@register_decomposition(aten.logit) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: + if eps is None: + eps = -1.0 + lo = eps + hi = 1 - eps + self = torch.clamp(self, lo, hi) + return torch.log(torch.true_divide(self, torch.sub(1, self))) + + +@register_decomposition(aten.special_xlog1py) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(a, TensorLike) and isinstance(b, Number): + b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@register_decomposition(aten.mvlgamma) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType: + c = 0.25 * p * (p - 1) * math.log(math.pi) + b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device) + return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c + + +@register_decomposition(aten.special_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / √2 + M_SQRT1_2 = 0.707106781186547524400844362104849039 + a_sqrt_2 = a * M_SQRT1_2 + return (1 + torch.erf(a_sqrt_2)) * 0.5 + + +@register_decomposition(aten.special_ndtri) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtri(a: TensorLikeType) -> TensorLikeType: + return prims.ndtri(a) + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.spherical_bessel_j0(a) + + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.zeta(a, b) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py new file mode 100644 index 0000000000000000000000000000000000000000..4b902c0c6d4d76c6d584ed4d0ad1cc71a3f9cc6d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py @@ -0,0 +1,91 @@ +r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling.""" + +from contextlib import contextmanager + +try: + from torch._C import _nvtx +except ImportError: + + class _NVTXStub: + @staticmethod + def _fail(*args, **kwargs): + raise RuntimeError( + "NVTX functions not installed. Are you sure you have a CUDA build?" + ) + + rangePushA = _fail + rangePop = _fail + markA = _fail + + _nvtx = _NVTXStub() # type: ignore[assignment] + +__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"] + + +def range_push(msg): + """ + Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started. + + Args: + msg (str): ASCII message to associate with range + """ + return _nvtx.rangePushA(msg) + + +def range_pop(): + """Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended.""" + return _nvtx.rangePop() + + +def range_start(msg) -> int: + """ + Mark the start of a range with string message. It returns an unique handle + for this range to pass to the corresponding call to rangeEnd(). + + A key difference between this and range_push/range_pop is that the + range_start/range_end version supports range across threads (start on one + thread and end on another thread). + + Returns: A range handle (uint64_t) that can be passed to range_end(). + + Args: + msg (str): ASCII message to associate with the range. + """ + return _nvtx.rangeStartA(msg) + + +def range_end(range_id) -> None: + """ + Mark the end of a range for a given range_id. + + Args: + range_id (int): an unique handle for the start range. + """ + _nvtx.rangeEnd(range_id) + + +def mark(msg): + """ + Describe an instantaneous event that occurred at some point. + + Args: + msg (str): ASCII message to associate with the event. + """ + return _nvtx.markA(msg) + + +@contextmanager +def range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an NVTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + Args: + msg (str): message to associate with the range + """ + range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + range_pop() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d27e469d1d620ce4cb3c9f13d5ebf9f512c202a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..101e92f7b9323af59935e8054c522a8054274ec4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc2bb3227c0a89f5e38be2363211a0d053b673a3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0558e4fd44dc35ff655f7b7b0ecf72b9138f253 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b02e2a2ef2af181c62bf8545f6aa010cfb0a5ec Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a43d8f3ebbe060d8c7659b65a2dd924e34d2ce3b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py @@ -0,0 +1,52 @@ +from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ + BVar +from torch.fx.experimental.migrate_gradual_types.operation import op_leq + + +def gen_tvar(curr): + """ + Generate a tensor variable + :param curr: The current counter + :return: a tensor variable and the updated counter + """ + curr += 1 + return TVar(curr), curr + + +def gen_dvar(curr): + """ + Generate a dimension variable + :param curr: the current counter + :return: a dimension variable and an updated counter + """ + curr += 1 + return DVar(curr), curr + +def gen_bvar(curr): + """ + Generate a boolean variable + :param curr: the current counter + :return: a boolean variable and an updated counter + """ + curr += 1 + return BVar(curr), curr + +def gen_tensor_dims(n, curr): + """ + Generate a list of tensor dimensions + :param n: the number of dimensions + :param curr: the current counter + :return: a list of dimension variables and an updated counter + """ + dims = [] + for _ in range(n): + dvar, curr = gen_dvar(curr) + dims.append(dvar) + return dims, curr + + +def gen_nat_constraints(list_of_dims): + """ + Generate natural number constraints for dimensions + """ + return [BinConstraintD(0, d, op_leq) for d in list_of_dims] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0266d36a2ff1c9d5e5e68828bbe8b3c4d182182 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4932bd4fdfbd9f7935c9cef54b19e9e033c86718 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a30d02b6a986c0640ed8c32442d074e3e26218b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af0baa1b44aa34c9c343cbd9486cca7192d6adef Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab200aad19578e4999bd19890b2edaabe0670c20 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a04405130c80a334e6ede4c58b3010056bb179 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f165d3d6a55575108077be4bda1231f4b8d8221 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c35f314e596854c48098e46f0ab8387ab6993eb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0d247f617729052a62b751e8f53411abe143015 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6e228e5bce9c623dc4b18c53060b02dd3568c20 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61213deba34b01f8d7ddd8a87b24351c8815fa10 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..657b6a93014f428eece18ec896136c81bc3949f3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py @@ -0,0 +1,2 @@ + +from . import pass_manager diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5a1bbb7cde32b33d28089417d8a52edc80e225a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e78b2a019e8758cb171ae2feb619bd3dd5fff678 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..dd699ea86cdecbe9f85af2b76b5503b3c8cbd0b5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py @@ -0,0 +1,75 @@ +import abc +from collections import namedtuple +from typing import Optional + +from torch.fx.graph_module import GraphModule +from torch.fx._compatibility import compatibility + + +__all__ = ['PassResult', 'PassBase'] + +@compatibility(is_backward_compatible=False) +class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): + """ + Result of a pass: + graph_module: The modified graph module + modified: A flag for if the pass has modified the graph module + """ + def __new__(cls, graph_module, modified): + return super().__new__(cls, graph_module, modified) + +@compatibility(is_backward_compatible=False) +class PassBase(abc.ABC): + """ + Base interface for implementing passes. + + It is required to implement the `call` function so that we can directly + pass instances of the Pass directly to the PassManager and call them as a + function. + + We can directly pass an instance of a class implementing this interface into + the PassManager's `passes` attribute. + """ + + def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(graph_module) + res = self.call(graph_module) + self.ensures(graph_module) + return res + + @abc.abstractmethod + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + The pass that is run through the given graph module. To implement a + pass, it is required to implement this function. + + Args: + graph_module: The graph module we will run a pass on + """ + pass + + def requires(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given graph module contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ + pass + + def ensures(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given graph module contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ + pass diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0adc75a1afd30a25df98bddfa6210e310b43fc92 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py @@ -0,0 +1,303 @@ +import inspect +import logging +from queue import Queue +from functools import wraps +from typing import Callable, Dict, List + +import torch.nn as nn +from torch.fx.graph_module import GraphModule +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager'] + +@compatibility(is_backward_compatible=False) +def pass_result_wrapper(fn: Callable) -> Callable: + """ + Wrapper for passes which currently do not return a PassResult. + This wrapper makes them return a PassResult containing the modified object + and True for the "modified" flag. + + Args: + fn (Callable[Module, Any]) + + Returns: + wrapped_fn (Callable[Module, PassResult]) + """ + if fn is None: + return None + + @wraps(fn) + def wrapped_fn(gm): + res = fn(gm) + if res is None: + return PassResult(gm, True) + if isinstance(res, PassResult): + return res + elif isinstance(res, nn.Module): + return PassResult(res, True) + + if not inspect.isfunction(fn): + wrapped_fn.__name__ = type(fn).__name__ + + return wrapped_fn + +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: List[Callable] +) -> None: + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + +def _topological_sort_passes( + passes: List[Callable], constraints: List[Callable] +) -> List[Callable]: + """ + Args + passes: Passes that we are ordering + constraints: Constraints applied on these passes + + Returns + A sorted list of callables and a boolean of if a circular dependency + existed + """ + if len(constraints) == 0: + return passes + + # Contruct a graph mapping nodes to a list of their users + graph: Dict[Callable, List[Callable]] = {p : [] for p in passes} + indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0) + candidates: Queue = Queue() + for a in passes: + for b in passes: + if a == b: + continue + + for constraint in constraints: + if not constraint(a, b): + graph[b].append(a) + indegree_map[a] += 1 + + if indegree_map[a] == 0: + candidates.put(a) + + visited: Dict[Callable, bool] = dict.fromkeys(passes, False) + sorted_passes: List[Callable] = [] + + while not candidates.empty(): + p = candidates.get() + sorted_passes.append(p) + visited[p] = True + + for n in graph[p]: + if not visited[n]: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + # Check if there are unvisited nodes (aka cycles in the graph) + cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) + if len(cycle_passes) != 0: + error = f"Circular dependency detected within the following passes: {cycle_passes}" + raise RuntimeError(error) + + return sorted_passes + +@compatibility(is_backward_compatible=False) +def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [pass_b, pass_a] + + constraints = [ + this_before_that_pass_constraint(pass_a, pass_b) + ] + ``` + + Args: + this (Callable): pass which should occur first + that (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + if a == that and b == this: + return False + return True + + return depends_on + + +@compatibility(is_backward_compatible=False) +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): List of passes. A pass is a + callable which modifies an object and returns a PassResult + constraint (Optional[List[Callable]]): List of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + steps (int): Max number of times we run the passes (default = 1). + run_checks_after_each_pass (bool): Whether to run checks and linting + after each pass + suppress_check_failures (bool): Whether to raise errors when running + checks + """ + + passes: List[Callable[[nn.Module], PassResult]] + constraints: List[Callable[[Callable, Callable], bool]] + _validated: bool = False + steps: int = 1 + + def __init__( + self, + passes=None, + constraints=None, + steps=None, + run_checks_after_each_pass: bool = False, + suppress_check_failures: bool = False, + ): + self.passes = passes or [] + self.constraints = constraints or [] + if steps: + self.steps = steps + + self.run_checks_after_each_pass = run_checks_after_each_pass + self.suppress_check_failures = suppress_check_failures + + def add_pass(self, _pass: Callable): + """ + Adds a pass into the current list of passes. + """ + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint: Callable): + """ + Adds a constraint into the current list of constraints. + """ + self.constraints.append(constraint) + self._validated = False + + def validate_constraints(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def solve_constraints(self): + """ + Finds a valid traversal order based on the given constraints and orders + the passes based on this order. + + If a circular dependency exists between the constraints and steps = 1, + then we will raise an error because if steps != 1 this means that we + will re-run the passes, allowing for circular dependencies. + """ + self.passes = _topological_sort_passes(self.passes, self.constraints) + self._validated = True + + def add_checks(self, check: Callable) -> None: + """ + Adds a function which takes runs various checks on a given graph module. + This function is run before and after each pass if the + `run_checks_after_each_pass` flag is enabled. + """ + sig = inspect.signature(check) + + if len(list(sig.parameters.values())) != 1: + raise TypeError("PassManager check function should only take in one variable, a module") + + setattr(self, "check", check) # noqa: B010 + + def check(self, module: nn.Module) -> None: + pass + + def __call__(self, module: nn.Module) -> PassResult: + """ + Runs a list of passes in the order based on `self.passes` on the given + graph module. Each time a pass is run, checks and linting will be run on + the graph module if `run_checks_after_each_pass` is set. + + If the module is a graph module, we will run the list of passes until + the graph stops changing, or until `steps` number of times. + """ + # Order the passes based on the constraints + if not self._validated: + self.solve_constraints() + + # Check graph invariants + self.check(module) + + # Run the set of passes `steps` number of times or until the graph stops + # changing + overall_modified = False + for _ in range(self.steps): + modified = False + + # Run the set of passes on the graph module + for i, fn in enumerate(self.passes): + fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ + logger.debug("Running pass '%s'", fn_name) + + try: + res = fn(module) + + if not isinstance(res, PassResult) and not hasattr( + res, "graph_module" + ): + raise TypeError( + f"The result of the pass {fn_name} should be type PassResult." + + "Please wrap it with pass_result_wrapper()" + ) + module = res.graph_module + modified = modified or res.modified + + if isinstance(module, GraphModule): + logger.debug("Graph after pass '%s': %s", fn_name, module.graph) + module.recompile() + + # Check graph invariants + if self.run_checks_after_each_pass: + self.check(module) + + except Exception as e: + prev_pass_names = [ + p.__name__ if inspect.isfunction(p) else type(p).__name__ + for p in self.passes[:i] + ] + msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}" + raise Exception(msg) from e + + # If the graph no longer changes, then we can stop running these passes + overall_modified = overall_modified or modified + if not modified: + break + + return PassResult(module, overall_modified) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6014b1c2aff40c74f86a845c9a6f2cfc1d5213 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py @@ -0,0 +1,675 @@ +import torch +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor +from torch.utils._pytree import tree_map_only +from torch.utils import _pytree as pytree +from torch.multiprocessing.reductions import StorageWeakRef + +import _operator +from enum import Enum +import itertools +from typing import Set, Dict +from collections import defaultdict + +__all__ = ['reinplace'] + +class _ViewType(Enum): + NonView = 0 + SingleOutputView = 1 + MultiOutputView = 2 + +def _is_view_op(tgt): + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return first_arg.alias_info is not None and not first_arg.alias_info.is_write + +def _get_view_type(tgt) -> _ViewType: + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + if first_arg.alias_info is not None and not first_arg.alias_info.is_write: + # check if op is a multi-output view + if '*' in first_arg.alias_info.after_set: + return _ViewType.MultiOutputView + else: + return _ViewType.SingleOutputView + return _ViewType.NonView + + +# Stores a bunch of metadata related to functionalization each node. +# Relevant metadata: +# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) +# The fake tensor output from running the current node +# n.meta['view_of']: Node +# If the current node n is a view of some base tensor, the 'view_of' field tells us which +# view node was used to generate the current node (a view tensor). +# This information actually makes `fake_result` redundant, but we can use `fake_result` +# to sanity check that our aliasing information is correct. +@compatibility(is_backward_compatible=False) +class _FunctionalizationMetadataProp(torch.fx.Interpreter): + + def run_node(self, node: Node): + self.node_counter += 1 + result = super().run_node(node) + node.meta['fake_result'] = result + node.meta['node_idx'] = self.node_counter + + # (1) Update metadata with the list of nodes that are used by this node + # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. + # We don't want to treat it as "being used as an input". + node_args = node.args + if node.target is torch.ops.aten.copy_.default: + node_args = node_args[1:] + + # (2) Update metadata to track aliasing information about view tensor nodes. + if node.op == 'call_function': + view_type = _get_view_type(node.target) + if view_type == _ViewType.SingleOutputView: + assert isinstance(node.args[0], Node) + node.meta['view_of'] = node.args[0] + elif view_type == _ViewType.MultiOutputView: + self.multi_output_view_nodes[node] = node.args[0] + + # Check if we returned a multi-output view, + # and we're now grabbing the individual views from the output. + # + # For multi-output views, we want to map each output view to the base, + # but this mapping involves two separate nodes in FX IR. + # e.g. "a, b = x_1.split(...)" becomes: + # %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) + # %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) + # %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) + # And we'd like to set: + # getitem1.meta['view_of'] = x_1 + elif node.target is _operator.getitem: + list_arg = node.args[0] + maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) + if maybe_base_of_view is not None: + # Note: we could also track indexing info here for multi-output views. + # I don't think this metadata is strictly needed for de-functionalization. + assert isinstance(maybe_base_of_view, Node) + node.meta['view_of'] = maybe_base_of_view + + if 'view_of' in node.meta: + # We're linking the current node with its first argument as views. + # Assert here that this is actually the case, and their storages are the same. + assert isinstance(node.meta['fake_result'], FakeTensor) + assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) + view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) + base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage()) + assert view_storage == base_storage + return result + + + + def propagate(self, *args): + self.multi_output_view_nodes = {} + self.node_counter = -1 + + with FakeTensorMode() as mode: + fake_args = [mode.from_tensor(a) for a in args] + return super().run(*fake_args) + +def _schemas_match(functional_schema, inplace_schema): + names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name + arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( + a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) + # for the inplace op, its first argument should be mutable + assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write + # and its remaining arguments shouldn't be. + assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) + return names_match and arg_types_match + +# TODO: this should be beefed up to be able to properly re-inplace with: +# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) +# - out= ops (e.g. angle -> angle.out) +# TODO: we should also figure this info out using torchgen. +def _maybe_get_inplace_op(op): + # __module__ seems broken; it returns torch._ops.aten which doesn't exist + if not isinstance(op, torch._ops.OpOverload): + return None + # Some view ops have inplace variants (as_strided_, etc), + # but we do NOT want the reinplacing pass to directly add these into the program. + # (they'll require extra special handling, aren't aren't really useful for perf anyway) + if _is_view_op(op): + return None + op_namespace = op.__module__.split(".")[-1] + op_base_name = op.overloadpacket.__name__ + maybe_namespace_module = getattr(torch.ops, op_namespace) + maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) + if maybe_inplace_op is None: + return None + + inplace_overloads = [ + getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() + ] + inplace_overloads_with_matching_schemas = [ + f + for f in inplace_overloads + if _schemas_match(op._schema, f._schema) + ] + # Just because foo() and foo_() are both existing operators, + # They aren't guaranteed to have compatible schemas. + # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant, + # Even though several overloads of pow_ exist. + if len(inplace_overloads_with_matching_schemas) == 0: + return None + assert len(inplace_overloads_with_matching_schemas) == 1 + inplace_op = inplace_overloads_with_matching_schemas[0] + return inplace_op + +_VIEW_INVERSE_MAP = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} + +# This function, given a set of set of (aliased) tensor nodes, +# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index +# in the node ordering. +def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): + def _add_if_tensor(x, set_): + if isinstance(x, FakeTensor): + set_.add(StorageWeakRef(x._typed_storage())) + + nodes_used_after = set() + for t in tensor_aliases: + # get all nodes that use the current alias + usage_nodes = t.users + for n in usage_nodes: + # We only care about usages after the current node + if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index: + continue + # We also don't care about intermediate view ops. + # They only matter if their output is then used elsewhere + # (either in an out-of-place op, or as an output to the function). + if n in tensor_aliases: + if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: + continue + nodes_used_after.add(n) + return nodes_used_after + +# Given an op that we're trying to re-inplace, "b = foo(a)", +# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" +# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: +# If there are any aliases in the alias_set(a) that satisfy: +# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" +# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata +# as "alias" +def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: + def matching_view_metadata(a, b): + return a.size() == b.size() and \ + a.stride() == b.stride() and \ + a.storage_offset() == b.storage_offset() + + view_inverse_nodes = set() + # Go through them in node order, so we can see chains of view_scatter ops. + for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): + if n.target not in _VIEW_INVERSE_MAP: + continue + base = n.args[0] + mutated_view = n.args[1] + assert isinstance(base, Node) + assert isinstance(base.meta['fake_result'], FakeTensor) + assert isinstance(mutated_view, Node) + assert isinstance(mutated_view.meta['fake_result'], FakeTensor) + # Check that this view_inverse op actually corresponds to taking doing the inverse + # of one of our existing self_alias nodes. + original_view = _VIEW_INVERSE_MAP[n.target] + for self_alias in self_aliases: + # We're looking for some alias of the self arg, "alias", + # that was created from some op `alias = foo(base, args...)` + # such that the current _scatter op "inverts" that foo call. + # We can check that by running the original op again, and checking that the strides match. + if 'view_of' not in self_alias.meta: + continue + self_alias_base = self_alias.meta['view_of'] + try: + # The we're trying to re-use the args from the view_scatter call inside of the corresponding + # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse + # of the current alias we're looking at. + view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) + expected_metadata = self_alias.meta['fake_result'] + # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. + if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ + matching_view_metadata(view_replay_metadata, expected_metadata): + view_inverse_nodes.add(n) + except Exception: + continue + + return view_inverse_nodes + + +@compatibility(is_backward_compatible=True) +def reinplace(gm, *sample_args): + """ + Given an fx.GraphModule, modifies it to perform "reinplacing", + mutating the nodes of the graph. + We look for out-of-place op call sites like `b = a.add(...)`, + and convert them to be inplace (`b = a.add_(...)`), + as long as the input to the current operator ("a") isn't re-used + anywhere later in the graph. + + This pass currently expects to operate on a **functional, ATen** graph. + This can be obtained by running `make_fx(functionalize(f))`. + + Sample inputs are needed to determine aliasing relationships of the inputs. + In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the + inputs to the program. + + Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: + + (1) Perform some initial checks on the metadata of "a" and "args..." + that can disqualify them from being reinplaced. + + (1a) Check that the self argument we're attempting to reinplace + has acceptable dtype/size metadata to reinplace with. + + For example, if we have: + a = torch.ones(1) + b = torch.ones(10) + out = torch.add(a, b) + We can't turn that into + a.add_(b) + Because that would require resizing "a". + + Similarly, we can't convert torch.ge(a, b) into a.ge_(b), + because that would require changing a's dtype (from e.g. float32 to bool). + Note that in this specific example, we could technically do better.. + + If we see the pattern: + a_1 = a.ge(b) + a_2 = aten._to_copy(a_1, a.dtype) + Then we this should be valid to completely re-inplace + (this is exactly what functionalization will emit when it sees a.ge_(b)). + + This optimization is only really important for user programs + that directly use inplace comparison ops though. + + We also cannot re-inplace on tensors that have overlapping memory, + e.g. torch.ones(1).expand(4, 4).add_(1) + + (1b) Check if "a" is an alias of any of the program inputs. + + If it is, skip and move to the next node. + Inplace'ing an op that would cause it to mutate a program is not sound, + because that would be a side effect visible to the user. + + NOTE: there's a future optimization that we should make: + if "a" is a (alias of a) program input, but later in the program + there is a node that looks like "a.copy_(...)", + Then re-inplacing is ok to do - we are temporarily re-using a's buffer, + which will later be overwritten by the copy_() call. + + This will be an important optimization to have for programs that mutate + their inputs. It currently isn't implemented though. + + (1c) Check if "a" and "args..." alias + + For example, re-inplacing to create code like the below + isn't guaranteed to be sound: + + aten.mul_(a, a) + + (2) Check that "a" and all of its outstanding aliases are not used anywhere + later in the graph. If this is the case, then it's safe to re-inplace + to "b = foo_(a)". + + There are a few caveats to this, explained in more detail below: + (a) If "a" is used later as an argument to a view op, that is okay. + It's only a problem if "a" (or that view) is later passed + into a normal operator, or if it is returned as the program output. + (b) If "a" is a repeat argument in `foo()`, then don't reinplace. + Most ATen kernels don't make any guarantees that this is sound, + e.g. if you do aten.mul_(a, a). + So we'll just ban re-inplacing in this case. + It's only a problem if "a" (or that view) is later passed + (c) If "a" is used as an input into a view "inverse" / "scatter" + operator, it is potentially fine to re-inplace + (and remove that scatter operator from the graph). + See below for a more detailed example. + + NOTE: there is an optimization in this step that is crucial + to fully recovering performance from functionalization. + + Given this program: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a) + torch.ops.aten.fill_(b, 0) + return d + + Functionalization will emit the following: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a, 0, 1) + b_updated = torch.ops.aten.fill(b, 0) + a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) + return a_updated + + Ordinarily, we would not be able to reinplace the fill, + because "b" aliases with "a" which is used by the diagonal_scatter call. + + "re-inplacing" is on the hook for figuring out that it is ok to + completely, the expensive diagonal_scatter call, if we re-inplace the add(). + + So, for every `alias in alias_set(a)`, instead of checking + that "alias" is not used anywhere later in the graph, + we check that + EITHER: + (a) alias is not used anywhere later in the graph + OR: + (b) alias is used exactly once later on in the graph, + in the following op: + + out = foo_scatter(alias, x, args...) + + where the following must hold: + (i) "foo_scatter" is the "inverse" operator for foo. + This only applies to "foo" ops that are view operators, + which view into a subset of the original tensor's memory. + In practice, there are ~4 operators where this applies: + diagonal -> diagonal_scatter + slice -> slice_scatter + select -> select_scatter + as_strided -> as_strided_scatter + (ii) "args..." are the same between the foo() and foo_scatter() calls. + + (3) Perform the actual re-inplacing on foo! + + (3b) is the common case, but special care is needed for {view}_scatter (3a) + + (3a) {view}_scatter ops. + + Consider this program: + a = torch.zeros(2, 2) + b = torch.ones(2) + a[0] = b + + Post functionalization, that will look like: + a = torch.zeros(2) + b = torch.ones(1) + a_updated = torch.select_scatter(a, b, 0, 0) + + In this case though, there is no "functional" op to re-inplace! + Instead, we'd like to directly remove toe select_scatter call. + We already know from (3) that this is valid, + because "a" has no later usages in the graph. + + We perform the re-inplacing on the {view}_scatter op like so + Before: + a_updated = torch.select_scatter(a, b, args...) + After: + a_slice = a.select(a, args...) + a_slice.copy_(b) + + (3b) Otherwise, replace the functional op with its inplace variant. + Before: + b = foo(a, args...) + After: + a.foo_(args...) + + (4) Finally, after converting either: + Before: + b = foo(a) + After: + foo_(a) + or + Before: + b = {slice}_scatter(a, mutated_slice, args...) + After: + slice = {slice}(a, args...) + slice.copy_(mutated_slice) + + We now need to find all later nodes that use "b" as an argument + and update them to take in "a" instead. + + Note that for the majority of inplace ops, this isn't actually necessary + (because most inplace ops return "self" as their output). + This isn't generally true for all mutable ops though, which is why + we need to actually replace all of the arguments. + + We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], + That maps a given tensor storage to the set of all nodes that take in that storage + as an input. + Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused + together. + + (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" + during step (3) get manually deleted from the graph. + Their outputs are no longer used, so technically standard DCE would be able + to do this, but we can no longer run FX's DCE pass now that we have mutable + ops in the graph. + """ + _FunctionalizationMetadataProp(gm).propagate(*sample_args) + + # Useful debug printing + # def _print(x): + # if isinstance(x, FakeTensor): + # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}') + + # for n in gm.graph.nodes: + # print(n.format_node()) + # if hasattr(n, 'meta'): + # print(f'node_idx: {n.meta["node_idx"]}') + # if 'fake_result' in n.meta: + # tree_map(_print, n.meta['fake_result']) + # if 'view_of' in n.meta: + # print(f'view_of: {str(n.meta["view_of"])}') + # print() + + # We need to know which nodes correspond to inputs (or their aliases) + # so we know not to re-inplace them. + # NOTE: later, we'll need to add an optimization for fully recovering performance + # on programs that mutate inputs. + input_storages = { + StorageWeakRef( + node.meta['fake_result']._typed_storage() + ) for node in gm.graph.nodes if node.op == 'placeholder'} + + + # We also need to know for a given node, what are all of its aliasing nodes. + storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) + for n in gm.graph.nodes: + if 'fake_result' in n.meta: + # Tree-mapping because some ops can return lists of tensors. + def _add_to_map(x): + if isinstance(x, FakeTensor): + storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) + pytree.tree_map_(_add_to_map, n.meta['fake_result']) + + # inplace-ify functional ops, subject to the constraints written below. + all_later_view_inverse_nodes_to_delete = set() + for idx, node in enumerate(gm.graph.nodes): + if node.op == 'call_function': + + # Today, the re-inplace pass on directly acts on: + # - functional ops with an inplace variant + # - {view}_scatter ops that can be potentially removed from the graph. + # Both of these ops take in tensor first args, so filtering on this condition + # makes the later code simpler. + # We should revisit this at some point though, particularly when we also want + # the reinplacer to be able to handle out= and mutable operators + # and tensorlist first args (like `_foreach_` ops). + if not isinstance(node.target, torch._ops.OpOverload): + continue + if len(node.target._schema.arguments) < 1: + continue + if type(node.target._schema.arguments[0].type) != torch.TensorType: + continue + + # Step 1a: Check that the self argument we're attempting to reinplace + # has the same size/stride as the output. + # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor) + # As it would require resizing scalar_tensor. + # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), + # this is probably an optimization to revisit later). + self_arg = node.args[0] + self_flattened = pytree.tree_leaves(self_arg.meta['fake_result']) + node_flattened = pytree.tree_leaves(node.meta['fake_result']) + self_has_wrong_metadata = False + if len(self_flattened) == len(node_flattened): + for self_meta, node_meta in zip(self_flattened, node_flattened): + if self_meta.numel() != node_meta.numel(): + self_has_wrong_metadata = True + if self_meta.dtype != node_meta.dtype: + self_has_wrong_metadata = True + # We also cannot re-inplace on tensors that have internal memory overlap. + # e.g. torch.ones(1).expand(4, 4).add_(1) + if torch._debug_has_internal_overlap(self_meta) == 1: + self_has_wrong_metadata = True + # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace, + # Since users should never really be calling the functional "torch.ops.aten.resize" + # op directly in their programs. + if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: + continue + + # Step 1b: ensure that the op we're trying to re-inplace isn't a program input + self_arg_name = self_arg.name + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) + if self_arg_storage in input_storages: + # TODO: later, add the optimization for handling `copy_()` calls in the graph. + continue + if len([x for x in node.args if x is self_arg]) > 1: + # Step 1c: + # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, + # so we prevent re-inplacing in this case. + continue + + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) + self_aliases = storage_to_nodes[self_arg_storage] + + # First, we find all later usages of any of the aliases of self_arg. + later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) + # Then, we check if any of those later usages are actually view_scatter ops + # that are safe to fully remove. + later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) + + # Step 2: Check to see if the input to the op is re-used later in the graph. + # If not (same goes for its aliases), then this op is safe to re-in place. + # This is a slightly roundabout way to check that there are no later usages of the current self argument. + # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) + can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 + if not can_reinplace: + continue + + # Step 3a: Special handling for when we see *_scatter operators. + # When we see an operator like `b = torch.slice_scatter(a, ...)`, + # instead of trying to "inplace" it into a.slice_scatter_(..._), + # we would prefer to remove it from the graph entirely, + # and instead copy_() the slice directly into the larger tensor. + # See the description of the algorithm for a full example. + if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete: + view_op = _VIEW_INVERSE_MAP[node.target] + # Before: + # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) + # After: + # slice = torch.ops.aten.slice.default(base, args...) + # slice.copy_(mutated_slice) + with gm.graph.inserting_before(node): + mutated_slice_node = node.args[1] + remaining_slice_args = node.args[2:] + slice_node = gm.graph.create_node( + 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs) + copy_node = gm.graph.create_node( + 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {}) + # Add the slice_scatter node to our "nodes to delete" list. + all_later_view_inverse_nodes_to_delete.add(node) + + + else: + # Step 3b: Check to see if this operator has an inplace variant. + maybe_inplace_op = _maybe_get_inplace_op(node.target) + if maybe_inplace_op is None: + continue + # And if so, replace it with its inplace variant. + node.target = maybe_inplace_op + + # At this point, 'storage_to_nodes' will be stale. + # Now that we're inplacing `b = foo(a)`, we need to effectively + # union together the dict values for b and a's storage. + # Hmm... morally I think we also want to keep the `fake_result` metadata + # up to date here, but I'm not sure how easy it is to do. + # Maybe it's fine to wait until the end of the pass to update it. + curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) + storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) + storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) + + # Need to remember the view_scatter view nodes we found so we can remove them alter. + all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages) + + # Step 4: + # Now that we've replaced b = a.foo() with a.foo_(), + # We need to replace any later usages of "b" with "a" + for old in itertools.chain([node], later_view_inverse_node_usages): + new = old.args[0] + nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] + for node_to_update in nodes_to_update: + new_args = [] + args = node_to_update.args + + def replace_arg(a): + if a == old: + return new + return a + + # First, replace usages of "b" with "a" + node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args) + node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs) + + # Second, update our storage_to_nodes data structure. + old_flattened_res = pytree.tree_leaves(old.meta['fake_result']) + node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result']) + + old_res_storage = { + StorageWeakRef( + x._typed_storage() + ) for x in old_flattened_res if isinstance(x, FakeTensor)} + node_res_storage = { + StorageWeakRef( + x._typed_storage() + ) for x in node_flattened_res if isinstance(x, FakeTensor)} + + # This will happen if we're updating a view op, e.g. + # e.g. replacing + # x = view(old) + # x = view(new) + # When that happens, we need to make sure to keep our + # storage mapping up to date. + # + # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, + # or multiple tensors that all share the same storage. + # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. + if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: + new_flattened_res = pytree.tree_leaves(new.meta['fake_result']) + new_res_storage = { + StorageWeakRef( + x._typed_storage() + ) for x in new_flattened_res if isinstance(x, FakeTensor)} + assert len(new_res_storage) == 1 + (old_ref,) = old_res_storage + (new_ref,) = new_res_storage + (node_ref,) = node_res_storage + # Technically, "old_ref" and all its aliases will remain + # in our mapping. + # That should be fine though, since we deleted "old" + # from the graph at this point. + storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) + storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) + + # Step 4: delete any _scatter nodes that we de-functionalized + # Need to take care not to delete any of these nodes until after *all* modifications + # to the graph are finished. + for to_delete in all_later_view_inverse_nodes_to_delete: + gm.graph.erase_node(to_delete) + + + gm.recompile() + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..735f1030b355c06f946b52d38c6e2f5b00a6cda6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2830f60d5eab1cc188826313b640ffcf0c00d94a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py @@ -0,0 +1,144 @@ +from dataclasses import dataclass, field +from torch.fx.graph import Graph +from torch.fx.node import Node +from torch.fx._compatibility import compatibility +from typing import Dict, List, Any, Type, Optional, Callable +import logging +import os + + +__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition'] + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger(): + logger = logging.getLogger(__name__) + + level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + +logger = _init_logger() + + +@compatibility(is_backward_compatible=False) +@dataclass +class SourcePartition: + # Nodes in a particular partition + nodes: List[Node] + + # The source these nodes decomposed from + source: Any + + # Nodes in the graph that are needed as inputs to the partition + input_nodes: List[Node] = field(default_factory=list) + + # Nodes in the partition that are being used by nodes outside of the + # partition + output_nodes: List[Node] = field(default_factory=list) + + # Parameters that are being used + params: List[Node] = field(default_factory=list) + + +@compatibility(is_backward_compatible=False) +def get_source_partitions( + graph: Graph, + wanted_sources: List[Any], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Dict[Any, List[SourcePartition]]: + """ + Args: + graph: The graph we want to partition + wanted_sources: List of sources of nodes that were decomposed from this + source. This can be a function (ex. torch.nn.functional.linear) or a + leaf module type (ex. torch.nn.Linear). + + Returns: + Dictionary mapping sources that were given to a list of SourcePartitions + that correspond to the list of nodes that were decomposed from the given + source. + """ + modules: Dict[Type, Dict[str, List[Node]]] = {} + + for node in graph.nodes: + # The metadata source_fn should contain a tuple of a unique name for the + # source, and the source function if the node is decomposed from a + # function, or the type of module if the node is decomposed from a leaf + # module + + if (source_fn_st := node.meta.get("source_fn_stack", None)) is None: + continue + + source_fn = source_fn_st[-1] + if source_fn[1] not in wanted_sources: + continue + + diff_modules = modules.setdefault(source_fn[1], {}) + partition = diff_modules.setdefault(source_fn[0], []) + partition.append(node) + + def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition: + input_nodes = set() + output_nodes = set() + params = set() + for node in nodes: + for arg in node.args: + if isinstance(arg, Node) and arg not in nodes: + input_nodes.add(arg) + + if node.op == "get_attr": + params.add(node) + + for user in node.users.keys(): + if user not in nodes: + output_nodes.add(node) + + return SourcePartition( + nodes, + module_type, + list(input_nodes), + list(output_nodes), + list(params), # type: ignore[arg-type] + ) + + ret: Dict[Type[Any], List[SourcePartition]] = {} + + if filter_fn: + # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the + # filter condition + filtered_modules = {} + for tp, name_to_partition in modules.items(): + filtered_name_to_partition = { + name: partition + for name, partition in name_to_partition.items() + if all(map(filter_fn, partition)) + } + filtered_modules[tp] = filtered_name_to_partition + modules = filtered_modules + + for k, v in modules.items(): + ret[k] = [make_partition(partition, k) for partition in v.values()] + + return ret + + +@compatibility(is_backward_compatible=False) +def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool: + """ + Given two subgraphs A and B (in the form of a list of nodes), checks if + A has nodes connecting to at least one node in B -- aka there exists a node + in B that uses a node in A (not the other way around). + """ + + for node in reversed(subgraph1.nodes): + for user in node.users.keys(): + if user in subgraph2.nodes: + return True + return False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbb1fb07ff885d5fc4d26667e5fb4a1670efb9e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py @@ -0,0 +1,78 @@ +"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module. + +It registers custom reducers, that use shared memory to provide shared +views on the same data in different processes. Once the tensor/storage is moved +to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible +to send it to other processes without making any copies. + +The API is 100% compatible with the original module - it's enough to change +``import multiprocessing`` to ``import torch.multiprocessing`` to have all the +tensors sent through the queues or shared via other mechanisms, moved to shared +memory. + +Because of the similarity of APIs we do not document most of this package +contents, and we recommend referring to very good docs of the original module. +""" +import multiprocessing +import sys + +import torch +from .reductions import init_reductions + +__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"] + + +from multiprocessing import * # noqa: F403 + + +__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined] + + +# This call adds a Linux specific prctl(2) wrapper function to this module. +# See https://github.com/pytorch/pytorch/pull/14391 for more information. +torch._C._multiprocessing_init() + + +"""Add helper function to spawn N processes and wait for completion of any of +them. This depends `mp.get_context` which was added in Python 3.4.""" +from .spawn import ( + ProcessContext, + ProcessExitedException, + ProcessRaisedException, + spawn, + SpawnContext, + start_processes, +) + + +if sys.platform == "darwin" or sys.platform == "win32": + _sharing_strategy = "file_system" + _all_sharing_strategies = {"file_system"} +else: + _sharing_strategy = "file_descriptor" + _all_sharing_strategies = {"file_descriptor", "file_system"} + + +def set_sharing_strategy(new_strategy): + """Set the strategy for sharing CPU tensors. + + Args: + new_strategy (str): Name of the selected strategy. Should be one of + the values returned by :func:`get_all_sharing_strategies()`. + """ + global _sharing_strategy + assert new_strategy in _all_sharing_strategies + _sharing_strategy = new_strategy + + +def get_sharing_strategy(): + """Return the current strategy for sharing CPU tensors.""" + return _sharing_strategy + + +def get_all_sharing_strategies(): + """Return a set of sharing strategies supported on a current system.""" + return _all_sharing_strategies + + +init_reductions() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8254cb074cd23dede627b78ed46fa3b2a80f10c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc2bc9661492bb71819d8b821473b1c298f4f839 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35ab6616f9a73c3e0c78a055c6ebd1c00954d453 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py new file mode 100644 index 0000000000000000000000000000000000000000..92a3280fee78b538230dfa63862c4681c1a5b186 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py @@ -0,0 +1,33 @@ +import sys + +__all__ = ["register_after_fork"] + +if sys.platform == "win32": + import multiprocessing.util as _util + + def _register(func): + def wrapper(arg): + func() + + _util.register_after_fork(_register, wrapper) + +else: + import os + + def _register(func): + os.register_at_fork(after_in_child=func) + + +def register_after_fork(func): + """Register a callable to be executed in the child process after a fork. + + Note: + In python < 3.7 this will only work with processes created using the + ``multiprocessing`` module. In python >= 3.7 it also works with + ``os.fork()``. + + Args: + func (function): Function taking no arguments to be called in the child after fork + + """ + _register(func) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fca8055ad253e48bed216dfc43a34a8f11a99913 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__init__.py @@ -0,0 +1,117 @@ +""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """ +import contextlib +from typing import List, Union +from warnings import warn + +from torch.backends.cuda import ( + can_use_efficient_attention, + can_use_flash_attention, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + flash_sdp_enabled, + math_sdp_enabled, + mem_efficient_sdp_enabled, + SDPAParams, +) + +__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"] + +# Note: [SDPA warnings] +# TODO: Consider using this for sdpa regardless of subclasses +# This only effects users of bias subclasses +# If this is set to True, we will warn the user if they are not using the fused kernels +# As well, it will raise warnings for all the reasons why the fused kernels can't be run. +# To set this to True, run +# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True +WARN_FOR_UNFUSED_KERNELS = False + + +from torch._C import _SDPBackend as SDPBackend + +# Hacks for Sphinx documentation: +# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class +SDPBackend = SDPBackend +r"""An enum-like class that contains the different backends for scaled dot product attention. + This backend class is designed to be used with the sdpa_kernel context manager. + + The following Enums are available: + - ERROR: An error occurred when trying to determine the backend. + - MATH: The math backend for scaled dot product attention. + - FLASH_ATTENTION: The flash attention backend for scaled dot product attention. + - EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention. + - CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention. + + See :func:`torch.nn.attention.sdpa_kernel` for more details. + + .. warning:: This class is in beta and subject to change. +""" +SDPBackend.__module__ = __name__ +SDPBackend.__name__ = "SDPBackend" + + +def _raise_kernel_warnings(params: SDPAParams) -> None: + """ + If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings + for all the reasons why the fused kernels can't be run. If using subclasses + """ + if WARN_FOR_UNFUSED_KERNELS: + if not can_use_efficient_attention(params): + warn("Efficient attention can't be used because:") + can_use_efficient_attention(params, True) + if not can_use_flash_attention(params): + warn("Flash attention can't be used because:") + can_use_flash_attention(params, True) + + +@contextlib.contextmanager +def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]): + r""" + Context manager to select which backend to use for scaled dot product attention. + + .. warning:: This function is beta and subject to change. + + Args: + backend (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention. + + Example: + + .. code-block:: python + + from torch.nn.functional import scaled_dot_product_attention + from torch.nn.attention import SDPBackend, sdpa_kernel + # Only enable flash attention backend + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + scaled_dot_product_attention(...) + + # Enable the Math or Efficient attention backends + with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + scaled_dot_product_attention(...) + + This context manager can be used to select which backend to use for scaled dot product attention. + Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends. + """ + assert isinstance( + backends, (list, SDPBackend) + ), "Backend must be an instance of SDPBackend or a list of SDPBackend instances" + + if isinstance(backends, SDPBackend): + backends = [backends] + + backends = set(backends) + previous_flash: bool = flash_sdp_enabled() + previous_mem_efficient: bool = mem_efficient_sdp_enabled() + previous_math: bool = math_sdp_enabled() + try: + enable_flash = SDPBackend.FLASH_ATTENTION in backends + enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends + enable_math = SDPBackend.MATH in backends + + enable_flash_sdp(enable_flash) + enable_mem_efficient_sdp(enable_mem_efficient) + enable_math_sdp(enable_math) + yield {} + finally: + enable_flash_sdp(previous_flash) + enable_mem_efficient_sdp(previous_mem_efficient) + enable_math_sdp(previous_math) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e29e8caa987d280b4ffb0d41aa6ef9dbc4caded Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/bias.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/bias.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e2d3f47a483c2b586a8aa1f000a0b70f488f6d4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/bias.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/bias.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/bias.py new file mode 100644 index 0000000000000000000000000000000000000000..d54ed8915789d4ac2cd9c328c95003e4c27e7e43 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/bias.py @@ -0,0 +1,353 @@ +"""Defines bias subclasses that work with scaled_dot_product_attention""" +from enum import auto, IntEnum +from typing import Optional +from warnings import warn + +import torch +from torch.backends.cuda import ( + can_use_efficient_attention, + can_use_flash_attention, + SDPAParams, +) +from torch.nn.attention import _raise_kernel_warnings +from torch.nn.attention._utils import ( + _calculate_scale, + _input_requires_grad, + _postprocess_flash_output, + _validate_sdpa_input, +) +from torch.nn.functional import scaled_dot_product_attention + +__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"] + + +torch._dynamo.allow_in_graph(can_use_flash_attention) +torch._dynamo.allow_in_graph(can_use_efficient_attention) +torch._dynamo.allow_in_graph(SDPAParams) + + +class CausalVariant(IntEnum): + r""" + Enum for causal variants used in attention mechanisms. + + Defines two types of causal biases: + + `UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention. + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + torch.tril(torch.ones(size, dtype=torch.bool)) + + For instance, with `shape=(3,4)`, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0]] + + + `LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower + right corner of the matrix. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + diagonal_offset = size[1] - size[0] + torch.tril( + torch.ones(size, dtype=torch.bool), + diagonal=diagonal_offset, + ) + + For instance, with `shape=(3,4)`, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]] + + Note that these variants are equivalent to each other when the sequence lengths of the query and key/value + tensors are equal since the triangular matrix is square. + + .. warning:: This enum is a prototype and subject to change. + """ + + UPPER_LEFT = auto() + LOWER_RIGHT = auto() + + +class CausalBias(torch.Tensor): + """ + A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum. + + This class is used for defining causal (triangular) attention biases. For construing the bias, there exist + two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`. + + Example: + + .. code-block:: python + + from torch.nn.attention.bias import causal_lower_right + + bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 + + # Create a lower-right causal bias + attn_bias = causal_lower_right(seqlen_q, seqlen_kv) + + q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) + + out = F.scaled_dot_product_attention(q, k, v, attn_bias) + + .. warning:: This class is a prototype and subject to change. + """ + + def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int): + """ + Initializes the CausalBias instance with a specified variant and sequence lengths. + + Args: + variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT). + seq_len_q (int): The sequence length of the query tensor. + seq_len_kv (int): The sequence length of the key/value tensor. + + Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs. + """ + assert isinstance(variant, CausalVariant) + self.variant = variant + self.seq_len_q = seq_len_q + self.seq_len_kv = seq_len_kv + if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT: + warn( + "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!" + ) + + def _upper_left(self, device: torch.device) -> torch.Tensor: + """Upper left causal bias""" + return torch.tril( + torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool) + ) + + def _lower_right(self, device: torch.device) -> torch.Tensor: + """Lower right causal bias""" + diagonal_offset = self.seq_len_kv - self.seq_len_q + return torch.tril( + torch.ones( + self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool + ), + diagonal=diagonal_offset, + ) + + def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor: + """ + Materializes the causal bias into a tensor form. + + Depending on the variant, this method generates either an upper-left or lower-right + triangular matrix to represent the causal bias. + + Args: + device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU. + + Returns: + torch.Tensor: The materialized bias tensor. + """ + if device is None: + device = torch.device("cpu") + if self.variant == CausalVariant.UPPER_LEFT: + return self._upper_left(device) + elif self.variant == CausalVariant.LOWER_RIGHT: + return self._lower_right(device) + + @staticmethod + def _dispatch( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: "CausalBias", + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + ) -> torch.Tensor: + r""" + Handles the logic for computing attention with the specified causal bias. + + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. + attn_mask (CausalBias): The type of causal attention to apply. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal + are set. + scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. + + Raises: + ValueError: If the causal bias variant is not a CausalVariant type. + + """ + if is_causal: + raise ValueError("CausalBias should not be used with causal=True") + + if ( + attn_mask.seq_len_q == attn_mask.seq_len_kv + or attn_mask.variant == CausalVariant.UPPER_LEFT + ): + return scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True, + scale=scale, + ) + elif attn_mask.variant == CausalVariant.LOWER_RIGHT: + _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale) + sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal) + if can_use_flash_attention(sdpa_params): + needs_padding = query.size(-1) % 8 != 0 + og_head_size = query.size(-1) + og_scale = _calculate_scale(og_head_size, scale) + if needs_padding: + query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8)) + key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8)) + value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8)) + out = torch.ops.aten._scaled_dot_product_flash_attention( + query, + key, + value, + dropout_p, + is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right + return_debug_mask=False, + scale=og_scale, + )[0] + return _postprocess_flash_output(out, og_head_size) + if can_use_efficient_attention(sdpa_params): + compute_log_sumexp = False + if _input_requires_grad(query, key, value): + compute_log_sumexp = True + return torch.ops.aten._efficient_attention_forward( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=int(attn_mask.variant), + compute_log_sumexp=compute_log_sumexp, + scale=scale, + causal_diagonal=None, + seqlen_k=None, + )[0].transpose(1, 2) + else: + _raise_kernel_warnings(sdpa_params) + # We cant use efficient attention the only support for lower right is via materialization + return scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask._materialize(query.device), + dropout_p=dropout_p, + is_causal=False, + scale=scale, + ) + else: + raise ValueError( + f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}" + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" + if kwargs is None: + kwargs = {} + if func != torch.nn.functional.scaled_dot_product_attention: + raise NotImplementedError( + "CausalBias only supports scaled_dot_product_attention" + ) + return cls._dispatch(*args, **kwargs) + + def __repr__(self): + return self._materialize().__repr__() + + +def causal_upper_left(*size) -> CausalBias: + """ + Creates an upper-left triangular causal bias. + + This function generates a upper-left triangular matrix to represent causal attention bias with a + diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix. + This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + torch.tril(torch.ones(size, dtype=torch.bool)) + + For instance, with `shape=(3,4)`, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0]] + + Args: + size: The size of the bias matrix. + + Returns: + CausalBias: The UPPER_LEFT triangular causal bias variant. + """ + assert len(size) == 2, "causal_upper_left only supports 2D tensors" + seq_len_q, seq_len_kv = size + return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv) + + +def causal_lower_right(*size) -> CausalBias: + """ + Creates a lower-right triangular causal bias. + + This function generates a lower-right triangular matrix to represent causal attention bias with a + diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix. + + The equivalent pytorch code for constructing this bias is: + + .. code-block:: python + + diagonal_offset = size[1] - size[0] + torch.tril( + torch.ones(size, dtype=torch.bool), + diagonal=diagonal_offset, + ) + + For instance, with `shape=(3,4)`, the materialized bias tensor will be: + + .. code-block:: text + + [[1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]] + + Args: + size: The size of the bias matrix. + + Returns: + CausalBias: The LOWER_RIGHT triangular causal bias variant. + """ + assert len(size) == 2, "causal_lower_right only supports 2D tensors" + seq_len_q, seq_len_kv = size + return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/common_types.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/common_types.py new file mode 100644 index 0000000000000000000000000000000000000000..884f739e27813a38e8364469bb9659c09ea410b3 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/common_types.py @@ -0,0 +1,42 @@ +from typing import TypeVar, Union, Tuple, Optional +from .. import Tensor + +# Create some useful type aliases + +# Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally +# broadcast to a tuple. +# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations. +T = TypeVar('T') +_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] +_scalar_or_tuple_1_t = Union[T, Tuple[T]] +_scalar_or_tuple_2_t = Union[T, Tuple[T, T]] +_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]] +_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]] +_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]] +_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]] + +# For arguments which represent size parameters (eg, kernel size, padding) +_size_any_t = _scalar_or_tuple_any_t[int] +_size_1_t = _scalar_or_tuple_1_t[int] +_size_2_t = _scalar_or_tuple_2_t[int] +_size_3_t = _scalar_or_tuple_3_t[int] +_size_4_t = _scalar_or_tuple_4_t[int] +_size_5_t = _scalar_or_tuple_5_t[int] +_size_6_t = _scalar_or_tuple_6_t[int] + +# For arguments which represent optional size parameters (eg, adaptive pool parameters) +_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]] +_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]] +_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]] + +# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters) +_ratio_2_t = _scalar_or_tuple_2_t[float] +_ratio_3_t = _scalar_or_tuple_3_t[float] +_ratio_any_t = _scalar_or_tuple_any_t[float] + +_tensor_list_t = _scalar_or_tuple_any_t[Tensor] + +# For the return value of max pooling operations that may or may not return indices. +# With the proposed 'Literal' feature to Python typing, it might be possible to +# eventually eliminate this. +_maybe_indices_t = _scalar_or_tuple_2_t[Tensor] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/grad.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/grad.py new file mode 100644 index 0000000000000000000000000000000000000000..660c87fb4133c4bffc9fb711767306a1ccf7f5f1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/grad.py @@ -0,0 +1,189 @@ +"""Gradient interface.""" + +import torch +from .modules.utils import _single, _pair, _triple + + +def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): + r"""Compute the gradient of conv1d with respect to the input of the convolution. + + This is same as the 1D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weight tensor (out_channels x in_channels/groups x kW) + grad_output : output gradient tensor (minibatch x out_channels x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, requires_grad=True) + >>> output = F.conv1d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv1d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _single(stride), _single(padding), _single(dilation), + False, [0], groups, (True, False, False))[0] + + +def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): + r"""Compute the gradient of conv1d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, requires_grad=True) + >>> output = F.conv1d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> # xdoctest: +SKIP + >>> grad_weight = torch.autograd.grad(output, filter, grad_output) + >>> F.grad.conv1d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _single(stride), _single(padding), _single(dilation), + False, [0], groups, (False, True, False))[1] + + +def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): + r"""Compute the gradient of conv2d with respect to the input of the convolution. + + This is same as the 2D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weight tensor (out_channels x in_channels/groups x kH x kW) + grad_output : output gradient tensor (minibatch x out_channels x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) + >>> output = F.conv2d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv2d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _pair(stride), _pair(padding), _pair(dilation), + False, [0], groups, (True, False, False))[0] + + +def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): + r"""Compute the gradient of conv2d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iH x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(1, 1, 3, 3, requires_grad=True) + >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True) + >>> output = F.conv2d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> # xdoctest: +SKIP + >>> grad_weight = torch.autograd.grad(output, filter, grad_output) + >>> F.grad.conv2d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _pair(stride), _pair(padding), _pair(dilation), + False, [0], groups, (False, True, False))[1] + + +def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): + r"""Compute the gradient of conv3d with respect to the input of the convolution. + + This is same as the 3D transposed convolution operator under the hood but requires + the shape of the gradient w.r.t. input to be specified explicitly. + + Args: + input_size : Shape of the input gradient tensor + weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW) + grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) + >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) + >>> output = F.conv3d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_input = torch.autograd.grad(output, input, grad_output) + >>> F.grad.conv3d_input(input.shape, weight, grad_output) + + """ + input = grad_output.new_empty(1).expand(input_size) + + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _triple(stride), _triple(padding), _triple(dilation), + False, [0], groups, (True, False, False))[0] + + +def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): + r"""Compute the gradient of conv3d with respect to the weight of the convolution. + + Args: + input: input tensor of shape (minibatch x in_channels x iT x iH x iW) + weight_size : Shape of the weight gradient tensor + grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW) + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + + Examples:: + + >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True) + >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True) + >>> output = F.conv3d(input, weight) + >>> grad_output = torch.randn(output.shape) + >>> grad_weight = torch.autograd.grad(output, weight, grad_output) + >>> F.grad.conv3d_weight(input, weight.shape, grad_output) + + """ + weight = grad_output.new_empty(1).expand(weight_size) + + return torch.ops.aten.convolution_backward(grad_output, input, weight, None, + _triple(stride), _triple(padding), _triple(dilation), + False, [0], groups, (False, True, False))[1] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f44820c637e86c82c0f5d02919fa1c66803f21ac --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__init__.py @@ -0,0 +1,31 @@ +from .linear_relu import LinearReLU +from .linear_fused import LinearBn1d +from .conv_fused import ( + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + update_bn_stats, + freeze_bn_stats, +) + +__all__ = [ + "LinearReLU", + "LinearBn1d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..496884d013e37763d4b358061c023b498e8a94b9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..65c811cc5666d21fcbebeeb47f9efdc52d21375f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py @@ -0,0 +1,9 @@ +from torch.ao.nn.intrinsic.quantized import ConvReLU1d +from torch.ao.nn.intrinsic.quantized import ConvReLU2d +from torch.ao.nn.intrinsic.quantized import ConvReLU3d + +__all__ = [ + 'ConvReLU1d', + 'ConvReLU2d', + 'ConvReLU3d', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/dropout.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/dropout.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a246b1af85d789099ac349bb5e570ed2650a749 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/dropout.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/padding.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/padding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e805693c4b92742ab8f1b87a80bd815031d7f541 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/padding.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/adaptive.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/adaptive.py new file mode 100644 index 0000000000000000000000000000000000000000..3d61e9d8f59aed12af20bd48a075d421ac90560a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/adaptive.py @@ -0,0 +1,312 @@ + +from collections import namedtuple + +import torch + +from torch import Tensor +from typing import List, Sequence + +from . import Sequential, ModuleList, Linear +from .module import Module +from ..functional import log_softmax + +__all__ = ['AdaptiveLogSoftmaxWithLoss'] + +_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) + + +class AdaptiveLogSoftmaxWithLoss(Module): + r"""Efficient softmax approximation. + + As described in + `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, + Moustapha Cissé, David Grangier, and Hervé Jégou + `__. + + Adaptive softmax is an approximate strategy for training models with large + output spaces. It is most effective when the label distribution is highly + imbalanced, for example in natural language modelling, where the word + frequency distribution approximately follows the `Zipf's law`_. + + Adaptive softmax partitions the labels into several clusters, according to + their frequency. These clusters may contain different number of targets + each. + Additionally, clusters containing less frequent labels assign lower + dimensional embeddings to those labels, which speeds up the computation. + For each minibatch, only clusters for which at least one target is + present are evaluated. + + The idea is that the clusters which are accessed frequently + (like the first one, containing most frequent labels), should also be cheap + to compute -- that is, contain a small number of assigned labels. + + We highly recommend taking a look at the original paper for more details. + + * :attr:`cutoffs` should be an ordered Sequence of integers sorted + in the increasing order. + It controls number of clusters and the partitioning of targets into + clusters. For example setting ``cutoffs = [10, 100, 1000]`` + means that first `10` targets will be assigned + to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be + assigned to the first cluster, and targets `101, 102, ..., 1000` will be + assigned to the second cluster, while targets + `1001, 1002, ..., n_classes - 1` will be assigned + to the last, third cluster. + + * :attr:`div_value` is used to compute the size of each additional cluster, + which is given as + :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters + for less frequent words having larger indices, + and indices starting from :math:`1`). + + * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the + adaptive softmax. See paper for details. Set to False in the official + implementation. + + .. warning:: + Labels passed as inputs to this module should be sorted according to + their frequency. This means that the most frequent label should be + represented by the index `0`, and the least frequent + label should be represented by the index `n_classes - 1`. + + .. note:: + This module returns a ``NamedTuple`` with ``output`` + and ``loss`` fields. See further documentation for details. + + .. note:: + To compute log-probabilities for all classes, the ``log_prob`` + method can be used. + + Args: + in_features (int): Number of features in the input tensor + n_classes (int): Number of classes in the dataset + cutoffs (Sequence): Cutoffs used to assign targets to their buckets + div_value (float, optional): value used as an exponent to compute sizes + of the clusters. Default: 4.0 + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the + adaptive softmax. Default: ``False`` + + Returns: + ``NamedTuple`` with ``output`` and ``loss`` fields: + * **output** is a Tensor of size ``N`` containing computed target + log probabilities for each example + * **loss** is a Scalar representing the computed negative + log likelihood loss + + Shape: + - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})` + - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` + - output1: :math:`(N)` or :math:`()` + - output2: ``Scalar`` + + .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + """ + + in_features: int + n_classes: int + cutoffs: List[int] + div_value: float + head_bias: bool + head: Linear + tail: ModuleList + + def __init__( + self, + in_features: int, + n_classes: int, + cutoffs: Sequence[int], + div_value: float = 4., + head_bias: bool = False, + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + + cutoffs = list(cutoffs) + + if (len(cutoffs) == 0): + raise ValueError("cutoffs should be a sequence of length larger than 0") + + if (cutoffs != sorted(cutoffs)) \ + or (min(cutoffs) <= 0) \ + or (max(cutoffs) > (n_classes - 1)) \ + or (len(set(cutoffs)) != len(cutoffs)) \ + or any(int(c) != c for c in cutoffs): + + raise ValueError("cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1") + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head = Linear(self.in_features, self.head_size, bias=self.head_bias, + **factory_kwargs) + self.tail = ModuleList() + + for i in range(self.n_clusters): + + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + + projection = Sequential( + Linear(self.in_features, hsz, bias=False, **factory_kwargs), + Linear(hsz, osz, bias=False, **factory_kwargs), + ) + + self.tail.append(projection) + + def reset_parameters(self) -> None: + self.head.reset_parameters() + for i2h, h2o in self.tail: + i2h.reset_parameters() + h2o.reset_parameters() + + def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: + targ_dim = target_.dim() + + if targ_dim == 1: + if input_.size(0) != target_.size(0): + raise RuntimeError('Input and target should have the same size ' + 'in the batch dimension.') + if input_.dim() != 2: + raise RuntimeError('1D target tensor expects 2D input tensors, ' + 'but found inputs with size', input_.size()) + elif targ_dim == 0: + if input_.dim() != 1: + raise RuntimeError('0D target tensor expects 1D input tensors, ' + 'but found inputs with size', input_.size()) + else: + raise RuntimeError('0D or 1D target tensor expected, ' + 'multi-target not supported') + + is_batched = targ_dim > 0 + input = input_ if is_batched else input_.unsqueeze(0) + target = target_ if is_batched else target_.unsqueeze(0) + + used_rows = 0 + batch_size = target.size(0) + + output = input.new_zeros(batch_size) + gather_inds = target.new_empty(batch_size) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + target_mask = (target >= low_idx) & (target < high_idx) + row_indices = target_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + + if i == 0: + gather_inds.index_copy_(0, row_indices, target[target_mask]) + + else: + relative_target = target[target_mask] - low_idx + input_subset = input.index_select(0, row_indices) + + cluster_output = self.tail[i - 1](input_subset) + cluster_index = self.shortlist_size + i - 1 + + gather_inds.index_fill_(0, row_indices, cluster_index) + cluster_logprob = log_softmax(cluster_output, dim=1) + local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) + output.index_copy_(0, row_indices, local_logprob.squeeze(1)) + + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], " + f"but values in range [{target.min().item()}, {target.max().item()}] " + "were found. ") + + head_output = self.head(input) + head_logprob = log_softmax(head_output, dim=1) + output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() + loss = (-output).mean() + + if not is_batched: + output = output.squeeze(0) + + return _ASMoutput(output, loss) + + def _get_full_log_prob(self, input, head_output): + """Given input tensor, and output of ``self.head``, compute the log of the full distribution.""" + out = input.new_empty((head_output.size(0), self.n_classes)) + head_logprob = log_softmax(head_output, dim=1) + + out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] + + for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + cluster_output = self.tail[i](input) + cluster_logprob = log_softmax(cluster_output, dim=1) + output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1) + + out[:, start_idx:stop_idx] = output_logprob + + return out + + def log_prob(self, input: Tensor) -> Tensor: + r"""Compute log probabilities for all :math:`\texttt{n\_classes}`. + + Args: + input (Tensor): a minibatch of examples + + Returns: + log-probabilities of for each class :math:`c` + in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a + parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N, \texttt{n\_classes})` + + """ + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) + + def predict(self, input: Tensor) -> Tensor: + r"""Return the class with the highest probability for each example in the input minibatch. + + This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases. + + Args: + input (Tensor): a minibatch of examples + + Returns: + output (Tensor): a class with the highest probability for each example + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N)` + """ + head_output = self.head(input) + output = torch.argmax(head_output, dim=1) + not_in_shortlist = (output >= self.shortlist_size) + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return torch.argmax(log_prob, dim=1) + + else: + log_prob = self._get_full_log_prob(input[not_in_shortlist], + head_output[not_in_shortlist]) + output[not_in_shortlist] = torch.argmax(log_prob, dim=1) + return output diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/distance.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/distance.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf98665799e3d3f6453e1ff4a5382375ea38b74 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/distance.py @@ -0,0 +1,89 @@ +from .module import Module +from .. import functional as F + +from torch import Tensor + +__all__ = ['PairwiseDistance', 'CosineSimilarity'] + +class PairwiseDistance(Module): + r""" + Computes the pairwise distance between input vectors, or between columns of input matrices. + + Distances are computed using ``p``-norm, with constant ``eps`` added to avoid division by zero + if ``p`` is negative, i.e.: + + .. math :: + \mathrm{dist}\left(x, y\right) = \left\Vert x-y + \epsilon e \right\Vert_p, + + where :math:`e` is the vector of ones and the ``p``-norm is given by. + + .. math :: + \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. + + Args: + p (real, optional): the norm degree. Can be negative. Default: 2 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-6 + keepdim (bool, optional): Determines whether or not to keep the vector dimension. + Default: False + Shape: + - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension` + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1 + - Output: :math:`(N)` or :math:`()` based on input dimension. + If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension. + + Examples:: + >>> pdist = nn.PairwiseDistance(p=2) + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> output = pdist(input1, input2) + """ + + __constants__ = ['norm', 'eps', 'keepdim'] + norm: float + eps: float + keepdim: bool + + def __init__(self, p: float = 2., eps: float = 1e-6, keepdim: bool = False) -> None: + super().__init__() + self.norm = p + self.eps = eps + self.keepdim = keepdim + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) + + +class CosineSimilarity(Module): + r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`. + + .. math :: + \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}. + + Args: + dim (int, optional): Dimension where cosine similarity is computed. Default: 1 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-8 + Shape: + - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim` + - Input2: :math:`(\ast_1, D, \ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`, + and broadcastable with x1 at other dimensions. + - Output: :math:`(\ast_1, \ast_2)` + Examples:: + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) + >>> output = cos(input1, input2) + """ + + __constants__ = ['dim', 'eps'] + dim: int + eps: float + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + return F.cosine_similarity(x1, x2, self.dim, self.eps) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0708296d47a49f47f6eae0acc89bb556d9fe4ca --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__init__.py @@ -0,0 +1,14 @@ +from .parallel_apply import parallel_apply +from .replicate import replicate +from .data_parallel import DataParallel, data_parallel +from .scatter_gather import gather, scatter +from .distributed import DistributedDataParallel + +__all__ = ['replicate', 'scatter', 'parallel_apply', 'gather', 'data_parallel', + 'DataParallel', 'DistributedDataParallel'] + +def DistributedDataParallelCPU(*args, **kwargs): + import warnings + warnings.warn("torch.nn.parallel.DistributedDataParallelCPU is deprecated, " + "please use torch.nn.parallel.DistributedDataParallel instead.") + return DistributedDataParallel(*args, **kwargs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/comm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/comm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5218766296971bee68264e2b14f3e89b2cbd7439 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/comm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85ade4294521ece4411606863e53ec1e5886d595 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b85f9c6201b2664a5d3c1b6305aac34d2631b55 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/replicate.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..016a6fbd0c40d510d4c123923e16d62514c71c45 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/replicate.py @@ -0,0 +1,186 @@ +import torch +from ..modules import Module +from . import comm +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast +from torch._utils import _get_device_index + +from collections import OrderedDict + +if TYPE_CHECKING: + import torch.jit + import torch.jit._state + +__all__ = ['replicate'] + +def _is_script_module(module: Module) -> bool: + import torch.jit + return isinstance(module, torch.jit.ScriptModule) + + +def _is_script_method(module: Module) -> bool: + import torch.jit + return isinstance(module, torch._C.ScriptMethod) + + +def _init_script_module() -> "torch.jit.ScriptModule": + import torch.jit + return torch.jit.ScriptModule() + + +def _is_jit_enabled() -> "torch.jit._state.EnabledProxy": + import torch.jit._state + return torch.jit._state._enabled + + +# Check if we can safely replicate the module. +# there are two types of module: +# 1. python modules +# 2. ScriptModule +# +# currently a module cannot be replicated properly if the descendants of +# any ScriptModule contains python module (type 1 above) +def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool: + + # module.modules() contains module itself as the first element + def descendant_modules(module: Module) -> Iterator[Module]: + gen = module.modules() + next(gen) + return gen + + if not _is_jit_enabled(): + return True + if memo is None: + memo = set() + + # memoize visited modules + memo.add(module) + if _is_script_module(module): + memo.update(descendant_modules(module)) + return all(_is_script_module(descendant) for + descendant in descendant_modules(module)) + + for child in module.children(): + # since any unreplicatable module will cause the check to return + # False early, visited modules here can be safely ignored. + if child in memo: + continue + if not _replicatable_module(child, memo): + return False + + return True + +def _broadcast_coalesced_reshape( + tensors: Sequence[torch.Tensor], + devices: Sequence[Union[int, torch.device]], + detach: bool = False, +) -> List[List[torch.Tensor]]: + from ._functions import Broadcast + if detach: + return comm.broadcast_coalesced(tensors, devices) + else: + # Use the autograd function to broadcast if not detach + if len(tensors) > 0: + tensor_copies = Broadcast.apply(devices, *tensors) + return [tensor_copies[i:i + len(tensors)] + for i in range(0, len(tensor_copies), len(tensors))] + else: + return [] + + +T = TypeVar("T", bound=Module) + + +def replicate( + network: T, + devices: Sequence[Union[int, torch.device]], + detach: bool = False, +) -> List[T]: + if not _replicatable_module(network): + raise RuntimeError("Cannot replicate network where python modules are " + "childrens of ScriptModule") + + if not devices: + return [] + + devices = [_get_device_index(x, True) for x in devices] + num_replicas = len(devices) + + params = list(network.parameters()) + param_indices = {param: idx for idx, param in enumerate(params)} + param_copies = _broadcast_coalesced_reshape(params, devices, detach) + + buffers = list(network.buffers()) + buffers_rg: List[torch.Tensor] = [] + buffers_not_rg: List[torch.Tensor] = [] + for buf in buffers: + if buf.requires_grad and not detach: + buffers_rg.append(buf) + else: + buffers_not_rg.append(buf) + + buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} + buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} + + buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) + buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) + + modules = list(network.modules()) + module_copies: List[List[Module]] = [[] for _ in devices] + module_indices: Dict[Module, int] = {} + + for i, module in enumerate(modules): + module_indices[module] = i + for j in range(num_replicas): + replica = module._replicate_for_data_parallel() + # This is a temporary fix for DDP. DDP needs to access the + # replicated model parameters. It used to do so through + # `mode.parameters()`. The fix added in #33907 for DP stops the + # `parameters()` API from exposing the replicated parameters. + # Hence, we add a `_former_parameters` dict here to support DDP. + replica._former_parameters = OrderedDict() + + module_copies[j].append(replica) + + for i, module in enumerate(modules): + for key, child in module._modules.items(): + if child is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._modules[key] = None + else: + module_idx = module_indices[child] + for j in range(num_replicas): + replica = module_copies[j][i] + setattr(replica, key, module_copies[j][module_idx]) + for key, param in module._parameters.items(): + if param is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._parameters[key] = None + else: + param_idx = param_indices[param] + for j in range(num_replicas): + replica = module_copies[j][i] + param_copy = param_copies[j][param_idx] + # parameters in replicas are no longer leaves, + # so setattr them as non-parameter attributes + setattr(replica, key, param_copy) + # expose the parameter for DDP + replica._former_parameters[key] = param_copy + for key, buf in module._buffers.items(): # type: ignore[assignment] + if buf is None: + for j in range(num_replicas): + replica = module_copies[j][i] + replica._buffers[key] = None + else: + if buf.requires_grad and not detach: + buffer_copies = buffer_copies_rg + buffer_idx = buffer_indices_rg[buf] + else: + buffer_copies = buffer_copies_not_rg + buffer_idx = buffer_indices_not_rg[buf] + for j in range(num_replicas): + replica = module_copies[j][i] + setattr(replica, key, buffer_copies[j][buffer_idx]) + + return [cast(T, module_copies[j][0]) for j in range(num_replicas)] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19313b70c9527fabb4fd65b3e0a06989a573a1cb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/__init__.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""QAT Dynamic Modules. + +This package is in the process of being deprecated. +Please, use `torch.ao.nn.qat.dynamic` instead. +""" +from . import dynamic # noqa: F403 +from . import modules # noqa: F403 +from .modules import * # noqa: F403 + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..620b32c07ec0fa2a897bce26d88f176ae321d071 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1257b404b7346c6a96c4de3adb45c6e63564fac --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py @@ -0,0 +1,9 @@ +from torch.ao.nn.quantizable.modules.activation import MultiheadAttention +from torch.ao.nn.quantizable.modules.rnn import LSTM +from torch.ao.nn.quantizable.modules.rnn import LSTMCell + +__all__ = [ + 'LSTM', + 'LSTMCell', + 'MultiheadAttention', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50203efea4aa2267c8739bc316ef4ea58138348a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a08565925088c02f3a77db9b229fefdec16dd780 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/sparse.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..e01f4e9b14897e051e15ed0de65a2772ffd46299 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/sparse.py @@ -0,0 +1,13 @@ +# flake8: noqa: F401 +r"""Quantized Reference Modules. + +This module is in the process of migration to +`torch/ao/nn/quantized/reference`, and is kept here for +compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/reference`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.reference.modules.sparse import Embedding +from torch.ao.nn.quantized.reference.modules.sparse import EmbeddingBag diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/rnn.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a0076d13bc4e3ee29e9b3e410171d20e8e9a65 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/rnn.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.rnn import LSTM diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a6f48ad935e460704cf241cffbdaf33a0b7a5b3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01fa633bbb1cd61d6df9a197ac01b0ae0e9afe1b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9091ae15eb7312dd80a7f72d8185b0dfc79c1d3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5efed77b6c74c1b25046da8fc0d6b63a5ef29861 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..c7956a3a1b1f666708eefbec69d031af2da18592 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -0,0 +1,54 @@ +import torch +import torch.nn.functional as F +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import standard_kwargs, forward_helper, set_grad_sample_if_exists + +from typing import List, Optional + +@implements_per_sample_grads(F.embedding) +class EmbeddingPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs(kwarg_names, expanded_args_and_kwargs) + if len(expanded_args[0].shape) == 1: + raise RuntimeError(f"Expanded Weights needs an input with a batch size, got a 1D tensor, {expanded_args[0]}") + output = forward_helper(F.embedding, expanded_args, expanded_kwargs) + ctx.input, ctx.weight = expanded_args + ctx.padding_idx, ctx.scale_grad_by_freq = expanded_kwargs['padding_idx'], expanded_kwargs['scale_grad_by_freq'] + ctx.sparse = expanded_kwargs['sparse'] + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.input, ctx.weight + padding_idx, scale_grad_by_freq, sparse = ctx.padding_idx, ctx.scale_grad_by_freq, ctx.sparse + + def weight_per_sample_grad(weight): + batch_size = input.shape[0] + embedding_dim = weight.shape[1] + index = ( + input.unsqueeze(-1) + .expand(*input.shape, embedding_dim) + .reshape(batch_size, -1, embedding_dim) + ) + grad_sample = torch.zeros( + batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype + ) + return grad_sample.scatter_add_(1, index, grad_output.reshape(batch_size, -1, embedding_dim)) + + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + if input.requires_grad: + bw_fn = torch.ops.aten.embedding_backward + results.append(bw_fn(grad_output, input, weight.shape[0], padding_idx, scale_grad_by_freq, sparse)) + else: + results.append(None) + + # weight doesn't compute batched gradients; no other arguments are differentiable (2 not saved from forward) + results = results + [None] * 6 + + # set grad_sample field for weight with per sample gradients + set_grad_sample_if_exists(weight, weight_per_sample_grad) + return tuple(results) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..94e6041c6de5df13986ef329c8e13e0671326f54 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -0,0 +1,153 @@ +from contextlib import contextmanager + +import torch +import functools +from torch._decomp import decomposition_table + +from typing import Callable, Dict + +from torch.utils._pytree import tree_map_only + +HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {} + +aten = torch._ops.ops.aten +# __torch_function__ runs before the pydispatcher so we need to manually use the same +# decompositions indexed by their torch equivalent +expanded_weights_rnn_decomps = { + # func: (input_decomp, data_decomp) + torch.rnn_relu: (decomposition_table[aten.rnn_relu.input], decomposition_table[aten.rnn_relu.data]), + torch.rnn_tanh: (decomposition_table[aten.rnn_tanh.input], decomposition_table[aten.rnn_tanh.data]), + torch.lstm: (decomposition_table[aten.lstm.input], decomposition_table[aten.lstm.data]), + torch.gru: (decomposition_table[aten.gru.input], decomposition_table[aten.gru.data]), +} + +# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set +@contextmanager +def batch_second(args, kwargs): + def set_batch_second(ew): + ew.set_batch_first(False) + + def reset_batch_first(ew): + ew.set_batch_first(True) + + tree_map_only(ExpandedWeight, set_batch_second, args) + tree_map_only(ExpandedWeight, set_batch_second, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset_batch_first, args) + tree_map_only(ExpandedWeight, reset_batch_first, kwargs) + +# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch +@contextmanager +def allow_smaller_batches(args, kwargs): + def allow(ew): + ew.set_allow_smaller_batches(True) + + def reset(ew): + ew.set_allow_smaller_batches(False) + + tree_map_only(ExpandedWeight, allow, args) + tree_map_only(ExpandedWeight, allow, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset, args) + tree_map_only(ExpandedWeight, reset, kwargs) + +@contextmanager +def setup_rnn(use_input_variant, args, kwargs): + with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches(args, kwargs): + yield + + +def implements_per_sample_grads(torch_function): + @functools.wraps(torch_function) + def decorator(autograd_func): + HANDLED_FUNCTIONS[torch_function] = autograd_func + return autograd_func + return decorator + +# ExpandedWeight represents a weight (parameter) Tensor that has an expanded +# batch dimension. Operations on the ExpandedWeight Tensor act exactly like +# those without an expanded batch dimension but a call to .backward() populates +# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field +# +# ExpandedWeight has a fallback that always fails since we cannot know what the batch +# dimension of the input tensor is and therefore cannot know if this is a valid call +# +# This is a __torch_function__ object but it could have also been a Tensor Extension +# with a dispatch key. +# +# Needs to be a tensor subclass to allow reparamaterization +class ExpandedWeight(torch.Tensor): + def __init__(self, orig_weight, batch_size, loss_reduction): + self.batch_size = batch_size + self.batch_first = True + self.allow_smaller_batches = False + self.orig_weight = orig_weight + self.loss_reduction = loss_reduction + + handled_functions = HANDLED_FUNCTIONS + + def __new__(cls, orig_weight, batch_size, loss_reduction): + if not isinstance(orig_weight, torch.Tensor): + raise RuntimeError(f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}") + if not orig_weight.requires_grad: + raise RuntimeError("Can only build ExpandedWeights objects of tensors that require_grad") + ret = torch.Tensor._make_subclass(cls, orig_weight, True) + return ret + + @classmethod + def __torch_function__(cls, func, _, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in expanded_weights_rnn_decomps: + # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that + decomp_opts = expanded_weights_rnn_decomps[func] + use_input_variant = isinstance(args[2], list) # data variant uses a list here + decomp = decomp_opts[0] if use_input_variant else decomp_opts[1] + + if decomp is not None: + with setup_rnn(use_input_variant, args, kwargs): + return decomp(*args, **kwargs) + if func == torch._cudnn_rnn_flatten_weight: + # since we aren't using the fused cuda kernels for RNNs, don't do this + return + if func in cls.handled_functions: + return cls.handled_functions[func].apply(tuple(kwargs.keys()), func, *(args + tuple(kwargs.values()))) + # We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs, + # i.e. torch.add(torch.Tensor, ExpandedWeight) + raise RuntimeError(f"Expanded Weights encountered but cannot handle function {func.__name__}") + + @property + def dtype(self): + return self.orig_weight.dtype + + @property + def data(self): + return self.orig_weight.data + + @property + def shape(self): + return self.orig_weight.shape + + @property + def device(self): + return self.orig_weight.device + + @property + def is_cuda(self): + return self.orig_weight.is_cuda + + def data_ptr(self): + return self.orig_weight.data_ptr() + + def get_device(self): + return self.orig_weight.get_device() + + def set_allow_smaller_batches(self, is_allow_smaller_batches): + self.allow_smaller_batches = is_allow_smaller_batches + + def set_batch_first(self, is_batch_first=True): + self.batch_first = is_batch_first diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..c2cbae63f33651a0f44e287cb0fa6d5d4a25bc62 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py @@ -0,0 +1,44 @@ +import torch +import torch.nn.functional as F +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import \ + forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor, is_batch_first +from typing import List, Optional + +@implements_per_sample_grads(F.linear) +class LinearPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, _, __, *expanded_args_and_kwargs): + if len(expanded_args_and_kwargs[0].shape) <= 1: + raise RuntimeError("Input does not have a batch dimension. Expanded Weights expected input " + f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}") + expanded_kwargs = {'bias': expanded_args_and_kwargs[2] if len(expanded_args_and_kwargs) == 3 else None} + expanded_args = expanded_args_and_kwargs[:2] + ctx.batch_first = is_batch_first(expanded_args_and_kwargs) + output = forward_helper(F.linear, expanded_args, expanded_kwargs) + ctx.args = expanded_args + ctx.kwargs = expanded_kwargs + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.args + bias = ctx.kwargs['bias'] + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg_names + results.append(None) # for op reference + + if input.requires_grad: + results.append(grad_output.matmul(unpack_expanded_weight_or_tensor(weight))) + else: + results.append(None) + results.extend([None] * 2) # weight and bias don't compute batched gradients + + if not ctx.batch_first: + grad_output = grad_output.transpose(0, 1) + input = input.transpose(0, 1) + + # weight and bias get their grad_sample fields set directly if they exist + set_grad_sample_if_exists(weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input)) + set_grad_sample_if_exists(bias, lambda _: torch.einsum("n...k->nk", grad_output)) + return tuple(results)