diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac132d9db588e95fb3ce327344081dcdd2e7d51 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__init__.py @@ -0,0 +1 @@ +from .cond import cond diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b71dc68988563b089e10eb7478f6e6bc70f3b5d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ad79f81d14ef5c59f33abfb1cac3dde407ea6ea Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0586e9f2b38f850ec0e9115f1225d143d4666f8d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da5318a966124732ac9fe7bf6fd34bae3d6b3899 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3caeee1f4641f845b07d66db2c453abd0717b92f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8190058d0c962b71b8ea6a6f22ef133ad93b5837 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b153e4979d9a8cbad8d44018e13e72ec2253dc21 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4943130d973b45f645667e6116777fae910b6e38 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/auto_functionalize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/auto_functionalize.py new file mode 100644 index 0000000000000000000000000000000000000000..55567ac4c99e29bfca2a9e8df0048924100ebe50 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/auto_functionalize.py @@ -0,0 +1,261 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._prims_common import clone_preserve_strides +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +# NOTE: [auto-functionalizing custom ops] +# Users may wish to torch.compile custom ops that mutate their inputs. +# torch.compile will automatically support this op without anyone needing +# to provide a functionalization kernel for it. Here's how. +# +# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () +# op. First, when FakeTensor sees this op: +# - If the schema says it returns nothing, we can generate a trivial +# FakeTensor rule for it (that returns nothing). +# - Otherwise, the user needs to provide a FakeTensor rule (abstract impl) +# +# Next, when Python FunctionalTensor sees the op, it will functionalize +# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) +# HOP and replacing the mutated inputs with corresponding outputs of this HOP. +# This HOP effectively runs the functional version of the op when +# called: it clones inputs that will be mutated, runs the op, and +# then returns (output, Tensors with the new values) + + +class AutoFunctionalized(HigherOrderOperator): + """auto_functionalized(_mutable_op, **kwargs) + + This HOP runs a "functional" version of _mutable_op. + + Concretely, it looks at all the arguments that are mutable through + _mutable_op's operator schema, clones those kwargs, runs + `out = _mutable_op(**kwargs)` with the cloned values, and then returns the + operator output concatenated with the cloned values that were mutated. + + We have some restrictions on `_mutable_op`. + See `can_auto_functionalize` for the restrictions. We can likely lift + many of these if users request it. + + The reason why _mutable_op is prefixed with an + underscore is to prevent collisions with kwarg names in **kwargs. + """ + + def __init__(self): + super().__init__("auto_functionalized") + + def __call__( + self, + _mutable_op: torch._ops.OpOverload, + **kwargs: Dict[str, Any], + ) -> Tuple[Any, Tuple[Tensor, ...]]: + assert can_auto_functionalize(_mutable_op) + assert isinstance(kwargs, dict) + return super().__call__(_mutable_op, **kwargs) + + +auto_functionalized = AutoFunctionalized() + + +def can_auto_functionalize(op: torch._ops.OperatorBase) -> bool: + if not isinstance(op, torch._ops.OpOverload): + return False + + if torch._library.utils.is_builtin(op): + # We control the built-ins. These may (in rare cases) + # do input metadata mutation (which we have banned on custom ops) + return False + schema = op._schema + if not schema.is_mutable: + return False + schema = op._schema + + for arg in schema.arguments: + if arg.alias_info is None: + continue + if not arg.alias_info.is_write: + continue + if type(arg.type) is torch.TensorType: + continue + if ( + type(arg.type) is torch.OptionalType + and type(arg.type.getElementType()) is torch.TensorType + ): + continue + # Not yet supported: other Tensor types. This includes things like + # Tensor[], Tensor?[], Tensor[]?. + return False + + # The returns must not alias anything + for ret in schema.returns: + if ret.alias_info is None and type(ret.type) is torch.TensorType: + continue + # Not yet supported: List[Tensor] return. + return False + return True + + +@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_dense( + _mutable_op: torch._ops.OpOverload, + _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + new_kwargs = dict(**kwargs) + result = [] + + _mutable_args_names = get_mutable_arg_names(_mutable_op) + for name in _mutable_args_names: + if ( + _only_clone_these_tensors is not None + and name not in _only_clone_these_tensors + ): + new_kwargs[name] = kwargs[name] + else: + new_kwargs[name] = ( + clone_preserve_strides(kwargs[name]) + if kwargs[name] is not None + else None + ) + result.append(new_kwargs[name]) + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *result) # type: ignore[return-value] + else: + return (out, *result) # type: ignore[return-value] + + +@auto_functionalized.py_impl(FakeTensorMode) +def auto_functionalized_fake( + mode, + _mutable_op: torch._ops.OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_dense(_mutable_op, **kwargs) + return result + + +@auto_functionalized.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_proxy( + mode, + _mutable_op: torch._ops.OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + if not mode.enable_tracing: + return auto_functionalized(_mutable_op, **kwargs) + + with disable_proxy_modes_tracing(): + out = auto_functionalized(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +auto_functionalized.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) + + +def get_mutable_arg_names(op: torch._ops.OpOverload) -> List[str]: + """ + Returns the list of argument names that get mutated according to the + schema. + """ + mutable_args_names = [ + arg.name + for arg in op._schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + return mutable_args_names + + +def do_auto_functionalize( + op: torch._ops.OpOverload, args: Tuple[Any, ...], kwargs: Dict[str, Any] +) -> Any: + """Functionalizes a call to op(*args, **kwargs) by emitting a call to + `outs = auto_functionalized(op, normalized_kwargs)` + and replacing the mutated (args, kwargs) with the corresponding outputs. + + The normalized_kwargs are just the (args, kwargs), but all in kwarg form. + This makes handling easier for the auto_functionalized HOP. + """ + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + # All of the (args, kwargs), but all as kwargs. The names for the + # args come from the schema. This makes it easier for us to work with them. + normalized_kwargs = {} + schema = op._schema + for idx, arg in enumerate(schema.arguments): + # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema + if arg.name in kwargs: + normalized_kwargs[arg.name] = kwargs[arg.name] + elif idx < len(args): + # if its out of bounds we don't need to do anything + # as it means the the optional arg was passed with its default + # value + normalized_kwargs[arg.name] = args[idx] + else: + normalized_kwargs[arg.name] = arg.default_value + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized( + op, **unwrapped_kwargs # type: ignore[arg-type] + ) + + # List of the name of args that get mutated (according to the schema) + mutable_args_names = get_mutable_arg_names(op) + + unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[ + : -len(mutable_args_names) + ] + unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :] + + if len(op._schema.returns) == 0: + assert unwrapped_actual_out[0] is None + unwrapped_actual_out = None + elif len(op._schema.returns) == 1: + assert len(unwrapped_actual_out) == 1 + unwrapped_actual_out = unwrapped_actual_out[0] + else: + assert len(unwrapped_actual_out) == len(op._schema.returns) + + for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out): + # Can be None if input was `Tensor(a!)?` + if unwrapped_out is None: + continue + assert isinstance(unwrapped_out, torch.Tensor) + orig_arg = normalized_kwargs[name] + ctx.replace(orig_arg, unwrapped_out) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] + + +@auto_functionalized.py_functionalize_impl +def auto_functionalized_func(ctx, _mutable_op, **kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + result = auto_functionalized(_mutable_op, **unwrapped_kwargs) + return ctx.wrap_tensors(result) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/cond.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/cond.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4dba02bac1ec4f086b86e87dd8d6a6fafde40c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/cond.py @@ -0,0 +1,349 @@ +import torch +import torch._subclasses.functional_tensor + +import torch.utils._pytree as pytree + +from torch._C import DispatchKey +from torch._C._functorch import ( + _add_batch_dim, + get_unwrapped, + is_batchedtensor, + maybe_get_bdim, +) +from torch._functorch.utils import exposed_in + +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _set_compilation_env, + autograd_not_implemented, + reenter_make_fx, + UnsupportedAliasMutationException, +) + +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +@exposed_in("torch") +def cond(pred, true_fn, false_fn, operands): + r""" + Conditionally applies `true_fn` or `false_fn`. + + .. warning:: + `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `cond` is structured control flow operator. That is, it is like a Python if-statement, + but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be + capturable using torch.compile and torch.export. + + Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following:: + + def cond(pred, true_branch, false_branch, operands): + if pred: + return true_branch(*operands) + else: + return false_branch(*operands) + + Args: + pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element, + indicating which branch function to apply. + + true_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. + + false_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. The true branch and false branch must + have consistent input and outputs, meaning the inputs have to be + the same, and the outputs have to be the same type and shape. + + operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions. + + Example:: + + def true_fn(x: torch.Tensor): + return x.cos() + def false_fn(x: torch.Tensor): + return x.sin() + return cond(x.shape[0] > 4, true_fn, false_fn, (x,)) + + Restrictions: + - The conditional statement (aka `pred`) must meet one of the following constraints: + + - It's a `torch.Tensor` with only one element, and torch.bool dtype + + - It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10` + + - The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints: + + - The function signature must match with operands. + + - The function must return a tensor with the same metadata, e.g. shape, + dtype, etc. + + - The function cannot have in-place mutations on inputs or global variables. + (Note: in-place tensor operations such as `add_` for intermediate results + are allowed in a branch) + + .. warning:: + Temporal Limitations: + + - `cond` only supports **inference** right now. Autograd will be supported in the future. + + - The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future. + + """ + + if torch.compiler.is_dynamo_compiling(): + return cond_op(pred, true_fn, false_fn, operands) + + def _validate_input(pred, true_fn, false_fn, operands): + if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)): + raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.") + + if isinstance(pred, torch.Tensor) and pred.numel() != 1: + raise RuntimeError( + f"Expected pred to be bool or single-element tensor, but got {pred}." + ) + + if not callable(true_fn) or not callable(false_fn): + raise RuntimeError("Expect both branches to be callbale.") + + if not isinstance(operands, (tuple, list)) or pytree.tree_any( + lambda t: not isinstance(t, torch.Tensor), operands + ): + raise RuntimeError( + "Expect operands to be a tuple of possibly nested dict/list/tuple that only" + f"consists of tensor leaves, but got {operands}." + ) + + _validate_input(pred, true_fn, false_fn, operands) + + if not torch._dynamo.is_dynamo_supported(): + raise RuntimeError("torch.cond requires dynamo support.") + + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + return torch.compile(cond_op, backend="eager", fullgraph=True)( + pred, true_fn, false_fn, operands + ) + + +""" +We're going to define a `cond_op` operation. +In order to do this, we need implementations for each of the dispatch keys. +""" +cond_op = HigherOrderOperator("cond") + + +def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): + assert isinstance( + operands, (list, tuple) + ), "Cond operands must be a list or tuple of tensors" + assert all( + isinstance(o, torch.Tensor) for o in operands + ), "Cond operands must be a list of tensors" + + pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) + + with disable_proxy_modes_tracing(): + true_graph = reenter_make_fx(true_fn, pre_dispatch)(*operands) + false_graph = reenter_make_fx(false_fn, pre_dispatch)(*operands) + + true_outs = [] + false_outs = [] + for node in true_graph.graph.nodes: + if node.op == "output": + true_outs.extend(node.args) + + for node in false_graph.graph.nodes: + if node.op == "output": + false_outs.extend(node.args) + + flat_true_outs = pytree.arg_tree_leaves(*true_outs) + flat_false_outs = pytree.arg_tree_leaves(*false_outs) + if len(flat_true_outs) != len(flat_false_outs): + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected to return same number of outputs but got:" + f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)" + f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)" + ) + + for i in range(0, len(flat_true_outs)): + true_out = flat_true_outs[i] + false_out = flat_false_outs[i] + if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]: + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected each tensor to have same metadata but got:" + f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" + f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}" + ) + + # There are probably better ways - I know that create_arg has some self incrementing name + # magic to it, but since we explicitly have to get the name for register_module, + # I was not sure how to do that. This kinda simulates it. + next_name = None + i = 0 + while not next_name: + candidate = f"true_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + + true_name = next_name + false_name = f"false_graph_{i}" + assert not hasattr(proxy_mode.tracer.root, false_name) + + proxy_mode.tracer.root.register_module(true_name, true_graph) + proxy_mode.tracer.root.register_module(false_name, false_graph) + + args = (pred, true_graph, false_graph, operands) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="conditional" + ) + + # At this point, we're *guaranteed* that whether an output came from the + # true or false branch is indistinguishable. So, as this is just for tracing + # purposes, choose the true branch. + + # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in + # a FakeTensorMode error : + # `Current active mode not registered` + # TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in + # dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys. + out = false_fn(*operands) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def cond_op_dense(pred, true_fn, false_fn, operands): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + if pred: + return true_fn(*operands) + else: + return false_fn(*operands) + + +cond_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(cond_op, deferred_error=True) +) + + +@cond_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, pred, true_fn, false_fn, operands): + if mode.enable_tracing: + return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands) + else: + return cond_op(pred, true_fn, false_fn, operands) + + +@cond_op.py_impl(FakeTensorMode) +def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): + with mode: + true_outs = true_fn(*operands) + flat_true_outs = pytree.tree_leaves(true_outs) + flat_false_outs = pytree.tree_leaves(false_fn(*operands)) + if len(flat_true_outs) != len(flat_false_outs): + raise RuntimeError("Unmatched number of outputs from cond() branches.") + + for true_out, false_out in zip(flat_true_outs, flat_false_outs): + true_meta = _extract_tensor_metadata(true_out) + false_meta = _extract_tensor_metadata(false_out) + if true_meta != false_meta: + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected each tensor to have same metadata but got:" + f"\n {true_fn.__name__} returns {true_meta}" + f"\n {false_fn.__name__} returns {false_meta}" + ) + return true_outs + + +@cond_op.py_functionalize_impl +def cond_func(ctx, pred, true_fn, false_fn, inputs): + unwrapped_inputs = ctx.unwrap_tensors(inputs) + unwrapped_pred = ctx.unwrap_tensors(pred) + with ctx.redispatch_to_next() as m: + functional_true = ctx.functionalize(true_fn) + functional_false = ctx.functionalize(false_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + for branch in [functional_true, functional_false]: + if _has_potential_branch_input_mutation( + branch, unwrapped_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "One of torch.cond branch might be modifying the input!" + ) + for branch in [true_fn, false_fn]: + if _has_potential_branch_input_alias( + branch, unwrapped_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "One of torch.cond branch might be aliasing the input!" + ) + + cond_return = cond_op( + unwrapped_pred, functional_true, functional_false, unwrapped_inputs + ) + return ctx.wrap_tensors(cond_return) + + +@cond_op.py_impl(torch._C._functorch.TransformType.Vmap) +def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs): + assert isinstance( + inputs, (list, tuple) + ), "Cond inputs must be a list or tuple of tensors" + assert all( + isinstance(i, torch.Tensor) for i in inputs + ), "Cond inputs must be a list of tensors" + + pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred + + # unbatched tensors are not vmapped + tensors, in_dims = zip( + *[ + (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None) + for t in inputs + ] + ) + + if is_batchedtensor(pred): + # prepend "pred" and vmap everything + tensors = (pred_,) + tensors + in_dims = (0,) + in_dims + + def fn(p, *args): + t = true_fn(*args) + f = false_fn(*args) + return torch.where(p, t[0], f[0]) + + with interpreter.lower(): + result = torch.vmap(fn, in_dims=in_dims)(*tensors) + + else: + # predicate is known at this stage and it is a boolean expression or a + # tensor with one element. + true_fn = torch.vmap(true_fn, in_dims=in_dims) + false_fn = torch.vmap(false_fn, in_dims=in_dims) + + with interpreter.lower(): + result = cond_op(pred, true_fn, false_fn, tensors) + + if not isinstance(result, tuple): + result = (result,) + lvl = interpreter.level() + return tuple([_add_batch_dim(r, 0, lvl) for r in result]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/effects.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/effects.py new file mode 100644 index 0000000000000000000000000000000000000000..08c49c964631e5794e2783c3351d5f10fae29a94 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/effects.py @@ -0,0 +1,204 @@ +from enum import Enum +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class _EffectType(Enum): + ORDERED = "Ordered" + + +SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = { + torch.ops.aten._print.default: _EffectType.ORDERED, +} + + +class WithEffects(HigherOrderOperator): + """ + with_effects(token, op, args, kwargs) -> (new_token, op_results) + + This HOP helps ensure ordering between side effectful ops like prints or ops + using torchbind objects. This is needed to ensure a traced graph from + AOTAutograd is functional so that future optimization passes do not reorder + these operators. This is done through threading "effect tokens" through the + graph to enforce data dependence between side effectful ops. + + The tokens are basically dummy values (torch.tensor([])). We create a token + per "effect type", which are enumerated in the _EffectType enum. + """ + + def __init__(self): + super().__init__("with_effects") + + def __call__( + self, + token, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], + ) -> Tuple[Any, ...]: + assert isinstance(op, torch._ops.OpOverload) + assert not has_aliasing(op), "Ops with aliasing is not supported" + assert has_effects(op, args, kwargs) + assert isinstance(kwargs, dict) + return super().__call__(token, op, *args, **kwargs) + + +with_effects = WithEffects() + + +def has_aliasing(op: torch._ops.OpOverload): + for arg in op._schema.arguments: + if arg.alias_info is not None: + return True + for arg in op._schema.returns: + if arg.alias_info is not None: + return True + return False + + +def has_effects(op, args, kwargs) -> bool: + return ( + isinstance(op, torch._ops.OpOverload) + and not has_aliasing(op) + and get_effect_key(op, args, kwargs) is not None + ) + + +def get_effect_key(op, args, kwargs) -> Optional[_EffectType]: + if op in SIDE_EFFECTS: + return SIDE_EFFECTS[op] + + for arg in args: + if isinstance(arg, torch.ScriptObject): + return _EffectType.ORDERED + + return None + + +@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd) +def with_effects_dense( + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + out = op(*args, **kwargs) + new_token = torch.tensor([]) + if isinstance(out, tuple): + return (new_token, *out) + return (new_token, out) + + +@with_effects.py_impl(FakeTensorMode) +def with_effects_fake( + mode, + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + with mode: + result = with_effects_dense(token, op, *args, **kwargs) + return result + + +@with_effects.py_impl(ProxyTorchDispatchMode) +def with_effects_proxy( + mode, + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + if not mode.enable_tracing: + return with_effects(token, op, *args, **kwargs) + + with disable_proxy_modes_tracing(): + out = with_effects(token, op, *args, **kwargs) + + proxy_token = mode.tracer.unwrap_proxy(token) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + with_effects, + (proxy_token, op, *proxy_args), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +with_effects.fallthrough(DispatchKey.AutogradCPU) +with_effects.fallthrough(DispatchKey.AutogradCUDA) + + +def handle_effects( + allow_token_discovery: bool, + tokens: Dict[_EffectType, torch.Tensor], + op: torch._ops.OpOverload, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + """ + Args: + allow_token_discovery: Whether or not we are discovering tokens. If this + is true, we will create a token for every side effect type seen that + does not have a token assigned yet. If this is false, the tokens + should've all been created ahead of time, so we will error if there is + no token mapping to every effect type. + + tokens: Map of effect type to tokens. This is to chain operators of the + same effects together so that they do not get reordered in later + optimization passes. + """ + + # Get a token. We can't do `tokens.get(op, torch.tensor([]))` because + # this will create an empty tensor during proxy mode tracing if the token + # doesn't exist. But the tokens should always exist during proxy mode tracing. + key = get_effect_key(op, args, kwargs) + assert key is not None + if key not in tokens: + assert allow_token_discovery, f"Could not find a token for effect {key}" + tokens[key] = torch.tensor([]) + token = tokens[key] + + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type] + unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type] + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] + with ctx.redispatch_to_next(): + (new_token, *unwrapped_outs) = with_effects( + unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type] + ) + + if len(op._schema.returns) == 0: + assert unwrapped_outs[0] is None + unwrapped_outs = None # type: ignore[assignment] + elif len(op._schema.returns) == 1: + assert len(unwrapped_outs) == 1 + unwrapped_outs = unwrapped_outs[0] + else: + assert len(unwrapped_outs) == len(op._schema.returns) + + # Add the newly created token into the tokens map for a following call to + # use this token. + wrapped_token = ctx.wrap_tensors(new_token) + assert isinstance(wrapped_token, torch.Tensor) + tokens[key] = wrapped_token + + return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/map.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/map.py new file mode 100644 index 0000000000000000000000000000000000000000..76f4b89532c86a614a9f0179b32f2f380d0d1d73 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/map.py @@ -0,0 +1,358 @@ +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._dispatch.python import suspend_functionalization +from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun + +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + reenter_make_fx, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.functional_tensor import ( + disable_functional_mode, + FunctionalTensor, +) +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.multiprocessing.reductions import StorageWeakRef + + +# TODO: We add this to prevent dymamo from tracing into map_wrapper, +# remove the wrapper call when it's ready. +class MapWrapper(HigherOrderOperator): + def __call__(self, xs, *args): + return map_wrapper(xs, *args) + + +map = MapWrapper("map") +map_impl = HigherOrderOperator("map_impl") + +dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, +) + + +def create_fw_bw_graph(f, num_mapped_args, *args): + mapped_xs = args[:num_mapped_args] + pos_args = args[num_mapped_args:] + + # Note: We create "clean" environments for make_fx by suspending all dispatch keys + # between Autograd and Python key. Currently, we only suspend functionalization but more can be + # added when required. Will encounter two problems if we don't suspend functionalization: + # + # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper, + # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching. + # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to + # fetch the proxy for the inputs and fail to capture any operations on them. + # + # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further + # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer + # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore, + # when creating the output node, it fails to associate the wrapped tensor with its proxy. + # Instead, it will create _tensor_constant as output. + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + + def _from_fun(t): + if isinstance(t, torch.Tensor): + if t.dtype != torch.bool: + return torch.empty_strided( + t.size(), + t.stride(), + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + else: + # clone of a functional tensor produces a functional tensor + # but we want to avoid it so we clone a non-functional version + maybe_unfunc_t = t + if isinstance(t, FunctionalTensor): + torch._sync(t) + maybe_unfunc_t = from_fun(t) + elif torch._is_functional_tensor(t): + # need to handle both types of functionalization here: + # these are the tensors that came from the user, + # which could be either FunctionalTensorWrapper or FunctionalTensor + torch._sync(t) + maybe_unfunc_t = torch._from_functional_tensor(t) + return maybe_unfunc_t.clone() + return t + + unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs) + example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] + + example_pos_args = [ + _from_fun(arg) if isinstance(arg, torch.Tensor) else arg + for arg in pos_args + ] + example_flat_out = pytree.tree_map( + _from_fun, f(*example_xs, *example_pos_args) + ) + if any( + not isinstance(out, torch.Tensor) + for out in example_flat_out + if out is not None + ): + raise RuntimeError( + "Expect outputs of map only contains tensors or None. " + f"Got types {[type(out) for out in example_flat_out]}." + ) + example_grad = [_from_fun(out) for out in example_flat_out] + + fw_graph = make_fx(f)(*example_xs, *example_pos_args) + + def joint_f(*example_args): + joint_mapped_args = example_args[:joint_num_mapped] + args = example_args[joint_num_mapped:] + + mapped_input = joint_mapped_args[:num_mapped_args] + mapped_grads = joint_mapped_args[num_mapped_args:] + + def fw_with_masks(*args): + fw_out = f(*args) + return fw_out, [ + True + if isinstance(ret, torch.Tensor) and ret.requires_grad + else False + for ret in fw_out + ] + + joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) + _, grads = joint( + list(mapped_input) + list(args), + [ + grad + for grad in mapped_grads + if grad is not None and grad.requires_grad + ], + ) + + # In order to keep map functional for backward graph, + # we clone outputs that are aliasing inputs + input_storage = { + StorageWeakRef(arg._typed_storage()) + for arg in example_args + if isinstance(arg, torch.Tensor) + } + + def maybe_clone(t): + if ( + isinstance(t, torch.Tensor) + and StorageWeakRef(t._typed_storage()) in input_storage + ): + return t.clone() + return t + + return pytree.tree_map(maybe_clone, grads) + + joint_num_mapped = len(example_grad) + len(example_xs) + joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) + return fw_graph, joint_graph + + +def map_wrapper(f, xs, *args): + flat_xs, xs_spec = pytree.tree_flatten(xs) + if not all(isinstance(t, torch.Tensor) for t in flat_xs): + raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.") + + num_mapped_args = len(flat_xs) + shapes = [xs.shape for xs in flat_xs] + leading_dim_size = shapes[0][0] + if leading_dim_size == 0: + raise RuntimeError("Leading dimensions of mapped xs cannot be 0.") + + if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): + raise RuntimeError( + f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}." + ) + + out_spec = None + + def flat_fn(*flat_args): + xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec) + unflattened_out = f(xs, *flat_args[num_mapped_args:]) + flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out) + + nonlocal out_spec + out_spec = tmp_out_spec + return flat_out + + return pytree.tree_unflatten( + map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type] + ) + + +class MapAutogradOp(torch.autograd.Function): + @staticmethod + def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): + ctx.save_for_backward(*flat_args) + ctx._joint_graph = joint_graph + ctx._num_mapped_args = num_mapped_args + with torch._C._AutoDispatchBelowAutograd(): + return ( + *map_impl( + fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:] + ), + ) + + @staticmethod + def backward(ctx, *flat_grads): + fw_args = ctx.saved_tensors + fw_mapped_args = fw_args[: ctx._num_mapped_args] + pos_args = fw_args[ctx._num_mapped_args :] + + grads = map_impl( + ctx._joint_graph, + fw_mapped_args + flat_grads, + pos_args, + ) + return None, None, None, *grads + + +def trace_map(proxy_mode, func_overload, f, xs, pos_args): + leading_dim_size = xs[0].shape[0] + + example_input = _unstack_pytree(xs)[0] + body_graph = f + + pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) + body_graph = reenter_make_fx(body_graph, pre_dispatch)(*example_input, *pos_args) + + next_name = None + i = 0 + while not next_name: + candidate = f"body_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + + proxy_mode.tracer.root.register_module(next_name, body_graph) + + with disable_proxy_modes_tracing(): + example_outs = body_graph(*example_input, *pos_args) + + def expand_tensor(t): + if isinstance(t, torch.Tensor): + return t.expand(leading_dim_size, *t.shape) + return t + + expanded_outs = pytree.tree_map(expand_tensor, example_outs) + + node_args = (body_graph, list(xs), list(pos_args)) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="map_impl" + ) + return track_tensor_tree( + expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +def _unstack_pytree(xs): + flat_xs, inspec = pytree.tree_flatten(xs) + if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): + raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") + + if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): + raise RuntimeError( + f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" + ) + + a = zip(*flat_xs) + + pytrees = [] + for tuple in a: + pytrees.append(pytree.tree_unflatten(tuple, inspec)) + return pytrees + + +def _stack_pytree(pytrees): + flat_out = [] + out_spec = None + for pt in pytrees: + flat_pt, out_spec = pytree.tree_flatten(pt) + flat_out.append(flat_pt) + assert out_spec is not None + b = zip(*flat_out) + stacked_out = [] + for leaves in b: + if all(isinstance(leaf, torch.Tensor) for leaf in leaves): + stacked_out.append(torch.stack(leaves)) + elif all(leaf is None for leaf in leaves): + # Backward graph can return None output when forward inputs doesn't require grad. + # When we eagerly execute backward graph, we need to call _stack_pytree on its output, + # therefore we need to deal with None output. + stacked_out.append(None) # type: ignore[arg-type] + else: + raise RuntimeError(f"Cannot stack {leaves}.") + return pytree.tree_unflatten(stacked_out, out_spec) + + +@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) +def map_dense(f, xs, pos_args): + pytrees = [] + for inp in _unstack_pytree(xs): + pytrees.append(f(*inp, *pos_args)) + return _stack_pytree(pytrees) + + +@map_impl.py_impl(DispatchKey.Autograd) +def map_autograd(f, xs, pos_args): + num_mapped_args = len(xs) + fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) + flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args) + return flat_out + + +@map_impl.py_impl(ProxyTorchDispatchMode) +def map_proxy_torch_dispatch_mode(mode, f, xs, args): + if mode.enable_tracing: + return trace_map(mode, map_impl, f, xs, args) + else: + return map_impl(f, xs, args) + + +@map_impl.py_impl(FakeTensorMode) +def map_fake_tensor_mode(mode, f, xs, args): + with mode: + return map_dense(f, xs, args) + + +@map_impl.py_functionalize_impl +def map_functionalize(ctx, f, xs, pos_args): + unwrapped_xs = ctx.unwrap_tensors(xs) + unwrapped_args = ctx.unwrap_tensors(pos_args) + wrapped_fn = ctx.functionalize(f) + + with ctx.redispatch_to_next(): + with disable_proxy_modes_tracing(): + example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + if _has_potential_branch_input_mutation( + f, example_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException("torch.map is mutating the input!") + + if _has_potential_branch_input_alias( + f, example_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException("torch.map is aliasing the input!") + + map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args) + return ctx.wrap_tensors(map_return) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/strict_mode.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/strict_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..57e319230a4abc1bb2dec0763f098023a43dda40 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/strict_mode.py @@ -0,0 +1,100 @@ +import torch +import torch._subclasses.functional_tensor + +import torch.utils._pytree as pytree + +from torch._C import DispatchKey +from torch._functorch.utils import exposed_in + +from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +@exposed_in("torch") +def strict_mode(callable, operands): + if torch.compiler.is_dynamo_compiling(): + return strict_mode_op(callable, operands) + + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + return torch.compile(strict_mode_op, backend="eager", fullgraph=True)( + callable, operands + ) + + +strict_mode_op = HigherOrderOperator("strict_mode") + + +@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def strict_mode_op_dense(callable, operands): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return callable(*operands) + + +strict_mode_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(strict_mode_op, deferred_error=True) +) + + +@strict_mode_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, callable, operands): + if mode.enable_tracing: + return trace_strict_mode(mode, strict_mode_op, callable, operands) + else: + return strict_mode_op(callable, operands) + + +def trace_strict_mode(mode, strict_mode_op, callable, operands): + pre_dispatch = getattr(mode, "pre_dispatch", False) + + with disable_proxy_modes_tracing(): + graph = make_fx(callable, pre_dispatch=pre_dispatch)(*operands) + + next_name = None + i = 0 + while not next_name: + candidate = f"strict_graph_{i}" + if hasattr(mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + + graph_name = next_name + mode.tracer.root.register_module(graph_name, graph) + + args = (graph, operands) + + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + + out_proxy = mode.tracer.create_proxy( + "call_function", strict_mode_op, proxy_args, {}, name="strict_mode" + ) + + out = graph(*operands) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + +@strict_mode_op.py_impl(FakeTensorMode) +def strict_mode_fake_tensor_mode(mode, callable, operands): + with mode: + true_outs = callable(*operands) + return true_outs + + +@strict_mode_op.py_functionalize_impl +def strict_mode_func(ctx, callable, inputs): + unwrapped_inputs = ctx.unwrap_tensors(inputs) + with ctx.redispatch_to_next(): + functional_callable = ctx.functionalize(callable) + + cond_return = strict_mode_op(functional_callable, unwrapped_inputs) + return ctx.wrap_tensors(cond_return) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/torchbind.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/torchbind.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca866ee3d8b99dbd4b73b9790845eff995b23e9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/torchbind.py @@ -0,0 +1,94 @@ +from contextlib import contextmanager + +import torch +from torch._C import DispatchKey # @manual +from torch._functorch._aot_autograd.utils import KNOWN_TYPES +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.node import has_side_effect +from torch.utils import _pytree as pytree + +# The call_torchbind operator represents a method invocation on a torchbind +# object. The calling convention is: +# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs) +# We do not expect users to write this operator directly. Instead it will be +# emitted by Dynamo when tracing encounters a torchbind object. +call_torchbind = HigherOrderOperator("call_torchbind") + +# Register this operator as side-effectful with FX. +# TODO: this is not really sufficient. While passes (hopefully) check +# Node.is_impure() and make good decisions, we also assume we can execute the +# graph as many times as we want without changing behavior, which is NOT true of +# ops that mutate torchbind object state. +has_side_effect(call_torchbind) + +_orig_scriptmethod_call = torch.ScriptMethod.__call__ + + +def torchbind_method_redispatch(self, *args, **kwargs): + if isinstance(self.raw_owner, torch.ScriptObject): + return call_torchbind(self.raw_owner, self.name, *args, **kwargs) + return _orig_scriptmethod_call(self, *args, **kwargs) + + +@contextmanager +def enable_torchbind_tracing(): + """Context manager that acts as a feature flag to enable torchbind tracing + behavior. Once torchbind tracing has been stabilized, we can remove this and + turn it always on. + """ + try: + KNOWN_TYPES.append(torch.ScriptObject) + torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign] + yield + finally: + assert ( + KNOWN_TYPES.pop() is torch.ScriptObject + ), "Someone else messed with KNOWN_TYPES during tracing, exploding." + torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign] + + +@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd) +def call_torchbind_impl(obj, method, *args, **kwargs): + return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs) + + +@call_torchbind.py_impl(ProxyTorchDispatchMode) +def inner(mode, *args, **kwargs): + if mode.enable_tracing: + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + call_torchbind, + proxy_args, + proxy_kwargs, + ) + out = call_torchbind_impl(*args, **kwargs) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + else: + return call_torchbind(*args, **kwargs) + + +# TODO: currently we just run the C++ implementation with fake tensors. +# But we should make it possible to register a fake torchbind implementation. +@call_torchbind.py_impl(FakeTensorMode) +def call_torchbind_fake(mode, *args, **kwargs): + with mode: + return call_torchbind_impl(*args, **kwargs) + + +call_torchbind.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(call_torchbind, deferred_error=True) +) + + +@call_torchbind.py_functionalize_impl +def call_torchbind_func(ctx, *args, **kwargs): + args = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + return ctx.wrap_tensors(call_torchbind(*args, **kwargs)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..89b94561affddf51b5992ab30d52c9c930668a3e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py @@ -0,0 +1,842 @@ +import dataclasses +import logging +import threading +import warnings +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._prims_common import clone_preserve_strides +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + +log = logging.getLogger("torch._dynamo") + + +############################################################################### +# Kernel Side Table + + +# We cannot put Triton Kernels into the FX graph as the graph nodes +# do not support arbitrary functions. +# Use a side table. +# We use two dicts so that fetching both the kernel and id are O(1) +class KernelSideTable: + id_to_kernel: Dict[int, Any] = dict() + kernel_to_id: Dict[Any, int] = dict() + lock = threading.Lock() + + # Returns index on the table + def add_kernel(self, kernel) -> int: + with self.lock: + if kernel in self.kernel_to_id: + return self.kernel_to_id[kernel] + + idx = len(self.id_to_kernel) + self.id_to_kernel[idx] = kernel + self.kernel_to_id[kernel] = idx + return idx + + # Returns the triton kernel at the given index + def get_kernel(self, idx: int): + # No need to lock here as fetching from dict is atomic + assert idx in self.id_to_kernel + return self.id_to_kernel[idx] + + # Resets the table (only meant to be used in unit tests) + # This is only safe assuming single threaded execution + def reset_table(self) -> None: + self.id_to_kernel = dict() + self.kernel_to_id = dict() + + +kernel_side_table = KernelSideTable() + + +############################################################################### +# Mutation Tracker + + +@dataclasses.dataclass(frozen=True) +class Param: + idx: int + + +@dataclasses.dataclass(frozen=True) +class Intermediate: + idx: int + + def fake(self): + return self.idx < 0 + + +@dataclasses.dataclass(frozen=True) +class Op: + name: str + fn_call_name: Optional[str] + args: List[Union[Param, Intermediate]] + ret: Intermediate = dataclasses.field(repr=False) + + def __post_init__(self): + if self.name == "tt.call": + assert self.fn_call_name is not None + else: + assert self.fn_call_name is None + + +def generate_ttir(kernel, kwargs): + """ + Uses Triton's internal code generation to create TTIR + """ + from triton.compiler.compiler import ASTSource + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + import torch + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(kernel, Autotuner): + if len(kernel.configs) > 0: + # If we are autotuning, then it doesn't matter which version gets + # picked for tracing purposes, so lets pick the first one + kwargs = {**kwargs, **kernel.configs[0].kwargs} + kernel = kernel.fn + + assert isinstance(kernel, JITFunction) + + if len(kwargs) != len(kernel.arg_names): + raise Exception("Incorrect number of arguments passed to kernel") + + # Replace all SymExprs with a regular value for TTIR generation + # Replace all FakeTensor with real tensors + # These replacements are needed for triton's type, key and config functions + ordered_args: Dict[str, Any] = {} + for name in kernel.arg_names: + a = kwargs[name] + if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)): + ordered_args[name] = 2 + elif isinstance(a, FakeTensor): + ordered_args[name] = torch.empty(2, dtype=a.dtype) + else: + ordered_args[name] = a + + ordered_tensor_names = [ + name for name, arg in ordered_args.items() if isinstance(arg, Tensor) + ] + specialization = kernel._get_config(*ordered_args.values()) + constants = { + i: arg + for i, arg in enumerate(ordered_args.values()) + if not isinstance(arg, Tensor) + } + + # Build kernel signature -- doesn't include constexpr arguments. + signature = { + i: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(ordered_args.values()) + if i not in kernel.constexprs + } + + def get_backend(): + from triton.compiler.backends.cuda import CUDABackend + from triton.runtime.driver import driver + + target = driver.get_current_target() + return CUDABackend(target) + + backend = get_backend() + + options = backend.parse_options(dict()) + # triton._C.libtriton.triton.ir.load_dialects(context) + # backend.load_dialects(context) + + src = ASTSource(kernel, signature, constants, specialization) + ttir_module = src.make_ir(options) + if not ttir_module.verify(): + raise Exception("Verification for TTIR module has failed") + + return ttir_module, ordered_tensor_names + + +def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]: + """ + Walk the `ttir_module` bottom up to mine the `functions` from + the structured MLIR entities representing the Triton kernel + (mlir::Operation, mlir::Block, mlir::Region). + """ + functions: Dict[str, Dict[Intermediate, List[Op]]] = {} + + # block id --> op result (Intermediate) --> one or more ops + op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict( + lambda: defaultdict(list) + ) + region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list) + block_id_to_block_arg_ids: Dict[int, List[int]] = {} + replacements: Dict[int, Union[Intermediate, Param]] = {} + reindex_map: Dict[int, int] = {} + next_fake_intermediate = 0 + + def reindex(idx): + if idx not in reindex_map: + reindex_map[idx] = len(reindex_map) + return reindex_map[idx] + + def mlir_to_functions(op) -> None: + name: str = op.get_name() + if name == "builtin.module": + # this wraps all tt.func ops + return + + operand_ids: List[int] = [ + reindex(op.get_operand(i).id()) for i in range(op.get_num_operands()) + ] + result_ids: List[int] = [ + reindex(op.get_result(i).id()) for i in range(op.get_num_results()) + ] + + child_block_ids: List[int] = [] + for i in [op.get_region(i).id() for i in range(op.get_num_regions())]: + # as the walk is bottom-up, the region_id_to_block_ids[i] + # must be populated by the time we process the enclosing op + child_block_ids.extend(region_id_to_block_ids[i]) + + parent_block_id = -1 + parent_block = op.get_block() + if parent_block is not None: + parent_block_id = parent_block.id() + if parent_block_id not in block_id_to_block_arg_ids: + block_id_to_block_arg_ids[parent_block_id] = [] + for i in range(parent_block.get_num_arguments()): + block_id_to_block_arg_ids[parent_block_id].append( + reindex(parent_block.get_argument(i).id()), + ) + # the region info is collected via ops' parent blocks to be + # used later when the region's encloding op is traversed + parent_region = parent_block.get_parent() + if parent_region is not None: + region_id_to_block_ids[parent_region.id()].append(parent_block_id) + + nonlocal next_fake_intermediate + + if name == "tt.func": + # for function ops: gather and inline + # the ops from all child blocks + fn_ops = defaultdict(list) + for child_block_id in child_block_ids: + for result, block_fn_ops in op_stack.pop(child_block_id).items(): + for block_fn_op in block_fn_ops: + fn_ops[result].append(block_fn_op) + + # replace the corresponding Intermediates in the + # child op args with the function args (Params) + for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]): + replacements[idx] = Param(i) + + for fn_op_list in fn_ops.values(): + for fn_op in fn_op_list: + for i in range(len(fn_op.args)): + arg = fn_op.args[i] + if isinstance(arg, Intermediate) and arg.idx in replacements: + fn_op.args[i] = replacements[arg.idx] + + # next function capture starts + # with empty replacements + replacements.clear() + + fn_name = op.get_str_attr("sym_name") + functions[fn_name] = fn_ops + elif child_block_ids: + if name in ("scf.if", "scf.for", "scf.while"): + # for blocked control flow ops: inline the enclosed + # ops into the parent block + rewire the last op in + # each child block (yield) to return the scf result + yield_ops = [] + for block_id in child_block_ids: + # the block args used as operands of the ops in the block + # (and nested blocks inlined in the current block by now) + # are replaced by new fake Intermediates to avoid "this + # operand is not returned by anything other op in the fn" + # error in the downstream analysis + for idx in block_id_to_block_arg_ids[block_id]: + next_fake_intermediate -= 1 + replacements[idx] = Intermediate(next_fake_intermediate) + + if block_id in op_stack: + block_ops = op_stack.pop(block_id) + if not block_ops: + continue + last_ret, last_ops = block_ops.popitem() + if all(op.name == "scf.yield" for op in last_ops): + # if last_ops are scf.yield, treat them separately + yield_ops.extend(last_ops) + else: + # otherwise, return last_ops to the block + block_ops[last_ret] = last_ops + for op_result, child_ops in block_ops.items(): + op_stack[parent_block_id][op_result].extend(child_ops) + + scf_results = [Intermediate(idx) for idx in result_ids] + for scf_result in scf_results: + for yield_op in yield_ops: + op_stack[parent_block_id][scf_result].append(yield_op) + else: + # TODO(oulgen): add support for tt.reduce + raise Exception( + f"Unknown blocked function: {name}. Can't capture the TTIR." + ) + else: + callee = None + if name == "tt.call": + callee = op.get_flat_symbol_ref_attr("callee") + args: List[Union[Param, Intermediate]] = [ + Intermediate(operand) for operand in operand_ids + ] + block_ops = op_stack[parent_block_id] + if result_ids: + for result_id in result_ids: + res = Intermediate(result_id) + block_ops[res].append(Op(name, callee, args, res)) + else: + next_fake_intermediate -= 1 + fake_res = Intermediate(next_fake_intermediate) + block_ops[fake_res].append(Op(name, callee, args, fake_res)) + + ttir_module.walk(mlir_to_functions) + + return functions + + +def parse_ttir(ttir, kwargs): + """ + Given a Triton emitted TTIR text, this function lexes and parses the + code using a minimal grammar defined inside. During the lexing/parsing, + we drop any constant value and type information as they are not + necessary to us. + Being able to choose what we need makes this not a general purpose TTIR + parser which further makes parsing much simpler. + """ + # TODO(oulgen): + # - Support closures (e.g. "tt.reduce") + + try: + import lark # type: ignore[import-not-found] + from lark import Lark, Transformer, v_args + except ModuleNotFoundError: + warnings.warn( + "Using slow path for user-defined Triton kernels. `pip install lark` to fix this." + ) + raise + + # Ops looks like one of the following forms: + # + # %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + # tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32> + # %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950 + grammar = """ + start: (module_block | loc_line)+ + + loc_line: "#loc" /.+/ NEWLINE + + module_block: "module" "{" func_block+ "}" LOC + + func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func + + ?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt + + if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if + for: [assign_lhs "="] "scf.for" args rest stmt* "}" divisibility_annot? LOC -> process_for + while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while + + condition_stmt: "scf.condition" "(" arg ")" args rest + label_stmt: LABEL ":" "// pred:" LABEL + | LABEL "(" /.+/ NEWLINE + cf_stmt: "cf" "." NAME /.+/ NEWLINE + + op: OP_NAME LOC + | [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op + + ?rest: (":" | "{" | "\\"" | "->" | "<" | "=") /.+/ NEWLINE + divisibility_annot: "{" "tt.divisibility_arg1" /[^}]+/ "}" + + args: | "(" ")" | "("? arg ("," arg)* ")"? + + ?arg: INTERMEDIATE + | INTERMEDIATE_CONSTANT + | CONSTANT + | PARAM + | "[" args "]" + | arg_with_index + + ?arg_with_index: arg "#" DIGIT+ + + ?assign_lhs: (INTERMEDIATE | INTERMEDIATE_CONSTANT) [":" DIGIT+] + + PARAM.5: "%arg" DIGIT+ + INTERMEDIATE.4: "%" DIGIT+ + INTERMEDIATE_CONSTANT.3: "%" NAME + CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")? + LABEL: "^bb" DIGIT+ + + NAME: (LETTER | DIGIT | "_")+ + NON_CF_NAME: /(?!(cf))/ NAME + FN_NAME: "@" (NAME | ESCAPED_STRING) + OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""? + + LOC.5: "loc(#loc" DIGIT* ")" + + %import common.LETTER + %import common.DIGIT + %import common.WS + %import common.NEWLINE + %import common.ESCAPED_STRING + %import common.FLOAT + %ignore WS + """ + + next_fake_intermediate = 0 + + def convert(token): + if isinstance(token, lark.tree.Tree): + if token.data == "args": + res = [] + for a in token.children: + c = convert(a) + if isinstance(c, list): + res.extend(c) + else: + res.append(c) + return res + elif token.data in {"assign_lhs", "arg_with_index"}: + # Drop length/index qualifier + return convert(token.children[0]) + else: + raise AssertionError(f"Tree node with {token.data}") + + if token is None or ( + isinstance(token, lark.lexer.Token) + and token.type in ("CONSTANT", "INTERMEDIATE_CONSTANT") + ): + nonlocal next_fake_intermediate + next_fake_intermediate -= 1 + return Intermediate(next_fake_intermediate) + + assert isinstance(token, lark.lexer.Token) + + if token.type == "INTERMEDIATE": + return Intermediate(int(token.value[len("%") :])) + if token.type == "PARAM": + return Param(int(token.value[len("%arg") :])) + + raise AssertionError(f"{type(token.type)} => {token.value} invalid") + + # In alternative representation, function names are quoted. + # It should be possible to move this into the grammar alltogether. + def convert_name(token): + if token is None: + return None + s = token.value + if len(s) > 2 and s[0] == '"' and s[-1] == '"': + return s[1:-1] + return s + + functions: Dict[str, Dict[Intermediate, List[Op]]] = {} + + def extend_dict_list(d1, d2): + for key, values in d2.items(): + d1[key].extend(values) + + @v_args(inline=True) + class TransformOps(Transformer): + def process_op(self, ret, op_name, fn_name, args, *rest): + return Op( + convert_name(op_name), + convert_name(fn_name), + convert(args), + convert(ret), + ) + + def process_func(self, name, _args, *stmts): + ops: Dict[Intermediate, List[Op]] = defaultdict(list) + for e in stmts: + if isinstance(e, Op): + ops[e.ret].append(e) + elif isinstance(e, dict): + extend_dict_list(ops, e) + functions[name.value] = ops + + def _process_scf(self, ret, stmts): + ret = convert(ret) + ops: Dict[Intermediate, List[Op]] = defaultdict(list) + for e in stmts: + if isinstance(e, Op): + if e.name == "scf.yield": + ops[ret].append(Op(e.name, None, e.args, ret)) + else: + ops[e.ret].append(e) + elif isinstance(e, dict): + extend_dict_list(ops, e) + return ops + + def process_if(self, ret, _args, _rest, *stmts): + return self._process_scf(ret, stmts) + + def process_for(self, ret, _args, _rest, *stmts): + return self._process_scf(ret, stmts) + + def process_while(self, ret, _args, _rest, *stmts): + return self._process_scf(ret, stmts) + + parser = Lark( + grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps() + ) + parser.parse(ttir) + return functions + + +class MemoizeWithCycleCheck: + def __init__(self, fn): + self.fn = fn + self.reset() + + def __call__(self, functions, fn_name, num_args): + key = (fn_name, num_args) + if key not in self.cache: + self.cache[key] = None + self.cache[key] = self.fn(functions, fn_name, num_args) + if self.cache[key] is None: + raise Exception("Recursion is not supported") + return self.cache[key] + + def reset(self): + self.cache = {} + + +@MemoizeWithCycleCheck +def analyze_kernel_mutations(functions, fn_name, num_args): + """ + Analyzes the graph to detect all sinks from a predefined list of sinks + by using triton's MemWrite trait list. NOTE: What if triton exposed this? + From each sink, it traverses the CFG backwards to identify all the input + pointers that are mutated. + """ + # Name of mutation op to mutated parameter indices + # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td + # All the OPs that have MemWrite trait. + # What if Triton exposed this? + MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]} + # Ops that we want to bail out on + UNKNOWN_OPS = {"tt.elementwise_inline_asm"} + + stack: List[Union[Param, Intermediate]] = [] + visited = set() + ops = functions[fn_name] + for op_list in ops.values(): + for op in op_list: + if op.name in UNKNOWN_OPS: + raise Exception( + f"ttir analysis hit an op we do not know how to analyze: {op.name}" + ) + + if op.name == "tt.call": + assert op.fn_call_name in functions + mutations = analyze_kernel_mutations( + functions, op.fn_call_name, len(op.args) + ) + stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated) + else: + for idx in MUTATION_OPS.get(op.name, []): + stack.append(op.args[idx]) + + # The following is an iterative DFS algorithm + mutated = [False] * num_args + while stack: + arg = stack.pop() + if arg in visited: + continue + + visited.add(arg) + + if isinstance(arg, Param): + if arg.idx >= num_args: + # This is an argument defined in the kernel, not passed in + continue + mutated[arg.idx] = True + elif isinstance(arg, Intermediate) and not arg.fake(): + for op in ops[arg]: + # Skip arguments to load + if op.name != "tt.load": + stack.extend(op.args) + return mutated + + +def identify_mutated_tensors(kernel, kwargs): + """ + Given a triton kernel and the arguments for this kernel, this function + 1) Retrieves the TTIR converted version of the kernel from Triton's API. + 2) Parses the TTIR and creates a control flow graph + 3) Analyzes the graph to detect all input tensor mutations + """ + + ttir_module = None + functions = None + try: + from torch._dynamo import config + + if not config.optimize_user_defined_triton_kernels: + raise Exception("optimize_user_defined_triton_kernels is False") + + ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs) + + # extract functions from TTIR + if hasattr(ttir_module, "walk"): + # use MLIR bindings exposed by Triton code + functions = ttir_to_functions(ttir_module) + else: + # parse string representation of Triton IR + functions = parse_ttir(str(ttir_module), kwargs) + + assert functions is not None + kernel_name = next(iter(functions.keys())) + # Triton codegen modifies the name + assert kernel.fn.__name__ in kernel_name + # Reset the cache between top level invocations + # The cache for analyze kernel mutations is mainly used for cycle + # detection, so each top level invocation needs a clean cache + analyze_kernel_mutations.reset() + mutations = analyze_kernel_mutations( + functions, kernel_name, len(ordered_tensor_names) + ) + + return [ + ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated + ] + except Exception as e: + import traceback + + warnings.warn( + "Encountered an exception in identify_mutated_tensors, " + "assuming every input is mutated:\n" + "".join( + traceback.TracebackException.from_exception(e).format() # noqa: G001 + ) + ) + if ttir_module is not None: + log.debug("TTIR:\n%s", str(ttir_module)) + if functions is not None: + log.debug("functions:") + for name, fn in functions.items(): + log.debug("===\t%s\t===", name) + for ret, ops in fn.items(): + log.debug("%s\t=>\t%s", ret, ops) + return [key for key, value in kwargs.items() if isinstance(value, Tensor)] + + +############################################################################### +# Triton Kernel Wrappers + + +# Used for wrapping a Triton Kernel +class TritonKernelWrapperMutation(HigherOrderOperator): + def __init__(self): + super().__init__("triton_kernel_wrapper_mutation") + + +triton_kernel_wrapper_mutation = TritonKernelWrapperMutation() + + +# Used for wrapping a Triton Kernel in a functional manner +class TritonKernelWrapperFunctional(HigherOrderOperator): + def __init__(self): + super().__init__("triton_kernel_wrapper_functional") + + +triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() + + +@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd) +def triton_kernel_wrapper_mutation_dense(*, kernel_idx, grid, kwargs): + from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code + + kernel = kernel_side_table.get_kernel(kernel_idx) + + if len(grid) == 1: + grid_fn = grid[0] + else: + fn_name, code = user_defined_kernel_grid_fn_code( + kernel.fn.__name__, kernel.configs, grid + ) + namespace: Dict[str, Any] = {} + exec(code, namespace) + grid_fn = namespace[fn_name] + + kernel[grid_fn](**kwargs) + + +@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode) +def triton_kernel_wrapper_mutation_fake_tensor_mode(mode, *, kernel_idx, grid, kwargs): + with mode: + return None + + +def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args): + with disable_proxy_modes_tracing(): + out = func_overload(**node_args) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", + func_overload, + (), + proxy_args, + name=func_overload.__name__ + "_proxy", + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode) +def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( + mode, *, kernel_idx, grid, kwargs +): + if mode.enable_tracing: + trace_triton_kernel_wrapper( + mode, + triton_kernel_wrapper_mutation, + {"kernel_idx": kernel_idx, "grid": grid, "kwargs": kwargs}, + ) + else: + triton_kernel_wrapper_mutation(kernel_idx=kernel_idx, grid=grid, kwargs=kwargs) + + return None + + +@triton_kernel_wrapper_mutation.py_functionalize_impl +def triton_kernel_wrapper_mutation_functionalize(ctx, kernel_idx, grid, kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + kernel = kernel_side_table.get_kernel(kernel_idx) + # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each + # other, and one gets mutated in kernel, and later another gets mutated, + # they are no longer equal. Fix this by graph breaking on this condition + # earlier in dynamo. + tensors_to_clone = identify_mutated_tensors(kernel, unwrapped_kwargs) + with ctx.redispatch_to_next(): + unwrapped_outputs = triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + grid=grid, + kwargs=unwrapped_kwargs, + tensors_to_clone=tensors_to_clone, + ) + + assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys())) + for key, output_arg in unwrapped_outputs.items(): + if not isinstance(output_arg, Tensor): + continue + input_arg = kwargs[key] + assert isinstance(input_arg, Tensor) + + ctx.replace(input_arg, output_arg) + # indicate that above replace is hidden from autograd + ctx.mark_mutation_hidden_from_autograd(input_arg) + ctx.commit_update(input_arg) + ctx.sync(input_arg) + # sync calls replace_ under the hood, so again indicate that + # this indirect replace is hidden from autograd + ctx.mark_mutation_hidden_from_autograd(input_arg) + return None + + +@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd) +def triton_kernel_wrapper_functional_dense( + *, kernel_idx, grid, kwargs, tensors_to_clone +): + # TODO(oulgen): For performance reasons, we want to ensure that these + # `clone_preserve_strides` calls are never executed at runtime + # (inductor should always optimize them away). + # Requires https://github.com/pytorch/pytorch/issues/109240 + kwargs = { + key: (clone_preserve_strides(val) if key in tensors_to_clone else val) + for key, val in kwargs.items() + } + triton_kernel_wrapper_mutation(kernel_idx=kernel_idx, grid=grid, kwargs=kwargs) + return {key: val for key, val in kwargs.items() if key in tensors_to_clone} + + +@triton_kernel_wrapper_functional.py_impl(FakeTensorMode) +def triton_kernel_wrapper_functional_fake_tensor_mode( + mode, *, kernel_idx, grid, kwargs, tensors_to_clone +): + # TODO(oulgen): For performance reasons, we want to ensure that these + # `clone_preserve_strides` calls are never executed at runtime + # (inductor should always optimize them away). + # Requires https://github.com/pytorch/pytorch/issues/109240 + with mode: + return { + key: clone_preserve_strides(val) + for key, val in kwargs.items() + if key in tensors_to_clone + } + + +@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode) +def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode( + mode, *, kernel_idx, grid, kwargs, tensors_to_clone +): + if mode.enable_tracing: + return trace_triton_kernel_wrapper( + mode, + triton_kernel_wrapper_functional, + { + "kernel_idx": kernel_idx, + "grid": grid, + "kwargs": kwargs, + "tensors_to_clone": tensors_to_clone, + }, + ) + else: + return triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + grid=grid, + kwargs=kwargs, + tensors_to_clone=tensors_to_clone, + ) + + +@triton_kernel_wrapper_functional.py_functionalize_impl +def triton_kernel_wrapper_functional_functionalize( + ctx, kernel_idx, grid, kwargs, tensors_to_clone +): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + outputs = triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + grid=grid, + kwargs=unwrapped_kwargs, + tensors_to_clone=tensors_to_clone, + ) + return ctx.wrap_tensors(outputs) + + +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU) + +triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/while_loop.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/while_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee4b51f2da6975b50b7df7ba0f3555bd77b21fe --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/while_loop.py @@ -0,0 +1,232 @@ +import torch +import torch.utils._pytree as pytree + +from torch._C import DispatchKey + +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _set_compilation_env, + autograd_not_implemented, + reenter_make_fx, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class WhileLoopOp(HigherOrderOperator): + def __call__(self, cond_fn, body_fn, operands): + if not isinstance(cond_fn, torch.fx.GraphModule) or not isinstance( + body_fn, torch.fx.GraphModule + ): + raise RuntimeError( + "cond_fn and body_fn must be torch.fx.GraphModule, got " + f"{type(cond_fn)} and {type(body_fn)}" + ) + if not isinstance(operands, tuple): + raise RuntimeError("operands must be a tuple, got " f"{type(operands)}") + if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in operands): + raise RuntimeError( + "operands must be a tuple of tensors, ints, floats, or bools, got " + f"{operands}" + ) + return super().__call__(cond_fn, body_fn, operands) + + +while_loop_op = HigherOrderOperator("while_loop") + + +def while_loop(cond_fn, body_fn, operands): + r""" + Run body_fn(*operands) while cond_fn(*operands) returns a True scalar tensor. Returns the output of body_fn or + initial operands. + + .. warning:: + `torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export. + + `while_loop` is equivalent to the following: + + def while_loop(cond_fn, body_fn, operands): + val = operands + while cond_fn(*val): + val = body_fn(*val) + return val + + Args: + cond_fn (Callable): A callable function that returns a boolean Scalar tensor. + + body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors + + operands (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also + the initial value of states that are carried across iterations. + + Example: + + def cond_fn(iter, x): + return iter.sum() < 10 + + def body_fn(iter, x): + return iter + 1, x.sin() + + while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4))) + + Restrictions: + + - body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs. + + - body_fn and cond_fn must not in-place mutate the operands. A clone before the mutation is required. + + - body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn. + + - body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required. + + .. warning:: + Temporal Limitations: + + - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. + + """ + if torch.compiler.is_dynamo_compiling(): + return while_loop_op(cond_fn, body_fn, operands) + + def _validate_input(cond_fn, body_fn, operands): + if not callable(cond_fn) or not callable(body_fn): + raise RuntimeError("Expect cond_fn and body_fn to be callbale.") + + if not isinstance(operands, (tuple, list)) or pytree.tree_any( + lambda t: not isinstance(t, torch.Tensor), operands + ): + raise RuntimeError( + "Expect operands to be a tuple of possibly nested dict/list/tuple that only" + f"consists of tensor leaves, but got {operands}." + ) + + _validate_input(cond_fn, body_fn, operands) + + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): + return torch.compile(while_loop_op, backend="eager", fullgraph=True)( + cond_fn, body_fn, operands + ) + + +@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def while_loop_dense(cond_fn, body_fn, operands): + init_val = operands + + def _is_boolean_scalar_tensor(pred): + return ( + isinstance(pred, torch.Tensor) + and pred.size() == torch.Size([]) + and pred.dtype == torch.bool + ) + + if not isinstance(operands, tuple): + raise RuntimeError(f"operands must be a tuple but got {type(operands)}") + + while pred := cond_fn(*init_val): + if not _is_boolean_scalar_tensor(pred): + raise RuntimeError( + f"cond_fn must return a boolean scalar tensor but got {pred}" + ) + out = body_fn(*init_val) + assert isinstance( + out, tuple + ), f"body_fn should return a tuple but got {type(out)}" + assert len(out) == len( + init_val + ), "body_fn should return the same number of elements as operands" + init_val = out + return init_val + + +while_loop_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(while_loop_op, deferred_error=True) +) + + +@while_loop_op.py_impl(ProxyTorchDispatchMode) +def while_loop_tracing(mode, cond_fn, body_fn, operands): + def _trace_while_loop(proxy_mode, while_loop_op, cond_fn, body_fn, operands): + pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) + with disable_proxy_modes_tracing(): + cond_graph = reenter_make_fx(cond_fn, pre_dispatch)(*operands) + body_graph = reenter_make_fx(body_fn, pre_dispatch)(*operands) + + next_name = None + i = 0 + while not next_name: + candidate = f"while_loop_cond_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + cond_graph_name = next_name + body_graph_name = f"while_loop_body_graph_{i}" + assert not hasattr(proxy_mode.tracer.root, body_graph_name) + + proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph) + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) + + args = (cond_graph, body_graph, operands) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", while_loop_op, proxy_args, {}, name="while_loop" + ) + + # body_fn return output with the same pytree and tensor meta data as operands + # so we could just return the output after one iteration. + out = body_fn(*operands) + return track_tensor_tree( + out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + if mode.enable_tracing: + return _trace_while_loop(mode, while_loop_op, cond_fn, body_fn, operands) + else: + return while_loop_op(cond_fn, body_fn, operands) + + +@while_loop_op.py_impl(FakeTensorMode) +def while_loop_fake_tensor_mode(mode, cond_fn, body_fn, operands): + return body_fn(*operands) + + +@while_loop_op.py_functionalize_impl +def while_loop_func(ctx, cond_fn, body_fn, operands): + unwrapped_operands = ctx.unwrap_tensors(operands) + with ctx.redispatch_to_next() as m: + functional_cond_fn = ctx.functionalize(cond_fn) + functional_body_fn = ctx.functionalize(body_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + for fn, fn_name in [ + (functional_cond_fn, "cond_fn"), + (functional_body_fn, "body_fn"), + ]: + if _has_potential_branch_input_mutation( + fn, unwrapped_operands, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + f"torch.while_loop's {fn_name} might be modifying the input!" + ) + + for fn in [functional_cond_fn, functional_body_fn]: + if _has_potential_branch_input_alias( + fn, unwrapped_operands, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + f"torch.while_loop's {fn_name} might be aliasing the input!" + ) + ret = while_loop_op(functional_cond_fn, functional_body_fn, unwrapped_operands) + return ctx.wrap_tensors(ret) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71c4482165f9d7da4e46451806e8caf63ce5aeea Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mkl/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mkl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..261ee764485b653a9ecd50c8e8fe3943f0c23449 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mkl/__init__.py @@ -0,0 +1,56 @@ +import torch + + +def is_available(): + r"""Return whether PyTorch is built with MKL support.""" + return torch._C.has_mkl + + +VERBOSE_OFF = 0 +VERBOSE_ON = 1 + + +class verbose: + """ + On-demand oneMKL verbosing functionality. + + To make it easier to debug performance issues, oneMKL can dump verbose + messages containing execution information like duration while executing + the kernel. The verbosing functionality can be invoked via an environment + variable named `MKL_VERBOSE`. However, this methodology dumps messages in + all steps. Those are a large amount of verbose messages. Moreover, for + investigating the performance issues, generally taking verbose messages + for one single iteration is enough. This on-demand verbosing functionality + makes it possible to control scope for verbose message dumping. In the + following example, verbose messages will be dumped out for the second + inference only. + + .. highlight:: python + .. code-block:: python + + import torch + model(data) + with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON): + model(data) + + Args: + level: Verbose level + - ``VERBOSE_OFF``: Disable verbosing + - ``VERBOSE_ON``: Enable verbosing + """ + + def __init__(self, enable): + self.enable = enable + + def __enter__(self): + if self.enable == VERBOSE_OFF: + return + st = torch._C._verbose.mkl_set_verbose(self.enable) + assert ( + st + ), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch._C._verbose.mkl_set_verbose(VERBOSE_OFF) + return False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mps/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mps/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b157c15470c85a0cf394a1937fad156c9da22b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mps/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..646ae9f8ad8871fa64bb55a604013e26a8decc89 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/openmp/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/openmp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a7fcca12d0c8be54a3a1d733facf2cf9f2e6aaa --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/openmp/__init__.py @@ -0,0 +1,6 @@ +import torch + + +def is_available(): + r"""Return whether PyTorch is built with OpenMP support.""" + return torch._C.has_openmp diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b3a6a0d87d4e5ff9e384edb1962d85a7e8f7071 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55b638f1b26f18bec15dd6f3ea4cfcda57c3ca87 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/autograd/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e94ab1bb9d636d2ebe4d57ba8230eaf597c344ee --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/autograd/__init__.py @@ -0,0 +1,52 @@ + +import sys +import torch + + +def is_available(): + return hasattr(torch._C, "_dist_autograd_init") + + +if is_available() and not torch._C._dist_autograd_init(): + raise RuntimeError("Failed to initialize torch.distributed.autograd") + +if is_available(): + from torch._C._distributed_autograd import ( + get_gradients, + backward, + _init, + _new_context, + _release_context, + _get_max_id, + _is_valid_context, + _retrieve_context, + _current_context, + _get_debug_info, + DistAutogradContext, + ) + + +class context: + ''' + Context object to wrap forward and backward passes when using + distributed autograd. The ``context_id`` generated in the ``with`` + statement is required to uniquely identify a distributed backward pass + on all workers. Each worker stores metadata associated with this + ``context_id``, which is required to correctly execute a distributed + autograd pass. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.distributed.autograd as dist_autograd + >>> with dist_autograd.context() as context_id: + >>> t1 = torch.rand((3, 3), requires_grad=True) + >>> t2 = torch.rand((3, 3), requires_grad=True) + >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() + >>> dist_autograd.backward(context_id, [loss]) + ''' + def __enter__(self): + self.autograd_context = _new_context() + return self.autograd_context._context_id() + + def __exit__(self, type, value, traceback): + _release_context(self.autograd_context._context_id()) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07e8498e8eeba6eb219e8ddfd667910e840b0ccb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/api.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/api.py new file mode 100644 index 0000000000000000000000000000000000000000..62f5d7500922ceb035f1f6229e6edb93acfb1922 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/api.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Dict, Union, Optional + +__all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent'] + +EventMetadataValue = Union[str, int, float, bool, None] + + +class EventSource(str, Enum): + """Known identifiers of the event producers.""" + + AGENT = "AGENT" + WORKER = "WORKER" + + +@dataclass +class Event: + """ + The class represents the generic event that occurs during the torchelastic job execution. + + The event can be any kind of meaningful action. + + Args: + name: event name. + source: the event producer, e.g. agent or worker + timestamp: timestamp in milliseconds when event occurred. + metadata: additional data that is associated with the event. + """ + + name: str + source: EventSource + timestamp: int = 0 + metadata: Dict[str, EventMetadataValue] = field(default_factory=dict) + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "Event"]) -> "Event": + if isinstance(data, Event): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] + return Event(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) + + +class NodeState(str, Enum): + """The states that a node can be in rendezvous.""" + + INIT = "INIT" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + +@dataclass +class RdzvEvent: + """ + Dataclass to represent any rendezvous event. + + Args: + name: Event name. (E.g. Current action being performed) + run_id: The run id of the rendezvous + message: The message describing the event + hostname: Hostname of the node + pid: The process id of the node + node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED) + master_endpoint: The master endpoint for the rendezvous store, if known + rank: The rank of the node, if known + local_id: The local_id of the node, if defined in dynamic_rendezvous.py + error_trace: Error stack trace, if this is an error event. + """ + + name: str + run_id: str + message: str + hostname: str + pid: int + node_state: NodeState + master_endpoint: str = "" + rank: Optional[int] = None + local_id: Optional[int] = None + error_trace: str = "" + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent": + if isinstance(data, RdzvEvent): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] + return RdzvEvent(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/handlers.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7c16e3fd808db47d257158cedb4aad185d41e6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/handlers.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict + + +_log_handlers: Dict[str, logging.Handler] = { + "console": logging.StreamHandler(), + "dynamic_rendezvous": logging.NullHandler(), + "null": logging.NullHandler(), +} + + +def get_logging_handler(destination: str = "null") -> logging.Handler: + global _log_handlers + return _log_handlers[destination] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py new file mode 100644 index 0000000000000000000000000000000000000000..1499943c78d24d0fdaac31526318c3067743c79c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import time +import warnings +from collections import namedtuple +from functools import wraps +from typing import Dict, Optional + +__all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream', + 'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms', + 'MetricData'] + +MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) + + +class MetricsConfig: + __slots__ = ["params"] + + def __init__(self, params: Optional[Dict[str, str]] = None): + self.params = params + if self.params is None: + self.params = {} + + +class MetricHandler(abc.ABC): + @abc.abstractmethod + def emit(self, metric_data: MetricData): + pass + + +class ConsoleMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + print( + f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}" + ) + + +class NullMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + pass + + +class MetricStream: + def __init__(self, group_name: str, handler: MetricHandler): + self.group_name = group_name + self.handler = handler + + def add_value(self, metric_name: str, metric_value: int): + self.handler.emit( + MetricData(time.time(), self.group_name, metric_name, metric_value) + ) + + +_metrics_map: Dict[str, MetricHandler] = {} +_default_metrics_handler: MetricHandler = NullMetricHandler() + + +# pyre-fixme[9]: group has type `str`; used as `None`. +def configure(handler: MetricHandler, group: Optional[str] = None): + if group is None: + global _default_metrics_handler + # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used + # as `MetricHandler`. + _default_metrics_handler = handler + else: + _metrics_map[group] = handler + + +def getStream(group: str): + if group in _metrics_map: + handler = _metrics_map[group] + else: + handler = _default_metrics_handler + return MetricStream(group, handler) + + +def _get_metric_name(fn): + qualname = fn.__qualname__ + split = qualname.split(".") + if len(split) == 1: + module = fn.__module__ + if module: + return module.split(".")[-1] + "." + split[0] + else: + return split[0] + else: + return qualname + + +def prof(fn=None, group: str = "torchelastic"): + r""" + @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates. + + The metric name defaults to the qualified name (``class_name.def_name``) of the function. + If the function does not belong to a class, it uses the leaf module name instead. + + Usage + + :: + + @metrics.prof + def x(): + pass + + @metrics.prof(group="agent") + def y(): + pass + """ + + def wrap(f): + @wraps(f) + def wrapper(*args, **kwargs): + key = _get_metric_name(f) + try: + start = time.time() + result = f(*args, **kwargs) + put_metric(f"{key}.success", 1, group) + except Exception: + put_metric(f"{key}.failure", 1, group) + raise + finally: + put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined] + return result + + return wrapper + + if fn: + return wrap(fn) + else: + return wrap + + +def profile(group=None): + """ + @profile decorator adds latency and success/failure metrics to any given function. + + Usage + + :: + + @metrics.profile("my_metric_group") + def some_function(): + """ + warnings.warn("Deprecated, use @prof instead", DeprecationWarning) + + def wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + start_time = time.time() + result = func(*args, **kwargs) + publish_metric(group, f"{func.__name__}.success", 1) + except Exception: + publish_metric(group, f"{func.__name__}.failure", 1) + raise + finally: + publish_metric( + group, + f"{func.__name__}.duration.ms", + get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined] + ) + return result + + return wrapper + + return wrap + + +def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"): + """ + Publish a metric data point. + + Usage + + :: + + put_metric("metric_name", 1) + put_metric("metric_name", 1, "metric_group_name") + """ + getStream(metric_group).add_value(metric_name, metric_value) + + +def publish_metric(metric_group: str, metric_name: str, metric_value: int): + warnings.warn( + "Deprecated, use put_metric(metric_group)(metric_name, metric_value) instead" + ) + metric_stream = getStream(metric_group) + metric_stream.add_value(metric_name, metric_value) + + +def get_elapsed_time_ms(start_time_in_seconds: float): + """Return the elapsed time in millis from the given start time.""" + end_time = time.time() + return int((end_time - start_time_in_seconds) * 1000) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e53d242eb3635195e5269a5cb1b0a45517025e2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9a9df36146b553cde0010c58e28da0f3f0cc7d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Each host in a distributed PyTorch job runs with a single TorchElastic agent, +and multiple workers (as children processes of the TorchElastic agent). +Since the workers are user-provided (your PyTorch script/job), TorchElastic +has a way to propagate errors on the trainers through the agent and up to the +scheduler, which ultimately informs the end-user about the state of the job +and applies any retry policies. + +TorchElastic categorizes errors into 3 categories: + ++----------------+----------------+--------------------------------------------------------------+ +| Category | Sub-Category | Description | ++================+================+==============================================================+ +| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) | +| +----------------+--------------------------------------------------------------+ +| | Worker Failure | any failures on the worker child process | ++----------------+----------------+--------------------------------------------------------------+ +| Platform Error | n/a | failures caused by the agent | ++----------------+----------------+--------------------------------------------------------------+ +| Infra Error | n/a | failures outside the domain of the agent and workers | +| | | (e.g. host failures) | ++----------------+----------------+--------------------------------------------------------------+ + +All errors other than "Worker Failure" are either raised canonically from the +agent process or implicitly or explicitly crash the agent process. So the +standard language (python) provided exception handling strategies apply. + +Worker Failures are special because the exception/failure originates on a different +process from the agent so the error needs to be propagated inter-process +(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process). + +TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes` +to launch the workers which has a simple file based inter-process error propagation +built-in. + +Any function or binary entrypoint decorated with :func:`record` +will write uncaught exceptions (with the trace information) to a file specified by the +environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent) +sets this env var on each child it launches, then aggregates the error files for all +children, and propagates the one with the **smallest** timestamp (e.g. the **first** error). +""" + +import json +import os +import signal +import socket +import time +import warnings +from dataclasses import dataclass, field +from datetime import datetime +from functools import wraps +from string import Template +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar + +from torch.distributed.elastic.utils.logging import get_logger + +from .error_handler import ErrorHandler # noqa: F401 +from .handlers import get_error_handler # noqa: F401 + +__all__ = ["ProcessFailure", "ChildFailedError", "record", "ErrorHandler", "get_error_handler"] + +log = get_logger(__name__) + + +JSON = Dict + +_EMPTY_ERROR_DATA = {"message": ""} +_NOT_AVAILABLE = "" + +T = TypeVar("T") + + +@dataclass +class ProcessFailure: + """ + Represent the failed process result. When the worker process fails, it may record failure root cause into the file. + + Tries to read the failure timestamp from the provided ``error_file``, + if the ``error_file`` does not exist, the timestamp is the current + timestamp (seconds since epoch). + + The ``message`` field is a concise explanation of the failure. If + the error file exists then the message is obtained from the error file. + Otherwise one is generated based on the failure signature. + + .. note:: It is assumed that the ``error_file`` is written by + ``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``. + Otherwise the behavior is undefined. + + """ + + local_rank: int + pid: int + exitcode: int + error_file: str + error_file_data: JSON = field(init=False) + message: str = field(init=False) + timestamp: int = field(init=False) + + def __post_init__(self): + self.error_file_data = _EMPTY_ERROR_DATA + if os.path.isfile(self.error_file): + try: + with open(self.error_file) as fp: + self.error_file_data = json.load(fp) + log.debug( + "User process failed with error data: %s", json.dumps(self.error_file_data, indent=2) + ) + self.message, self.timestamp = self._get_error_data( + self.error_file_data + ) + except Exception: + log.exception("Failed to parse reply file: %s", self.error_file) + raise + else: + self._set_no_reply_file() + + # make up an informative message if not already present + if not self.message: + # signals typically do not generate an error file message + if self.exitcode < 0: + self.message = ( + f"Signal {-self.exitcode} ({self.signal_name()})" + f" received by PID {self.pid}" + ) + else: + self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html" + + def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]: + message = error_file_data["message"] + if isinstance(message, str): + timestamp = int(error_file_data.get("timestamp", 0)) + else: + timestamp = int(message["extraInfo"]["timestamp"]) + return (message, timestamp) + + def _set_no_reply_file(self): + self.error_file = _NOT_AVAILABLE + self.error_file_data = _EMPTY_ERROR_DATA + self.message = "" + self.timestamp = int(time.time()) + + def signal_name(self) -> str: + if self.exitcode < 0: + # We don't want to kill the parent process trying to find the signal name. + # if the signal doesn't map to a known name, use not available. + try: + return signal.Signals(-self.exitcode).name + except Exception: + return _NOT_AVAILABLE + else: + return _NOT_AVAILABLE + + def timestamp_isoformat(self): + """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS).""" + return datetime.fromtimestamp(self.timestamp).isoformat(sep="_") + + +GlobalRank = int + +_FAILURE_FORMAT_TEMPLATE = """[${idx}]: + time : ${time} + host : ${hostname} + rank : ${rank} (local_rank: ${local_rank}) + exitcode : ${exitcode} (pid: ${pid}) + error_file: ${error_file} + traceback : ${message}""" + +# extra new lines before and after are intentional +_MSG_FORMAT_TEMPLATE = """ +${boarder} +${title} +${section} +Failures: +${other_failures} +${section} +Root Cause (first observed failure): +${root_failure} +${boarder}""" + + +class ChildFailedError(Exception): + """ + Special exception type that can be raised from a function annotated with the + ``@record`` decorator to have the child process' (root exception) propagate + up the stack as-is (e.g. without being wrapped in the parent's traceback). + + Useful in cases where the parent is a simple nanny process + and the child (worker) processes are actually doing meaningful compute. + In this case, errors typically occur on the child process as the parent + is not doing anything non-trivial, and child errors should be propagated + to the scheduler for accurate root cause diagnostics. + + .. note:: The propagation relies on error files rather than exception handling to + support both function and binary launches. + + Example: + :: + + # process tree on a host (container) + 0: scheduler-init-process: + |- 1: torchelastic_agent: + |- 2: trainer_0 (ok) + |- 3: trainer_1 (fail) -> error.json + |- ... + |- n+2: trainer_n (ok) + |- n+3: other processes + |- ... + + In the example above, trainer 1's failure (written into error.json) is + the root cause and should be reported to the scheduler's init process. + The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})`` + upon detecting trainer 1's failure which would propagate the contents + of trainer 1's error file to the scheduler's init process. + """ + + def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]): + self.name = name + self.failures = failures + assert ( + self.failures + ) # does not make sense to create a ChildFaileError with no failures + super().__init__(self.format_msg()) + + def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]: + rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp) + return rank, self.failures[rank] + + def format_msg(self, boarder_delim="=", section_delim="-"): + title = f"{self.name} FAILED" + root_rank, root_failure = self.get_first_failure() + + root_failure_fmt: str = "" + other_failures_fmt: List[str] = [] + width = len(title) + for idx, (rank, failure) in enumerate(self.failures.items()): + fmt, w = self._format_failure(idx, rank, failure) + width = max(width, w) + if rank == root_rank: + root_failure_fmt = fmt + else: + other_failures_fmt.append(fmt) + + # upper boundary on width + width = min(width, 60) + + return Template(_MSG_FORMAT_TEMPLATE).substitute( + boarder=boarder_delim * width, + title=title, + section=section_delim * width, + root_failure=root_failure_fmt, + other_failures="\n".join(other_failures_fmt or [" "]), + ) + + def _format_failure( + self, idx: int, rank: int, failure: ProcessFailure + ) -> Tuple[str, int]: + + # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) + # or a dict (json) of the form + # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} + # so the display logic is: + # 1. if failure.message is not a dict (it is a str) just show it as is + # 2. else try to get the traceback (py_callstack) + # 3. if the traceback is not there, use the message + # 4. if the message is not there show + msg = failure.message + if isinstance(failure.message, dict): + msg = ( + failure.message.get("extraInfo", {}) + .get("py_callstack", failure.message.get("message", "")) + .replace("\n", "\n ") # to properly indent the traceback + ) + + fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( + idx=idx, + time=failure.timestamp_isoformat(), + hostname=socket.getfqdn(), + rank=rank, + local_rank=failure.local_rank, + exitcode=failure.exitcode, + pid=failure.pid, + error_file=failure.error_file, + message=msg, + ) + width = 0 + for line in fmt.split("\n"): + width = max(width, len(line)) + return fmt, width + + +def record( + fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None +) -> Callable[..., T]: + """ + Syntactic sugar to record errors/exceptions that happened in the decorated + function using the provided ``error_handler``. + + Using this decorator is equivalent to: + + :: + + error_handler = get_error_handler() + error_handler.initialize() + try: + foobar() + except ChildFailedError as e: + _, failure = e.get_first_failure() + error_handler.dump_error_file(failure.error_file, failure.exitcode) + raise + except Exception as e: + error_handler.record(e) + raise + + .. important:: use this decorator once per process at the top level method, + typically this is the main method. + + Example + + :: + + @record + def main(): + pass + + if __name__=="__main__": + main() + + """ + if not error_handler: + error_handler = get_error_handler() + + def wrap(f): + @wraps(f) + def wrapper(*args, **kwargs): + assert error_handler is not None # assertion for mypy type checker + error_handler.initialize() + try: + return f(*args, **kwargs) + except SystemExit as se: + # For run_path based entrypoints, SystemExit with code = 0 will never exit. + # Handling it here by returning a value: + if se.code == 0: + return None + else: + raise + except ChildFailedError as e: + rank, failure = e.get_first_failure() + if failure.error_file != _NOT_AVAILABLE: + error_handler.dump_error_file(failure.error_file, failure.exitcode) + else: + log.info( + ( + "local_rank %s FAILED with no error file." + " Decorate your entrypoint fn with @record for traceback info." + " See: https://pytorch.org/docs/stable/elastic/errors.html", + rank + ) + ) + raise + except Exception as e: + error_handler.record_exception(e) + raise + + return wrapper + + return wrap(fn) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..3071aef1711785602265a4dec81405b382444132 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# Multiprocessing error-reporting module + + +from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler + +__all__ = ['get_error_handler'] + +def get_error_handler(): + return ErrorHandler() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45313bcb09a82047f1cecc595ab6ea260c645d0e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beaff3418692161b11ccdde5d07c9bf11efd33a8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..8d4477452a200edb881ae3573ff63db6c9f67e65 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict, Tuple + +from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) + +__all__ = ["get_subprocess_handler"] + + +def get_subprocess_handler( + entrypoint: str, + args: Tuple, + env: Dict[str, str], + stdout: str, + stderr: str, + local_rank_id: int, +): + return SubprocessHandler( + entrypoint=entrypoint, + args=args, + env=env, + stdout=stdout, + stderr=stderr, + local_rank_id=local_rank_id, + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4b2a46c4231dcec6f2b99af677b6979083b4b7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Expiration timers are set up on the same process as the agent and +used from your script to deal with stuck workers. When you go into +a code-block that has the potential to get stuck you can acquire +an expiration timer, which instructs the timer server to kill the +process if it does not release the timer by the self-imposed expiration +deadline. + +Usage:: + + import torchelastic.timer as timer + import torchelastic.agent.server as agent + + def main(): + start_method = "spawn" + message_queue = mp.get_context(start_method).Queue() + server = timer.LocalTimerServer(message, max_interval=0.01) + server.start() # non-blocking + + spec = WorkerSpec( + fn=trainer_func, + args=(message_queue,), + ...) + agent = agent.LocalElasticAgent(spec, start_method) + agent.run() + + def trainer_func(message_queue): + timer.configure(timer.LocalTimerClient(message_queue)) + with timer.expires(after=60): # 60 second expiry + # do some work + +In the example above if ``trainer_func`` takes more than 60 seconds to +complete, then the worker process is killed and the agent retries the worker group. +""" + +from .api import TimerClient, TimerRequest, TimerServer, configure, expires # noqa: F401 +from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 +from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest # noqa: F401 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e3bcd376131591a9d8cb4643c8b027325d22e0d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/local_timer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/local_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..05f467c807a5bc61bb0a3c6853cd17243636e1cb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/local_timer.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import logging +import multiprocessing as mp +import os +import signal +import time +from queue import Empty +from typing import Any, Dict, List, Set, Tuple + +from .api import RequestQueue, TimerClient, TimerRequest, TimerServer + +__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer'] + +log = logging.getLogger(__name__) + +class LocalTimerClient(TimerClient): + """ + Client side of ``LocalTimerServer``. This client is meant to be used + on the same host that the ``LocalTimerServer`` is running on and uses + pid to uniquely identify a worker. This is particularly useful in situations + where one spawns a subprocess (trainer) per GPU on a host with multiple + GPU devices. + """ + + def __init__(self, mp_queue): + super().__init__() + self._mp_queue = mp_queue + + def acquire(self, scope_id, expiration_time): + pid = os.getpid() + acquire_request = TimerRequest(pid, scope_id, expiration_time) + self._mp_queue.put(acquire_request) + + def release(self, scope_id): + pid = os.getpid() + release_request = TimerRequest(pid, scope_id, -1) + self._mp_queue.put(release_request) + + +class MultiprocessingRequestQueue(RequestQueue): + """ + A ``RequestQueue`` backed by python ``multiprocessing.Queue`` + """ + + def __init__(self, mp_queue: mp.Queue): + super().__init__() + self._mp_queue = mp_queue + + def size(self) -> int: + return self._mp_queue.qsize() + + def get(self, size, timeout: float) -> List[TimerRequest]: + requests = [] + wait = timeout + for _ in range(0, size): + start = time.time() + + try: + r = self._mp_queue.get(block=True, timeout=wait) + except Empty: + break + + requests.append(r) + wait = wait - (time.time() - start) + if wait <= 0: + break + + return requests + + +class LocalTimerServer(TimerServer): + """ + Server that works with ``LocalTimerClient``. Clients are expected to be + subprocesses to the parent process that is running this server. Each host + in the job is expected to start its own timer server locally and each + server instance manages timers for local workers (running on processes + on the same host). + """ + + def __init__( + self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True + ): + super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) + self._timers: Dict[Tuple[Any, str], TimerRequest] = {} + + def register_timers(self, timer_requests: List[TimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_id + scope_id = request.scope_id + expiration_time = request.expiration_time + + # negative expiration is a proxy for a release call + if expiration_time < 0: + self._timers.pop((pid, scope_id), None) + else: + self._timers[(pid, scope_id)] = request + + def clear_timers(self, worker_ids: Set[int]) -> None: + for (pid, scope_id) in list(self._timers.keys()): + if pid in worker_ids: + self._timers.pop((pid, scope_id)) + + def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]: + # pid -> [timer_requests...] + expired_timers: Dict[Any, List[TimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_id, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_id: int) -> bool: + try: + os.kill(worker_id, signal.SIGKILL) + return True + except ProcessLookupError: + log.info("Process with pid=%s does not exist. Skipping", worker_id) + return True + except Exception: + log.exception("Error terminating pid=%s", worker_id) + return False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ed1b42cbe1582f3b974ccbb1befc90637ba18e0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/__init__.py @@ -0,0 +1,4 @@ +import torch +if torch.distributed.rpc.is_available(): + from .api.remote_module import RemoteModule +from .functional import * # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/api/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87c124c62774bd5a7dd930652d411aff690fe7d5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/templates/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a22da527735321351e324895aacbaa4ba91549b1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7da201e04c0b666db263151621175debd04d3bc9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..671dcb5ac00b326b794ed4c87bdb1ccec3c50805 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f70bf013b695d470f5d952fdd186533614735e6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36fcfc95699bb3f41461ee689e3338a06b9662f1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2058d5d49349178770615d0cc2eb627a7c2e02d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bb904db8cdcd88de02f02c6c24388dbf8c644d4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/api.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e2e5bc1bfdfd29d61048147286d0256de4c9e0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/api.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Dict, Union + +import torch +import torch.distributed._tensor.random as random +import torch.nn as nn +from torch.distributed._tensor import ( + DeviceMesh, +) +from torch.distributed._tensor.random import ( + is_rng_supported_mesh, + TensorParallelRNGTracker, +) +from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim +from torch.distributed.tensor.parallel.style import ( + ParallelStyle, +) + + +__all__ = [ + "parallelize_module", +] + + +def parallelize_module( # type: ignore[return] + module: nn.Module, + device_mesh: DeviceMesh, + parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]], +) -> nn.Module: + """ + Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. + + We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains + :class:`ParallelStyle`, which indicates how user wants the module or sub_module + to be parallelized. + + User can also specify different parallel style per module fully qualified name (FQN). + + Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`, + slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``) + + Args: + module (:class:`nn.Module`): + Module to be parallelized. + device_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]): + The plan used to parallelize the module. It can be either a + :class:`ParallelStyle` object which contains how + we prepare input/output for Tensor Parallelism or it can be a + dict of module FQN and its corresponding :class:`ParallelStyle` object. + Return: + A :class:`nn.Module` object parallelized. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> + >>> # Define the module. + >>> m = Model(...) + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) + >>> + + .. note:: For complex module architecture like Attention, MLP layers, we recommend composing + different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass + as a parallelize_plan, to achieves the desired sharding computation. + """ + torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") + + _validate_tp_mesh_dim(device_mesh) + + # instantiate a TP RNG state tracker if it's not there + if is_rng_supported_mesh(device_mesh) and not isinstance( + random._rng_tracker, TensorParallelRNGTracker + ): + random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type) + # TODO: we should allow user to pass in the default seed from a config + random._rng_tracker._manual_seed(device_mesh, base_seed=1234) + # By default we execute random ops in non-tensor-parallel region. If users want + # to execute in tensor-parallel region, they can manually set this field to True + # after parallelizing the model. + random._rng_tracker.distribute_region_enabled = False + + if isinstance(parallelize_plan, ParallelStyle): + return parallelize_plan._apply(module, device_mesh) + elif isinstance(parallelize_plan, dict): + for module_path, parallelize_style in parallelize_plan.items(): + sub_module = module.get_submodule(module_path) + parent_module = module + if "." in module_path: + parent_module_path = ".".join(module_path.split(".")[:-1]) + parent_module = module.get_submodule(parent_module_path) + module_path = module_path.split(".")[-1] + parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20] + module_path, + parallelize_module( # type: ignore[arg-type] + sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6] + ), + ) + return module + else: + raise RuntimeError( # pyre-ignore[7] + "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for" + f" parallelize_plan, {type(parallelize_plan)} found!" + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/dirichlet.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/dirichlet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6abe658b4b45eded053431e2d3aa06b89b5da49b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/dirichlet.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/exp_family.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/exp_family.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..751ebb810a7e57a061848ba4e039c5a168a8c2b9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/exp_family.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/geometric.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/geometric.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce74f53da7b823607faee8e57e9bbcf640425948 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/geometric.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/independent.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/independent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8751a021c45c677ab17d6278d3f9cb0ca1431c40 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/independent.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba1e675241436019a83ea096bae4a3425d81f4d1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7a01162f3b1836552723f7d34850773777c9be7 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/uniform.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/uniform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7746f1fd887b9ce378e574b161eb7943cb3efd65 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/uniform.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/weibull.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/weibull.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce43d58b084299aeead53f39bfd9fe1fe991b65 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/weibull.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/wishart.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/wishart.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..102b65163f7ab5045c7d3f07c5363e882bd8fab9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/__pycache__/wishart.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/binomial.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/binomial.py new file mode 100644 index 0000000000000000000000000000000000000000..9243da7b6bf4ccb503626ef02c1644c84961a716 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/binomial.py @@ -0,0 +1,165 @@ +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) + +__all__ = ["Binomial"] + + +def _clamp_by_zero(x): + # works like clamp(x, min=0) but has grad at 0 is 0.5 + return (x.clamp(min=0) + x - x.clamp(max=0)) / 2 + + +class Binomial(Distribution): + r""" + Creates a Binomial distribution parameterized by :attr:`total_count` and + either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be + broadcastable with :attr:`probs`/:attr:`logits`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) + >>> x = m.sample() + tensor([ 0., 22., 71., 100.]) + + >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8])) + >>> x = m.sample() + tensor([[ 4., 5.], + [ 7., 6.]]) + + Args: + total_count (int or Tensor): number of Bernoulli trials + probs (Tensor): Event probabilities + logits (Tensor): Event log-odds + """ + arg_constraints = { + "total_count": constraints.nonnegative_integer, + "probs": constraints.unit_interval, + "logits": constraints.real, + } + has_enumerate_support = True + + def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + ( + self.total_count, + self.probs, + ) = broadcast_all(total_count, probs) + self.total_count = self.total_count.type_as(self.probs) + else: + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) + self.total_count = self.total_count.type_as(self.logits) + + self._param = self.probs if probs is not None else self.logits + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Binomial, _instance) + batch_shape = torch.Size(batch_shape) + new.total_count = self.total_count.expand(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(Binomial, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @constraints.dependent_property(is_discrete=True, event_dim=0) + def support(self): + return constraints.integer_interval(0, self.total_count) + + @property + def mean(self): + return self.total_count * self.probs + + @property + def mode(self): + return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count) + + @property + def variance(self): + return self.total_count * self.probs * (1 - self.probs) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self): + return self._param.size() + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.binomial( + self.total_count.expand(shape), self.probs.expand(shape) + ) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + log_factorial_n = torch.lgamma(self.total_count + 1) + log_factorial_k = torch.lgamma(value + 1) + log_factorial_nmk = torch.lgamma(self.total_count - value + 1) + # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p) + # (case logit < 0) = k * logit - n * log1p(e^logit) + # (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) + # = k * logit - n * logit - n * log1p(e^-logit) + # (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) + normalize_term = ( + self.total_count * _clamp_by_zero(self.logits) + + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits))) + - log_factorial_n + ) + return ( + value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term + ) + + def entropy(self): + total_count = int(self.total_count.max()) + if not self.total_count.min() == total_count: + raise NotImplementedError( + "Inhomogeneous total count not supported by `entropy`." + ) + + log_prob = self.log_prob(self.enumerate_support(False)) + return -(torch.exp(log_prob) * log_prob).sum(0) + + def enumerate_support(self, expand=True): + total_count = int(self.total_count.max()) + if not self.total_count.min() == total_count: + raise NotImplementedError( + "Inhomogeneous total count not supported by `enumerate_support`." + ) + values = torch.arange( + 1 + total_count, dtype=self._param.dtype, device=self._param.device + ) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.expand((-1,) + self._batch_shape) + return values diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/chi2.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/chi2.py new file mode 100644 index 0000000000000000000000000000000000000000..16d0d6d60fbeb93544d21127c57f4bebcfb2bd74 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/chi2.py @@ -0,0 +1,33 @@ +from torch.distributions import constraints +from torch.distributions.gamma import Gamma + +__all__ = ["Chi2"] + + +class Chi2(Gamma): + r""" + Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`. + This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)`` + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Chi2(torch.tensor([1.0])) + >>> m.sample() # Chi2 distributed with shape df=1 + tensor([ 0.1046]) + + Args: + df (float or Tensor): shape parameter of the distribution + """ + arg_constraints = {"df": constraints.positive} + + def __init__(self, df, validate_args=None): + super().__init__(0.5 * df, 0.5, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Chi2, _instance) + return super().expand(batch_shape, new) + + @property + def df(self): + return self.concentration * 2 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/exponential.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/exponential.py new file mode 100644 index 0000000000000000000000000000000000000000..020b5215bbdb4ba6c216e3ddc70eca238df0ed44 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/exponential.py @@ -0,0 +1,84 @@ +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all + +__all__ = ["Exponential"] + + +class Exponential(ExponentialFamily): + r""" + Creates a Exponential distribution parameterized by :attr:`rate`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Exponential(torch.tensor([1.0])) + >>> m.sample() # Exponential distributed with rate=1 + tensor([ 0.1046]) + + Args: + rate (float or Tensor): rate = 1 / scale of the distribution + """ + arg_constraints = {"rate": constraints.positive} + support = constraints.nonnegative + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self): + return self.rate.reciprocal() + + @property + def mode(self): + return torch.zeros_like(self.rate) + + @property + def stddev(self): + return self.rate.reciprocal() + + @property + def variance(self): + return self.rate.pow(-2) + + def __init__(self, rate, validate_args=None): + (self.rate,) = broadcast_all(rate) + batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Exponential, _instance) + batch_shape = torch.Size(batch_shape) + new.rate = self.rate.expand(batch_shape) + super(Exponential, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + return self.rate.new(shape).exponential_() / self.rate + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return self.rate.log() - self.rate * value + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 1 - torch.exp(-self.rate * value) + + def icdf(self, value): + return -torch.log1p(-value) / self.rate + + def entropy(self): + return 1.0 - torch.log(self.rate) + + @property + def _natural_params(self): + return (-self.rate,) + + def _log_normalizer(self, x): + return -torch.log(-x) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/geometric.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf2f3dbacc67f3a4c8be53d2045523f9f9ec113 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/geometric.py @@ -0,0 +1,128 @@ +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.nn.functional import binary_cross_entropy_with_logits + +__all__ = ["Geometric"] + + +class Geometric(Distribution): + r""" + Creates a Geometric distribution parameterized by :attr:`probs`, + where :attr:`probs` is the probability of success of Bernoulli trials. + + .. math:: + + P(X=k) = (1-p)^{k} p, k = 0, 1, ... + + .. note:: + :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success + hence draws samples in :math:`\{0, 1, \ldots\}`, whereas + :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Geometric(torch.tensor([0.3])) + >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 + tensor([ 2.]) + + Args: + probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1] + logits (Number, Tensor): the log-odds of sampling `1`. + """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.nonnegative_integer + + def __init__(self, probs=None, logits=None, validate_args=None): + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + (self.probs,) = broadcast_all(probs) + else: + (self.logits,) = broadcast_all(logits) + probs_or_logits = probs if probs is not None else logits + if isinstance(probs_or_logits, Number): + batch_shape = torch.Size() + else: + batch_shape = probs_or_logits.size() + super().__init__(batch_shape, validate_args=validate_args) + if self._validate_args and probs is not None: + # Add an extra check beyond unit_interval + value = self.probs + valid = value > 0 + if not valid.all(): + invalid_value = value.data[~valid] + raise ValueError( + "Expected parameter probs " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"of distribution {repr(self)} " + f"to be positive but found invalid values:\n{invalid_value}" + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Geometric, _instance) + batch_shape = torch.Size(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + super(Geometric, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self): + return 1.0 / self.probs - 1.0 + + @property + def mode(self): + return torch.zeros_like(self.probs) + + @property + def variance(self): + return (1.0 / self.probs - 1.0) / self.probs + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits, is_binary=True) + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + tiny = torch.finfo(self.probs.dtype).tiny + with torch.no_grad(): + if torch._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .uniform_() + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + u = u.clamp(min=tiny) + else: + u = self.probs.new(shape).uniform_(tiny, 1) + return (u.log() / (-self.probs).log1p()).floor() + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value, probs = broadcast_all(value, self.probs) + probs = probs.clone(memory_format=torch.contiguous_format) + probs[(probs == 1) & (value == 0)] = 0 + return value * (-probs).log1p() + self.probs.log() + + def entropy(self): + return ( + binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none") + / self.probs + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/laplace.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/laplace.py new file mode 100644 index 0000000000000000000000000000000000000000..7b830cc76f9b43149105d4c6d75560560117a18f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/laplace.py @@ -0,0 +1,94 @@ +from numbers import Number + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all + +__all__ = ["Laplace"] + + +class Laplace(Distribution): + r""" + Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # Laplace distributed with loc=0, scale=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of the distribution + scale (float or Tensor): scale of the distribution + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + + @property + def mean(self): + return self.loc + + @property + def mode(self): + return self.loc + + @property + def variance(self): + return 2 * self.scale.pow(2) + + @property + def stddev(self): + return (2**0.5) * self.scale + + def __init__(self, loc, scale, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + if isinstance(loc, Number) and isinstance(scale, Number): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Laplace, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Laplace, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + finfo = torch.finfo(self.loc.dtype) + if torch._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .uniform_() + u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1 + return self.loc - self.scale * u.sign() * torch.log1p( + -u.abs().clamp(min=finfo.tiny) + ) + u = self.loc.new(shape).uniform_(finfo.eps - 1, 1) + # TODO: If we ever implement tensor.nextafter, below is what we want ideally. + # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5) + return self.loc - self.scale * u.sign() * torch.log1p(-u.abs()) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1( + -(value - self.loc).abs() / self.scale + ) + + def icdf(self, value): + term = value - 0.5 + return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs()) + + def entropy(self): + return 1 + torch.log(2 * self.scale) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/log_normal.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/log_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..f6694cf9507f1b22e76fa3700e4f64240ac2ba99 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/log_normal.py @@ -0,0 +1,62 @@ +from torch.distributions import constraints +from torch.distributions.normal import Normal +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import ExpTransform + +__all__ = ["LogNormal"] + + +class LogNormal(TransformedDistribution): + r""" + Creates a log-normal distribution parameterized by + :attr:`loc` and :attr:`scale` where:: + + X ~ Normal(loc, scale) + Y = exp(X) ~ LogNormal(loc, scale) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # log-normal distributed with mean=0 and stddev=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of log of distribution + scale (float or Tensor): standard deviation of log of the distribution + """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.positive + has_rsample = True + + def __init__(self, loc, scale, validate_args=None): + base_dist = Normal(loc, scale, validate_args=validate_args) + super().__init__(base_dist, ExpTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LogNormal, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def loc(self): + return self.base_dist.loc + + @property + def scale(self): + return self.base_dist.scale + + @property + def mean(self): + return (self.loc + self.scale.pow(2) / 2).exp() + + @property + def mode(self): + return (self.loc - self.scale.square()).exp() + + @property + def variance(self): + scale_sq = self.scale.pow(2) + return scale_sq.expm1() * (2 * self.loc + scale_sq).exp() + + def entropy(self): + return self.base_dist.entropy() + self.loc diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/multivariate_normal.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/multivariate_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..2784eeb214d5c59e1e3aa3ac21c7059f032f16de --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/multivariate_normal.py @@ -0,0 +1,262 @@ +import math + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import _standard_normal, lazy_property + +__all__ = ["MultivariateNormal"] + + +def _batch_mv(bmat, bvec): + r""" + Performs a batched matrix-vector product, with compatible but different batch shapes. + + This function takes as input `bmat`, containing :math:`n \times n` matrices, and + `bvec`, containing length :math:`n` vectors. + + Both `bmat` and `bvec` may have any number of leading dimensions, which correspond + to a batch shape. They are not necessarily assumed to have the same batch shape, + just ones which can be broadcasted. + """ + return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1) + + +def _batch_mahalanobis(bL, bx): + r""" + Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}` + for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`. + + Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch + shape, but `bL` one should be able to broadcasted to `bx` one. + """ + n = bx.size(-1) + bx_batch_shape = bx.shape[:-1] + + # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), + # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve + bx_batch_dims = len(bx_batch_shape) + bL_batch_dims = bL.dim() - 2 + outer_batch_dims = bx_batch_dims - bL_batch_dims + old_batch_dims = outer_batch_dims + bL_batch_dims + new_batch_dims = outer_batch_dims + 2 * bL_batch_dims + # Reshape bx with the shape (..., 1, i, j, 1, n) + bx_new_shape = bx.shape[:outer_batch_dims] + for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): + bx_new_shape += (sx // sL, sL) + bx_new_shape += (n,) + bx = bx.reshape(bx_new_shape) + # Permute bx to make it have shape (..., 1, j, i, 1, n) + permute_dims = ( + list(range(outer_batch_dims)) + + list(range(outer_batch_dims, new_batch_dims, 2)) + + list(range(outer_batch_dims + 1, new_batch_dims, 2)) + + [new_batch_dims] + ) + bx = bx.permute(permute_dims) + + flat_L = bL.reshape(-1, n, n) # shape = b x n x n + flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n + flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c + M_swap = ( + torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) + ) # shape = b x c + M = M_swap.t() # shape = c x b + + # Now we revert the above reshape and permute operators. + permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1) + permute_inv_dims = list(range(outer_batch_dims)) + for i in range(bL_batch_dims): + permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] + reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1) + return reshaped_M.reshape(bx_batch_shape) + + +def _precision_to_scale_tril(P): + # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril + Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1))) + L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1) + Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device) + L = torch.linalg.solve_triangular(L_inv, Id, upper=False) + return L + + +class MultivariateNormal(Distribution): + r""" + Creates a multivariate normal (also called Gaussian) distribution + parameterized by a mean vector and a covariance matrix. + + The multivariate normal distribution can be parameterized either + in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}` + or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}` + or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued + diagonal entries, such that + :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix + can be obtained via e.g. Cholesky decomposition of the covariance. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) + >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` + tensor([-0.2102, -0.5429]) + + Args: + loc (Tensor): mean of the distribution + covariance_matrix (Tensor): positive-definite covariance matrix + precision_matrix (Tensor): positive-definite precision matrix + scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal + + Note: + Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or + :attr:`scale_tril` can be specified. + + Using :attr:`scale_tril` will be more efficient: all computations internally + are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or + :attr:`precision_matrix` is passed instead, it is only used to compute + the corresponding lower triangular matrices using a Cholesky decomposition. + """ + arg_constraints = { + "loc": constraints.real_vector, + "covariance_matrix": constraints.positive_definite, + "precision_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + } + support = constraints.real_vector + has_rsample = True + + def __init__( + self, + loc, + covariance_matrix=None, + precision_matrix=None, + scale_tril=None, + validate_args=None, + ): + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + if (covariance_matrix is not None) + (scale_tril is not None) + ( + precision_matrix is not None + ) != 1: + raise ValueError( + "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + ) + + if scale_tril is not None: + if scale_tril.dim() < 2: + raise ValueError( + "scale_tril matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) + self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) + elif covariance_matrix is not None: + if covariance_matrix.dim() < 2: + raise ValueError( + "covariance_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes( + covariance_matrix.shape[:-2], loc.shape[:-1] + ) + self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) + else: + if precision_matrix.dim() < 2: + raise ValueError( + "precision_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes( + precision_matrix.shape[:-2], loc.shape[:-1] + ) + self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) + self.loc = loc.expand(batch_shape + (-1,)) + + event_shape = self.loc.shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + if scale_tril is not None: + self._unbroadcasted_scale_tril = scale_tril + elif covariance_matrix is not None: + self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) + else: # precision_matrix is not None + self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MultivariateNormal, _instance) + batch_shape = torch.Size(batch_shape) + loc_shape = batch_shape + self.event_shape + cov_shape = batch_shape + self.event_shape + self.event_shape + new.loc = self.loc.expand(loc_shape) + new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril + if "covariance_matrix" in self.__dict__: + new.covariance_matrix = self.covariance_matrix.expand(cov_shape) + if "scale_tril" in self.__dict__: + new.scale_tril = self.scale_tril.expand(cov_shape) + if "precision_matrix" in self.__dict__: + new.precision_matrix = self.precision_matrix.expand(cov_shape) + super(MultivariateNormal, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @lazy_property + def scale_tril(self): + return self._unbroadcasted_scale_tril.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @lazy_property + def covariance_matrix(self): + return torch.matmul( + self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT + ).expand(self._batch_shape + self._event_shape + self._event_shape) + + @lazy_property + def precision_matrix(self): + return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @property + def mean(self): + return self.loc + + @property + def mode(self): + return self.loc + + @property + def variance(self): + return ( + self._unbroadcasted_scale_tril.pow(2) + .sum(-1) + .expand(self._batch_shape + self._event_shape) + ) + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + diff = value - self.loc + M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff) + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) + return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det + + def entropy(self): + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) + H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det + if len(self._batch_shape) == 0: + return H + else: + return H.expand(self._batch_shape) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/relaxed_categorical.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/relaxed_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..245ab87aa2a75291d4d74d4845720b0bfa8fe935 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributions/relaxed_categorical.py @@ -0,0 +1,139 @@ +import torch +from torch.distributions import constraints +from torch.distributions.categorical import Categorical +from torch.distributions.distribution import Distribution +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import ExpTransform +from torch.distributions.utils import broadcast_all, clamp_probs + +__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"] + + +class ExpRelaxedCategorical(Distribution): + r""" + Creates a ExpRelaxedCategorical parameterized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). + Returns the log of a point in the simplex. Based on the interface to + :class:`OneHotCategorical`. + + Implementation based on [1]. + + See also: :func:`torch.distributions.OneHotCategorical` + + Args: + temperature (Tensor): relaxation temperature + probs (Tensor): event probabilities + logits (Tensor): unnormalized log probability for each event + + [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables + (Maddison et al, 2017) + + [2] Categorical Reparametrization with Gumbel-Softmax + (Jang et al, 2017) + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = ( + constraints.real_vector + ) # The true support is actually a submanifold of this. + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + self._categorical = Categorical(probs, logits) + self.temperature = temperature + batch_shape = self._categorical.batch_shape + event_shape = self._categorical.param_shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ExpRelaxedCategorical, _instance) + batch_shape = torch.Size(batch_shape) + new.temperature = self.temperature + new._categorical = self._categorical.expand(batch_shape) + super(ExpRelaxedCategorical, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._categorical._new(*args, **kwargs) + + @property + def param_shape(self): + return self._categorical.param_shape + + @property + def logits(self): + return self._categorical.logits + + @property + def probs(self): + return self._categorical.probs + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + uniforms = clamp_probs( + torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) + ) + gumbels = -((-(uniforms.log())).log()) + scores = (self.logits + gumbels) / self.temperature + return scores - scores.logsumexp(dim=-1, keepdim=True) + + def log_prob(self, value): + K = self._categorical._num_events + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + log_scale = torch.full_like( + self.temperature, float(K) + ).lgamma() - self.temperature.log().mul(-(K - 1)) + score = logits - value.mul(self.temperature) + score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1) + return score + log_scale + + +class RelaxedOneHotCategorical(TransformedDistribution): + r""" + Creates a RelaxedOneHotCategorical distribution parametrized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. + This is a relaxed version of the :class:`OneHotCategorical` distribution, so + its samples are on simplex, and are reparametrizable. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), + ... torch.tensor([0.1, 0.2, 0.3, 0.4])) + >>> m.sample() + tensor([ 0.1294, 0.2324, 0.3859, 0.2523]) + + Args: + temperature (Tensor): relaxation temperature + probs (Tensor): event probabilities + logits (Tensor): unnormalized log probability for each event + """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = constraints.simplex + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist = ExpRelaxedCategorical( + temperature, probs, logits, validate_args=validate_args + ) + super().__init__(base_dist, ExpTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RelaxedOneHotCategorical, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def temperature(self): + return self.base_dist.temperature + + @property + def logits(self): + return self.base_dist.logits + + @property + def probs(self): + return self.base_dist.probs diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/func/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/func/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a572ed3373dab3959e223747e2618d1a7ed1ed14 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/func/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cee8ed508438527f42f255fdfe1436227eecabbd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/_ops.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..2a2ff3fd6f857507df133598b057a8e33a65fad0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/_ops.py @@ -0,0 +1,1796 @@ + +import warnings + +# A workaround to support both TorchScript and MyPy: +from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from torch import Tensor +from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor +from . import _docs +from torch._prims_common import corresponding_real_dtype +from torch import sym_float + +if TYPE_CHECKING: + from torch.types import _dtype as DType + + DimOrDims = Optional[Union[int, Tuple[int], List[int]]] +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + DimOrDims = Optional[Tuple[int]] + + +__all__: List[str] = [] + +# All masked reduction/normalization operations have the same +# signatures. Here we introduce docstring templates that are applied +# to docstrings of reduction/normalization functions via +# _apply_docstring_templates decorator. + + +def _apply_docstring_templates(func): + """Decorator that applies docstring templates to function docstring + and returns the function instance. + """ + + doc_string = getattr(_docs, f"{func.__name__}_docstring", None) + if doc_string is None: + warnings.warn( + f"No documentation string available for {func.__name__}." + " PyTorch team should run `python tools/update_masked_docs.py`" + " to generate the missing docstrings." + ) + else: + func.__doc__ = doc_string + + # Expose function as public symbol + __all__.append(func.__name__) + + return func + + +def _generate_docstring(func): + """A utility function called from tools/update_masked_docs.py + script to update the module torch.masked._docs.py + """ + docstring_templates = dict( + reduction_signature="""\ +{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", + reduction_descr="""\ +Returns {operation name} of all the elements in the :attr:`input` +tensor along the given dimension(s) :attr:`dim` while the :attr:`input` +elements are masked out according to the boolean tensor +:attr:`mask`.""", + reduction_args="""\ +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of +size 1. Otherwise, :attr:`dim` is squeezed (see +:func:`torch.squeeze`), resulting in the output tensor having 1 (or +``len(dim)``) fewer dimension(s). + +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True +then the corresponding element in :attr:`input` tensor will be +included in {operation name} computation, otherwise the element is +ignored. + +When all elements of :attr:`input` along the given dimension +:attr:`dim` are ignored (fully masked-out), the corresponding element +of the output tensor will have undefined value: it may or may not +correspond to the identity value of {operation name} operation; the +choice may correspond to the value that leads to the most efficient +storage of :attr:`output` tensor. + +The mask of the output tensor can be computed as +``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, +dtype=torch.bool)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + {args_declarations} + +Keyword args: + {kwargs_declarations}""", + reduction_example="""\ +Example:: + + >>> input = {example_input} + >>> input + {indent_example_input} + >>> mask = {example_mask} + >>> mask + {indent_example_mask} + >>> {full_function_name}(input, {example_args}, mask=mask) + {indent_example_output} +""", + reduction_identity="""\ +The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""", + reduction_identity_dtype="""\ +The identity value of {operation name} operation, which is used to start the +reduction, depends on input dtype. For instance, for float32, uint8, +and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""", + normalization_signature="""\ +{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", + normalization_descr="""\ +Returns {operation name} of all the slices in the :attr:`input` tensor +along :attr:`dim` while the :attr:`input` elements are masked out +according to the boolean tensor :attr:`mask`. + +{definition}""", + normalization_args="""\ +The boolean tensor :attr:`mask` defines the "validity" of +:attr:`input` tensor elements: if :attr:`mask` element is True then +the corresponding element in :attr:`input` tensor will be included in +{operation name} computation, otherwise the element is ignored. + +The values of masked-out elements of the output tensor have undefined +value: it may or may not be set to zero or nan; the choice may correspond to +the value that leads to the most efficient storage of :attr:`output` +tensor. + +The mask of the {operation name} output tensor can be computed as +``torch.broadcast_to(mask, input.shape)``. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the :attr:`mask` +tensor must not be greater than of the :attr:`input` tensor. + +Args: + input (Tensor): the input tensor + {args_declarations} + +Keyword args: + {kwargs_declarations}""", + normalization_example="""\ +Example:: + + >>> input = {example_input} + >>> input + {indent_example_input} + >>> mask = {example_mask} + >>> mask + {indent_example_mask} + >>> {full_function_name}(input, {example_args}, mask=mask) + {indent_example_output} +""", + ) + + args_and_kwargs = dict( + # argument name sufficies separated by double underscore will + # be removed in the final documentation string. + sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + cumsum=(("dim__as_int",), ("dtype=None", "mask=None")), + cumprod=(("dim__as_int",), ("dtype=None", "mask=None")), + amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + norm=( + ( + "ord", + "dim", + ), + ("keepdim=False", "dtype=None", "mask=None"), + ), + var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + softmin=(("dim__as_int",), ("dtype=None", "mask=None")), + normalize=( + ( + "ord__required", + "dim__as_int", + ), + ("eps=1e-12", "dtype=None", "mask=None"), + ), + ) + + argument_declarations = dict( + dim="""\ +dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``.""", + dim__as_int="""\ +dim (int): the dimension along which {operation name} is computed.""", + ord="""\ +ord (int, float, optional): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + ord__required="""\ +ord (int, float): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + unbiased="""\ +unbiased (bool): when True, use Bessel’s correction, otherwise, compute + the uncorrected sample variance.""", + eps="""\ +eps (float, optional): small value to avoid division by zero. Default: {default}.""", + keepdim="""\ +keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: {default}.""", + dtype="""\ +dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: {default}.""", + mask="""\ +mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""", + ) + + definitions = dict( + softmax="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmax of i-th element in ``x`` is +defined as ``exp(x[i])/sum(exp(x))``.""", + log_softmax="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is +defined as ``log(exp(x[i])/sum(exp(x)))``.""", + softmin="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmin of i-th element in ``x`` is +defined as ``exp(-x[i])/sum(exp(-x))``.""", + normalize="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Normalize of i-th element in ``x`` is +defined as ``x[i]/max(norm(x, p), eps)``.""", + cumsum="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``sum(x[:i])``.""", + cumprod="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``prod(x[:i])``.""", + ) + + reduction_names = dict( + sum="sum", + prod="product", + amax="maximum", + amin="minimum", + argmax="argmax", + argmin="argmin", + mean="mean", + median="median", + norm="norm", + var="variance", + std="standard_deviation", + logsumexp="logsumexp", + ) + + normalization_names = dict( + softmax="softmax", + log_softmax="log_softmax", + softmin="softmin", + normalize="normalize", + cumsum="cumulative_sum", + cumprod="cumulative_prod", + ) + + operation_names = {} + operation_names.update(reduction_names) + operation_names.update(normalization_names) + + # Default example data: + example_dim = 1 + example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]]) + example_mask = torch.tensor([[True, False, True], [False, False, False]]) + example_args: Tuple[Any, ...] + if func.__name__ in {"norm", "normalize"}: + example_args = (2.0, example_dim) + example_input = example_input.to(dtype=torch.float32) + elif func.__name__ in {"var", "std"}: + example_args = (example_dim, False) + elif func.__name__ == "median": + example_args = (example_dim,) + example_input = example_input.to(dtype=torch.float32) + else: + example_args = (example_dim,) + + operation_args: Tuple[str, ...] + operation_kwargs: Tuple[str, ...] + operation_args, operation_kwargs = args_and_kwargs[func.__name__] + arg_declarations = [ + "\n ".join( + argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines() + ) + for a in operation_args + ] + kwarg_declarations = [ + "\n ".join( + argument_declarations.get( + a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.' + ) + .format(default=a.split("=", 1)[1]) + .splitlines() + ) + for a in operation_kwargs + ] + + if func.__name__ in reduction_names: + op_kind = "reduction" + doc_sections = ["signature", "descr", "identity", "args", "example"] + elif func.__name__ in normalization_names: + op_kind = "normalization" + doc_sections = ["signature", "descr", "args", "example"] + example_input = example_input.to(dtype=torch.float32) + else: + assert 0 # add function name to operation names dictionaries + example_output = func(example_input, *example_args, mask=example_mask) + + template_data = { + "function_name": func.__name__, + "full_function_name": func.__module__ + "." + func.__name__, + "operation name": operation_names[func.__name__], + "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args), + "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs), + # one-line representation of a tensor: + "example_input": " ".join(str(example_input).split()), + "example_args": ", ".join(map(str, example_args)), + "example_mask": " ".join(str(example_mask).split()), + # multi-line representation of a tensor with indent + "indent_example_input": ("\n ").join(str(example_input).splitlines()), + "indent_example_mask": ("\n ").join(str(example_mask).splitlines()), + "indent_example_output": ("\n ").join(str(example_output).splitlines()), + } + + if func.__name__ in reduction_names: + template_data.update( + identity_uint8=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.uint8) + ), + identity_int32=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.int32) + ), + identity_float32=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.float32) + ), + ) + if func.__name__ == "norm": + template_data.update( + identity_ord_ninf=_reduction_identity( + func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf") + ) + ) + elif func.__name__ in normalization_names: + template_data.update(definition=definitions[func.__name__]) + else: + assert 0 # add function name to operation names dictionaries + template_data.update( + args_declarations=("\n ".join(arg_declarations)).format_map(template_data) + ) + template_data.update( + kwargs_declarations=("\n ".join(kwarg_declarations)).format_map( + template_data + ) + ) + + # Apply function name info to docstring templates: + templates = { + k: v.format_map(template_data) + for k, v in docstring_templates.items() + if k.startswith(op_kind) + } + templates.update( + (k, v.format_map(template_data) if isinstance(v, str) else v) + for k, v in template_data.items() + ) + + # Apply docstring templates to function doctring: + if func.__doc__ is None: + doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections]) + else: + doc_template = func.__doc__ + return doc_template.format_map(templates) + + +def _reduction_identity(op_name: str, input: Tensor, *args): + """Return identity value as scalar tensor of a reduction operation on + given input, or None, if the identity value cannot be uniquely + defined for the given input. + + The identity value of the operation is defined as the initial + value to reduction operation that has a property ``op(op_identity, + value) == value`` for any value in the domain of the operation. + Or put it another way, including or excluding the identity value in + a list of operands will not change the reduction result. + + See https://github.com/pytorch/rfcs/pull/27 for more information. + + """ + dtype: DType = input.dtype + device = input.device + op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present + if op_name in {"sum", "cumsum"}: + return torch.tensor(0, dtype=dtype, device=device) + elif op_name in {"prod", "cumprod"}: + return torch.tensor(1, dtype=dtype, device=device) + elif op_name in {"amax", "argmax", "logsumexp"}: + if torch.is_floating_point(input): + return torch.tensor(-torch.inf, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) + elif op_name in {"amin", "argmin"}: + if torch.is_floating_point(input): + return torch.tensor(torch.inf, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) + elif op_name == "mean": + # Strictly speaking, the identity value of the mean operation + # is the mean of the input. Since the mean value depends on + # the dim argument and it may be a non-scalar tensor, we + # consider the identity value of the mean operation ambiguous. + # Moreover, the mean value of empty input is undefined. + return None + elif op_name == "norm": + ord = args[0] if args else 2 + if ord == float("-inf"): + assert torch.is_floating_point(input), input.dtype + return torch.tensor(torch.inf, dtype=dtype, device=device) + return torch.tensor(0, dtype=dtype, device=device) + elif op_name == "median": + # We use NaN for now because the implementation is currently using torch.nanmedian + # and NaN is the identity for that function since it gets ignored + dtype = input.dtype if torch.is_floating_point(input) else torch.float + return torch.tensor(torch.nan, dtype=dtype, device=device) + elif op_name in {"var", "std"}: + return None + raise NotImplementedError(f"identity of {op_name} on {dtype} input") + + +def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]: + """Return dim argument as a tuple of sorted dim values.""" + dims: List[int] = [] + if dim == (): + # Currently, `dim=()` in reductions operations means "reduce + # over all dimensions" while in future, it will read "no + # reduce". See https://github.com/pytorch/pytorch/issues/29137 + # When gh-29137 is resolved, this if-block must be deleted. + dim = None + if dim is None: + return tuple(range(ndim)) + ndim = max(ndim, 1) + dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim + for d in dim_: + if d in dims: + raise RuntimeError(f"dim={d} appears multiple times in the list of dims") + if d >= ndim or d < -ndim: + raise IndexError( + f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})" + ) + dims.append(d % ndim) + return tuple(sorted(dims)) + + +def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple): + # Flatted N-D indices to 1-D indices + flat_indices = indices.new_zeros(indices.size(1)) + for d, sz in enumerate(shape): + flat_indices.mul_(sz) + flat_indices.add_(indices[d]) + return flat_indices + + +def _any(input: Tensor, dim: tuple, keepdim: bool): + # Support torch.any with tuple dim argument. + # Workaround of https://github.com/pytorch/pytorch/issues/56586 + r = input + for d in reversed(dim): + r = r.any(dim=d, keepdim=keepdim) + return r + + +def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors. + + _sparse_coo_where implements the following invariant: + + _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) == + torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) + + where `a == b` means `assertEqual(a, b)`, mask is boolean sparse + tensor, and `to_dense(fill_value)` is like `to_dense()` except + that the unspecified elements are mapped to `fill_value` rather + than to `0`. + + Returns a sparse COO tensor with the following features: + + - all specified elements correspond to masked-in elements that + have the values of the input tensor. If there exists a masked-in + element (as specified by mask) that is not specified in the + input, in the result tensor, the corresponding element has value + 0. In the dense part of the sparse tensor, the masked-out + elements are replaced with fill_value. + + - all unspecified elements correspond to masked-out elements. + """ + + assert input.layout == torch.sparse_coo + assert mask.layout == input.layout + assert mask.shape == input.shape + assert mask.dense_dim() == input.dense_dim() # TODO: eliminate this restriction + + input = input.coalesce() + + # For set operations on sparse tensor indices, we'll convert + # multi-dimensional indices to 1-D indices for efficiency. + input_flat_indices = _sparse_coo_flatten_indices( + input.indices(), input.shape[: input.sparse_dim()] + ) + mask_flat_indices = _sparse_coo_flatten_indices( + mask.indices(), mask.shape[: mask.sparse_dim()] + ) + + # the set of mask flat indices that define masked-in elements: + if mask.dense_dim() > 0: + mask_values = _any( + mask.values(), tuple(range(1, input.sparse_dim() + 1)), False + ) + else: + mask_values = mask.values() + maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]] + + def intersection(i1, i2): + union, counts = torch.cat([i1, i2]).unique(return_counts=True) + return union, torch.where(counts.gt(1)) + + def minus(i1, i2): + union, counts = torch.cat([i1, i2]).unique(return_counts=True) + return intersection(union[torch.where(counts.eq(1))], i1) + + def _apply(a): + obj, w = a + return obj[w] + + # the set of input flat indices of specified and masked-in elements: + maskin_input_flat_indices = _apply( + intersection(maskin_flat_indices, input_flat_indices) + ) + _, w = intersection(input_flat_indices, maskin_input_flat_indices) + + # the indices and values of masked-in elements + where_input_indices = input.indices()[(slice(None),) + w] + where_input_values = input.values()[w] + + if mask.dense_dim() > 0: + # apply mask to the dense part of the input values: + _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices) + where_mask_values = mask.values()[w1] + where_input_values = torch.where( + where_mask_values, where_input_values, fill_value + ) + + # the set of flat indices of unspecified input and masked-in elements: + maskin_zero_flat_indices = _apply( + minus(maskin_flat_indices, maskin_input_flat_indices) + ) + + # the indices of masked-in zero elements + _, w = intersection(mask_flat_indices, maskin_zero_flat_indices) + where_zero_indices = mask.indices()[(slice(None),) + w] + + # construct result + n = where_zero_indices.size(1) + if n == 0: + # the input is coalesced, hence input_flat_indices are ordered + # and the result is guaranteed to be coalesced: + result = torch.sparse_coo_tensor( + where_input_indices, where_input_values, input.shape + ) + return result._coalesced_(True) + + where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1) + where_values = torch.cat( + [ + where_input_values, + where_input_values.new_zeros((n,) + where_input_values.shape[1:]), + ] + ) + result = torch.sparse_coo_tensor(where_indices, where_values, input.shape) + + # appending zero elements leads to uncoalesced sparse tensor + return result.coalesce() + + +def _sparse_coo_scatter_reduction_helper( + op, + mask_input: Tensor, + dims: Tuple[int, ...], + keepdim: bool, + dtype: Optional[DType] = None, +) -> Tensor: + reduce = op.__name__ + valid_reductions = ["sum", "prod", "amax", "amin"] + if reduce not in valid_reductions: + raise ValueError( + f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" + ) + + output_dtype = dtype + values, indices = mask_input._values(), mask_input._indices() + input_dims = mask_input.dim() + num_sparse_dims = mask_input.sparse_dim() + reduced_sparse_dims = [] + retained_sparse_dims = [] + reduced_dense_dims = [] + + # promote dtype if specified + if values.dtype != output_dtype: + values = values.to(output_dtype) + + if keepdim: + output_shape = tuple( + 1 if i in dims else si for (i, si) in enumerate(mask_input.shape) + ) + else: + output_shape = tuple( + si for (i, si) in enumerate(mask_input.shape) if i not in dims + ) + + for d in dims: + if d >= input_dims: + continue + + if d < num_sparse_dims: + reduced_sparse_dims.append(d) + else: + reduced_dense_dims.append(d + 1 - num_sparse_dims) + + # Reduce dense dimensions + if len(reduced_dense_dims) > 0: + if reduce == "sum": + new_values = values + new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim)) + else: + # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities + return NotImplemented + else: + new_values = values.clone() + + # Reduce sparse dimensions + if len(reduced_sparse_dims) == num_sparse_dims: + if reduce in {"amax", "amin"} and new_values.size(0) == 0: + # IndexError: amax(): Expected reduction dim 0 to have non-zero size. + # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not + # See https://github.com/pytorch/pytorch/issues/61901 + new_values = _reduction_identity(reduce, new_values) + else: + new_values = op(new_values, dim=0) + if keepdim: + for _ in range(num_sparse_dims): + new_values = new_values.unsqueeze(0) + return new_values.to(dtype=output_dtype).to_sparse() + else: + new_indices = indices.clone() + if keepdim: + # zero out reduced sparse dimensions if keepdim = True + # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension + new_indices[reduced_sparse_dims, :] = 0 + else: + # remove reduced sparse dimensions if keepdim = False + if len(reduced_sparse_dims) > 0: + retained_sparse_dims = [ + i + for i in range(num_sparse_dims) + if i not in set(reduced_sparse_dims) + ] + new_indices = new_indices.index_select( + 0, torch.tensor(retained_sparse_dims).to(mask_input.device) + ) + + # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices + if new_indices.numel() > 0: + # lexsort indices and get index tensor for scatter reduction + new_indices, inverse_indices = torch.unique( + new_indices, return_inverse=True, dim=1 + ) + out_shape = list(new_values.shape) + out_shape[0] = new_indices.shape[1] + for _ in range(new_values.ndim - 1): + inverse_indices = inverse_indices.unsqueeze(-1) + scatter_indices = inverse_indices.expand(new_values.shape) + # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce + if output_dtype in {torch.bfloat16, torch.float16}: + new_values = new_values.to(torch.float) + out = new_values.new_empty(out_shape) + new_values = out.scatter_reduce_( + 0, scatter_indices, new_values, reduce=reduce, include_self=False + ) + new_values = new_values.to(dtype=output_dtype) + else: + out = new_values.new_empty(out_shape) + new_values = out.scatter_reduce_( + 0, scatter_indices, new_values, reduce=reduce, include_self=False + ) + + return torch.sparse_coo_tensor( + new_indices, + new_values, + output_shape, + dtype=output_dtype, + device=mask_input.device, + ) + + +def _sparse_csr_segment_reduction_helper( + op, + mask_input: Tensor, + dims: Tuple[int, ...], + keepdim: bool, + dtype: Optional[DType] = None, +) -> Tensor: + # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True + # FIXME: when dense dimensions are implemented for CSR tensors + assert ( + keepdim + ), "reduction operations on CSR tensors with keepdim=False is unsupported" + reduce = op.__name__ + valid_reductions = ["sum", "prod", "mean", "amax", "amin"] + if reduce not in valid_reductions: + raise ValueError( + f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" + ) + device = mask_input.device + output_dtype = dtype + values, crow_indices, col_indices = ( + mask_input.values(), + mask_input.crow_indices(), + mask_input.col_indices(), + ) + + # promote dtype if specified + if values.dtype != output_dtype: + values = values.to(output_dtype) + + if len(dims) == 0: + return mask_input + if len(dims) == 1: + if dims[0] == 0: + new_col_indices, scatter_indices = torch.unique( + col_indices, return_inverse=True + ) + new_nnz = new_col_indices.shape[0] + new_crow_indices = torch.tensor([0, new_nnz]) + new_values = values.new_empty(new_col_indices.shape) + new_values.scatter_reduce_( + 0, scatter_indices, values, reduce, include_self=False + ) + new_shape = [1, mask_input.size(1)] + else: + assert ( + dims[0] == 1 + ), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." + # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 + # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 + new_crow_indices = torch.cat( + ( + crow_indices.new_zeros(1), + torch.cumsum(torch.diff(crow_indices) != 0, 0), + ), + 0, + ) + new_nnz = new_crow_indices[-1] + new_col_indices = col_indices.new_zeros(new_nnz) + new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined] + new_shape = [mask_input.size(0), 1] + else: + assert len(dims) == 2 + nnz = min(1, values.numel()) + if nnz == 1: + op_kwargs = {"keepdim": True, "dtype": output_dtype} + # amax and amin do not support dtype kwarg + if reduce in ["amax", "amin"]: + del op_kwargs["dtype"] + new_values = op(values, 0, **op_kwargs) + else: + new_values = torch.empty(0, dtype=output_dtype) + new_col_indices = col_indices.new_zeros(nnz) + new_crow_indices = torch.tensor([0, nnz]) + new_shape = [1, nnz] + + return torch.sparse_csr_tensor( + new_crow_indices, + new_col_indices, + new_values, + new_shape, + dtype=output_dtype, + device=device, + ) + + +def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """Sparse variant of torch.where. Supports sparse CSR tensors.""" + # TODO: implement sparse CSR specific where operator for efficiency + return _sparse_coo_where( + mask.to_sparse_coo(), input.to_sparse_coo(), fill_value + ).to_sparse_csr() + + +def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: + """torch.where with sparse inputs support. + + _where implements the following invariant: + + _where(mask, input, fill_value).to_dense(fill_value) == + torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) + + where `a == b` means `assertEqual(a, b)`, mask is boolean sparse + tensor, and `to_dense(fill_value)` is like `to_dense()` except + that the unspecified elements are mapped to `fill_value` rather + than to `0`. + + Returns a sparse tensor with the following features: + + - all specified elements correspond to masked-in elements that + have the values of the input tensor. If there exists a masked-in + element (as specified by mask) that is not specified in the + input, in the result tensor, the corresponding element has value + 0. In the dense part of the sparse tensor, the masked-out + elements are replaced with fill_value. + + - all unspecified elements correspond to masked-out elements. + """ + if mask.layout == torch.strided: + return torch.where(mask, input, fill_value) + elif mask.layout == torch.sparse_coo: + return _sparse_coo_where(mask, input, fill_value) + elif mask.layout == torch.sparse_csr: + return _sparse_csr_where(mask, input, fill_value) + else: + raise ValueError( + f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}" + ) + + +def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor: + """Return canonical input mask. + + A canonical input mask is defined as a boolean mask tensor that + shape and layout matches with the shape and the layout of the + input. + + The canonical input mask is computed from the :attr:`mask` tensor + content to meet the following criteria: + + 1. The shape of the canonical input mask is the same as the shape + of :attr:`input` tensor. If the mask tensor has a smaller shape + than the shape of the :attr:`input`, broadcasting rules will be + applied. Downcasting of mask is not supported. + + 2. The layout of the canonical input mask is the same as the + layout of the :attr:`input` tensor. If the mask has different + layout, it will be converted to the expected layout. In the + case of sparse COO layout, the canonical input mask will be + coalesced. + + 3. The dtype of the canonical input mask is torch.bool. If the + mask dtype is not bool then it will be converted to bool dtype + using `.to(dtype=bool)` method call. + + 4. The elements of the canonical input mask have boolean values + copied from the content of the :attr:`mask` tensor (after + possible broadcasting and dtype conversion transforms). In + general, the sparsity pattern of the sparse canonical input + mask need not to be the same as the sparsity pattern of the + sparse :attr:`input` tensor. + + """ + if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: + raise ValueError( + f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}" + ) + + mask = kwargs.get("mask") + + # default mask + if mask is None: + raise ValueError("_input_mask requires explicit mask") + + # mask shape must match with input shape + if mask.shape != input.shape: + if mask.ndim > input.ndim: + raise IndexError( + "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)" + ) + if mask.layout == torch.strided: + mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool) + elif mask.layout == torch.sparse_coo: + mask = torch._sparse_broadcast_to(mask, input.shape) + else: + assert mask.layout == torch.sparse_csr + # Broadcasting of CSR tensors is not implemented. Working + # around by using COO layout. + mask = torch._sparse_broadcast_to( + mask.to_sparse(), input.shape + ).to_sparse_csr() + + # mask layout must match with input layout + if mask.layout != input.layout: + if input.layout == torch.strided: + mask = mask.to_dense() + elif input.layout == torch.sparse_coo: + if mask.layout == torch.strided: + mask = mask.to_sparse(input.sparse_dim()) + else: + mask = mask.to_sparse() + else: + assert input.layout == torch.sparse_csr + mask = mask.to_sparse_csr() + + # sparse mask must be coalesced + if mask.layout == torch.sparse_coo: + mask = mask.coalesce() + + # mask is a boolean tensor + mask = mask.to(dtype=torch.bool) + + return mask + + +def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: + """Return output mask of masked operation applied to given arguments.""" + if callable(op): + is_reduction = op.__name__ in { + "sum", + "prod", + "amax", + "amin", + "argmax", + "argmin", + "mean", + "median", + "norm", + "var", + "std", + "logsumexp", + } + is_normalization = op.__name__ in { + "softmax", + "log_softmax", + "softmin", + "normalize", + "cumsum", + "cumprod", + } + if is_reduction: + if op.__name__ == "norm": + if args: + args = args[1:] # lstrip ord argument + dim = args[0] if args else kwargs.get("dim") + outmask = _input_mask(input, *args, **kwargs) + keepdim = kwargs.get("keepdim", False) + dim_ = _canonical_dim(dim, input.ndim) + return _any(outmask, dim_, bool(keepdim)) + elif is_normalization: + return _input_mask(input, *args, **kwargs) + else: + raise ValueError( + f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})" + ) + else: + raise ValueError( + f"_output_mask expected masked operation (got {type(op).__name__} object)" + ) + + +def _combine_input_and_mask( + op, input: Union[MaskedTensor, Tensor], mask, *args +) -> Tensor: + def helper(input, mask): + if mask is None: + return input + canonical_mask = _input_mask(input, mask=mask) + if callable(op): + fill_value = _reduction_identity(op.__name__, input, *args) + return _where(canonical_mask, input, fill_value) + else: + raise ValueError( + f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)" + ) + + class Combine(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mask): + """Return input with masked-out elements eliminated for the given operations.""" + ctx.save_for_backward(mask) + + if mask is not None: + ctx.mark_non_differentiable(mask) + + return helper(input, mask) + + @staticmethod + def backward(ctx, grad_output): + (mask,) = ctx.saved_tensors + grad_data = ( + grad_output.get_data() if is_masked_tensor(grad_output) else grad_output + ) + result = as_masked_tensor(grad_data, mask) + return result, None + + return ( + Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr] + if is_masked_tensor(input) + else helper(input, mask) + ) + + +@_apply_docstring_templates +def sum( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + # __doc__ is generated by _apply_docstring_templates decorator + if dtype is None: + # promote integer types to int64 when output dtype is not specified + if input.layout == torch.sparse_csr: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + # csr.to(dtype=torch.int64) is not implemented, so + # using coo.to on input to ensure the promoted dtype + input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() + else: + dtype = input.dtype + else: + dtype = input.dtype + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + dtype = torch.int64 + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + return _sparse_coo_scatter_reduction_helper( + torch.sum, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + return torch._sparse_csr_sum( + mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def prod( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + # __doc__ is generated by _apply_docstring_templates decorator + if dtype is None: + # promote integer types to int64 when output dtype is not specified + if input.layout == torch.sparse_csr: + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + # csr.to(dtype=torch.int64) is not implemented, so + # using coo.to on input to ensure the promoted dtype + input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() + else: + dtype = input.dtype + else: + dtype = input.dtype + if input.dtype in { + torch.uint8, + torch.bool, + torch.int8, + torch.int16, + torch.int32, + }: + dtype = torch.int64 + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(prod, input, mask) + if mask_input.layout == torch.strided: + # Workaround https://github.com/pytorch/pytorch/issues/56586 + result = mask_input + result = result.to(dtype=dtype) + for d in reversed(dim_): + result = result.prod(dim=d, keepdim=bool(keepdim)) + return result + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors + raise ValueError( + "masked prod expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.prod, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + # mask is None corresponds to all-True mask. The + # unspecified elements in the CSR tensor correspond to + # zero values. Hence, the prod reduction result is + # automatically zero unless all elements are specified. + # A semi-optimal way to take this into account is to use: + # + # masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...) + # + # but that requires implementing `all` and `nonzero` + # support for sparse csr tensors. + raise ValueError( + "masked prod expects explicit mask for sparse_csr tensor input" + ) + return torch._sparse_csr_prod( + mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def cumsum( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype) + else: + raise ValueError( + f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def cumprod( + input: Tensor, + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(prod, input, mask) + if mask_input.layout == torch.strided: + return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype) + else: + raise ValueError( + f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def amax( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +{reduction_identity_dtype} + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + + mask_input = _combine_input_and_mask(amax, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask_input.layout == torch.strided: + return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch of prod, a similar issue arises here + # where unspecified elements along a dimension may need to be reduced with the result + raise ValueError( + "masked amax expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.amax, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + raise ValueError( + "masked amax expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.amax, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def amin( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +{reduction_identity_dtype} + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + + mask_input = _combine_input_and_mask(amin, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask_input.layout == torch.strided: + return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype) + elif mask_input.layout == torch.sparse_coo: + if mask is None: + # See comment in the sparse_csr branch of prod, a similar issue arises here + # where unspecified elements along a dimension may need to be reduced with the result + raise ValueError( + "masked amax expects explicit mask for sparse_coo tensor input" + ) + return _sparse_coo_scatter_reduction_helper( + torch.amin, mask_input, dim_, bool(keepdim), dtype + ) + elif mask_input.layout == torch.sparse_csr: + if mask is None: + raise ValueError( + "masked amin expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.amin, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def argmax( + input: Union[Tensor, MaskedTensor], + dim: Optional[int] = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +{reduction_identity_dtype} +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(argmax, input, mask) + if mask_input.layout == torch.strided: + return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype) + else: + raise ValueError( + f"masked argmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def argmin( + input: Union[Tensor, MaskedTensor], + dim: Optional[int] = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +{reduction_identity_dtype} +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(argmin, input, mask) + if mask_input.layout == torch.strided: + return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype) + else: + raise ValueError( + f"masked argmin expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def mean( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +By definition, the identity value of a mean operation is the mean +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +mean is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + if input.layout == torch.strided: + if mask is None: + # TODO: compute count analytically + count = sum( + torch.ones(input.shape, dtype=torch.int64, device=input.device), + dim, + keepdim=keepdim, + ) + total = sum(input, dim, keepdim=keepdim, dtype=dtype) + else: + inmask = _input_mask(input, mask=mask) + count = sum( + inmask.new_ones(input.shape, dtype=torch.int64), + dim, + keepdim=keepdim, + mask=inmask, + ) + total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) + return total / count + elif input.layout == torch.sparse_csr: + mask_input = _combine_input_and_mask(mean, input, mask) + dim_ = _canonical_dim(dim, mask_input.ndim) + if mask is None: + raise ValueError( + "masked mean expects explicit mask for sparse_csr tensor input" + ) + return _sparse_csr_segment_reduction_helper( + torch.mean, mask_input, dim_, bool(keepdim), dtype + ) + else: + raise ValueError( + f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)" + ) + + +@_apply_docstring_templates +def median( + input: Union[Tensor, MaskedTensor], + dim: int = -1, + *, + keepdim: bool = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + + """\ +{reduction_signature} +{reduction_descr} +By definition, the identity value of a median operation is the median +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +median is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + is_float = torch.is_floating_point(input) + if not is_float: + input = input.to(dtype=torch.float) + mask_input = _combine_input_and_mask(median, input, mask) + if mask_input.layout == torch.strided: + output = torch.nanmedian(mask_input, dim_, keepdim).values + if is_float: + return output + elif not is_float and not torch.isnan(output).any(): + return output.to(dtype=dtype) + else: + raise ValueError( + "masked median expects no fully masked out rows if dtype is not floating point" + ) + else: + raise ValueError( + f"masked median expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def logsumexp( + input: Tensor, + dim: DimOrDims = None, + *, + keepdim: bool = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim) + mask_input = _combine_input_and_mask(logsumexp, input, mask) + if mask_input.layout == torch.strided: + return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype) + else: + raise ValueError( + f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)" + ) + + +# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations +def logaddexp( + input: Union[Tensor, MaskedTensor], + other: Union[Tensor, MaskedTensor], + *, + dtype: Optional[DType] = None, + input_mask: Optional[Tensor] = None, + other_mask: Optional[Tensor] = None, +) -> Tensor: + """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor + +Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` +tensor. The :attr:`input` elements are masked out according to the boolean tensor +:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor +:attr:`other_mask`. + +The shapes of a mask tensor and the tensor to be masked +don't need to match, but they must be :ref:`broadcastable +` and the dimensionality of the mask +tensor must not be greater than of the tensor to be masked. + +Args: + input (Tensor): the input tensor + other (Tensor): the second input tensor + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the output tensor is + casted to :attr:`dtype` after the operation is + performed. Default: None. + input_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`input` tensor elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + other_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`other` tensor elements. + Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. + +Example:: + + >>> input = torch.tensor([-100.0, -200, -300]) + >>> input + tensor([-100., -200., -300.]) + >>> other = torch.tensor([-1.0, -2, -3]) + >>> other + tensor([-1., -2., -3.]) + >>> mask = torch.tensor([True, False, True]) + >>> mask + tensor([ True, False, True]) + >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) + tensor([-1., -inf, -3.]) +""" + if dtype is None: + dtype = input.dtype + if input.layout == torch.strided and other.layout == torch.strided: + mask_input = _combine_input_and_mask(logsumexp, input, input_mask) + mask_other = _combine_input_and_mask(logsumexp, other, other_mask) + return torch.logaddexp(mask_input, mask_other).to(dtype=dtype) + else: + raise ValueError( + f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)" + ) + + +@_apply_docstring_templates +def norm( + input: Union[Tensor, MaskedTensor], + ord: Optional[float] = 2.0, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +The identity value of norm operation, which is used to start the +reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is +``{identity_ord_ninf}``. + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + mask_input = _combine_input_and_mask(norm, input, mask, ord) + if mask_input.layout == torch.strided: + dim_ = _canonical_dim(dim, input.ndim) + return torch.linalg.vector_norm( + mask_input, ord, dim_, bool(keepdim), dtype=dtype + ) + else: + raise ValueError( + f"masked norm expects strided tensor (got {mask_input.layout} tensor)" + ) + + +def _std_var( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims, + unbiased: Optional[bool], + *, + correction_opt: Optional[Union[int, float]], + keepdim: Optional[bool], + dtype: Optional[DType], + mask: Optional[Tensor], + take_sqrt: Optional[bool], +) -> Tensor: + assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given" + correction = 1.0 + if unbiased is not None: + correction = 1.0 if unbiased else 0.0 + if correction_opt is not None: + correction = sym_float(correction_opt) + + if dtype is None: + dtype = input.dtype + if not (dtype.is_floating_point or dtype.is_complex): + dtype = torch.float32 + compute_dtype = dtype + if not (compute_dtype.is_floating_point or compute_dtype.is_complex): + compute_dtype = torch.float32 + if input.layout == torch.strided: + if mask is None: + # TODO: compute count analytically + count = sum( + torch.ones(input.shape, dtype=torch.int64, device=input.device), + dim, + keepdim=True, + ) + sample_total = sum(input, dim, keepdim=True, dtype=dtype) + else: + inmask = _input_mask(input, mask=mask) + count = sum( + inmask.new_ones(input.shape, dtype=torch.int64), + dim, + keepdim=True, + mask=inmask, + ) + sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask) + # TODO: replace torch.subtract/divide/square/maximum with + # masked subtract/divide/square/maximum when these will be + # available. + sample_mean = torch.divide(sample_total, count) + x = torch.subtract(input, sample_mean) + if mask is None: + total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) + else: + total = sum( + x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined] + ) + if not keepdim: + count = count.reshape(total.shape) + if correction != 0: + real_dtype = (corresponding_real_dtype(compute_dtype) + if compute_dtype.is_complex else compute_dtype) + count = count.to(real_dtype) + count = torch.subtract(count, correction) + count = torch.maximum(count, count.new_zeros([])) + output = torch.divide(total, count).to(dtype=dtype) + if take_sqrt: + output = torch.sqrt(output) + return output + else: + raise ValueError( + f"masked std/var expects strided tensor (got {input.layout} tensor)" + ) + + +@_apply_docstring_templates +def var( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + unbiased: Optional[bool] = None, + *, + correction: Optional[Union[int, float]] = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +The identity value of sample variance operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + return _std_var( + input=input, + dim=dim, + unbiased=unbiased, + correction_opt=correction, + keepdim=keepdim, + dtype=dtype, + mask=mask, + take_sqrt=False, + ) + + +@_apply_docstring_templates +def std( + input: Union[Tensor, MaskedTensor], + dim: DimOrDims = None, + unbiased: Optional[bool] = None, + *, + correction: Optional[int] = None, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + """\ +{reduction_signature} +{reduction_descr} +The identity value of sample standard deviation operation is undefined. The +elements of output tensor with strided layout, that correspond to +fully masked-out elements, have ``nan`` values. +{reduction_args} +{reduction_example}""" + return _std_var( + input=input, + dim=dim, + unbiased=unbiased, + correction_opt=correction, + keepdim=keepdim, + dtype=dtype, + mask=mask, + take_sqrt=True, + ) + + +@_apply_docstring_templates +def softmax( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amax, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked softmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def log_softmax( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amax, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def softmin( + input: Union[Tensor, MaskedTensor], + dim: int, + *, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + mask_input = _combine_input_and_mask(amin, input, mask) + if mask_input.layout == torch.strided: + return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype) + else: + raise ValueError( + f"masked softmin expects strided tensor (got {mask_input.layout} tensor)" + ) + + +@_apply_docstring_templates +def normalize( + input: Union[Tensor, MaskedTensor], + ord: float, + dim: int, + *, + eps: float = 1e-12, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None, +) -> Tensor: + if dtype is None: + dtype = input.dtype + dim_ = _canonical_dim(dim, input.ndim)[0] + # TODO: eliminate mask_input as unnecessary when using masked divide. + mask_input = _combine_input_and_mask(sum, input, mask) + if mask_input.layout == torch.strided: + nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask) + # TODO: replace torch.maximum with masked maximum when available. + denom = torch.maximum(nrm_, nrm_.new_full([], eps)) + # TODO: replace torch.divide with masked divide when available. + return torch.divide(mask_input, denom) + else: + raise ValueError( + f"masked normalize expects strided tensor (got {mask_input.layout} tensor)" + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e38e03c87086cf50d031dd5591f64f65399d6ac1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# flake8: noqa + +from .binary import _apply_native_binary, _is_native_binary +from .core import is_masked_tensor, MaskedTensor +from .passthrough import _apply_pass_through_fn, _is_pass_through_fn +from .reductions import _apply_reduction, _is_reduction +from .unary import _apply_native_unary, _is_native_unary diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee9175c813909861b688f3d7d36bd9f3908b9fbb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/_ops_refs.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/_ops_refs.py new file mode 100644 index 0000000000000000000000000000000000000000..81a890af5d65fdeac98635aa16aed03184bcd290 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/masked/maskedtensor/_ops_refs.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from functools import partial +from typing import Callable, Any, Dict, TYPE_CHECKING +import torch + +if TYPE_CHECKING: + import torch._ops + +from .binary import ( + _apply_native_binary, + NATIVE_BINARY_FNS, + NATIVE_INPLACE_BINARY_FNS, +) +from .core import is_masked_tensor, MaskedTensor, _get_data, _masks_match, _maybe_get_mask +from .passthrough import ( + _apply_pass_through_fn, + PASSTHROUGH_FNS +) +from .reductions import ( + _apply_reduction, + NATIVE_REDUCE_FNS, + TORCH_REDUCE_FNS, + TENSOR_REDUCE_FNS, +) +from .unary import ( + _apply_native_unary, + NATIVE_UNARY_FNS, + NATIVE_INPLACE_UNARY_FNS, +) + + +__all__ = [] # type: ignore[var-annotated] + + +def _check_args_kwargs_length(args, kwargs, error_prefix, len_args=None, len_kwargs=None): + if len_args is not None and len_args != len(args): + raise ValueError(f"{error_prefix}: len(args) must be {len_args} but got {len(args)}") + if len_kwargs is not None and len_kwargs != len(kwargs): + raise ValueError(f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}") + + +class _MaskedContiguous(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if not is_masked_tensor(input): + raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.") + + if input.is_contiguous(): + return input + + data = input.get_data() + mask = input.get_mask() + + return MaskedTensor(data.contiguous(), mask.contiguous()) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _MaskedToDense(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if not is_masked_tensor(input): + raise ValueError("MaskedToDense forward: input must be a MaskedTensor.") + + if input.layout == torch.strided: + return input + + ctx.layout = input.layout + data = input.get_data() + mask = input.get_mask() + + return MaskedTensor(data.to_dense(), mask.to_dense()) + + @staticmethod + def backward(ctx, grad_output): + layout = ctx.layout + + if layout == torch.sparse_coo: + return grad_output.to_sparse_coo() + elif layout == torch.sparse_csr: + return grad_output.to_sparse_csr() + elif layout == torch.strided: + return grad_output.to_dense() + raise ValueError("to_dense: Unsupported input layout: ", layout) + + +class _MaskedToSparse(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if not is_masked_tensor(input): + raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.") + + # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo + if input.layout == torch.sparse_coo: + return input + + data = input.get_data() + mask = input.get_mask() + sparse_mask = mask.to_sparse_coo().coalesce() + sparse_data = data.sparse_mask(sparse_mask) + + return MaskedTensor(sparse_data, sparse_mask) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.to_dense() + + +class _MaskedToSparseCsr(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + if not is_masked_tensor(input): + raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") + + if input._masked_data.ndim != 2: + raise ValueError(f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}") + + if input.layout == torch.sparse_csr: + return input + + data = input.get_data() + mask = input.get_mask() + sparse_mask = mask.to_sparse_csr() + sparse_data = data.sparse_mask(sparse_mask) + + return MaskedTensor(sparse_data, sparse_mask) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.to_dense() + + +class _MaskedWhere(torch.autograd.Function): + @staticmethod + def forward(ctx, cond, self, other): + ctx.mark_non_differentiable(cond) + ctx.save_for_backward(cond) + return torch.ops.aten.where(cond, self, other) + + @staticmethod + def backward(ctx, grad_output): + (cond,) = ctx.saved_tensors + + def masked_out_like(mt): + return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool()) + + return ( + None, + torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)), + torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output), + ) + + +_MASKEDTENSOR_FUNCTION_TABLE = {} + +_function_fn_apply_map = { + (tuple(NATIVE_REDUCE_FNS), tuple(TORCH_REDUCE_FNS), tuple(TENSOR_REDUCE_FNS)): _apply_reduction, +} + +for fn_map_list, apply_fn in _function_fn_apply_map.items(): + for fn_map in fn_map_list: + for fn in fn_map: + _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn) + + +def register_function_func(ops): + """ + Used for registering a new __torch_function__ function to MaskedTensor + Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) + + The code to register a new function looks like: + + @register_function_func(list_of_ops) + def foo(func, *args, **kwargs): + + """ + def wrapper(func): + for op in ops: + _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op) + return wrapper + + +@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) +def _general_function_reductions(func, *args, **kwargs): + return _apply_reduction(func, *args, **kwargs) + + +@register_function_func([torch.Tensor.where, torch.where]) +def _function_where(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0) + return _MaskedWhere.apply(*args) + + +@register_function_func([torch.Tensor.contiguous]) +def _function_contiguous(func, *args, **kwargs): + return _MaskedContiguous.apply(args[0]) + + +@register_function_func([torch.Tensor.to_dense]) +def _function_to_dense(func, *args, **kwargs): + return _MaskedToDense.apply(args[0]) + + +@register_function_func([torch.Tensor.to_sparse]) +def _function_to_sparse(func, *args, **kwargs): + return _MaskedToSparse.apply(args[0]) + + +@register_function_func([torch.Tensor.to_sparse_csr]) +def _function_to_sparse_csr(func, *args, **kwargs): + return _MaskedToSparseCsr.apply(args[0]) + + +_MASKEDTENSOR_DISPATCH_TABLE: Dict["torch._ops.OpOverload", Callable[..., Any]] = {} + +def register_dispatch_func(aten_ops): + """ + Used for registering a new __torch_dispatch__ function to MaskedTensor + Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) + + The code to register a new function looks like: + + @register_dispatch_func(list_of_ops) + def foo(func, *args, **kwargs): + + """ + def wrapper(func): + for aten_op in aten_ops: + _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op) + return wrapper + + +@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) +def _general_reduction(func, *args, **kwargs): + return _apply_reduction(func, *args, **kwargs) + + +@register_dispatch_func(PASSTHROUGH_FNS) +def _general_passthrough(func, *args, **kwargs): + return _apply_pass_through_fn(func, *args, **kwargs) + + +@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS) +def _general_unary(func, *args, **kwargs): + return _apply_native_unary(func, *args, **kwargs) + + +@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS) +def _general_binary(func, *args, **kwargs): + return _apply_native_binary(func, *args, **kwargs) + + +@register_dispatch_func([torch.ops.aten.stride]) +def stride(func, *args, **kwargs): + return None + + +@register_dispatch_func([torch.ops.aten.sym_stride]) +def sym_stride(func, *args, **kwargs): + return None + + +@register_dispatch_func([torch.ops.prim.layout]) +def layout(func, *args, **kwargs): + return _get_data(args[0]).layout + + +@register_dispatch_func([torch.ops.aten.is_contiguous]) +def is_contiguous(func, *args, **kwargs): + data = _get_data(args[0]) + if data.is_sparse: + raise ValueError( + "MaskedTensors with sparse data do not have is_contiguous" + ) + return func(data, *args[1:], **kwargs) + + +@register_dispatch_func([torch.ops.aten.is_strides_like_format]) +def is_strides_like_format(func, *args, **kwargs): + data = _get_data(args[0]) + if data.is_sparse: + raise ValueError( + "MaskedTensors with sparse data do not have is_strides_like_format" + ) + return func(data, *args[1:], **kwargs) + + +@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense]) +def is_non_overlapping_and_dense(func, *args, **kwargs): + data = _get_data(args[0]) + if data.is_sparse: + raise ValueError( + "MaskedTensors with sparse data do not have is_non_overlapping_and_dense" + ) + return func(data, *args[1:], **kwargs) + + +@register_dispatch_func([torch.ops.aten.contiguous]) +def contiguous(func, *args, **kwargs): + if _get_data(args[0]).is_sparse: + raise ValueError( + "MaskedTensors with sparse data do not have contiguous" + ) + return _MaskedContiguous.apply(args[0]) + + +@register_dispatch_func([torch.ops.aten.new_empty_strided]) +def new_empty_strided(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3) + data = _get_data(args[0]) + mask = _maybe_get_mask(args[0]) + if tuple(args[1]) != tuple(data.size()): + raise ValueError(f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()") + if tuple(args[2]) != tuple(data.stride()): + raise ValueError(f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()") + return MaskedTensor(func(data, args[1], args[2], **kwargs), mask) + + +@register_dispatch_func([torch.ops.aten._local_scalar_dense]) +def _local_scalar_dense(func, *args, **kwargs): + if not _maybe_get_mask(args[0]): + raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor") + return torch.ops.aten._local_scalar_dense(_get_data(args[0])) + + +@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone]) +def _apply_fn_on_data(func, *args, **kwargs): + return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0])) + + +@register_dispatch_func([torch.ops.aten._to_copy]) +def _to_copy(func, *args, **kwargs): + new_data = func(_get_data(args[0]), *args[1:], **kwargs) + return MaskedTensor(new_data, _maybe_get_mask(args[0])) + + +@register_dispatch_func([torch.ops.aten._softmax]) +def _softmax(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) + data = _get_data(args[0]) + mask = _maybe_get_mask(args[0]) + result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2) + return MaskedTensor(result_data, mask) + + +@register_dispatch_func([torch.ops.aten.ones_like]) +def ones_like(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1) + result_data = func(_get_data(args[0]), **kwargs) + return MaskedTensor(result_data, _maybe_get_mask(args[0])) + + +@register_dispatch_func([torch.ops.aten._softmax_backward_data]) +def _softmax_backward_data(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4) + grad, output, dim, input_dtype = args + if is_masked_tensor(grad) and is_masked_tensor(output): + if not _masks_match(grad, output): + raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match") + grad_data = _get_data(grad) + new_grad_data = torch.ops.aten._masked_softmax_backward( + grad_data, + _get_data(output), + ~_maybe_get_mask(grad), + dim % grad_data.ndim, + ) + res = MaskedTensor(new_grad_data, _maybe_get_mask(grad)) + return res + else: + raise ValueError(f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors") + + +@register_dispatch_func([torch.ops.aten.copy_]) +def copy_(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) + if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])): + raise ValueError("args[0] mask and args[1] mask must match but do not") + func(_get_data(args[0]), _get_data(args[1])) + return args[0] + + +@register_dispatch_func([torch.ops.aten.where]) +def where(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) + if not torch.is_tensor(args[0]): + raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") + mx = args[1] + my = args[2] + if not is_masked_tensor(mx): + mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool)) + if not is_masked_tensor(my): + my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool)) + new_data = func(args[0], mx.get_data(), my.get_data()) + new_mask = func(args[0], mx.get_mask(), my.get_mask()) + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten._to_sparse]) +def _to_sparse(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + if not torch.is_tensor(args[0]): + raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor") + mt = args[0] + if not is_masked_tensor(mt): + mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool)) + if mt.is_sparse_coo(): + return mt + new_mask = func(_maybe_get_mask(args[0])).coalesce() + new_data = _get_data(args[0]).sparse_mask(new_mask) + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten._to_sparse_csr]) +def _to_sparse_csr(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + if not torch.is_tensor(args[0]): + raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") + mt = args[0] + if not is_masked_tensor(mt): + mt = MaskedTensor(mt, torch.ones_like(mt).bool()) + if mt.is_sparse_csr(): + return mt + new_mask = func(_maybe_get_mask(args[0])) + new_data = _get_data(args[0]).sparse_mask(new_mask) + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten._to_dense]) +def _to_dense(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + if not torch.is_tensor(args[0]): + raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") + mt = args[0] + if not is_masked_tensor(mt): + mt = MaskedTensor(mt, torch.ones_like(mt).bool()) + new_data = func(_get_data(args[0])) + new_mask = func(_maybe_get_mask(args[0])) + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten._indices]) +def _indices(func, *args, **kwargs): + # Assumes data is sparse + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + data = _get_data(args[0]).indices() + return MaskedTensor(data, torch.ones_like(data).bool()) + + +@register_dispatch_func([torch.ops.aten._values]) +def _values(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + data = _get_data(args[0]).values() + return MaskedTensor(data, torch.ones_like(data).bool()) + + +@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors]) +def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs): + new_args = list(args) + if is_masked_tensor(args[-1]): + new_args[-1] = args[-1].get_data() + if is_masked_tensor(args[-2]): + new_args[-2] = args[-2].get_data() + + new_data = func(*new_args, **kwargs) + new_args[-1] = torch.ones_like(new_args[-1]) + new_mask = func(*new_args, **kwargs).bool() + + return MaskedTensor(new_data, new_mask) + + +@register_dispatch_func([torch.ops.aten.is_same_size]) +def is_same_size(func, *args, **kwargs): + _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) + return _get_data(args[0]).is_same_size(_get_data(args[1])) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/mps/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/mps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52cda4fb0c06c0b56c41cd031dea53dcaafc87ed --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/mps/__init__.py @@ -0,0 +1,130 @@ +r""" +This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python. +Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased +performance can be achieved, by running work on the metal GPU(s). +See https://developer.apple.com/documentation/metalperformanceshaders for more details. +""" +import torch +from .. import Tensor + +_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False) +_default_mps_generator: torch._C.Generator = None # type: ignore[assignment] + + +# local helper function (not public or exported) +def _get_default_mps_generator() -> torch._C.Generator: + global _default_mps_generator + if _default_mps_generator is None: + _default_mps_generator = torch._C._mps_get_default_generator() + return _default_mps_generator + + +def synchronize() -> None: + r"""Waits for all kernels in all streams on a MPS device to complete.""" + return torch._C._mps_deviceSynchronize() + + +def get_rng_state() -> Tensor: + r"""Returns the random number generator state as a ByteTensor.""" + return _get_default_mps_generator().get_state() + + +def set_rng_state(new_state: Tensor) -> None: + r"""Sets the random number generator state. + + Args: + new_state (torch.ByteTensor): The desired state + """ + new_state_copy = new_state.clone(memory_format=torch.contiguous_format) + _get_default_mps_generator().set_state(new_state_copy) + + +def manual_seed(seed: int) -> None: + r"""Sets the seed for generating random numbers. + + Args: + seed (int): The desired seed. + """ + # the torch.mps.manual_seed() can be called from the global + # torch.manual_seed() in torch/random.py. So we need to make + # sure mps is available (otherwise we just return without + # erroring out) + if not torch._C._has_mps: + return + seed = int(seed) + _get_default_mps_generator().manual_seed(seed) + + +def seed() -> None: + r"""Sets the seed for generating random numbers to a random number.""" + _get_default_mps_generator().seed() + + +def empty_cache() -> None: + r"""Releases all unoccupied cached memory currently held by the caching + allocator so that those can be used in other GPU applications. + """ + torch._C._mps_emptyCache() + + +def set_per_process_memory_fraction(fraction) -> None: + r"""Set memory fraction for limiting process's memory allocation on MPS device. + The allowed value equals the fraction multiplied by recommended maximum device memory + (obtained from Metal API device.recommendedMaxWorkingSetSize). + If trying to allocate more than the allowed value in a process, it will raise an out of + memory error in allocator. + + Args: + fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction. + + .. note:: + Passing 0 to fraction means unlimited allocations + (may cause system failure if out of memory). + Passing fraction greater than 1.0 allows limits beyond the value + returned from device.recommendedMaxWorkingSetSize. + """ + + if not isinstance(fraction, float): + raise TypeError("Invalid type for fraction argument, must be `float`") + if fraction < 0 or fraction > 2: + raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2") + + torch._C._mps_setMemoryFraction(fraction) + + +def current_allocated_memory() -> int: + r"""Returns the current GPU memory occupied by tensors in bytes. + + .. note:: + The returned size does not include cached allocations in + memory pools of MPSAllocator. + """ + return torch._C._mps_currentAllocatedMemory() + + +def driver_allocated_memory() -> int: + r"""Returns total GPU memory allocated by Metal driver for the process in bytes. + + .. note:: + The returned size includes cached allocations in MPSAllocator pools + as well as allocations from MPS/MPSGraph frameworks. + """ + return torch._C._mps_driverAllocatedMemory() + + +from . import profiler +from .event import Event + +__all__ = [ + "get_rng_state", + "manual_seed", + "seed", + "set_rng_state", + "synchronize", + "empty_cache", + "set_per_process_memory_fraction", + "current_allocated_memory", + "driver_allocated_memory", + "Event", + "profiler", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/mps/event.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/mps/event.py new file mode 100644 index 0000000000000000000000000000000000000000..a206b640ef4ad41c546564d3fa91ba257762c9c6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/mps/event.py @@ -0,0 +1,45 @@ +import torch + + +class Event: + r"""Wrapper around an MPS event. + + MPS events are synchronization markers that can be used to monitor the + device's progress, to accurately measure timing, and to synchronize MPS streams. + + Args: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + """ + + def __init__(self, enable_timing=False): + self.__eventId = torch._C._mps_acquireEvent(enable_timing) + + def __del__(self): + # checks if torch._C is already destroyed + if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0: + torch._C._mps_releaseEvent(self.__eventId) + + def record(self): + r"""Records the event in the default stream.""" + torch._C._mps_recordEvent(self.__eventId) + + def wait(self): + r"""Makes all future work submitted to the default stream wait for this event.""" + torch._C._mps_waitForEvent(self.__eventId) + + def query(self): + r"""Returns True if all work currently captured by event has completed.""" + return torch._C._mps_queryEvent(self.__eventId) + + def synchronize(self): + r"""Waits until the completion of all work currently captured in this event. + This prevents the CPU thread from proceeding until the event completes. + """ + torch._C._mps_synchronizeEvent(self.__eventId) + + def elapsed_time(self, end_event): + r"""Returns the time elapsed in milliseconds after the event was + recorded and before the end_event was recorded. + """ + return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId)