| |
| import copy |
| import inspect |
| import math |
| import warnings |
| from collections.abc import Sequence |
| from itertools import chain |
| from typing import Any, Optional |
|
|
| import sympy |
|
|
| import torch |
| import torch.utils._pytree as pytree |
| from torch._export.non_strict_utils import ( |
| _enter_enable_graph_inputs_of_type_nn_module, |
| _exit_enable_graph_inputs_of_type_nn_module, |
| _get_graph_inputs_of_type_nn_module, |
| ) |
| from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( |
| _convert_range_to_int, |
| ) |
| from torch._export.utils import _check_input_constraints_for_graph |
| from torch.export.unflatten import _assign_attr, _AttrKind |
| from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info |
| from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo |
| from torch.fx.traceback import NodeSource, NodeSourceAction |
| from torch.utils._sympy.solve import try_solve |
| from torch.utils._sympy.value_ranges import ValueRanges |
|
|
| from ._remove_effect_tokens_pass import _remove_effect_tokens |
| from ._tree_utils import reorder_kwargs |
| from .exported_program import ( |
| ExportedProgram, |
| ExportGraphSignature, |
| InputKind, |
| OutputKind, |
| ) |
|
|
|
|
| def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool: |
| """ |
| Refinement of TreeSpec.__eq__ where, e.g., torch.Size(...) matches tuple(...). |
| See _pytree_subclasses_that_lose_info in proxy_tensor.py for more details. |
| """ |
|
|
| def _normalize_type(t): |
| return str(_pytree_subclasses_that_lose_info.get(t, t)) |
|
|
| def _match_normalized_structure(a, b): |
| if a is b: |
| return True |
| if _normalize_type(a.type) != _normalize_type(b.type): |
| return False |
| if a.context != b.context: |
| return False |
| if len(a.children_specs) != len(b.children_specs): |
| return False |
| return all( |
| _match_normalized_structure(a, b) |
| for a, b in zip(a.children_specs, b.children_specs) |
| ) |
|
|
| return _match_normalized_structure(self, other) |
|
|
|
|
| def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list: |
| reordered_kwargs = reorder_kwargs(kwargs, in_spec) |
| flat_args_with_path, received_spec = pytree.tree_flatten_with_path( |
| (args, reordered_kwargs) |
| ) |
|
|
| if not eq_spec(received_spec, in_spec): |
| raise ValueError( |
| "Trying to flatten user inputs with exported input tree spec: \n" |
| f"{in_spec}\n" |
| "but actually got inputs with tree spec of: \n" |
| f"{received_spec}.\n" |
| "Please check that the inputs have the same number and type of " |
| "args and kwargs as the ones you used when tracing." |
| ) |
|
|
| return flat_args_with_path |
|
|
|
|
| def _convert_guards_code_to_fn( |
| guards_code: list[str], |
| paths_of_placeholders: list[pytree.KeyPath], |
| ): |
| """ |
| Generates Python code given guards code and paths of placeholders. |
| We assume that, based on source information, |
| - the tracer generates the guards code |
| - the input spec generates the paths of placeholders. |
| |
| Example: |
| |
| Suppose we are given the guards code "L['z']['k'].size()[1] == 3" |
| and we are given that ['z']['k'] is the path of placeholder #2. |
| Then we will generate: |
| ``` |
| torch._assert( |
| args[2].size()[0] == 3, |
| "Guard failed: z['k'].size()[0] == 3", |
| ) |
| ``` |
| |
| FAQ: Why do we generate code based on (flattened) args instead of |
| the original (unflattened) inputs? Because this would require |
| inserting an additional pytree.unflatten call in our graph. |
| |
| FAQ: Why do we not emit RuntimeError on guard failure as we used to? |
| Because it is inconvenient :/, get used to AssertionError instead. |
| """ |
|
|
| import ast |
|
|
| from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP |
|
|
| actual_guards_code = [] |
| shadow_guards_code = [] |
| for c in guards_code: |
| a, s = c, c |
| for idx, path in enumerate(paths_of_placeholders): |
| |
| a = a.replace("L" + pytree.keystr(path), f"args[{idx}]") |
| |
| s = s.replace( |
| "L" + pytree.keystr(path), |
| path[0].key + pytree.keystr(path[1:]), |
| ) |
| actual_guards_code.append(a) |
| shadow_guards_code.append(s.replace("\n", "")) |
|
|
| |
| code_str = "\ndef _(*args):\n" |
| for actual, shadow in zip(actual_guards_code, shadow_guards_code): |
| |
| |
| |
| _shadow = ast.unparse(ast.parse(shadow, mode="eval")) |
| |
| code_str += f' torch._assert({actual}, "Guard failed: {_shadow}")\n' |
| code_str += " return\n" |
|
|
| |
| namespace = {**SYMPY_INTERP} |
| exec(code_str, namespace) |
|
|
| |
| |
| |
| |
| |
| guards_fn = GuardsFn() |
| guards_fn.forward = torch._dynamo.dont_skip_tracing(namespace["_"]) |
| guards_fn._is_impure = True |
| return guards_fn |
|
|
|
|
| @torch._dynamo.disable |
| def _check_input_constraints_for_module(self, args, kwargs): |
| flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) |
| _check_input_constraints_for_graph( |
| self.graph.find_nodes(op="placeholder"), |
| flat_args_with_path, |
| self.range_constraints, |
| ) |
|
|
|
|
| def _check_input_constraints_pre_hook(self, args, kwargs): |
| |
| if not self.validate_inputs: |
| return |
|
|
| |
| |
| |
| if hasattr(self, "_guards_fn"): |
| _check_inputs_match(args, kwargs, self._in_spec) |
| return |
|
|
| |
| _check_input_constraints_for_module(self, args, kwargs) |
|
|
|
|
| def _unlift_inputs_as_getattr( |
| gm: torch.fx.GraphModule, |
| lifted_inputs: Sequence[Optional[str]], |
| ) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]: |
| """ |
| Unlift inputs referring to params/buffers/constants as getattr nodes in the |
| graph |
| """ |
| unlifted_name_to_node = {} |
| input_name_to_node = {} |
|
|
| placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] |
| assert len(lifted_inputs) == len(placeholder_nodes) |
| for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs): |
| if lifted_node is None: |
| input_name_to_node[input_node.name] = input_node |
|
|
| else: |
| with gm.graph.inserting_after(input_node): |
| |
| |
| |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| getattr_node = gm.graph.get_attr(lifted_node) |
| input_node.replace_all_uses_with(getattr_node) |
| metadata = input_node.meta |
| gm.graph.erase_node(input_node) |
| getattr_node.meta = metadata |
| getattr_node.meta["from_node"] = [ |
| NodeSource( |
| input_node, |
| "ExportedProgram.module().unlift()", |
| [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], |
| ) |
| ] |
| unlifted_name_to_node[lifted_node] = getattr_node |
|
|
| return unlifted_name_to_node, input_name_to_node |
|
|
|
|
| def _insert_copy_for_mutations( |
| gm: torch.fx.GraphModule, |
| mutated_outputs: Sequence[Optional[str]], |
| unlifted_name_to_node: dict[str, torch.fx.Node], |
| input_name_to_node: dict[str, torch.fx.Node], |
| ) -> None: |
| """ |
| Find the all the buffers and inputs that were mutated and insert copy_ |
| operators to reflect mutations. |
| """ |
| output_node = gm.graph.output_node() |
| outputs = pytree.tree_flatten(output_node.args)[0] |
| assert len(outputs) == len(mutated_outputs) |
|
|
| user_output_nodes = [] |
| return_nodes_to_copy = {} |
| for return_node, mutated_node_name in zip(outputs, mutated_outputs): |
| if mutated_node_name is None: |
| user_output_nodes.append(return_node) |
| continue |
|
|
| if mutated_node_name in unlifted_name_to_node: |
| mutated_node = unlifted_name_to_node[mutated_node_name] |
| elif mutated_node_name in input_name_to_node: |
| mutated_node = input_name_to_node[mutated_node_name] |
| else: |
| raise RuntimeError( |
| f"Could not find {mutated_node_name} in either buffer or input nodes" |
| ) |
|
|
| with gm.graph.inserting_before(output_node): |
| copy_node = gm.graph.call_function( |
| torch.ops.aten.copy_.default, (mutated_node, return_node) |
| ) |
| return_nodes_to_copy[return_node] = copy_node |
|
|
| output_args = tuple( |
| return_nodes_to_copy[node] if node in return_nodes_to_copy else node |
| for node in user_output_nodes |
| ) |
| with gm.graph.inserting_before(output_node): |
| |
| new_output = gm.graph.output(output_args) |
| output_node.replace_all_uses_with(new_output) |
| gm.graph.erase_node(output_node) |
| new_output.name = output_node.name |
| new_output.meta.update(output_node.meta) |
| new_output.meta["from_node"] = [ |
| NodeSource( |
| output_node, |
| "ExportedProgram.module().unlift()", |
| [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], |
| ) |
| ] |
|
|
|
|
| def _get_codegen( |
| in_spec: pytree.TreeSpec, |
| out_spec: Optional[pytree.TreeSpec], |
| forward_arg_names: Optional[list[str]] = None, |
| ) -> _PyTreeCodeGen: |
| """ |
| Create the codegen for the graph module based on the in/out specs |
| """ |
| if forward_arg_names: |
| names = forward_arg_names |
| elif ( |
| in_spec.type == tuple |
| and in_spec.num_children == 2 |
| and in_spec.children_specs[0].type == tuple |
| and in_spec.children_specs[1].type == dict |
| ): |
| |
| names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] |
| |
| names.extend(in_spec.children_specs[1].context) |
| else: |
| names = [f"arg_{i}" for i in range(in_spec.num_children)] |
|
|
| return _PyTreeCodeGen( |
| _PyTreeInfo( |
| names, |
| in_spec, |
| out_spec, |
| ) |
| ) |
|
|
|
|
| def _unlift( |
| gm: torch.fx.GraphModule, |
| lifted_inputs: Sequence[Optional[str]], |
| mutated_outputs: Sequence[Optional[str]], |
| in_spec: pytree.TreeSpec, |
| out_spec: Optional[pytree.TreeSpec], |
| forward_arg_names: Optional[list[str]] = None, |
| ): |
| """ |
| Args: |
| lifted_inputs: A list matching the graph module's input nodes. For |
| an input node that is referring to a lifted parameter/buffer, this |
| list will contain the fqn the corresponding attribute. Otherwise, this |
| list will contain None. This is used to unlift the lifted parameters as |
| get_attr nodes. |
| |
| mutated_outputs: A list matching the graph module's output nodes. For |
| an output node that is referring to a mutated buffer or user input, this |
| list will contain the name of the corresponding buffer or user input |
| that needs to be mutated. Otherwise, this list will contain None. This |
| is used to re-insert an inplace copy_ operator to copy the mutated |
| values back to the original node. |
| """ |
| unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr( |
| gm, lifted_inputs |
| ) |
| _insert_copy_for_mutations( |
| gm, mutated_outputs, unlifted_name_to_node, input_name_to_node |
| ) |
| gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names) |
| gm.graph.lint() |
| gm.recompile() |
| return gm |
|
|
|
|
| def _register_attrs_to_new_gm( |
| new_gm: torch.fx.GraphModule, |
| graph_signature: ExportGraphSignature, |
| state_dict: dict[str, Any], |
| constants: dict[str, Any], |
| ) -> None: |
| non_persistent_buffers = set(graph_signature.non_persistent_buffers) |
| for name in graph_signature.buffers: |
| if name in non_persistent_buffers: |
| persistent = False |
| value = constants[name] |
| else: |
| persistent = True |
| value = state_dict[name] |
| _assign_attr( |
| value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent |
| ) |
| for name in graph_signature.parameters: |
| value = state_dict[name] |
| _assign_attr( |
| value, |
| new_gm, |
| name, |
| attr_kind=_AttrKind.PARAMETER, |
| ) |
|
|
| |
| |
| |
| for name in chain( |
| graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants |
| ): |
| value = constants[name] |
| _assign_attr( |
| value, |
| new_gm, |
| name, |
| attr_kind=_AttrKind.CONSTANT, |
| ) |
|
|
|
|
| class _StatefulGraphModuleFactory(type): |
| """ |
| Metaclass that ensures a private constructor for _StatefulGraphModule |
| """ |
|
|
| def __call__(cls, *args, **kwargs): |
| raise TypeError( |
| f"{cls.__module__}.{cls.__qualname__} has no public constructor. " |
| ) |
|
|
| def _create(cls, root, graph, range_constraints=None): |
| return super().__call__( |
| root, |
| graph, |
| range_constraints=range_constraints, |
| ) |
|
|
|
|
| class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory): |
| def __init__(self, root, graph, range_constraints=None): |
| super().__init__(root, graph) |
| |
| self.range_constraints = range_constraints or [] |
| self.validate_inputs = True |
|
|
|
|
| def _create_stateful_graph_module( |
| plain_graph_module: torch.fx.GraphModule, |
| range_constraints, |
| ep: ExportedProgram, |
| ) -> _StatefulGraphModule: |
| stateful_gm = _StatefulGraphModule._create( |
| plain_graph_module, |
| plain_graph_module.graph, |
| range_constraints=range_constraints, |
| ) |
|
|
| module_types = _get_graph_inputs_of_type_nn_module(ep.example_inputs) |
| stateful_gm.register_forward_pre_hook( |
| lambda *args, **kwargs: _enter_enable_graph_inputs_of_type_nn_module( |
| module_types |
| ) |
| ) |
| stateful_gm.register_forward_pre_hook( |
| _check_input_constraints_pre_hook, with_kwargs=True |
| ) |
|
|
| stateful_gm.register_forward_hook( |
| lambda *args, **kwargs: _exit_enable_graph_inputs_of_type_nn_module( |
| module_types |
| ), |
| always_call=True, |
| ) |
|
|
| |
| |
| |
| |
| |
| original_tensor_to_detached_tensor = {} |
|
|
| |
| |
| |
| |
| |
| for constant_fqn in ep.graph_signature.lifted_tensor_constants: |
| |
| |
| |
| |
| |
| buffer = stateful_gm.get_buffer(constant_fqn) |
| if buffer.requires_grad: |
| warnings.warn( |
| f"A model attribute `{constant_fqn}` requires gradient. " |
| f"but it's not properly registered as a parameter. " |
| f"torch.export will detach it and treat it as a constant tensor " |
| f"but please register it as parameter instead." |
| ) |
| detached_buffer = buffer.detach() |
| original_tensor_to_detached_tensor[buffer] = detached_buffer |
| buffer = detached_buffer |
| *prefix, field = constant_fqn.rsplit(".") |
| submod = torch.fx.graph_module._get_attr_via_attr_list(stateful_gm, prefix) |
| delattr(submod, field) |
| _assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT) |
|
|
| |
| for const_name, value in ep.constants.items(): |
| if not torch.fx.graph_module._has_attr(stateful_gm, const_name): |
| if isinstance(value, torch.Tensor): |
| if value.requires_grad: |
| warnings.warn( |
| f"A model attribute `{const_name}` requires gradient " |
| f"but it's not properly registered as a parameter. " |
| f"torch.export will detach it and treat it as a constant tensor " |
| f"but please register it as parameter instead." |
| ) |
| if value in original_tensor_to_detached_tensor: |
| value = original_tensor_to_detached_tensor[value] |
| else: |
| detached_value = value.detach() |
| original_tensor_to_detached_tensor[value] = detached_value |
| value = detached_value |
| _assign_attr( |
| value, |
| stateful_gm, |
| const_name, |
| attr_kind=_AttrKind.CONSTANT, |
| ) |
|
|
| |
| |
| |
| for buffer in ep.graph_signature.non_persistent_buffers: |
| _assign_attr( |
| plain_graph_module.get_buffer(buffer), |
| stateful_gm, |
| buffer, |
| attr_kind=_AttrKind.BUFFER, |
| persistent=False, |
| ) |
|
|
| return stateful_gm |
|
|
|
|
| def _get_input_paths(example_inputs, signature): |
| """ |
| Generate paths of placeholders, needed for generating the guards function. |
| |
| NOTE: Here we make use of the example inputs used for export as well as |
| the signature of the unlifted graph module (not preserved by export). |
| """ |
|
|
| args, kwargs = example_inputs |
| ctx = signature.bind(*args, **kwargs).arguments |
| flat_example_inputs_with_paths = pytree.tree_leaves_with_path(ctx) |
| return [path for path, _ in flat_example_inputs_with_paths] |
|
|
|
|
| def _get_input_guards_for_graph( |
| placeholders: list[torch.fx.Node], |
| range_constraints: dict[sympy.Symbol, ValueRanges], |
| paths_for_placeholders: list[pytree.KeyPath], |
| ): |
| """ |
| Guards generated by the tracer include conditions observed in code, but |
| but do not include some additional checks we typically do in export. |
| For example, when dynamic shapes get specialized, are specified to be |
| within a range, or are specified to be in some equational relation, |
| corresponding input invalidation is done within a pre_hook, specifically, |
| `_check_input_constraints_for_graph`. |
| |
| Here we generate guards corresponding to the checks that happen in |
| `_check_input_constraints_for_graph`, and add them to the guards already |
| generated by the tracer. In the future, it may be worthwhile to separate |
| them so that we can allow clients to turn off one but not the other. |
| (Looking at you, AOTI.) |
| |
| NOTE: We should eventually reconcile this logic with `build_guards` that |
| is used by AOT Precompile. |
| """ |
|
|
| deferred_expressions = [] |
| new_guards_code = [] |
| sources: dict[sympy.Expr, str] = {} |
|
|
| def handle_symint(expr, src): |
| if len(expr.free_symbols) == 1: |
| |
| |
| |
| deferred_expressions.append((src, expr)) |
| if expr in sources: |
| |
| |
| |
| orig_src = sources[expr] |
| new_guards_code.append(f"{src} == {orig_src}") |
| else: |
| sources[expr] = src |
| |
| min_val, max_val = _convert_range_to_int(range_constraints[expr]) |
| if min_val > 2: |
| new_guards_code.append(f"{src} >= {min_val}") |
| if max_val < math.inf: |
| new_guards_code.append(f"{src} <= {max_val}") |
|
|
| for placeholder, path in zip(placeholders, paths_for_placeholders): |
| src = "L" + pytree.keystr(path) |
| meta = placeholder.meta["val"] |
| |
| if isinstance(meta, int): |
| new_guards_code.append(f"{src} == {meta}") |
| if isinstance(meta, float): |
| if meta == math.inf: |
| new_guards_code.append(f"{src} == math.inf") |
| elif meta == -math.inf: |
| new_guards_code.append(f"{src} == -math.inf") |
| else: |
| new_guards_code.append(f"{src} == {meta}") |
| elif isinstance(meta, str): |
| new_guards_code.append(f"{src} == '{meta}'") |
| |
| elif isinstance(meta, torch.SymInt) and meta.node.expr in range_constraints: |
| handle_symint(meta.node.expr, src) |
| elif isinstance(meta, torch.Tensor): |
| for i, dim in enumerate(meta.shape): |
| src = "L" + pytree.keystr(path) + f".size()[{i}]" |
| if isinstance(dim, int): |
| |
| new_guards_code.append(f"{src} == {dim}") |
| elif ( |
| isinstance(dim, torch.SymInt) and dim.node.expr in range_constraints |
| ): |
| |
| handle_symint(dim.node.expr, src) |
|
|
| unification_map: dict[sympy.Symbol, sympy.Expr] = {} |
| py_printer = torch.utils._sympy.printers.PythonPrinter() |
|
|
| |
| for src, expr in deferred_expressions: |
| |
| symbol = next(iter(expr.free_symbols)) |
| if symbol in sources: |
| |
| |
| |
| continue |
|
|
| |
| |
| if symbol in unification_map: |
| |
| |
| substitution = expr.subs(unification_map) |
| new_guards_code.append( |
| py_printer.doprint(sympy.Eq(substitution, sympy.Symbol(src))) |
| ) |
| else: |
| |
| |
| solution = try_solve(sympy.Eq(expr, sympy.Symbol(src)), symbol) |
| if solution is not None: |
| definition = solution[1] |
| unification_map[symbol] = definition |
|
|
| return new_guards_code |
|
|
|
|
| def _unlift_exported_program_lifted_states( |
| ep: ExportedProgram, check_guards=True |
| ) -> torch.fx.GraphModule: |
| |
| |
| |
| |
| frame = inspect.currentframe() |
| while frame is not None: |
| if "executorch" in frame.f_code.co_filename: |
| check_guards = False |
| break |
| frame = frame.f_back |
|
|
| |
| if ep.verifiers[0].dialect != "TRAINING": |
| ep = _remove_effect_tokens(ep) |
|
|
| new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) |
| _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) |
| forward_arg_names = ( |
| sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None |
| ) |
| lifted_inputs: list[Optional[str]] = [ |
| ( |
| in_spec.target |
| if in_spec.kind |
| in ( |
| InputKind.BUFFER, |
| InputKind.CONSTANT_TENSOR, |
| InputKind.PARAMETER, |
| InputKind.CUSTOM_OBJ, |
| ) |
| else None |
| ) |
| for in_spec in ep.graph_signature.input_specs |
| ] |
|
|
| mutated_outputs: list[Optional[str]] = [ |
| ( |
| out_spec.target |
| if out_spec.kind |
| in ( |
| OutputKind.BUFFER_MUTATION, |
| OutputKind.USER_INPUT_MUTATION, |
| OutputKind.PARAMETER_MUTATION, |
| ) |
| else None |
| ) |
| for out_spec in ep.graph_signature.output_specs |
| ] |
|
|
| source_node_dict = { |
| node.name: node for node in ep.graph.nodes if node.op != "placeholder" |
| } |
| |
| placeholder_source_node_dict = { |
| node.target: node for node in ep.graph.nodes if node.op == "placeholder" |
| } |
| for node in new_gm.graph.nodes: |
| source_node = None |
| if node.op == "placeholder": |
| source_node = placeholder_source_node_dict.get(node.target) |
| else: |
| source_node = source_node_dict.get(node.name) |
| node.meta["from_node"] = [ |
| NodeSource( |
| source_node, |
| "ExportedProgram.module()", |
| NodeSourceAction.CREATE, |
| ) |
| ] |
|
|
| assert ep.call_spec.in_spec is not None |
| new_gm = _unlift( |
| new_gm, |
| lifted_inputs, |
| mutated_outputs, |
| ep.call_spec.in_spec, |
| ep.call_spec.out_spec, |
| forward_arg_names=forward_arg_names, |
| ) |
| unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep) |
| unlift_gm.meta.update(ep.graph_module.meta) |
|
|
| |
| graph = unlift_gm.graph |
| placeholders = graph.find_nodes(op="placeholder") |
| if check_guards and placeholders and ep.example_inputs: |
| input_paths = _get_input_paths( |
| ep.example_inputs, |
| inspect.signature(unlift_gm.forward), |
| ) |
| guards_code = _get_input_guards_for_graph( |
| placeholders, ep.range_constraints, input_paths |
| ) |
| guards_code.extend(ep._guards_code) |
| unlift_gm._guards_fn = _convert_guards_code_to_fn(guards_code, input_paths) |
|
|
| root_nn_module_stack = torch.fx._utils.first_call_function_nn_module_stack( |
| graph |
| ) |
| with graph.inserting_after(placeholders[-1]): |
| node = graph.call_module("_guards_fn", tuple(placeholders)) |
| node.meta["nn_module_stack"] = root_nn_module_stack |
|
|
| unlift_gm.recompile() |
|
|
| return unlift_gm |
|
|
|
|
| class GuardsFn(torch.nn.Module): |
| """ |
| Module class for guard functions. |
| """ |
|
|
| def forward(self, *args): |
| pass |
|
|