| """ |
| Core guard system for Dynamo that detects when compiled code needs to be recompiled due to |
| changes in program state. Guards are conditions that must remain true for previously-compiled |
| code to be valid for reuse. |
| |
| This module provides the infrastructure for creating, managing and checking guards, including: |
| - Guard creation and composition |
| - Guard state management and invalidation |
| - Guard checking and failure handling |
| - Utilities for guard optimization and debugging |
| - Integration with Dynamo's compilation caching |
| |
| The guard system is critical for Dynamo's ability to efficiently reuse compiled code while |
| maintaining correctness by detecting when recompilation is necessary due to changes in |
| program state, tensor properties, or control flow. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import ast |
| import builtins |
| import collections |
| import dataclasses |
| import enum |
| import functools |
| import importlib |
| import inspect |
| import io |
| import logging |
| import math |
| import pickle |
| import sys |
| import textwrap |
| import traceback |
| import types |
| import warnings |
| import weakref |
| from contextlib import contextmanager |
| from copy import deepcopy |
| from inspect import currentframe |
| from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union |
|
|
|
|
| try: |
| from typing import LiteralString |
| except ImportError: |
| from typing_extensions import LiteralString |
|
|
| from typing_extensions import TypeAliasType, TypeVar |
| from weakref import ReferenceType |
|
|
| import torch |
| import torch.overrides |
| import torch.utils._device |
| from torch._C._dynamo.eval_frame import code_framelocals_names |
| from torch._C._dynamo.guards import ( |
| check_obj_id, |
| check_type_id, |
| ClosureGuardAccessor, |
| CodeGuardAccessor, |
| dict_version, |
| DictGetItemGuardAccessor, |
| DictGuardManager, |
| FuncDefaultsGuardAccessor, |
| FuncKwDefaultsGuardAccessor, |
| GetAttrGuardAccessor, |
| GetGenericDictGuardAccessor, |
| GuardAccessor, |
| GuardDebugInfo, |
| GuardManager, |
| install_no_tensor_aliasing_guard, |
| install_object_aliasing_guard, |
| install_storage_overlapping_guard, |
| install_symbolic_shape_guard, |
| LeafGuard, |
| profile_guard_manager, |
| RelationalGuard, |
| RootGuardManager, |
| TupleGetItemGuardAccessor, |
| TypeDictGuardAccessor, |
| TypeGuardAccessor, |
| TypeMROGuardAccessor, |
| ) |
| from torch._dynamo.source import ( |
| get_global_source_name, |
| get_local_source_name, |
| IndexedSource, |
| is_from_flatten_script_object_source, |
| is_from_local_source, |
| is_from_optimizer_source, |
| is_from_skip_guard_source, |
| is_from_unspecialized_builtin_nn_module_source, |
| TensorProperty, |
| TensorPropertySource, |
| ) |
| from torch._dynamo.utils import CompileEventLogger, get_metrics_context |
| from torch._guards import ( |
| CompileContext, |
| CompileId, |
| DuplicateInputs, |
| Guard, |
| GuardBuilderBase, |
| GuardEnvExpr, |
| GuardSource, |
| Source, |
| StorageOverlap, |
| ) |
| from torch._inductor.utils import IndentedBuffer |
| from torch._logging import structured |
| from torch._utils_internal import justknobs_check |
| from torch.fx.experimental.symbolic_shapes import ( |
| _CppShapeGuardsHelper, |
| _ShapeGuardsHelper, |
| EqualityConstraint, |
| is_symbolic, |
| SYMPY_INTERP, |
| ) |
| from torch.utils import _pytree as pytree |
| from torch.utils._ordered_set import OrderedSet |
| from torch.utils._traceback import format_frame, report_compile_source_on_error |
| from torch.utils.weak import TensorWeakRef |
|
|
| from . import config, convert_frame, exc |
| from .eval_frame import set_guard_error_hook |
| from .source import ( |
| AttrProxySource, |
| AttrSource, |
| CallFunctionNoArgsSource, |
| CallMethodItemSource, |
| ChainedSource, |
| ClosureSource, |
| CodeSource, |
| ConstantSource, |
| ConstDictKeySource, |
| DataclassFieldsSource, |
| DefaultsSource, |
| DictGetItemSource, |
| DictSubclassGetItemSource, |
| FlattenScriptObjectSource, |
| FloatTensorSource, |
| FSDPNNModuleSource, |
| GenericAttrSource, |
| GetItemSource, |
| GlobalSource, |
| GlobalStateSource, |
| GlobalWeakRefSource, |
| GradSource, |
| ListGetItemSource, |
| LocalSource, |
| NamedTupleFieldsSource, |
| NNModuleSource, |
| NonSerializableSetGetItemSource, |
| NumpyTensorSource, |
| OptimizerSource, |
| ScriptObjectQualifiedNameSource, |
| ShapeEnvSource, |
| SubclassAttrListSource, |
| TorchFunctionModeStackSource, |
| TorchSource, |
| TupleIteratorGetItemSource, |
| TypeDictSource, |
| TypeMROSource, |
| TypeSource, |
| UnspecializedBuiltinNNModuleSource, |
| UnspecializedNNModuleSource, |
| UnspecializedParamBufferSource, |
| WeakRefCallSource, |
| ) |
| from .types import ( |
| CacheEntry, |
| DynamoFrameType, |
| ExtraState, |
| GuardedCode, |
| GuardFail, |
| GuardFilterEntry, |
| GuardFn, |
| ) |
| from .utils import ( |
| builtin_dict_keys, |
| common_constant_types, |
| dataclass_fields, |
| dict_keys, |
| get_custom_getattr, |
| get_torch_function_mode_stack, |
| get_torch_function_mode_stack_at, |
| guard_failures, |
| istype, |
| key_is_id, |
| key_to_id, |
| normalize_range_iter, |
| orig_code_map, |
| tensor_always_has_static_shape, |
| tuple_iterator_getitem, |
| tuple_iterator_len, |
| unpatched_nn_module_getattr, |
| verify_guard_fn_signature, |
| ) |
|
|
|
|
| guard_manager_testing_hook_fn: Optional[Callable[[Any, Any, Any], Any]] = None |
|
|
| try: |
| import numpy as np |
| except ModuleNotFoundError: |
| np = None |
|
|
|
|
| if TYPE_CHECKING: |
| from collections.abc import Generator, KeysView, Sequence |
|
|
| from sympy import Symbol |
|
|
| from torch._C import DispatchKeySet |
| from torch._dynamo.output_graph import OutputGraph, OutputGraphGuardsState |
|
|
| T = TypeVar("T") |
| log = logging.getLogger(__name__) |
| guards_log = torch._logging.getArtifactLogger(__name__, "guards") |
| recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") |
| recompiles_verbose_log = torch._logging.getArtifactLogger( |
| __name__, "recompiles_verbose" |
| ) |
| verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") |
|
|
|
|
| dunder_attrs_assumed_constants = ( |
| "__defaults__", |
| "__kwdefaults__", |
| "__code__", |
| "__closure__", |
| "__annotations__", |
| "__func__", |
| "__mro__", |
| ) |
|
|
|
|
| class IndentedBufferWithPrefix(IndentedBuffer): |
| def prefix(self) -> str: |
| return "| " * (self._indent * self.tabwidth) |
|
|
| def writeline(self, line: str, skip_prefix: bool = False) -> None: |
| if skip_prefix: |
| super().writeline(line) |
| else: |
| super().writeline("+- " + line) |
|
|
|
|
| class GuardManagerWrapper: |
| """ |
| A helper class that contains the root guard manager. An instance of this |
| class is stored in the Dynamo cache entry, so that the cache entry can |
| access the RootGuardManager stored in the "root" attribute and directly call |
| the check_nopybind from C++. |
| """ |
|
|
| def __init__(self, root: Optional[RootGuardManager] = None) -> None: |
| if root is None: |
| self.root = RootGuardManager() |
| else: |
| self.root = root |
|
|
| self.diff_guard_root: Optional[RootGuardManager] = None |
| self.closure_vars: Optional[dict[str, Any]] = None |
| self.args: Optional[list[str]] = None |
| self.code_parts: list[str] = [] |
| self.verbose_code_parts: Optional[list[str]] = None |
| self.global_scope: Optional[dict[str, Any]] = None |
| self.guard_fail_fn: Optional[Callable[[GuardFail], None]] = None |
| self.cache_entry: Optional[CacheEntry] = None |
| self.extra_state: Optional[ExtraState] = None |
| self.id_matched_objs: dict[str, ReferenceType[object]] = {} |
| self.no_tensor_aliasing_sources: list[str] = [] |
|
|
| self.printed_relational_guards: set[RelationalGuard] = set() |
|
|
| self.diff_guard_sources: OrderedSet[str] = OrderedSet() |
|
|
| @contextmanager |
| def _preserve_printed_relational_guards(self) -> Generator[None, None, None]: |
| self.printed_relational_guards = set() |
| try: |
| yield |
| finally: |
| self.printed_relational_guards = set() |
|
|
| |
| def collect_diff_guard_sources(self) -> OrderedSet[str]: |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| def visit_dict_manager(node: DictGuardManager) -> bool: |
| is_diff_guard_node = ( |
| node.get_source() in self.diff_guard_sources or node.fail_count() > 0 |
| ) |
| for idx, (key_mgr, val_mgr) in sorted( |
| node.get_key_value_managers().items() |
| ): |
| is_diff_guard_node |= visit(key_mgr) | visit(val_mgr) |
|
|
| if is_diff_guard_node: |
| self.diff_guard_sources.add(node.get_source()) |
|
|
| return is_diff_guard_node |
|
|
| def visit_manager(node: GuardManager) -> bool: |
| assert not isinstance(node, DictGuardManager) |
|
|
| is_diff_guard_node = ( |
| node.get_source() in self.diff_guard_sources or node.fail_count() > 0 |
| ) |
| for child_mgr in node.get_child_managers(): |
| is_diff_guard_node |= visit(child_mgr) |
|
|
| if is_diff_guard_node: |
| self.diff_guard_sources.add(node.get_source()) |
|
|
| return is_diff_guard_node |
|
|
| def visit(node: GuardManager) -> bool: |
| if node is None: |
| return False |
| if isinstance(node, DictGuardManager): |
| return visit_dict_manager(node) |
| return visit_manager(node) |
|
|
| visit(self.root) |
|
|
| return self.diff_guard_sources |
|
|
| def finalize(self) -> None: |
| if config.use_recursive_dict_tags_for_guards and justknobs_check( |
| "pytorch/compiler:use_recursive_dict_tags_for_guards" |
| ): |
| self.find_tag_safe_roots() |
| self.prepare_diff_guard_manager() |
|
|
| def prepare_diff_guard_manager(self) -> None: |
| self.collect_diff_guard_sources() |
| self.populate_diff_guard_manager() |
|
|
| def find_tag_safe_roots(self) -> None: |
| """ |
| Identify ``tag safe nodes`` and ``tag safe roots`` within a guard tree. |
| |
| ----------------------------------------------------------------------- |
| tag safe node |
| ----------------------------------------------------------------------- |
| A *tag safe node* is a ``GuardManager`` whose guarded value satisfies one |
| of the following conditions: |
| |
| 1. Immutable value - The value is intrinsically immutable according to |
| ``is_immutable_object``. Tensors are considered immutable. To ensure |
| that symbolic guards run, we also check that the GuardManager has no |
| accessors. |
| |
| 2. Nested tag safe dictionary - The value is a ``dict`` whose keys and |
| values are all tag safe nodes (checked recursively). Such dictionaries |
| allow entire nested structures to be skipped once their identity tag |
| matches. |
| |
| 3. Pure ``nn.Module`` - The value is an ``nn.Module`` whose sole |
| accessor is ``GetGenericDictGuardAccessor``—i.e., it only exposes its |
| ``__dict__`` and nothing else that could mutate between runs. |
| |
| For every tag safe node, verifying the identity/tag of just the top-level |
| dictionary is enough to guarantee the entire subtree is unchanged, enabling |
| a *fast-path* guard check. |
| |
| ----------------------------------------------------------------------- |
| tag safe root |
| ----------------------------------------------------------------------- |
| A ``tag safe root`` is a tag safe node whose parent is not tag safe. |
| These boundary nodes mark the points where guard evaluation can safely |
| prune traversal: if a tag-safe root’s dictionary tag matches, the entire |
| subtree beneath it is skipped. |
| |
| One strong requirement for tag safe root is for the guarded object to |
| support weakref. Refer to more details in the Recursive dict tag |
| matching note. In short, we need to save the weakref of the object on |
| first invocation, and check if it is still valid in later iterations, to |
| apply recursive dict tag optimizations. `dict` objects do NOT support |
| weakref. Therefore, as of now, we only mark nn module related guard |
| managers as tag safe roots. |
| |
| Algorithm |
| --------- |
| The search runs in post-order traversal |
| |
| 1. Visit leaves and classify them as tag safe or not. |
| 2. Propagate tag-safety upward: a parent dictionary becomes tag safe only if |
| all of its children are already tag-safe. |
| 3. Propagate tag-safe-rootness upward: if the whole subtree is tag safe, |
| the current node becomes the new tag safe root, otherwise propagate the |
| subtree tag safe roots. |
| 4. Collect every tag safe node and, by inspecting parent tags, label the |
| subset that are tag safe roots. |
| """ |
|
|
| def check_tag_safety( |
| node: GuardManager, accepted_accessors: tuple[type[GuardAccessor], ...] |
| ) -> bool: |
| accessors = node.get_accessors() |
| child_mgrs = node.get_child_managers() |
| return all( |
| isinstance(accessor, accepted_accessors) and mgr.is_tag_safe() |
| for accessor, mgr in zip(accessors, child_mgrs) |
| ) |
|
|
| def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]: |
| |
| |
| assert issubclass(node.get_type_of_guarded_value(), dict) |
|
|
| tag_safe_roots = [] |
| is_subtree_tag_safe = True |
|
|
| |
| for idx, (key_mgr, val_mgr) in sorted( |
| node.get_key_value_managers().items() |
| ): |
| if key_mgr is not None: |
| visit(key_mgr) |
| if val_mgr is not None: |
| tag_safe_roots.extend(visit(val_mgr)) |
|
|
| for idx, (key_mgr, val_mgr) in sorted( |
| node.get_key_value_managers().items() |
| ): |
| if key_mgr: |
| is_subtree_tag_safe &= key_mgr.is_tag_safe() |
|
|
| if val_mgr: |
| is_subtree_tag_safe &= val_mgr.is_tag_safe() |
|
|
| if is_subtree_tag_safe: |
| node.mark_tag_safe() |
| return tag_safe_roots |
|
|
| def visit_manager(node: GuardManager) -> list[GuardManager]: |
| assert not isinstance(node, DictGuardManager) |
|
|
| |
| tag_safe_roots = [] |
| for child_mgr in node.get_child_managers(): |
| tag_safe_roots.extend(visit(child_mgr)) |
|
|
| if node.is_guarded_value_immutable(): |
| |
| |
| |
| if issubclass(node.get_type_of_guarded_value(), torch.Tensor): |
| if node.has_no_accessors() and not node.has_object_aliasing_guard(): |
| node.mark_tag_safe() |
| else: |
| node.mark_tag_safe() |
| elif issubclass(node.get_type_of_guarded_value(), dict): |
| accessors = node.get_accessors() |
| child_mgrs = node.get_child_managers() |
| is_subtree_tag_safe = all( |
| isinstance(accessor, DictGetItemGuardAccessor) and mgr.is_tag_safe() |
| for accessor, mgr in zip(accessors, child_mgrs) |
| ) |
| if is_subtree_tag_safe: |
| node.mark_tag_safe() |
| elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module): |
| is_subtree_tag_safe = check_tag_safety( |
| node, (GetGenericDictGuardAccessor, TypeGuardAccessor) |
| ) |
| if is_subtree_tag_safe: |
| node.mark_tag_safe() |
| |
| |
| return [ |
| node, |
| ] |
| elif ( |
| node.get_type_of_guarded_value() |
| in ( |
| types.FunctionType, |
| types.MethodType, |
| staticmethod, |
| classmethod, |
| ) |
| and config.assume_dunder_attributes_remain_unchanged |
| ): |
| |
| |
| |
| |
| |
| |
| is_subtree_tag_safe = check_tag_safety( |
| node, |
| ( |
| CodeGuardAccessor, |
| ClosureGuardAccessor, |
| FuncDefaultsGuardAccessor, |
| FuncKwDefaultsGuardAccessor, |
| GetAttrGuardAccessor, |
| ), |
| ) |
|
|
| for accessor in node.get_accessors(): |
| if isinstance(accessor, GetAttrGuardAccessor): |
| is_subtree_tag_safe &= ( |
| accessor.get_attr_name() in dunder_attrs_assumed_constants |
| ) |
|
|
| if is_subtree_tag_safe: |
| node.mark_tag_safe() |
| elif issubclass(node.get_type_of_guarded_value(), types.CellType): |
| is_subtree_tag_safe = check_tag_safety(node, (GetAttrGuardAccessor,)) |
|
|
| is_subtree_tag_safe &= all( |
| isinstance(accessor, GetAttrGuardAccessor) |
| and accessor.get_attr_name() == "cell_contents" |
| for accessor in node.get_accessors() |
| ) |
| if is_subtree_tag_safe: |
| node.mark_tag_safe() |
| elif ( |
| issubclass(node.get_type_of_guarded_value(), tuple) |
| and node.get_source().endswith(dunder_attrs_assumed_constants) |
| and config.assume_dunder_attributes_remain_unchanged |
| ): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| is_subtree_tag_safe = check_tag_safety( |
| node, (TupleGetItemGuardAccessor,) |
| ) |
| if is_subtree_tag_safe: |
| node.mark_tag_safe() |
| elif issubclass(node.get_type_of_guarded_value(), type): |
| is_subtree_tag_safe = check_tag_safety( |
| node, (TypeDictGuardAccessor, TypeMROGuardAccessor) |
| ) |
| if is_subtree_tag_safe: |
| node.mark_tag_safe() |
|
|
| return tag_safe_roots |
|
|
| def visit(node: GuardManager) -> list[GuardManager]: |
| if node is None: |
| return [] |
| if isinstance(node, DictGuardManager): |
| return visit_dict_manager(node) |
| return visit_manager(node) |
|
|
| tag_safe_roots = visit(self.root) |
| for node in tag_safe_roots: |
| if issubclass(node.get_type_of_guarded_value(), torch.nn.Module): |
| node.mark_tag_safe_root() |
|
|
| def populate_diff_guard_manager(self) -> None: |
| self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources) |
|
|
| |
| |
| |
| |
| |
| |
| |
| if self.cache_entry: |
| self.cache_entry.update_diff_guard_root_manager() |
|
|
| def clone_with_chosen_sources( |
| self, chosen_sources: OrderedSet[str] |
| ) -> RootGuardManager: |
| def filter_fn(node_mgr: GuardManager) -> bool: |
| return node_mgr.get_source() in chosen_sources |
|
|
| return self.root.clone_manager(filter_fn) |
|
|
| def get_guard_lines(self, guard: LeafGuard) -> list[str]: |
| guard_name = guard.__class__.__name__ |
| parts = guard.verbose_code_parts() |
| parts = [guard_name + ": " + part for part in parts] |
| return parts |
|
|
| def get_manager_line( |
| self, guard_manager: GuardManager, accessor_str: Optional[str] = None |
| ) -> str: |
| source = guard_manager.get_source() |
| t = guard_manager.__class__.__name__ |
| s = t + ": source=" + source |
| if accessor_str: |
| s += ", " + accessor_str |
| s += f", type={guard_manager.get_type_of_guarded_value()}" |
| s += f", tag_safe=({guard_manager.is_tag_safe()}, {guard_manager.is_tag_safe_root()})" |
| return s |
|
|
| def construct_dict_manager_string( |
| self, mgr: DictGuardManager, body: IndentedBufferWithPrefix |
| ) -> None: |
| for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()): |
| body.writeline(f"KeyValueManager pair at index={idx}") |
| with body.indent(): |
| if key_mgr: |
| body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}") |
| self.construct_manager_string(key_mgr, body) |
|
|
| if val_mgr: |
| body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}") |
| self.construct_manager_string(val_mgr, body) |
|
|
| def construct_manager_string( |
| self, mgr: GuardManager, body: IndentedBufferWithPrefix |
| ) -> None: |
| with body.indent(): |
| for guard in mgr.get_leaf_guards(): |
| if isinstance(guard, RelationalGuard): |
| if guard not in self.printed_relational_guards: |
| self.printed_relational_guards.add(guard) |
| body.writelines(self.get_guard_lines(guard)) |
| else: |
| body.writelines( |
| [ |
| guard.__class__.__name__, |
| ] |
| ) |
| else: |
| body.writelines(self.get_guard_lines(guard)) |
|
|
| |
| if isinstance(mgr, DictGuardManager): |
| self.construct_dict_manager_string(mgr, body) |
|
|
| |
| for accessor, child_mgr in zip( |
| mgr.get_accessors(), mgr.get_child_managers() |
| ): |
| body.writeline( |
| self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}") |
| ) |
| self.construct_manager_string(child_mgr, body) |
|
|
| def __str__(self) -> str: |
| with self._preserve_printed_relational_guards(): |
| body = IndentedBufferWithPrefix() |
| body.tabwidth = 1 |
| body.writeline("", skip_prefix=True) |
| body.writeline("TREE_GUARD_MANAGER:", skip_prefix=True) |
| body.writeline("RootGuardManager") |
| self.construct_manager_string(self.root, body) |
| if hasattr(self.root, "get_epilogue_lambda_guards"): |
| for guard in self.root.get_epilogue_lambda_guards(): |
| body.writelines(self.get_guard_lines(guard)) |
| return body.getvalue() |
|
|
| def check(self, x: Any) -> bool: |
| |
| return self.root.check(x) |
|
|
| def check_verbose(self, x: Any) -> GuardDebugInfo: |
| |
| return self.root.check_verbose(x) |
|
|
| def populate_code_parts_for_debugging(self) -> None: |
| |
| relational_guards_seen = set() |
|
|
| def get_code_parts(leaf_guard: LeafGuard) -> list[str]: |
| code_parts = [] |
| for verbose_code_part in leaf_guard.verbose_code_parts(): |
| code_part = verbose_code_part.split("#")[0].rstrip() |
| code_parts.append(code_part) |
| return code_parts |
|
|
| def visit(mgr: GuardManager) -> None: |
| nonlocal relational_guards_seen |
| for guard in mgr.get_leaf_guards(): |
| if isinstance(guard, RelationalGuard): |
| if guard not in relational_guards_seen: |
| self.code_parts.extend(get_code_parts(guard)) |
| relational_guards_seen.add(guard) |
| else: |
| self.code_parts.extend(get_code_parts(guard)) |
|
|
| for child_mgr in mgr.get_child_managers(): |
| visit(child_mgr) |
|
|
| visit(self.root) |
|
|
|
|
| def from_numpy(a: Any) -> torch.Tensor: |
| |
| |
| |
| with torch.overrides._enable_torch_function(): |
| return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a |
|
|
|
|
| |
| @functools.cache |
| def uninteresting_files() -> set[str]: |
| import torch._dynamo.external_utils |
| import torch._dynamo.polyfills |
|
|
| mods = [torch._dynamo.external_utils, torch._dynamo.polyfills] |
|
|
| from torch._dynamo.polyfills.loader import POLYFILLED_MODULES |
|
|
| mods.extend(POLYFILLED_MODULES) |
|
|
| return {inspect.getfile(m) for m in mods} |
|
|
|
|
| _CLOSURE_VARS: Optional[dict[str, object]] = None |
|
|
|
|
| def _get_closure_vars() -> dict[str, object]: |
| global _CLOSURE_VARS |
| if _CLOSURE_VARS is None: |
| _CLOSURE_VARS = { |
| "___check_type_id": check_type_id, |
| "___check_obj_id": check_obj_id, |
| "___odict_getitem": collections.OrderedDict.__getitem__, |
| "___key_to_id": key_to_id, |
| "___dict_version": dict_version, |
| "___dict_contains": lambda a, b: dict.__contains__(b, a), |
| "___tuple_iterator_len": tuple_iterator_len, |
| "___normalize_range_iter": normalize_range_iter, |
| "___tuple_iterator_getitem": tuple_iterator_getitem, |
| "___dataclass_fields": dataclass_fields, |
| "___namedtuple_fields": lambda x: x._fields, |
| "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, |
| "__math_isnan": math.isnan, |
| "__numpy_isnan": None if np is None else np.isnan, |
| "inf": float("inf"), |
| "__load_module": importlib.import_module, |
| "utils_device": torch.utils._device, |
| "device": torch.device, |
| "___from_numpy": from_numpy, |
| "___as_tensor": torch._as_tensor_fullprec, |
| "torch": torch, |
| "inspect": inspect, |
| } |
| return _CLOSURE_VARS |
|
|
|
|
| def _ast_unparse(node: ast.AST) -> str: |
| return ast.unparse(node).replace("\n", "") |
|
|
|
|
| strip_function_call = torch._C._dynamo.strip_function_call |
|
|
|
|
| def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str: |
| extra = "" |
| if guard is not None: |
| if guard.user_stack: |
| for fs in reversed(guard.user_stack): |
| if fs.filename not in uninteresting_files(): |
| extra = f" # {format_frame(fs, line=True)}" |
| if len(extra) > 1024: |
| |
| |
| |
| |
| |
| extra = f" # {format_frame(fs)}" |
| break |
| elif guard.stack: |
| summary = guard.stack.summary() |
| if len(summary) > 0: |
| extra = f" # {format_frame(summary[-1])}" |
| else: |
| extra = " # <unknown>" |
| return f"{code_part:<60}{extra}" |
|
|
|
|
| def get_verbose_code_parts( |
| code_parts: Union[str, list[str]], |
| guard: Optional[Guard], |
| recompile_hint: Optional[str] = None, |
| ) -> list[str]: |
| if not isinstance(code_parts, list): |
| code_parts = [code_parts] |
|
|
| verbose_code_parts = [ |
| get_verbose_code_part(code_part, guard) for code_part in code_parts |
| ] |
| if recompile_hint: |
| verbose_code_parts = [ |
| f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts |
| ] |
|
|
| return verbose_code_parts |
|
|
|
|
| def convert_int_to_concrete_values(dim: Any) -> Optional[int]: |
| if dim is None: |
| return None |
| if not is_symbolic(dim): |
| return dim |
| else: |
| assert isinstance(dim, torch.SymInt) |
| return dim.node.maybe_as_int() |
|
|
|
|
| def convert_to_concrete_values(size_or_stride: list[Any]) -> list[Optional[int]]: |
| return [convert_int_to_concrete_values(dim) for dim in size_or_stride] |
|
|
|
|
| def get_tensor_guard_code_part( |
| value: torch.Tensor, |
| name: str, |
| sizes: list[Optional[int]], |
| strides: list[Optional[int]], |
| pytype: type, |
| dispatch_keys: DispatchKeySet, |
| ) -> str: |
| dispatch_key = ( |
| dispatch_keys | torch._C._dispatch_tls_local_include_set() |
| ) - torch._C._dispatch_tls_local_exclude_set() |
| dtype = value.dtype |
| device_index = value.device.index |
| requires_grad = value.requires_grad |
| guard_str = ( |
| f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, " |
| f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})" |
| ) |
| return guard_str |
|
|
|
|
| def get_key_index(dct: dict[Any, Any], key: Any) -> int: |
| |
| |
| |
| |
| return list(builtin_dict_keys(dct)).index(key) |
|
|
|
|
| def get_key_index_source(source: Any, index: Any) -> str: |
| return f"list(dict.keys({source}))[{index}]" |
|
|
|
|
| def raise_local_type_error(obj: Any) -> NoReturn: |
| raise TypeError( |
| f"Type {type(obj)} for object {obj} cannot be saved " |
| + "into torch.compile() package since it's defined in local scope. " |
| + "Please define the class at global scope (top level of a module)." |
| ) |
|
|
|
|
| def should_optimize_getattr_on_nn_module(value: Any) -> bool: |
| |
| |
| |
| return isinstance(value, torch.nn.Module) and ( |
| config.inline_inbuilt_nn_modules |
| or get_custom_getattr(value) is unpatched_nn_module_getattr |
| ) |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class NNModuleAttrAccessorInfo: |
| |
| |
|
|
| |
| present_in_generic_dict: bool = False |
|
|
| |
| l1_key: Optional[str] = None |
|
|
| |
| l2_key: Optional[str] = None |
|
|
|
|
| def getitem_on_dict_manager( |
| source: Union[DictGetItemSource, DictSubclassGetItemSource], |
| base_guard_manager: DictGuardManager, |
| base_example_value: Any, |
| example_value: Any, |
| guard_manager_enum: GuardManagerType, |
| ) -> GuardManager: |
| base_source_name = source.base.name() |
| if isinstance(source.index, ConstDictKeySource): |
| index = source.index.index |
| else: |
| assert isinstance(base_example_value, dict) |
| index = get_key_index(base_example_value, source.index) |
|
|
| key_source = get_key_index_source(base_source_name, index) |
|
|
| |
| |
| |
| |
| key_example_value = list(builtin_dict_keys(base_example_value))[index] |
| if isinstance(key_example_value, (int, str)): |
| value_source = f"{base_source_name}[{key_example_value!r}]" |
| else: |
| value_source = f"{base_source_name}[{key_source}]" |
| if not isinstance(source.index, ConstDictKeySource): |
| |
| |
| base_guard_manager.get_key_manager( |
| index=index, |
| source=key_source, |
| example_value=source.index, |
| guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
| ).add_equals_match_guard( |
| source.index, [f"{key_source} == {key_example_value!r}"] |
| ) |
|
|
| return base_guard_manager.get_value_manager( |
| index=index, |
| source=value_source, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
|
|
|
|
| def match_on_id_for_tensor(guard: Guard) -> bool: |
| source = guard.originating_source |
| |
| |
| if isinstance(source, NumpyTensorSource): |
| return False |
|
|
| if guard.is_specialized_nn_module(): |
| return True |
|
|
| return source.is_dict_key() and not isinstance(source, GradSource) |
|
|
|
|
| |
| |
| @dataclasses.dataclass |
| class GuardCodeList: |
| code_list: list[str] |
| guard: Guard |
|
|
|
|
| class GuardManagerType(enum.Enum): |
| GUARD_MANAGER = 1 |
| DICT_GUARD_MANAGER = 2 |
|
|
|
|
| @functools.cache |
| def code_framelocals_names_reversed_cached(code: types.CodeType) -> list[str]: |
| return list(reversed(code_framelocals_names(code))) |
|
|
|
|
| class GuardBuilder(GuardBuilderBase): |
| def __init__( |
| self, |
| f_code: types.CodeType, |
| id_ref: Callable[[object, str], int], |
| source_ref: Callable[[Source], str], |
| lookup_weakrefs: Callable[[object], Optional[weakref.ref[object]]], |
| local_scope: dict[str, object], |
| global_scope: dict[str, object], |
| guard_manager: GuardManagerWrapper, |
| check_fn_manager: CheckFunctionManager, |
| save_guards: bool = False, |
| runtime_global_scope: Optional[dict[str, object]] = None, |
| ) -> None: |
| self.f_code = f_code |
| self.id_ref = id_ref |
| self.source_ref = source_ref |
| self.lookup_weakrefs = lookup_weakrefs |
| self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope} |
| self.runtime_global_scope = runtime_global_scope or global_scope |
| self.scope["__builtins__"] = builtins.__dict__.copy() |
| for ( |
| name, |
| package_module, |
| ) in torch.package.package_importer._package_imported_modules.items(): |
| name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_") |
| |
| self.scope["__builtins__"][name] = package_module |
| |
| self.scope[name] = package_module |
| self.guard_manager = guard_manager |
|
|
| self.argnames: list[str] = [] |
| |
| self.code: list[GuardCodeList] = [] |
| |
| |
| |
| |
| self.shape_env_code: list[GuardCodeList] = [] |
|
|
| |
| |
| self.no_tensor_aliasing_names: list[str] = [] |
| self.no_tensor_aliasing_guard_managers: list[GuardManager] = [] |
|
|
| self.check_fn_manager: CheckFunctionManager = check_fn_manager |
|
|
| |
| |
| |
| |
| self.key_order_guarded_dict_ids = set() |
| assert self.check_fn_manager.output_graph is not None |
| for source in self.check_fn_manager.output_graph.guard_on_key_order: |
| self.key_order_guarded_dict_ids.add(id(self.get(source.name()))) |
|
|
| |
| |
| |
| self.id_matched_objs: dict[str, ReferenceType[object]] = {} |
|
|
| |
| self._cached_guard_managers: dict[str, GuardManager] = {} |
| self._cached_duplicate_input_guards: set[tuple[str, str]] = set() |
| self.object_aliasing_guard_codes: list[tuple[str, str]] = [] |
| self.save_guards = save_guards |
| self.guard_nn_modules = config.guard_nn_modules and justknobs_check( |
| "pytorch/compiler:guard_nn_modules" |
| ) |
| self.already_guarded_not_present_in_generic_dict: OrderedSet[ |
| tuple[str, str] |
| ] = OrderedSet() |
|
|
| def guard_on_dict_keys_and_ignore_order( |
| self, example_value: dict[Any, Any], guard: Guard |
| ) -> None: |
| dict_mgr = self.get_guard_manager(guard) |
| if isinstance(dict_mgr, DictGuardManager): |
| raise NotImplementedError( |
| "Not expecting a DictGuardManager. Seems like Dynamo incorrectly " |
| f"added the dict to tx.output.guard_on_key_order for {guard.name}" |
| ) |
|
|
| |
| dict_source = guard.originating_source.name() |
|
|
| |
| |
| |
| |
| for key in builtin_dict_keys(example_value): |
| value = example_value[key] |
| value_source = DictGetItemSource(guard.originating_source, index=key) |
| guard_manager_enum = self.get_guard_manager_type( |
| value_source, example_value |
| ) |
| dict_mgr.dict_getitem_manager( |
| key=key, |
| source=f"{dict_source}[{key!r}]", |
| example_value=value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
|
|
| def guard_on_dict_keys_and_order(self, value: dict[Any, Any], guard: Guard) -> None: |
| |
| |
| dict_mgr = self.get_guard_manager(guard) |
| if not isinstance(dict_mgr, DictGuardManager): |
| raise NotImplementedError( |
| "Expecting a DictGuardManager. Seems like Dynamo forgot " |
| f"to set the right guard manager enum for {guard.name}" |
| ) |
| assert isinstance(dict_mgr, DictGuardManager) |
|
|
| |
| |
| |
| |
| for idx, key in enumerate(builtin_dict_keys(value)): |
| key_source = get_key_index_source(guard.name, idx) |
| key_manager = dict_mgr.get_key_manager( |
| index=idx, |
| source=key_source, |
| example_value=key, |
| guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
| ) |
| if key_is_id(key): |
| |
| id_val = self.id_ref(key, key_source) |
| key_manager.add_id_match_guard( |
| id_val, |
| get_verbose_code_parts( |
| f"__check_obj_id({key_source}, {id_val})", guard |
| ), |
| ) |
| else: |
| |
| key_manager.add_equals_match_guard( |
| key, get_verbose_code_parts(f"{key_source} == {key!r}", guard) |
| ) |
|
|
| @staticmethod |
| def _get_generic_dict_manager_example_value(example_value: Any) -> Optional[Any]: |
| |
| |
| |
| |
| if ( |
| config.issue_3_13_0_warning |
| and sys.version_info >= (3, 13) |
| and sys.version_info < (3, 13, 1) |
| ): |
| warnings.warn( |
| "Guards may run slower on Python 3.13.0. Consider upgrading to Python 3.13.1+.", |
| RuntimeWarning, |
| ) |
| return None |
| return example_value |
|
|
| def getattr_on_nn_module( |
| self, |
| source: AttrSource, |
| base_guard_manager: GuardManager, |
| base_example_value: Any, |
| example_value: Any, |
| base_source_name: str, |
| source_name: str, |
| guard_manager_enum: GuardManagerType, |
| ) -> GuardManager: |
| """ |
| This tries to avoid calling the expensive nn module custom getattr method by |
| checking if the attribute is accessible via __dict__. For attributes that |
| are not accessible via __dict__ (like descriptors), we fallback to |
| PyObject_GetAttr. |
| |
| There are two cases that we optimize for |
| 1) attributes present directly in __dict__, e.g training. |
| 2) parameters/buffers/modules - they can be accessed via _parameters, |
| _buffers, _modules keys in __dict__. For example, mod.linear can be |
| accessed as mod.__dict__["_parameters"]["linear"] |
| |
| The most common and expensive case for nn module guards is of type |
| mod.submod1.submod2.submod3.training. We avoid the python getattr of nn |
| modules by going through the __dict__. |
| """ |
|
|
| def getitem_on_dict_mgr( |
| mgr: GuardManager, |
| key: Any, |
| source_name: str, |
| base_example_value: Any, |
| example_value: Any, |
| guard_manager_enum: GuardManagerType, |
| ) -> GuardManager: |
| if isinstance(mgr, DictGuardManager): |
| |
| |
| index = get_key_index(base_example_value, key) |
|
|
| |
| key_source = f"list(dict.keys({source_name}))[{index!r}]" |
| mgr.get_key_manager( |
| index=index, |
| source=key_source, |
| example_value=key, |
| guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
| ).add_equals_match_guard(key, [f"{key_source} == {key!r}"]) |
|
|
| |
| return mgr.get_value_manager( |
| index=index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| else: |
| return mgr.dict_getitem_manager( |
| key=key, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
|
|
| attr_name = source.member |
| mod_dict = base_example_value.__dict__ |
|
|
| all_class_attribute_names: set[str] = set() |
| for x in inspect.getmro(base_example_value.__class__): |
| all_class_attribute_names.update(x.__dict__.keys()) |
|
|
| accessor_info = NNModuleAttrAccessorInfo(False, None, None) |
|
|
| if attr_name in mod_dict: |
| accessor_info = NNModuleAttrAccessorInfo(True, attr_name, None) |
| elif "_parameters" in mod_dict and attr_name in mod_dict["_parameters"]: |
| accessor_info = NNModuleAttrAccessorInfo(True, "_parameters", attr_name) |
| elif "_buffers" in mod_dict and attr_name in mod_dict["_buffers"]: |
| accessor_info = NNModuleAttrAccessorInfo(True, "_buffers", attr_name) |
| elif ( |
| attr_name not in all_class_attribute_names |
| and "_modules" in mod_dict |
| and attr_name in mod_dict["_modules"] |
| ): |
| |
| accessor_info = NNModuleAttrAccessorInfo(True, "_modules", attr_name) |
|
|
| if not accessor_info.present_in_generic_dict: |
| |
| |
| return base_guard_manager.getattr_manager( |
| attr=source.member, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| else: |
| assert accessor_info.l1_key |
| l1_key = accessor_info.l1_key |
| l2_key = accessor_info.l2_key |
|
|
| |
| mod_dict_source = f"{base_source_name}.__dict__" |
| l1_source_name = l2_source_name = None |
| l1_value = l2_value = None |
| l1_guard_manager_enum = l2_guard_manager_enum = None |
| if l2_key: |
| l1_source = AttrSource(source.base, l1_key) |
| l1_source_name = l1_source.name() |
| l1_value = mod_dict[l1_key] |
| |
| |
| l1_guard_manager_enum = self.get_guard_manager_type(l1_source, l1_value) |
|
|
| l2_source_name = source_name |
| l2_value = example_value |
| l2_guard_manager_enum = self.get_guard_manager_type( |
| source, example_value |
| ) |
| else: |
| l1_source_name = source_name |
| l1_value = example_value |
| l1_guard_manager_enum = self.get_guard_manager_type( |
| source, example_value |
| ) |
|
|
| |
| |
| mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager( |
| source=mod_dict_source, |
| example_value=self._get_generic_dict_manager_example_value(mod_dict), |
| guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
| ) |
|
|
| l1_mgr = getitem_on_dict_mgr( |
| mgr=mod_generic_dict_manager, |
| key=l1_key, |
| source_name=l1_source_name, |
| base_example_value=mod_dict, |
| example_value=l1_value, |
| guard_manager_enum=l1_guard_manager_enum, |
| ) |
|
|
| if l2_key: |
| assert l2_source_name is not None and l2_guard_manager_enum is not None |
| return getitem_on_dict_mgr( |
| mgr=l1_mgr, |
| key=l2_key, |
| source_name=l2_source_name, |
| base_example_value=l1_value, |
| example_value=l2_value, |
| guard_manager_enum=l2_guard_manager_enum, |
| ) |
| return l1_mgr |
|
|
| def requires_key_order_guarding(self, source: Source) -> bool: |
| source_name = source.name() |
| if source_name == "": |
| return False |
| obj_id = id(self.get(source_name)) |
| return obj_id in self.key_order_guarded_dict_ids |
|
|
| def get_guard_manager_type( |
| self, |
| source: Source, |
| example_value: Optional[ |
| Union[KeysView[Any], set[Any], frozenset[Any], dict[Any, Any]] |
| ], |
| ) -> GuardManagerType: |
| guard_manager_enum = GuardManagerType.GUARD_MANAGER |
| if self.requires_key_order_guarding(source): |
| |
| if isinstance(example_value, dict_keys): |
| guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER |
| elif isinstance(example_value, (set, frozenset)): |
| |
| |
| |
| guard_manager_enum = GuardManagerType.GUARD_MANAGER |
| else: |
| assert isinstance(example_value, dict) |
| guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER |
| return guard_manager_enum |
|
|
| def manager_guards_on_keys(self, mgr_enum: GuardManagerType) -> bool: |
| return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER |
|
|
| def get_global_guard_manager(self) -> GuardManager: |
| return self.guard_manager.root.globals_dict_manager( |
| f_globals=self.runtime_global_scope, |
| source="G", |
| example_value=self.scope["G"], |
| guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
| ) |
|
|
| def get_guard_manager_from_source(self, source: Source) -> GuardManager: |
| root_guard_manager = self.guard_manager.root |
|
|
| example_value = None |
| source_name = source.name() |
|
|
| if source_name != "" and source_name in self._cached_guard_managers: |
| return self._cached_guard_managers[source_name] |
|
|
| if source_name != "": |
| example_value = self.get(source_name) |
|
|
| guard_manager_enum = self.get_guard_manager_type(source, example_value) |
|
|
| |
| base_source_name = None |
| base_example_value = None |
| base_guard_manager = None |
| base_guard_manager_enum = GuardManagerType.GUARD_MANAGER |
| if isinstance(source, ChainedSource): |
| base_source_name = source.base.name() |
| base_example_value = self.get(base_source_name) |
| base_guard_manager = self.get_guard_manager_from_source(source.base) |
| base_guard_manager_enum = self.get_guard_manager_type( |
| source.base, base_example_value |
| ) |
|
|
| |
| if istype(source, LocalSource): |
| |
| |
| |
| |
| |
| |
| framelocals_names_reversed = code_framelocals_names_reversed_cached( |
| self.f_code |
| ) |
| framelocals_idx = ( |
| len(framelocals_names_reversed) |
| - framelocals_names_reversed.index(source.local_name) |
| - 1 |
| ) |
| out = root_guard_manager.framelocals_manager( |
| key=(source.local_name, framelocals_idx), |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, GlobalSource): |
| |
| |
| |
| out = self.get_global_guard_manager().dict_getitem_manager( |
| key=source.global_name, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, GlobalWeakRefSource): |
| out = self.get_global_guard_manager().global_weakref_manager( |
| global_name=source.global_name, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, GlobalStateSource): |
| |
| |
| return root_guard_manager |
| elif istype(source, ShapeEnvSource): |
| return root_guard_manager |
| elif istype(source, TypeSource): |
| assert base_guard_manager |
| out = base_guard_manager.type_manager( |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, TypeDictSource): |
| assert base_guard_manager |
| out = base_guard_manager.type_dict_manager( |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, TypeMROSource): |
| assert base_guard_manager |
| out = base_guard_manager.type_mro_manager( |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype( |
| source, |
| ( |
| OptimizerSource, |
| NNModuleSource, |
| UnspecializedNNModuleSource, |
| UnspecializedBuiltinNNModuleSource, |
| FSDPNNModuleSource, |
| ), |
| ): |
| assert base_guard_manager |
| out = base_guard_manager |
| elif istype(source, TorchSource): |
| out = root_guard_manager.lambda_manager( |
| python_lambda=lambda _: torch, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, TorchFunctionModeStackSource): |
| out = root_guard_manager.lambda_manager( |
| python_lambda=lambda _: get_torch_function_mode_stack_at( |
| source._get_index() |
| ), |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, GradSource): |
| assert base_guard_manager |
| out = base_guard_manager.grad_manager( |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, GenericAttrSource): |
| assert base_guard_manager |
| out = base_guard_manager.generic_getattr_manager( |
| attr=source.member, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, (AttrSource, UnspecializedParamBufferSource)): |
| assert base_guard_manager |
| assert isinstance(source, AttrSource) |
| if should_optimize_getattr_on_nn_module(base_example_value): |
| assert base_source_name |
| out = self.getattr_on_nn_module( |
| source, |
| base_guard_manager, |
| base_example_value, |
| example_value, |
| base_source_name, |
| source_name, |
| guard_manager_enum, |
| ) |
| else: |
| out = base_guard_manager.getattr_manager( |
| attr=source.member, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)): |
| assert base_guard_manager |
| assert isinstance(base_example_value, (dict, collections.OrderedDict)) |
| assert isinstance(source, (DictGetItemSource, DictSubclassGetItemSource)) |
| if isinstance(base_guard_manager, DictGuardManager): |
| assert self.manager_guards_on_keys(base_guard_manager_enum) |
| out = getitem_on_dict_manager( |
| source, |
| base_guard_manager, |
| base_example_value, |
| example_value, |
| guard_manager_enum, |
| ) |
| else: |
| if isinstance(source.index, ConstDictKeySource): |
| raise RuntimeError( |
| "Expecting clean index here. Likely Dynamo forgot to mark" |
| " a dict as guard_on_key_order" |
| ) |
| out = base_guard_manager.dict_getitem_manager( |
| key=source.index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, TensorPropertySource): |
| out = getattr( |
| base_guard_manager, |
| f"tensor_property_{source.prop.name.lower()}_manager", |
| )( |
| idx=source.idx, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, IndexedSource): |
| assert base_guard_manager |
|
|
| out = base_guard_manager.indexed_manager( |
| idx=source.idx, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, ListGetItemSource): |
| assert base_guard_manager |
| out = base_guard_manager.list_getitem_manager( |
| key=source.index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, GetItemSource): |
| assert base_guard_manager |
| assert not isinstance( |
| base_example_value, (dict, collections.OrderedDict) |
| ), "Use DictGetItemSource" |
| if isinstance(base_example_value, list) and not source.index_is_slice: |
| out = base_guard_manager.list_getitem_manager( |
| key=source.index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif isinstance(base_example_value, tuple) and not source.index_is_slice: |
| out = base_guard_manager.tuple_getitem_manager( |
| key=source.index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| else: |
| index = source.index |
| if source.index_is_slice: |
| index = source.unpack_slice() |
| out = base_guard_manager.getitem_manager( |
| key=index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, DefaultsSource): |
| assert base_guard_manager |
| assert base_source_name |
| assert callable(base_example_value) |
| if not source.is_kw: |
| out = base_guard_manager.func_defaults_manager( |
| source=base_source_name, |
| example_value=base_example_value.__defaults__, |
| guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
| ).getitem_manager( |
| key=source.idx_key, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| else: |
| |
| kwdefaults = base_example_value.__kwdefaults__ |
| assert base_source_name is not None |
| kw_source = base_source_name + ".__kwdefaults__" |
|
|
| |
| dict_mgr = base_guard_manager.func_kwdefaults_manager( |
| source=kw_source, |
| example_value=kwdefaults, |
| guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
| ) |
| assert not isinstance(dict_mgr, DictGuardManager) |
|
|
| out = dict_mgr.dict_getitem_manager( |
| key=source.idx_key, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, NumpyTensorSource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=from_numpy, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, SubclassAttrListSource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=lambda x: x.__tensor_flatten__()[0], |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, FlattenScriptObjectSource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=lambda x: x.__obj_flatten__(), |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, ScriptObjectQualifiedNameSource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=lambda x: x._type().qualified_name(), |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, AttrProxySource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=lambda x: x.get_base(), |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, CallMethodItemSource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=lambda x: x.item(), |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, FloatTensorSource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=lambda x: torch._as_tensor_fullprec(x), |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, TupleIteratorGetItemSource): |
| assert base_guard_manager |
| out = base_guard_manager.tuple_iterator_getitem_manager( |
| index=source.index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif isinstance(source, ConstDictKeySource): |
| if not isinstance(base_guard_manager, DictGuardManager): |
| raise AssertionError( |
| "ConstDictKeySource can only work on DictGuardManager" |
| ) |
| out = base_guard_manager.get_key_manager( |
| index=source.index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, NonSerializableSetGetItemSource): |
| assert base_guard_manager |
| out = base_guard_manager.set_getitem_manager( |
| index=source.index, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, WeakRefCallSource): |
| assert base_guard_manager |
| out = base_guard_manager.weakref_call_manager( |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, CallFunctionNoArgsSource): |
| assert base_guard_manager |
| out = base_guard_manager.call_function_no_args_manager( |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, DataclassFieldsSource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=lambda x: dataclass_fields(x), |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, NamedTupleFieldsSource): |
| assert base_guard_manager |
| out = base_guard_manager.lambda_manager( |
| python_lambda=lambda x: x._fields, |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, CodeSource): |
| assert base_guard_manager |
| out = base_guard_manager.code_manager( |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| elif istype(source, ClosureSource): |
| assert base_guard_manager |
| out = base_guard_manager.closure_manager( |
| source=source_name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| else: |
| raise AssertionError( |
| f"missing guard manager builder {source} - {source.name()}" |
| ) |
|
|
| self._cached_guard_managers[source.name()] = out |
| return out |
|
|
| def get_guard_manager(self, guard: Guard) -> GuardManager: |
| return self.get_guard_manager_from_source(guard.originating_source) |
|
|
| def add_python_lambda_leaf_guard_to_root( |
| self, |
| code_parts: list[str], |
| verbose_code_parts: list[str], |
| closure_vars: Optional[dict[str, object]] = None, |
| is_epilogue: bool = True, |
| ) -> None: |
| if closure_vars is None: |
| closure_vars = _get_closure_vars() |
| |
| |
| |
| make_guard_fn_args = ", ".join(closure_vars.keys()) |
| _guard_body, pycode = build_guard_function(code_parts, make_guard_fn_args) |
| out: dict[str, Any] = {} |
| globals_for_guard_fn = {"G": self.scope["G"]} |
| guards_log.debug("Python shape guard function:\n%s", pycode) |
| exec(pycode, globals_for_guard_fn, out) |
| guard_fn = out["___make_guard_fn"](*closure_vars.values()) |
| if is_epilogue: |
| |
| |
| |
| self.guard_manager.root.add_epilogue_lambda_guard( |
| guard_fn, verbose_code_parts |
| ) |
| else: |
| self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts) |
|
|
| |
| |
| |
| |
| |
| |
| def get(self, name: str, closure_vars: Optional[dict[str, Any]] = None) -> Any: |
| if closure_vars is None: |
| closure_vars = _get_closure_vars() |
| return eval(name, self.scope, closure_vars) |
|
|
| |
| |
| |
| |
| |
| def arg_ref(self, guard: Union[str, Guard]) -> str: |
| name: str |
| if isinstance(guard, str): |
| name = guard |
| else: |
| name = guard.name |
| base = strip_function_call(name) |
| if base not in self.argnames: |
| is_valid = torch._C._dynamo.is_valid_var_name(base) |
| if is_valid: |
| if is_valid == 2: |
| log.warning("invalid var name: %s", guard) |
| self.argnames.append(base) |
|
|
| return name |
|
|
| def _guard_on_attribute( |
| self, |
| guard: Guard, |
| attr_name: str, |
| guard_fn: Callable[[GuardBuilderBase, Guard], Any], |
| ) -> None: |
| if attr_name == "__code__": |
| attr_source = CodeSource(guard.originating_source) |
| else: |
| attr_source = AttrSource(guard.originating_source, attr_name) |
| |
| new_guard = Guard( |
| attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack |
| ) |
| new_guard.create(self) |
|
|
| |
| def HASATTR(self, guard: Guard) -> None: |
| source = guard.originating_source |
| if isinstance(source, NNModuleSource): |
| source = source.base |
| if isinstance(source, CodeSource): |
| |
| return |
| assert isinstance(source, AttrSource), f"invalid source {guard.name}" |
| base_source = source.base |
| base = base_source.name() |
| attr = source.member |
|
|
| ref = self.arg_ref(base) |
| val = hasattr(self.get(base), attr) |
| code = None |
| if val: |
| code = f"hasattr({ref}, {attr!r})" |
| else: |
| code = f"not hasattr({ref}, {attr!r})" |
| self._set_guard_export_info( |
| guard, [code], provided_guarded_object=self.get(base) |
| ) |
|
|
| base_manager = self.get_guard_manager_from_source(base_source) |
| if val: |
| |
| |
| example_value = self.get(source.name()) |
| base_example_value = self.get(base) |
| guard_manager_enum = self.get_guard_manager_type(source, example_value) |
|
|
| |
| |
| if should_optimize_getattr_on_nn_module(base_example_value): |
| self.getattr_on_nn_module( |
| source, |
| base_manager, |
| base_example_value, |
| example_value, |
| base, |
| source.name(), |
| guard_manager_enum, |
| ) |
| else: |
| base_manager.getattr_manager( |
| attr=attr, |
| source=guard.name, |
| example_value=example_value, |
| guard_manager_enum=guard_manager_enum, |
| ) |
| else: |
| base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard)) |
|
|
| def NOT_PRESENT_IN_GENERIC_DICT( |
| self, guard: Guard, attr: Optional[Any] = None |
| ) -> None: |
| assert attr is not None |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
|
|
| base_manager = self.get_guard_manager(guard) |
|
|
| if (ref, attr) in self.already_guarded_not_present_in_generic_dict: |
| return |
|
|
| mod_dict_source = f"{guard.name}.__dict__" |
| mod_generic_dict_manager = base_manager.get_generic_dict_manager( |
| source=mod_dict_source, |
| example_value=self._get_generic_dict_manager_example_value(val.__dict__), |
| guard_manager_enum=GuardManagerType.GUARD_MANAGER, |
| ) |
|
|
| code = f"not ___dict_contains({attr!r}, {ref}.__dict__)" |
| mod_generic_dict_manager.add_dict_contains_guard( |
| False, attr, get_verbose_code_parts(code, guard) |
| ) |
| self.already_guarded_not_present_in_generic_dict.add((ref, attr)) |
|
|
| def TYPE_MATCH(self, guard: Guard) -> None: |
| |
| value = self.get(guard.name) |
| if isinstance(value, torch._subclasses.FakeTensor) and value.pytype: |
| t = value.pytype |
| else: |
| t = type(value) |
|
|
| if t.__qualname__ != t.__name__: |
| |
| |
| guard._unserializable = True |
|
|
| obj_id = self.id_ref(t, f"type({guard.name})") |
| code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" |
| self._set_guard_export_info(guard, [code]) |
|
|
| self.get_guard_manager(guard).add_type_match_guard( |
| obj_id, get_verbose_code_parts(code, guard) |
| ) |
|
|
| def DICT_VERSION(self, guard: Guard) -> None: |
| |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| version = dict_version(self.get(guard.name)) |
| code = f"___dict_version({ref}) == {version}" |
| self._set_guard_export_info(guard, [code]) |
|
|
| |
| |
| self.get_guard_manager(guard).add_dict_version_guard( |
| val, get_verbose_code_parts(code, guard) |
| ) |
|
|
| def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool) -> None: |
| dict_ref = self.arg_ref(guard) |
|
|
| maybe_not = "not " if invert else "" |
| code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})" |
| self._set_guard_export_info(guard, [code]) |
|
|
| self.get_guard_manager(guard).add_dict_contains_guard( |
| not invert, key, get_verbose_code_parts(code, guard) |
| ) |
|
|
| def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None: |
| set_ref = self.arg_ref(guard) |
| item = key |
| contains = not invert |
|
|
| code = f"set.__contains__({set_ref}, {item!r})" |
|
|
| self._set_guard_export_info(guard, [code]) |
|
|
| self.get_guard_manager(guard).add_set_contains_guard( |
| contains, item, get_verbose_code_parts(code, guard) |
| ) |
|
|
| def BOOL_MATCH(self, guard: Guard) -> None: |
| |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| assert istype(val, bool) |
| code = [f"{ref} == {val!r}"] |
| self._set_guard_export_info(guard, code) |
|
|
| if val: |
| self.get_guard_manager(guard).add_true_match_guard( |
| get_verbose_code_parts(code, guard) |
| ) |
| else: |
| self.get_guard_manager(guard).add_false_match_guard( |
| get_verbose_code_parts(code, guard) |
| ) |
|
|
| def NONE_MATCH(self, guard: Guard) -> None: |
| |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| assert val is None |
| code = [f"{ref} is None"] |
| self._set_guard_export_info(guard, code) |
|
|
| self.get_guard_manager(guard).add_none_match_guard( |
| get_verbose_code_parts(code, guard) |
| ) |
|
|
| def ID_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: |
| return self.id_match_unchecked(guard, recompile_hint) |
|
|
| def id_match_unchecked( |
| self, guard: Guard, recompile_hint: Optional[str] = None |
| ) -> None: |
| |
| if isinstance(guard.originating_source, TypeSource): |
| |
| return self.TYPE_MATCH( |
| Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH) |
| ) |
|
|
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| id_val = self.id_ref(val, guard.name) |
| code = f"___check_obj_id({ref}, {id_val})" |
| self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH") |
| self.get_guard_manager(guard).add_id_match_guard( |
| id_val, get_verbose_code_parts(code, guard, recompile_hint) |
| ) |
|
|
| |
| |
| if isinstance(guard.originating_source, LocalSource): |
| |
| |
| |
| if isinstance(val, torch.nn.Module): |
| local_name = guard.originating_source.local_name |
| weak_id = self.lookup_weakrefs(val) |
| if weak_id is not None: |
| self.id_matched_objs[local_name] = weak_id |
|
|
| def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| assert isinstance(val, torch.Tensor) |
| code = f"{ref} is not None" |
| self._set_guard_export_info(guard, [code]) |
|
|
| self.get_guard_manager(guard).add_not_none_guard( |
| get_verbose_code_parts(code, guard) |
| ) |
|
|
| def DISPATCH_KEY_SET_MATCH(self, guard: Guard) -> None: |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| assert isinstance(val, torch._C.DispatchKeySet) |
| code_parts = f"{ref}.raw_repr() == {val!r}.raw_repr()" |
|
|
| self.get_guard_manager(guard).add_dispatch_key_set_guard( |
| val, get_verbose_code_parts(code_parts, guard) |
| ) |
|
|
| def NAME_MATCH(self, guard: Guard) -> None: |
| self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) |
|
|
| def DUAL_LEVEL(self, guard: Guard) -> None: |
| |
| |
| assert self.check_fn_manager.output_graph is not None |
| dual_level = self.check_fn_manager.output_graph.dual_level |
| code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] |
| self._set_guard_export_info(guard, code) |
| |
| forward_ad = torch.autograd.forward_ad |
|
|
| def fn(x: Any) -> bool: |
| return forward_ad._current_level == dual_level |
|
|
| self.guard_manager.root.add_lambda_guard( |
| fn, get_verbose_code_parts(code, guard) |
| ) |
|
|
| def FUNCTORCH_STACK_MATCH(self, guard: Guard) -> None: |
| |
| |
| assert self.check_fn_manager.output_graph is not None |
| cis = self.check_fn_manager.output_graph.functorch_layers |
| states = [ci.get_state() for ci in cis] |
| code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] |
| self._set_guard_export_info(guard, code) |
|
|
| |
| compare_fn = torch._functorch.pyfunctorch.compare_functorch_state |
|
|
| def fn(x: Any) -> bool: |
| return compare_fn(states) |
|
|
| self.guard_manager.root.add_lambda_guard( |
| fn, get_verbose_code_parts(code, guard) |
| ) |
|
|
| def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard) -> None: |
| get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks |
| are_inline_hooks = ( |
| torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable |
| ) |
|
|
| def hooks_ids_fn( |
| hooks: tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]], |
| ) -> Optional[tuple[int, ...]]: |
| if not are_inline_hooks(hooks): |
| return None |
|
|
| pack_hook, unpack_hook = hooks |
| return tuple(map(id, hooks)) |
|
|
| guard_hooks_ids = hooks_ids_fn(get_hooks()) |
|
|
| code = [ |
| f"torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == {guard_hooks_ids}" |
| ] |
| self._set_guard_export_info(guard, code) |
|
|
| def fn(x: Any) -> bool: |
| return guard_hooks_ids == hooks_ids_fn(get_hooks()) |
|
|
| self.guard_manager.root.add_lambda_guard( |
| fn, get_verbose_code_parts(code, guard) |
| ) |
|
|
| def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard) -> None: |
| value = self.get(guard.name) |
| original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) |
| if hasattr(value, "__metadata_guard__"): |
| verify_guard_fn_signature(value) |
|
|
| def metadata_checker(x: Any) -> bool: |
| return value.__metadata_guard__( |
| original_metadata, x.__tensor_flatten__()[1] |
| ) |
|
|
| else: |
|
|
| def metadata_checker(x: Any) -> bool: |
| return x.__tensor_flatten__()[1] == original_metadata |
|
|
| global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" |
| self.get_guard_manager(guard).add_lambda_guard( |
| metadata_checker, get_verbose_code_parts(global_name, guard) |
| ) |
|
|
| def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| if np: |
| np_types: tuple[type[Any], ...] = ( |
| np.int8, |
| np.int16, |
| np.int32, |
| np.int64, |
| np.uint8, |
| np.uint16, |
| np.uint32, |
| np.uint64, |
| np.float16, |
| np.float32, |
| np.float64, |
| ) |
| else: |
| np_types = () |
|
|
| ok_mutable_types = (list, set) |
|
|
| ok_types = tuple( |
| common_constant_types |
| | { |
| type, |
| tuple, |
| frozenset, |
| slice, |
| range, |
| dict_keys, |
| torch.Size, |
| *np_types, |
| *ok_mutable_types, |
| } |
| ) |
|
|
| if torch.distributed.is_available(): |
| from torch.distributed.device_mesh import DeviceMesh |
| from torch.distributed.tensor.placement_types import ( |
| _StridedShard, |
| Partial, |
| Replicate, |
| Shard, |
| ) |
|
|
| ok_types = ok_types + ( |
| Shard, |
| Replicate, |
| Partial, |
| DeviceMesh, |
| _StridedShard, |
| ) |
|
|
| from torch.export.dynamic_shapes import _IntWrapper |
|
|
| ok_types = ok_types + (_IntWrapper,) |
|
|
| import torch.utils._pytree as pytree |
|
|
| assert istype(val, ok_types) or pytree.is_constant_class(type(val)), ( |
| f"Unexpected type {type(val)}" |
| ) |
|
|
| |
| if istype(val, float) and math.isnan(val): |
| self.TYPE_MATCH(guard) |
| code = [] |
| code.append(f"__math_isnan({ref})") |
| self._set_guard_export_info(guard, code) |
|
|
| self.get_guard_manager(guard).add_lambda_guard( |
| _get_closure_vars()["__math_isnan"], |
| get_verbose_code_parts(code, guard), |
| ) |
| return |
|
|
| |
| if istype(val, complex) and np.isnan(val): |
| self.TYPE_MATCH(guard) |
| code = [] |
| code.append(f"__numpy_isnan({ref})") |
| self._set_guard_export_info(guard, code) |
|
|
| self.get_guard_manager(guard).add_lambda_guard( |
| _get_closure_vars()["__numpy_isnan"], |
| get_verbose_code_parts(code, guard), |
| ) |
| return |
|
|
| |
| code = [f"{ref} == {val!r}"] |
| if istype(val, ok_mutable_types): |
| |
| |
| |
| val = deepcopy(val) |
|
|
| verbose_code_parts = get_verbose_code_parts(code, guard) |
| if recompile_hint: |
| verbose_code_parts = [ |
| f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts |
| ] |
|
|
| self.get_guard_manager(guard).add_equals_match_guard(val, verbose_code_parts) |
| self._set_guard_export_info(guard, code) |
| return |
|
|
| def CONSTANT_MATCH(self, guard: Guard) -> None: |
| val = self.get(guard.name) |
| if istype(val, bool): |
| self.BOOL_MATCH(guard) |
| elif val is None: |
| self.NONE_MATCH(guard) |
| elif istype(val, types.CodeType): |
| self.ID_MATCH(guard) |
| else: |
| self.EQUALS_MATCH(guard) |
|
|
| def NN_MODULE(self, guard: Guard) -> None: |
| |
| self.ID_MATCH(guard, "[inline-inbuilt-nn-modules-candidate]") |
| val = self.get(guard.name) |
| if hasattr(val, "training"): |
| assert istype(val.training, bool) |
| if not self.guard_nn_modules: |
| |
| self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) |
| else: |
| exc.unimplemented_v2( |
| gb_type="Attempted to guard on uninitialized nn.Module", |
| context="", |
| explanation="Attempted to setup an NN_MODULE guard on uninitialized " |
| f"nn.Module subclass `{type(val)}`.", |
| hints=[ |
| "Ensure the `nn.Module` subclass instance has called `super().__init__()`.", |
| ], |
| ) |
|
|
| def FUNCTION_MATCH(self, guard: Guard) -> None: |
| """things like torch.add and user defined functions""" |
| |
| return self.ID_MATCH(guard) |
|
|
| def CLOSURE_MATCH(self, guard: Guard) -> None: |
| """matches a closure by __code__ id.""" |
| |
| val = self.get(guard.name) |
| |
| if type(val) == types.FunctionType and hasattr(val, "__code__"): |
| self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) |
| self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) |
| else: |
| self.FUNCTION_MATCH(guard) |
|
|
| def BUILTIN_MATCH(self, guard: Guard) -> None: |
| if self.save_guards: |
| |
| if isinstance(guard.originating_source, DictGetItemSource): |
| self.check_fn_manager.used_builtin_vars.add( |
| guard.originating_source.index |
| ) |
| return self.id_match_unchecked(guard) |
|
|
| return self.ID_MATCH(guard) |
|
|
| def SEQUENCE_LENGTH(self, guard: Guard) -> None: |
| |
| |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
|
|
| if not isinstance(value, dict): |
| |
| self.TYPE_MATCH(guard) |
|
|
| code = [] |
| if len(value) == 0: |
| code.append(f"not {ref}") |
| else: |
| code.append(f"len({ref}) == {len(value)}") |
|
|
| self._set_guard_export_info(guard, code) |
| if isinstance(value, dict): |
| self.get_guard_manager(guard).add_dict_length_check_guard( |
| len(value), get_verbose_code_parts(code, guard) |
| ) |
| else: |
| self.get_guard_manager(guard).add_length_check_guard( |
| len(value), get_verbose_code_parts(code, guard) |
| ) |
|
|
| def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None: |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
| t = type(value) |
|
|
| code = [] |
| code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}") |
| self._set_guard_export_info(guard, code) |
|
|
| t = type(value) |
| obj_id = self.id_ref(t, f"type({guard.name})") |
|
|
| self.get_guard_manager(guard).add_tuple_iterator_length_guard( |
| tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) |
| ) |
|
|
| def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None: |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
| t = type(value) |
|
|
| code = [] |
| normalized_range_iter = normalize_range_iter(value) |
| code.append(f"___normalize_range_iter({ref}) == {normalized_range_iter}") |
| self._set_guard_export_info(guard, code) |
|
|
| t = type(value) |
| obj_id = self.id_ref(t, f"type({guard.name})") |
|
|
| start, stop, step = normalized_range_iter |
| self.get_guard_manager(guard).add_range_iterator_match_guard( |
| start, stop, step, obj_id, get_verbose_code_parts(code, guard) |
| ) |
|
|
| |
| def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None: |
| if self.save_guards: |
| if name := get_local_source_name(source_b): |
| self.check_fn_manager.additional_used_local_vars.add(name) |
| if name := get_global_source_name(source_b): |
| self.check_fn_manager.additional_used_global_vars.add(name) |
|
|
| ref_a = self.arg_ref(guard) |
| ref_b = self.arg_ref(source_b.name()) |
|
|
| if is_from_optimizer_source( |
| guard.originating_source |
| ) or is_from_optimizer_source(source_b): |
| return |
|
|
| |
| key = (ref_a, ref_b) |
| if key in self._cached_duplicate_input_guards: |
| return |
|
|
| self._cached_duplicate_input_guards.add((ref_a, ref_b)) |
| self._cached_duplicate_input_guards.add((ref_b, ref_a)) |
|
|
| code = [f"{ref_b} is {ref_a}"] |
| self._set_guard_export_info(guard, code) |
|
|
| if config.use_lamba_guard_for_object_aliasing: |
| |
| |
| |
| code_part = code[0] |
| verbose_code_part = get_verbose_code_parts(code_part, guard)[0] |
| self.object_aliasing_guard_codes.append((code_part, verbose_code_part)) |
| else: |
| install_object_aliasing_guard( |
| self.get_guard_manager(guard), |
| self.get_guard_manager_from_source(source_b), |
| get_verbose_code_parts(code, guard), |
| ) |
|
|
| def WEAKREF_ALIVE(self, guard: Guard) -> None: |
| code = [f"{self.arg_ref(guard)} is not None"] |
|
|
| self._set_guard_export_info(guard, code) |
| self.get_guard_manager(guard).add_not_none_guard( |
| get_verbose_code_parts(code, guard) |
| ) |
|
|
| def MAPPING_KEYS_CHECK(self, guard: Guard) -> None: |
| """Guard on the key order of types.MappingProxyType object""" |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
|
|
| code = [] |
| code.append(f"list({ref}.keys()) == {list(value.keys())}") |
| self._set_guard_export_info(guard, code) |
| self.get_guard_manager(guard).add_mapping_keys_guard(value, code) |
|
|
| def DICT_KEYS_MATCH(self, guard: Guard) -> None: |
| """Insert guard to check that the keys of a dict are same""" |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
|
|
| if value is torch.utils._pytree.SUPPORTED_NODES: |
| |
| self.DICT_VERSION(guard) |
| return |
|
|
| self.SEQUENCE_LENGTH(guard) |
|
|
| code = [] |
| |
| |
| |
| |
| code.append(f"list(dict.keys({ref})) == {list(builtin_dict_keys(value))!r}") |
| self._set_guard_export_info(guard, code) |
|
|
| if self.requires_key_order_guarding(guard.originating_source): |
| self.guard_on_dict_keys_and_order(value, guard) |
| else: |
| self.guard_on_dict_keys_and_ignore_order(value, guard) |
|
|
| def EMPTY_NN_MODULE_HOOKS_DICT(self, guard: Guard) -> None: |
| """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" |
| if config.skip_nnmodule_hook_guards: |
| |
| return |
| self.SEQUENCE_LENGTH(guard) |
|
|
| def GRAD_MODE(self, guard: Guard) -> None: |
| pass |
|
|
| def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None: |
| pass |
|
|
| def TORCH_FUNCTION_STATE(self, guard: Guard) -> None: |
| pass |
|
|
| def FSDP_TRAINING_STATE(self, guard: Guard) -> None: |
| pass |
|
|
| def DEFAULT_DEVICE(self, guard: Guard) -> None: |
| """Guard on CURRENT_DEVICE per torch.utils._device""" |
| assert guard.source is GuardSource.GLOBAL |
|
|
| assert self.check_fn_manager.output_graph is not None |
| code = [ |
| f"utils_device.CURRENT_DEVICE == {self.check_fn_manager.output_graph.current_device!r}" |
| ] |
| self._set_guard_export_info(guard, code) |
|
|
| self.get_guard_manager(guard).add_default_device_guard( |
| get_verbose_code_parts(code, guard) |
| ) |
|
|
| def SHAPE_ENV(self, guard: Guard) -> None: |
| from torch._dynamo.output_graph import OutputGraph |
|
|
| assert guard.name == "" |
| output_graph = self.check_fn_manager.output_graph |
| assert output_graph is not None |
| if self.check_fn_manager.shape_code_parts is not None: |
| shape_code_parts = self.check_fn_manager.shape_code_parts |
| python_code_parts = shape_code_parts.python_code_parts |
| verbose_code_parts = shape_code_parts.verbose_code_parts |
| if shape_code_parts.cpp_code_parts is not None: |
| cpp_code_parts = shape_code_parts.cpp_code_parts |
| python_fallback = shape_code_parts.python_fallback |
| else: |
| |
| |
| |
| |
| assert isinstance(output_graph, OutputGraph) |
| fs = output_graph.tracked_fakes |
| input_contexts = [a.symbolic_context for a in fs] |
|
|
| def get_sources(t_id: int, dim: int) -> list[Source]: |
| |
| |
| return [ |
| TensorPropertySource(source, TensorProperty.SIZE, dim) |
| for source in output_graph.tracked_fakes_id_to_source[t_id] |
| ] |
|
|
| assert output_graph.shape_env is not None |
| if output_graph.export_constraints: |
| names: dict[str, tuple[int, int]] = {} |
| source_pairs: list[tuple[Source, Source]] = [] |
| derived_equalities: list[ |
| tuple[Source, Union[Source, Symbol], Callable] |
| ] = [] |
| phantom_symbols: dict[str, Symbol] = {} |
| relaxed_sources: set[Source] = set() |
| for constraint in output_graph.export_constraints: |
| if constraint.t_id in output_graph.tracked_fakes_id_to_source: |
| torch.export.dynamic_shapes._process_equalities( |
| constraint, |
| get_sources, |
| output_graph.shape_env, |
| names, |
| source_pairs, |
| derived_equalities, |
| phantom_symbols, |
| relaxed_sources, |
| ) |
| else: |
| log.warning("Untracked tensor used in export constraints") |
| equalities_inputs = EqualityConstraint( |
| source_pairs=source_pairs, |
| derived_equalities=derived_equalities, |
| phantom_symbols=list(phantom_symbols.values()), |
| relaxed_sources=relaxed_sources, |
| warn_only=False, |
| ) |
| else: |
| equalities_inputs = None |
|
|
| def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: |
| return output_graph.shape_env.produce_guards_verbose( |
| [a.fake for a in fs], |
| [a.source for a in fs], |
| input_contexts=input_contexts, |
| equalities_inputs=equalities_inputs, |
| source_ref=self.source_ref, |
| |
| ignore_static=(not output_graph.export), |
| langs=langs, |
| ) |
|
|
| if config.enable_cpp_symbolic_shape_guards: |
| try: |
| |
| python_code_parts, verbose_code_parts, cpp_code_parts = ( |
| _get_code_parts(("python", "verbose_python", "cpp")) |
| ) |
| python_fallback = False |
| except OverflowError: |
| |
| python_fallback = True |
| python_code_parts, verbose_code_parts = _get_code_parts( |
| ("python", "verbose_python") |
| ) |
| else: |
| python_fallback = True |
| python_code_parts, verbose_code_parts = _get_code_parts( |
| ("python", "verbose_python") |
| ) |
|
|
| |
| |
| if not output_graph.export: |
| output_graph.shape_env.freeze() |
|
|
| if self.save_guards: |
| |
| |
| maybe_cpp_code_parts = locals().get("cpp_code_parts") |
| assert maybe_cpp_code_parts is None or isinstance( |
| maybe_cpp_code_parts, _CppShapeGuardsHelper |
| ) |
| maybe_shape_env_sources = ( |
| [] |
| if maybe_cpp_code_parts is None |
| else list(maybe_cpp_code_parts.source_to_symbol.keys()) |
| ) |
| self.check_fn_manager.shape_code_parts = ShapeCodeParts( |
| python_code_parts=python_code_parts, |
| verbose_code_parts=verbose_code_parts, |
| cpp_code_parts=maybe_cpp_code_parts, |
| python_fallback=python_fallback, |
| shape_env_sources=maybe_shape_env_sources, |
| ) |
|
|
| for code in python_code_parts.exprs: |
| self._set_guard_export_info(guard, [code]) |
|
|
| |
| if compile_context := CompileContext.try_get(): |
| compile_context.shape_env_guards.extend(verbose_code_parts.exprs) |
|
|
| int_source_to_symbol = [] |
| float_source_to_symbol = [] |
|
|
| if not python_fallback: |
| assert cpp_code_parts |
| code_parts, source_to_symbol = ( |
| cpp_code_parts.exprs, |
| cpp_code_parts.source_to_symbol, |
| ) |
|
|
| if not code_parts: |
| return |
|
|
| for source, symbol in source_to_symbol.items(): |
| if isinstance(source, ConstantSource): |
| python_fallback = True |
| else: |
| example_value = self.get( |
| source.name(), |
| closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, |
| ) |
| if isinstance(example_value, int): |
| int_source_to_symbol.append((source, symbol)) |
| elif isinstance(example_value, float): |
| float_source_to_symbol.append((source, symbol)) |
| else: |
| |
| |
| python_fallback = True |
|
|
| if not python_fallback: |
| import ctypes |
|
|
| from torch._inductor.codecache import CppCodeCache |
|
|
| assert cpp_code_parts |
| code_parts, source_to_symbol = ( |
| cpp_code_parts.exprs, |
| cpp_code_parts.source_to_symbol, |
| ) |
|
|
| source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol) |
| try: |
| guard_managers = [ |
| self.get_guard_manager_from_source(IndexedSource(source, i)) |
| for i, source in enumerate(source_to_symbol) |
| ] |
|
|
| int_symbols_str = ", ".join( |
| f"{symbol} = int_values[{i}]" |
| for i, (_, symbol) in enumerate(int_source_to_symbol) |
| ) |
| float_symbols_str = ", ".join( |
| f"{symbol} = float_values[{i}]" |
| for i, (_, symbol) in enumerate(float_source_to_symbol) |
| ) |
|
|
| if int_symbols_str: |
| int_symbols_str = f"int64_t {int_symbols_str};" |
| if float_symbols_str: |
| float_symbols_str = f"double {float_symbols_str};" |
|
|
| func_str = textwrap.dedent( |
| f""" |
| #include <algorithm> |
| #include <cstdint> |
| #include <cmath> |
| #include <c10/util/generic_math.h> |
| |
| #if defined(_MSC_VER) |
| # define EXTERN_DLL_EXPORT extern "C" __declspec(dllexport) |
| #else |
| # define EXTERN_DLL_EXPORT extern "C" |
| #endif |
| |
| EXTERN_DLL_EXPORT int8_t guard(int64_t *int_values, double *float_values) {{ |
| {int_symbols_str} |
| {float_symbols_str} |
| return ({") && (".join(code_parts)}); |
| }} |
| """ |
| ) |
| guards_log.debug( |
| "C++ shape guard function: %s %s", |
| func_str, |
| verbose_code_parts.exprs, |
| ) |
| clib = CppCodeCache.load(func_str) |
| cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value |
| assert cguard |
| except torch._inductor.exc.InvalidCxxCompiler: |
| |
| pass |
| else: |
| install_symbolic_shape_guard( |
| guard_managers, |
| len(int_source_to_symbol), |
| len(float_source_to_symbol), |
| cguard, |
| clib, |
| verbose_code_parts.exprs, |
| ) |
| return |
|
|
| |
| |
| |
| if python_code_parts.exprs: |
| self.add_python_lambda_leaf_guard_to_root( |
| python_code_parts.exprs, |
| verbose_code_parts.exprs, |
| closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, |
| ) |
|
|
| def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: |
| if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module(): |
| return |
| |
| |
| |
| if match_on_id_for_tensor(guard): |
| self.ID_MATCH(guard) |
| else: |
| if isinstance(value, TensorWeakRef): |
| value = value() |
|
|
| value = value if value is not None else self.get(guard.name) |
|
|
| pytype = type(value) |
| dispatch_keys = torch._C._dispatch_keys(value) |
| if isinstance(value, torch._subclasses.FakeTensor): |
| if value.pytype is not None: |
| pytype = value.pytype |
| if value.dispatch_keys is not None: |
| dispatch_keys = value.dispatch_keys |
|
|
| assert isinstance(value, torch.Tensor) |
|
|
| if config.log_compilation_metrics and isinstance(value, torch.nn.Parameter): |
| metrics_context = get_metrics_context() |
| metrics_context.increment("param_numel", value.numel()) |
| metrics_context.increment("param_bytes", value.nbytes) |
| metrics_context.increment("param_count", 1) |
|
|
| tensor_name = self.arg_ref(guard) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| code: list[str] = [] |
| assert self.check_fn_manager.output_graph is not None |
| if self.check_fn_manager.output_graph.export: |
| self.TYPE_MATCH(guard) |
| terms = [ |
| "dtype", |
| "device", |
| "requires_grad", |
| "ndimension()", |
| ] |
|
|
| for term in terms: |
| real_value = self.get(tensor_name + "." + term) |
| if istype(real_value, (torch.device, torch.dtype)): |
| |
| code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}") |
| else: |
| code.append(f"{tensor_name}.{term} == {real_value}") |
| else: |
| guard_manager = self.get_guard_manager(guard) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if not ( |
| config.skip_no_tensor_aliasing_guards_on_parameters |
| and ( |
| istype(value, torch.nn.Parameter) |
| or is_from_unspecialized_builtin_nn_module_source( |
| guard.originating_source |
| ) |
| ) |
| ) and not isinstance(guard.originating_source, NumpyTensorSource): |
| |
| |
| self.no_tensor_aliasing_names.append(tensor_name) |
| self.no_tensor_aliasing_guard_managers.append(guard_manager) |
|
|
| output_graph = self.check_fn_manager.output_graph |
| metadata = output_graph.input_source_to_sizes_strides[ |
| guard.originating_source |
| ] |
| size = convert_to_concrete_values(metadata["size"]) |
| stride = convert_to_concrete_values(metadata["stride"]) |
|
|
| verbose_code_parts = get_verbose_code_parts( |
| get_tensor_guard_code_part( |
| value, |
| tensor_name, |
| size, |
| stride, |
| pytype, |
| dispatch_keys, |
| ), |
| guard, |
| ) |
| guard_manager.add_tensor_match_guard( |
| value, |
| size, |
| stride, |
| tensor_name, |
| verbose_code_parts, |
| pytype, |
| dispatch_keys, |
| ) |
|
|
| |
| |
| if not isinstance(value, torch.nn.Parameter): |
| self.guard_manager.diff_guard_sources.add(guard.name) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| assert guard.source is not None |
| static, _reason = tensor_always_has_static_shape( |
| value, is_tensor=True, tensor_source=guard.originating_source |
| ) |
|
|
| if not static: |
| if hasattr(value, "_dynamo_dynamic_indices"): |
| dynamic_indices = value._dynamo_dynamic_indices |
| code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" |
| code.append(code_part) |
| self.get_guard_manager(guard).add_dynamic_indices_guard( |
| dynamic_indices, get_verbose_code_parts(code_part, guard) |
| ) |
| |
| |
| else: |
| code_part = ( |
| f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" |
| ) |
| code.append(code_part) |
| self.get_guard_manager(guard).add_no_hasattr_guard( |
| "_dynamo_dynamic_indices", |
| get_verbose_code_parts(code_part, guard), |
| ) |
| if len(code) > 0: |
| self._set_guard_export_info(guard, code) |
|
|
| |
| def _set_guard_export_info( |
| self, |
| guard: Guard, |
| code_list: list[str], |
| provided_guarded_object: Optional[Any] = None, |
| provided_func_name: Optional[str] = None, |
| ) -> None: |
| |
| |
| |
| cur_frame = currentframe() |
| assert cur_frame is not None |
| caller = cur_frame.f_back |
| del cur_frame |
| assert caller is not None |
| func_name = provided_func_name or caller.f_code.co_name |
| del caller |
| |
| assert func_name in self.__class__.__dict__, ( |
| f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}" |
| ) |
|
|
| |
| if provided_guarded_object is None: |
| name = guard.name |
| guarded_object = None if not name else self.get(name) |
| else: |
| guarded_object = provided_guarded_object |
|
|
| guarded_object_type = ( |
| weakref.ref(type(guarded_object)) if guarded_object is not None else None |
| ) |
| obj_ref = None |
| |
| |
| supports_weakref = ( |
| getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0 |
| ) |
| |
| if supports_weakref and not isinstance( |
| guarded_object, (enum.Enum, tuple, weakref.ProxyTypes) |
| ): |
| obj_ref = weakref.ref(guarded_object) |
|
|
| guard.set_export_info( |
| func_name, |
| guarded_object_type, |
| code_list, |
| obj_ref, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| class PyExprCSEPass: |
| |
| |
| USE_THRESHOLD = 1 |
|
|
| |
| ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript) |
|
|
| @dataclasses.dataclass |
| class Config: |
| expr_count: dict[str, int] |
| expr_to_name: dict[str, str] |
|
|
| class ExprCounter(ast.NodeVisitor): |
| def __init__(self, config: PyExprCSEPass.Config) -> None: |
| self._config = config |
|
|
| def visit(self, node: ast.AST) -> None: |
| if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): |
| self._config.expr_count[_ast_unparse(node)] += 1 |
| super().visit(node) |
|
|
| class Replacer(ast.NodeTransformer): |
| def __init__( |
| self, |
| config: PyExprCSEPass.Config, |
| gen_name: Callable[[], str], |
| ) -> None: |
| super().__init__() |
| self._config = config |
| self._gen_name = gen_name |
| self.preface: list[str] = [] |
|
|
| def visit(self, node: ast.AST) -> Any: |
| if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): |
| expr = _ast_unparse(node) |
|
|
| |
| |
| if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD: |
| if expr not in self._config.expr_to_name: |
| |
| |
| |
| |
| |
| |
| |
| node_ = super().visit(node) |
| expr_ = _ast_unparse(node_) |
| var_name = self._gen_name() |
| self.preface.append(f"{var_name} = {expr_}") |
| self._config.expr_to_name[expr] = var_name |
| else: |
| var_name = self._config.expr_to_name[expr] |
| return ast.Name(var_name, ast.Load()) |
|
|
| return super().visit(node) |
|
|
| def __init__(self) -> None: |
| self._counter = 0 |
| self._config = self.Config( |
| expr_count=collections.defaultdict(lambda: 0), expr_to_name={} |
| ) |
|
|
| def _new_var(self, prefix: str = "_var") -> str: |
| name = f"{prefix}{self._counter}" |
| self._counter += 1 |
| return name |
|
|
| def count(self, exprs: list[str]) -> None: |
| counter = self.ExprCounter(self._config) |
| for e in exprs: |
| try: |
| counter.visit(ast.parse(e)) |
| except SyntaxError as ex: |
| log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e) |
| raise |
|
|
| def replace(self, expr: str) -> tuple[list[str], str]: |
| replacer = self.Replacer(self._config, self._new_var) |
| new_node = replacer.visit(ast.parse(expr)) |
| return replacer.preface, _ast_unparse(new_node) |
|
|
|
|
| def must_add_nn_module_guards(guard: Guard) -> bool: |
| |
| |
| return ( |
| |
| isinstance(guard.originating_source, DefaultsSource) |
| |
| or ( |
| config.guard_nn_modules_using_dict_tags |
| and guard.create_fn is GuardBuilder.NN_MODULE |
| ) |
| ) |
|
|
|
|
| class DeletedGuardManagerWrapper(GuardManagerWrapper): |
| def __init__(self, reason: str) -> None: |
| super().__init__() |
| self.invalidation_reason = reason |
|
|
| def populate_diff_guard_manager(self) -> None: |
| self.diff_guard_root = None |
|
|
|
|
| @dataclasses.dataclass |
| class ShapeCodeParts: |
| python_code_parts: _ShapeGuardsHelper |
| verbose_code_parts: _ShapeGuardsHelper |
| cpp_code_parts: Optional[_CppShapeGuardsHelper] |
| python_fallback: bool |
| shape_env_sources: list[Source] |
|
|
|
|
| @dataclasses.dataclass |
| class GuardsState: |
| output_graph: OutputGraphGuardsState |
| shape_code_parts: Optional[ShapeCodeParts] |
|
|
|
|
| class _Missing: |
| pass |
|
|
|
|
| class GuardsStatePickler(pickle.Pickler): |
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| super().__init__(*args, **kwargs) |
| self.fake_mode = torch._subclasses.FakeTensorMode() |
| self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() |
|
|
| @classmethod |
| def _unpickle_module(cls, state: Any) -> torch.nn.Module: |
| mod = torch.nn.Module() |
| mod.__setstate__(state) |
| return mod |
|
|
| @classmethod |
| def _unpickle_tensor( |
| cls, |
| meta_tensor: torch.Tensor, |
| device: torch.device, |
| pytype: type, |
| dispatch_keys_raw: int, |
| grad: torch.Tensor, |
| ) -> torch.Tensor: |
| fake_mode = torch._subclasses.FakeTensorMode() |
| tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() |
| ret = tensor_converter.from_meta_and_device( |
| fake_mode, |
| meta_tensor, |
| device, |
| pytype, |
| torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw), |
| ) |
| ret.grad = grad |
| return ret |
|
|
| @classmethod |
| def _unpickle_traceable_wrapper_subclass( |
| cls, |
| meta_tensor: torch.Tensor, |
| device: torch.device, |
| pytype: type, |
| dispatch_keys_raw: int, |
| ctx: Any, |
| inner_data: list[tuple[str, Callable[..., Any], tuple[Any, ...]]], |
| ) -> torch.Tensor: |
| |
| inner_tensors = {} |
| for attr, unpickle_func, unpickle_func_args in inner_data: |
| inner_tensors[attr] = unpickle_func(*unpickle_func_args) |
|
|
| outer_size, outer_stride = meta_tensor.shape, meta_tensor.stride() |
| out = type(meta_tensor).__tensor_unflatten__( |
| inner_tensors, ctx, outer_size, outer_stride |
| ) |
| out.pytype = pytype |
| out.dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw) |
| return out |
|
|
| @classmethod |
| def _unpickle_python_module(cls, alias: str) -> types.ModuleType: |
| return importlib.import_module(alias) |
|
|
| @classmethod |
| def _unpickle_dispatch_key_set(cls, raw_repr: int) -> torch._C.DispatchKeySet: |
| return torch._C.DispatchKeySet.from_raw_repr(raw_repr) |
|
|
| @classmethod |
| def _unpickle_functorch_interpreter( |
| cls, json: bytes |
| ) -> torch._C._functorch.CInterpreter: |
| return torch._C._functorch.CInterpreter.deserialize(json) |
|
|
| @classmethod |
| def _unpickle_mapping_proxy( |
| cls, d: dict[Any, Any] |
| ) -> types.MappingProxyType[Any, Any]: |
| return types.MappingProxyType(d) |
|
|
| @classmethod |
| def _unpickle_c_op(cls, name: str) -> Any: |
| return getattr(torch.ops._C, name) |
|
|
| def reducer_override( |
| self, obj: Any |
| ) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]: |
| import sympy |
|
|
| if isinstance(obj, torch.Tensor) and obj.device.type != "meta": |
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
|
|
| if is_traceable_wrapper_subclass(obj): |
| |
| |
| |
| inner_data = [] |
| attrs, ctx = obj.__tensor_flatten__() |
| |
| for attr in attrs: |
| inner = getattr(obj, attr) |
| func, args_tuple = self.reducer_override(inner) |
| inner_data.append((attr, func, args_tuple)) |
|
|
| return type(self)._unpickle_traceable_wrapper_subclass, ( |
| torch.empty_like(obj, device="meta"), |
| obj.device, |
| type(obj), |
| torch._C._dispatch_keys(obj).raw_repr(), |
| ctx, |
| inner_data, |
| ) |
|
|
| return type(self)._unpickle_tensor, ( |
| torch.empty_like(obj, device="meta", requires_grad=obj.requires_grad), |
| obj.device, |
| type(obj), |
| torch._C._dispatch_keys(obj).raw_repr(), |
| obj.grad, |
| ) |
|
|
| elif isinstance(obj, torch.nn.Module): |
| if type(obj).__qualname__ == type(obj).__name__: |
| return NotImplemented |
| if obj.__class__.__getstate__ == torch.nn.Module.__getstate__: |
| return type(self)._unpickle_module, (obj.__getstate__(),) |
|
|
| elif inspect.ismodule(obj): |
| return type(self)._unpickle_python_module, (obj.__name__,) |
|
|
| elif isinstance(obj, torch._C.DispatchKeySet): |
| return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),) |
|
|
| elif isinstance(obj, torch._C._functorch.CInterpreter): |
| return type(self)._unpickle_functorch_interpreter, (obj.serialize(),) |
|
|
| elif ( |
| inspect.isclass(obj) |
| and issubclass(obj, sympy.Function) |
| and hasattr(obj, "_torch_handler_name") |
| ): |
| assert hasattr(obj, "_torch_unpickler") |
| return obj._torch_unpickler, (obj._torch_handler_name,) |
|
|
| elif isinstance(obj, torch.SymInt): |
| raise RuntimeError(f"Cannot serialize SymInt {obj} (node: {obj.node})") |
|
|
| elif isinstance(obj, types.MappingProxyType): |
| return type(self)._unpickle_mapping_proxy, (obj.copy(),) |
|
|
| elif isinstance( |
| obj, torch._ops.OpOverloadPacket |
| ) and obj._qualified_op_name.startswith("_C::"): |
| return type(self)._unpickle_c_op, (obj.__name__,) |
|
|
| elif ( |
| obj.__class__.__module__ == "builtins" |
| and obj.__class__.__name__ == "PyCapsule" |
| ): |
| |
| return _Missing, () |
|
|
| elif isinstance(obj, types.CodeType): |
| |
| return _Missing, () |
|
|
| elif inspect.isfunction(obj) and (obj.__code__.co_flags & inspect.CO_NESTED): |
| |
| assert obj.__qualname__ != obj.__name__ |
| return _Missing, () |
|
|
| if type(obj).__qualname__ != type(obj).__name__: |
| raise torch._dynamo.exc.PackageError( |
| f"Type {type(obj)} for object {obj} cannot be saved " |
| + "into torch.compile() package since it's defined in local scope. " |
| + "Please define the class at global scope (top level of a module)." |
| ) |
|
|
| return NotImplemented |
|
|
|
|
| def pickle_guards_state(state: GuardsState) -> bytes: |
| buf = io.BytesIO() |
| pickler = GuardsStatePickler(buf) |
| try: |
| pickler.dump(state) |
| except AttributeError as e: |
| raise torch._dynamo.exc.PackageError(str(e)) from e |
| return buf.getvalue() |
|
|
|
|
| |
| |
| |
| |
| |
| class CheckFunctionManager: |
| def __init__( |
| self, |
| f_code: types.CodeType, |
| output_graph: OutputGraphGuardsState, |
| cache_entry: Optional[CacheEntry] = None, |
| guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, |
| guard_filter_fn: Optional[ |
| Callable[[list[GuardFilterEntry]], list[bool]] |
| ] = None, |
| shape_code_parts: Optional[ShapeCodeParts] = None, |
| runtime_global_scope: Optional[dict[str, Any]] = None, |
| save_guards: bool = False, |
| strict_error: bool = False, |
| ): |
| guards = output_graph.guards if output_graph else None |
| self._weakrefs: dict[int, ReferenceType[object]] = {} |
|
|
| existing_diff_guard_sources = ( |
| update_diff_guard_managers_for_existing_cache_entries(cache_entry) |
| ) |
| self.output_graph: Optional[OutputGraphGuardsState] = output_graph |
| assert self.output_graph is not None |
|
|
| |
| self.shape_code_parts = shape_code_parts |
|
|
| |
| |
| self.torch_function_mode_stack = ( |
| output_graph.torch_function_mode_stack if output_graph else None |
| ) |
| self.used_builtin_vars: OrderedSet[str] = OrderedSet() |
| self.additional_used_local_vars: OrderedSet[str] = OrderedSet() |
| self.additional_used_global_vars: OrderedSet[str] = OrderedSet() |
| self.runtime_global_scope = runtime_global_scope |
|
|
| if not justknobs_check("pytorch/compiler:guard_nn_modules"): |
| log.warning("guard_nn_modules is turned off using justknobs killswitch") |
|
|
| |
| if torch._dynamo.config.caching_precompile: |
| _guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs]) |
|
|
| def guard_filter_fn(guards: list[GuardFilterEntry]) -> list[bool]: |
| ret = [] |
| for keep, g in zip(_guard_filter_fn(guards), guards): |
| if not keep: |
| ret.append(False) |
| elif ( |
| g.guard_type in ("ID_MATCH", "CLOSURE_MATCH", "WEAKREF_ALIVE") |
| or "ID_MATCH" in g.derived_guard_types |
| ): |
| log.warning( |
| "%s guard on %s is dropped with caching_precompile=True.", |
| g.guard_type, |
| g.orig_guard.name, |
| ) |
| ret.append(False) |
| else: |
| ret.append(True) |
| return ret |
|
|
| sorted_guards = sorted(guards or (), key=Guard.sort_key) |
|
|
| if guard_filter_fn: |
| |
| |
| builder, guard_manager = self.build_guards( |
| sorted_guards, existing_diff_guard_sources, f_code, output_graph, False |
| ) |
|
|
| def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: |
| MISSING = object() |
| name = strip_local_scope(guard.name) |
| if name == "": |
| has_value = False |
| value = MISSING |
| else: |
| try: |
| |
| |
| |
| |
| value = builder.get(guard.name) |
| has_value = True |
| except: |
| value = MISSING |
| has_value = False |
| is_global = get_global_source_name(guard.originating_source) is not None |
| return GuardFilterEntry( |
| name=name, |
| has_value=has_value, |
| value=value, |
| guard_type=guard.create_fn_name(), |
| derived_guard_types=( |
| tuple(guard.guard_types) if guard.guard_types else () |
| ), |
| is_global=is_global, |
| orig_guard=guard, |
| ) |
|
|
| filter_results = guard_filter_fn( |
| [make_guard_filter_entry(guard) for guard in sorted_guards] |
| ) |
| assert len(filter_results) == len(sorted_guards) |
| assert all(type(x) == bool for x in filter_results) |
| sorted_guards = [ |
| guard for i, guard in enumerate(sorted_guards) if filter_results[i] |
| ] |
|
|
| |
| builder, guard_manager = self.build_guards( |
| sorted_guards, |
| existing_diff_guard_sources, |
| f_code, |
| output_graph, |
| save_guards, |
| ) |
|
|
| self.guard_manager = guard_manager |
| self.compile_check_fn(builder, sorted_guards, guard_fail_fn) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| self.guard_manager.id_matched_objs = builder.id_matched_objs |
|
|
| guards_log.debug("%s", self.guard_manager) |
| self.guard_manager.id_matched_objs = builder.id_matched_objs |
|
|
| |
| |
| |
| |
| latency = 0.0 |
|
|
| if not output_graph.skip_guards_check and not output_graph.export: |
| if not self.guard_manager.check(output_graph.local_scope): |
| reasons = get_guard_fail_reason_helper( |
| self.guard_manager, |
| output_graph.local_scope, |
| CompileContext.current_compile_id(), |
| ) |
| raise AssertionError(f"Guard check failed: {reasons}") |
|
|
| if guard_manager_testing_hook_fn is not None: |
| guard_manager_testing_hook_fn( |
| self.guard_manager, output_graph.local_scope, builder |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| latency = profile_guard_manager( |
| self.guard_manager.root, output_graph.local_scope, 1 |
| ) |
| guards_log.debug("Guard eval latency = %s us", f"{latency:.2f}") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| CompileEventLogger.increment_toplevel("guard_latency_us", int(latency)) |
|
|
| self.guards_state: Optional[bytes] = None |
| if save_guards: |
| from torch._dynamo.output_graph import OutputGraph |
|
|
| assert isinstance(self.output_graph, OutputGraph) |
| try: |
| self.guards_state = self.serialize_guards( |
| builder, sorted_guards, self.output_graph |
| ) |
| except exc.PackageError as e: |
| if torch._dynamo.config.strict_precompile or strict_error: |
| raise e |
| self.output_graph.bypass_package( |
| f"Guard evaluation failed: {str(e)}", |
| traceback=traceback.format_exc().split("\n"), |
| ) |
|
|
| |
| torch._logging.trace_structured( |
| "dynamo_cpp_guards_str", |
| payload_fn=lambda: f"{self.guard_manager}\nGuard latency = {latency:.2f} us", |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| self._weakrefs.clear() |
| self.output_graph = None |
|
|
| UNSUPPORTED_SERIALIZATION_GUARD_TYPES: tuple[LiteralString, ...] = ( |
| "DICT_VERSION", |
| "NN_MODULE", |
| "ID_MATCH", |
| "FUNCTION_MATCH", |
| "CLOSURE_MATCH", |
| "WEAKREF_ALIVE", |
| ) |
|
|
| def serialize_guards( |
| self, |
| builder: GuardBuilder, |
| sorted_guards: list[Guard], |
| output_graph: OutputGraph, |
| ) -> bytes: |
| |
| for guard in sorted_guards: |
| guard_type = guard.create_fn_name() |
| derived_guard_types = tuple(guard.guard_types) if guard.guard_types else () |
| |
| |
| if guard_type in ("TYPE_MATCH", "BUILTIN_MATCH"): |
| if guard._unserializable: |
| |
| obj = builder.get(guard.name) |
| raise_local_type_error(obj) |
| elif ( |
| guard_type in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES |
| ): |
| raise torch._dynamo.exc.PackageError( |
| f"{guard_type} guard cannot be serialized." |
| ) |
| elif failed := next( |
| ( |
| i |
| for i in derived_guard_types |
| if i in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES |
| ), |
| None, |
| ): |
| |
| raise torch._dynamo.exc.PackageError( |
| f"{failed} guard cannot be serialized." |
| ) |
|
|
| builtins_dict_name = output_graph.name_of_builtins_dict_key_in_fglobals |
| used_global_vars = set() |
| used_local_vars = set() |
|
|
| def prune_variable(source: Source) -> None: |
| if name := get_global_source_name(source): |
| assert isinstance(name, str) |
| |
| |
| |
| if name not in (builtins_dict_name,): |
| used_global_vars.add(name) |
| elif name := get_local_source_name(source): |
| assert isinstance(name, str) |
| used_local_vars.add(name) |
|
|
| output_graph_guards_state = output_graph.dump_guards_state() |
| |
| for guard in sorted_guards: |
| if isinstance(guard.originating_source, ShapeEnvSource): |
| assert self.shape_code_parts |
| for source in self.shape_code_parts.shape_env_sources: |
| prune_variable(source) |
| else: |
| prune_variable(guard.originating_source) |
|
|
| for source in output_graph.guard_on_key_order: |
| prune_variable(source) |
|
|
| def normalize_create_fn(x: Callable[..., None]) -> Callable[..., None]: |
| if isinstance(x, functools.partial): |
|
|
| def _ref(x: Any) -> Any: |
| if isinstance(x, (TensorWeakRef, weakref.ref)): |
| return x() |
| return x |
|
|
| new_args = tuple(_ref(a) for a in x.args) |
| new_keywords = {k: _ref(v) for k, v in x.keywords.items()} |
| return functools.partial(x.func, *new_args, **new_keywords) |
|
|
| return x |
|
|
| global_scope_state = { |
| k: v |
| for k, v in output_graph_guards_state.global_scope.items() |
| if k in used_global_vars or k in self.additional_used_global_vars |
| } |
| global_scope_state[builtins_dict_name] = { |
| k: v |
| for k, v in output_graph_guards_state.global_scope[ |
| builtins_dict_name |
| ].items() |
| if k in self.used_builtin_vars |
| } |
| output_graph_guards_state = dataclasses.replace( |
| output_graph_guards_state, |
| local_scope={ |
| k: v |
| for k, v in output_graph_guards_state.local_scope.items() |
| if k in used_local_vars or k in self.additional_used_local_vars |
| }, |
| global_scope=global_scope_state, |
| _guards=torch._guards.GuardsSet( |
| { |
| dataclasses.replace( |
| guard, |
| obj_weakref=None, |
| guarded_class_weakref=None, |
| create_fn=normalize_create_fn(guard.create_fn), |
| ) |
| for guard in sorted_guards |
| } |
| ), |
| input_source_to_sizes_strides=pytree.tree_map( |
| convert_int_to_concrete_values, |
| output_graph_guards_state.input_source_to_sizes_strides, |
| ), |
| skip_guards_check=True, |
| ) |
| guards_state = GuardsState( |
| output_graph=output_graph_guards_state, |
| shape_code_parts=self.shape_code_parts, |
| ) |
|
|
| return pickle_guards_state(guards_state) |
|
|
| def build_guards( |
| self, |
| sorted_guards: list[Guard], |
| existing_diff_guard_sources: OrderedSet[str], |
| f_code: types.CodeType, |
| output_graph: OutputGraphGuardsState, |
| save_guards: bool, |
| ) -> tuple[GuardBuilder, GuardManagerWrapper]: |
| guard_manager = GuardManagerWrapper() |
| guard_manager.diff_guard_sources = existing_diff_guard_sources |
|
|
| w_builder = None |
|
|
| def source_ref(source: Source) -> str: |
| guard_source = source.guard_source() |
| if guard_source is GuardSource.CONSTANT: |
| |
| return source.name() |
| assert w_builder |
| r_builder = w_builder() |
| assert r_builder is not None |
| return r_builder.arg_ref(source.name()) |
|
|
| builder = GuardBuilder( |
| f_code, |
| self.id_ref, |
| source_ref, |
| self.lookup_weakrefs, |
| output_graph.local_scope, |
| output_graph.global_scope, |
| guard_manager, |
| self, |
| save_guards, |
| runtime_global_scope=self.runtime_global_scope, |
| ) |
|
|
| |
| def cleanup_builder(weak_b: weakref.ref[GuardBuilder]) -> None: |
| b = weak_b() |
| if b: |
| b.scope = None |
|
|
| |
| w_builder = weakref.ref(builder, cleanup_builder) |
|
|
| guard_on_nn_modules = config.guard_nn_modules and justknobs_check( |
| "pytorch/compiler:guard_nn_modules" |
| ) |
|
|
| for guard in sorted_guards: |
| if ( |
| not guard_on_nn_modules |
| and guard.is_specialized_nn_module() |
| |
| |
| and "__defaults__" not in guard.name |
| and "__kwdefaults__" not in guard.name |
| and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name) |
| ): |
| continue |
|
|
| guard.create(builder) |
| return builder, guard_manager |
|
|
| def compile_check_fn( |
| self, |
| builder: GuardBuilder, |
| guards_out: list[Guard], |
| guard_fail_fn: Optional[Callable[[GuardFail], None]], |
| ) -> None: |
| |
| largs = builder.argnames |
| largs += ["**___kwargs_ignored"] |
|
|
| guards_log.debug("GUARDS:") |
|
|
| code_parts = [] |
| verbose_code_parts = [] |
| structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] |
|
|
| assert self.torch_function_mode_stack is not None |
| torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( |
| self.torch_function_mode_stack |
| ) |
|
|
| |
| self.guard_manager.root.attach_compile_id( |
| str(CompileContext.current_compile_id()) |
| ) |
|
|
| |
| assert self.output_graph is not None |
| global_state = self.output_graph.global_state_guard |
| self.guard_manager.root.add_global_state_guard( |
| global_state, ["___check_global_state()"] |
| ) |
|
|
| self.guard_manager.root.add_torch_function_mode_stack_guard( |
| self.torch_function_mode_stack, |
| ["___check_torch_function_mode_stack()"], |
| ) |
| |
| self.torch_function_mode_stack = None |
|
|
| def add_code_part( |
| code_part: str, guard: Optional[Guard], log_only: bool = False |
| ) -> None: |
| verbose_code_part = get_verbose_code_part(code_part, guard) |
| guards_log.debug("%s", verbose_code_part) |
|
|
| structured_guard_fns.append( |
| lambda: { |
| "code": code_part, |
| "stack": ( |
| structured.from_traceback(guard.stack.summary()) |
| if guard and guard.stack |
| else None |
| ), |
| "user_stack": ( |
| structured.from_traceback(guard.user_stack) |
| if guard and guard.user_stack |
| else None |
| ), |
| } |
| ) |
|
|
| if verbose_guards_log.isEnabledFor(logging.DEBUG): |
| maybe_stack = "" |
| maybe_user_stack = "" |
| if guard is not None: |
| if guard.stack: |
| maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}" |
| if guard.user_stack: |
| maybe_user_stack = ( |
| f"\nUser stack:\n{''.join(guard.user_stack.format())}" |
| ) |
| verbose_guards_log.debug( |
| "Guard: %s%s%s", |
| code_part, |
| maybe_stack, |
| maybe_user_stack, |
| ) |
|
|
| if not log_only: |
| code_parts.append(code_part) |
| verbose_code_parts.append(verbose_code_part) |
|
|
| seen = set() |
| for gcl in builder.code: |
| for code in gcl.code_list: |
| if code not in seen: |
| |
| |
| add_code_part(code, gcl.guard, True) |
| seen.add(code) |
|
|
| no_tensor_aliasing_names = builder.no_tensor_aliasing_names |
| check_tensors_fn = None |
| check_tensors_verbose_fn = None |
|
|
| if len(no_tensor_aliasing_names) > 1: |
| |
| |
| install_no_tensor_aliasing_guard( |
| builder.no_tensor_aliasing_guard_managers, |
| no_tensor_aliasing_names, |
| ["check_no_aliasing(" + ", ".join(no_tensor_aliasing_names) + ")"], |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if ( |
| config.use_lamba_guard_for_object_aliasing |
| and builder.object_aliasing_guard_codes |
| ): |
| aliasing_code_parts, aliasing_verbose_code_parts = map( |
| list, zip(*builder.object_aliasing_guard_codes) |
| ) |
| builder.add_python_lambda_leaf_guard_to_root( |
| aliasing_code_parts, aliasing_verbose_code_parts |
| ) |
|
|
| aotautograd_guards: list[GuardEnvExpr] = ( |
| self.output_graph.aotautograd_guards if self.output_graph else [] |
| ) |
|
|
| |
| |
| |
| for guard in aotautograd_guards: |
| if isinstance(guard, DuplicateInputs): |
| source_a = guard.input_source_a |
| source_b = guard.input_source_b |
| code_part = f"{source_a.name()} is {source_b.name()}" |
| install_object_aliasing_guard( |
| builder.get_guard_manager_from_source(source_a), |
| builder.get_guard_manager_from_source(source_b), |
| [code_part], |
| ) |
| add_code_part(code_part, None, True) |
| elif isinstance(guard, StorageOverlap): |
| overlapping_guard_managers = [ |
| builder.get_guard_manager_from_source(s) |
| for s in guard.overlapping_sources |
| ] |
| non_overlapping_guard_managers = [ |
| builder.get_guard_manager_from_source(s) |
| for s in guard.non_overlapping_sources |
| ] |
| code_part = ( |
| """check_overlapping(""" |
| f"""overlapping=[{", ".join(s.name() for s in guard.overlapping_sources)}], """ |
| f"""non_overlapping=[{", ".join(s.name() for s in guard.non_overlapping_sources)}])""" |
| ) |
| install_storage_overlapping_guard( |
| overlapping_guard_managers, |
| non_overlapping_guard_managers, |
| [code_part], |
| ) |
| add_code_part(code_part, None, True) |
| else: |
| raise RuntimeError(f"Unknown GuardEnvExpr: {guard}") |
|
|
| |
| |
| for gcl in builder.shape_env_code: |
| for code in gcl.code_list: |
| |
| |
| add_code_part(code, gcl.guard, True) |
|
|
| |
| if structured_guard_fns: |
| torch._logging.trace_structured( |
| "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] |
| ) |
|
|
| if convert_frame.initial_global_state is None: |
| |
| global_state = convert_frame.GlobalStateGuard() |
| closure_vars = { |
| "___check_tensors": check_tensors_fn, |
| "___check_tensors_verbose": check_tensors_verbose_fn, |
| "___check_global_state": global_state.check, |
| "___check_torch_function_mode_stack": torch_function_mode_stack_check_fn, |
| **SYMPY_INTERP, |
| **_get_closure_vars(), |
| } |
|
|
| self.guard_manager.finalize() |
|
|
| globals_for_guard_fn = {"G": builder.scope["G"]} |
| |
| |
| assert len(code_parts) == 0 |
|
|
| self.guard_manager.closure_vars = closure_vars |
| self.guard_manager.args = largs |
| self.guard_manager.populate_code_parts_for_debugging() |
| self.guard_manager.verbose_code_parts = verbose_code_parts |
| |
| self.guard_manager.global_scope = globals_for_guard_fn |
| self.guard_manager.guard_fail_fn = guard_fail_fn |
| |
| |
| self.guard_manager.cache_entry = None |
| self.guard_manager.extra_state = None |
| self.guard_manager.no_tensor_aliasing_sources = no_tensor_aliasing_names |
|
|
| def invalidate(self, obj_str: str) -> None: |
| |
| |
| |
| if ( |
| hasattr(self, "guard_manager") |
| and not isinstance(self.guard_manager, DeletedGuardManagerWrapper) |
| and (cache_entry := self.guard_manager.cache_entry) is not None |
| and (extra_state := self.guard_manager.extra_state) is not None |
| ): |
| assert isinstance(cache_entry, CacheEntry) |
| assert isinstance(extra_state, ExtraState) |
| reason = f"Cache line invalidated because {obj_str} got deallocated" |
| deleted_guard_manager = DeletedGuardManagerWrapper(reason) |
| extra_state.invalidate(cache_entry, deleted_guard_manager) |
| self.guard_manager = deleted_guard_manager |
|
|
| def id_ref(self, obj: object, obj_str: str) -> int: |
| """add a weakref, return the id""" |
| try: |
| if id(obj) not in self._weakrefs: |
| |
| |
| |
| self._weakrefs[id(obj)] = weakref.ref(obj) |
| weakref.finalize( |
| obj, functools.partial(self.invalidate, obj_str=obj_str) |
| ) |
| except TypeError: |
| pass |
| return id(obj) |
|
|
| def lookup_weakrefs(self, obj: object) -> Optional[weakref.ref[object]]: |
| """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" |
| if id(obj) in self._weakrefs: |
| return self._weakrefs[id(obj)] |
| return None |
|
|
|
|
| def build_guard_function(code_parts: list[str], closure_args: str) -> tuple[str, str]: |
| from torch._inductor.utils import IndentedBuffer |
|
|
| csepass = PyExprCSEPass() |
| try: |
| csepass.count(code_parts) |
|
|
| def replace(expr: str) -> tuple[list[str], str]: |
| return csepass.replace(expr) |
|
|
| except RecursionError: |
| |
| |
| def replace(expr: str) -> tuple[list[str], str]: |
| return [], expr |
|
|
| |
| |
| guard_body = IndentedBuffer() |
| for expr in code_parts: |
| preface, expr = replace(expr) |
| guard_body.writelines(preface) |
| guard_body.writeline(f"if not ({expr}):") |
| with guard_body.indent(): |
| guard_body.writeline("return False") |
|
|
| |
| guard = IndentedBuffer() |
| guard.writeline("def guard(L):") |
| with guard.indent(): |
| guard.splice(guard_body) |
| guard.writeline("return True") |
|
|
| |
| |
| make_guard_fn = IndentedBuffer() |
| make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):") |
| with make_guard_fn.indent(): |
| make_guard_fn.splice(guard) |
| make_guard_fn.writeline("return guard") |
|
|
| return guard_body.getvalue(), make_guard_fn.getvalue() |
|
|
|
|
| def is_recompiles_enabled() -> bool: |
| return torch._logging._internal.log_state.is_artifact_enabled("recompiles") |
|
|
|
|
| def is_recompiles_verbose_enabled() -> bool: |
| return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") |
|
|
|
|
| |
| def make_torch_function_mode_stack_guard( |
| initial_stack: list[torch.overrides.TorchFunctionMode], |
| ) -> Callable[[], bool]: |
| types = [type(x) for x in initial_stack] |
|
|
| def check_torch_function_mode_stack() -> bool: |
| cur_stack = get_torch_function_mode_stack() |
|
|
| if len(cur_stack) != len(types): |
| return False |
|
|
| for ty, mode in zip(types, cur_stack): |
| if ty != type(mode): |
| return False |
|
|
| return True |
|
|
| return check_torch_function_mode_stack |
|
|
|
|
| Scope = TypeAliasType("Scope", dict[str, object]) |
|
|
|
|
| def recompilation_reason_for_no_tensor_aliasing_guard( |
| guard_manager: GuardManagerWrapper, scope: Scope |
| ) -> list[str]: |
| assert guard_manager.global_scope is not None |
| global_scope = dict(guard_manager.global_scope) |
| ids_to_source = collections.defaultdict(list) |
| for tensor_source in guard_manager.no_tensor_aliasing_sources: |
| global_scope["__compile_source__"] = tensor_source |
| tensor_id = id(eval(tensor_source, global_scope, scope)) |
| ids_to_source[tensor_id].append(tensor_source) |
|
|
| duplicate_tensors = [ |
| f"{ids_to_source[key]}" for key in ids_to_source if len(ids_to_source[key]) > 1 |
| ] |
|
|
| reason = ", ".join(duplicate_tensors) |
| return [f"Duplicate tensors found: {reason}"] |
|
|
|
|
| def strip_local_scope(s: str) -> str: |
| """ |
| Replace occurrences of L[...] with just the inner content. |
| Handles both single and double quotes. |
| |
| This is to generate user friendly recompilation messages. |
| """ |
| import re |
|
|
| pattern = r"L\[\s*['\"](.*?)['\"]\s*\]" |
| return re.sub(pattern, r"\1", s) |
|
|
|
|
| def get_guard_fail_reason_helper( |
| guard_manager: GuardManagerWrapper, |
| f_locals: dict[str, object], |
| compile_id: Optional[CompileId], |
| ) -> str: |
| """ |
| Return the reason why `guard_manager` failed. |
| Updates `guard_failures` with the generated reason. |
| Only the first failed check of guard_manager is reported. |
| """ |
| assert guard_manager.global_scope is not None |
| assert guard_manager.closure_vars is not None |
| scope = {"L": f_locals, "G": guard_manager.global_scope["G"]} |
| scope.update(guard_manager.closure_vars) |
| reasons: list[str] = [] |
|
|
| no_tensor_aliasing_check_failed = False |
|
|
| verbose_code_parts: list[str] = [] |
| guard_debug_info = guard_manager.check_verbose(f_locals) |
| |
| |
| |
| if not guard_debug_info.result: |
| verbose_code_parts = guard_debug_info.verbose_code_parts |
| |
| |
| |
| |
| |
| |
|
|
| if len(verbose_code_parts) == 1: |
| if "Duplicate tensor found" in verbose_code_parts[0]: |
| no_tensor_aliasing_check_failed = True |
| else: |
| reasons = verbose_code_parts |
| verbose_code_parts = [] |
|
|
| if no_tensor_aliasing_check_failed: |
| reasons = recompilation_reason_for_no_tensor_aliasing_guard( |
| guard_manager, scope |
| ) |
| else: |
| for part in verbose_code_parts: |
| global_scope = dict(guard_manager.global_scope) |
| global_scope["__compile_source__"] = part |
| with report_compile_source_on_error(): |
| try: |
| fail_reason = eval(part, global_scope, scope) |
| except Exception: |
| if is_recompiles_verbose_enabled(): |
| continue |
| else: |
| raise |
| |
| |
|
|
| if isinstance(fail_reason, bool) and not fail_reason: |
| fail_reason = part |
| if isinstance(fail_reason, str): |
| reasons.append(fail_reason) |
| if not is_recompiles_verbose_enabled(): |
| break |
|
|
| reason_str = f"{compile_id}: " + "; ".join(reasons) |
| return strip_local_scope(reason_str) |
|
|
|
|
| def get_guard_fail_reason( |
| guard_manager: GuardManagerWrapper, |
| code: types.CodeType, |
| f_locals: dict[str, object], |
| compile_id: CompileId, |
| skip_logging: bool = False, |
| ) -> str: |
| if isinstance(guard_manager, DeletedGuardManagerWrapper): |
| return f"{compile_id}: {guard_manager.invalidation_reason}" |
| reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id) |
| if skip_logging: |
| return reason_str |
| guard_failures[orig_code_map[code]].append(reason_str) |
|
|
| try: |
| if guard_manager.guard_fail_fn is not None: |
| guard_manager.guard_fail_fn( |
| GuardFail(reason_str or "unknown reason", orig_code_map[code]) |
| ) |
| except Exception: |
| log.exception( |
| "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", |
| ) |
|
|
| return reason_str |
|
|
|
|
| def get_and_maybe_log_recompilation_reasons( |
| cache_entry: Optional[CacheEntry], |
| frame: DynamoFrameType, |
| skip_logging: bool = False, |
| ) -> list[str]: |
| """ |
| Return the list of guard failure reasons using cache_entry. |
| Logs the recompilation reason if `recompiles` logging is enabled. |
| Raises a RecompileError if `config.error_on_recompile` is enabled. |
| """ |
| reasons = [] |
| while cache_entry is not None: |
| reason = get_guard_fail_reason( |
| cache_entry.guard_manager, |
| cache_entry.code, |
| frame.f_locals, |
| cache_entry.compile_id, |
| skip_logging, |
| ) |
| if reason: |
| reasons.append(reason) |
| cache_entry = cache_entry.next |
|
|
| code = frame.f_code |
|
|
| if skip_logging: |
| return reasons |
| |
| do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled() |
|
|
| if do_recompiles_log or config.error_on_recompile: |
| if is_recompiles_verbose_enabled(): |
| failures = "\n\n".join( |
| f"guard {i} failures:\n" + textwrap.indent(reason, "- ") |
| for i, reason in enumerate(reasons) |
| ) |
| else: |
| failures = textwrap.indent("\n".join(reasons), "- ") |
| guard_failure_details = ( |
| f"triggered by the following guard failure(s):\n{failures}" |
| ) |
| message = ( |
| f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n" |
| f"{textwrap.indent(guard_failure_details, ' ')}" |
| ) |
| if do_recompiles_log: |
| if is_recompiles_verbose_enabled(): |
| recompiles_verbose_log.debug(message) |
| else: |
| recompiles_log.debug(message) |
| if config.error_on_recompile: |
| raise exc.RecompileError(message) |
|
|
| torch._logging.trace_structured( |
| "artifact", |
| metadata_fn=lambda: { |
| "name": "recompile_reasons", |
| "encoding": "json", |
| }, |
| payload_fn=lambda: reasons, |
| ) |
|
|
| return reasons |
|
|
|
|
| def update_diff_guard_managers_for_existing_cache_entries( |
| cache_entry: Optional[CacheEntry], |
| ) -> OrderedSet[str]: |
| first_cache_entry = cache_entry |
|
|
| |
| |
| |
| acc_diff_guard_sources: OrderedSet[str] = OrderedSet() |
| while cache_entry is not None: |
| acc_diff_guard_sources.update( |
| cache_entry.guard_manager.collect_diff_guard_sources() |
| ) |
| cache_entry = cache_entry.next |
|
|
| |
| |
| cache_entry = first_cache_entry |
| while cache_entry is not None: |
| cache_entry.guard_manager.diff_guard_sources = acc_diff_guard_sources |
| cache_entry.guard_manager.populate_diff_guard_manager() |
| cache_entry = cache_entry.next |
|
|
| |
| return acc_diff_guard_sources |
|
|
|
|
| def guard_error_hook( |
| guard_manager: GuardFn, |
| code: types.CodeType, |
| f_locals: dict[str, object], |
| index: int, |
| last: bool, |
| ) -> None: |
| print( |
| f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" |
| ) |
| print("lambda " + ", ".join(guard_manager.args) + ":") |
| print(" ", " and\n ".join(guard_manager.code_parts)) |
|
|
| print(guard_manager) |
|
|
| local_scope = {"L": f_locals, **guard_manager.closure_vars} |
| for guard in guard_manager.code_parts: |
| try: |
| eval(guard, guard_manager.global_scope, local_scope) |
| except: |
| print(f"Malformed guard:\n{guard}") |
|
|
|
|
| set_guard_error_hook(guard_error_hook) |
|
|
|
|
| def unique(seq: Sequence[T]) -> Generator[T, None, None]: |
| seen = set() |
| for x in seq: |
| if x not in seen: |
| yield x |
| seen.add(x) |
|
|
|
|
| def make_dupe_guard( |
| obj_source: Source, dupe_source: Source |
| ) -> Optional[functools.partial[Any]]: |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| if dupe_source and dupe_source != obj_source: |
| ser_source_is_local = is_from_local_source(dupe_source) |
| source_is_local = is_from_local_source(obj_source) |
| if is_from_flatten_script_object_source( |
| dupe_source |
| ) or is_from_flatten_script_object_source(obj_source): |
| raise exc.UnsafeScriptObjectError( |
| f"{obj_source.name()} is aliasing {dupe_source.name()}. This is not supported." |
| f" Please do a clone for corresponding input." |
| ) |
|
|
| |
| |
| |
| |
| if ser_source_is_local == source_is_local: |
| |
| |
| return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source) |
| return None |
|
|
|
|
| def install_guard(*guards: Guard, skip: int = 0) -> None: |
| """ |
| Add dynamo guards to the current tracing context. |
| |
| Args: |
| guards: guard(s) to add |
| skip: number of stack frames to ignore for debug stack trace |
| """ |
| from torch._guards import TracingContext |
|
|
| collect_debug_stack = guards_log.isEnabledFor( |
| logging.DEBUG |
| ) or verbose_guards_log.isEnabledFor(logging.DEBUG) |
| add = TracingContext.get().guards_context.dynamo_guards.add |
| for guard in guards: |
| assert isinstance(guard, Guard) |
|
|
| if is_from_skip_guard_source(guard.originating_source): |
| continue |
| add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) |
|
|