|
|
|
|
|
|
|
|
"""
|
|
|
Core graph building functionality for PyTorch's Dynamo system. This module contains
|
|
|
the essential components for constructing and managing FX graphs during compilation:
|
|
|
|
|
|
- OutputGraph: Manages the overall graph construction and compilation process. It owns
|
|
|
a SubgraphTracer and handles graph compilation, execution, and state management.
|
|
|
OutputGraph also manages features like graph deduplication, symbolic shape handling,
|
|
|
and tracking of side effects.
|
|
|
|
|
|
- SubgraphTracer: Handles the actual FX graph construction by tracing Python code.
|
|
|
It supports advanced features like higher-order operators through nested tracers,
|
|
|
lifting of free variables, and handling of symbolic shapes.
|
|
|
|
|
|
The module supports key Dynamo features including:
|
|
|
- Higher-order operators through nested SubgraphTracers
|
|
|
- Graph deduplication for optimization
|
|
|
- Symbolic shape handling and propagation
|
|
|
- Side effect tracking and management
|
|
|
- Guard insertion and management
|
|
|
"""
|
|
|
|
|
|
import collections
|
|
|
import contextlib
|
|
|
import copy
|
|
|
import functools
|
|
|
import inspect
|
|
|
import itertools
|
|
|
import logging
|
|
|
import operator
|
|
|
import re
|
|
|
import sys
|
|
|
import traceback
|
|
|
import weakref
|
|
|
from dataclasses import dataclass, field as dc_field
|
|
|
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
|
|
|
|
|
import sympy
|
|
|
|
|
|
import torch._guards
|
|
|
import torch._logging
|
|
|
import torch.distributed as dist
|
|
|
import torch.nn
|
|
|
import torch.utils._pytree as pytree
|
|
|
from torch import fx, Tensor
|
|
|
from torch._C._dynamo import guards
|
|
|
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
|
|
|
from torch._guards import (
|
|
|
CompileContext,
|
|
|
CompileId,
|
|
|
GlobalContextCheckpointState,
|
|
|
Source,
|
|
|
tracing,
|
|
|
TracingContext,
|
|
|
)
|
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
|
from torch._utils_internal import signpost_event
|
|
|
from torch.fx._lazy_graph_module import _make_graph_module
|
|
|
from torch.fx.experimental._backward_state import BackwardState
|
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
|
free_symbols,
|
|
|
guard_scalar,
|
|
|
is_symbolic,
|
|
|
ShapeEnv,
|
|
|
Specialization,
|
|
|
)
|
|
|
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
|
|
|
|
from . import config, exc, logging as torchdynamo_logging, variables
|
|
|
from .backends.registry import CompiledFn, CompilerFn
|
|
|
from .bytecode_transformation import (
|
|
|
create_call_function,
|
|
|
create_instruction,
|
|
|
create_load_const,
|
|
|
Instruction,
|
|
|
unique_id,
|
|
|
)
|
|
|
from .code_context import code_context
|
|
|
from .codegen import PyCodegen
|
|
|
from .current_scope_id import enter_new_scope
|
|
|
from .device_interface import get_interface_for_device
|
|
|
from .exc import (
|
|
|
BackendCompilerFailed,
|
|
|
exceptions_allowed_to_be_fallback,
|
|
|
SkipFrame,
|
|
|
unimplemented_v2,
|
|
|
unimplemented_v2_with_warning,
|
|
|
)
|
|
|
from .graph_deduplication import apply_graph_deduplication
|
|
|
from .graph_region_tracker import GraphRegionTracker
|
|
|
from .guards import GuardBuilder, install_guard
|
|
|
from .mutation_guard import is_dynamic_nn_module
|
|
|
from .side_effects import AttributeMutationExisting, SideEffects
|
|
|
from .source import (
|
|
|
AttrSource,
|
|
|
BackwardStateSource,
|
|
|
ConstantSource,
|
|
|
GetItemSource,
|
|
|
GlobalStateSource,
|
|
|
is_constant_source,
|
|
|
is_from_local_source,
|
|
|
LocalSource,
|
|
|
NumpyTensorSource,
|
|
|
ParamBufferSource,
|
|
|
ShapeEnvSource,
|
|
|
SyntheticLocalSource,
|
|
|
TensorProperty,
|
|
|
TensorPropertySource,
|
|
|
)
|
|
|
from .utils import (
|
|
|
_extract_tensor_dict,
|
|
|
checkpoint_params,
|
|
|
CleanupHook,
|
|
|
clone_inputs,
|
|
|
count_calls,
|
|
|
counters,
|
|
|
dynamo_timed,
|
|
|
get_instruction_source_311,
|
|
|
get_locals_to_steal,
|
|
|
get_static_address_type,
|
|
|
get_unique_name_wrt,
|
|
|
graph_break_reasons,
|
|
|
increment_op_count,
|
|
|
istype,
|
|
|
lazy_format_graph_code,
|
|
|
LazyString,
|
|
|
nn_module_proxy,
|
|
|
same,
|
|
|
set_example_value,
|
|
|
)
|
|
|
from .variables.base import VariableTracker
|
|
|
from .variables.builder import (
|
|
|
BackwardStateGraphArg,
|
|
|
GraphArg,
|
|
|
TrackedFake,
|
|
|
wrap_fx_proxy,
|
|
|
)
|
|
|
from .variables.ctx_manager import ContextWrappingVariable
|
|
|
from .variables.lists import BaseListVariable
|
|
|
from .variables.misc import CellVariable, NullVariable
|
|
|
from .variables.nn_module import NNModuleVariable
|
|
|
from .variables.tensor import (
|
|
|
NumpyNdarrayVariable,
|
|
|
SymNodeVariable,
|
|
|
TensorVariable,
|
|
|
UnspecializedPythonVariable,
|
|
|
)
|
|
|
from .variables.torch_function import TensorWithTFOverrideVariable
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
|
|
|
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
|
|
|
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
|
|
|
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
|
|
|
|
|
|
RootGuardManager = guards.RootGuardManager
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class VariableTrackerCacheKey:
|
|
|
vt_id: int
|
|
|
|
|
|
|
|
|
|
|
|
source: Source
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class AliasingInfo:
|
|
|
has_aliasing: bool
|
|
|
msg: str
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class MutationInfo:
|
|
|
has_mutation: bool
|
|
|
msg: str
|
|
|
|
|
|
|
|
|
class VariableTrackerCache:
|
|
|
def __init__(self):
|
|
|
self.cache = {}
|
|
|
|
|
|
def lookup(self, value, source):
|
|
|
key = VariableTrackerCacheKey(id(value), source)
|
|
|
if key not in self.cache:
|
|
|
return None
|
|
|
return self.cache[key]
|
|
|
|
|
|
def add(self, value, source, vt):
|
|
|
key = VariableTrackerCacheKey(id(value), source)
|
|
|
self.cache[key] = vt
|
|
|
|
|
|
def clone(self):
|
|
|
|
|
|
new_cache = VariableTrackerCache()
|
|
|
new_cache.cache.update(self.cache)
|
|
|
return new_cache
|
|
|
|
|
|
def clear(self):
|
|
|
self.cache.clear()
|
|
|
|
|
|
|
|
|
@functools.cache
|
|
|
def _step_logger():
|
|
|
return torchdynamo_logging.get_step_logger(log)
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class GraphCompileReason:
|
|
|
"""Stores why a given output graph was compiled; i.e. what caused the graph break."""
|
|
|
|
|
|
reason: str
|
|
|
user_stack: list[traceback.FrameSummary]
|
|
|
|
|
|
|
|
|
graph_break: bool = True
|
|
|
|
|
|
def __post_init__(self):
|
|
|
if self.graph_break:
|
|
|
graph_break_reasons.append(self)
|
|
|
|
|
|
|
|
|
def _get_gen_rand_values_fn(random_calls):
|
|
|
def _gen_rand_values():
|
|
|
return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
|
|
|
|
|
|
return _gen_rand_values
|
|
|
|
|
|
|
|
|
class FakeRootModule(torch.nn.Module):
|
|
|
"""Trick the constructor of fx.GraphModule"""
|
|
|
|
|
|
def __init__(self, nn_modules: dict[str, torch.nn.Module]):
|
|
|
super().__init__()
|
|
|
for k, v in nn_modules.items():
|
|
|
setattr(self, k, v)
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return "FakeRootModule(...)"
|
|
|
|
|
|
def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]):
|
|
|
for k, v in nn_modules.items():
|
|
|
setattr(self, k, v)
|
|
|
|
|
|
|
|
|
class WrapperBackend:
|
|
|
def __init__(self, backend: CompilerFn):
|
|
|
self.backend: CompilerFn = backend
|
|
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]):
|
|
|
self.restore = checkpoint_params(gm)
|
|
|
self.gm = gm
|
|
|
copy_gm = copy.deepcopy(self.gm)
|
|
|
self.candidate = self.backend(copy_gm, example_inputs)
|
|
|
|
|
|
if self.candidate is None or self.candidate is self.gm.forward:
|
|
|
return self.gm.forward
|
|
|
|
|
|
if not config.verify_correctness:
|
|
|
return self.candidate
|
|
|
|
|
|
|
|
|
try:
|
|
|
correct = self.gm.forward(*clone_inputs(example_inputs))
|
|
|
result = self.candidate(*clone_inputs(example_inputs))
|
|
|
|
|
|
|
|
|
if same(correct, result):
|
|
|
return self.candidate
|
|
|
|
|
|
raise RuntimeError(f"incorrect results of backend {self}")
|
|
|
|
|
|
except Exception:
|
|
|
log.exception("error in verify_correctness")
|
|
|
raise
|
|
|
finally:
|
|
|
self.restore()
|
|
|
|
|
|
|
|
|
Scope = dict[str, object]
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class OutputGraphGuardsState:
|
|
|
"""
|
|
|
A base class containing fields that are considered "persistent" when we
|
|
|
want to save all the important state for reconstrucing guards in a different
|
|
|
process. Normally we don't need to add states here, but we may have to when
|
|
|
the information is needed to serialize the guards, so the fields here are
|
|
|
supposed to be serializable as a requirement.
|
|
|
"""
|
|
|
|
|
|
local_scope: Scope
|
|
|
global_scope: Scope
|
|
|
|
|
|
torch_function_mode_stack: list[torch.overrides.TorchFunctionMode]
|
|
|
guard_on_key_order: set[Source]
|
|
|
|
|
|
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
|
|
|
dual_level: int
|
|
|
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
|
|
|
current_device: Optional[torch.device]
|
|
|
|
|
|
export: bool = False
|
|
|
export_constraints: bool = False
|
|
|
|
|
|
_guards: Optional[torch._guards.GuardsSet] = None
|
|
|
_aotautograd_guards: Optional[list[torch._guards.GuardEnvExpr]] = None
|
|
|
|
|
|
@property
|
|
|
def shape_env(self):
|
|
|
raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}")
|
|
|
|
|
|
@property
|
|
|
def guards(self):
|
|
|
return self._guards
|
|
|
|
|
|
@property
|
|
|
def aotautograd_guards(self):
|
|
|
return self._aotautograd_guards
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class StackLocalsMetadata:
|
|
|
"""
|
|
|
Stores metadata for a frame's stack and locals for the purposes of building resume functions
|
|
|
"""
|
|
|
|
|
|
stack_null_idxes: list[int] = dc_field(default_factory=list)
|
|
|
locals_null_keys: list[str] = dc_field(default_factory=list)
|
|
|
stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list)
|
|
|
stack_ctx_idxes_orig: list[int] = dc_field(default_factory=list)
|
|
|
locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list)
|
|
|
|
|
|
|
|
|
class OutputGraph(OutputGraphGuardsState):
|
|
|
"""
|
|
|
Wrapper class to hold outputs of InstructionTranslator. Mainly the
|
|
|
generated fx.Graph.
|
|
|
|
|
|
OutputGraph is 1:1 with a frame being processed. Each frame is associated
|
|
|
with some root InstructionTranslator. When user code calls a function,
|
|
|
we construct a InliningInstructionTranslator that continues to write into
|
|
|
the root InstructionTranslator's OutputGraph.
|
|
|
"""
|
|
|
|
|
|
side_effects: SideEffects
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
code_options: dict[str, Any],
|
|
|
compiler_fn: Optional[CompilerFn],
|
|
|
root_tx,
|
|
|
export: bool,
|
|
|
export_constraints,
|
|
|
frame_state,
|
|
|
local_scope: Scope,
|
|
|
global_scope: Scope,
|
|
|
f_code,
|
|
|
torch_function_mode_stack,
|
|
|
package,
|
|
|
):
|
|
|
super().__init__(
|
|
|
local_scope,
|
|
|
global_scope,
|
|
|
torch_function_mode_stack,
|
|
|
guard_on_key_order=set(),
|
|
|
input_source_to_sizes_strides={},
|
|
|
dual_level=torch.autograd.forward_ad._current_level,
|
|
|
functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
|
|
|
current_device=torch.utils._device.CURRENT_DEVICE,
|
|
|
)
|
|
|
self.tracers = [SubgraphTracer(self, is_export=export)]
|
|
|
|
|
|
|
|
|
self.input_source_to_var: dict[Source, VariableTracker] = {}
|
|
|
self.export = export
|
|
|
self.export_constraints = export_constraints
|
|
|
self.frame_state = frame_state
|
|
|
self.cleanup_hooks: list[Callable[[], Any]] = []
|
|
|
|
|
|
self.compile_id: int = next(_compile_id_counter)
|
|
|
|
|
|
self.installed_globals: set[str] = set()
|
|
|
|
|
|
|
|
|
|
|
|
self.co_fields = {
|
|
|
"co_name": f_code.co_name,
|
|
|
"co_filename": f_code.co_filename,
|
|
|
"co_firstlineno": f_code.co_firstlineno,
|
|
|
}
|
|
|
|
|
|
self.region_tracker = GraphRegionTracker()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tracked_fakes: list[TrackedFake] = []
|
|
|
|
|
|
shape_env = ShapeEnv(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tracked_fakes=self.tracked_fakes,
|
|
|
allow_scalar_outputs=config.capture_scalar_outputs,
|
|
|
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
|
|
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
|
|
|
allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
|
|
|
co_fields=self.co_fields,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
import torch._functorch.config as _config
|
|
|
|
|
|
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
|
|
fake_mode = torch._subclasses.FakeTensorMode(
|
|
|
shape_env=shape_env,
|
|
|
|
|
|
allow_non_fake_inputs=True if self.export else False,
|
|
|
export=self.export,
|
|
|
)
|
|
|
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
|
|
self.tracing_context.traced_code.append(f_code)
|
|
|
self.dynamo_compile_id: Optional[CompileId] = (
|
|
|
CompileContext.current_compile_id()
|
|
|
)
|
|
|
self.init_ambient_guards()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tracked_fakes_id_to_source: dict[int, list[Source]] = (
|
|
|
collections.defaultdict(list)
|
|
|
)
|
|
|
|
|
|
self.param_name_to_source: Optional[dict[str, Source]] = {}
|
|
|
self.side_effects = SideEffects(self)
|
|
|
|
|
|
|
|
|
self.variable_tracker_cache = VariableTrackerCache()
|
|
|
self.unique_var_id = itertools.count()
|
|
|
self.code_options: dict[str, Any] = dict(code_options)
|
|
|
self.output_instructions: list[Instruction] = []
|
|
|
|
|
|
|
|
|
self.timestamp = 0
|
|
|
|
|
|
|
|
|
self.register_finalizer_fns: list[Callable[[fx.GraphModule], None]] = []
|
|
|
|
|
|
|
|
|
self.compiler_fn: Optional[CompilerFn] = compiler_fn
|
|
|
self.root_tx = root_tx
|
|
|
|
|
|
self.package = package
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.source_to_user_stacks: dict[Source, list[traceback.StackSummary]] = {}
|
|
|
|
|
|
self._current_tx: list[InstructionTranslatorBase] = []
|
|
|
self.cleanups: list[CleanupHook] = []
|
|
|
self.should_exit = False
|
|
|
self.unspec_variable_map: dict[str, UnspecializedPythonVariable] = {}
|
|
|
|
|
|
|
|
|
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.has_user_defined_allowed_in_graph = False
|
|
|
|
|
|
|
|
|
|
|
|
self.non_compliant_ops: set[torch._ops.OpOverload] = set({})
|
|
|
|
|
|
|
|
|
|
|
|
self.compliant_custom_ops: set[torch._ops.OpOverload] = set({})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.save_global_state()
|
|
|
|
|
|
|
|
|
|
|
|
self.dynamo_flat_name_to_original_fqn: dict[str, str] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.random_calls: list[
|
|
|
tuple[Callable[..., object], tuple[object, ...], dict[str, object]]
|
|
|
] = []
|
|
|
self.random_values_var: Any = None
|
|
|
|
|
|
|
|
|
self.pregraph_bytecode: list[Instruction] = []
|
|
|
|
|
|
|
|
|
self.backward_state: dict[str, VariableTracker] = {}
|
|
|
self.backward_state_proxy: Optional[torch.fx.Proxy] = None
|
|
|
self.backward_state_var: Optional[str] = None
|
|
|
|
|
|
self.name_of_builtins_dict_key_in_fglobals: str = (
|
|
|
self.install_builtins_dict_in_fglobals()
|
|
|
)
|
|
|
|
|
|
self.compiler_trace_stack = contextlib.ExitStack()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.saved_tensors_hooks_subgraph_names: Optional[list[str]] = (
|
|
|
self.maybe_install_saved_tensors_hooks_subgraphs()
|
|
|
)
|
|
|
|
|
|
def mark_bytecode_tracing_start(self):
|
|
|
self.compiler_trace_stack.enter_context(
|
|
|
dynamo_timed(
|
|
|
"bytecode_tracing",
|
|
|
log_pt2_compile_event=True,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
def mark_bytecode_tracing_stop(self):
|
|
|
self.compiler_trace_stack.close()
|
|
|
|
|
|
def install_builtins_dict_in_fglobals(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f_builtins = self.global_scope["__builtins__"]
|
|
|
if not isinstance(f_builtins, dict):
|
|
|
f_builtins = f_builtins.__dict__
|
|
|
return self.install_global("__builtins_dict__", f_builtins)
|
|
|
|
|
|
def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"):
|
|
|
name = f"{prefix}{len(self.backward_state)}"
|
|
|
assert name not in self.backward_state
|
|
|
self.backward_state[name] = hook
|
|
|
return name, self.get_backward_state_proxy()
|
|
|
|
|
|
def get_backward_state_proxy(self):
|
|
|
if self.backward_state_proxy is None:
|
|
|
if self.export:
|
|
|
unimplemented_v2(
|
|
|
gb_type="backward_state does not support export",
|
|
|
context="",
|
|
|
explanation="Compiled autograd doesn't work with `torch.export`.",
|
|
|
hints=[],
|
|
|
)
|
|
|
example_value = BackwardState()
|
|
|
self.backward_state_proxy = self.root_tracer.create_graph_input(
|
|
|
"dynamo_backward_state",
|
|
|
type(example_value),
|
|
|
example_value,
|
|
|
source=BackwardStateSource(),
|
|
|
)
|
|
|
self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
|
|
|
self.backward_state_var = self.new_var()
|
|
|
return self.backward_state_proxy
|
|
|
|
|
|
|
|
|
def init_ambient_guards(self):
|
|
|
|
|
|
|
|
|
self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
|
|
|
|
|
|
self.guards.add(
|
|
|
GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
|
|
|
)
|
|
|
|
|
|
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
|
|
|
|
|
|
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
|
|
|
|
|
|
self.guards.add(
|
|
|
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
|
|
|
)
|
|
|
|
|
|
ci = torch._C._functorch.peek_interpreter_stack()
|
|
|
if ci is not None:
|
|
|
self.guards.add(
|
|
|
GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
|
|
|
)
|
|
|
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
|
|
self.guards.add(
|
|
|
GlobalStateSource().make_guard(
|
|
|
GuardBuilder.AUTOGRAD_SAVED_TENSORS_HOOKS
|
|
|
)
|
|
|
)
|
|
|
|
|
|
def maybe_install_saved_tensors_hooks_subgraphs(self) -> Optional[list[str]]:
|
|
|
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
|
|
return None
|
|
|
|
|
|
get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
|
|
|
are_inline_hooks = (
|
|
|
torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
|
|
|
)
|
|
|
hooks = get_hooks()
|
|
|
if not are_inline_hooks(hooks):
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pack_gm, unpack_gm = hooks
|
|
|
pack_subgraph_name = self.install_subgraph(
|
|
|
"saved_tensors_hooks_pack",
|
|
|
torch.fx.GraphModule(self.nn_modules, pack_gm.graph),
|
|
|
)
|
|
|
unpack_subgraph_name = self.install_subgraph(
|
|
|
"saved_tensors_hooks_unpack",
|
|
|
torch.fx.GraphModule(self.nn_modules, unpack_gm.graph),
|
|
|
)
|
|
|
assert pack_subgraph_name == "saved_tensors_hooks_pack_0"
|
|
|
assert unpack_subgraph_name == "saved_tensors_hooks_unpack_0"
|
|
|
return [pack_subgraph_name, unpack_subgraph_name]
|
|
|
|
|
|
def dump_guards_state(self):
|
|
|
return OutputGraphGuardsState(
|
|
|
local_scope=self.local_scope,
|
|
|
global_scope=self.global_scope,
|
|
|
torch_function_mode_stack=self.torch_function_mode_stack,
|
|
|
guard_on_key_order=self.guard_on_key_order,
|
|
|
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
|
|
|
dual_level=self.dual_level,
|
|
|
functorch_layers=self.functorch_layers,
|
|
|
current_device=self.current_device,
|
|
|
export=self.export,
|
|
|
export_constraints=self.export_constraints,
|
|
|
_guards=self.guards,
|
|
|
_aotautograd_guards=self.aotautograd_guards,
|
|
|
)
|
|
|
|
|
|
def synthetic_graph_input(self, fn, args):
|
|
|
"""
|
|
|
call fn(*args) before the graph runs and turn the result into a fake input.
|
|
|
"""
|
|
|
example_value = fn(*args)
|
|
|
varname = self.new_var()
|
|
|
cg = PyCodegen(self.root_tx)
|
|
|
cg.add_push_null(
|
|
|
lambda: cg.load_import_from(
|
|
|
fn.__module__,
|
|
|
fn.__name__,
|
|
|
)
|
|
|
)
|
|
|
cg.foreach(map(variables.ConstantVariable.create, args))
|
|
|
cg.call_function(len(args), False)
|
|
|
cg.store(varname)
|
|
|
self.pregraph_bytecode.extend(cg.get_instructions())
|
|
|
source = SyntheticLocalSource(varname)
|
|
|
result = VariableTracker.build(self.root_tx, example_value, source)
|
|
|
|
|
|
result = result.realize()
|
|
|
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
|
|
|
source
|
|
|
)
|
|
|
return result
|
|
|
|
|
|
def add_cleanup_hook(self, fn: Callable[[], Any]):
|
|
|
self.cleanup_hooks.append(fn)
|
|
|
|
|
|
def call_cleanup_hooks(self):
|
|
|
for hook in reversed(self.cleanup_hooks):
|
|
|
hook()
|
|
|
self.cleanup_hooks.clear()
|
|
|
|
|
|
@property
|
|
|
def root_tracer(self):
|
|
|
return self.tracers[0]
|
|
|
|
|
|
@property
|
|
|
def current_tracer(self):
|
|
|
return self.tracers[-1]
|
|
|
|
|
|
def is_root_tracer(self):
|
|
|
|
|
|
return len(self.tracers) == 1
|
|
|
|
|
|
@property
|
|
|
def graph(self):
|
|
|
return self.current_tracer.graph
|
|
|
|
|
|
|
|
|
@graph.setter
|
|
|
def graph(self, value):
|
|
|
self.current_tracer.graph = value
|
|
|
|
|
|
@property
|
|
|
def input_name_to_proxy(self):
|
|
|
return self.current_tracer.input_name_to_proxy
|
|
|
|
|
|
@property
|
|
|
def real_value_cache(self):
|
|
|
return self.current_tracer.real_value_cache
|
|
|
|
|
|
@property
|
|
|
def bound_symbols(self):
|
|
|
return self.current_tracer.bound_symbols
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_proxy(self, *args, **kwargs):
|
|
|
return self.current_tracer.create_proxy(*args, **kwargs)
|
|
|
|
|
|
def create_node(self, *args, **kwargs):
|
|
|
return self.current_tracer.create_node(*args, **kwargs)
|
|
|
|
|
|
def remove_node(self, *args, **kwargs):
|
|
|
return self.current_tracer.remove_node(*args, **kwargs)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
def subtracer(self, source_target, prior_tracer):
|
|
|
new_scope_ctx = enter_new_scope()
|
|
|
try:
|
|
|
if prior_tracer:
|
|
|
|
|
|
assert prior_tracer.parent is self.current_tracer
|
|
|
new_scope_ctx.__enter__()
|
|
|
tracer = (
|
|
|
prior_tracer
|
|
|
if prior_tracer
|
|
|
else SubgraphTracer(
|
|
|
self,
|
|
|
parent=self.current_tracer,
|
|
|
source_target=source_target,
|
|
|
is_export=self.current_tracer.is_export,
|
|
|
)
|
|
|
)
|
|
|
self.tracers.append(tracer)
|
|
|
yield tracer
|
|
|
finally:
|
|
|
new_scope_ctx.__exit__(None, None, None)
|
|
|
self.tracers.pop()
|
|
|
|
|
|
@property
|
|
|
def output(self):
|
|
|
return self
|
|
|
|
|
|
@property
|
|
|
def fake_mode(self):
|
|
|
return self.tracing_context.fake_mode
|
|
|
|
|
|
@property
|
|
|
def shape_env(self):
|
|
|
return self.tracing_context.fake_mode.shape_env
|
|
|
|
|
|
@property
|
|
|
def guards(self) -> torch._guards.GuardsSet:
|
|
|
return self.tracing_context.guards_context.dynamo_guards
|
|
|
|
|
|
@property
|
|
|
def nn_modules(self) -> dict[str, Any]:
|
|
|
return self.tracing_context.module_context.nn_modules
|
|
|
|
|
|
@property
|
|
|
def aotautograd_guards(self):
|
|
|
return self.tracing_context.guards_context.aotautograd_guards
|
|
|
|
|
|
def save_global_state(self, out=None):
|
|
|
"""
|
|
|
Saves to out if it is provided. Else saves to the tracing context's global_state.
|
|
|
"""
|
|
|
global_state = cast(
|
|
|
dict[str, tuple[Callable[..., Any], bool]],
|
|
|
(
|
|
|
out
|
|
|
if out is not None
|
|
|
else self.tracing_context.global_context.global_state
|
|
|
),
|
|
|
)
|
|
|
|
|
|
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
|
|
|
|
|
|
global_state["autocast_enabled"] = (
|
|
|
functools.partial(torch.set_autocast_enabled, "cuda"),
|
|
|
torch.is_autocast_enabled("cuda"),
|
|
|
)
|
|
|
global_state["autocast_cpu_enabled"] = (
|
|
|
functools.partial(torch.set_autocast_enabled, "cpu"),
|
|
|
torch.is_autocast_enabled("cpu"),
|
|
|
)
|
|
|
global_state["autocast_gpu_dtype"] = (
|
|
|
functools.partial(torch.set_autocast_dtype, "cuda"),
|
|
|
torch.get_autocast_dtype("cuda"),
|
|
|
)
|
|
|
global_state["autocast_cpu_dtype"] = (
|
|
|
functools.partial(torch.set_autocast_dtype, "cpu"),
|
|
|
torch.get_autocast_dtype("cpu"),
|
|
|
)
|
|
|
global_state["autocast_cache_enabled"] = (
|
|
|
torch.set_autocast_cache_enabled,
|
|
|
torch.is_autocast_cache_enabled(),
|
|
|
)
|
|
|
|
|
|
def push_tx(self, tx):
|
|
|
self._current_tx.append(tx)
|
|
|
|
|
|
def pop_tx(self):
|
|
|
return self._current_tx.pop()
|
|
|
|
|
|
@property
|
|
|
def current_tx(self):
|
|
|
return self.root_tx if not self._current_tx else self._current_tx[-1]
|
|
|
|
|
|
def count_calls(self):
|
|
|
return count_calls(self.graph)
|
|
|
|
|
|
def is_empty_graph(self):
|
|
|
return len(list(self.graph.nodes)) == 0
|
|
|
|
|
|
def get_submodule(self, keys):
|
|
|
assert keys
|
|
|
obj: Union[torch.nn.Module, dict[str, torch.nn.Module]] = self.nn_modules
|
|
|
for k in keys.split("."):
|
|
|
if isinstance(obj, dict):
|
|
|
obj = obj[k]
|
|
|
else:
|
|
|
obj = getattr(obj, k)
|
|
|
return obj
|
|
|
|
|
|
def new_var(self, name="tmp"):
|
|
|
existing = set(self.code_options["co_varnames"])
|
|
|
|
|
|
while True:
|
|
|
var = f"{name}_{next(self.unique_var_id)}"
|
|
|
if var not in existing:
|
|
|
self.code_options["co_varnames"] += (var,)
|
|
|
return var
|
|
|
|
|
|
def update_co_names(self, name):
|
|
|
"""Ensure self.code_options.co_names contains name"""
|
|
|
if name not in self.code_options["co_names"]:
|
|
|
self.code_options["co_names"] += (name,)
|
|
|
|
|
|
@staticmethod
|
|
|
def module_key_name(*names):
|
|
|
|
|
|
name = "_".join(map(str, names))
|
|
|
|
|
|
name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
|
|
|
|
|
|
name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
|
|
|
|
|
|
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
|
|
|
|
|
|
if not name or not name[0].isalpha():
|
|
|
name = "sub" + name
|
|
|
|
|
|
return name
|
|
|
|
|
|
def register_static_attr_and_return_proxy(
|
|
|
self, attr_prefix: str, attr_value: Any
|
|
|
) -> fx.Proxy:
|
|
|
attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules)
|
|
|
|
|
|
|
|
|
self.nn_modules[attr_name] = attr_value
|
|
|
proxy = self.create_proxy("get_attr", attr_name, (), {})
|
|
|
set_example_value(proxy.node, attr_value)
|
|
|
return proxy
|
|
|
|
|
|
def register_attr_or_module(
|
|
|
self,
|
|
|
target: Union[torch.nn.Module, torch.Tensor, Any],
|
|
|
*names,
|
|
|
**options,
|
|
|
):
|
|
|
if is_dynamic_nn_module(target, self.export):
|
|
|
|
|
|
|
|
|
return VariableTracker.build(self.current_tx, target, **options)
|
|
|
|
|
|
options = dict(options)
|
|
|
assert "source" in options
|
|
|
source = options["source"]
|
|
|
assert not isinstance(source, ParamBufferSource)
|
|
|
|
|
|
if isinstance(target, torch.Tensor):
|
|
|
tracer = self.current_tracer
|
|
|
if not self.is_root_tracer():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tracer = self.root_tracer
|
|
|
|
|
|
def wrap_name(module_key):
|
|
|
assert self.param_name_to_source is not None
|
|
|
self.param_name_to_source[module_key] = source
|
|
|
|
|
|
|
|
|
|
|
|
if target in self.root_tx.output.side_effects:
|
|
|
return self.root_tx.output.side_effects[target]
|
|
|
|
|
|
if get_static_address_type(target) == "guarded" and not isinstance(
|
|
|
source, NumpyTensorSource
|
|
|
):
|
|
|
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
|
|
|
elif not is_constant_source(source):
|
|
|
install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
|
|
|
|
|
|
vt = wrap_fx_proxy(
|
|
|
self.root_tx,
|
|
|
tracer.create_proxy("get_attr", module_key, (), {}),
|
|
|
example_value=target,
|
|
|
**options,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
|
|
|
|
|
|
assert "tensor_dict" not in vt.proxy.node.meta
|
|
|
vt.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(target)
|
|
|
|
|
|
return vt
|
|
|
|
|
|
elif isinstance(target, torch.nn.Module):
|
|
|
assert isinstance(target, torch.nn.Module)
|
|
|
|
|
|
if source:
|
|
|
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
|
|
|
|
|
|
def wrap_name(module_key):
|
|
|
return NNModuleVariable(type(target), module_key, target, **options)
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_name(module_key):
|
|
|
return variables.UnspecializedNNModuleVariable(target, **options)
|
|
|
|
|
|
elif isinstance(target, (torch.SymInt, torch.SymFloat)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_name(module_key):
|
|
|
return SymNodeVariable.create(
|
|
|
self,
|
|
|
self.create_proxy("get_attr", module_key, (), {}),
|
|
|
sym_num=target,
|
|
|
**options,
|
|
|
)
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
def wrap_name(module_key):
|
|
|
self.output.update_co_names(module_key)
|
|
|
self.global_scope[module_key] = target
|
|
|
return VariableTracker.build(
|
|
|
self,
|
|
|
target,
|
|
|
ConstantSource(source_name=module_key),
|
|
|
)
|
|
|
|
|
|
for k, v in self.nn_modules.items():
|
|
|
if v is target:
|
|
|
|
|
|
return wrap_name(k)
|
|
|
|
|
|
name = OutputGraph.module_key_name(*names)
|
|
|
name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
|
|
|
self.nn_modules[name] = target
|
|
|
if isinstance(target, torch.nn.Module):
|
|
|
|
|
|
def register_leaf_name(leaf_name):
|
|
|
assert self.param_name_to_source is not None
|
|
|
new_source = ParamBufferSource(source, leaf_name)
|
|
|
new_name = f"{name}.{leaf_name}"
|
|
|
self.param_name_to_source[new_name] = new_source
|
|
|
if isinstance(source, LocalSource):
|
|
|
self.dynamo_flat_name_to_original_fqn[
|
|
|
OutputGraph.module_key_name(new_source.name())
|
|
|
] = leaf_name
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(target, "_parameters"):
|
|
|
for leaf_name, _ in target.named_parameters():
|
|
|
register_leaf_name(leaf_name)
|
|
|
if hasattr(target, "_buffers"):
|
|
|
for leaf_name, _ in target.named_buffers():
|
|
|
register_leaf_name(leaf_name)
|
|
|
|
|
|
return wrap_name(name)
|
|
|
|
|
|
def handle_aliases_for_stolen_lists(self, tx):
|
|
|
|
|
|
maybe_gm = self.local_scope.get("self")
|
|
|
stolen_list_names = get_locals_to_steal(maybe_gm)
|
|
|
if not stolen_list_names:
|
|
|
return [], {}
|
|
|
|
|
|
alias_insts = []
|
|
|
needs_alias: dict[str, list[VariableTracker]] = {}
|
|
|
|
|
|
queue = [
|
|
|
*tx.stack,
|
|
|
*tx.symbolic_locals.values(),
|
|
|
*self.side_effects.store_attr_mutations.keys(),
|
|
|
]
|
|
|
|
|
|
while queue:
|
|
|
x = queue.pop()
|
|
|
if isinstance(x, BaseListVariable):
|
|
|
assert isinstance(x.items, list)
|
|
|
queue += x.items
|
|
|
continue
|
|
|
|
|
|
if not (
|
|
|
(
|
|
|
x not in self.side_effects.store_attr_mutations
|
|
|
or isinstance(x.mutation_type, AttributeMutationExisting)
|
|
|
)
|
|
|
and isinstance(x.source, GetItemSource)
|
|
|
and isinstance(x.source.base, LocalSource)
|
|
|
and x.source.base.local_name in stolen_list_names
|
|
|
):
|
|
|
continue
|
|
|
|
|
|
stolen_name = x.source.base.local_name
|
|
|
if stolen_name not in needs_alias:
|
|
|
needs_alias[stolen_name] = []
|
|
|
needs_alias[stolen_name].append(x)
|
|
|
|
|
|
visited = {}
|
|
|
overridden_sources: dict[Source, Source] = {}
|
|
|
for arg in self.graphargs:
|
|
|
if not (
|
|
|
isinstance(arg._example, list)
|
|
|
and isinstance(arg.source, LocalSource)
|
|
|
and arg.source.local_name in needs_alias
|
|
|
):
|
|
|
continue
|
|
|
|
|
|
|
|
|
list_name = arg.source.local_name
|
|
|
assert list_name in self.code_options["co_varnames"]
|
|
|
for x in needs_alias[list_name]:
|
|
|
|
|
|
if x.source in overridden_sources:
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
list_idx = x.source.index
|
|
|
if list_idx not in visited:
|
|
|
alias_name = self.new_var(
|
|
|
f"{list_name}_ref"
|
|
|
)
|
|
|
|
|
|
visited[list_idx] = alias_name
|
|
|
|
|
|
alias_insts.extend(
|
|
|
[
|
|
|
create_instruction("LOAD_FAST", argval=list_name),
|
|
|
create_load_const(list_idx),
|
|
|
create_instruction("BINARY_SUBSCR"),
|
|
|
create_instruction("STORE_FAST", argval=alias_name),
|
|
|
]
|
|
|
)
|
|
|
|
|
|
|
|
|
old_source = x.source
|
|
|
overridden_sources[old_source] = LocalSource(visited[list_idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return alias_insts, overridden_sources
|
|
|
|
|
|
def _get_stack_values_to_restore(self, tx, stack_pops):
|
|
|
"""
|
|
|
Gets the stack + locals values belonging to tx that need to be restored.
|
|
|
|
|
|
Also prunes dead tx locals and realizes all VTs in the tx's stack.
|
|
|
|
|
|
NullVariables in stack/locals will NOT be restored, unless they are the top `stack_pops`
|
|
|
elements of the stack - it is expected that the next instruction to run will pop the top
|
|
|
`stack_pops` elements of the stack, so we should codegen NULLs.
|
|
|
|
|
|
Returns:
|
|
|
- stack_values: stack and locals values that need to be restored
|
|
|
- restore_vars: names of locals corresponding to the locals part of `stack_values`
|
|
|
- meta: locations of NULLs and ContextWrappingVariables in the stack/locals
|
|
|
(ignores the top `stack_pops` values on the stack)
|
|
|
"""
|
|
|
tx.prune_dead_locals()
|
|
|
|
|
|
stack_values = []
|
|
|
meta = StackLocalsMetadata()
|
|
|
|
|
|
|
|
|
|
|
|
for i, value in enumerate(tx.stack):
|
|
|
variables.LazyVariableTracker.realize_all(value)
|
|
|
|
|
|
if len(tx.stack) - i <= stack_pops:
|
|
|
stack_values.append(value)
|
|
|
continue
|
|
|
if isinstance(value, NullVariable):
|
|
|
meta.stack_null_idxes.append(i)
|
|
|
else:
|
|
|
stack_values.append(value)
|
|
|
if isinstance(value, ContextWrappingVariable):
|
|
|
target_values = (
|
|
|
() if value.target_values is None else tuple(value.target_values)
|
|
|
)
|
|
|
|
|
|
meta.stack_ctx_args.append((len(stack_values) - 1, target_values))
|
|
|
meta.stack_ctx_idxes_orig.append(i)
|
|
|
|
|
|
|
|
|
restore_vars: list[str] = []
|
|
|
val_to_names: dict[VariableTracker, list[str]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for k, v in tx.symbolic_locals.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(v.source, LocalSource) and v.source.local_name == k:
|
|
|
continue
|
|
|
if isinstance(v, CellVariable) and v.local_name == k:
|
|
|
continue
|
|
|
|
|
|
if sys.version_info >= (3, 12):
|
|
|
|
|
|
if type.__instancecheck__(NullVariable, v):
|
|
|
meta.locals_null_keys.append(k)
|
|
|
continue
|
|
|
else:
|
|
|
|
|
|
assert not type.__instancecheck__(NullVariable, v)
|
|
|
if isinstance(v, ContextWrappingVariable):
|
|
|
target_values = (
|
|
|
() if v.target_values is None else tuple(v.target_values)
|
|
|
)
|
|
|
meta.locals_ctx_args.append((k, target_values))
|
|
|
if v not in val_to_names:
|
|
|
val_to_names[v] = []
|
|
|
val_to_names[v].append(k)
|
|
|
for v in val_to_names.keys():
|
|
|
restore_vars.extend(val_to_names[v])
|
|
|
stack_values.extend([v] * len(val_to_names[v]))
|
|
|
|
|
|
return stack_values, restore_vars, meta
|
|
|
|
|
|
def compile_subgraph(
|
|
|
self,
|
|
|
tx: "InstructionTranslatorBase",
|
|
|
reason: GraphCompileReason,
|
|
|
partial_convert=False,
|
|
|
stack_pops=0,
|
|
|
):
|
|
|
"""
|
|
|
Compiles the current subgraph, with inputs w.r.t. self.root_tx, and codegens:
|
|
|
- Call the compiled subgraph
|
|
|
- Apply side effects
|
|
|
- Codegen stack and locals
|
|
|
- Store the locals
|
|
|
|
|
|
Python does not allow NULL to be an arg to a function, so we do not codegen NULLs on the stack,
|
|
|
unless the value is one of the top `stack_pops` values on the stack (these values are expected to be
|
|
|
popped immediately after this generated code. The prologue of the resume function is expected to restore
|
|
|
any dropped NULLs.
|
|
|
|
|
|
Returns stack indices and locals keys where we dropped NULLs, and where we found inactive context manager objects.
|
|
|
"""
|
|
|
|
|
|
assert self.root_tx is not None
|
|
|
|
|
|
|
|
|
|
|
|
assert self.root_tx is tx
|
|
|
|
|
|
|
|
|
self.mark_bytecode_tracing_stop()
|
|
|
|
|
|
self.partial_convert = partial_convert
|
|
|
self.compile_subgraph_reason = reason
|
|
|
self.should_exit = True
|
|
|
|
|
|
log.debug("COMPILING GRAPH due to %s", reason)
|
|
|
|
|
|
|
|
|
prefix_insts: list[Instruction] = []
|
|
|
if sys.version_info >= (3, 11):
|
|
|
for inst in tx.prefix_insts:
|
|
|
if inst.opname == "MAKE_CELL":
|
|
|
prefix_insts.append(
|
|
|
create_instruction("MAKE_CELL", argval=inst.argval)
|
|
|
)
|
|
|
elif inst.opname == "COPY_FREE_VARS":
|
|
|
prefix_insts.append(
|
|
|
create_instruction(
|
|
|
"COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
|
|
|
)
|
|
|
)
|
|
|
else:
|
|
|
prefix_insts.append(copy.copy(inst))
|
|
|
self.add_output_instructions(prefix_insts)
|
|
|
|
|
|
assert not (self.pregraph_bytecode and self.export), (
|
|
|
"export does not support pregraph_bytecode"
|
|
|
)
|
|
|
self.add_output_instructions(self.pregraph_bytecode)
|
|
|
|
|
|
alias_insts, overridden_sources = self.handle_aliases_for_stolen_lists(
|
|
|
self.root_tx
|
|
|
)
|
|
|
self.add_output_instructions(alias_insts)
|
|
|
|
|
|
|
|
|
for block in reversed(self.root_tx.block_stack):
|
|
|
block.exit(self.root_tx, is_graph_break=reason.graph_break)
|
|
|
|
|
|
self.cleanup_graph()
|
|
|
|
|
|
|
|
|
|
|
|
all_stack_values = []
|
|
|
all_restore_vars = []
|
|
|
all_stack_locals_metas = []
|
|
|
cur_tx: Optional[InstructionTranslatorBase] = tx
|
|
|
while True:
|
|
|
assert cur_tx is not None
|
|
|
|
|
|
assert all(block.can_restore() for block in cur_tx.block_stack)
|
|
|
stack_values, restore_vars, meta = self._get_stack_values_to_restore(
|
|
|
cur_tx, stack_pops
|
|
|
)
|
|
|
all_stack_values.append(stack_values)
|
|
|
all_restore_vars.append(restore_vars)
|
|
|
all_stack_locals_metas.append(meta)
|
|
|
if cur_tx is self.root_tx:
|
|
|
break
|
|
|
cur_tx = tx.parent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn_modules_proxies = {
|
|
|
name: nn_module_proxy(mod) for name, mod in self.nn_modules.items()
|
|
|
}
|
|
|
root = FakeRootModule(nn_modules_proxies)
|
|
|
|
|
|
from .decorators import disable
|
|
|
|
|
|
|
|
|
if len(self.random_calls) > 0:
|
|
|
random_calls_instructions = []
|
|
|
self.random_values_var = self.new_var("random_values")
|
|
|
rand_fn = disable(
|
|
|
_get_gen_rand_values_fn(self.random_calls),
|
|
|
reason="do not trace into Dynamo rng recovery function",
|
|
|
)
|
|
|
rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
|
|
|
codegen = PyCodegen(
|
|
|
self.root_tx, root, overridden_sources=overridden_sources
|
|
|
)
|
|
|
random_calls_instructions.extend(
|
|
|
codegen.load_function_name(rand_fn_name, True)
|
|
|
)
|
|
|
random_calls_instructions.extend(create_call_function(0, False))
|
|
|
random_calls_instructions.append(
|
|
|
codegen.create_store(self.random_values_var),
|
|
|
)
|
|
|
self.add_output_instructions(random_calls_instructions)
|
|
|
|
|
|
|
|
|
graph_output_var = None
|
|
|
stored_graph_output_var = False
|
|
|
root_stack_values = all_stack_values[-1]
|
|
|
if (
|
|
|
self.root_tx is tx
|
|
|
and root_stack_values
|
|
|
and all(
|
|
|
not isinstance(
|
|
|
v,
|
|
|
(
|
|
|
UnspecializedPythonVariable,
|
|
|
NumpyNdarrayVariable,
|
|
|
TensorWithTFOverrideVariable,
|
|
|
),
|
|
|
)
|
|
|
and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
|
|
|
for v in root_stack_values
|
|
|
)
|
|
|
and all(isinstance(x, TensorVariable) for x in root_stack_values)
|
|
|
and len(set(root_stack_values)) == len(root_stack_values)
|
|
|
and self.side_effects.is_empty()
|
|
|
and not tx.debug_locals
|
|
|
and not self.backward_state
|
|
|
and not all_stack_locals_metas[-1].stack_null_idxes
|
|
|
and not all_stack_locals_metas[-1].locals_null_keys
|
|
|
):
|
|
|
|
|
|
self.add_output_instructions(
|
|
|
self.compile_and_call_fx_graph(
|
|
|
tx, list(reversed(root_stack_values)), root
|
|
|
)
|
|
|
+ [create_instruction("UNPACK_SEQUENCE", arg=len(root_stack_values))]
|
|
|
)
|
|
|
else:
|
|
|
graph_output_var = self.new_var("graph_out")
|
|
|
|
|
|
stack_values_flat = [
|
|
|
val for vals in reversed(all_stack_values) for val in vals
|
|
|
]
|
|
|
pass1 = PyCodegen(
|
|
|
self.root_tx,
|
|
|
root,
|
|
|
graph_output_var,
|
|
|
overridden_sources=overridden_sources,
|
|
|
)
|
|
|
self.codegen_suffix(tx, stack_values_flat, pass1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tempvars = {}
|
|
|
for val, count in pass1.uses.items():
|
|
|
|
|
|
if count > 1 and not istype(val, (SyntheticLocalSource, LocalSource)):
|
|
|
tempvars[val] = None
|
|
|
pass2 = PyCodegen(
|
|
|
self.root_tx,
|
|
|
root,
|
|
|
graph_output_var,
|
|
|
tempvars=tempvars,
|
|
|
overridden_sources=overridden_sources,
|
|
|
)
|
|
|
self.codegen_suffix(tx, stack_values_flat, pass2)
|
|
|
|
|
|
output = []
|
|
|
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
|
|
|
output.extend(
|
|
|
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
|
|
|
)
|
|
|
|
|
|
if len(pass2.graph_outputs) != 0:
|
|
|
output.append(pass2.create_store(graph_output_var))
|
|
|
stored_graph_output_var = True
|
|
|
else:
|
|
|
output.append(create_instruction("POP_TOP"))
|
|
|
else:
|
|
|
|
|
|
|
|
|
self.run_compiler_collective()
|
|
|
self.add_output_instructions(output + pass2.get_instructions())
|
|
|
|
|
|
|
|
|
local_restore_cg = PyCodegen(
|
|
|
self.root_tx, overridden_sources=overridden_sources
|
|
|
)
|
|
|
|
|
|
self.add_output_instructions(
|
|
|
[
|
|
|
local_restore_cg.create_store(var)
|
|
|
for var in reversed(all_restore_vars[-1])
|
|
|
]
|
|
|
)
|
|
|
|
|
|
if graph_output_var and stored_graph_output_var:
|
|
|
self.add_output_instructions(
|
|
|
[local_restore_cg.create_delete(graph_output_var)]
|
|
|
)
|
|
|
|
|
|
return all_stack_locals_metas
|
|
|
|
|
|
def codegen_suffix(self, tx, stack_values, cg):
|
|
|
|
|
|
|
|
|
|
|
|
self.side_effects.codegen_save_tempvars(cg)
|
|
|
if self.backward_state:
|
|
|
assert not self.export
|
|
|
for name, val in self.backward_state.items():
|
|
|
cg(val)
|
|
|
cg.append_output(cg.create_load(self.backward_state_var))
|
|
|
cg.store_attr(name)
|
|
|
self.side_effects.codegen_hooks(cg)
|
|
|
|
|
|
|
|
|
for debug_var, args in tx.debug_locals:
|
|
|
cg.add_push_null(lambda: cg(debug_var))
|
|
|
for arg in args:
|
|
|
cg(arg)
|
|
|
cg.extend_output(create_call_function(len(args), False))
|
|
|
cg.extend_output([create_instruction("POP_TOP")])
|
|
|
|
|
|
cg.restore_stack(stack_values, value_from_source=not tx.export)
|
|
|
self.side_effects.codegen_update_mutated(cg)
|
|
|
|
|
|
def cleanup_graph(self):
|
|
|
"""
|
|
|
Remove "creation_timestamp" from node meta
|
|
|
|
|
|
Remove this pattern from the graph:
|
|
|
torch._C._set_grad_enabled(False)
|
|
|
torch._C._set_grad_enabled(True)
|
|
|
"""
|
|
|
assert self.should_exit
|
|
|
nodes = list(self.graph.nodes)
|
|
|
for node in nodes:
|
|
|
node.meta.pop("creation_timestamp", None)
|
|
|
|
|
|
grad_enabled = torch.is_grad_enabled()
|
|
|
for node1, node2 in zip(nodes, nodes[1:]):
|
|
|
if (
|
|
|
node1.target is torch._C._set_grad_enabled
|
|
|
and tuple(node1.args) == (not grad_enabled,)
|
|
|
and not node1._erased
|
|
|
):
|
|
|
grad_enabled = node1.args[0]
|
|
|
if (
|
|
|
node2.target is torch._C._set_grad_enabled
|
|
|
and tuple(node2.args) == (not grad_enabled,)
|
|
|
and not node2._erased
|
|
|
):
|
|
|
grad_enabled = node2.args[0]
|
|
|
self.graph.erase_node(node1)
|
|
|
self.graph.erase_node(node2)
|
|
|
|
|
|
def get_graph_sizes_structured(self):
|
|
|
ret = {}
|
|
|
for node in self.graph.nodes:
|
|
|
example_value = node.meta.get("example_value", None)
|
|
|
if isinstance(example_value, torch._subclasses.FakeTensor):
|
|
|
size = example_value.size()
|
|
|
ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
|
|
|
return ret
|
|
|
|
|
|
def get_graph_sizes(self, name: str):
|
|
|
graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
|
|
|
graph_sizes_str += f"===== {name} =====\n"
|
|
|
for node in self.graph.nodes:
|
|
|
example_value = node.meta.get("example_value", None)
|
|
|
if isinstance(example_value, torch._subclasses.FakeTensor):
|
|
|
size = example_value.size()
|
|
|
graph_sizes_str += f"{node.name}: {tuple(size)}\n"
|
|
|
concrete_size = []
|
|
|
has_symint = False
|
|
|
for sz in size:
|
|
|
if isinstance(sz, int):
|
|
|
concrete_size.append(sz)
|
|
|
elif isinstance(sz, torch.SymInt):
|
|
|
has_symint = True
|
|
|
concrete_size.append(sz.node.hint)
|
|
|
else:
|
|
|
break
|
|
|
else:
|
|
|
if has_symint:
|
|
|
graph_sizes_str += (
|
|
|
f"{node.name} (concrete): {tuple(concrete_size)}\n"
|
|
|
)
|
|
|
return graph_sizes_str
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
def restore_global_state(self):
|
|
|
"""
|
|
|
Momentarily restores the global state to what it was prior to tracing the current output
|
|
|
"""
|
|
|
prior_global_state = self.tracing_context.global_context.copy_graphstate()
|
|
|
current_global_state: dict[str, tuple[Any, bool]] = {}
|
|
|
self.save_global_state(out=current_global_state)
|
|
|
try:
|
|
|
|
|
|
self.tracing_context.global_context.restore_graphstate(prior_global_state)
|
|
|
yield
|
|
|
finally:
|
|
|
|
|
|
self.tracing_context.global_context.restore_graphstate(
|
|
|
GlobalContextCheckpointState(current_global_state)
|
|
|
)
|
|
|
|
|
|
def run_compiler_collective(self):
|
|
|
tx = self.root_tx
|
|
|
assert tx is not None
|
|
|
if (ds := tx.distributed_state) is not None and ds.all_states is None:
|
|
|
compile_pg = ds.compile_pg
|
|
|
log.info("compiler_collective %s", ds.local_state)
|
|
|
torch._logging.trace_structured(
|
|
|
"artifact",
|
|
|
metadata_fn=lambda: {
|
|
|
"name": "compiler_collective",
|
|
|
"encoding": "string",
|
|
|
},
|
|
|
payload_fn=lambda: ds.local_state.render(),
|
|
|
)
|
|
|
device_types = compile_pg._device_types
|
|
|
assert len(device_types) == 1, (
|
|
|
"Expect only one device type but got {}".format("+".join(device_types))
|
|
|
)
|
|
|
with (
|
|
|
get_interface_for_device(device_types.pop()).device(
|
|
|
compile_pg.rank() % torch.accelerator.device_count()
|
|
|
),
|
|
|
dynamo_timed("compiler_collective", log_pt2_compile_event=True),
|
|
|
):
|
|
|
all_states = [None] * compile_pg.size()
|
|
|
dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
|
|
|
ds.all_states = all_states
|
|
|
|
|
|
|
|
|
tx.speculation_log.clear()
|
|
|
raise exc.CompileCollectiveRestartAnalysis
|
|
|
|
|
|
def compile_and_call_fx_graph(self, tx, rv, root):
|
|
|
"""
|
|
|
Generate code from self.graph and return the Instruction()s to
|
|
|
call that generated code.
|
|
|
|
|
|
Code is generated w.r.t. self.root_tx.
|
|
|
tx is only used for preserving GraphModule metadata
|
|
|
"""
|
|
|
with torch._guards.TracingContext.clear_frame():
|
|
|
from .decorators import disable
|
|
|
|
|
|
assert self.should_exit
|
|
|
|
|
|
self.run_compiler_collective()
|
|
|
|
|
|
name = unique_id("__compiled_fn", with_uuid=True)
|
|
|
|
|
|
assert isinstance(rv, list)
|
|
|
assert isinstance(root, FakeRootModule)
|
|
|
|
|
|
output_node = self.create_node(
|
|
|
"output",
|
|
|
"output",
|
|
|
(self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
|
|
|
{},
|
|
|
)
|
|
|
sub_gms = self.dedup_pass()
|
|
|
root.add_nn_modules(sub_gms)
|
|
|
|
|
|
self.current_tracer._maybe_preserve_original_meta(tx, output_node)
|
|
|
if not config.do_not_emit_runtime_asserts:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.remove_unused_get_attr_nodes()
|
|
|
insert_deferred_runtime_asserts(
|
|
|
fx.GraphModule(root, self.graph),
|
|
|
self.shape_env,
|
|
|
name,
|
|
|
export=self.export,
|
|
|
)
|
|
|
|
|
|
|
|
|
self.remove_unused_graphargs()
|
|
|
ncalls = count_calls(self.graph)
|
|
|
counters["stats"]["calls_captured"] += ncalls
|
|
|
|
|
|
self.remove_tensorify_specialized_graphargs()
|
|
|
|
|
|
|
|
|
self.real_value_cache.clear()
|
|
|
|
|
|
gm = _make_graph_module(root, self.graph)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.saved_tensors_hooks_subgraph_names:
|
|
|
for subgraph_name in self.saved_tensors_hooks_subgraph_names:
|
|
|
setattr(gm, subgraph_name, getattr(root, subgraph_name))
|
|
|
|
|
|
for register_finalizer in self.register_finalizer_fns:
|
|
|
register_finalizer(gm)
|
|
|
|
|
|
gm._backend_id = name
|
|
|
gm.compile_subgraph_reason = self.compile_subgraph_reason
|
|
|
gm.meta["dynamo_flat_name_to_original_fqn"] = (
|
|
|
self.dynamo_flat_name_to_original_fqn.copy()
|
|
|
)
|
|
|
gm.meta["dynamo_compile_id"] = self.dynamo_compile_id
|
|
|
|
|
|
graph_code_log.debug(
|
|
|
"%s",
|
|
|
lazy_format_graph_code(
|
|
|
name, gm, include_stride=True, include_device=True, colored=True
|
|
|
),
|
|
|
)
|
|
|
torch._logging.trace_structured(
|
|
|
"dynamo_output_graph",
|
|
|
lambda: {"sizes": self.get_graph_sizes_structured()},
|
|
|
payload_fn=lambda: gm.print_readable(
|
|
|
print_output=False, include_stride=True, include_device=True
|
|
|
),
|
|
|
)
|
|
|
self.call_cleanup_hooks()
|
|
|
old_fake_mode = self.tracing_context.fake_mode
|
|
|
if not self.export:
|
|
|
import torch._functorch.config as _config
|
|
|
|
|
|
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
|
|
|
|
|
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
|
|
shape_env=old_fake_mode.shape_env,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.tracing_context.fake_mode = backend_fake_mode
|
|
|
|
|
|
with self.restore_global_state():
|
|
|
compiled_fn = self.call_user_compiler(gm, self.example_inputs())
|
|
|
|
|
|
from torch.fx._lazy_graph_module import _LazyGraphModule
|
|
|
|
|
|
if isinstance(compiled_fn, _LazyGraphModule) or (
|
|
|
isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
|
|
|
and compiled_fn.__name__ == "_lazy_forward"
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lazy_gm = (
|
|
|
compiled_fn
|
|
|
if isinstance(compiled_fn, _LazyGraphModule)
|
|
|
else compiled_fn.__self__
|
|
|
)
|
|
|
|
|
|
_LazyGraphModule.force_recompile(lazy_gm)
|
|
|
|
|
|
if not isinstance(compiled_fn, _LazyGraphModule):
|
|
|
|
|
|
compiled_fn = lazy_gm.forward
|
|
|
|
|
|
if self.package is not None:
|
|
|
self.package.add_backend_id(name, compiled_fn)
|
|
|
|
|
|
compiled_fn = disable(
|
|
|
compiled_fn, reason="do not trace Dynamo-compiled graph"
|
|
|
)
|
|
|
|
|
|
counters["stats"]["unique_graphs"] += 1
|
|
|
if specializations := old_fake_mode.shape_env.specializations:
|
|
|
specialization_guards = []
|
|
|
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
|
|
|
sources = [a.source for a in self.graphargs]
|
|
|
for specialization in specializations:
|
|
|
source_index = sources.index(specialization.source)
|
|
|
check_fn_source = inspect.getsource(specialization.check_fn).strip()
|
|
|
check_fn = guards.LAMBDA_GUARD(
|
|
|
specialization.check_fn,
|
|
|
[check_fn_source],
|
|
|
)
|
|
|
|
|
|
log.debug(
|
|
|
"Compiling backend specialized graph with specialization=%s",
|
|
|
check_fn_source,
|
|
|
)
|
|
|
|
|
|
specialization_guards.append(
|
|
|
(
|
|
|
functools.partial(
|
|
|
lambda idx, args, check_fn=check_fn: check_fn(
|
|
|
args[idx]
|
|
|
),
|
|
|
source_index,
|
|
|
),
|
|
|
specialization,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
@torch._dynamo.disable(reason="do not trace Dynamo-compiled graph")
|
|
|
def specialized_dispatch(*args, **kwargs):
|
|
|
for check_fn, specialization in specialization_guards:
|
|
|
if check_fn(args):
|
|
|
if specialization in specialization_cache:
|
|
|
return specialization_cache[specialization](
|
|
|
*args, **kwargs
|
|
|
)
|
|
|
|
|
|
with self.shape_env.patch_source_specialization(
|
|
|
specialization.source, specialization.check_fn
|
|
|
):
|
|
|
|
|
|
gm.meta["specialization"] = specialization
|
|
|
example_inputs: list[Tensor] = list(args)
|
|
|
with tracing(self.tracing_context):
|
|
|
specialization_cache[specialization] = (
|
|
|
self.call_user_compiler(gm, example_inputs)
|
|
|
)
|
|
|
|
|
|
return specialization_cache[specialization](*args, **kwargs)
|
|
|
return compiled_fn(*args, **kwargs)
|
|
|
|
|
|
|
|
|
self.install_global_unsafe(name, specialized_dispatch)
|
|
|
else:
|
|
|
|
|
|
self.install_global_unsafe(name, compiled_fn)
|
|
|
|
|
|
assert self.root_tx is not None
|
|
|
cg = PyCodegen(self.root_tx)
|
|
|
cg.make_call_generated_code(name)
|
|
|
return cg.get_instructions()
|
|
|
|
|
|
@property
|
|
|
def placeholders(self) -> list[fx.Node]:
|
|
|
return self.graph.find_nodes(op="placeholder")
|
|
|
|
|
|
@property
|
|
|
def graphargs(self) -> list[GraphArg]:
|
|
|
return [node.meta["grapharg"] for node in self.placeholders]
|
|
|
|
|
|
def call_user_compiler(
|
|
|
self, gm: fx.GraphModule, example_inputs: list[Tensor]
|
|
|
) -> CompiledFn:
|
|
|
with dynamo_timed(
|
|
|
"OutputGraph.call_user_compiler",
|
|
|
phase_name="backend_compile",
|
|
|
log_pt2_compile_event=True,
|
|
|
log_waitcounter=True,
|
|
|
waitcounter_name_override="compile_aot_autograd",
|
|
|
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
|
|
|
):
|
|
|
return self._call_user_compiler(gm, example_inputs)
|
|
|
|
|
|
def _call_user_compiler(
|
|
|
self, gm: fx.GraphModule, example_inputs: list[Tensor]
|
|
|
) -> CompiledFn:
|
|
|
assert self.compiler_fn is not None
|
|
|
tot = 0
|
|
|
placeholders = []
|
|
|
for node in gm.graph.nodes:
|
|
|
if node.op in ("call_function", "call_method", "call_module"):
|
|
|
tot += 1
|
|
|
if node.op == "placeholder":
|
|
|
placeholders.append(node)
|
|
|
increment_op_count(tot)
|
|
|
for pl in placeholders:
|
|
|
if not hasattr(pl, "_dynamo_source"):
|
|
|
arg = pl.meta["grapharg"]
|
|
|
|
|
|
|
|
|
pl._dynamo_source = arg.source
|
|
|
|
|
|
|
|
|
gm._param_name_to_source = self.param_name_to_source
|
|
|
gm._source_to_user_stacks = self.source_to_user_stacks
|
|
|
|
|
|
name = (
|
|
|
self.compiler_fn.__name__
|
|
|
if hasattr(self.compiler_fn, "__name__")
|
|
|
else "<unknown compiler_fn>"
|
|
|
)
|
|
|
try:
|
|
|
_step_logger()(logging.INFO, f"calling compiler function {name}")
|
|
|
compiler_fn = self.compiler_fn
|
|
|
if config.verify_correctness:
|
|
|
compiler_fn = WrapperBackend(compiler_fn)
|
|
|
compiled_fn = compiler_fn(gm, example_inputs)
|
|
|
_step_logger()(logging.INFO, f"done compiler function {name}")
|
|
|
assert callable(compiled_fn), "compiler_fn did not return callable"
|
|
|
except (TensorifyScalarRestartAnalysis, ShortenTraceback):
|
|
|
raise
|
|
|
except exceptions_allowed_to_be_fallback as e:
|
|
|
if self.has_user_defined_allowed_in_graph:
|
|
|
raise BackendCompilerFailed(
|
|
|
self.compiler_fn, e, inspect.currentframe()
|
|
|
).with_traceback(e.__traceback__) from None
|
|
|
unimplemented_v2_with_warning(
|
|
|
e,
|
|
|
self.root_tx.f_code,
|
|
|
gb_type="Backend compiler exception",
|
|
|
context=f"Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}",
|
|
|
explanation=f"Backend compiler `{name}` failed with {str(e)}. Adding a graph break.",
|
|
|
hints=[
|
|
|
"Report an issue to the backend compiler repo.",
|
|
|
],
|
|
|
)
|
|
|
except SkipFrame as e:
|
|
|
|
|
|
|
|
|
raise e
|
|
|
except Exception as e:
|
|
|
raise BackendCompilerFailed(
|
|
|
self.compiler_fn, e, inspect.currentframe()
|
|
|
).with_traceback(e.__traceback__) from None
|
|
|
|
|
|
signpost_event(
|
|
|
"dynamo",
|
|
|
"OutputGraph.call_user_compiler",
|
|
|
{
|
|
|
**self.co_fields,
|
|
|
"op_count": tot,
|
|
|
"node_count": len(gm.graph.nodes),
|
|
|
"input_count": len(placeholders),
|
|
|
},
|
|
|
)
|
|
|
|
|
|
return compiled_fn
|
|
|
|
|
|
def dedup_pass(self):
|
|
|
if torch._dynamo.config.use_graph_deduplication:
|
|
|
return apply_graph_deduplication(self)
|
|
|
else:
|
|
|
return {}
|
|
|
|
|
|
def install_subgraph(self, name, sub_gm):
|
|
|
next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True)
|
|
|
sub_gm.__name__ = next_name
|
|
|
sub_gm.torchdynamo_force_dynamic = False
|
|
|
|
|
|
|
|
|
self.register_attr_or_module(sub_gm, next_name, source=None)
|
|
|
return next_name
|
|
|
|
|
|
def example_inputs(self) -> list[torch.Tensor]:
|
|
|
result = [arg.example for arg in self.graphargs]
|
|
|
return result
|
|
|
|
|
|
def remove_unused_get_attr_nodes(self) -> None:
|
|
|
for node in sorted(self.graph.find_nodes(op="get_attr"), reverse=True):
|
|
|
if len(list(node.users)) == 0:
|
|
|
self.remove_node(node)
|
|
|
|
|
|
def remove_unused_graphargs(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert self.should_exit
|
|
|
|
|
|
|
|
|
def is_static_true(b_node: fx.node.Argument):
|
|
|
if b_node is True:
|
|
|
return True
|
|
|
if not isinstance(b_node, fx.Node):
|
|
|
return False
|
|
|
b = b_node.meta.get("example_value")
|
|
|
if b is None:
|
|
|
return False
|
|
|
if b is True:
|
|
|
return True
|
|
|
if (
|
|
|
isinstance(b, torch.SymBool)
|
|
|
and (r := b.node.maybe_as_bool()) is not None
|
|
|
):
|
|
|
return r
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
def is_symnode_arg(a: fx.node.Argument):
|
|
|
from torch.fx.experimental.sym_node import SymTypes
|
|
|
|
|
|
if isinstance(a, (int, float, bool)):
|
|
|
return True
|
|
|
if isinstance(a, fx.Node):
|
|
|
return isinstance(a.meta.get("example_value"), SymTypes)
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_symnode_compute_node(node):
|
|
|
from torch.fx.experimental.sym_node import SymTypes
|
|
|
|
|
|
if node.op != "call_function":
|
|
|
return False
|
|
|
|
|
|
if not isinstance(node.meta.get("example_value"), SymTypes):
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
if not all(is_symnode_arg(a) for a in node.args):
|
|
|
return False
|
|
|
if not all(is_symnode_arg(a) for a in node.kwargs.values()):
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import is_accessor_node
|
|
|
|
|
|
for node in reversed(list(self.graph.nodes)):
|
|
|
if len(list(node.users)) == 0:
|
|
|
if (
|
|
|
node.op == "get_attr"
|
|
|
or (node.op == "call_function" and node.target is operator.getitem)
|
|
|
or (
|
|
|
node.op == "call_function"
|
|
|
and node.target is torch._check
|
|
|
and is_static_true(node.args[0])
|
|
|
)
|
|
|
or is_symnode_compute_node(node)
|
|
|
or is_accessor_node(node)
|
|
|
):
|
|
|
self.remove_node(node)
|
|
|
|
|
|
def placeholder_binds_symbol(node):
|
|
|
arg = node.meta["grapharg"]
|
|
|
example = arg.example
|
|
|
if isinstance(example, torch.SymInt) and isinstance(
|
|
|
example.node.expr, sympy.Symbol
|
|
|
):
|
|
|
return example.node.expr
|
|
|
return None
|
|
|
|
|
|
def remove_unused(node):
|
|
|
log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
|
|
|
|
|
|
|
|
|
del node.meta["grapharg"]
|
|
|
self.remove_node(node)
|
|
|
self.real_value_cache.pop(node, None)
|
|
|
|
|
|
used_symbols: set[sympy.Symbol] = set()
|
|
|
|
|
|
def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
|
|
|
used_symbols |= free_symbols(fake)
|
|
|
|
|
|
recheck_placeholders = []
|
|
|
for node in self.placeholders:
|
|
|
binds_symbol = placeholder_binds_symbol(node) is not None
|
|
|
|
|
|
if binds_symbol:
|
|
|
if not node.users:
|
|
|
recheck_placeholders.append(node)
|
|
|
else:
|
|
|
if not node.users and not isinstance(
|
|
|
node.meta["grapharg"], BackwardStateGraphArg
|
|
|
):
|
|
|
remove_unused(node)
|
|
|
else:
|
|
|
|
|
|
arg = node.meta["grapharg"]
|
|
|
if isinstance(arg, BackwardStateGraphArg):
|
|
|
continue
|
|
|
if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
|
|
|
real_script_obj = node.meta["grapharg"].example
|
|
|
fake_script_obj = node.meta["grapharg"].example_strong_ref
|
|
|
if not torch._library.fake_class_registry.tracing_with_real(
|
|
|
real_script_obj
|
|
|
):
|
|
|
flat_dict = dict(real_script_obj.__obj_flatten__())
|
|
|
for attr in flat_dict.keys():
|
|
|
fake_attr_val = getattr(
|
|
|
fake_script_obj.wrapped_obj, attr
|
|
|
)
|
|
|
pytree.tree_map_only(
|
|
|
(torch.SymInt, torch.Tensor),
|
|
|
lambda t: update_used_symbols(used_symbols, t),
|
|
|
fake_attr_val,
|
|
|
)
|
|
|
continue
|
|
|
fake = (
|
|
|
arg.fake_tensor if arg.fake_tensor is not None else arg.example
|
|
|
)
|
|
|
update_used_symbols(used_symbols, fake)
|
|
|
|
|
|
|
|
|
for node in recheck_placeholders:
|
|
|
symbol = placeholder_binds_symbol(node)
|
|
|
if symbol is not None:
|
|
|
if symbol not in used_symbols:
|
|
|
remove_unused(node)
|
|
|
else:
|
|
|
|
|
|
used_symbols.remove(symbol)
|
|
|
|
|
|
def remove_tensorify_specialized_graphargs(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch._dynamo.symbolic_convert import TensorifyState
|
|
|
|
|
|
for node in self.graph.nodes:
|
|
|
example_value = node.meta.get("example_value")
|
|
|
if (
|
|
|
isinstance(example_value, FakeTensor)
|
|
|
and example_value.item_memo is not None
|
|
|
and hasattr(example_value.item_memo.node._expr, "name")
|
|
|
and all(u.target == "item" for u in node.users)
|
|
|
and TensorifyState.should_specialize(
|
|
|
|
|
|
example_value.item_memo.node._expr.name
|
|
|
)
|
|
|
):
|
|
|
for u in list(node.users):
|
|
|
u.replace_all_uses_with(guard_scalar(example_value.item_memo))
|
|
|
self.remove_node(u)
|
|
|
self.remove_node(node)
|
|
|
|
|
|
def add_output_instructions(self, prefix: list[Instruction]) -> None:
|
|
|
"""
|
|
|
We call this on the creation of a new compiled subgraph that is inserted
|
|
|
before user code.
|
|
|
"""
|
|
|
self.output_instructions.extend(prefix)
|
|
|
self.should_exit = True
|
|
|
|
|
|
def install_global_unsafe(self, name, value) -> None:
|
|
|
"""
|
|
|
WARNING: prefer the safer `install_global_by_id/install_global`.
|
|
|
torch.compile instances should be independent of each other;
|
|
|
one footgun is to have one instance depend on the existence of
|
|
|
a global installed by another instance. This can happen if we mangle
|
|
|
a global the same way across both instances.
|
|
|
"""
|
|
|
assert name not in self.installed_globals
|
|
|
self.installed_globals.add(name)
|
|
|
self.cleanups.append(CleanupHook.create(self.global_scope, name, value))
|
|
|
|
|
|
def install_global_by_id(self, prefix, value) -> str:
|
|
|
"""
|
|
|
Installs a global if it hasn't been installed already.
|
|
|
This is determined by (prefix, id(value)) pair.
|
|
|
|
|
|
Returns the name of the newly installed global.
|
|
|
"""
|
|
|
|
|
|
|
|
|
name = f"{prefix}_{id(value)}_c{self.compile_id}"
|
|
|
if name in self.installed_globals:
|
|
|
return name
|
|
|
self.install_global_unsafe(name, value)
|
|
|
return name
|
|
|
|
|
|
def install_global(self, prefix, value) -> str:
|
|
|
"""
|
|
|
Installs a global, generating a unique name for it.
|
|
|
|
|
|
Returns the name of the newly installed global.
|
|
|
"""
|
|
|
|
|
|
name = unique_id(prefix)
|
|
|
self.install_global_unsafe(name, value)
|
|
|
return name
|
|
|
|
|
|
def cleanup(self) -> None:
|
|
|
|
|
|
|
|
|
self.root_tx = None
|
|
|
self.nn_modules.clear()
|
|
|
self.param_name_to_source = None
|
|
|
|
|
|
for node in self.graph.nodes:
|
|
|
if "grapharg" in node.meta:
|
|
|
del node.meta["grapharg"]
|
|
|
self.real_value_cache.clear()
|
|
|
self.input_name_to_proxy.clear()
|
|
|
self.side_effects.clear()
|
|
|
self.variable_tracker_cache.clear()
|
|
|
self.register_finalizer_fns.clear()
|
|
|
self.dynamo_flat_name_to_original_fqn.clear()
|
|
|
self.tracing_context.clear()
|
|
|
self.input_source_to_var.clear()
|
|
|
self.unspec_variable_map.clear()
|
|
|
self.backward_state.clear()
|
|
|
|
|
|
def add_graph_finalizer(
|
|
|
self, register_finalizer: Callable[[fx.GraphModule], None]
|
|
|
) -> None:
|
|
|
self.register_finalizer_fns.append(register_finalizer)
|
|
|
|
|
|
def example_value_from_input_node(self, node: torch.fx.Node):
|
|
|
"""Extract the non-fake example tensor"""
|
|
|
if node.op == "placeholder":
|
|
|
return node.meta["grapharg"].example
|
|
|
assert node.op == "get_attr"
|
|
|
return self.nn_modules[node.target]
|
|
|
|
|
|
|
|
|
err_epilogue = (
|
|
|
"With the current config, we will graph break "
|
|
|
"(and fall back to eager-mode PyTorch) on all ops "
|
|
|
"that have do not have the 'pt2_compliant_tag'. "
|
|
|
"Please see the following doc for how to mark this op as PT2 compliant "
|
|
|
"https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
|
|
|
)
|
|
|
|
|
|
|
|
|
def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
|
|
|
if kind != "call_function":
|
|
|
return
|
|
|
|
|
|
def encountered_compliant_op(target):
|
|
|
if target.namespace in {"prim", "prims", "aten"}:
|
|
|
return
|
|
|
output_graph.compliant_custom_ops.add(target)
|
|
|
|
|
|
def encountered_non_compliant_op(target, msg):
|
|
|
output_graph.non_compliant_ops.add(target)
|
|
|
if config.only_allow_pt2_compliant_ops:
|
|
|
unimplemented_v2(
|
|
|
gb_type="Encountered non-PT2-compliant op",
|
|
|
context="",
|
|
|
explanation=msg + " " + err_epilogue,
|
|
|
hints=[],
|
|
|
)
|
|
|
|
|
|
if isinstance(target, torch._ops.OpOverload):
|
|
|
if torch.Tag.pt2_compliant_tag in target.tags:
|
|
|
encountered_compliant_op(target)
|
|
|
return
|
|
|
encountered_non_compliant_op(
|
|
|
target,
|
|
|
f"Encountered the torch.ops.OpOverload {target} that is not PT2 compliant.",
|
|
|
)
|
|
|
return
|
|
|
|
|
|
if isinstance(target, torch._ops.OpOverloadPacket):
|
|
|
overloads = tuple(target.overloads())
|
|
|
|
|
|
|
|
|
if len(overloads) == 1:
|
|
|
op = getattr(target, overloads[0])
|
|
|
if torch.Tag.pt2_compliant_tag in op.tags:
|
|
|
encountered_compliant_op(op)
|
|
|
return
|
|
|
encountered_non_compliant_op(
|
|
|
op,
|
|
|
f"Encountered the non-overloaded "
|
|
|
f"torch.ops.OpOverloadPacket {target} "
|
|
|
f"that is not PT2 compliant. ",
|
|
|
)
|
|
|
return
|
|
|
|
|
|
args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes(
|
|
|
output_graph.current_tx, (args, kwargs), False
|
|
|
)
|
|
|
try:
|
|
|
overload = torch._C._jit_resolve_packet(
|
|
|
target._qualified_op_name, *args, **kwargs
|
|
|
)
|
|
|
except RuntimeError as e:
|
|
|
unimplemented_v2(
|
|
|
gb_type="Error when attempting to resolve op packet",
|
|
|
context="",
|
|
|
explanation=str(e),
|
|
|
hints=[],
|
|
|
)
|
|
|
|
|
|
op = getattr(target, overload)
|
|
|
if torch.Tag.pt2_compliant_tag in op.tags:
|
|
|
encountered_compliant_op(op)
|
|
|
else:
|
|
|
encountered_non_compliant_op(
|
|
|
op,
|
|
|
f"Encountered the torch.ops.OpOverloadPacket {target} "
|
|
|
f"which resolves to the overload ({overload}) that is "
|
|
|
f"not PT2 compliant.",
|
|
|
)
|
|
|
|
|
|
|
|
|
_compile_id_counter = itertools.count()
|
|
|
|
|
|
|
|
|
class LazyProxy:
|
|
|
def __init__(self, tracer, fn, *args, **kwargs):
|
|
|
self.tracer = tracer
|
|
|
self.fn = fn
|
|
|
self.args = args
|
|
|
self.kwargs = kwargs
|
|
|
|
|
|
def __call__(self):
|
|
|
return self.fn(*self.args, **self.kwargs)
|
|
|
|
|
|
|
|
|
class SubgraphTracer(fx.Tracer):
|
|
|
"""
|
|
|
Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
|
|
|
and the separation of responsibilities is that SubgraphTracer is
|
|
|
responsible for building the graph while OutputGraph is responsible for
|
|
|
compiling and executing the graph.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, output_graph, parent=None, is_export=False, source_target=None):
|
|
|
super().__init__()
|
|
|
self.output_graph = weakref.proxy(output_graph)
|
|
|
self.graph = torch.fx.Graph()
|
|
|
|
|
|
|
|
|
self.is_export = is_export
|
|
|
|
|
|
|
|
|
|
|
|
self.input_name_to_proxy: dict[str, fx.Proxy] = {}
|
|
|
|
|
|
self.real_value_cache: dict[fx.Node, torch.Tensor] = {}
|
|
|
|
|
|
|
|
|
self.parent = parent
|
|
|
self.source_target = source_target
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lifted_freevars = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}
|
|
|
|
|
|
self.prev_inst = None
|
|
|
|
|
|
|
|
|
self.under_activation_checkpoint = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.allow_side_effects_under_checkpoint = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.unsafe_allow_externally_visible_side_effects = False
|
|
|
|
|
|
|
|
|
self.is_reconstructing_generator = False
|
|
|
|
|
|
self.debug_level: int = parent.debug_level + 1 if parent is not None else 0
|
|
|
|
|
|
self._cur_code = None
|
|
|
self._orig_gm_meta = None
|
|
|
self._orig_gm_lineno_map = None
|
|
|
self._orig_gm_firstlineno = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.parent is None:
|
|
|
self.source_fn_stack = []
|
|
|
else:
|
|
|
self.source_fn_stack = self.parent.source_fn_stack + [
|
|
|
(self.graph._target_to_str(source_target), source_target)
|
|
|
]
|
|
|
|
|
|
|
|
|
self._used_names: OrderedSet[str] = OrderedSet()
|
|
|
|
|
|
|
|
|
self._input_versions_at_beginning: list[int] = []
|
|
|
if torch.is_inference_mode_enabled():
|
|
|
raise RuntimeError(
|
|
|
"Inference mode is supposed to be disabled during compilation. Please open an issue."
|
|
|
)
|
|
|
|
|
|
|
|
|
def _maybe_preserve_original_meta(self, tx, node):
|
|
|
if (
|
|
|
self._orig_gm_meta
|
|
|
and self._orig_gm_lineno_map
|
|
|
and self._orig_gm_firstlineno
|
|
|
):
|
|
|
lineno = tx.current_instruction.starts_line
|
|
|
node_idx = None
|
|
|
if lineno is not None:
|
|
|
node_idx = self._orig_gm_lineno_map.get(
|
|
|
lineno - self._orig_gm_firstlineno, None
|
|
|
)
|
|
|
if node_idx is not None:
|
|
|
meta = self._orig_gm_meta[node_idx]
|
|
|
for field in fx.proxy._COPY_META_FIELDS:
|
|
|
if field in meta:
|
|
|
node.meta[field] = meta[field]
|
|
|
if "stack_trace" in meta:
|
|
|
node.meta["stack_trace"] = meta["stack_trace"]
|
|
|
|
|
|
def create_proxy(
|
|
|
self,
|
|
|
kind,
|
|
|
target,
|
|
|
args,
|
|
|
kwargs,
|
|
|
name=None,
|
|
|
type_expr=None,
|
|
|
proxy_factory_fn=None,
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.parent is not None:
|
|
|
flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
|
|
|
new_flat_args = []
|
|
|
for arg in flat_args:
|
|
|
maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg)
|
|
|
new_flat_args.append(maybe_new_arg)
|
|
|
|
|
|
args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)
|
|
|
|
|
|
rv = super().create_proxy(
|
|
|
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
|
|
|
)
|
|
|
|
|
|
|
|
|
tx = self.output_graph.current_tx
|
|
|
|
|
|
|
|
|
if sys.version_info >= (3, 11) and kind in (
|
|
|
"call_function",
|
|
|
"call_method",
|
|
|
"call_module",
|
|
|
):
|
|
|
cur_inst = tx.current_instruction
|
|
|
if (
|
|
|
cur_inst is not self.prev_inst
|
|
|
and cur_inst.positions is not None
|
|
|
and cur_inst.positions.lineno is not None
|
|
|
):
|
|
|
tx_code = tx.f_code
|
|
|
header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
|
|
|
|
|
|
def get_trace_call_log_str():
|
|
|
line = get_instruction_source_311(tx_code, cur_inst).rstrip()
|
|
|
return f"TRACE FX call {rv.node.name} from {header}\n{line}"
|
|
|
|
|
|
trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
|
|
|
self.prev_inst = cur_inst
|
|
|
|
|
|
|
|
|
is_retracing = False
|
|
|
if tx.f_code is not self._cur_code:
|
|
|
orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
|
|
|
"orig_graphmodule", lambda: None
|
|
|
)()
|
|
|
if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
|
|
|
is_retracing = True
|
|
|
self._orig_gm_meta = [
|
|
|
nd.meta for nd in orig_graphmodule_maybe.graph.nodes
|
|
|
]
|
|
|
self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
|
|
|
self._orig_gm_firstlineno = (
|
|
|
orig_graphmodule_maybe.forward.__code__.co_firstlineno
|
|
|
)
|
|
|
else:
|
|
|
self._orig_gm_meta = None
|
|
|
self._orig_gm_lineno_map = None
|
|
|
self._orig_gm_firstlineno = None
|
|
|
nn_module_stack = tx.nn_module_stack
|
|
|
if nn_module_stack:
|
|
|
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
|
|
|
|
|
|
if kind in {"call_function", "call_method"}:
|
|
|
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
|
|
|
(rv.node.name, target)
|
|
|
]
|
|
|
elif kind == "call_module":
|
|
|
if self.parent is not None:
|
|
|
|
|
|
unimplemented_v2(
|
|
|
gb_type="Invoking an nn.Module inside a higher order operator",
|
|
|
context=f"Higher order op name: {self.source_target}",
|
|
|
explanation="This is not supported.",
|
|
|
hints=[],
|
|
|
)
|
|
|
|
|
|
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
|
|
|
(
|
|
|
rv.node.name,
|
|
|
next(
|
|
|
ty
|
|
|
for k, (_, ty) in rv.node.meta["nn_module_stack"].items()
|
|
|
if k.split("@")[0] == target
|
|
|
),
|
|
|
)
|
|
|
]
|
|
|
|
|
|
self._maybe_preserve_original_meta(tx, rv.node)
|
|
|
|
|
|
if not is_retracing:
|
|
|
if "nn_module_stack" not in rv.node.meta:
|
|
|
nn_module_stack = tx.nn_module_stack
|
|
|
if nn_module_stack:
|
|
|
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
|
|
|
|
|
|
if "source_fn_stack" not in rv.node.meta:
|
|
|
if kind in {"call_function", "call_method"}:
|
|
|
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
|
|
|
(rv.node.name, target)
|
|
|
]
|
|
|
elif kind == "call_module":
|
|
|
if self.parent is not None:
|
|
|
|
|
|
unimplemented_v2(
|
|
|
gb_type="Invoking an nn.Module inside a HigherOrderOperator",
|
|
|
context="",
|
|
|
explanation="This is not supported.",
|
|
|
hints=[],
|
|
|
)
|
|
|
|
|
|
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
|
|
|
(
|
|
|
rv.node.name,
|
|
|
rv.node.meta["nn_module_stack"][target][1],
|
|
|
)
|
|
|
]
|
|
|
|
|
|
if "stack_trace" not in rv.node.meta:
|
|
|
frame_summaries: list[traceback.FrameSummary] = []
|
|
|
while tx:
|
|
|
|
|
|
|
|
|
if not tx.is_co_filename_from_nn_modules():
|
|
|
frame_summaries.append(tx.frame_summary())
|
|
|
tx = getattr(tx, "parent", None)
|
|
|
|
|
|
frame_summaries.reverse()
|
|
|
|
|
|
|
|
|
msgs = traceback.StackSummary.from_list(frame_summaries).format()
|
|
|
rv.node.stack_trace = "".join(msgs)
|
|
|
|
|
|
if (
|
|
|
torch._dynamo.config.use_graph_deduplication
|
|
|
or torch._dynamo.config.track_nodes_for_deduplication
|
|
|
):
|
|
|
self.output_graph.region_tracker.track_node(
|
|
|
self.output_graph.current_tx, rv.node
|
|
|
)
|
|
|
return rv
|
|
|
|
|
|
def create_node(
|
|
|
self, op, target, args=None, kwargs=None, name=None, type_expr=None
|
|
|
):
|
|
|
check_pt2_compliant_op(self.output_graph, op, target, args, kwargs)
|
|
|
if self.parent is not None:
|
|
|
flat_args = pytree.arg_tree_leaves(*args, **kwargs)
|
|
|
for arg in flat_args:
|
|
|
if not isinstance(arg, torch.fx.Node):
|
|
|
continue
|
|
|
assert arg.graph == self.graph, (
|
|
|
"create_node using arg not from this SubgraphTracer"
|
|
|
)
|
|
|
|
|
|
node = super().create_node(op, target, args, kwargs, name, type_expr)
|
|
|
node.meta["creation_timestamp"] = self.output_graph.timestamp
|
|
|
self._used_names.add(node.name)
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
|
|
def remove_node(self, node):
|
|
|
if len(node.users) > 0:
|
|
|
user_graph_nodes: list[torch.fx.Node] = []
|
|
|
for user in node.users.keys():
|
|
|
|
|
|
|
|
|
if user.graph != self.graph:
|
|
|
|
|
|
|
|
|
|
|
|
user_graph_nodes.extend(reversed(list(user.graph.nodes)))
|
|
|
for other_graph_node in user_graph_nodes:
|
|
|
other_graph_node.graph.erase_node(other_graph_node)
|
|
|
self.graph.erase_node(node)
|
|
|
self.input_name_to_proxy.pop(node.name, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_graph_input(
|
|
|
self, name, type_expr, example_value, before=False, source=None
|
|
|
):
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
self._input_versions_at_beginning.append(example_value._version)
|
|
|
log.debug(
|
|
|
"create_graph_input %s %s %s at debug_level %s before=%s",
|
|
|
name,
|
|
|
source.name() if source is not None else "(none)",
|
|
|
example_value,
|
|
|
self.debug_level,
|
|
|
before,
|
|
|
)
|
|
|
if source is None:
|
|
|
assert self.parent is not None, (
|
|
|
f"you are required to provide a source for inputs {name} example_val {example_value} on the root tracer"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.is_export and self.parent is None:
|
|
|
if not is_from_local_source(source, only_allow_input=True):
|
|
|
self.output_graph.source_to_user_stacks.setdefault(source, []).append(
|
|
|
TracingContext.extract_stack()
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
name = get_unique_name_wrt(name, self._used_names)
|
|
|
if self.input_name_to_proxy:
|
|
|
prev_name = next(reversed(self.input_name_to_proxy))
|
|
|
node = self.input_name_to_proxy[prev_name].node
|
|
|
if before:
|
|
|
ctx = self.graph.inserting_before(node)
|
|
|
else:
|
|
|
ctx = self.graph.inserting_after(node)
|
|
|
else:
|
|
|
ctx = self.graph.inserting_before(None)
|
|
|
with ctx:
|
|
|
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
|
|
|
set_example_value(proxy.node, example_value)
|
|
|
if self.input_name_to_proxy and before:
|
|
|
k, v = self.input_name_to_proxy.popitem()
|
|
|
self.input_name_to_proxy[name] = proxy
|
|
|
self.input_name_to_proxy[k] = v
|
|
|
else:
|
|
|
self.input_name_to_proxy[name] = proxy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._used_names.add(name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_strict_export = self.is_export
|
|
|
is_non_strict_export = torch.compiler.is_compiling()
|
|
|
if not is_strict_export and not is_non_strict_export:
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
self._lift_basic_symbols(example_value, source)
|
|
|
elif isinstance(example_value, (list, tuple)):
|
|
|
for i, e in enumerate(example_value):
|
|
|
if not isinstance(e, torch.Tensor):
|
|
|
continue
|
|
|
|
|
|
e_source = None
|
|
|
if source:
|
|
|
e_source = GetItemSource(
|
|
|
base=source, index=i, index_is_slice=False
|
|
|
)
|
|
|
|
|
|
self._lift_basic_symbols(e, e_source)
|
|
|
|
|
|
|
|
|
if isinstance(example_value, torch.SymInt) and isinstance(
|
|
|
example_value.node.expr, sympy.Symbol
|
|
|
):
|
|
|
self.bound_symbols[example_value.node.expr] = proxy
|
|
|
return proxy
|
|
|
|
|
|
|
|
|
def lift_tracked_freevar_to_input(self, proxy):
|
|
|
|
|
|
|
|
|
assert self.parent is not None, (
|
|
|
"lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
|
|
|
)
|
|
|
|
|
|
example_value = proxy.node.meta["example_value"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
isinstance(example_value, torch.SymInt)
|
|
|
and example_value.node.expr in self.bound_symbols
|
|
|
):
|
|
|
return self.bound_symbols[example_value.node.expr]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if proxy in self.lifted_freevars:
|
|
|
return self.lifted_freevars[proxy]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if proxy.tracer != self.parent:
|
|
|
self.parent.lift_tracked_freevar_to_input(proxy)
|
|
|
|
|
|
example_value = proxy.node.meta["example_value"]
|
|
|
new_proxy = self.create_graph_input(
|
|
|
proxy.node.name, type(example_value), example_value
|
|
|
)
|
|
|
self.lifted_freevars[proxy] = new_proxy
|
|
|
return new_proxy
|
|
|
|
|
|
def maybe_lift_tracked_freevar_to_input(self, arg):
|
|
|
"""
|
|
|
If arg is a free variable, then lift it to be an input.
|
|
|
Returns the new lifted arg (if arg was a freevar), else the
|
|
|
original arg.
|
|
|
"""
|
|
|
if not isinstance(arg, torch.fx.Proxy):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(arg, slice):
|
|
|
return slice(
|
|
|
*(
|
|
|
self.maybe_lift_tracked_freevar_to_input(sub_arg)
|
|
|
for sub_arg in (arg.start, arg.stop, arg.step)
|
|
|
)
|
|
|
)
|
|
|
else:
|
|
|
return arg
|
|
|
elif arg.tracer == self:
|
|
|
return arg
|
|
|
return self.lift_tracked_freevar_to_input(arg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def track_unbacked_symbols(
|
|
|
self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy]
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tracer = e_proxy.tracer
|
|
|
assert isinstance(tracer, SubgraphTracer)
|
|
|
|
|
|
def need_bind(s) -> bool:
|
|
|
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
|
|
|
|
|
return (
|
|
|
is_symbolic(s)
|
|
|
and isinstance(s.node.expr, sympy.Symbol)
|
|
|
and s.node.shape_env.is_unbacked_symint(s.node.expr)
|
|
|
and s.node.expr not in self.bound_symbols
|
|
|
)
|
|
|
|
|
|
def _proxy_with_example_value(example_value, *args, **kwargs):
|
|
|
proxy = tracer.create_proxy(*args, **kwargs)
|
|
|
set_example_value(proxy.node, example_value)
|
|
|
return proxy
|
|
|
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
for i, s in enumerate(example_value.size()):
|
|
|
if need_bind(s):
|
|
|
log.debug(
|
|
|
"_track_unbacked_symbols %s for %s.size()[%s] at debug_level %s",
|
|
|
s,
|
|
|
e_proxy,
|
|
|
i,
|
|
|
tracer.debug_level,
|
|
|
)
|
|
|
lazy_proxy = LazyProxy(
|
|
|
tracer,
|
|
|
_proxy_with_example_value,
|
|
|
s,
|
|
|
"call_function",
|
|
|
torch.ops.aten.sym_size.int,
|
|
|
(e_proxy, i),
|
|
|
{},
|
|
|
type_expr=type(s),
|
|
|
)
|
|
|
self.track_unbacked_symbols(s, lazy_proxy)
|
|
|
|
|
|
if example_value.layout is torch.strided:
|
|
|
for i, s in enumerate(example_value.stride()):
|
|
|
if need_bind(s):
|
|
|
log.debug(
|
|
|
"_track_unbacked_symbols %s for %s.stride()[%s] at debug_level %s",
|
|
|
s,
|
|
|
e_proxy,
|
|
|
i,
|
|
|
tracer.debug_level,
|
|
|
)
|
|
|
lazy_proxy = LazyProxy(
|
|
|
tracer,
|
|
|
_proxy_with_example_value,
|
|
|
s,
|
|
|
"call_function",
|
|
|
torch.ops.aten.sym_stride.int,
|
|
|
(e_proxy, i),
|
|
|
{},
|
|
|
type_expr=type(s),
|
|
|
)
|
|
|
self.track_unbacked_symbols(s, lazy_proxy)
|
|
|
|
|
|
elif example_value.layout is torch.sparse_coo:
|
|
|
self.track_unbacked_symbols(example_value._indices(), e_proxy)
|
|
|
self.track_unbacked_symbols(example_value._values(), e_proxy)
|
|
|
elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
|
|
self.track_unbacked_symbols(example_value.crow_indices(), e_proxy)
|
|
|
self.track_unbacked_symbols(example_value.col_indices(), e_proxy)
|
|
|
elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
|
|
|
self.track_unbacked_symbols(example_value.ccol_indices(), e_proxy)
|
|
|
self.track_unbacked_symbols(example_value.row_indices(), e_proxy)
|
|
|
if is_traceable_wrapper_subclass(example_value):
|
|
|
attrs, ctx = example_value.__tensor_flatten__()
|
|
|
for attr in attrs:
|
|
|
inner_t = getattr(example_value, attr)
|
|
|
self.track_unbacked_symbols(inner_t, getattr(e_proxy, attr))
|
|
|
elif isinstance(example_value, torch.SymInt):
|
|
|
|
|
|
if need_bind(example_value):
|
|
|
expr = example_value.node.expr
|
|
|
tracer.bound_symbols[expr] = e_proxy
|
|
|
|
|
|
|
|
|
def _lift_basic_symbols(
|
|
|
self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source]
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _lift_symbols_in_symint(
|
|
|
s: Union[int, torch.SymInt],
|
|
|
source: Optional[Source],
|
|
|
before: bool = False,
|
|
|
) -> None:
|
|
|
if not is_symbolic(s):
|
|
|
return
|
|
|
|
|
|
assert isinstance(s, torch.SymInt)
|
|
|
self_to_be_bound = self.lookup_unbound_symbols(s)
|
|
|
if len(self_to_be_bound) == 0:
|
|
|
return
|
|
|
|
|
|
|
|
|
if self.parent is not None:
|
|
|
|
|
|
self.parent._lift_basic_symbols(s, source)
|
|
|
for s0 in self_to_be_bound:
|
|
|
parent_proxy = self.parent.bound_symbols[s0]
|
|
|
example_val = parent_proxy.node.meta["example_value"]
|
|
|
assert isinstance(example_val, torch.SymInt)
|
|
|
ph = self.create_graph_input(
|
|
|
str(s0),
|
|
|
type(example_val),
|
|
|
example_val,
|
|
|
before=before,
|
|
|
source=source,
|
|
|
)
|
|
|
log.debug(
|
|
|
"_lift_symbols_in_symint %s from %s at debug_level %s",
|
|
|
s0,
|
|
|
source.name() if source is not None else "subgraph inputs",
|
|
|
self.debug_level,
|
|
|
)
|
|
|
self.lifted_freevars[parent_proxy] = ph
|
|
|
|
|
|
else:
|
|
|
assert len(self_to_be_bound) == 1, (
|
|
|
f"For root tracer, we only expect to bind basic symbols (compound symbols "
|
|
|
f"should be cached before) but got unbound symbols {self_to_be_bound} in {s}"
|
|
|
)
|
|
|
assert source is not None, (
|
|
|
f"Source of '{s}' is None when lifting it to input of top-level. If it's an unbacked symbol, "
|
|
|
"this could be because it's not tracked with lazy_bind_unbacked_symbols. "
|
|
|
f"Otherwise, should provide a source when create_graph_input for `{s}` at root tracer."
|
|
|
)
|
|
|
s0 = next(iter(self_to_be_bound))
|
|
|
ph = self.create_graph_input(
|
|
|
str(s0),
|
|
|
type(s),
|
|
|
s,
|
|
|
before=before,
|
|
|
source=source,
|
|
|
)
|
|
|
log.debug(
|
|
|
"_lift_symbols_in_symint %s from %s at debug_level %s",
|
|
|
s,
|
|
|
source.name() if source is not None else "subgraph inputs",
|
|
|
self.debug_level,
|
|
|
)
|
|
|
ph.node.meta["grapharg"] = GraphArg(
|
|
|
source,
|
|
|
s,
|
|
|
pass_arg_as_tensor=False,
|
|
|
fake_tensor=None,
|
|
|
is_tensor=False,
|
|
|
)
|
|
|
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
for i, s in enumerate(example_value.size()):
|
|
|
_lift_symbols_in_symint(
|
|
|
s,
|
|
|
(
|
|
|
TensorPropertySource(src, TensorProperty.SIZE, i)
|
|
|
if src is not None
|
|
|
else None
|
|
|
),
|
|
|
before=True,
|
|
|
)
|
|
|
if example_value.layout is torch.strided:
|
|
|
for i, s in enumerate(example_value.stride()):
|
|
|
_lift_symbols_in_symint(
|
|
|
s,
|
|
|
(
|
|
|
TensorPropertySource(src, TensorProperty.STRIDE, i)
|
|
|
if src is not None
|
|
|
else None
|
|
|
),
|
|
|
before=True,
|
|
|
)
|
|
|
_lift_symbols_in_symint(
|
|
|
example_value.storage_offset(),
|
|
|
(
|
|
|
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET)
|
|
|
if src is not None
|
|
|
else None
|
|
|
),
|
|
|
before=True,
|
|
|
)
|
|
|
elif example_value.layout is torch.sparse_coo:
|
|
|
self._lift_basic_symbols(example_value._indices(), src)
|
|
|
self._lift_basic_symbols(example_value._values(), src)
|
|
|
elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
|
|
self._lift_basic_symbols(example_value.crow_indices(), src)
|
|
|
self._lift_basic_symbols(example_value.col_indices(), src)
|
|
|
elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
|
|
|
self._lift_basic_symbols(example_value.ccol_indices(), src)
|
|
|
self._lift_basic_symbols(example_value.row_indices(), src)
|
|
|
if is_traceable_wrapper_subclass(example_value):
|
|
|
attrs, ctx = example_value.__tensor_flatten__()
|
|
|
for attr in attrs:
|
|
|
inner_t = getattr(example_value, attr)
|
|
|
self._lift_basic_symbols(
|
|
|
inner_t, AttrSource(src, attr) if src is not None else None
|
|
|
)
|
|
|
elif isinstance(example_value, torch.SymInt):
|
|
|
_lift_symbols_in_symint(
|
|
|
example_value,
|
|
|
src,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def lookup_unbound_symbols(self, s: torch.SymInt) -> list[sympy.Symbol]:
|
|
|
free_symbols = s.node.expr.free_symbols
|
|
|
if len(free_symbols) == 0:
|
|
|
return []
|
|
|
|
|
|
to_be_bound = []
|
|
|
for s0 in free_symbols:
|
|
|
if s0 not in self.bound_symbols:
|
|
|
to_be_bound.append(s0)
|
|
|
continue
|
|
|
|
|
|
proxy = self.bound_symbols[s0]
|
|
|
if isinstance(proxy, LazyProxy):
|
|
|
proxy = proxy()
|
|
|
self.bound_symbols[s0] = proxy
|
|
|
assert isinstance(proxy, torch.fx.Proxy) and proxy.tracer is self, (
|
|
|
f"The proxy of symbol {s0} doesn't belong to current tracer."
|
|
|
)
|
|
|
|
|
|
return sorted(to_be_bound, key=lambda s: s.name)
|
|
|
|
|
|
def has_input_mutation(self):
|
|
|
input_versions_at_beginning = self._input_versions_at_beginning
|
|
|
input_nodes = []
|
|
|
|
|
|
input_versions_at_end = []
|
|
|
for node in self.graph.nodes:
|
|
|
if node.op == "placeholder":
|
|
|
example_value = node.meta["example_value"]
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
input_versions_at_end.append(example_value._version)
|
|
|
input_nodes.append(node)
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
mutated_inputs = [
|
|
|
i
|
|
|
for i, (v1, v2) in enumerate(
|
|
|
zip(input_versions_at_beginning, input_versions_at_end)
|
|
|
)
|
|
|
if v1 != v2
|
|
|
]
|
|
|
|
|
|
if len(mutated_inputs):
|
|
|
mutated_nodes = [input_nodes[i] for i in mutated_inputs]
|
|
|
msg = f"Input mutation detected at {mutated_nodes}"
|
|
|
return MutationInfo(True, msg)
|
|
|
|
|
|
return MutationInfo(False, "")
|
|
|
|
|
|
def has_aliasing(self):
|
|
|
from torch._higher_order_ops.utils import _collect_fake_inputs
|
|
|
|
|
|
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
|
|
|
|
|
for node in self.graph.nodes:
|
|
|
if node.op == "placeholder":
|
|
|
example_value = _collect_fake_inputs([node])[0]
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
storage = StorageWeakRef(example_value._typed_storage())
|
|
|
if storage in input_storages:
|
|
|
|
|
|
msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}"
|
|
|
return AliasingInfo(True, msg)
|
|
|
input_storages[storage] = node
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
|
|
out_nodes = self.graph.find_nodes(op="output")[0]
|
|
|
for out_node in pytree.tree_leaves(out_nodes.args[0]):
|
|
|
if out_node:
|
|
|
example_value = _collect_fake_inputs([out_node])[0]
|
|
|
assert not isinstance(example_value, list)
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
storage = StorageWeakRef(example_value._typed_storage())
|
|
|
if storage in output_storages:
|
|
|
|
|
|
msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}"
|
|
|
return AliasingInfo(True, msg)
|
|
|
output_storages[storage] = out_node
|
|
|
|
|
|
intersected_storages = input_storages.keys() & output_storages.keys()
|
|
|
if len(intersected_storages) > 0:
|
|
|
|
|
|
aliased = [
|
|
|
(input_storages[s], output_storages[s]) for s in intersected_storages
|
|
|
]
|
|
|
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
|
|
|
msg = f"Input-to-output aliasing detected at nodes {aliased}"
|
|
|
return AliasingInfo(True, msg)
|
|
|
|
|
|
return AliasingInfo(False, "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|