diff --git a/.gitattributes b/.gitattributes index b7cfc37993554251b8c74a209a8c242aacfa2b61..1636cdfcaf1b223a53daf8e2a83182392142f306 100644 --- a/.gitattributes +++ b/.gitattributes @@ -123,3 +123,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 b/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 new file mode 100644 index 0000000000000000000000000000000000000000..7e10ed9de449c49d91e04a91ab513efe25586e80 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94fab98c15040558c3c80f2c1a2f5fda9baa72afc39a88bdcc82185f49d241c3 +size 86326864 diff --git a/.venv/lib/python3.11/site-packages/torch/_export/converter.py b/.venv/lib/python3.11/site-packages/torch/_export/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..b45d7849b29ae04ff1e77a812b0ccf86a90a4b0d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/converter.py @@ -0,0 +1,1584 @@ +# mypy: allow-untyped-defs +import builtins +import logging +import operator +import typing +import warnings +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union + +import torch +import torch.export._trace +from torch import _C +from torch._export.passes.replace_quantized_ops_with_standard_ops_pass import ( + replace_quantized_ops_with_standard_ops, +) +from torch.export.exported_program import ExportedProgram +from torch.export.graph_signature import ( + ConstantArgument, + CustomObjArgument, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TensorArgument, +) +from torch.fx import subgraph_rewriter + + +log = logging.getLogger(__name__) + + +def _get_param_count_list(method_graph, args_params): + param_count_list = [] + for input_, arg_params_ in zip(method_graph.inputs(), args_params): + if "PackedParams" in str(input_.type()): + in_vars, _ = torch.jit._flatten(arg_params_) + param_count_list.append(len(in_vars)) + else: + param_count_list.append(arg_params_ is not None) + + return param_count_list + + +def _trace_and_get_graph_from_model(model, args): + # A basic sanity check: make sure the state_dict keys are the same + # before and after running the model. Fail fast! + orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() + + # Disable Autocast cache because it replaces kernel's weight and bias + # by (undesired) constants. + # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 + prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + model, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) + + if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): + raise RuntimeError( + "state_dict changed after running the tracer; " + "something weird is happening in your model!" + ) + + return trace_graph, torch_out + + +def _create_jit_graph( + model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any] +) -> Tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]: + if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): + flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) + torch_out = None + + if isinstance(model, torch.jit.ScriptModule): + try: + graph = model.forward.graph # type: ignore[attr-defined] + except AttributeError as e: + raise RuntimeError("'forward' method must be a script method") from e + _C._jit_pass_onnx_function_substitution(graph) + freezed_module = _C._freeze_module( + typing.cast(_C.ScriptModule, model._c), preserveParameters=True + ) + module, params = _C._jit_onnx_list_model_parameters(freezed_module) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = _C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) + return graph, params, torch_out, module + + # torch.jit.ScriptFunction + params = [] + graph = model.graph + _C._jit_pass_onnx_function_substitution(graph) + param_count_list = _get_param_count_list(graph, args) + graph = _C._propagate_and_assign_input_shapes( + graph, flattened_args, param_count_list, False, False + ) + return graph, params, torch_out, None + + graph, torch_out = _trace_and_get_graph_from_model(model, args) + _C._jit_pass_onnx_lint(graph) + state_dict = torch.jit._unique_state_dict(model) + params = list(state_dict.values()) + graph_inputs = list(graph.inputs()) + user_input_num = len(graph_inputs) - len(state_dict) + param_names = list(state_dict.keys()) + for i, inp in enumerate(graph_inputs): + if i >= user_input_num: + inp.setDebugName(param_names[i - user_input_num]) + _C._jit_pass_onnx_function_substitution(graph) + return graph, params, torch_out, None + + +def list_add(a, b): + return a + b + + +def list_append(container, element): + return container + [element] + + +def execute_subgraph_from_prim_loop( + subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs +): + """ + subgraph: GraphModule from sub-block. + iter_idx: The index of interation. + len_loop_local_arguments: The number of loop local arguments in args. + """ + + # Loop local variables. TS graph create those as inputs because their values + # are updated inside the loop. + loop_local_args = args[:len_loop_local_arguments] + # Global variables that are not passed in as inputs to the loop sub-blocks + # but are directly used. Most of time, their values are not updated, but + # the only exception is when there are some operations that perform inplace + # updates. + global_args = args[len_loop_local_arguments:] + return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs) + + +def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule): + def pattern(im, dim, scale): + sym_size_int = torch.ops.aten.sym_size.int(im, dim) + scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int) + div_scalar_mode = torch.ops.aten.div.Scalar_mode( + scalar_tensor, scale, rounding_mode="trunc" + ) + int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode) + return int_tensor + + def replacement(im, dim, scale): + sym_size_int = torch.ops.aten.sym_size.int(im, dim) + return sym_size_int // scale + + replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement) + + +def is_valid_for_codegen(name): + if len(name) == 0: + raise RuntimeError("Empty argument name for codegen") + if name[0].isdigit(): + return False + return True + + +def normalize_name(name: str, prefix: str = "rename") -> str: + name = name.replace(".", "_") + if is_valid_for_codegen(name): + return name + return f"{prefix}_{name}" + + +def ir_name_to_func_name(name: str) -> str: + """prim::If -> convert_prim_If""" + name_list = name.split("::") + return "convert_" + "_".join(name_list) + + +def get_node_as_placeholder_or_get_attr(fx_graph, name, is_top_level_graph): + if is_top_level_graph: + return fx_graph.get_attr(name) + return fx_graph.placeholder(name) + + +_TORCH_DTYPE_TO_ENUM = { + torch.uint8: 0, + torch.int8: 1, + torch.int16: 2, + torch.int32: 3, + torch.int64: 4, + torch.float16: 5, + torch.float32: 6, + torch.float64: 7, + torch.complex32: 8, + torch.complex64: 9, + torch.complex128: 10, + torch.bool: 11, + torch.qint8: 12, + torch.quint8: 13, + torch.bfloat16: 15, +} + +_TORCH_ENUM_TO_DTYPE = {value: key for key, value in _TORCH_DTYPE_TO_ENUM.items()} + + +def get_dtype_as_int(tensor): + """ + prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of + the tensor and returns the integer corresponding to this dtype based on the + enum in ScalarType.h + """ + dtype = tensor.dtype + if dtype not in _TORCH_DTYPE_TO_ENUM: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _TORCH_DTYPE_TO_ENUM[dtype] + + +# Those operators will be automatically populated to a instance method +# of TS2FXGraphConverter with name convert__(). +# Please check __init__ for method population implementations. +kind_to_standard_operators = { + "prim::max": builtins.max, + "prim::min": builtins.min, + "prim::TupleIndex": operator.getitem, + "aten::__is__": operator.is_, + "aten::__isnot__": operator.is_not, + "aten::__not__": operator.not_, + "aten::__contains__": operator.contains, + "prim::dtype": get_dtype_as_int, + "aten::len": len, + # Mapping from specialized op to its symbolic counterpart. + # They currently do not have any other overrides. + "aten::numel": torch.ops.aten.sym_numel, + "aten::size": torch.ops.aten.sym_size, + "aten::storage_offset": torch.ops.aten.sym_storage_offset, + "aten::stride": torch.ops.aten.sym_stride, +} + + +def get_ir_value_parent_name_and_attr_name(node): + irv_parent_name, irv_name = node.input().debugName(), node.output().debugName() + attr_name = node.s("name") + return irv_name, irv_parent_name, attr_name + + +def construct_fqn(ir, ref_map, name_map): + name_list = [] + while ir in ref_map: + name_list.append(name_map[ir]) + ir = ref_map[ir] + return ".".join(reversed(name_list)) + + +def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]: + """ + Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes. + When a graph has control flow, the graph will be divided into multiple blocks. We want to convert + each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model + parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model, + we will run this pass which will: + 1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls. + 2. Process the graph bottom up to find the lifted attributes of each block by taking the union + of the attributes used in the current block, and the lifted attributes of all its child blocks. + + Returns: + A mapping of blocks to a set of FQNs of its lifted attributes. + """ + + # A map from a block to its expected to be lifted arguments. + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {} + + # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a + # GetAttr node. By traversing this reference map, we can figure out the + # full IR aliasing pass and figure out the FQN of an attribute. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1" + node_to_parent_map: Dict[str, str] = {} + + # Used for reconstructing the FQN of an attribute based on the reference map. + # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR + # This name map stores which attribute name is called for a src IR --> dest IR action. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear" + node_to_attr_name: Dict[str, str] = {} + + def _dfs_get_attr_dependency(entry): + """ + First DFS path to construct reference map and name map. + """ + for node in entry.nodes(): + if node.kind() == "prim::GetAttr": + ( + irv_name, + irv_parent_name, + attr_name, + ) = get_ir_value_parent_name_and_attr_name(node) + node_to_parent_map[irv_name] = irv_parent_name + node_to_attr_name[irv_name] = attr_name + for block in node.blocks(): + _dfs_get_attr_dependency(block) + + def _map_blocks_to_lifted_attrs(entry): + """ + Walk the graph in a bottom-up fashion to build the expected to be + lifted arguments for each block. + """ + arguments: Set[str] = set() + for node in entry.nodes(): + for block in node.blocks(): + # Recursively build. + arguments = arguments.union(_map_blocks_to_lifted_attrs(block)) + if node.kind() == "prim::GetAttr": + irv_name = node.output().debugName() + # Skip for intermediate GetAttr, which will anyway not result a FQN. + # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"} + # node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"} + # There is only one FQN %3-->%2-->%1: self.linear.weight + # %2-->%1 is not a FQN: self.linear + if irv_name not in set(node_to_parent_map.values()): + arguments.add( + construct_fqn(irv_name, node_to_parent_map, node_to_attr_name) + ) + if not isinstance(entry, torch._C.Graph): # Skip the top level. + blocks_to_lifted_attrs[entry] = arguments + return arguments + + _dfs_get_attr_dependency(graph) + _map_blocks_to_lifted_attrs(graph) + + return blocks_to_lifted_attrs + + +def get_attribute_fqn_from_ts_node( + name_to_attribute_fqn: Dict[str, str], node: torch._C.Node +) -> str: + def get_attr(name: str): + if name in name_to_attribute_fqn: + return name_to_attribute_fqn[name] + else: + raise ValueError(f"Attribute {name} not found") + + if node.kind() == "prim::SetAttr": + input_name = next(node.inputs()).debugName() + elif node.kind() == "prim::GetAttr": + input_name = node.input().debugName() + else: + raise RuntimeError( + f"Unexpected node kind when getting attribute fqn. node: {node} " + ) + + attr_name = node.s("name") + root_attr_name = get_attr(input_name) + attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name + + return attr_fqn + + +def get_op_overload(node: torch._C.Node): + schema_str = node.schema() + assert schema_str != "(no schema)", f"got empty schema for {node}" + schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str) + ns, op_name = str(schema.name).split("::") + override = schema.overload_name + + try: + op_overload_mod = getattr(torch.ops, ns) + op_overload_packet = getattr(op_overload_mod, op_name) + if override: + op_overload = getattr(op_overload_packet, override) + else: + op_overload = op_overload_packet.default + except Exception as e: + raise RuntimeError( + f"Unable to find operator {node.kind()} with schema {node.schema()}" + ) from e + + return op_overload + + +class TS2FXGraphConverter: + def __init__( + self, + ts_graph: Union[torch._C.Graph, torch._C.Block], + name_to_param: Dict[str, torch.Tensor], + name_to_buffer: Dict[str, torch.Tensor], + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], + name_to_non_tensor_attribute: Dict[str, Any], + name_to_constant: Dict[str, Any], + ): + self.ts_graph = ts_graph + self.name_to_param = name_to_param + self.name_to_buffer = name_to_buffer + + self.fx_graph: torch.fx.Graph = torch.fx.Graph() + self.input_specs: List[InputSpec] = [] + self.output_specs: List[OutputSpec] = [] + + self.name_to_node: Dict[ + str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]] + ] = {} + self.name_to_constant: Dict[str, Any] = name_to_constant + + # Mapping from torchscript node output name to attribute fully qualified name + self.name_to_attribute_fqn: Dict[str, str] = {} + + # Mapping from fully qualified name to real values or a fx graph node + # During convert, this represents the current value of a non-tensor attribute + # One use case is: + # def forward(self, x): + # c1 = self.count + # self.count += 1 + # c2 = self.count + # return x + c1 + c2 + self.name_to_non_tensor_attribute_node: Dict[str, Any] = {} + + # Mapping from fully qualified name to initial real values inputs + # We separate it from self.name_to_non_tensor_attribute_node since + # we need initial real value input when we construct fx.GraphModule + self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute + + self.subgraphs: Dict[str, torch.fx.GraphModule] = {} + + self.blocks_to_lifted_attrs = blocks_to_lifted_attrs + + # Populate methods for the standard operators. + for k in kind_to_standard_operators.keys(): + handler_func_name = ir_name_to_func_name(k) + # Create an indirect function call: + # convert__ --> lambda node: _convert_standard_operator(node) + setattr( + self, + handler_func_name, + lambda node: self._convert_standard_operators(node), + ) + + # This stores a list of return results that do not appear in the original TS + # graph's outputs. The reason we maintain this is because some operations in the sub-block + # might have inplace updates to the variable defined in the parent fx graph. After + # the execution of that sub-block, the variable defined in the parent fx graph also + # needs to be updated. + self.name_update_from_subblock_to_parent: Set[str] = set() + + def _is_get_attr_node(self, fqn): + return ( + fqn in self.name_to_buffer + or fqn in self.name_to_param + or ( + fqn in self.name_to_constant + and isinstance(self.name_to_constant[fqn], torch.ScriptObject) + ) + ) + + def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]): + subgraph_nodes, subgraph_converters = [], [] + for block in node.blocks(): + subgraph_converter = TS2FXGraphConverter( + block, + self.name_to_param, + self.name_to_buffer, + self.blocks_to_lifted_attrs, + {}, + self.name_to_constant, + ) + subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn + + for block_arg in arguments: + normalized_block_arg_name = normalize_name(block_arg) + placeholder_node = subgraph_converter.fx_graph.placeholder( + normalized_block_arg_name + ) + subgraph_converter.name_to_node[block_arg] = placeholder_node + + subgraph = subgraph_converter.convert() + subgraph_name = self.add_subgraph(subgraph) + subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name)) + subgraph_converters.append(subgraph_converter) + return subgraph_nodes, subgraph_converters + + def _identify_inputs_as_arguments(self, entry): + """ + Identify inputs from the innermost sub-block. This is needed + for nested sub-blocks when the input is hidden in the nested sub-block. + E.g., example IR of input is hidden in the nested sub-block. + Graph[x.1] + %1 = ... + Block[] + Block[x.1] + %2 = x.1 ... + """ + arguments: Set[str] = set() + for block in entry.blocks(): + for block_node in block.nodes(): + for block_node_in in block_node.inputs(): + if ( + block_node_in.debugName() in self.name_to_node + and block_node_in.debugName() not in self.name_to_attribute_fqn + ): + arguments.add(block_node_in.debugName()) + arguments = arguments.union( + self._identify_inputs_as_arguments(block_node) + ) + return arguments + + def is_top_level_graph(self): + return isinstance(self.ts_graph, torch._C.Graph) + + def add_subgraph(self, subgraph) -> str: + name = f"subgraph_{len(self.subgraphs)}" + self.subgraphs[name] = subgraph + return name + + def get_args_kwargs(self, node: torch._C.Node, schema): + args = [] + kwargs = {} + for input, schema_arg in zip(node.inputs(), schema.arguments): + if schema_arg.kwarg_only: + kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input) + else: + args.append(self.get_fx_value_by_ir_value(input)) + + return tuple(args), kwargs + + def get_fx_value_by_ir_value(self, value: torch._C.Value): + value_name = value.debugName() + + if value_name in self.name_to_node: + input_node = self.name_to_node[value_name] + return input_node + elif value_name in self.name_to_constant: + if isinstance(self.name_to_constant[value_name], torch.ScriptObject): + return self.fx_graph.get_attr(value_name) + return self.name_to_constant[value_name] + else: + raise ValueError(f"Input {value_name} not found") + + def get_fx_value_by_fqn(self, name): + if name in self.name_to_node: + fx_node = self.name_to_node[name] + elif name in self.name_to_constant: + fx_node = self.name_to_constant[name] + elif name in self.name_to_non_tensor_attribute_node: + fx_node = self.name_to_non_tensor_attribute_node[name] + elif name in self.name_to_non_tensor_attribute: + fx_node = self.name_to_non_tensor_attribute[name] + else: + raise ValueError(f"Attribute {name} not found") + return fx_node + + def convert(self) -> torch.fx.GraphModule: + self.convert_graph_inputs() + + for node in self.ts_graph.nodes(): + self.convert_node(node) + + self.convert_graph_outputs() + + # Pass parameter and buffer to the root for lookup. + gm = torch.fx.GraphModule( + { + **self.subgraphs, + **self.name_to_param, + **self.name_to_buffer, + **self.name_to_non_tensor_attribute, + **self.name_to_constant, + }, + self.fx_graph, + ) + + inplace_optimize_sym_size_div(gm) + + gm.graph.lint() + + return gm + + def convert_graph_inputs(self): + for graph_input in self.ts_graph.inputs(): + name = graph_input.debugName() + + if name in self.name_to_param: + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.PARAMETER, + arg=TensorArgument(name=normalized_name), + target=name, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif name in self.name_to_buffer: + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.BUFFER, + arg=TensorArgument(name=normalized_name), + target=name, + persistent=True, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif name in self.name_to_constant: + assert isinstance( + self.name_to_constant[name], torch.ScriptObject + ), "Input conversion only handles ScriptObject" + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.CUSTOM_OBJ, + arg=CustomObjArgument( + name=normalized_name, class_fqn=normalized_name + ), + target=name, + persistent=False, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif isinstance(graph_input.type(), torch.ClassType): + # Directly skip inputs that are ScriptObject but not used in the graph. + continue + else: + normalized_name = normalize_name(name, prefix="input") + self.input_specs.append( + InputSpec( + InputKind.USER_INPUT, + arg=TensorArgument(name=normalized_name), + target=name, + ) + ) + fx_node = self.fx_graph.placeholder(normalized_name) + + self.name_to_node[name] = fx_node + + def convert_aten_Float(self, node: torch._C.Node): + def to_float_tensor(t): + return t.to(dtype=torch.float).item() + + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 + fx_node = self.fx_graph.call_function( + to_float_tensor, + tuple(inp_list), + ) + self.name_to_node[node.output().debugName()] = fx_node + + def convert_aten_tensor(self, node: torch._C.Node): + """aten::tensor creates a constant tensor ad-hoc --> GetAttr""" + args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema) + + for k in kwargs: + if k == "requires_grad": + kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True + + to_tensor = ( + torch.tensor + if all(isinstance(a, int) for a in args) + else torch._refs.tensor + ) + + def target(*args, **kwargs): + if "dtype" in kwargs and kwargs["dtype"] is not None: + kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]] + return to_tensor(*args, **kwargs) + + # def to_dynamic_tensor(*args, **kwargs): + # if "dtype" in kwargs and kwargs["dtype"] is not None: + # kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]] + # return torch._refs.tensor(*args, **kwargs) + + output_name = node.output().debugName() + fx_node = self.fx_graph.call_function(target, args, kwargs) + self.name_to_node[output_name] = fx_node + + def convert_aten_append(self, node: torch._C.Node): + # special handle python list append: "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)" + + # inplace append to the list!! This is kinda crazy, as we are inplace mutating the list + # This makes the converter "non-functional", and the result depends on the order of the nodes being converter + # In a sense, the converter now becomes an stateful interpreter + warnings.warn( + "Converting aten::append.t, which is a inplace mutation of the list. " + "This makes the converter non-functional: the result depends on the order of the append nodes being converter!" + ) + + args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs()) + fx_node = self.fx_graph.call_function(list_append, args) + self.name_to_node[node.output().debugName()] = fx_node + + # inplace mutate arg[0], which is the python list + self.name_to_node[node.inputsAt(0).debugName()] = fx_node + + # Variables that need to be updated to parent module. + if not self.is_top_level_graph() and args[0].op == "placeholder": + self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName()) + + def convert_prim_Constant(self, node: torch._C.Node): + name = node.output().debugName() + + value: Any = None + if node.hasAttribute("value"): + constant_kind = node.kindOf("value") + if constant_kind == "i": + value = node.i("value") + elif constant_kind == "f": + value = node.f("value") + elif constant_kind == "s": + value = node.s("value") + elif constant_kind == "t": + alias_name = ( + f"lifted_tensor_{name}" # Follow naming convention from EP tracing. + ) + fx_node = self.fx_graph.get_attr(alias_name) + self.name_to_node[name] = fx_node + name, value = alias_name, node.t("value") + elif constant_kind == "ival": + value = node.ival("value") + else: + raise ValueError(f"Unsupported constant type: {node.kindOf('value')}") + else: + value = None + + self.name_to_constant[name] = value + + def convert_prim_CallMethod(self, node: torch._C.Node): + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 + fx_node = self.fx_graph.call_method( + node.s("name"), + tuple(inp_list), + ) + self.name_to_node[node.output().debugName()] = fx_node + + def convert_prim_device(self, node: torch._C.Node): + input_type = node.input().type() + if input_type.isSubtypeOf(torch._C.TensorType.get()): + device = input_type.device() # type: ignore[attr-defined] + output_name = node.output().debugName() + self.name_to_constant[output_name] = device + else: + raise ValueError(f"Unsupported JitType ({input_type}) when get device") + + def convert_prim_GetAttr(self, node: torch._C.Node): + # Build fully qulified name + attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) + output_name = node.output().debugName() + self.name_to_attribute_fqn[output_name] = attr_fqn + + if self.is_top_level_graph(): + if self._is_get_attr_node(attr_fqn): + # We insert a get_attr node due to two reasons. + # First, ts graph does not lift tensor constants as input nodes. So + # tensor constants may be ignored by in convert_graph_inputs(). + # Second, attr_fqn may have been written to via SetAttr. Two + # GetAttr may give different values. + self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn) + else: + if attr_fqn not in self.name_to_non_tensor_attribute_node: + self.name_to_non_tensor_attribute_node[ + attr_fqn + ] = self.name_to_non_tensor_attribute[attr_fqn] + self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[ + attr_fqn + ] + else: + # Special support for if blocks which do not allow SetAttr TorchScript + # node and get_attr FX Graph Node. + if self._is_get_attr_node(attr_fqn): + self.name_to_node[output_name] = self.name_to_node[attr_fqn] + + def convert_prim_SetAttr(self, node: torch._C.Node): + attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) + attr_value = tuple(node.inputs())[1] + ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value) + if self._is_get_attr_node(attr_fqn): + fx_attr_node = self.fx_graph.get_attr(attr_fqn) + self.fx_graph.call_function( + torch.Tensor.copy_, (fx_attr_node, ts_graph_tensor_input) + ) + else: + self.name_to_non_tensor_attribute_node[attr_fqn] = ts_graph_tensor_input + + def convert_call_function_op(self, node: torch._C.Node): + target = get_op_overload(node) + + args, kwargs = self.get_args_kwargs(node, target._schema) + + fx_node = self.fx_graph.call_function(target, args, kwargs) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + if node.outputsSize() == 1: + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + else: + for i, outp in enumerate(node.outputs()): + output_name = outp.debugName() + next_fx_node = self.fx_graph.call_function( + operator.getitem, (fx_node, i) + ) + self.name_to_node[output_name] = next_fx_node + + def convert_prim_TupleConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + + def convert_prim_ListConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + + def _convert_prim_iterator(self, node: torch._C.Node): + output_list = [] + for inp in node.inputs(): + output_list.append(self.get_fx_value_by_ir_value(inp)) + + output_name = node.output().debugName() + self.name_to_node[output_name] = output_list + + def convert_prim_DictConstruct(self, node: torch._C.Node): + output_dict = {} + k, v = None, None + for i, inp in enumerate(node.inputs()): + # We assume key value are stored in pair in the DictConstruct. + # The first element is the key and the following is the value. + if i % 2 == 0: + k = self.get_fx_value_by_ir_value(inp) + else: + v = self.get_fx_value_by_ir_value(inp) + assert ( + k is not None and v is not None + ), "DictConstruct has an empty key value pair." + output_dict[k] = v + k, v = None, None + + assert ( + k is None and v is None + ), "DictConstruct has an odd number of elements (violating our assumption)." + + output_name = node.output().debugName() + self.name_to_node[output_name] = output_dict + + def convert_prim_ListUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def convert_prim_TupleUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def _convert_prim_unpack_iterator(self, node: torch._C.Node): + # Single input and multiple outputs for unpacking. + for i, outp in enumerate(node.outputs()): + outp_name = outp.debugName() + inp = self.get_fx_value_by_ir_value(node.input()) + fx_node = self.fx_graph.call_function(operator.getitem, (inp, i)) + self.name_to_node[outp_name] = fx_node + + def convert_aten_Int(self, node: torch._C.Node): + # converts aten::Int as aten._to_copy + aten::_local_scalar_dense + target = torch.ops.aten._to_copy.default + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32}) + + fx_node = self.fx_graph.call_function( + torch.ops.aten._local_scalar_dense.default, (to_copy_node,) + ) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_NumToTensor(self, node: torch._C.Node): + # Converts prim::NumToTensor as aten.scalar_tensor. + # prim::NumToTensor IRs are currently triggered by: + # .size() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L950 + # .numel() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L971 + # For both of those APIs, torch.jit.trace implicitly sets the output tensor type + # to be LongTensor. + target = torch.ops.aten.scalar_tensor + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + + fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long}) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_CreateObject(self, node: torch._C.Node): + output_name = node.output().debugName() + self.name_to_attribute_fqn[output_name] = "" + + def convert_aten__convolution(self, node: torch._C.Node): + # converts aten::_convolution as aten.convolution, since aten::_convolution + # doesn't have a meta function + target = torch.ops.aten.convolution.default + args, kwargs = self.get_args_kwargs(node, target._schema) + + fx_node = self.fx_graph.call_function(target, args, kwargs) + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_aten_div(self, node: torch._C.Node): + target = get_op_overload(node) + schema = target._schema + + args, kwargs = self.get_args_kwargs(node, schema) + + # converts aten::div.Tensor_mode(x, tensor_constant) + # as aten.div.Scalar_mode(x, tensor_constant.item()) + if schema.overload_name == "Tensor_mode": + arg1_name = args[1].name + if arg1_name in self.name_to_constant and isinstance( + self.name_to_constant[arg1_name], torch.Tensor + ): + tensor_constant = self.name_to_constant[arg1_name] + if tensor_constant.numel() == 1: + updated_args = list(args) + updated_args[1] = self.name_to_constant[arg1_name].item() + + fx_node = self.fx_graph.call_function( + torch.ops.aten.div.Scalar_mode, + tuple(updated_args), + kwargs, + ) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + return + + self.convert_call_function_op(node) + + def convert_aten___getitem__(self, node: torch._C.Node): + input_container, index = tuple( + self.get_fx_value_by_ir_value(input) for input in node.inputs() + ) + fx_node = self.fx_graph.call_function( + operator.getitem, (input_container, index) + ) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_aten_to(self, node: torch._C.Node): + target = get_op_overload(node) + args, kwargs = self.get_args_kwargs(node, target._schema) + + # special handle aten.to.dtype and aten.to.prim_dtype followed by inplace_mutation_op + # coz aten.to + inplace_mutation_op pattern would trigger + # "cannot mutate tensors with frozen storage" functionalization error. + # To work around the issue, we override the copy to be True, so that the output + # is for sure not an alias of input + if target == torch.ops.aten.to.dtype or target == torch.ops.aten.to.prim_dtype: + user_nodes = [use.user for use in node.output().uses()] + user_targets = [ + get_op_overload(user_node) + for user_node in user_nodes + if user_node.schema() != "(no schema)" + ] + has_mutable_target = any( + target._schema.is_mutable for target in user_targets + ) + + if has_mutable_target: + assert len(args) >= 4 + new_args = list(args) + new_args[3] = True # copy, override to True + fx_node = self.fx_graph.call_function( + torch.ops.aten.to.dtype, tuple(new_args) + ) + # temp hack to work around the issue https://github.com/pytorch/pytorch/issues/131679 + # When this issue is fixed, the clone node would be no longer needed + clone_node = self.fx_graph.call_function( + torch.ops.aten.clone.default, (fx_node,) + ) + output_name = node.output().debugName() + self.name_to_node[output_name] = clone_node + return + + self.convert_call_function_op(node) + + def convert_aten_add(self, node: torch._C.Node): + if node.schema() == "(no schema)": + if isinstance(node.inputsAt(0).type(), torch.ListType) and isinstance( + node.inputsAt(1).type(), torch.ListType + ): + target = torch.ops.aten.add.t + else: + raise RuntimeError(f"unable to determind the target for {node}") + else: + target = get_op_overload(node) + + if target == torch.ops.aten.add.t: + # special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for + # RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'. + args, kwargs = self.get_args_kwargs(node, target._schema) + output_name = node.output().debugName() + self.name_to_node[output_name] = self.fx_graph.call_function(list_add, args) + else: + self.convert_call_function_op(node) + + def _check_prim_loop_support(self, node): + inputs = list(node.inputs()) + + # TODO: (1/N) stage. + if inputs[0].debugName() not in self.name_to_constant: + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of number of iterations." + ) + + # Make sure the condition is not updated in the subblock. + subblock = next(node.blocks()) + condition_output_name = next(subblock.outputs()).debugName() + for node in subblock.nodes(): + if ( + node.outputsSize() == 1 + and node.output().debugName() == condition_output_name + ): + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of condition." + ) + if node.outputsSize() >= 2: + for outp in node.outputs(): + if outp.debugName() == condition_output_name: + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of condition." + ) + + def convert_prim_Loop(self, node: torch._C.Node): + inputs = list(node.inputs()) + self._check_prim_loop_support(node) + + num_iterations = self.get_fx_value_by_ir_value(inputs[0]) + + # Find inputs. + loop_local_arguments = [inp.debugName() for inp in inputs[2:]] + + global_arguments = self._identify_inputs_as_arguments(node) + + # Lift parameters as inputs. + for block in node.blocks(): + global_arguments = global_arguments.union( + self.blocks_to_lifted_attrs[block] + ) + + global_arguments = list(global_arguments) + + subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph( + node, global_arguments + ) + + assert len(subgraph_nodes) == 1 + subgraph_converter = subgraph_converters[0] + if not self.is_top_level_graph(): + self.name_update_from_subblock_to_parent = ( + self.name_update_from_subblock_to_parent.union( + subgraph_converter.name_update_from_subblock_to_parent + ) + ) + + fx_block_args = [ + self.get_fx_value_by_fqn(name) + for name in loop_local_arguments + global_arguments + ] + for iter_idx in range(num_iterations): + loop_node = self.fx_graph.call_function( + execute_subgraph_from_prim_loop, + # Check execute_node function for the expected arguments order. + ( + subgraph_nodes[0], + iter_idx, + len(loop_local_arguments), + *fx_block_args, + ), + {}, + ) + + # Update the value of loop local variables. + if node.outputsSize() >= 1: + for i, outp in enumerate(node.outputs()): + output_name = outp.debugName() + self.name_to_node[output_name] = self.fx_graph.call_function( + operator.getitem, + ( + loop_node, + i + 1, + ), # + 1 because the 0th element is the condition. + ) + fx_block_args[i] = self.name_to_node[output_name] + + # Update the value of global variables, whose values are modified inplace. + for i, name in enumerate( + subgraph_converter.name_update_from_subblock_to_parent + ): + self.name_to_node[name] = self.fx_graph.call_function( + operator.getitem, + ( + loop_node, + i + node.outputsSize() + 1, + ), # + 1 because the 0th element is the condition. + ) + global_argument_index = global_arguments.index(name) + fx_block_args[ + i + node.outputsSize() + global_argument_index + ] = self.name_to_node[name] + + def _check_set_attr_in_if_block(self, if_node: torch._C.Node): + for block in if_node.blocks(): + for node in block.nodes(): + if node.kind() == "prim::SetAttr": + raise RuntimeError( + "During converting prim::If to torch.cond, found prim::SetAttr op" + " which is not supported yet. Please file an issue if you come " + "across this error." + ) + + def convert_prim_If(self, node: torch._C.Node): + self._check_set_attr_in_if_block(node) + + inputs = list(node.inputs()) + assert len(inputs) == 1 + predicate = self.get_fx_value_by_ir_value(inputs[0]) + + # Find inputs. + arguments = self._identify_inputs_as_arguments(node) + + # Lift parameters as inputs. + for block in node.blocks(): + arguments = arguments.union(self.blocks_to_lifted_attrs[block]) + + arguments = list(arguments) + subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments) + + assert len(subgraph_nodes) == 2 + + fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments] + + args = ( + predicate, + subgraph_nodes[0], + subgraph_nodes[1], + tuple(fx_block_args), + ) + + cond_node = self.fx_graph.call_function(torch.cond, args, {}) + + # prim::If can also have zero output. + if node.outputsSize() == 1: + output_name = node.output().debugName() + self.name_to_node[output_name] = cond_node + elif node.outputsSize() > 1: + for i, output in enumerate(node.outputs()): + output_name = output.debugName() + getitem = self.fx_graph.call_function(operator.getitem, (cond_node, i)) + self.name_to_node[output_name] = getitem + + def convert_aten_Bool(self, node: torch._C.Node): + self._convert_as_noop(node) + + def convert_prim_Enter(self, node: torch._C.Node): + # export generally treats prim::Enter as noop + # The only context manager export supports is aten::enable_grad. + # Unfortunately, TorchScript does not support aten::enable_grad yet. + # TODO: support aten::enable_grad in both TorchScript and Converter. + return + + def convert_prim_Exit(self, node: torch._C.Node): + # export treats prim::Exit as noop + return + + def _convert_as_noop(self, node: torch._C.Node): + # Converts the node as a no-op by mapping its output node as arg[0] + + target = get_op_overload(node) + schema = target._schema + + args, kwargs = self.get_args_kwargs(node, schema) + + output_name = node.output().debugName() + self.name_to_node[output_name] = args[0] + + def convert_profiler__record_function_exit(self, node: torch._C.Node): + # _record_function_exit has side effect so we keep it in fx.graph + # currently, _record_function_enter_new and _record_function_exit are + # discarded during `retrace_as_exported_program`. + target = torch.ops.profiler._record_function_exit + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + self.fx_graph.call_function(target, args) + + def convert_prim_tolist(self, node: torch._C.Node): + # prim::tolist cannot be supported by `_convert_standard_operators` + # since it requires call_method instead of call_function. + target = "tolist" + args = (self.get_fx_value_by_ir_value(next(node.inputs())),) + fx_node = self.fx_graph.call_method(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_Uninitialized(self, node: torch._C.Node): + # `prim::Uninitialized` is inserted by the compiler when it can prove + # the value will never be used. It can be introduced by exceptions, + # breaks, continues, and returns. + # So we add a dummy constant to the graph. + output_name = node.output().debugName() + self.name_to_constant[output_name] = torch.Tensor() + + def _convert_standard_operators(self, node: torch._C.Node): + target = kind_to_standard_operators[node.kind()] + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + fx_node = self.fx_graph.call_function(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_node(self, node: torch._C.Node): + node_kind = node.kind() + + # Get handler based on namespace and operator name. + # Provide a default node handler as well in case we don't find + # matching converter for that. + handler_func_name = ir_name_to_func_name(node_kind) + handler_func = getattr(self, handler_func_name, self.convert_call_function_op) + + # str calls print function implemented in CPP. To avoid repeating + # the entire logic here, we simply keep first line from node string (getting rid + # of sub-blocks IR prints). + node_str = "".join(str(node).split("\n")[:1]) + log.debug("[%s] converts [%s]", handler_func.__name__, node_str) + try: + handler_func(node) + except Exception as e: + raise RuntimeError(f"TS2EPConverter failed for node {node_kind}") from e + + def convert_graph_outputs(self): + args = [] + outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list( + self.name_update_from_subblock_to_parent + ) + for output_name in outp_name_list: + if output_name in self.name_to_node: + fx_node = self.name_to_node[output_name] + # TODO: Revisit this later after HigherOrderOp design changes. + # Currently, we cannot directly return input as output. + if ( + not self.is_top_level_graph() + and isinstance(fx_node, torch.fx.Node) + and fx_node.op == "placeholder" + ): + fx_node = self.fx_graph.call_function(torch.clone, (fx_node,)) + args.append(fx_node) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=TensorArgument(name=output_name), + target=output_name, + ) + ) + elif output_name in self.name_to_constant: + args.append(self.name_to_constant[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=ConstantArgument( + name=output_name, value=self.name_to_constant[output_name] + ), + target=output_name, + ) + ) + else: + raise ValueError(f"Output {output_name} not found") + + if len(args) == 0: + # Sub-block of prim::If can have zero output. + self.fx_graph.output([]) + elif len(args) == 1: + self.fx_graph.output( + args[0] + ) # Get rid of an extra list wrapped around final output. + elif len(args) > 1: + self.fx_graph.output( + args + ) # For prim::Loop and prim::If with multiple outputs. + else: + # Sub-block of prim::Loop can have multiple outputs. + self.fx_graph.output(args) + + +class ExplainTS2FXGraphConverter(TS2FXGraphConverter): + """ + Run TS2FXGraphConverter in an explain mode. It collects all failed operators conversions + and provide that information to users. In order to collect all failed conversions, it + also mocks some internal attributes (e.g., name_to_node). + """ + + class _DictMock(dict): + def __init__(self, dict_data, mock_value): + super().__init__(dict_data) + self.mock_value = mock_value + + def __getitem__(self, key): + # If the original dictionary has the key, return its value. + # Otherwise, return the mock value. + if not super().__contains__(key): + return self.mock_value + return super().__getitem__(key) + + def __contains__(self, key): + return True + + def __init__( + self, + ts_graph: Union[torch._C.Graph, torch._C.Block], + name_to_param: Dict[str, torch.Tensor], + name_to_buffer: Dict[str, torch.Tensor], + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], + name_to_non_tensor_attribute: Dict[str, Any], + name_to_constant: Dict[str, Any], + ): + super().__init__( + ts_graph, + name_to_param, + name_to_buffer, + blocks_to_lifted_attrs, + name_to_non_tensor_attribute, + name_to_constant, + ) + + # Data to keep track of unsupported nodes. + self.unsupported_node_list: List[torch._C.Node] = [] + + # Add mock to needed attributes. + self.name_to_node = ExplainTS2FXGraphConverter._DictMock( + self.name_to_node, + # Dummy node. + torch.fx.Node( + None, # type: ignore[arg-type] + "mock", + "call_function", + lambda: None, + (), + {}, + ), + ) + + def explain(self): + self.convert_graph_inputs() + for node in self.ts_graph.nodes(): + self.convert_node(node) + self.convert_graph_outputs() + + def convert_node(self, node): + try: + super().convert_node(node) + except Exception as e: + self.unsupported_node_list.append(node) + + +@contextmanager +def disable_logging(log): + disabled = log.disabled + log.disabled = True + try: + yield + finally: + log.disabled = disabled + + +class TS2EPConverter: + # TorchScript model to ExportedProgram converter + def __init__( + self, + ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction], + sample_args: Tuple[Any, ...], + sample_kwargs: Optional[Dict[str, Any]] = None, + ): + self.ts_model = ts_model + self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) + + self.sample_args = sample_args + self.sample_kwargs = sample_kwargs + + self.name_to_param: Dict[str, torch.Tensor] = {} + self.name_to_buffer: Dict[str, torch.Tensor] = {} + param_list = ( + list(self.ts_model.parameters()) + if not isinstance(self.ts_model, torch._C.ScriptFunction) + else [] + ) + if not isinstance(self.ts_model, torch._C.ScriptFunction): + for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] + # Check if tensor belongs to any parameter. + if any( + (tensor == param).all() + for param in param_list + if tensor.shape == param.shape + ): + self.name_to_param[k] = tensor + else: + self.name_to_buffer[k] = tensor + + self.name_to_non_tensor_attributes: Dict[str, Any] = {} + self.name_to_constant: Dict[str, Any] = {} + + self.lift_get_attr() + + def convert(self) -> ExportedProgram: + log.info( + """ +TS2EPConverter logging starts from here. + +INFO: (TORCH_LOGS="export" ) + * Log TorchScript IR. + +DEBUG: (TORCH_LOGS="+export" ), additionally + * Log conversion IR by IR in a format of [] converts []. + """ + ) + log.info("TorchScript graph\n\n%s\n", self.ts_graph) + + blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph) + + graph_converter = TS2FXGraphConverter( + self.ts_graph, + self.name_to_param, + self.name_to_buffer, + blocks_to_lifted_attrs, + self.name_to_non_tensor_attributes, + self.name_to_constant, + ) + gm = graph_converter.convert() + + # Post-proccessing step to deal with quantized operators. + replace_quantized_ops_with_standard_ops(gm) + log.info("GraphModule: %s", gm.print_readable(print_output=False)) + + ep = self.retrace_as_exported_program( + gm, + graph_converter.name_to_constant, + ) + log.info("%s", ep) + + # Post-processing step to ensure ExportedProgram has the same state_dict as + # the original TorchScript model. Throw warnings for additionally populated + # state_dict entries. + if not isinstance(self.ts_model, torch._C.ScriptFunction): + for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] + if k not in ep.state_dict: + warnings.warn( + f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram." + ) + ep.state_dict[k] = tensor + + return ep + + @disable_logging(log) + def explain(self, print_output=True): + blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph) + + graph_converter = ExplainTS2FXGraphConverter( + self.ts_graph, + self.name_to_param, + self.name_to_buffer, + blocks_to_lifted_attrs, + self.name_to_non_tensor_attributes, + self.name_to_constant, + ) + graph_converter.explain() + if len(graph_converter.unsupported_node_list) > 0: + explain_str = "Unsupported nodes are found in the following list:" + for i, n in enumerate(graph_converter.unsupported_node_list): + node_str = "".join(str(n).split("\n")[:1]) + explain_str += f"\n\n {i}. {n.kind()} [{node_str}]" + else: + explain_str = "Success!" + if print_output: + print(explain_str) + return explain_str + + def retrace_as_exported_program( + self, + gm: torch.fx.GraphModule, + name_to_constant: Dict[str, Any], + ): + # TODO: adjust input orders to match GraphSignature convention + ep = torch.export._trace._export( + gm, + self.sample_args, + strict=False, + pre_dispatch=True, + ) + + # Post-processing to make sure the ExportedProgram states are correct. + # Because during conversion, we set tensor constants as GetAttr, + # retracing cannot recognize them as tensor constants but instead + # treat them as buffers. We need to set them again here. + ep._constants.update( + { + k: v + for k, v in name_to_constant.items() + if isinstance(v, (torch.Tensor, torch.ScriptObject)) + } + ) + for k in name_to_constant: + ep.state_dict.pop(k, None) + + for i, spec in enumerate(ep.graph_signature.input_specs): + # Mark as constant tensors for erroneously traced buffers. + if spec.kind == InputKind.BUFFER and spec.target in name_to_constant: + assert isinstance( + name_to_constant[spec.target], torch.Tensor + ), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer" + spec.kind = InputKind.CONSTANT_TENSOR + ep.verifier().check(ep) + + return ep + + def lift_get_attr(self): + # This function lifts multiple data types. + + # 1. Tensor constants attributes (e.g., self.data = torch.tensor([2,3])) + # to buffers. Currently, when there are tensor constants, export + # would error and ask users to register tensor constants as buffers. + # Since it is hard to manually do so for TorchScript models + # (e.g., source code is missing), this function automatically + # lifts tensor constants to be buffers. + + # 2. ScriptObbject to constant. It will then be converted to getattr in + # in the fx graph. + # + # This function should happen in TS2EPConverter instead of + # TS2FXGraphConverter since it gets attributes from self.ts_model + # which is not accessable in TS2FXGraphConverter. It is similar to where + # we collect self.name_to_param and self.name_to_buffer. + name_to_attribute_fqn: Dict[str, str] = {} + + def get_attr(fqn: str): + name = fqn.split(".") + v = self.ts_model + for n in name: + v = getattr(v, n) + return v + + def get_fqn(node: torch._C.Node): + attr_name = node.s("name") + input_name = node.input().debugName() + root_attr_name = name_to_attribute_fqn[input_name] + attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name + return attr_fqn + + def _dfs_get_attr(block): + for node in block.nodes(): + if node.kind() == "prim::CreateObject": + output_name = node.output().debugName() + name_to_attribute_fqn[output_name] = "" + + if node.kind() == "prim::GetAttr": + attr_fqn = get_fqn(node) + value = get_attr(attr_fqn) + output_name = node.output().debugName() + name_to_attribute_fqn[output_name] = attr_fqn + if isinstance(value, torch.Tensor): + if attr_fqn not in self.name_to_buffer: + # Lift tensor constants to be a buffer + self.name_to_buffer[attr_fqn] = value + elif isinstance(value, torch.ScriptObject): + if attr_fqn not in self.name_to_constant: + self.name_to_constant[attr_fqn] = value + else: + self.name_to_non_tensor_attributes[attr_fqn] = value + + for subblock in node.blocks(): + _dfs_get_attr(subblock) + + _dfs_get_attr(self.ts_graph) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py b/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7331659d26a00dc68e0d169a70328cec251c2b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py @@ -0,0 +1,523 @@ +# mypy: allow-untyped-defs +import contextlib +import inspect +import logging +from collections import defaultdict +from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.source import ( + AttrSource, + GetItemSource, + LocalSource, + TensorProperty, + TensorPropertySource, +) +from torch._dynamo.variables.builder import TrackedFake +from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim +from torch._export.passes.lift_constants_pass import ConstantAttrMap +from torch._guards import Source +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.export import Constraint +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _combine_args, + _DimHint, + _process_dynamic_shapes, + _transform_shapes_for_default_dynamic, + _tree_map_with_path, +) +from torch.export.graph_signature import CustomObjArgument +from torch.fx.experimental import _config as config +from torch.fx.experimental.symbolic_shapes import ( + _find_user_code_frame, + _suggest_fixes_for_data_dependent_error_non_strict, + ConstraintViolationError, + DimDynamic, + EqualityConstraint, + GuardOnDataDependentSymNode, + ShapeEnv, + StatelessSymbolicContext, + ValueRanges, +) +from torch.utils._pytree import ( + GetAttrKey, + KeyPath, + MappingKey, + SequenceKey, + tree_map_with_path, +) + + +if TYPE_CHECKING: + from sympy import Symbol + + +log = logging.getLogger(__name__) + + +def key_path_to_source(kp: KeyPath) -> Source: + """ + Given a key path, return the source for the key path. + """ + source: Source = LocalSource("args") + for k in kp: + if isinstance(k, SequenceKey): + source = GetItemSource(source, k.idx) + elif isinstance(k, MappingKey): + source = GetItemSource(source, k.key) + elif isinstance(k, GetAttrKey): + source = AttrSource(source, k.name) + else: + raise ValueError(f"Unknown KeyEntry {k}") + + return source + + +def _is_constant_argument(t): + return t is None or isinstance(t, (int, float, bool, str)) + + +def fakify( + mode: FakeTensorMode, + kp: KeyPath, + t: Any, + t_constraints: Dict[int, Dict[int, Constraint]], + sources: Dict[Tuple[int, int], List[Source]], +): + source = key_path_to_source(kp) + if _is_constant_argument(t) or isinstance(t, torch.ScriptObject): + return t + + if not isinstance(t, torch.Tensor): + raise ValueError(f"Unsupported input type {type(t)}") + n_dims = len(t.shape) + symbolic_context = StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.DYNAMIC] * n_dims, + constraint_sizes=[None] * n_dims, + ) + t_id = id(t) + assert mode.shape_env is not None + if t_id in t_constraints: + for i, constraint in t_constraints[t_id].items(): + symbolic_context.constraint_sizes[i] = constraint.constraint_range + src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i) + sources[(t_id, i)].append(src) + mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment] + fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) + mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr] + return fake + + +def make_fake_inputs( + nn_module, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=False, + allow_complex_guards_as_runtime_asserts=False, +): + """ + Given an nn module, example inputs, and constraints, return a new fake mode, + fake inputs created in that mode whose dynamic shape dimensions are constrained + by the given ranges, and sources for pairs of dynamic shape dimensions that are + constrained to be equal. + """ + # TODO(avik): refactor Dynamo to avoid duplication of the following code + # between non-strict and strict. + # Specifically, here (non-strict) we do the following pre-tracing steps: + # - Fakify inputs. + # - Process input shape equalities. + # In strict, these steps are spread across multiple files: + # - output_graph.py fakifies inputs. + # - [post-tracing] guards.py processes input shape equalities. + + combined_args = _combine_args(nn_module, args, kwargs) + _check_dynamic_shapes(combined_args, dynamic_shapes) + transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( + combined_args, dynamic_shapes + ) + constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes) + t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) + for constraint in constraints: + t_constraints[constraint.t_id][constraint.dim] = constraint + + context = torch._guards.TracingContext.try_get() + if context is not None: + # This occurs when we are exporting within dynamo. There already exists + # a toplevel TracingContext with a fake mode, so we do not want to + # create another fake mode. + fake_mode = context.fake_mode + elif not _is_torch_jit_trace: + code = nn_module.forward.__code__ + co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } + fake_mode = FakeTensorMode( + shape_env=ShapeEnv( + tracked_fakes=[], + co_fields=co_fields, + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + ), + allow_non_fake_inputs=True, + export=True, + ) + else: + fake_mode = FakeTensorMode( + shape_env=ShapeEnv( + tracked_fakes=[], + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + ), + allow_non_fake_inputs=True, + ) + if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: + raise ValueError( + "Detected fake_mode does not have a shape_env with tracked fakes. " + "If you constructed the module under a FakeTensorMode, " + "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))" + ) + + with fake_mode: + # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock + if not _is_torch_jit_trace: + original_signature = inspect.signature(nn_module.forward) + else: + original_signature = None + sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list) + fake_args, fake_kwargs = tree_map_with_path( + lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), + (args, kwargs), + ) + + names: Dict[str, Tuple[int, int]] = {} + source_pairs: List[Tuple[Source, Source]] = [] + derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = [] + phantom_symbols: Dict[str, Symbol] = {} + for constraint in constraints: + torch.export.dynamic_shapes._process_equalities( + constraint, + lambda t_id, dim: sources[(t_id, dim)], + fake_mode.shape_env, + names, + source_pairs, + derived_equalities, + phantom_symbols, + ) + + equalities_inputs = EqualityConstraint( + source_pairs=source_pairs, + derived_equalities=derived_equalities, + phantom_symbols=list(phantom_symbols.values()), + warn_only=False, + ) + return ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + transformed_dynamic_shapes, + ) + + +def _flatten_dynamic_shapes( + combined_args: Dict[str, Any], + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], +) -> List[Any]: + flat_shapes = [] + + def _tree_map_helper(path, t, shape): + nonlocal flat_shapes + flat_shapes.append(shape) + + _tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes) + return flat_shapes + + +def produce_guards_and_solve_constraints( + fake_mode: FakeTensorMode, + gm: torch.fx.GraphModule, + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + equalities_inputs: EqualityConstraint, + original_signature: inspect.Signature, + _is_torch_jit_trace=False, +): + """ + Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, + and a graph module, produce guards on the fake mode's shape env (raising constraint + violations if any), solve (to suggest simplifications or fixes). + Dynamo already performs this, so this is for non-strict mode. + + Additional inputs: + equalities_inputs: the equality constraints to use for guards + original_signature: the signature of the forward method + """ + shape_env = fake_mode.shape_env + assert shape_env is not None + assert shape_env.tracked_fakes is not None + + placeholders = [tf.fake for tf in shape_env.tracked_fakes] + sources = [tf.source for tf in shape_env.tracked_fakes] + input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes] + constraint_violation_error = None + try: + shape_env.produce_guards( + placeholders, + sources, + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + ignore_static=False, + ) + except ConstraintViolationError as e: + constraint_violation_error = e + + shape_env.frozen = True + dim_constraints = shape_env.dim_constraints + if dim_constraints is None: + # Expected when shape_env.produce_guards throws an early constraint violation error. + # There is nothing to solve for in this case. + # TODO(avik): Maybe record the constraint violation error instead and replay later? + assert constraint_violation_error + raise constraint_violation_error + dim_constraints.solve() + forced_specializations = dim_constraints.forced_specializations() + if not _is_torch_jit_trace: + msg = dim_constraints.prettify_results( + original_signature, + dynamic_shapes, + constraint_violation_error, + forced_specializations, + ) + else: + # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod + msg = "dummy constraint violation message" + if constraint_violation_error: + constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) + elif forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + if constraint_violation_error: + raise constraint_violation_error + + +def make_constraints( + fake_mode: FakeTensorMode, + gm: torch.fx.GraphModule, + combined_args: Dict[str, Any], + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + num_lifted_inputs: int, +): + """ + Given a fake mode's shape env and user-specified dynamic shapes, + return the resulting range constraints and equality constraints. + + Additional args: + num_lifted_inputs: the number of non-user-input placeholder nodes in the graph + (used only to enumerate the user-input nodes) + """ + + shape_env = fake_mode.shape_env + assert shape_env is not None + inline_constraints = gm.meta.get("inline_constraints", []) + range_constraints = { + symbol: inline_constraints[symbol] for symbol in inline_constraints + } + if not dynamic_shapes: + return range_constraints + + # get individual dynamic shapes spec for each input + if not isinstance(dynamic_shapes, dict): + assert isinstance(dynamic_shapes, (tuple, list)) + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) + + # check number of shapes vs. number of inputs + num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True) + assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs + + input_dims = defaultdict(list) + free_symbols = set() + for input_index, node in enumerate(gm.graph.nodes): + if input_index < num_lifted_inputs or node.op != "placeholder": + continue + if _is_constant_argument(node.meta["val"]) or isinstance( + node.meta["val"], CustomObjArgument + ): + continue + shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs] + for i, d in enumerate(node.meta["val"].shape): + if isinstance(d, torch.SymInt) and not d.node.expr.is_number: + # Look up the range constraint for the symbol corresponding to this shape dimension + # and store it indexed by the symbolic expression corresponding to it. + # NOTE(avik): Use node._expr instead of node.expr for the lookup here because + # we want the symbol, not its replacement, which could be an expression. Maybe + # there's a better way to do this, e.g., by (re)computing value ranges for expressions? + dim = shape_spec[i] if shape_spec else None + if dim is None or isinstance(dim, _DimHint): + range_constraints[d.node.expr] = shape_env.var_to_range[ + d.node._expr + ] + else: + range_constraints[d.node.expr] = ValueRanges( + lower=dim.min, upper=dim.max + ) + input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i)) + free_symbols.update(d.node.expr.free_symbols) + + for symbol in free_symbols: + if symbol not in range_constraints: + # Placeholders can have symbolic shapes that are derived expressions. + # The above code will record direct range constraints for them + # so that we can do runtime assertions. In addition, for serde checks + # we want to record range constraints for their root symbols. + range_constraints[symbol] = shape_env.var_to_range[symbol] + + return range_constraints + + +def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: + """Search the module hierarchy, gathering up all tensor and ScriptObject constants. + + Returns a dictionary mapping hash(value) to the name of the constant. We + have to abuse `hash` here unfortunately, see: [ScriptObject hash]. + """ + constants = ConstantAttrMap() + buffers_parameters = set(m.buffers()) + buffers_parameters.update(m.parameters()) + + def inner(m: torch.nn.Module, prefix_atoms: List[str], constants): + for k, v in m.__dict__.items(): + if isinstance( + v, + ( + torch.Tensor, + torch.ScriptObject, + FakeScriptObject, + ), + ): + if v in buffers_parameters: + # filter out buffers and parameters, leaving only constants + continue + + fqn = ".".join(prefix_atoms + [k]) + constants.add(v, fqn) + for k, v in m.named_children(): + inner(v, prefix_atoms + [k], constants) + + inner(m, [], constants) + return constants + + +@contextlib.contextmanager +def _fakify_script_objects( + mod: torch.nn.Module, + args: Tuple[Any], + kwargs: Dict[Any, Any], + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, +): + # This context manager is used to fakify script objects into FakeScriptObject. + # Inputs: + # mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified. + # args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified. + # fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors. + # + # Returns: + # mod: the patched module, its (and its recursive submodules) script object attrs have been fakified. + # fake_args, fake_kwargs: new fakified args and kwargs. + # Script object inputs have been fakified. Don't touch the tensors. + # fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object. + # fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching. + + constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod) + assert not any( + isinstance(obj, FakeScriptObject) for obj in constant_attrs.values() + ), "Mod shouldn't contain any FakeScriptObject." + assert not pytree.tree_any( + lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs) + ), "args and kwargs shouldn't contain any FakeScriptObject." + + patched_attr = {} + fake_constant_attrs = ConstantAttrMap() + fake_to_real = {} + + def _maybe_fakify_obj(obj): + fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj) + fake_to_real[fake_obj] = obj + return fake_obj + + def _leaf_mod_and_attr( + mod: torch.nn.Module, attr_fqn: str + ) -> Tuple[torch.nn.Module, str]: + *prefix_attr, last_attr = attr_fqn.split(".") + cur_mod = mod + for attr in prefix_attr: + cur_mod = getattr(cur_mod, attr) + return cur_mod, last_attr + + try: + for obj, fqns in constant_attrs.items(): + if isinstance(obj, torch.ScriptObject): + fake_script_obj = _maybe_fakify_obj(obj) + for fqn in fqns: + cur_mod, attr = _leaf_mod_and_attr(mod, fqn) + assert obj is getattr(cur_mod, attr) + setattr(cur_mod, attr, fake_script_obj) + fake_constant_attrs.add(fake_script_obj, fqn) + patched_attr[fqn] = obj + else: + for fqn in fqns: + fake_constant_attrs.add(obj, fqn) + + fake_args, fake_kwargs = pytree.tree_map_only( + torch.ScriptObject, _maybe_fakify_obj, (args, kwargs) + ) + yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real) + finally: + for fqn, orig_obj in patched_attr.items(): + cur_mod, attr = _leaf_mod_and_attr(mod, fqn) + setattr(cur_mod, attr, orig_obj) + + +class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): + """ + 1. Handles data-dependent errors raised by torch function calls in non-strict. + + Any data-dependent error is due to some condition on unbacked symints + that cannot be resolved. A mechanical way of fixing the error is to use + a torch._check() call to assert either that condition or its negation. + The handler suggests these options as code and points to the location + of the torch function call that raised the error as part of the error + message shown to the user, who can then simply select and copy-paste + a suggested fix at that location. + + NOTE: Not all data-dependent errors are raised by torch function calls. + In particular, conditions on unbacked symints can appear outside such + calls, and as such are not handled here. + + 2. Handles line-of-code logging for each torch function call in non-strict. + + Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ... + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: + frame = _find_user_code_frame() + if frame is not None: + log.debug( + "%s called at %s:%s in %s", + func.__qualname__, + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + try: + return func(*args, **kwargs) + except GuardOnDataDependentSymNode as e: + _suggest_fixes_for_data_dependent_error_non_strict(e) + raise diff --git a/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py b/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..55612c98ce8d51d95999f0f4e124f3479070deb1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/pass_base.py @@ -0,0 +1,441 @@ +# mypy: allow-untyped-defs +import operator +import traceback +import typing +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from functorch.experimental.control_flow import _unstack_pytree +from torch import fx +from torch._dispatch.python import enable_python_dispatcher +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._subclasses import FakeTensor, UnsupportedFakeTensorException +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import traceback as fx_traceback +from torch.fx.experimental.proxy_tensor import PythonKeyTracer +from torch.fx.graph import CodeGen +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils import _pytree as pytree +from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings + + +__all__ = ["_ExportPassBaseDeprecatedDoNotUse"] + + +Argument = Any +Value = Any +Fn = Callable[..., Any] +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +_TORCH_SYM_OPS: Set[Callable] = { + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, +} + + +class ExportPassBaseError(RuntimeError): + pass + + +class _ExportPassBaseDeprecatedDoNotUse(PassBase): + """ + Interpreter-based pass class to help users maintain the IR spec while writing + transformations. + """ + + @staticmethod + def _create_dummy_node_metadata(): + return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) + + + class ExportTracer(PythonKeyTracer): + def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None: + super().__init__() + self.callback = callback + self.root = torch.nn.Module() + self.graph = torch.fx.Graph() + self.graph.set_codegen(codegen) + self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.submodules: Dict[torch.nn.Module, str] = {} + + def trace(self) -> None: # type: ignore[override] + raise ExportPassBaseError("ExportTracer doesn't support trace().") + + def create_arg(self, a: Argument) -> torch.fx.Node: + if isinstance(a, torch.nn.Module): + if a not in self.submodules: + name_submodule = f"submodule_{len(self.submodules)}" + self.root.add_module(name_submodule, a) + self.submodules[a] = name_submodule + elif isinstance(a, FakeTensor): + if not hasattr(a, "constant") or a.constant is None: + raise ExportPassBaseError(f"Cannot add {a} to graph.") + a = a.constant + node = super().create_arg(a) + if ( + isinstance(a, torch.Tensor) + and isinstance(node, torch.fx.Node) + and node.op == "get_attr" + ): + self.set_metadata(node, a) + self.callback.on_attr(ProxyValue(a, node)) + return node + + def set_metadata( + self, node: torch.fx.Node, value: Argument, + ) -> None: + # propagate the fake tensor or sym nodes + def make_val( + x: Argument, + ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]: + if isinstance(x, FakeTensor): + return x + elif isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + # TODO we should allocate static shapes + # for param/buffer values + if isinstance(x, torch.nn.Parameter): + fake_tensor = self.fake_tensor_mode.from_tensor( + x, static_shapes=True + ) + else: + fake_tensor = self.fake_tensor_mode.from_tensor(x) + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + print( + "Fakeifying a Tensor subclass is not supported \ + right now. Instead a TensorMetadata is used." + ) + fake_tensor = None + return fake_tensor + elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)): + return x + else: + return None + + node.meta["val"] = pytree.tree_map(make_val, value) + + # Set the tensor_metadata for values that do not have a corresponding FakeTensor + def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: + if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + _ = self.fake_tensor_mode.from_tensor(x) + tensor_meta = None + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + tensor_meta = _extract_tensor_metadata(x) + return tensor_meta + else: + return None + + node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) + + class ExportInterpreter(fx.Interpreter): + def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None: + super().__init__(gm) + self.callback = callback + self.node: torch.fx.Node = next(iter(gm.graph.nodes)) + + def placeholder( + self, + target: str, # type: ignore[override] + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + arg = super().placeholder(target, args, kwargs) + return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) + + def output( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + return self.callback.output(args[0], NodeMetadata(self.node.meta)).data + + def call_function( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + meta = NodeMetadata(self.node.meta) + + if target == operator.getitem: + value, key = args + return self.callback.call_getitem(value, key, meta) + elif getattr(target, "__module__", None) in {"_operator", "math"}: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif target in _TORCH_SYM_OPS: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + return self.callback.call_operator( + target, + args, + kwargs, + meta, + ) + elif target == torch.ops.higher_order.cond: + pred, true_fn, false_fn, inputs = args + return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) + elif target == torch.ops.higher_order.map_impl: + f, mapped_args, operands = args # type: ignore[assignment] + return self.callback.call_map(f, mapped_args, operands, meta) + # For other unregistered HigherOrderOps, just interpret them blindly + elif isinstance(target, torch._ops.HigherOrderOperator): + return self.callback._fx( + "call_function", + target, + args, + kwargs, + meta, + ) + else: + raise ExportPassBaseError(f"Unsupported target type: {target}") + + def get_attr( + self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] + ) -> Argument: + return super().get_attr(target, args, kwargs) + + def call_module( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> None: + raise ExportPassBaseError("call_module is not supported.") + + def call_method( + self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] + ) -> None: + raise ExportPassBaseError("call_method is not supported.") + + def run_node(self, n: torch.fx.Node) -> Argument: + self.node = n + self.callback.node_debug_str = n.format_node() + return super().run_node(n) + + def __init__(self) -> None: + self.interpreter = PropagateUnbackedSymInts( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + self.tracer = self.ExportTracer(self, CodeGen()) + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self._initialized = True + self.node_debug_str: typing.Optional[str] = None + + def _fx( + self, + kind: str, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + args_data, kwargs_data = pytree.tree_map_only( + ProxyValue, lambda x: x.data, (args, kwargs) + ) + res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data) + args_proxy, kwargs_proxy = pytree.tree_map_only( + ProxyValue, lambda x: x.proxy, (args, kwargs) + ) + + name = None + if isinstance(target, torch._ops.OpOverload): + name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) + + res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name) + res_proxy.node.meta.update(meta.data) + if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env): + if symbol_to_path := compute_unbacked_bindings(shape_env, res_data): + res_proxy.node.meta["unbacked_bindings"] = symbol_to_path + self.tracer.set_metadata(res_proxy.node, res_data) + return ProxyValue(res_data, res_proxy) + + def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: + # TODO(angelayi): Update this with what we decide to do for metadata in + # the exported graph module + if (args := graph_module.meta.get("args", None)) is not None: + return list(args) + + def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: + if "val" in node.meta: + fake = node.meta["val"] + if hasattr(fake, "constant") and fake.constant is not None: + return fake.constant + return fake + elif tensor_meta := node.meta.get("tensor_meta"): + assert self.fake_tensor_mode is not None + return FakeTensor( + self.fake_tensor_mode, + torch.empty( + tensor_meta.shape, + dtype=tensor_meta.dtype, + device="meta", + requires_grad=tensor_meta.requires_grad, + memory_format=tensor_meta.memory_format, + ), + torch.device("cpu"), + ) + elif len(node.users) == 0: + return None + raise ExportPassBaseError( + f"Cannot construct an input for graph module: {graph_module}.", + ) + + return [ + extract_input(node) + for node in graph_module.graph.nodes + if node.op == "placeholder" + ] + + def on_attr(self, attr: ProxyValue) -> None: + pass + + def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: + arg_proxy = self.tracer.create_proxy("placeholder", name, (), {}) + arg_proxy.node.meta = meta.data + self.tracer.set_metadata(arg_proxy.node, arg) + return ProxyValue(arg, arg_proxy) + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", op, args, kwargs, meta) + + def call_sym( + self, + target: Fn, + args: Tuple[Argument, ...], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", target, args, {}, meta) + + def call_cond( + self, + pred: ProxyValue, + true_fn: torch.fx.GraphModule, + false_fn: torch.fx.GraphModule, + inputs: List[Argument], + meta: NodeMetadata, + ) -> ProxyValue: + true_branch = self.call_submodule(true_fn, tuple(inputs)) + false_branch = self.call_submodule(false_fn, tuple(inputs)) + assert true_branch is not None + assert false_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.cond, + (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)), + {}, + meta, + ) + + def call_map( + self, + f: torch.fx.GraphModule, + mapped_args: List[ProxyValue], + operands: List[ProxyValue], + meta: NodeMetadata, + ) -> ProxyValue: + xs = _unstack_pytree([arg.data for arg in mapped_args])[0] + f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands])) + assert f_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.map_impl, + (f_branch.graph_module, mapped_args, operands), + {}, + meta, + ) + + def call_getitem( + self, value: ProxyValue, key: int, meta: NodeMetadata + ) -> ProxyValue: + return self._fx("call_function", operator.getitem, (value, key), {}, meta) + + def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + return self._fx("output", "output", (results,), {}, meta) + + def call_submodule( + self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] + ) -> PassResult: + prev_tracer, self.tracer = self.tracer, self.ExportTracer( + self, graph_module.graph._codegen + ) + self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode + interpreter = self.ExportInterpreter(self, graph_module) + prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment] + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) + with fx_traceback.preserve_node_meta(): + interpreter.run(*inputs_data) + + new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph) + + self.tracer = prev_tracer + self.interpreter = prev_interpreter + return PassResult( + new_graph_module, + True, + ) + + def call(self, graph_module: fx.GraphModule) -> PassResult: + if not getattr(self, "_initialized", False): + raise ExportPassBaseError( + "ExportPass is not initialized with __init__().", + ) + + inputs = self.inputs(graph_module) + + fake_tensor_mode = None + for i in inputs: + if isinstance(i, FakeTensor): + assert ( + fake_tensor_mode is None or fake_tensor_mode is i.fake_mode + ), "Multiple fake tensor mode detected." + fake_tensor_mode = i.fake_mode + if fake_tensor_mode is None: + self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) + fake_tensor_mode = nullcontext() # type: ignore[assignment] + dispatcher_mode = nullcontext() # type: ignore[assignment] + else: + fake_tensor_mode.allow_non_fake_inputs = True + self.tracer.fake_tensor_mode = fake_tensor_mode + dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] + self.fake_tensor_mode = self.tracer.fake_tensor_mode + + with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] + result = self.call_submodule(graph_module, tuple(inputs)) + + return result diff --git a/.venv/lib/python3.11/site-packages/torch/_export/tools.py b/.venv/lib/python3.11/site-packages/torch/_export/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b96f909d1642f888546d1068d31d1a5f4ee9f1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/tools.py @@ -0,0 +1,146 @@ +# mypy: allow-untyped-defs +import logging +import warnings +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +import torch.export +import torch.export._trace +from torch._utils_internal import log_export_usage + + +log = logging.getLogger(__name__) + +__all__ = ["report_exportability"] + + +def _generate_inputs_for_submodules( + model: torch.nn.Module, + target_submodules: Iterable[str], + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, Tuple[Any, Any]]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + kwargs = kwargs or {} + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_args, module_kwargs): + results[submodule_to_names[module]] = (module_args, module_kwargs) + + try: + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append( + mod.register_forward_pre_hook(pre_forward, with_kwargs=True) + ) + model(*args, **kwargs) + except Exception as e: + warnings.warn( + f"Failed to generate submodule inputs because of the following error:\n{e}" + ) + finally: + for h in handles: + h.remove() + return results + + +def report_exportability( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + strict: bool = True, + pre_dispatch: bool = False, +) -> Dict[str, Optional[Exception]]: + """ + Report exportability issues for a module in one-shot. + + Args: + mod: root module. + args: args to the root module. + kwargs: kwargs to the root module. + Returns: + A dict that maps from submodule name to the exception that was raised when trying to export it. + `None` means the module is exportable without issue. + Sample output: + { + '': UnsupportedOperatorException(func=), + 'submod_1': UnsupportedOperatorException(func=), + 'submod_2': None + } + """ + + log_export_usage(event="export.report_exportability") + + kwargs = kwargs or {} + + all_submod_names = [name for name, _ in mod.named_modules() if name != ""] + submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) + + tried_module_types = set() + report: Dict[str, Optional[Exception]] = {} + + def try_export(module, module_name, args, kwargs): + nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types + + if type(module) in tried_module_types: + return + tried_module_types.add(type(module)) + + if args is not None or kwargs is not None: + try: + torch.export._trace._export( + module, + args, + kwargs, + strict=strict, + pre_dispatch=pre_dispatch, + ) + report[module_name] = None + log.info("Successfully exported `%s`", module_name) + return + except Exception as e: + short_msg = repr(e).split("\n")[0] + log.warning( + "Failed exporting `%s` with exception: %s", module_name, short_msg + ) + report[module_name] = e + + for name, submod in module.named_children(): + sub_module_name = name if module_name == "" else f"{module_name}.{name}" + + submod_args, submod_kwargs = submod_inputs.get( + sub_module_name, (None, None) + ) + + try_export(submod, sub_module_name, submod_args, submod_kwargs) + + return + + try_export(mod, "", args, kwargs) + + unique_issues = set() + for exception in report.values(): + if exception is not None: + key = repr(exception).split("\\n")[0] + unique_issues.add(key) + + log.warning("Found %d export issues:", len(unique_issues)) + for issue in unique_issues: + log.warning(issue) + + return report diff --git a/.venv/lib/python3.11/site-packages/torch/_export/verifier.py b/.venv/lib/python3.11/site-packages/torch/_export/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..68c5bcaae39af69f5527e8dcf0e08ed49bad4563 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/verifier.py @@ -0,0 +1,456 @@ +# mypy: allow-untyped-defs +import inspect +import math +import operator +from collections.abc import Iterable +from typing import Any, Dict, final, List, Tuple, Type, TYPE_CHECKING + +import torch +from torch._ops import HigherOrderOperator, OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch.export.graph_signature import ( + CustomObjArgument, + InputKind, + SymIntArgument, + TensorArgument, + TokenArgument, +) +from torch.fx import GraphModule + +if TYPE_CHECKING: + from torch.export.exported_program import ExportedProgram + +class SpecViolationError(Exception): + pass + + +def is_functional(op: OpOverload) -> bool: + return not op._schema.is_mutable + + +def _check_has_fake_tensor(node: torch.fx.Node) -> None: + # TODO(angelayi): remove this in favor of _check_val + return _check_val(node) + + +def _check_val(node: torch.fx.Node) -> None: + from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt + + def _check_correct_val(val): + if val is None: + return True + elif isinstance(val, (int, bool, str, float)): + return True + elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)): + return True + elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor. + return True + elif isinstance(val, (SymInt, SymFloat, SymBool)): + return True + elif isinstance(val, CustomObjArgument): + return True + elif isinstance(val, Iterable): + return all(_check_correct_val(x) for x in val) + return False + + def _no_returns(op): + if not isinstance(op, OpOverload): + return False + return len(op._schema.returns) == 0 + + if "val" not in node.meta: + if node.op == "call_function" and _no_returns(node.target): + return + raise SpecViolationError(f"Node.meta {node.name} is missing val field.") + + val = node.meta["val"] + if not _check_correct_val(val): + raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}") + + +def _check_torch_fn(node: torch.fx.Node) -> None: + torch_fn = node.meta.get("torch_fn") + if torch_fn is None: + raise SpecViolationError(f"Unable to find torch_fn metadata for node {node.name}") + if ( + not isinstance(torch_fn, tuple) and + isinstance(torch_fn[0], str) and + isinstance(torch_fn[1], str) + ): + raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}") + +class _VerifierMeta(type): + _registry: Dict[str, Type['Verifier']] = {} + + def __new__(metacls, name, bases, attrs): + if bases: + if "check" in attrs or "_check_graph_module" in attrs: + raise SyntaxError("Overriding method check is not allowed.") + assert "dialect" in attrs and attrs["dialect"] != "ATEN" + else: + assert "check" in attrs + assert "_check_graph_module" in attrs + assert attrs["dialect"] == "ATEN" + + assert isinstance(attrs["dialect"], str) + ret = type.__new__(metacls, name, bases, attrs) + metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment] + return ret + +def getattr_recursive(obj: Any, target: str) -> Any: + target_atoms = target.split('.') + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +class Verifier(metaclass=_VerifierMeta): + dialect = "ATEN" + + def allowed_builtin_ops(self) -> List: + return [ + operator.getitem, + operator.add, + operator.mul, + operator.sub, + operator.truediv, + operator.ge, + operator.le, + operator.gt, + operator.lt, + operator.eq, + operator.ne, + operator.floordiv, + operator.mod, + operator.and_, + operator.or_, + operator.not_, + operator.pow, + operator.neg, + operator.abs, + math.ceil, + math.floor, + math.trunc, + ] + + def allowed_op_types(self) -> Tuple[Type[Any], ...]: + return (OpOverload, HigherOrderOperator) + + def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: + return (torch.fx.GraphModule,) + + def check_valid_op(self, op): + pass + + def check_additional(self, gm: GraphModule) -> None: + """ + Additional checks that are specific to some dialects. + """ + + @final + def check(self, ep: "ExportedProgram") -> None: + self._check_graph_module(ep.graph_module) + _verify_exported_program_module_call_graph(ep) + _verify_exported_program_signature(ep) + + @final + def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: + def _allowed_getattr_types() -> Tuple[Type[Any], ...]: + ret = self.allowed_getattr_types() + assert not any(t is object for t in ret) + return ret + + def _check_valid_op(op) -> None: + def _allowed_builtin_ops() -> List: + ret = self.allowed_builtin_ops() + assert all(inspect.isbuiltin(op) for op in ret) + return ret + + def _allowed_op_types() -> Tuple[Type[Any], ...]: + ret = self.allowed_op_types() + assert not any(t is object for t in ret) + return ret + + # TODO Remove this allowlist. + _allowed_torch_functions = ( + torch.autograd.grad_mode.set_grad_enabled, + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, + # TODO (tmanlaibaatar) + # Predispatch export is able to contain autograd ops. + # These will be modeled as HOO later + torch._C._set_grad_enabled, + ) + + if not isinstance(op, _allowed_op_types()): + if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: + raise SpecViolationError( + f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" + f"Valid builtin ops: {_allowed_builtin_ops()}" + f"Valid torch functions: {_allowed_torch_functions}" + ) + + if isinstance(op, OpOverload): + # All ops functional + # TODO (tmanlaibaatar) more proper way is needed here + if self.dialect != "TRAINING" and not is_functional(op): + raise SpecViolationError( + f"operator '{op}' is not functional" + ) + self.check_valid_op(op) + + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + + mod.graph.lint() + for node in mod.graph.nodes: + # TODO(T140410192): should have fake tensor for all dialects + if node.op in {"call_module", "call_method"}: + raise SpecViolationError( + f"call_module is not valid: got a class '{node.target}' ", + ) + + elif node.op == "call_function": + _check_val(node) + + _check_valid_op(node.target) + + elif node.op == "get_attr": + if not isinstance(node.target, str): + raise SpecViolationError( + f"Expected get_attr target to be string, but got {type(node.target)}" + ) + + attr = getattr_recursive(mod, node.target) + if isinstance(attr, torch.nn.Module): + def _is_type(name, ty): + return isinstance(getattr(attr, name, None), ty) + if type(attr).__name__ == "LoweredBackendModule": + if _is_type("backend_id", str) \ + and _is_type("processed_bytes", bytes) \ + and _is_type("compile_specs", list) \ + and hasattr(attr, "original_module"): + continue + else: + backend_id = getattr(attr, "backend_id", None) + processed_bytes = getattr(attr, "processed_bytes", None) + compile_specs = getattr(attr, "compile_specs", None) + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"LoweredBackendModule fields: " + f"backend_id(str) : {type(backend_id)}, " + f"processed_bytes(bytes) : {type(processed_bytes)}, " + f"compile_specs(list) : {type(compile_specs)}" + ) + + if not isinstance(attr, _allowed_getattr_types()): + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"Valid get_attr types: {_allowed_getattr_types()}" + ) + + + elif node.op == "placeholder": + _check_val(node) + # TODO(zhxchen17) + # elif node.op == "output": + # _check_flattened_outputs() + + self.check_additional(gm) + + +class TrainingIRVerifier(Verifier): + dialect = "TRAINING" + + +def _verify_exported_program_module_call_graph(exported_program) -> None: + module_call_graph = exported_program.module_call_graph + nodes = { + node.name for node in exported_program.graph.nodes + } + for entry in module_call_graph: + if entry.signature is not None: + for arg in entry.signature.inputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Input {arg.name} does not exist in the graph." + ) + for arg in entry.signature.outputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Output {arg.name} does not exist in the graph." + ) + + +def _verify_exported_program_signature(exported_program) -> None: + # Check ExportedProgram signature matches + gs = exported_program.graph_signature + + # Check every node in the signature exists in the graph + input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"] + + if len(input_node_names) != len(gs.input_specs): + raise SpecViolationError( + f"Number of graph inputs ({len(input_node_names)}) " + f"does not match number of inputs in the graph signature ({len(gs.input_specs)})" + ) + + for input_spec, node in zip(gs.input_specs, input_node_names): + if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)): + if input_spec.arg.name != node: + raise SpecViolationError( + f"Input spec name {input_spec.arg.name} does not match node name {node}" + ) + + if input_spec.kind == InputKind.USER_INPUT: + continue + + elif input_spec.kind == InputKind.PARAMETER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + param = input_spec.target + if param not in exported_program.state_dict: + raise SpecViolationError( + f"Parameter {param} is not in the state dict." + ) + + if not isinstance(exported_program.state_dict[param], torch.nn.Parameter): + raise SpecViolationError( + f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter." + ) + + elif input_spec.kind == InputKind.BUFFER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + buffer = input_spec.target + if input_spec.persistent is None: + raise SpecViolationError( + f"Buffer {buffer} is missing a persistence flag" + ) + + if input_spec.persistent is True and buffer not in exported_program.state_dict: + raise SpecViolationError( + f"Buffer {buffer} is not in the state dict." + ) + + if input_spec.persistent is False and buffer in exported_program.state_dict: + raise SpecViolationError( + f"Non-persistent buffer {buffer} is in the state dict, it should not be." + ) + elif input_spec.kind == InputKind.CONSTANT_TENSOR: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + tensor_const = input_spec.target + if tensor_const not in exported_program.constants: + raise SpecViolationError( + f"Constant tensor {tensor_const} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.CUSTOM_OBJ: + if not isinstance(input_spec.arg, CustomObjArgument): + raise SpecViolationError( + f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + custom_obj = input_spec.target + if custom_obj not in exported_program.constants: + raise SpecViolationError( + f"Custom object {custom_obj} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.TOKEN: + if not isinstance(input_spec.arg, TokenArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + else: + raise SpecViolationError( + f"Unknown InputKind {input_spec.kind}." + ) + + # Check outputs + output_node = list(exported_program.graph.nodes)[-1] + assert output_node.op == "output" + output_nodes = [ + arg.name if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ] + + if len(output_nodes) != len(gs.output_specs): + raise SpecViolationError( + f"Number of output nodes {len(output_nodes)} is different " + "Than the number of outputs specified by the graph signature: \n" + f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n" + f"Number of user outputs: {len(gs.user_outputs)}. \n" + ) + + num_tokens = len(gs.output_tokens) + end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens + mutate_nodes: List[str] = output_nodes[num_tokens:end] + user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] + + for mutation_node in mutate_nodes: + if mutation_node in gs.buffers_to_mutate: + if gs.buffers_to_mutate[mutation_node] not in gs.buffers: + raise SpecViolationError( + f"Buffer output {mutation_node} does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" + ) + elif mutation_node in gs.user_inputs_to_mutate: + if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: + raise SpecViolationError( + f"User input output {mutation_node} does not point to a user input that exists. \n" + f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" + f"User input nodes available: {gs.user_inputs} \n") + else: + raise SpecViolationError( + f"Mutation node {mutation_node} is neither a buffer nor a user input. " + f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" + ) + + for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): + if user_output_node != user_output_name: + raise SpecViolationError( + f"User output {user_output_node} is not in the correct " + "order or is not found in the " + f"exported program's user_output list: {gs.user_outputs}. " + ) + + +def load_verifier(dialect: str) -> Type[Verifier]: + if dialect == "ATEN" or dialect == "": + return _VerifierMeta._registry.get(dialect, Verifier) + return _VerifierMeta._registry[dialect] diff --git a/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py b/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..d57ff46de41c8f5961a859d3d1e2871984929b8d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/wrappers.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +import torch +import torch._custom_ops +from torch._C import DispatchKey +from torch._higher_order_ops.strict_mode import strict_mode +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils import _pytree as pytree + + +class ExportTracepoint(HigherOrderOperator): + def __init__(self): + super().__init__("_export_tracepoint") + + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +_export_tracepoint = ExportTracepoint() + + +@_export_tracepoint.py_impl(ProxyTorchDispatchMode) +def export_tracepoint_dispatch_mode(mode, *args, **kwargs): + p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs)) + proxy = mode.tracer.create_proxy( + "call_function", _export_tracepoint, p_args, p_kwargs + ) + return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer) + + +@_export_tracepoint.py_impl(FakeTensorMode) +def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs): + with mode: + return args + + +@_export_tracepoint.py_functionalize_impl +def export_tracepoint_functional(ctx, *args, **kwargs): + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs) + return ctx.wrap_tensors(out) + + +_export_tracepoint.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(_export_tracepoint, deferred_error=True) +) + + +@_export_tracepoint.py_impl(DispatchKey.CPU) +def export_tracepoint_cpu(*args, **kwargs): + return args + + +def _wrap_submodule(mod, path, module_call_specs): + assert isinstance(mod, torch.nn.Module) + assert path != "" + submodule = mod + for name in path.split("."): + if not hasattr(submodule, name): + raise RuntimeError(f"Couldn't find submodule at path {path}") + submodule = getattr(submodule, name) + + def update_module_call_signatures(path, in_spec, out_spec): + if path in module_call_specs: + assert module_call_specs[path]["in_spec"] == in_spec + assert module_call_specs[path]["out_spec"] == out_spec + module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} + + def check_flattened(flat_args): + for a in flat_args: + if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None): + raise AssertionError( + f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}" + ) + + def pre_hook(module, args, kwargs): + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + check_flattened(flat_args) + flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path) + args, kwargs = pytree.tree_unflatten(flat_args, in_spec) + return args, kwargs + + def post_hook(module, args, kwargs, res): + _, in_spec = pytree.tree_flatten((args, kwargs)) + flat_res, out_spec = pytree.tree_flatten(res) + check_flattened(flat_res) + flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path) + update_module_call_signatures(path, in_spec, out_spec) + return pytree.tree_unflatten(flat_res, out_spec) + + pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True) + post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True) + return pre_handle, post_handle + + +@contextmanager +def _wrap_submodules(f, preserve_signature, module_call_signatures): + handles = [] + + try: + for path in preserve_signature: + handles.extend(_wrap_submodule(f, path, module_call_signatures)) + yield + finally: + for handle in handles: + handle.remove() + + +def _mark_strict_experimental(cls): + def call(self, *args): + return strict_mode(self, args) + + cls.__call__ = call + return cls diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py b/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d90efa40e58841a11a25569ca6722b791894999 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs + +import torch._C._lazy +from torch.utils._pytree import tree_flatten, tree_unflatten + +from .closure import add_step_closure, run_step_closures + + +def mark_step(device: str = "", wait=False): + """Triggers a mark step, which amounts to + - collecting a group of 'live' lazy tensors to index into the compilation cache + (lowering/compiling their IR graphs if not cached) + - kicking off execution of the compiled function + - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) + """ + # TODO(whc) expand this to include backend hooks and align with XLA backend needs + torch._C._lazy._mark_step(device, [], wait=wait) + + run_step_closures() + + +def wait_device_ops(devices=None): + """Waits for all the async operations on the given devices to complete. + Args: + devices (string..., optional): The devices whose async ops need to be waited + for. If empty, all the local devices will be waited for. + """ + if devices is None: + devices = [] + torch._C._lazy._wait_device_ops(devices=devices) + + +def sync_multi(tensors, devices): + """ + Sync the list of lazy tensors so there IR get lowered for the activate backend + and the compiled computation graph get cached. + """ + torch._C._lazy._sync_multi(tensors, devices) + + +def get_tensor_id(tensor): + """Return a unique id of the lazy tensor maintained by LTC""" + return torch._C._lazy._get_tensor_id(tensor) + + +def to_cpu(tensors, devices=None): + devices = devices or ["lazy"] + + flattened, spec = tree_flatten(tensors) + sync_multi(flattened, devices) + return tree_unflatten([t.to("cpu") for t in flattened], spec) + + +def save(tensors, *args, **kwargs): + torch.save(to_cpu(tensors), *args, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faff16110f1ba77ea580b919ad14aaaa019f3324 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e40d967dd6ec18fb021840fc221ae5f072360966 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py b/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py new file mode 100644 index 0000000000000000000000000000000000000000..17a61e36cb9f2a46461d14caa3c1a3ff6e8c9094 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import torch._C._lazy +import torch._C._lazy_ts_backend + + +def get_tensors_ts_device_data_node(tensors): + """Return tensor ids and eager tensors for DeviceData nodes in the + IR for the passed in lazy tensors. + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors) + + +def get_graph_hash(tensors): + """Return the graph hash for the passed in lazy tensors""" + return torch._C._lazy._get_graph_hash(tensors) + + +def run_cached_graph(hash_str, graph_inputs): + """Running the cached computation graph with the given inputs + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs) diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/config.py b/.venv/lib/python3.11/site-packages/torch/_lazy/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ebca12de7fc44c27a2b3ae7c2ed1c7d8097c99 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/config.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def get_force_fallback(): + """Get the config used to force LTC fallback""" + return torch._C._lazy._get_force_fallback() + + +def set_force_fallback(configval): + """Set the config used to force LTC fallback""" + torch._C._lazy._set_force_fallback(configval) + + +def set_reuse_ir(val: bool): + """Set the config to reuse IR nodes for faster tracing""" + torch._C._lazy._set_reuse_ir(val) diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py b/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..84534fb232509f0c9bbe722820bd1ae649d53e07 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def render_ir_graph(tensors): + """Return a text dump of the LTC IR graph in dot format for the tensors. + The text can be processed by tools like dot to be rendered in pdf,png etc.""" + return torch._C._lazy._get_tensors_dot(tensors) + + +def dump_ir(tensors, ir_format): + """Return a dump of the tensors in the specified format. + Valid format are + - text: for LTC IR + - backend: for the activate backend IR + """ + if ir_format == "text": + return torch._C._lazy._get_tensors_text(tensors) + elif ir_format == "backend": + return torch._C._lazy._get_tensors_backend(tensors) + else: + raise RuntimeError(f"Unrecognized IR format: {ir_format}") diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py b/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py new file mode 100644 index 0000000000000000000000000000000000000000..e09fdab3f7458cc6a410a1736b89e68b4a4eef17 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +import threading +from typing import Any, Dict + +import torch._C._lazy + + +class DeviceContext: + _CONTEXTS: Dict[str, Any] = {} + _CONTEXTS_LOCK = threading.Lock() + + def __init__(self, device): + self.device = device + + +def get_device_context(device=None): + if device is None: + device = torch._C._lazy._get_default_device_type() + else: + device = str(device) + with DeviceContext._CONTEXTS_LOCK: + devctx = DeviceContext._CONTEXTS.get(device, None) + if devctx is None: + devctx = DeviceContext(device) + DeviceContext._CONTEXTS[device] = devctx + return devctx diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py b/.venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..f46eea4eee9b79033aa22ce2bcc77ba9f650c622 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py @@ -0,0 +1,225 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import itertools +import os +from typing import Any, Callable, Dict, List + +import torch +import torch._lazy as lazy +import torch._lazy.metrics as metrics +from torch import fx +from torch._lazy import computation, debug as lazy_debug +from torch._lazy.tensor_factory_functions import tensor_factory_functions + + +debug = os.environ.get("debug_extract_compiled_graph") is not None + + +@dataclasses.dataclass +class GraphInputMatcher: + """ + The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing. + Specifically, those graph inputs corresponding to method parameters should be replaced with the + arguments for the current call. + + tensor_id_to_arg_idx maps the tensor id to the parameter index. + graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the + TS/XLA graph inputs. + """ + + tensor_id_to_arg_idx: Dict[int, int] + graph_input_tensor_ids: List[int] + # there are 2 categories of graph_input_tensors. + # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are + # most likely const tensors and we can get its content from graph_input_tensors + # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get + # the tensor from method arguments + graph_input_ivalues: List[Any] + + # get the real graph input tensors + def __call__(self, args): + real_input = [] + for tensor_id, traced_ivalue in zip( + self.graph_input_tensor_ids, self.graph_input_ivalues + ): + arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) + if arg_idx is None: + inp = traced_ivalue + else: + inp = args[arg_idx] + real_input.append(inp) + return real_input + + +class ReturnValueHandler: + r""" + When ltc_sync_multi is called on multi tensors, the compiled graph + will contain output only for unique tensors - if a tensor appears multiple + times in the input to _ltc_sync_multi, only the first occurance matters. + + However from python level, we still expect multi tensors returned with duplciation + even if the TS graph dedup the output. e.g. for method: + + def forward(self, a): + return a, a + + the TS graph captured by LTC will return a single tensor, but Python method expects 2. + + This class dedup the lazy tensors first to get the index that will be used + to duplicate the eager tensors later. + """ + + def __init__(self, lazy_out_list): + self.index: List[List[int]] = [] + self.total_count = len(lazy_out_list) + + tensor_id_to_idx: Dict[int, int] = {} + for dup_idx, lazy_tensor in enumerate(lazy_out_list): + uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None) + if uniq_idx is not None: + self.index[uniq_idx].append(dup_idx) + else: + uniq_idx = len(self.index) + self.index.append([dup_idx]) + tensor_id_to_idx[id(lazy_tensor)] = uniq_idx + + def duplicate_eager_tensors(self, eager_tensor_list): + duplicated_list = [None] * self.total_count + assert len(eager_tensor_list) == len(self.index) + + for uniq_idx, eager_tensor in enumerate(eager_tensor_list): + for dup_idx in self.index[uniq_idx]: + duplicated_list[dup_idx] = eager_tensor + return duplicated_list + + +def force_lazy_device(model: fx.GraphModule): + """ + Factory methods in a Fx graph may create tensors for a specific eager devices. + If we take no actions, those eager tensors will be mixed with lazy tensors and + cause crash. This method overwrite those eager device to lazy device. + """ + + def tolazydevice(dev): + if isinstance(dev, torch.device): + return torch.device("lazy", index=dev.index) + return dev + + def hasDeviceArg(args, kwargs): + return any( + isinstance(arg, torch.device) + for arg in itertools.chain(args, kwargs.values()) + ) + + for nd in model.graph.nodes: + nd.args = tuple(tolazydevice(arg) for arg in nd.args) + nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()} + + # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return + # eager tensors on the default device + # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove, + # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart). + # To force those tensors on the lazy device, we can not simply override + # the device argument since there is no explicit device argument. + # What we are doing here is, for the list of covered tensor factory methods + # we add a lazy device argument explicity. + # + # TODO: This solution is no ideal since we may miss some factory methods. In future + # when we support lazy mode, this method can be replaced by that. + if nd.target in tensor_factory_functions and not hasDeviceArg( + nd.args, nd.kwargs + ): + kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy. + kwargs["device"] = torch.device("lazy") + nd.kwargs = kwargs + + model.recompile() + + +def get_fallback_ops(): + fallback_ops = [] + for opname in metrics.counter_names(): + if "aten::" not in opname: + continue + val = int(metrics.counter_value(opname)) + if val > 0: + fallback_ops.append(f"{opname}={val}") + + return fallback_ops + + +def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: + """ + Optimize an eager model with LTC and returns a wrapper to execute the + compiled graph directly without retracing. It depends on other mechanisms + like TorchDynamo guards to guarantee the returned wrapper is only called + when it's safe. + """ + lazy_args = [arg.to(device="lazy") for arg in example_inputs] + args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args] + tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} + lazy_model = copy.deepcopy(model).to(device=torch.device("lazy")) + force_lazy_device(lazy_model) + + # This line executes lazy tracing and enable us extracting compiled graph later + metrics.reset() + lazy_out = lazy_model(*lazy_args) + fallback_ops = get_fallback_ops() + metrics.reset() + + if len(fallback_ops) > 0: + raise RuntimeError( + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" + ) + + if not isinstance(lazy_out, (tuple, list)): + lazy_out = (lazy_out,) + + args_and_out = tuple(lazy_args) + tuple(lazy_out) + return_value_handler = ReturnValueHandler(args_and_out) + if debug: + print("Fx code:\n", model.code) + print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text")) + + # TODO: this part is TS backend specific for now and will be generalized to + # support XLA + ( + graph_input_tensor_ids, + graph_input_ivalues, + ) = computation.get_tensors_ts_device_data_node(args_and_out) + assert len(graph_input_tensor_ids) == len(graph_input_ivalues) + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues + ) + + graph_hash = computation.get_graph_hash(args_and_out) + + if debug: + print("graph_hash", graph_hash) + print(f"args_tensor_ids {args_tensor_ids}") + print("tensor ids from device data:", graph_input_tensor_ids) + + # sync the list of output tensors so the computation graph for these + # tensors will be cached. Those computation graphs can be retrieved + # by graph hash later. + lazy.sync_multi(args_and_out, []) + + def optimized_mod(*args): + if len(args_and_out) == 0: + return () + graph_input = graph_input_matcher(args) + res = return_value_handler.duplicate_eager_tensors( + computation.run_cached_graph(graph_hash, graph_input) + ) + + assert len(res) == len(args_and_out) + for i, arg in enumerate(args): + # only copy those tensors that get inplace updated + if arg is not res[i]: + arg.copy_(res[i]) + + # skip the args + return res[len(args) :] + + return optimized_mod diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py b/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a77981feb90dbd74eb0a31ae86fe661a758a494a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def reset(): + """Resets all metric counters.""" + torch._C._lazy._reset_metrics() + + +def counter_names(): + """Retrieves all the currently active counter names.""" + return torch._C._lazy._counter_names() + + +def counter_value(name: str): + """Return the value of the counter with the speficied name""" + return torch._C._lazy._counter_value(name) + + +def metrics_report(): + """Return the combined (lazy core and backend) metric report""" + return torch._C._lazy._metrics_report() diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py b/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6ce13746e913db8e27081b8b0dcf8f4e0d4c88 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch._C._lazy_ts_backend + + +def init(): + """Initializes the lazy Torchscript backend""" + torch._C._lazy_ts_backend._init() diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py b/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4a97c10c07ae680765b0f362ef33c4bfb2308b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +import sys + + +__all__ = ["register_after_fork"] + +if sys.platform == "win32": + import multiprocessing.util as _util + + def _register(func): + def wrapper(arg): + func() + + _util.register_after_fork(_register, wrapper) + +else: + import os + + def _register(func): + os.register_at_fork(after_in_child=func) + + +def register_after_fork(func): + """Register a callable to be executed in the child process after a fork. + + Note: + In python < 3.7 this will only work with processes created using the + ``multiprocessing`` module. In python >= 3.7 it also works with + ``os.fork()``. + + Args: + func (function): Function taking no arguments to be called in the child after fork + + """ + _register(func) diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py b/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py new file mode 100644 index 0000000000000000000000000000000000000000..6915203566469cfaf7170d87894ce03cc8348dd5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py @@ -0,0 +1,52 @@ +import multiprocessing.pool +import multiprocessing.util as util + +from .queue import SimpleQueue + + +def clean_worker(*args, **kwargs): + import gc + + multiprocessing.pool.worker(*args, **kwargs) + # Regular multiprocessing workers don't fully clean up after themselves, + # so we have to explicitly trigger garbage collection to make sure that all + # destructors are called... + gc.collect() + + +class Pool(multiprocessing.pool.Pool): + """Pool implementation which uses our version of SimpleQueue. + + This lets us pass tensors in shared memory across processes instead of + serializing the underlying data. + """ + + def _setup_queues(self): + self._inqueue = SimpleQueue() + self._outqueue = SimpleQueue() + self._quick_put = self._inqueue._writer.send + self._quick_get = self._outqueue._reader.recv + + def _repopulate_pool(self): + """Increase the number of pool processes to the specified number. + + Bring the number of pool processes up to the specified number, for use after + reaping workers which have exited. + """ + for i in range(self._processes - len(self._pool)): + # changed worker -> clean_worker + args = ( + self._inqueue, + self._outqueue, + self._initializer, + self._initargs, + self._maxtasksperchild, + ) + if hasattr(self, "_wrap_exception"): + args += (self._wrap_exception,) + w = self.Process(target=clean_worker, args=args) + self._pool.append(w) + w.name = w.name.replace("Process", "PoolWorker") + w.daemon = True + w.start() + util.debug("added worker") diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py b/.venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py new file mode 100644 index 0000000000000000000000000000000000000000..876bf8d0e7459b60a41b59b0a093608e515ba455 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py @@ -0,0 +1,43 @@ +# mypy: allow-untyped-defs +import io +import multiprocessing.queues +import pickle +from multiprocessing.reduction import ForkingPickler + + +class ConnectionWrapper: + """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization.""" + + def __init__(self, conn): + self.conn = conn + + def send(self, obj): + buf = io.BytesIO() + ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) + self.send_bytes(buf.getvalue()) + + def recv(self): + buf = self.recv_bytes() + return pickle.loads(buf) + + def __getattr__(self, name): + if "conn" in self.__dict__: + return getattr(self.conn, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'") + + +class Queue(multiprocessing.queues.Queue): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + self._send = self._writer.send + self._recv = self._reader.recv + + +class SimpleQueue(multiprocessing.queues.SimpleQueue): + def _make_methods(self): + if not isinstance(self._reader, ConnectionWrapper): + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + super()._make_methods() # type: ignore[misc] diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py b/.venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0818571a93c0e9809c4638446e7ebdb15bd87e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py @@ -0,0 +1,647 @@ +# mypy: allow-untyped-defs +import multiprocessing +import os +import threading +from multiprocessing.reduction import ForkingPickler +from multiprocessing.util import register_after_fork +from typing import Union + +import torch +from torch._namedtensor_internals import check_serializing_named_tensor + + +try: + # Early load resource_sharer to prevent a partially initialized instance + # from being inherited in a forked child process. The reduce_storage method + # requires this module indirectly through DupFd(). The built-in mp.Queue + # class pickles arguments in a background thread which may overlap with the + # fork. + import multiprocessing.resource_sharer +except ImportError: + pass + + +class StorageWeakRef: + r"""A weak reference to a Storage. + + The cdata member is a Python number containing the integer representation of + the Storage pointer. + """ + + __slots__ = ["cdata", "_free_weak_ref"] + + def __init__(self, storage): + self.cdata = storage._weak_ref() + # Save a direct reference to _free_weak_ref because the `torch` module + # might be cleared during Python shutdown before this module is cleared. + self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] + + @classmethod + def from_weakref(cls, cdata): + instance = cls.__new__(cls) + instance.cdata = cdata + instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] + return instance + + def expired(self): + return torch.Storage._expired(self.cdata) # type: ignore[attr-defined] + + def __del__(self): + self._free_weak_ref(self.cdata) + + def __hash__(self): + return self.cdata + + def __eq__(self, other): + if id(self) == id(other): + return True + return self.cdata == other.cdata + + +class SharedCache(dict): + """Dictionary from multiprocessing handles to StorageWeakRef.""" + + def __init__(self) -> None: + # free_dead_references() is called if the len exceeds the current + # limit. The limit scales with the number of remaining live objects. + self.limit = 128 + # `fork` inherits lock state, so in case we fork when the lock is held, + # we register a function to reset the lock to a new object to avoid + # possible deadlocks, following python multiprocessing library design. + self._after_fork() + register_after_fork(self, SharedCache._after_fork) + + def _after_fork(self): + self.lock = threading.Lock() + + def get(self, key): + with self.lock: + return dict.get(self, key) + + def __setitem__(self, key, storage_ref): + with self.lock: + dict.__setitem__(self, key, storage_ref) + if len(self) > self.limit: + self.free_dead_references() + + def free_dead_references(self): + live = 0 + for key, storage_ref in list(self.items()): + if storage_ref.expired(): + del self[key] + else: + live += 1 + self.limit = max(128, live * 2) + + +# mapping from handles to StorageWeakRef objects +shared_cache = SharedCache() + + +def rebuild_event(device, handle): + return torch.cuda.Event.from_ipc_handle(device, handle) + + +def reduce_event(event): + handle = event.ipc_handle() + return (rebuild_event, (event.device, handle)) + + +def rebuild_tensor(cls, storage, metadata): + storage_offset, size, stride, requires_grad = metadata + t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) + if cls == torch.nn.parameter.Parameter: + # we have to pass requires_grad into constructor, rather than set it as an + # attribute later, because it's an important check for Integer Tensors to + # have requires_grad=False (or else they raise an error) + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + return t + + +def rebuild_meta_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + dtype, + storage_size_bytes, + requires_grad, +): + untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta") + + typed_storage = torch.TypedStorage( + wrap_storage=untyped_storage, dtype=dtype, _internal=True + ) + + t = torch._utils._rebuild_tensor( + typed_storage, + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def rebuild_cuda_tensor( + tensor_cls, + tensor_size, + tensor_stride, + tensor_offset, + storage_cls, + dtype, + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, +): + # If storage_handle is None, storage points to nullptr. + if storage_handle is None or storage_size_bytes == 0: + storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) + else: + storage = storage_from_cache( + storage_cls, (storage_handle, storage_offset_bytes) + ) + if storage is None: + torch.cuda._lazy_init() + storage = storage_cls._new_shared_cuda( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) + shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef( + storage + ) + else: + # We already ref counting this Storage, but producer needs new ref-counters to be released. + storage_cls._release_ipc_counter( + ref_counter_handle, ref_counter_offset, device=storage_device + ) + + _storage = ( + storage + if isinstance(storage, torch.UntypedStorage) + else storage._untyped_storage + ) + + t = torch._utils._rebuild_tensor( + torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), + tensor_offset, + tensor_size, + tensor_stride, + ) + + if tensor_cls == torch.nn.parameter.Parameter: + # It is crucial for integer tensors to receive + # the requires_grad=False as an argument in the constructor + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad + + return t + + +def reduce_tensor(tensor): + if tensor.requires_grad and not tensor.is_leaf: + raise RuntimeError( + "Cowardly refusing to serialize non-leaf tensor which requires_grad, " + "since autograd does not support crossing process boundaries. " + "If you just want to transfer the data, call detach() on the tensor " + "before serializing (e.g., putting it on the queue)." + ) + + check_serializing_named_tensor(tensor) + torch.utils.hooks.warn_if_has_hooks(tensor) + + # Note [CUDA IPC and the caching allocator] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # When you send a CUDA tensor over IPC, you might expect that you will + # get out the same storage from the other end. However, the CUDA caching + # allocator makes it difficult to preserve this invariant. Consider + # the following situation: a tensor of size 0x100 points to offset 0x20 of + # a storage at 0xA100 of size 0x100. (For simplicity, all of these + # sizes are given in bytes). HOWEVER, with the caching allocator, this storage + # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000. + # + # When we want to send this CUDA tensor over IPC, we must send the + # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just + # the storage 0xA100 (because that is what CUDA supports). So, on the + # other end, there simply isn't any way to say, "Wait, you gave me + # a bigger region (0xA000) than the one I wanted (0xA100)". + # + # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as + # one storage itself? No, because this cudaMalloc allocation might contain + # storages of mixed types: float, bytes, double... If you make the entire + # allocation a single storage of a type A, we'll hit an error when constructing + # a tensor of type B on the storage. + # + # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the + # receiver side. However, cudaIpcMemHandles from each device in a given process may + # only be opened by one context per device per other process. + # If we open and close a memory handle multiples times in a process, CUDA is allowed + # to give it a different address; similarly, once we close the memory, we're not + # allowed to access it(and the storage/tensor built on top of it), even if it is + # still live in the original process. As we cannot make a cudaMalloc allocation + # to a single storage in one go, this requires us to cache the device pointer for + # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep + # the old ones alives. + # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html] + # + # This is fine, because all we need to do is to save our position in the allocation, + # and reconstruct storage and tensor from it. + # 0xA000 -> -------CUDA Allocation------ + # | | + # | | + # | | + # | | + # 0xA100 -> --------storage1 begin------ + # | | + # 0xA120 -> --------tensor1 begin ------ + # | | + # | | + # | | + # | | + # | | + # 0xA160 -> --------tensor1 end--------- + # | | + # | | + # | | + # 0xA200 -> --------storage1 end-------- + # | | + # 0xE000 -> --------CUDA allocation----- + # + # To send tensor1, the following info are required from sender to receiver for + # storage recontruction. + # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process). + # basePtr may not be exactly 0xA000 since it's a different process. + # 2. offset(0xA100) of storage1 in the CUDA allocation. + # 3. size of storage1(0x100). + # + # On receiver side: + # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage + # of the same type using (basePtr, offset, size). + # 2. we can reconstruct the tensor on top of the reconstructed storage + # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100)) + # + # This strategy has a few implications: + # + # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one + # go (non-compositionally), and this requires to have a global map + # memHandle -> devPtr for each process. + # + # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize + # of the storage beyond 0x100 would merely have caused us to do a + # reallocation. You don't really want to do this, but if you did, + # all that would happen is that you would lose IPC sharing. But if + # you do this in the new world, we will happily let you write out of + # bounds of your "allocation", clobbering unrelated data in the cached + # allocator block. BAD! + # + # By the way, in old versions of PyTorch, we supported this situation + # natively using a "storage view", which permitted multiple storages to be + # views on each other. But this was the *only* use of storage views, so we + # eliminated it so that we could just use tensor views to implement the same + # thing. + # + + # TODO: Handle distinguishing between subclass and non-subclass versions of NT better + # https://github.com/pytorch/pytorch/issues/110543 + from torch.nested._internal.nested_tensor import NestedTensor + + if tensor.is_nested and not isinstance(tensor, NestedTensor): + return reduce_nested_tensor(tensor) + + if tensor.layout in { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_bsr, + torch.sparse_csc, + torch.sparse_bsc, + }: + return reduce_sparse_tensor(tensor) + + storage = tensor._typed_storage() + + if storage._untyped_storage.device.type == "cuda": + ( + device, + handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() + tensor_offset = tensor.storage_offset() + shared_cache[handle] = StorageWeakRef(storage) + # _backward_hooks purposely omitted here, see + # Note [Don't serialize hooks] + return ( + rebuild_cuda_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset, # tensor offset in its storage + type(storage), + tensor.dtype, + device, + handle, # identifier which CUDA allocation is the storage in. + storage_size_bytes, # size(in bytes) of the storage + storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation + tensor.requires_grad, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ), + ) + elif storage._untyped_storage.device.type == "meta": + return ( + rebuild_meta_tensor, + ( + type(tensor), + tensor.size(), + tensor.stride(), + tensor.storage_offset(), + tensor.dtype, + tensor.untyped_storage().size(), + tensor.requires_grad, + ), + ) + + # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] + metadata = ( + tensor.storage_offset(), + tensor.size(), + tensor.stride(), + tensor.requires_grad, + ) + return (rebuild_tensor, (type(tensor), storage, metadata)) + + +def rebuild_nested_tensor( + rebuild_buffer_func, + rebuild_buffer_args, + rebuild_sizes_func, + rebuild_sizes_args, + rebuild_strides_func, + rebuild_strides_args, + rebuild_offsets_func, + rebuild_offsets_args, +): + buffer = rebuild_buffer_func(*rebuild_buffer_args) + sizes = rebuild_sizes_func(*rebuild_sizes_args) + strides = rebuild_strides_func(*rebuild_strides_args) + offsets = rebuild_offsets_func(*rebuild_offsets_args) + return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets) + + +def reduce_nested_tensor(nt): + rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values()) + rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size()) + rebuild_strides_func, rebuild_strides_args = reduce_tensor( + nt._nested_tensor_strides() + ) + rebuild_offsets_func, rebuild_offsets_args = reduce_tensor( + nt._nested_tensor_storage_offsets() + ) + + return ( + rebuild_nested_tensor, + ( + rebuild_buffer_func, + rebuild_buffer_args, + rebuild_sizes_func, + rebuild_sizes_args, + rebuild_strides_func, + rebuild_strides_args, + rebuild_offsets_func, + rebuild_offsets_args, + ), + ) + + +def rebuild_sparse_coo_tensor( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + is_coalesced, +): + indices = rebuild_indices_func(*rebuild_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced) + + +def rebuild_sparse_compressed_tensor( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + shape, + layout, +): + compressed_indices = rebuild_compressed_indices_func( + *rebuild_compressed_indices_args + ) + plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args) + values = rebuild_values_func(*rebuild_values_args) + return torch.sparse_compressed_tensor( + compressed_indices, plain_indices, values, shape, layout=layout + ) + + +def reduce_sparse_tensor(sparse): + if sparse.layout is torch.sparse_coo: + rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices()) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values()) + return ( + rebuild_sparse_coo_tensor, + ( + rebuild_indices_func, + rebuild_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.is_coalesced(), + ), + ) + else: + if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices = sparse.crow_indices() + plain_indices = sparse.col_indices() + elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}: + compressed_indices = sparse.ccol_indices() + plain_indices = sparse.row_indices() + else: + raise NotImplementedError(sparse.layout) + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + ) = reduce_tensor(compressed_indices) + rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor( + plain_indices + ) + rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values()) + return ( + rebuild_sparse_compressed_tensor, + ( + rebuild_compressed_indices_func, + rebuild_compressed_indices_args, + rebuild_plain_indices_func, + rebuild_plain_indices_args, + rebuild_values_func, + rebuild_values_args, + sparse.shape, + sparse.layout, + ), + ) + + +def fd_id(fd): + # Returns a tuple which uniquely identifies a file descriptor. In Mac OS, + # this doesn't work with shared memory handles, which is why we don't + # support the "file_descriptor" sharing method on that platform. + stat = os.fstat(fd) + return (stat.st_ino, stat.st_dev) + + +def storage_from_cache(cls, key): + storage_ref = shared_cache.get(key) + if storage_ref is None: + return None + return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata) + + +def rebuild_storage_fd(cls, df, size): + fd = df.detach() + try: + storage = storage_from_cache(cls, fd_id(fd)) + if storage is not None: + return storage + storage = cls._new_shared_fd_cpu(fd, size) + shared_cache[fd_id(fd)] = StorageWeakRef(storage) + return storage + finally: + os.close(fd) + + +def rebuild_storage_filename(cls, manager, handle, size, dtype=None): + storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache( + cls, handle + ) + if storage is not None: + return storage._shared_decref() + if dtype is None: + storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) + else: + byte_size = size * torch._utils._element_size(dtype) + untyped_storage: torch.UntypedStorage = ( + torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) + ) + storage = torch.TypedStorage( + wrap_storage=untyped_storage, dtype=dtype, _internal=True + ) + shared_cache[handle] = StorageWeakRef(storage) + return storage._shared_decref() + + +def rebuild_storage_empty(cls): + return cls() + + +def rebuild_typed_storage(storage, dtype): + return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) + + +# Use for torch.storage.TypedStorage +def reduce_typed_storage(storage): + return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) + + +def rebuild_typed_storage_child(storage, storage_type): + return storage_type(wrap_storage=storage, _internal=True) + + +# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage +def reduce_typed_storage_child(storage): + return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) + + +def reduce_storage(storage): + from . import get_sharing_strategy + + if storage.is_cuda: + raise RuntimeError( + "Cannot pickle CUDA storage; try pickling a CUDA tensor instead" + ) + elif storage.device.type == "meta": + raise RuntimeError( + "Cannot pickle meta storage; try pickling a meta tensor instead" + ) + elif get_sharing_strategy() == "file_system": + metadata = storage._share_filename_cpu_() + cache_key = metadata[1] + rebuild = rebuild_storage_filename + if isinstance(storage, torch.TypedStorage): + metadata += (storage.dtype,) + storage._shared_incref() + elif storage.size() == 0: + # This is special cased because Empty tensors + # (with size 0) cannot be mmapped. + return (rebuild_storage_empty, (type(storage),)) + else: + fd, size = storage._share_fd_cpu_() + df = multiprocessing.reduction.DupFd(fd) + cache_key = fd_id(fd) + metadata = (df, size) + rebuild = rebuild_storage_fd # type: ignore[assignment] + + shared_cache[cache_key] = StorageWeakRef(storage) + return (rebuild, (type(storage),) + metadata) + + +def init_reductions(): + ForkingPickler.register(torch.cuda.Event, reduce_event) + + for t in torch._storage_classes: + if t.__name__ == "UntypedStorage": + ForkingPickler.register(t, reduce_storage) + else: + ForkingPickler.register(t, reduce_typed_storage_child) + + ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + + for t in torch._tensor_classes: + ForkingPickler.register(t, reduce_tensor) + + # TODO: Maybe this should be in tensor_classes? :) + ForkingPickler.register(torch.Tensor, reduce_tensor) + + from torch.nn.parameter import Parameter + + ForkingPickler.register(Parameter, reduce_tensor) diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py b/.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py new file mode 100644 index 0000000000000000000000000000000000000000..74bdde0fd97b20355686fc49fdb50a8fe02c5006 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py @@ -0,0 +1,328 @@ +# mypy: allow-untyped-defs +import logging +import multiprocessing +import multiprocessing.connection +import os +import pickle +import signal +import sys +import tempfile +import time +import warnings +from concurrent.futures import as_completed, ThreadPoolExecutor +from typing import Optional + +from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] + + +ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START" + +log = logging.getLogger(__name__) + +__all__ = [ + "ProcessContext", + "ProcessException", + "ProcessExitedException", + "ProcessRaisedException", + "spawn", + "SpawnContext", + "start_processes", +] + + +class ProcessException(Exception): + __slots__ = ["error_index", "error_pid"] + + def __init__(self, msg: str, error_index: int, pid: int): + super().__init__(msg) + self.msg = msg + self.error_index = error_index + self.pid = pid + + def __reduce__(self): + return type(self), (self.msg, self.error_index, self.pid) + + +class ProcessRaisedException(ProcessException): + """Exception raised when a process failed due to an exception raised by the code.""" + + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + ): + super().__init__(msg, error_index, error_pid) + + +class ProcessExitedException(ProcessException): + """Exception raised when a process failed due to signal or exited with a specific code.""" + + __slots__ = ["exit_code"] + + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + exit_code: int, + signal_name: Optional[str] = None, + ): + super().__init__(msg, error_index, error_pid) + self.exit_code = exit_code + self.signal_name = signal_name + + def __reduce__(self): + return ( + type(self), + (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name), + ) + + +def _wrap(fn, i, args, error_file): + # prctl(2) is a Linux specific system call. + # On other systems the following function call has no effect. + # This is set to ensure that non-daemonic child processes can + # terminate if their parent terminates before they do. + _prctl_pr_set_pdeathsig(signal.SIGINT) + + try: + fn(i, *args) + except KeyboardInterrupt: + pass # SIGINT; Killed by parent, do nothing + except Exception: + # Propagate exception to parent process, keeping original traceback + import traceback + + with open(error_file, "wb") as fh: + pickle.dump(traceback.format_exc(), fh) + sys.exit(1) + + +class ProcessContext: + def __init__(self, processes, error_files): + self.error_files = error_files + self.processes = processes + self.sentinels = { + process.sentinel: index for index, process in enumerate(processes) + } + + def pids(self): + return [int(process.pid) for process in self.processes] + + def join(self, timeout=None): + r"""Join one or more processes within spawn context. + + Attempt to join one or more processes in this spawn context. + If one of them exited with a non-zero exit status, this function + kills the remaining processes and raises an exception with the cause + of the first process exiting. + + Returns ``True`` if all processes have been joined successfully, + ``False`` if there are more processes that need to be joined. + + Args: + timeout (float): Wait this long before giving up on waiting. + """ + # Ensure this function can be called even when we're done. + if len(self.sentinels) == 0: + return True + + # Wait for any process to fail or all of them to succeed. + ready = multiprocessing.connection.wait( + self.sentinels.keys(), + timeout=timeout, + ) + + error_index = None + for sentinel in ready: + index = self.sentinels.pop(sentinel) + process = self.processes[index] + process.join() + if process.exitcode != 0: + error_index = index + break + + # Return if there was no error. + if error_index is None: + # Return whether or not all processes have been joined. + return len(self.sentinels) == 0 + + # Assume failure. Terminate processes that are still alive. + # Try SIGTERM then SIGKILL if the process isn't going down. + # The reason is related to python signal handling is limited + # to main thread and if that is in c/c++ land and stuck it won't + # to handle it. We have seen processes getting stuck not handling + # SIGTERM for the above reason. + timeout: int = 30 + for process in self.processes: + if process.is_alive(): + log.warning("Terminating process %s via signal SIGTERM", process.pid) + process.terminate() + end = time.monotonic() + timeout + for process in self.processes: + time_to_wait = max(0, end - time.monotonic()) + process.join(time_to_wait) + for process in self.processes: + if process.is_alive(): + log.warning( + "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL", + process.pid, + ) + process.kill() + process.join() + + # The file will only be created if the process crashed. + failed_process = self.processes[error_index] + if not os.access(self.error_files[error_index], os.R_OK): + exitcode = self.processes[error_index].exitcode + if exitcode < 0: + try: + name = signal.Signals(-exitcode).name + except ValueError: + name = f"" + raise ProcessExitedException( + "process %d terminated with signal %s" % (error_index, name), + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + signal_name=name, + ) + else: + raise ProcessExitedException( + "process %d terminated with exit code %d" % (error_index, exitcode), + error_index=error_index, + error_pid=failed_process.pid, + exit_code=exitcode, + ) + + with open(self.error_files[error_index], "rb") as fh: + original_trace = pickle.load(fh) + msg = "\n\n-- Process %d terminated with the following error:\n" % error_index + msg += original_trace + raise ProcessRaisedException(msg, error_index, failed_process.pid) + + +class SpawnContext(ProcessContext): + def __init__(self, processes, error_files): + warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") + super().__init__(processes, error_files) + + +# Note: [start_processes] +# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a +# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the +# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork' +# works better than 'spawn'. Every helper function we created for mp.spawn is indeed +# general enough, and backends like XLA can reuse them in Colab notebooks as well. +# Currently we only add this API first, we can consider adding it to documentation as +# needed in the future. +def start_processes( + fn, + args=(), + nprocs=1, + join=True, + daemon=False, + start_method="spawn", +): + # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), + # this func will start processes in parallel if start_method is 'forkserver'. + # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1. + # todo: investigate why spawn does not work with threadpool and raises SIGINT + if ( + start_method == "forkserver" + and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1" + ): + log.info("Starting processes in parallel.") + start_parallel = True + else: + # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start + start_parallel = False + + mp = multiprocessing.get_context(start_method) + error_files = [None] * nprocs + processes = [None] * nprocs + + def start_process(i): + # Each process is assigned a file to write tracebacks to. We + # use the file being non-empty to indicate an exception + # occurred (vs an expected shutdown). Note: this previously + # used a multiprocessing.Queue but that can be prone to + # deadlocks, so we went with a simpler solution for a one-shot + # message between processes. + tf = tempfile.NamedTemporaryFile( + prefix="pytorch-errorfile-", suffix=".pickle", delete=False + ) + tf.close() + os.unlink(tf.name) + process = mp.Process( + target=_wrap, + args=(fn, i, args, tf.name), + daemon=daemon, + ) + process.start() + return i, process, tf.name + + if not start_parallel: + for i in range(nprocs): + idx, process, tf_name = start_process(i) + error_files[idx] = tf_name + processes[idx] = process + else: + with ThreadPoolExecutor(max_workers=nprocs) as executor: + futures = [executor.submit(start_process, i) for i in range(nprocs)] + for fut in as_completed(futures): + idx, process, tf_name = fut.result() + # idx and process rank needs to be the same. + error_files[idx] = tf_name + processes[idx] = process + context = ProcessContext(processes, error_files) + if not join: + return context + + # Loop on join until it returns True or raises an exception. + while not context.join(): + pass + + +def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): + r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. + + If one of the processes exits with a non-zero exit status, the + remaining processes are killed and an exception is raised with the + cause of termination. In the case an exception was caught in the + child process, it is forwarded and its traceback is included in + the exception raised in the parent process. + + Args: + fn (function): Function is called as the entrypoint of the + spawned process. This function must be defined at the top + level of a module so it can be pickled and spawned. This + is a requirement imposed by multiprocessing. + + The function is called as ``fn(i, *args)``, where ``i`` is + the process index and ``args`` is the passed through tuple + of arguments. + + args (tuple): Arguments passed to ``fn``. + nprocs (int): Number of processes to spawn. + join (bool): Perform a blocking join on all processes. + daemon (bool): The spawned processes' daemon flag. If set to True, + daemonic processes will be created. + start_method (str): (deprecated) this method will always use ``spawn`` + as the start method. To use a different start method + use ``start_processes()``. + + Returns: + None if ``join`` is ``True``, + :class:`~ProcessContext` if ``join`` is ``False`` + + """ + if start_method != "spawn": + msg = ( + f"This method only supports start_method=spawn (got: {start_method}).\n" + "To use a different start_method use:\n\t\t" + " torch.multiprocessing.start_processes(...)" + ) + warnings.warn(msg, FutureWarning, stacklevel=2) + return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d271856b95210a60966a8a9a97e8cb109af29506 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e0ca1743566174b7abfd3663dfa90b744ba56f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py @@ -0,0 +1,9 @@ +from torch.ao.nn.quantizable.modules.activation import MultiheadAttention +from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell + + +__all__ = [ + "LSTM", + "LSTMCell", + "MultiheadAttention", +] diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07d2128ee59eb64088a38b3805d3aa3121f8c202 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..441d839d5783d8e12a8f92c139cc6519d13ca4ec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8967b541b96aad6d073c86bd1053ca02594e9c10 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7cfa9867c629d9ce173ed96e8071d53a2a96e5c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..249b4b7a08be6d21fc3d0196191d233a7e27cbe9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47ddcb331eac5aed2e858b171354bc46ac8d1d3a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77b2a068c7d066d22763ae6cafd80af2117c2df6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4dcc773691966eca4e36133247a459da2035701 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py @@ -0,0 +1,39 @@ +from . import parametrizations, rnn, stateless +from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_ +from .convert_parameters import parameters_to_vector, vector_to_parameters +from .fusion import ( + fuse_conv_bn_eval, + fuse_conv_bn_weights, + fuse_linear_bn_eval, + fuse_linear_bn_weights, +) +from .init import skip_init +from .memory_format import ( + convert_conv2d_weight_memory_format, + convert_conv3d_weight_memory_format, +) +from .spectral_norm import remove_spectral_norm, spectral_norm +from .weight_norm import remove_weight_norm, weight_norm + + +__all__ = [ + "clip_grad_norm", + "clip_grad_norm_", + "clip_grad_value_", + "convert_conv2d_weight_memory_format", + "convert_conv3d_weight_memory_format", + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", + "parameters_to_vector", + "parametrizations", + "remove_spectral_norm", + "remove_weight_norm", + "rnn", + "skip_init", + "spectral_norm", + "stateless", + "vector_to_parameters", + "weight_norm", +] diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d049c78a99747d28346f5f085af87552054474b6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..783a91c4f4b103fe2296bfe0299fa1309a01f3bf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d35891f123efc748c70599cf2b0717c746fd63f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17840707dd2aba17a1ceb9cc796134a922f28799 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3473c2f66980f07638fc28147c0bad19d367c9a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d05c60e565666337f54e9871d0630d4340c10136 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acac605cc85234ffc8825daf437cc8464fbdb483 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e98690ba89a0da457e4dc969360adf85a7c5aab Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8726a3032427099ab6a62c81c4419f9a2c5f492d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8574be4f6bf3a93105257087a72fd91f43430369 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbcdc0eef7de0780fbcffbc60c44f62b499d51df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88e5906a177948cf77014a42f856bc59dec995aa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a837ceef0eafaf5c07534b01f435a0ed692808 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86ee11f65446782d76d221375f3303f02604d400 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9537802bf300c5b6a80959ec3b2e5e2f58e37fdc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0359182c631a359c775add49b4551edbe4517b0b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15d560ab5bc371b3e23da564f0b0bc23116f5468 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py @@ -0,0 +1,54 @@ +# mypy: allow-untyped-defs +import importlib +import warnings +from typing import Callable, List + + +_MESSAGE_TEMPLATE = ( + r"Usage of '{old_location}' is deprecated; please use '{new_location}' instead." +) + + +def lazy_deprecated_import( + all: List[str], + old_module: str, + new_module: str, +) -> Callable: + r"""Import utility to lazily import deprecated packages / modules / functional. + + The old_module and new_module are also used in the deprecation warning defined + by the `_MESSAGE_TEMPLATE`. + + Args: + all: The list of the functions that are imported. Generally, the module's + __all__ list of the module. + old_module: Old module location + new_module: New module location / Migrated location + + Returns: + Callable to assign to the `__getattr__` + + Usage: + + # In the `torch/nn/quantized/functional.py` + from torch.nn.utils._deprecation_utils import lazy_deprecated_import + _MIGRATED_TO = "torch.ao.nn.quantized.functional" + __getattr__ = lazy_deprecated_import( + all=__all__, + old_module=__name__, + new_module=_MIGRATED_TO) + """ + warning_message = _MESSAGE_TEMPLATE.format( + old_location=old_module, new_location=new_module + ) + + def getattr_dunder(name): + if name in all: + # We are using the "RuntimeWarning" to make sure it is not + # ignored by default. + warnings.warn(warning_message, RuntimeWarning) + package = importlib.import_module(new_module) + return getattr(package, name) + raise AttributeError(f"Module {new_module!r} has no attribute {name!r}.") + + return getattr_dunder diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__init__.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0eaf86bdbeacfe7d4e7cbd50daf11385955d7d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__init__.py @@ -0,0 +1,10 @@ +from .conv_expanded_weights import ConvPerSampleGrad +from .embedding_expanded_weights import EmbeddingPerSampleGrad +from .expanded_weights_impl import ExpandedWeight +from .group_norm_expanded_weights import GroupNormPerSampleGrad +from .instance_norm_expanded_weights import InstanceNormPerSampleGrad +from .layer_norm_expanded_weights import LayerNormPerSampleGrad +from .linear_expanded_weights import LinearPerSampleGrad + + +__all__ = ["ExpandedWeight"] diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc44d5913a5a56d2a81dafcc79c7457d3f1072ae Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90cc0595dbdb7a052376a62f238fbb10405e802f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cda7fabaa56aa4d9d5e44827cc375af21beb600b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..aee3a75e70f45542451b81525f8b5ae7e01b7a77 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -0,0 +1,68 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn.functional as F + +from .conv_utils import ( + conv_args_and_kwargs, + conv_backward, + conv_input_for_string_padding, + conv_picker, +) +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import forward_helper + + +@implements_per_sample_grads(F.conv1d) +@implements_per_sample_grads(F.conv2d) +@implements_per_sample_grads(F.conv3d) +class ConvPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, conv_fn, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = conv_args_and_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + orig_input = expanded_args[0] + was_same_padding = expanded_kwargs["padding"] == "same" + + if isinstance(expanded_kwargs["padding"], str): + # if padding is a string, we'll do the necessary padding (slowly) using F.pad + kernel_size = expanded_args[1].shape[2:] + padding, dilation = expanded_kwargs["padding"], expanded_kwargs["dilation"] + input = conv_input_for_string_padding( + conv_fn, padding, expanded_args[0], dilation, kernel_size + ) + expanded_args = (input, expanded_args[1]) + # since we've already done the padding, don't need any more + expanded_kwargs["padding"] = 0 + + output = forward_helper(conv_fn, expanded_args, expanded_kwargs) + input, weight = expanded_args + batched_dim_size = conv_picker(conv_fn, 3, 4, 5) + if input.dim() != batched_dim_size: + raise RuntimeError( + f"Expanded Weights only support convolution with batched input, got {conv_fn} with an" + f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}" + ) + + ctx.conv_fn = conv_fn + + ctx.batch_size = orig_input.shape[0] + ctx.input_required_grad = orig_input.requires_grad + ctx.orig_input_shape = orig_input.shape + ctx.was_same_padding = was_same_padding + ctx.stride, ctx.padding = expanded_kwargs["stride"], expanded_kwargs["padding"] + ctx.dilation, ctx.groups = ( + expanded_kwargs["dilation"], + expanded_kwargs["groups"], + ) + + if isinstance(weight, ExpandedWeight): + ctx.input = input + ctx.weight = weight + ctx.bias = expanded_kwargs["bias"] + + return output + + @staticmethod + def backward(ctx, grad_output): + return conv_backward(ctx.conv_fn, ctx, grad_output) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ecbe9bdcc0a741c7a6fc81f1f88a45b0ba4369 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py @@ -0,0 +1,353 @@ +# mypy: allow-untyped-defs +from typing import List, Optional + +import numpy as np + +import torch +import torch.nn.functional as F + +from .expanded_weights_utils import ( + set_grad_sample_if_exists, + unpack_expanded_weight_or_tensor, +) + + +THRESHOLD = 32 + + +def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt): + if func == F.conv1d: + return conv1dOpt + if func == F.conv2d: + return conv2dOpt + else: + assert func == F.conv3d + return conv3dOpt + + +def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs): + args = expanded_args_and_kwargs[: len(expanded_args_and_kwargs) - len(kwarg_names)] + kwargs = expanded_args_and_kwargs[ + len(expanded_args_and_kwargs) - len(kwarg_names) : + ] + kwargs = dict(zip(kwarg_names, kwargs)) + + return conv_normalizer(*args, **kwargs) + + +def conv_normalizer( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, +): + return (input, weight), { + "bias": bias, + "stride": stride, + "padding": padding, + "dilation": dilation, + "groups": groups, + } + + +def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size): + if padding_style == "valid": + return input + else: + padding = int_padding_for_string_padding( + func, padding_style, dilation, kernel_size + ) + return F.pad(input, padding) + + +def int_padding_for_string_padding(func, padding_style, dilation, kernel_size): + def get_dilation(i): + return dilation[i] if isinstance(dilation, tuple) else dilation + + if padding_style == "same": + padding: List[int] = [] + # F.pad needs the padding in reverse order from what conv expects + for i in range(conv_picker(func, 0, 1, 2), -1, -1): + padding += conv_padding_for_same(get_dilation(i), kernel_size[i]) + return padding + elif padding_style == "valid": + return conv_picker(func, 2, 4, 6) * (0,) + else: + raise RuntimeError( + f"got padding type of {padding_style}, only accept 'same' or 'valid'" + ) + + +def conv_padding_for_same(dilation, kernel_size): + total_pad = dilation * (kernel_size - 1) + left_pad = total_pad // 2 + right_pad = total_pad - left_pad + return left_pad, right_pad + + +def conv_backward(func, ctx, grad_output): + def weight_grad_sample(weight): + if batch_size < THRESHOLD and groups == 1: + return conv_group_weight_grad_sample( + ctx.input, + grad_output, + weight_shape, + stride, + padding, + dilation, + batch_size, + func, + ) + else: + return conv_unfold_weight_grad_sample( + ctx.input, + grad_output, + weight_shape, + kernel_size, + stride, + padding, + dilation, + groups, + func, + ) + + def expand(param): + if isinstance(param, int): + return conv_picker(func, (param,), (param, param), (param, param, param)) + else: + return param + + def calc_total_padding(func, was_same, padding, dilation, kernel_size): + if was_same: + all_padding = int_padding_for_string_padding( + func, "same", dilation, kernel_size + ) + # F.pad needs the padding in reverse order from what conv expects + total_padding = tuple( + all_padding[i] + all_padding[i - 1] + for i in range(len(all_padding) - 1, -1, -2) + ) + return total_padding + else: + return tuple(2 * pad for pad in padding) + + weight_shape = ctx.weight.shape + stride, padding, dilation, groups = ( + expand(ctx.stride), + expand(ctx.padding), + expand(ctx.dilation), + ctx.groups, + ) + + kernel_size = [] + for i in range(2, conv_picker(func, 3, 4, 5)): + kernel_size.append(weight_shape[i]) + + batch_size = ctx.batch_size + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + # "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding + total_padding = calc_total_padding( + func, ctx.was_same_padding, padding, dilation, kernel_size + ) + + if ctx.input_required_grad: + output_padding = [] + input_dims = conv_picker(func, 1, 2, 3) + for i in range(input_dims): + input_dim = ctx.orig_input_shape[2 + i] + output_padding.append( + ( + total_padding[i] + + input_dim + - (kernel_size[i] * dilation[i] - dilation[i] + 1) + ) + % stride[i] + ) + weight_ = unpack_expanded_weight_or_tensor(ctx.weight) + transpose_func = conv_picker( + func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d + ) + out = transpose_func( + grad_output, + weight_, + None, + stride, + padding, + tuple(output_padding), + groups, + dilation, + ) + + if ctx.was_same_padding: + for i in range(len(total_padding)): + out = torch.narrow( + out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i] + ) + + results.append(out) + else: + results.append(None) + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 6 + + # set grad_sample field for weight and bias with per sample gradients + set_grad_sample_if_exists(ctx.weight, weight_grad_sample) + set_grad_sample_if_exists( + ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2) + ) + return tuple(results) + + +def conv_unfold_weight_grad_sample( + input, + grad_output, + weight_shape, + kernel_size, + stride, + padding, + dilation, + groups, + func, +): + n = input.shape[0] + in_channels = input.shape[1] + + unfold_func = conv_picker( + func, + lambda: F.unfold( + input.unsqueeze(-2), + kernel_size=(1, kernel_size[0]), + dilation=(1, dilation[0]), + padding=(0, padding[0]), + stride=(1, stride[0]), + ), + lambda: F.unfold( + input, kernel_size, dilation=dilation, padding=padding, stride=stride + ), + lambda: unfold3d(input, kernel_size, padding, stride, dilation), + ) + + input = unfold_func() + grad_output = grad_output.reshape(n, -1, input.shape[-1]) + + # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz + weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input) + # rearrange the above tensor and extract diagonals. + weight_grad_sample = weight_grad_sample.view( + n, + groups, + -1, + groups, + int(in_channels / groups), + np.prod(kernel_size), + ) + weight_grad_sample = torch.einsum( + "ngrg...->ngr...", weight_grad_sample + ).contiguous() + shape = [n] + list(weight_shape) + weight_grad_sample = weight_grad_sample.view(shape) + return weight_grad_sample + + +def conv_group_weight_grad_sample( + input, + grad_output, + weight_shape, + stride, + padding, + dilation, + batch_size, + func, +): + I = input.shape[1] + O = grad_output.shape[1] + + input_ = input.transpose(0, 1) + grad_output_ = grad_output.view( + grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] + ) + + weight_grad_sample = func( + input_, + grad_output_, + None, + stride=dilation, + padding=padding, + dilation=stride, + groups=batch_size, + ) + input_dims = conv_picker(func, 3, 4, 5) + for i in range(2, input_dims): + weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i]) + weight_grad_sample = weight_grad_sample.view( + I, batch_size, O, *weight_grad_sample.shape[2:] + ) + weight_grad_sample = weight_grad_sample.movedim(0, 2) + return weight_grad_sample + + +def unfold3d( + tensor, + kernel_size, + padding, + stride, + dilation, +): + r""" + Extract sliding local blocks from an batched input tensor. + + :class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors). + This method implements the same action for 5D inputs + Args: + tensor: An input tensor of shape ``(B, C, D, H, W)``. + kernel_size: the size of the sliding blocks + padding: implicit zero padding to be added on both sides of input + stride: the stride of the sliding blocks in the input spatial dimensions + dilation: the spacing between the kernel points. + Returns: + A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions. + See :class:`torch.nn.Unfold` for more details + Example: + >>> # xdoctest: +SKIP + >>> B, C, D, H, W = 3, 4, 5, 6, 7 + >>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W) + >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape + torch.Size([3, 32, 120]) + """ + if len(tensor.shape) != 5: + raise ValueError( + f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}" + ) + + if dilation != (1, 1, 1): + raise NotImplementedError(f"dilation={dilation} not supported.") + + batch_size, channels, _, _, _ = tensor.shape + + # Input shape: (B, C, D, H, W) + tensor = F.pad( + tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]) + ) + # Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0]) + + tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0]) + tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1]) + tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2]) + # Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2]) + # For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold` + + tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7) + # Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2]) + + tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose( + 1, 2 + ) + # Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2] + + return tensor diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..3db3371379d111898abe2aec2751d907cef188bf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -0,0 +1,83 @@ +# mypy: allow-untyped-defs +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, +) + + +@implements_per_sample_grads(F.embedding) +class EmbeddingPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + if len(expanded_args[0].shape) == 1: + raise RuntimeError( + f"Expanded Weights needs an input with a batch size, got a 1D tensor, {expanded_args[0]}" + ) + output = forward_helper(F.embedding, expanded_args, expanded_kwargs) + ctx.input, ctx.weight = expanded_args + ctx.padding_idx, ctx.scale_grad_by_freq = ( + expanded_kwargs["padding_idx"], + expanded_kwargs["scale_grad_by_freq"], + ) + ctx.sparse = expanded_kwargs["sparse"] + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.input, ctx.weight + padding_idx, scale_grad_by_freq, sparse = ( + ctx.padding_idx, + ctx.scale_grad_by_freq, + ctx.sparse, + ) + + def weight_per_sample_grad(weight): + batch_size = input.shape[0] + embedding_dim = weight.shape[1] + index = ( + input.unsqueeze(-1) + .expand(*input.shape, embedding_dim) + .reshape(batch_size, -1, embedding_dim) + ) + grad_sample = torch.zeros( + batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype + ) + return grad_sample.scatter_add_( + 1, index, grad_output.reshape(batch_size, -1, embedding_dim) + ) + + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + if input.requires_grad: + bw_fn = torch.ops.aten.embedding_backward + results.append( + bw_fn( + grad_output, + input, + weight.shape[0], + padding_idx, + scale_grad_by_freq, + sparse, + ) + ) + else: + results.append(None) + + # weight doesn't compute batched gradients; no other arguments are differentiable (2 not saved from forward) + results = results + [None] * 6 + + # set grad_sample field for weight with per sample gradients + set_grad_sample_if_exists(weight, weight_per_sample_grad) + return tuple(results) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..83b38cc57eb75c21bb44f15d51f6252c70ca4fea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -0,0 +1,182 @@ +# mypy: allow-untyped-defs +import functools +from contextlib import contextmanager +from typing import Callable, Dict + +import torch +from torch._decomp import decomposition_table +from torch.utils._pytree import tree_map_only + + +HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {} + +aten = torch._ops.ops.aten +# __torch_function__ runs before the pydispatcher so we need to manually use the same +# decompositions indexed by their torch equivalent +expanded_weights_rnn_decomps = { + # func: (input_decomp, data_decomp) + torch.rnn_relu: ( + decomposition_table[aten.rnn_relu.input], + decomposition_table[aten.rnn_relu.data], + ), + torch.rnn_tanh: ( + decomposition_table[aten.rnn_tanh.input], + decomposition_table[aten.rnn_tanh.data], + ), + torch.lstm: ( + decomposition_table[aten.lstm.input], + decomposition_table[aten.lstm.data], + ), + torch.gru: ( + decomposition_table[aten.gru.input], + decomposition_table[aten.gru.data], + ), +} + + +# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set +@contextmanager +def batch_second(args, kwargs): + def set_batch_second(ew): + ew.set_batch_first(False) + + def reset_batch_first(ew): + ew.set_batch_first(True) + + tree_map_only(ExpandedWeight, set_batch_second, args) + tree_map_only(ExpandedWeight, set_batch_second, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset_batch_first, args) + tree_map_only(ExpandedWeight, reset_batch_first, kwargs) + + +# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch +@contextmanager +def allow_smaller_batches(args, kwargs): + def allow(ew): + ew.set_allow_smaller_batches(True) + + def reset(ew): + ew.set_allow_smaller_batches(False) + + tree_map_only(ExpandedWeight, allow, args) + tree_map_only(ExpandedWeight, allow, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset, args) + tree_map_only(ExpandedWeight, reset, kwargs) + + +@contextmanager +def setup_rnn(use_input_variant, args, kwargs): + with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches( + args, kwargs + ): + yield + + +def implements_per_sample_grads(torch_function): + @functools.wraps(torch_function) + def decorator(autograd_func): + HANDLED_FUNCTIONS[torch_function] = autograd_func + return autograd_func + + return decorator + + +# ExpandedWeight represents a weight (parameter) Tensor that has an expanded +# batch dimension. Operations on the ExpandedWeight Tensor act exactly like +# those without an expanded batch dimension but a call to .backward() populates +# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field +# +# ExpandedWeight has a fallback that always fails since we cannot know what the batch +# dimension of the input tensor is and therefore cannot know if this is a valid call +# +# This is a __torch_function__ object but it could have also been a Tensor Extension +# with a dispatch key. +# +# Needs to be a tensor subclass to allow reparamaterization +class ExpandedWeight(torch.Tensor): + def __init__(self, orig_weight, batch_size, loss_reduction): + self.batch_size = batch_size + self.batch_first = True + self.allow_smaller_batches = False + self.orig_weight = orig_weight + self.loss_reduction = loss_reduction + + handled_functions = HANDLED_FUNCTIONS + + def __new__(cls, orig_weight, batch_size, loss_reduction): + if not isinstance(orig_weight, torch.Tensor): + raise RuntimeError( + f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}" + ) + if not orig_weight.requires_grad: + raise RuntimeError( + "Can only build ExpandedWeights objects of tensors that require_grad" + ) + ret = torch.Tensor._make_subclass(cls, orig_weight, True) + return ret + + @classmethod + def __torch_function__(cls, func, _, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in expanded_weights_rnn_decomps: + # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that + decomp_opts = expanded_weights_rnn_decomps[func] + use_input_variant = isinstance( + args[2], list + ) # data variant uses a list here + decomp = decomp_opts[0] if use_input_variant else decomp_opts[1] + + if decomp is not None: + with setup_rnn(use_input_variant, args, kwargs): + return decomp(*args, **kwargs) + if func == torch._cudnn_rnn_flatten_weight: + # since we aren't using the fused cuda kernels for RNNs, don't do this + return + if func in cls.handled_functions: + return cls.handled_functions[func].apply( + tuple(kwargs.keys()), func, *(args + tuple(kwargs.values())) + ) + # We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs, + # i.e. torch.add(torch.Tensor, ExpandedWeight) + raise RuntimeError( + f"Expanded Weights encountered but cannot handle function {func.__name__}" + ) + + @property + def dtype(self): + return self.orig_weight.dtype + + @property + def data(self): + return self.orig_weight.data + + @property + def shape(self): + return self.orig_weight.shape + + @property + def device(self): + return self.orig_weight.device + + @property + def is_cuda(self): + return self.orig_weight.is_cuda + + def data_ptr(self): + return self.orig_weight.data_ptr() + + def get_device(self): + return self.orig_weight.get_device() + + def set_allow_smaller_batches(self, is_allow_smaller_batches): + self.allow_smaller_batches = is_allow_smaller_batches + + def set_batch_first(self, is_batch_first=True): + self.batch_first = is_batch_first diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1249adfd7594b5d0def11c168d247398a42bee7e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -0,0 +1,188 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch + +from .expanded_weights_impl import ExpandedWeight + + +def is_batch_first(expanded_args_and_kwargs): + batch_first = None + for arg in expanded_args_and_kwargs: + if not isinstance(arg, ExpandedWeight): + continue + + if not batch_first: + batch_first = arg.batch_first + elif arg.batch_first != batch_first: + raise RuntimeError( + "Got conflicting batch_first arguments in the same layer" + ) + return batch_first + + +def standard_kwargs(kwarg_names, expanded_args): + r"""Separate args and kwargs from `__torch_function__`s that standardize kwargs. + + Most `__torch_function__`s standardize the kwargs that they give, so this will separate + the args and kwargs they pass. Functions that don't are linear and convND. + """ + kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names) :] + expanded_args_without_kwargs = expanded_args[ + : len(expanded_args) - len(kwarg_names) + ] + expanded_kwargs = dict(zip(kwarg_names, kwarg_values)) + return expanded_args_without_kwargs, expanded_kwargs + + +def forward_helper(func, expanded_args, expanded_kwargs): + r"""Compute the forward pass for a function that has expanded weight(s) passed to it. + + It will run the forward pass where all ExpandedWeights are their original + weight. It runs checks on the given arguments and detaches the outputs. + + .. note:: First argument in :attr:`expanded_args` must be the input with the batch + dimension as the first element of the shape + + .. note:: :attr:`func` must return a Tensor or tuple of Tensors + + Args: + func: The function to be called + expanded_args: Arguments to be passed to :attr:`func`. Will include arguments + that need to be unpacked because they are ExpandedWeights + expanded_kwargs: Keyword arguments to be passed to :attr:`func`. + Similar to :attr:`expanded_args`. + """ + unexpanded_args, unexpanded_kwargs = _check_and_unexpand_args( + func, expanded_args, expanded_kwargs + ) + return func(*unexpanded_args, **unexpanded_kwargs) + + +def _check_and_unexpand_args(func, expanded_args, expanded_kwargs): + # input must be the first argument passed + input = expanded_args[0] + if isinstance(input, ExpandedWeight): + raise RuntimeError( + "Expanded Weights do not support inputs that are also ExpandedWeights. " + f"Input must be a Tensor, got {type(input).__name__} in function {func.__name__}" + ) + if not isinstance(input, torch.Tensor): + raise RuntimeError( + "Expanded Weights requires a Tensor as the first input to get the batch dimension, " + f"got {type(input).__name__} in function {func.__name__}" + ) + if len(input.shape) == 0: + raise RuntimeError( + f"Expanded Weights requires a batch dimension but got an input of size 0 in function {func.__name__}" + ) + if input.shape[0] == 0: + raise RuntimeError( + "0 is not a valid batch size for Expanded Weights but got input tensor of " + f"{input} in function {func.__name__}" + ) + for arg in expanded_args + tuple(expanded_kwargs.values()): + if not isinstance(arg, ExpandedWeight): + continue + batch_size = input.shape[0] if arg.batch_first else input.shape[1] + if (arg.allow_smaller_batches and batch_size > arg.batch_size) or ( + not arg.allow_smaller_batches and arg.batch_size != batch_size + ): + raise RuntimeError( + "Expected ExpandedWeights to have batch size matching input but got " + f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}" + ) + + loss_reduction: Optional[str] = None + for arg in expanded_args + tuple(expanded_kwargs.values()): + if isinstance(arg, ExpandedWeight): + if loss_reduction is None: + loss_reduction = arg.loss_reduction + elif loss_reduction != arg.loss_reduction: + raise RuntimeError( + "Expected ExpandedWeights to all have the same loss_reduction argument but got one" + f"with {loss_reduction} and one with {arg.loss_reduction}" + ) + + unexpanded_args = tuple( + arg.orig_weight if isinstance(arg, ExpandedWeight) else arg + for arg in expanded_args + ) + unexpanded_kwargs = { + name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg + for (name, arg) in expanded_kwargs.items() + } + return unexpanded_args, unexpanded_kwargs + + +def maybe_scale_by_batch_size(grad_sample, expanded_weight): + if expanded_weight.loss_reduction == "mean": + return grad_sample * expanded_weight.batch_size + else: + return grad_sample + + +def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): + unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) + if isinstance(maybe_expanded_weight, ExpandedWeight): + grad_sample_contribution = maybe_scale_by_batch_size( + per_sample_grad_fn(unpacked), maybe_expanded_weight + ) + + if maybe_expanded_weight.batch_size > grad_sample_contribution.shape[0]: + # this only passes the other checks if the arg allows smaller batch sizes + intermediate = torch.zeros( + maybe_expanded_weight.batch_size, + *grad_sample_contribution.shape[1:], + dtype=grad_sample_contribution.dtype, + device=grad_sample_contribution.device, + ) + intermediate[: grad_sample_contribution.shape[0]] = grad_sample_contribution + grad_sample_contribution = intermediate + + if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None: + unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution + else: + unpacked.grad_sample = grad_sample_contribution + + +def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x): + if isinstance(maybe_expanded_weight, ExpandedWeight): + orig_weight = maybe_expanded_weight.orig_weight + return func(orig_weight) + elif ( + isinstance(maybe_expanded_weight, torch.Tensor) + and not maybe_expanded_weight.requires_grad + ): + return func(maybe_expanded_weight) + elif isinstance(maybe_expanded_weight, torch.Tensor): + raise RuntimeError( + "ExpandedWeights currently does not support a mixture of ExpandedWeight parameters " + "and normal Parameters. Please file and issue with pytorch/pytorch" + ) + + +def sum_over_all_but_batch_and_last_n( + tensor: torch.Tensor, + n_dims: int, +) -> torch.Tensor: + r""" + Calculate the sum over all dimensions, except the first (batch dimension), and excluding the last n_dims. + + This function will ignore the first dimension and it will + not aggregate over the last n_dims dimensions. + Args: + tensor: An input tensor of shape ``(B, ..., X[n_dims-1])``. + n_dims: Number of dimensions to keep. + Example: + >>> tensor = torch.ones(1, 2, 3, 4, 5) + >>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape + torch.Size([1, 4, 5]) + Returns: + A tensor of shape ``(B, ..., X[n_dims-1])`` + """ + if tensor.dim() == n_dims + 1: + return tensor + else: + dims = list(range(1, tensor.dim() - n_dims)) + return tensor.sum(dim=dims) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..ef9197a827bf774a8be7a6350a314b7720f85264 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.group_norm) +class GroupNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + input, num_groups = expanded_args + N = input.shape[0] + C = input.shape[1] + HxW = reduce(operator.mul, input.shape[2:], 1) + weight, bias, eps = ( + expanded_kwargs["weight"], + expanded_kwargs["bias"], + expanded_kwargs["eps"], + ) + output, mean, rstd = forward_helper( + torch.native_group_norm, + (input, weight, bias, N, C, HxW, num_groups, eps), + {}, + ) + ctx.input, ctx.num_groups = input, num_groups + ctx.weight, ctx.eps = weight, eps + ctx.mean, ctx.rstd = mean, rstd + if isinstance(bias, ExpandedWeight): + ctx.bias = bias + if input.requires_grad and isinstance(weight, ExpandedWeight): + ctx.weight = weight + return output + + @staticmethod + def backward(ctx, grad_output): + input, num_groups = ctx.input, ctx.num_groups + weight, bias, eps = ctx.weight, ctx.bias, ctx.eps + mean, rstd = ctx.mean, ctx.rstd + + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + if input.requires_grad: + weight_c = unpack_expanded_weight_or_tensor( + weight, lambda t: t.contiguous() + ) + input_c = input.contiguous() + grad_output_c = ( + grad_output.contiguous() if grad_output is not None else None + ) + N = input.shape[0] + C = input.shape[1] + HxW = 1 + for s in input.shape[2:]: + HxW *= s + bw_fn = torch.ops.aten.native_group_norm_backward + results.append( + bw_fn( + grad_output_c, + input_c, + mean, + rstd, + weight_c, + N, + C, + HxW, + num_groups, + (True, False, False), + )[0] + ) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 4 + + # set grad_sample field for weight and bias with per sample gradients + if hasattr(ctx, "weight"): + set_grad_sample_if_exists( + weight, + lambda _: torch.einsum( + "ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output + ), + ) + if hasattr(ctx, "bias"): + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("ni...->ni", grad_output) + ) + return tuple(results) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..3929bfa9f2f6d84c071fbc9f354250a2724c556f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -0,0 +1,100 @@ +# mypy: allow-untyped-defs +from functools import partial +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.instance_norm) +class InstanceNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + instance_norm = partial(torch.instance_norm, cudnn_enabled=True) + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + output = forward_helper(instance_norm, expanded_args, expanded_kwargs) + ctx.input = expanded_args[0] + ctx.running_mean, ctx.running_var = ( + expanded_kwargs["running_mean"], + expanded_kwargs["running_var"], + ) + ctx.weight, ctx.bias, ctx.eps = ( + expanded_kwargs["weight"], + expanded_kwargs["bias"], + expanded_kwargs["eps"], + ) + return output + + @staticmethod + def backward(ctx, grad_output): + input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var + weight, bias, eps = ctx.weight, ctx.bias, ctx.eps + + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + if input.requires_grad: + b = input.shape[0] + c = input.shape[1] + new_shape = (1, b * c, *input.shape[2:]) + + weight_ = unpack_expanded_weight_or_tensor( + weight, lambda orig_weight: orig_weight.repeat(b) + ) + running_mean_ = running_mean.repeat(b) if running_mean is not None else None + running_var_ = running_var.repeat(b) if running_var is not None else None + input_reshaped = input.contiguous().view(new_shape) + grad_output_reshaped = grad_output.contiguous().view(new_shape) + mean = torch.mean( + input_reshaped, (0,) + tuple(range(2, input.dim())), False + ) + var = torch.var( + input_reshaped, + (0,) + tuple(range(2, input.dim())), + keepdim=False, + unbiased=False, + ) + rstd = 1 / torch.sqrt(var + eps) + + # must use native batch norm since it supports all inputs. This may have used cuda or openmi during the forward but + # it didn't save the metadata, so we don't know during the backward + res = torch.ops.aten.native_batch_norm_backward( + grad_output_reshaped, + input_reshaped, + weight_, + running_mean_, + running_var_, + mean, + rstd, + True, + eps, + (True, False, False), + ) + results.append(res[0].reshape(input.shape)) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable (2 are not saved from the forward) + results = results + [None] * 7 + + # set grad_sample field for weight and bias with per sample gradients + set_grad_sample_if_exists( + weight, + lambda _: torch.einsum( + "ni...->ni", F.instance_norm(input, eps=eps) * grad_output + ), + ) + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("ni...->ni", grad_output) + ) + return tuple(results) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..8f529665092dc35a557881c3bce7da57fdb1db48 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + sum_over_all_but_batch_and_last_n, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.layer_norm) +class LayerNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs( + kwarg_names, expanded_args_and_kwargs + ) + input = expanded_args[0] + normalized_shape = expanded_args[1] + if len(input.shape) <= len(normalized_shape): + raise RuntimeError( + "Expanded Weights: Layer norm should not normalize over batch dimension for per sample gradient" + f"computations but got that normalized shape, {normalized_shape}, matched input shape." + ) + output, mean, rstd = forward_helper( + torch.native_layer_norm, expanded_args, expanded_kwargs + ) + ctx.args = expanded_args + + if input.requires_grad or isinstance(expanded_kwargs["weight"], ExpandedWeight): + ctx.weight = expanded_kwargs["weight"] + if input.requires_grad or isinstance(expanded_kwargs["bias"], ExpandedWeight): + ctx.bias = expanded_kwargs["bias"] + ctx.eps = expanded_kwargs["eps"] + ctx.mean, ctx.rstd = mean, rstd + return output + + @staticmethod + def backward(ctx, grad_output): + def weight_per_sample_grad(weight): + return sum_over_all_but_batch_and_last_n( + F.layer_norm(input, normalized_shape, eps=ctx.eps) * grad_output, + weight.dim(), + ) + + input, normalized_shape = ctx.args + mean, rstd = ctx.mean, ctx.rstd + + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + if input.requires_grad: + weight_ = unpack_expanded_weight_or_tensor(ctx.weight) + bias_ = unpack_expanded_weight_or_tensor(ctx.bias) + results.append( + torch.ops.aten.native_layer_norm_backward( + grad_output, + input, + normalized_shape, + mean, + rstd, + weight_, + bias_, + (True, False, False), + )[0] + ) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 4 + + # set grad_sample field for weight and bias with per sample gradients + if hasattr(ctx, "weight"): + set_grad_sample_if_exists(ctx.weight, weight_per_sample_grad) + if hasattr(ctx, "bias"): + set_grad_sample_if_exists( + ctx.bias, + lambda bias: sum_over_all_but_batch_and_last_n(grad_output, bias.dim()), + ) + return tuple(results) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd11428f2b977978693a8e7dbd3b1ce0a8d1125 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/linear_expanded_weights.py @@ -0,0 +1,62 @@ +# mypy: allow-untyped-defs +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import ( + forward_helper, + is_batch_first, + set_grad_sample_if_exists, + unpack_expanded_weight_or_tensor, +) + + +@implements_per_sample_grads(F.linear) +class LinearPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, _, __, *expanded_args_and_kwargs): + if len(expanded_args_and_kwargs[0].shape) <= 1: + raise RuntimeError( + "Input does not have a batch dimension. Expanded Weights expected input " + f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}" + ) + expanded_kwargs = { + "bias": expanded_args_and_kwargs[2] + if len(expanded_args_and_kwargs) == 3 + else None + } + expanded_args = expanded_args_and_kwargs[:2] + ctx.batch_first = is_batch_first(expanded_args_and_kwargs) + output = forward_helper(F.linear, expanded_args, expanded_kwargs) + ctx.args = expanded_args + ctx.kwargs = expanded_kwargs + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.args + bias = ctx.kwargs["bias"] + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg_names + results.append(None) # for op reference + + if input.requires_grad: + results.append(grad_output.matmul(unpack_expanded_weight_or_tensor(weight))) + else: + results.append(None) + results.extend([None] * 2) # weight and bias don't compute batched gradients + + if not ctx.batch_first: + grad_output = grad_output.transpose(0, 1) + input = input.transpose(0, 1) + + # weight and bias get their grad_sample fields set directly if they exist + set_grad_sample_if_exists( + weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input) + ) + set_grad_sample_if_exists( + bias, lambda _: torch.einsum("n...k->nk", grad_output) + ) + return tuple(results) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_named_member_accessor.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_named_member_accessor.py new file mode 100644 index 0000000000000000000000000000000000000000..f1f5a117e685d2c32cda21d5d2562f424c5c44d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_named_member_accessor.py @@ -0,0 +1,372 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Iterable, List, Tuple + +import torch + + +_MISSING: torch.Tensor = object() # type: ignore[assignment] + + +def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None: + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(tensor, torch.Tensor) and tensor is not None: + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + if "." in name: + raise KeyError('tensor name can\'t contain "."') + if name == "": + raise KeyError('tensor name can\'t be empty string ""') + if name in module._parameters: + module._parameters[name] = tensor # type: ignore[assignment] + elif name in module._buffers: + module._buffers[name] = tensor + else: + setattr(module, name, tensor) + + +def swap_tensor( + module: "torch.nn.Module", + name: str, + tensor: torch.Tensor, + allow_missing: bool = False, +) -> torch.Tensor: + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if ( + tensor is not _MISSING + and not isinstance(tensor, torch.Tensor) + and tensor is not None + ): + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + if "." in name: + raise KeyError('tensor name can\'t contain "."') + if name == "": + raise KeyError('tensor name can\'t be empty string ""') + + orig_tensor: torch.Tensor + if name in module._parameters: + orig_tensor = module._parameters[name] # type: ignore[assignment] + if tensor is not _MISSING: + module._parameters[name] = tensor # type: ignore[assignment] + else: + del module._parameters[name] + elif name in module._buffers: + orig_tensor = module._buffers[name] # type: ignore[assignment] + if tensor is not _MISSING: + module._buffers[name] = tensor + else: + del module._buffers[name] + else: + if hasattr(module, name): + orig_tensor = getattr(module, name) + else: + if not allow_missing: + raise AttributeError(f"{module._get_name()} has no attribute `{name}`") + orig_tensor = _MISSING + if ( + orig_tensor is not _MISSING + and not isinstance(orig_tensor, torch.Tensor) + and orig_tensor is not None + ): + raise TypeError( + f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor" + ) + if tensor is not _MISSING: + setattr(module, name, tensor) + elif hasattr(module, name): + delattr(module, name) + return orig_tensor + + +def swap_submodule( + module: "torch.nn.Module", + name: str, + submodule: "torch.nn.Module", +) -> "torch.nn.Module": + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(submodule, torch.nn.Module): + raise TypeError(f"{submodule} is not an instance of torch.nn.Module") + if "." in name: + raise KeyError('submodule name can\'t contain "."') + if name == "": + raise KeyError('submodule name can\'t be empty string ""') + if name not in module._modules: + raise KeyError(f"submodule {name} does not exist") + + orig_submodule = module._modules[name] + if not isinstance(orig_submodule, torch.nn.Module): + raise TypeError(f"{name} attribute is not an instance of torch.nn.Module") + module._modules[name] = submodule + return orig_submodule + + +class NamedMemberAccessor: + """ + A class that provides a way to access the submodules and parameters/buffers of a module. + + It provides caching mechanism to speed up submodule lookups. + This is useful for functional programming to manipulate the module state. + """ + + def __init__(self, module: "torch.nn.Module") -> None: + self.module = module + self.memo: Dict[str, torch.nn.Module] = {} + + # Nested attribute access + + def get_submodule(self, name: str) -> "torch.nn.Module": + """ + Return the submodule specified by the given path. + + For example, to get the submodule mod.layer1.conv1, + use accessor.get_submodule("layer1.conv1") + + Compare to mod.get_submodule("layer1.conv1"), this method will cache the + intermediate submodule access to speed up future lookups. + """ + if not name: + return self.module + + if name in self.memo: + return self.memo[name] + else: + prefix, dot, attr = name.rpartition(".") + if dot: + module = self.get_submodule(prefix) + else: + module = self.module + try: + submodule = getattr(module, attr) + except AttributeError as ex: + raise AttributeError( + f"{module._get_name()} has no attribute `{attr}`" + ) from ex + if not isinstance(submodule, torch.nn.Module): + raise TypeError( # noqa: B904 + f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module" + ) + self.memo[name] = submodule + return submodule + + def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module": + """ + Swap the submodule specified by the given ``path`` to ``value``. + + For example, to swap the attribute mod.layer1.conv1 use + ``accessor.swap_submodule("layer1.conv1", conv2)``. + """ + prefix, _, attr = path.rpartition(".") + return swap_submodule(self.get_submodule(prefix), attr, value) + + def get_tensor(self, name: str) -> torch.Tensor: + """ + Get the tensor specified by the given path to value. + + For example, to get the attribute mod.layer1.conv1.weight, + use accessor.get_tensor('layer1.conv1.weight') + + Compare to mod.get_parameter("layer1.conv1.weight"), this method will + cache the intermediate submodule access to speed up future lookups. + """ + prefix, _, attr = name.rpartition(".") + submodule = self.get_submodule(prefix) + try: + tensor = getattr(submodule, attr) + except AttributeError as ex: + raise AttributeError( + f"{submodule._get_name()} has no attribute `{name}`" + ) from ex + if not isinstance(tensor, torch.Tensor) and tensor is not None: + raise TypeError(f"{tensor} is not an instance of torch.Tensor") + return tensor # type: ignore[return-value] + + def set_tensor(self, name: str, value: torch.Tensor) -> None: + """ + Set the attribute specified by the given path to value. + + For example, to set the attribute mod.layer1.conv1.weight, + use accessor.set_tensor("layer1.conv1.weight", value) + """ + prefix, _, attr = name.rpartition(".") + set_tensor(self.get_submodule(prefix), attr, value) + + def del_tensor(self, name: str) -> None: + """ + Delete the attribute specified by the given path. + + For example, to delete the attribute mod.layer1.conv1.weight, + use accessor.del_tensor("layer1.conv1.weight") + """ + prefix, _, attr = name.rpartition(".") + submodule = self.get_submodule(prefix) + try: + delattr(submodule, attr) + except AttributeError as ex: + raise AttributeError( + f"{submodule._get_name()} has no attribute `{name}`" + ) from ex + + def swap_tensor( + self, name: str, value: torch.Tensor, allow_missing: bool = False + ) -> torch.Tensor: + """ + Swap the attribute specified by the given path to value. + + For example, to swap the attribute mod.layer1.conv1.weight, + use accessor.swap_tensor("layer1.conv1.weight", value) + """ + prefix, _, attr = name.rpartition(".") + return swap_tensor( + self.get_submodule(prefix), attr, value, allow_missing=allow_missing + ) + + # Batched operations + + def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]: + """ + Get the tensors specified by the given paths. + + For example, to get the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"]) + """ + return [self.get_tensor(name) for name in names] + + def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None: + """ + Set the attributes specified by the given paths to values. + + For example, to set the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"], [weight, bias]) + """ + if not isinstance(names, (list, tuple)): + names = list(names) + if not isinstance(values, (list, tuple)): + values = list(values) + assert len(names) == len(values), "names and values must have the same length" + + for name, value in zip(names, values): + self.set_tensor(name, value) + + def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None: + """ + Set the attributes specified by the given paths to values. + + For example, to set the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.set_tensors_dict({ + "layer1.conv1.weight": weight, + "layer1.conv1.bias": bias, + }) + """ + for name, value in named_tensors.items(): + self.set_tensor(name, value) + + def del_tensors(self, names: Iterable[str]) -> None: + """ + Delete the attributes specified by the given paths. + + For example, to delete the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"]) + """ + for name in names: + self.del_tensor(name) + + def swap_tensors( + self, + names: Iterable[str], + values: Iterable[torch.Tensor], + allow_missing: bool = False, + ) -> List[torch.Tensor]: + """ + Swap the attributes specified by the given paths to values. + + For example, to swap the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight", + "layer1.conv1.bias"], [weight, bias]) + """ + if not isinstance(names, (list, tuple)): + names = list(names) + if not isinstance(values, (list, tuple)): + values = list(values) + assert len(names) == len(values), "names and values must have the same length" + + return [ + self.swap_tensor(name, value, allow_missing=allow_missing) + for name, value in zip(names, values) + ] + + def swap_tensors_dict( + self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False + ) -> Tuple[Dict[str, torch.Tensor], List[str]]: + """ + Swap the attributes specified by the given paths to values. + + For example, to swap the attributes mod.layer1.conv1.weight and + mod.layer1.conv1.bias, use accessor.swap_tensors_dict({ + "layer1.conv1.weight": weight, + "layer1.conv1.bias": bias, + }) + """ + orig_named_tensors = {} + missing_keys = [] + try: + for name, tensor in named_tensors.items(): + orig_tensor = self.swap_tensor(name, tensor, allow_missing=True) + if orig_tensor is _MISSING: + missing_keys.append(name) + orig_named_tensors[name] = orig_tensor + except Exception: + # Swap back if any exception occurs + for name, orig_tensor in orig_named_tensors.items(): + self.swap_tensor(name, orig_tensor, allow_missing=True) + raise + if missing_keys and not allow_missing: + # Swap back if any key is missing when allow_missing is False + for name, orig_tensor in orig_named_tensors.items(): + self.swap_tensor(name, orig_tensor, allow_missing=True) + raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") + return orig_named_tensors, missing_keys + + def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]: + """Check that the given keys are valid.""" + keys = set(keys) + valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)} + missing_keys = valid_keys - keys + unexpected_keys = keys - valid_keys + return sorted(missing_keys), sorted(unexpected_keys) + + # Shortcut methods + + def named_parameters( + self, + remove_duplicate: bool = True, + ) -> Iterable[Tuple[str, torch.Tensor]]: + """Iterate over all the parameters in the module.""" + yield from self.module.named_parameters(remove_duplicate=remove_duplicate) + + def named_buffers( + self, + remove_duplicate: bool = True, + ) -> Iterable[Tuple[str, torch.Tensor]]: + """Iterate over all the buffers in the module.""" + yield from self.module.named_buffers(remove_duplicate=remove_duplicate) + + def named_tensors( + self, + remove_duplicate: bool = True, + ) -> Iterable[Tuple[str, torch.Tensor]]: + """Iterate over all the tensors in the module.""" + yield from self.module.named_parameters(remove_duplicate=remove_duplicate) + yield from self.module.named_buffers(remove_duplicate=remove_duplicate) + + def named_modules( + self, + remove_duplicate: bool = True, + ) -> Iterable[Tuple[str, "torch.nn.Module"]]: + """Iterate over all the modules in the module.""" + yield from self.module.named_modules(remove_duplicate=remove_duplicate) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/_per_sample_grad.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/_per_sample_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb6e1eeaf3c04b986f2b7d29a89761c66b7e6cd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/_per_sample_grad.py @@ -0,0 +1,124 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight +from torch.utils import _pytree as pytree + + +# dependency on `functional_call` means that this can't be exposed in utils +# without creating circular dependency +def call_for_per_sample_grads( + module, + *, + batch_size=None, + loss_reduction="sum", + batch_first=True, +): + r""" + Return a forward function for a module, populating grad_sample with per sample gradients on backward invocation. + + Args: + module: The ``nn.Module`` to get per sample gradients with respect to. All trainable + parameters will compute per sample gradients, located in a ``grad_sample`` + field when ``backward`` is invoked + batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have + the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually. + Default: None + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If + "mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from + running mean across a batch. Must be "mean" or "sum". Default: "sum" + batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first + dimension. If False, it's the second dimension. Default: True. + + Examples:: + >>> # xdoctest: +SKIP + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model)(batched_input).sum() + >>> res.backward() + >>> assert model.weight.shape == (3, 4) + >>> assert model.weight.grad_sample.shape == (5, 3, 4) + >>> assert model.weight.grad is None + >>> assert model.bias.shape == (3,) + >>> assert model.bias.grad_sample.shape == (5, 3) + >>> assert model.bias.grad is None + + An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be + if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all + grad_outputs by 1 / batch_size from cross batch interaction. + >>> model = nn.Linear(4, 3) + >>> batched_input = torch.randn(5, 4) # batch size of 5 + >>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean() + >>> res.backward() + + Note:: + Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom + rewrites that wrap an `nn.Linear` module. See Opacus for an example + """ + + def maybe_build_expanded_weight(og_tensor, batch_size): + if og_tensor.requires_grad: + return ExpandedWeight(og_tensor, batch_size, loss_reduction) + else: + return og_tensor + + def compute_batch_size(*args, **kwargs): + args_and_kwargs = pytree.arg_tree_leaves(*args, **kwargs) + batch_size = None + for arg in args_and_kwargs: + if not isinstance(arg, torch.Tensor): + continue + + arg_batch_size = arg.shape[0] if batch_first else arg.shape[1] + if batch_size is not None and batch_size != arg_batch_size: + raise RuntimeError( + "When computing batch size, found at least one input with batch size " + f"{batch_size} and one with batch size {arg_batch_size}. Please specify it " + "explicitly using the batch size kwarg in call_for_per_sample_grads" + ) + batch_size = arg_batch_size + if batch_size is None: + raise RuntimeError( + "Unable to find a tensor in the passed args and kwargs. They may not be pytree-able " + "and so ExpandedWeights cannot compute the batch size from the inputs. Please specify " + "it explicitly" + ) + return batch_size + + if loss_reduction not in ["sum", "mean"]: + raise RuntimeError( + f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}" + ) + + if not isinstance(module, torch.nn.Module): + raise RuntimeError( + f"Module passed must be nn.Module, got {type(module).__name__}" + ) + if not (batch_size is None or isinstance(batch_size, int)): + raise RuntimeError( + f"Batch size passed must be None or an integer, got {type(batch_size).__name__}" + ) + if batch_size is not None and batch_size < 1: + raise RuntimeError(f"Batch size must be positive, got {batch_size}") + for weight in module.parameters(): + if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined] + raise RuntimeError( + "Current Expanded Weights accumulates the gradients, which will be incorrect for multiple " + f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or " + "post an issue to pytorch/pytorch to prioritize correct behavior" + ) + + @functools.wraps(module.forward) + def wrapper(*args, **kwargs): + wrapper_batch_size = batch_size + if wrapper_batch_size is None: + wrapper_batch_size = compute_batch_size(*args, **kwargs) + + params = { + name: maybe_build_expanded_weight(value, wrapper_batch_size) + for (name, value) in module.named_parameters() + } + return torch.func.functional_call(module, params, args, kwargs) + + return wrapper diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..ea895b9c959889433472b158e33565b5983286ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py @@ -0,0 +1,189 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +from typing import cast, Dict, Iterable, List, Optional, Tuple, Union +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +__all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"] + + +_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + + +def _no_grad(func): + """ + This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions + clip_grad_norm_ and clip_grad_value_ themselves. + """ + + def _no_grad_wrapper(*args, **kwargs): + with torch.no_grad(): + return func(*args, **kwargs) + + functools.update_wrapper(_no_grad_wrapper, func) + return _no_grad_wrapper + + +@_no_grad +def clip_grad_norm_( + parameters: _tensor_or_tensors, + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm + + +@deprecated( + "`torch.nn.utils.clip_grad_norm` is now deprecated " + "in favor of `torch.nn.utils.clip_grad_norm_`.", + category=FutureWarning, +) +def clip_grad_norm( + parameters: _tensor_or_tensors, + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + .. warning:: + This method is now deprecated in favor of + :func:`torch.nn.utils.clip_grad_norm_`. + """ + return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) + + +@_no_grad +def clip_grad_value_( + parameters: _tensor_or_tensors, + clip_value: float, + foreach: Optional[bool] = None, +) -> None: + r"""Clip the gradients of an iterable of parameters at specified value. + + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + clip_value (float): maximum allowed value of the gradients. + The gradients are clipped in the range + :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` + foreach (bool): use the faster foreach-based implementation + If ``None``, use the foreach implementation for CUDA and CPU native tensors and + silently fall back to the slow implementation for other device types. + Default: ``None`` + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + clip_value = float(clip_value) + + grads = [p.grad for p in parameters if p.grad is not None] + grouped_grads = _group_tensors_by_device_and_dtype([grads]) + + for (device, _), ([grads], _) in grouped_grads.items(): # type: ignore[assignment] + if ( + foreach is None + and _has_foreach_support(cast(List[Tensor], grads), device=device) + ) or (foreach and _device_has_foreach_support(device)): + torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value) + torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + for grad in grads: + cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/convert_parameters.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/convert_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..6975f9c375185fb7b5b3d78385e9222c5c0b1090 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/convert_parameters.py @@ -0,0 +1,90 @@ +from typing import Iterable, Optional + +import torch + + +def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor: + r"""Flatten an iterable of parameters into a single vector. + + Args: + parameters (Iterable[Tensor]): an iterable of Tensors that are the + parameters of a model. + + Returns: + The parameters represented by a single vector + """ + # Flag for the device where the parameter is located + param_device = None + + vec = [] + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + vec.append(param.view(-1)) + return torch.cat(vec) + + +def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None: + r"""Copy slices of a vector into an iterable of parameters. + + Args: + vec (Tensor): a single vector representing the parameters of a model. + parameters (Iterable[Tensor]): an iterable of Tensors that are the + parameters of a model. + """ + # Ensure vec of type Tensor + if not isinstance(vec, torch.Tensor): + raise TypeError(f"expected torch.Tensor, but got: {torch.typename(vec)}") + # Flag for the device where the parameter is located + param_device = None + + # Pointer for slicing the vector for each parameter + pointer = 0 + for param in parameters: + # Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device) + + # The length of the parameter + num_param = param.numel() + # Slice the vector, reshape it, and replace the old data of the parameter + param.data = vec[pointer : pointer + num_param].view_as(param).data + + # Increment the pointer + pointer += num_param + + +def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int: + r"""Check if the parameters are located on the same device. + + Currently, the conversion between model parameters and single vector form is not supported + for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1. + + Args: + param ([Tensor]): a Tensor of a parameter of a model + old_param_device (int): the device where the first parameter of a + model is allocated. + + Returns: + old_param_device (int): report device for the first time + """ + # Meet the first parameter + support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()] + if old_param_device is None: + old_param_device = ( + param.get_device() if param.device.type in support_device_types else -1 + ) + else: + warn = False + if ( + param.device.type in support_device_types + ): # Check if in same GPU/PrivateUse1 + warn = param.get_device() != old_param_device + else: # Check if in CPU + warn = old_param_device != -1 + if warn: + raise TypeError( + "Found two parameters on different devices, " + "this is currently not supported." + ) + return old_param_device diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/fusion.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..74b4fe8b28dc5818a551a8bb0e1096f722481dc8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/fusion.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import copy +from typing import Optional, Tuple, TypeVar + +import torch + + +__all__ = [ + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", +] + +ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd") +LinearT = TypeVar("LinearT", bound="torch.nn.Linear") + + +def fuse_conv_bn_eval( + conv: ConvT, + bn: torch.nn.modules.batchnorm._BatchNorm, + transpose: bool = False, +) -> ConvT: + r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module. + + Args: + conv (torch.nn.modules.conv._ConvNd): A convolutional module. + bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. + transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False. + + Returns: + torch.nn.modules.conv._ConvNd: The fused convolutional module. + + .. note:: + Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. + """ + assert not (conv.training or bn.training), "Fusion only for eval!" + fused_conv = copy.deepcopy(conv) + + assert bn.running_mean is not None and bn.running_var is not None + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn.running_mean, + bn.running_var, + bn.eps, + bn.weight, + bn.bias, + transpose, + ) + + return fused_conv + + +def fuse_conv_bn_weights( + conv_w: torch.Tensor, + conv_b: Optional[torch.Tensor], + bn_rm: torch.Tensor, + bn_rv: torch.Tensor, + bn_eps: float, + bn_w: Optional[torch.Tensor], + bn_b: Optional[torch.Tensor], + transpose: bool = False, +) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: + r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters. + + Args: + conv_w (torch.Tensor): Convolutional weight. + conv_b (Optional[torch.Tensor]): Convolutional bias. + bn_rm (torch.Tensor): BatchNorm running mean. + bn_rv (torch.Tensor): BatchNorm running variance. + bn_eps (float): BatchNorm epsilon. + bn_w (Optional[torch.Tensor]): BatchNorm weight. + bn_b (Optional[torch.Tensor]): BatchNorm bias. + transpose (bool, optional): If True, transpose the conv weight. Defaults to False. + + Returns: + Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias. + """ + conv_weight_dtype = conv_w.dtype + conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype + if conv_b is None: + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) + bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) + + if transpose: + shape = [1, -1] + [1] * (len(conv_w.shape) - 2) + else: + shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) + + fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to( + dtype=conv_weight_dtype + ) + fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to( + dtype=conv_bias_dtype + ) + + return ( + torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), + torch.nn.Parameter(fused_conv_b, conv_b.requires_grad), + ) + + +def fuse_linear_bn_eval( + linear: LinearT, + bn: torch.nn.modules.batchnorm._BatchNorm, +) -> LinearT: + r"""Fuse a linear module and a BatchNorm module into a single, new linear module. + + Args: + linear (torch.nn.Linear): A Linear module. + bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. + + Returns: + torch.nn.Linear: The fused linear module. + + .. note:: + Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. + """ + assert not (linear.training or bn.training), "Fusion only for eval!" + fused_linear = copy.deepcopy(linear) + + """ + Linear-BN needs to be fused while preserving the shapes of linear weight/bias. + To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear, + because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in). + To be broadcastable, the number of features in bn and + the number of output features from linear must satisfy the following condition: + 1. they are equal, or + 2. the number of features in bn is 1 + Otherwise, skip the folding path + """ + assert ( + linear.out_features == bn.num_features or bn.num_features == 1 + ), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1" + + assert bn.running_mean is not None and bn.running_var is not None + fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights( + fused_linear.weight, + fused_linear.bias, + bn.running_mean, + bn.running_var, + bn.eps, + bn.weight, + bn.bias, + ) + + return fused_linear + + +def fuse_linear_bn_weights( + linear_w: torch.Tensor, + linear_b: Optional[torch.Tensor], + bn_rm: torch.Tensor, + bn_rv: torch.Tensor, + bn_eps: float, + bn_w: torch.Tensor, + bn_b: torch.Tensor, +) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: + r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters. + + Args: + linear_w (torch.Tensor): Linear weight. + linear_b (Optional[torch.Tensor]): Linear bias. + bn_rm (torch.Tensor): BatchNorm running mean. + bn_rv (torch.Tensor): BatchNorm running variance. + bn_eps (float): BatchNorm epsilon. + bn_w (torch.Tensor): BatchNorm weight. + bn_b (torch.Tensor): BatchNorm bias. + + Returns: + Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias. + """ + linear_weight_dtype = linear_w.dtype + linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype + if linear_b is None: + linear_b = torch.zeros_like(bn_rm) + bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps) + + fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype) + fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype) + + return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter( + fused_b, linear_b.requires_grad + ) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/init.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/init.py new file mode 100644 index 0000000000000000000000000000000000000000..10fa03b7c01c2eac7e474ef55f433e4704e6c778 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/init.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs +import inspect + +import torch + + +def skip_init(module_cls, *args, **kwargs): + r""" + Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers. + + This can be useful if initialization is slow or if custom initialization will + be performed, making the default initialization unnecessary. There are some caveats to this, due to + the way this function is implemented: + + 1. The module must accept a `device` arg in its constructor that is passed to any parameters + or buffers created during construction. + + 2. The module must not perform any computation on parameters in its constructor except + initialization (i.e. functions from :mod:`torch.nn.init`). + + If these conditions are satisfied, the module can be instantiated with parameter / buffer values + uninitialized, as if having been created using :func:`torch.empty`. + + Args: + module_cls: Class object; should be a subclass of :class:`torch.nn.Module` + args: args to pass to the module's constructor + kwargs: kwargs to pass to the module's constructor + + Returns: + Instantiated module with uninitialized parameters / buffers + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> import torch + >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) + >>> m.weight + Parameter containing: + tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], + requires_grad=True) + >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) + >>> m2.weight + Parameter containing: + tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, + 4.5915e-41]], requires_grad=True) + + """ + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError(f"Expected a Module; got {module_cls}") + if "device" not in inspect.signature(module_cls).parameters: + raise RuntimeError("Module must support a 'device' arg to skip initialization") + + final_device = kwargs.pop("device", "cpu") + kwargs["device"] = "meta" + return module_cls(*args, **kwargs).to_empty(device=final_device) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/memory_format.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/memory_format.py new file mode 100644 index 0000000000000000000000000000000000000000..ab723b4697b95e024bee0953f8fc6ade7b76d3e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/memory_format.py @@ -0,0 +1,152 @@ +# mypy: allow-untyped-defs +import torch + + +def convert_conv2d_weight_memory_format(module, memory_format): + r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``. + + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive + than the utility function ``convert_conv2d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NHWC(channels_last) conversion for + convolution in cuDNN, as it is beneficial to run convolution in NHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format). + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv2d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda") + >>> model = nn.Sequential( + >>> nn.Conv2d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) + >>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) + >>> out = model(input) + """ + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + weight_data = ( + module.weight.detach().clone().contiguous(memory_format=memory_format) + ) + module.weight.data = weight_data.resize_( + weight_data.size(), memory_format=memory_format + ) + for child in module.children(): + convert_conv2d_weight_memory_format(child, memory_format) + return module + + +def convert_conv3d_weight_memory_format(module, memory_format): + r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format`` + The conversion recursively applies to nested ``nn.Module``, including ``module``. + Note that it only changes the memory_format, but not the semantics of each dimensions. + This function is used to facilitate the computation to adopt NHWC kernels, which + provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0 + + .. note:: + Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive + than the utility function ``convert_conv3d_weight_memory_format``. Any + layer with 4d weight will be affected by ``model.to``, which does not + necessarily benefit from conversion to specified ``memory_format``. + One place we are confident in is that NDHWC(channels_last_3d) conversion for + convolution in cuDNN, as it is beneficial to run convolution in NDHWC, + even in cases where we have to apply permutation to input tensors. + + Hence our strategy here is to convert only the weight of convolution to + channels_last_3d. This ensures that; + 1. Fast convolution kernels will be used, the benefit of which could + outweigh overhead of permutation (if input is not in the same format). + 2. No unnecessary permutations are applied on layers that do not benefit + from memory_format conversion. + + The optimal case is that, layers between convolution layers are channels + last compatible. Input tensor would be permuted to channels last when it + encounters the first convolution layer and stay in that memory format. + Hence following convolutions will not need to permute its input tensor. + + In case where a channels last incompatible layer is between convolution + layers, we need to permute the input tensor back to contiguous format + for that layer. The input tensor will go through the remaining layers in + contiguous format and be permuted to channels last when it encounters + another convolution layer. There's no point in propagating that + permutation to an earlier layer, as most layers are quite agnostic to + ``memory_format``. + + This claim might change when PyTorch supports fusion of permutation, as + there might have been a better spot to fuse the permutation other than + immediately before a convolution. + + Args: + module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container + ``nn.Module`` + memory_format: user specified ``memory_format``, + e.g. ``torch.channels_last`` or ``torch.contiguous_format`` + + Returns: + The original module with updated ``nn.Conv3d`` + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG) + >>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") + >>> model = nn.Sequential( + >>> nn.Conv3d(8, 4, 3)).cuda().half() + >>> # This is identical to: + >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) + >>> out = model(input) + """ + + # TODO: expand this to `_ConvNd` when channels_last support is extended + # beyond only 4d tensors. + if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): + weight_data = ( + module.weight.detach().clone().contiguous(memory_format=memory_format) + ) + module.weight.data = weight_data.resize_( + weight_data.size(), memory_format=memory_format + ) + for child in module.children(): + convert_conv3d_weight_memory_format(child, memory_format) + return module diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrizations.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrizations.py new file mode 100644 index 0000000000000000000000000000000000000000..5a371af995b68a40cf4288f9f7fbd69b87da8083 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrizations.py @@ -0,0 +1,628 @@ +# mypy: allow-untyped-defs +from enum import auto, Enum +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import Module +from torch.nn.utils import parametrize + + +__all__ = ["orthogonal", "spectral_norm", "weight_norm"] + + +def _is_orthogonal(Q, eps=None): + n, k = Q.size(-2), Q.size(-1) + Id = torch.eye(k, dtype=Q.dtype, device=Q.device) + # A reasonable eps, but not too large + eps = 10.0 * n * torch.finfo(Q.dtype).eps + return torch.allclose(Q.mH @ Q, Id, atol=eps) + + +def _make_orthogonal(A): + """Assume that A is a tall matrix. + + Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative. + """ + X, tau = torch.geqrf(A) + Q = torch.linalg.householder_product(X, tau) + # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs + Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + return Q + + +class _OrthMaps(Enum): + matrix_exp = auto() + cayley = auto() + householder = auto() + + +class _Orthogonal(Module): + base: Tensor + + def __init__( + self, weight, orthogonal_map: _OrthMaps, *, use_trivialization=True + ) -> None: + super().__init__() + + # Note [Householder complex] + # For complex tensors, it is not possible to compute the tensor `tau` necessary for + # linalg.householder_product from the reflectors. + # To see this, note that the reflectors have a shape like: + # 0 0 0 + # * 0 0 + # * * 0 + # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters + # to parametrize the unitary matrices. Saving tau on its own does not work either, because + # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise + # them as independent tensors we would not maintain the constraint + # An equivalent reasoning holds for rectangular matrices + if weight.is_complex() and orthogonal_map == _OrthMaps.householder: + raise ValueError( + "The householder parametrization does not support complex tensors." + ) + + self.shape = weight.shape + self.orthogonal_map = orthogonal_map + if use_trivialization: + self.register_buffer("base", None) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + n, k = X.size(-2), X.size(-1) + transposed = n < k + if transposed: + X = X.mT + n, k = k, n + # Here n > k and X is a tall matrix + if ( + self.orthogonal_map == _OrthMaps.matrix_exp + or self.orthogonal_map == _OrthMaps.cayley + ): + # We just need n x k - k(k-1)/2 parameters + X = X.tril() + if n != k: + # Embed into a square matrix + X = torch.cat( + [X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1 + ) + A = X - X.mH + # A is skew-symmetric (or skew-hermitian) + if self.orthogonal_map == _OrthMaps.matrix_exp: + Q = torch.matrix_exp(A) + elif self.orthogonal_map == _OrthMaps.cayley: + # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} + Id = torch.eye(n, dtype=A.dtype, device=A.device) + Q = torch.linalg.solve( + torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5) + ) + # Q is now orthogonal (or unitary) of size (..., n, n) + if n != k: + Q = Q[..., :k] + # Q is now the size of the X (albeit perhaps transposed) + else: + # X is real here, as we do not support householder with complex numbers + A = X.tril(diagonal=-1) + tau = 2.0 / (1.0 + (A * A).sum(dim=-2)) + Q = torch.linalg.householder_product(A, tau) + # The diagonal of X is 1's and -1's + # We do not want to differentiate through this or update the diagonal of X hence the casting + Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) + + if hasattr(self, "base"): + Q = self.base @ Q + if transposed: + Q = Q.mT + return Q # type: ignore[possibly-undefined] + + @torch.autograd.no_grad() + def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: + if Q.shape != self.shape: + raise ValueError( + f"Expected a matrix or batch of matrices of shape {self.shape}. " + f"Got a tensor of shape {Q.shape}." + ) + + Q_init = Q + n, k = Q.size(-2), Q.size(-1) + transpose = n < k + if transpose: + Q = Q.mT + n, k = k, n + + # We always make sure to always copy Q in every path + if not hasattr(self, "base"): + # Note [right_inverse expm cayley] + # If we do not have use_trivialization=True, we just implement the inverse of the forward + # map for the Householder. To see why, think that for the Cayley map, + # we would need to find the matrix X \in R^{n x k} such that: + # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + # A = Y - Y.mH + # cayley(A)[:, :k] + # gives the original tensor. It is not clear how to do this. + # Perhaps via some algebraic manipulation involving the QR like that of + # Corollary 2.2 in Edelman, Arias and Smith? + if ( + self.orthogonal_map == _OrthMaps.cayley + or self.orthogonal_map == _OrthMaps.matrix_exp + ): + raise NotImplementedError( + "It is not possible to assign to the matrix exponential " + "or the Cayley parametrizations when use_trivialization=False." + ) + + # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. + # Here Q is always real because we do not support householder and complex matrices. + # See note [Householder complex] + A, tau = torch.geqrf(Q) + # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could + # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition + # The diagonal of Q is the diagonal of R from the qr decomposition + A.diagonal(dim1=-2, dim2=-1).sign_() + # Equality with zero is ok because LAPACK returns exactly zero when it does not want + # to use a particular reflection + A.diagonal(dim1=-2, dim2=-1)[tau == 0.0] *= -1 + return A.mT if transpose else A + else: + if n == k: + # We check whether Q is orthogonal + if not _is_orthogonal(Q): + Q = _make_orthogonal(Q) + else: # Is orthogonal + Q = Q.clone() + else: + # Complete Q into a full n x n orthogonal matrix + N = torch.randn( + *(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device + ) + Q = torch.cat([Q, N], dim=-1) + Q = _make_orthogonal(Q) + self.base = Q + + # It is necessary to return the -Id, as we use the diagonal for the + # Householder parametrization. Using -Id makes: + # householder(torch.zeros(m,n)) == torch.eye(m,n) + # Poor man's version of eye_like + neg_Id = torch.zeros_like(Q_init) + neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.0) + return neg_Id + + +def orthogonal( + module: Module, + name: str = "weight", + orthogonal_map: Optional[str] = None, + *, + use_trivialization: bool = True, +) -> Module: + r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices. + + Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized + matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as + + .. math:: + + \begin{align*} + Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ + QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} + \end{align*} + + where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex + and the transpose when :math:`Q` is real-valued, and + :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` + and orthonormal rows otherwise. + + If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. + + The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: + + - ``"matrix_exp"``/``"cayley"``: + the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ + :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric + :math:`A` to give an orthogonal matrix. + - ``"householder"``: computes a product of Householder reflectors + (:func:`~torch.linalg.householder_product`). + + ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than + ``"householder"``, but they are slower to compute for very thin or very wide matrices. + + If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", + where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under + ``module.parametrizations.weight[0].base``. This helps the + convergence of the parametrized layer at the expense of some extra memory use. + See `Trivializations for Gradient-Based Optimization on Manifolds`_ . + + Initial value of :math:`Q`: + If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value + of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) + and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). + Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. + Otherwise, the initial value is the result of the composition of all the registered + parametrizations applied to the original tensor. + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. + + + .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map + .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 + + Args: + module (nn.Module): module on which to register the parametrization. + name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. + orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. + Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. + use_trivialization (bool, optional): whether to use the dynamic trivialization framework. + Default: ``True``. + + Returns: + The original module with an orthogonal parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> orth_linear = orthogonal(nn.Linear(20, 40)) + >>> orth_linear + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _Orthogonal() + ) + ) + ) + >>> # xdoctest: +IGNORE_WANT + >>> Q = orth_linear.weight + >>> torch.dist(Q.T @ Q, torch.eye(20)) + tensor(4.9332e-07) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + # We could implement this for 1-dim tensors as the maps on the sphere + # but I believe it'd bite more people than it'd help + if weight.ndim < 2: + raise ValueError( + "Expected a matrix or batch of matrices. " + f"Got a tensor of {weight.ndim} dimensions." + ) + + if orthogonal_map is None: + orthogonal_map = ( + "matrix_exp" + if weight.size(-2) == weight.size(-1) or weight.is_complex() + else "householder" + ) + + orth_enum = getattr(_OrthMaps, orthogonal_map, None) + if orth_enum is None: + raise ValueError( + 'orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' + f"Got: {orthogonal_map}" + ) + orth = _Orthogonal(weight, orth_enum, use_trivialization=use_trivialization) + parametrize.register_parametrization(module, name, orth, unsafe=True) + return module + + +class _WeightNorm(Module): + def __init__( + self, + dim: Optional[int] = 0, + ) -> None: + super().__init__() + if dim is None: + dim = -1 + self.dim = dim + + def forward(self, weight_g, weight_v): + return torch._weight_norm(weight_v, weight_g, self.dim) + + def right_inverse(self, weight): + weight_g = torch.norm_except_dim(weight, 2, self.dim) + weight_v = weight + + return weight_g, weight_v + + +def weight_norm(module: Module, name: str = "weight", dim: int = 0): + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` with two parameters: one specifying the magnitude + and one specifying the direction. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _WeightNorm() + ) + ) + ) + >>> m.parametrizations.weight.original0.size() + torch.Size([40, 1]) + >>> m.parametrizations.weight.original1.size() + torch.Size([40, 20]) + + """ + _weight_norm = _WeightNorm(dim) + parametrize.register_parametrization(module, name, _weight_norm, unsafe=True) + + def _weight_norm_compat_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + g_key = f"{prefix}{name}_g" + v_key = f"{prefix}{name}_v" + if g_key in state_dict and v_key in state_dict: + original0 = state_dict.pop(g_key) + original1 = state_dict.pop(v_key) + state_dict[f"{prefix}parametrizations.{name}.original0"] = original0 + state_dict[f"{prefix}parametrizations.{name}.original1"] = original1 + + module._register_load_state_dict_pre_hook(_weight_norm_compat_hook) + return module + + +class _SpectralNorm(Module): + def __init__( + self, + weight: torch.Tensor, + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12, + ) -> None: + super().__init__() + ndim = weight.ndim + if dim >= ndim or dim < -ndim: + raise IndexError( + "Dimension out of range (expected to be in range of " + f"[-{ndim}, {ndim - 1}] but got {dim})" + ) + + if n_power_iterations <= 0: + raise ValueError( + "Expected n_power_iterations to be positive, but " + f"got n_power_iterations={n_power_iterations}" + ) + self.dim = dim if dim >= 0 else dim + ndim + self.eps = eps + if ndim > 1: + # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward) + self.n_power_iterations = n_power_iterations + weight_mat = self._reshape_weight_to_matrix(weight) + h, w = weight_mat.size() + + u = weight_mat.new_empty(h).normal_(0, 1) + v = weight_mat.new_empty(w).normal_(0, 1) + self.register_buffer("_u", F.normalize(u, dim=0, eps=self.eps)) + self.register_buffer("_v", F.normalize(v, dim=0, eps=self.eps)) + + # Start with u, v initialized to some reasonable values by performing a number + # of iterations of the power method + self._power_method(weight_mat, 15) + + def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + # Precondition + assert weight.ndim > 1 + + if self.dim != 0: + # permute dim to front + weight = weight.permute( + self.dim, *(d for d in range(weight.dim()) if d != self.dim) + ) + + return weight.flatten(1) + + @torch.autograd.no_grad() + def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None: + # See original note at torch/nn/utils/spectral_norm.py + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + + # Precondition + assert weight_mat.ndim > 1 + + for _ in range(n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + self._u = F.normalize( + torch.mv(weight_mat, self._v), # type: ignore[has-type] + dim=0, + eps=self.eps, + out=self._u, # type: ignore[has-type] + ) + self._v = F.normalize( + torch.mv(weight_mat.H, self._u), # type: ignore[has-type] + dim=0, + eps=self.eps, + out=self._v, # type: ignore[has-type] + ) + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if weight.ndim == 1: + # Faster and more exact path, no need to approximate anything + return F.normalize(weight, dim=0, eps=self.eps) + else: + weight_mat = self._reshape_weight_to_matrix(weight) + if self.training: + self._power_method(weight_mat, self.n_power_iterations) + # See above on why we need to clone + u = self._u.clone(memory_format=torch.contiguous_format) + v = self._v.clone(memory_format=torch.contiguous_format) + # The proper way of computing this should be through F.bilinear, but + # it seems to have some efficiency issues: + # https://github.com/pytorch/pytorch/issues/58093 + sigma = torch.vdot(u, torch.mv(weight_mat, v)) + return weight / sigma + + def right_inverse(self, value: torch.Tensor) -> torch.Tensor: + # we may want to assert here that the passed value already + # satisfies constraints + return value + + +def spectral_norm( + module: Module, + name: str = "weight", + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: Optional[int] = None, +) -> Module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + When applied on a vector, it simplifies to + + .. math:: + \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant + of the model. :math:`\sigma` is approximated performing one iteration of the + `power method`_ every time the weight is accessed. If the dimension of the + weight tensor is greater than 2, it is reshaped to 2D in power iteration + method to get spectral norm. + + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a + reimplementation of :func:`torch.nn.utils.spectral_norm`. + + .. note:: + When this constraint is registered, the singular vectors associated to the largest + singular value are estimated rather than sampled at random. These are then updated + performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor + is accessed with the module on `training` mode. + + .. note:: + If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, + is in training mode on removal, it will perform another power iteration. + If you'd like to avoid this iteration, set the module to eval mode + before its removal. + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter. Default: ``"weight"``. + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm. Default: ``1``. + eps (float, optional): epsilon for numerical stability in + calculating norms. Default: ``1e-12``. + dim (int, optional): dimension corresponding to number of outputs. + Default: ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with a new parametrization registered to the specified + weight + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> snm = spectral_norm(nn.Linear(20, 40)) + >>> snm + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _SpectralNorm() + ) + ) + ) + >>> torch.linalg.matrix_norm(snm.weight, 2) + tensor(1.0081, grad_fn=) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + f"Module '{module}' has no parameter or buffer with name '{name}'" + ) + + if dim is None: + if isinstance( + module, + ( + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ), + ): + dim = 1 + else: + dim = 0 + parametrize.register_parametrization( + module, name, _SpectralNorm(weight, n_power_iterations, dim, eps) + ) + return module diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py new file mode 100644 index 0000000000000000000000000000000000000000..d4946604bcb3925b01f5e5abe86b34d9126e6240 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py @@ -0,0 +1,819 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import collections +import copyreg +from contextlib import contextmanager +from copy import deepcopy +from typing import Dict, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor +from torch.__future__ import get_swap_module_params_on_conversion +from torch.nn.modules.container import Module, ModuleDict, ModuleList +from torch.nn.parameter import Parameter +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +__all__ = [ + "cached", + "ParametrizationList", + "register_parametrization", + "is_parametrized", + "remove_parametrizations", + "type_before_parametrizations", + "transfer_parametrizations_and_params", +] + +_cache_enabled = 0 +_cache: Dict[Tuple[int, str], Optional[Tensor]] = {} + + +@contextmanager +def cached(): + r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. + + The value of the parametrized objects is computed and cached the first time + they are required when this context manager is active. The cached values are + discarded when leaving the context manager. + + This is useful when using a parametrized parameter more than once in the forward pass. + An example of this is when parametrizing the recurrent kernel of an RNN or when + sharing weights. + + The simplest way to activate the cache is by wrapping the forward pass of the neural network + + .. code-block:: python + + import torch.nn.utils.parametrize as P + ... + with P.cached(): + output = model(inputs) + + in training and evaluation. One may also wrap the parts of the modules that use + several times the parametrized tensors. For example, the loop of an RNN with a + parametrized recurrent kernel: + + .. code-block:: python + + with P.cached(): + for x in xs: + out_rnn = self.rnn_cell(x, out_rnn) + """ + global _cache + global _cache_enabled + _cache_enabled += 1 + try: + yield + finally: + _cache_enabled -= 1 + if not _cache_enabled: + _cache = {} + + +def _register_parameter_or_buffer(module, name, X): + if isinstance(X, Parameter): + module.register_parameter(name, X) + else: + module.register_buffer(name, X) + + +def _maybe_set(dest: Tensor, src: Tensor) -> None: + should_swap = ( + get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest) + ) + if should_swap: + if isinstance(dest, Parameter) and not isinstance(src, Parameter): + src = Parameter(src, requires_grad=dest.requires_grad) + torch.utils.swap_tensors(dest, src) + else: + dest.set_(src) # type: ignore[call-overload] + + +class ParametrizationList(ModuleList): + r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`. + + It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` + has been parametrized with :func:`register_parametrization`. + + If the first registered parametrization has a ``right_inverse`` that returns one tensor or + does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), + it will hold the tensor under the name ``original``. + If it has a ``right_inverse`` that returns more than one tensor, these will be registered as + ``original0``, ``original1``, ... + + .. warning:: + This class is used internally by :func:`register_parametrization`. It is documented + here for completeness. It shall not be instantiated by the user. + + Args: + modules (sequence): sequence of modules representing the parametrizations + original (Parameter or Tensor): parameter or buffer that is parametrized + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + """ + + original: Tensor + unsafe: bool + + def __init__( + self, + modules: Sequence[Module], + original: Union[Tensor, Parameter], + unsafe: bool = False, + ) -> None: + # We require this because we need to treat differently the first parametrization + # This should never throw, unless this class is used from the outside + if len(modules) == 0: + raise ValueError("ParametrizationList requires one or more modules.") + + super().__init__(modules) + self.unsafe = unsafe + + # In plain words: + # module.weight must keep its dtype and shape. + # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, + # this should be of the same dtype as the original tensor + # + # We check that the following invariants hold: + # X = module.weight + # Y = param.right_inverse(X) + # assert isinstance(Y, Tensor) or + # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) + # Z = param(Y) if isinstance(Y, Tensor) else param(*Y) + # # Consistency checks + # assert X.dtype == Z.dtype and X.shape == Z.shape + # # If it has one input, this allows to be able to use set_ to be able to + # # move data to/from the original tensor without changing its id (which is what the + # # optimizer uses to track parameters) + # if isinstance(Y, Tensor) + # assert X.dtype == Y.dtype + # Below we use original = X, new = Y + + original_shape = original.shape + original_dtype = original.dtype + + # Compute new + with torch.no_grad(): + new = original + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + try: + new = module.right_inverse(new) + except NotImplementedError: + pass + # else, or if it throws, we assume that right_inverse is the identity + + if not isinstance(new, Tensor) and not isinstance( + new, collections.abc.Sequence + ): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " + f"Got {type(new).__name__}" + ) + + # Set the number of original tensors + self.is_tensor = isinstance(new, Tensor) + self.ntensors = 1 if self.is_tensor else len(new) + + # Register the tensor(s) + if self.is_tensor: + if original.dtype != new.dtype: + raise ValueError( + "When `right_inverse` outputs one tensor, it may not change the dtype.\n" + f"original.dtype: {original.dtype}\n" + f"right_inverse(original).dtype: {new.dtype}" + ) + # Set the original to original so that the user does not need to re-register the parameter + # manually in the optimiser + with torch.no_grad(): + _maybe_set(original, new) + _register_parameter_or_buffer(self, "original", original) + else: + for i, originali in enumerate(new): + if not isinstance(originali, Tensor): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors " + "(list, tuple...). " + f"Got element {i} of the sequence with type {type(originali).__name__}." + ) + + # If the original tensor was a Parameter that required grad, we expect the user to + # add the new parameters to the optimizer after registering the parametrization + # (this is documented) + if isinstance(original, Parameter): + originali = Parameter(originali, original.requires_grad) + originali.requires_grad_(original.requires_grad) + _register_parameter_or_buffer(self, f"original{i}", originali) + + if not self.unsafe: + # Consistency checks: + # Since f : A -> B, right_inverse : B -> A, Z and original should live in B + # Z = forward(right_inverse(original)) + Z = self() + if not isinstance(Z, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(Z).__name__}." + ) + if Z.dtype != original_dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized dtype: {original_dtype}\n" + f"parametrized dtype: {Z.dtype}" + ) + if Z.shape != original_shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized shape: {original_shape}\n" + f"parametrized shape: {Z.shape}" + ) + + def right_inverse(self, value: Tensor) -> None: + r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order. + + Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor + or in ``self.original0``, ``self.original1``, ... if it outputs several. + + Args: + value (Tensor): Value to which initialize the module + """ + # All the exceptions in this function should almost never throw. + # They could throw if, for example, right_inverse function returns a different + # dtype when given a different input, which should most likely be caused by a + # bug in the user's code + + with torch.no_grad(): + # See https://github.com/pytorch/pytorch/issues/53103 + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + value = module.right_inverse(value) + else: + raise RuntimeError( + f"parametrization {type(module).__name__} does not implement " + "right_inverse." + ) + if self.is_tensor: + # These exceptions should only throw when a right_inverse function does not + # return the same dtype for every input, which should most likely be caused by a bug + if not isinstance(value, Tensor): + raise ValueError( + f"`right_inverse` should return a tensor. Got {type(value).__name__}" + ) + if value.dtype != self.original.dtype: + raise ValueError( + f"The tensor returned by `right_inverse` has dtype {value.dtype} " + f"while `original` has dtype {self.original.dtype}" + ) + # We know that the result is going to have the same dtype + _maybe_set(self.original, value) + else: + if not isinstance(value, collections.abc.Sequence): + raise ValueError( + "'right_inverse' must return a sequence of tensors. " + f"Got {type(value).__name__}." + ) + if len(value) != self.ntensors: + raise ValueError( + "'right_inverse' must return a sequence of tensors of length " + f"{self.ntensors}. Got a sequence of length {len(value)}." + ) + for i, tensor in enumerate(value): + original_i = getattr(self, f"original{i}") + if not isinstance(tensor, Tensor): + raise ValueError( + f"`right_inverse` must return a sequence of tensors. " + f"Got element {i} of type {type(tensor).__name__}" + ) + if original_i.dtype != tensor.dtype: + raise ValueError( + f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " + f"while `original{i}` has dtype {original_i.dtype}" + ) + _maybe_set(original_i, tensor) + + def forward(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + # Unpack the originals for the first parametrization + if self.is_tensor: + x = self[0](self.original) + else: + originals = (getattr(self, f"original{i}") for i in range(self.ntensors)) + x = self[0](*originals) + # It's not possible to call self[1:] here, so we have to be a bit more cryptic + # Also we want to skip all non-integer keys + curr_idx = 1 + while hasattr(self, str(curr_idx)): + x = self[curr_idx](x) + curr_idx += 1 + return x + + +def _inject_new_class(module: Module) -> None: + r"""Set up a module to be parametrized. + + This works by substituting the class of the module by a class + that extends it to be able to inject a property + + Args: + module (nn.Module): module into which to inject the property + """ + cls = module.__class__ + + def default_deepcopy(self, memo): + # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. + obj = memo.get(id(self), None) + if obj is not None: + return obj + replica = self.__new__(self.__class__) + memo[id(self)] = replica + replica.__dict__ = deepcopy(self.__dict__, memo) + # Also save all slots if they exist. + slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] + for slot in slots_to_save: + if hasattr(self, slot): + setattr(replica, slot, deepcopy(getattr(self, slot), memo)) + return replica + + def getstate(self): + raise RuntimeError( + "Serialization of parametrized modules is only " + "supported through state_dict(). See:\n" + "https://pytorch.org/tutorials/beginner/saving_loading_models.html" + "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" + ) + + dct = {"__getstate__": getstate} + # We don't allow serialization of parametrized modules but should still allow deepcopying. + # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. + if not hasattr(cls, "__deepcopy__"): + dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] + + param_cls = type( + f"Parametrized{cls.__name__}", + (cls,), + dct, + ) + + module.__class__ = param_cls + + +def _inject_property(module: Module, tensor_name: str) -> None: + r"""Injects a property into module[tensor_name]. + + It assumes that the class in the module has already been modified from its + original one using _inject_new_class and that the tensor under :attr:`tensor_name` + has already been moved out + + Args: + module (nn.Module): module into which to inject the property + tensor_name (str): name of the name of the property to create + """ + # We check the precondition. + # This should never fire if register_parametrization is correctly implemented + assert not hasattr(module, tensor_name) + + @torch.jit.unused + def get_cached_parametrization(parametrization) -> Tensor: + global _cache + key = (id(module), tensor_name) + tensor = _cache.get(key) + if tensor is None: + tensor = parametrization() + _cache[key] = tensor + return tensor + + def get_parametrized(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + parametrization = self.parametrizations[tensor_name] + if _cache_enabled: + if torch.jit.is_scripting(): + # Scripting + raise RuntimeError( + "Caching is not implemented for scripting. " + "Either disable caching or avoid scripting." + ) + elif torch._C._get_tracing_state() is not None: + # Tracing + raise RuntimeError( + "Cannot trace a model while caching parametrizations." + ) + else: + return get_cached_parametrization(parametrization) + else: + # If caching is not active, this function just evaluates the parametrization + return parametrization() + + def set_original(self, value: Tensor) -> None: + if torch.jit.is_scripting(): + raise RuntimeError("Parametrization is not working with scripting.") + self.parametrizations[tensor_name].right_inverse(value) + + setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) + + +def register_parametrization( + module: Module, + tensor_name: str, + parametrization: Module, + *, + unsafe: bool = False, +) -> Module: + r"""Register a parametrization to a tensor in a module. + + Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, + the module will return the parametrized version ``parametrization(module.weight)``. + If the original tensor requires a gradient, the backward pass will differentiate + through :attr:`parametrization`, and the optimizer will update the tensor accordingly. + + The first time that a module registers a parametrization, this function will add an attribute + ``parametrizations`` to the module of type :class:`~ParametrizationList`. + + The list of parametrizations on the tensor ``weight`` will be accessible under + ``module.parametrizations.weight``. + + The original tensor will be accessible under + ``module.parametrizations.weight.original``. + + Parametrizations may be concatenated by registering several parametrizations + on the same attribute. + + The training mode of a registered parametrization is updated on registration + to match the training mode of the host module + + Parametrized parameters and buffers have an inbuilt caching system that can be activated + using the context manager :func:`cached`. + + A :attr:`parametrization` may optionally implement a method with signature + + .. code-block:: python + + def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] + + This method is called on the unparametrized tensor when the first parametrization + is registered to compute the initial value of the original tensor. + If this method is not implemented, the original tensor will be just the unparametrized tensor. + + If all the parametrizations registered on a tensor implement `right_inverse` it is possible + to initialize a parametrized tensor by assigning to it, as shown in the example below. + + It is possible for the first parametrization to depend on several inputs. + This may be implemented returning a tuple of tensors from ``right_inverse`` + (see the example implementation of a ``RankOne`` parametrization below). + + In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` + with names ``original0``, ``original1``,... + + .. note:: + + If unsafe=False (default) both the forward and right_inverse methods will be called + once to perform a number of consistency checks. + If unsafe=True, then right_inverse will be called if the tensor is not parametrized, + and nothing will be called otherwise. + + .. note:: + + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this. + + .. warning:: + + If a parametrization depends on several inputs, :func:`~register_parametrization` + will register a number of new parameters. If such parametrization is registered + after the optimizer is created, these new parameters will need to be added manually + to the optimizer. See :meth:`torch.Optimizer.add_param_group`. + + Args: + module (nn.Module): module on which to register the parametrization + tensor_name (str): name of the parameter or buffer on which to register + the parametrization + parametrization (nn.Module): the parametrization to register + Keyword args: + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + + Raises: + ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> import torch + >>> import torch.nn as nn + >>> import torch.nn.utils.parametrize as P + >>> + >>> class Symmetric(nn.Module): + >>> def forward(self, X): + >>> return X.triu() + X.triu(1).T # Return a symmetric matrix + >>> + >>> def right_inverse(self, A): + >>> return A.triu() + >>> + >>> m = nn.Linear(5, 5) + >>> P.register_parametrization(m, "weight", Symmetric()) + >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric + True + >>> A = torch.rand(5, 5) + >>> A = A + A.T # A is now symmetric + >>> m.weight = A # Initialize the weight to be the symmetric matrix A + >>> print(torch.allclose(m.weight, A)) + True + + >>> class RankOne(nn.Module): + >>> def forward(self, x, y): + >>> # Form a rank 1 matrix multiplying two vectors + >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) + >>> + >>> def right_inverse(self, Z): + >>> # Project Z onto the rank 1 matrices + >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) + >>> # Return rescaled singular vectors + >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) + >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt + >>> + >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) + >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) + 1 + + """ + parametrization.train(module.training) + if is_parametrized(module, tensor_name): + # Correctness checks. + # If A is the space of tensors with shape and dtype equal to module.weight + # we check that parametrization.forward and parametrization.right_inverse are + # functions from A to A + if not unsafe: + Y = getattr(module, tensor_name) + X = parametrization(Y) + if not isinstance(X, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(X).__name__}." + ) + if X.dtype != Y.dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"parametrization(module.{tensor_name}).dtype: {X.dtype}" + ) + if X.shape != Y.shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"parametrization(module.{tensor_name}).shape: {X.shape}" + ) + if hasattr(parametrization, "right_inverse"): + try: + Z = parametrization.right_inverse(X) # type: ignore[operator] + except NotImplementedError: + pass + else: + if not isinstance(Z, Tensor): + raise ValueError( + f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" + ) + if Z.dtype != Y.dtype: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same dtype " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"returned dtype: {Z.dtype}" + ) + if Z.shape != Y.shape: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same shape " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"returned shape: {Z.shape}" + ) + # else right_inverse is assumed to be the identity + + # add the new parametrization to the parametrization list + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name].append(parametrization) + # If unsafe was True in previous parametrization, keep it enabled + module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr] + elif tensor_name in module._buffers or tensor_name in module._parameters: + # Set the parametrization mechanism + # Fetch the original buffer or parameter + original = getattr(module, tensor_name) + # We create this early to check for possible errors + parametrizations = ParametrizationList( + [parametrization], original, unsafe=unsafe + ) + # Delete the previous parameter or buffer + delattr(module, tensor_name) + # If this is the first parametrization registered on the module, + # we prepare the module to inject the property + if not is_parametrized(module): + # Change the class + _inject_new_class(module) + # Inject a ``ModuleDict`` into the instance under module.parametrizations + module.parametrizations = ModuleDict() + # Add a property into the class + _inject_property(module, tensor_name) + # Add a ParametrizationList + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name] = parametrizations + else: + raise ValueError( + f"Module '{module}' does not have a parameter, a buffer, or a " + f"parametrized element with name '{tensor_name}'" + ) + return module + + +def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: + r"""Determine if a module has a parametrization. + + Args: + module (nn.Module): module to query + tensor_name (str, optional): name of the parameter in the module + Default: ``None`` + Returns: + ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`, + or if it has any parametrization when :attr:`tensor_name` is ``None``; + otherwise ``False`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 + else: + return tensor_name in parametrizations + + +def remove_parametrizations( + module: Module, + tensor_name: str, + leave_parametrized: bool = True, +) -> Module: + r"""Remove the parametrizations on a tensor in a module. + + - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to + its current output. In this case, the parametrization shall not change the ``dtype`` + of the tensor. + - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to + the unparametrised tensor in ``module.parametrizations[tensor_name].original``. + This is only possible when the parametrization depends on just one tensor. + + Args: + module (nn.Module): module from which remove the parametrization + tensor_name (str): name of the parametrization to be removed + leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. + Default: ``True`` + + Returns: + Module: module + + Raises: + ValueError: if ``module[tensor_name]`` is not parametrized + ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors + """ + if not is_parametrized(module, tensor_name): + raise ValueError( + f"Module {module} does not have a parametrization on {tensor_name}" + ) + + # Fetch the original tensor + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + parametrizations = module.parametrizations[tensor_name] + if parametrizations.is_tensor: + original = parametrizations.original + if leave_parametrized: + with torch.no_grad(): + t = getattr(module, tensor_name) + # We know they have the same dtype because we have checked this when registering the + # parametrizations. As such, we can use set_ + # We do this so that the parameter does not to change the id() + # This way the user does not need to update the optimizer + with torch.no_grad(): + if type(original) is torch.Tensor: + _maybe_set(original, t) + else: + try: + _maybe_set(original, t) + except RuntimeError as e: + # TODO: Fix this for tensor subclasses that are parameters: + # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). + raise RuntimeError( + "Calling remove_parametrizations() with leave_parametrized=True " + "for a parameter that is an instance of a tensor subclass requires " + "set_() to be implemented correctly for the tensor subclass." + "Alternatively, one can opt into the swap_tensors path" + "Either set leave_parametrized=False or provide a working implementation" + "for set_() in the tensor subclass or set " + "torch.__future__.set_swap_module_params_on_conversion(True)." + ) from e + else: + if leave_parametrized: + # We cannot use no_grad because we need to know whether one or more + # original tensors required grad + t = getattr(module, tensor_name) + # We'll have to trust the user to add it to the optimizer + original = Parameter(t) if t.requires_grad else t + else: + raise ValueError( + "Cannot leave unparametrized (`leave_parametrized=False`) a tensor " + "that is parametrized in terms of a sequence of tensors." + ) + + # Delete the property that manages the parametrization + delattr(module.__class__, tensor_name) + # Delete the ParametrizationList + del module.parametrizations[tensor_name] + + # Restore the parameter / buffer into the main class + _register_parameter_or_buffer(module, tensor_name, original) + + # Roll back the parametrized class if no other buffer or parameter + # is currently parametrized in this class + if not is_parametrized(module): + delattr(module, "parametrizations") + # Restore class + orig_cls = module.__class__.__bases__[0] + module.__class__ = orig_cls + return module + + +def type_before_parametrizations(module: Module) -> type: + r"""Return the module type before parametrizations were applied and if not, then it returns the module type. + + Args: + module (nn.Module): module to get type of + """ + if is_parametrized(module): + return module.__class__.__bases__[0] + else: + return type(module) + + +def transfer_parametrizations_and_params( + from_module: Module, + to_module: Module, + tensor_name: Optional[str] = None, +) -> Module: + r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`. + + If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise + transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. + Does nothing if from_module is not parametrized. + + Args: + from_module (nn.Module): module to transfer from + to_module (nn.Module): module to transfer to + tensor_name (str, optional): parameter to transfer + + Returns: + Module: to_module + """ + if is_parametrized(from_module): + assert isinstance(from_module.parametrizations, ModuleDict) # for mypy + + # get list of all params or the single param to transfer + parameters_to_transfer: Union[list, ModuleDict] = ( + from_module.parametrizations if tensor_name is None else [tensor_name] + ) + + assert hasattr(parameters_to_transfer, "__iter__") # for mypy + for parameter_name in parameters_to_transfer: + # initialize the to-be-transferred param in to_module if it doesn't exist already + if not hasattr(to_module, parameter_name): + setattr( + to_module, + parameter_name, + Parameter(getattr(from_module, parameter_name)), + ) + + # apply the params's parametrizations to to_module + for param_func in from_module.parametrizations[parameter_name]: + register_parametrization(to_module, parameter_name, param_func) + assert isinstance(to_module.parametrizations, ModuleDict) # for mypy + + # make values match, original values can be stored in either original or + # original0, original1..., need to check both cases + if hasattr(from_module.parametrizations[parameter_name], "original"): + to_module.parametrizations[ + parameter_name + ].original = from_module.parametrizations[parameter_name].original + else: + num = 0 + orig_num = "original" + str(num) + # loop through each original# until all values have been set + while hasattr(from_module.parametrizations[parameter_name], orig_num): + setattr( + to_module.parametrizations[parameter_name], + orig_num, + getattr(from_module.parametrizations[parameter_name], orig_num), + ) + num = num + 1 + orig_num = "original" + str(num) + + return to_module diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/prune.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/prune.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc55d6207420028a94a8207026234ee0a9a0a57 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/prune.py @@ -0,0 +1,1373 @@ +# mypy: allow-untyped-defs +r"""Pruning methods.""" +import numbers +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Tuple + +import torch + + +class BasePruningMethod(ABC): + r"""Abstract base class for creation of new pruning techniques. + + Provides a skeleton for customization requiring the overriding of methods + such as :meth:`compute_mask` and :meth:`apply`. + """ + + _tensor_name: str + + def __call__(self, module, inputs): + r"""Multiply the mask into original tensor and store the result. + + Multiplies the mask (stored in ``module[name + '_mask']``) + into the original tensor (stored in ``module[name + '_orig']``) + and stores the result into ``module[name]`` by using :meth:`apply_mask`. + + Args: + module (nn.Module): module containing the tensor to prune + inputs: not used. + """ + setattr(module, self._tensor_name, self.apply_mask(module)) + + @abstractmethod + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a random mask to + apply on top of the ``default_mask`` according to the specific pruning + method recipe. + + Args: + t (torch.Tensor): tensor representing the importance scores of the + parameter to prune. + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + """ + + def apply_mask(self, module): + r"""Simply handles the multiplication between the parameter being pruned and the generated mask. + + Fetches the mask and the original tensor from the module + and returns the pruned version of the tensor. + + Args: + module (nn.Module): module containing the tensor to prune + + Returns: + pruned_tensor (torch.Tensor): pruned version of the input tensor + """ + # to carry out the multiplication, the mask needs to have been computed, + # so the pruning method must know what tensor it's operating on + assert ( + self._tensor_name is not None + ), f"Module {module} has to be pruned" # this gets set in apply() + mask = getattr(module, self._tensor_name + "_mask") + orig = getattr(module, self._tensor_name + "_orig") + pruned_tensor = mask.to(dtype=orig.dtype) * orig + return pruned_tensor + + @classmethod + def apply(cls, module, name, *args, importance_scores=None, **kwargs): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + args: arguments passed on to a subclass of + :class:`BasePruningMethod` + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the + corresponding elements in the parameter being pruned. + If unspecified or None, the parameter will be used in its place. + kwargs: keyword arguments passed on to a subclass of a + :class:`BasePruningMethod` + """ + + def _get_composite_method(cls, module, name, *args, **kwargs): + # Check if a pruning method has already been applied to + # `module[name]`. If so, store that in `old_method`. + old_method = None + found = 0 + # there should technically be only 1 hook with hook.name == name + # assert this using `found` + hooks_to_remove = [] + for k, hook in module._forward_pre_hooks.items(): + # if it exists, take existing thing, remove hook, then + # go through normal thing + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: + old_method = hook + hooks_to_remove.append(k) + found += 1 + assert ( + found <= 1 + ), f"Avoid adding multiple pruning hooks to the\ + same tensor {name} of module {module}. Use a PruningContainer." + + for k in hooks_to_remove: + del module._forward_pre_hooks[k] + + # Apply the new pruning method, either from scratch or on top of + # the previous one. + method = cls(*args, **kwargs) # new pruning + # Have the pruning method remember what tensor it's been applied to + method._tensor_name = name + + # combine `methods` with `old_method`, if `old_method` exists + if old_method is not None: # meaning that there was a hook + # if the hook is already a pruning container, just add the + # new pruning method to the container + if isinstance(old_method, PruningContainer): + old_method.add_pruning_method(method) + method = old_method # rename old_method --> method + + # if the hook is simply a single pruning method, create a + # container, add the old pruning method and the new one + elif isinstance(old_method, BasePruningMethod): + container = PruningContainer(old_method) + # Have the pruning method remember the name of its tensor + # setattr(container, '_tensor_name', name) + container.add_pruning_method(method) + method = container # rename container --> method + return method + + method = _get_composite_method(cls, module, name, *args, **kwargs) + # at this point we have no forward_pre_hooks but we could have an + # active reparametrization of the tensor if another pruning method + # had been applied (in which case `method` would be a PruningContainer + # and not a simple pruning method). + + # Pruning is to be applied to the module's tensor named `name`, + # starting from the state it is found in prior to this iteration of + # pruning. The pruning mask is calculated based on importances scores. + + orig = getattr(module, name) + if importance_scores is not None: + assert ( + importance_scores.shape == orig.shape + ), f"importance_scores should have the same shape as parameter {name} of {module}" + else: + importance_scores = orig + + # If this is the first time pruning is applied, take care of moving + # the original tensor to a new parameter called name + '_orig' and + # and deleting the original parameter + if not isinstance(method, PruningContainer): + # copy `module[name]` to `module[name + '_orig']` + module.register_parameter(name + "_orig", orig) + # temporarily delete `module[name]` + del module._parameters[name] + default_mask = torch.ones_like(orig) # temp + # If this is not the first time pruning is applied, all of the above + # has been done before in a previous pruning iteration, so we're good + # to go + else: + default_mask = ( + getattr(module, name + "_mask") + .detach() + .clone(memory_format=torch.contiguous_format) + ) + + # Use try/except because if anything goes wrong with the mask + # computation etc., you'd want to roll back. + try: + # get the final mask, computed according to the specific method + mask = method.compute_mask(importance_scores, default_mask=default_mask) + # reparameterize by saving mask to `module[name + '_mask']`... + module.register_buffer(name + "_mask", mask) + # ... and the new pruned tensor to `module[name]` + setattr(module, name, method.apply_mask(module)) + # associate the pruning method to the module via a hook to + # compute the function before every forward() (compile by run) + module.register_forward_pre_hook(method) + + except Exception as e: + if not isinstance(method, PruningContainer): + orig = getattr(module, name + "_orig") + module.register_parameter(name, orig) + del module._parameters[name + "_orig"] + raise e + + return method + + def prune(self, t, default_mask=None, importance_scores=None): + r"""Compute and returns a pruned version of input tensor ``t``. + + According to the pruning rule specified in :meth:`compute_mask`. + + Args: + t (torch.Tensor): tensor to prune (of same dimensions as + ``default_mask``). + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as ``t``) used to compute mask for pruning ``t``. + The values in this tensor indicate the importance of the + corresponding elements in the ``t`` that is being pruned. + If unspecified or None, the tensor ``t`` will be used in its place. + default_mask (torch.Tensor, optional): mask from previous pruning + iteration, if any. To be considered when determining what + portion of the tensor that pruning should act on. If None, + default to a mask of ones. + + Returns: + pruned version of tensor ``t``. + """ + if importance_scores is not None: + assert ( + importance_scores.shape == t.shape + ), "importance_scores should have the same shape as tensor t" + else: + importance_scores = t + default_mask = default_mask if default_mask is not None else torch.ones_like(t) + return t * self.compute_mask(importance_scores, default_mask=default_mask) + + def remove(self, module): + r"""Remove the pruning reparameterization from a module. + + The pruned parameter named ``name`` remains permanently pruned, + and the parameter named ``name+'_orig'`` is removed from the parameter list. + Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. + + Note: + Pruning itself is NOT undone or reversed! + """ + # before removing pruning from a tensor, it has to have been applied + assert ( + self._tensor_name is not None + ), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply() + + # to update module[name] to latest trained weights + weight = self.apply_mask(module) # masked weights + + # delete and reset + if hasattr(module, self._tensor_name): + delattr(module, self._tensor_name) + orig = module._parameters[self._tensor_name + "_orig"] + orig.data = weight.data + del module._parameters[self._tensor_name + "_orig"] + del module._buffers[self._tensor_name + "_mask"] + setattr(module, self._tensor_name, orig) + + +class PruningContainer(BasePruningMethod): + """Container holding a sequence of pruning methods for iterative pruning. + + Keeps track of the order in which pruning methods are applied and handles + combining successive pruning calls. + + Accepts as argument an instance of a BasePruningMethod or an iterable of + them. + """ + + def __init__(self, *args): + self._pruning_methods: Tuple[BasePruningMethod, ...] = () + if not isinstance(args, Iterable): # only 1 item + self._tensor_name = args._tensor_name + self.add_pruning_method(args) + elif len(args) == 1: # only 1 item in a tuple + self._tensor_name = args[0]._tensor_name + self.add_pruning_method(args[0]) + else: # manual construction from list or other iterable (or no args) + for method in args: + self.add_pruning_method(method) + + def add_pruning_method(self, method): + r"""Add a child pruning ``method`` to the container. + + Args: + method (subclass of BasePruningMethod): child pruning method + to be added to the container. + """ + # check that we're adding a pruning method to the container + if not isinstance(method, BasePruningMethod) and method is not None: + raise TypeError(f"{type(method)} is not a BasePruningMethod subclass") + elif method is not None and self._tensor_name != method._tensor_name: + raise ValueError( + "Can only add pruning methods acting on " + f"the parameter named '{self._tensor_name}' to PruningContainer {self}." + + f" Found '{method._tensor_name}'" + ) + # if all checks passed, add to _pruning_methods tuple + self._pruning_methods += (method,) # type: ignore[operator] + + def __len__(self): + return len(self._pruning_methods) + + def __iter__(self): + return iter(self._pruning_methods) + + def __getitem__(self, idx): + return self._pruning_methods[idx] + + def compute_mask(self, t, default_mask): + r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``. + + The new partial mask should be computed on the entries or channels + that were not zeroed out by the ``default_mask``. + Which portions of the tensor ``t`` the new mask will be calculated from + depends on the ``PRUNING_TYPE`` (handled by the type handler): + + * for 'unstructured', the mask will be computed from the raveled + list of nonmasked entries; + + * for 'structured', the mask will be computed from the nonmasked + channels in the tensor; + + * for 'global', the mask will be computed across all entries. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + (of same dimensions as ``default_mask``). + default_mask (torch.Tensor): mask from previous pruning iteration. + + Returns: + mask (torch.Tensor): new mask that combines the effects + of the ``default_mask`` and the new mask from the current + pruning ``method`` (of same dimensions as ``default_mask`` and + ``t``). + """ + + def _combine_masks(method, t, mask): + r"""Combine the masks from all pruning methods and returns a new mask. + + Args: + method (a BasePruningMethod subclass): pruning method + currently being applied. + t (torch.Tensor): tensor representing the parameter to prune + (of same dimensions as mask). + mask (torch.Tensor): mask from previous pruning iteration + + Returns: + new_mask (torch.Tensor): new mask that combines the effects + of the old mask and the new mask from the current + pruning method (of same dimensions as mask and t). + """ + new_mask = mask # start off from existing mask + new_mask = new_mask.to(dtype=t.dtype) + + # compute a slice of t onto which the new pruning method will operate + if method.PRUNING_TYPE == "unstructured": + # prune entries of t where the mask is 1 + slc = mask == 1 + + # for struct pruning, exclude channels that have already been + # entirely pruned + elif method.PRUNING_TYPE == "structured": + if not hasattr(method, "dim"): + raise AttributeError( + "Pruning methods of PRUNING_TYPE " + '"structured" need to have the attribute `dim` defined.' + ) + + # find the channels to keep by removing the ones that have been + # zeroed out already (i.e. where sum(entries) == 0) + n_dims = t.dim() # "is this a 2D tensor? 3D? ..." + dim = method.dim + # convert negative indexing + if dim < 0: + dim = n_dims + dim + # if dim is still negative after subtracting it from n_dims + if dim < 0: + raise IndexError( + f"Index is out of bounds for tensor with dimensions {n_dims}" + ) + # find channels along dim = dim that aren't already tots 0ed out + keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 + # create slice to identify what to prune + slc = [slice(None)] * n_dims + slc[dim] = keep_channel + + elif method.PRUNING_TYPE == "global": + n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..." + slc = [slice(None)] * n_dims + + else: + raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}") + + # compute the new mask on the unpruned slice of the tensor t + partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) + new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) + + return new_mask + + method = self._pruning_methods[-1] + mask = _combine_masks(method, t, default_mask) + return mask + + +class Identity(BasePruningMethod): + r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones.""" + + PRUNING_TYPE = "unstructured" + + def compute_mask(self, t, default_mask): + mask = default_mask + return mask + + @classmethod + def apply(cls, module, name): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + """ + return super().apply(module, name) + + +class RandomUnstructured(BasePruningMethod): + r"""Prune (currently unpruned) units in a tensor at random. + + Args: + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + + PRUNING_TYPE = "unstructured" + + def __init__(self, amount): + # Check range of validity of pruning amount + _validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + # Check that the amount of units to prune is not > than the number of + # parameters in t + tensor_size = t.nelement() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + prob = torch.rand_like(t) + topk = torch.topk(prob.view(-1), k=nparams_toprune) + mask.view(-1)[topk.indices] = 0 + + return mask + + @classmethod + def apply(cls, module, name, amount): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + return super().apply(module, name, amount=amount) + + +class L1Unstructured(BasePruningMethod): + r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. + + Args: + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + + PRUNING_TYPE = "unstructured" + + def __init__(self, amount): + # Check range of validity of pruning amount + _validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + # Check that the amount of units to prune is not > than the number of + # parameters in t + tensor_size = t.nelement() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + # largest=True --> top k; largest=False --> bottom k + # Prune the smallest k + topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) + # topk will have .indices and .values + mask.view(-1)[topk.indices] = 0 + + return mask + + @classmethod + def apply(cls, module, name, amount, importance_scores=None): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + """ + return super().apply( + module, name, amount=amount, importance_scores=importance_scores + ) + + +class RandomStructured(BasePruningMethod): + r"""Prune entire (currently unpruned) channels in a tensor at random. + + Args: + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + + PRUNING_TYPE = "structured" + + def __init__(self, amount, dim=-1): + # Check range of validity of amount + _validate_pruning_amount_init(amount) + self.amount = amount + self.dim = dim + + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a random mask to + apply on top of the ``default_mask`` by randomly zeroing out channels + along the specified dim of the tensor. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + + Raises: + IndexError: if ``self.dim >= len(t.shape)`` + """ + # Check that tensor has structure (i.e. more than 1 dimension) such + # that the concept of "channels" makes sense + _validate_structured_pruning(t) + + # Check that self.dim is a valid dim to index t, else raise IndexError + _validate_pruning_dim(t, self.dim) + + # Check that the amount of channels to prune is not > than the number of + # channels in t along the dim to prune + tensor_size = t.shape[self.dim] + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + # Compute binary mask by initializing it to all 0s and then filling in + # 1s wherever topk.indices indicates, along self.dim. + # mask has the same shape as tensor t + def make_mask(t, dim, nchannels, nchannels_toprune): + # generate a random number in [0, 1] to associate to each channel + prob = torch.rand(nchannels) + # generate mask for each channel by 0ing out the channels that + # got assigned the k = nchannels_toprune lowest values in prob + threshold = torch.kthvalue(prob, k=nchannels_toprune).values + channel_mask = prob > threshold + + mask = torch.zeros_like(t) + slc = [slice(None)] * len(t.shape) + slc[dim] = channel_mask + mask[slc] = 1 + return mask + + if nparams_toprune == 0: # k=0 not supported by torch.kthvalue + mask = default_mask + else: + # apply the new structured mask on top of prior (potentially + # unstructured) mask + mask = make_mask(t, self.dim, tensor_size, nparams_toprune) + mask *= default_mask.to(dtype=mask.dtype) + return mask + + @classmethod + def apply(cls, module, name, amount, dim=-1): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + return super().apply(module, name, amount=amount, dim=dim) + + +class LnStructured(BasePruningMethod): + r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm. + + Args: + amount (int or float): quantity of channels to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + + PRUNING_TYPE = "structured" + + def __init__(self, amount, n, dim=-1): + # Check range of validity of amount + _validate_pruning_amount_init(amount) + self.amount = amount + self.n = n + self.dim = dim + + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a mask to apply on + top of the ``default_mask`` by zeroing out the channels along the + specified dim with the lowest L\ ``n``-norm. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + + Raises: + IndexError: if ``self.dim >= len(t.shape)`` + """ + # Check that tensor has structure (i.e. more than 1 dimension) such + # that the concept of "channels" makes sense + _validate_structured_pruning(t) + # Check that self.dim is a valid dim to index t, else raise IndexError + _validate_pruning_dim(t, self.dim) + + # Check that the amount of channels to prune is not > than the number of + # channels in t along the dim to prune + tensor_size = t.shape[self.dim] + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + nparams_tokeep = tensor_size - nparams_toprune + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + # Structured pruning prunes entire channels so we need to know the + # L_n norm along each channel to then find the topk based on this + # metric + norm = _compute_norm(t, self.n, self.dim) + # largest=True --> top k; largest=False --> bottom k + # Keep the largest k channels along dim=self.dim + topk = torch.topk(norm, k=nparams_tokeep, largest=True) + # topk will have .indices and .values + + # Compute binary mask by initializing it to all 0s and then filling in + # 1s wherever topk.indices indicates, along self.dim. + # mask has the same shape as tensor t + def make_mask(t, dim, indices): + # init mask to 0 + mask = torch.zeros_like(t) + # e.g.: slc = [None, None, None], if len(t.shape) = 3 + slc = [slice(None)] * len(t.shape) + # replace a None at position=dim with indices + # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] + slc[dim] = indices + # use slc to slice mask and replace all its entries with 1s + # e.g.: mask[:, :, [0, 2, 3]] = 1 + mask[slc] = 1 + return mask + + if nparams_toprune == 0: # k=0 not supported by torch.kthvalue + mask = default_mask + else: + mask = make_mask(t, self.dim, topk.indices) + mask *= default_mask.to(dtype=mask.dtype) + + return mask + + @classmethod + def apply(cls, module, name, amount, n, dim, importance_scores=None): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int): index of the dim along which we define channels to + prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + """ + return super().apply( + module, + name, + amount=amount, + n=n, + dim=dim, + importance_scores=importance_scores, + ) + + +class CustomFromMask(BasePruningMethod): + PRUNING_TYPE = "global" + + def __init__(self, mask): + self.mask = mask + + def compute_mask(self, t, default_mask): + assert default_mask.shape == self.mask.shape + mask = default_mask * self.mask.to(dtype=default_mask.dtype) + return mask + + @classmethod + def apply(cls, module, name, mask): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + """ + return super().apply(module, name, mask=mask) + + +def identity(module, name): + r"""Apply pruning reparametrization without pruning any units. + + Applies pruning reparametrization to the tensor corresponding to the + parameter called ``name`` in ``module`` without actually pruning any + units. Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Note: + The mask is a tensor of ones. + + Args: + module (nn.Module): module containing the tensor to prune. + name (str): parameter name within ``module`` on which pruning + will act. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.identity(nn.Linear(2, 3), 'bias') + >>> print(m.bias_mask) + tensor([1., 1., 1.]) + """ + Identity.apply(module, name) + return module + + +def random_unstructured(module, name, amount): + r"""Prune tensor by removing random (currently unpruned) units. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) units + selected at random. + Modifies module in place (and also return the modified module) by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) + >>> torch.sum(m.weight_mask == 0) + tensor(1) + + """ + RandomUnstructured.apply(module, name, amount) + return module + + +def l1_unstructured(module, name, amount, importance_scores=None): + r"""Prune tensor by removing units with the lowest L1-norm. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified `amount` of (currently unpruned) units with the + lowest L1-norm. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2) + >>> m.state_dict().keys() + odict_keys(['bias', 'weight_orig', 'weight_mask']) + """ + L1Unstructured.apply( + module, name, amount=amount, importance_scores=importance_scores + ) + return module + + +def random_structured(module, name, amount, dim): + r"""Prune tensor by removing random channels along the specified dimension. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) channels + along the specified ``dim`` selected at random. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int): index of the dim along which we define channels to prune. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.random_structured( + ... nn.Linear(5, 3), 'weight', amount=3, dim=1 + ... ) + >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) + >>> print(columns_pruned) + 3 + """ + RandomStructured.apply(module, name, amount, dim) + return module + + +def ln_structured(module, name, amount, n, dim, importance_scores=None): + r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) channels + along the specified ``dim`` with the lowest L\ ``n``-norm. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> from torch.nn.utils import prune + >>> m = prune.ln_structured( + ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') + ... ) + """ + LnStructured.apply( + module, name, amount, n, dim, importance_scores=importance_scores + ) + return module + + +def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): + r""" + Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. + + Modifies modules in place by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + parameters (Iterable of (module, name) tuples): parameters of + the model to prune in a global fashion, i.e. by aggregating all + weights prior to deciding which ones to prune. module must be of + type :class:`nn.Module`, and name must be a string. + pruning_method (function): a valid pruning function from this module, + or a custom one implemented by the user that satisfies the + implementation guidelines and has ``PRUNING_TYPE='unstructured'``. + importance_scores (dict): a dictionary mapping (module, name) tuples to + the corresponding parameter's importance scores tensor. The tensor + should be the same shape as the parameter, and is used for computing + mask for pruning. + If unspecified or None, the parameter will be used in place of its + importance scores. + kwargs: other keyword arguments such as: + amount (int or float): quantity of parameters to prune across the + specified parameters. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + + Raises: + TypeError: if ``PRUNING_TYPE != 'unstructured'`` + + Note: + Since global structured pruning doesn't make much sense unless the + norm is normalized by the size of the parameter, we now limit the + scope of global pruning to unstructured methods. + + Examples: + >>> from torch.nn.utils import prune + >>> from collections import OrderedDict + >>> net = nn.Sequential(OrderedDict([ + ... ('first', nn.Linear(10, 4)), + ... ('second', nn.Linear(4, 1)), + ... ])) + >>> parameters_to_prune = ( + ... (net.first, 'weight'), + ... (net.second, 'weight'), + ... ) + >>> prune.global_unstructured( + ... parameters_to_prune, + ... pruning_method=prune.L1Unstructured, + ... amount=10, + ... ) + >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) + tensor(10) + + """ + # ensure parameters is a list or generator of tuples + if not isinstance(parameters, Iterable): + raise TypeError("global_unstructured(): parameters is not an Iterable") + + importance_scores = importance_scores if importance_scores is not None else {} + if not isinstance(importance_scores, dict): + raise TypeError("global_unstructured(): importance_scores must be of type dict") + + # flatten importance scores to consider them all at once in global pruning + relevant_importance_scores = torch.nn.utils.parameters_to_vector( + [ + importance_scores.get((module, name), getattr(module, name)) + for (module, name) in parameters + ] + ) + # similarly, flatten the masks (if they exist), or use a flattened vector + # of 1s of the same dimensions as t + default_mask = torch.nn.utils.parameters_to_vector( + [ + getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) + for (module, name) in parameters + ] + ) + + # use the canonical pruning methods to compute the new mask, even if the + # parameter is now a flattened out version of `parameters` + container = PruningContainer() + container._tensor_name = "temp" # to make it match that of `method` + method = pruning_method(**kwargs) + method._tensor_name = "temp" # to make it match that of `container` + if method.PRUNING_TYPE != "unstructured": + raise TypeError( + 'Only "unstructured" PRUNING_TYPE supported for ' + f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}" + ) + + container.add_pruning_method(method) + + # use the `compute_mask` method from `PruningContainer` to combine the + # mask computed by the new method with the pre-existing mask + final_mask = container.compute_mask(relevant_importance_scores, default_mask) + + # Pointer for slicing the mask to match the shape of each parameter + pointer = 0 + for module, name in parameters: + param = getattr(module, name) + # The length of the parameter + num_param = param.numel() + # Slice the mask, reshape it + param_mask = final_mask[pointer : pointer + num_param].view_as(param) + # Assign the correct pre-computed mask to each parameter and add it + # to the forward_pre_hooks like any other pruning method + custom_from_mask(module, name, mask=param_mask) + + # Increment the pointer to continue slicing the final_mask + pointer += num_param + + +def custom_from_mask(module, name, mask): + r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``. + + Modifies module in place (and also return the modified module) by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + mask (Tensor): binary mask to be applied to the parameter. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> from torch.nn.utils import prune + >>> m = prune.custom_from_mask( + ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) + ... ) + >>> print(m.bias_mask) + tensor([0., 1., 0.]) + + """ + CustomFromMask.apply(module, name, mask) + return module + + +def remove(module, name): + r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook. + + The pruned parameter named ``name`` remains permanently pruned, and the parameter + named ``name+'_orig'`` is removed from the parameter list. Similarly, + the buffer named ``name+'_mask'`` is removed from the buffers. + + Note: + Pruning itself is NOT undone or reversed! + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + + Examples: + >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2) + >>> m = remove(m, name='weight') + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError( + f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed" + ) + + +def is_pruned(module): + r"""Check if a module is pruned by looking for pruning pre-hooks. + + Check whether ``module`` is pruned by looking for + ``forward_pre_hooks`` in its modules that inherit from the + :class:`BasePruningMethod`. + + Args: + module (nn.Module): object that is either pruned or unpruned + + Returns: + binary answer to whether ``module`` is pruned. + + Examples: + >>> from torch.nn.utils import prune + >>> m = nn.Linear(5, 7) + >>> print(prune.is_pruned(m)) + False + >>> prune.random_unstructured(m, name='weight', amount=0.2) + >>> print(prune.is_pruned(m)) + True + """ + for _, submodule in module.named_modules(): + for hook in submodule._forward_pre_hooks.values(): + if isinstance(hook, BasePruningMethod): + return True + return False + + +def _validate_pruning_amount_init(amount): + r"""Validate helper to check the range of amount at init. + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + + Raises: + ValueError: if amount is a float not in [0, 1], or if it's a negative + integer. + TypeError: if amount is neither a float nor an integer. + + Note: + This does not take into account the number of parameters in the + tensor to be pruned, which is known only at prune. + """ + if not isinstance(amount, numbers.Real): + raise TypeError(f"Invalid type for amount: {amount}. Must be int or float.") + + if (isinstance(amount, numbers.Integral) and amount < 0) or ( + not isinstance(amount, numbers.Integral) # so it's a float + and (float(amount) > 1.0 or float(amount) < 0.0) + ): + raise ValueError( + f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer" + ) + + +def _validate_pruning_amount(amount, tensor_size): + r"""Validate that the pruning amount is meaningful wrt to the size of the data. + + Validation helper to check that the amount of parameters to prune + is meaningful wrt to the size of the data (`tensor_size`). + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + tensor_size (int): absolute number of parameters in the tensor + to prune. + """ + # TODO: consider removing this check and allowing users to specify + # a number of units to prune that is greater than the number of units + # left to prune. In this case, the tensor will just be fully pruned. + + if isinstance(amount, numbers.Integral) and amount > tensor_size: + raise ValueError( + f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}" + ) + + +def _validate_structured_pruning(t): + r"""Validate that the tensor to be pruned is at least 2-Dimensional. + + Validation helper to check that the tensor to be pruned is multi- + dimensional, such that the concept of "channels" is well-defined. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + + Raises: + ValueError: if the tensor `t` is not at least 2D. + """ + shape = t.shape + if len(shape) <= 1: + raise ValueError( + "Structured pruning can only be applied to " + "multidimensional tensors. Found tensor of shape " + f"{shape} with {len(shape)} dims" + ) + + +def _compute_nparams_toprune(amount, tensor_size): + r"""Convert the pruning amount from a percentage to absolute value. + + Since amount can be expressed either in absolute value or as a + percentage of the number of units/channels in a tensor, this utility + function converts the percentage to absolute value to standardize + the handling of pruning. + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + tensor_size (int): absolute number of parameters in the tensor + to prune. + + Returns: + int: the number of units to prune in the tensor + """ + # incorrect type already checked in _validate_pruning_amount_init + if isinstance(amount, numbers.Integral): + return amount + else: + return round(amount * tensor_size) + + +def _validate_pruning_dim(t, dim): + r"""Validate that the pruning dimension is within the bounds of the tensor dimension. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + dim (int): index of the dim along which we define channels to prune + """ + if dim >= t.dim(): + raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}") + + +def _compute_norm(t, n, dim): + r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension. + + The L_n-norm will be computed across all entries in tensor `t` along all dimension + except for the one identified by dim. + Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim), + then norm will have Size [4], and each entry will represent the + `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument p in torch.norm + dim (int): dim identifying the channels to prune + + Returns: + norm (torch.Tensor): L_n norm computed across all dimensions except + for `dim`. By construction, `norm.shape = t.shape[-1]`. + """ + # dims = all axes, except for the one identified by `dim` + dims = list(range(t.dim())) + # convert negative indexing + if dim < 0: + dim = dims[dim] + dims.remove(dim) + + norm = torch.norm(t, p=n, dim=dims) + return norm diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/rnn.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..13fa6324833196170abe5d273aa8145a693b0e57 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/rnn.py @@ -0,0 +1,599 @@ +import warnings +from collections.abc import Iterable +from typing import ( + Any, + Callable, + List, + NamedTuple, + Optional, + overload, + Tuple, + TypeVar, + Union, +) +from typing_extensions import Self + +import torch +from torch import _VF, Tensor + + +__all__ = [ + "PackedSequence", + "invert_permutation", + "pack_padded_sequence", + "pad_packed_sequence", + "pad_sequence", + "unpad_sequence", + "pack_sequence", + "unpack_sequence", +] + +_T = TypeVar("_T") +_R = TypeVar("_R") + + +class PackedSequence_(NamedTuple): + data: torch.Tensor + batch_sizes: torch.Tensor + sorted_indices: Optional[torch.Tensor] + unsorted_indices: Optional[torch.Tensor] + + +def bind(optional: Optional[_T], fn: Callable[[_T], _R]) -> Optional[_R]: + if optional is None: + return None + return fn(optional) + + +class PackedSequence(PackedSequence_): + r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence. + + All RNN modules accept packed sequences as inputs. + + Note: + Instances of this class should never be created manually. They are meant + to be instantiated by functions like :func:`pack_padded_sequence`. + + Batch sizes represent the number elements at each sequence step in + the batch, not the varying sequence lengths passed to + :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x`` + the :class:`PackedSequence` would contain data ``axbc`` with + ``batch_sizes=[2,1,1]``. + + Attributes: + data (Tensor): Tensor containing packed sequence + batch_sizes (Tensor): Tensor of integers holding + information about the batch size at each sequence step + sorted_indices (Tensor, optional): Tensor of integers holding how this + :class:`PackedSequence` is constructed from sequences. + unsorted_indices (Tensor, optional): Tensor of integers holding how this + to recover the original sequences with correct order. + + .. note:: + :attr:`data` can be on arbitrary device and of arbitrary dtype. + :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64`` + tensors on the same device as :attr:`data`. + + However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor. + + This invariant is maintained throughout :class:`PackedSequence` class, + and all functions that construct a :class:`PackedSequence` in PyTorch + (i.e., they only pass in tensors conforming to this constraint). + """ + + def __new__( + cls, + data: Tensor, + batch_sizes: Optional[Tensor] = None, + sorted_indices: Optional[Tensor] = None, + unsorted_indices: Optional[Tensor] = None, + ) -> Self: + return super().__new__( + cls, + *_packed_sequence_init_args( + data, batch_sizes, sorted_indices, unsorted_indices + ), + ) + + # NOTE [ device and dtype of a PackedSequence ] + # + # See the note above in doc string (starting with ":attr:`data` can be on + # arbitrary device..."). + def pin_memory(self) -> Self: + # Why not convert `batch_sizes`? + # See NOTE [ device and dtype of a PackedSequence ] + return type(self)( + self.data.pin_memory(), + self.batch_sizes, + bind(self.sorted_indices, lambda t: t.pin_memory()), + bind(self.unsorted_indices, lambda t: t.pin_memory()), + ) + + @overload + def to( + self, + dtype: torch.dtype, + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: + ... + + @overload + def to( + self, + device: Optional[Union[str, torch.device, int]] = ..., + dtype: Optional[torch.dtype] = ..., + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: + ... + + @overload + def to( + self, + other: Tensor, + non_blocking: bool = ..., + copy: bool = ..., + ) -> Self: + ... + + def to(self, *args: Any, **kwargs: Any) -> Self: + r"""Perform dtype and/or device conversion on `self.data`. + + It has similar signature as :meth:`torch.Tensor.to`, except optional + arguments like `non_blocking` and `copy` should be passed as kwargs, + not args, or they will not apply to the index tensors. + + .. note:: + + If the ``self.data`` Tensor already has the correct :class:`torch.dtype` + and :class:`torch.device`, then ``self`` is returned. + Otherwise, returns a copy with the desired configuration. + """ + # Why not convert `batch_sizes`? + # See NOTE [ device and dtype of a PackedSequence ] + data = self.data.to(*args, **kwargs) + if data is self.data: + return self + else: + # Does not forward device or dtype arg/kwargs, device is set from data.device + kwargs = dict( + filter(lambda t: t[0] != "device" and t[0] != "dtype", kwargs.items()) + ) + sorted_indices = bind( + self.sorted_indices, lambda t: t.to(data.device, **kwargs) + ) + unsorted_indices = bind( + self.unsorted_indices, lambda t: t.to(data.device, **kwargs) + ) + return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices) + + def cuda(self, *args: Any, **kwargs: Any) -> Self: + # Tests to see if 'cuda' should be added to kwargs + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( + *args, **kwargs + ) + if ex.is_cuda: + return self.to(*args, **kwargs) + kwargs["device"] = "cuda" + return self.to(*args, **kwargs) + + def cpu(self, *args: Any, **kwargs: Any) -> Self: + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( + *args, **kwargs + ) + if ex.device.type == "cpu": + return self.to(*args, **kwargs) + kwargs["device"] = "cpu" + return self.to(*args, **kwargs) + + def double(self) -> Self: + return self.to(dtype=torch.double) + + def float(self) -> Self: + return self.to(dtype=torch.float) + + def half(self) -> Self: + return self.to(dtype=torch.half) + + def long(self) -> Self: + return self.to(dtype=torch.long) + + def int(self) -> Self: + return self.to(dtype=torch.int) + + def short(self) -> Self: + return self.to(dtype=torch.short) + + def char(self) -> Self: + return self.to(dtype=torch.int8) + + def byte(self) -> Self: + return self.to(dtype=torch.uint8) + + @property + def is_cuda(self) -> bool: + r"""Return true if `self.data` stored on a gpu.""" + return self.data.is_cuda + + def is_pinned(self) -> bool: + r"""Return true if `self.data` stored on in pinned memory.""" + return self.data.is_pinned() + + +# TorchScript doesn't support constructors on named tuples, so we use this helper +# method to construct PackedSequence +def _packed_sequence_init_args( + data: Tensor, + batch_sizes: Optional[Tensor] = None, + sorted_indices: Optional[Tensor] = None, + unsorted_indices: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + # NB: if unsorted_indices is provided, it should be the inverse permutation + # to sorted_indices. Don't assert it here because the PackedSequence ctor + # should only be used internally. + + if unsorted_indices is None: + unsorted_indices = invert_permutation(sorted_indices) + + # support being called as `PackedSequence(data, batch_sizes, sorted_indices)` + if batch_sizes is not None: + # TODO: Re-enable this check (.type isn't supported in TorchScript) + if batch_sizes.device.type != "cpu": + raise ValueError( + "batch_sizes should always be on CPU. " + "Instances of PackedSequence should never be created manually. " + "They should be instantiated by functions like pack_sequence " + "and pack_padded_sequences in nn.utils.rnn. " + "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence" + ) + return data, batch_sizes, sorted_indices, unsorted_indices + + # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)` + else: + assert isinstance(data, (list, tuple)) and len(data) == 2 + return data[0], data[1], sorted_indices, unsorted_indices + + +def _packed_sequence_init( + data: Tensor, + batch_sizes: Optional[Tensor] = None, + sorted_indices: Optional[Tensor] = None, + unsorted_indices: Optional[Tensor] = None, +) -> PackedSequence: + data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args( + data, batch_sizes, sorted_indices, unsorted_indices + ) + return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices) + + +def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]: + if permutation is None: + return None + output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format) + output.scatter_( + 0, permutation, torch.arange(0, permutation.numel(), device=permutation.device) + ) + return output + + +def pack_padded_sequence( + input: Tensor, + lengths: Union[Tensor, List[int]], + batch_first: bool = False, + enforce_sorted: bool = True, +) -> PackedSequence: + r"""Packs a Tensor containing padded sequences of variable length. + + :attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) + or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length + of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions + (including 0). + + For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is + ``True``, the sequences should be sorted by length in a decreasing order, i.e. + ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest + one. `enforce_sorted = True` is only necessary for ONNX export. + + Note: + This function accepts any input that has at least two dimensions. You + can apply it to pack the labels, and use the output of the RNN with + them to compute the loss directly. A Tensor can be retrieved from + a :class:`PackedSequence` object by accessing its ``.data`` attribute. + + Args: + input (Tensor): padded batch of variable length sequences. + lengths (Tensor or list(int)): list of sequence lengths of each batch + element (must be on the CPU if provided as a tensor). + batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *`` + format, ``T x B x *`` otherwise. + enforce_sorted (bool, optional): if ``True``, the input is expected to + contain sequences sorted by length in a decreasing order. If + ``False``, the input will get sorted unconditionally. Default: ``True``. + + Returns: + a :class:`PackedSequence` object + """ + if not isinstance(lengths, torch.Tensor): + if torch._C._get_tracing_state(): + warnings.warn( + "pack_padded_sequence has been called with a Python list of " + "sequence lengths. The tracer cannot track the data flow of Python " + "values, and it will treat them as constants, likely rendering " + "the trace incorrect for any other combination of lengths.", + stacklevel=2, + ) + lengths = torch.as_tensor(lengths, dtype=torch.int64, device="cpu") + else: + lengths = lengths.to(dtype=torch.int64) + + if enforce_sorted: + sorted_indices = None + else: + lengths, sorted_indices = torch.sort(lengths, descending=True) + sorted_indices = sorted_indices.to(input.device) + batch_dim = 0 if batch_first else 1 + input = input.index_select(batch_dim, sorted_indices) + + data, batch_sizes = _VF._pack_padded_sequence(input, lengths, batch_first) + return _packed_sequence_init(data, batch_sizes, sorted_indices, None) + + +def pad_packed_sequence( + sequence: PackedSequence, + batch_first: bool = False, + padding_value: float = 0.0, + total_length: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + r"""Pad a packed batch of variable length sequences. + + It is an inverse operation to :func:`pack_padded_sequence`. + + The returned Tensor's data will be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) + or ``B x T x *`` (if :attr:`batch_first` is ``True``) , where ``T`` is the length of the longest + sequence and ``B`` is the batch size. + + Example: + >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) + >>> lens = [2, 1, 3] + >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False) + >>> packed + PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), + sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) + >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) + >>> seq_unpacked + tensor([[1, 2, 0], + [3, 0, 0], + [4, 5, 6]]) + >>> lens_unpacked + tensor([2, 1, 3]) + + .. note:: + :attr:`total_length` is useful to implement the + ``pack sequence -> recurrent network -> unpack sequence`` pattern in a + :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. + See :ref:`this FAQ section ` for + details. + + Args: + sequence (PackedSequence): batch to pad + batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` + format, ``T x B x *`` otherwise. + padding_value (float, optional): values for padded elements. + total_length (int, optional): if not ``None``, the output will be padded to + have length :attr:`total_length`. This method will throw :class:`ValueError` + if :attr:`total_length` is less than the max sequence length in + :attr:`sequence`. + + Returns: + Tuple of Tensor containing the padded sequence, and a Tensor + containing the list of lengths of each sequence in the batch. + Batch elements will be re-ordered as they were ordered originally when + the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``. + """ + max_seq_length = sequence.batch_sizes.size(0) + if total_length is not None: + if total_length < max_seq_length: + raise ValueError( + "Expected total_length to be at least the length " + "of the longest sequence in input, but got " + f"total_length={total_length} and max sequence length being {max_seq_length}" + ) + max_seq_length = total_length + padded_output, lengths = _VF._pad_packed_sequence( + sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length + ) + unsorted_indices = sequence.unsorted_indices + if unsorted_indices is not None: + batch_dim = 0 if batch_first else 1 + return ( + padded_output.index_select(batch_dim, unsorted_indices), + lengths[unsorted_indices.cpu()], + ) + return padded_output, lengths + + +# NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable. +def pad_sequence( + sequences: Union[Tensor, List[Tensor]], + batch_first: bool = False, + padding_value: float = 0.0, + padding_side: str = "right", +) -> Tensor: + r"""Pad a list of variable length Tensors with :attr:`padding_value`. + + ``pad_sequence`` stacks a list of Tensors along a new dimension, and pads them + to equal length. :attr:`sequences` can be list of sequences with size ``L x *``, + where `L` is length of the sequence and ``*`` is any number of dimensions + (including 0). If :attr:`batch_first` is ``False``, the output is of size + ``T x B x *``, and ``B x T x *`` otherwise, where ``B`` is the batch size + (the number of elements in :attr:`sequences`), ``T`` is the length of the longest + sequence. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> pad_sequence([a, b, c]).size() + torch.Size([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` + format, ``T x B x *`` otherwise. + padding_value (float, optional): value for padded elements. Default: 0. + padding_side (str, optional): the side to pad the sequences on. + Default: "right". + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # JIT doesn't support `Iterable` + if not isinstance(sequences, Iterable): + msg = ( + "pad_sequence: Expected iterable for input sequences, but got arg of type: " + f"{type(sequences)}" + ) + raise RuntimeError(msg) + + # In JIT context this leads to, + # RuntimeError: cannot statically infer the expected size of a list in this context + sequences = tuple(sequences) # type: ignore[assignment] + else: + # For JIT, we only support Union[Tensor, Tuple[Tensor]] + if isinstance(sequences, torch.Tensor): + sequences = sequences.unbind(0) # type: ignore[assignment] + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + return torch._C._nn.pad_sequence( + sequences, batch_first, padding_value, padding_side # type: ignore[arg-type] + ) + + +def unpad_sequence( + padded_sequences: Tensor, + lengths: Tensor, + batch_first: bool = False, +) -> List[Tensor]: + r"""Unpad padded Tensor into a list of variable length Tensors. + + ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> sequences = [a, b, c] + >>> padded_sequences = pad_sequence(sequences) + >>> lengths = torch.as_tensor([v.size(0) for v in sequences]) + >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths) + >>> torch.allclose(sequences[0], unpadded_sequences[0]) + True + >>> torch.allclose(sequences[1], unpadded_sequences[1]) + True + >>> torch.allclose(sequences[2], unpadded_sequences[2]) + True + + Args: + padded_sequences (Tensor): padded sequences. + lengths (Tensor): length of original (unpadded) sequences. + batch_first (bool, optional): whether batch dimension first or not. Default: False. + + Returns: + a list of :class:`Tensor` objects + """ + unpadded_sequences = [] + + if not batch_first: + padded_sequences.transpose_(0, 1) + + max_length = padded_sequences.shape[1] + idx = torch.arange(max_length, device=lengths.device) + + for seq, length in zip(padded_sequences, lengths): + mask = idx < length + unpacked_seq = seq[mask] + unpadded_sequences.append(unpacked_seq) + + return unpadded_sequences + + +def pack_sequence( + sequences: List[Tensor], + enforce_sorted: bool = True, +) -> PackedSequence: + r"""Packs a list of variable length Tensors. + + Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``. + + ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is + the length of a sequence and `*` is any number of trailing dimensions, + including zero. + + For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` + is ``True``, the sequences should be sorted in the order of decreasing length. + ``enforce_sorted = True`` is only necessary for ONNX export. + + Example: + >>> from torch.nn.utils.rnn import pack_sequence + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5]) + >>> c = torch.tensor([6]) + >>> pack_sequence([a, b, c]) + PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) + + Args: + sequences (list[Tensor]): A list of sequences of decreasing length. + enforce_sorted (bool, optional): if ``True``, checks that the input + contains sequences sorted by length in a decreasing order. If + ``False``, this condition is not checked. Default: ``True``. + + Returns: + a :class:`PackedSequence` object + """ + lengths = torch.as_tensor([v.size(0) for v in sequences]) + return pack_padded_sequence( + pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted + ) + + +def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]: + r"""Unpack PackedSequence into a list of variable length Tensors. + + ``packed_sequences`` should be a PackedSequence object. + + Example: + >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5]) + >>> c = torch.tensor([6]) + >>> sequences = [a, b, c] + >>> print(sequences) + [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] + >>> packed_sequences = pack_sequence(sequences) + >>> print(packed_sequences) + PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) + >>> unpacked_sequences = unpack_sequence(packed_sequences) + >>> print(unpacked_sequences) + [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] + + Args: + packed_sequences (PackedSequence): A PackedSequence object. + + Returns: + a list of :class:`Tensor` objects + """ + padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True) + unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True) + return unpacked_sequences diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/spectral_norm.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..3474a127a0b4946fa57e9d4ec6b23ab44e2706ff --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/spectral_norm.py @@ -0,0 +1,366 @@ +# mypy: allow-untyped-defs +"""Spectral Normalization from https://arxiv.org/abs/1802.05957.""" +from typing import Any, Optional, TypeVar + +import torch +import torch.nn.functional as F +from torch.nn.modules import Module + + +__all__ = [ + "SpectralNorm", + "SpectralNormLoadStateDictPreHook", + "SpectralNormStateDictHook", + "spectral_norm", + "remove_spectral_norm", +] + + +class SpectralNorm: + # Invariant before and after each forward call: + # u = F.normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version: int = 1 + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + name: str + dim: int + n_power_iterations: int + eps: float + + def __init__( + self, + name: str = "weight", + n_power_iterations: int = 1, + dim: int = 0, + eps: float = 1e-12, + ) -> None: + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError( + "Expected n_power_iterations to be positive, but " + f"got n_power_iterations={n_power_iterations}" + ) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute( + self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim] + ) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor: + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + "_orig") + u = getattr(module, self.name + "_u") + v = getattr(module, self.name + "_v") + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = F.normalize( + torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v + ) + u = F.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone(memory_format=torch.contiguous_format) + v = v.clone(memory_format=torch.contiguous_format) + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module: Module) -> None: + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + "_u") + delattr(module, self.name + "_v") + delattr(module, self.name + "_orig") + module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr( + module, + self.name, + self.compute_weight(module, do_power_iteration=module.training), + ) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = F.normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.linalg.multi_dot( + [weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)] + ).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply( + module: Module, name: str, n_power_iterations: int, dim: int, eps: float + ) -> "SpectralNorm": + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError( + f"Cannot register two spectral_norm hooks on the same parameter {name}" + ) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + if weight is None: + raise ValueError( + f"`SpectralNorm` cannot be applied as parameter `{name}` is None" + ) + if isinstance(weight, torch.nn.parameter.UninitializedParameter): + raise ValueError( + "The module passed to `SpectralNorm` can't have uninitialized parameters. " + "Make sure to run the dummy forward before applying spectral normalization" + ) + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = F.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = F.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = F.normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + fn = self.fn + version = local_metadata.get("spectral_norm", {}).get( + fn.name + ".version", None + ) + if version is None or version < 1: + weight_key = prefix + fn.name + if ( + version is None + and all(weight_key + s in state_dict for s in ("_orig", "_u", "_v")) + and weight_key not in state_dict + ): + # Detect if it is the updated state dict and just missing metadata. + # This could happen if the users are crafting a state dict themselves, + # so we just pretend that this is the newest. + return + has_missing_keys = False + for suffix in ("_orig", "", "_u"): + key = weight_key + suffix + if key not in state_dict: + has_missing_keys = True + if strict: + missing_keys.append(key) + if has_missing_keys: + return + with torch.no_grad(): + weight_orig = state_dict[weight_key + "_orig"] + weight = state_dict.pop(weight_key) + sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[weight_key + "_u"] + v = fn._solve_v_and_rescale(weight_mat, u, sigma) + state_dict[weight_key + "_v"] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata) -> None: + if "spectral_norm" not in local_metadata: + local_metadata["spectral_norm"] = {} + key = self.fn.name + ".version" + if key in local_metadata["spectral_norm"]: + raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}") + local_metadata["spectral_norm"][key] = self.fn._version + + +T_module = TypeVar("T_module", bound=Module) + + +def spectral_norm( + module: T_module, + name: str = "weight", + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: Optional[int] = None, +) -> T_module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + .. note:: + This function has been reimplemented as + :func:`torch.nn.utils.parametrizations.spectral_norm` using the new + parametrization functionality in + :func:`torch.nn.utils.parametrize.register_parametrization`. Please use + the newer version. This function will be deprecated in a future version + of PyTorch. + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance( + module, + ( + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ), + ): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module: T_module, name: str = "weight") -> T_module: + r"""Remove the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + break + else: + raise ValueError(f"spectral_norm of '{name}' not found in {module}") + + for k, hook in module._state_dict_hooks.items(): + if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name: + del module._state_dict_hooks[k] + break + + for k, hook in module._load_state_dict_pre_hooks.items(): + if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name: + del module._load_state_dict_pre_hooks[k] + break + + return module diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/stateless.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/stateless.py new file mode 100644 index 0000000000000000000000000000000000000000..69994f315f77e68083ca5e669c9e36dc13bdf900 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/stateless.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, Optional, Set, Tuple, Union +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + + +__all__ = ["functional_call"] + + +def _untie_named_tensors_map( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """ + Unties all tied tensors in the module to parameters_and_buffers. + + This function returns a new untied_parameters_and_buffers dictionary and leave the original + untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors + in the module to untied_parameters_and_buffers. The value of the new key is the user-given value + in the original parameters_and_buffers dictionary. + + If there are more than one user-given values for the same tied tensor, it will raise an error. + + For example, if the module has two tied weights self.foo and self.tied_foo and the user passes + {'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the + user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the + user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error. + + Args: + module (torch.nn.Module): the module to determine which tensors are tied. + parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module. + + Returns: + A new untied version of the parameters_and_buffers dictionary. + + Raises: + ValueError: if there are more than one user-given values for the same tied tensor. + """ + # A map of {name: tensor} for all tensors (including tied ones) in the module. + all_named_tensors: Dict[str, Tensor] = {} + all_named_tensors.update(module.named_parameters(remove_duplicate=False)) + all_named_tensors.update(module.named_buffers(remove_duplicate=False)) + + # A map of {tensor: set(all_tied_names)} for all tensor names in the module. + tensor_to_tied_names_map: Dict[Tensor, Set[str]] = {} + for name, tensor in all_named_tensors.items(): + if tensor not in tensor_to_tied_names_map: + tensor_to_tied_names_map[tensor] = set() + tensor_to_tied_names_map[tensor].add(name) + + # A map of {tied_name: set(all_tied_names)} for all tensor names in the module. + # If a name is not tied, it will not be in this map. + tied_names_map: Dict[str, Set[str]] = {} + for tied_names in tensor_to_tied_names_map.values(): + if len(tied_names) > 1: + for tied_name in tied_names: + tied_names_map[tied_name] = tied_names + + # Make sure the user didn't pass multiple values for the same tied tensor. + given_names = set(parameters_and_buffers.keys()) + # same as given_names.intersection(tied_names_map.keys()) but dynamo can't + # handle that + given_names_for_tied_tensors: set[str] = set() + for name in given_names: + if name in tied_names_map: + given_names_for_tied_tensors.add(name) + + for given_name in given_names_for_tied_tensors: + tied_names = tied_names_map[given_name] + if ( + # Detect if there are multiple keys present for the same tied tensor. + len(tied_names.intersection(given_names_for_tied_tensors)) > 1 + # Only raise an error if the user passed multiple values for the same tied tensor. + # If all given values are the same, don't raise. + and len({parameters_and_buffers[tied_name] for tied_name in tied_names}) + != 1 + ): + raise ValueError( + f"functional_call got multiple values for keys {sorted(tied_names)}, " + f"which are tied. Consider using tie_weights=False" + ) + + # Untie the given named tensor map + # Make a copy for not modifying the original dict + untied_parameters_and_buffers = parameters_and_buffers.copy() + for given_name in given_names_for_tied_tensors: + for tied_name in tied_names_map[given_name]: + untied_parameters_and_buffers[tied_name] = parameters_and_buffers[ + given_name + ] + return untied_parameters_and_buffers + + +class _ReparametrizeModule: + def __init__( + self, + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + tie_weights: bool = False, + strict: bool = False, + stack_weights: bool = False, + ): + self.parameters_and_buffers = parameters_and_buffers + self.stack_weights = stack_weights + + if tie_weights: + self.untied_parameters_and_buffers = _untie_named_tensors_map( + module, parameters_and_buffers + ) + else: + self.untied_parameters_and_buffers = parameters_and_buffers + + self.accessor = NamedMemberAccessor(module) + if strict: + missing_keys, unexpected_keys = self.accessor.check_keys( + self.untied_parameters_and_buffers + ) + error_msgs = [] + if len(unexpected_keys) > 0: + error_msgs.append( + f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}." + ) + if len(missing_keys) > 0: + error_msgs.append( + f"Missing key(s): {', '.join(map(repr, missing_keys))}." + ) + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in reparametrizing for {}:\n\t{}".format( + module._get_name(), "\n\t".join(error_msgs) + ) + ) + + def __enter__(self): + self.orig_parameters_and_buffers, _ = self.accessor.swap_tensors_dict( + self.untied_parameters_and_buffers, allow_missing=True + ) + + def __exit__(self, exception_type, exception_value, traceback): + if self.stack_weights: + # When stacking is enabled, we will restore the weights in LIFO order. + self.orig_parameters_and_buffers = dict( + reversed(self.orig_parameters_and_buffers.items()) + ) + new_parameters_and_buffers, _ = self.accessor.swap_tensors_dict( + self.orig_parameters_and_buffers, allow_missing=True + ) + # Sometimes the module is not completely stateless and has some in-place modifications on + # the _parameters and _buffers dictionaries. + # Write the changed parameters and buffers back to the original dict. + self.parameters_and_buffers.update( + { + k: new_parameters_and_buffers[k] + for k in self.parameters_and_buffers + if k in new_parameters_and_buffers + } + ) + + +def _reparametrize_module( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + *, + tie_weights: bool = False, + strict: bool = False, + stack_weights: bool = False, +) -> _ReparametrizeModule: + return _ReparametrizeModule( + module, + parameters_and_buffers, + tie_weights=tie_weights, + strict=strict, + stack_weights=stack_weights, + ) + + +@deprecated( + "`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 " + "and will be removed in a future version of PyTorch. " + "Please use `torch.func.functional_call` instead which is a drop-in replacement.", + category=FutureWarning, +) +def functional_call( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + args: Union[Any, Tuple], + kwargs: Optional[Dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + r"""Perform a functional call on the module by replacing the module parameters and buffers with the provided ones. + + .. warning:: + + This API is deprecated as of PyTorch 2.0 and will be removed in a future + version of PyTorch. Please use :func:`torch.func.functional_call` instead, + which is a drop-in replacement for this API. + + .. note:: If the module has active parametrizations, passing a value in the + :attr:`parameters_and_buffers` argument with the name set to the regular parameter + name will completely disable the parametrization. + If you want to apply the parametrization function to the value passed + please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. + + .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected + in the `parameters_and_buffers` input. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # does self.foo = self.foo + 1 + >>> print(mod.foo) # tensor(0.) + >>> functional_call(mod, a, torch.ones(())) + >>> print(mod.foo) # tensor(0.) + >>> print(a['foo']) # tensor(1.) + + .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the + tie_weights flag. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied + >>> print(mod.foo) # tensor(1.) + >>> mod(torch.zeros(())) # tensor(2.) + >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too + >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated + >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} + >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) + + Args: + module (torch.nn.Module): the module to call + parameters_and_buffers (dict of str and Tensor): the parameters that will be used in + the module call. + args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. + kwargs (dict): keyword arguments to be passed to the module call + tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as + tied in the reparamaterized version. Therefore, if True and different values are passed for the tied + parameters and buffers, it will error. If False, it will not respect the originally tied parameters and + buffers unless the values passed for both weights are the same. Default: True. + strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and + buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will + error. Default: False. + + Returns: + Any: the result of calling ``module``. + """ + return _functional_call( + module, + parameters_and_buffers, + args, + kwargs, + tie_weights=tie_weights, + strict=strict, + ) + + +def _functional_call( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + args: Union[Any, Tuple], + kwargs: Optional[Dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + # TODO allow kwargs such as unsafe and others for parametrization + if ( + torch.jit.is_tracing() + or torch.jit.is_scripting() + or isinstance( + module, + ( + torch.jit.RecursiveScriptModule, + torch.jit.ScriptModule, + torch.jit.ScriptFunction, + ), + ) + ): + raise RuntimeError("The stateless API can't be used with Jitted modules") + if isinstance(module, torch.nn.DataParallel): + raise RuntimeError( + "The stateless API can't be used with nn.DataParallel module" + ) + if kwargs is None: + kwargs = {} + if not isinstance(args, tuple): + args = (args,) + with _reparametrize_module( + module, parameters_and_buffers, tie_weights=tie_weights, strict=strict + ): + return module(*args, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py b/.venv/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb51d4df132693075cf246621001729d186c8ea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py @@ -0,0 +1,164 @@ +# mypy: allow-untyped-defs +r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" +from typing import Any, TypeVar +from typing_extensions import deprecated + +from torch import _weight_norm, norm_except_dim +from torch.nn.modules import Module +from torch.nn.parameter import Parameter, UninitializedParameter + + +__all__ = ["WeightNorm", "weight_norm", "remove_weight_norm"] + + +class WeightNorm: + name: str + dim: int + + def __init__(self, name: str, dim: int) -> None: + if dim is None: + dim = -1 + self.name = name + self.dim = dim + + # TODO Make return type more specific + def compute_weight(self, module: Module) -> Any: + g = getattr(module, self.name + "_g") + v = getattr(module, self.name + "_v") + return _weight_norm(v, g, self.dim) + + @staticmethod + @deprecated( + "`torch.nn.utils.weight_norm` is deprecated " + "in favor of `torch.nn.utils.parametrizations.weight_norm`.", + category=FutureWarning, + ) + def apply(module, name: str, dim: int) -> "WeightNorm": + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, WeightNorm) and hook.name == name: + raise RuntimeError( + f"Cannot register two weight_norm hooks on the same parameter {name}" + ) + + if dim is None: + dim = -1 + + fn = WeightNorm(name, dim) + + weight = getattr(module, name) + if isinstance(weight, UninitializedParameter): + raise ValueError( + "The module passed to `WeightNorm` can't have uninitialized parameters. " + "Make sure to run the dummy forward before applying weight normalization" + ) + # remove w from parameter list + del module._parameters[name] + + # add g and v as new parameters and express w as g/||v|| * v + module.register_parameter( + name + "_g", Parameter(norm_except_dim(weight, 2, dim).data) + ) + module.register_parameter(name + "_v", Parameter(weight.data)) + setattr(module, name, fn.compute_weight(module)) + + # recompute weight before every forward() + module.register_forward_pre_hook(fn) + + return fn + + def remove(self, module: Module) -> None: + weight = self.compute_weight(module) + delattr(module, self.name) + del module._parameters[self.name + "_g"] + del module._parameters[self.name + "_v"] + setattr(module, self.name, Parameter(weight.data)) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr(module, self.name, self.compute_weight(module)) + + +T_module = TypeVar("T_module", bound=Module) + + +def weight_norm(module: T_module, name: str = "weight", dim: int = 0) -> T_module: + r"""Apply weight normalization to a parameter in the given module. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude + (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). + Weight normalization is implemented via a hook that recomputes the weight + tensor from the magnitude and direction before every :meth:`~Module.forward` + call. + + By default, with ``dim=0``, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + ``dim=None``. + + See https://arxiv.org/abs/1602.07868 + + .. warning:: + + This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` + which uses the modern parametrization API. The new ``weight_norm`` is compatible + with ``state_dict`` generated from old ``weight_norm``. + + Migration guide: + + * The magnitude (``weight_g``) and direction (``weight_v``) are now expressed + as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` + respectively. If this is bothering you, please comment on + https://github.com/pytorch/pytorch/issues/102999 + + * To remove the weight normalization reparametrization, use + :func:`torch.nn.utils.parametrize.remove_parametrizations`. + + * The weight is no longer recomputed once at module forward; instead, it will + be recomputed on every access. To restore the old behavior, use + :func:`torch.nn.utils.parametrize.cached` before invoking the module + in question. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = weight_norm(nn.Linear(20, 40), name='weight') + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_g.size() + torch.Size([40, 1]) + >>> m.weight_v.size() + torch.Size([40, 20]) + + """ + WeightNorm.apply(module, name, dim) + return module + + +def remove_weight_norm(module: T_module, name: str = "weight") -> T_module: + r"""Remove the weight normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = weight_norm(nn.Linear(20, 40)) + >>> remove_weight_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, WeightNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError(f"weight_norm of '{name}' not found in {module}") diff --git a/.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so b/.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..adf86faf1f0967c1d1eff20870e0dacf1b7f1ecf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63250073d32dd31af3dd82ca6825aadbaf4562c9c4cd9336adb86f91f3d4e143 +size 275990736